spacetimedsl_derive/
lib.rs1use ident_case::RenameRule;
2use proc_macro::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use spacetimedsl_derive_input::api::Table;
5mod output;
6
7#[proc_macro_attribute]
10pub fn dsl(args: TokenStream, item: TokenStream) -> TokenStream {
11 let derive_table_helper = derive_table_helper_attr();
13
14 ok_or_compile_error(|| {
15 let args = proc_macro2::TokenStream::from(args);
17 let mut derive_input: syn::DeriveInput = syn::parse(item)?;
18
19 let first_dsl_attribute = if !derive_input.attrs.contains(&derive_table_helper) {
23 derive_input.attrs.push(derive_table_helper);
24 true
25 } else {
26 false
27 };
28
29 let input = Table::try_parse(args, &derive_input)?;
30
31 let output = output::output(&input, first_dsl_attribute)?;
33
34 let is_last_dsl_attribute = is_last_dsl_attribute(&derive_input);
36
37 if is_last_dsl_attribute {
40 make_struct_fields_private(&mut derive_input);
41 }
42
43 Ok(proc_macro2::TokenStream::from_iter([
44 quote!(#derive_input),
45 output,
46 ]))
47 })
48}
49
50fn derive_table_helper_attr() -> syn::Attribute {
51 let source = quote!(#[derive(Clone, Debug, PartialEq, spacetimedsl::SpacetimeDSL)]); syn::parse::Parser::parse2(syn::Attribute::parse_outer, source)
54 .unwrap()
55 .into_iter()
56 .next()
57 .unwrap()
58}
59
60#[proc_macro_derive(
63 SpacetimeDSL,
64 attributes(create_wrapper, use_wrapper, foreign_key, referenced_by)
65)]
66pub fn table_helper(_input: proc_macro::TokenStream) -> proc_macro::TokenStream {
67 proc_macro::TokenStream::default()
68}
69
70fn ok_or_compile_error<Res: Into<proc_macro::TokenStream>>(
71 f: impl FnOnce() -> syn::Result<Res>,
72) -> proc_macro::TokenStream {
73 match f() {
74 Ok(ok) => ok.into(),
75 Err(e) => e.into_compile_error().into(),
76 }
77}
78
79fn is_last_dsl_attribute(derive_input: &syn::DeriveInput) -> bool {
83 let mut dsl_attr_count = 0;
85
86 for attr in &derive_input.attrs {
87 if let Ok(list) = attr.meta.require_list() {
89 let path_string = list.path.to_token_stream().to_string();
90 if path_string == "dsl" || path_string == "spacetimedsl :: dsl" {
91 dsl_attr_count += 1;
92 }
93 }
94 }
95
96 dsl_attr_count == 0
98}
99
100fn make_struct_fields_private(derive_input: &mut syn::DeriveInput) {
102 if let syn::Data::Struct(data_struct) = &mut derive_input.data
103 && let syn::Fields::Named(fields) = &mut data_struct.fields
104 {
105 for field in &mut fields.named {
106 field.vis = syn::Visibility::Inherited;
107 }
108 }
109}
110
111#[proc_macro_attribute]
115pub fn hook(_args: TokenStream, item: TokenStream) -> TokenStream {
116 ok_or_compile_error(|| {
117 let function_input: syn::ItemFn = syn::parse(item)?;
118
119 let trait_name = format_ident!(
120 "{}Hook",
121 RenameRule::PascalCase.apply_to_field(function_input.sig.ident.to_string())
122 );
123
124 Ok(quote! {
125 impl #trait_name for spacetimedsl::DSLMethodHooks {
126 #function_input
127 }
128 })
129 })
130}
131
132