thiserrorctx_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, ItemEnum};
4
5#[proc_macro_attribute]
6pub fn error(_args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_enum = parse_macro_input!(input as ItemEnum);
8
9    let enum_name = &input_enum.ident;
10    let vis = &input_enum.vis;
11
12    // new ctx err name
13    let ctx_err_name = format_ident!("{}Ctx", enum_name);
14    
15    // Result type name
16    let enum_name_str = enum_name.to_string();
17    let result_name = if enum_name_str.ends_with("Error") {
18        format_ident!("{}Result", &enum_name_str[..enum_name_str.len() - 5])
19    } else {
20        format_ident!("{}Result", enum_name)
21    };
22
23    let expanded = quote! {
24        #[derive(Debug, ::thiserrorctx::thiserror::Error)] 
25        #input_enum
26
27        #[derive(Debug)]
28        #vis struct #ctx_err_name {
29            pub context: Vec<String>,
30            pub error: #enum_name,
31        }
32
33        impl std::fmt::Display for #ctx_err_name {
34            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35                for ctx in &self.context {
36                    write!(f, "{}: ", ctx)?;
37                }
38                write!(f, "{}", self.error)
39            }
40        }
41
42        impl std::error::Error for #ctx_err_name {
43            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44                self.error.source()
45            }
46        }
47
48        impl<T> From<T> for #ctx_err_name 
49        where T: Into<#enum_name>
50        {
51            fn from(value: T) -> Self {
52                Self { context: vec![], error: value.into() }
53            }
54        }
55
56        #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
57
58        impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
59        where
60            E: Into<#enum_name>,
61        {
62            type Ok = T;
63
64            fn context<C>(self, ctx: C) -> #result_name<T>
65            where
66                C: std::fmt::Display + Send + Sync + 'static
67            {
68                self.map_err(|e| {
69                    let mut err_ctx = #ctx_err_name::from(e.into());
70                    err_ctx.context.push(ctx.to_string());
71                    err_ctx
72                })
73            }
74
75            fn with_context<C, F>(self, f: F) -> #result_name<T>
76            where
77                C: std::fmt::Display + Send + Sync + 'static,
78                F: FnOnce() -> C
79            {
80                self.map_err(|e| {
81                    let mut err_ctx = #ctx_err_name::from(e.into());
82                    err_ctx.context.push(f().to_string());
83                    err_ctx
84                })
85            }
86        }
87
88        impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
89            type Ok = T;
90
91            fn context<C>(self, ctx: C) -> #result_name<T>
92            where
93                C: std::fmt::Display + Send + Sync + 'static
94            {
95                self.map_err(|mut e| {
96                    e.context.push(ctx.to_string());
97                    e
98                })
99            }
100
101            fn with_context<C, F>(self, f: F) -> #result_name<T>
102            where
103                C: std::fmt::Display + Send + Sync + 'static,
104                F: FnOnce() -> C
105            {
106                self.map_err(|mut e| {
107                    e.context.push(f().to_string());
108                    e
109                })
110            }
111        }
112    };
113
114    let s = TokenStream::from(expanded);
115
116    // panic!("{}", s);
117
118    s
119}