spacetimedb_bindings_macro/
lib.rs

1//! Defines procedural macros like `#[spacetimedb::table]`,
2//! simplifying writing SpacetimeDB modules in Rust.
3
4// DO NOT WRITE (public) DOCS IN THIS MODULE.
5// Docs should be written in the `spacetimedb` crate (i.e. `bindings/`) at reexport sites
6// using `#[doc(inline)]`.
7// We do this so that links to library traits, structs, etc can resolve correctly.
8//
9// (private documentation for the macro authors is totally fine here and you SHOULD write that!)
10
11mod procedure;
12mod reducer;
13mod sats;
14mod table;
15mod util;
16mod view;
17
18use proc_macro::TokenStream as StdTokenStream;
19use proc_macro2::TokenStream;
20use quote::quote;
21use std::time::Duration;
22use syn::{parse::ParseStream, Attribute};
23use syn::{ItemConst, ItemFn};
24use util::{cvt_attr, ok_or_compile_error};
25
26mod sym {
27    /// A symbol known at compile-time against
28    /// which identifiers and paths may be matched.
29    pub struct Symbol(&'static str);
30
31    macro_rules! symbol {
32        ($ident:ident) => {
33            symbol!($ident, $ident);
34        };
35        ($const:ident, $ident:ident) => {
36            #[allow(non_upper_case_globals)]
37            #[doc = concat!("Matches `", stringify!($ident), "`.")]
38            pub const $const: Symbol = Symbol(stringify!($ident));
39        };
40    }
41
42    symbol!(at);
43    symbol!(auto_inc);
44    symbol!(btree);
45    symbol!(client_connected);
46    symbol!(client_disconnected);
47    symbol!(column);
48    symbol!(columns);
49    symbol!(crate_, crate);
50    symbol!(direct);
51    symbol!(hash);
52    symbol!(index);
53    symbol!(init);
54    symbol!(name);
55    symbol!(primary_key);
56    symbol!(private);
57    symbol!(public);
58    symbol!(repr);
59    symbol!(sats);
60    symbol!(scheduled);
61    symbol!(unique);
62    symbol!(update);
63    symbol!(default);
64
65    symbol!(u8);
66    symbol!(i8);
67    symbol!(u16);
68    symbol!(i16);
69    symbol!(u32);
70    symbol!(i32);
71    symbol!(u64);
72    symbol!(i64);
73    symbol!(u128);
74    symbol!(i128);
75    symbol!(f32);
76    symbol!(f64);
77
78    impl PartialEq<Symbol> for syn::Ident {
79        fn eq(&self, sym: &Symbol) -> bool {
80            self == sym.0
81        }
82    }
83    impl PartialEq<Symbol> for &syn::Ident {
84        fn eq(&self, sym: &Symbol) -> bool {
85            *self == sym.0
86        }
87    }
88    impl PartialEq<Symbol> for syn::Path {
89        fn eq(&self, sym: &Symbol) -> bool {
90            self.is_ident(sym)
91        }
92    }
93    impl PartialEq<Symbol> for &syn::Path {
94        fn eq(&self, sym: &Symbol) -> bool {
95            self.is_ident(sym)
96        }
97    }
98    impl std::fmt::Display for Symbol {
99        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100            f.write_str(self.0)
101        }
102    }
103    impl std::borrow::Borrow<str> for Symbol {
104        fn borrow(&self) -> &str {
105            self.0
106        }
107    }
108}
109
110#[proc_macro_attribute]
111pub fn procedure(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
112    cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
113        let args = procedure::ProcedureArgs::parse(args)?;
114        procedure::procedure_impl(args, original_function)
115    })
116}
117
118#[proc_macro_attribute]
119pub fn reducer(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
120    cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
121        let args = reducer::ReducerArgs::parse(args)?;
122        reducer::reducer_impl(args, original_function)
123    })
124}
125
126#[proc_macro_attribute]
127pub fn view(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
128    cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
129        let args = view::ViewArgs::parse(args, &original_function.sig.ident)?;
130        view::view_impl(args, original_function)
131    })
132}
133
134/// It turns out to be shockingly difficult to construct an [`Attribute`].
135/// That type is not [`Parse`], instead having two distinct methods
136/// for parsing "inner" vs "outer" attributes.
137///
138/// We need this [`Attribute`] in [`table`] so that we can "pushnew" it
139/// onto the end of a list of attributes. See comments within [`table`].
140fn derive_table_helper_attr() -> Attribute {
141    let source = quote!(#[derive(spacetimedb::__TableHelper)]);
142
143    syn::parse::Parser::parse2(Attribute::parse_outer, source)
144        .unwrap()
145        .into_iter()
146        .next()
147        .unwrap()
148}
149
150#[proc_macro_attribute]
151pub fn table(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
152    // put this on the struct so we don't get unknown attribute errors
153    let derive_table_helper: syn::Attribute = derive_table_helper_attr();
154
155    ok_or_compile_error(|| {
156        let item = TokenStream::from(item);
157        let mut derive_input: syn::DeriveInput = syn::parse2(item.clone())?;
158
159        // Add `derive(__TableHelper)` only if it's not already in the attributes of the `derive_input.`
160        // If multiple `#[table]` attributes are applied to the same `struct` item,
161        // this will ensure that we don't emit multiple conflicting implementations
162        // for traits like `SpacetimeType`, `Serialize` and `Deserialize`.
163        //
164        // We need to push at the end, rather than the beginning,
165        // because rustc expands attribute macros (including derives) top-to-bottom,
166        // and we need *all* `#[table]` attributes *before* the `derive(__TableHelper)`.
167        // This way, the first `table` will insert a `derive(__TableHelper)`,
168        // and all subsequent `#[table]`s on the same `struct` will see it,
169        // and not add another.
170        //
171        // Note, thank goodness, that `syn`'s `PartialEq` impls (provided with the `extra-traits` feature)
172        // skip any [`Span`]s contained in the items,
173        // thereby comparing for syntactic rather than structural equality. This shouldn't matter,
174        // since we expect that the `derive_table_helper` will always have the same [`Span`]s,
175        // but it's nice to know.
176        if !derive_input.attrs.contains(&derive_table_helper) {
177            derive_input.attrs.push(derive_table_helper);
178        }
179
180        let args = table::TableArgs::parse(args.into(), &derive_input.ident)?;
181        let generated = table::table_impl(args, &derive_input)?;
182        Ok(TokenStream::from_iter([quote!(#derive_input), generated]))
183    })
184}
185
186/// Special alias for `derive(SpacetimeType)`, aka [`schema_type`], for use by [`table`].
187///
188/// Provides helper attributes for `#[spacetimedb::table]`, so that we don't get unknown attribute errors.
189#[doc(hidden)]
190#[proc_macro_derive(__TableHelper, attributes(sats, unique, auto_inc, primary_key, index, default))]
191pub fn table_helper(input: StdTokenStream) -> StdTokenStream {
192    schema_type(input)
193}
194
195#[proc_macro]
196pub fn duration(input: StdTokenStream) -> StdTokenStream {
197    let dur = syn::parse_macro_input!(input with parse_duration);
198    let (secs, nanos) = (dur.as_secs(), dur.subsec_nanos());
199    quote!({
200        const DUR: ::core::time::Duration = ::core::time::Duration::new(#secs, #nanos);
201        DUR
202    })
203    .into()
204}
205
206fn parse_duration(input: ParseStream) -> syn::Result<Duration> {
207    let lookahead = input.lookahead1();
208    let (s, span) = if lookahead.peek(syn::LitStr) {
209        let s = input.parse::<syn::LitStr>()?;
210        (s.value(), s.span())
211    } else if lookahead.peek(syn::LitInt) {
212        let i = input.parse::<syn::LitInt>()?;
213        (i.to_string(), i.span())
214    } else {
215        return Err(lookahead.error());
216    };
217    humantime::parse_duration(&s).map_err(|e| syn::Error::new(span, format_args!("can't parse as duration: {e}")))
218}
219
220/// A helper for the common bits of the derive macros.
221fn sats_derive(
222    input: StdTokenStream,
223    assume_in_module: bool,
224    logic: impl FnOnce(&sats::SatsType) -> TokenStream,
225) -> StdTokenStream {
226    let input = syn::parse_macro_input!(input as syn::DeriveInput);
227    let crate_fallback = if assume_in_module {
228        quote!(spacetimedb::spacetimedb_lib)
229    } else {
230        quote!(spacetimedb_lib)
231    };
232    sats::sats_type_from_derive(&input, crate_fallback)
233        .map(|ty| logic(&ty))
234        .unwrap_or_else(syn::Error::into_compile_error)
235        .into()
236}
237
238#[proc_macro_derive(Deserialize, attributes(sats))]
239pub fn deserialize(input: StdTokenStream) -> StdTokenStream {
240    sats_derive(input, false, sats::derive_deserialize)
241}
242
243#[proc_macro_derive(Serialize, attributes(sats))]
244pub fn serialize(input: StdTokenStream) -> StdTokenStream {
245    sats_derive(input, false, sats::derive_serialize)
246}
247
248#[proc_macro_derive(SpacetimeType, attributes(sats))]
249pub fn schema_type(input: StdTokenStream) -> StdTokenStream {
250    sats_derive(input, true, |ty| {
251        let ident = ty.ident;
252        let name = &ty.name;
253
254        let krate = &ty.krate;
255        TokenStream::from_iter([
256            sats::derive_satstype(ty),
257            sats::derive_deserialize(ty),
258            sats::derive_serialize(ty),
259            // unfortunately, generic types don't work in modules at the moment.
260            quote!(#krate::__make_register_reftype!(#ident, #name);),
261        ])
262    })
263}
264
265#[proc_macro_attribute]
266pub fn client_visibility_filter(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
267    ok_or_compile_error(|| {
268        if !args.is_empty() {
269            return Err(syn::Error::new_spanned(
270                TokenStream::from(args),
271                "The `client_visibility_filter` attribute does not accept arguments",
272            ));
273        }
274
275        let item: ItemConst = syn::parse(item)?;
276        let rls_ident = item.ident.clone();
277        let register_rls_symbol = format!("__preinit__20_register_row_level_security_{rls_ident}");
278
279        Ok(quote! {
280            #item
281
282            const _: () = {
283                #[export_name = #register_rls_symbol]
284                extern "C" fn __register_client_visibility_filter() {
285                    spacetimedb::rt::register_row_level_security(#rls_ident.sql_text())
286                }
287            };
288        })
289    })
290}