soliddb_derive/
lib.rs

1use darling::{
2    ast::{Data, Fields},
3    FromDeriveInput, FromField, FromVariant,
4};
5use proc_macro::TokenStream;
6use quote::{format_ident, quote};
7use syn::parse_macro_input;
8
9#[derive(FromDeriveInput)]
10#[darling(attributes(solid), supports(enum_any, struct_named))]
11struct ItemOpts {
12    ident: syn::Ident,
13    data: Data<VariantOpts, FieldOpts>,
14    table: u32,
15}
16
17#[derive(FromVariant)]
18#[darling(attributes(solid))]
19struct VariantOpts {
20    fields: Fields<FieldOpts>,
21}
22
23#[derive(Clone, FromField)]
24#[darling(attributes(solid))]
25struct FieldOpts {
26    ident: Option<syn::Ident>,
27    ty: syn::Type,
28    #[darling(default)]
29    unique: bool,
30    #[darling(default)]
31    indexed: bool,
32}
33
34#[derive(FromDeriveInput)]
35#[darling(attributes(solid), supports(any))]
36struct SingleOpts {
37    ident: syn::Ident,
38    single: u32,
39}
40
41#[proc_macro_derive(Single, attributes(solid))]
42pub fn derive_single(input: TokenStream) -> TokenStream {
43    let input = parse_macro_input!(input);
44    let SingleOpts { ident, single } = match SingleOpts::from_derive_input(&input) {
45        Ok(v) => v,
46        Err(e) => return TokenStream::from(e.write_errors()),
47    };
48
49    let output = quote! {
50        impl ::soliddb::Single for #ident {
51            const SINGLE: u32 = #single;
52        }
53    };
54    output.into()
55}
56
57#[proc_macro_derive(Table, attributes(solid))]
58pub fn derive_item(input: TokenStream) -> TokenStream {
59    let input = parse_macro_input!(input);
60    let ItemOpts { ident, table, data } = match ItemOpts::from_derive_input(&input) {
61        Ok(v) => v,
62        Err(e) => return TokenStream::from(e.write_errors()),
63    };
64
65    if table == 0 {
66        panic!("table 0 is reserved");
67    }
68
69    match data {
70        Data::Struct(fields) => gen_struct(ident, table, fields),
71        Data::Enum(variants) => gen_enum(ident, table, variants),
72    }
73}
74
75fn gen_struct(ident: syn::Ident, table: u32, fields: Fields<FieldOpts>) -> TokenStream {
76    let unique_fields = find_unique_fields(&fields);
77    let unique_field_names: Vec<_> = unique_fields
78        .iter()
79        .cloned()
80        .map(|field| field.ident.unwrap())
81        .collect();
82
83    let indexed_fields = find_indexed_fields(&fields);
84    let indexed_field_names: Vec<_> = indexed_fields
85        .iter()
86        .cloned()
87        .map(|field| field.ident.unwrap())
88        .collect();
89
90    if unique_fields.len() > 126 {
91        panic!("only 126 unique indices per table are allowed");
92    }
93    if indexed_fields.len() > 126 {
94        panic!("only 126 non unique indices per table are allowed");
95    }
96
97    let unique_keys: Vec<_> = (1u8..).take(unique_fields.len()).collect();
98    let indexed_keys: Vec<_> = (128u8..).take(indexed_fields.len()).collect();
99
100    let unique_getters =
101        unique_keys
102            .iter()
103            .copied()
104            .zip(unique_fields.iter())
105            .map(|(index, field)| {
106                unique_getter_method(index, field.ident.as_ref().unwrap(), &field.ty)
107            });
108
109    let indexed_getters =
110        indexed_keys
111            .iter()
112            .copied()
113            .zip(indexed_fields.iter())
114            .map(|(index, field)| {
115                indexed_getter_method(index, field.ident.as_ref().unwrap(), &field.ty)
116            });
117
118    let unique_value_func = if unique_keys.is_empty() {
119        quote! {}
120    } else {
121        quote! {
122            fn unique_value(&self, index: u8) -> Vec<u8> {
123                match index {
124                    #(#unique_keys => ::soliddb::IndexValue::as_bytes(&self.#unique_field_names),)*
125                    _ => unreachable!("no unique value for index {}", index),
126                }
127            }
128        }
129    };
130
131    let non_unique_value_func = if indexed_fields.is_empty() {
132        quote! {}
133    } else {
134        quote! {
135            fn non_unique_value(&self, index: u8) -> Vec<u8> {
136                match index {
137                    #(#indexed_keys => ::soliddb::IndexValue::as_bytes(&self.#indexed_field_names),)*
138                    _ => unreachable!("no non unique value for index {}", index),
139                }
140            }
141        }
142    };
143
144    let output = quote! {
145        impl ::soliddb::Table for #ident {
146            const TABLE: u32 = #table;
147            const UNIQUE_INDICES: &'static [u8] = &[#(#unique_keys),*];
148            const NON_UNIQUE_INDICES: &'static [u8] = &[#(#indexed_keys),*];
149
150            #unique_value_func
151            #non_unique_value_func
152        }
153
154        impl #ident {
155            #(#unique_getters)*
156            #(#indexed_getters)*
157        }
158    };
159    output.into()
160}
161
162fn gen_enum(ident: syn::Ident, table: u32, variants: Vec<VariantOpts>) -> TokenStream {
163    for variant in variants {
164        if !find_unique_fields(&variant.fields).is_empty() {
165            panic!("unique fields are not allowed for enums");
166        }
167
168        if !find_indexed_fields(&variant.fields).is_empty() {
169            panic!("indexed fields are not allowed for enums");
170        }
171    }
172
173    let output = quote! {
174        impl ::soliddb::Table for #ident {
175            const TABLE: u32 = #table;
176        }
177    };
178    output.into()
179}
180
181fn find_unique_fields(fields: &Fields<FieldOpts>) -> Vec<FieldOpts> {
182    fields
183        .iter()
184        .cloned()
185        .filter(|field| field.unique)
186        .collect()
187}
188
189fn find_indexed_fields(fields: &Fields<FieldOpts>) -> Vec<FieldOpts> {
190    fields
191        .iter()
192        .cloned()
193        .filter(|field| field.indexed)
194        .collect()
195}
196
197fn unique_getter_method(index: u8, field: &syn::Ident, ty: &syn::Type) -> proc_macro2::TokenStream {
198    let method = format_ident!("get_by_{field}");
199
200    quote! {
201        pub fn #method(db: &::soliddb::DB, value: &#ty) -> ::soliddb::Result<::soliddb::WithId<Self>> {
202            let value = <#ty as ::soliddb::IndexValue>::as_bytes(value);
203            Self::get_by_unique_index(db, #index, &value)
204        }
205    }
206}
207
208fn indexed_getter_method(
209    index: u8,
210    field: &syn::Ident,
211    ty: &syn::Type,
212) -> proc_macro2::TokenStream {
213    let method = format_ident!("get_by_{field}");
214
215    quote! {
216        pub fn #method(db: &::soliddb::DB, value: &#ty) -> ::soliddb::Result<Vec<::soliddb::WithId<Self>>> {
217            let value = <#ty as ::soliddb::IndexValue>::as_bytes(value);
218            Self::get_by_non_unique_index(db, #index, &value)
219        }
220    }
221}