wherror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, quote_spanned, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::{End, ParseStream};
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat,
8    LitInt, LitStr, Meta, Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<Source<'a>>,
14    pub backtrace: Option<&'a Attribute>,
15    pub location: Option<&'a Attribute>,
16    pub from: Option<From<'a>>,
17    pub transparent: Option<Transparent<'a>>,
18    pub fmt: Option<Fmt<'a>>,
19}
20
21#[derive(Clone)]
22pub struct Display<'a> {
23    pub original: &'a Attribute,
24    pub fmt: LitStr,
25    pub args: TokenStream,
26    pub requires_fmt_machinery: bool,
27    pub has_bonus_display: bool,
28    pub infinite_recursive: bool,
29    pub implied_bounds: Set<(usize, Trait)>,
30    pub bindings: Vec<(Ident, TokenStream)>,
31}
32
33#[derive(Copy, Clone)]
34pub struct Source<'a> {
35    pub original: &'a Attribute,
36    pub span: Span,
37}
38
39#[derive(Copy, Clone)]
40pub struct From<'a> {
41    pub original: &'a Attribute,
42    pub span: Span,
43}
44
45#[derive(Copy, Clone)]
46pub struct Transparent<'a> {
47    pub original: &'a Attribute,
48    pub span: Span,
49}
50
51#[derive(Clone)]
52pub struct Fmt<'a> {
53    pub original: &'a Attribute,
54    pub path: ExprPath,
55}
56
57#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
58pub enum Trait {
59    Debug,
60    Display,
61    Octal,
62    LowerHex,
63    UpperHex,
64    Pointer,
65    Binary,
66    LowerExp,
67    UpperExp,
68}
69
70pub fn get(input: &[Attribute]) -> Result<Attrs> {
71    let mut attrs = Attrs {
72        display: None,
73        source: None,
74        backtrace: None,
75        location: None,
76        from: None,
77        transparent: None,
78        fmt: None,
79    };
80
81    for attr in input {
82        if attr.path().is_ident("error") {
83            parse_error_attribute(&mut attrs, attr)?;
84        } else if attr.path().is_ident("source") {
85            attr.meta.require_path_only()?;
86            if attrs.source.is_some() {
87                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
88            }
89            let span = (attr.pound_token.span)
90                .join(attr.bracket_token.span.join())
91                .unwrap_or(attr.path().get_ident().unwrap().span());
92            attrs.source = Some(Source {
93                original: attr,
94                span,
95            });
96        } else if attr.path().is_ident("backtrace") {
97            attr.meta.require_path_only()?;
98            if attrs.backtrace.is_some() {
99                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
100            }
101            attrs.backtrace = Some(attr);
102        } else if attr.path().is_ident("location") {
103            attr.meta.require_path_only()?;
104            if attrs.location.is_some() {
105                return Err(Error::new_spanned(attr, "duplicate #[location] attribute"));
106            }
107            attrs.location = Some(attr);
108        } else if attr.path().is_ident("from") {
109            match attr.meta {
110                Meta::Path(_) => {}
111                Meta::List(_) | Meta::NameValue(_) => {
112                    // Assume this is meant for derive_more crate or something.
113                    continue;
114                }
115            }
116            if attrs.from.is_some() {
117                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
118            }
119            let span = (attr.pound_token.span)
120                .join(attr.bracket_token.span.join())
121                .unwrap_or(attr.path().get_ident().unwrap().span());
122            attrs.from = Some(From {
123                original: attr,
124                span,
125            });
126        }
127    }
128
129    Ok(attrs)
130}
131
132fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
133    mod kw {
134        syn::custom_keyword!(transparent);
135        syn::custom_keyword!(fmt);
136    }
137
138    attr.parse_args_with(|input: ParseStream| {
139        let lookahead = input.lookahead1();
140        let fmt = if lookahead.peek(LitStr) {
141            input.parse::<LitStr>()?
142        } else if lookahead.peek(kw::transparent) {
143            let kw: kw::transparent = input.parse()?;
144            if attrs.transparent.is_some() {
145                return Err(Error::new_spanned(
146                    attr,
147                    "duplicate #[error(transparent)] attribute",
148                ));
149            }
150            attrs.transparent = Some(Transparent {
151                original: attr,
152                span: kw.span,
153            });
154            return Ok(());
155        } else if lookahead.peek(kw::fmt) {
156            input.parse::<kw::fmt>()?;
157            input.parse::<Token![=]>()?;
158            let path: ExprPath = input.parse()?;
159            if attrs.fmt.is_some() {
160                return Err(Error::new_spanned(
161                    attr,
162                    "duplicate #[error(fmt = ...)] attribute",
163                ));
164            }
165            attrs.fmt = Some(Fmt {
166                original: attr,
167                path,
168            });
169            return Ok(());
170        } else {
171            return Err(lookahead.error());
172        };
173
174        let args = if input.is_empty() || input.peek(Token![,]) && input.peek2(End) {
175            input.parse::<Option<Token![,]>>()?;
176            TokenStream::new()
177        } else {
178            parse_token_expr(input, false)?
179        };
180
181        let requires_fmt_machinery = !args.is_empty();
182
183        let display = Display {
184            original: attr,
185            fmt,
186            args,
187            requires_fmt_machinery,
188            has_bonus_display: false,
189            infinite_recursive: false,
190            implied_bounds: Set::new(),
191            bindings: Vec::new(),
192        };
193        if attrs.display.is_some() {
194            return Err(Error::new_spanned(
195                attr,
196                "only one #[error(...)] attribute is allowed",
197            ));
198        }
199        attrs.display = Some(display);
200        Ok(())
201    })
202}
203
204fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
205    let mut tokens = Vec::new();
206    while !input.is_empty() {
207        if input.peek(token::Group) {
208            let group: TokenTree = input.parse()?;
209            tokens.push(group);
210            begin_expr = false;
211            continue;
212        }
213
214        if begin_expr && input.peek(Token![.]) {
215            if input.peek2(Ident) {
216                input.parse::<Token![.]>()?;
217                begin_expr = false;
218                continue;
219            } else if input.peek2(LitInt) {
220                input.parse::<Token![.]>()?;
221                let int: Index = input.parse()?;
222                tokens.push({
223                    let ident = format_ident!("_{}", int.index, span = int.span);
224                    TokenTree::Ident(ident)
225                });
226                begin_expr = false;
227                continue;
228            } else if input.peek2(LitFloat) {
229                let ahead = input.fork();
230                ahead.parse::<Token![.]>()?;
231                let float: LitFloat = ahead.parse()?;
232                let repr = float.to_string();
233                let mut indices = repr.split('.').map(syn::parse_str::<Index>);
234                if let (Some(Ok(first)), Some(Ok(second)), None) =
235                    (indices.next(), indices.next(), indices.next())
236                {
237                    input.advance_to(&ahead);
238                    tokens.push({
239                        let ident = format_ident!("_{}", first, span = float.span());
240                        TokenTree::Ident(ident)
241                    });
242                    tokens.push({
243                        let mut punct = Punct::new('.', Spacing::Alone);
244                        punct.set_span(float.span());
245                        TokenTree::Punct(punct)
246                    });
247                    tokens.push({
248                        let mut literal = Literal::u32_unsuffixed(second.index);
249                        literal.set_span(float.span());
250                        TokenTree::Literal(literal)
251                    });
252                    begin_expr = false;
253                    continue;
254                }
255            }
256        }
257
258        begin_expr = input.peek(Token![break])
259            || input.peek(Token![continue])
260            || input.peek(Token![if])
261            || input.peek(Token![in])
262            || input.peek(Token![match])
263            || input.peek(Token![mut])
264            || input.peek(Token![return])
265            || input.peek(Token![while])
266            || input.peek(Token![+])
267            || input.peek(Token![&])
268            || input.peek(Token![!])
269            || input.peek(Token![^])
270            || input.peek(Token![,])
271            || input.peek(Token![/])
272            || input.peek(Token![=])
273            || input.peek(Token![>])
274            || input.peek(Token![<])
275            || input.peek(Token![|])
276            || input.peek(Token![%])
277            || input.peek(Token![;])
278            || input.peek(Token![*])
279            || input.peek(Token![-]);
280
281        let token: TokenTree = if input.peek(token::Paren) {
282            let content;
283            let delimiter = parenthesized!(content in input);
284            let nested = parse_token_expr(&content, true)?;
285            let mut group = Group::new(Delimiter::Parenthesis, nested);
286            group.set_span(delimiter.span.join());
287            TokenTree::Group(group)
288        } else if input.peek(token::Brace) {
289            let content;
290            let delimiter = braced!(content in input);
291            let nested = parse_token_expr(&content, true)?;
292            let mut group = Group::new(Delimiter::Brace, nested);
293            group.set_span(delimiter.span.join());
294            TokenTree::Group(group)
295        } else if input.peek(token::Bracket) {
296            let content;
297            let delimiter = bracketed!(content in input);
298            let nested = parse_token_expr(&content, true)?;
299            let mut group = Group::new(Delimiter::Bracket, nested);
300            group.set_span(delimiter.span.join());
301            TokenTree::Group(group)
302        } else {
303            input.parse()?
304        };
305        tokens.push(token);
306    }
307    Ok(TokenStream::from_iter(tokens))
308}
309
310impl ToTokens for Display<'_> {
311    fn to_tokens(&self, tokens: &mut TokenStream) {
312        if self.infinite_recursive {
313            let span = self.fmt.span();
314            tokens.extend(quote_spanned! {span=>
315                #[warn(unconditional_recursion)]
316                fn _fmt() { _fmt() }
317            });
318        }
319
320        let fmt = &self.fmt;
321        let args = &self.args;
322
323        // Currently `write!(f, "text")` produces less efficient code than
324        // `f.write_str("text")`. We recognize the case when the format string
325        // has no braces and no interpolated values, and generate simpler code.
326        let write = if self.requires_fmt_machinery {
327            quote! {
328                ::core::write!(__formatter, #fmt #args)
329            }
330        } else {
331            quote! {
332                __formatter.write_str(#fmt)
333            }
334        };
335
336        tokens.extend(if self.bindings.is_empty() {
337            write
338        } else {
339            let locals = self.bindings.iter().map(|(local, _value)| local);
340            let values = self.bindings.iter().map(|(_local, value)| value);
341            quote! {
342                match (#(#values,)*) {
343                    (#(#locals,)*) => #write
344                }
345            }
346        });
347    }
348}
349
350impl ToTokens for Trait {
351    fn to_tokens(&self, tokens: &mut TokenStream) {
352        let trait_name = match self {
353            Trait::Debug => "Debug",
354            Trait::Display => "Display",
355            Trait::Octal => "Octal",
356            Trait::LowerHex => "LowerHex",
357            Trait::UpperHex => "UpperHex",
358            Trait::Pointer => "Pointer",
359            Trait::Binary => "Binary",
360            Trait::LowerExp => "LowerExp",
361            Trait::UpperExp => "UpperExp",
362        };
363        let ident = Ident::new(trait_name, Span::call_site());
364        tokens.extend(quote!(::core::fmt::#ident));
365    }
366}