ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use quote::{quote, quote_spanned};
5use syn::{
6    Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
7    parse_quote, spanned::Spanned,
8};
9
10/// Derive `FromRow`.
11#[proc_macro_derive(FromRow)]
12pub fn derive_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    // Parse the input tokens into a syntax tree.
14    let input = parse_macro_input!(input as DeriveInput);
15
16    let name = input.ident;
17
18    // Add required trait bounds depending on type.
19    let generics = add_trait_bounds(
20        input.generics,
21        parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
22    );
23    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
24
25    let Data::Struct(data_struct) = input.data else {
26        panic!("FromRow can only be derived on a struct")
27    };
28
29    let Fields::Named(fields) = data_struct.fields else {
30        panic!("FromRow can only be derived on a struct with named fields")
31    };
32
33    let each_field_from_row = fields.named.iter().filter_map(|f| {
34        let name = f.ident.as_ref()?;
35        let name_lit = name.to_string();
36        let field_type = &f.ty;
37
38        Some(quote_spanned! {f.span()=>
39            let #name: #field_type = row.try_get(#name_lit)?;
40        })
41    });
42
43    let struct_fields = fields.named.iter().map(|f| {
44        let name = &f.ident;
45        quote_spanned! {f.span() => #name}
46    });
47
48    let expanded = quote! {
49        // The generated impl.
50        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
51            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
52                #( #each_field_from_row )*
53
54                Ok(Self {
55                    #( #struct_fields ),*
56                })
57            }
58        }
59    };
60
61    // Hand the output tokens back to the compiler.
62    proc_macro::TokenStream::from(expanded)
63}
64
65/// Derive `FromSql`
66#[proc_macro_derive(FromSql)]
67pub fn derive_from_sql(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
68    // Parse the input tokens into a syntax tree.
69    let input = parse_macro_input!(input as DeriveInput);
70
71    if !matches!(input.data, Data::Enum(_)) {
72        panic!("FromSql can only be derived on an enum")
73    }
74
75    let name = input.ident;
76
77    let (repr, accepts, from_sql) = {
78        let mut repr_type = parse_quote!(&str);
79        let mut accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(TEXT));
80        let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
81            raw
82        )?);
83
84        for attr in input.attrs {
85            if !attr.path().is_ident("repr") {
86                continue;
87            }
88
89            let Ok(arg) = attr.parse_args::<Type>() else {
90                continue;
91            };
92
93            if arg == parse_quote!(i8) {
94                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(CHAR));
95                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
96                    raw
97                )?);
98            } else if arg == parse_quote!(i16) {
99                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT2));
100                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
101                    raw
102                )?);
103            } else if arg == parse_quote!(i32) {
104                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT4));
105                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
106                    raw
107                )?);
108            } else if arg == parse_quote!(i64) {
109                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT8));
110                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
111                    raw
112                )?);
113            } else {
114                continue;
115            }
116
117            repr_type = arg;
118            break;
119        }
120
121        (repr_type, accepts, from_sql)
122    };
123
124    let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
125    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
126
127    let expanded = quote! {
128        impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
129            fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
130                let raw_value = #from_sql;
131                let value = Self::try_from(raw_value)?;
132                Ok(value)
133            }
134
135            #accepts;
136
137        }
138    };
139
140    proc_macro::TokenStream::from(expanded)
141}
142
143// Add a bound to every type parameter T.
144fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
145    for param in &mut generics.params {
146        if let GenericParam::Type(ref mut type_param) = *param {
147            type_param.bounds.push(bounds.clone());
148        }
149    }
150    generics
151}