ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{quote, quote_spanned};
8use regex::Regex;
9use syn::{
10    Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
11    parse_quote, spanned::Spanned,
12};
13
14use crate::query::{
15    QueryMacroInput,
16    main_struct::create_main_struct,
17    parameters::{get_param_types, parameter_to_type},
18    row_struct::create_row_struct,
19    test::create_test,
20};
21
22mod query;
23
24/// Macro for creating and test SQL.
25#[proc_macro]
26pub fn query(input: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(input as QueryMacroInput);
28
29    let query = input.query.value();
30    static REGEX: LazyLock<Regex> =
31        LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
32    let query = REGEX.replace_all(query.trim(), " ");
33
34    let parameters: Vec<Type> = get_param_types(&query)
35        .into_iter()
36        .enumerate()
37        .map(|(index, parameter)| {
38            let r#type = parameter_to_type(&parameter);
39            if input
40                .optional_params
41                .as_ref()
42                .is_some_and(|params| params.contains(&(index + 1)))
43            {
44                parse_quote!(Option<#r#type>)
45            } else {
46                r#type
47            }
48        })
49        .collect();
50
51    let struct_name = input.name;
52
53    let main_struct = create_main_struct(&struct_name, &query, &parameters);
54    let test = create_test(&struct_name);
55    let row_struct = if let Some(row_fields) = input.row {
56        create_row_struct(&struct_name, &row_fields)
57    } else {
58        proc_macro2::TokenStream::new()
59    };
60
61    quote! {
62        #main_struct
63        #row_struct
64        #test
65    }
66    .into()
67}
68
69/// Derive `FromRow`.
70#[proc_macro_derive(FromRow)]
71pub fn derive_from_row(input: TokenStream) -> TokenStream {
72    // Parse the input tokens into a syntax tree.
73    let input = parse_macro_input!(input as DeriveInput);
74
75    let name = input.ident;
76
77    // Add required trait bounds depending on type.
78    let generics = add_trait_bounds(
79        input.generics,
80        parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
81    );
82    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
83
84    let Data::Struct(data_struct) = input.data else {
85        panic!("FromRow can only be derived on a struct")
86    };
87
88    let Fields::Named(fields) = data_struct.fields else {
89        panic!("FromRow can only be derived on a struct with named fields")
90    };
91
92    let each_field_from_row = fields.named.iter().filter_map(|f| {
93        let name = f.ident.as_ref()?;
94        let name_lit = name.to_string();
95        let field_type = &f.ty;
96
97        Some(quote_spanned! {f.span()=>
98            let #name: #field_type = row.try_get(#name_lit)?;
99        })
100    });
101
102    let struct_fields = fields.named.iter().map(|f| {
103        let name = &f.ident;
104        quote_spanned! {f.span() => #name}
105    });
106
107    let expanded = quote! {
108        // The generated impl.
109        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
110            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
111                #( #each_field_from_row )*
112
113                Ok(Self {
114                    #( #struct_fields ),*
115                })
116            }
117        }
118    };
119
120    // Hand the output tokens back to the compiler.
121    TokenStream::from(expanded)
122}
123
124/// Derive `FromSql`
125#[proc_macro_derive(FromSql)]
126pub fn derive_from_sql(input: TokenStream) -> TokenStream {
127    // Parse the input tokens into a syntax tree.
128    let input = parse_macro_input!(input as DeriveInput);
129
130    if !matches!(input.data, Data::Enum(_)) {
131        panic!("FromSql can only be derived on an enum")
132    }
133
134    let name = input.ident;
135
136    let (repr, accepts, from_sql) = {
137        let mut repr_type = parse_quote!(&str);
138        let mut accepts: Vec<Type> = vec![
139            parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
140            parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
141        ];
142        let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
143            raw
144        )?);
145
146        for attr in input.attrs {
147            if !attr.path().is_ident("repr") {
148                continue;
149            }
150
151            let Ok(arg) = attr.parse_args::<Type>() else {
152                continue;
153            };
154
155            if arg == parse_quote!(i8) {
156                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
157                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
158                    raw
159                )?);
160            } else if arg == parse_quote!(i16) {
161                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
162                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
163                    raw
164                )?);
165            } else if arg == parse_quote!(i32) {
166                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
167                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
168                    raw
169                )?);
170            } else if arg == parse_quote!(i64) {
171                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
172                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
173                    raw
174                )?);
175            } else {
176                continue;
177            }
178
179            repr_type = arg;
180            break;
181        }
182
183        (repr_type, accepts, from_sql)
184    };
185
186    let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
187    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
188
189    let expanded = quote! {
190        impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
191            fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
192                let raw_value = #from_sql;
193                let value = Self::try_from(raw_value)?;
194                Ok(value)
195            }
196
197            fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
198                match (*ty) {
199                    #(#accepts)|* => true,
200                    _ => false,
201                }
202            }
203        }
204    };
205
206    TokenStream::from(expanded)
207}
208
209// Add a bound to every type parameter T.
210fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
211    for param in &mut generics.params {
212        if let GenericParam::Type(ref mut type_param) = *param {
213            type_param.bounds.push(bounds.clone());
214        }
215    }
216    generics
217}