1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15 pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18
19 let entity_ident = format_ident!("{}", entity.struct_name);
20 let repo_name = format!("{}Repository", entity.struct_name);
21 let repo_ident = format_ident!("{}", repo_name);
22
23 let table_name = match &entity.schema_name {
24 Some(schema) => format!("{}.{}", schema, entity.table_name),
25 None => entity.table_name.clone(),
26 };
27
28 let pool_type = pool_type_tokens(db_kind);
30
31 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37 imports.insert(format!(
39 "use {}::{};",
40 entity_module_path, entity.struct_name
41 ));
42
43 let entity_parent = entity_module_path
47 .rsplit_once("::")
48 .map(|(parent, _)| parent)
49 .unwrap_or(entity_module_path);
50 for imp in &entity.imports {
51 if let Some(rest) = imp.strip_prefix("use super::") {
52 imports.insert(format!("use {}::{}", entity_parent, rest));
53 } else {
54 imports.insert(imp.clone());
55 }
56 }
57
58 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
60
61 let non_pk_fields: Vec<&ParsedField> =
63 entity.fields.iter().filter(|f| !f.is_primary_key).collect();
64
65 let is_view = entity.is_view;
66
67 let mut method_tokens = Vec::new();
69 let mut param_structs = Vec::new();
70
71 if methods.get_all {
73 let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
74 let method = if use_macro {
75 quote! {
76 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
77 sqlx::query_as!(#entity_ident, #sql)
78 .fetch_all(&self.pool)
79 .await
80 }
81 }
82 } else {
83 quote! {
84 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
85 sqlx::query_as::<_, #entity_ident>(#sql)
86 .fetch_all(&self.pool)
87 .await
88 }
89 }
90 };
91 method_tokens.push(method);
92 }
93
94 if methods.paginate {
96 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
97 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
98 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
99 let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
100 let sql = raw_sql_lit(&match db_kind {
101 DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
102 DatabaseKind::Mysql | DatabaseKind::Sqlite => {
103 format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name)
104 }
105 });
106 let method = if use_macro {
107 quote! {
108 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
109 let total: i64 = sqlx::query_scalar!(#count_sql)
110 .fetch_one(&self.pool)
111 .await?
112 .unwrap_or(0);
113 let per_page = params.per_page;
114 let current_page = params.page;
115 let last_page = (total + per_page - 1) / per_page;
116 let offset = (current_page - 1) * per_page;
117 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
118 .fetch_all(&self.pool)
119 .await?;
120 Ok(#paginated_ident {
121 meta: #pagination_meta_ident {
122 total,
123 per_page,
124 current_page,
125 last_page,
126 first_page: 1,
127 },
128 data,
129 })
130 }
131 }
132 } else {
133 quote! {
134 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
135 let total: i64 = sqlx::query_scalar(#count_sql)
136 .fetch_one(&self.pool)
137 .await?;
138 let per_page = params.per_page;
139 let current_page = params.page;
140 let last_page = (total + per_page - 1) / per_page;
141 let offset = (current_page - 1) * per_page;
142 let data = sqlx::query_as::<_, #entity_ident>(#sql)
143 .bind(per_page)
144 .bind(offset)
145 .fetch_all(&self.pool)
146 .await?;
147 Ok(#paginated_ident {
148 meta: #pagination_meta_ident {
149 total,
150 per_page,
151 current_page,
152 last_page,
153 first_page: 1,
154 },
155 data,
156 })
157 }
158 }
159 };
160 method_tokens.push(method);
161 param_structs.push(quote! {
162 #[derive(Debug, Clone, Default)]
163 pub struct #paginate_params_ident {
164 pub page: i64,
165 pub per_page: i64,
166 }
167 });
168 param_structs.push(quote! {
169 #[derive(Debug, Clone)]
170 pub struct #pagination_meta_ident {
171 pub total: i64,
172 pub per_page: i64,
173 pub current_page: i64,
174 pub last_page: i64,
175 pub first_page: i64,
176 }
177 });
178 param_structs.push(quote! {
179 #[derive(Debug, Clone)]
180 pub struct #paginated_ident {
181 pub meta: #pagination_meta_ident,
182 pub data: Vec<#entity_ident>,
183 }
184 });
185 }
186
187 if methods.get && !pk_fields.is_empty() {
189 let pk_params: Vec<TokenStream> = pk_fields
190 .iter()
191 .map(|f| {
192 let name = format_ident!("{}", f.rust_name);
193 let ty: TokenStream = f.inner_type.parse().unwrap();
194 quote! { #name: #ty }
195 })
196 .collect();
197
198 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
199 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
200 let sql = raw_sql_lit(&format!(
201 "SELECT *\nFROM {}\nWHERE {}",
202 table_name, where_clause
203 ));
204 let sql_macro = raw_sql_lit(&format!(
205 "SELECT *\nFROM {}\nWHERE {}",
206 table_name, where_clause_cast
207 ));
208
209 let binds: Vec<TokenStream> = pk_fields
210 .iter()
211 .map(|f| {
212 let name = format_ident!("{}", f.rust_name);
213 quote! { .bind(#name) }
214 })
215 .collect();
216
217 let method = if use_macro {
218 let pk_arg_names: Vec<TokenStream> = pk_fields
219 .iter()
220 .map(|f| {
221 let name = format_ident!("{}", f.rust_name);
222 quote! { #name }
223 })
224 .collect();
225 quote! {
226 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
227 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
228 .fetch_optional(&self.pool)
229 .await
230 }
231 }
232 } else {
233 quote! {
234 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
235 sqlx::query_as::<_, #entity_ident>(#sql)
236 #(#binds)*
237 .fetch_optional(&self.pool)
238 .await
239 }
240 }
241 };
242 method_tokens.push(method);
243 }
244
245 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
247 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
248
249 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
251 pk_fields.clone()
252 } else {
253 non_pk_fields.clone()
254 };
255
256 let insert_fields: Vec<TokenStream> = insert_source_fields
258 .iter()
259 .map(|f| {
260 let name = format_ident!("{}", f.rust_name);
261 if f.column_default.is_some() && !f.is_nullable {
262 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
263 quote! { pub #name: #ty, }
264 } else {
265 let ty: TokenStream = f.rust_type.parse().unwrap();
266 quote! { pub #name: #ty, }
267 }
268 })
269 .collect();
270
271 let col_names: Vec<&str> = insert_source_fields
272 .iter()
273 .map(|f| f.column_name.as_str())
274 .collect();
275 let col_list = col_names.join(", ");
276
277 let placeholders: String = insert_source_fields
279 .iter()
280 .enumerate()
281 .map(|(i, f)| {
282 let p = placeholder(db_kind, i + 1);
283 match &f.column_default {
284 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
285 None => p,
286 }
287 })
288 .collect::<Vec<_>>()
289 .join(", ");
290
291 let placeholders_cast: String = insert_source_fields
292 .iter()
293 .enumerate()
294 .map(|(i, f)| {
295 let p = placeholder_with_cast(db_kind, i + 1, f);
296 match &f.column_default {
297 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
298 None => p,
299 }
300 })
301 .collect::<Vec<_>>()
302 .join(", ");
303
304 let build_insert_sql = |ph: &str| match db_kind {
305 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
306 format!(
307 "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
308 table_name, col_list, ph
309 )
310 }
311 DatabaseKind::Mysql => {
312 format!("INSERT INTO {} ({})\nVALUES ({})", table_name, col_list, ph)
313 }
314 };
315 let sql = build_insert_sql(&placeholders);
316 let sql_macro = build_insert_sql(&placeholders_cast);
317
318 let binds: Vec<TokenStream> = insert_source_fields
319 .iter()
320 .map(|f| {
321 let name = format_ident!("{}", f.rust_name);
322 quote! { .bind(¶ms.#name) }
323 })
324 .collect();
325
326 let insert_method = build_insert_method_parsed(
327 &entity_ident,
328 &insert_params_ident,
329 &sql,
330 &sql_macro,
331 &binds,
332 db_kind,
333 &table_name,
334 &pk_fields,
335 &insert_source_fields,
336 use_macro,
337 );
338 method_tokens.push(insert_method);
339
340 param_structs.push(quote! {
341 #[derive(Debug, Clone, Default)]
342 pub struct #insert_params_ident {
343 #(#insert_fields)*
344 }
345 });
346 }
347
348 if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
350 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
351
352 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
353 pk_fields.clone()
354 } else {
355 non_pk_fields.clone()
356 };
357
358 let col_names: Vec<&str> = insert_source_fields
359 .iter()
360 .map(|f| f.column_name.as_str())
361 .collect();
362 let col_list = col_names.join(", ");
363 let num_cols = insert_source_fields.len();
364
365 let binds_loop: Vec<TokenStream> = insert_source_fields
366 .iter()
367 .map(|f| {
368 let name = format_ident!("{}", f.rust_name);
369 quote! { query = query.bind(¶ms.#name); }
370 })
371 .collect();
372
373 let insert_many_method = build_insert_many_transactionally_method(
374 &entity_ident,
375 &insert_params_ident,
376 &col_list,
377 num_cols,
378 &insert_source_fields,
379 &binds_loop,
380 db_kind,
381 &table_name,
382 &pk_fields,
383 );
384 method_tokens.push(insert_many_method);
385
386 if !methods.insert {
388 let insert_fields: Vec<TokenStream> = insert_source_fields
389 .iter()
390 .map(|f| {
391 let name = format_ident!("{}", f.rust_name);
392 if f.column_default.is_some() && !f.is_nullable {
393 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
394 quote! { pub #name: #ty, }
395 } else {
396 let ty: TokenStream = f.rust_type.parse().unwrap();
397 quote! { pub #name: #ty, }
398 }
399 })
400 .collect();
401
402 param_structs.push(quote! {
403 #[derive(Debug, Clone, Default)]
404 pub struct #insert_params_ident {
405 #(#insert_fields)*
406 }
407 });
408 }
409 }
410
411 if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
413 let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
414
415 let pk_fn_params: Vec<TokenStream> = pk_fields
417 .iter()
418 .map(|f| {
419 let name = format_ident!("{}", f.rust_name);
420 let ty: TokenStream = f.inner_type.parse().unwrap();
421 quote! { #name: #ty }
422 })
423 .collect();
424
425 let overwrite_fields: Vec<TokenStream> = non_pk_fields
427 .iter()
428 .map(|f| {
429 let name = format_ident!("{}", f.rust_name);
430 let ty: TokenStream = f.rust_type.parse().unwrap();
431 quote! { pub #name: #ty, }
432 })
433 .collect();
434
435 let set_cols: Vec<String> = non_pk_fields
436 .iter()
437 .enumerate()
438 .map(|(i, f)| {
439 let p = placeholder(db_kind, i + 1);
440 format!("{} = {}", f.column_name, p)
441 })
442 .collect();
443 let set_clause = set_cols.join(",\n ");
444
445 let set_cols_cast: Vec<String> = non_pk_fields
446 .iter()
447 .enumerate()
448 .map(|(i, f)| {
449 let p = placeholder_with_cast(db_kind, i + 1, f);
450 format!("{} = {}", f.column_name, p)
451 })
452 .collect();
453 let set_clause_cast = set_cols_cast.join(",\n ");
454
455 let pk_start = non_pk_fields.len() + 1;
456 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
457 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
458
459 let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
460 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
461 format!(
462 "UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *",
463 table_name, sc, wc
464 )
465 }
466 DatabaseKind::Mysql => {
467 format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc)
468 }
469 };
470 let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
471 let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
472
473 let mut all_binds: Vec<TokenStream> = non_pk_fields
475 .iter()
476 .map(|f| {
477 let name = format_ident!("{}", f.rust_name);
478 quote! { .bind(¶ms.#name) }
479 })
480 .collect();
481 for f in &pk_fields {
482 let name = format_ident!("{}", f.rust_name);
483 all_binds.push(quote! { .bind(#name) });
484 }
485
486 let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
488 .iter()
489 .map(|f| macro_arg_for_field(f))
490 .chain(pk_fields.iter().map(|f| {
491 let name = format_ident!("{}", f.rust_name);
492 quote! { #name }
493 }))
494 .collect();
495
496 let overwrite_method = if use_macro {
497 match db_kind {
498 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
499 quote! {
500 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
501 sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
502 .fetch_one(&self.pool)
503 .await
504 }
505 }
506 }
507 DatabaseKind::Mysql => {
508 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
509 let select_sql = raw_sql_lit(&format!(
510 "SELECT *\nFROM {}\nWHERE {}",
511 table_name, pk_where_select
512 ));
513 let pk_macro_args: Vec<TokenStream> = pk_fields
514 .iter()
515 .map(|f| {
516 let name = format_ident!("{}", f.rust_name);
517 quote! { #name }
518 })
519 .collect();
520 quote! {
521 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
522 sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
523 .execute(&self.pool)
524 .await?;
525 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
526 .fetch_one(&self.pool)
527 .await
528 }
529 }
530 }
531 }
532 } else {
533 match db_kind {
534 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
535 quote! {
536 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
537 sqlx::query_as::<_, #entity_ident>(#sql)
538 #(#all_binds)*
539 .fetch_one(&self.pool)
540 .await
541 }
542 }
543 }
544 DatabaseKind::Mysql => {
545 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
546 let select_sql = raw_sql_lit(&format!(
547 "SELECT *\nFROM {}\nWHERE {}",
548 table_name, pk_where_select
549 ));
550 let pk_binds: Vec<TokenStream> = pk_fields
551 .iter()
552 .map(|f| {
553 let name = format_ident!("{}", f.rust_name);
554 quote! { .bind(#name) }
555 })
556 .collect();
557 quote! {
558 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
559 sqlx::query(#sql)
560 #(#all_binds)*
561 .execute(&self.pool)
562 .await?;
563 sqlx::query_as::<_, #entity_ident>(#select_sql)
564 #(#pk_binds)*
565 .fetch_one(&self.pool)
566 .await
567 }
568 }
569 }
570 }
571 };
572 method_tokens.push(overwrite_method);
573
574 param_structs.push(quote! {
575 #[derive(Debug, Clone, Default)]
576 pub struct #overwrite_params_ident {
577 #(#overwrite_fields)*
578 }
579 });
580 }
581
582 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
584 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
585
586 let pk_fn_params: Vec<TokenStream> = pk_fields
588 .iter()
589 .map(|f| {
590 let name = format_ident!("{}", f.rust_name);
591 let ty: TokenStream = f.inner_type.parse().unwrap();
592 quote! { #name: #ty }
593 })
594 .collect();
595
596 let update_fields: Vec<TokenStream> = non_pk_fields
598 .iter()
599 .map(|f| {
600 let name = format_ident!("{}", f.rust_name);
601 if f.is_nullable {
602 let ty: TokenStream = f.rust_type.parse().unwrap();
604 quote! { pub #name: #ty, }
605 } else {
606 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
607 quote! { pub #name: #ty, }
608 }
609 })
610 .collect();
611
612 let set_cols: Vec<String> = non_pk_fields
614 .iter()
615 .enumerate()
616 .map(|(i, f)| {
617 let p = placeholder(db_kind, i + 1);
618 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
619 })
620 .collect();
621 let set_clause = set_cols.join(",\n ");
622
623 let set_cols_cast: Vec<String> = non_pk_fields
625 .iter()
626 .enumerate()
627 .map(|(i, f)| {
628 let p = placeholder_with_cast(db_kind, i + 1, f);
629 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
630 })
631 .collect();
632 let set_clause_cast = set_cols_cast.join(",\n ");
633
634 let pk_start = non_pk_fields.len() + 1;
635 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
636 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
637
638 let build_update_sql = |sc: &str, wc: &str| match db_kind {
639 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
640 format!(
641 "UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *",
642 table_name, sc, wc
643 )
644 }
645 DatabaseKind::Mysql => {
646 format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc)
647 }
648 };
649 let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
650 let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
651
652 let mut all_binds: Vec<TokenStream> = non_pk_fields
654 .iter()
655 .map(|f| {
656 let name = format_ident!("{}", f.rust_name);
657 quote! { .bind(¶ms.#name) }
658 })
659 .collect();
660 for f in &pk_fields {
661 let name = format_ident!("{}", f.rust_name);
662 all_binds.push(quote! { .bind(#name) });
663 }
664
665 let update_macro_args: Vec<TokenStream> = non_pk_fields
667 .iter()
668 .map(|f| macro_arg_for_field(f))
669 .chain(pk_fields.iter().map(|f| {
670 let name = format_ident!("{}", f.rust_name);
671 quote! { #name }
672 }))
673 .collect();
674
675 let update_method = if use_macro {
676 match db_kind {
677 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
678 quote! {
679 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
680 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
681 .fetch_one(&self.pool)
682 .await
683 }
684 }
685 }
686 DatabaseKind::Mysql => {
687 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
688 let select_sql = raw_sql_lit(&format!(
689 "SELECT *\nFROM {}\nWHERE {}",
690 table_name, pk_where_select
691 ));
692 let pk_macro_args: Vec<TokenStream> = pk_fields
693 .iter()
694 .map(|f| {
695 let name = format_ident!("{}", f.rust_name);
696 quote! { #name }
697 })
698 .collect();
699 quote! {
700 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
701 sqlx::query!(#sql_macro, #(#update_macro_args),*)
702 .execute(&self.pool)
703 .await?;
704 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
705 .fetch_one(&self.pool)
706 .await
707 }
708 }
709 }
710 }
711 } else {
712 match db_kind {
713 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
714 quote! {
715 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
716 sqlx::query_as::<_, #entity_ident>(#sql)
717 #(#all_binds)*
718 .fetch_one(&self.pool)
719 .await
720 }
721 }
722 }
723 DatabaseKind::Mysql => {
724 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
725 let select_sql = raw_sql_lit(&format!(
726 "SELECT *\nFROM {}\nWHERE {}",
727 table_name, pk_where_select
728 ));
729 let pk_binds: Vec<TokenStream> = pk_fields
730 .iter()
731 .map(|f| {
732 let name = format_ident!("{}", f.rust_name);
733 quote! { .bind(#name) }
734 })
735 .collect();
736 quote! {
737 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
738 sqlx::query(#sql)
739 #(#all_binds)*
740 .execute(&self.pool)
741 .await?;
742 sqlx::query_as::<_, #entity_ident>(#select_sql)
743 #(#pk_binds)*
744 .fetch_one(&self.pool)
745 .await
746 }
747 }
748 }
749 }
750 };
751 method_tokens.push(update_method);
752
753 param_structs.push(quote! {
754 #[derive(Debug, Clone, Default)]
755 pub struct #update_params_ident {
756 #(#update_fields)*
757 }
758 });
759 }
760
761 if !is_view && methods.delete && !pk_fields.is_empty() {
763 let pk_params: Vec<TokenStream> = pk_fields
764 .iter()
765 .map(|f| {
766 let name = format_ident!("{}", f.rust_name);
767 let ty: TokenStream = f.inner_type.parse().unwrap();
768 quote! { #name: #ty }
769 })
770 .collect();
771
772 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
773 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
774 let sql = raw_sql_lit(&format!(
775 "DELETE FROM {}\nWHERE {}",
776 table_name, where_clause
777 ));
778 let sql_macro = raw_sql_lit(&format!(
779 "DELETE FROM {}\nWHERE {}",
780 table_name, where_clause_cast
781 ));
782
783 let binds: Vec<TokenStream> = pk_fields
784 .iter()
785 .map(|f| {
786 let name = format_ident!("{}", f.rust_name);
787 quote! { .bind(#name) }
788 })
789 .collect();
790
791 let method = if query_macro {
792 let pk_arg_names: Vec<TokenStream> = pk_fields
793 .iter()
794 .map(|f| {
795 let name = format_ident!("{}", f.rust_name);
796 quote! { #name }
797 })
798 .collect();
799 quote! {
800 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
801 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
802 .execute(&self.pool)
803 .await?;
804 Ok(())
805 }
806 }
807 } else {
808 quote! {
809 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
810 sqlx::query(#sql)
811 #(#binds)*
812 .execute(&self.pool)
813 .await?;
814 Ok(())
815 }
816 }
817 };
818 method_tokens.push(method);
819 }
820
821 let pool_vis: TokenStream = match pool_visibility {
822 PoolVisibility::Private => quote! {},
823 PoolVisibility::Pub => quote! { pub },
824 PoolVisibility::PubCrate => quote! { pub(crate) },
825 };
826
827 let tokens = quote! {
828 #(#param_structs)*
829
830 pub struct #repo_ident {
831 #pool_vis pool: #pool_type,
832 }
833
834 impl #repo_ident {
835 pub fn new(pool: #pool_type) -> Self {
836 Self { pool }
837 }
838
839 #(#method_tokens)*
840 }
841 };
842
843 (tokens, imports)
844}
845
846fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
847 match db_kind {
848 DatabaseKind::Postgres => quote! { sqlx::PgPool },
849 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
850 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
851 }
852}
853
854fn raw_sql_lit(s: &str) -> TokenStream {
857 if s.contains('\n') {
858 format!("r#\"\n{}\n\"#", s).parse().unwrap()
859 } else {
860 format!("r#\"{}\"#", s).parse().unwrap()
861 }
862}
863
864fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
865 match db_kind {
866 DatabaseKind::Postgres => format!("${}", index),
867 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
868 }
869}
870
871fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
872 let base = placeholder(db_kind, index);
873 match (&field.sql_type, field.is_sql_array) {
874 (Some(t), true) => format!("{} as {}[]", base, t),
875 (Some(t), false) => format!("{} as {}", base, t),
876 (None, _) => base,
877 }
878}
879
880fn build_where_clause_parsed(
881 pk_fields: &[&ParsedField],
882 db_kind: DatabaseKind,
883 start_index: usize,
884) -> String {
885 pk_fields
886 .iter()
887 .enumerate()
888 .map(|(i, f)| {
889 let p = placeholder(db_kind, start_index + i);
890 format!("{} = {}", f.column_name, p)
891 })
892 .collect::<Vec<_>>()
893 .join(" AND ")
894}
895
896fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
897 let name = format_ident!("{}", field.rust_name);
898 let check_type = if field.is_nullable {
899 &field.inner_type
900 } else {
901 &field.rust_type
902 };
903 let normalized = check_type.replace(' ', "");
904 if normalized.starts_with("Vec<") {
905 quote! { params.#name.as_slice() }
906 } else {
907 quote! { params.#name }
908 }
909}
910
911fn build_where_clause_cast(
912 pk_fields: &[&ParsedField],
913 db_kind: DatabaseKind,
914 start_index: usize,
915) -> String {
916 pk_fields
917 .iter()
918 .enumerate()
919 .map(|(i, f)| {
920 let p = placeholder_with_cast(db_kind, start_index + i, f);
921 format!("{} = {}", f.column_name, p)
922 })
923 .collect::<Vec<_>>()
924 .join(" AND ")
925}
926
927#[allow(clippy::too_many_arguments)]
928fn build_insert_method_parsed(
929 entity_ident: &proc_macro2::Ident,
930 insert_params_ident: &proc_macro2::Ident,
931 sql: &str,
932 sql_macro: &str,
933 binds: &[TokenStream],
934 db_kind: DatabaseKind,
935 table_name: &str,
936 pk_fields: &[&ParsedField],
937 non_pk_fields: &[&ParsedField],
938 use_macro: bool,
939) -> TokenStream {
940 let sql = raw_sql_lit(sql);
941 let sql_macro = raw_sql_lit(sql_macro);
942
943 if use_macro {
944 let macro_args: Vec<TokenStream> = non_pk_fields
945 .iter()
946 .map(|f| macro_arg_for_field(f))
947 .collect();
948
949 match db_kind {
950 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
951 quote! {
952 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
953 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
954 .fetch_one(&self.pool)
955 .await
956 }
957 }
958 }
959 DatabaseKind::Mysql => {
960 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
961 let select_sql = raw_sql_lit(&format!(
962 "SELECT *\nFROM {}\nWHERE {}",
963 table_name, pk_where
964 ));
965 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
966 quote! {
967 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
968 sqlx::query!(#sql_macro, #(#macro_args),*)
969 .execute(&self.pool)
970 .await?;
971 let id = sqlx::query_scalar!(#last_insert_id_sql)
972 .fetch_one(&self.pool)
973 .await?;
974 sqlx::query_as!(#entity_ident, #select_sql, id)
975 .fetch_one(&self.pool)
976 .await
977 }
978 }
979 }
980 }
981 } else {
982 match db_kind {
983 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
984 quote! {
985 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
986 sqlx::query_as::<_, #entity_ident>(#sql)
987 #(#binds)*
988 .fetch_one(&self.pool)
989 .await
990 }
991 }
992 }
993 DatabaseKind::Mysql => {
994 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
995 let select_sql = raw_sql_lit(&format!(
996 "SELECT *\nFROM {}\nWHERE {}",
997 table_name, pk_where
998 ));
999 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1000 quote! {
1001 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1002 sqlx::query(#sql)
1003 #(#binds)*
1004 .execute(&self.pool)
1005 .await?;
1006 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1007 .fetch_one(&self.pool)
1008 .await?;
1009 sqlx::query_as::<_, #entity_ident>(#select_sql)
1010 .bind(id)
1011 .fetch_one(&self.pool)
1012 .await
1013 }
1014 }
1015 }
1016 }
1017 }
1018}
1019
1020#[allow(clippy::too_many_arguments)]
1021fn build_insert_many_transactionally_method(
1022 entity_ident: &proc_macro2::Ident,
1023 insert_params_ident: &proc_macro2::Ident,
1024 col_list: &str,
1025 num_cols: usize,
1026 insert_source_fields: &[&ParsedField],
1027 binds_loop: &[TokenStream],
1028 db_kind: DatabaseKind,
1029 table_name: &str,
1030 pk_fields: &[&ParsedField],
1031) -> TokenStream {
1032 let body = match db_kind {
1033 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1034 let col_list_str = col_list.to_string();
1035 let table_name_str = table_name.to_string();
1036
1037 let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1038 .iter()
1039 .enumerate()
1040 .map(|(i, f)| {
1041 let offset = i;
1042 match &f.column_default {
1043 Some(default_expr) => {
1044 let def = default_expr.as_str();
1045 match db_kind {
1046 DatabaseKind::Postgres => quote! {
1047 format!("COALESCE(${}, {})", base + #offset + 1, #def)
1048 },
1049 _ => quote! {
1050 format!("COALESCE(?, {})", #def)
1051 },
1052 }
1053 }
1054 None => match db_kind {
1055 DatabaseKind::Postgres => quote! {
1056 format!("${}", base + #offset + 1)
1057 },
1058 _ => quote! {
1059 "?".to_string()
1060 },
1061 },
1062 }
1063 })
1064 .collect();
1065
1066 quote! {
1067 let mut tx = self.pool.begin().await?;
1068 let mut all_results = Vec::with_capacity(entries.len());
1069 let max_per_chunk = 65535 / #num_cols;
1070 for chunk in entries.chunks(max_per_chunk) {
1071 let mut values_parts = Vec::with_capacity(chunk.len());
1072 for (row_idx, _) in chunk.iter().enumerate() {
1073 let base = row_idx * #num_cols;
1074 let placeholders = vec![#(#row_placeholder_exprs),*];
1075 values_parts.push(format!("({})", placeholders.join(", ")));
1076 }
1077 let sql = format!(
1078 "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1079 #table_name_str,
1080 #col_list_str,
1081 values_parts.join(", ")
1082 );
1083 let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1084 for params in chunk {
1085 #(#binds_loop)*
1086 }
1087 let rows = query.fetch_all(&mut *tx).await?;
1088 all_results.extend(rows);
1089 }
1090 tx.commit().await?;
1091 Ok(all_results)
1092 }
1093 }
1094 DatabaseKind::Mysql => {
1095 let single_placeholders: String = insert_source_fields
1096 .iter()
1097 .enumerate()
1098 .map(|(i, f)| {
1099 let p = placeholder(db_kind, i + 1);
1100 match &f.column_default {
1101 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1102 None => p,
1103 }
1104 })
1105 .collect::<Vec<_>>()
1106 .join(", ");
1107
1108 let single_insert_sql = raw_sql_lit(&format!(
1109 "INSERT INTO {} ({})\nVALUES ({})",
1110 table_name, col_list, single_placeholders
1111 ));
1112
1113 let single_binds: Vec<TokenStream> = insert_source_fields
1114 .iter()
1115 .map(|f| {
1116 let name = format_ident!("{}", f.rust_name);
1117 quote! { .bind(¶ms.#name) }
1118 })
1119 .collect();
1120
1121 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1122 let select_sql = raw_sql_lit(&format!(
1123 "SELECT *\nFROM {}\nWHERE {}",
1124 table_name, pk_where
1125 ));
1126 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1127
1128 quote! {
1129 let mut tx = self.pool.begin().await?;
1130 let mut results = Vec::with_capacity(entries.len());
1131 for params in &entries {
1132 sqlx::query(#single_insert_sql)
1133 #(#single_binds)*
1134 .execute(&mut *tx)
1135 .await?;
1136 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1137 .fetch_one(&mut *tx)
1138 .await?;
1139 let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1140 .bind(id)
1141 .fetch_one(&mut *tx)
1142 .await?;
1143 results.push(row);
1144 }
1145 tx.commit().await?;
1146 Ok(results)
1147 }
1148 }
1149 };
1150
1151 quote! {
1152 pub async fn insert_many_transactionally(
1153 &self,
1154 entries: Vec<#insert_params_ident>,
1155 ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1156 #body
1157 }
1158 }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164 use crate::cli::Methods;
1165 use crate::codegen::parse_and_format;
1166 use crate::codegen::parse_and_format_with_tab_spaces;
1167
1168 fn make_field(
1169 rust_name: &str,
1170 column_name: &str,
1171 rust_type: &str,
1172 nullable: bool,
1173 is_pk: bool,
1174 ) -> ParsedField {
1175 let inner_type = if nullable {
1176 rust_type
1178 .strip_prefix("Option<")
1179 .and_then(|s| s.strip_suffix('>'))
1180 .unwrap_or(rust_type)
1181 .to_string()
1182 } else {
1183 rust_type.to_string()
1184 };
1185 ParsedField {
1186 rust_name: rust_name.to_string(),
1187 column_name: column_name.to_string(),
1188 rust_type: rust_type.to_string(),
1189 is_nullable: nullable,
1190 inner_type,
1191 is_primary_key: is_pk,
1192 sql_type: None,
1193 is_sql_array: false,
1194 column_default: None,
1195 }
1196 }
1197
1198 fn make_field_with_default(
1199 rust_name: &str,
1200 column_name: &str,
1201 rust_type: &str,
1202 nullable: bool,
1203 is_pk: bool,
1204 default: &str,
1205 ) -> ParsedField {
1206 let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1207 f.column_default = Some(default.to_string());
1208 f
1209 }
1210
1211 fn entity_with_defaults() -> ParsedEntity {
1212 ParsedEntity {
1213 struct_name: "Tasks".to_string(),
1214 table_name: "tasks".to_string(),
1215 schema_name: None,
1216 is_view: false,
1217 fields: vec![
1218 make_field("id", "id", "i32", false, true),
1219 make_field("title", "title", "String", false, false),
1220 make_field_with_default(
1221 "status",
1222 "status",
1223 "String",
1224 false,
1225 false,
1226 "'idle'::task_status",
1227 ),
1228 make_field_with_default("priority", "priority", "i32", false, false, "0"),
1229 make_field_with_default(
1230 "created_at",
1231 "created_at",
1232 "DateTime<Utc>",
1233 false,
1234 false,
1235 "now()",
1236 ),
1237 make_field("description", "description", "Option<String>", true, false),
1238 make_field_with_default(
1239 "deleted_at",
1240 "deleted_at",
1241 "Option<DateTime<Utc>>",
1242 true,
1243 false,
1244 "NULL",
1245 ),
1246 ],
1247 imports: vec![],
1248 }
1249 }
1250
1251 fn standard_entity() -> ParsedEntity {
1252 ParsedEntity {
1253 struct_name: "Users".to_string(),
1254 table_name: "users".to_string(),
1255 schema_name: None,
1256 is_view: false,
1257 fields: vec![
1258 make_field("id", "id", "i32", false, true),
1259 make_field("name", "name", "String", false, false),
1260 make_field("email", "email", "Option<String>", true, false),
1261 ],
1262 imports: vec![],
1263 }
1264 }
1265
1266 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1267 let skip = Methods::all();
1268 let (tokens, _) = generate_crud_from_parsed(
1269 entity,
1270 db,
1271 "crate::models::users",
1272 &skip,
1273 false,
1274 PoolVisibility::Private,
1275 );
1276 parse_and_format(&tokens)
1277 }
1278
1279 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1280 let skip = Methods::all();
1281 let (tokens, _) = generate_crud_from_parsed(
1282 entity,
1283 db,
1284 "crate::models::users",
1285 &skip,
1286 true,
1287 PoolVisibility::Private,
1288 );
1289 parse_and_format(&tokens)
1290 }
1291
1292 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1293 let (tokens, _) = generate_crud_from_parsed(
1294 entity,
1295 db,
1296 "crate::models::users",
1297 methods,
1298 false,
1299 PoolVisibility::Private,
1300 );
1301 parse_and_format(&tokens)
1302 }
1303
1304 fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1305 let skip = Methods::all();
1306 let (tokens, _) = generate_crud_from_parsed(
1307 entity,
1308 db,
1309 "crate::models::users",
1310 &skip,
1311 false,
1312 PoolVisibility::Private,
1313 );
1314 parse_and_format_with_tab_spaces(&tokens, tab_spaces)
1315 }
1316
1317 #[test]
1320 fn test_repo_struct_name() {
1321 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1322 assert!(code.contains("pub struct UsersRepository"));
1323 }
1324
1325 #[test]
1326 fn test_repo_new_method() {
1327 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1328 assert!(code.contains("pub fn new("));
1329 }
1330
1331 #[test]
1332 fn test_repo_pool_field_pg() {
1333 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1334 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1335 }
1336
1337 #[test]
1338 fn test_repo_pool_field_pub() {
1339 let skip = Methods::all();
1340 let (tokens, _) = generate_crud_from_parsed(
1341 &standard_entity(),
1342 DatabaseKind::Postgres,
1343 "crate::models::users",
1344 &skip,
1345 false,
1346 PoolVisibility::Pub,
1347 );
1348 let code = parse_and_format(&tokens);
1349 assert!(
1350 code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool")
1351 );
1352 }
1353
1354 #[test]
1355 fn test_repo_pool_field_pub_crate() {
1356 let skip = Methods::all();
1357 let (tokens, _) = generate_crud_from_parsed(
1358 &standard_entity(),
1359 DatabaseKind::Postgres,
1360 "crate::models::users",
1361 &skip,
1362 false,
1363 PoolVisibility::PubCrate,
1364 );
1365 let code = parse_and_format(&tokens);
1366 assert!(
1367 code.contains("pub(crate) pool: sqlx::PgPool")
1368 || code.contains("pub(crate) pool: sqlx :: PgPool")
1369 );
1370 }
1371
1372 #[test]
1373 fn test_repo_pool_field_private() {
1374 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1375 assert!(!code.contains("pub pool"));
1377 assert!(!code.contains("pub(crate) pool"));
1378 }
1379
1380 #[test]
1381 fn test_repo_pool_field_mysql() {
1382 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1383 assert!(code.contains("MySqlPool") || code.contains("MySql"));
1384 }
1385
1386 #[test]
1387 fn test_repo_pool_field_sqlite() {
1388 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1389 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1390 }
1391
1392 #[test]
1395 fn test_get_all_method() {
1396 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1397 assert!(code.contains("pub async fn get_all"));
1398 }
1399
1400 #[test]
1401 fn test_get_all_returns_vec() {
1402 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1403 assert!(code.contains("Vec<Users>"));
1404 }
1405
1406 #[test]
1407 fn test_get_all_sql() {
1408 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1409 assert!(code.contains("SELECT * FROM users"));
1410 }
1411
1412 #[test]
1415 fn test_paginate_method() {
1416 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1417 assert!(code.contains("pub async fn paginate"));
1418 }
1419
1420 #[test]
1421 fn test_paginate_params_struct() {
1422 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1423 assert!(code.contains("pub struct PaginateUsersParams"));
1424 }
1425
1426 #[test]
1427 fn test_paginate_params_fields() {
1428 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1429 assert!(code.contains("pub page: i64"));
1430 assert!(code.contains("pub per_page: i64"));
1431 }
1432
1433 #[test]
1434 fn test_paginate_returns_paginated() {
1435 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1436 assert!(code.contains("PaginatedUsers"));
1437 assert!(code.contains("PaginationUsersMeta"));
1438 }
1439
1440 #[test]
1441 fn test_paginate_meta_struct() {
1442 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1443 assert!(code.contains("pub struct PaginationUsersMeta"));
1444 assert!(code.contains("pub total: i64"));
1445 assert!(code.contains("pub last_page: i64"));
1446 assert!(code.contains("pub first_page: i64"));
1447 assert!(code.contains("pub current_page: i64"));
1448 }
1449
1450 #[test]
1451 fn test_paginate_data_struct() {
1452 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1453 assert!(code.contains("pub struct PaginatedUsers"));
1454 assert!(code.contains("pub meta: PaginationUsersMeta"));
1455 assert!(code.contains("pub data: Vec<Users>"));
1456 }
1457
1458 #[test]
1459 fn test_paginate_count_sql() {
1460 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1461 assert!(code.contains("SELECT COUNT(*) FROM users"));
1462 }
1463
1464 #[test]
1465 fn test_paginate_sql_pg() {
1466 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1467 assert!(code.contains("LIMIT $1 OFFSET $2"));
1468 }
1469
1470 #[test]
1471 fn test_paginate_sql_mysql() {
1472 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1473 assert!(code.contains("LIMIT ? OFFSET ?"));
1474 }
1475
1476 #[test]
1479 fn test_get_method() {
1480 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1481 assert!(code.contains("pub async fn get"));
1482 }
1483
1484 #[test]
1485 fn test_get_returns_option() {
1486 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1487 assert!(code.contains("Option<Users>"));
1488 }
1489
1490 #[test]
1491 fn test_get_where_pk_pg() {
1492 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1493 assert!(code.contains("WHERE id = $1"));
1494 }
1495
1496 #[test]
1497 fn test_get_where_pk_mysql() {
1498 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1499 assert!(code.contains("WHERE id = ?"));
1500 }
1501
1502 #[test]
1505 fn test_insert_method() {
1506 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1507 assert!(code.contains("pub async fn insert"));
1508 }
1509
1510 #[test]
1511 fn test_insert_params_struct() {
1512 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1513 assert!(code.contains("pub struct InsertUsersParams"));
1514 }
1515
1516 #[test]
1517 fn test_insert_params_no_pk() {
1518 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1519 assert!(code.contains("pub name: String"));
1520 assert!(
1521 code.contains("pub email: Option<String>")
1522 || code.contains("pub email: Option < String >")
1523 );
1524 }
1525
1526 #[test]
1527 fn test_insert_returning_pg() {
1528 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1529 assert!(code.contains("RETURNING *"));
1530 }
1531
1532 #[test]
1533 fn test_insert_returning_sqlite() {
1534 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1535 assert!(code.contains("RETURNING *"));
1536 }
1537
1538 #[test]
1539 fn test_insert_mysql_last_insert_id() {
1540 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1541 assert!(code.contains("LAST_INSERT_ID"));
1542 }
1543
1544 #[test]
1547 fn test_insert_default_col_is_optional() {
1548 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1549 let struct_start = code
1551 .find("pub struct InsertTasksParams")
1552 .expect("InsertTasksParams not found");
1553 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1554 let struct_body = &code[struct_start..struct_end];
1555 assert!(
1556 struct_body.contains("Option") && struct_body.contains("status"),
1557 "Expected status as Option in InsertTasksParams: {}",
1558 struct_body
1559 );
1560 }
1561
1562 #[test]
1563 fn test_insert_non_default_col_required() {
1564 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1565 let struct_start = code
1567 .find("pub struct InsertTasksParams")
1568 .expect("InsertTasksParams not found");
1569 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1570 let struct_body = &code[struct_start..struct_end];
1571 assert!(
1572 struct_body.contains("title") && struct_body.contains("String"),
1573 "Expected title as String: {}",
1574 struct_body
1575 );
1576 }
1577
1578 #[test]
1579 fn test_insert_default_col_coalesce_sql() {
1580 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1581 assert!(
1582 code.contains("COALESCE($2, 'idle'::task_status)"),
1583 "Expected COALESCE for status:\n{}",
1584 code
1585 );
1586 assert!(
1587 code.contains("COALESCE($3, 0)"),
1588 "Expected COALESCE for priority:\n{}",
1589 code
1590 );
1591 assert!(
1592 code.contains("COALESCE($4, now())"),
1593 "Expected COALESCE for created_at:\n{}",
1594 code
1595 );
1596 }
1597
1598 #[test]
1599 fn test_insert_default_col_coalesce_json() {
1600 let mut entity = entity_with_defaults();
1601 entity.fields.push(make_field_with_default(
1602 "metadata",
1603 "metadata",
1604 "serde_json::Value",
1605 false,
1606 false,
1607 r#"'{"key": "value"}'::jsonb"#,
1608 ));
1609 let code = gen(&entity, DatabaseKind::Postgres);
1610 assert!(
1611 code.contains(r#"COALESCE($7, '{"key": "value"}'::jsonb)"#),
1612 "Expected COALESCE with JSON default:\n{}",
1613 code
1614 );
1615 }
1616
1617 #[test]
1618 fn test_insert_no_coalesce_for_non_default() {
1619 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1620 assert!(
1622 code.contains("VALUES ($1, COALESCE"),
1623 "Expected $1 without COALESCE for title:\n{}",
1624 code
1625 );
1626 }
1627
1628 #[test]
1629 fn test_insert_nullable_with_default_no_double_option() {
1630 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1631 assert!(
1632 !code.contains("Option < Option") && !code.contains("Option<Option"),
1633 "Should not have Option<Option>:\n{}",
1634 code
1635 );
1636 }
1637
1638 #[test]
1639 fn test_insert_derive_default() {
1640 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1641 let struct_start = code
1642 .find("pub struct InsertTasksParams")
1643 .expect("InsertTasksParams not found");
1644 let before_struct = &code[..struct_start];
1645 assert!(
1646 before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"),
1647 "Expected #[derive(Default)] on InsertTasksParams"
1648 );
1649 }
1650
1651 #[test]
1654 fn test_update_method() {
1655 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1656 assert!(code.contains("pub async fn update"));
1657 }
1658
1659 #[test]
1660 fn test_update_params_struct() {
1661 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1662 assert!(code.contains("pub struct UpdateUsersParams"));
1663 }
1664
1665 #[test]
1666 fn test_update_pk_in_fn_signature() {
1667 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1668 let update_pos = code.find("fn update").expect("fn update not found");
1670 let params_pos = code[update_pos..]
1671 .find("UpdateUsersParams")
1672 .expect("UpdateUsersParams not found in update fn");
1673 let signature = &code[update_pos..update_pos + params_pos];
1674 assert!(
1675 signature.contains("id"),
1676 "Expected 'id' PK in update fn signature: {}",
1677 signature
1678 );
1679 }
1680
1681 #[test]
1682 fn test_update_pk_not_in_struct() {
1683 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1684 let struct_start = code
1687 .find("pub struct UpdateUsersParams")
1688 .expect("UpdateUsersParams not found");
1689 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1690 let struct_body = &code[struct_start..struct_end];
1691 assert!(
1692 !struct_body.contains("pub id"),
1693 "PK 'id' should not be in UpdateUsersParams:\n{}",
1694 struct_body
1695 );
1696 }
1697
1698 #[test]
1699 fn test_update_params_non_nullable_wrapped_in_option() {
1700 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1701 assert!(
1703 code.contains("pub name: Option<String>")
1704 || code.contains("pub name : Option < String >")
1705 );
1706 }
1707
1708 #[test]
1709 fn test_update_params_already_nullable_no_double_option() {
1710 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1711 assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1713 }
1714
1715 #[test]
1716 fn test_update_set_clause_uses_coalesce_pg() {
1717 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1718 assert!(
1719 code.contains("COALESCE($1, name)"),
1720 "Expected COALESCE for name:\n{}",
1721 code
1722 );
1723 assert!(
1724 code.contains("COALESCE($2, email)"),
1725 "Expected COALESCE for email:\n{}",
1726 code
1727 );
1728 }
1729
1730 #[test]
1731 fn test_update_where_clause_pg() {
1732 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1733 assert!(code.contains("WHERE id = $3"));
1734 }
1735
1736 #[test]
1737 fn test_update_returning_pg() {
1738 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1739 assert!(code.contains("COALESCE"));
1740 assert!(code.contains("RETURNING *"));
1741 }
1742
1743 #[test]
1744 fn test_update_set_clause_mysql() {
1745 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1746 assert!(
1747 code.contains("COALESCE(?, name)"),
1748 "Expected COALESCE for MySQL:\n{}",
1749 code
1750 );
1751 assert!(
1752 code.contains("COALESCE(?, email)"),
1753 "Expected COALESCE for email in MySQL:\n{}",
1754 code
1755 );
1756 }
1757
1758 #[test]
1759 fn test_update_set_clause_sqlite() {
1760 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1761 assert!(
1762 code.contains("COALESCE(?, name)"),
1763 "Expected COALESCE for SQLite:\n{}",
1764 code
1765 );
1766 }
1767
1768 #[test]
1771 fn test_overwrite_method() {
1772 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1773 assert!(code.contains("pub async fn overwrite"));
1774 }
1775
1776 #[test]
1777 fn test_overwrite_params_struct() {
1778 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1779 assert!(code.contains("pub struct OverwriteUsersParams"));
1780 }
1781
1782 #[test]
1783 fn test_overwrite_pk_in_fn_signature() {
1784 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1785 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1786 let params_pos = code[pos..]
1787 .find("OverwriteUsersParams")
1788 .expect("OverwriteUsersParams not found");
1789 let signature = &code[pos..pos + params_pos];
1790 assert!(
1791 signature.contains("id"),
1792 "Expected PK in overwrite fn signature: {}",
1793 signature
1794 );
1795 }
1796
1797 #[test]
1798 fn test_overwrite_pk_not_in_struct() {
1799 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1800 let struct_start = code
1801 .find("pub struct OverwriteUsersParams")
1802 .expect("OverwriteUsersParams not found");
1803 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1804 let struct_body = &code[struct_start..struct_end];
1805 assert!(
1806 !struct_body.contains("pub id"),
1807 "PK should not be in OverwriteUsersParams: {}",
1808 struct_body
1809 );
1810 }
1811
1812 #[test]
1813 fn test_overwrite_no_coalesce() {
1814 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1815 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1817 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1818 assert!(
1819 !method_body.contains("COALESCE"),
1820 "Overwrite should not use COALESCE: {}",
1821 method_body
1822 );
1823 }
1824
1825 #[test]
1826 fn test_overwrite_set_clause_pg() {
1827 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1828 assert!(code.contains("name = $1,"));
1829 assert!(code.contains("email = $2"));
1830 assert!(code.contains("WHERE id = $3"));
1831 }
1832
1833 #[test]
1834 fn test_overwrite_returning_pg() {
1835 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1836 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1837 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1838 assert!(
1839 method_body.contains("RETURNING *"),
1840 "Expected RETURNING * in overwrite"
1841 );
1842 }
1843
1844 #[test]
1845 fn test_view_no_overwrite() {
1846 let mut entity = standard_entity();
1847 entity.is_view = true;
1848 let code = gen(&entity, DatabaseKind::Postgres);
1849 assert!(!code.contains("pub async fn overwrite"));
1850 }
1851
1852 #[test]
1853 fn test_without_overwrite() {
1854 let m = Methods {
1855 overwrite: false,
1856 ..Methods::all()
1857 };
1858 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1859 assert!(!code.contains("pub async fn overwrite"));
1860 assert!(!code.contains("OverwriteUsersParams"));
1861 }
1862
1863 #[test]
1864 fn test_update_and_overwrite_coexist() {
1865 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1866 assert!(
1867 code.contains("pub async fn update"),
1868 "Expected update method"
1869 );
1870 assert!(
1871 code.contains("pub async fn overwrite"),
1872 "Expected overwrite method"
1873 );
1874 assert!(
1875 code.contains("UpdateUsersParams"),
1876 "Expected UpdateUsersParams"
1877 );
1878 assert!(
1879 code.contains("OverwriteUsersParams"),
1880 "Expected OverwriteUsersParams"
1881 );
1882 }
1883
1884 #[test]
1887 fn test_delete_method() {
1888 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1889 assert!(code.contains("pub async fn delete"));
1890 }
1891
1892 #[test]
1893 fn test_delete_where_pk() {
1894 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1895 assert!(code.contains("DELETE FROM users"));
1896 assert!(code.contains("WHERE id = $1"));
1897 }
1898
1899 #[test]
1900 fn test_tab_spaces_2_sql_indent() {
1901 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
1902 assert!(
1904 code.contains(" SELECT *"),
1905 "Expected SQL at 8-space indent:\n{}",
1906 code
1907 );
1908 assert!(
1909 code.contains(" \"#"),
1910 "Expected closing tag at 6-space indent:\n{}",
1911 code
1912 );
1913 }
1914
1915 #[test]
1916 fn test_tab_spaces_4_sql_indent() {
1917 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
1918 assert!(
1920 code.contains(" SELECT *"),
1921 "Expected SQL at 12-space indent:\n{}",
1922 code
1923 );
1924 assert!(
1925 code.contains(" \"#"),
1926 "Expected closing tag at 8-space indent:\n{}",
1927 code
1928 );
1929 }
1930
1931 #[test]
1932 fn test_delete_returns_unit() {
1933 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1934 assert!(
1935 code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>")
1936 );
1937 }
1938
1939 #[test]
1942 fn test_view_no_insert() {
1943 let mut entity = standard_entity();
1944 entity.is_view = true;
1945 let code = gen(&entity, DatabaseKind::Postgres);
1946 assert!(!code.contains("pub async fn insert"));
1947 }
1948
1949 #[test]
1950 fn test_view_no_update() {
1951 let mut entity = standard_entity();
1952 entity.is_view = true;
1953 let code = gen(&entity, DatabaseKind::Postgres);
1954 assert!(!code.contains("pub async fn update"));
1955 }
1956
1957 #[test]
1958 fn test_view_no_delete() {
1959 let mut entity = standard_entity();
1960 entity.is_view = true;
1961 let code = gen(&entity, DatabaseKind::Postgres);
1962 assert!(!code.contains("pub async fn delete"));
1963 }
1964
1965 #[test]
1966 fn test_view_has_get_all() {
1967 let mut entity = standard_entity();
1968 entity.is_view = true;
1969 let code = gen(&entity, DatabaseKind::Postgres);
1970 assert!(code.contains("pub async fn get_all"));
1971 }
1972
1973 #[test]
1974 fn test_view_has_paginate() {
1975 let mut entity = standard_entity();
1976 entity.is_view = true;
1977 let code = gen(&entity, DatabaseKind::Postgres);
1978 assert!(code.contains("pub async fn paginate"));
1979 }
1980
1981 #[test]
1982 fn test_view_has_get() {
1983 let mut entity = standard_entity();
1984 entity.is_view = true;
1985 let code = gen(&entity, DatabaseKind::Postgres);
1986 assert!(code.contains("pub async fn get"));
1987 }
1988
1989 #[test]
1992 fn test_only_get_all() {
1993 let m = Methods {
1994 get_all: true,
1995 ..Default::default()
1996 };
1997 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1998 assert!(code.contains("pub async fn get_all"));
1999 assert!(!code.contains("pub async fn paginate"));
2000 assert!(!code.contains("pub async fn insert"));
2001 }
2002
2003 #[test]
2004 fn test_without_get_all() {
2005 let m = Methods {
2006 get_all: false,
2007 ..Methods::all()
2008 };
2009 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2010 assert!(!code.contains("pub async fn get_all"));
2011 }
2012
2013 #[test]
2014 fn test_without_paginate() {
2015 let m = Methods {
2016 paginate: false,
2017 ..Methods::all()
2018 };
2019 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2020 assert!(!code.contains("pub async fn paginate"));
2021 assert!(!code.contains("PaginateUsersParams"));
2022 }
2023
2024 #[test]
2025 fn test_without_get() {
2026 let m = Methods {
2027 get: false,
2028 ..Methods::all()
2029 };
2030 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2031 assert!(code.contains("pub async fn get_all"));
2032 let without_get_all = code.replace("get_all", "XXX");
2033 assert!(!without_get_all.contains("fn get("));
2034 }
2035
2036 #[test]
2037 fn test_without_insert() {
2038 let m = Methods {
2039 insert: false,
2040 insert_many: false,
2041 ..Methods::all()
2042 };
2043 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2044 assert!(!code.contains("pub async fn insert"));
2045 assert!(!code.contains("InsertUsersParams"));
2046 }
2047
2048 #[test]
2049 fn test_without_update() {
2050 let m = Methods {
2051 update: false,
2052 ..Methods::all()
2053 };
2054 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2055 assert!(!code.contains("pub async fn update"));
2056 assert!(!code.contains("UpdateUsersParams"));
2057 }
2058
2059 #[test]
2060 fn test_without_delete() {
2061 let m = Methods {
2062 delete: false,
2063 ..Methods::all()
2064 };
2065 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2066 assert!(!code.contains("pub async fn delete"));
2067 }
2068
2069 #[test]
2070 fn test_empty_methods_no_methods() {
2071 let m = Methods::default();
2072 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2073 assert!(!code.contains("pub async fn get_all"));
2074 assert!(!code.contains("pub async fn paginate"));
2075 assert!(!code.contains("pub async fn insert"));
2076 assert!(!code.contains("pub async fn update"));
2077 assert!(!code.contains("pub async fn overwrite"));
2078 assert!(!code.contains("pub async fn delete"));
2079 assert!(!code.contains("pub async fn insert_many"));
2080 }
2081
2082 #[test]
2085 fn test_no_pool_import() {
2086 let skip = Methods::all();
2087 let (_, imports) = generate_crud_from_parsed(
2088 &standard_entity(),
2089 DatabaseKind::Postgres,
2090 "crate::models::users",
2091 &skip,
2092 false,
2093 PoolVisibility::Private,
2094 );
2095 assert!(!imports.iter().any(|i| i.contains("PgPool")));
2096 }
2097
2098 #[test]
2099 fn test_imports_contain_entity() {
2100 let skip = Methods::all();
2101 let (_, imports) = generate_crud_from_parsed(
2102 &standard_entity(),
2103 DatabaseKind::Postgres,
2104 "crate::models::users",
2105 &skip,
2106 false,
2107 PoolVisibility::Private,
2108 );
2109 assert!(imports
2110 .iter()
2111 .any(|i| i.contains("crate::models::users::Users")));
2112 }
2113
2114 #[test]
2117 fn test_renamed_column_in_sql() {
2118 let entity = ParsedEntity {
2119 struct_name: "Connector".to_string(),
2120 table_name: "connector".to_string(),
2121 schema_name: None,
2122 is_view: false,
2123 fields: vec![
2124 make_field("id", "id", "i32", false, true),
2125 make_field("connector_type", "type", "String", false, false),
2126 ],
2127 imports: vec![],
2128 };
2129 let code = gen(&entity, DatabaseKind::Postgres);
2130 assert!(code.contains("type"));
2132 assert!(code.contains("pub connector_type: String"));
2134 }
2135
2136 #[test]
2139 fn test_no_pk_no_get() {
2140 let entity = ParsedEntity {
2141 struct_name: "Logs".to_string(),
2142 table_name: "logs".to_string(),
2143 schema_name: None,
2144 is_view: false,
2145 fields: vec![
2146 make_field("message", "message", "String", false, false),
2147 make_field("ts", "ts", "String", false, false),
2148 ],
2149 imports: vec![],
2150 };
2151 let code = gen(&entity, DatabaseKind::Postgres);
2152 assert!(code.contains("pub async fn get_all"));
2153 let without_get_all = code.replace("get_all", "XXX");
2154 assert!(!without_get_all.contains("fn get("));
2155 }
2156
2157 #[test]
2158 fn test_no_pk_no_delete() {
2159 let entity = ParsedEntity {
2160 struct_name: "Logs".to_string(),
2161 table_name: "logs".to_string(),
2162 schema_name: None,
2163 is_view: false,
2164 fields: vec![make_field("message", "message", "String", false, false)],
2165 imports: vec![],
2166 };
2167 let code = gen(&entity, DatabaseKind::Postgres);
2168 assert!(!code.contains("pub async fn delete"));
2169 }
2170
2171 #[test]
2174 fn test_param_structs_have_default() {
2175 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2176 assert!(code.contains("Default"));
2177 }
2178
2179 #[test]
2182 fn test_entity_imports_forwarded() {
2183 let entity = ParsedEntity {
2184 struct_name: "Users".to_string(),
2185 table_name: "users".to_string(),
2186 schema_name: None,
2187 is_view: false,
2188 fields: vec![
2189 make_field("id", "id", "Uuid", false, true),
2190 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
2191 ],
2192 imports: vec![
2193 "use chrono::{DateTime, Utc};".to_string(),
2194 "use uuid::Uuid;".to_string(),
2195 ],
2196 };
2197 let skip = Methods::all();
2198 let (_, imports) = generate_crud_from_parsed(
2199 &entity,
2200 DatabaseKind::Postgres,
2201 "crate::models::users",
2202 &skip,
2203 false,
2204 PoolVisibility::Private,
2205 );
2206 assert!(imports.iter().any(|i| i.contains("chrono")));
2207 assert!(imports.iter().any(|i| i.contains("uuid")));
2208 }
2209
2210 #[test]
2211 fn test_entity_imports_empty_when_no_imports() {
2212 let skip = Methods::all();
2213 let (_, imports) = generate_crud_from_parsed(
2214 &standard_entity(),
2215 DatabaseKind::Postgres,
2216 "crate::models::users",
2217 &skip,
2218 false,
2219 PoolVisibility::Private,
2220 );
2221 assert!(!imports.iter().any(|i| i.contains("chrono")));
2223 assert!(!imports.iter().any(|i| i.contains("uuid")));
2224 }
2225
2226 #[test]
2229 fn test_macro_get_all() {
2230 let m = Methods {
2231 get_all: true,
2232 ..Default::default()
2233 };
2234 let (tokens, _) = generate_crud_from_parsed(
2235 &standard_entity(),
2236 DatabaseKind::Postgres,
2237 "crate::models::users",
2238 &m,
2239 true,
2240 PoolVisibility::Private,
2241 );
2242 let code = parse_and_format(&tokens);
2243 assert!(code.contains("query_as!"));
2244 assert!(!code.contains("query_as::<"));
2245 }
2246
2247 #[test]
2248 fn test_macro_paginate() {
2249 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2250 assert!(code.contains("query_as!"));
2251 assert!(code.contains("per_page, offset"));
2252 }
2253
2254 #[test]
2255 fn test_macro_get() {
2256 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2257 assert!(code.contains("query_as!(Users"));
2259 }
2260
2261 #[test]
2262 fn test_macro_insert_pg() {
2263 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2264 assert!(code.contains("query_as!(Users"));
2265 assert!(code.contains("params.name"));
2266 assert!(code.contains("params.email"));
2267 }
2268
2269 #[test]
2270 fn test_macro_insert_mysql() {
2271 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
2272 assert!(code.contains("query!"));
2274 assert!(code.contains("query_scalar!"));
2275 }
2276
2277 #[test]
2278 fn test_macro_update() {
2279 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2280 assert!(code.contains("query_as!(Users"));
2281 assert!(
2282 code.contains("COALESCE"),
2283 "Expected COALESCE in macro update:\n{}",
2284 code
2285 );
2286 assert!(code.contains("pub async fn update"));
2287 assert!(code.contains("UpdateUsersParams"));
2288 }
2289
2290 #[test]
2291 fn test_macro_delete() {
2292 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2293 assert!(code.contains("query!"));
2295 }
2296
2297 #[test]
2298 fn test_macro_no_bind_calls() {
2299 let m = Methods {
2301 insert_many: false,
2302 ..Methods::all()
2303 };
2304 let (tokens, _) = generate_crud_from_parsed(
2305 &standard_entity(),
2306 DatabaseKind::Postgres,
2307 "crate::models::users",
2308 &m,
2309 true,
2310 PoolVisibility::Private,
2311 );
2312 let code = parse_and_format(&tokens);
2313 assert!(!code.contains(".bind("));
2314 }
2315
2316 #[test]
2317 fn test_function_style_uses_bind() {
2318 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2319 assert!(code.contains(".bind("));
2320 assert!(!code.contains("query_as!("));
2321 assert!(!code.contains("query!("));
2322 }
2323
2324 fn entity_with_sql_array() -> ParsedEntity {
2327 ParsedEntity {
2328 struct_name: "AgentConnector".to_string(),
2329 table_name: "agent.agent_connector".to_string(),
2330 schema_name: Some("agent".to_string()),
2331 is_view: false,
2332 fields: vec![
2333 ParsedField {
2334 rust_name: "connector_id".to_string(),
2335 column_name: "connector_id".to_string(),
2336 rust_type: "Uuid".to_string(),
2337 inner_type: "Uuid".to_string(),
2338 is_nullable: false,
2339 is_primary_key: true,
2340 sql_type: None,
2341 is_sql_array: false,
2342 column_default: None,
2343 },
2344 ParsedField {
2345 rust_name: "agent_id".to_string(),
2346 column_name: "agent_id".to_string(),
2347 rust_type: "Uuid".to_string(),
2348 inner_type: "Uuid".to_string(),
2349 is_nullable: false,
2350 is_primary_key: false,
2351 sql_type: None,
2352 is_sql_array: false,
2353 column_default: None,
2354 },
2355 ParsedField {
2356 rust_name: "usages".to_string(),
2357 column_name: "usages".to_string(),
2358 rust_type: "Vec<ConnectorUsages>".to_string(),
2359 inner_type: "Vec<ConnectorUsages>".to_string(),
2360 is_nullable: false,
2361 is_primary_key: false,
2362 sql_type: Some("agent.connector_usages".to_string()),
2363 is_sql_array: true,
2364 column_default: None,
2365 },
2366 ],
2367 imports: vec!["use uuid::Uuid;".to_string()],
2368 }
2369 }
2370
2371 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2372 let skip = Methods::all();
2373 let (tokens, _) = generate_crud_from_parsed(
2374 entity,
2375 db,
2376 "crate::models::agent_connector",
2377 &skip,
2378 true,
2379 PoolVisibility::Private,
2380 );
2381 parse_and_format(&tokens)
2382 }
2383
2384 #[test]
2385 fn test_sql_array_macro_get_all_uses_runtime() {
2386 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2387 assert!(code.contains("query_as::<"));
2389 }
2390
2391 #[test]
2392 fn test_sql_array_macro_get_uses_runtime() {
2393 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2394 assert!(code.contains(".bind("));
2396 }
2397
2398 #[test]
2399 fn test_sql_array_macro_insert_uses_runtime() {
2400 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2401 assert!(
2403 code.contains("query_as::<_ , AgentConnector>")
2404 || code.contains("query_as::<_, AgentConnector>")
2405 );
2406 }
2407
2408 #[test]
2409 fn test_sql_array_macro_delete_still_uses_macro() {
2410 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2411 assert!(code.contains("query!"));
2413 }
2414
2415 #[test]
2416 fn test_sql_array_no_query_as_macro() {
2417 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2418 assert!(!code.contains("query_as!("));
2420 }
2421
2422 fn entity_with_sql_enum() -> ParsedEntity {
2425 ParsedEntity {
2426 struct_name: "Task".to_string(),
2427 table_name: "tasks".to_string(),
2428 schema_name: None,
2429 is_view: false,
2430 fields: vec![
2431 ParsedField {
2432 rust_name: "id".to_string(),
2433 column_name: "id".to_string(),
2434 rust_type: "i32".to_string(),
2435 inner_type: "i32".to_string(),
2436 is_nullable: false,
2437 is_primary_key: true,
2438 sql_type: None,
2439 is_sql_array: false,
2440 column_default: None,
2441 },
2442 ParsedField {
2443 rust_name: "status".to_string(),
2444 column_name: "status".to_string(),
2445 rust_type: "TaskStatus".to_string(),
2446 inner_type: "TaskStatus".to_string(),
2447 is_nullable: false,
2448 is_primary_key: false,
2449 sql_type: Some("task_status".to_string()),
2450 is_sql_array: false,
2451 column_default: None,
2452 },
2453 ],
2454 imports: vec![],
2455 }
2456 }
2457
2458 #[test]
2459 fn test_sql_enum_macro_uses_runtime() {
2460 let skip = Methods::all();
2461 let (tokens, _) = generate_crud_from_parsed(
2462 &entity_with_sql_enum(),
2463 DatabaseKind::Postgres,
2464 "crate::models::task",
2465 &skip,
2466 true,
2467 PoolVisibility::Private,
2468 );
2469 let code = parse_and_format(&tokens);
2470 assert!(code.contains("query_as::<"));
2472 assert!(!code.contains("query_as!("));
2473 }
2474
2475 #[test]
2476 fn test_sql_enum_macro_delete_still_uses_macro() {
2477 let skip = Methods::all();
2478 let (tokens, _) = generate_crud_from_parsed(
2479 &entity_with_sql_enum(),
2480 DatabaseKind::Postgres,
2481 "crate::models::task",
2482 &skip,
2483 true,
2484 PoolVisibility::Private,
2485 );
2486 let code = parse_and_format(&tokens);
2487 assert!(code.contains("query!"));
2489 }
2490
2491 fn entity_with_vec_string() -> ParsedEntity {
2494 ParsedEntity {
2495 struct_name: "PromptHistory".to_string(),
2496 table_name: "prompt_history".to_string(),
2497 schema_name: None,
2498 is_view: false,
2499 fields: vec![
2500 ParsedField {
2501 rust_name: "id".to_string(),
2502 column_name: "id".to_string(),
2503 rust_type: "Uuid".to_string(),
2504 inner_type: "Uuid".to_string(),
2505 is_nullable: false,
2506 is_primary_key: true,
2507 sql_type: None,
2508 is_sql_array: false,
2509 column_default: None,
2510 },
2511 ParsedField {
2512 rust_name: "content".to_string(),
2513 column_name: "content".to_string(),
2514 rust_type: "String".to_string(),
2515 inner_type: "String".to_string(),
2516 is_nullable: false,
2517 is_primary_key: false,
2518 sql_type: None,
2519 is_sql_array: false,
2520 column_default: None,
2521 },
2522 ParsedField {
2523 rust_name: "tags".to_string(),
2524 column_name: "tags".to_string(),
2525 rust_type: "Vec<String>".to_string(),
2526 inner_type: "Vec<String>".to_string(),
2527 is_nullable: false,
2528 is_primary_key: false,
2529 sql_type: None,
2530 is_sql_array: false,
2531 column_default: None,
2532 },
2533 ],
2534 imports: vec!["use uuid::Uuid;".to_string()],
2535 }
2536 }
2537
2538 #[test]
2539 fn test_vec_string_macro_insert_uses_as_slice() {
2540 let skip = Methods::all();
2541 let (tokens, _) = generate_crud_from_parsed(
2542 &entity_with_vec_string(),
2543 DatabaseKind::Postgres,
2544 "crate::models::prompt_history",
2545 &skip,
2546 true,
2547 PoolVisibility::Private,
2548 );
2549 let code = parse_and_format(&tokens);
2550 assert!(code.contains("as_slice()"));
2551 }
2552
2553 #[test]
2554 fn test_vec_string_macro_update_uses_as_slice() {
2555 let skip = Methods::all();
2556 let (tokens, _) = generate_crud_from_parsed(
2557 &entity_with_vec_string(),
2558 DatabaseKind::Postgres,
2559 "crate::models::prompt_history",
2560 &skip,
2561 true,
2562 PoolVisibility::Private,
2563 );
2564 let code = parse_and_format(&tokens);
2565 let count = code.matches("as_slice()").count();
2567 assert!(
2568 count >= 2,
2569 "expected at least 2 as_slice() calls (insert + update), found {}",
2570 count
2571 );
2572 }
2573
2574 #[test]
2575 fn test_vec_string_non_macro_no_as_slice() {
2576 let skip = Methods::all();
2577 let (tokens, _) = generate_crud_from_parsed(
2578 &entity_with_vec_string(),
2579 DatabaseKind::Postgres,
2580 "crate::models::prompt_history",
2581 &skip,
2582 false,
2583 PoolVisibility::Private,
2584 );
2585 let code = parse_and_format(&tokens);
2586 assert!(!code.contains("as_slice()"));
2588 }
2589
2590 #[test]
2591 fn test_vec_string_parsed_from_source_uses_as_slice() {
2592 use crate::codegen::entity_parser::parse_entity_source;
2593 let source = r#"
2594 use uuid::Uuid;
2595
2596 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2597 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2598 pub struct PromptHistory {
2599 #[sqlx_gen(primary_key)]
2600 pub id: Uuid,
2601 pub content: String,
2602 pub tags: Vec<String>,
2603 }
2604 "#;
2605 let entity = parse_entity_source(source).unwrap();
2606 let skip = Methods::all();
2607 let (tokens, _) = generate_crud_from_parsed(
2608 &entity,
2609 DatabaseKind::Postgres,
2610 "crate::models::prompt_history",
2611 &skip,
2612 true,
2613 PoolVisibility::Private,
2614 );
2615 let code = parse_and_format(&tokens);
2616 assert!(
2617 code.contains("as_slice()"),
2618 "Expected as_slice() in generated code:\n{}",
2619 code
2620 );
2621 }
2622
2623 fn junction_entity() -> ParsedEntity {
2626 ParsedEntity {
2627 struct_name: "AnalysisRecord".to_string(),
2628 table_name: "analysis.analysis__record".to_string(),
2629 schema_name: None,
2630 is_view: false,
2631 fields: vec![
2632 make_field("record_id", "record_id", "uuid::Uuid", false, true),
2633 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2634 ],
2635 imports: vec![],
2636 }
2637 }
2638
2639 #[test]
2640 fn test_composite_pk_only_insert_generated() {
2641 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2642 assert!(
2643 code.contains("pub struct InsertAnalysisRecordParams"),
2644 "Expected InsertAnalysisRecordParams struct:\n{}",
2645 code
2646 );
2647 assert!(
2648 code.contains("pub record_id"),
2649 "Expected record_id field in insert params:\n{}",
2650 code
2651 );
2652 assert!(
2653 code.contains("pub analysis_id"),
2654 "Expected analysis_id field in insert params:\n{}",
2655 code
2656 );
2657 assert!(
2658 code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"),
2659 "Expected INSERT INTO clause:\n{}",
2660 code
2661 );
2662 assert!(
2663 code.contains("VALUES ($1, $2)"),
2664 "Expected VALUES clause:\n{}",
2665 code
2666 );
2667 assert!(
2668 code.contains("RETURNING *"),
2669 "Expected RETURNING clause:\n{}",
2670 code
2671 );
2672 assert!(
2673 code.contains("pub async fn insert"),
2674 "Expected insert method:\n{}",
2675 code
2676 );
2677 }
2678
2679 #[test]
2680 fn test_composite_pk_only_no_update() {
2681 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2682 assert!(
2683 !code.contains("UpdateAnalysisRecordParams"),
2684 "Expected no UpdateAnalysisRecordParams struct:\n{}",
2685 code
2686 );
2687 assert!(
2688 !code.contains("pub async fn update"),
2689 "Expected no update method:\n{}",
2690 code
2691 );
2692 }
2693
2694 #[test]
2695 fn test_composite_pk_only_delete_generated() {
2696 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2697 assert!(
2698 code.contains("pub async fn delete"),
2699 "Expected delete method:\n{}",
2700 code
2701 );
2702 assert!(
2703 code.contains("DELETE FROM analysis.analysis__record"),
2704 "Expected DELETE clause:\n{}",
2705 code
2706 );
2707 assert!(
2708 code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2709 "Expected WHERE clause:\n{}",
2710 code
2711 );
2712 }
2713
2714 #[test]
2715 fn test_composite_pk_only_get_generated() {
2716 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2717 assert!(
2718 code.contains("pub async fn get"),
2719 "Expected get method:\n{}",
2720 code
2721 );
2722 assert!(
2723 code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2724 "Expected WHERE clause with both PK columns:\n{}",
2725 code
2726 );
2727 }
2728
2729 #[test]
2732 fn test_insert_many_transactionally_method_generated() {
2733 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2734 assert!(
2735 code.contains("pub async fn insert_many_transactionally"),
2736 "Expected insert_many_transactionally method:\n{}",
2737 code
2738 );
2739 }
2740
2741 #[test]
2742 fn test_insert_many_transactionally_signature() {
2743 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2744 assert!(
2745 code.contains("entries: Vec<InsertUsersParams>"),
2746 "Expected Vec<InsertUsersParams> param:\n{}",
2747 code
2748 );
2749 assert!(
2750 code.contains("Result<Vec<Users>"),
2751 "Expected Result<Vec<Users>> return type:\n{}",
2752 code
2753 );
2754 }
2755
2756 #[test]
2757 fn test_insert_many_transactionally_no_strategy_enum() {
2758 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2759 assert!(
2760 !code.contains("TransactionStrategy"),
2761 "TransactionStrategy should not be generated:\n{}",
2762 code
2763 );
2764 assert!(
2765 !code.contains("InsertManyUsersResult"),
2766 "InsertManyUsersResult should not be generated:\n{}",
2767 code
2768 );
2769 }
2770
2771 #[test]
2772 fn test_insert_many_transactionally_uses_transaction_pg() {
2773 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2774 let method_start = code
2775 .find("fn insert_many_transactionally")
2776 .expect("insert_many_transactionally not found");
2777 let method_body = &code[method_start..];
2778 assert!(
2779 method_body.contains("self.pool.begin()"),
2780 "Expected begin():\n{}",
2781 method_body
2782 );
2783 assert!(
2784 method_body.contains("tx.commit()"),
2785 "Expected commit():\n{}",
2786 method_body
2787 );
2788 }
2789
2790 #[test]
2791 fn test_insert_many_transactionally_multi_row_pg() {
2792 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2793 let method_start = code
2794 .find("fn insert_many_transactionally")
2795 .expect("not found");
2796 let method_body = &code[method_start..];
2797 assert!(
2798 method_body.contains("RETURNING *"),
2799 "Expected RETURNING * in multi-row SQL:\n{}",
2800 method_body
2801 );
2802 assert!(
2803 method_body.contains("values_parts"),
2804 "Expected multi-row VALUES building:\n{}",
2805 method_body
2806 );
2807 assert!(
2808 method_body.contains("65535"),
2809 "Expected chunk size limit:\n{}",
2810 method_body
2811 );
2812 }
2813
2814 #[test]
2815 fn test_insert_many_transactionally_multi_row_sqlite() {
2816 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2817 let method_start = code
2818 .find("fn insert_many_transactionally")
2819 .expect("not found");
2820 let method_body = &code[method_start..];
2821 assert!(
2822 method_body.contains("values_parts"),
2823 "Expected multi-row VALUES building for SQLite:\n{}",
2824 method_body
2825 );
2826 assert!(
2827 method_body.contains("RETURNING *"),
2828 "Expected RETURNING * for SQLite:\n{}",
2829 method_body
2830 );
2831 }
2832
2833 #[test]
2834 fn test_insert_many_transactionally_mysql_individual_inserts() {
2835 let code = gen(&standard_entity(), DatabaseKind::Mysql);
2836 let method_start = code
2837 .find("fn insert_many_transactionally")
2838 .expect("not found");
2839 let method_body = &code[method_start..];
2840 assert!(
2841 method_body.contains("LAST_INSERT_ID"),
2842 "Expected LAST_INSERT_ID for MySQL:\n{}",
2843 method_body
2844 );
2845 assert!(
2846 method_body.contains("self.pool.begin()"),
2847 "Expected begin() for MySQL:\n{}",
2848 method_body
2849 );
2850 }
2851
2852 #[test]
2853 fn test_insert_many_transactionally_view_not_generated() {
2854 let mut entity = standard_entity();
2855 entity.is_view = true;
2856 let code = gen(&entity, DatabaseKind::Postgres);
2857 assert!(
2858 !code.contains("pub async fn insert_many_transactionally"),
2859 "should not be generated for views"
2860 );
2861 }
2862
2863 #[test]
2864 fn test_insert_many_transactionally_without_method_not_generated() {
2865 let m = Methods {
2866 insert_many: false,
2867 ..Methods::all()
2868 };
2869 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2870 assert!(
2871 !code.contains("pub async fn insert_many_transactionally"),
2872 "should not be generated when disabled"
2873 );
2874 }
2875
2876 #[test]
2877 fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
2878 let m = Methods {
2879 insert: false,
2880 insert_many: true,
2881 ..Default::default()
2882 };
2883 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2884 assert!(
2885 code.contains("pub struct InsertUsersParams"),
2886 "Expected InsertUsersParams:\n{}",
2887 code
2888 );
2889 assert!(
2890 code.contains("pub async fn insert_many_transactionally"),
2891 "Expected method:\n{}",
2892 code
2893 );
2894 assert!(
2895 !code.contains("pub async fn insert("),
2896 "insert should not be present:\n{}",
2897 code
2898 );
2899 }
2900
2901 #[test]
2902 fn test_insert_many_transactionally_with_column_defaults_coalesce() {
2903 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
2904 let method_start = code
2905 .find("fn insert_many_transactionally")
2906 .expect("not found");
2907 let method_body = &code[method_start..];
2908 assert!(
2909 method_body.contains("COALESCE"),
2910 "Expected COALESCE for fields with defaults:\n{}",
2911 method_body
2912 );
2913 }
2914
2915 #[test]
2916 fn test_insert_many_transactionally_junction_table() {
2917 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2918 assert!(
2919 code.contains("pub async fn insert_many_transactionally"),
2920 "Expected method for junction table:\n{}",
2921 code
2922 );
2923 }
2924
2925 #[test]
2926 fn test_insert_many_transactionally_all_three_backends_compile() {
2927 for db in [
2928 DatabaseKind::Postgres,
2929 DatabaseKind::Mysql,
2930 DatabaseKind::Sqlite,
2931 ] {
2932 let code = gen(&standard_entity(), db);
2933 assert!(
2934 code.contains("pub async fn insert_many_transactionally"),
2935 "Expected method for {:?}",
2936 db
2937 );
2938 }
2939 }
2940}