Skip to main content

rust_rel8_derive/
lib.rs

1use darling::{FromDeriveInput, FromField, FromTypeParam, ast, util::Flag};
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{DeriveInput, parse_macro_input};
5
6#[derive(FromTypeParam, Debug)]
7#[darling(attributes(table))]
8struct GenericOpts {
9    ident: syn::Ident,
10    bounds: Vec<syn::TypeParamBound>,
11
12    /// Use this to mark the type param as being kept in the derived instances.
13    ///
14    /// These should come after the `'scope` lifetime and the `Mode` type param.
15    proxy: Flag,
16}
17
18#[derive(FromField, Debug)]
19#[darling(attributes(table))]
20struct FieldOpts {
21    ident: Option<syn::Ident>,
22
23    /// Marks this field as a nested table, and not a column.
24    ///
25    /// Use this when a field is another table struct.
26    nested: Flag,
27}
28
29#[derive(FromDeriveInput, Debug)]
30#[darling(attributes(table), supports(struct_named))]
31struct TableStructOpts {
32    ident: syn::Ident,
33    data: ast::Data<darling::util::Ignored, FieldOpts>,
34    generics: darling::ast::Generics<darling::ast::GenericParam<GenericOpts>>,
35
36    #[darling(rename = "crate", default)]
37    crate_name: Option<syn::Path>,
38}
39
40/// Derive the necessary traits for Table structs
41///
42/// This must be used on a struct with named fields (for now).
43/// This must be used on a struct which has type parameters in the form: `<'scope, Mode: TableMode, ...>`
44///
45/// This derive macro has the following attributes:
46///
47/// - `#[table(proxy)]`: Used on any type parameter coming after `Mode`, this
48///     tells us to wire them up. (This will be improved in the future)
49///
50/// - `#[table(nested)]`: Used on a field which has the type `OtherTable<'scope, Mode>`,
51///     instructs the derive macro to treat this field as nested.
52#[proc_macro_derive(TableStruct, attributes(table))]
53pub fn derive_table_struct(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54    let a = match TableStructOpts::from_derive_input(&parse_macro_input!(input as DeriveInput)) {
55        Ok(a) => a,
56        Err(e) => {
57            return proc_macro::TokenStream::from(e.write_errors());
58        }
59    };
60
61    let r = match do_all(&a) {
62        Ok(r) => r,
63        Err(e) => {
64            return proc_macro::TokenStream::from(e.write_errors());
65        }
66    };
67
68    r.into()
69}
70
71fn do_all(a: &TableStructOpts) -> darling::Result<TokenStream> {
72    #[cfg(feature = "sqlx")]
73    let sqlx_table_loader = do_sqlx_table_loader(a)?;
74    #[cfg(not(feature = "sqlx"))]
75    let sqlx_table_loader = quote!();
76
77    let rest = do_rest(a)?;
78
79    Ok(quote! {
80        #sqlx_table_loader
81        #rest
82    })
83}
84
85#[cfg(feature = "sqlx")]
86fn do_sqlx_table_loader(t: &TableStructOpts) -> darling::Result<TokenStream> {
87    use proc_macro2::Span;
88
89    let crate_ = t.crate_name.clone().unwrap_or_else(|| {
90        let mut path = syn::Path::from(syn::Ident::new("rust_rel8", Span::call_site()));
91        path.leading_colon = Some(syn::Token![::](Span::call_site()));
92        path
93    });
94    let ident = &t.ident;
95    let proxied_type_params = t
96        .generics
97        .type_params()
98        .filter(|p| p.proxy.is_present())
99        .map(|p| p.ident.clone())
100        .collect::<Vec<_>>();
101    let proxied_type_param_bounds = t
102        .generics
103        .type_params()
104        .filter(|p| p.proxy.is_present())
105        .flat_map(|p| {
106            p.bounds.iter().map(|b| {
107                let ident = &p.ident;
108                quote! { #ident: #b}
109            })
110        })
111        .collect::<Vec<_>>();
112    let tokens = quote! {
113        impl<#(#proxied_type_params,)*> #crate_::TableLoaderSqlx for #ident<'static, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
114        where
115            #(#proxied_type_param_bounds,)*
116        {
117            fn load<'a>(
118                &self,
119                values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>,
120            ) -> Self::Result {
121                #crate_::TableUsingMapper::wrap_ref(self).load(values)
122            }
123
124            fn skip<'a>(&self, values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>) {
125                #crate_::TableUsingMapper::wrap_ref(self).skip(values)
126            }
127        }
128
129        impl<#(#proxied_type_params,)*> #crate_::TableLoaderManySqlx for #ident<'static, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
130        where
131            #(#proxied_type_param_bounds,)*
132        {
133            fn load_many<'a>(
134                &self,
135                values: &mut impl Iterator<Item = ::sqlx::postgres::PgValueRef<'a>>,
136            ) -> Vec<Self::Result> {
137                #crate_::TableUsingMapper::wrap_ref(self).load_many(values)
138            }
139        }
140    };
141
142    Ok(tokens)
143}
144
145fn do_rest(t: &TableStructOpts) -> darling::Result<TokenStream> {
146    use proc_macro2::Span;
147
148    let crate_ = t.crate_name.clone().unwrap_or_else(|| {
149        let mut path = syn::Path::from(syn::Ident::new("rust_rel8", Span::call_site()));
150        path.leading_colon = Some(syn::Token![::](Span::call_site()));
151        path
152    });
153    let ident = &t.ident;
154    let proxied_type_params = t
155        .generics
156        .type_params()
157        .filter(|p| p.proxy.is_present())
158        .map(|p| p.ident.clone())
159        .collect::<Vec<_>>();
160    let proxied_type_param_bounds = t
161        .generics
162        .type_params()
163        .filter(|p| p.proxy.is_present())
164        .flat_map(|p| {
165            p.bounds.iter().map(|b| {
166                let ident = &p.ident;
167                quote! { #ident: #b}
168            })
169        })
170        .collect::<Vec<_>>();
171
172    let fields = t.data.as_ref().take_struct().unwrap();
173
174    let fields = fields
175        .iter()
176        .map(|f| {
177            let ident = f.ident.as_ref().unwrap();
178            let nested = f.nested.is_present();
179
180            (ident, nested)
181        })
182        .collect::<Vec<_>>();
183
184    let map_modes_final_fields = fields.iter().map(|(ident, _)| quote! { #ident });
185    let map_modes_final = quote! {
186        #ident {
187            #(#map_modes_final_fields,)*
188        }
189    };
190
191    let map_modes_fields = fields.iter().map(|(ident, is_nested)| {
192        if *is_nested {
193            quote! { let #ident = self.#ident.map_modes(mapper); }
194        } else {
195            quote! { let #ident = mapper.map_mode(self.#ident); }
196        }
197    });
198
199    let map_modes_ref_fields = fields.iter().map(|(ident, is_nested)| {
200        if *is_nested {
201            quote! { let #ident = self.#ident.map_modes_ref(mapper); }
202        } else {
203            quote! { let #ident = mapper.map_mode_ref(&self.#ident); }
204        }
205    });
206
207    let map_modes_mut_fields = fields.iter().map(|(ident, is_nested)| {
208        if *is_nested {
209            quote! { let #ident = self.#ident.map_modes_mut(mapper); }
210        } else {
211            quote! { let #ident = mapper.map_mode_mut(&mut self.#ident); }
212        }
213    });
214
215    let with_lt_fields_with_lt = fields
216        .iter()
217        .map(|(ident, _is_nested)| {
218            quote! { #ident: self.#ident.with_lt(marker) }
219        })
220        .collect::<Vec<_>>();
221
222    let unwith_lt_fields_with_lt = fields
223        .iter()
224        .map(|(ident, _is_nested)| {
225            quote! { #ident: #crate_::ForLifetimeTable::unwith_lt(with_lt.#ident, marker) }
226        })
227        .collect::<Vec<_>>();
228
229    let shorten_lifetime_fields_shorten = fields
230        .iter()
231        .map(|(ident, _is_nested)| {
232            quote! { #ident: self.#ident.shorten_lifetime() }
233        })
234        .collect::<Vec<_>>();
235
236    let shorten_lifetime_fields_noop = fields
237        .iter()
238        .map(|(ident, is_nested)| {
239            if *is_nested {
240                quote! { #ident: self.#ident.shorten_lifetime() }
241            } else {
242                quote! { #ident: self.#ident }
243            }
244        })
245        .collect::<Vec<_>>();
246
247    let tokens = quote! {
248        impl<'scope, #(#proxied_type_params,)*> #crate_::ForLifetimeTable for #ident<'scope, #crate_::ExprMode, #(#proxied_type_params,)*>
249        where
250            for<'a> #ident<'a, #crate_::ExprMode, #(#proxied_type_params,)*>: #crate_::Table,
251            #(#proxied_type_param_bounds,)*
252        {
253            type WithLt<'lt> = #ident<'lt, #crate_::ExprMode, #(#proxied_type_params,)*>;
254
255            fn with_lt<'lt>(self, marker: &mut #crate_::WithLtMarker) -> Self::WithLt<'lt> {
256                #ident {
257                    #(#with_lt_fields_with_lt,)*
258                }
259            }
260
261            fn unwith_lt<'lt>(with_lt: Self::WithLt<'lt>, marker: &mut #crate_::WithLtMarker) -> Self {
262                #ident {
263                    #(#unwith_lt_fields_with_lt,)*
264                }
265            }
266        }
267
268        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::NameMode, #(#proxied_type_params,)*>
269        where
270            #(#proxied_type_param_bounds,)*
271        {
272            type Shortened<'small> = #ident<'small, #crate_::NameMode, #(#proxied_type_params,)*>
273            where
274                Self: 'small
275                ;
276
277            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
278            where
279                Self: 'large,
280            {
281                #ident {
282                    #(#shorten_lifetime_fields_noop,)*
283                }
284            }
285        }
286
287        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::ValueMode, #(#proxied_type_params,)*>
288        where
289            #(#proxied_type_param_bounds,)*
290        {
291            type Shortened<'small> = #ident<'small, #crate_::ValueMode, #(#proxied_type_params,)*>
292            where
293                Self: 'small
294                ;
295
296            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
297            where
298                Self: 'large,
299            {
300                #ident {
301                    #(#shorten_lifetime_fields_noop,)*
302                }
303            }
304        }
305
306        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::EmptyMode, #(#proxied_type_params,)*>
307        where
308            #(#proxied_type_param_bounds,)*
309        {
310            type Shortened<'small> = #ident<'small, #crate_::EmptyMode, #(#proxied_type_params,)*>
311            where
312                Self: 'small
313                ;
314
315            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
316            where
317                Self: 'large,
318            {
319                #ident {
320                    #(#shorten_lifetime_fields_noop,)*
321                }
322            }
323        }
324
325        impl<'scope, #(#proxied_type_params,)*> #crate_::ShortenLifetime for #ident<'scope, #crate_::ExprMode, #(#proxied_type_params,)*>
326        where
327            #(#proxied_type_param_bounds,)*
328        {
329            type Shortened<'small> = #ident<'small, #crate_::ExprMode, #(#proxied_type_params,)*>
330            where
331                Self: 'small
332                ;
333
334            fn shorten_lifetime<'small, 'large: 'small>(self) -> Self::Shortened<'small>
335            where
336                Self: 'large,
337            {
338                #ident {
339                    #(#shorten_lifetime_fields_shorten,)*
340                }
341            }
342        }
343
344        impl<'scope, T: #crate_::TableMode, #(#proxied_type_params,)*> #crate_::TableHKT for #ident<'scope, T, #(#proxied_type_params,)*>
345        where
346            #(#proxied_type_param_bounds,)*
347        {
348            type InMode<Mode: #crate_::TableMode> = #ident<'scope, Mode, #(#proxied_type_params,)*>;
349
350            type Mode = T;
351        }
352
353        impl<'scope, Mode: #crate_::TableMode, #(#proxied_type_params,)*> #crate_::MapTable<'scope> for #ident<'scope, Mode, #(#proxied_type_params,)*>
354        where
355            #(#proxied_type_param_bounds,)*
356        {
357            fn map_modes<Mapper, DestMode>(self, mapper: &mut Mapper) -> Self::InMode<DestMode>
358            where
359                Mapper: #crate_::ModeMapper<'scope, Self::Mode, DestMode>,
360                DestMode: #crate_::TableMode,
361            {
362                #(#map_modes_fields)*
363
364                #map_modes_final
365            }
366
367            fn map_modes_ref<Mapper, DestMode>(&self, mapper: &mut Mapper) -> Self::InMode<DestMode>
368            where
369                Mapper: #crate_::ModeMapperRef<'scope, Self::Mode, DestMode>,
370                DestMode: #crate_::TableMode,
371            {
372                #(#map_modes_ref_fields)*
373
374                #map_modes_final
375            }
376
377            fn map_modes_mut<Mapper, DestMode>(&mut self, mapper: &mut Mapper) -> Self::InMode<DestMode>
378            where
379                Mapper: #crate_::ModeMapperMut<'scope, Self::Mode, DestMode>,
380                DestMode: #crate_::TableMode,
381            {
382                #(#map_modes_mut_fields)*
383
384                #map_modes_final
385            }
386        }
387
388        impl<'scope, #(#proxied_type_params,)*> #crate_::Table for #ident<'scope, #crate_::table_modes::ExprMode, #(#proxied_type_params,)*>
389        where
390            #(#proxied_type_param_bounds,)*
391        {
392            type Result = <Self as #crate_::TableHKT>::InMode<#crate_::table_modes::ValueMode>;
393
394            fn visit(&self, f: &mut impl FnMut(&#crate_::ErasedExpr), mode: VisitTableMode) {
395                #crate_::TableUsingMapper::wrap_ref(self).visit(f, mode)
396            }
397
398            fn visit_mut(&mut self, f: &mut impl FnMut(&mut #crate_::ErasedExpr), mode: VisitTableMode) {
399                #crate_::TableUsingMapper::wrap_mut(self).visit_mut(f, mode)
400            }
401        }
402    };
403
404    Ok(tokens)
405}