oasis_runtime_sdk_macros/module_derive/
method_handler.rs1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::parse_quote;
4
5use crate::emit_compile_error;
6
7pub struct DeriveMethodHandler {
9 handlers: Vec<ParsedImplItem>,
10}
11
12impl DeriveMethodHandler {
13 pub fn new() -> Box<Self> {
14 Box::new(Self { handlers: vec![] })
15 }
16}
17
18impl super::Deriver for DeriveMethodHandler {
19 fn preprocess(&mut self, item: syn::ImplItem) -> Option<syn::ImplItem> {
20 let method = match item {
21 syn::ImplItem::Fn(ref f) => f,
22 _ => return Some(item),
23 };
24
25 let attrs = if let Some(attrs) = parse_attrs(&method.attrs) {
26 attrs
27 } else {
28 return Some(item);
29 };
30
31 self.handlers.push(ParsedImplItem {
32 handler: Some(HandlerInfo {
33 attrs,
34 ident: method.sig.ident.clone(),
35 }),
36 item,
37 });
38
39 None }
41
42 fn derive(&mut self, generics: &syn::Generics, ty: &Box<syn::Type>) -> TokenStream {
43 let handlers = &self.handlers;
44 let handler_items = handlers
45 .iter()
46 .map(|ParsedImplItem { item, .. }| item)
47 .collect::<Vec<_>>();
48
49 fn filter_by_kind(
51 handlers: &[ParsedImplItem],
52 kind: HandlerKind,
53 ) -> (Vec<syn::Expr>, Vec<syn::Ident>) {
54 handlers
55 .iter()
56 .filter_map(|h| h.handler.as_ref())
57 .filter(|h| h.attrs.kind == kind)
58 .map(|h| (h.attrs.rpc_name.clone(), h.ident.clone()))
59 .unzip()
60 }
61
62 let prefetch_impl = {
63 let (handler_names, handler_idents) = filter_by_kind(handlers, HandlerKind::Prefetch);
64
65 let (call_handler_names, _) = filter_by_kind(handlers, HandlerKind::Call);
68 let handler_names_without_impl: Vec<&syn::Expr> = call_handler_names
69 .iter()
70 .filter(|n| !handler_names.contains(n))
71 .collect();
72
73 if handler_names.is_empty() {
74 quote! {}
75 } else {
76 quote! {
77 fn prefetch(
78 prefixes: &mut BTreeSet<Prefix>,
79 method: &str,
80 body: cbor::Value,
81 auth_info: &AuthInfo,
82 ) -> module::DispatchResult<cbor::Value, Result<(), sdk::error::RuntimeError>> {
83 let mut add_prefix = |p| {prefixes.insert(p);};
84 match method {
85 #(
87 #handler_names => module::DispatchResult::Handled(
88 Self::#handler_idents(&mut add_prefix, body, auth_info)
89 ),
90 )*
91 #(
93 #handler_names_without_impl => module::DispatchResult::Handled(Ok(())),
94 )*
95 _ => module::DispatchResult::Unhandled(body),
96 }
97 }
98 }
99 }
100 };
101
102 let dispatch_call_impl = {
103 let (handler_names, handler_fns): (Vec<_>, Vec<_>) = handlers
104 .iter()
105 .filter_map(|h| h.handler.as_ref())
106 .filter(|h| h.attrs.kind == HandlerKind::Call)
107 .map(|h| {
108 (h.attrs.rpc_name.clone(), {
109 let ident = &h.ident;
110
111 if h.attrs.is_internal {
112 quote! {
113 |ctx, body| {
114 if !sdk::state::CurrentState::with_env(|env| env.is_internal()) {
115 return Err(sdk::modules::core::Error::Forbidden.into());
116 }
117 Self::#ident(ctx, body)
118 }
119 }
120 } else {
121 quote! { Self::#ident }
122 }
123 })
124 })
125 .unzip();
126
127 if handler_names.is_empty() {
128 quote! {}
129 } else {
130 quote! {
131 fn dispatch_call<C: Context>(
132 ctx: &C,
133 method: &str,
134 body: cbor::Value,
135 ) -> DispatchResult<cbor::Value, CallResult> {
136 match method {
137 #(
138 #handler_names => module::dispatch_call(ctx, body, #handler_fns),
139 )*
140 _ => DispatchResult::Unhandled(body),
141 }
142 }
143 }
144 }
145 };
146
147 let query_parameters_impl = {
148 quote! {
149 fn query_parameters<C: Context>(_ctx: &C, _args: ()) -> Result<<Self as module::Module>::Parameters, <Self as module::Module>::Error> {
150 Ok(Self::params())
151 }
152 }
153 };
154
155 let dispatch_query_impl = {
156 let (handler_names, handler_idents) = filter_by_kind(handlers, HandlerKind::Query);
157
158 if handler_names.is_empty() {
159 quote! {
160 fn dispatch_query<C: Context>(
161 ctx: &C,
162 method: &str,
163 args: cbor::Value,
164 ) -> DispatchResult<cbor::Value, Result<cbor::Value, sdk::error::RuntimeError>> {
165 match method {
166 q if q == format!("{}.Parameters", Self::NAME) => module::dispatch_query(ctx, args, Self::query_parameters),
167 _ => DispatchResult::Unhandled(args),
168 }
169 }
170 }
171 } else {
172 quote! {
173 fn dispatch_query<C: Context>(
174 ctx: &C,
175 method: &str,
176 args: cbor::Value,
177 ) -> DispatchResult<cbor::Value, Result<cbor::Value, sdk::error::RuntimeError>> {
178 match method {
179 #(
180 #handler_names => module::dispatch_query(ctx, args, Self::#handler_idents),
181 )*
182 q if q == format!("{}.Parameters", Self::NAME) => module::dispatch_query(ctx, args, Self::query_parameters),
183 _ => DispatchResult::Unhandled(args),
184 }
185 }
186 }
187 }
188 };
189
190 let dispatch_message_result_impl = {
191 let (handler_names, handler_idents) =
192 filter_by_kind(handlers, HandlerKind::MessageResult);
193
194 if handler_names.is_empty() {
195 quote! {}
196 } else {
197 quote! {
198 fn dispatch_message_result<C: Context>(
199 ctx: &C,
200 handler_name: &str,
201 result: MessageResult,
202 ) -> DispatchResult<MessageResult, ()> {
203 match handler_name {
204 #(
205 #handler_names => {
206 Self::#handler_idents(
207 ctx,
208 result.event,
209 cbor::from_value(result.context).expect("invalid message handler context"),
210 );
211 DispatchResult::Handled(())
212 }
213 )*
214 _ => DispatchResult::Unhandled(result),
215 }
216 }
217 }
218 }
219 };
220
221 let supported_methods_impl = {
222 let (handler_names, handler_kinds): (Vec<syn::Expr>, Vec<syn::Path>) = handlers
223 .iter()
224 .filter_map(|h| h.handler.as_ref())
225 .filter(|h| h.attrs.kind != HandlerKind::Prefetch)
227 .map(|h| (h.attrs.rpc_name.clone(), h.attrs.kind.as_sdk_ident()))
228 .unzip();
229 if handler_names.is_empty() {
230 quote! {}
231 } else {
232 quote! {
233 fn supported_methods() -> Vec<core_types::MethodHandlerInfo> {
234 vec![ #(
235 core_types::MethodHandlerInfo {
236 kind: #handler_kinds,
237 name: #handler_names.to_string(),
238 },
239 )* ]
240 }
241 }
242 }
243 };
244
245 let expensive_queries_impl = {
246 let handler_names: Vec<syn::Expr> = handlers
247 .iter()
248 .filter_map(|h| h.handler.as_ref())
249 .filter(|h| h.attrs.kind == HandlerKind::Query && h.attrs.is_expensive)
250 .map(|h| h.attrs.rpc_name.clone())
251 .collect();
252 if handler_names.is_empty() {
253 quote! {}
254 } else {
255 quote! {
256 fn is_expensive_query(method: &str) -> bool {
257 [ #( #handler_names, )* ].contains(&method)
258 }
259 }
260 }
261 };
262
263 let allowed_private_km_queries_impl = {
264 let handler_names: Vec<syn::Expr> = handlers
265 .iter()
266 .filter_map(|h| h.handler.as_ref())
267 .filter(|h| h.attrs.kind == HandlerKind::Query && h.attrs.allow_private_km)
268 .map(|h| h.attrs.rpc_name.clone())
269 .collect();
270 if handler_names.is_empty() {
271 quote! {}
272 } else {
273 quote! {
274 fn is_allowed_private_km_query(method: &str) -> bool {
275 [ #( #handler_names, )* ].contains(&method)
276 }
277 }
278 }
279 };
280
281 let allowed_interactive_calls_impl = {
282 let handler_names: Vec<syn::Expr> = handlers
283 .iter()
284 .filter_map(|h| h.handler.as_ref())
285 .filter(|h| h.attrs.kind == HandlerKind::Call && h.attrs.allow_interactive)
286 .map(|h| h.attrs.rpc_name.clone())
287 .collect();
288 if handler_names.is_empty() {
289 quote! {}
290 } else {
291 quote! {
292 fn is_allowed_interactive_call(method: &str) -> bool {
293 [ #( #handler_names, )* ].contains(&method)
294 }
295 }
296 }
297 };
298
299 quote! {
300 #[automatically_derived]
301 impl #generics sdk::module::MethodHandler for #ty {
302 #prefetch_impl
303 #dispatch_call_impl
304 #dispatch_query_impl
305 #dispatch_message_result_impl
306 #supported_methods_impl
307 #expensive_queries_impl
308 #allowed_private_km_queries_impl
309 #allowed_interactive_calls_impl
310 }
311
312 #[automatically_derived]
313 impl #generics #ty {
314 #query_parameters_impl
315
316 #(#handler_items)*
317 }
318 }
319 }
320}
321
322#[derive(Clone)]
325struct ParsedImplItem {
326 item: syn::ImplItem,
327 handler: Option<HandlerInfo>,
328}
329
330#[derive(Clone, Debug)]
331struct HandlerInfo {
332 attrs: MethodHandlerAttr,
333 ident: syn::Ident,
335}
336
337#[derive(Debug, Copy, Clone, PartialEq)]
338enum HandlerKind {
339 Call,
340 Query,
341 MessageResult,
342 Prefetch,
343}
344
345impl HandlerKind {
346 fn as_sdk_ident(&self) -> syn::Path {
347 match self {
348 HandlerKind::Call => parse_quote!(core_types::MethodHandlerKind::Call),
349 HandlerKind::Query => parse_quote!(core_types::MethodHandlerKind::Query),
350 HandlerKind::MessageResult => {
351 parse_quote!(core_types::MethodHandlerKind::MessageResult)
352 }
353 HandlerKind::Prefetch => {
354 unimplemented!("prefetch cannot be expressed in core::types::MethodHandlerKind")
355 }
356 }
357 }
358}
359
360#[derive(Debug, Clone, PartialEq)]
361struct MethodHandlerAttr {
362 kind: HandlerKind,
363 rpc_name: syn::Expr,
365 is_expensive: bool,
367 allow_private_km: bool,
370 allow_interactive: bool,
372 is_internal: bool,
374}
375impl syn::parse::Parse for MethodHandlerAttr {
376 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
377 let kind: syn::Ident = input.parse()?;
378 let kind = match kind.to_string().as_str() {
379 "call" => HandlerKind::Call,
380 "query" => HandlerKind::Query,
381 "message_result" => HandlerKind::MessageResult,
382 "prefetch" => HandlerKind::Prefetch,
383 _ => return Err(syn::Error::new(kind.span(), "invalid handler kind")),
384 };
385 let _: syn::token::Eq = input.parse()?;
386 let rpc_name: syn::Expr = input.parse()?;
387
388 let mut is_expensive = false;
390 let mut allow_private_km = false;
391 let mut allow_interactive = false;
392 let mut is_internal = false;
393 while input.peek(syn::token::Comma) {
394 let _: syn::token::Comma = input.parse()?;
395 let tag: syn::Ident = input.parse()?;
396
397 if tag == "expensive" {
398 if kind != HandlerKind::Query {
399 return Err(syn::Error::new(
400 tag.span(),
401 "`expensive` tag is only allowed on `query` handlers",
402 ));
403 }
404 is_expensive = true;
405 } else if tag == "allow_private_km" {
406 if kind != HandlerKind::Query {
407 return Err(syn::Error::new(
408 tag.span(),
409 "`allow_private_km` tag is only allowed on `query` handlers",
410 ));
411 }
412 allow_private_km = true;
413 } else if tag == "allow_interactive" {
414 if kind != HandlerKind::Call {
415 return Err(syn::Error::new(
416 tag.span(),
417 "`allow_interactive` tag is only allowed on `call` handlers",
418 ));
419 }
420 allow_interactive = true;
421 } else if tag == "internal" {
422 if kind != HandlerKind::Call {
423 return Err(syn::Error::new(
424 tag.span(),
425 "`internal` tag is only allowed on `call` handlers",
426 ));
427 }
428 is_internal = true;
429 } else {
430 return Err(syn::Error::new(
431 tag.span(),
432 "invalid handler tag; supported: `expensive`, `allow_private_km`, `allow_interactive`, `internal`",
433 ));
434 }
435 }
436
437 if !input.is_empty() {
438 return Err(syn::Error::new(input.span(), "unexpected extra tokens"));
439 }
440 Ok(Self {
441 kind,
442 rpc_name,
443 is_expensive,
444 allow_private_km,
445 allow_interactive,
446 is_internal,
447 })
448 }
449}
450
451fn parse_attrs(attrs: &[syn::Attribute]) -> Option<MethodHandlerAttr> {
452 let handler_meta = attrs.iter().find(|attr| attr.path().is_ident("handler"))?;
453 handler_meta
454 .parse_args()
455 .map_err(|err| {
456 emit_compile_error(format!(
457 "Unsupported format of #[handler(...)] attribute: {err}"
458 ))
459 })
460 .ok()
461}