oasis_runtime_sdk_macros/module_derive/
migration_handler.rs

1use std::collections::HashSet;
2
3use proc_macro2::TokenStream;
4use quote::quote;
5
6use crate::emit_compile_error;
7
8/// Deriver for the `MigrationHandler` trait.
9pub struct DeriveMigrationHandler {
10    /// Item defining the `MigrationHandler::Genesis` associated type.
11    genesis_ty: Option<syn::ImplItem>,
12    /// Migration functions.
13    migrate_fns: Vec<MigrateFn>,
14}
15
16struct MigrateFn {
17    item: syn::ImplItem,
18    ident: syn::Ident,
19    from_version: u32,
20}
21
22impl DeriveMigrationHandler {
23    pub fn new() -> Box<Self> {
24        Box::new(Self {
25            genesis_ty: None,
26            migrate_fns: vec![],
27        })
28    }
29}
30
31impl super::Deriver for DeriveMigrationHandler {
32    fn preprocess(&mut self, item: syn::ImplItem) -> Option<syn::ImplItem> {
33        match item {
34            // We are looking for a `type Genesis = ...;` item.
35            syn::ImplItem::Type(ref ty) if &ty.ident.to_string() == "Genesis" => {
36                self.genesis_ty = Some(item);
37
38                None // Take the item.
39            }
40            syn::ImplItem::Fn(ref f) => {
41                // Check whether a `migration` attribute is set for the method.
42                if let Some(attrs) = parse_attrs(&f.attrs) {
43                    self.migrate_fns.push(MigrateFn {
44                        ident: f.sig.ident.clone(),
45                        from_version: attrs.from_version,
46                        item,
47                    });
48
49                    None // Take the item.
50                } else {
51                    Some(item) // Return the item.
52                }
53            }
54            _ => Some(item), // Return the item.
55        }
56    }
57
58    fn derive(&mut self, generics: &syn::Generics, ty: &Box<syn::Type>) -> TokenStream {
59        let genesis_ty = if let Some(genesis_ty) = &self.genesis_ty {
60            genesis_ty
61        } else {
62            return quote! {};
63        };
64
65        // Sort by version to ensure migrations are processed in the right order.
66        self.migrate_fns.sort_by_key(|f| f.from_version);
67
68        let mut seen_versions = HashSet::new();
69        let (migrate_fns, mut migrate_arms): (Vec<_>, Vec<_>) = self.migrate_fns.iter().map(|f| {
70            let MigrateFn { item, ident, from_version } = f;
71            if seen_versions.contains(from_version) {
72                emit_compile_error(format!(
73                    "Duplicate migration for version: {from_version}"
74                ));
75            }
76            seen_versions.insert(from_version);
77
78            (
79                item,
80                if from_version == &0 {
81                    // Version zero is special as initializing from genesis always gets us latest.
82                    quote! { if version == #from_version { Self::#ident(genesis); version = Self::VERSION; } }
83                } else {
84                    // For other versions, each migration brings us from V to V+1.
85                    // TODO: Add a compile-time assert that version < Self::VERSION.
86                    quote! { if version == #from_version && version < Self::VERSION { Self::#ident(); version += 1; } }
87                }
88            )
89        }).unzip();
90
91        // Ensure there is a genesis migration, at least an empty one that bumps the version.
92        if !seen_versions.contains(&0) {
93            migrate_arms.push(quote! {
94                if version == 0u32 { version = Self::VERSION; }
95            });
96        }
97
98        quote! {
99            #[automatically_derived]
100            impl #generics sdk::module::MigrationHandler for #ty {
101                #genesis_ty
102
103                fn init_or_migrate<C: Context>(
104                    _ctx: &C,
105                    meta: &mut sdk::modules::core::types::Metadata,
106                    genesis: Self::Genesis,
107                ) -> bool {
108                    let mut version = meta.versions.get(Self::NAME).copied().unwrap_or_default();
109                    if version == Self::VERSION {
110                        return false; // Already the latest version.
111                    }
112
113                    #(#migrate_arms)*
114
115                    if version != Self::VERSION {
116                        panic!("no migration for module state from version {version} to {}", Self::VERSION)
117                    }
118
119                    // Update version information.
120                    meta.versions.insert(Self::NAME.to_owned(), Self::VERSION);
121                    return true;
122                }
123            }
124
125            #[automatically_derived]
126            impl #generics #ty {
127                #(#migrate_fns)*
128            }
129        }
130    }
131}
132
133#[derive(Debug, Clone, PartialEq)]
134struct MigrationHandlerAttr {
135    /// Version that this handler handles. Zero indicates genesis.
136    from_version: u32,
137}
138impl syn::parse::Parse for MigrationHandlerAttr {
139    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
140        let kind: syn::Ident = input.parse()?;
141        let from_version = match kind.to_string().as_str() {
142            "init" => 0,
143            "from" => {
144                let _: syn::token::Eq = input.parse()?;
145                let version: syn::LitInt = input.parse()?;
146
147                version.base10_parse()?
148            }
149            _ => return Err(syn::Error::new(kind.span(), "invalid migration kind")),
150        };
151
152        if !input.is_empty() {
153            return Err(syn::Error::new(input.span(), "unexpected extra tokens"));
154        }
155        Ok(Self { from_version })
156    }
157}
158
159fn parse_attrs(attrs: &[syn::Attribute]) -> Option<MigrationHandlerAttr> {
160    let migration_meta = attrs
161        .iter()
162        .find(|attr| attr.path().is_ident("migration"))?;
163    migration_meta
164        .parse_args()
165        .map_err(|err| {
166            emit_compile_error(format!(
167                "Unsupported format of #[migration(...)] attribute: {err}"
168            ))
169        })
170        .ok()
171}