1use itertools::Itertools;
2use proc_macro2::Span;
3use proc_macro_crate::FoundCrate;
4use quote::{format_ident, quote, quote_spanned, ToTokens};
5use std::{
6 collections::{BTreeMap, BTreeSet},
7 env,
8};
9use syn::{
10 braced, parenthesized,
11 parse::{Parse, ParseStream},
12 parse_quote,
13 spanned::Spanned,
14 token::{Brace, Paren, Underscore},
15 Arm, Attribute, Error, Expr, Ident, LitInt, Pat, PatWild, Path, Token,
16};
17
18pub fn vesta_path() -> Path {
21 match proc_macro_crate::crate_name("vesta") {
22 Ok(FoundCrate::Itself) if env::var("CARGO_CRATE_NAME").as_deref() == Ok("vesta") => {
23 parse_quote!(crate::vesta)
24 }
25 Ok(FoundCrate::Itself) | Err(_) => parse_quote!(::vesta),
26 Ok(FoundCrate::Name(name)) => {
27 let name_ident = format_ident!("{}", name);
28 parse_quote!(::#name_ident)
29 }
30 }
31}
32
33#[derive(Clone)]
35pub struct CaseInput {
36 pub scrutinee: Expr,
38 pub brace_token: Brace,
40 pub arms: Vec<CaseArm>,
42}
43
44impl Parse for CaseInput {
45 fn parse(input: ParseStream) -> syn::Result<Self> {
46 let scrutinee = Expr::parse_without_eager_brace(input)?;
47 let content;
48 let brace_token = braced!(content in input);
49 let mut arms = Vec::new();
50 while !content.is_empty() {
51 arms.push(content.call(CaseArm::parse)?);
52 }
53 Ok(CaseInput {
54 scrutinee,
55 arms,
56 brace_token,
57 })
58 }
59}
60
61#[derive(Clone)]
63pub struct CaseArm {
64 pub tag: Option<usize>,
66 pub tag_span: Span,
68 pub arm: Arm,
70}
71
72impl Parse for CaseArm {
73 fn parse(input: ParseStream) -> syn::Result<Self> {
74 let tag;
76 let tag_span;
77 let mut arm;
78
79 let attrs = input.call(Attribute::parse_outer)?;
81
82 if input.peek(Token![_]) {
83 tag = None;
85 tag_span = input.fork().parse::<Token![_]>()?.span();
86 arm = input.parse()?;
87 } else if input.peek2(Paren) {
88 let lit = input.parse::<LitInt>()?;
93 tag = Some(lit.base10_parse::<usize>()?);
94 tag_span = lit.span();
95 let pat;
96 parenthesized!(pat in input.fork());
97 if pat.is_empty() {
98 return Err(pat.error("expected pattern"));
99 }
100 arm = input.parse::<Arm>()?;
101 } else {
102 let lit = input.fork().parse::<LitInt>()?;
107 tag = Some(lit.base10_parse::<usize>()?);
108 tag_span = lit.span();
109 arm = input.parse::<Arm>()?;
110 arm.pat = Pat::Wild(PatWild {
113 attrs: vec![],
114 underscore_token: Underscore { spans: [tag_span] },
115 });
116 };
117
118 arm.attrs.extend(attrs);
120
121 Ok(CaseArm { tag, tag_span, arm })
122 }
123}
124
125impl CaseInput {
126 pub fn compile(self) -> Result<CaseOutput, Error> {
129 let CaseInput {
130 scrutinee,
131 arms,
132 brace_token,
133 } = self;
134
135 let mut cases: BTreeMap<usize, Vec<(Span, Arm)>> = BTreeMap::new();
136 let mut default: Option<(Span, Arm)> = None;
137 let mut unreachable: Vec<CaseArm> = Vec::new();
138 let mut all_tags = BTreeSet::new();
139
140 for case_arm in arms {
142 if default.is_none() {
143 if let Some(tag) = case_arm.tag {
144 all_tags.insert(tag);
145 cases
146 .entry(tag)
147 .or_insert_with(Vec::new)
148 .push((case_arm.tag_span, case_arm.arm));
149 } else {
150 default = Some((case_arm.tag_span, case_arm.arm));
151 }
152 } else {
153 unreachable.push(case_arm);
154 }
155 }
156
157 let max_tag: Option<usize> = all_tags.iter().rev().next().cloned();
159 let missing_cases = if let Some(max_tag) = max_tag {
160 if default.is_none() {
161 (0..=max_tag)
162 .filter(|tag| !all_tags.contains(tag))
163 .collect()
164 } else {
165 Vec::new()
166 }
167 } else {
168 Vec::new()
169 };
170
171 if missing_cases.is_empty() {
172 Ok(CaseOutput {
173 scrutinee,
174 brace_token,
175 cases,
176 default,
177 unreachable,
178 })
179 } else {
180 let mut patterns = String::new();
182 let max = missing_cases.len().saturating_sub(1);
183 let mut previous = false;
184 for (n, tag) in missing_cases.iter().enumerate() {
185 if previous {
186 if n == max {
187 if max > 1 {
188 patterns.push(',');
189 }
190 patterns.push_str(" and ");
191 } else {
192 patterns.push_str(", ");
193 }
194 }
195 patterns.push_str(&format!("`{}`", tag));
196 previous = true;
197 }
198 let message = format!("non-exhaustive patterns: {} not covered", patterns);
199 Err(Error::new(scrutinee.span(), message))
200 }
201 }
202}
203
204#[derive(Clone)]
207pub struct CaseOutput {
208 pub scrutinee: Expr,
210 pub brace_token: Brace,
212 pub cases: BTreeMap<usize, Vec<(Span, Arm)>>,
215 pub default: Option<(Span, Arm)>,
217 pub unreachable: Vec<CaseArm>,
219}
220
221impl ToTokens for CaseOutput {
222 fn to_tokens(&self, stream: &mut proc_macro2::TokenStream) {
223 let vesta_path = crate::vesta_path();
224
225 let value_ident = Ident::new("value", Span::mixed_site());
227 let tag_ident = Ident::new("tag", Span::mixed_site());
228
229 let CaseOutput {
230 scrutinee,
231 brace_token,
232 cases,
233 default,
234 unreachable,
235 } = self;
236
237 let cases_span = brace_token.span;
239
240 let mut max_tag = None;
242 cases
243 .keys()
244 .chain(
245 unreachable
246 .iter()
247 .filter_map(|case_arm| case_arm.tag.as_ref()),
248 )
249 .for_each(|tag| {
250 max_tag = match max_tag {
251 None => Some(tag),
252 Some(max_tag) => Some(max_tag.max(tag)),
253 }
254 });
255
256 let exhaustive_cases = if default.is_some() {
259 None
260 } else {
261 Some(max_tag.map(|t| t + 1).unwrap_or(0))
262 };
263
264 let active_arms = cases.iter().map(|(tag, inner_cases)| {
266 let inner_arms = inner_cases.iter().map(|(_, arm)| arm);
267
268 let tag_span: Span = inner_cases
270 .iter()
271 .map(|(span, _)| span)
272 .cloned()
273 .fold1(|s, t| s.join(t).unwrap_or(s))
274 .unwrap_or_else(Span::call_site);
275 let pat = quote_spanned!(tag_span=> ::std::option::Option::Some(#tag));
276
277 let default_arm = default.iter().map(|(_, arm)| {
280 quote! {
281 #[allow(unreachable_patterns)]
282 #arm
283 }
284 });
285
286 quote! {
287 #pat => match unsafe {
288 #vesta_path::Case::<#tag>::case(#value_ident)
289 } {
290 #(#inner_arms)*
291 #(#default_arm)*
292 }
293 }
294 });
295
296 let exhaustive_arm = exhaustive_cases.iter().map(|num_cases| {
298 quote! {
299 _ => {
300 #vesta_path::assert_exhaustive::<_, #num_cases>(&#value_ident);
301 unsafe { #vesta_path::unreachable() }
302 }
303 }
304 });
305
306 let unreachable_arms = unreachable
308 .iter()
309 .map(|CaseArm { tag, arm, tag_span }| match tag {
310 Some(tag) => quote_spanned! { *tag_span=>
311 ::std::option::Option::Some(#tag) => match unsafe {
312 #vesta_path::Case::<#tag>::case(#value_ident)
313 } {
314 #arm
315 _ => unsafe { #vesta_path::unreachable() }
321 }
322 },
323 None => quote!(#arm),
324 });
325
326 let arms = active_arms.chain(
328 exhaustive_arm.chain(
329 default
330 .iter()
331 .map(|(_, arm)| quote!(#arm))
333 .chain(unreachable_arms),
334 ),
335 );
336
337 stream.extend(quote_spanned!(cases_span=> {
338 let #value_ident = #scrutinee;
339 let #tag_ident = #vesta_path::Match::tag(&#value_ident);
340 #[allow(unused_parens)]
341 match #tag_ident {
342 #(#arms)*
343 }
344 }))
345 }
346}