1mod utils;
2use core::panic;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, FieldsNamed, Ident};
7use utils::{get_list_attr, has_attribute, validate_attributes};
8
9#[proc_macro_derive(Repository, attributes(index, unique, check))]
10pub fn repository(input: TokenStream) -> TokenStream {
11 let ast = parse_macro_input!(input as DeriveInput);
12 let name = ast.ident;
13
14 let fields = match ast.data {
15 Data::Struct(syn::DataStruct {
16 fields: Fields::Named(FieldsNamed { ref named, .. }),
17 ..
18 }) => named,
19 _ => panic!("can only derive on a struct"),
20 };
21
22 let (index_names, index_members, by_index, index_initializers) =
23 match process_fields(&name, fields.iter()) {
24 Ok(v) => v,
25 Err(e) => return e,
26 };
27
28 if let Err(tokens) =
29 validate_attributes(&ast.attrs, None, &[("check", true)], &["unique", "index"])
30 {
31 return tokens.into();
32 }
33
34 let checks: Vec<proc_macro2::TokenStream> = match get_list_attr(&ast.attrs, "check") {
35 Ok(v) => v,
36 Err(err) => return err.into(),
37 }
38 .iter()
39 .map(|check_fn| {
40 return quote! {
41 _table.constraint(tinybase::Constraint::check(#check_fn))?;
42 };
43 })
44 .collect();
45
46 let vis = ast.vis.clone();
47 let wrapper_name = syn::Ident::new(&format!("{}Repository", name.to_string()), name.span());
48
49 let expanded = quote! {
50 #[derive(Clone)]
51 #vis struct #wrapper_name {
52 _table: tinybase::Table<#name>,
53 #(#index_members)*
54 }
55
56 impl std::ops::Deref for #wrapper_name {
57 type Target = tinybase::Table<#name>;
58
59 fn deref(&self) -> &Self::Target {
60 &self._table
61 }
62 }
63
64 impl #wrapper_name {
65 #(#by_index)*
66 }
67
68 impl #name {
69 pub fn init(db: &tinybase::TinyBase, name: &str) -> tinybase::DbResult<#wrapper_name> {
70 let _table: tinybase::Table<#name> = db.open_table(name)?;
71 #(#index_initializers);*
72 #(#checks)*
73
74 Ok(#wrapper_name {
75 _table, #(#index_names),*
76 })
77 }
78 }
79 };
80
81 expanded.into()
82}
83
84fn process_fields<'a>(
86 struct_name: &proc_macro2::Ident,
87 fields: impl Iterator<Item = &'a Field>,
88) -> Result<
89 (
90 Vec<Ident>,
91 Vec<proc_macro2::TokenStream>,
92 Vec<proc_macro2::TokenStream>,
93 Vec<proc_macro2::TokenStream>,
94 ),
95 TokenStream,
96> {
97 let mut index_names = vec![];
98 let mut index_members = vec![];
99
100 let mut by_index = vec![];
101 let mut index_initializers = vec![];
102
103 for field in fields {
104 validate_attributes(
105 &field.attrs,
106 Some("index"),
107 &[("unique", false), ("index", false)], &["check"],
109 )?;
110
111 if has_attribute(&field.attrs, "index").is_some() {
112 let (field_name, type_name) = (field.ident.as_ref().unwrap(), &field.ty);
113
114 index_names.push(field_name.clone());
115
116 index_members.push(quote! {
117 pub #field_name: tinybase::Index<#struct_name, #type_name>,
118 });
119
120 let methods = create_methods(field_name, type_name, struct_name);
121
122 by_index.push(methods);
123
124 let field_str = format!("{}", field_name);
125
126 index_initializers.push(quote! {
127 let #field_name = _table.create_index(#field_str, |record| record.#field_name.clone())?;
128 });
129
130 if has_attribute(&field.attrs, "unique").is_some() {
131 index_initializers.push(quote! {
132 _table.constraint(tinybase::Constraint::unique(&#field_name))?;
133 })
134 }
135 }
136 }
137
138 Ok((index_names, index_members, by_index, index_initializers))
139}
140
141fn create_methods(
143 field_name: &Ident,
144 type_name: &syn::Type,
145 name: &Ident,
146) -> proc_macro2::TokenStream {
147 let find_method = syn::Ident::new(&format!("find_by_{}", field_name), field_name.span());
148 let delete_method = syn::Ident::new(&format!("delete_by_{}", field_name), field_name.span());
149 let update_method = syn::Ident::new(&format!("update_by_{}", field_name), field_name.span());
150
151 quote! {
152 pub fn #find_method(&self, #field_name: #type_name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
153 self.#field_name.select(&#field_name)
154 }
155
156 pub fn #delete_method(&self, #field_name: #type_name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
157 self.#field_name.delete(&#field_name)
158 }
159
160 pub fn #update_method(&self, #field_name: #type_name, updater: fn(#name) -> #name) -> tinybase::result::DbResult<Vec<tinybase::Record<#name>>> {
161 let records: Vec<u64> = self.#field_name.select(&#field_name)?.iter().map(|r| r.id).collect();
162 self._table.update(&records, updater)
163 }
164 }
165}