ts_sql_helper_derive/
lib.rs1use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{quote, quote_spanned};
8use regex::Regex;
9use syn::{
10 Data, DeriveInput, Fields, GenericParam, Generics, Type, TypeParamBound, parse_macro_input,
11 parse_quote, spanned::Spanned,
12};
13
14use crate::query::{
15 QueryMacroInput,
16 main_struct::create_main_struct,
17 parameters::{get_param_types, parameter_to_type},
18 row_struct::create_row_struct,
19 test::create_test,
20};
21
22mod query;
23
24#[proc_macro]
26pub fn query(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as QueryMacroInput);
28
29 let query = input.query.value();
30 static REGEX: LazyLock<Regex> =
31 LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
32 let query = REGEX.replace_all(query.trim(), " ");
33
34 let parameters: Vec<Type> = get_param_types(&query)
35 .into_iter()
36 .enumerate()
37 .map(|(index, parameter)| {
38 let r#type = parameter_to_type(¶meter);
39 if input
40 .optional_params
41 .as_ref()
42 .is_some_and(|params| params.contains(&(index + 1)))
43 {
44 parse_quote!(Option<#r#type>)
45 } else {
46 r#type
47 }
48 })
49 .collect();
50
51 let struct_name = input.name;
52
53 let main_struct = create_main_struct(&struct_name, &query, ¶meters);
54 let test = create_test(&struct_name);
55 let row_struct = if let Some(row_fields) = input.row {
56 create_row_struct(&struct_name, &row_fields)
57 } else {
58 proc_macro2::TokenStream::new()
59 };
60
61 quote! {
62 #main_struct
63 #row_struct
64 #test
65 }
66 .into()
67}
68
69#[proc_macro_derive(FromRow)]
71pub fn derive_from_row(input: TokenStream) -> TokenStream {
72 let input = parse_macro_input!(input as DeriveInput);
74
75 let name = input.ident;
76
77 let generics = add_trait_bounds(
79 input.generics,
80 parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
81 );
82 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
83
84 let Data::Struct(data_struct) = input.data else {
85 panic!("FromRow can only be derived on a struct")
86 };
87
88 let Fields::Named(fields) = data_struct.fields else {
89 panic!("FromRow can only be derived on a struct with named fields")
90 };
91
92 let each_field_from_row = fields.named.iter().filter_map(|f| {
93 let name = f.ident.as_ref()?;
94 let name_lit = name.to_string();
95 let field_type = &f.ty;
96
97 Some(quote_spanned! {f.span()=>
98 let #name: #field_type = row.try_get(#name_lit)?;
99 })
100 });
101
102 let struct_fields = fields.named.iter().map(|f| {
103 let name = &f.ident;
104 quote_spanned! {f.span() => #name}
105 });
106
107 let expanded = quote! {
108 impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
110 fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
111 #( #each_field_from_row )*
112
113 Ok(Self {
114 #( #struct_fields ),*
115 })
116 }
117 }
118 };
119
120 TokenStream::from(expanded)
122}
123
124#[proc_macro_derive(FromSql)]
126pub fn derive_from_sql(input: TokenStream) -> TokenStream {
127 let input = parse_macro_input!(input as DeriveInput);
129
130 if !matches!(input.data, Data::Enum(_)) {
131 panic!("FromSql can only be derived on an enum")
132 }
133
134 let name = input.ident;
135
136 let (repr, accepts, from_sql) = {
137 let mut repr_type = parse_quote!(&str);
138 let mut accepts: Vec<Type> = vec![
139 parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
140 parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
141 ];
142 let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
143 raw
144 )?);
145
146 for attr in input.attrs {
147 if !attr.path().is_ident("repr") {
148 continue;
149 }
150
151 let Ok(arg) = attr.parse_args::<Type>() else {
152 continue;
153 };
154
155 if arg == parse_quote!(i8) {
156 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
157 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
158 raw
159 )?);
160 } else if arg == parse_quote!(i16) {
161 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
162 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
163 raw
164 )?);
165 } else if arg == parse_quote!(i32) {
166 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
167 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
168 raw
169 )?);
170 } else if arg == parse_quote!(i64) {
171 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
172 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
173 raw
174 )?);
175 } else {
176 continue;
177 }
178
179 repr_type = arg;
180 break;
181 }
182
183 (repr_type, accepts, from_sql)
184 };
185
186 let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
187 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
188
189 let expanded = quote! {
190 impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
191 fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
192 let raw_value = #from_sql;
193 let value = Self::try_from(raw_value)?;
194 Ok(value)
195 }
196
197 fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
198 match (*ty) {
199 #(#accepts)|* => true,
200 _ => false,
201 }
202 }
203 }
204 };
205
206 TokenStream::from(expanded)
207}
208
209fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
211 for param in &mut generics.params {
212 if let GenericParam::Type(ref mut type_param) = *param {
213 type_param.bounds.push(bounds.clone());
214 }
215 }
216 generics
217}