vesta_syntax/
lib.rs

1use itertools::Itertools;
2use proc_macro2::Span;
3use proc_macro_crate::FoundCrate;
4use quote::{format_ident, quote, quote_spanned, ToTokens};
5use std::{
6    collections::{BTreeMap, BTreeSet},
7    env,
8};
9use syn::{
10    braced, parenthesized,
11    parse::{Parse, ParseStream},
12    parse_quote,
13    spanned::Spanned,
14    token::{Brace, Paren, Underscore},
15    Arm, Attribute, Error, Expr, Ident, LitInt, Pat, PatWild, Path, Token,
16};
17
18/// Get the absolute path to `vesta`, from within the package itself, the doc tests, or any other
19/// package. This means we can use these proc macros from inside `vesta` with no issue.
20pub fn vesta_path() -> Path {
21    match proc_macro_crate::crate_name("vesta") {
22        Ok(FoundCrate::Itself) if env::var("CARGO_CRATE_NAME").as_deref() == Ok("vesta") => {
23            parse_quote!(crate::vesta)
24        }
25        Ok(FoundCrate::Itself) | Err(_) => parse_quote!(::vesta),
26        Ok(FoundCrate::Name(name)) => {
27            let name_ident = format_ident!("{}", name);
28            parse_quote!(::#name_ident)
29        }
30    }
31}
32
33/// The input syntax to `vesta`'s `case!` macro. This implements [`Parse`].
34#[derive(Clone)]
35pub struct CaseInput {
36    /// The scrutinee of the `case!` macro: the thing upon which we are matching.
37    pub scrutinee: Expr,
38    /// The brace token wrapping all the cases.
39    pub brace_token: Brace,
40    /// The cases, as input by the user.
41    pub arms: Vec<CaseArm>,
42}
43
44impl Parse for CaseInput {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        let scrutinee = Expr::parse_without_eager_brace(input)?;
47        let content;
48        let brace_token = braced!(content in input);
49        let mut arms = Vec::new();
50        while !content.is_empty() {
51            arms.push(content.call(CaseArm::parse)?);
52        }
53        Ok(CaseInput {
54            scrutinee,
55            arms,
56            brace_token,
57        })
58    }
59}
60
61/// A single arm of a `case!`, i.e. `1(x, Some(y)) => x + y,`. This implements [`Parse`].
62#[derive(Clone)]
63pub struct CaseArm {
64    /// The tag for this case, or `None` if the case was a catch-all `_` case.
65    pub tag: Option<usize>,
66    /// The span for the tag.
67    pub tag_span: Span,
68    /// The [`Arm`] for the case, i.e. the pattern following the tag, its `=>`, and its body.
69    pub arm: Arm,
70}
71
72impl Parse for CaseArm {
73    fn parse(input: ParseStream) -> syn::Result<Self> {
74        // We will fill in these fields:
75        let tag;
76        let tag_span;
77        let mut arm;
78
79        // Parse outer attributes
80        let attrs = input.call(Attribute::parse_outer)?;
81
82        if input.peek(Token![_]) {
83            // If wildcard pattern, the tag is `None`, parse an arm also with a wildcard pattern
84            tag = None;
85            tag_span = input.fork().parse::<Token![_]>()?.span();
86            arm = input.parse()?;
87        } else if input.peek2(Paren) {
88            // If of the form `N(...) => ...`, we *consume* the `N` token, then parse an `Arm` with
89            // the given pattern (after verifying that the thing *inside* the parentheses is
90            // non-empty, so as to make sure you can't write `N()`: you have to do either `N(())` or
91            // `N` alone)
92            let lit = input.parse::<LitInt>()?;
93            tag = Some(lit.base10_parse::<usize>()?);
94            tag_span = lit.span();
95            let pat;
96            parenthesized!(pat in input.fork());
97            if pat.is_empty() {
98                return Err(pat.error("expected pattern"));
99            }
100            arm = input.parse::<Arm>()?;
101        } else {
102            // If of the form `N => ...`, we parse the `N` token but do *not* consume it, then parse
103            // an `Arm` which will use that `N` token as its pattern, allowing us to re-use the
104            // `Arm`-parsing built into `syn`, then replace the pattern in the `Arm` itself with
105            // `_`, which is what we wanted in the first place
106            let lit = input.fork().parse::<LitInt>()?;
107            tag = Some(lit.base10_parse::<usize>()?);
108            tag_span = lit.span();
109            arm = input.parse::<Arm>()?;
110            // Explicitly construct a `_` pattern with the right span, so unreachable pattern
111            // warnings get displayed nicely
112            arm.pat = Pat::Wild(PatWild {
113                attrs: vec![],
114                underscore_token: Underscore { spans: [tag_span] },
115            });
116        };
117
118        // Add the previously-parsed outer attributes to the arm
119        arm.attrs.extend(attrs);
120
121        Ok(CaseArm { tag, tag_span, arm })
122    }
123}
124
125impl CaseInput {
126    /// Compile a [`CaseInput`] into a [`CaseOutput`], if it is valid input, or return an [`Error`]
127    /// if it is missing cases.
128    pub fn compile(self) -> Result<CaseOutput, Error> {
129        let CaseInput {
130            scrutinee,
131            arms,
132            brace_token,
133        } = self;
134
135        let mut cases: BTreeMap<usize, Vec<(Span, Arm)>> = BTreeMap::new();
136        let mut default: Option<(Span, Arm)> = None;
137        let mut unreachable: Vec<CaseArm> = Vec::new();
138        let mut all_tags = BTreeSet::new();
139
140        // Read each case arm into the appropriate location
141        for case_arm in arms {
142            if default.is_none() {
143                if let Some(tag) = case_arm.tag {
144                    all_tags.insert(tag);
145                    cases
146                        .entry(tag)
147                        .or_insert_with(Vec::new)
148                        .push((case_arm.tag_span, case_arm.arm));
149                } else {
150                    default = Some((case_arm.tag_span, case_arm.arm));
151                }
152            } else {
153                unreachable.push(case_arm);
154            }
155        }
156
157        // Compute the missing cases, if any were skipped when there was not a default
158        let max_tag: Option<usize> = all_tags.iter().rev().next().cloned();
159        let missing_cases = if let Some(max_tag) = max_tag {
160            if default.is_none() {
161                (0..=max_tag)
162                    .filter(|tag| !all_tags.contains(tag))
163                    .collect()
164            } else {
165                Vec::new()
166            }
167        } else {
168            Vec::new()
169        };
170
171        if missing_cases.is_empty() {
172            Ok(CaseOutput {
173                scrutinee,
174                brace_token,
175                cases,
176                default,
177                unreachable,
178            })
179        } else {
180            // Construct the list of missing cases as a nice string
181            let mut patterns = String::new();
182            let max = missing_cases.len().saturating_sub(1);
183            let mut previous = false;
184            for (n, tag) in missing_cases.iter().enumerate() {
185                if previous {
186                    if n == max {
187                        if max > 1 {
188                            patterns.push(',');
189                        }
190                        patterns.push_str(" and ");
191                    } else {
192                        patterns.push_str(", ");
193                    }
194                }
195                patterns.push_str(&format!("`{}`", tag));
196                previous = true;
197            }
198            let message = format!("non-exhaustive patterns: {} not covered", patterns);
199            Err(Error::new(scrutinee.span(), message))
200        }
201    }
202}
203
204/// The output of `vesta`'s `case!` macro, in a representation suitable for turning back into tokens
205/// via [`ToTokens`].
206#[derive(Clone)]
207pub struct CaseOutput {
208    /// The scrutinee of the `case!`.
209    pub scrutinee: Expr,
210    /// The brace token wrapping the whole of the cases.
211    pub brace_token: Brace,
212    /// The reachable cases, organized by which tag they belong to, ordered within each tag by the
213    /// order they were listed in the original input.
214    pub cases: BTreeMap<usize, Vec<(Span, Arm)>>,
215    /// The default case `_ => ...`, if there was any.
216    pub default: Option<(Span, Arm)>,
217    /// All the unreachable arms, for which we emit code so as to generate warnings.
218    pub unreachable: Vec<CaseArm>,
219}
220
221impl ToTokens for CaseOutput {
222    fn to_tokens(&self, stream: &mut proc_macro2::TokenStream) {
223        let vesta_path = crate::vesta_path();
224
225        // Generate hygienic idents named "value" and "tag"
226        let value_ident = Ident::new("value", Span::mixed_site());
227        let tag_ident = Ident::new("tag", Span::mixed_site());
228
229        let CaseOutput {
230            scrutinee,
231            brace_token,
232            cases,
233            default,
234            unreachable,
235        } = self;
236
237        // Get the span for all the cases
238        let cases_span = brace_token.span;
239
240        // Compute the max tag ever mentioned
241        let mut max_tag = None;
242        cases
243            .keys()
244            .chain(
245                unreachable
246                    .iter()
247                    .filter_map(|case_arm| case_arm.tag.as_ref()),
248            )
249            .for_each(|tag| {
250                max_tag = match max_tag {
251                    None => Some(tag),
252                    Some(max_tag) => Some(max_tag.max(tag)),
253                }
254            });
255
256        // Determine whether all the combined cases should have been exhaustive, and if so, what
257        // their bound should be
258        let exhaustive_cases = if default.is_some() {
259            None
260        } else {
261            Some(max_tag.map(|t| t + 1).unwrap_or(0))
262        };
263
264        // Generate all the reachable outer arms
265        let active_arms = cases.iter().map(|(tag, inner_cases)| {
266            let inner_arms = inner_cases.iter().map(|(_, arm)| arm);
267
268            // The pattern for the outer match on the tag, with a good span
269            let tag_span: Span = inner_cases
270                .iter()
271                .map(|(span, _)| span)
272                .cloned()
273                .fold1(|s, t| s.join(t).unwrap_or(s))
274                .unwrap_or_else(Span::call_site);
275            let pat = quote_spanned!(tag_span=> ::std::option::Option::Some(#tag));
276
277            // The default arm, if one exists, is allowed to be unreachable but always inserted in
278            // the inner match if it exists
279            let default_arm = default.iter().map(|(_, arm)| {
280                quote! {
281                    #[allow(unreachable_patterns)]
282                    #arm
283                }
284            });
285
286            quote! {
287                #pat => match unsafe {
288                    #vesta_path::Case::<#tag>::case(#value_ident)
289                } {
290                    #(#inner_arms)*
291                    #(#default_arm)*
292                }
293            }
294        });
295
296        // Generate the exhaustive fall-through case, if one is necessary
297        let exhaustive_arm = exhaustive_cases.iter().map(|num_cases| {
298            quote! {
299                _ => {
300                    #vesta_path::assert_exhaustive::<_, #num_cases>(&#value_ident);
301                    unsafe { #vesta_path::unreachable() }
302                }
303            }
304        });
305
306        // Generate all the unreachable arms, for maximum warning reporting
307        let unreachable_arms = unreachable
308            .iter()
309            .map(|CaseArm { tag, arm, tag_span }| match tag {
310                Some(tag) => quote_spanned! { *tag_span=>
311                    ::std::option::Option::Some(#tag) => match unsafe {
312                        #vesta_path::Case::<#tag>::case(#value_ident)
313                    } {
314                        #arm
315                        // We need to make this pattern match complete so that this type-checks, but
316                        // the only reason we're generating code at all is for warnings, so here we
317                        // say the next arm is unreachable: it *is* unreachable, because this whole
318                        // match expression is unreachable. This is only a valid assumption because
319                        // all the arms for which this is generated are unreachable.
320                        _ => unsafe { #vesta_path::unreachable() }
321                    }
322                },
323                None => quote!(#arm),
324            });
325
326        // Glue all the arms together
327        let arms = active_arms.chain(
328            exhaustive_arm.chain(
329                default
330                    .iter()
331                    // Unlike in the inner matches, we don't `#[allow(unreachable)]` the default
332                    .map(|(_, arm)| quote!(#arm))
333                    .chain(unreachable_arms),
334            ),
335        );
336
337        stream.extend(quote_spanned!(cases_span=> {
338            let #value_ident = #scrutinee;
339            let #tag_ident = #vesta_path::Match::tag(&#value_ident);
340            #[allow(unused_parens)]
341            match #tag_ident {
342                #(#arms)*
343            }
344        }))
345    }
346}