1use std::sync::LazyLock;
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote, quote_spanned};
8use regex::Regex;
9use syn::{
10 Data, DeriveInput, Fields, GenericParam, Generics, Ident, LitInt, LitStr, Token, Type,
11 TypeParamBound, bracketed,
12 parse::{Parse, ParseStream},
13 parse_macro_input, parse_quote,
14 spanned::Spanned,
15};
16
17struct QueryMacroInput {
18 name: Ident,
19 query: LitStr,
20 optional_params: Vec<usize>,
21}
22impl Parse for QueryMacroInput {
23 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
24 if input.parse::<Ident>()? != Ident::new("name", input.span()) {
25 return Err(input.error("expected `name`"));
26 }
27 input.parse::<Token![:]>()?;
28 let name: Ident = input.parse()?;
29 input.parse::<Token![,]>()?;
30
31 let mut ident = input.parse::<Ident>()?;
32 let optional_params = if ident == Ident::new("optional_params", input.span()) {
33 input.parse::<Token![:]>()?;
34
35 let content;
36 bracketed![content in input];
37 let optional_params: Vec<_> = content
38 .parse_terminated(LitInt::parse, Token![,])?
39 .iter()
40 .map(|v| v.base10_parse().unwrap())
41 .collect();
42
43 input.parse::<Token![,]>()?;
44
45 ident = input.parse::<Ident>()?;
46 optional_params
47 } else {
48 Vec::new()
49 };
50
51 if ident != Ident::new("query", input.span()) {
52 return Err(input.error("expected `query`"));
53 }
54 input.parse::<Token![:]>()?;
55 let query: LitStr = input.parse()?;
56
57 Ok(Self {
58 name,
59 query,
60 optional_params,
61 })
62 }
63}
64
65#[proc_macro]
67pub fn query(input: TokenStream) -> TokenStream {
68 let input = parse_macro_input!(input as QueryMacroInput);
69
70 pub enum State {
71 Neutral,
72 ConsumingNumber { has_consumed_a_digit: bool },
73 ConsumingTypeSeparator,
74 ConsumingType { type_string: String },
75 }
76
77 let query = input.query.value();
78 static REGEX: LazyLock<Regex> =
79 LazyLock::new(|| Regex::new(r"(?m)(\r\n|\r|\n| ){2,}").unwrap());
80 let query = REGEX.replace_all(query.trim(), " ");
81
82 let mut parameter_types = vec![];
83 let mut state = State::Neutral;
84 for character in query.chars() {
85 match &mut state {
86 State::Neutral => {
87 if character == '$' {
88 state = State::ConsumingNumber {
89 has_consumed_a_digit: false,
90 };
91 }
92 }
93 State::ConsumingNumber {
94 has_consumed_a_digit,
95 } => {
96 if character.is_ascii_digit() {
97 *has_consumed_a_digit = true;
98 } else if character == ':' {
99 state = State::ConsumingTypeSeparator;
100 } else {
101 if *has_consumed_a_digit {
102 parameter_types.push("unknown".to_string());
103 }
104 state = State::Neutral;
105 }
106 }
107 State::ConsumingTypeSeparator => {
108 if character.is_ascii_alphabetic() {
109 state = State::ConsumingType {
110 type_string: character.to_string(),
111 };
112 } else if character != ':' {
113 parameter_types.push("unknown".to_string());
114 state = State::Neutral;
115 }
116 }
117 State::ConsumingType { type_string } => {
118 if character.is_ascii_alphabetic() || character == '[' || character == ']' {
119 type_string.push(character);
120 } else {
121 parameter_types.push(type_string.to_uppercase());
122 state = State::Neutral;
123 }
124 }
125 }
126 }
127 match state {
128 State::Neutral => {}
129 State::ConsumingNumber {
130 has_consumed_a_digit,
131 } => {
132 if has_consumed_a_digit {
133 parameter_types.push("unknown".to_string());
134 }
135 }
136 State::ConsumingTypeSeparator => {
137 parameter_types.push("unknown".to_string());
138 }
139 State::ConsumingType { type_string } => {
140 parameter_types.push(type_string.to_uppercase());
141 }
142 }
143
144 let struct_name = input.name;
145 let param_struct_name = format_ident!("{struct_name}Params");
146 let param_count = parameter_types.len();
147
148 const KNOWN_TYPES: [&str; 30] = [
149 "BOOL",
150 "BOOL[]",
151 "BYTEA",
152 "BYTEA[]",
153 "CHAR",
154 "CHAR[]",
155 "INT8",
156 "INT8[]",
157 "INT4",
158 "INT4[]",
159 "INT2",
160 "INT2[]",
161 "FLOAT8",
162 "FLOAT8[]",
163 "FLOAT4",
164 "FLOAT4[]",
165 "UUID",
166 "UUID[]",
167 "TEXT",
168 "VARCHAR",
169 "VARCHAR[]",
170 "TEXT[]",
171 "TIMESTAMP",
172 "TIMESTAMP[]",
173 "TIMESTAMPTZ",
174 "TIMESTAMPTZ[]",
175 "DATE",
176 "DATE[]",
177 "TIME",
178 "TIME[]",
179 ];
180 let param_types: Vec<Type> = parameter_types
181 .iter()
182 .enumerate()
183 .map(|(index, name)| {
184 let param_number = index + 1;
185 let param_type = match name.as_str() {
186 "BOOL" => parse_quote!(&'a bool),
187 "BOOL[]" => parse_quote!(&'a [bool]),
188 "BYTEA" => parse_quote!(&'a [u8]),
189 "BYTEA[]" => parse_quote!(&'a [Vec<u8>]),
190 "CHAR" => parse_quote!(&'a i8),
191 "CHAR[]" => parse_quote!(&'a [i8]),
192 "INT8" => parse_quote!(&'a i64),
193 "INT8[]" => parse_quote!(&'a [i64]),
194 "INT4" => parse_quote!(&'a i32),
195 "INT4[]" => parse_quote!(&'a [i32]),
196 "INT2" => parse_quote!(&'a i16),
197 "INT2[]" => parse_quote!(&'a [i16]),
198 "FLOAT8" => parse_quote!(&'a f64),
199 "FLOAT8[]" => parse_quote!(&'a [f64]),
200 "FLOAT4" => parse_quote!(&'a f32),
201 "FLOAT4[]" => parse_quote!(&'a [f32]),
202 "UUID" => parse_quote!(&'a uuid::Uuid),
203 "UUID[]" => parse_quote!(&'a [uuid::Uuid]),
204 "TEXT" | "VARCHAR" => parse_quote!(&'a str),
205 "VARCHAR[]" | "TEXT[]" => parse_quote!(&'a [String]),
206 "TIMESTAMP" => parse_quote!(&'a ts_sql_helper_lib::SqlDateTime),
207 "TIMESTAMP[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlDateTime]),
208 "TIMESTAMPTZ" => parse_quote!(&'a ts_sql_helper_lib::SqlTimestamp),
209 "TIMESTAMPTZ[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlTimestamp]),
210 "DATE" => parse_quote!(&'a ts_sql_helper_lib::SqlDate),
211 "DATE[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlDate]),
212 "TIME" => parse_quote!(&'a ts_sql_helper_lib::SqlTime),
213 "TIME[]" => parse_quote!(&'a [ts_sql_helper_lib::SqlTime]),
214
215 _ => parse_quote!(&'a (dyn ts_sql_helper_lib::postgres::types::ToSql + Sync)),
216 };
217 if input.optional_params.contains(¶m_number) {
218 parse_quote!(Option<#param_type>)
219 } else {
220 param_type
221 }
222 })
223 .collect();
224 let param_names: Vec<Ident> = (1..param_count + 1)
225 .map(|number| format_ident!("p{number}"))
226 .collect();
227
228 let params: Vec<_> = param_types
229 .iter()
230 .enumerate()
231 .map(|(index, field_type)| {
232 let name = ¶m_names[index];
233 quote! {
234 #name: #field_type
235 }
236 })
237 .collect();
238
239 let pub_params = params.iter().map(|param| quote! {pub #param});
240 let self_params = param_names.iter().enumerate().map(|(index, param)| {
241 let type_string = ¶meter_types[index];
242 if KNOWN_TYPES.contains(&type_string.as_str()) {
243 quote!(&self.#param)
244 } else {
245 quote!(self.#param)
246 }
247 });
248
249 let test_name = format_ident!("test_{struct_name}");
250 let test = quote! {
251 #[cfg(test)]
252 #[allow(non_snake_case)]
253 #[test]
254 fn #test_name() {
255 use ts_sql_helper_lib::test::get_test_database;
256
257 let (mut client, _container) = get_test_database();
258 let statement = client.prepare(#struct_name::QUERY);
259 assert!(statement.is_ok(), "invalid query `{}`: {}", #struct_name::QUERY, statement.unwrap_err());
260 let statement = statement.unwrap();
261
262 let mut data: Vec<Box<dyn ts_sql_helper_lib::postgres_types::ToSql + Sync>> = Vec::new();
263 let params = statement.params();
264 for param in params.iter() {
265 match ts_sql_helper_lib::test::data_for_type(param) {
266 Some(param_data) => data.push(param_data),
267 None => panic!("unsupported parameter type `{}`", param.name()),
268 }
269 }
270
271 let borrowed_data: Vec<&(dyn ts_sql_helper_lib::postgres_types::ToSql + Sync)> =
272 data.iter().map(|data| data.as_ref()).collect();
273
274 let result = client.execute(&statement, borrowed_data.as_slice());
275 if let Err(error) = result {
276 use ts_sql_helper_lib::postgres::error::SqlState;
277
278 assert!(
279 matches!(
280 error.code(),
281 Some(&SqlState::FOREIGN_KEY_VIOLATION) | Some(&SqlState::CHECK_VIOLATION)
282 ),
283 "invalid query `{}`: {error}",
284 #struct_name::QUERY
285 );
286 }
287 }
288 };
289 quote! {
290 struct #struct_name;
291 impl #struct_name {
292 pub const QUERY: &str = #query;
293 pub fn params<'a>(#( #params ),*) -> #param_struct_name<'a> {
294 #param_struct_name {
295 #( #param_names , )*
296 phantom_data: core::marker::PhantomData,
297 }
298 }
299 }
300 struct #param_struct_name<'a> {
301 #( #pub_params , )*
302 pub phantom_data: core::marker::PhantomData<&'a ()>,
303 }
304 impl<'a> #param_struct_name<'a> {
305 pub fn as_array(&'a self) -> [&'a (dyn ts_sql_helper_lib::postgres::types::ToSql + Sync); #param_count] {
306 [
307 #( #self_params , )*
308 ]
309 }
310 }
311 #test
312 }
313 .into()
314}
315
316#[proc_macro_derive(FromRow)]
318pub fn derive_from_row(input: TokenStream) -> TokenStream {
319 let input = parse_macro_input!(input as DeriveInput);
321
322 let name = input.ident;
323
324 let generics = add_trait_bounds(
326 input.generics,
327 parse_quote!(ts_sql_helper_lib::postgres::types::FromSql),
328 );
329 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
330
331 let Data::Struct(data_struct) = input.data else {
332 panic!("FromRow can only be derived on a struct")
333 };
334
335 let Fields::Named(fields) = data_struct.fields else {
336 panic!("FromRow can only be derived on a struct with named fields")
337 };
338
339 let each_field_from_row = fields.named.iter().filter_map(|f| {
340 let name = f.ident.as_ref()?;
341 let name_lit = name.to_string();
342 let field_type = &f.ty;
343
344 Some(quote_spanned! {f.span()=>
345 let #name: #field_type = row.try_get(#name_lit)?;
346 })
347 });
348
349 let struct_fields = fields.named.iter().map(|f| {
350 let name = &f.ident;
351 quote_spanned! {f.span() => #name}
352 });
353
354 let expanded = quote! {
355 impl #impl_generics ts_sql_helper_lib::FromRow for #name #ty_generics #where_clause {
357 fn from_row(row: &ts_sql_helper_lib::postgres::Row) -> Result<Self, ts_sql_helper_lib::postgres::Error> {
358 #( #each_field_from_row )*
359
360 Ok(Self {
361 #( #struct_fields ),*
362 })
363 }
364 }
365 };
366
367 TokenStream::from(expanded)
369}
370
371#[proc_macro_derive(FromSql)]
373pub fn derive_from_sql(input: TokenStream) -> TokenStream {
374 let input = parse_macro_input!(input as DeriveInput);
376
377 if !matches!(input.data, Data::Enum(_)) {
378 panic!("FromSql can only be derived on an enum")
379 }
380
381 let name = input.ident;
382
383 let (repr, accepts, from_sql) = {
384 let mut repr_type = parse_quote!(&str);
385 let mut accepts: Vec<Type> = vec![
386 parse_quote!(ts_sql_helper_lib::postgres_types::Type::TEXT),
387 parse_quote!(ts_sql_helper_lib::postgres_types::Type::VARCHAR),
388 ];
389 let mut from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::text_from_sql(
390 raw
391 )?);
392
393 for attr in input.attrs {
394 if !attr.path().is_ident("repr") {
395 continue;
396 }
397
398 let Ok(arg) = attr.parse_args::<Type>() else {
399 continue;
400 };
401
402 if arg == parse_quote!(i8) {
403 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::CHAR)];
404 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::char_from_sql(
405 raw
406 )?);
407 } else if arg == parse_quote!(i16) {
408 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT2)];
409 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int2_from_sql(
410 raw
411 )?);
412 } else if arg == parse_quote!(i32) {
413 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT4)];
414 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int4_from_sql(
415 raw
416 )?);
417 } else if arg == parse_quote!(i64) {
418 accepts = vec![parse_quote!(ts_sql_helper_lib::postgres_types::Type::INT8)];
419 from_sql = quote!(ts_sql_helper_lib::postgres_protocol::types::int8_from_sql(
420 raw
421 )?);
422 } else {
423 continue;
424 }
425
426 repr_type = arg;
427 break;
428 }
429
430 (repr_type, accepts, from_sql)
431 };
432
433 let generics = add_trait_bounds(input.generics, parse_quote!(TryFrom<#repr>));
434 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
435
436 let expanded = quote! {
437 impl<'a> #impl_generics ts_sql_helper_lib::postgres::types::FromSql<'a> for #name #ty_generics #where_clause {
438 fn from_sql(_: &ts_sql_helper_lib::postgres::types::Type, raw: &[u8]) -> Result<Self, Box<dyn core::error::Error + Sync + Send>> {
439 let raw_value = #from_sql;
440 let value = Self::try_from(raw_value)?;
441 Ok(value)
442 }
443
444 fn accepts(ty: &ts_sql_helper_lib::postgres_types::Type) -> bool {
445 match (*ty) {
446 #(#accepts)|* => true,
447 _ => false,
448 }
449 }
450 }
451 };
452
453 TokenStream::from(expanded)
454}
455
456fn add_trait_bounds(mut generics: Generics, bounds: TypeParamBound) -> Generics {
458 for param in &mut generics.params {
459 if let GenericParam::Type(ref mut type_param) = *param {
460 type_param.bounds.push(bounds.clone());
461 }
462 }
463 generics
464}