Skip to main content

rust_rel8_derive/
lib.rs

1use darling::{FromDeriveInput, FromField, FromTypeParam, ast, util::Flag};
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{DeriveInput, parse_macro_input};
5
6#[derive(FromTypeParam, Debug)]
7#[darling(attributes(table))]
8struct GenericOpts {
9    ident: syn::Ident,
10    bounds: Vec<syn::TypeParamBound>,
11
12    /// Use this to mark the type param as being kept in the derived instances.
13    ///
14    /// These should come after the `'scope` lifetime and the `Mode` type param.
15    proxy: Flag,
16}
17
18#[derive(FromField, Debug)]
19#[darling(attributes(table))]
20struct FieldOpts {
21    ident: Option<syn::Ident>,
22
23    /// Marks this field as a nested table, and not a column.
24    ///
25    /// Use this when a field is another table struct.
26    nested: Flag,
27}
28
29#[derive(FromDeriveInput, Debug)]
30#[darling(attributes(table), supports(struct_named))]
31struct TableStructOpts {
32    ident: syn::Ident,
33    data: ast::Data<darling::util::Ignored, FieldOpts>,
34    generics: darling::ast::Generics<darling::ast::GenericParam<GenericOpts>>,
35
36    #[darling(rename = "crate", default)]
37    crate_name: Option<syn::Path>,
38}
39
40/// Derive the necessary traits for Table structs
41///
42/// This must be used on a struct with named fields (for now).
43/// This must be used on a struct which has type parameters in the form: `<'scope, Mode: [TableMode], ...>`
44///
45/// This derive macro has the following attributes:
46///
47/// - `#[table(proxy)]`: Used on any type parameter coming after `Mode`, this
48///     tells us to wire them up. (This will be improved in the future)
49///
50/// - `#[table(nested)]`: Used on a field which has the type `OtherTable<'scope, Mode>`,
51///     instructs the derive macro to treat this field as nested.
52#[proc_macro_derive(TableStruct, attributes(table))]
53pub fn derive_table_struct(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54    let a = match TableStructOpts::from_derive_input(&parse_macro_input!(input as DeriveInput)) {
55        Ok(a) => a,
56        Err(e) => {
57            return proc_macro::TokenStream::from(e.write_errors());
58        }
59    };
60
61    let r = match do_all(&a) {
62        Ok(r) => r,
63        Err(e) => {
64            return proc_macro::TokenStream::from(e.write_errors());
65        }
66    };
67
68    r.into()
69}
70
71fn do_all(a: &TableStructOpts) -> darling::Result<TokenStream> {
72    #[cfg(feature = "sqlx")]
73    let sqlx_table_loader = do_sqlx_table_loader(a)?;
74    #[cfg(not(feature = "sqlx"))]
75    let sqlx_table_loader = quote!();
76
77    let rest = do_rest(a)?;
78
79    Ok(quote! {
80        #sqlx_table_loader
81        #rest
82    })
83}
84
85#[cfg(feature = "sqlx")]
86fn do_sqlx_table_loader(t: &TableStructOpts) -> darling::Result<TokenStream> {
87    use proc_macro2::Span;
88
89    let crate_ = t.crate_name.clone().unwrap_or_else(|| {
90        let mut path = syn::Path::from(syn::Ident::new("rust_rel8", Span::call_site()));
91        path.leading_colon = Some(syn::Token![::](Span::call_site()));
92        path
93    });
94    let ident = &t.ident;
95    let proxied_type_params = t
96        .generics
97        .type_params()
98        .filter(|p| p.proxy.is_present())
99        .map(|p| p.ident.clone())
100        .collect::<Vec<_>>();
101    let proxied_type_param_bounds = t
102        .generics
103        .type_params()
104        .filter(|p| p.proxy.is_present())
105        .flat_map(|p| {
106            p.bounds.iter().map(|b| {
107                let ident = &p.ident;
108                quote! { #ident: #b}
109            })
110        })
111        .collect::<Vec<_>>();
112    let tokens = quote! {
113        impl<#(#proxied_type_params,)*> #crate_::TableLoaderSqlx for #ident<'static, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
114        where
115            #(#proxied_type_param_bounds,)*
116        {
117            fn load<'a>(
118                &self,
119                values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>,
120            ) -> Self::Result {
121                #crate_::TableUsingMapper::wrap_ref(self).load(values)
122            }
123
124            fn skip<'a>(&self, values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>) {
125                #crate_::TableUsingMapper::wrap_ref(self).skip(values)
126            }
127        }
128
129        impl<#(#proxied_type_params,)*> #crate_::TableLoaderManySqlx for #ident<'static, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
130        where
131            #(#proxied_type_param_bounds,)*
132        {
133            fn load_many<'a>(
134                &self,
135                values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>,
136            ) -> Vec<Self::Result> {
137                #crate_::TableUsingMapper::wrap_ref(self).load_many(values)
138            }
139        }
140    };
141
142    Ok(tokens)
143}
144
145fn do_rest(t: &TableStructOpts) -> darling::Result<TokenStream> {
146    use proc_macro2::Span;
147
148    let crate_ = t.crate_name.clone().unwrap_or_else(|| {
149        let mut path = syn::Path::from(syn::Ident::new("rust_rel8", Span::call_site()));
150        path.leading_colon = Some(syn::Token![::](Span::call_site()));
151        path
152    });
153    let ident = &t.ident;
154    let proxied_type_params = t
155        .generics
156        .type_params()
157        .filter(|p| p.proxy.is_present())
158        .map(|p| p.ident.clone())
159        .collect::<Vec<_>>();
160    let proxied_type_param_bounds = t
161        .generics
162        .type_params()
163        .filter(|p| p.proxy.is_present())
164        .flat_map(|p| {
165            p.bounds.iter().map(|b| {
166                let ident = &p.ident;
167                quote! { #ident: #b}
168            })
169        })
170        .collect::<Vec<_>>();
171
172    let fields = t.data.as_ref().take_struct().unwrap();
173
174    let fields = fields
175        .iter()
176        .map(|f| {
177            let ident = f.ident.as_ref().unwrap();
178            let nested = f.nested.is_present();
179
180            (ident, nested)
181        })
182        .collect::<Vec<_>>();
183
184    let map_modes_final_fields = fields.iter().map(|(ident, _)| quote! { #ident });
185    let map_modes_final = quote! {
186        #ident {
187            #(#map_modes_final_fields,)*
188        }
189    };
190
191    let map_modes_fields = fields.iter().map(|(ident, is_nested)| {
192        if *is_nested {
193            quote! { let #ident = self.#ident.map_modes(mapper); }
194        } else {
195            quote! { let #ident = mapper.map_mode(self.#ident); }
196        }
197    });
198
199    let map_modes_ref_fields = fields.iter().map(|(ident, is_nested)| {
200        if *is_nested {
201            quote! { let #ident = self.#ident.map_modes_ref(mapper); }
202        } else {
203            quote! { let #ident = mapper.map_mode_ref(&self.#ident); }
204        }
205    });
206
207    let map_modes_mut_fields = fields.iter().map(|(ident, is_nested)| {
208        if *is_nested {
209            quote! { let #ident = self.#ident.map_modes_mut(mapper); }
210        } else {
211            quote! { let #ident = mapper.map_mode_mut(&mut self.#ident); }
212        }
213    });
214
215    let with_lt_fields_with_lt = fields
216        .iter()
217        .map(|(ident, _is_nested)| {
218            quote! { #ident: self.#ident.with_lt(marker) }
219        })
220        .collect::<Vec<_>>();
221
222    let shorten_lifetime_fields_shorten = fields
223        .iter()
224        .map(|(ident, _is_nested)| {
225            quote! { #ident: self.#ident.shorten_lifetime() }
226        })
227        .collect::<Vec<_>>();
228
229    let shorten_lifetime_fields_noop = fields
230        .iter()
231        .map(|(ident, is_nested)| {
232            if *is_nested {
233                quote! { #ident: self.#ident.shorten_lifetime() }
234            } else {
235                quote! { #ident: self.#ident }
236            }
237        })
238        .collect::<Vec<_>>();
239
240    let tokens = quote! {
241        impl<'scope, #(#proxied_type_params,)*> #crate_::ForLifetimeTable for #ident<'scope, #crate_::ExprMode, #(#proxied_type_params,)*>
242        where
243            for<'a> #ident<'a, #crate_::ExprMode, #(#proxied_type_params,)*>: #crate_::Table<'a>,
244            #(#proxied_type_param_bounds,)*
245        {
246            type WithLt<'lt> = #ident<'lt, #crate_::ExprMode, #(#proxied_type_params,)*>;
247
248            fn with_lt<'lt>(self, marker: &mut #crate_::WithLtMarker) -> Self::WithLt<'lt> {
249                #ident {
250                    #(#with_lt_fields_with_lt,)*
251                }
252            }
253        }
254
255        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::NameMode, #(#proxied_type_params,)*>
256        where
257            #(#proxied_type_param_bounds,)*
258        {
259            type Shortened<'small> = #ident<'small, #crate_::NameMode, #(#proxied_type_params,)*>
260            where
261                Self: 'small
262                ;
263
264            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
265            where
266                Self: 'large,
267            {
268                #ident {
269                    #(#shorten_lifetime_fields_noop,)*
270                }
271            }
272        }
273
274        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::ValueMode, #(#proxied_type_params,)*>
275        where
276            #(#proxied_type_param_bounds,)*
277        {
278            type Shortened<'small> = #ident<'small, #crate_::ValueMode, #(#proxied_type_params,)*>
279            where
280                Self: 'small
281                ;
282
283            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
284            where
285                Self: 'large,
286            {
287                #ident {
288                    #(#shorten_lifetime_fields_noop,)*
289                }
290            }
291        }
292
293        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::EmptyMode, #(#proxied_type_params,)*>
294        where
295            #(#proxied_type_param_bounds,)*
296        {
297            type Shortened<'small> = #ident<'small, #crate_::EmptyMode, #(#proxied_type_params,)*>
298            where
299                Self: 'small
300                ;
301
302            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
303            where
304                Self: 'large,
305            {
306                #ident {
307                    #(#shorten_lifetime_fields_noop,)*
308                }
309            }
310        }
311
312        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::ExprMode, #(#proxied_type_params,)*>
313        where
314            #(#proxied_type_param_bounds,)*
315        {
316            type Shortened<'small> = #ident<'small, #crate_::ExprMode, #(#proxied_type_params,)*>
317            where
318                Self: 'small
319                ;
320
321            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
322            where
323                Self: 'large,
324            {
325                #ident {
326                    #(#shorten_lifetime_fields_shorten,)*
327                }
328            }
329        }
330
331        impl<'scope, T: #crate_::TableMode, #(#proxied_type_params,)*> #crate_::TableHKT for #ident<'scope, T, #(#proxied_type_params,)*>
332        where
333            #(#proxied_type_param_bounds,)*
334        {
335            type InMode<Mode: #crate_::TableMode> = #ident<'scope, Mode, #(#proxied_type_params,)*>;
336
337            type Mode = T;
338        }
339
340        impl<'scope, Mode: #crate_::TableMode, #(#proxied_type_params,)*> #crate_::MapTable<'scope> for #ident<'scope, Mode, #(#proxied_type_params,)*>
341        where
342            #(#proxied_type_param_bounds,)*
343        {
344            fn map_modes<Mapper, DestMode>(self, mapper: &mut Mapper) -> Self::InMode<DestMode>
345            where
346                Mapper: #crate_::ModeMapper<'scope, Self::Mode, DestMode>,
347                DestMode: #crate_::TableMode,
348            {
349                #(#map_modes_fields)*
350
351                #map_modes_final
352            }
353
354            fn map_modes_ref<Mapper, DestMode>(&self, mapper: &mut Mapper) -> Self::InMode<DestMode>
355            where
356                Mapper: #crate_::ModeMapperRef<'scope, Self::Mode, DestMode>,
357                DestMode: #crate_::TableMode,
358            {
359                #(#map_modes_ref_fields)*
360
361                #map_modes_final
362            }
363
364            fn map_modes_mut<Mapper, DestMode>(&mut self, mapper: &mut Mapper) -> Self::InMode<DestMode>
365            where
366                Mapper: #crate_::ModeMapperMut<'scope, Self::Mode, DestMode>,
367                DestMode: #crate_::TableMode,
368            {
369                #(#map_modes_mut_fields)*
370
371                #map_modes_final
372            }
373        }
374
375        impl<'scope, #(#proxied_type_params,)*> #crate_::Table<'scope> for #ident<'scope, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
376        where
377            #(#proxied_type_param_bounds,)*
378        {
379            type Result = <Self as #crate_::TableHKT>::InMode<#crate_::table_modes::ValueMode>;
380
381            fn visit(&self, f: &mut impl FnMut(&#crate_::ErasedExpr), mode: VisitTableMode) {
382                #crate_::TableUsingMapper::wrap_ref(self).visit(f, mode)
383            }
384
385            fn visit_mut(&mut self, f: &mut impl FnMut(&mut #crate_::ErasedExpr), mode: VisitTableMode) {
386                #crate_::TableUsingMapper::wrap_mut(self).visit_mut(f, mode)
387            }
388        }
389    };
390
391    Ok(tokens)
392}