1#![cfg_attr(feature = "doc", feature(extended_key_value_attributes))]
2#![cfg_attr(feature = "doc", cfg_attr(feature = "doc", doc = include_str!("../README.md")))]
3
4use std::cmp::Ordering;
5use std::collections::HashMap;
6
7use itertools::EitherOrBoth;
8use itertools::Itertools;
9use proc_macro::TokenStream;
10use proc_macro_error::proc_macro_error;
11use proc_macro_error::Diagnostic;
12use proc_macro_error::Level;
13use quote::quote;
14use syn::braced;
15use syn::parenthesized;
16use syn::parse;
17use syn::parse::Parse;
18use syn::parse::ParseStream;
19use syn::parse::Parser;
20use syn::parse_str;
21use syn::punctuated::Punctuated;
22use syn::token::Brace;
23use syn::token::Paren;
24use syn::Attribute;
25use syn::GenericParam;
26use syn::Generics;
27use syn::Ident;
28use syn::Lifetime;
29use syn::Path;
30use syn::PathSegment;
31use syn::Token;
32use syn::TraitBound;
33use syn::Visibility;
34
35struct Field {
36 pub paren_token: Paren,
37 pub path: Path,
38}
39
40impl Parse for Field {
41 #[allow(clippy::eval_order_dependence)]
42 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
43 let content;
44 Ok(Field {
45 paren_token: parenthesized!(content in input),
46 path: content.parse()?,
47 })
48 }
49}
50
51struct Variant {
52 pub ident: Ident,
53 pub field: Option<Field>,
54}
55
56impl Parse for Variant {
57 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
58 Ok(Variant {
59 ident: input.parse()?,
60 field: {
61 if input.peek(Paren) {
62 let field: Field = input.parse()?;
63 Some(field)
64 } else {
65 None
66 }
67 },
68 })
69 }
70}
71
72fn ident_to_path(ident: &Ident) -> Path {
73 let mut punctuated = Punctuated::new();
74 punctuated.push_value(syn::PathSegment {
75 ident: ident.clone(),
76 arguments: syn::PathArguments::None,
77 });
78 Path {
79 leading_colon: None,
80 segments: punctuated,
81 }
82}
83
84fn path_segment_cmp(path_segment_lhs: &PathSegment, path_segment_rhs: &PathSegment) -> Ordering {
85 path_segment_lhs
86 .ident
87 .cmp(&path_segment_rhs.ident)
88 .then_with(|| Ordering::Less)
89}
90
91fn path_cmp(path_lhs: &Path, path_rhs: &Path) -> Ordering {
92 if path_lhs.leading_colon.is_some() {
93 if path_rhs.leading_colon.is_none() {
94 return Ordering::Less;
95 }
96 } else if path_rhs.leading_colon.is_some() {
97 return Ordering::Greater;
98 }
99
100 path_lhs
101 .segments
102 .iter()
103 .zip_longest(path_rhs.segments.iter())
104 .map(|x| {
105 match x {
106 EitherOrBoth::Both(path_segment_lhs, path_segment_rhs) => {
107 path_segment_cmp(path_segment_lhs, path_segment_rhs)
108 }
109 EitherOrBoth::Left(_) => Ordering::Less,
110 EitherOrBoth::Right(_) => Ordering::Greater,
111 }
112 })
113 .find(|ordering| !matches!(ordering, Ordering::Equal))
114 .unwrap_or(Ordering::Equal)
115}
116
117struct VariantEnum {
118 pub attrs: Vec<Attribute>,
119 pub vis: Visibility,
120 pub enum_token: Token![enum],
121 pub ident: Ident,
122 pub generics: Generics,
123 pub brace_token: Brace,
124 pub variants: Punctuated<Variant, Token![,]>,
125}
126
127impl Parse for VariantEnum {
128 #[allow(clippy::eval_order_dependence)]
129 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
130 let content;
131 Ok(VariantEnum {
132 attrs: input.call(Attribute::parse_outer)?,
133 vis: input.parse()?,
134 enum_token: input.parse()?,
135 ident: input.parse()?,
136 generics: input.parse()?,
137 brace_token: braced!(content in input),
138 variants: content.parse_terminated(Variant::parse)?,
139 })
140 }
141}
142
143#[proc_macro_error]
146#[proc_macro_attribute]
147pub fn struct_variant(metadata: TokenStream, input: TokenStream) -> TokenStream {
148 let parser = Punctuated::<TraitBound, Token![+]>::parse_terminated;
149 let bound_item = match parser.parse(metadata) {
150 Ok(item) => item,
151 Err(e) => {
152 Diagnostic::spanned(
153 e.span(),
154 Level::Error,
155 format!("Unable to parse struct variant attribute: {} ", e),
156 )
157 .abort()
158 }
159 };
160
161 let enum_item: VariantEnum = match parse(input) {
162 Ok(item) => item,
163 Err(e) => {
164 Diagnostic::spanned(
165 e.span(),
166 Level::Error,
167 format!("Failed to parse struct variant input: {}", e),
168 )
169 .abort()
170 }
171 };
172
173 let mut struct_map = HashMap::new();
174 for variant in &enum_item.variants {
175 let ident = &variant.ident;
176 if let Some(variant_duplicate) = struct_map.insert(ident.clone(), variant) {
177 Diagnostic::spanned(
178 variant.ident.span(),
179 Level::Error,
180 format!("Duplicate variant name: {}", &ident),
181 )
182 .span_note(
183 variant_duplicate.ident.span(),
184 "Duplicate variant name first found here".to_string(),
185 )
186 .emit()
187 }
188 }
189
190 let attrs = &enum_item.attrs;
191 let vis = &enum_item.vis;
192 let ident = &enum_item.ident;
193 let generics = &enum_item.generics;
194 let generic_params = &generics.params;
195 let generics_params_types = generics.params.iter().filter_map(|param| {
196 match param {
197 GenericParam::Type(t) => Some(t.ident.clone()),
198 _ => None,
199 }
200 });
201 let lifetime_ident: Lifetime = parse_str("'struct_variant_lifetime").unwrap();
202 let generics_params_types_lifetimes = quote! {
203 #( #generics_params_types: #lifetime_ident ),*
204 };
205
206 let enum_list: Vec<_> = struct_map
207 .values()
208 .map(|variant| {
209 let struct_ident = variant
210 .field
211 .as_ref()
212 .map(|field| field.path.clone())
213 .unwrap_or_else(|| ident_to_path(&variant.ident));
214 (&variant.ident, struct_ident)
215 })
216 .sorted_by(
217 |(lhs_variant_ident, lhs_struct_ident), (rhs_variant_ident, rhs_struct_ident)| {
218 lhs_variant_ident
219 .cmp(rhs_variant_ident)
220 .then_with(|| path_cmp(lhs_struct_ident, rhs_struct_ident))
221 },
222 )
223 .collect();
224 let bound_list: Vec<&Ident> = bound_item
225 .iter()
226 .map(|trait_bound| trait_bound.path.get_ident())
227 .map(Option::unwrap)
228 .collect();
229
230 let enum_field = enum_list.iter().map(|(variant_ident, struct_ident)| {
231 quote! {
232 #variant_ident(#struct_ident)
233 }
234 });
235
236 let from_impl = enum_list.iter().map(|(variant_ident, struct_ident)| {
237 quote! {
238 impl#generics From<#struct_ident> for #ident#generics {
239 fn from(value: #struct_ident) -> Self {
240 Self::#variant_ident(value)
241 }
242 }
243 }
244 });
245
246 let variant_list: Vec<_> = enum_list.iter().map(|(id, _)| id).collect();
247 let as_ref_match_arm = quote! {
248 #( #ident::#variant_list(ref value) => value ),*
249 };
250
251 let result = quote! {
253 #(#attrs)*
254 #vis enum #ident#generics {
255 #(#enum_field),*
256 }
257
258 #(#from_impl)*
259
260 #(
261 impl<
262 #lifetime_ident,
263 #generic_params
264 > AsRef<dyn #bound_list + #lifetime_ident> for #ident#generics
265 where #generics_params_types_lifetimes {
266 fn as_ref(&self) -> &(dyn #bound_list + #lifetime_ident) {
267 match self {
268 #as_ref_match_arm
269 }
270 }
271 }
272 )*
273 };
274 result.into()
275}
276
277#[test]
278fn ui() {
279 let t = trybuild::TestCases::new();
280 t.compile_fail("tests/fail/missing-struct.rs");
281 t.compile_fail("tests/fail/enum-syntax.rs");
282 t.compile_fail("tests/fail/not-enum.rs");
283 t.pass("tests/pass/bound-single.rs");
284 t.pass("tests/pass/bound-none.rs");
285 t.pass("tests/pass/bound-multi.rs");
286 t.pass("tests/pass/rename.rs");
287 t.pass("tests/pass/generic.rs");
288}