some_error/
lib.rs

1//! A library for allowing Anonymous Sum Types in the error of the returned result.
2//! The function this is attached to must return a type in the form of `Result<T, E1 + E2 + ...>`.
3//!
4//! ## Example:
5//!
6//! ```rust
7//! use std::io;
8//! use some_error::*;
9//! 
10//! #[derive(Debug, Clone, Copy)]
11//! struct NotZeroError(u32);
12//! 
13//! #[some_error]
14//! fn my_func() -> Result<(), io::Error + NotZeroError>{
15//!     let x = 3;
16//!     if x != 0 {
17//!         Err(NotZeroError(x))?;
18//!     }
19//! 
20//!     Ok(())
21//! }
22//! 
23//! fn main() {
24//!     match my_func() {
25//!         Ok(_) => {
26//!             println!("Worked ok!");
27//!         }
28//!         Err(my_func::NotZeroError(NotZeroError(x))) => {
29//!             println!("{} is not zero!!", x);
30//!         }
31//!         Err(my_func::io::Error(io_err)) => {
32//!             println!("io error: {:?}", io_err);
33//!         }
34//!     }
35//! }
36//! ```
37use quote::quote;
38use proc_macro::TokenStream;
39
40#[derive(Copy, Clone)]
41enum ResultPos {
42    Ok = 0,
43    Err = 1
44}
45
46fn get_result_type(ret: &mut syn::ReturnType, pos: ResultPos) -> &mut syn::Type {
47    match ret {
48        syn::ReturnType::Type(_, ty) => match ty.as_mut() {
49            syn::Type::Path(syn::TypePath { path, .. }) => {
50                let end = path.segments.iter_mut().last().unwrap();
51
52                if end.ident.to_string() == "Result" {
53                    match &mut end.arguments {
54                        syn::PathArguments::AngleBracketed(args) => {
55                            if args.args.len() != 2 {
56                                panic!("Return type must be `Result<T, E>`")
57                            }
58
59                            let err = args.args.iter_mut().skip(pos as usize).next().unwrap();
60
61                            match err {
62                                syn::GenericArgument::Type(ref mut err) => err,
63                                _ => panic!("Return type must be `Result<T, E>`")
64                            }
65                        }
66                        _ => panic!("Return type must be `Result<T, E>`")
67                    }
68                } else {
69                    panic!("Return type must be `Result<T, E>`")
70                }
71            }
72            _ => panic!("Return type must be `Result<T, E>`")
73        }
74        syn::ReturnType::Default => panic!("Return type must be `Result<T, E>`")
75    }
76}
77
78fn get_ok_type(ret: &mut syn::ReturnType) -> syn::Type {
79    get_result_type(ret, ResultPos::Ok).clone()
80}
81
82fn get_err_type(ret: &mut syn::ReturnType) -> &mut syn::Type {
83    get_result_type(ret, ResultPos::Err)
84}
85
86fn with_first_letter_uppercase(ident: String) -> String {
87    ident.chars()
88        .enumerate()
89        .map(|(i, first_char)| if i == 0 { first_char.to_ascii_uppercase() } else { first_char })
90        .collect()
91}
92
93fn path_to_normalized_path(path: &syn::Path) -> syn::Ident {
94    let normal_path: String = path.segments.iter()
95        .map(|seg| seg.ident.to_string())
96        .map(with_first_letter_uppercase)
97        .collect();
98
99    syn::Ident::new(&normal_path, proc_macro2::Span::call_site())
100}
101
102fn generate_error_modules(variants: &[syn::Ident], types: &[&syn::Path]) -> impl quote::ToTokens {
103    let modules = types.iter()
104        .zip(variants.iter())
105        .map(|(path, variant)| {
106            let segs: Vec<&_> = path.segments.iter().collect();
107            let (last, segs) = segs.split_last().unwrap();
108
109            let supers = std::iter::repeat_with(|| quote!(super)).take(segs.len());
110            let supers = quote!(
111                #(#supers::)*
112            );
113
114            let mut modules = quote!(
115                pub use #supers Error::#variant as #last;
116            );
117
118            for seg in segs.iter().rev() {
119                modules = quote!(
120                    pub mod #seg {
121                        #modules
122                    }
123                );
124            }
125
126            modules
127        });
128
129    quote!(
130        #(
131            #modules
132        )*
133    )
134}
135
136fn generate_into_impls(variants: &[syn::Ident], types: &[&syn::Path]) -> impl quote::ToTokens {
137    let modules = types.iter()
138        .zip(variants.iter())
139        .map(|(path, variant)| quote!(
140            impl ::core::convert::From<super::#path> for Error {
141                fn from(val: super::#path) -> Error {
142                    Error::#variant(val)
143                }
144            }
145        ));
146
147    quote!(
148        #(
149            #modules
150        )*
151    )
152}
153
154fn generate_error_enum(err: &mut syn::Type) -> impl quote::ToTokens {
155    let is_lifetime_bound = |x: &syn::TypeParamBound| if let syn::TypeParamBound::Lifetime(_) = x { true } else { false };
156    let (variant_names, types): (Vec<_>, Vec<_>) = match err {
157        syn::Type::TraitObject(trait_obj) => {
158            if trait_obj.bounds.iter().any(is_lifetime_bound) {
159                panic!("Lifetime bounds are not allowed in anonymous sum type")
160            } else {
161                trait_obj.bounds.iter()
162                    .filter_map(|x| match x {
163                        syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => Some(path),
164                        _ => None
165                    })
166                    .map(|path| (path_to_normalized_path(&path), path))
167                    .unzip()
168            }
169        }
170        _ => panic!("Return type must be in form of `Result<T, E1 + E2 + ...>`")
171    };
172
173    let error_modules = generate_error_modules(&variant_names, &types);
174    let into_impls = generate_into_impls(&variant_names, &types);
175
176    quote!(
177        pub enum Error {
178            #(
179                #variant_names(super::#types)
180            ),*
181        }
182
183        #error_modules
184
185        #into_impls
186    )
187}
188
189/// A Macro for allowing Anonymous Sum Types in the error of the returned result.
190/// The function this is attached to must return a type in the form of `Result<T, E1 + E2 + ...>`.
191///
192/// ## Example:
193///
194/// ```rust
195/// use std::io;
196/// use some_error::*;
197/// 
198/// #[derive(Debug, Clone, Copy)]
199/// struct NotZeroError(u32);
200/// 
201/// #[some_error]
202/// fn my_func() -> Result<(), io::Error + NotZeroError>{
203///     let x = 3;
204///     if x != 0 {
205///         Err(NotZeroError(x))?;
206///     }
207/// 
208///     Ok(())
209/// }
210/// 
211/// fn main() {
212///     match my_func() {
213///         Ok(_) => {
214///             println!("Worked ok!");
215///         }
216///         Err(my_func::NotZeroError(NotZeroError(x))) => {
217///             println!("{} is not zero!!", x);
218///         }
219///         Err(my_func::io::Error(io_err)) => {
220///             println!("io error: {:?}", io_err);
221///         }
222///     }
223/// }
224/// ```
225///
226/// ## More Info
227///
228/// * The type of the anonymous sum type can be referenced via `function_name::Error`.
229/// * All variants of the anonymous sum type can be accessed via `function_name::<path>`
230///     * For example if you use the return type `Result<(), fmt::Error + i32>` the variants will
231///     be named `function_name::fmt::Error` and `function_name::i32` and can be used as patterns
232///     to match against (see the above example)
233#[proc_macro_attribute]
234pub fn some_error(_: TokenStream, contents: TokenStream) -> TokenStream {
235    let mut function = syn::parse_macro_input!(contents as syn::ItemFn);
236
237    let vis = function.vis.clone();
238    let ident = function.sig.ident.clone();
239    let ok_type = get_ok_type(&mut function.sig.output);
240    let err_type = get_err_type(&mut function.sig.output);
241
242    let error_enum = generate_error_enum(err_type);
243
244    function.sig.output = syn::parse_quote!(
245        -> Result<#ok_type, #ident::Error>
246    );
247
248    quote!(
249        #vis mod #ident {
250            #error_enum
251        }
252
253        #function
254    ).into()
255}