1#![allow(clippy::eval_order_dependence)]
2extern crate proc_macro;
3
4use linked_hash_set::LinkedHashSet;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use std::iter::FromIterator;
10use syn::visit::Visit;
11use syn::{
12 braced, bracketed, parenthesized,
13 parse::{Parse, ParseStream},
14 parse_macro_input,
15 punctuated::Punctuated,
16 token, Attribute, Field, GenericArgument, Ident, Result, Token, Type, Visibility, WhereClause,
17 WherePredicate,
18};
19
20#[derive(Default)]
21struct TypeArgumentsCollectorVisitor {
22 items: LinkedHashSet<String>,
23}
24
25impl<'ast> Visit<'ast> for TypeArgumentsCollectorVisitor {
26 fn visit_ident(&mut self, id: &'ast Ident) {
27 self.items.insert(id.to_string());
28 }
29}
30
31struct TypeArgumentsCheckVisitor<'ast> {
32 args: &'ast Vec<TypeArgumentConfiguration<'ast>>,
33 matched: Vec<&'ast TypeArgumentConfiguration<'ast>>,
34}
35
36impl<'ast> Visit<'ast> for TypeArgumentsCheckVisitor<'ast> {
37 fn visit_ident(&mut self, id: &'ast Ident) {
38 let name = &id.to_string();
39 for arg in self.args.iter() {
40 for id in arg.identifiers.iter() {
41 if id == name {
42 self.matched.push(arg);
43 }
44 }
45 }
46 }
47}
48
49struct Generics {
50 #[allow(dead_code)]
51 start: Token![<],
52 args: Punctuated<GenericArgument, Token![,]>,
53 #[allow(dead_code)]
54 end: Token![>],
55}
56
57impl Parse for Generics {
58 fn parse(input: ParseStream) -> Result<Self> {
59 Ok(Generics {
60 start: input.parse()?,
61 args: {
62 let mut args = Punctuated::new();
63 loop {
64 if input.peek(Token![>]) {
65 break;
66 }
67 let value = input.parse()?;
68 args.push_value(value);
69 if input.peek(Token![>]) {
70 break;
71 }
72 let punct = input.parse()?;
73 args.push_punct(punct);
74 }
75 args
76 },
77 end: input.parse()?,
78 })
79 }
80}
81
82enum ActionVariant {
83 Omit(Punctuated<Ident, Token![,]>),
84 Include(Punctuated<Ident, Token![,]>),
85 Attr(Punctuated<Attribute, Token![,]>),
86 Upsert(Punctuated<Field, Token![,]>),
87 AsTuple,
88}
89
90struct Action {
91 #[allow(dead_code)]
92 parens: token::Paren,
93 fields: ActionVariant,
94}
95
96impl Parse for Action {
97 fn parse(input: ParseStream) -> Result<Self> {
98 let content;
99 let name: Ident = input.parse()?;
100 let name_str = &name.to_string();
101
102 Ok(Action {
103 parens: parenthesized!(content in input),
104 fields: {
105 if name_str == "omit" {
106 ActionVariant::Omit(content.parse_terminated(Ident::parse)?)
107 } else if name_str == "include" {
108 ActionVariant::Include(content.parse_terminated(Ident::parse)?)
109 } else if name_str == "as_tuple" {
110 ActionVariant::AsTuple
111 } else if name_str == "attr" {
112 use syn::parse_quote::ParseQuote;
113 ActionVariant::Attr(content.parse_terminated(Attribute::parse)?)
114 } else if name_str == "upsert" {
115 ActionVariant::Upsert(content.parse_terminated(Field::parse_named)?)
116 } else {
117 panic!("{} is not a valid action", name_str)
118 }
119 },
120 })
121 }
122}
123
124struct ConfigurationExpr {
125 struct_name: Ident,
126 #[allow(dead_code)]
127 arrow: Token![=>],
128 #[allow(dead_code)]
129 bracket: token::Bracket,
130 actions: Punctuated<Action, Token![,]>,
131}
132
133impl Parse for ConfigurationExpr {
134 fn parse(input: ParseStream) -> Result<Self> {
135 let struct_content;
136
137 Ok(ConfigurationExpr {
138 struct_name: input.parse()?,
139 arrow: input.parse::<Token![=>]>()?,
140 bracket: bracketed!(struct_content in input),
141 actions: struct_content.parse_terminated(Action::parse)?,
142 })
143 }
144}
145
146struct StructGen {
147 attrs: Vec<Attribute>,
148 visibility: Option<Visibility>,
149 generics: Option<Generics>,
150 where_clause: Option<WhereClause>,
151 #[allow(dead_code)]
152 brace: token::Brace,
153 fields: Punctuated<Field, Token![,]>,
154 #[allow(dead_code)]
155 arrow: token::FatArrow,
156 #[allow(dead_code)]
157 conf_brace: token::Brace,
158 conf: Punctuated<ConfigurationExpr, Token![,]>,
159}
160
161impl Parse for StructGen {
162 fn parse(input: ParseStream) -> Result<Self> {
163 let struct_content;
164 let conf_content;
165
166 Ok(StructGen {
167 attrs: input.call(Attribute::parse_outer)?,
168 visibility: {
169 if input.lookahead1().peek(Token![pub]) {
170 Some(input.parse()?)
171 } else {
172 None
173 }
174 },
175 generics: {
176 if input.lookahead1().peek(Token![<]) {
177 Some(input.parse()?)
178 } else {
179 None
180 }
181 },
182 where_clause: {
183 if input.lookahead1().peek(Token![where]) {
184 Some(input.parse()?)
185 } else {
186 None
187 }
188 },
189 brace: braced!(struct_content in input),
190 fields: struct_content.parse_terminated(Field::parse_named)?,
191 arrow: input.parse()?,
192 conf_brace: braced!(conf_content in input),
193 conf: conf_content.parse_terminated(ConfigurationExpr::parse)?,
194 })
195 }
196}
197
198struct StructOutputConfiguration<'ast> {
199 omitted_fields: LinkedHashSet<String>,
200 included_fields: LinkedHashSet<String>,
201 upsert_fields_names: LinkedHashSet<String>,
202 upsert_fields: Vec<&'ast Field>,
203 attributes: Vec<&'ast Attribute>,
204 is_tuple: bool,
205}
206
207struct TypeArgumentConfiguration<'ast> {
208 arg: &'ast GenericArgument,
209 identifiers: LinkedHashSet<String>,
210}
211
212#[proc_macro]
213pub fn generate(input: TokenStream) -> TokenStream {
214 let StructGen {
215 attrs: top_level_attrs,
216 generics: parsed_generics,
217 where_clause,
218 fields: parsed_fields,
219 conf,
220 visibility,
221 ..
222 } = parse_macro_input!(input as StructGen);
223
224 let structs: Vec<(String, StructOutputConfiguration)> = conf
225 .iter()
226 .map(|c| {
227 let mut omitted_fields = LinkedHashSet::<String>::new();
228 let mut included_fields = LinkedHashSet::<String>::new();
229 let mut upsert_fields = Vec::<&Field>::new();
230 let mut upsert_fields_names = LinkedHashSet::<String>::new();
231 let mut attributes = Vec::<&Attribute>::new();
232 attributes.extend(top_level_attrs.iter());
233 let mut is_tuple = false;
234
235 for a in c.actions.iter() {
236 match &a.fields {
237 ActionVariant::Omit(fields) => {
238 omitted_fields.extend(fields.iter().map(|f| f.to_string()));
239 }
240 ActionVariant::Include(fields) => {
241 included_fields.extend(fields.iter().map(|f| f.to_string()));
242 }
243 ActionVariant::Attr(attrs) => {
244 attributes.extend(attrs.iter());
245 }
246 ActionVariant::Upsert(fields) => {
247 upsert_fields_names
248 .extend(fields.iter().map(|f| f.ident.as_ref().unwrap().to_string()));
249 upsert_fields.extend(fields);
250 }
251 ActionVariant::AsTuple => {
252 is_tuple = true;
253 }
254 }
255 }
256
257 (
258 c.struct_name.to_string(),
259 StructOutputConfiguration {
260 omitted_fields,
261 included_fields,
262 upsert_fields,
263 upsert_fields_names,
264 attributes,
265 is_tuple,
266 },
267 )
268 })
269 .collect();
270
271 let generics: Vec<TypeArgumentConfiguration> = if parsed_generics.is_some() {
272 parsed_generics
273 .as_ref()
274 .unwrap()
275 .args
276 .iter()
277 .map(|arg| {
278 let mut collector = TypeArgumentsCollectorVisitor {
279 ..Default::default()
280 };
281 collector.visit_generic_argument(arg);
282
283 TypeArgumentConfiguration {
284 arg,
285 identifiers: collector.items,
286 }
287 })
288 .collect()
289 } else {
290 Vec::new()
291 };
292
293 let wheres: Vec<(&WherePredicate, Vec<&TypeArgumentConfiguration>)> = if where_clause.is_some()
294 {
295 where_clause
296 .as_ref()
297 .unwrap()
298 .predicates
299 .iter()
300 .map(|p| {
301 let mut collector = TypeArgumentsCheckVisitor {
302 args: &generics,
303 matched: Vec::new(),
304 };
305 collector.visit_where_predicate(&p);
306
307 (p, collector.matched)
308 })
309 .collect()
310 } else {
311 Vec::new()
312 };
313
314 let fields: Vec<(&Field, Vec<&TypeArgumentConfiguration>)> = parsed_fields
315 .iter()
316 .map(|f| {
317 let mut collector = TypeArgumentsCheckVisitor {
318 args: &generics,
319 matched: Vec::new(),
320 };
321 collector.visit_type(&f.ty);
322
323 (f, collector.matched)
324 })
325 .collect();
326
327 let token_streams = structs.iter().map(
328 |(
329 struct_name,
330 StructOutputConfiguration {
331 omitted_fields,
332 attributes,
333 included_fields,
334 upsert_fields,
335 upsert_fields_names,
336 is_tuple
337 },
338 )| {
339 let mut used_fields = LinkedHashSet::<&Field>::new();
340 let mut used_types = LinkedHashSet::<&Type>::new();
341 let mut used_generics = LinkedHashSet::<&GenericArgument>::new();
342 let mut used_wheres = LinkedHashSet::<&WherePredicate>::new();
343
344 let test_skip_predicate: Box<dyn Fn(&Field) -> bool> = if included_fields.is_empty() {
345 Box::new(|f: &Field| {
346 let name = &f.ident.as_ref().unwrap().to_string();
347 upsert_fields_names.contains(name) || omitted_fields.contains(name)
348 })
349 } else {
350 Box::new(|f: &Field| {
351 let name = &f.ident.as_ref().unwrap().to_string();
352 upsert_fields_names.contains(name) || !included_fields.contains(name)
353 })
354 };
355
356 for (f, type_args) in fields.iter() {
357 if test_skip_predicate(f) {
358 continue;
359 }
360
361 if *is_tuple {
362 used_types.insert(&f.ty);
363 } else {
364 used_fields.insert(f);
365 }
366
367 for type_arg in type_args.iter() {
368 used_generics.insert(type_arg.arg);
369
370 for w in wheres.iter() {
371 for w_type_arg in w.1.iter() {
372 if w_type_arg.arg == type_arg.arg {
373 used_wheres.insert(w.0);
374 }
375 }
376 }
377 }
378 }
379 if *is_tuple {
380 used_types.extend(upsert_fields.iter().map(|f| &f.ty));
381 } else {
382 used_fields.extend(upsert_fields.iter());
383 }
384
385 let field_items = Vec::from_iter(used_fields);
386 let type_items = Vec::from_iter(used_types);
387 let generic_items = Vec::from_iter(used_generics);
388 let where_items = Vec::from_iter(used_wheres);
389 let struct_name_ident = Ident::new(struct_name, Span::call_site());
390 if *is_tuple {
391 if where_items.is_empty() {
392 quote! {
393 #(#attributes)*
394 #visibility struct #struct_name_ident <#(#generic_items),*> (#(#type_items),*);
395 }
396 } else {
397 quote! {
398 #(#attributes)*
399 #visibility struct #struct_name_ident <#(#generic_items),*> (#(#type_items),*) where #(#where_items),*;
400 }
401 }
402 } else if where_items.is_empty() {
403 quote! {
404 #(#attributes)*
405 #visibility struct #struct_name_ident <#(#generic_items),*> {
406 #(#field_items),*
407 }
408 }
409 } else {
410 quote! {
411 #(#attributes)*
412 #visibility struct #struct_name_ident <#(#generic_items),*> where #(#where_items),* {
413 #(#field_items),*
414 }
415 }
416 }
417 },
418 );
419
420 (quote! {
421 #(#token_streams)*
422 })
423 .into()
424}
425
426#[cfg(test)]
427mod tests {
428 use path_clean::PathClean;
429 use std::env;
430 use std::path::{Path, PathBuf};
431 use std::process::Command;
432
433 pub fn absolute_path(path: impl AsRef<Path>) -> std::io::Result<PathBuf> {
434 let path = path.as_ref();
435
436 let absolute_path = if path.is_absolute() {
437 path.to_path_buf()
438 } else {
439 env::current_dir()?.join(path)
440 }
441 .clean();
442
443 Ok(absolute_path)
444 }
445
446 fn run_for_fixture(fixture: &str) -> String {
447 let output = Command::new("cargo")
448 .arg("expand")
449 .arg(fixture)
450 .arg("--manifest-path")
451 .arg(format!(
452 "{}",
453 absolute_path("./test_fixtures/testbed/Cargo.toml")
454 .unwrap()
455 .display()
456 ))
457 .output()
458 .expect("Failed to spawn process");
459
460 String::from_utf8_lossy(&output.stdout)
461 .to_owned()
462 .to_string()
463 }
464
465 #[test]
466 fn generics() {
467 insta::assert_snapshot!(run_for_fixture("generics"), @r###"
468 pub mod generics {
469 use structout::generate;
470 struct OnlyBar<T> {
471 bar: T,
472 }
473 struct OnlyFoo {
474 foo: u32,
475 }
476 }
477 "###);
478 }
479
480 #[test]
481 fn wheres() {
482 insta::assert_snapshot!(run_for_fixture("wheres"), @r###"
483 pub mod wheres {
484 use structout::generate;
485 struct OnlyBar<C>
486 where
487 C: Copy,
488 {
489 bar: C,
490 }
491 struct OnlyFoo<S>
492 where
493 S: Sized,
494 {
495 foo: S,
496 }
497 }
498 "###);
499 }
500
501 #[test]
502 fn simple() {
503 insta::assert_snapshot!(run_for_fixture("simple"), @r###"
504 pub mod simple {
505 use structout::generate;
506 struct WithoutFoo {
507 bar: u64,
508 baz: String,
509 }
510 struct WithoutBar {
511 foo: u32,
512 baz: String,
513 }
514 # [object (context = Database)]
515 #[object(config = "latest")]
516 struct WithAttrs {
517 foo: u32,
518 bar: u64,
519 baz: String,
520 }
521 }
522 "###);
523 }
524
525 #[test]
526 fn visibility() {
527 insta::assert_snapshot!(run_for_fixture("visibility"), @r###"
528 pub mod visibility {
529 use structout::generate;
530 pub(crate) struct Everything {
531 foo: u32,
532 }
533 }
534 "###);
535 }
536
537 #[test]
538 fn include() {
539 insta::assert_snapshot!(run_for_fixture("include"), @r###"
540 pub mod include {
541 use structout::generate;
542 struct WithoutFoo {
543 bar: u64,
544 }
545 struct WithoutBar {
546 foo: u32,
547 }
548 }
549 "###);
550 }
551
552 #[test]
553 fn as_tuple() {
554 insta::assert_snapshot!(run_for_fixture("as_tuple"), @r###"
555 pub mod as_tuple {
556 use structout::generate;
557 struct OnlyBar<C>(C, i32)
558 where
559 C: Copy;
560 struct OnlyFoo<S>(S, i32)
561 where
562 S: Sized;
563 }
564 "###);
565 }
566
567 #[test]
568 fn upsert() {
569 insta::assert_snapshot!(run_for_fixture("upsert"), @r###"
570 pub mod upsert {
571 use structout::generate;
572 struct NewFields {
573 foo: u32,
574 bar: i32,
575 baz: i64,
576 }
577 struct OverriddenField {
578 foo: u64,
579 }
580 struct Tupled(u64);
581 }
582
583 "###);
584 }
585
586 #[test]
587 fn shared_attrs() {
588 insta::assert_snapshot!(run_for_fixture("shared_attrs"), @r###"
589 pub mod shared_attrs {
590 use structout::generate;
591 struct InheritsAttributes {
592 foo: u32,
593 }
594 #[automatically_derived]
595 #[allow(unused_qualifications)]
596 impl ::core::fmt::Debug for InheritsAttributes {
597 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
598 match *self {
599 InheritsAttributes {
600 foo: ref __self_0_0,
601 } => {
602 let mut debug_trait_builder = f.debug_struct("InheritsAttributes");
603 let _ = debug_trait_builder.field("foo", &&(*__self_0_0));
604 debug_trait_builder.finish()
605 }
606 }
607 }
608 }
609 struct InheritsAttributesTwo {
610 foo: u32,
611 }
612 #[automatically_derived]
613 #[allow(unused_qualifications)]
614 impl ::core::fmt::Debug for InheritsAttributesTwo {
615 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
616 match *self {
617 InheritsAttributesTwo {
618 foo: ref __self_0_0,
619 } => {
620 let mut debug_trait_builder = f.debug_struct("InheritsAttributesTwo");
621 let _ = debug_trait_builder.field("foo", &&(*__self_0_0));
622 debug_trait_builder.finish()
623 }
624 }
625 }
626 }
627 }
628 "###);
629 }
630}