1use darling::{
2 ast::{Data, Fields},
3 FromDeriveInput, FromField, FromVariant,
4};
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use syn::parse_macro_input;
8
9#[derive(FromDeriveInput)]
10#[darling(attributes(solid), supports(enum_any, struct_named))]
11struct ItemOpts {
12 ident: syn::Ident,
13 data: Data<VariantOpts, FieldOpts>,
14 table: u32,
15}
16
17#[derive(FromVariant)]
18#[darling(attributes(solid))]
19struct VariantOpts {
20 fields: Fields<FieldOpts>,
21}
22
23#[derive(Clone, FromField)]
24#[darling(attributes(solid))]
25struct FieldOpts {
26 ident: Option<syn::Ident>,
27 ty: syn::Type,
28 #[darling(default)]
29 unique: bool,
30 #[darling(default)]
31 indexed: bool,
32}
33
34#[derive(FromDeriveInput)]
35#[darling(attributes(solid), supports(any))]
36struct SingleOpts {
37 ident: syn::Ident,
38 single: u32,
39}
40
41#[proc_macro_derive(Single, attributes(solid))]
42pub fn derive_single(input: TokenStream) -> TokenStream {
43 let input = parse_macro_input!(input);
44 let SingleOpts { ident, single } = match SingleOpts::from_derive_input(&input) {
45 Ok(v) => v,
46 Err(e) => return TokenStream::from(e.write_errors()),
47 };
48
49 let output = quote! {
50 impl ::soliddb::Single for #ident {
51 const SINGLE: u32 = #single;
52 }
53 };
54 output.into()
55}
56
57#[proc_macro_derive(Table, attributes(solid))]
58pub fn derive_item(input: TokenStream) -> TokenStream {
59 let input = parse_macro_input!(input);
60 let ItemOpts { ident, table, data } = match ItemOpts::from_derive_input(&input) {
61 Ok(v) => v,
62 Err(e) => return TokenStream::from(e.write_errors()),
63 };
64
65 if table == 0 {
66 panic!("table 0 is reserved");
67 }
68
69 match data {
70 Data::Struct(fields) => gen_struct(ident, table, fields),
71 Data::Enum(variants) => gen_enum(ident, table, variants),
72 }
73}
74
75fn gen_struct(ident: syn::Ident, table: u32, fields: Fields<FieldOpts>) -> TokenStream {
76 let unique_fields = find_unique_fields(&fields);
77 let unique_field_names: Vec<_> = unique_fields
78 .iter()
79 .cloned()
80 .map(|field| field.ident.unwrap())
81 .collect();
82
83 let indexed_fields = find_indexed_fields(&fields);
84 let indexed_field_names: Vec<_> = indexed_fields
85 .iter()
86 .cloned()
87 .map(|field| field.ident.unwrap())
88 .collect();
89
90 if unique_fields.len() > 126 {
91 panic!("only 126 unique indices per table are allowed");
92 }
93 if indexed_fields.len() > 126 {
94 panic!("only 126 non unique indices per table are allowed");
95 }
96
97 let unique_keys: Vec<_> = (1u8..).take(unique_fields.len()).collect();
98 let indexed_keys: Vec<_> = (128u8..).take(indexed_fields.len()).collect();
99
100 let unique_getters =
101 unique_keys
102 .iter()
103 .copied()
104 .zip(unique_fields.iter())
105 .map(|(index, field)| {
106 unique_getter_method(index, field.ident.as_ref().unwrap(), &field.ty)
107 });
108
109 let indexed_getters =
110 indexed_keys
111 .iter()
112 .copied()
113 .zip(indexed_fields.iter())
114 .map(|(index, field)| {
115 indexed_getter_method(index, field.ident.as_ref().unwrap(), &field.ty)
116 });
117
118 let unique_value_func = if unique_keys.is_empty() {
119 quote! {}
120 } else {
121 quote! {
122 fn unique_value(&self, index: u8) -> Vec<u8> {
123 match index {
124 #(#unique_keys => ::soliddb::IndexValue::as_bytes(&self.#unique_field_names),)*
125 _ => unreachable!("no unique value for index {}", index),
126 }
127 }
128 }
129 };
130
131 let non_unique_value_func = if indexed_fields.is_empty() {
132 quote! {}
133 } else {
134 quote! {
135 fn non_unique_value(&self, index: u8) -> Vec<u8> {
136 match index {
137 #(#indexed_keys => ::soliddb::IndexValue::as_bytes(&self.#indexed_field_names),)*
138 _ => unreachable!("no non unique value for index {}", index),
139 }
140 }
141 }
142 };
143
144 let output = quote! {
145 impl ::soliddb::Table for #ident {
146 const TABLE: u32 = #table;
147 const UNIQUE_INDICES: &'static [u8] = &[#(#unique_keys),*];
148 const NON_UNIQUE_INDICES: &'static [u8] = &[#(#indexed_keys),*];
149
150 #unique_value_func
151 #non_unique_value_func
152 }
153
154 impl #ident {
155 #(#unique_getters)*
156 #(#indexed_getters)*
157 }
158 };
159 output.into()
160}
161
162fn gen_enum(ident: syn::Ident, table: u32, variants: Vec<VariantOpts>) -> TokenStream {
163 for variant in variants {
164 if !find_unique_fields(&variant.fields).is_empty() {
165 panic!("unique fields are not allowed for enums");
166 }
167
168 if !find_indexed_fields(&variant.fields).is_empty() {
169 panic!("indexed fields are not allowed for enums");
170 }
171 }
172
173 let output = quote! {
174 impl ::soliddb::Table for #ident {
175 const TABLE: u32 = #table;
176 }
177 };
178 output.into()
179}
180
181fn find_unique_fields(fields: &Fields<FieldOpts>) -> Vec<FieldOpts> {
182 fields
183 .iter()
184 .cloned()
185 .filter(|field| field.unique)
186 .collect()
187}
188
189fn find_indexed_fields(fields: &Fields<FieldOpts>) -> Vec<FieldOpts> {
190 fields
191 .iter()
192 .cloned()
193 .filter(|field| field.indexed)
194 .collect()
195}
196
197fn unique_getter_method(index: u8, field: &syn::Ident, ty: &syn::Type) -> proc_macro2::TokenStream {
198 let method = format_ident!("get_by_{field}");
199
200 quote! {
201 pub fn #method(db: &::soliddb::DB, value: &#ty) -> ::soliddb::Result<::soliddb::WithId<Self>> {
202 let value = <#ty as ::soliddb::IndexValue>::as_bytes(value);
203 Self::get_by_unique_index(db, #index, &value)
204 }
205 }
206}
207
208fn indexed_getter_method(
209 index: u8,
210 field: &syn::Ident,
211 ty: &syn::Type,
212) -> proc_macro2::TokenStream {
213 let method = format_ident!("get_by_{field}");
214
215 quote! {
216 pub fn #method(db: &::soliddb::DB, value: &#ty) -> ::soliddb::Result<Vec<::soliddb::WithId<Self>>> {
217 let value = <#ty as ::soliddb::IndexValue>::as_bytes(value);
218 Self::get_by_non_unique_index(db, #index, &value)
219 }
220 }
221}