ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use quote::{quote, quote_spanned};
5use syn::{
6    Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
7    parse_quote, spanned::Spanned,
8};
9
10/// Derive `FromRow`.
11#[proc_macro_derive(FromRow)]
12pub fn derive_from_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    // Parse the input tokens into a syntax tree.
14    let input = parse_macro_input!(input as DeriveInput);
15
16    let name = input.ident;
17
18    // Add required trait bounds depending on type.
19    let generics = add_trait_bounds(
20        input.generics,
21        parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
22    );
23    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
24
25    let Data::Struct(data_struct) = input.data else {
26        panic!("FromRow can only be derived on a struct")
27    };
28
29    let Fields::Named(fields) = data_struct.fields else {
30        panic!("FromRow can only be derived on a struct with named fields")
31    };
32
33    let each_field_from_row = fields.named.iter().filter_map(|f| {
34        let name = f.ident.as_ref()?;
35
36        let field_type = &f.ty;
37
38        let name_lit = name.to_string();
39
40        Some(quote_spanned! {f.span()=>
41            let #name: #field_type = row.try_get(#name_lit).ok()?;
42        })
43    });
44
45    let struct_fields = fields.named.iter().map(|f| {
46        let name = &f.ident;
47        quote_spanned! {f.span() => #name}
48    });
49
50    let expanded = quote! {
51        // The generated impl.
52        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
53            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Option<Self> {
54                #( #each_field_from_row )*
55
56                Some(Self {
57                    #( #struct_fields ),*
58                })
59            }
60        }
61    };
62
63    // Hand the output tokens back to the compiler.
64    proc_macro::TokenStream::from(expanded)
65}
66
67/// Derive `FromSql`
68#[proc_macro_derive(FromSql)]
69pub fn derive_from_sql(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
70    // Parse the input tokens into a syntax tree.
71    let input = parse_macro_input!(input as DeriveInput);
72
73    if !matches!(input.data, Data::Enum(_)) {
74        panic!("FromSql can only be derived on an enum")
75    }
76
77    let name = input.ident;
78
79    let (repr, accepts, from_sql) = {
80        let mut repr_type = parse_quote!(&str);
81        let mut accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(TEXT));
82        let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
83            raw
84        )?);
85
86        for attr in input.attrs {
87            if !attr.path().is_ident("repr") {
88                continue;
89            }
90
91            let Ok(arg) = attr.parse_args::<Type>() else {
92                continue;
93            };
94
95            if arg == parse_quote!(i8) {
96                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(CHAR));
97                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
98                    raw
99                )?);
100            } else if arg == parse_quote!(i16) {
101                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT2));
102                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
103                    raw
104                )?);
105            } else if arg == parse_quote!(i32) {
106                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT4));
107                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
108                    raw
109                )?);
110            } else if arg == parse_quote!(i64) {
111                accepts = quote!(ts_sql_helper_lib::postgres::types::accepts!(INT8));
112                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
113                    raw
114                )?);
115            } else {
116                continue;
117            }
118
119            repr_type = arg;
120            break;
121        }
122
123        (repr_type, accepts, from_sql)
124    };
125
126    let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
127    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
128
129    let expanded = quote! {
130        impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
131            fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
132                let raw_value = #from_sql;
133                let value = Self::try_from(raw_value)?;
134                Ok(value)
135            }
136
137            #accepts;
138
139        }
140    };
141
142    proc_macro::TokenStream::from(expanded)
143}
144
145// Add a bound to every type parameter T.
146fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
147    for param in &mut generics.params {
148        if let GenericParam::Type(ref mut type_param) = *param {
149            type_param.bounds.push(bounds.clone());
150        }
151    }
152    generics
153}