1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, ItemEnum};
4
5#[proc_macro_attribute]
6pub fn error(_args: TokenStream, input: TokenStream) -> TokenStream {
7 let input_enum = parse_macro_input!(input as ItemEnum);
8
9 let enum_name = &input_enum.ident;
10 let vis = &input_enum.vis;
11
12 let ctx_err_name = format_ident!("{}Ctx", enum_name);
14
15 let enum_name_str = enum_name.to_string();
17 let result_name = if enum_name_str.ends_with("Error") {
18 format_ident!("{}Result", &enum_name_str[..enum_name_str.len() - 5])
19 } else {
20 format_ident!("{}Result", enum_name)
21 };
22
23 let expanded = quote! {
24 #[derive(Debug, ::thiserrorctx::thiserror::Error)]
25 #input_enum
26
27 #[derive(Debug)]
28 #vis struct #ctx_err_name {
29 pub context: Vec<String>,
30 pub error: #enum_name,
31 }
32
33 impl std::fmt::Display for #ctx_err_name {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 for ctx in &self.context {
36 write!(f, "{}: ", ctx)?;
37 }
38 write!(f, "{}", self.error)
39 }
40 }
41
42 impl std::error::Error for #ctx_err_name {
43 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44 self.error.source()
45 }
46 }
47
48 impl<T> From<T> for #ctx_err_name
49 where T: Into<#enum_name>
50 {
51 fn from(value: T) -> Self {
52 Self { context: vec![], error: value.into() }
53 }
54 }
55
56 #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
57
58 impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
59 where
60 E: Into<#enum_name>,
61 {
62 type Ok = T;
63
64 fn context<C>(self, ctx: C) -> #result_name<T>
65 where
66 C: std::fmt::Display + Send + Sync + 'static
67 {
68 self.map_err(|e| {
69 let mut err_ctx = #ctx_err_name::from(e.into());
70 err_ctx.context.push(ctx.to_string());
71 err_ctx
72 })
73 }
74
75 fn with_context<C, F>(self, f: F) -> #result_name<T>
76 where
77 C: std::fmt::Display + Send + Sync + 'static,
78 F: FnOnce() -> C
79 {
80 self.map_err(|e| {
81 let mut err_ctx = #ctx_err_name::from(e.into());
82 err_ctx.context.push(f().to_string());
83 err_ctx
84 })
85 }
86 }
87
88 impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
89 type Ok = T;
90
91 fn context<C>(self, ctx: C) -> #result_name<T>
92 where
93 C: std::fmt::Display + Send + Sync + 'static
94 {
95 self.map_err(|mut e| {
96 e.context.push(ctx.to_string());
97 e
98 })
99 }
100
101 fn with_context<C, F>(self, f: F) -> #result_name<T>
102 where
103 C: std::fmt::Display + Send + Sync + 'static,
104 F: FnOnce() -> C
105 {
106 self.map_err(|mut e| {
107 e.context.push(f().to_string());
108 e
109 })
110 }
111 }
112 };
113
114 let s = TokenStream::from(expanded);
115
116 s
119}