Skip to main content

spacetimedb_bindings_macro/
lib.rs

1//! Defines procedural macros like `#[spacetimedb::table]`,
2//! simplifying writing SpacetimeDB modules in Rust.
3
4// DO NOT WRITE (public) DOCS IN THIS MODULE.
5// Docs should be written in the `spacetimedb` crate (i.e. `bindings/`) at reexport sites
6// using `#[doc(inline)]`.
7// We do this so that links to library traits, structs, etc can resolve correctly.
8//
9// (private documentation for the macro authors is totally fine here and you SHOULD write that!)
10
11mod procedure;
12
13#[proc_macro_attribute]
14pub fn procedure(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
15    cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
16        let args = procedure::ProcedureArgs::parse(args)?;
17        procedure::procedure_impl(args, original_function)
18    })
19}
20mod reducer;
21
22#[proc_macro_attribute]
23pub fn reducer(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
24    cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
25        let args = reducer::ReducerArgs::parse(args)?;
26        reducer::reducer_impl(args, original_function)
27    })
28}
29mod sats;
30mod table;
31
32#[proc_macro_attribute]
33pub fn table(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
34    // put this on the struct so we don't get unknown attribute errors
35    let derive_table_helper: syn::Attribute = derive_table_helper_attr();
36
37    ok_or_compile_error(|| {
38        let item = TokenStream::from(item);
39        let mut derive_input: syn::DeriveInput = syn::parse2(item.clone())?;
40
41        // Add `derive(__TableHelper)` only if it's not already in the attributes of the `derive_input.`
42        // If multiple `#[table]` attributes are applied to the same `struct` item,
43        // this will ensure that we don't emit multiple conflicting implementations
44        // for traits like `SpacetimeType`, `Serialize` and `Deserialize`.
45        //
46        // We need to push at the end, rather than the beginning,
47        // because rustc expands attribute macros (including derives) top-to-bottom,
48        // and we need *all* `#[table]` attributes *before* the `derive(__TableHelper)`.
49        // This way, the first `table` will insert a `derive(__TableHelper)`,
50        // and all subsequent `#[table]`s on the same `struct` will see it,
51        // and not add another.
52        //
53        // Note, thank goodness, that `syn`'s `PartialEq` impls (provided with the `extra-traits` feature)
54        // skip any [`Span`]s contained in the items,
55        // thereby comparing for syntactic rather than structural equality. This shouldn't matter,
56        // since we expect that the `derive_table_helper` will always have the same [`Span`]s,
57        // but it's nice to know.
58        if !derive_input.attrs.contains(&derive_table_helper) {
59            derive_input.attrs.push(derive_table_helper);
60        }
61
62        let args = table::TableArgs::parse(args.into(), &derive_input.ident)?;
63        let generated = table::table_impl(args, &derive_input)?;
64        Ok(TokenStream::from_iter([quote!(#derive_input), generated]))
65    })
66}
67mod util;
68mod view;
69
70#[proc_macro_attribute]
71pub fn view(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
72    let item_ts: TokenStream = item.into();
73    let original_function = match syn::parse2::<ItemFn>(item_ts.clone()) {
74        Ok(f) => f,
75        Err(e) => return TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
76    };
77    let args = match view::ViewArgs::parse(args.into(), &original_function.sig.ident) {
78        Ok(a) => a,
79        Err(e) => return TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
80    };
81    match view::view_impl(args, &original_function) {
82        Ok(ts) => ts.into(),
83        Err(e) => TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
84    }
85}
86
87use proc_macro::TokenStream as StdTokenStream;
88use proc_macro2::TokenStream;
89use quote::quote;
90use std::time::Duration;
91use syn::{parse::ParseStream, Attribute};
92use syn::{ItemConst, ItemFn};
93use util::{cvt_attr, ok_or_compile_error};
94
95mod sym {
96    /// A symbol known at compile-time against
97    /// which identifiers and paths may be matched.
98    pub struct Symbol(&'static str);
99
100    macro_rules! symbol {
101        ($ident:ident) => {
102            symbol!($ident, $ident);
103        };
104        ($const:ident, $ident:ident) => {
105            #[allow(non_upper_case_globals)]
106            #[doc = concat!("Matches `", stringify!($ident), "`.")]
107            pub const $const: Symbol = Symbol(stringify!($ident));
108        };
109    }
110
111    symbol!(accessor);
112    symbol!(at);
113    symbol!(auto_inc);
114    symbol!(btree);
115    symbol!(client_connected);
116    symbol!(client_disconnected);
117    symbol!(column);
118    symbol!(columns);
119    symbol!(crate_, crate);
120    symbol!(direct);
121    symbol!(hash);
122    symbol!(index);
123    symbol!(init);
124    symbol!(name);
125    symbol!(primary_key);
126    symbol!(private);
127    symbol!(public);
128    symbol!(repr);
129    symbol!(sats);
130    symbol!(scheduled);
131    symbol!(unique);
132    symbol!(update);
133    symbol!(default);
134    symbol!(event);
135
136    symbol!(u8);
137    symbol!(i8);
138    symbol!(u16);
139    symbol!(i16);
140    symbol!(u32);
141    symbol!(i32);
142    symbol!(u64);
143    symbol!(i64);
144    symbol!(u128);
145    symbol!(i128);
146    symbol!(f32);
147    symbol!(f64);
148
149    impl PartialEq<Symbol> for syn::Ident {
150        fn eq(&self, sym: &Symbol) -> bool {
151            self == sym.0
152        }
153    }
154    impl PartialEq<Symbol> for &syn::Ident {
155        fn eq(&self, sym: &Symbol) -> bool {
156            *self == sym.0
157        }
158    }
159    impl PartialEq<Symbol> for syn::Path {
160        fn eq(&self, sym: &Symbol) -> bool {
161            self.is_ident(sym)
162        }
163    }
164    impl PartialEq<Symbol> for &syn::Path {
165        fn eq(&self, sym: &Symbol) -> bool {
166            self.is_ident(sym)
167        }
168    }
169    impl std::fmt::Display for Symbol {
170        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171            f.write_str(self.0)
172        }
173    }
174    impl std::borrow::Borrow<str> for Symbol {
175        fn borrow(&self) -> &str {
176            self.0
177        }
178    }
179}
180
181/// It turns out to be shockingly difficult to construct an [`Attribute`].
182/// That type is not [`Parse`], instead having two distinct methods
183/// for parsing "inner" vs "outer" attributes.
184///
185/// We need this [`Attribute`] in [`table`] so that we can "pushnew" it
186/// onto the end of a list of attributes. See comments within [`table`].
187fn derive_table_helper_attr() -> Attribute {
188    let source = quote!(#[derive(spacetimedb::__TableHelper)]);
189
190    syn::parse::Parser::parse2(Attribute::parse_outer, source)
191        .unwrap()
192        .into_iter()
193        .next()
194        .unwrap()
195}
196
197/// Special alias for `derive(SpacetimeType)`, aka [`schema_type`], for use by [`table`].
198///
199/// Provides helper attributes for `#[spacetimedb::table]`, so that we don't get unknown attribute errors.
200#[doc(hidden)]
201#[proc_macro_derive(__TableHelper, attributes(sats, unique, auto_inc, primary_key, index, default))]
202pub fn table_helper(input: StdTokenStream) -> StdTokenStream {
203    schema_type(input)
204}
205
206#[proc_macro]
207pub fn duration(input: StdTokenStream) -> StdTokenStream {
208    let dur = syn::parse_macro_input!(input with parse_duration);
209    let (secs, nanos) = (dur.as_secs(), dur.subsec_nanos());
210    quote!({
211        const DUR: ::core::time::Duration = ::core::time::Duration::new(#secs, #nanos);
212        DUR
213    })
214    .into()
215}
216
217fn parse_duration(input: ParseStream) -> syn::Result<Duration> {
218    let lookahead = input.lookahead1();
219    let (s, span) = if lookahead.peek(syn::LitStr) {
220        let s = input.parse::<syn::LitStr>()?;
221        (s.value(), s.span())
222    } else if lookahead.peek(syn::LitInt) {
223        let i = input.parse::<syn::LitInt>()?;
224        (i.to_string(), i.span())
225    } else {
226        return Err(lookahead.error());
227    };
228    humantime::parse_duration(&s).map_err(|e| syn::Error::new(span, format_args!("can't parse as duration: {e}")))
229}
230
231/// A helper for the common bits of the derive macros.
232fn sats_derive(
233    input: StdTokenStream,
234    assume_in_module: bool,
235    logic: impl FnOnce(&sats::SatsType) -> TokenStream,
236) -> StdTokenStream {
237    let input = syn::parse_macro_input!(input as syn::DeriveInput);
238    let crate_fallback = if assume_in_module {
239        quote!(spacetimedb::spacetimedb_lib)
240    } else {
241        quote!(spacetimedb_lib)
242    };
243    sats::sats_type_from_derive(&input, crate_fallback)
244        .map(|ty| logic(&ty))
245        .unwrap_or_else(syn::Error::into_compile_error)
246        .into()
247}
248
249#[proc_macro_derive(Deserialize, attributes(sats))]
250pub fn deserialize(input: StdTokenStream) -> StdTokenStream {
251    sats_derive(input, false, sats::derive_deserialize)
252}
253
254#[proc_macro_derive(Serialize, attributes(sats))]
255pub fn serialize(input: StdTokenStream) -> StdTokenStream {
256    sats_derive(input, false, sats::derive_serialize)
257}
258
259#[proc_macro_derive(SpacetimeType, attributes(sats))]
260pub fn schema_type(input: StdTokenStream) -> StdTokenStream {
261    sats_derive(input, true, |ty| {
262        let ident = ty.ident;
263        let name = &ty.name;
264
265        let krate = &ty.krate;
266        TokenStream::from_iter([
267            sats::derive_satstype(ty),
268            sats::derive_deserialize(ty),
269            sats::derive_serialize(ty),
270            // unfortunately, generic types don't work in modules at the moment.
271            quote!(#krate::__make_register_reftype!(#ident, #name);),
272        ])
273    })
274}
275
276#[proc_macro_attribute]
277pub fn client_visibility_filter(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
278    ok_or_compile_error(|| {
279        if !args.is_empty() {
280            return Err(syn::Error::new_spanned(
281                TokenStream::from(args),
282                "The `client_visibility_filter` attribute does not accept arguments",
283            ));
284        }
285
286        let item: ItemConst = syn::parse(item)?;
287        let rls_ident = item.ident.clone();
288        let register_rls_symbol = format!("__preinit__20_register_row_level_security_{rls_ident}");
289
290        Ok(quote! {
291            #item
292
293            const _: () = {
294                #[unsafe(export_name = #register_rls_symbol)]
295                extern "C" fn __register_client_visibility_filter() {
296                    spacetimedb::rt::register_row_level_security(#rls_ident.sql_text())
297                }
298            };
299        })
300    })
301}
302
303/// Known setting names and their registration code generators.
304const KNOWN_SETTINGS: &[&str] = &["CASE_CONVERSION_POLICY"];
305
306#[proc_macro_attribute]
307pub fn settings(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
308    ok_or_compile_error(|| {
309        if !args.is_empty() {
310            return Err(syn::Error::new_spanned(
311                TokenStream::from(args),
312                "The `settings` attribute does not accept arguments",
313            ));
314        }
315
316        let item: ItemConst = syn::parse(item)?;
317        let ident = &item.ident;
318        let ident_str = ident.to_string();
319
320        if !KNOWN_SETTINGS.contains(&ident_str.as_str()) {
321            return Err(syn::Error::new_spanned(
322                ident,
323                format!(
324                    "unknown setting `{ident_str}`. Known settings: {}",
325                    KNOWN_SETTINGS.join(", ")
326                ),
327            ));
328        }
329
330        // Use a fixed export name so that two `#[spacetimedb::settings]` consts
331        // for the same setting produce a linker error (duplicate symbol).
332        let register_symbol = format!("__preinit__05_setting_{ident_str}");
333
334        // Generate the registration call based on the setting name.
335        let register_call = match ident_str.as_str() {
336            "CASE_CONVERSION_POLICY" => quote! {
337                spacetimedb::rt::register_case_conversion_policy(#ident)
338            },
339            _ => unreachable!("validated above"),
340        };
341
342        Ok(quote! {
343            #item
344
345            const _: () = {
346                #[unsafe(export_name = #register_symbol)]
347                extern "C" fn __register_setting() {
348                    #register_call
349                }
350            };
351        })
352    })
353}