typhoon_errors_macro/
lib.rs

1use {
2    proc_macro::TokenStream,
3    quote::{quote, ToTokens},
4    syn::{
5        parse::Parse, parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Ident, Lit,
6        LitStr,
7    },
8};
9
10/// Derive macro for generating error implementations
11///
12/// Usage:
13/// ```rust
14/// # use {
15/// #    pinocchio::program_error::{ProgramError, ToStr},
16/// #    typhoon_errors::Error,
17/// #    typhoon_errors_macro::TyphoonError
18/// # };
19/// #[derive(TyphoonError)]
20/// pub enum MyError {
21///     #[msg("Error: Invalid owner")]
22///     InvalidOwner = 200,
23///     #[msg("Error: Insufficient funds")]
24///     InsufficientFunds,
25/// }
26/// ```
27#[proc_macro_derive(TyphoonError, attributes(msg))]
28pub fn typhoon_error(input: TokenStream) -> TokenStream {
29    let errors_token = parse_macro_input!(input as Errors);
30
31    errors_token.to_token_stream().into()
32}
33
34fn parse_attribute(attributes: &[Attribute]) -> Option<String> {
35    attributes.iter().find_map(|attr| {
36        if !attr.path().is_ident("msg") {
37            return None;
38        }
39
40        let lit: LitStr = attr.parse_args().ok()?;
41        Some(lit.value())
42    })
43}
44
45struct Variant {
46    discriminant: u32,
47    name: Ident,
48    msg: String,
49}
50
51struct Errors {
52    name: Ident,
53    variants: Vec<Variant>,
54}
55
56impl Parse for Errors {
57    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
58        let derive_input: DeriveInput = input.parse()?;
59
60        let Data::Enum(data) = &derive_input.data else {
61            return Err(syn::Error::new_spanned(
62                &derive_input,
63                "TyphoonDerive can only be used on enums",
64            ));
65        };
66
67        let mut variants = Vec::with_capacity(data.variants.len());
68        let mut latest_dis: isize = -1;
69
70        for variant in &data.variants {
71            let variant_name = &variant.ident;
72            let msg = parse_attribute(&variant.attrs)
73                .ok_or(syn::Error::new_spanned(variant, "No error msg set."))?;
74
75            if let Some((_, ref expr)) = variant.discriminant {
76                if let Expr::Lit(ExprLit {
77                    lit: Lit::Int(val), ..
78                }) = expr
79                {
80                    latest_dis = val.base10_parse::<isize>()?
81                } else {
82                    return Err(syn::Error::new_spanned(expr, "Invalid discriminant."));
83                }
84            } else {
85                latest_dis += 1;
86            }
87
88            variants.push(Variant {
89                name: variant_name.to_owned(),
90                msg,
91                discriminant: latest_dis as u32,
92            });
93        }
94
95        Ok(Errors {
96            name: derive_input.ident,
97            variants,
98        })
99    }
100}
101
102impl ToTokens for Errors {
103    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
104        let name = &self.name;
105
106        let (to_str_arms, try_from_arms) = self
107            .variants
108            .iter()
109            .map(|v| {
110                let variant_name = &v.name;
111                let msg = &v.msg;
112                let discriminant = &v.discriminant;
113                (
114                    quote!(#name::#variant_name => #msg,),
115                    quote!(#discriminant => Ok(#name::#variant_name),),
116                )
117            })
118            .collect::<(Vec<_>, Vec<_>)>();
119
120        quote! {
121            impl TryFrom<u32> for #name {
122                type Error = ProgramError;
123
124                fn try_from(value: u32) -> Result<Self, Self::Error> {
125                    match value {
126                        #(#try_from_arms)*
127                        _ => Err(ProgramError::InvalidArgument),
128                    }
129                }
130            }
131
132            impl ToStr for #name {
133                fn to_str<E>(&self) -> &'static str
134                where
135                    E: 'static + ToStr + TryFrom<u32>,
136                {
137                    match self {
138                        #(#to_str_arms)*
139                    }
140                }
141            }
142
143            impl From<#name> for Error {
144                fn from(value: #name) -> Self {
145                    Error::new(ProgramError::Custom(value as u32))
146                }
147            }
148        }
149        .to_tokens(tokens);
150    }
151}