ts_sql_helper_derive/
lib.rs1use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{quote, quote_spanned};
8use regex::Regex;
9use syn::{
10 Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
11 parse_quote, spanned::Spanned,
12};
13
14use crate::query::{
15 QueryMacroInput,
16 main_struct::create_main_struct,
17 parameters::{Parameter, get_parameters},
18 row_struct::create_row_struct,
19 test::create_test,
20};
21
22mod query;
23
24#[proc_macro]
26pub fn query(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as QueryMacroInput);
28
29 let query = input.query.value();
30 static REGEX: LazyLock<Regex> =
31 LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
32 let query = REGEX.replace_all(query.trim(), " ");
33
34 let parameters: Vec<Parameter> =
35 get_parameters(&query, input.optional_params.unwrap_or_default());
36
37 let struct_name = input.name;
38
39 let main_struct = create_main_struct(&struct_name, &query, ¶meters);
40 let test = create_test(&struct_name);
41 let row_struct = if let Some(row_fields) = input.row {
42 create_row_struct(&struct_name, &row_fields)
43 } else {
44 proc_macro2::TokenStream::new()
45 };
46
47 quote! {
48 #main_struct
49 #row_struct
50 #test
51 }
52 .into()
53}
54
55#[proc_macro_derive(FromRow)]
57pub fn derive_from_row(input: TokenStream) -> TokenStream {
58 let input = parse_macro_input!(input as DeriveInput);
60
61 let name = input.ident;
62
63 let generics = add_trait_bounds(
65 input.generics,
66 parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
67 );
68 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
69
70 let Data::Struct(data_struct) = input.data else {
71 panic!("FromRow can only be derived on a struct")
72 };
73
74 let Fields::Named(fields) = data_struct.fields else {
75 panic!("FromRow can only be derived on a struct with named fields")
76 };
77
78 let each_field_from_row = fields.named.iter().filter_map(|f| {
79 let name = f.ident.as_ref()?;
80 let name_lit = name.to_string();
81 let field_type = &f.ty;
82
83 Some(quote_spanned! {f.span()=>
84 let #name: #field_type = row.try_get(#name_lit)?;
85 })
86 });
87
88 let struct_fields = fields.named.iter().map(|f| {
89 let name = &f.ident;
90 quote_spanned! {f.span() => #name}
91 });
92
93 let expanded = quote! {
94 impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
96 fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
97 #( #each_field_from_row )*
98
99 Ok(Self {
100 #( #struct_fields ),*
101 })
102 }
103 }
104 };
105
106 TokenStream::from(expanded)
108}
109
110#[proc_macro_derive(FromSql)]
112pub fn derive_from_sql(input: TokenStream) -> TokenStream {
113 let input = parse_macro_input!(input as DeriveInput);
115
116 if !matches!(input.data, Data::Enum(_)) {
117 panic!("FromSql can only be derived on an enum")
118 }
119
120 let name = input.ident;
121
122 let (repr, accepts, from_sql) = {
123 let mut repr_type = parse_quote!(&str);
124 let mut accepts: Vec<Type> = vec![
125 parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
126 parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
127 ];
128 let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
129 raw
130 )?);
131
132 for attr in input.attrs {
133 if !attr.path().is_ident("repr") {
134 continue;
135 }
136
137 let Ok(arg) = attr.parse_args::<Type>() else {
138 continue;
139 };
140
141 if arg == parse_quote!(i8) {
142 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
143 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
144 raw
145 )?);
146 } else if arg == parse_quote!(i16) {
147 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
148 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
149 raw
150 )?);
151 } else if arg == parse_quote!(i32) {
152 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
153 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
154 raw
155 )?);
156 } else if arg == parse_quote!(i64) {
157 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
158 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
159 raw
160 )?);
161 } else {
162 continue;
163 }
164
165 repr_type = arg;
166 break;
167 }
168
169 (repr_type, accepts, from_sql)
170 };
171
172 let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
173 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
174
175 let expanded = quote! {
176 impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
177 fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
178 let raw_value = #from_sql;
179 let value = Self::try_from(raw_value)?;
180 Ok(value)
181 }
182
183 fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
184 match (*ty) {
185 #(#accepts)|* => true,
186 _ => false,
187 }
188 }
189 }
190 };
191
192 TokenStream::from(expanded)
193}
194
195fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
197 for param in &mut generics.params {
198 if let GenericParam::Type(ref mut type_param) = *param {
199 type_param.bounds.push(bounds.clone());
200 }
201 }
202 generics
203}