spacetimedb_bindgen/
lib.rs

1#![crate_type = "proc-macro"]
2
3// mod csharp;
4mod module;
5
6extern crate core;
7extern crate proc_macro;
8
9// use crate::csharp::{autogen_csharp_reducer, autogen_csharp_tuple};
10use crate::module::{
11    args_to_tuple_schema, autogen_module_struct_to_schema, autogen_module_struct_to_tuple,
12    autogen_module_tuple_to_struct,
13};
14use proc_macro::TokenStream;
15use proc_macro2::Ident;
16use quote::{format_ident, quote, ToTokens};
17use std::time::Duration;
18use syn::Fields::{Named, Unit, Unnamed};
19use syn::{parse_macro_input, AttributeArgs, FnArg, ItemFn, ItemStruct};
20
21// When we add support for more than 1 language uncomment this. For now its just cumbersome.
22// enum Lang {
23//     CS
24// }
25
26#[proc_macro_attribute]
27pub fn spacetimedb(macro_args: TokenStream, item: TokenStream) -> TokenStream {
28    let attribute_args = parse_macro_input!(macro_args as AttributeArgs);
29    let attribute_str = attribute_args[0].to_token_stream().to_string();
30    let attribute_str = attribute_str.as_str();
31
32    match attribute_str {
33        "table" => spacetimedb_table(attribute_args, item),
34        "reducer" => spacetimedb_reducer(attribute_args, item),
35        "connect" => spacetimedb_connect_disconnect(attribute_args, item, true),
36        "disconnect" => spacetimedb_connect_disconnect(attribute_args, item, false),
37        "migrate" => spacetimedb_migrate(attribute_args, item),
38        "tuple" => spacetimedb_tuple(attribute_args, item),
39        "index(btree)" => spacetimedb_index(attribute_args, item),
40        "index(hash)" => spacetimedb_index(attribute_args, item),
41        _ => proc_macro::TokenStream::from(quote! {
42            compile_error!("Please pass a valid attribute to the spacetimedb macro: reducer, table, connect, disconnect, migrate, tuple, index, ...");
43        }),
44    }
45}
46
47fn spacetimedb_reducer(args: AttributeArgs, item: TokenStream) -> TokenStream {
48    if *(&args.len()) > 1 {
49        let arg = args[1].to_token_stream();
50        let arg_components = arg.into_iter().collect::<Vec<_>>();
51        let arg_name = &arg_components[0];
52        let repeat = match arg_name {
53            proc_macro2::TokenTree::Group(_) => false,
54            proc_macro2::TokenTree::Ident(ident) => {
55                if ident.to_string() != "repeat" {
56                    false
57                } else {
58                    true
59                }
60            }
61            proc_macro2::TokenTree::Punct(_) => false,
62            proc_macro2::TokenTree::Literal(_) => false,
63        };
64        if !repeat {
65            let str = format!("Unexpected macro argument name: {}", arg_name.to_string());
66            return proc_macro::TokenStream::from(quote! {
67                compile_error!(#str);
68            });
69        }
70        let arg_value = &arg_components[2];
71        let res = parse_duration::parse(&arg_value.to_string());
72        if let Err(_) = res {
73            let str = format!("Can't parse repeat time: {}", arg_value.to_string());
74            return proc_macro::TokenStream::from(quote! {
75                compile_error!(#str);
76            });
77        }
78        let repeat_duration = res.unwrap();
79
80        return spacetimedb_repeating_reducer(args, item, repeat_duration);
81    }
82
83    let original_function = parse_macro_input!(item as ItemFn);
84    let func_name = &original_function.sig.ident;
85    let reducer_func_name = format_ident!("__reducer__{}", &func_name);
86    let descriptor_func_name = format_ident!("__describe_reducer__{}", &func_name);
87
88    let mut parse_json_to_args = Vec::new();
89    let mut function_call_arg_names = Vec::new();
90    let mut arg_num: usize = 0;
91    let mut json_arg_num: usize = 0;
92    let function_arguments = &original_function.sig.inputs;
93
94    let function_call_arg_types = args_to_tuple_schema(function_arguments.iter().skip(2));
95
96    for function_argument in function_arguments {
97        match function_argument {
98            FnArg::Receiver(_) => {
99                return proc_macro::TokenStream::from(quote! {
100                    compile_error!("Receiver types in reducer parameters not supported!");
101                });
102            }
103            FnArg::Typed(typed) => {
104                let arg_type = &typed.ty;
105                let arg_token = arg_type.to_token_stream();
106                let arg_type_str = arg_token.to_string();
107                let var_name = format_ident!("arg_{}", arg_num);
108
109                // First argument must be Hash (sender)
110                if arg_num == 0 {
111                    if arg_type_str != "spacetimedb::spacetimedb_lib::hash::Hash" && arg_type_str != "Hash" {
112                        let error_str = format!(
113                            "Parameter 1 of reducer {} must be of type \'Hash\'.",
114                            func_name.to_string()
115                        );
116                        return proc_macro::TokenStream::from(quote! {
117                            compile_error!(#error_str);
118                        });
119                    }
120                    arg_num += 1;
121                    continue;
122                }
123
124                // Second argument must be a u64 (timestamp)
125                if arg_num == 1 {
126                    if arg_type_str != "u64" {
127                        let error_str = format!(
128                            "Parameter 2 of reducer {} must be of type \'u64\'.",
129                            func_name.to_string()
130                        );
131                        return proc_macro::TokenStream::from(quote! {
132                            compile_error!(#error_str);
133                        });
134                    }
135                    arg_num += 1;
136                    continue;
137                }
138
139                // Stash the function
140                parse_json_to_args.push(quote! {
141                    let #var_name : #arg_token = serde_json::from_value(args[#json_arg_num].clone()).unwrap();
142                });
143
144                function_call_arg_names.push(var_name);
145                json_arg_num += 1;
146            }
147        }
148
149        arg_num = arg_num + 1;
150    }
151
152    let unwrap_args = match arg_num > 2 {
153        true => {
154            quote! {
155                let arg_json: serde_json::Value = serde_json::from_slice(
156                    arguments.argument_bytes.as_slice()).
157                expect(format!("Unable to parse arguments as JSON: {} bytes/arg_size: {}: {:?}",
158                    arguments.argument_bytes.len(), arg_size, arguments.argument_bytes).as_str());
159                let args = arg_json.as_array().expect("Unable to extract reducer arguments list");
160            }
161        }
162        false => {
163            quote! {}
164        }
165    };
166
167    let generated_function = quote! {
168        #[no_mangle]
169        #[allow(non_snake_case)]
170        pub extern "C" fn #reducer_func_name(arg_ptr: usize, arg_size: usize) {
171            let arguments = spacetimedb::spacetimedb_lib::args::ReducerArguments::decode_mem(
172                unsafe { arg_ptr as *mut u8 }, arg_size).expect("Unable to decode module arguments");
173
174            // Unwrap extra arguments, conditional on whether or not there are extra args.
175            #unwrap_args
176
177            // Deserialize the json argument list
178            #(#parse_json_to_args);*
179
180            // Invoke the function with the deserialized args
181            #func_name(arguments.identity, arguments.timestamp, #(#function_call_arg_names),*);
182        }
183    };
184
185    let reducer_name = func_name.to_string();
186    let generated_describe_function = quote! {
187        #[no_mangle]
188        #[allow(non_snake_case)]
189        // u64 is offset << 32 | length
190        pub extern "C" fn #descriptor_func_name() -> u64 {
191            let tupledef = spacetimedb::spacetimedb_lib::ReducerDef {
192                name: Some(#reducer_name.into()),
193                args: vec![
194                    #(#function_call_arg_types),*
195                ],
196            };
197            let mut bytes = vec![];
198            tupledef.encode(&mut bytes);
199            let offset = bytes.as_ptr() as u64;
200            let length = bytes.len() as u64;
201            std::mem::forget(bytes);
202            return offset << 32 | length;
203        }
204    };
205
206    // autogen_csharp_reducer(original_function.clone());
207
208    proc_macro::TokenStream::from(quote! {
209        #generated_function
210        #generated_describe_function
211        #original_function
212    })
213}
214
215fn spacetimedb_repeating_reducer(_args: AttributeArgs, item: TokenStream, repeat_duration: Duration) -> TokenStream {
216    let original_function = parse_macro_input!(item as ItemFn);
217    let func_name = &original_function.sig.ident;
218    let reducer_func_name = format_ident!("__repeating_reducer__{}", &func_name);
219
220    let mut arg_num: usize = 0;
221    let function_arguments = &original_function.sig.inputs;
222    if function_arguments.len() != 2 {
223        return proc_macro::TokenStream::from(quote! {
224            compile_error!("Expected 2 arguments (timestamp: u64, delta_time: u64) for repeating reducer.");
225        });
226    }
227    for function_argument in function_arguments {
228        match function_argument {
229            FnArg::Receiver(_) => {
230                return proc_macro::TokenStream::from(quote! {
231                    compile_error!("Receiver types in reducer parameters not supported!");
232                });
233            }
234            FnArg::Typed(typed) => {
235                let arg_type = &typed.ty;
236                let arg_token = arg_type.to_token_stream();
237                let arg_type_str = arg_token.to_string();
238
239                // First argument must be a u64 (timestamp)
240                if arg_num == 0 {
241                    if arg_type_str != "u64" {
242                        let error_str = format!(
243                            "Parameter 1 of reducer {} must be of type \'u64\'.",
244                            func_name.to_string()
245                        );
246                        return proc_macro::TokenStream::from(quote! {
247                            compile_error!(#error_str);
248                        });
249                    }
250                    arg_num += 1;
251                    continue;
252                }
253
254                // Second argument must be an u64 (delta_time)
255                if arg_num == 1 {
256                    if arg_type_str != "u64" {
257                        let error_str = format!(
258                            "Parameter 2 of reducer {} must be of type \'u64\'.",
259                            func_name.to_string()
260                        );
261                        return proc_macro::TokenStream::from(quote! {
262                            compile_error!(#error_str);
263                        });
264                    }
265                    arg_num += 1;
266                    continue;
267                }
268            }
269        }
270        arg_num = arg_num + 1;
271    }
272
273    let duration_as_millis = repeat_duration.as_millis() as u64;
274    let generated_function = quote! {
275        #[no_mangle]
276        #[allow(non_snake_case)]
277        pub extern "C" fn #reducer_func_name(arg_ptr: usize, arg_size: usize) -> u64 {
278            // Deserialize the arguments
279            let arguments = spacetimedb::spacetimedb_lib::args::RepeatingReducerArguments::decode_mem(
280                unsafe { arg_ptr as *mut u8 }, arg_size).expect("Unable to decode module arguments");
281
282            // Invoke the function with the deserialized args
283            #func_name(arguments.timestamp, arguments.delta_time);
284
285            return #duration_as_millis;
286        }
287    };
288
289    proc_macro::TokenStream::from(quote! {
290        #generated_function
291        #original_function
292    })
293}
294
295// TODO: We actually need to add a constraint that requires this column to be unique!
296struct Column {
297    ty: syn::Type,
298    ident: Ident,
299    index: u8,
300    convert_to_typevalue: proc_macro2::TokenStream,
301}
302
303fn spacetimedb_table(args: AttributeArgs, item: TokenStream) -> TokenStream {
304    if *(&args.len()) > 1 {
305        let str = format!("Unexpected macro argument: {}", args[1].to_token_stream().to_string());
306        return proc_macro::TokenStream::from(quote! {
307            compile_error!(#str);
308        });
309    }
310
311    let original_struct = parse_macro_input!(item as ItemStruct);
312    let original_struct_ident = &original_struct.clone().ident;
313
314    match original_struct.clone().fields {
315        Named(_) => {
316            // let table_id_field: Field = Field {
317            //     attrs: Vec::new(),
318            //     vis: Visibility::Public(VisPublic { pub_token: Default::default() }),
319            //     ident: Some(format_ident!("{}", "table_id")),
320            //     colon_token: Some(Colon::default()),
321            //     ty: syn::Type::Verbatim(format_ident!("{}", "u32").to_token_stream()),
322            // };
323            //
324            // fields.named.push(table_id_field);
325        }
326        Unnamed(_) => {
327            let str = format!("spacetimedb tables must have named fields.");
328            return proc_macro::TokenStream::from(quote! {
329                compile_error!(#str);
330            });
331        }
332        Unit => {
333            let str = format!("spacetimedb tables must have named fields (unit struct forbidden).");
334            return proc_macro::TokenStream::from(quote! {
335                compile_error!(#str);
336            });
337        }
338    }
339
340    let mut unique_columns = Vec::<Column>::new();
341    let mut filterable_columns = Vec::<Column>::new();
342
343    let table_id_static_var_name = format_ident!("__table_id__{}", original_struct.ident);
344    let get_table_id_func = quote! {
345        pub fn table_id() -> u32 {
346            *#table_id_static_var_name.get_or_init(|| {
347                spacetimedb::get_table_id(<Self as spacetimedb::TableType>::TABLE_NAME)
348            })
349        }
350    };
351
352    for (col_num, field) in original_struct.fields.iter().enumerate() {
353        let col_num: u8 = col_num.try_into().expect("too many columns");
354        let col_name = &field.ident.clone().unwrap();
355
356        // The TypeValue representation of this type
357        let convert_to_typevalue: proc_macro2::TokenStream;
358
359        match rust_to_spacetimedb_ident(field.ty.clone().to_token_stream().to_string().as_str()) {
360            Some(ident) => {
361                convert_to_typevalue = quote!(
362                    let value = spacetimedb::spacetimedb_lib::TypeValue::#ident(value);
363                );
364            }
365            None => match field.ty.clone().to_token_stream().to_string().as_str() {
366                "Hash" => {
367                    convert_to_typevalue = quote!(
368                        let value = spacetimedb::spacetimedb_lib::TypeValue::Hash(Box::new(value));
369                    );
370                }
371                "Vec < u8 >" => {
372                    // TODO: We are aliasing Vec<u8> to Bytes for now, we should deconstruct the vec here.
373                    convert_to_typevalue = quote!(
374                        let value = spacetimedb::spacetimedb_lib::TypeValue::Bytes(value);
375                    );
376                }
377                _custom_type => {
378                    convert_to_typevalue = quote!(
379                        let value = spacetimedb::spacetimedb_lib::TypeValue::Tuple(value);
380                    );
381                }
382            },
383        }
384        // // The simple name for the type, e.g. Hash
385        // let col_type: proc_macro2::TokenStream;
386        // // The fully qualified name for this type, e.g. spacetimedb::spacetimedb_lib::Hash
387        // let col_type_full: proc_macro2::TokenStream;
388        // // The TypeValue representation of this type
389        // let col_type_value: proc_macro2::TokenStream;
390        // let col_value_insert: proc_macro2::TokenStream;
391
392        // col_value_insert = format!("{}({})", col_type_value.clone(), format!("ins.{}", col_name))
393        //     .parse()
394        //     .unwrap();
395
396        let mut is_unique = false;
397        let mut is_filterable = false;
398        for attr in &field.attrs {
399            if attr.path.is_ident("unique") {
400                if is_filterable {
401                    panic!("can't be both") // TODO: better error
402                }
403                is_unique = true;
404            } else if attr.path.is_ident("filterable_by") {
405                if is_unique {
406                    panic!("can't be both") // TODO: better error
407                }
408                is_filterable = true;
409            }
410        }
411        let column = || Column {
412            ty: field.ty.clone(),
413            ident: col_name.clone(),
414            index: col_num,
415            convert_to_typevalue,
416        };
417
418        if is_unique {
419            unique_columns.push(column());
420        } else if is_filterable {
421            filterable_columns.push(column());
422        }
423    }
424
425    let mut unique_filter_funcs = Vec::with_capacity(unique_columns.len());
426    let mut unique_update_funcs = Vec::with_capacity(unique_columns.len());
427    let mut unique_delete_funcs = Vec::with_capacity(unique_columns.len());
428    let mut unique_fields = Vec::with_capacity(unique_columns.len());
429    for unique in unique_columns {
430        let filter_func_ident = format_ident!("filter_{}_eq", unique.ident);
431        let update_func_ident = format_ident!("update_{}_eq", unique.ident);
432        let delete_func_ident = format_ident!("delete_{}_eq", unique.ident);
433        let comparison_block = tuple_field_comparison_block(&original_struct.ident, &unique.ident, true);
434
435        let Column {
436            ty: column_type,
437            ident: column_ident,
438            index: column_index,
439            convert_to_typevalue,
440        } = unique;
441        let column_index_usize: usize = column_index.into();
442
443        unique_fields.push(column_index);
444
445        unique_filter_funcs.push(quote! {
446            #[allow(unused_variables)]
447            #[allow(non_snake_case)]
448            pub fn #filter_func_ident(#column_ident: #column_type) -> Option<Self> {
449                let table_iter = #original_struct_ident::iter_tuples();
450                for row in table_iter {
451                    let column_data = row.elements[#column_index_usize].clone();
452                    #comparison_block
453                }
454
455                return None;
456            }
457        });
458
459        unique_update_funcs.push(quote! {
460            #[allow(unused_variables)]
461            #[allow(non_snake_case)]
462            pub fn #update_func_ident(value: #column_type, new_value: Self) -> bool {
463                #original_struct_ident::#delete_func_ident(value);
464                #original_struct_ident::insert(new_value);
465
466                // For now this is always successful
467                true
468            }
469        });
470
471        unique_delete_funcs.push(quote! {
472            #[allow(unused_variables)]
473            #[allow(non_snake_case)]
474            pub fn #delete_func_ident(value: #column_type) -> bool {
475                #convert_to_typevalue
476                let result = spacetimedb::delete_eq(Self::table_id(), #column_index, value);
477                match result {
478                    None => {
479                        //TODO: Returning here was supposed to signify an error, but it can also return none when there is nothing to delete.
480                        //spacetimedb::println!("Internal server error on equatable type: {}", #primary_key_tuple_type_str);
481                        false
482                    },
483                    Some(count) => {
484                        count > 0
485                    }
486                }
487            }
488        });
489    }
490
491    let mut non_primary_filter_func = Vec::with_capacity(filterable_columns.len());
492    for column in filterable_columns {
493        let filter_func_ident: proc_macro2::TokenStream = format!("filter_{}_eq", column.ident).parse().unwrap();
494
495        let comparison_block = tuple_field_comparison_block(&original_struct_ident, &column.ident, false);
496
497        let column_ident = column.ident;
498        let column_type = column.ty;
499        let row_index: usize = column.index.into();
500
501        non_primary_filter_func.push(quote! {
502            #[allow(non_snake_case)]
503            #[allow(unused_variables)]
504            pub fn #filter_func_ident(#column_ident: #column_type) -> Vec<Self> {
505                let mut result = Vec::<Self>::new();
506                let table_iter = Self::iter_tuples();
507                for row in table_iter {
508                    let column_data = row.elements[#row_index].clone();
509                    #comparison_block
510                }
511
512                result
513            }
514        });
515    }
516
517    let db_insert: proc_macro2::TokenStream;
518    match parse_generated_func(quote! {
519        #[allow(unused_variables)]
520        pub fn insert(ins: #original_struct_ident) {
521            spacetimedb::insert(Self::table_id(), spacetimedb::IntoTuple::into_tuple(ins));
522        }
523    }) {
524        Ok(func) => db_insert = func,
525        Err(err) => {
526            return proc_macro::TokenStream::from(err);
527        }
528    }
529
530    let db_delete: proc_macro2::TokenStream;
531    match parse_generated_func(quote! {
532    #[allow(unused_variables)]
533    pub fn delete(f: fn (#original_struct_ident) -> bool) -> usize {
534        panic!("Delete using a function is not supported yet!");
535    }}) {
536        Ok(func) => db_delete = func,
537        Err(err) => {
538            return proc_macro::TokenStream::from(err);
539        }
540    }
541
542    let db_update: proc_macro2::TokenStream;
543    match parse_generated_func(quote! {
544    #[allow(unused_variables)]
545    pub fn update(value: #original_struct_ident) -> bool {
546        panic!("Update using a value is not supported yet!");
547    }}) {
548        Ok(func) => db_update = func,
549        Err(err) => {
550            return proc_macro::TokenStream::from(err);
551        }
552    }
553
554    let db_iter_tuples: proc_macro2::TokenStream;
555    match parse_generated_func(quote! {
556        #[allow(unused_variables)]
557        pub fn iter_tuples() -> spacetimedb::TableIter {
558            spacetimedb::__iter__(Self::table_id()).expect("Failed to get iterator from table.")
559        }
560    }) {
561        Ok(func) => db_iter_tuples = func,
562        Err(err) => {
563            return proc_macro::TokenStream::from(err);
564        }
565    }
566
567    let db_iter_ident = format_ident!("{}{}", original_struct_ident, "Iter");
568    let db_iter_struct = quote! {
569        pub struct #db_iter_ident {
570            iter: spacetimedb::TableIter,
571        }
572
573        impl Iterator for #db_iter_ident {
574            type Item = #original_struct_ident;
575
576            fn next(&mut self) -> Option<Self::Item> {
577                if let Some(tuple) = self.iter.next() {
578                    Some(spacetimedb::FromTuple::from_tuple(tuple).expect("Failed to convert tuple to struct."))
579                } else {
580                    None
581                }
582            }
583        }
584    };
585
586    let db_iter: proc_macro2::TokenStream;
587    match parse_generated_func(quote! {
588        #[allow(unused_variables)]
589        pub fn iter() -> #db_iter_ident {
590            #db_iter_ident {
591                iter: Self::iter_tuples()
592            }
593        }
594    }) {
595        Ok(func) => db_iter = func,
596        Err(err) => {
597            return proc_macro::TokenStream::from(err);
598        }
599    }
600
601    let from_value_impl = match autogen_module_tuple_to_struct(&original_struct) {
602        Ok(func) => func,
603        Err(err) => {
604            return TokenStream::from(err);
605        }
606    };
607    let into_value_impl = match autogen_module_struct_to_tuple(&original_struct) {
608        Ok(func) => func,
609        Err(err) => {
610            return TokenStream::from(err);
611        }
612    };
613    let schema_impl = match autogen_module_struct_to_schema(&original_struct) {
614        Ok(func) => func,
615        Err(err) => {
616            return TokenStream::from(err);
617        }
618    };
619    let table_name = original_struct_ident.to_string();
620    let tabletype_impl = quote! {
621        impl spacetimedb::TableType for #original_struct_ident {
622            const TABLE_NAME: &'static str = #table_name;
623            const UNIQUE_COLUMNS: &'static [u8] = &[#(#unique_fields),*];
624        }
625    };
626
627    // let csharp_output = autogen_csharp_tuple(original_struct.clone(), Some(original_struct_ident.to_string()));
628
629    let create_table_func_name = format_ident!("__create_table__{}", original_struct_ident);
630    let describe_table_func_name = format_ident!("__describe_table__{}", original_struct_ident);
631
632    let table_id_static_var = quote! {
633        #[allow(non_upper_case_globals)]
634        static #table_id_static_var_name: spacetimedb::__private::OnceCell<u32> = spacetimedb::__private::OnceCell::new();
635    };
636
637    let create_table_func = quote! {
638        #[allow(non_snake_case)]
639        #[no_mangle]
640        pub extern "C" fn #create_table_func_name(arg_ptr: usize, arg_size: usize) {
641            let table_id = <#original_struct_ident as spacetimedb::TableType>::create_table();
642            #table_id_static_var_name.set(table_id).unwrap_or_else(|_| {
643                // TODO: this is okay? or should we panic? can this even happen?
644            });
645        }
646    };
647
648    let describe_table_func = quote! {
649        #[allow(non_snake_case)]
650        #[no_mangle]
651        pub extern "C" fn #describe_table_func_name() -> u64 {
652            <#original_struct_ident as spacetimedb::TableType>::describe_table()
653        }
654    };
655
656    // Output all macro data
657    let emission = quote! {
658        #table_id_static_var
659
660        #create_table_func
661        #describe_table_func
662        // #csharp_output
663
664        #[derive(spacetimedb::Unique, spacetimedb::Index)]
665        #[derive(serde::Serialize, serde::Deserialize)]
666        #original_struct
667
668        #db_iter_struct
669        impl #original_struct_ident {
670            #db_insert
671            #db_delete
672            #db_update
673            #(#unique_filter_funcs)*
674            #(#unique_update_funcs)*
675            #(#unique_delete_funcs)*
676
677            #db_iter
678            #db_iter_tuples
679            #(#non_primary_filter_func)*
680
681            #get_table_id_func
682        }
683
684        #schema_impl
685        #from_value_impl
686        #into_value_impl
687        #tabletype_impl
688    };
689
690    if std::env::var("PROC_MACRO_DEBUG").is_ok() {
691        println!("{}", emission.to_string());
692    }
693
694    proc_macro::TokenStream::from(emission)
695}
696
697fn spacetimedb_index(args: AttributeArgs, item: TokenStream) -> TokenStream {
698    let mut index_name: String = "default_index".to_string();
699    let mut index_fields = Vec::<u32>::new();
700    let mut all_fields = Vec::<Ident>::new();
701    let index_type: u8; // default index is a btree
702
703    match args[0].to_token_stream().to_string().as_str() {
704        "index(btree)" => {
705            index_type = 0;
706        }
707        "index(hash)" => {
708            index_type = 1;
709        }
710        _ => {
711            let invalid_index = format!(
712                "Invalid index type: {}\nValid options are: index(btree), index(hash)",
713                args[0].to_token_stream().to_string()
714            );
715            return proc_macro::TokenStream::from(quote! {
716                compile_error!(#invalid_index);
717            });
718        }
719    }
720
721    let original_struct = parse_macro_input!(item as ItemStruct);
722    for field in original_struct.clone().fields {
723        all_fields.push(field.ident.unwrap());
724    }
725
726    for x in 1..args.len() {
727        let arg = &args[x];
728        let arg_str = arg.to_token_stream().to_string();
729        let name_prefix = "name = ";
730        if arg_str.starts_with(name_prefix) {
731            index_name = arg_str
732                .chars()
733                .skip(name_prefix.len() + 1)
734                .take(arg_str.len() - name_prefix.len() - 2)
735                .collect();
736        } else {
737            let field_index = all_fields
738                .iter()
739                .position(|a| a.to_token_stream().to_string() == arg_str);
740            match field_index {
741                Some(field_index) => {
742                    index_fields.push(field_index as u32);
743                }
744                None => {
745                    let invalid_index = format!("Invalid field for index: {}", arg_str);
746                    return proc_macro::TokenStream::from(quote! {
747                        compile_error!(#invalid_index);
748                    });
749                }
750            }
751        }
752    }
753
754    let original_struct_name = &original_struct.ident;
755    let function_name: Ident = format_ident!("__create_index__{}", format_ident!("{}", index_name.as_str()));
756
757    let output = quote! {
758        #original_struct
759
760        impl #original_struct_name {
761            #[allow(non_snake_case)]
762            fn #function_name(arg_ptr: u32, arg_size: u32) {
763                spacetimedb::create_index(Self::table_id(), #index_type, vec!(#(#index_fields),*));
764            }
765        }
766    };
767
768    if std::env::var("PROC_MACRO_DEBUG").is_ok() {
769        println!("{}", output.to_string());
770    }
771
772    proc_macro::TokenStream::from(output)
773}
774
775fn spacetimedb_tuple(_: AttributeArgs, item: TokenStream) -> TokenStream {
776    let original_struct = parse_macro_input!(item as ItemStruct);
777    let original_struct_ident = original_struct.clone().ident;
778
779    match original_struct.clone().fields {
780        Named(_) => {}
781        Unnamed(_) => {
782            let str = format!("spacetimedb tables and types must have named fields.");
783            return TokenStream::from(quote! {
784                compile_error!(#str);
785            });
786        }
787        Unit => {
788            let str = format!("Unit structure not supported.");
789            return TokenStream::from(quote! {
790                compile_error!(#str);
791            });
792        }
793    }
794
795    // let csharp_output = autogen_csharp_tuple(original_struct.clone(), None);
796    let schema_impl = match autogen_module_struct_to_schema(&original_struct) {
797        Ok(func) => func,
798        Err(err) => {
799            return TokenStream::from(err);
800        }
801    };
802    let from_value_impl = match autogen_module_tuple_to_struct(&original_struct) {
803        Ok(func) => func,
804        Err(err) => {
805            return TokenStream::from(err);
806        }
807    };
808    let into_value_impl = match autogen_module_struct_to_tuple(&original_struct) {
809        Ok(func) => func,
810        Err(err) => {
811            return TokenStream::from(err);
812        }
813    };
814
815    let create_tuple_func_name = format_ident!("__create_type__{}", original_struct_ident);
816    let create_tuple_func = quote! {
817        #[no_mangle]
818        #[allow(non_snake_case)]
819        pub extern "C" fn #create_tuple_func_name(ptr: *mut u8, arg_size: usize) {
820            let def = <#original_struct_ident as spacetimedb::SchemaType>::get_schema();
821            let mut bytes = unsafe { Vec::from_raw_parts(ptr, 0, arg_size) };
822            def.encode(&mut bytes);
823        }
824    };
825
826    let describe_tuple_func_name = format_ident!("__describe_tuple__{}", original_struct_ident);
827
828    let emission = quote! {
829        #[derive(serde::Serialize, serde::Deserialize)]
830        #original_struct
831        #schema_impl
832        #from_value_impl
833        #into_value_impl
834        #create_tuple_func
835
836        #[allow(non_snake_case)]
837        #[no_mangle]
838        pub extern "C" fn #describe_tuple_func_name() -> u64 {
839            <#original_struct_ident as spacetimedb::TupleType>::describe_tuple()
840        }
841    };
842
843    if std::env::var("PROC_MACRO_DEBUG").is_ok() {
844        println!("{}", emission.to_string());
845    }
846
847    return TokenStream::from(emission);
848}
849
850fn spacetimedb_migrate(_: AttributeArgs, item: TokenStream) -> TokenStream {
851    let original_func = parse_macro_input!(item as ItemFn);
852    let func_name = &original_func.sig.ident;
853
854    let emission = match parse_generated_func(quote! {
855    #[allow(non_snake_case)]
856    pub extern "C" fn __migrate__(arg_ptr: u32, arg_size: u32) {
857        #func_name();
858    }}) {
859        Ok(func) => {
860            quote! {
861                #func
862                #original_func
863            }
864        }
865        Err(err) => err,
866    };
867
868    if std::env::var("PROC_MACRO_DEBUG").is_ok() {
869        println!("{}", emission.to_string());
870    }
871
872    proc_macro::TokenStream::from(emission)
873}
874
875fn spacetimedb_connect_disconnect(args: AttributeArgs, item: TokenStream, connect: bool) -> TokenStream {
876    if *(&args.len()) > 1 {
877        let str = format!("Unexpected macro argument: {}", args[1].to_token_stream().to_string());
878        return proc_macro::TokenStream::from(quote! {
879            compile_error!(#str);
880        });
881    }
882
883    let original_function = parse_macro_input!(item as ItemFn);
884    let func_name = &original_function.sig.ident;
885    let connect_disconnect_func_name = if connect {
886        "__identity_connected__"
887    } else {
888        "__identity_disconnected__"
889    };
890    let connect_disconnect_ident = format_ident!("{}", connect_disconnect_func_name);
891
892    let mut arg_num: usize = 0;
893    for function_argument in original_function.sig.inputs.iter() {
894        if arg_num > 1 {
895            return proc_macro::TokenStream::from(quote! {
896                compile_error!("Client connect/disconnect can only have one argument (identity: Hash)");
897            });
898        }
899
900        match function_argument {
901            FnArg::Receiver(_) => {
902                return proc_macro::TokenStream::from(quote! {
903                    compile_error!("Receiver types in reducer parameters not supported!");
904                });
905            }
906            FnArg::Typed(typed) => {
907                let arg_type = &typed.ty;
908                let arg_token = arg_type.to_token_stream();
909                let arg_type_str = arg_token.to_string();
910
911                // First argument must be Hash (sender)
912                if arg_num == 0 {
913                    if arg_type_str != "spacetimedb::spacetimedb_lib::hash::Hash" && arg_type_str != "Hash" {
914                        let error_str = format!(
915                            "Parameter 1 of connect/disconnect {} must be of type \'Hash\'.",
916                            func_name.to_string()
917                        );
918                        return proc_macro::TokenStream::from(quote! {
919                            compile_error!(#error_str);
920                        });
921                    }
922                    arg_num += 1;
923                    continue;
924                }
925
926                // Second argument must be a u64 (timestamp)
927                if arg_num == 1 {
928                    if arg_type_str != "u64" {
929                        let error_str = format!(
930                            "Parameter 1 of connect/disconnect {} must be of type \'Hash\'.",
931                            func_name.to_string()
932                        );
933                        return proc_macro::TokenStream::from(quote! {
934                            compile_error!(#error_str);
935                        });
936                    }
937                    arg_num += 1;
938                    continue;
939                }
940            }
941        }
942
943        arg_num = arg_num + 1;
944    }
945
946    let emission = match parse_generated_func(quote! {
947        #[no_mangle]
948        #[allow(non_snake_case)]
949        pub extern "C" fn #connect_disconnect_ident(arg_ptr: usize, arg_size: usize) {
950            let arguments = spacetimedb::spacetimedb_lib::args::ConnectDisconnectArguments::decode_mem(
951                unsafe { arg_ptr as *mut u8 }, arg_size).expect("Unable to decode module arguments");
952
953            // Invoke the function with the deserialized args
954            #func_name(arguments.identity, arguments.timestamp,);
955        }
956    }) {
957        Ok(func) => quote! {
958            #func
959            #original_function
960        },
961        Err(err) => err,
962    };
963
964    if std::env::var("PROC_MACRO_DEBUG").is_ok() {
965        println!("{}", emission.to_string());
966    }
967
968    proc_macro::TokenStream::from(emission)
969}
970
971// This derive is actually a no-op, we need the helper attribute for spacetimedb
972#[proc_macro_derive(Unique, attributes(unique))]
973pub fn derive_unique(_: TokenStream) -> TokenStream {
974    TokenStream::new()
975}
976
977#[proc_macro_derive(Index, attributes(index))]
978pub fn derive_index(_item: TokenStream) -> TokenStream {
979    TokenStream::new()
980}
981
982pub(crate) fn rust_to_spacetimedb_ident(input_type: &str) -> Option<Ident> {
983    return match input_type {
984        // These are typically prefixed with spacetimedb::spacetimedb_lib::TypeDef::
985        "bool" => Some(format_ident!("Bool")),
986        "i8" => Some(format_ident!("I8")),
987        "u8" => Some(format_ident!("U8")),
988        "i16" => Some(format_ident!("I16")),
989        "u16" => Some(format_ident!("U16")),
990        "i32" => Some(format_ident!("I32")),
991        "u32" => Some(format_ident!("U32")),
992        "i64" => Some(format_ident!("I64")),
993        "u64" => Some(format_ident!("U64")),
994        "i128" => Some(format_ident!("I128")),
995        "u128" => Some(format_ident!("U128")),
996        "String" => Some(format_ident!("String")),
997        "&str" => Some(format_ident!("String")),
998        "f32" => Some(format_ident!("F32")),
999        "f64" => Some(format_ident!("F64")),
1000        _ => None,
1001    };
1002}
1003
1004fn tuple_field_comparison_block(
1005    tuple_type: &Ident,
1006    filter_field_name: &Ident,
1007    is_unique: bool,
1008) -> proc_macro2::TokenStream {
1009    let err_string = format!(
1010        "Internal stdb error: Can't convert from tuple to struct (wrong version?) {}",
1011        tuple_type
1012    );
1013
1014    let result_statement = if is_unique {
1015        quote! {
1016            let tuple = <Self as spacetimedb::FromTuple>::from_tuple(row);
1017            if tuple.is_none() {
1018                spacetimedb::println!(#err_string);
1019            }
1020            return tuple;
1021        }
1022    } else {
1023        quote! {
1024            let tuple = <Self as spacetimedb::FromTuple>::from_tuple(row);
1025            match tuple {
1026                Some(value) => result.push(value),
1027                None => {
1028                    spacetimedb::println!(#err_string);
1029                    continue;
1030                }
1031            }
1032        }
1033    };
1034
1035    quote! {
1036        if spacetimedb::FilterableValue::equals(&#filter_field_name, &column_data) {
1037            #result_statement
1038        }
1039    }
1040}
1041
1042fn parse_generated_func(
1043    func_stream: proc_macro2::TokenStream,
1044) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
1045    if !syn::parse2::<ItemFn>(func_stream.clone()).is_ok() {
1046        println!(
1047            "This function has an invalid generation:\n{}",
1048            func_stream.clone().to_string()
1049        );
1050        return Err(quote! {
1051            compile_error!("Invalid function produced by spacetimedb macro.");
1052        });
1053    }
1054    Ok(func_stream)
1055}