1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{parse_macro_input, ItemEnum, Meta, Expr, Lit, ExprLit, Token};
5use syn::punctuated::Punctuated;
6
7#[proc_macro_attribute]
8pub fn context_error(args: TokenStream, input: TokenStream) -> TokenStream {
9 let input_enum = parse_macro_input!(input as ItemEnum);
10 let enum_name = &input_enum.ident;
11 let vis = &input_enum.vis;
12
13 let args_parsed = {
15 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
16 parser.parse(args).expect("Failed to parse macro arguments")
17 };
18
19 let mut custom_ctx_name = None;
20 let mut custom_res_name = None;
21
22 for meta in args_parsed {
23 if let Meta::NameValue(nv) = meta {
24 if nv.path.is_ident("ctxerror") {
25 if let Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) = nv.value {
26 custom_ctx_name = Some(format_ident!("{}", lit.value()));
27 }
28 }
29 else if nv.path.is_ident("result") {
30 if let Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) = nv.value {
31 custom_res_name = Some(format_ident!("{}", lit.value()));
32 }
33 }
34 }
35 }
36
37 let enum_name_str = enum_name.to_string();
38 if (custom_ctx_name.is_none() || custom_res_name.is_none()) && !enum_name_str.ends_with("Error") {
39 panic!(
40 "#[context_error]: Enum name '{}' must end with 'Error' to use auto-generated names (e.g., MyError). \
41 Either rename your Enum or explicitly provide both 'name' and 'result' attributes.",
42 enum_name_str
43 );
44 }
45
46 let (ctx_err_name, result_name) = if enum_name_str.ends_with("Error") {
47 let stem = &enum_name_str[..enum_name_str.len() - 5];
48 (
49 custom_ctx_name.unwrap_or_else(|| format_ident!("{}CtxError", stem)),
50 custom_res_name.unwrap_or_else(|| format_ident!("{}Result", stem)),
51 )
52 } else {
53 (
54 custom_ctx_name.unwrap(),
55 custom_res_name.unwrap()
56 )
57 };
58
59 let expanded = quote! {
61 #[derive(Debug, ::thiserrorctx::thiserror::Error)]
62 #input_enum
63
64 #[derive(Debug)]
65 #vis struct #ctx_err_name {
66 pub contexts: Vec<String>,
67 pub error: #enum_name,
68 }
69
70 impl #ctx_err_name {
71 pub fn wrap<T, E, C>(res: std::result::Result<T, E>, ctx: C) -> #result_name<T>
72 where
73 E: Into<#enum_name>,
74 C: std::fmt::Display + Send + Sync + 'static
75 {
76 res.map_err(|e| {
77 let mut me = Self::from(e);
78 me.contexts.push(ctx.to_string());
79 me
80 })
81 }
82
83 pub fn with_wrap<T, E, C, F>(res: std::result::Result<T, E>, f: F) -> #result_name<T>
84 where
85 E: Into<#enum_name>,
86 C: std::fmt::Display + Send + Sync + 'static,
87 F: FnOnce() -> C
88 {
89 res.map_err(|e| {
90 let mut me = Self::from(e);
91 me.contexts.push(f().to_string());
92 me
93 })
94 }
95 }
96
97 impl std::fmt::Display for #ctx_err_name {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 if self.contexts.is_empty() {
100 return write!(f, "{}", self.error);
101 }
102
103 for ctx in self.contexts.iter().rev() {
104 writeln!(f, "{}", ctx)?;
105 write!(f, " ↳ ")?;
106 }
107
108 write!(f, "{}", self.error)
109 }
110 }
111
112 impl std::error::Error for #ctx_err_name {
113 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
114 Some(&self.error)
115 }
116 }
117
118 impl<T> From<T> for #ctx_err_name
119 where T: Into<#enum_name>
120 {
121 fn from(value: T) -> Self {
122 Self { contexts: vec![], error: value.into() }
123 }
124 }
125
126 #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
127
128 impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
129 where
130 E: Into<#enum_name>,
131 {
132 type Ok = T;
133
134 fn context<C>(self, ctx: C) -> #result_name<T>
135 where
136 C: std::fmt::Display + Send + Sync + 'static
137 {
138 self.map_err(|e| {
139 let mut err_ctx = #ctx_err_name::from(e.into());
140 err_ctx.contexts.push(ctx.to_string());
141 err_ctx
142 })
143 }
144
145 fn with_context<C, F>(self, f: F) -> #result_name<T>
146 where
147 C: std::fmt::Display + Send + Sync + 'static,
148 F: FnOnce() -> C
149 {
150 self.map_err(|e| {
151 let mut err_ctx = #ctx_err_name::from(e.into());
152 err_ctx.contexts.push(f().to_string());
153 err_ctx
154 })
155 }
156 }
157
158 impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
159 type Ok = T;
160
161 fn context<C>(self, ctx: C) -> #result_name<T>
162 where
163 C: std::fmt::Display + Send + Sync + 'static
164 {
165 self.map_err(|mut e| {
166 e.contexts.push(ctx.to_string());
167 e
168 })
169 }
170
171 fn with_context<C, F>(self, f: F) -> #result_name<T>
172 where
173 C: std::fmt::Display + Send + Sync + 'static,
174 F: FnOnce() -> C
175 {
176 self.map_err(|mut e| {
177 e.contexts.push(f().to_string());
178 e
179 })
180 }
181 }
182 };
183
184 TokenStream::from(expanded)
185}