ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{quote, quote_spanned};
8use regex::Regex;
9use syn::{
10    Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
11    parse_quote, spanned::Spanned,
12};
13
14use crate::query::{
15    QueryMacroInput,
16    main_struct::create_main_struct,
17    parameters::{Parameter, get_parameters},
18    row_struct::create_row_struct,
19    test::create_test,
20};
21
22mod query;
23
24/// Macro for creating and test SQL.
25#[proc_macro]
26pub fn query(input: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(input as QueryMacroInput);
28
29    let query = input.query.value();
30    static REGEX: LazyLock<Regex> =
31        LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
32    let query = REGEX.replace_all(query.trim(), " ");
33
34    let parameters: Vec<Parameter> =
35        get_parameters(&query, input.optional_params.unwrap_or_default());
36
37    let struct_name = input.name;
38
39    let main_struct = create_main_struct(&struct_name, &query, &parameters);
40    let test = create_test(&struct_name);
41    let row_struct = if let Some(row_fields) = input.row {
42        create_row_struct(&struct_name, &row_fields)
43    } else {
44        proc_macro2::TokenStream::new()
45    };
46
47    quote! {
48        #main_struct
49        #row_struct
50        #test
51    }
52    .into()
53}
54
55/// Derive `FromRow`.
56#[proc_macro_derive(FromRow)]
57pub fn derive_from_row(input: TokenStream) -> TokenStream {
58    // Parse the input tokens into a syntax tree.
59    let input = parse_macro_input!(input as DeriveInput);
60
61    let name = input.ident;
62
63    // Add required trait bounds depending on type.
64    let generics = add_trait_bounds(
65        input.generics,
66        parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
67    );
68    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
69
70    let Data::Struct(data_struct) = input.data else {
71        panic!("FromRow can only be derived on a struct")
72    };
73
74    let Fields::Named(fields) = data_struct.fields else {
75        panic!("FromRow can only be derived on a struct with named fields")
76    };
77
78    let each_field_from_row = fields.named.iter().filter_map(|f| {
79        let name = f.ident.as_ref()?;
80        let name_lit = name.to_string();
81        let field_type = &f.ty;
82
83        Some(quote_spanned! {f.span()=>
84            let #name: #field_type = row.try_get(#name_lit)?;
85        })
86    });
87
88    let struct_fields = fields.named.iter().map(|f| {
89        let name = &f.ident;
90        quote_spanned! {f.span() => #name}
91    });
92
93    let expanded = quote! {
94        // The generated impl.
95        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
96            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
97                #( #each_field_from_row )*
98
99                Ok(Self {
100                    #( #struct_fields ),*
101                })
102            }
103        }
104    };
105
106    // Hand the output tokens back to the compiler.
107    TokenStream::from(expanded)
108}
109
110/// Derive `FromSql`
111#[proc_macro_derive(FromSql)]
112pub fn derive_from_sql(input: TokenStream) -> TokenStream {
113    // Parse the input tokens into a syntax tree.
114    let input = parse_macro_input!(input as DeriveInput);
115
116    if !matches!(input.data, Data::Enum(_)) {
117        panic!("FromSql can only be derived on an enum")
118    }
119
120    let name = input.ident;
121
122    let (repr, accepts, from_sql) = {
123        let mut repr_type = parse_quote!(&str);
124        let mut accepts: Vec<Type> = vec![
125            parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
126            parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
127        ];
128        let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
129            raw
130        )?);
131
132        for attr in input.attrs {
133            if !attr.path().is_ident("repr") {
134                continue;
135            }
136
137            let Ok(arg) = attr.parse_args::<Type>() else {
138                continue;
139            };
140
141            if arg == parse_quote!(i8) {
142                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
143                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
144                    raw
145                )?);
146            } else if arg == parse_quote!(i16) {
147                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
148                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
149                    raw
150                )?);
151            } else if arg == parse_quote!(i32) {
152                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
153                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
154                    raw
155                )?);
156            } else if arg == parse_quote!(i64) {
157                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
158                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
159                    raw
160                )?);
161            } else {
162                continue;
163            }
164
165            repr_type = arg;
166            break;
167        }
168
169        (repr_type, accepts, from_sql)
170    };
171
172    let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
173    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
174
175    let expanded = quote! {
176        impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
177            fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
178                let raw_value = #from_sql;
179                let value = Self::try_from(raw_value)?;
180                Ok(value)
181            }
182
183            fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
184                match (*ty) {
185                    #(#accepts)|* => true,
186                    _ => false,
187                }
188            }
189        }
190    };
191
192    TokenStream::from(expanded)
193}
194
195// Add a bound to every type parameter T.
196fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
197    for param in &mut generics.params {
198        if let GenericParam::Type(ref mut type_param) = *param {
199            type_param.bounds.push(bounds.clone());
200        }
201    }
202    generics
203}