try_catch/
lib.rs

1//! This crate provides a macro that enables the familiar `try-catch` syntax of other programming languages.
2//! It can be used to easlily group errors and manage them dynamically by type rather than value.
3//!
4//! ```rust
5//! use try_catch::catch;
6//! use std::*;
7//! use serde_json::Value;
8//!
9//! catch! {
10//!     try {
11//!         let number: i32 = "10".parse()?;
12//!         let data = fs::read_to_string("data.json")?;
13//!         let json: Value = serde_json::from_str(&data)?;
14//!     }
15//!     catch error: io::Error {
16//!         println!("Failed to open the file: {}", error)
17//!     }
18//!     catch json_err: serde_json::Error {
19//!         println!("Failed to serialize data: {}", json_err)
20//!     }
21//!     catch err {
22//!         println!("Error of unknown type: {}", err)
23//!     }
24//! };
25//!
26//! ```
27//! Note, if no wildcard is present then the compiler will warn about unused results.
28//! It can also be used as an expression:
29//! ```rust
30//! // We can guarantee that all errors are catched 
31//! // so the type of this expression is `i32`.
32//! // It can be guaranteed because the final catch 
33//! // does not specify an Error type. 
34//! let number: i32 = catch! {
35//!     try {
36//!         let number: i32 = "10".parse()?;
37//!         number
38//!     } catch error {
39//!         0
40//!     }
41//! };
42//! // we can't know for sure if all possible errors are 
43//! // handled so the type of this expression 
44//! // is still Result. 
45//! let result: Result<i32, _> = catch! {
46//!     try {
47//!         let number: i32 = "invalid number".parse()?;
48//!         number
49//!     } catch error: io::Error {
50//!         0
51//!     }
52//! };
53//! ```
54
55mod prelude;
56
57use crate::prelude::*;
58use proc_macro2::Span;
59
60use quote::ToTokens;
61use syn::{parse::Parse, spanned::Spanned};
62
63#[proc_macro]
64pub fn catch(input: TokenStream) -> TokenStream {
65    let try_catch = parse_macro_input!(input as TryCatch);
66
67    template(try_catch).into()
68}
69
70struct TryCatch {
71    try_block: ExprBlock,
72    catches: Vec<Catch>,
73    is_async: bool,
74}
75struct Catch {
76    error: Ident,
77    err_type: Option<Type>,
78    block: ExprBlock,
79}
80
81fn parse_block(input: &parse::ParseStream) -> Result<ExprBlock> {
82    let out = input.parse().map(|block| match block {
83        Expr::Block(block) => Ok(block),
84        span => Err(Error::new(span.span(), "Expected a block `{ /* ... */ }`.")),
85    })??;
86    Ok(out)
87}
88
89impl Parse for TryCatch {
90    fn parse(input: parse::ParseStream) -> Result<Self> {
91        let _try_kw: Token![try] = input.parse()?;
92        let try_block = parse_block(&input)?;
93        let ts = try_block.to_token_stream();
94        let is_async = is_async(ts);
95        let mut catches = vec![];
96        while let Ok(catch) = input.parse() {
97            catches.push(catch)
98        }
99
100        Ok(TryCatch {
101            try_block,
102            catches,
103            is_async,
104        })
105    }
106}
107
108impl Parse for Catch {
109    fn parse(input: parse::ParseStream) -> Result<Self> {
110        let catch_kw: Ident = input.parse()?;
111        if catch_kw != "catch" {
112            return Err(Error::new(catch_kw.span(), "Expected `catch`"));
113        }
114        let error: Ident = input.parse()?;
115        let err_type = if input.peek(Token![:]) {
116            let _colon: Token![:] = input.parse()?;
117            Some(input.parse()?)
118        } else {
119            None
120        };
121        let block = parse_block(&input)?;
122        Ok(Catch {
123            error,
124            err_type,
125            block,
126        })
127    }
128}
129use syn::ExprBlock;
130
131fn template(try_catch: TryCatch) -> TokenStream2 {
132    let try_block = try_catch.try_block;
133    let result = Ident::new("__try_catch_block", Span::mixed_site());
134    let result_err = Ident::new("__try_catch_error", Span::mixed_site());
135
136    let mut template = if try_catch.is_async {
137        quote![
138            let #result: ::std::result::Result<_, Box<dyn ::std::error::Error>> = (|| async {Ok(#try_block)})().await;
139        ]
140    } else {
141        quote![
142            let #result: ::std::result::Result<_, Box<dyn ::std::error::Error>> = (|| Ok(#try_block))();
143        ]
144    };
145
146    let mut catch_template = quote!();
147    let mut warn_unused_must_use = true;
148    for catch in try_catch.catches {
149        let block = catch.block;
150        let error_name = catch.error;
151        if let Some(err_type) = catch.err_type {
152            catch_template.extend(quote![
153                _ if  #result_err.is::<#err_type>() => {
154                    let #error_name = #result_err.downcast::<#err_type>().unwrap();
155                    ::std::result::Result::Ok(#block)
156                }
157            ]);
158        } else {
159            warn_unused_must_use = false;
160            catch_template.extend(quote![
161                _ => {
162                    let #error_name = #result_err;
163                    ::std::result::Result::Ok(#block)
164                }
165            ]);
166        }
167    }
168
169    catch_template.extend(quote![
170        _ => {
171            ::std::result::Result::Err(#result_err)
172        }
173    ]);
174
175    template.extend(quote![
176        if let ::std::result::Result::Err(#result_err) = #result {
177           match () { #catch_template }
178        } else {
179            #result
180        }
181    ]);
182
183    if warn_unused_must_use {
184        quote!({#template})
185    } else {
186        quote!({#template.ok().unwrap()})
187    }
188}
189
190fn is_async(input: TokenStream2) -> bool {
191    let mut out = false;
192    for token in input {
193        match token {
194            proc_macro2::TokenTree::Ident(ident) => {
195                if ident == "await" {
196                    out = true;
197                }
198            }
199            proc_macro2::TokenTree::Group(group) => {
200                out |= is_async(group.stream());
201            }
202            _ => (),
203        }
204    }
205    out
206}