wherror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, quote_spanned, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::{End, ParseStream};
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat,
8    LitInt, LitStr, Meta, Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<Source<'a>>,
14    pub backtrace: Option<&'a Attribute>,
15    pub location: Option<&'a Attribute>,
16    pub from: Option<From<'a>>,
17    pub transparent: Option<Transparent<'a>>,
18    pub fmt: Option<Fmt<'a>>,
19    pub debug: Option<DebugFallback<'a>>,
20}
21
22#[derive(Clone)]
23pub struct Display<'a> {
24    pub original: &'a Attribute,
25    pub fmt: LitStr,
26    pub args: TokenStream,
27    pub requires_fmt_machinery: bool,
28    pub has_bonus_display: bool,
29    pub infinite_recursive: bool,
30    pub implied_bounds: Set<(usize, Trait)>,
31    pub bindings: Vec<(Ident, TokenStream)>,
32}
33
34#[derive(Copy, Clone)]
35pub struct Source<'a> {
36    pub original: &'a Attribute,
37    pub span: Span,
38}
39
40#[derive(Copy, Clone)]
41pub struct From<'a> {
42    pub original: &'a Attribute,
43    pub span: Span,
44    pub no_source: bool,
45}
46
47#[derive(Copy, Clone)]
48pub struct Transparent<'a> {
49    pub original: &'a Attribute,
50    pub span: Span,
51}
52
53#[derive(Copy, Clone)]
54pub struct DebugFallback<'a> {
55    pub original: &'a Attribute,
56    pub span: Span,
57}
58
59#[derive(Clone)]
60pub struct Fmt<'a> {
61    pub original: &'a Attribute,
62    pub path: ExprPath,
63}
64
65#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
66pub enum Trait {
67    Debug,
68    Display,
69    Octal,
70    LowerHex,
71    UpperHex,
72    Pointer,
73    Binary,
74    LowerExp,
75    UpperExp,
76}
77
78pub fn get(input: &[Attribute]) -> Result<Attrs> {
79    let mut attrs = Attrs {
80        display: None,
81        source: None,
82        backtrace: None,
83        location: None,
84        from: None,
85        transparent: None,
86        fmt: None,
87        debug: None,
88    };
89
90    for attr in input {
91        if attr.path().is_ident("error") {
92            parse_error_attribute(&mut attrs, attr)?;
93        } else if attr.path().is_ident("source") {
94            attr.meta.require_path_only()?;
95            if attrs.source.is_some() {
96                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
97            }
98            let span = (attr.pound_token.span)
99                .join(attr.bracket_token.span.join())
100                .unwrap_or(attr.path().get_ident().unwrap().span());
101            attrs.source = Some(Source {
102                original: attr,
103                span,
104            });
105        } else if attr.path().is_ident("backtrace") {
106            attr.meta.require_path_only()?;
107            if attrs.backtrace.is_some() {
108                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
109            }
110            attrs.backtrace = Some(attr);
111        } else if attr.path().is_ident("location") {
112            attr.meta.require_path_only()?;
113            if attrs.location.is_some() {
114                return Err(Error::new_spanned(attr, "duplicate #[location] attribute"));
115            }
116            attrs.location = Some(attr);
117        } else if attr.path().is_ident("from") {
118            if attrs.from.is_some() {
119                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
120            }
121            // Support optional flags: #[from(no_source)]
122            let mut no_source = false;
123            match &attr.meta {
124                Meta::Path(_) => {}
125                Meta::List(_) => {
126                    mod kw {
127                        syn::custom_keyword!(no_source);
128                    }
129                    attr.parse_args_with(|input: ParseStream| {
130                        while !input.is_empty() {
131                            if input.peek(kw::no_source) {
132                                let _ = input.parse::<kw::no_source>()?;
133                                no_source = true;
134                            } else {
135                                return Err(input.error("unsupported option in #[from(...)]"));
136                            }
137                            let _ = input.parse::<Option<Token![,]>>()?;
138                        }
139                        Ok(())
140                    })?;
141                }
142                Meta::NameValue(_) => {
143                    // Assume meant for other crates; ignore.
144                }
145            }
146            let span = (attr.pound_token.span)
147                .join(attr.bracket_token.span.join())
148                .unwrap_or(attr.path().get_ident().unwrap().span());
149            attrs.from = Some(From {
150                original: attr,
151                span,
152                no_source,
153            });
154        }
155    }
156
157    Ok(attrs)
158}
159
160fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
161    mod kw {
162        syn::custom_keyword!(transparent);
163        syn::custom_keyword!(fmt);
164        syn::custom_keyword!(debug);
165    }
166
167    attr.parse_args_with(|input: ParseStream| {
168        let lookahead = input.lookahead1();
169        let fmt = if lookahead.peek(LitStr) {
170            input.parse::<LitStr>()?
171        } else if lookahead.peek(kw::transparent) {
172            let kw: kw::transparent = input.parse()?;
173            if attrs.transparent.is_some() {
174                return Err(Error::new_spanned(
175                    attr,
176                    "duplicate #[error(transparent)] attribute",
177                ));
178            }
179            attrs.transparent = Some(Transparent {
180                original: attr,
181                span: kw.span,
182            });
183            return Ok(());
184        } else if lookahead.peek(kw::fmt) {
185            input.parse::<kw::fmt>()?;
186            input.parse::<Token![=]>()?;
187            let path: ExprPath = input.parse()?;
188            if attrs.fmt.is_some() {
189                return Err(Error::new_spanned(
190                    attr,
191                    "duplicate #[error(fmt = ...)] attribute",
192                ));
193            }
194            attrs.fmt = Some(Fmt {
195                original: attr,
196                path,
197            });
198            return Ok(());
199        } else if lookahead.peek(kw::debug) {
200            let kw: kw::debug = input.parse()?;
201            if attrs.debug.is_some() {
202                return Err(Error::new_spanned(
203                    attr,
204                    "duplicate #[error(debug)] attribute",
205                ));
206            }
207            attrs.debug = Some(DebugFallback {
208                original: attr,
209                span: kw.span,
210            });
211            return Ok(());
212        } else {
213            return Err(lookahead.error());
214        };
215
216        let args = if input.is_empty() || input.peek(Token![,]) && input.peek2(End) {
217            input.parse::<Option<Token![,]>>()?;
218            TokenStream::new()
219        } else {
220            parse_token_expr(input, false)?
221        };
222
223        let requires_fmt_machinery = !args.is_empty();
224
225        let display = Display {
226            original: attr,
227            fmt,
228            args,
229            requires_fmt_machinery,
230            has_bonus_display: false,
231            infinite_recursive: false,
232            implied_bounds: Set::new(),
233            bindings: Vec::new(),
234        };
235        if attrs.display.is_some() {
236            return Err(Error::new_spanned(
237                attr,
238                "only one #[error(...)] attribute is allowed",
239            ));
240        }
241        attrs.display = Some(display);
242        Ok(())
243    })
244}
245
246fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
247    let mut tokens = Vec::new();
248    while !input.is_empty() {
249        if input.peek(token::Group) {
250            let group: TokenTree = input.parse()?;
251            tokens.push(group);
252            begin_expr = false;
253            continue;
254        }
255
256        if begin_expr && input.peek(Token![.]) {
257            if input.peek2(Ident) {
258                input.parse::<Token![.]>()?;
259                begin_expr = false;
260                continue;
261            } else if input.peek2(LitInt) {
262                input.parse::<Token![.]>()?;
263                let int: Index = input.parse()?;
264                tokens.push({
265                    let ident = format_ident!("_{}", int.index, span = int.span);
266                    TokenTree::Ident(ident)
267                });
268                begin_expr = false;
269                continue;
270            } else if input.peek2(LitFloat) {
271                let ahead = input.fork();
272                ahead.parse::<Token![.]>()?;
273                let float: LitFloat = ahead.parse()?;
274                let repr = float.to_string();
275                let mut indices = repr.split('.').map(syn::parse_str::<Index>);
276                if let (Some(Ok(first)), Some(Ok(second)), None) =
277                    (indices.next(), indices.next(), indices.next())
278                {
279                    input.advance_to(&ahead);
280                    tokens.push({
281                        let ident = format_ident!("_{}", first, span = float.span());
282                        TokenTree::Ident(ident)
283                    });
284                    tokens.push({
285                        let mut punct = Punct::new('.', Spacing::Alone);
286                        punct.set_span(float.span());
287                        TokenTree::Punct(punct)
288                    });
289                    tokens.push({
290                        let mut literal = Literal::u32_unsuffixed(second.index);
291                        literal.set_span(float.span());
292                        TokenTree::Literal(literal)
293                    });
294                    begin_expr = false;
295                    continue;
296                }
297            }
298        }
299
300        begin_expr = input.peek(Token![break])
301            || input.peek(Token![continue])
302            || input.peek(Token![if])
303            || input.peek(Token![in])
304            || input.peek(Token![match])
305            || input.peek(Token![mut])
306            || input.peek(Token![return])
307            || input.peek(Token![while])
308            || input.peek(Token![+])
309            || input.peek(Token![&])
310            || input.peek(Token![!])
311            || input.peek(Token![^])
312            || input.peek(Token![,])
313            || input.peek(Token![/])
314            || input.peek(Token![=])
315            || input.peek(Token![>])
316            || input.peek(Token![<])
317            || input.peek(Token![|])
318            || input.peek(Token![%])
319            || input.peek(Token![;])
320            || input.peek(Token![*])
321            || input.peek(Token![-]);
322
323        let token: TokenTree = if input.peek(token::Paren) {
324            let content;
325            let delimiter = parenthesized!(content in input);
326            let nested = parse_token_expr(&content, true)?;
327            let mut group = Group::new(Delimiter::Parenthesis, nested);
328            group.set_span(delimiter.span.join());
329            TokenTree::Group(group)
330        } else if input.peek(token::Brace) {
331            let content;
332            let delimiter = braced!(content in input);
333            let nested = parse_token_expr(&content, true)?;
334            let mut group = Group::new(Delimiter::Brace, nested);
335            group.set_span(delimiter.span.join());
336            TokenTree::Group(group)
337        } else if input.peek(token::Bracket) {
338            let content;
339            let delimiter = bracketed!(content in input);
340            let nested = parse_token_expr(&content, true)?;
341            let mut group = Group::new(Delimiter::Bracket, nested);
342            group.set_span(delimiter.span.join());
343            TokenTree::Group(group)
344        } else {
345            input.parse()?
346        };
347        tokens.push(token);
348    }
349    Ok(TokenStream::from_iter(tokens))
350}
351
352impl ToTokens for Display<'_> {
353    fn to_tokens(&self, tokens: &mut TokenStream) {
354        if self.infinite_recursive {
355            let span = self.fmt.span();
356            tokens.extend(quote_spanned! {span=>
357                #[warn(unconditional_recursion)]
358                fn _fmt() { _fmt() }
359            });
360        }
361
362        let fmt = &self.fmt;
363        let args = &self.args;
364
365        // Currently `write!(f, "text")` produces less efficient code than
366        // `f.write_str("text")`. We recognize the case when the format string
367        // has no braces and no interpolated values, and generate simpler code.
368        let write = if self.requires_fmt_machinery {
369            quote! {
370                ::core::write!(__formatter, #fmt #args)
371            }
372        } else {
373            quote! {
374                __formatter.write_str(#fmt)
375            }
376        };
377
378        tokens.extend(if self.bindings.is_empty() {
379            write
380        } else {
381            let locals = self.bindings.iter().map(|(local, _value)| local);
382            let values = self.bindings.iter().map(|(_local, value)| value);
383            quote! {
384                match (#(#values,)*) {
385                    (#(#locals,)*) => #write
386                }
387            }
388        });
389    }
390}
391
392impl ToTokens for Trait {
393    fn to_tokens(&self, tokens: &mut TokenStream) {
394        let trait_name = match self {
395            Trait::Debug => "Debug",
396            Trait::Display => "Display",
397            Trait::Octal => "Octal",
398            Trait::LowerHex => "LowerHex",
399            Trait::UpperHex => "UpperHex",
400            Trait::Pointer => "Pointer",
401            Trait::Binary => "Binary",
402            Trait::LowerExp => "LowerExp",
403            Trait::UpperExp => "UpperExp",
404        };
405        let ident = Ident::new(trait_name, Span::call_site());
406        tokens.extend(quote!(::core::fmt::#ident));
407    }
408}