1mod prelude;
56
57use crate::prelude::*;
58use proc_macro2::Span;
59
60use quote::ToTokens;
61use syn::{parse::Parse, spanned::Spanned};
62
63#[proc_macro]
64pub fn catch(input: TokenStream) -> TokenStream {
65 let try_catch = parse_macro_input!(input as TryCatch);
66
67 template(try_catch).into()
68}
69
70struct TryCatch {
71 try_block: ExprBlock,
72 catches: Vec<Catch>,
73 is_async: bool,
74}
75struct Catch {
76 error: Ident,
77 err_type: Option<Type>,
78 block: ExprBlock,
79}
80
81fn parse_block(input: &parse::ParseStream) -> Result<ExprBlock> {
82 let out = input.parse().map(|block| match block {
83 Expr::Block(block) => Ok(block),
84 span => Err(Error::new(span.span(), "Expected a block `{ /* ... */ }`.")),
85 })??;
86 Ok(out)
87}
88
89impl Parse for TryCatch {
90 fn parse(input: parse::ParseStream) -> Result<Self> {
91 let _try_kw: Token![try] = input.parse()?;
92 let try_block = parse_block(&input)?;
93 let ts = try_block.to_token_stream();
94 let is_async = is_async(ts);
95 let mut catches = vec![];
96 while let Ok(catch) = input.parse() {
97 catches.push(catch)
98 }
99
100 Ok(TryCatch {
101 try_block,
102 catches,
103 is_async,
104 })
105 }
106}
107
108impl Parse for Catch {
109 fn parse(input: parse::ParseStream) -> Result<Self> {
110 let catch_kw: Ident = input.parse()?;
111 if catch_kw != "catch" {
112 return Err(Error::new(catch_kw.span(), "Expected `catch`"));
113 }
114 let error: Ident = input.parse()?;
115 let err_type = if input.peek(Token![:]) {
116 let _colon: Token![:] = input.parse()?;
117 Some(input.parse()?)
118 } else {
119 None
120 };
121 let block = parse_block(&input)?;
122 Ok(Catch {
123 error,
124 err_type,
125 block,
126 })
127 }
128}
129use syn::ExprBlock;
130
131fn template(try_catch: TryCatch) -> TokenStream2 {
132 let try_block = try_catch.try_block;
133 let result = Ident::new("__try_catch_block", Span::mixed_site());
134 let result_err = Ident::new("__try_catch_error", Span::mixed_site());
135
136 let mut template = if try_catch.is_async {
137 quote![
138 let #result: ::std::result::Result<_, Box<dyn ::std::error::Error>> = (|| async {Ok(#try_block)})().await;
139 ]
140 } else {
141 quote![
142 let #result: ::std::result::Result<_, Box<dyn ::std::error::Error>> = (|| Ok(#try_block))();
143 ]
144 };
145
146 let mut catch_template = quote!();
147 let mut warn_unused_must_use = true;
148 for catch in try_catch.catches {
149 let block = catch.block;
150 let error_name = catch.error;
151 if let Some(err_type) = catch.err_type {
152 catch_template.extend(quote![
153 _ if #result_err.is::<#err_type>() => {
154 let #error_name = #result_err.downcast::<#err_type>().unwrap();
155 ::std::result::Result::Ok(#block)
156 }
157 ]);
158 } else {
159 warn_unused_must_use = false;
160 catch_template.extend(quote![
161 _ => {
162 let #error_name = #result_err;
163 ::std::result::Result::Ok(#block)
164 }
165 ]);
166 }
167 }
168
169 catch_template.extend(quote![
170 _ => {
171 ::std::result::Result::Err(#result_err)
172 }
173 ]);
174
175 template.extend(quote![
176 if let ::std::result::Result::Err(#result_err) = #result {
177 match () { #catch_template }
178 } else {
179 #result
180 }
181 ]);
182
183 if warn_unused_must_use {
184 quote!({#template})
185 } else {
186 quote!({#template.ok().unwrap()})
187 }
188}
189
190fn is_async(input: TokenStream2) -> bool {
191 let mut out = false;
192 for token in input {
193 match token {
194 proc_macro2::TokenTree::Ident(ident) => {
195 if ident == "await" {
196 out = true;
197 }
198 }
199 proc_macro2::TokenTree::Group(group) => {
200 out |= is_async(group.stream());
201 }
202 _ => (),
203 }
204 }
205 out
206}