oasis_runtime_sdk_macros/module_derive/
method_handler.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::parse_quote;
4
5use crate::emit_compile_error;
6
7/// Deriver for the `MethodHandler` trait.
8pub struct DeriveMethodHandler {
9    handlers: Vec<ParsedImplItem>,
10}
11
12impl DeriveMethodHandler {
13    pub fn new() -> Box<Self> {
14        Box::new(Self { handlers: vec![] })
15    }
16}
17
18impl super::Deriver for DeriveMethodHandler {
19    fn preprocess(&mut self, item: syn::ImplItem) -> Option<syn::ImplItem> {
20        let method = match item {
21            syn::ImplItem::Fn(ref f) => f,
22            _ => return Some(item),
23        };
24
25        let attrs = if let Some(attrs) = parse_attrs(&method.attrs) {
26            attrs
27        } else {
28            return Some(item);
29        };
30
31        self.handlers.push(ParsedImplItem {
32            handler: Some(HandlerInfo {
33                attrs,
34                ident: method.sig.ident.clone(),
35            }),
36            item,
37        });
38
39        None // Take the item.
40    }
41
42    fn derive(&mut self, generics: &syn::Generics, ty: &Box<syn::Type>) -> TokenStream {
43        let handlers = &self.handlers;
44        let handler_items = handlers
45            .iter()
46            .map(|ParsedImplItem { item, .. }| item)
47            .collect::<Vec<_>>();
48
49        /// Generates parallel vectors of rpc names and handler idents for all handlers of kind `kind`.
50        fn filter_by_kind(
51            handlers: &[ParsedImplItem],
52            kind: HandlerKind,
53        ) -> (Vec<syn::Expr>, Vec<syn::Ident>) {
54            handlers
55                .iter()
56                .filter_map(|h| h.handler.as_ref())
57                .filter(|h| h.attrs.kind == kind)
58                .map(|h| (h.attrs.rpc_name.clone(), h.ident.clone()))
59                .unzip()
60        }
61
62        let prefetch_impl = {
63            let (handler_names, handler_idents) = filter_by_kind(handlers, HandlerKind::Prefetch);
64
65            // Find call handlers; for every call handler without a corresponding prefetch handler, we'll
66            // generate a dummy prefetch handler.
67            let (call_handler_names, _) = filter_by_kind(handlers, HandlerKind::Call);
68            let handler_names_without_impl: Vec<&syn::Expr> = call_handler_names
69                .iter()
70                .filter(|n| !handler_names.contains(n))
71                .collect();
72
73            if handler_names.is_empty() {
74                quote! {}
75            } else {
76                quote! {
77                    fn prefetch(
78                        prefixes: &mut BTreeSet<Prefix>,
79                        method: &str,
80                        body: cbor::Value,
81                        auth_info: &AuthInfo,
82                    ) -> module::DispatchResult<cbor::Value, Result<(), sdk::error::RuntimeError>> {
83                        let mut add_prefix = |p| {prefixes.insert(p);};
84                        match method {
85                            // "Real", user-defined prefetch handlers.
86                            #(
87                              #handler_names => module::DispatchResult::Handled(
88                                Self::#handler_idents(&mut add_prefix, body, auth_info)
89                              ),
90                            )*
91                            // No-op prefetch handlers.
92                            #(
93                              #handler_names_without_impl => module::DispatchResult::Handled(Ok(())),
94                            )*
95                            _ => module::DispatchResult::Unhandled(body),
96                        }
97                    }
98                }
99            }
100        };
101
102        let dispatch_call_impl = {
103            let (handler_names, handler_fns): (Vec<_>, Vec<_>) = handlers
104                .iter()
105                .filter_map(|h| h.handler.as_ref())
106                .filter(|h| h.attrs.kind == HandlerKind::Call)
107                .map(|h| {
108                    (h.attrs.rpc_name.clone(), {
109                        let ident = &h.ident;
110
111                        if h.attrs.is_internal {
112                            quote! {
113                                |ctx, body| {
114                                    if !sdk::state::CurrentState::with_env(|env| env.is_internal()) {
115                                        return Err(sdk::modules::core::Error::Forbidden.into());
116                                    }
117                                    Self::#ident(ctx, body)
118                                }
119                            }
120                        } else {
121                            quote! { Self::#ident }
122                        }
123                    })
124                })
125                .unzip();
126
127            if handler_names.is_empty() {
128                quote! {}
129            } else {
130                quote! {
131                    fn dispatch_call<C: Context>(
132                        ctx: &C,
133                        method: &str,
134                        body: cbor::Value,
135                    ) -> DispatchResult<cbor::Value, CallResult> {
136                        match method {
137                            #(
138                              #handler_names => module::dispatch_call(ctx, body, #handler_fns),
139                            )*
140                            _ => DispatchResult::Unhandled(body),
141                        }
142                    }
143                }
144            }
145        };
146
147        let query_parameters_impl = {
148            quote! {
149                fn query_parameters<C: Context>(_ctx: &C, _args: ()) -> Result<<Self as module::Module>::Parameters, <Self as module::Module>::Error> {
150                    Ok(Self::params())
151                }
152            }
153        };
154
155        let dispatch_query_impl = {
156            let (handler_names, handler_idents) = filter_by_kind(handlers, HandlerKind::Query);
157
158            if handler_names.is_empty() {
159                quote! {
160                    fn dispatch_query<C: Context>(
161                        ctx: &C,
162                        method: &str,
163                        args: cbor::Value,
164                    ) -> DispatchResult<cbor::Value, Result<cbor::Value, sdk::error::RuntimeError>> {
165                        match method {
166                            q if q == format!("{}.Parameters", Self::NAME) => module::dispatch_query(ctx, args, Self::query_parameters),
167                            _ => DispatchResult::Unhandled(args),
168                        }
169                    }
170                }
171            } else {
172                quote! {
173                    fn dispatch_query<C: Context>(
174                        ctx: &C,
175                        method: &str,
176                        args: cbor::Value,
177                    ) -> DispatchResult<cbor::Value, Result<cbor::Value, sdk::error::RuntimeError>> {
178                        match method {
179                            #(
180                              #handler_names => module::dispatch_query(ctx, args, Self::#handler_idents),
181                            )*
182                            q if q == format!("{}.Parameters", Self::NAME) => module::dispatch_query(ctx, args, Self::query_parameters),
183                            _ => DispatchResult::Unhandled(args),
184                        }
185                    }
186                }
187            }
188        };
189
190        let dispatch_message_result_impl = {
191            let (handler_names, handler_idents) =
192                filter_by_kind(handlers, HandlerKind::MessageResult);
193
194            if handler_names.is_empty() {
195                quote! {}
196            } else {
197                quote! {
198                    fn dispatch_message_result<C: Context>(
199                        ctx: &C,
200                        handler_name: &str,
201                        result: MessageResult,
202                    ) -> DispatchResult<MessageResult, ()> {
203                        match handler_name {
204                            #(
205                              #handler_names => {
206                                  Self::#handler_idents(
207                                      ctx,
208                                      result.event,
209                                      cbor::from_value(result.context).expect("invalid message handler context"),
210                                  );
211                                  DispatchResult::Handled(())
212                              }
213                            )*
214                            _ => DispatchResult::Unhandled(result),
215                        }
216                    }
217                }
218            }
219        };
220
221        let supported_methods_impl = {
222            let (handler_names, handler_kinds): (Vec<syn::Expr>, Vec<syn::Path>) = handlers
223                .iter()
224                .filter_map(|h| h.handler.as_ref())
225                // `prefetch` is an implementation detail of `call` handlers, so we don't list them
226                .filter(|h| h.attrs.kind != HandlerKind::Prefetch)
227                .map(|h| (h.attrs.rpc_name.clone(), h.attrs.kind.as_sdk_ident()))
228                .unzip();
229            if handler_names.is_empty() {
230                quote! {}
231            } else {
232                quote! {
233                    fn supported_methods() -> Vec<core_types::MethodHandlerInfo> {
234                        vec![ #(
235                            core_types::MethodHandlerInfo {
236                                kind: #handler_kinds,
237                                name: #handler_names.to_string(),
238                            },
239                        )* ]
240                    }
241                }
242            }
243        };
244
245        let expensive_queries_impl = {
246            let handler_names: Vec<syn::Expr> = handlers
247                .iter()
248                .filter_map(|h| h.handler.as_ref())
249                .filter(|h| h.attrs.kind == HandlerKind::Query && h.attrs.is_expensive)
250                .map(|h| h.attrs.rpc_name.clone())
251                .collect();
252            if handler_names.is_empty() {
253                quote! {}
254            } else {
255                quote! {
256                    fn is_expensive_query(method: &str) -> bool {
257                        [ #( #handler_names, )* ].contains(&method)
258                    }
259                }
260            }
261        };
262
263        let allowed_private_km_queries_impl = {
264            let handler_names: Vec<syn::Expr> = handlers
265                .iter()
266                .filter_map(|h| h.handler.as_ref())
267                .filter(|h| h.attrs.kind == HandlerKind::Query && h.attrs.allow_private_km)
268                .map(|h| h.attrs.rpc_name.clone())
269                .collect();
270            if handler_names.is_empty() {
271                quote! {}
272            } else {
273                quote! {
274                    fn is_allowed_private_km_query(method: &str) -> bool {
275                        [ #( #handler_names, )* ].contains(&method)
276                    }
277                }
278            }
279        };
280
281        let allowed_interactive_calls_impl = {
282            let handler_names: Vec<syn::Expr> = handlers
283                .iter()
284                .filter_map(|h| h.handler.as_ref())
285                .filter(|h| h.attrs.kind == HandlerKind::Call && h.attrs.allow_interactive)
286                .map(|h| h.attrs.rpc_name.clone())
287                .collect();
288            if handler_names.is_empty() {
289                quote! {}
290            } else {
291                quote! {
292                    fn is_allowed_interactive_call(method: &str) -> bool {
293                        [ #( #handler_names, )* ].contains(&method)
294                    }
295                }
296            }
297        };
298
299        quote! {
300            #[automatically_derived]
301            impl #generics sdk::module::MethodHandler for #ty {
302                #prefetch_impl
303                #dispatch_call_impl
304                #dispatch_query_impl
305                #dispatch_message_result_impl
306                #supported_methods_impl
307                #expensive_queries_impl
308                #allowed_private_km_queries_impl
309                #allowed_interactive_calls_impl
310            }
311
312            #[automatically_derived]
313            impl #generics #ty {
314                #query_parameters_impl
315
316                #(#handler_items)*
317            }
318        }
319    }
320}
321
322/// An item (in the `syn` sense, i.e. a fn, type, comment, etc) in an `impl` block,
323/// plus parsed data about its #[handler] attribute, if any.
324#[derive(Clone)]
325struct ParsedImplItem {
326    item: syn::ImplItem,
327    handler: Option<HandlerInfo>,
328}
329
330#[derive(Clone, Debug)]
331struct HandlerInfo {
332    attrs: MethodHandlerAttr,
333    /// Name of the handler function.
334    ident: syn::Ident,
335}
336
337#[derive(Debug, Copy, Clone, PartialEq)]
338enum HandlerKind {
339    Call,
340    Query,
341    MessageResult,
342    Prefetch,
343}
344
345impl HandlerKind {
346    fn as_sdk_ident(&self) -> syn::Path {
347        match self {
348            HandlerKind::Call => parse_quote!(core_types::MethodHandlerKind::Call),
349            HandlerKind::Query => parse_quote!(core_types::MethodHandlerKind::Query),
350            HandlerKind::MessageResult => {
351                parse_quote!(core_types::MethodHandlerKind::MessageResult)
352            }
353            HandlerKind::Prefetch => {
354                unimplemented!("prefetch cannot be expressed in core::types::MethodHandlerKind")
355            }
356        }
357    }
358}
359
360#[derive(Debug, Clone, PartialEq)]
361struct MethodHandlerAttr {
362    kind: HandlerKind,
363    /// Name of the RPC that this handler handles, e.g. "my_module.MyQuery".
364    rpc_name: syn::Expr,
365    /// Whether this handler is tagged as expensive. Only applies to query handlers.
366    is_expensive: bool,
367    /// Whether this handler is tagged as allowing access to private key manager state. Only applies
368    /// to query handlers.
369    allow_private_km: bool,
370    /// Whether this handler is tagged as allowing interactive calls. Only applies to call handlers.
371    allow_interactive: bool,
372    /// Whether this handler is tagged as internal.
373    is_internal: bool,
374}
375impl syn::parse::Parse for MethodHandlerAttr {
376    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
377        let kind: syn::Ident = input.parse()?;
378        let kind = match kind.to_string().as_str() {
379            "call" => HandlerKind::Call,
380            "query" => HandlerKind::Query,
381            "message_result" => HandlerKind::MessageResult,
382            "prefetch" => HandlerKind::Prefetch,
383            _ => return Err(syn::Error::new(kind.span(), "invalid handler kind")),
384        };
385        let _: syn::token::Eq = input.parse()?;
386        let rpc_name: syn::Expr = input.parse()?;
387
388        // Parse optional comma-separated tags.
389        let mut is_expensive = false;
390        let mut allow_private_km = false;
391        let mut allow_interactive = false;
392        let mut is_internal = false;
393        while input.peek(syn::token::Comma) {
394            let _: syn::token::Comma = input.parse()?;
395            let tag: syn::Ident = input.parse()?;
396
397            if tag == "expensive" {
398                if kind != HandlerKind::Query {
399                    return Err(syn::Error::new(
400                        tag.span(),
401                        "`expensive` tag is only allowed on `query` handlers",
402                    ));
403                }
404                is_expensive = true;
405            } else if tag == "allow_private_km" {
406                if kind != HandlerKind::Query {
407                    return Err(syn::Error::new(
408                        tag.span(),
409                        "`allow_private_km` tag is only allowed on `query` handlers",
410                    ));
411                }
412                allow_private_km = true;
413            } else if tag == "allow_interactive" {
414                if kind != HandlerKind::Call {
415                    return Err(syn::Error::new(
416                        tag.span(),
417                        "`allow_interactive` tag is only allowed on `call` handlers",
418                    ));
419                }
420                allow_interactive = true;
421            } else if tag == "internal" {
422                if kind != HandlerKind::Call {
423                    return Err(syn::Error::new(
424                        tag.span(),
425                        "`internal` tag is only allowed on `call` handlers",
426                    ));
427                }
428                is_internal = true;
429            } else {
430                return Err(syn::Error::new(
431                    tag.span(),
432                    "invalid handler tag; supported: `expensive`, `allow_private_km`, `allow_interactive`, `internal`",
433                ));
434            }
435        }
436
437        if !input.is_empty() {
438            return Err(syn::Error::new(input.span(), "unexpected extra tokens"));
439        }
440        Ok(Self {
441            kind,
442            rpc_name,
443            is_expensive,
444            allow_private_km,
445            allow_interactive,
446            is_internal,
447        })
448    }
449}
450
451fn parse_attrs(attrs: &[syn::Attribute]) -> Option<MethodHandlerAttr> {
452    let handler_meta = attrs.iter().find(|attr| attr.path().is_ident("handler"))?;
453    handler_meta
454        .parse_args()
455        .map_err(|err| {
456            emit_compile_error(format!(
457                "Unsupported format of #[handler(...)] attribute: {err}"
458            ))
459        })
460        .ok()
461}