thiserrorctx_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{parse_macro_input, ItemEnum, Meta, Expr, Lit, ExprLit, Token};
5use syn::punctuated::Punctuated;
6
7#[proc_macro_attribute]
8pub fn context_error(args: TokenStream, input: TokenStream) -> TokenStream {
9    let input_enum = parse_macro_input!(input as ItemEnum);
10    let enum_name = &input_enum.ident;
11    let vis = &input_enum.vis;
12
13    // (ctxerror = "...", result = "...")
14    let args_parsed = {
15        let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
16        parser.parse(args).expect("Failed to parse macro arguments")
17    };
18
19    let mut custom_ctx_name = None;
20    let mut custom_res_name = None;
21
22    for meta in args_parsed {
23        if let Meta::NameValue(nv) = meta {
24            if nv.path.is_ident("ctxerror") {
25                if let Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) = nv.value {
26                    custom_ctx_name = Some(format_ident!("{}", lit.value()));
27                }
28            } 
29            else if nv.path.is_ident("result") {
30                if let Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) = nv.value {
31                    custom_res_name = Some(format_ident!("{}", lit.value()));
32                }
33            }
34        }
35    }
36
37    let enum_name_str = enum_name.to_string();
38    if (custom_ctx_name.is_none() || custom_res_name.is_none()) && !enum_name_str.ends_with("Error") {
39        panic!(
40            "#[context_error]: Enum name '{}' must end with 'Error' to use auto-generated names (e.g., MyError). \
41             Either rename your Enum or explicitly provide both 'name' and 'result' attributes.",
42            enum_name_str
43        );
44    }
45
46    let (ctx_err_name, result_name) = if enum_name_str.ends_with("Error") {
47        let stem = &enum_name_str[..enum_name_str.len() - 5];
48        (
49            custom_ctx_name.unwrap_or_else(|| format_ident!("{}CtxError", stem)),
50            custom_res_name.unwrap_or_else(|| format_ident!("{}Result", stem)),
51        )
52    } else {
53        (
54            custom_ctx_name.unwrap(), 
55            custom_res_name.unwrap()
56        )
57    };
58
59    // generate code
60    let expanded = quote! {
61        #[derive(Debug, ::thiserrorctx::thiserror::Error)] 
62        #input_enum
63
64        #[derive(Debug)]
65        #vis struct #ctx_err_name {
66            pub contexts: Vec<String>,
67            pub error: #enum_name,
68        }
69
70        impl #ctx_err_name {
71            pub fn wrap<T, E, C>(res: std::result::Result<T, E>, ctx: C) -> #result_name<T>
72            where 
73                E: Into<#enum_name>,
74                C: std::fmt::Display + Send + Sync + 'static 
75            {
76                res.map_err(|e| {
77                    let mut me = Self::from(e);
78                    me.contexts.push(ctx.to_string());
79                    me
80                })
81            }
82
83            pub fn with_wrap<T, E, C, F>(res: std::result::Result<T, E>, f: F) -> #result_name<T>
84            where 
85                E: Into<#enum_name>,
86                C: std::fmt::Display + Send + Sync + 'static,
87                F: FnOnce() -> C
88            {
89                res.map_err(|e| {
90                    let mut me = Self::from(e);
91                    me.contexts.push(f().to_string());
92                    me
93                })
94            }
95        }
96
97        impl std::fmt::Display for #ctx_err_name {
98            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99                if self.contexts.is_empty() {
100                    return write!(f, "{}", self.error);
101                }
102
103                for ctx in self.contexts.iter().rev() {
104                    writeln!(f, "{}", ctx)?;
105                    write!(f, "  ↳ ")?;
106                }
107
108                write!(f, "{}", self.error)
109            }
110        }
111
112        impl std::error::Error for #ctx_err_name {
113            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
114                Some(&self.error) 
115            }
116        }
117
118        impl<T> From<T> for #ctx_err_name 
119        where T: Into<#enum_name>
120        {
121            fn from(value: T) -> Self {
122                Self { contexts: vec![], error: value.into() }
123            }
124        }
125
126        #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
127
128        impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
129        where
130            E: Into<#enum_name>,
131        {
132            type Ok = T;
133
134            fn context<C>(self, ctx: C) -> #result_name<T>
135            where
136                C: std::fmt::Display + Send + Sync + 'static
137            {
138                self.map_err(|e| {
139                    let mut err_ctx = #ctx_err_name::from(e.into());
140                    err_ctx.contexts.push(ctx.to_string());
141                    err_ctx
142                })
143            }
144
145            fn with_context<C, F>(self, f: F) -> #result_name<T>
146            where
147                C: std::fmt::Display + Send + Sync + 'static,
148                F: FnOnce() -> C
149            {
150                self.map_err(|e| {
151                    let mut err_ctx = #ctx_err_name::from(e.into());
152                    err_ctx.contexts.push(f().to_string());
153                    err_ctx
154                })
155            }
156        }
157
158        impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
159            type Ok = T;
160
161            fn context<C>(self, ctx: C) -> #result_name<T>
162            where
163                C: std::fmt::Display + Send + Sync + 'static
164            {
165                self.map_err(|mut e| {
166                    e.contexts.push(ctx.to_string());
167                    e
168                })
169            }
170
171            fn with_context<C, F>(self, f: F) -> #result_name<T>
172            where
173                C: std::fmt::Display + Send + Sync + 'static,
174                F: FnOnce() -> C
175            {
176                self.map_err(|mut e| {
177                    e.contexts.push(f().to_string());
178                    e
179                })
180            }
181        }
182    };
183
184    TokenStream::from(expanded)
185}