Skip to main content

spacetimedsl_derive/
lib.rs

1use ident_case::RenameRule;
2use proc_macro::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use spacetimedsl_derive_input::api::Table;
5mod output;
6
7/// Add `#[dsl]` to your structs with `#[table]`
8/// to interact in a more ergonomic way than SpacetimeDB allows you by default.
9#[proc_macro_attribute]
10pub fn dsl(args: TokenStream, item: TokenStream) -> TokenStream {
11    // put this on the struct so we don't get unknown attribute errors
12    let derive_table_helper = derive_table_helper_attr();
13
14    ok_or_compile_error(|| {
15        // Parse the input tokens into a syntax tree
16        let args = proc_macro2::TokenStream::from(args);
17        let mut derive_input: syn::DeriveInput = syn::parse(item)?;
18
19        // Check if this is a singleton table by scanning args for the `singleton` keyword
20        let is_singleton = args.clone().into_iter().any(|token| {
21            if let proc_macro2::TokenTree::Ident(ident) = token {
22                ident == "singleton"
23            } else {
24                false
25            }
26        });
27
28        // For singletons, inject `#[primary_key] id: u8` into the struct
29        if is_singleton {
30            inject_singleton_primary_key(&mut derive_input)?;
31        }
32
33        // Add `derive(SpacetimeDSL)` only if it's not already in the attributes of the item.
34        // If multiple `#[dsl]` attributes are applied to the same `struct` item,
35        // this will ensure that we don't emit multiple conflicting implementations.
36        let first_dsl_attribute = if !derive_input.attrs.contains(&derive_table_helper) {
37            derive_input.attrs.push(derive_table_helper);
38            true
39        } else {
40            false
41        };
42
43        let input = Table::try_parse(args, &derive_input)?;
44
45        // Build the output, possibly using quasi-quotation
46        let output = output::output(&input, first_dsl_attribute)?;
47
48        // Check if this is the last #[dsl] attribute by counting remaining ones
49        let _is_last_dsl_attribute = is_last_dsl_attribute(&derive_input);
50
51        // If this is the last #[dsl] attribute, make all struct fields private
52        // We do this AFTER parsing and generating methods so the setter logic works correctly
53        // TODO: Temporarily disabled to allow public primary key columns
54        // if is_last_dsl_attribute {
55        //     make_struct_fields_private(&mut derive_input);
56        // }
57
58        Ok(proc_macro2::TokenStream::from_iter([
59            quote!(#derive_input),
60            output,
61        ]))
62    })
63}
64
65fn derive_table_helper_attr() -> syn::Attribute {
66    let source = quote!(#[derive(Clone, Debug, PartialEq, spacetimedsl::SpacetimeDSL)]); // TODO: Add PartialOrd if ScheduledAt has implemented it
67
68    syn::parse::Parser::parse2(syn::Attribute::parse_outer, source)
69        .unwrap()
70        .into_iter()
71        .next()
72        .unwrap()
73}
74
75/// Provides helper attributes for `#[dsl]` because proc_macro_attribute's currently don't support them.
76// TODO: Remove if https://github.com/rust-lang/rust/issues/65823 is implemented.
77#[proc_macro_derive(
78    SpacetimeDSL,
79    attributes(create_wrapper, use_wrapper, foreign_key, referenced_by)
80)]
81pub fn table_helper(_input: proc_macro::TokenStream) -> proc_macro::TokenStream {
82    proc_macro::TokenStream::default()
83}
84
85fn ok_or_compile_error<Res: Into<proc_macro::TokenStream>>(
86    f: impl FnOnce() -> syn::Result<Res>,
87) -> proc_macro::TokenStream {
88    match f() {
89        Ok(ok) => ok.into(),
90        Err(e) => e.into_compile_error().into(),
91    }
92}
93
94/// Check if this is the last #[dsl] attribute on the struct.
95/// Each attribute removes itself before the macro function runs, so the last one
96/// will see 0 remaining DSL attributes in the attributes list.
97fn is_last_dsl_attribute(derive_input: &syn::DeriveInput) -> bool {
98    // Find all remaining dsl attributes similar to how integration.rs finds table attributes
99    let mut dsl_attr_count = 0;
100
101    for attr in &derive_input.attrs {
102        // Check for #[dsl(...)] attributes with require_list()
103        if let Ok(list) = attr.meta.require_list() {
104            let path_string = list.path.to_token_stream().to_string();
105            if path_string == "dsl" || path_string == "spacetimedsl :: dsl" {
106                dsl_attr_count += 1;
107            }
108        }
109    }
110
111    // If there are 0 dsl attributes left, this is the last one being processed
112    dsl_attr_count == 0
113}
114
115// TODO: Temporarily disabled to allow public primary key columns
116// /// Make all struct fields private by setting their visibility to Inherited,
117// /// except for fields with #[primary_key] which preserve their original visibility
118// fn make_struct_fields_private(derive_input: &mut syn::DeriveInput) {
119//     if let syn::Data::Struct(data_struct) = &mut derive_input.data
120//         && let syn::Fields::Named(fields) = &mut data_struct.fields
121//     {
122//         for field in &mut fields.named {
123//             // Check if this field has the #[primary_key] attribute
124//             let is_primary_key = field.attrs.iter().any(|attr| {
125//                 attr.path().is_ident("primary_key")
126//             });
127//
128//             // Only make non-primary-key fields private
129//             if !is_primary_key {
130//                 field.vis = syn::Visibility::Inherited;
131//             }
132//         }
133//     }
134// }
135
136/// For singleton tables, inject `#[primary_key] id: u8` as the first field.
137/// Errors if the user already has a field named `id`.
138fn inject_singleton_primary_key(derive_input: &mut syn::DeriveInput) -> syn::Result<()> {
139    if let syn::Data::Struct(data_struct) = &mut derive_input.data
140        && let syn::Fields::Named(fields) = &mut data_struct.fields
141    {
142        // Check if user manually defined an `id` field
143        for field in fields.named.iter() {
144            if let Some(ident) = &field.ident
145                && ident == "id"
146            {
147                return Err(syn::Error::new_spanned(
148                    field,
149                    "Singleton tables automatically add `#[primary_key] id: u8`. Do not define an `id` field manually!",
150                ));
151            }
152        }
153
154        // Create the field: `#[primary_key] id: u8`
155        let pk_field = syn::Field {
156            attrs: vec![syn::parse_quote!(#[primary_key])],
157            vis: syn::Visibility::Inherited,
158            mutability: syn::FieldMutability::None,
159            ident: Some(syn::Ident::new("id", proc_macro2::Span::call_site())),
160            colon_token: Some(syn::token::Colon::default()),
161            ty: syn::parse_quote!(u8),
162        };
163
164        // Insert as the first field
165        fields.named.insert(0, pk_field);
166    } else {
167        return Err(syn::Error::new(
168            proc_macro2::Span::call_site(),
169            "Singleton tables must be structs with named fields!",
170        ));
171    }
172
173    Ok(())
174}
175
176//region Hooks
177
178/// Add `#[hook]` to your functions to add the trait implementation line required for SpacetimeDSL hooks to work.
179#[proc_macro_attribute]
180pub fn hook(_args: TokenStream, item: TokenStream) -> TokenStream {
181    ok_or_compile_error(|| {
182        let function_input: syn::ItemFn = syn::parse(item)?;
183
184        let trait_name = format_ident!(
185            "{}Hook",
186            RenameRule::PascalCase.apply_to_field(function_input.sig.ident.to_string())
187        );
188
189        Ok(quote! {
190            impl<T: spacetimedsl::WriteContext> #trait_name<T> for spacetimedsl::DSLMethodHooks {
191                #function_input
192            }
193        })
194    })
195}
196
197//endregion Hooks