struct_variant/
lib.rs

1#![cfg_attr(feature = "doc", feature(extended_key_value_attributes))]
2#![cfg_attr(feature = "doc", cfg_attr(feature = "doc", doc = include_str!("../README.md")))]
3
4use std::cmp::Ordering;
5use std::collections::HashMap;
6
7use itertools::EitherOrBoth;
8use itertools::Itertools;
9use proc_macro::TokenStream;
10use proc_macro_error::proc_macro_error;
11use proc_macro_error::Diagnostic;
12use proc_macro_error::Level;
13use quote::quote;
14use syn::braced;
15use syn::parenthesized;
16use syn::parse;
17use syn::parse::Parse;
18use syn::parse::ParseStream;
19use syn::parse::Parser;
20use syn::parse_str;
21use syn::punctuated::Punctuated;
22use syn::token::Brace;
23use syn::token::Paren;
24use syn::Attribute;
25use syn::GenericParam;
26use syn::Generics;
27use syn::Ident;
28use syn::Lifetime;
29use syn::Path;
30use syn::PathSegment;
31use syn::Token;
32use syn::TraitBound;
33use syn::Visibility;
34
35struct Field {
36	pub paren_token: Paren,
37	pub path: Path,
38}
39
40impl Parse for Field {
41	#[allow(clippy::eval_order_dependence)]
42	fn parse(input: ParseStream) -> Result<Self, syn::Error> {
43		let content;
44		Ok(Field {
45			paren_token: parenthesized!(content in input),
46			path: content.parse()?,
47		})
48	}
49}
50
51struct Variant {
52	pub ident: Ident,
53	pub field: Option<Field>,
54}
55
56impl Parse for Variant {
57	fn parse(input: ParseStream) -> Result<Self, syn::Error> {
58		Ok(Variant {
59			ident: input.parse()?,
60			field: {
61				if input.peek(Paren) {
62					let field: Field = input.parse()?;
63					Some(field)
64				} else {
65					None
66				}
67			},
68		})
69	}
70}
71
72fn ident_to_path(ident: &Ident) -> Path {
73	let mut punctuated = Punctuated::new();
74	punctuated.push_value(syn::PathSegment {
75		ident: ident.clone(),
76		arguments: syn::PathArguments::None,
77	});
78	Path {
79		leading_colon: None,
80		segments: punctuated,
81	}
82}
83
84fn path_segment_cmp(path_segment_lhs: &PathSegment, path_segment_rhs: &PathSegment) -> Ordering {
85	path_segment_lhs
86		.ident
87		.cmp(&path_segment_rhs.ident)
88		.then_with(|| Ordering::Less)
89}
90
91fn path_cmp(path_lhs: &Path, path_rhs: &Path) -> Ordering {
92	if path_lhs.leading_colon.is_some() {
93		if path_rhs.leading_colon.is_none() {
94			return Ordering::Less;
95		}
96	} else if path_rhs.leading_colon.is_some() {
97		return Ordering::Greater;
98	}
99
100	path_lhs
101		.segments
102		.iter()
103		.zip_longest(path_rhs.segments.iter())
104		.map(|x| {
105			match x {
106				EitherOrBoth::Both(path_segment_lhs, path_segment_rhs) => {
107					path_segment_cmp(path_segment_lhs, path_segment_rhs)
108				}
109				EitherOrBoth::Left(_) => Ordering::Less,
110				EitherOrBoth::Right(_) => Ordering::Greater,
111			}
112		})
113		.find(|ordering| !matches!(ordering, Ordering::Equal))
114		.unwrap_or(Ordering::Equal)
115}
116
117struct VariantEnum {
118	pub attrs: Vec<Attribute>,
119	pub vis: Visibility,
120	pub enum_token: Token![enum],
121	pub ident: Ident,
122	pub generics: Generics,
123	pub brace_token: Brace,
124	pub variants: Punctuated<Variant, Token![,]>,
125}
126
127impl Parse for VariantEnum {
128	#[allow(clippy::eval_order_dependence)]
129	fn parse(input: ParseStream) -> Result<Self, syn::Error> {
130		let content;
131		Ok(VariantEnum {
132			attrs: input.call(Attribute::parse_outer)?,
133			vis: input.parse()?,
134			enum_token: input.parse()?,
135			ident: input.parse()?,
136			generics: input.parse()?,
137			brace_token: braced!(content in input),
138			variants: content.parse_terminated(Variant::parse)?,
139		})
140	}
141}
142
143// TODO: https://github.com/rust-lang/rust/issues/54722
144// TODO: https://github.com/rust-lang/rust/issues/54140
145#[proc_macro_error]
146#[proc_macro_attribute]
147pub fn struct_variant(metadata: TokenStream, input: TokenStream) -> TokenStream {
148	let parser = Punctuated::<TraitBound, Token![+]>::parse_terminated;
149	let bound_item = match parser.parse(metadata) {
150		Ok(item) => item,
151		Err(e) => {
152			Diagnostic::spanned(
153				e.span(),
154				Level::Error,
155				format!("Unable to parse struct variant attribute: {} ", e),
156			)
157			.abort()
158		}
159	};
160
161	let enum_item: VariantEnum = match parse(input) {
162		Ok(item) => item,
163		Err(e) => {
164			Diagnostic::spanned(
165				e.span(),
166				Level::Error,
167				format!("Failed to parse struct variant input: {}", e),
168			)
169			.abort()
170		}
171	};
172
173	let mut struct_map = HashMap::new();
174	for variant in &enum_item.variants {
175		let ident = &variant.ident;
176		if let Some(variant_duplicate) = struct_map.insert(ident.clone(), variant) {
177			Diagnostic::spanned(
178				variant.ident.span(),
179				Level::Error,
180				format!("Duplicate variant name: {}", &ident),
181			)
182			.span_note(
183				variant_duplicate.ident.span(),
184				"Duplicate variant name first found here".to_string(),
185			)
186			.emit()
187		}
188	}
189
190	let attrs = &enum_item.attrs;
191	let vis = &enum_item.vis;
192	let ident = &enum_item.ident;
193	let generics = &enum_item.generics;
194	let generic_params = &generics.params;
195	let generics_params_types = generics.params.iter().filter_map(|param| {
196		match param {
197			GenericParam::Type(t) => Some(t.ident.clone()),
198			_ => None,
199		}
200	});
201	let lifetime_ident: Lifetime = parse_str("'struct_variant_lifetime").unwrap();
202	let generics_params_types_lifetimes = quote! {
203		#( #generics_params_types: #lifetime_ident ),*
204	};
205
206	let enum_list: Vec<_> = struct_map
207		.values()
208		.map(|variant| {
209			let struct_ident = variant
210				.field
211				.as_ref()
212				.map(|field| field.path.clone())
213				.unwrap_or_else(|| ident_to_path(&variant.ident));
214			(&variant.ident, struct_ident)
215		})
216		.sorted_by(
217			|(lhs_variant_ident, lhs_struct_ident), (rhs_variant_ident, rhs_struct_ident)| {
218				lhs_variant_ident
219					.cmp(rhs_variant_ident)
220					.then_with(|| path_cmp(lhs_struct_ident, rhs_struct_ident))
221			},
222		)
223		.collect();
224	let bound_list: Vec<&Ident> = bound_item
225		.iter()
226		.map(|trait_bound| trait_bound.path.get_ident())
227		.map(Option::unwrap)
228		.collect();
229
230	let enum_field = enum_list.iter().map(|(variant_ident, struct_ident)| {
231		quote! {
232			#variant_ident(#struct_ident)
233		}
234	});
235
236	let from_impl = enum_list.iter().map(|(variant_ident, struct_ident)| {
237		quote! {
238			impl#generics From<#struct_ident> for #ident#generics {
239				fn from(value: #struct_ident) -> Self {
240					Self::#variant_ident(value)
241				}
242			}
243		}
244	});
245
246	let variant_list: Vec<_> = enum_list.iter().map(|(id, _)| id).collect();
247	let as_ref_match_arm = quote! {
248		#( #ident::#variant_list(ref value) => value ),*
249	};
250
251	// TODO: https://github.com/rust-lang/rust/issues/75294
252	let result = quote! {
253		#(#attrs)*
254		#vis enum #ident#generics {
255			#(#enum_field),*
256		}
257
258		#(#from_impl)*
259
260		#(
261			impl<
262				#lifetime_ident,
263				#generic_params
264			> AsRef<dyn #bound_list + #lifetime_ident> for #ident#generics
265			where #generics_params_types_lifetimes {
266				fn as_ref(&self) -> &(dyn #bound_list + #lifetime_ident) {
267					match self {
268						#as_ref_match_arm
269					}
270				}
271			}
272		)*
273	};
274	result.into()
275}
276
277#[test]
278fn ui() {
279	let t = trybuild::TestCases::new();
280	t.compile_fail("tests/fail/missing-struct.rs");
281	t.compile_fail("tests/fail/enum-syntax.rs");
282	t.compile_fail("tests/fail/not-enum.rs");
283	t.pass("tests/pass/bound-single.rs");
284	t.pass("tests/pass/bound-none.rs");
285	t.pass("tests/pass/bound-multi.rs");
286	t.pass("tests/pass/rename.rs");
287	t.pass("tests/pass/generic.rs");
288}