oasis_runtime_sdk_macros/
error_derive.rs

1use darling::{util::Flag, FromDeriveInput, FromField, FromVariant};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote, quote_spanned};
4use syn::{DeriveInput, Ident, Index, Member, Path};
5
6use crate::generators::{self as gen, CodedVariant};
7
8#[derive(FromDeriveInput)]
9#[darling(supports(enum_any), attributes(sdk_error))]
10struct Error {
11    ident: Ident,
12
13    data: darling::ast::Data<ErrorVariant, darling::util::Ignored>,
14
15    /// The path to a const set to the module name.
16    module_name: Option<syn::Path>,
17
18    /// Whether to sequentially autonumber the error codes.
19    /// This option exists as a convenience for runtimes that
20    /// only append errors or release only breaking changes.
21    #[darling(rename = "autonumber")]
22    autonumber: Flag,
23
24    /// Whether the `into_abort` function should return itself. This can only be used when the type
25    /// being annotated is the dispatcher error type so it is only for internal use.
26    #[darling(rename = "abort_self")]
27    abort_self: Flag,
28}
29
30#[derive(FromVariant)]
31#[darling(attributes(sdk_error))]
32struct ErrorVariant {
33    ident: Ident,
34
35    fields: darling::ast::Fields<ErrorField>,
36
37    /// The explicit ID of the error code. Overrides any autonumber set on the error enum.
38    #[darling(rename = "code")]
39    code: Option<u32>,
40
41    #[darling(rename = "transparent")]
42    transparent: Flag,
43
44    #[darling(rename = "abort")]
45    abort: Flag,
46}
47
48impl CodedVariant for ErrorVariant {
49    const FIELD_NAME: &'static str = "code";
50
51    fn ident(&self) -> &Ident {
52        &self.ident
53    }
54
55    fn code(&self) -> Option<u32> {
56        self.code
57    }
58}
59
60#[derive(FromField)]
61#[darling(forward_attrs(source, from))]
62struct ErrorField {
63    ident: Option<Ident>,
64
65    attrs: Vec<syn::Attribute>,
66}
67
68pub fn derive_error(input: DeriveInput) -> TokenStream {
69    let error = match Error::from_derive_input(&input) {
70        Ok(error) => error,
71        Err(e) => return e.write_errors(),
72    };
73
74    let error_ty_ident = &error.ident;
75
76    let module_name = error
77        .module_name
78        .unwrap_or_else(|| syn::parse_quote!(MODULE_NAME));
79
80    let (module_name_body, code_body, abort_body) = convert_variants(
81        &format_ident!("self"),
82        module_name,
83        &error.data.as_ref().take_enum().unwrap(),
84        error.autonumber.is_present(),
85        error.abort_self.is_present(),
86    );
87
88    let sdk_crate = gen::sdk_crate_path();
89
90    gen::wrap_in_const(quote! {
91        use #sdk_crate::{self as __sdk, error::Error as _};
92
93        #[automatically_derived]
94        impl __sdk::error::Error for #error_ty_ident {
95            fn module_name(&self) -> &str {
96                #module_name_body
97            }
98
99            fn code(&self) -> u32 {
100                #code_body
101            }
102
103            fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
104                #abort_body
105            }
106        }
107
108        #[automatically_derived]
109        impl From<#error_ty_ident> for __sdk::error::RuntimeError {
110            fn from(err: #error_ty_ident) -> Self {
111                Self::new(err.module_name(), err.code(), &err.to_string())
112            }
113        }
114    })
115}
116
117fn convert_variants(
118    enum_binding: &Ident,
119    module_name: Path,
120    variants: &[&ErrorVariant],
121    autonumber: bool,
122    abort_self: bool,
123) -> (TokenStream, TokenStream, TokenStream) {
124    if variants.is_empty() {
125        return (quote!(#module_name), quote!(0), quote!(Err(#enum_binding)));
126    }
127
128    let mut next_autonumber = 0u32;
129    let mut reserved_numbers = std::collections::BTreeSet::new();
130
131    let (module_name_matches, (code_matches, abort_matches)): (Vec<_>, (Vec<_>, Vec<_>)) = variants
132        .iter()
133        .map(|variant| {
134            let variant_ident = &variant.ident;
135
136            if variant.transparent.is_present() {
137                // Transparently forward everything to the source.
138                let mut maybe_sources = variant
139                    .fields
140                    .iter()
141                    .enumerate()
142                    .filter(|(_, f)| (!f.attrs.is_empty()))
143                    .map(|(i, f)| (i, f.ident.clone()));
144                let source = maybe_sources.next();
145                if maybe_sources.count() != 0 {
146                    variant_ident
147                        .span()
148                        .unwrap()
149                        .error("multiple error sources specified for variant")
150                        .emit();
151                    return (quote!(), (quote!(), quote!()));
152                }
153                if source.is_none() {
154                    variant_ident
155                        .span()
156                        .unwrap()
157                        .error("no source error specified for variant")
158                        .emit();
159                    return (quote!(), (quote!(), quote!()));
160                }
161                let (field_index, field_ident) = source.unwrap();
162
163                let field = match field_ident {
164                    Some(ident) => Member::Named(ident),
165                    None => Member::Unnamed(Index {
166                        index: field_index as u32,
167                        span: variant_ident.span(),
168                    }),
169                };
170
171                // Get all other fields that are needed for forwarding in abort variants.
172                let non_source_fields = variant
173                    .fields
174                    .iter()
175                    .enumerate()
176                    .filter(|(i, _)| i != &field_index)
177                    .map(|(i, f)| {
178                        let pat = match f.ident {
179                            Some(ref ident) => Member::Named(ident.clone()),
180                            None => Member::Unnamed(Index {
181                                index: i as u32,
182                                span: variant_ident.span(),
183                            }),
184                        };
185                        let ident = Ident::new(&format!("__a{i}"), variant_ident.span());
186                        let binding = quote!( #pat: #ident, );
187
188                        binding
189                    });
190                let non_source_field_bindings = non_source_fields.clone();
191
192                let source = quote!(source);
193                let module_name = quote_spanned!(variant_ident.span()=> #source.module_name());
194                let code = quote_spanned!(variant_ident.span()=> #source.code());
195                let abort_reclaim = quote!(Self::#variant_ident { #field: e, #(#non_source_fields)* });
196                let abort = quote_spanned!(variant_ident.span()=> #source.into_abort().map_err(|e| #abort_reclaim));
197
198                (
199                    quote! {
200                        Self::#variant_ident { #field: #source, .. } => #module_name,
201                    },
202                    (
203                        quote! {
204                            Self::#variant_ident { #field: #source, .. } => #code,
205                        },
206                        quote! {
207                            Self::#variant_ident { #field: #source, #(#non_source_field_bindings)* } => #abort,
208                        },
209                    ),
210                )
211            } else {
212                // Regular case without forwarding.
213                let code = match variant.code {
214                    Some(code) => {
215                        if reserved_numbers.contains(&code) {
216                            variant_ident
217                                .span()
218                                .unwrap()
219                                .error(format!("code {code} already used"))
220                                .emit();
221                            return (quote!(), (quote!(), quote!()));
222                        }
223                        reserved_numbers.insert(code);
224                        code
225                    }
226                    None if autonumber => {
227                        let mut reserved_successors = reserved_numbers.range(next_autonumber..);
228                        while reserved_successors.next() == Some(&next_autonumber) {
229                            next_autonumber += 1;
230                        }
231                        let code = next_autonumber;
232                        reserved_numbers.insert(code);
233                        next_autonumber += 1;
234                        code
235                    }
236                    None => {
237                        variant_ident
238                            .span()
239                            .unwrap()
240                            .error("missing `code` for variant")
241                            .emit();
242                        return (quote!(), (quote!(), quote!()));
243                    }
244                };
245
246                let abort = if variant.abort.is_present() {
247                    quote!{
248                        Self::#variant_ident(err) => Ok(err),
249                    }
250                } else {
251                    quote!{
252                        Self::#variant_ident { .. } => Err(#enum_binding),
253                    }
254                };
255
256                (
257                    quote! {
258                        Self::#variant_ident { .. } => #module_name,
259                    },
260                    (
261                        quote! {
262                            Self::#variant_ident { .. } => #code,
263                        },
264                        abort,
265                    ),
266                )
267            }
268        })
269        .unzip();
270
271    let abort_body = if abort_self {
272        quote!(Ok(self))
273    } else {
274        quote! {
275            match #enum_binding {
276                #(#abort_matches)*
277            }
278        }
279    };
280
281    (
282        quote! {
283            match #enum_binding {
284                #(#module_name_matches)*
285            }
286        },
287        quote! {
288            match #enum_binding {
289                #(#code_matches)*
290            }
291        },
292        abort_body,
293    )
294}
295
296#[cfg(test)]
297mod tests {
298    #[test]
299    fn generate_error_impl_auto_abort() {
300        let expected: syn::Stmt = syn::parse_quote!(
301            const _: () = {
302                use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
303                #[automatically_derived]
304                impl __sdk::error::Error for Error {
305                    fn module_name(&self) -> &str {
306                        match self {
307                            Self::Error0 { .. } => MODULE_NAME,
308                            Self::Error2 { .. } => MODULE_NAME,
309                            Self::Error1 { .. } => MODULE_NAME,
310                            Self::Error3 { .. } => MODULE_NAME,
311                            Self::ErrorAbort { .. } => MODULE_NAME,
312                        }
313                    }
314                    fn code(&self) -> u32 {
315                        match self {
316                            Self::Error0 { .. } => 0u32,
317                            Self::Error2 { .. } => 2u32,
318                            Self::Error1 { .. } => 1u32,
319                            Self::Error3 { .. } => 3u32,
320                            Self::ErrorAbort { .. } => 4u32,
321                        }
322                    }
323                    fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
324                        match self {
325                            Self::Error0 { .. } => Err(self),
326                            Self::Error2 { .. } => Err(self),
327                            Self::Error1 { .. } => Err(self),
328                            Self::Error3 { .. } => Err(self),
329                            Self::ErrorAbort(err) => Ok(err),
330                        }
331                    }
332                }
333                #[automatically_derived]
334                impl From<Error> for __sdk::error::RuntimeError {
335                    fn from(err: Error) -> Self {
336                        Self::new(err.module_name(), err.code(), &err.to_string())
337                    }
338                }
339            };
340        );
341
342        let input: syn::DeriveInput = syn::parse_quote!(
343            #[derive(Error)]
344            #[sdk_error(autonumber)]
345            pub enum Error {
346                Error0,
347                #[sdk_error(code = 2)]
348                Error2 {
349                    payload: Vec<u8>,
350                },
351                Error1(String),
352                Error3,
353                #[sdk_error(abort)]
354                ErrorAbort(sdk::dispatcher::Error),
355            }
356        );
357        let error_derivation = super::derive_error(input);
358        let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
359
360        crate::assert_empty_diff!(actual, expected);
361    }
362
363    #[test]
364    fn generate_error_impl_manual() {
365        let expected: syn::Stmt = syn::parse_quote!(
366            const _: () = {
367                use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
368                #[automatically_derived]
369                impl __sdk::error::Error for Error {
370                    fn module_name(&self) -> &str {
371                        THE_MODULE_NAME
372                    }
373                    fn code(&self) -> u32 {
374                        0
375                    }
376                    fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
377                        Err(self)
378                    }
379                }
380                #[automatically_derived]
381                impl From<Error> for __sdk::error::RuntimeError {
382                    fn from(err: Error) -> Self {
383                        Self::new(err.module_name(), err.code(), &err.to_string())
384                    }
385                }
386            };
387        );
388
389        let input: syn::DeriveInput = syn::parse_quote!(
390            #[derive(Error)]
391            #[sdk_error(autonumber, module_name = "THE_MODULE_NAME")]
392            pub enum Error {}
393        );
394        let error_derivation = super::derive_error(input);
395        let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
396
397        crate::assert_empty_diff!(actual, expected);
398    }
399
400    #[test]
401    fn generate_error_impl_from() {
402        let expected: syn::Stmt = syn::parse_quote!(
403            const _: () = {
404                use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
405                #[automatically_derived]
406                impl __sdk::error::Error for Error {
407                    fn module_name(&self) -> &str {
408                        match self {
409                            Self::Foo { 0: source, .. } => source.module_name(),
410                        }
411                    }
412                    fn code(&self) -> u32 {
413                        match self {
414                            Self::Foo { 0: source, .. } => source.code(),
415                        }
416                    }
417                    fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
418                        match self {
419                            Self::Foo { 0: source } => {
420                                source.into_abort().map_err(|e| Self::Foo { 0: e })
421                            }
422                        }
423                    }
424                }
425                #[automatically_derived]
426                impl From<Error> for __sdk::error::RuntimeError {
427                    fn from(err: Error) -> Self {
428                        Self::new(err.module_name(), err.code(), &err.to_string())
429                    }
430                }
431            };
432        );
433
434        let input: syn::DeriveInput = syn::parse_quote!(
435            #[derive(Error)]
436            #[sdk_error(module_name = "THE_MODULE_NAME")]
437            pub enum Error {
438                #[sdk_error(transparent)]
439                Foo(#[from] AnotherError),
440            }
441        );
442        let error_derivation = super::derive_error(input);
443        let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
444
445        crate::assert_empty_diff!(actual, expected);
446    }
447
448    #[test]
449    fn generate_error_impl_abort_self() {
450        let expected: syn::Stmt = syn::parse_quote!(
451            const _: () = {
452                use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
453                #[automatically_derived]
454                impl __sdk::error::Error for Error {
455                    fn module_name(&self) -> &str {
456                        match self {
457                            Self::Foo { .. } => THE_MODULE_NAME,
458                            Self::Bar { .. } => THE_MODULE_NAME,
459                        }
460                    }
461                    fn code(&self) -> u32 {
462                        match self {
463                            Self::Foo { .. } => 1u32,
464                            Self::Bar { .. } => 2u32,
465                        }
466                    }
467                    fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
468                        Ok(self)
469                    }
470                }
471                #[automatically_derived]
472                impl From<Error> for __sdk::error::RuntimeError {
473                    fn from(err: Error) -> Self {
474                        Self::new(err.module_name(), err.code(), &err.to_string())
475                    }
476                }
477            };
478        );
479
480        let input: syn::DeriveInput = syn::parse_quote!(
481            #[derive(Error)]
482            #[sdk_error(module_name = "THE_MODULE_NAME", abort_self)]
483            pub enum Error {
484                #[sdk_error(code = 1)]
485                Foo,
486                #[sdk_error(code = 2)]
487                Bar,
488            }
489        );
490        let error_derivation = super::derive_error(input);
491        let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
492
493        crate::assert_empty_diff!(actual, expected);
494    }
495}