ts_sql_helper_derive/
lib.rs

1//! Derives for SQL helper
2//!
3
4use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote, quote_spanned};
8use regex::Regex;
9use syn::{
10    Data, DeriveInput, Fields, GenericParam, Generics, Ident, LitInt, LitStr, Token, Type,
11    TypeParamBound, bracketed,
12    parse::{Parse, ParseStream},
13    parse_macro_input, parse_quote,
14    spanned::Spanned,
15};
16
17struct QueryMacroInput {
18    name: Ident,
19    query: LitStr,
20    optional_params: Vec<usize>,
21}
22impl Parse for QueryMacroInput {
23    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
24        if input.parse::<Ident>()? != Ident::new("name", input.span()) {
25            return Err(input.error("expected `name`"));
26        }
27        input.parse::<Token![:]>()?;
28        let name: Ident = input.parse()?;
29        input.parse::<Token![,]>()?;
30
31        let mut ident = input.parse::<Ident>()?;
32        let optional_params = if ident == Ident::new("optional_params", input.span()) {
33            input.parse::<Token![:]>()?;
34
35            let content;
36            bracketed![content in input];
37            let optional_params: Vec<_> = content
38                .parse_terminated(LitInt::parse, Token![,])?
39                .iter()
40                .map(|v| v.base10_parse().unwrap())
41                .collect();
42
43            input.parse::<Token![,]>()?;
44
45            ident = input.parse::<Ident>()?;
46            optional_params
47        } else {
48            Vec::new()
49        };
50
51        if ident != Ident::new("query", input.span()) {
52            return Err(input.error("expected `query`"));
53        }
54        input.parse::<Token![:]>()?;
55        let query: LitStr = input.parse()?;
56
57        Ok(Self {
58            name,
59            query,
60            optional_params,
61        })
62    }
63}
64
65/// Macro for creating and test SQL.
66#[proc_macro]
67pub fn query(input: TokenStream) -> TokenStream {
68    let input = parse_macro_input!(input as QueryMacroInput);
69
70    pub enum State {
71        Neutral,
72        ConsumingNumber { has_consumed_a_digit: bool },
73        ConsumingTypeSeparator,
74        ConsumingType { type_string: String },
75    }
76
77    let query = input.query.value();
78    static REGEX: LazyLock<Regex> =
79        LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
80    let query = REGEX.replace_all(query.trim(), " ");
81
82    let mut parameter_types = vec![];
83    let mut state = State::Neutral;
84    for character in query.chars() {
85        match &mut state {
86            State::Neutral => {
87                if character == '$' {
88                    state = State::ConsumingNumber {
89                        has_consumed_a_digit: false,
90                    };
91                }
92            }
93            State::ConsumingNumber {
94                has_consumed_a_digit,
95            } => {
96                if character.is_ascii_digit() {
97                    *has_consumed_a_digit = true;
98                } else if character == ':' {
99                    state = State::ConsumingTypeSeparator;
100                } else {
101                    if *has_consumed_a_digit {
102                        parameter_types.push("unknown".to_string());
103                    }
104                    state = State::Neutral;
105                }
106            }
107            State::ConsumingTypeSeparator => {
108                if character.is_ascii_alphabetic() {
109                    state = State::ConsumingType {
110                        type_string: character.to_string(),
111                    };
112                } else if character != ':' {
113                    parameter_types.push("unknown".to_string());
114                    state = State::Neutral;
115                }
116            }
117            State::ConsumingType { type_string } => {
118                if character.is_ascii_alphabetic() || character == '[' || character == ']' {
119                    type_string.push(character);
120                } else {
121                    parameter_types.push(type_string.to_uppercase());
122                    state = State::Neutral;
123                }
124            }
125        }
126    }
127    match state {
128        State::Neutral => {}
129        State::ConsumingNumber {
130            has_consumed_a_digit,
131        } => {
132            if has_consumed_a_digit {
133                parameter_types.push("unknown".to_string());
134            }
135        }
136        State::ConsumingTypeSeparator => {
137            parameter_types.push("unknown".to_string());
138        }
139        State::ConsumingType { type_string } => {
140            parameter_types.push(type_string.to_uppercase());
141        }
142    }
143
144    let struct_name = input.name;
145    let param_struct_name = format_ident!("{struct_name}Params");
146    let param_count = parameter_types.len();
147
148    const KNOWN_TYPES: [&str; 30] = [
149        "BOOL",
150        "BOOL[]",
151        "BYTEA",
152        "BYTEA[]",
153        "CHAR",
154        "CHAR[]",
155        "INT8",
156        "INT8[]",
157        "INT4",
158        "INT4[]",
159        "INT2",
160        "INT2[]",
161        "FLOAT8",
162        "FLOAT8[]",
163        "FLOAT4",
164        "FLOAT4[]",
165        "UUID",
166        "UUID[]",
167        "TEXT",
168        "VARCHAR",
169        "VARCHAR[]",
170        "TEXT[]",
171        "TIMESTAMP",
172        "TIMESTAMP[]",
173        "TIMESTAMPTZ",
174        "TIMESTAMPTZ[]",
175        "DATE",
176        "DATE[]",
177        "TIME",
178        "TIME[]",
179    ];
180    let param_types: Vec<Type> = parameter_types
181        .iter()
182        .enumerate()
183        .map(|(index, name)| {
184            let param_number = index + 1;
185            let param_type = match name.as_str() {
186                "BOOL" => parse_quote!(&'a bool),
187                "BOOL[]" => parse_quote!(&'a [bool]),
188                "BYTEA" => parse_quote!(&'a [u8]),
189                "BYTEA[]" => parse_quote!(&'a [Vec<u8>]),
190                "CHAR" => parse_quote!(&'a i8),
191                "CHAR[]" => parse_quote!(&'a [i8]),
192                "INT8" => parse_quote!(&'a i64),
193                "INT8[]" => parse_quote!(&'a [i64]),
194                "INT4" => parse_quote!(&'a i32),
195                "INT4[]" => parse_quote!(&'a [i32]),
196                "INT2" => parse_quote!(&'a i16),
197                "INT2[]" => parse_quote!(&'a [i16]),
198                "FLOAT8" => parse_quote!(&'a f64),
199                "FLOAT8[]" => parse_quote!(&'a [f64]),
200                "FLOAT4" => parse_quote!(&'a f32),
201                "FLOAT4[]" => parse_quote!(&'a [f32]),
202                "UUID" => parse_quote!(&'a uuid::Uuid),
203                "UUID[]" => parse_quote!(&'a [uuid::Uuid]),
204                "TEXT" | "VARCHAR" => parse_quote!(&'a str),
205                "VARCHAR[]" | "TEXT[]" => parse_quote!(&'a [String]),
206                "TIMESTAMP" => parse_quote!(&'a ts_sql_helper_lib::SqlDateTime),
207                "TIMESTAMP[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlDateTime]),
208                "TIMESTAMPTZ" => parse_quote!(&'a ts_sql_helper_lib::SqlTimestamp),
209                "TIMESTAMPTZ[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlTimestamp]),
210                "DATE" => parse_quote!(&'a ts_sql_helper_lib::SqlDate),
211                "DATE[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlDate]),
212                "TIME" => parse_quote!(&'a ts_sql_helper_lib::SqlTime),
213                "TIME[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlTime]),
214
215                _ => parse_quote!(&'a (dyn ts_sql_helper_lib::postgres::types::ToSql + Sync)),
216            };
217            if input.optional_params.contains(&param_number) {
218                parse_quote!(Option<#param_type>)
219            } else {
220                param_type
221            }
222        })
223        .collect();
224    let param_names: Vec<Ident> = (1..param_count + 1)
225        .map(|number| format_ident!("p{number}"))
226        .collect();
227
228    let params: Vec<_> = param_types
229        .iter()
230        .enumerate()
231        .map(|(index, field_type)| {
232            let name = &param_names[index];
233            quote! {
234                #name: #field_type
235            }
236        })
237        .collect();
238
239    let pub_params = params.iter().map(|param| quote! {pub #param});
240    let self_params = param_names.iter().enumerate().map(|(index, param)| {
241        let type_string = &parameter_types[index];
242        if KNOWN_TYPES.contains(&type_string.as_str()) {
243            quote!(&self.#param)
244        } else {
245            quote!(self.#param)
246        }
247    });
248
249    let test_name = format_ident!("test_{struct_name}");
250    let test = quote! {
251        #[cfg(test)]
252        #[allow(non_snake_case)]
253        #[test]
254        fn #test_name() {
255            use ts_sql_helper_lib::test::get_test_database;
256
257            let (mut client, _container) = get_test_database();
258            let statement = client.prepare(#struct_name::QUERY);
259            assert!(statement.is_ok(), "invalid query `{}`: {}", #struct_name::QUERY, statement.unwrap_err());
260            let statement = statement.unwrap();
261
262            let mut data: Vec<Box<dyn ts_sql_helper_lib::postgres_types::ToSql + Sync>> = Vec::new();
263            let params = statement.params();
264            for param in params.iter() {
265                match ts_sql_helper_lib::test::data_for_type(param) {
266                    Some(param_data) => data.push(param_data),
267                    None => panic!("unsupported parameter type `{}`", param.name()),
268                }
269            }
270
271            let borrowed_data: Vec<&(dyn ts_sql_helper_lib::postgres_types::ToSql + Sync)> =
272                data.iter().map(|data| data.as_ref()).collect();
273
274            let result = client.execute(&statement, borrowed_data.as_slice());
275            if let Err(error) = result {
276                use ts_sql_helper_lib::postgres::error::SqlState;
277
278                assert!(
279                    matches!(
280                        error.code(),
281                        Some(&SqlState::FOREIGN_KEY_VIOLATION) | Some(&SqlState::CHECK_VIOLATION)
282                    ),
283                    "invalid query `{}`: {error}",
284                    #struct_name::QUERY
285                );
286            }
287        }
288    };
289    quote! {
290        struct #struct_name;
291        impl #struct_name {
292            pub const QUERY: &str = #query;
293            pub fn params<'a>(#( #params ),*) -> #param_struct_name<'a> {
294                #param_struct_name {
295                    #( #param_names , )*
296                    phantom_data: core::marker::PhantomData,
297                }
298            }
299        }
300        struct #param_struct_name<'a> {
301            #( #pub_params , )*
302            pub phantom_data: core::marker::PhantomData<&'a ()>,
303        }
304        impl<'a>  #param_struct_name<'a> {
305            pub fn as_array(&'a self) -> [&'a (dyn ts_sql_helper_lib::postgres::types::ToSql + Sync); #param_count] {
306                [
307                    #( #self_params , )*
308                ]
309            }
310        }
311        #test
312    }
313    .into()
314}
315
316/// Derive `FromRow`.
317#[proc_macro_derive(FromRow)]
318pub fn derive_from_row(input: TokenStream) -> TokenStream {
319    // Parse the input tokens into a syntax tree.
320    let input = parse_macro_input!(input as DeriveInput);
321
322    let name = input.ident;
323
324    // Add required trait bounds depending on type.
325    let generics = add_trait_bounds(
326        input.generics,
327        parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
328    );
329    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
330
331    let Data::Struct(data_struct) = input.data else {
332        panic!("FromRow can only be derived on a struct")
333    };
334
335    let Fields::Named(fields) = data_struct.fields else {
336        panic!("FromRow can only be derived on a struct with named fields")
337    };
338
339    let each_field_from_row = fields.named.iter().filter_map(|f| {
340        let name = f.ident.as_ref()?;
341        let name_lit = name.to_string();
342        let field_type = &f.ty;
343
344        Some(quote_spanned! {f.span()=>
345            let #name: #field_type = row.try_get(#name_lit)?;
346        })
347    });
348
349    let struct_fields = fields.named.iter().map(|f| {
350        let name = &f.ident;
351        quote_spanned! {f.span() => #name}
352    });
353
354    let expanded = quote! {
355        // The generated impl.
356        impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
357            fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
358                #( #each_field_from_row )*
359
360                Ok(Self {
361                    #( #struct_fields ),*
362                })
363            }
364        }
365    };
366
367    // Hand the output tokens back to the compiler.
368    TokenStream::from(expanded)
369}
370
371/// Derive `FromSql`
372#[proc_macro_derive(FromSql)]
373pub fn derive_from_sql(input: TokenStream) -> TokenStream {
374    // Parse the input tokens into a syntax tree.
375    let input = parse_macro_input!(input as DeriveInput);
376
377    if !matches!(input.data, Data::Enum(_)) {
378        panic!("FromSql can only be derived on an enum")
379    }
380
381    let name = input.ident;
382
383    let (repr, accepts, from_sql) = {
384        let mut repr_type = parse_quote!(&str);
385        let mut accepts: Vec<Type> = vec![
386            parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
387            parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
388        ];
389        let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
390            raw
391        )?);
392
393        for attr in input.attrs {
394            if !attr.path().is_ident("repr") {
395                continue;
396            }
397
398            let Ok(arg) = attr.parse_args::<Type>() else {
399                continue;
400            };
401
402            if arg == parse_quote!(i8) {
403                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
404                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
405                    raw
406                )?);
407            } else if arg == parse_quote!(i16) {
408                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
409                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
410                    raw
411                )?);
412            } else if arg == parse_quote!(i32) {
413                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
414                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
415                    raw
416                )?);
417            } else if arg == parse_quote!(i64) {
418                accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
419                from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
420                    raw
421                )?);
422            } else {
423                continue;
424            }
425
426            repr_type = arg;
427            break;
428        }
429
430        (repr_type, accepts, from_sql)
431    };
432
433    let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
434    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
435
436    let expanded = quote! {
437        impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
438            fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
439                let raw_value = #from_sql;
440                let value = Self::try_from(raw_value)?;
441                Ok(value)
442            }
443
444            fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
445                match (*ty) {
446                    #(#accepts)|* => true,
447                    _ => false,
448                }
449            }
450        }
451    };
452
453    TokenStream::from(expanded)
454}
455
456// Add a bound to every type parameter T.
457fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
458    for param in &mut generics.params {
459        if let GenericParam::Type(ref mut type_param) = *param {
460            type_param.bounds.push(bounds.clone());
461        }
462    }
463    generics
464}