Skip to main content

thiserrorctx_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{parse_macro_input, ItemEnum, Meta, Expr, Lit, ExprLit, Token};
5use syn::punctuated::Punctuated;
6
7#[proc_macro_attribute]
8pub fn context_error(args: TokenStream, input: TokenStream) -> TokenStream {    
9    let mut input_enum = parse_macro_input!(input as ItemEnum);
10    assert!(input_enum.generics.params.is_empty(), "Generics are not supported yet");
11
12    let enum_name = &input_enum.ident;
13    let vis = &input_enum.vis;
14
15    // 提取宏参数信息 (result 名称自定义)
16    let args_parsed = {
17        let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
18        parser.parse(args).expect("Failed to parse macro arguments")
19    };
20
21    let mut custom_res_name = None;
22    for meta in args_parsed {
23        if let Meta::NameValue(nv) = meta {
24            if nv.path.is_ident("result") {
25                if let Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) = nv.value {
26                    custom_res_name = Some(format_ident!("{}", lit.value()));
27                }
28            }
29        }
30    }
31
32    let enum_name_str = enum_name.to_string();
33    let result_name = if enum_name_str.ends_with("Error") {
34        let stem = &enum_name_str[..enum_name_str.len() - 5];
35        custom_res_name.unwrap_or_else(|| format_ident!("{}Result", stem))
36    } else {
37        custom_res_name.unwrap_or_else(|| format_ident!("{}Result", enum_name_str))
38    };
39
40    let display_struct_name = format_ident!("__{}ContextDisplay", enum_name);
41
42    // 插入 Context 变体 (移除了 backtrace 字段)
43    let context_variant: syn::Variant = syn::parse_quote! {
44        #[error("{}", #display_struct_name(.err, .contexts))]
45        Context {
46            #[source]
47            err: Box<Self>,
48            contexts: Vec<String>,
49        }
50    };
51    input_enum.variants.push(context_variant);
52
53    let expanded = quote! {
54        #[derive(Debug, ::thiserror::Error)]
55        #input_enum
56
57        #[doc(hidden)]
58        #vis struct #display_struct_name<'a>(&'a Box<#enum_name>, &'a Vec<String>);
59
60        #vis type #result_name<T> = std::result::Result<T, #enum_name>;
61
62        impl<'a> std::fmt::Display for #display_struct_name<'a> {
63            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64                if self.1.is_empty() {
65                    return write!(f, "{}", self.0);
66                }
67                for ctx in self.1.iter().rev() {
68                    writeln!(f, "{}", ctx)?;
69                    write!(f, "  ↳ ")?;
70                }
71                write!(f, "{}", self.0)
72            }
73        }
74
75        impl<T, E> thiserrorctx::Context<#enum_name> for std::result::Result<T, E>
76        where
77            E: Into<#enum_name>,
78        {
79            type Ok = T;
80
81            fn context<C>(self, ctx: C) -> #result_name<T>
82            where
83                C: std::fmt::Display + Send + Sync + 'static,
84            {
85                self.map_err(|e| {
86                    let err: #enum_name = e.into();
87                    match err {
88                        #enum_name::Context { mut contexts, err } => {
89                            contexts.push(ctx.to_string());
90                            #enum_name::Context { contexts, err }
91                        }
92                        other => {
93                            #enum_name::Context {
94                                contexts: vec![ctx.to_string()],
95                                err: Box::new(other),
96                            }
97                        }
98                    }
99                })
100            }
101
102            fn with_context<C, F>(self, f: F) -> #result_name<T>
103            where
104                C: std::fmt::Display + Send + Sync + 'static,
105                F: FnOnce() -> C,
106            {
107                self.map_err(|e| {
108                    let err: #enum_name = e.into();
109                    match err {
110                        #enum_name::Context { mut contexts, err } => {
111                            contexts.push(f().to_string());
112                            #enum_name::Context { contexts, err }
113                        }
114                        other => {
115                            #enum_name::Context {
116                                contexts: vec![f().to_string()],
117                                err: Box::new(other),
118                            }
119                        }
120                    }
121                })
122            }
123        }
124
125        impl #enum_name {
126            pub fn context<C>(self, ctx: C) -> Self
127            where
128                C: std::fmt::Display + Send + Sync + 'static,
129            {
130                match self {
131                    Self::Context { mut contexts, err } => {
132                        contexts.push(ctx.to_string());
133                        Self::Context { contexts, err }
134                    }
135                    other => Self::Context {
136                        contexts: vec![ctx.to_string()],
137                        err: Box::new(other),
138                    },
139                }
140            }
141        
142            pub fn with_context<C, F>(self, f: F) -> Self
143            where
144                C: std::fmt::Display + Send + Sync + 'static,
145                F: FnOnce() -> C,
146            {
147                self.context(f())
148            }
149
150            pub fn root_cause(&self) -> &Self {
151                match self {
152                    Self::Context { err, .. } => err.as_ref(),
153                    _ => self
154                }
155            }
156
157            pub fn into_root_cause(self) -> Self {
158                match self {
159                    Self::Context { err, .. } => *err,
160                    _ => self,
161                }
162            }
163
164            pub fn contexts(&self) -> &[String] {
165                match self {
166                    Self::Context { contexts, .. } => contexts.as_slice(),
167                    _ => &[], 
168                }
169            }
170        }
171    };
172
173    TokenStream::from(expanded)
174}