1use quote::quote;
38use proc_macro::TokenStream;
39
40#[derive(Copy, Clone)]
41enum ResultPos {
42 Ok = 0,
43 Err = 1
44}
45
46fn get_result_type(ret: &mut syn::ReturnType, pos: ResultPos) -> &mut syn::Type {
47 match ret {
48 syn::ReturnType::Type(_, ty) => match ty.as_mut() {
49 syn::Type::Path(syn::TypePath { path, .. }) => {
50 let end = path.segments.iter_mut().last().unwrap();
51
52 if end.ident.to_string() == "Result" {
53 match &mut end.arguments {
54 syn::PathArguments::AngleBracketed(args) => {
55 if args.args.len() != 2 {
56 panic!("Return type must be `Result<T, E>`")
57 }
58
59 let err = args.args.iter_mut().skip(pos as usize).next().unwrap();
60
61 match err {
62 syn::GenericArgument::Type(ref mut err) => err,
63 _ => panic!("Return type must be `Result<T, E>`")
64 }
65 }
66 _ => panic!("Return type must be `Result<T, E>`")
67 }
68 } else {
69 panic!("Return type must be `Result<T, E>`")
70 }
71 }
72 _ => panic!("Return type must be `Result<T, E>`")
73 }
74 syn::ReturnType::Default => panic!("Return type must be `Result<T, E>`")
75 }
76}
77
78fn get_ok_type(ret: &mut syn::ReturnType) -> syn::Type {
79 get_result_type(ret, ResultPos::Ok).clone()
80}
81
82fn get_err_type(ret: &mut syn::ReturnType) -> &mut syn::Type {
83 get_result_type(ret, ResultPos::Err)
84}
85
86fn with_first_letter_uppercase(ident: String) -> String {
87 ident.chars()
88 .enumerate()
89 .map(|(i, first_char)| if i == 0 { first_char.to_ascii_uppercase() } else { first_char })
90 .collect()
91}
92
93fn path_to_normalized_path(path: &syn::Path) -> syn::Ident {
94 let normal_path: String = path.segments.iter()
95 .map(|seg| seg.ident.to_string())
96 .map(with_first_letter_uppercase)
97 .collect();
98
99 syn::Ident::new(&normal_path, proc_macro2::Span::call_site())
100}
101
102fn generate_error_modules(variants: &[syn::Ident], types: &[&syn::Path]) -> impl quote::ToTokens {
103 let modules = types.iter()
104 .zip(variants.iter())
105 .map(|(path, variant)| {
106 let segs: Vec<&_> = path.segments.iter().collect();
107 let (last, segs) = segs.split_last().unwrap();
108
109 let supers = std::iter::repeat_with(|| quote!(super)).take(segs.len());
110 let supers = quote!(
111 #(#supers::)*
112 );
113
114 let mut modules = quote!(
115 pub use #supers Error::#variant as #last;
116 );
117
118 for seg in segs.iter().rev() {
119 modules = quote!(
120 pub mod #seg {
121 #modules
122 }
123 );
124 }
125
126 modules
127 });
128
129 quote!(
130 #(
131 #modules
132 )*
133 )
134}
135
136fn generate_into_impls(variants: &[syn::Ident], types: &[&syn::Path]) -> impl quote::ToTokens {
137 let modules = types.iter()
138 .zip(variants.iter())
139 .map(|(path, variant)| quote!(
140 impl ::core::convert::From<super::#path> for Error {
141 fn from(val: super::#path) -> Error {
142 Error::#variant(val)
143 }
144 }
145 ));
146
147 quote!(
148 #(
149 #modules
150 )*
151 )
152}
153
154fn generate_error_enum(err: &mut syn::Type) -> impl quote::ToTokens {
155 let is_lifetime_bound = |x: &syn::TypeParamBound| if let syn::TypeParamBound::Lifetime(_) = x { true } else { false };
156 let (variant_names, types): (Vec<_>, Vec<_>) = match err {
157 syn::Type::TraitObject(trait_obj) => {
158 if trait_obj.bounds.iter().any(is_lifetime_bound) {
159 panic!("Lifetime bounds are not allowed in anonymous sum type")
160 } else {
161 trait_obj.bounds.iter()
162 .filter_map(|x| match x {
163 syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => Some(path),
164 _ => None
165 })
166 .map(|path| (path_to_normalized_path(&path), path))
167 .unzip()
168 }
169 }
170 _ => panic!("Return type must be in form of `Result<T, E1 + E2 + ...>`")
171 };
172
173 let error_modules = generate_error_modules(&variant_names, &types);
174 let into_impls = generate_into_impls(&variant_names, &types);
175
176 quote!(
177 pub enum Error {
178 #(
179 #variant_names(super::#types)
180 ),*
181 }
182
183 #error_modules
184
185 #into_impls
186 )
187}
188
189#[proc_macro_attribute]
234pub fn some_error(_: TokenStream, contents: TokenStream) -> TokenStream {
235 let mut function = syn::parse_macro_input!(contents as syn::ItemFn);
236
237 let vis = function.vis.clone();
238 let ident = function.sig.ident.clone();
239 let ok_type = get_ok_type(&mut function.sig.output);
240 let err_type = get_err_type(&mut function.sig.output);
241
242 let error_enum = generate_error_enum(err_type);
243
244 function.sig.output = syn::parse_quote!(
245 -> Result<#ok_type, #ident::Error>
246 );
247
248 quote!(
249 #vis mod #ident {
250 #error_enum
251 }
252
253 #function
254 ).into()
255}