use darling::{util::Flag, FromDeriveInput, FromField, FromVariant};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use syn::{DeriveInput, Ident, Index, Member, Path};
use crate::generators::{self as gen, CodedVariant};
#[derive(FromDeriveInput)]
#[darling(supports(enum_any), attributes(sdk_error))]
struct Error {
ident: Ident,
data: darling::ast::Data<ErrorVariant, darling::util::Ignored>,
module_name: Option<syn::Path>,
#[darling(rename = "autonumber")]
autonumber: Flag,
#[darling(rename = "abort_self")]
abort_self: Flag,
}
#[derive(FromVariant)]
#[darling(attributes(sdk_error))]
struct ErrorVariant {
ident: Ident,
fields: darling::ast::Fields<ErrorField>,
#[darling(rename = "code")]
code: Option<u32>,
#[darling(rename = "transparent")]
transparent: Flag,
#[darling(rename = "abort")]
abort: Flag,
}
impl CodedVariant for ErrorVariant {
const FIELD_NAME: &'static str = "code";
fn ident(&self) -> &Ident {
&self.ident
}
fn code(&self) -> Option<u32> {
self.code
}
}
#[derive(FromField)]
#[darling(forward_attrs(source, from))]
struct ErrorField {
ident: Option<Ident>,
attrs: Vec<syn::Attribute>,
}
pub fn derive_error(input: DeriveInput) -> TokenStream {
let error = match Error::from_derive_input(&input) {
Ok(error) => error,
Err(e) => return e.write_errors(),
};
let error_ty_ident = &error.ident;
let module_name = error
.module_name
.unwrap_or_else(|| syn::parse_quote!(MODULE_NAME));
let (module_name_body, code_body, abort_body) = convert_variants(
&format_ident!("self"),
module_name,
&error.data.as_ref().take_enum().unwrap(),
error.autonumber.is_present(),
error.abort_self.is_present(),
);
let sdk_crate = gen::sdk_crate_path();
gen::wrap_in_const(quote! {
use #sdk_crate::{self as __sdk, error::Error as _};
#[automatically_derived]
impl __sdk::error::Error for #error_ty_ident {
fn module_name(&self) -> &str {
#module_name_body
}
fn code(&self) -> u32 {
#code_body
}
fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
#abort_body
}
}
#[automatically_derived]
impl From<#error_ty_ident> for __sdk::error::RuntimeError {
fn from(err: #error_ty_ident) -> Self {
Self::new(err.module_name(), err.code(), &err.to_string())
}
}
})
}
fn convert_variants(
enum_binding: &Ident,
module_name: Path,
variants: &[&ErrorVariant],
autonumber: bool,
abort_self: bool,
) -> (TokenStream, TokenStream, TokenStream) {
if variants.is_empty() {
return (quote!(#module_name), quote!(0), quote!(Err(#enum_binding)));
}
let mut next_autonumber = 0u32;
let mut reserved_numbers = std::collections::BTreeSet::new();
let (module_name_matches, (code_matches, abort_matches)): (Vec<_>, (Vec<_>, Vec<_>)) = variants
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
if variant.transparent.is_present() {
let mut maybe_sources = variant
.fields
.iter()
.enumerate()
.filter(|(_, f)| (!f.attrs.is_empty()))
.map(|(i, f)| (i, f.ident.clone()));
let source = maybe_sources.next();
if maybe_sources.count() != 0 {
variant_ident
.span()
.unwrap()
.error("multiple error sources specified for variant")
.emit();
return (quote!(), (quote!(), quote!()));
}
if source.is_none() {
variant_ident
.span()
.unwrap()
.error("no source error specified for variant")
.emit();
return (quote!(), (quote!(), quote!()));
}
let (field_index, field_ident) = source.unwrap();
let field = match field_ident {
Some(ident) => Member::Named(ident),
None => Member::Unnamed(Index {
index: field_index as u32,
span: variant_ident.span(),
}),
};
let non_source_fields = variant
.fields
.iter()
.enumerate()
.filter(|(i, _)| i != &field_index)
.map(|(i, f)| {
let pat = match f.ident {
Some(ref ident) => Member::Named(ident.clone()),
None => Member::Unnamed(Index {
index: i as u32,
span: variant_ident.span(),
}),
};
let ident = Ident::new(&format!("__a{i}"), variant_ident.span());
let binding = quote!( #pat: #ident, );
binding
});
let non_source_field_bindings = non_source_fields.clone();
let source = quote!(source);
let module_name = quote_spanned!(variant_ident.span()=> #source.module_name());
let code = quote_spanned!(variant_ident.span()=> #source.code());
let abort_reclaim = quote!(Self::#variant_ident { #field: e, #(#non_source_fields)* });
let abort = quote_spanned!(variant_ident.span()=> #source.into_abort().map_err(|e| #abort_reclaim));
(
quote! {
Self::#variant_ident { #field: #source, .. } => #module_name,
},
(
quote! {
Self::#variant_ident { #field: #source, .. } => #code,
},
quote! {
Self::#variant_ident { #field: #source, #(#non_source_field_bindings)* } => #abort,
},
),
)
} else {
let code = match variant.code {
Some(code) => {
if reserved_numbers.contains(&code) {
variant_ident
.span()
.unwrap()
.error(format!("code {code} already used"))
.emit();
return (quote!(), (quote!(), quote!()));
}
reserved_numbers.insert(code);
code
}
None if autonumber => {
let mut reserved_successors = reserved_numbers.range(next_autonumber..);
while reserved_successors.next() == Some(&next_autonumber) {
next_autonumber += 1;
}
let code = next_autonumber;
reserved_numbers.insert(code);
next_autonumber += 1;
code
}
None => {
variant_ident
.span()
.unwrap()
.error("missing `code` for variant")
.emit();
return (quote!(), (quote!(), quote!()));
}
};
let abort = if variant.abort.is_present() {
quote!{
Self::#variant_ident(err) => Ok(err),
}
} else {
quote!{
Self::#variant_ident { .. } => Err(#enum_binding),
}
};
(
quote! {
Self::#variant_ident { .. } => #module_name,
},
(
quote! {
Self::#variant_ident { .. } => #code,
},
abort,
),
)
}
})
.unzip();
let abort_body = if abort_self {
quote!(Ok(self))
} else {
quote! {
match #enum_binding {
#(#abort_matches)*
}
}
};
(
quote! {
match #enum_binding {
#(#module_name_matches)*
}
},
quote! {
match #enum_binding {
#(#code_matches)*
}
},
abort_body,
)
}
#[cfg(test)]
mod tests {
#[test]
fn generate_error_impl_auto_abort() {
let expected: syn::Stmt = syn::parse_quote!(
const _: () = {
use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
#[automatically_derived]
impl __sdk::error::Error for Error {
fn module_name(&self) -> &str {
match self {
Self::Error0 { .. } => MODULE_NAME,
Self::Error2 { .. } => MODULE_NAME,
Self::Error1 { .. } => MODULE_NAME,
Self::Error3 { .. } => MODULE_NAME,
Self::ErrorAbort { .. } => MODULE_NAME,
}
}
fn code(&self) -> u32 {
match self {
Self::Error0 { .. } => 0u32,
Self::Error2 { .. } => 2u32,
Self::Error1 { .. } => 1u32,
Self::Error3 { .. } => 3u32,
Self::ErrorAbort { .. } => 4u32,
}
}
fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
match self {
Self::Error0 { .. } => Err(self),
Self::Error2 { .. } => Err(self),
Self::Error1 { .. } => Err(self),
Self::Error3 { .. } => Err(self),
Self::ErrorAbort(err) => Ok(err),
}
}
}
#[automatically_derived]
impl From<Error> for __sdk::error::RuntimeError {
fn from(err: Error) -> Self {
Self::new(err.module_name(), err.code(), &err.to_string())
}
}
};
);
let input: syn::DeriveInput = syn::parse_quote!(
#[derive(Error)]
#[sdk_error(autonumber)]
pub enum Error {
Error0,
#[sdk_error(code = 2)]
Error2 {
payload: Vec<u8>,
},
Error1(String),
Error3,
#[sdk_error(abort)]
ErrorAbort(sdk::dispatcher::Error),
}
);
let error_derivation = super::derive_error(input);
let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
crate::assert_empty_diff!(actual, expected);
}
#[test]
fn generate_error_impl_manual() {
let expected: syn::Stmt = syn::parse_quote!(
const _: () = {
use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
#[automatically_derived]
impl __sdk::error::Error for Error {
fn module_name(&self) -> &str {
THE_MODULE_NAME
}
fn code(&self) -> u32 {
0
}
fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
Err(self)
}
}
#[automatically_derived]
impl From<Error> for __sdk::error::RuntimeError {
fn from(err: Error) -> Self {
Self::new(err.module_name(), err.code(), &err.to_string())
}
}
};
);
let input: syn::DeriveInput = syn::parse_quote!(
#[derive(Error)]
#[sdk_error(autonumber, module_name = "THE_MODULE_NAME")]
pub enum Error {}
);
let error_derivation = super::derive_error(input);
let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
crate::assert_empty_diff!(actual, expected);
}
#[test]
fn generate_error_impl_from() {
let expected: syn::Stmt = syn::parse_quote!(
const _: () = {
use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
#[automatically_derived]
impl __sdk::error::Error for Error {
fn module_name(&self) -> &str {
match self {
Self::Foo { 0: source, .. } => source.module_name(),
}
}
fn code(&self) -> u32 {
match self {
Self::Foo { 0: source, .. } => source.code(),
}
}
fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
match self {
Self::Foo { 0: source } => {
source.into_abort().map_err(|e| Self::Foo { 0: e })
}
}
}
}
#[automatically_derived]
impl From<Error> for __sdk::error::RuntimeError {
fn from(err: Error) -> Self {
Self::new(err.module_name(), err.code(), &err.to_string())
}
}
};
);
let input: syn::DeriveInput = syn::parse_quote!(
#[derive(Error)]
#[sdk_error(module_name = "THE_MODULE_NAME")]
pub enum Error {
#[sdk_error(transparent)]
Foo(#[from] AnotherError),
}
);
let error_derivation = super::derive_error(input);
let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
crate::assert_empty_diff!(actual, expected);
}
#[test]
fn generate_error_impl_abort_self() {
let expected: syn::Stmt = syn::parse_quote!(
const _: () = {
use ::oasis_runtime_sdk::{self as __sdk, error::Error as _};
#[automatically_derived]
impl __sdk::error::Error for Error {
fn module_name(&self) -> &str {
match self {
Self::Foo { .. } => THE_MODULE_NAME,
Self::Bar { .. } => THE_MODULE_NAME,
}
}
fn code(&self) -> u32 {
match self {
Self::Foo { .. } => 1u32,
Self::Bar { .. } => 2u32,
}
}
fn into_abort(self) -> Result<__sdk::dispatcher::Error, Self> {
Ok(self)
}
}
#[automatically_derived]
impl From<Error> for __sdk::error::RuntimeError {
fn from(err: Error) -> Self {
Self::new(err.module_name(), err.code(), &err.to_string())
}
}
};
);
let input: syn::DeriveInput = syn::parse_quote!(
#[derive(Error)]
#[sdk_error(module_name = "THE_MODULE_NAME", abort_self)]
pub enum Error {
#[sdk_error(code = 1)]
Foo,
#[sdk_error(code = 2)]
Bar,
}
);
let error_derivation = super::derive_error(input);
let actual: syn::Stmt = syn::parse2(error_derivation).unwrap();
crate::assert_empty_diff!(actual, expected);
}
}