spacetimedb_bindings_macro/
lib.rs1mod procedure;
12
13#[proc_macro_attribute]
14pub fn procedure(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
15 cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
16 let args = procedure::ProcedureArgs::parse(args)?;
17 procedure::procedure_impl(args, original_function)
18 })
19}
20mod reducer;
21
22#[proc_macro_attribute]
23pub fn reducer(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
24 cvt_attr::<ItemFn>(args, item, quote!(), |args, original_function| {
25 let args = reducer::ReducerArgs::parse(args)?;
26 reducer::reducer_impl(args, original_function)
27 })
28}
29mod sats;
30mod table;
31
32#[proc_macro_attribute]
33pub fn table(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
34 let derive_table_helper: syn::Attribute = derive_table_helper_attr();
36
37 ok_or_compile_error(|| {
38 let item = TokenStream::from(item);
39 let mut derive_input: syn::DeriveInput = syn::parse2(item.clone())?;
40
41 if !derive_input.attrs.contains(&derive_table_helper) {
59 derive_input.attrs.push(derive_table_helper);
60 }
61
62 let args = table::TableArgs::parse(args.into(), &derive_input.ident)?;
63 let generated = table::table_impl(args, &derive_input)?;
64 Ok(TokenStream::from_iter([quote!(#derive_input), generated]))
65 })
66}
67mod util;
68mod view;
69
70#[proc_macro_attribute]
71pub fn view(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
72 let item_ts: TokenStream = item.into();
73 let original_function = match syn::parse2::<ItemFn>(item_ts.clone()) {
74 Ok(f) => f,
75 Err(e) => return TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
76 };
77 let args = match view::ViewArgs::parse(args.into(), &original_function.sig.ident) {
78 Ok(a) => a,
79 Err(e) => return TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
80 };
81 match view::view_impl(args, &original_function) {
82 Ok(ts) => ts.into(),
83 Err(e) => TokenStream::from_iter([item_ts, e.into_compile_error()]).into(),
84 }
85}
86
87use proc_macro::TokenStream as StdTokenStream;
88use proc_macro2::TokenStream;
89use quote::quote;
90use std::time::Duration;
91use syn::{parse::ParseStream, Attribute};
92use syn::{ItemConst, ItemFn};
93use util::{cvt_attr, ok_or_compile_error};
94
95mod sym {
96 pub struct Symbol(&'static str);
99
100 macro_rules! symbol {
101 ($ident:ident) => {
102 symbol!($ident, $ident);
103 };
104 ($const:ident, $ident:ident) => {
105 #[allow(non_upper_case_globals)]
106 #[doc = concat!("Matches `", stringify!($ident), "`.")]
107 pub const $const: Symbol = Symbol(stringify!($ident));
108 };
109 }
110
111 symbol!(accessor);
112 symbol!(at);
113 symbol!(auto_inc);
114 symbol!(btree);
115 symbol!(client_connected);
116 symbol!(client_disconnected);
117 symbol!(column);
118 symbol!(columns);
119 symbol!(crate_, crate);
120 symbol!(direct);
121 symbol!(hash);
122 symbol!(index);
123 symbol!(init);
124 symbol!(name);
125 symbol!(primary_key);
126 symbol!(private);
127 symbol!(public);
128 symbol!(repr);
129 symbol!(sats);
130 symbol!(scheduled);
131 symbol!(unique);
132 symbol!(update);
133 symbol!(default);
134 symbol!(event);
135
136 symbol!(u8);
137 symbol!(i8);
138 symbol!(u16);
139 symbol!(i16);
140 symbol!(u32);
141 symbol!(i32);
142 symbol!(u64);
143 symbol!(i64);
144 symbol!(u128);
145 symbol!(i128);
146 symbol!(f32);
147 symbol!(f64);
148
149 impl PartialEq<Symbol> for syn::Ident {
150 fn eq(&self, sym: &Symbol) -> bool {
151 self == sym.0
152 }
153 }
154 impl PartialEq<Symbol> for &syn::Ident {
155 fn eq(&self, sym: &Symbol) -> bool {
156 *self == sym.0
157 }
158 }
159 impl PartialEq<Symbol> for syn::Path {
160 fn eq(&self, sym: &Symbol) -> bool {
161 self.is_ident(sym)
162 }
163 }
164 impl PartialEq<Symbol> for &syn::Path {
165 fn eq(&self, sym: &Symbol) -> bool {
166 self.is_ident(sym)
167 }
168 }
169 impl std::fmt::Display for Symbol {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.write_str(self.0)
172 }
173 }
174 impl std::borrow::Borrow<str> for Symbol {
175 fn borrow(&self) -> &str {
176 self.0
177 }
178 }
179}
180
181fn derive_table_helper_attr() -> Attribute {
188 let source = quote!(#[derive(spacetimedb::__TableHelper)]);
189
190 syn::parse::Parser::parse2(Attribute::parse_outer, source)
191 .unwrap()
192 .into_iter()
193 .next()
194 .unwrap()
195}
196
197#[doc(hidden)]
201#[proc_macro_derive(__TableHelper, attributes(sats, unique, auto_inc, primary_key, index, default))]
202pub fn table_helper(input: StdTokenStream) -> StdTokenStream {
203 schema_type(input)
204}
205
206#[proc_macro]
207pub fn duration(input: StdTokenStream) -> StdTokenStream {
208 let dur = syn::parse_macro_input!(input with parse_duration);
209 let (secs, nanos) = (dur.as_secs(), dur.subsec_nanos());
210 quote!({
211 const DUR: ::core::time::Duration = ::core::time::Duration::new(#secs, #nanos);
212 DUR
213 })
214 .into()
215}
216
217fn parse_duration(input: ParseStream) -> syn::Result<Duration> {
218 let lookahead = input.lookahead1();
219 let (s, span) = if lookahead.peek(syn::LitStr) {
220 let s = input.parse::<syn::LitStr>()?;
221 (s.value(), s.span())
222 } else if lookahead.peek(syn::LitInt) {
223 let i = input.parse::<syn::LitInt>()?;
224 (i.to_string(), i.span())
225 } else {
226 return Err(lookahead.error());
227 };
228 humantime::parse_duration(&s).map_err(|e| syn::Error::new(span, format_args!("can't parse as duration: {e}")))
229}
230
231fn sats_derive(
233 input: StdTokenStream,
234 assume_in_module: bool,
235 logic: impl FnOnce(&sats::SatsType) -> TokenStream,
236) -> StdTokenStream {
237 let input = syn::parse_macro_input!(input as syn::DeriveInput);
238 let crate_fallback = if assume_in_module {
239 quote!(spacetimedb::spacetimedb_lib)
240 } else {
241 quote!(spacetimedb_lib)
242 };
243 sats::sats_type_from_derive(&input, crate_fallback)
244 .map(|ty| logic(&ty))
245 .unwrap_or_else(syn::Error::into_compile_error)
246 .into()
247}
248
249#[proc_macro_derive(Deserialize, attributes(sats))]
250pub fn deserialize(input: StdTokenStream) -> StdTokenStream {
251 sats_derive(input, false, sats::derive_deserialize)
252}
253
254#[proc_macro_derive(Serialize, attributes(sats))]
255pub fn serialize(input: StdTokenStream) -> StdTokenStream {
256 sats_derive(input, false, sats::derive_serialize)
257}
258
259#[proc_macro_derive(SpacetimeType, attributes(sats))]
260pub fn schema_type(input: StdTokenStream) -> StdTokenStream {
261 sats_derive(input, true, |ty| {
262 let ident = ty.ident;
263 let name = &ty.name;
264
265 let krate = &ty.krate;
266 TokenStream::from_iter([
267 sats::derive_satstype(ty),
268 sats::derive_deserialize(ty),
269 sats::derive_serialize(ty),
270 quote!(#krate::__make_register_reftype!(#ident, #name);),
272 ])
273 })
274}
275
276#[proc_macro_attribute]
277pub fn client_visibility_filter(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
278 ok_or_compile_error(|| {
279 if !args.is_empty() {
280 return Err(syn::Error::new_spanned(
281 TokenStream::from(args),
282 "The `client_visibility_filter` attribute does not accept arguments",
283 ));
284 }
285
286 let item: ItemConst = syn::parse(item)?;
287 let rls_ident = item.ident.clone();
288 let register_rls_symbol = format!("__preinit__20_register_row_level_security_{rls_ident}");
289
290 Ok(quote! {
291 #item
292
293 const _: () = {
294 #[unsafe(export_name = #register_rls_symbol)]
295 extern "C" fn __register_client_visibility_filter() {
296 spacetimedb::rt::register_row_level_security(#rls_ident.sql_text())
297 }
298 };
299 })
300 })
301}
302
303const KNOWN_SETTINGS: &[&str] = &["CASE_CONVERSION_POLICY"];
305
306#[proc_macro_attribute]
307pub fn settings(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream {
308 ok_or_compile_error(|| {
309 if !args.is_empty() {
310 return Err(syn::Error::new_spanned(
311 TokenStream::from(args),
312 "The `settings` attribute does not accept arguments",
313 ));
314 }
315
316 let item: ItemConst = syn::parse(item)?;
317 let ident = &item.ident;
318 let ident_str = ident.to_string();
319
320 if !KNOWN_SETTINGS.contains(&ident_str.as_str()) {
321 return Err(syn::Error::new_spanned(
322 ident,
323 format!(
324 "unknown setting `{ident_str}`. Known settings: {}",
325 KNOWN_SETTINGS.join(", ")
326 ),
327 ));
328 }
329
330 let register_symbol = format!("__preinit__05_setting_{ident_str}");
333
334 let register_call = match ident_str.as_str() {
336 "CASE_CONVERSION_POLICY" => quote! {
337 spacetimedb::rt::register_case_conversion_policy(#ident)
338 },
339 _ => unreachable!("validated above"),
340 };
341
342 Ok(quote! {
343 #item
344
345 const _: () = {
346 #[unsafe(export_name = #register_symbol)]
347 extern "C" fn __register_setting() {
348 #register_call
349 }
350 };
351 })
352 })
353}