1use proc_macro::TokenStream;
2use syn::{TraitItemFn, parse_macro_input};
3
4mod alias_system;
5mod code_generator;
6mod constants;
7mod dml;
8mod error;
9mod fetch;
10mod method_variants;
11mod repo_system;
12mod scope_system;
13mod type_analyzer;
14mod type_system;
15
16#[cfg(test)]
17mod test_framework;
18
19use code_generator::CodeGenerator;
20use dml::DmlParser;
21
22#[proc_macro_attribute]
79pub fn dml(args: TokenStream, input: TokenStream) -> TokenStream {
80 let method = parse_macro_input!(input as TraitItemFn);
81
82 let parsed_method = match DmlParser::parse_dml_method_with_args(method, args, false) {
84 Ok(method) => method,
85 Err(error) => return error.to_compile_error().into(),
86 };
87
88 match CodeGenerator::generate_dml_methods(&parsed_method) {
90 Ok(tokens) => tokens.into(),
91 Err(error) => error.to_compile_error().into(),
92 }
93}
94
95#[proc_macro_attribute]
97pub fn repo(args: TokenStream, input: TokenStream) -> TokenStream {
98 let input_trait = parse_macro_input!(input as syn::ItemTrait);
99
100 match repo_system::RepoProcessor::process_trait_with_args(input_trait, args) {
101 Ok(tokens) => tokens.into(),
102 Err(error) => error.to_compile_error().into(),
103 }
104}
105
106#[proc_macro_attribute]
145pub fn generate_versions(args: TokenStream, input: TokenStream) -> TokenStream {
146 let input_method = parse_macro_input!(input as TraitItemFn);
147 let args_tokens = proc_macro2::TokenStream::from(args);
148
149 match method_variants::expand_method_variants(input_method, args_tokens) {
150 Ok(tokens) => TokenStream::from(tokens),
151 Err(error) => error.to_compile_error().into(),
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use syn::parse_quote;
158
159 fn create_test_dml_method(
161 method_name: &str,
162 sql: &str,
163 parameters: Vec<crate::dml::DmlParameter>,
164 return_type: syn::Type,
165 ) -> crate::dml::DmlMethod {
166 use syn::{FnArg, Pat, PatIdent, PatType, Signature, TraitItemFn};
167
168 let mut inputs = syn::punctuated::Punctuated::new();
170
171 inputs.push(FnArg::Receiver(syn::Receiver {
173 attrs: vec![],
174 reference: Some((syn::Token), None)),
175 mutability: None,
176 self_token: syn::Token),
177 colon_token: None,
178 ty: Box::new(parse_quote! { Self }),
179 }));
180
181 for param in ¶meters {
183 let pat = PatIdent {
184 attrs: vec![],
185 by_ref: None,
186 mutability: None,
187 ident: syn::Ident::new(¶m.name, proc_macro2::Span::call_site()),
188 subpat: None,
189 };
190
191 inputs.push(FnArg::Typed(PatType {
192 attrs: vec![],
193 pat: Box::new(Pat::Ident(pat)),
194 colon_token: syn::Token),
195 ty: Box::new(param.type_.clone()),
196 }));
197 }
198
199 let is_stream_type = matches!(&return_type, syn::Type::ImplTrait(impl_trait)
201 if impl_trait.bounds.iter().any(|bound| {
202 if let syn::TypeParamBound::Trait(trait_bound) = bound {
203 trait_bound.path.segments.last()
204 .map_or(false, |seg| seg.ident == "Stream")
205 } else {
206 false
207 }
208 })
209 );
210
211 let sig = Signature {
212 constness: None,
213 asyncness: if is_stream_type {
214 None
215 } else {
216 Some(syn::Token))
217 },
218 unsafety: None,
219 abi: None,
220 fn_token: syn::Token),
221 ident: syn::Ident::new(method_name, proc_macro2::Span::call_site()),
222 generics: syn::Generics::default(),
223 paren_token: syn::token::Paren::default(),
224 inputs,
225 variadic: None,
226 output: syn::ReturnType::Type(
227 syn::Token),
228 Box::new(return_type),
229 ),
230 };
231
232 let trait_method = TraitItemFn {
233 attrs: vec![],
234 sig,
235 default: None,
236 semi_token: Some(syn::Token)),
237 };
238
239 crate::dml::DmlMethod {
240 method: trait_method,
241 sql_content: sql.to_string(),
242 parameters,
243 statement: sqlx_data_parser::parse_sql(sql).unwrap(),
244 kind: sqlx_data_parser::SqlStatementType::Select,
245 is_json_query: false,
246 is_multi_insert: false,
247 is_unchecked: false,
248 has_explicit_instrument: false,
249 trait_instrument: false,
250 return_info_cache: std::sync::OnceLock::new(),
251 }
252 }
253
254 #[test]
255 fn test_dml_macro_basic() {
256 use crate::code_generator::CodeGenerator;
257 use crate::dml::DmlParameter;
258 use syn::parse_quote;
259
260 let method = create_test_dml_method(
261 "find_by_id",
262 "SELECT * FROM users WHERE id = $1",
263 vec![DmlParameter {
264 name: "id".to_string(),
265 type_: parse_quote! { i64 },
266 is_pool: false,
267 is_dynamic_param: false,
268 is_generic: false,
269 }],
270 parse_quote! { Result<User> },
271 );
272
273 let result = CodeGenerator::generate_dml_methods(&method);
274 assert!(result.is_ok());
275
276 let generated_code = result.unwrap().to_string();
277 assert!(generated_code.contains("find_by_id_query"));
278 assert!(generated_code.contains("find_by_id"));
279 assert!(generated_code.contains("sqlx::query_as!"));
280 }
281
282 #[test]
283 fn test_dml_macro_with_flatten() {
284 use crate::code_generator::CodeGenerator;
285 use crate::dml::DmlParameter;
286 use syn::parse_quote;
287
288 let method = create_test_dml_method(
289 "get_birth_year",
290 "SELECT birth_year FROM users WHERE id = $1",
291 vec![DmlParameter {
292 name: "id".to_string(),
293 type_: parse_quote! { i64 },
294 is_pool: false,
295 is_dynamic_param: false,
296 is_generic: false,
297 }],
298 parse_quote! { Result<Option<i64>> },
299 );
300
301 let result = CodeGenerator::generate_dml_methods(&method);
302 assert!(result.is_ok());
303
304 let generated_code = result.unwrap().to_string();
305 assert!(generated_code.contains("get_birth_year_query"));
306 assert!(generated_code.contains("get_birth_year"));
307 assert!(generated_code.contains("sqlx::query_scalar!"));
308 }
309
310 #[test]
311 #[cfg(feature = "sqlite")]
312 fn test_tuple_f32_casting() {
313 use crate::code_generator::CodeGenerator;
314 use syn::parse_quote;
315
316 let method = create_test_dml_method(
317 "group_avg",
318 "SELECT birth_year, AVG(age) as avg_age FROM users GROUP BY birth_year",
319 vec![],
320 parse_quote! { Result<Vec<(Option<u16>, f32)>> },
321 );
322
323 let result = CodeGenerator::generate_dml_methods(&method);
324 assert!(result.is_ok());
325
326 let generated_code = result.unwrap().to_string();
327 eprintln!("Generated Code for tuple casting:\n{}", generated_code);
328
329 assert!(generated_code.contains("as f32"));
331 assert!(generated_code.contains("as u16"));
333 assert!(generated_code.contains("group_avg_query"));
334 assert!(generated_code.contains("QueryTuple"));
335 }
336
337 #[test]
338 #[cfg(feature = "sqlite")]
339 fn test_tuple_f64_casting() {
340 use crate::code_generator::CodeGenerator;
341 use crate::dml::DmlParameter;
342 use syn::parse_quote;
343
344 let method = create_test_dml_method(
345 "group_having_avg",
346 "SELECT birth_year, AVG(age) as avg_age FROM users WHERE birth_year IS NOT NULL GROUP BY birth_year HAVING AVG(age) > $1",
347 vec![DmlParameter {
348 name: "min_avg".to_string(),
349 type_: parse_quote! { f32 },
350 is_pool: false,
351 is_dynamic_param: false,
352 is_generic: false,
353 }],
354 parse_quote! { Result<Vec<(Option<u16>, f64)>> },
355 );
356
357 let result = CodeGenerator::generate_dml_methods(&method);
358 assert!(result.is_ok());
359
360 let generated_code = result.unwrap().to_string();
361 eprintln!("Generated Code for f64 casting:\\n{}", generated_code);
362
363 assert!(generated_code.contains("as u16"));
365 assert!(generated_code.contains("group_having_avg_query"));
366 assert!(generated_code.contains("QueryTuple"));
367 }
368
369 #[test]
370 #[cfg(feature = "sqlite")]
371 fn test_tuple_i64_usize_casting() {
372 use crate::code_generator::CodeGenerator;
373 use syn::parse_quote;
374
375 let method = create_test_dml_method(
376 "count_by_year",
377 "SELECT birth_year, COUNT(*) as count FROM users GROUP BY birth_year",
378 vec![],
379 parse_quote! { Result<Vec<(Option<i64>, usize)>> },
380 );
381
382 let result = CodeGenerator::generate_dml_methods(&method);
383 assert!(result.is_ok());
384
385 let generated_code = result.unwrap().to_string();
386 eprintln!("Generated Code for i64/usize casting:\\n{}", generated_code);
387
388 assert!(generated_code.contains("as usize"));
391 assert!(generated_code.contains("count_by_year_query"));
392 }
393
394 #[test]
395 fn test_documentation_generation() {
396 use crate::code_generator::CodeGenerator;
397 use syn::parse_quote;
398
399 let method = create_test_dml_method(
400 "find_by_id",
401 "SELECT * FROM users WHERE id = $1",
402 vec![],
403 parse_quote! { Result<User> },
404 );
405
406 let result = CodeGenerator::generate_dml_methods(&method);
407 assert!(result.is_ok());
408
409 let generated_code = result.unwrap().to_string();
410 eprintln!("Generated Code:\n{}", generated_code);
411
412 assert!(generated_code.contains("# [doc = "));
414 assert!(generated_code.contains("Generated by #[dml] macro:"));
415 assert!(generated_code.contains("```rust"));
416 assert!(generated_code.contains("find_by_id_query"));
417 assert!(generated_code.contains("sqlx :: query_as !"));
418 }
419}