tinybase_derive/
lib.rs

1mod utils;
2use core::panic;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, FieldsNamed, Ident};
7use utils::{get_list_attr, has_attribute, validate_attributes};
8
9#[proc_macro_derive(Repository, attributes(index, unique, check))]
10pub fn repository(input: TokenStream) -> TokenStream {
11    let ast = parse_macro_input!(input as DeriveInput);
12    let name = ast.ident;
13
14    let fields = match ast.data {
15        Data::Struct(syn::DataStruct {
16            fields: Fields::Named(FieldsNamed { ref named, .. }),
17            ..
18        }) => named,
19        _ => panic!("can only derive on a struct"),
20    };
21
22    let (index_names, index_members, by_index, index_initializers) =
23        match process_fields(&name, fields.iter()) {
24            Ok(v) => v,
25            Err(e) => return e,
26        };
27
28    if let Err(tokens) =
29        validate_attributes(&ast.attrs, None, &[("check", true)], &["unique", "index"])
30    {
31        return tokens.into();
32    }
33
34    let checks: Vec<proc_macro2::TokenStream> = match get_list_attr(&ast.attrs, "check") {
35        Ok(v) => v,
36        Err(err) => return err.into(),
37    }
38    .iter()
39    .map(|check_fn| {
40        return quote! {
41            _table.constraint(tinybase::Constraint::check(#check_fn))?;
42        };
43    })
44    .collect();
45
46    let vis = ast.vis.clone();
47    let wrapper_name = syn::Ident::new(&format!("{}Repository", name.to_string()), name.span());
48
49    let expanded = quote! {
50        #[derive(Clone)]
51        #vis struct #wrapper_name {
52            _table: tinybase::Table<#name>,
53            #(#index_members)*
54        }
55
56        impl std::ops::Deref for #wrapper_name {
57            type Target = tinybase::Table<#name>;
58
59            fn deref(&self) -> &Self::Target {
60                &self._table
61            }
62        }
63
64        impl #wrapper_name {
65            #(#by_index)*
66        }
67
68        impl #name {
69            pub fn init(db: &tinybase::TinyBase, name: &str) -> tinybase::DbResult<#wrapper_name> {
70                let _table: tinybase::Table<#name> = db.open_table(name)?;
71                #(#index_initializers);*
72                #(#checks)*
73
74                Ok(#wrapper_name {
75                    _table, #(#index_names),*
76                })
77            }
78        }
79    };
80
81    expanded.into()
82}
83
84/// Process fields and decide what should be generated for each field.
85fn process_fields<'a>(
86    struct_name: &proc_macro2::Ident,
87    fields: impl Iterator<Item = &'a Field>,
88) -> Result<
89    (
90        Vec<Ident>,
91        Vec<proc_macro2::TokenStream>,
92        Vec<proc_macro2::TokenStream>,
93        Vec<proc_macro2::TokenStream>,
94    ),
95    TokenStream,
96> {
97    let mut index_names = vec![];
98    let mut index_members = vec![];
99
100    let mut by_index = vec![];
101    let mut index_initializers = vec![];
102
103    for field in fields {
104        validate_attributes(
105            &field.attrs,
106            Some("index"),
107            &[("unique", false), ("index", false)], // index is here as a hack to prevent allowing list.
108            &["check"],
109        )?;
110
111        if has_attribute(&field.attrs, "index").is_some() {
112            let (field_name, type_name) = (field.ident.as_ref().unwrap(), &field.ty);
113
114            index_names.push(field_name.clone());
115
116            index_members.push(quote! {
117                pub #field_name: tinybase::Index<#struct_name, #type_name>,
118            });
119
120            let methods = create_methods(field_name, type_name, struct_name);
121
122            by_index.push(methods);
123
124            let field_str = format!("{}", field_name);
125
126            index_initializers.push(quote! {
127                let #field_name = _table.create_index(#field_str, |record| record.#field_name.clone())?;
128            });
129
130            if has_attribute(&field.attrs, "unique").is_some() {
131                index_initializers.push(quote! {
132                    _table.constraint(tinybase::Constraint::unique(&#field_name))?;
133                })
134            }
135        }
136    }
137
138    Ok((index_names, index_members, by_index, index_initializers))
139}
140
141/// Create methods for an index.
142fn create_methods(
143    field_name: &Ident,
144    type_name: &syn::Type,
145    name: &Ident,
146) -> proc_macro2::TokenStream {
147    let find_method = syn::Ident::new(&format!("find_by_{}", field_name), field_name.span());
148    let delete_method = syn::Ident::new(&format!("delete_by_{}", field_name), field_name.span());
149    let update_method = syn::Ident::new(&format!("update_by_{}", field_name), field_name.span());
150
151    quote! {
152        pub fn #find_method(&self, #field_name: #type_name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
153            self.#field_name.select(&#field_name)
154        }
155
156        pub fn #delete_method(&self, #field_name: #type_name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
157            self.#field_name.delete(&#field_name)
158        }
159
160        pub fn #update_method(&self, #field_name: #type_name, updater: fn(#name) -> #name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
161            let records: Vec<u64> = self.#field_name.select(&#field_name)?.iter().map(|r| r.id).collect();
162            self._table.update(&records, updater)
163        }
164    }
165}