spacetimedsl_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{ToTokens, quote};
3use spacetimedsl_derive_input::api::Table;
4mod output;
5
6/// Add `#[dsl]` to your structs with `#[table]`
7/// to interact in a more ergonomic way than SpacetimeDB allows you by default.
8#[proc_macro_attribute]
9pub fn dsl(args: TokenStream, item: TokenStream) -> TokenStream {
10    // put this on the struct so we don't get unknown attribute errors
11    let derive_table_helper = derive_table_helper_attr();
12
13    ok_or_compile_error(|| {
14        // Parse the input tokens into a syntax tree
15        let args = proc_macro2::TokenStream::from(args);
16        let mut derive_input: syn::DeriveInput = syn::parse(item)?;
17
18        // Add `derive(SpacetimeDSL)` only if it's not already in the attributes of the item.
19        // If multiple `#[dsl]` attributes are applied to the same `struct` item,
20        // this will ensure that we don't emit multiple conflicting implementations.
21        let first_dsl_attribute = if !derive_input.attrs.contains(&derive_table_helper) {
22            derive_input.attrs.push(derive_table_helper);
23            true
24        } else {
25            false
26        };
27
28        let input = Table::try_parse(args, &derive_input)?;
29
30        // Build the output, possibly using quasi-quotation
31        let output = output::output(&input, first_dsl_attribute)?;
32
33        // Check if this is the last #[dsl] attribute by counting remaining ones
34        let is_last_dsl_attribute = is_last_dsl_attribute(&derive_input);
35
36        // If this is the last #[dsl] attribute, make all struct fields private
37        // We do this AFTER parsing and generating methods so the setter logic works correctly
38        if is_last_dsl_attribute {
39            make_struct_fields_private(&mut derive_input);
40        }
41
42        Ok(proc_macro2::TokenStream::from_iter([
43            quote!(#derive_input),
44            output,
45        ]))
46    })
47}
48
49fn derive_table_helper_attr() -> syn::Attribute {
50    let source = quote!(#[derive(Clone, Debug, PartialEq, spacetimedsl::SpacetimeDSL)]); // TODO: Add PartialOrd if ScheduledAt has implemented it
51
52    syn::parse::Parser::parse2(syn::Attribute::parse_outer, source)
53        .unwrap()
54        .into_iter()
55        .next()
56        .unwrap()
57}
58
59/// Provides helper attributes for `#[dsl]` because proc_macro_attribute's currently don't support them.
60// TODO: Remove if https://github.com/rust-lang/rust/issues/65823 is implemented.
61#[proc_macro_derive(
62    SpacetimeDSL,
63    attributes(create_wrapper, use_wrapper, foreign_key, referenced_by)
64)]
65pub fn table_helper(_input: proc_macro::TokenStream) -> proc_macro::TokenStream {
66    proc_macro::TokenStream::default()
67}
68
69fn ok_or_compile_error<Res: Into<proc_macro::TokenStream>>(
70    f: impl FnOnce() -> syn::Result<Res>,
71) -> proc_macro::TokenStream {
72    match f() {
73        Ok(ok) => ok.into(),
74        Err(e) => e.into_compile_error().into(),
75    }
76}
77
78/// Check if this is the last #[dsl] attribute on the struct.
79/// Each attribute removes itself before the macro function runs, so the last one
80/// will see 0 remaining DSL attributes in the attributes list.
81fn is_last_dsl_attribute(derive_input: &syn::DeriveInput) -> bool {
82    // Find all remaining dsl attributes similar to how integration.rs finds table attributes
83    let mut dsl_attr_count = 0;
84
85    for attr in &derive_input.attrs {
86        // Check for #[dsl(...)] attributes with require_list()
87        if let Ok(list) = attr.meta.require_list() {
88            let path_string = list.path.to_token_stream().to_string();
89            if path_string == "dsl" || path_string == "spacetimedsl :: dsl" {
90                dsl_attr_count += 1;
91            }
92        }
93    }
94
95    // If there are 0 dsl attributes left, this is the last one being processed
96    dsl_attr_count == 0
97}
98
99/// Make all struct fields private by setting their visibility to Inherited
100fn make_struct_fields_private(derive_input: &mut syn::DeriveInput) {
101    if let syn::Data::Struct(data_struct) = &mut derive_input.data {
102        if let syn::Fields::Named(fields) = &mut data_struct.fields {
103            for field in &mut fields.named {
104                field.vis = syn::Visibility::Inherited;
105            }
106        }
107    }
108}