ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use proc_macro2::TokenStream;
5use quote::{quote, quote_spanned};
6use syn::{
7    Data, DeriveInput, Fields, GenericParam, Generics, parse_macro_input, parse_quote,
8    spanned::Spanned,
9};
10
11/// Derive `FromRow`.
12#[proc_macro_derive(FromRow)]
13pub fn derive_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
14    // Parse the input tokens into a syntax tree.
15    let input = parse_macro_input!(input as DeriveInput);
16
17    let name = input.ident;
18
19    // Add a bound `T: FromSql` to every type parameter T.
20    let generics = add_trait_bounds(input.generics);
21    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
22
23    let implementation = all_from_row(&input.data);
24
25    let expanded = quote! {
26        // The generated impl.
27        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
28            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Option<Self> {
29                #implementation
30            }
31        }
32    };
33
34    // Hand the output tokens back to the compiler.
35    proc_macro::TokenStream::from(expanded)
36}
37
38// Add a bound `T: FromSql` to every type parameter T.
39fn add_trait_bounds(mut generics: Generics) -> Generics {
40    for param in &mut generics.params {
41        if let GenericParam::Type(ref mut type_param) = *param {
42            type_param
43                .bounds
44                .push(parse_quote!(ts_sql_helper_lib::postgres::types::FromSql));
45        }
46    }
47    generics
48}
49
50fn all_from_row(data: &Data) -> TokenStream {
51    match *data {
52        Data::Struct(ref data) => match data.fields {
53            Fields::Named(ref fields) => {
54                let try_get = fields.named.iter().filter_map(|f| {
55                    let Some(name) = &f.ident else {
56                        return None;
57                    };
58
59                    let field_type = &f.ty;
60
61                    let name_lit = name.to_string();
62
63                    let tokens = quote_spanned! {f.span()=>
64                        let #name: #field_type = row.try_get(#name_lit).ok()?;
65                    };
66
67                    Some(tokens)
68                });
69
70                let build_self = fields.named.iter().map(|f| {
71                    let name = &f.ident;
72
73                    quote_spanned! {f.span() => #name}
74                });
75
76                quote! {
77                    #( #try_get )*
78
79                    Some(Self {
80                        #( #build_self ),*
81                    })
82                }
83            }
84            Fields::Unnamed(ref fields) => {
85                panic!("`#[derive(FromRow)]` is only implemented for structs with named fields")
86            }
87            Fields::Unit => {
88                panic!("`#[derive(FromRow)]` is only implemented for structs with named fields")
89            }
90        },
91        Data::Enum(_) | Data::Union(_) => {
92            panic!("`#[derive(FromRow)]` is only implemented for structs with named fields")
93        }
94    }
95}