wherror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, quote_spanned, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::{End, ParseStream};
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat,
8    LitInt, LitStr, Meta, Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<Source<'a>>,
14    pub backtrace: Option<&'a Attribute>,
15    pub location: Option<&'a Attribute>,
16    pub from: Option<From<'a>>,
17    pub transparent: Option<Transparent<'a>>,
18    pub fmt: Option<Fmt<'a>>,
19    pub debug: Option<DebugFallback<'a>>,
20}
21
22#[derive(Clone)]
23pub struct Display<'a> {
24    pub original: &'a Attribute,
25    pub fmt: LitStr,
26    pub args: TokenStream,
27    pub requires_fmt_machinery: bool,
28    pub has_bonus_display: bool,
29    pub infinite_recursive: bool,
30    pub implied_bounds: Set<(usize, Trait)>,
31    pub bindings: Vec<(Ident, TokenStream)>,
32}
33
34#[derive(Copy, Clone)]
35pub struct Source<'a> {
36    pub original: &'a Attribute,
37    pub span: Span,
38}
39
40#[derive(Copy, Clone)]
41pub struct From<'a> {
42    pub original: &'a Attribute,
43    pub span: Span,
44}
45
46#[derive(Copy, Clone)]
47pub struct Transparent<'a> {
48    pub original: &'a Attribute,
49    pub span: Span,
50}
51
52#[derive(Copy, Clone)]
53pub struct DebugFallback<'a> {
54    pub original: &'a Attribute,
55    pub span: Span,
56}
57
58#[derive(Clone)]
59pub struct Fmt<'a> {
60    pub original: &'a Attribute,
61    pub path: ExprPath,
62}
63
64#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
65pub enum Trait {
66    Debug,
67    Display,
68    Octal,
69    LowerHex,
70    UpperHex,
71    Pointer,
72    Binary,
73    LowerExp,
74    UpperExp,
75}
76
77pub fn get(input: &[Attribute]) -> Result<Attrs> {
78    let mut attrs = Attrs {
79        display: None,
80        source: None,
81        backtrace: None,
82        location: None,
83        from: None,
84        transparent: None,
85        fmt: None,
86        debug: None,
87    };
88
89    for attr in input {
90        if attr.path().is_ident("error") {
91            parse_error_attribute(&mut attrs, attr)?;
92        } else if attr.path().is_ident("source") {
93            attr.meta.require_path_only()?;
94            if attrs.source.is_some() {
95                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
96            }
97            let span = (attr.pound_token.span)
98                .join(attr.bracket_token.span.join())
99                .unwrap_or(attr.path().get_ident().unwrap().span());
100            attrs.source = Some(Source {
101                original: attr,
102                span,
103            });
104        } else if attr.path().is_ident("backtrace") {
105            attr.meta.require_path_only()?;
106            if attrs.backtrace.is_some() {
107                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
108            }
109            attrs.backtrace = Some(attr);
110        } else if attr.path().is_ident("location") {
111            attr.meta.require_path_only()?;
112            if attrs.location.is_some() {
113                return Err(Error::new_spanned(attr, "duplicate #[location] attribute"));
114            }
115            attrs.location = Some(attr);
116        } else if attr.path().is_ident("from") {
117            match attr.meta {
118                Meta::Path(_) => {}
119                Meta::List(_) | Meta::NameValue(_) => {
120                    // Assume this is meant for derive_more crate or something.
121                    continue;
122                }
123            }
124            if attrs.from.is_some() {
125                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
126            }
127            let span = (attr.pound_token.span)
128                .join(attr.bracket_token.span.join())
129                .unwrap_or(attr.path().get_ident().unwrap().span());
130            attrs.from = Some(From {
131                original: attr,
132                span,
133            });
134        }
135    }
136
137    Ok(attrs)
138}
139
140fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
141    mod kw {
142        syn::custom_keyword!(transparent);
143        syn::custom_keyword!(fmt);
144        syn::custom_keyword!(debug);
145    }
146
147    attr.parse_args_with(|input: ParseStream| {
148        let lookahead = input.lookahead1();
149        let fmt = if lookahead.peek(LitStr) {
150            input.parse::<LitStr>()?
151        } else if lookahead.peek(kw::transparent) {
152            let kw: kw::transparent = input.parse()?;
153            if attrs.transparent.is_some() {
154                return Err(Error::new_spanned(
155                    attr,
156                    "duplicate #[error(transparent)] attribute",
157                ));
158            }
159            attrs.transparent = Some(Transparent {
160                original: attr,
161                span: kw.span,
162            });
163            return Ok(());
164        } else if lookahead.peek(kw::fmt) {
165            input.parse::<kw::fmt>()?;
166            input.parse::<Token![=]>()?;
167            let path: ExprPath = input.parse()?;
168            if attrs.fmt.is_some() {
169                return Err(Error::new_spanned(
170                    attr,
171                    "duplicate #[error(fmt = ...)] attribute",
172                ));
173            }
174            attrs.fmt = Some(Fmt {
175                original: attr,
176                path,
177            });
178            return Ok(());
179        } else if lookahead.peek(kw::debug) {
180            let kw: kw::debug = input.parse()?;
181            if attrs.debug.is_some() {
182                return Err(Error::new_spanned(
183                    attr,
184                    "duplicate #[error(debug)] attribute",
185                ));
186            }
187            attrs.debug = Some(DebugFallback {
188                original: attr,
189                span: kw.span,
190            });
191            return Ok(());
192        } else {
193            return Err(lookahead.error());
194        };
195
196        let args = if input.is_empty() || input.peek(Token![,]) && input.peek2(End) {
197            input.parse::<Option<Token![,]>>()?;
198            TokenStream::new()
199        } else {
200            parse_token_expr(input, false)?
201        };
202
203        let requires_fmt_machinery = !args.is_empty();
204
205        let display = Display {
206            original: attr,
207            fmt,
208            args,
209            requires_fmt_machinery,
210            has_bonus_display: false,
211            infinite_recursive: false,
212            implied_bounds: Set::new(),
213            bindings: Vec::new(),
214        };
215        if attrs.display.is_some() {
216            return Err(Error::new_spanned(
217                attr,
218                "only one #[error(...)] attribute is allowed",
219            ));
220        }
221        attrs.display = Some(display);
222        Ok(())
223    })
224}
225
226fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
227    let mut tokens = Vec::new();
228    while !input.is_empty() {
229        if input.peek(token::Group) {
230            let group: TokenTree = input.parse()?;
231            tokens.push(group);
232            begin_expr = false;
233            continue;
234        }
235
236        if begin_expr && input.peek(Token![.]) {
237            if input.peek2(Ident) {
238                input.parse::<Token![.]>()?;
239                begin_expr = false;
240                continue;
241            } else if input.peek2(LitInt) {
242                input.parse::<Token![.]>()?;
243                let int: Index = input.parse()?;
244                tokens.push({
245                    let ident = format_ident!("_{}", int.index, span = int.span);
246                    TokenTree::Ident(ident)
247                });
248                begin_expr = false;
249                continue;
250            } else if input.peek2(LitFloat) {
251                let ahead = input.fork();
252                ahead.parse::<Token![.]>()?;
253                let float: LitFloat = ahead.parse()?;
254                let repr = float.to_string();
255                let mut indices = repr.split('.').map(syn::parse_str::<Index>);
256                if let (Some(Ok(first)), Some(Ok(second)), None) =
257                    (indices.next(), indices.next(), indices.next())
258                {
259                    input.advance_to(&ahead);
260                    tokens.push({
261                        let ident = format_ident!("_{}", first, span = float.span());
262                        TokenTree::Ident(ident)
263                    });
264                    tokens.push({
265                        let mut punct = Punct::new('.', Spacing::Alone);
266                        punct.set_span(float.span());
267                        TokenTree::Punct(punct)
268                    });
269                    tokens.push({
270                        let mut literal = Literal::u32_unsuffixed(second.index);
271                        literal.set_span(float.span());
272                        TokenTree::Literal(literal)
273                    });
274                    begin_expr = false;
275                    continue;
276                }
277            }
278        }
279
280        begin_expr = input.peek(Token![break])
281            || input.peek(Token![continue])
282            || input.peek(Token![if])
283            || input.peek(Token![in])
284            || input.peek(Token![match])
285            || input.peek(Token![mut])
286            || input.peek(Token![return])
287            || input.peek(Token![while])
288            || input.peek(Token![+])
289            || input.peek(Token![&])
290            || input.peek(Token![!])
291            || input.peek(Token![^])
292            || input.peek(Token![,])
293            || input.peek(Token![/])
294            || input.peek(Token![=])
295            || input.peek(Token![>])
296            || input.peek(Token![<])
297            || input.peek(Token![|])
298            || input.peek(Token![%])
299            || input.peek(Token![;])
300            || input.peek(Token![*])
301            || input.peek(Token![-]);
302
303        let token: TokenTree = if input.peek(token::Paren) {
304            let content;
305            let delimiter = parenthesized!(content in input);
306            let nested = parse_token_expr(&content, true)?;
307            let mut group = Group::new(Delimiter::Parenthesis, nested);
308            group.set_span(delimiter.span.join());
309            TokenTree::Group(group)
310        } else if input.peek(token::Brace) {
311            let content;
312            let delimiter = braced!(content in input);
313            let nested = parse_token_expr(&content, true)?;
314            let mut group = Group::new(Delimiter::Brace, nested);
315            group.set_span(delimiter.span.join());
316            TokenTree::Group(group)
317        } else if input.peek(token::Bracket) {
318            let content;
319            let delimiter = bracketed!(content in input);
320            let nested = parse_token_expr(&content, true)?;
321            let mut group = Group::new(Delimiter::Bracket, nested);
322            group.set_span(delimiter.span.join());
323            TokenTree::Group(group)
324        } else {
325            input.parse()?
326        };
327        tokens.push(token);
328    }
329    Ok(TokenStream::from_iter(tokens))
330}
331
332impl ToTokens for Display<'_> {
333    fn to_tokens(&self, tokens: &mut TokenStream) {
334        if self.infinite_recursive {
335            let span = self.fmt.span();
336            tokens.extend(quote_spanned! {span=>
337                #[warn(unconditional_recursion)]
338                fn _fmt() { _fmt() }
339            });
340        }
341
342        let fmt = &self.fmt;
343        let args = &self.args;
344
345        // Currently `write!(f, "text")` produces less efficient code than
346        // `f.write_str("text")`. We recognize the case when the format string
347        // has no braces and no interpolated values, and generate simpler code.
348        let write = if self.requires_fmt_machinery {
349            quote! {
350                ::core::write!(__formatter, #fmt #args)
351            }
352        } else {
353            quote! {
354                __formatter.write_str(#fmt)
355            }
356        };
357
358        tokens.extend(if self.bindings.is_empty() {
359            write
360        } else {
361            let locals = self.bindings.iter().map(|(local, _value)| local);
362            let values = self.bindings.iter().map(|(_local, value)| value);
363            quote! {
364                match (#(#values,)*) {
365                    (#(#locals,)*) => #write
366                }
367            }
368        });
369    }
370}
371
372impl ToTokens for Trait {
373    fn to_tokens(&self, tokens: &mut TokenStream) {
374        let trait_name = match self {
375            Trait::Debug => "Debug",
376            Trait::Display => "Display",
377            Trait::Octal => "Octal",
378            Trait::LowerHex => "LowerHex",
379            Trait::UpperHex => "UpperHex",
380            Trait::Pointer => "Pointer",
381            Trait::Binary => "Binary",
382            Trait::LowerExp => "LowerExp",
383            Trait::UpperExp => "UpperExp",
384        };
385        let ident = Ident::new(trait_name, Span::call_site());
386        tokens.extend(quote!(::core::fmt::#ident));
387    }
388}