1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, ItemEnum};
4
5#[proc_macro_attribute]
6pub fn error(_args: TokenStream, input: TokenStream) -> TokenStream {
7 let input_enum = parse_macro_input!(input as ItemEnum);
8
9 let enum_name = &input_enum.ident;
10 let vis = &input_enum.vis;
11
12 let ctx_err_name = format_ident!("{}Ctx", enum_name);
14
15 let enum_name_str = enum_name.to_string();
17 let result_name = if enum_name_str.ends_with("Error") {
18 format_ident!("{}Result", &enum_name_str[..enum_name_str.len() - 5])
19 } else {
20 format_ident!("{}Result", enum_name)
21 };
22
23 let expanded = quote! {
24 #[derive(Debug, thiserror::Error)]
25 #input_enum
26
27 #[derive(Debug)]
28 #vis struct #ctx_err_name {
29 pub context: Vec<String>,
30 pub error: #enum_name,
31 }
32
33 impl std::fmt::Display for #ctx_err_name {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 for ctx in &self.context {
36 write!(f, "{}: ", ctx)?;
37 }
38 write!(f, "{}", self.error)
39 }
40 }
41
42 impl std::error::Error for #ctx_err_name {
43 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44 self.error.source()
45 }
46 }
47
48 impl From<#enum_name> for #ctx_err_name {
49 fn from(error: #enum_name) -> Self {
50 Self {
51 context: vec![],
52 error,
53 }
54 }
55 }
56
57 #vis type #result_name<T> = std::result::Result<T, #ctx_err_name>;
58
59 impl<T, E> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, E>
60 where
61 E: Into<#enum_name>,
62 {
63 type Ok = T;
64
65 fn context<C>(self, ctx: C) -> #result_name<T>
66 where
67 C: std::fmt::Display + Send + Sync + 'static
68 {
69 self.map_err(|e| {
70 let mut err_ctx = #ctx_err_name::from(e.into());
71 err_ctx.context.push(ctx.to_string());
72 err_ctx
73 })
74 }
75
76 fn with_context<C, F>(self, f: F) -> #result_name<T>
77 where
78 C: std::fmt::Display + Send + Sync + 'static,
79 F: FnOnce() -> C
80 {
81 self.map_err(|e| {
82 let mut err_ctx = #ctx_err_name::from(e.into());
83 err_ctx.context.push(f().to_string());
84 err_ctx
85 })
86 }
87 }
88
89 impl<T> thiserrorctx::Context<#ctx_err_name> for std::result::Result<T, #ctx_err_name> {
90 type Ok = T;
91
92 fn context<C>(self, ctx: C) -> #result_name<T>
93 where
94 C: std::fmt::Display + Send + Sync + 'static
95 {
96 self.map_err(|mut e| {
97 e.context.push(ctx.to_string());
98 e
99 })
100 }
101
102 fn with_context<C, F>(self, f: F) -> #result_name<T>
103 where
104 C: std::fmt::Display + Send + Sync + 'static,
105 F: FnOnce() -> C
106 {
107 self.map_err(|mut e| {
108 e.context.push(f().to_string());
109 e
110 })
111 }
112 }
113 };
114
115 let s = TokenStream::from(expanded);
116
117 s
120}