spacetimedsl_derive/
lib.rs

1use ident_case::RenameRule;
2use proc_macro::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use spacetimedsl_derive_input::api::Table;
5mod output;
6
7/// Add `#[dsl]` to your structs with `#[table]`
8/// to interact in a more ergonomic way than SpacetimeDB allows you by default.
9#[proc_macro_attribute]
10pub fn dsl(args: TokenStream, item: TokenStream) -> TokenStream {
11    // put this on the struct so we don't get unknown attribute errors
12    let derive_table_helper = derive_table_helper_attr();
13
14    ok_or_compile_error(|| {
15        // Parse the input tokens into a syntax tree
16        let args = proc_macro2::TokenStream::from(args);
17        let mut derive_input: syn::DeriveInput = syn::parse(item)?;
18
19        // Add `derive(SpacetimeDSL)` only if it's not already in the attributes of the item.
20        // If multiple `#[dsl]` attributes are applied to the same `struct` item,
21        // this will ensure that we don't emit multiple conflicting implementations.
22        let first_dsl_attribute = if !derive_input.attrs.contains(&derive_table_helper) {
23            derive_input.attrs.push(derive_table_helper);
24            true
25        } else {
26            false
27        };
28
29        let input = Table::try_parse(args, &derive_input)?;
30
31        // Build the output, possibly using quasi-quotation
32        let output = output::output(&input, first_dsl_attribute)?;
33
34        // Check if this is the last #[dsl] attribute by counting remaining ones
35        let is_last_dsl_attribute = is_last_dsl_attribute(&derive_input);
36
37        // If this is the last #[dsl] attribute, make all struct fields private
38        // We do this AFTER parsing and generating methods so the setter logic works correctly
39        if is_last_dsl_attribute {
40            make_struct_fields_private(&mut derive_input);
41        }
42
43        Ok(proc_macro2::TokenStream::from_iter([
44            quote!(#derive_input),
45            output,
46        ]))
47    })
48}
49
50fn derive_table_helper_attr() -> syn::Attribute {
51    let source = quote!(#[derive(Clone, Debug, PartialEq, spacetimedsl::SpacetimeDSL)]); // TODO: Add PartialOrd if ScheduledAt has implemented it
52
53    syn::parse::Parser::parse2(syn::Attribute::parse_outer, source)
54        .unwrap()
55        .into_iter()
56        .next()
57        .unwrap()
58}
59
60/// Provides helper attributes for `#[dsl]` because proc_macro_attribute's currently don't support them.
61// TODO: Remove if https://github.com/rust-lang/rust/issues/65823 is implemented.
62#[proc_macro_derive(
63    SpacetimeDSL,
64    attributes(create_wrapper, use_wrapper, foreign_key, referenced_by)
65)]
66pub fn table_helper(_input: proc_macro::TokenStream) -> proc_macro::TokenStream {
67    proc_macro::TokenStream::default()
68}
69
70fn ok_or_compile_error<Res: Into<proc_macro::TokenStream>>(
71    f: impl FnOnce() -> syn::Result<Res>,
72) -> proc_macro::TokenStream {
73    match f() {
74        Ok(ok) => ok.into(),
75        Err(e) => e.into_compile_error().into(),
76    }
77}
78
79/// Check if this is the last #[dsl] attribute on the struct.
80/// Each attribute removes itself before the macro function runs, so the last one
81/// will see 0 remaining DSL attributes in the attributes list.
82fn is_last_dsl_attribute(derive_input: &syn::DeriveInput) -> bool {
83    // Find all remaining dsl attributes similar to how integration.rs finds table attributes
84    let mut dsl_attr_count = 0;
85
86    for attr in &derive_input.attrs {
87        // Check for #[dsl(...)] attributes with require_list()
88        if let Ok(list) = attr.meta.require_list() {
89            let path_string = list.path.to_token_stream().to_string();
90            if path_string == "dsl" || path_string == "spacetimedsl :: dsl" {
91                dsl_attr_count += 1;
92            }
93        }
94    }
95
96    // If there are 0 dsl attributes left, this is the last one being processed
97    dsl_attr_count == 0
98}
99
100/// Make all struct fields private by setting their visibility to Inherited
101fn make_struct_fields_private(derive_input: &mut syn::DeriveInput) {
102    if let syn::Data::Struct(data_struct) = &mut derive_input.data
103        && let syn::Fields::Named(fields) = &mut data_struct.fields
104    {
105        for field in &mut fields.named {
106            field.vis = syn::Visibility::Inherited;
107        }
108    }
109}
110
111//region Hooks
112
113/// Add `#[hook]` to your functions to add the trait implementation line required for SpacetimeDSL hooks to work.
114#[proc_macro_attribute]
115pub fn hook(_args: TokenStream, item: TokenStream) -> TokenStream {
116    ok_or_compile_error(|| {
117        let function_input: syn::ItemFn = syn::parse(item)?;
118
119        let trait_name = format_ident!(
120            "{}Hook",
121            RenameRule::PascalCase.apply_to_field(function_input.sig.ident.to_string())
122        );
123
124        Ok(quote! {
125            impl #trait_name for spacetimedsl::DSLMethodHooks {
126                #function_input
127            }
128        })
129    })
130}
131
132//endregion Hooks