thiserrorctx_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, ItemEnum};
4
5#[proc_macro_attribute]
6pub fn error(_args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_enum = parse_macro_input!(input as ItemEnum);
8
9    let enum_name = &input_enum.ident;
10    let vis = &input_enum.vis;
11
12    // new ctx err name
13    let ctx_err_name = format_ident!("{}Ctx", enum_name);
14    
15    // Result type name
16    let enum_name_str = enum_name.to_string();
17    let result_name = if enum_name_str.ends_with("Error") {
18        format_ident!("{}Result", &enum_name_str[..enum_name_str.len() - 5])
19    } else {
20        format_ident!("{}Result", enum_name)
21    };
22
23    let expanded = quote! {
24        #[derive(Debug, thiserror::Error)] 
25        #input_enum
26
27        #[derive(Debug)]
28        #vis struct #ctx_err_name {
29            pub context: Vec<String>,
30            pub error: #enum_name,
31        }
32
33        impl std::fmt::Display for #ctx_err_name {
34            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35                for ctx in &self.context {
36                    write!(f, "{}: ", ctx)?;
37                }
38                write!(f, "{}", self.error)
39            }
40        }
41
42        impl std::error::Error for #ctx_err_name {
43            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44                self.error.source()
45            }
46        }
47
48        impl From<#enum_name> for #ctx_err_name {
49            fn from(error: #enum_name) -> Self {
50                Self {
51                    context: vec![],
52                    error,
53                }
54            }
55        }
56
57        #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
58
59        impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
60        where
61            E: Into<#enum_name>,
62        {
63            type Ok = T;
64
65            fn context<C>(self, ctx: C) -> #result_name<T>
66            where
67                C: std::fmt::Display + Send + Sync + 'static
68            {
69                self.map_err(|e| {
70                    let mut err_ctx = #ctx_err_name::from(e.into());
71                    err_ctx.context.push(ctx.to_string());
72                    err_ctx
73                })
74            }
75
76            fn with_context<C, F>(self, f: F) -> #result_name<T>
77            where
78                C: std::fmt::Display + Send + Sync + 'static,
79                F: FnOnce() -> C
80            {
81                self.map_err(|e| {
82                    let mut err_ctx = #ctx_err_name::from(e.into());
83                    err_ctx.context.push(f().to_string());
84                    err_ctx
85                })
86            }
87        }
88
89        impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
90            type Ok = T;
91
92            fn context<C>(self, ctx: C) -> #result_name<T>
93            where
94                C: std::fmt::Display + Send + Sync + 'static
95            {
96                self.map_err(|mut e| {
97                    e.context.push(ctx.to_string());
98                    e
99                })
100            }
101
102            fn with_context<C, F>(self, f: F) -> #result_name<T>
103            where
104                C: std::fmt::Display + Send + Sync + 'static,
105                F: FnOnce() -> C
106            {
107                self.map_err(|mut e| {
108                    e.context.push(f().to_string());
109                    e
110                })
111            }
112        }
113    };
114
115    let s = TokenStream::from(expanded);
116
117    // panic!("{}", s);
118
119    s
120}