timeout_macro_parse/
inject.rs

1use crate::Error;
2#[cfg(not(feature = "test"))]
3use proc_macro::{Delimiter, Span, TokenStream, TokenTree};
4#[cfg(feature = "test")]
5use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
6
7pub trait Injector {
8    fn inject(self, fn_name: &str, inner_code: TokenStream) -> TokenStream;
9}
10
11pub(crate) fn try_inject(
12    injector: impl Injector,
13    source: TokenStream,
14) -> crate::Result<TokenStream> {
15    let mut it = source.into_iter();
16    let mut pre = TokenStream::new();
17    let (fn_name, inner_body) = extract_inner_body(&mut pre, &mut it)?;
18    let res = injector.inject(&fn_name, inner_body);
19    pre.extend([res]);
20    Ok(pre)
21}
22
23fn extract_inner_body(
24    pre: &mut TokenStream,
25    source: &mut impl Iterator<Item = TokenTree>,
26) -> crate::Result<(String, TokenStream)> {
27    let mut seen_async = false;
28    let mut seen_fn_decl = false;
29    let mut fn_name = None;
30    let mut last = None;
31    let mut peek = source.peekable();
32    while let Some(token) = peek.next() {
33        match &token {
34            TokenTree::Ident(id) => {
35                let id = id.to_string();
36                match id.as_str() {
37                    "async" => seen_async = true,
38                    "fn" => seen_fn_decl = true,
39                    maybe_fn_name if seen_async && seen_fn_decl && fn_name.is_none() => {
40                        fn_name = Some(maybe_fn_name.to_string());
41                    }
42                    _ => {}
43                }
44            }
45            t if seen_async && seen_fn_decl && fn_name.is_none() => {
46                return Err(Error::with_span(
47                    t.span(),
48                    "unexpected token, expected fn name".to_string(),
49                ))
50            }
51            t => {
52                last = Some(t.clone());
53            }
54        }
55        if peek.peek().is_some() {
56            pre.extend([token]);
57        }
58    }
59    if !seen_fn_decl {
60        return Err(Error::missing_span(
61            "'timeout' macro used on something without a 'fn' declaration".to_string(),
62        ));
63    }
64    if !seen_async {
65        return Err(Error::missing_span(
66            "'timeout' macro only allowed on async functions".to_string(),
67        ));
68    }
69    let Some(TokenTree::Group(group)) = last else {
70        return Err(Error::missing_span(
71            "'timeout' macro used on something without a body".to_string(),
72        ));
73    };
74    if !matches!(group.delimiter(), Delimiter::Brace) {
75        return Err(Error::with_span(
76            group.span(),
77            "'timeout' macro used on something without a body (last group not a brace)".to_string(),
78        ));
79    }
80    let Some(fn_name) = fn_name else {
81        return Err(Error::with_span(
82            Span::call_site(),
83            "'timeout' macro unable to find fn name",
84        ));
85    };
86    Ok((fn_name, group.stream()))
87}