Skip to main content

saola_user_facing_error_macros/
lib.rs

1extern crate proc_macro;
2
3#[proc_macro_derive(SimpleUserFacingError, attributes(user_facing))]
4pub fn derive_simple_user_facing_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
5    let input = syn::parse_macro_input!(input as syn::DeriveInput);
6
7    let data = match &input.data {
8        syn::Data::Struct(data) => data,
9        _ => {
10            return syn::Error::new_spanned(input, "derive works only on structs")
11                .to_compile_error()
12                .into();
13        }
14    };
15
16    if !data.fields.is_empty() {
17        return syn::Error::new_spanned(&data.fields, "SimpleUserFacingError implementors cannot have fields")
18            .to_compile_error()
19            .into();
20    }
21
22    let UserErrorDeriveInput { ident, code, message } = match UserErrorDeriveInput::new(&input) {
23        Ok(input) => input,
24        Err(err) => return err.into_compile_error().into(),
25    };
26
27    proc_macro::TokenStream::from(quote::quote! {
28        impl serde::Serialize for #ident {
29            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30                where S: serde::Serializer
31            {
32                serializer.serialize_none()
33            }
34        }
35
36        impl crate::SimpleUserFacingError for #ident {
37            const ERROR_CODE: &'static str = #code;
38            const MESSAGE: &'static str = #message;
39        }
40    })
41}
42
43#[proc_macro_derive(UserFacingError, attributes(user_facing))]
44pub fn derive_user_facing_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
45    let input = syn::parse_macro_input!(input as syn::DeriveInput);
46
47    let data = match &input.data {
48        syn::Data::Struct(data) => data,
49        _ => {
50            return syn::Error::new_spanned(input, "derive works only on structs")
51                .to_compile_error()
52                .into();
53        }
54    };
55
56    let UserErrorDeriveInput { ident, code, message } = match UserErrorDeriveInput::new(&input) {
57        Ok(input) => input,
58        Err(err) => return err.into_compile_error().into(),
59    };
60
61    let template_variables: Box<dyn Iterator<Item = _>> = match &data.fields {
62        syn::Fields::Named(named) => Box::new(named.named.iter().map(|field| field.ident.as_ref().unwrap())),
63        syn::Fields::Unit => Box::new(std::iter::empty()),
64        syn::Fields::Unnamed(unnamed) => {
65            return syn::Error::new_spanned(unnamed, "The error fields must be named")
66                .to_compile_error()
67                .into();
68        }
69    };
70
71    proc_macro::TokenStream::from(quote::quote! {
72        impl crate::UserFacingError for #ident {
73            const ERROR_CODE: &'static str = #code;
74
75            fn message(&self) -> String {
76                format!(
77                    #message,
78                    #(
79                        #template_variables = self.#template_variables
80                    ),*
81                )
82            }
83        }
84    })
85}
86
87struct UserErrorDeriveInput<'a> {
88    /// The name of the struct.
89    ident: &'a syn::Ident,
90    /// The error code.
91    code: syn::LitStr,
92    /// The error message format string.
93    message: syn::LitStr,
94}
95
96impl<'a> UserErrorDeriveInput<'a> {
97    fn new(input: &'a syn::DeriveInput) -> Result<Self, syn::Error> {
98        let mut code = None;
99        let mut message = None;
100
101        for attr in &input.attrs {
102            if !attr
103                .path()
104                .get_ident()
105                .map(|ident| ident == "user_facing")
106                .unwrap_or(false)
107            {
108                continue;
109            }
110
111            for namevalue in attr.parse_args_with(|stream: &'_ syn::parse::ParseBuffer| {
112                syn::punctuated::Punctuated::<syn::MetaNameValue, syn::Token![,]>::parse_terminated(stream)
113            })? {
114                let litstr = match namevalue.value {
115                    syn::Expr::Lit(syn::ExprLit {
116                        lit: syn::Lit::Str(litstr),
117                        ..
118                    }) => litstr,
119                    other => {
120                        return Err(syn::Error::new_spanned(
121                            other,
122                            "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
123                        ));
124                    }
125                };
126
127                match namevalue.path.get_ident() {
128                    Some(ident) if ident == "code" => {
129                        code = Some(litstr);
130                    }
131                    Some(ident) if ident == "message" => {
132                        message = Some(litstr);
133                    }
134                    other => {
135                        return Err(syn::Error::new_spanned(
136                            other,
137                            "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
138                        ));
139                    }
140                }
141            }
142        }
143
144        match (message, code) {
145            (Some(message), Some(code)) => Ok(UserErrorDeriveInput {
146                ident: &input.ident,
147                message,
148                code,
149            }),
150            _ => Err(syn::Error::new_spanned(
151                input,
152                "Expected attribute of the form `#[user_facing(code = \"...\", message = \"...\")]`",
153            )),
154        }
155    }
156}