structout/
lib.rs

1#![allow(clippy::eval_order_dependence)]
2extern crate proc_macro;
3
4// LinkedHashSet is used instead of HashSet in order to insertion order across the board
5use linked_hash_set::LinkedHashSet;
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use quote::quote;
9use std::iter::FromIterator;
10use syn::visit::Visit;
11use syn::{
12    braced, bracketed, parenthesized,
13    parse::{Parse, ParseStream},
14    parse_macro_input,
15    punctuated::Punctuated,
16    token, Attribute, Field, GenericArgument, Ident, Result, Token, Type, Visibility, WhereClause,
17    WherePredicate,
18};
19
20#[derive(Default)]
21struct TypeArgumentsCollectorVisitor {
22    items: LinkedHashSet<String>,
23}
24
25impl<'ast> Visit<'ast> for TypeArgumentsCollectorVisitor {
26    fn visit_ident(&mut self, id: &'ast Ident) {
27        self.items.insert(id.to_string());
28    }
29}
30
31struct TypeArgumentsCheckVisitor<'ast> {
32    args: &'ast Vec<TypeArgumentConfiguration<'ast>>,
33    matched: Vec<&'ast TypeArgumentConfiguration<'ast>>,
34}
35
36impl<'ast> Visit<'ast> for TypeArgumentsCheckVisitor<'ast> {
37    fn visit_ident(&mut self, id: &'ast Ident) {
38        let name = &id.to_string();
39        for arg in self.args.iter() {
40            for id in arg.identifiers.iter() {
41                if id == name {
42                    self.matched.push(arg);
43                }
44            }
45        }
46    }
47}
48
49struct Generics {
50    #[allow(dead_code)]
51    start: Token![<],
52    args: Punctuated<GenericArgument, Token![,]>,
53    #[allow(dead_code)]
54    end: Token![>],
55}
56
57impl Parse for Generics {
58    fn parse(input: ParseStream) -> Result<Self> {
59        Ok(Generics {
60            start: input.parse()?,
61            args: {
62                let mut args = Punctuated::new();
63                loop {
64                    if input.peek(Token![>]) {
65                        break;
66                    }
67                    let value = input.parse()?;
68                    args.push_value(value);
69                    if input.peek(Token![>]) {
70                        break;
71                    }
72                    let punct = input.parse()?;
73                    args.push_punct(punct);
74                }
75                args
76            },
77            end: input.parse()?,
78        })
79    }
80}
81
82enum ActionVariant {
83    Omit(Punctuated<Ident, Token![,]>),
84    Include(Punctuated<Ident, Token![,]>),
85    Attr(Punctuated<Attribute, Token![,]>),
86    Upsert(Punctuated<Field, Token![,]>),
87    AsTuple,
88}
89
90struct Action {
91    #[allow(dead_code)]
92    parens: token::Paren,
93    fields: ActionVariant,
94}
95
96impl Parse for Action {
97    fn parse(input: ParseStream) -> Result<Self> {
98        let content;
99        let name: Ident = input.parse()?;
100        let name_str = &name.to_string();
101
102        Ok(Action {
103            parens: parenthesized!(content in input),
104            fields: {
105                if name_str == "omit" {
106                    ActionVariant::Omit(content.parse_terminated(Ident::parse)?)
107                } else if name_str == "include" {
108                    ActionVariant::Include(content.parse_terminated(Ident::parse)?)
109                } else if name_str == "as_tuple" {
110                    ActionVariant::AsTuple
111                } else if name_str == "attr" {
112                    use syn::parse_quote::ParseQuote;
113                    ActionVariant::Attr(content.parse_terminated(Attribute::parse)?)
114                } else if name_str == "upsert" {
115                    ActionVariant::Upsert(content.parse_terminated(Field::parse_named)?)
116                } else {
117                    panic!("{} is not a valid action", name_str)
118                }
119            },
120        })
121    }
122}
123
124struct ConfigurationExpr {
125    struct_name: Ident,
126    #[allow(dead_code)]
127    arrow: Token![=>],
128    #[allow(dead_code)]
129    bracket: token::Bracket,
130    actions: Punctuated<Action, Token![,]>,
131}
132
133impl Parse for ConfigurationExpr {
134    fn parse(input: ParseStream) -> Result<Self> {
135        let struct_content;
136
137        Ok(ConfigurationExpr {
138            struct_name: input.parse()?,
139            arrow: input.parse::<Token![=>]>()?,
140            bracket: bracketed!(struct_content in input),
141            actions: struct_content.parse_terminated(Action::parse)?,
142        })
143    }
144}
145
146struct StructGen {
147    attrs: Vec<Attribute>,
148    visibility: Option<Visibility>,
149    generics: Option<Generics>,
150    where_clause: Option<WhereClause>,
151    #[allow(dead_code)]
152    brace: token::Brace,
153    fields: Punctuated<Field, Token![,]>,
154    #[allow(dead_code)]
155    arrow: token::FatArrow,
156    #[allow(dead_code)]
157    conf_brace: token::Brace,
158    conf: Punctuated<ConfigurationExpr, Token![,]>,
159}
160
161impl Parse for StructGen {
162    fn parse(input: ParseStream) -> Result<Self> {
163        let struct_content;
164        let conf_content;
165
166        Ok(StructGen {
167            attrs: input.call(Attribute::parse_outer)?,
168            visibility: {
169                if input.lookahead1().peek(Token![pub]) {
170                    Some(input.parse()?)
171                } else {
172                    None
173                }
174            },
175            generics: {
176                if input.lookahead1().peek(Token![<]) {
177                    Some(input.parse()?)
178                } else {
179                    None
180                }
181            },
182            where_clause: {
183                if input.lookahead1().peek(Token![where]) {
184                    Some(input.parse()?)
185                } else {
186                    None
187                }
188            },
189            brace: braced!(struct_content in input),
190            fields: struct_content.parse_terminated(Field::parse_named)?,
191            arrow: input.parse()?,
192            conf_brace: braced!(conf_content in input),
193            conf: conf_content.parse_terminated(ConfigurationExpr::parse)?,
194        })
195    }
196}
197
198struct StructOutputConfiguration<'ast> {
199    omitted_fields: LinkedHashSet<String>,
200    included_fields: LinkedHashSet<String>,
201    upsert_fields_names: LinkedHashSet<String>,
202    upsert_fields: Vec<&'ast Field>,
203    attributes: Vec<&'ast Attribute>,
204    is_tuple: bool,
205}
206
207struct TypeArgumentConfiguration<'ast> {
208    arg: &'ast GenericArgument,
209    identifiers: LinkedHashSet<String>,
210}
211
212#[proc_macro]
213pub fn generate(input: TokenStream) -> TokenStream {
214    let StructGen {
215        attrs: top_level_attrs,
216        generics: parsed_generics,
217        where_clause,
218        fields: parsed_fields,
219        conf,
220        visibility,
221        ..
222    } = parse_macro_input!(input as StructGen);
223
224    let structs: Vec<(String, StructOutputConfiguration)> = conf
225        .iter()
226        .map(|c| {
227            let mut omitted_fields = LinkedHashSet::<String>::new();
228            let mut included_fields = LinkedHashSet::<String>::new();
229            let mut upsert_fields = Vec::<&Field>::new();
230            let mut upsert_fields_names = LinkedHashSet::<String>::new();
231            let mut attributes = Vec::<&Attribute>::new();
232            attributes.extend(top_level_attrs.iter());
233            let mut is_tuple = false;
234
235            for a in c.actions.iter() {
236                match &a.fields {
237                    ActionVariant::Omit(fields) => {
238                        omitted_fields.extend(fields.iter().map(|f| f.to_string()));
239                    }
240                    ActionVariant::Include(fields) => {
241                        included_fields.extend(fields.iter().map(|f| f.to_string()));
242                    }
243                    ActionVariant::Attr(attrs) => {
244                        attributes.extend(attrs.iter());
245                    }
246                    ActionVariant::Upsert(fields) => {
247                        upsert_fields_names
248                            .extend(fields.iter().map(|f| f.ident.as_ref().unwrap().to_string()));
249                        upsert_fields.extend(fields);
250                    }
251                    ActionVariant::AsTuple => {
252                        is_tuple = true;
253                    }
254                }
255            }
256
257            (
258                c.struct_name.to_string(),
259                StructOutputConfiguration {
260                    omitted_fields,
261                    included_fields,
262                    upsert_fields,
263                    upsert_fields_names,
264                    attributes,
265                    is_tuple,
266                },
267            )
268        })
269        .collect();
270
271    let generics: Vec<TypeArgumentConfiguration> = if parsed_generics.is_some() {
272        parsed_generics
273            .as_ref()
274            .unwrap()
275            .args
276            .iter()
277            .map(|arg| {
278                let mut collector = TypeArgumentsCollectorVisitor {
279                    ..Default::default()
280                };
281                collector.visit_generic_argument(arg);
282
283                TypeArgumentConfiguration {
284                    arg,
285                    identifiers: collector.items,
286                }
287            })
288            .collect()
289    } else {
290        Vec::new()
291    };
292
293    let wheres: Vec<(&WherePredicate, Vec<&TypeArgumentConfiguration>)> = if where_clause.is_some()
294    {
295        where_clause
296            .as_ref()
297            .unwrap()
298            .predicates
299            .iter()
300            .map(|p| {
301                let mut collector = TypeArgumentsCheckVisitor {
302                    args: &generics,
303                    matched: Vec::new(),
304                };
305                collector.visit_where_predicate(&p);
306
307                (p, collector.matched)
308            })
309            .collect()
310    } else {
311        Vec::new()
312    };
313
314    let fields: Vec<(&Field, Vec<&TypeArgumentConfiguration>)> = parsed_fields
315        .iter()
316        .map(|f| {
317            let mut collector = TypeArgumentsCheckVisitor {
318                args: &generics,
319                matched: Vec::new(),
320            };
321            collector.visit_type(&f.ty);
322
323            (f, collector.matched)
324        })
325        .collect();
326
327    let token_streams = structs.iter().map(
328        |(
329            struct_name,
330            StructOutputConfiguration {
331                omitted_fields,
332                attributes,
333                included_fields,
334                upsert_fields,
335                upsert_fields_names,
336                is_tuple
337            },
338        )| {
339            let mut used_fields = LinkedHashSet::<&Field>::new();
340            let mut used_types = LinkedHashSet::<&Type>::new();
341            let mut used_generics = LinkedHashSet::<&GenericArgument>::new();
342            let mut used_wheres = LinkedHashSet::<&WherePredicate>::new();
343
344            let test_skip_predicate: Box<dyn Fn(&Field) -> bool> = if included_fields.is_empty() {
345                Box::new(|f: &Field| {
346                    let name = &f.ident.as_ref().unwrap().to_string();
347                    upsert_fields_names.contains(name) || omitted_fields.contains(name)
348                })
349            } else {
350                Box::new(|f: &Field| {
351                    let name = &f.ident.as_ref().unwrap().to_string();
352                    upsert_fields_names.contains(name) || !included_fields.contains(name)
353                })
354            };
355
356            for (f, type_args) in fields.iter() {
357                if test_skip_predicate(f) {
358                    continue;
359                }
360
361                if *is_tuple {
362                    used_types.insert(&f.ty);
363                } else {
364                    used_fields.insert(f);
365                }
366
367                for type_arg in type_args.iter() {
368                    used_generics.insert(type_arg.arg);
369
370                    for w in wheres.iter() {
371                        for w_type_arg in w.1.iter() {
372                            if w_type_arg.arg == type_arg.arg {
373                                used_wheres.insert(w.0);
374                            }
375                        }
376                    }
377                }
378            }
379            if *is_tuple {
380                used_types.extend(upsert_fields.iter().map(|f| &f.ty));
381            } else {
382                used_fields.extend(upsert_fields.iter());
383            }
384
385            let field_items = Vec::from_iter(used_fields);
386            let type_items = Vec::from_iter(used_types);
387            let generic_items = Vec::from_iter(used_generics);
388            let where_items = Vec::from_iter(used_wheres);
389            let struct_name_ident = Ident::new(struct_name, Span::call_site());
390            if *is_tuple {
391                if where_items.is_empty() {
392                    quote! {
393                        #(#attributes)*
394                        #visibility struct #struct_name_ident <#(#generic_items),*> (#(#type_items),*);
395                    }
396                } else {
397                    quote! {
398                        #(#attributes)*
399                        #visibility struct #struct_name_ident <#(#generic_items),*> (#(#type_items),*) where #(#where_items),*;
400                    }
401                }
402            } else if where_items.is_empty() {
403                quote! {
404                    #(#attributes)*
405                    #visibility struct #struct_name_ident <#(#generic_items),*> {
406                        #(#field_items),*
407                    }
408                }
409            } else {
410                quote! {
411                    #(#attributes)*
412                    #visibility struct #struct_name_ident <#(#generic_items),*> where #(#where_items),* {
413                        #(#field_items),*
414                    }
415                }
416            }
417        },
418    );
419
420    (quote! {
421       #(#token_streams)*
422    })
423    .into()
424}
425
426#[cfg(test)]
427mod tests {
428    use path_clean::PathClean;
429    use std::env;
430    use std::path::{Path, PathBuf};
431    use std::process::Command;
432
433    pub fn absolute_path(path: impl AsRef<Path>) -> std::io::Result<PathBuf> {
434        let path = path.as_ref();
435
436        let absolute_path = if path.is_absolute() {
437            path.to_path_buf()
438        } else {
439            env::current_dir()?.join(path)
440        }
441        .clean();
442
443        Ok(absolute_path)
444    }
445
446    fn run_for_fixture(fixture: &str) -> String {
447        let output = Command::new("cargo")
448            .arg("expand")
449            .arg(fixture)
450            .arg("--manifest-path")
451            .arg(format!(
452                "{}",
453                absolute_path("./test_fixtures/testbed/Cargo.toml")
454                    .unwrap()
455                    .display()
456            ))
457            .output()
458            .expect("Failed to spawn process");
459
460        String::from_utf8_lossy(&output.stdout)
461            .to_owned()
462            .to_string()
463    }
464
465    #[test]
466    fn generics() {
467        insta::assert_snapshot!(run_for_fixture("generics"), @r###"
468        pub mod generics {
469            use structout::generate;
470            struct OnlyBar<T> {
471                bar: T,
472            }
473            struct OnlyFoo {
474                foo: u32,
475            }
476        }
477        "###);
478    }
479
480    #[test]
481    fn wheres() {
482        insta::assert_snapshot!(run_for_fixture("wheres"), @r###"
483        pub mod wheres {
484            use structout::generate;
485            struct OnlyBar<C>
486            where
487                C: Copy,
488            {
489                bar: C,
490            }
491            struct OnlyFoo<S>
492            where
493                S: Sized,
494            {
495                foo: S,
496            }
497        }
498        "###);
499    }
500
501    #[test]
502    fn simple() {
503        insta::assert_snapshot!(run_for_fixture("simple"), @r###"
504        pub mod simple {
505            use structout::generate;
506            struct WithoutFoo {
507                bar: u64,
508                baz: String,
509            }
510            struct WithoutBar {
511                foo: u32,
512                baz: String,
513            }
514            # [object (context = Database)]
515            #[object(config = "latest")]
516            struct WithAttrs {
517                foo: u32,
518                bar: u64,
519                baz: String,
520            }
521        }
522        "###);
523    }
524
525    #[test]
526    fn visibility() {
527        insta::assert_snapshot!(run_for_fixture("visibility"), @r###"
528        pub mod visibility {
529            use structout::generate;
530            pub(crate) struct Everything {
531                foo: u32,
532            }
533        }
534        "###);
535    }
536
537    #[test]
538    fn include() {
539        insta::assert_snapshot!(run_for_fixture("include"), @r###"
540        pub mod include {
541            use structout::generate;
542            struct WithoutFoo {
543                bar: u64,
544            }
545            struct WithoutBar {
546                foo: u32,
547            }
548        }
549        "###);
550    }
551
552    #[test]
553    fn as_tuple() {
554        insta::assert_snapshot!(run_for_fixture("as_tuple"), @r###"
555        pub mod as_tuple {
556            use structout::generate;
557            struct OnlyBar<C>(C, i32)
558            where
559                C: Copy;
560            struct OnlyFoo<S>(S, i32)
561            where
562                S: Sized;
563        }
564        "###);
565    }
566
567    #[test]
568    fn upsert() {
569        insta::assert_snapshot!(run_for_fixture("upsert"), @r###"
570        pub mod upsert {
571            use structout::generate;
572            struct NewFields {
573                foo: u32,
574                bar: i32,
575                baz: i64,
576            }
577            struct OverriddenField {
578                foo: u64,
579            }
580            struct Tupled(u64);
581        }
582        
583        "###);
584    }
585
586    #[test]
587    fn shared_attrs() {
588        insta::assert_snapshot!(run_for_fixture("shared_attrs"), @r###"
589        pub mod shared_attrs {
590            use structout::generate;
591            struct InheritsAttributes {
592                foo: u32,
593            }
594            #[automatically_derived]
595            #[allow(unused_qualifications)]
596            impl ::core::fmt::Debug for InheritsAttributes {
597                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
598                    match *self {
599                        InheritsAttributes {
600                            foo: ref __self_0_0,
601                        } => {
602                            let mut debug_trait_builder = f.debug_struct("InheritsAttributes");
603                            let _ = debug_trait_builder.field("foo", &&(*__self_0_0));
604                            debug_trait_builder.finish()
605                        }
606                    }
607                }
608            }
609            struct InheritsAttributesTwo {
610                foo: u32,
611            }
612            #[automatically_derived]
613            #[allow(unused_qualifications)]
614            impl ::core::fmt::Debug for InheritsAttributesTwo {
615                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
616                    match *self {
617                        InheritsAttributesTwo {
618                            foo: ref __self_0_0,
619                        } => {
620                            let mut debug_trait_builder = f.debug_struct("InheritsAttributesTwo");
621                            let _ = debug_trait_builder.field("foo", &&(*__self_0_0));
622                            debug_trait_builder.finish()
623                        }
624                    }
625                }
626            }
627        }
628        "###);
629    }
630}