oasis_runtime_sdk_macros/
error_derive.rs1use darling::{util::Flag, FromDeriveInput, FromField, FromVariant};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote, quote_spanned};
4use syn::{DeriveInput, Ident, Index, Member, Path};
5
6use crate::generators::{self as gen, CodedVariant};
7
8#[derive(FromDeriveInput)]
9#[darling(supports(enum_any), attributes(sdk_error))]
10struct Error {
11 ident: Ident,
12
13 data: darling::ast::Data<ErrorVariant, darling::util::Ignored>,
14
15 module_name: Option<syn::Path>,
17
18 #[darling(rename = "autonumber")]
22 autonumber: Flag,
23
24 #[darling(rename = "abort_self")]
27 abort_self: Flag,
28}
29
30#[derive(FromVariant)]
31#[darling(attributes(sdk_error))]
32struct ErrorVariant {
33 ident: Ident,
34
35 fields: darling::ast::Fields<ErrorField>,
36
37 #[darling(rename = "code")]
39 code: Option<u32>,
40
41 #[darling(rename = "transparent")]
42 transparent: Flag,
43
44 #[darling(rename = "abort")]
45 abort: Flag,
46}
47
48impl CodedVariant for ErrorVariant {
49 const FIELD_NAME: &'static str = "code";
50
51 fn ident(&self) -> &Ident {
52 &self.ident
53 }
54
55 fn code(&self) -> Option<u32> {
56 self.code
57 }
58}
59
60#[derive(FromField)]
61#[darling(forward_attrs(source, from))]
62struct ErrorField {
63 ident: Option<Ident>,
64
65 attrs: Vec<syn::Attribute>,
66}
67
68pub fn derive_error(input: DeriveInput) -> TokenStream {
69 let error = match Error::from_derive_input(&input) {
70 Ok(error) => error,
71 Err(e) => return e.write_errors(),
72 };
73
74 let error_ty_ident = &error.ident;
75
76 let module_name = error
77 .module_name
78 .unwrap_or_else(|| syn::parse_quote!(MODULE_NAME));
79
80 let (module_name_body, code_body, abort_body) = convert_variants(
81 &format_ident!("self"),
82 module_name,
83 &error.data.as_ref().take_enum().unwrap(),
84 error.autonumber.is_present(),
85 error.abort_self.is_present(),
86 );
87
88 let sdk_crate = gen::sdk_crate_path();
89
90 gen::wrap_in_const(quote! {
91 use #sdk_crate::{self as __sdk, error::Error as _};
92
93 #[automatically_derived]
94 impl __sdk::error::Error for #error_ty_ident {
95 fn module_name(&self) -> &str {
96 #module_name_body
97 }
98
99 fn code(&self) -> u32 {
100 #code_body
101 }
102
103 fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
104 #abort_body
105 }
106 }
107
108 #[automatically_derived]
109 impl From<#error_ty_ident> for __sdk::error::RuntimeError {
110 fn from(err: #error_ty_ident) -> Self {
111 Self::new(err.module_name(), err.code(), &err.to_string())
112 }
113 }
114 })
115}
116
117fn convert_variants(
118 enum_binding: &Ident,
119 module_name: Path,
120 variants: &[&ErrorVariant],
121 autonumber: bool,
122 abort_self: bool,
123) -> (TokenStream, TokenStream, TokenStream) {
124 if variants.is_empty() {
125 return (quote!(#module_name), quote!(0), quote!(Err(#enum_binding)));
126 }
127
128 let mut next_autonumber = 0u32;
129 let mut reserved_numbers = std::collections::BTreeSet::new();
130
131 let (module_name_matches, (code_matches, abort_matches)): (Vec<_>, (Vec<_>, Vec<_>)) = variants
132 .iter()
133 .map(|variant| {
134 let variant_ident = &variant.ident;
135
136 if variant.transparent.is_present() {
137 let mut maybe_sources = variant
139 .fields
140 .iter()
141 .enumerate()
142 .filter(|(_, f)| (!f.attrs.is_empty()))
143 .map(|(i, f)| (i, f.ident.clone()));
144 let source = maybe_sources.next();
145 if maybe_sources.count() != 0 {
146 variant_ident
147 .span()
148 .unwrap()
149 .error("multiple error sources specified for variant")
150 .emit();
151 return (quote!(), (quote!(), quote!()));
152 }
153 if source.is_none() {
154 variant_ident
155 .span()
156 .unwrap()
157 .error("no source error specified for variant")
158 .emit();
159 return (quote!(), (quote!(), quote!()));
160 }
161 let (field_index, field_ident) = source.unwrap();
162
163 let field = match field_ident {
164 Some(ident) => Member::Named(ident),
165 None => Member::Unnamed(Index {
166 index: field_index as u32,
167 span: variant_ident.span(),
168 }),
169 };
170
171 let non_source_fields = variant
173 .fields
174 .iter()
175 .enumerate()
176 .filter(|(i, _)| i != &field_index)
177 .map(|(i, f)| {
178 let pat = match f.ident {
179 Some(ref ident) => Member::Named(ident.clone()),
180 None => Member::Unnamed(Index {
181 index: i as u32,
182 span: variant_ident.span(),
183 }),
184 };
185 let ident = Ident::new(&format!("__a{i}"), variant_ident.span());
186 let binding = quote!( #pat: #ident, );
187
188 binding
189 });
190 let non_source_field_bindings = non_source_fields.clone();
191
192 let source = quote!(source);
193 let module_name = quote_spanned!(variant_ident.span()=> #source.module_name());
194 let code = quote_spanned!(variant_ident.span()=> #source.code());
195 let abort_reclaim = quote!(Self::#variant_ident { #field: e, #(#non_source_fields)* });
196 let abort = quote_spanned!(variant_ident.span()=> #source.into_abort().map_err(|e| #abort_reclaim));
197
198 (
199 quote! {
200 Self::#variant_ident { #field: #source, .. } => #module_name,
201 },
202 (
203 quote! {
204 Self::#variant_ident { #field: #source, .. } => #code,
205 },
206 quote! {
207 Self::#variant_ident { #field: #source, #(#non_source_field_bindings)* } => #abort,
208 },
209 ),
210 )
211 } else {
212 let code = match variant.code {
214 Some(code) => {
215 if reserved_numbers.contains(&code) {
216 variant_ident
217 .span()
218 .unwrap()
219 .error(format!("code {code} already used"))
220 .emit();
221 return (quote!(), (quote!(), quote!()));
222 }
223 reserved_numbers.insert(code);
224 code
225 }
226 None if autonumber => {
227 let mut reserved_successors = reserved_numbers.range(next_autonumber..);
228 while reserved_successors.next() == Some(&next_autonumber) {
229 next_autonumber += 1;
230 }
231 let code = next_autonumber;
232 reserved_numbers.insert(code);
233 next_autonumber += 1;
234 code
235 }
236 None => {
237 variant_ident
238 .span()
239 .unwrap()
240 .error("missing `code` for variant")
241 .emit();
242 return (quote!(), (quote!(), quote!()));
243 }
244 };
245
246 let abort = if variant.abort.is_present() {
247 quote!{
248 Self::#variant_ident(err) => Ok(err),
249 }
250 } else {
251 quote!{
252 Self::#variant_ident { .. } => Err(#enum_binding),
253 }
254 };
255
256 (
257 quote! {
258 Self::#variant_ident { .. } => #module_name,
259 },
260 (
261 quote! {
262 Self::#variant_ident { .. } => #code,
263 },
264 abort,
265 ),
266 )
267 }
268 })
269 .unzip();
270
271 let abort_body = if abort_self {
272 quote!(Ok(self))
273 } else {
274 quote! {
275 match #enum_binding {
276 #(#abort_matches)*
277 }
278 }
279 };
280
281 (
282 quote! {
283 match #enum_binding {
284 #(#module_name_matches)*
285 }
286 },
287 quote! {
288 match #enum_binding {
289 #(#code_matches)*
290 }
291 },
292 abort_body,
293 )
294}
295
296#[cfg(test)]
297mod tests {
298 #[test]
299 fn generate_error_impl_auto_abort() {
300 let expected: syn::Stmt = syn::parse_quote!(
301 const _: () = {
302 use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
303 #[automatically_derived]
304 impl __sdk::error::Error for Error {
305 fn module_name(&self) -> &str {
306 match self {
307 Self::Error0 { .. } => MODULE_NAME,
308 Self::Error2 { .. } => MODULE_NAME,
309 Self::Error1 { .. } => MODULE_NAME,
310 Self::Error3 { .. } => MODULE_NAME,
311 Self::ErrorAbort { .. } => MODULE_NAME,
312 }
313 }
314 fn code(&self) -> u32 {
315 match self {
316 Self::Error0 { .. } => 0u32,
317 Self::Error2 { .. } => 2u32,
318 Self::Error1 { .. } => 1u32,
319 Self::Error3 { .. } => 3u32,
320 Self::ErrorAbort { .. } => 4u32,
321 }
322 }
323 fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
324 match self {
325 Self::Error0 { .. } => Err(self),
326 Self::Error2 { .. } => Err(self),
327 Self::Error1 { .. } => Err(self),
328 Self::Error3 { .. } => Err(self),
329 Self::ErrorAbort(err) => Ok(err),
330 }
331 }
332 }
333 #[automatically_derived]
334 impl From<Error> for __sdk::error::RuntimeError {
335 fn from(err: Error) -> Self {
336 Self::new(err.module_name(), err.code(), &err.to_string())
337 }
338 }
339 };
340 );
341
342 let input: syn::DeriveInput = syn::parse_quote!(
343 #[derive(Error)]
344 #[sdk_error(autonumber)]
345 pub enum Error {
346 Error0,
347 #[sdk_error(code = 2)]
348 Error2 {
349 payload: Vec<u8>,
350 },
351 Error1(String),
352 Error3,
353 #[sdk_error(abort)]
354 ErrorAbort(sdk::dispatcher::Error),
355 }
356 );
357 let error_derivation = super::derive_error(input);
358 let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
359
360 crate::assert_empty_diff!(actual, expected);
361 }
362
363 #[test]
364 fn generate_error_impl_manual() {
365 let expected: syn::Stmt = syn::parse_quote!(
366 const _: () = {
367 use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
368 #[automatically_derived]
369 impl __sdk::error::Error for Error {
370 fn module_name(&self) -> &str {
371 THE_MODULE_NAME
372 }
373 fn code(&self) -> u32 {
374 0
375 }
376 fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
377 Err(self)
378 }
379 }
380 #[automatically_derived]
381 impl From<Error> for __sdk::error::RuntimeError {
382 fn from(err: Error) -> Self {
383 Self::new(err.module_name(), err.code(), &err.to_string())
384 }
385 }
386 };
387 );
388
389 let input: syn::DeriveInput = syn::parse_quote!(
390 #[derive(Error)]
391 #[sdk_error(autonumber, module_name = "THE_MODULE_NAME")]
392 pub enum Error {}
393 );
394 let error_derivation = super::derive_error(input);
395 let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
396
397 crate::assert_empty_diff!(actual, expected);
398 }
399
400 #[test]
401 fn generate_error_impl_from() {
402 let expected: syn::Stmt = syn::parse_quote!(
403 const _: () = {
404 use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
405 #[automatically_derived]
406 impl __sdk::error::Error for Error {
407 fn module_name(&self) -> &str {
408 match self {
409 Self::Foo { 0: source, .. } => source.module_name(),
410 }
411 }
412 fn code(&self) -> u32 {
413 match self {
414 Self::Foo { 0: source, .. } => source.code(),
415 }
416 }
417 fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
418 match self {
419 Self::Foo { 0: source } => {
420 source.into_abort().map_err(|e| Self::Foo { 0: e })
421 }
422 }
423 }
424 }
425 #[automatically_derived]
426 impl From<Error> for __sdk::error::RuntimeError {
427 fn from(err: Error) -> Self {
428 Self::new(err.module_name(), err.code(), &err.to_string())
429 }
430 }
431 };
432 );
433
434 let input: syn::DeriveInput = syn::parse_quote!(
435 #[derive(Error)]
436 #[sdk_error(module_name = "THE_MODULE_NAME")]
437 pub enum Error {
438 #[sdk_error(transparent)]
439 Foo(#[from] AnotherError),
440 }
441 );
442 let error_derivation = super::derive_error(input);
443 let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
444
445 crate::assert_empty_diff!(actual, expected);
446 }
447
448 #[test]
449 fn generate_error_impl_abort_self() {
450 let expected: syn::Stmt = syn::parse_quote!(
451 const _: () = {
452 use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
453 #[automatically_derived]
454 impl __sdk::error::Error for Error {
455 fn module_name(&self) -> &str {
456 match self {
457 Self::Foo { .. } => THE_MODULE_NAME,
458 Self::Bar { .. } => THE_MODULE_NAME,
459 }
460 }
461 fn code(&self) -> u32 {
462 match self {
463 Self::Foo { .. } => 1u32,
464 Self::Bar { .. } => 2u32,
465 }
466 }
467 fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
468 Ok(self)
469 }
470 }
471 #[automatically_derived]
472 impl From<Error> for __sdk::error::RuntimeError {
473 fn from(err: Error) -> Self {
474 Self::new(err.module_name(), err.code(), &err.to_string())
475 }
476 }
477 };
478 );
479
480 let input: syn::DeriveInput = syn::parse_quote!(
481 #[derive(Error)]
482 #[sdk_error(module_name = "THE_MODULE_NAME", abort_self)]
483 pub enum Error {
484 #[sdk_error(code = 1)]
485 Foo,
486 #[sdk_error(code = 2)]
487 Bar,
488 }
489 );
490 let error_derivation = super::derive_error(input);
491 let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
492
493 crate::assert_empty_diff!(actual, expected);
494 }
495}