1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8use crate::codegen::identifiers::{quote_ident, quote_qualified};
9
10pub fn generate_crud_from_parsed(
11 entity: &ParsedEntity,
12 db_kind: DatabaseKind,
13 entity_module_path: &str,
14 methods: &Methods,
15 query_macro: bool,
16 pool_visibility: PoolVisibility,
17) -> (TokenStream, BTreeSet<String>) {
18 let mut imports = BTreeSet::new();
19
20 let entity_ident = format_ident!("{}", entity.struct_name);
21 let repo_name = format!("{}Repository", entity.struct_name);
22 let repo_ident = format_ident!("{}", repo_name);
23
24 let schema_for_sql = entity
29 .schema_name
30 .as_deref()
31 .filter(|s| !crate::codegen::is_default_schema(s));
32 let table_name = quote_qualified(schema_for_sql, &entity.table_name, db_kind);
33
34 let pool_type = pool_type_tokens(db_kind);
36
37 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
41 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
42
43 imports.insert(format!(
45 "use {}::{};",
46 entity_module_path, entity.struct_name
47 ));
48
49 let entity_parent = entity_module_path
53 .rsplit_once("::")
54 .map(|(parent, _)| parent)
55 .unwrap_or(entity_module_path);
56 for imp in &entity.imports {
57 if let Some(rest) = imp.strip_prefix("use super::") {
58 imports.insert(format!("use {}::{}", entity_parent, rest));
59 } else {
60 imports.insert(imp.clone());
61 }
62 }
63
64 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
66
67 let non_pk_fields: Vec<&ParsedField> =
69 entity.fields.iter().filter(|f| !f.is_primary_key).collect();
70
71 let is_view = entity.is_view;
72
73 let mut method_tokens = Vec::new();
75 let mut param_structs = Vec::new();
76
77 if methods.get_all {
79 let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
80 let method = if use_macro {
81 quote! {
82 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
83 sqlx::query_as!(#entity_ident, #sql)
84 .fetch_all(&self.pool)
85 .await
86 }
87 }
88 } else {
89 quote! {
90 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
91 sqlx::query_as::<_, #entity_ident>(#sql)
92 .fetch_all(&self.pool)
93 .await
94 }
95 }
96 };
97 method_tokens.push(method);
98 }
99
100 if methods.paginate {
102 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
103 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
104 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
105 let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
106 let sql = raw_sql_lit(&match db_kind {
107 DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
108 DatabaseKind::Mysql | DatabaseKind::Sqlite => {
109 format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name)
110 }
111 });
112 let method = if use_macro {
113 quote! {
114 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
115 let total: i64 = sqlx::query_scalar!(#count_sql)
116 .fetch_one(&self.pool)
117 .await?
118 .unwrap_or(0);
119 let per_page = params.per_page;
120 let current_page = params.page;
121 let last_page = (total + per_page - 1) / per_page;
122 let offset = (current_page - 1) * per_page;
123 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
124 .fetch_all(&self.pool)
125 .await?;
126 Ok(#paginated_ident {
127 meta: #pagination_meta_ident {
128 total,
129 per_page,
130 current_page,
131 last_page,
132 first_page: 1,
133 },
134 data,
135 })
136 }
137 }
138 } else {
139 quote! {
140 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
141 let total: i64 = sqlx::query_scalar(#count_sql)
142 .fetch_one(&self.pool)
143 .await?;
144 let per_page = params.per_page;
145 let current_page = params.page;
146 let last_page = (total + per_page - 1) / per_page;
147 let offset = (current_page - 1) * per_page;
148 let data = sqlx::query_as::<_, #entity_ident>(#sql)
149 .bind(per_page)
150 .bind(offset)
151 .fetch_all(&self.pool)
152 .await?;
153 Ok(#paginated_ident {
154 meta: #pagination_meta_ident {
155 total,
156 per_page,
157 current_page,
158 last_page,
159 first_page: 1,
160 },
161 data,
162 })
163 }
164 }
165 };
166 method_tokens.push(method);
167 param_structs.push(quote! {
168 #[derive(Debug, Clone, Default)]
169 pub struct #paginate_params_ident {
170 pub page: i64,
171 pub per_page: i64,
172 }
173 });
174 param_structs.push(quote! {
175 #[derive(Debug, Clone)]
176 pub struct #pagination_meta_ident {
177 pub total: i64,
178 pub per_page: i64,
179 pub current_page: i64,
180 pub last_page: i64,
181 pub first_page: i64,
182 }
183 });
184 param_structs.push(quote! {
185 #[derive(Debug, Clone)]
186 pub struct #paginated_ident {
187 pub meta: #pagination_meta_ident,
188 pub data: Vec<#entity_ident>,
189 }
190 });
191 }
192
193 if methods.get && !pk_fields.is_empty() {
195 let pk_params: Vec<TokenStream> = pk_fields
196 .iter()
197 .map(|f| {
198 let name = format_ident!("{}", f.rust_name);
199 let ty: TokenStream = f.inner_type.parse().unwrap();
200 quote! { #name: #ty }
201 })
202 .collect();
203
204 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
205 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
206 let sql = raw_sql_lit(&format!(
207 "SELECT *\nFROM {}\nWHERE {}",
208 table_name, where_clause
209 ));
210 let sql_macro = raw_sql_lit(&format!(
211 "SELECT *\nFROM {}\nWHERE {}",
212 table_name, where_clause_cast
213 ));
214
215 let binds: Vec<TokenStream> = pk_fields
216 .iter()
217 .map(|f| {
218 let name = format_ident!("{}", f.rust_name);
219 quote! { .bind(#name) }
220 })
221 .collect();
222
223 let method = if use_macro {
224 let pk_arg_names: Vec<TokenStream> = pk_fields
225 .iter()
226 .map(|f| {
227 let name = format_ident!("{}", f.rust_name);
228 quote! { #name }
229 })
230 .collect();
231 quote! {
232 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
233 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
234 .fetch_optional(&self.pool)
235 .await
236 }
237 }
238 } else {
239 quote! {
240 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
241 sqlx::query_as::<_, #entity_ident>(#sql)
242 #(#binds)*
243 .fetch_optional(&self.pool)
244 .await
245 }
246 }
247 };
248 method_tokens.push(method);
249 }
250
251 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
253 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
254
255 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
264 pk_fields.clone()
265 } else if pk_fields.len() > 1 {
266 let mut combined: Vec<&ParsedField> = pk_fields.clone();
267 combined.extend(non_pk_fields.iter().copied());
268 combined
269 } else {
270 non_pk_fields.clone()
271 };
272
273 let insert_fields: Vec<TokenStream> = insert_source_fields
275 .iter()
276 .map(|f| {
277 let name = format_ident!("{}", f.rust_name);
278 if f.column_default.is_some() && !f.is_nullable {
279 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
280 quote! { pub #name: #ty, }
281 } else {
282 let ty: TokenStream = f.rust_type.parse().unwrap();
283 quote! { pub #name: #ty, }
284 }
285 })
286 .collect();
287
288 let col_names: Vec<String> = insert_source_fields
289 .iter()
290 .map(|f| quote_ident(&f.column_name, db_kind))
291 .collect();
292 let col_list = col_names.join(", ");
293
294 let placeholders: String = insert_source_fields
296 .iter()
297 .enumerate()
298 .map(|(i, f)| {
299 let p = placeholder(db_kind, i + 1);
300 match &f.column_default {
301 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
302 None => p,
303 }
304 })
305 .collect::<Vec<_>>()
306 .join(", ");
307
308 let placeholders_cast: String = insert_source_fields
309 .iter()
310 .enumerate()
311 .map(|(i, f)| {
312 let p = placeholder_with_cast(db_kind, i + 1, f);
313 match &f.column_default {
314 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
315 None => p,
316 }
317 })
318 .collect::<Vec<_>>()
319 .join(", ");
320
321 let build_insert_sql = |ph: &str| match db_kind {
322 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
323 format!(
324 "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
325 table_name, col_list, ph
326 )
327 }
328 DatabaseKind::Mysql => {
329 format!("INSERT INTO {} ({})\nVALUES ({})", table_name, col_list, ph)
330 }
331 };
332 let sql = build_insert_sql(&placeholders);
333 let sql_macro = build_insert_sql(&placeholders_cast);
334
335 let binds: Vec<TokenStream> = insert_source_fields
336 .iter()
337 .map(|f| {
338 let name = format_ident!("{}", f.rust_name);
339 quote! { .bind(¶ms.#name) }
340 })
341 .collect();
342
343 let insert_method = build_insert_method_parsed(
344 &entity_ident,
345 &insert_params_ident,
346 &sql,
347 &sql_macro,
348 &binds,
349 db_kind,
350 &table_name,
351 &pk_fields,
352 &insert_source_fields,
353 use_macro,
354 );
355 method_tokens.push(insert_method);
356
357 param_structs.push(quote! {
358 #[derive(Debug, Clone, Default)]
359 pub struct #insert_params_ident {
360 #(#insert_fields)*
361 }
362 });
363 }
364
365 if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
367 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
368
369 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
370 pk_fields.clone()
371 } else {
372 non_pk_fields.clone()
373 };
374
375 let col_names: Vec<String> = insert_source_fields
376 .iter()
377 .map(|f| quote_ident(&f.column_name, db_kind))
378 .collect();
379 let col_list = col_names.join(", ");
380 let num_cols = insert_source_fields.len();
381
382 let binds_loop: Vec<TokenStream> = insert_source_fields
383 .iter()
384 .map(|f| {
385 let name = format_ident!("{}", f.rust_name);
386 quote! { query = query.bind(¶ms.#name); }
387 })
388 .collect();
389
390 let insert_many_method = build_insert_many_transactionally_method(
391 &entity_ident,
392 &insert_params_ident,
393 &col_list,
394 num_cols,
395 &insert_source_fields,
396 &binds_loop,
397 db_kind,
398 &table_name,
399 &pk_fields,
400 );
401 method_tokens.push(insert_many_method);
402
403 if !methods.insert {
405 let insert_fields: Vec<TokenStream> = insert_source_fields
406 .iter()
407 .map(|f| {
408 let name = format_ident!("{}", f.rust_name);
409 if f.column_default.is_some() && !f.is_nullable {
410 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
411 quote! { pub #name: #ty, }
412 } else {
413 let ty: TokenStream = f.rust_type.parse().unwrap();
414 quote! { pub #name: #ty, }
415 }
416 })
417 .collect();
418
419 param_structs.push(quote! {
420 #[derive(Debug, Clone, Default)]
421 pub struct #insert_params_ident {
422 #(#insert_fields)*
423 }
424 });
425 }
426 }
427
428 if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
430 let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
431
432 let pk_fn_params: Vec<TokenStream> = pk_fields
434 .iter()
435 .map(|f| {
436 let name = format_ident!("{}", f.rust_name);
437 let ty: TokenStream = f.inner_type.parse().unwrap();
438 quote! { #name: #ty }
439 })
440 .collect();
441
442 let overwrite_fields: Vec<TokenStream> = non_pk_fields
444 .iter()
445 .map(|f| {
446 let name = format_ident!("{}", f.rust_name);
447 let ty: TokenStream = f.rust_type.parse().unwrap();
448 quote! { pub #name: #ty, }
449 })
450 .collect();
451
452 let set_cols: Vec<String> = non_pk_fields
453 .iter()
454 .enumerate()
455 .map(|(i, f)| {
456 let p = placeholder(db_kind, i + 1);
457 format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
458 })
459 .collect();
460 let set_clause = set_cols.join(",\n ");
461
462 let set_cols_cast: Vec<String> = non_pk_fields
463 .iter()
464 .enumerate()
465 .map(|(i, f)| {
466 let p = placeholder_with_cast(db_kind, i + 1, f);
467 format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
468 })
469 .collect();
470 let set_clause_cast = set_cols_cast.join(",\n ");
471
472 let pk_start = non_pk_fields.len() + 1;
473 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
474 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
475
476 let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
477 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
478 format!(
479 "UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *",
480 table_name, sc, wc
481 )
482 }
483 DatabaseKind::Mysql => {
484 format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc)
485 }
486 };
487
488 let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
489 let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
490
491 let mut all_binds: Vec<TokenStream> = non_pk_fields
493 .iter()
494 .map(|f| {
495 let name = format_ident!("{}", f.rust_name);
496 quote! { .bind(¶ms.#name) }
497 })
498 .collect();
499 for f in &pk_fields {
500 let name = format_ident!("{}", f.rust_name);
501 all_binds.push(quote! { .bind(#name) });
502 }
503
504 let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
506 .iter()
507 .map(|f| macro_arg_for_field(f))
508 .chain(pk_fields.iter().map(|f| {
509 let name = format_ident!("{}", f.rust_name);
510 quote! { #name }
511 }))
512 .collect();
513
514 let overwrite_method = if use_macro {
515 match db_kind {
516 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
517 quote! {
518 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
519 sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
520 .fetch_one(&self.pool)
521 .await
522 }
523 }
524 }
525 DatabaseKind::Mysql => {
526 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
527 let select_sql = raw_sql_lit(&format!(
528 "SELECT *\nFROM {}\nWHERE {}",
529 table_name, pk_where_select
530 ));
531 let pk_macro_args: Vec<TokenStream> = pk_fields
532 .iter()
533 .map(|f| {
534 let name = format_ident!("{}", f.rust_name);
535 quote! { #name }
536 })
537 .collect();
538 quote! {
539 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
540 sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
541 .execute(&self.pool)
542 .await?;
543 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
544 .fetch_one(&self.pool)
545 .await
546 }
547 }
548 }
549 }
550 } else {
551 match db_kind {
552 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
553 quote! {
554 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
555 sqlx::query_as::<_, #entity_ident>(#sql)
556 #(#all_binds)*
557 .fetch_one(&self.pool)
558 .await
559 }
560 }
561 }
562 DatabaseKind::Mysql => {
563 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
564 let select_sql = raw_sql_lit(&format!(
565 "SELECT *\nFROM {}\nWHERE {}",
566 table_name, pk_where_select
567 ));
568 let pk_binds: Vec<TokenStream> = pk_fields
569 .iter()
570 .map(|f| {
571 let name = format_ident!("{}", f.rust_name);
572 quote! { .bind(#name) }
573 })
574 .collect();
575 quote! {
576 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
577 sqlx::query(#sql)
578 #(#all_binds)*
579 .execute(&self.pool)
580 .await?;
581 sqlx::query_as::<_, #entity_ident>(#select_sql)
582 #(#pk_binds)*
583 .fetch_one(&self.pool)
584 .await
585 }
586 }
587 }
588 }
589 };
590 method_tokens.push(overwrite_method);
591
592 param_structs.push(quote! {
593 #[derive(Debug, Clone, Default)]
594 pub struct #overwrite_params_ident {
595 #(#overwrite_fields)*
596 }
597 });
598 }
599
600 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
602 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
603
604 let pk_fn_params: Vec<TokenStream> = pk_fields
606 .iter()
607 .map(|f| {
608 let name = format_ident!("{}", f.rust_name);
609 let ty: TokenStream = f.inner_type.parse().unwrap();
610 quote! { #name: #ty }
611 })
612 .collect();
613
614 let update_fields: Vec<TokenStream> = non_pk_fields
616 .iter()
617 .map(|f| {
618 let name = format_ident!("{}", f.rust_name);
619 if f.is_nullable {
620 let ty: TokenStream = f.rust_type.parse().unwrap();
622 quote! { pub #name: #ty, }
623 } else {
624 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
625 quote! { pub #name: #ty, }
626 }
627 })
628 .collect();
629
630 let set_cols: Vec<String> = non_pk_fields
632 .iter()
633 .enumerate()
634 .map(|(i, f)| {
635 let p = placeholder(db_kind, i + 1);
636 let col = quote_ident(&f.column_name, db_kind);
637 format!("{col} = COALESCE({p}, {col})", col = col, p = p)
638 })
639 .collect();
640 let set_clause = set_cols.join(",\n ");
641
642 let set_cols_cast: Vec<String> = non_pk_fields
644 .iter()
645 .enumerate()
646 .map(|(i, f)| {
647 let p = placeholder_with_cast(db_kind, i + 1, f);
648 let col = quote_ident(&f.column_name, db_kind);
649 format!("{col} = COALESCE({p}, {col})", col = col, p = p)
650 })
651 .collect();
652 let set_clause_cast = set_cols_cast.join(",\n ");
653
654 let pk_start = non_pk_fields.len() + 1;
655 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
656 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
657
658 let build_update_sql = |sc: &str, wc: &str| match db_kind {
659 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
660 format!(
661 "UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *",
662 table_name, sc, wc
663 )
664 }
665 DatabaseKind::Mysql => {
666 format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc)
667 }
668 };
669 let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
670 let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
671
672 let mut all_binds: Vec<TokenStream> = non_pk_fields
674 .iter()
675 .map(|f| {
676 let name = format_ident!("{}", f.rust_name);
677 quote! { .bind(¶ms.#name) }
678 })
679 .collect();
680 for f in &pk_fields {
681 let name = format_ident!("{}", f.rust_name);
682 all_binds.push(quote! { .bind(#name) });
683 }
684
685 let update_macro_args: Vec<TokenStream> = non_pk_fields
687 .iter()
688 .map(|f| macro_arg_for_field(f))
689 .chain(pk_fields.iter().map(|f| {
690 let name = format_ident!("{}", f.rust_name);
691 quote! { #name }
692 }))
693 .collect();
694
695 let update_method = if use_macro {
696 match db_kind {
697 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
698 quote! {
699 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
700 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
701 .fetch_one(&self.pool)
702 .await
703 }
704 }
705 }
706 DatabaseKind::Mysql => {
707 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
708 let select_sql = raw_sql_lit(&format!(
709 "SELECT *\nFROM {}\nWHERE {}",
710 table_name, pk_where_select
711 ));
712 let pk_macro_args: Vec<TokenStream> = pk_fields
713 .iter()
714 .map(|f| {
715 let name = format_ident!("{}", f.rust_name);
716 quote! { #name }
717 })
718 .collect();
719 quote! {
720 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
721 sqlx::query!(#sql_macro, #(#update_macro_args),*)
722 .execute(&self.pool)
723 .await?;
724 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
725 .fetch_one(&self.pool)
726 .await
727 }
728 }
729 }
730 }
731 } else {
732 match db_kind {
733 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
734 quote! {
735 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
736 sqlx::query_as::<_, #entity_ident>(#sql)
737 #(#all_binds)*
738 .fetch_one(&self.pool)
739 .await
740 }
741 }
742 }
743 DatabaseKind::Mysql => {
744 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
745 let select_sql = raw_sql_lit(&format!(
746 "SELECT *\nFROM {}\nWHERE {}",
747 table_name, pk_where_select
748 ));
749 let pk_binds: Vec<TokenStream> = pk_fields
750 .iter()
751 .map(|f| {
752 let name = format_ident!("{}", f.rust_name);
753 quote! { .bind(#name) }
754 })
755 .collect();
756 quote! {
757 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
758 sqlx::query(#sql)
759 #(#all_binds)*
760 .execute(&self.pool)
761 .await?;
762 sqlx::query_as::<_, #entity_ident>(#select_sql)
763 #(#pk_binds)*
764 .fetch_one(&self.pool)
765 .await
766 }
767 }
768 }
769 }
770 };
771 method_tokens.push(update_method);
772
773 param_structs.push(quote! {
774 #[derive(Debug, Clone, Default)]
775 pub struct #update_params_ident {
776 #(#update_fields)*
777 }
778 });
779 }
780
781 if !is_view && methods.delete && !pk_fields.is_empty() {
783 let pk_params: Vec<TokenStream> = pk_fields
784 .iter()
785 .map(|f| {
786 let name = format_ident!("{}", f.rust_name);
787 let ty: TokenStream = f.inner_type.parse().unwrap();
788 quote! { #name: #ty }
789 })
790 .collect();
791
792 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
793 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
794 let sql = raw_sql_lit(&format!(
795 "DELETE FROM {}\nWHERE {}",
796 table_name, where_clause
797 ));
798 let sql_macro = raw_sql_lit(&format!(
799 "DELETE FROM {}\nWHERE {}",
800 table_name, where_clause_cast
801 ));
802
803 let binds: Vec<TokenStream> = pk_fields
804 .iter()
805 .map(|f| {
806 let name = format_ident!("{}", f.rust_name);
807 quote! { .bind(#name) }
808 })
809 .collect();
810
811 let method = if query_macro {
812 let pk_arg_names: Vec<TokenStream> = pk_fields
813 .iter()
814 .map(|f| {
815 let name = format_ident!("{}", f.rust_name);
816 quote! { #name }
817 })
818 .collect();
819 quote! {
820 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
821 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
822 .execute(&self.pool)
823 .await?;
824 Ok(())
825 }
826 }
827 } else {
828 quote! {
829 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
830 sqlx::query(#sql)
831 #(#binds)*
832 .execute(&self.pool)
833 .await?;
834 Ok(())
835 }
836 }
837 };
838 method_tokens.push(method);
839 }
840
841 let pool_vis: TokenStream = match pool_visibility {
842 PoolVisibility::Private => quote! {},
843 PoolVisibility::Pub => quote! { pub },
844 PoolVisibility::PubCrate => quote! { pub(crate) },
845 };
846
847 let tokens = quote! {
848 #(#param_structs)*
849
850 pub struct #repo_ident {
851 #pool_vis pool: #pool_type,
852 }
853
854 impl #repo_ident {
855 pub fn new(pool: #pool_type) -> Self {
856 Self { pool }
857 }
858
859 #(#method_tokens)*
860 }
861 };
862
863 (tokens, imports)
864}
865
866fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
867 match db_kind {
868 DatabaseKind::Postgres => quote! { sqlx::PgPool },
869 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
870 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
871 }
872}
873
874fn raw_sql_lit(s: &str) -> TokenStream {
878 let mut hashes = 1usize;
882 while s.contains(&format!("\"{}", "#".repeat(hashes))) {
883 hashes += 1;
884 }
885 let fence = "#".repeat(hashes);
886 let body = if s.contains('\n') {
887 format!("\n{}\n", s)
888 } else {
889 s.to_string()
890 };
891 format!("r{fence}\"{body}\"{fence}", fence = fence, body = body)
892 .parse()
893 .expect("raw_sql_lit must produce a valid Rust raw string literal")
894}
895
896fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
897 match db_kind {
898 DatabaseKind::Postgres => format!("${}", index),
899 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
900 }
901}
902
903fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
904 let base = placeholder(db_kind, index);
905 match (&field.sql_type, field.is_sql_array) {
906 (Some(t), true) => format!("{} as {}[]", base, t),
907 (Some(t), false) => format!("{} as {}", base, t),
908 (None, _) => base,
909 }
910}
911
912fn build_where_clause_parsed(
913 pk_fields: &[&ParsedField],
914 db_kind: DatabaseKind,
915 start_index: usize,
916) -> String {
917 pk_fields
918 .iter()
919 .enumerate()
920 .map(|(i, f)| {
921 let p = placeholder(db_kind, start_index + i);
922 format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
923 })
924 .collect::<Vec<_>>()
925 .join(" AND ")
926}
927
928fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
929 let name = format_ident!("{}", field.rust_name);
930 let check_type = if field.is_nullable {
931 &field.inner_type
932 } else {
933 &field.rust_type
934 };
935 let normalized = check_type.replace(' ', "");
936 if normalized.starts_with("Vec<") {
937 quote! { params.#name.as_slice() }
938 } else {
939 quote! { params.#name }
940 }
941}
942
943fn build_where_clause_cast(
944 pk_fields: &[&ParsedField],
945 db_kind: DatabaseKind,
946 start_index: usize,
947) -> String {
948 pk_fields
949 .iter()
950 .enumerate()
951 .map(|(i, f)| {
952 let p = placeholder_with_cast(db_kind, start_index + i, f);
953 format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
954 })
955 .collect::<Vec<_>>()
956 .join(" AND ")
957}
958
959#[allow(clippy::too_many_arguments)]
960fn build_insert_method_parsed(
961 entity_ident: &proc_macro2::Ident,
962 insert_params_ident: &proc_macro2::Ident,
963 sql: &str,
964 sql_macro: &str,
965 binds: &[TokenStream],
966 db_kind: DatabaseKind,
967 table_name: &str,
968 pk_fields: &[&ParsedField],
969 non_pk_fields: &[&ParsedField],
970 use_macro: bool,
971) -> TokenStream {
972 let sql = raw_sql_lit(sql);
973 let sql_macro = raw_sql_lit(sql_macro);
974
975 if use_macro {
976 let macro_args: Vec<TokenStream> = non_pk_fields
977 .iter()
978 .map(|f| macro_arg_for_field(f))
979 .collect();
980
981 match db_kind {
982 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
983 quote! {
984 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
985 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
986 .fetch_one(&self.pool)
987 .await
988 }
989 }
990 }
991 DatabaseKind::Mysql => {
992 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
993 let select_sql = raw_sql_lit(&format!(
994 "SELECT *\nFROM {}\nWHERE {}",
995 table_name, pk_where
996 ));
997 if pk_fields.len() > 1 {
998 let pk_macro_args: Vec<TokenStream> = pk_fields
1003 .iter()
1004 .map(|f| {
1005 let name = format_ident!("{}", f.rust_name);
1006 quote! { params.#name }
1007 })
1008 .collect();
1009 quote! {
1010 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1011 sqlx::query!(#sql_macro, #(#macro_args),*)
1012 .execute(&self.pool)
1013 .await?;
1014 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
1015 .fetch_one(&self.pool)
1016 .await
1017 }
1018 }
1019 } else {
1020 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
1021 quote! {
1022 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1023 sqlx::query!(#sql_macro, #(#macro_args),*)
1024 .execute(&self.pool)
1025 .await?;
1026 let id = sqlx::query_scalar!(#last_insert_id_sql)
1027 .fetch_one(&self.pool)
1028 .await?;
1029 sqlx::query_as!(#entity_ident, #select_sql, id)
1030 .fetch_one(&self.pool)
1031 .await
1032 }
1033 }
1034 }
1035 }
1036 }
1037 } else {
1038 match db_kind {
1039 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1040 quote! {
1041 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1042 sqlx::query_as::<_, #entity_ident>(#sql)
1043 #(#binds)*
1044 .fetch_one(&self.pool)
1045 .await
1046 }
1047 }
1048 }
1049 DatabaseKind::Mysql => {
1050 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1051 let select_sql = raw_sql_lit(&format!(
1052 "SELECT *\nFROM {}\nWHERE {}",
1053 table_name, pk_where
1054 ));
1055 if pk_fields.len() > 1 {
1056 let pk_binds: Vec<TokenStream> = pk_fields
1057 .iter()
1058 .map(|f| {
1059 let name = format_ident!("{}", f.rust_name);
1060 quote! { .bind(¶ms.#name) }
1061 })
1062 .collect();
1063 quote! {
1064 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1065 sqlx::query(#sql)
1066 #(#binds)*
1067 .execute(&self.pool)
1068 .await?;
1069 sqlx::query_as::<_, #entity_ident>(#select_sql)
1070 #(#pk_binds)*
1071 .fetch_one(&self.pool)
1072 .await
1073 }
1074 }
1075 } else {
1076 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1077 quote! {
1078 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1079 sqlx::query(#sql)
1080 #(#binds)*
1081 .execute(&self.pool)
1082 .await?;
1083 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1084 .fetch_one(&self.pool)
1085 .await?;
1086 sqlx::query_as::<_, #entity_ident>(#select_sql)
1087 .bind(id)
1088 .fetch_one(&self.pool)
1089 .await
1090 }
1091 }
1092 }
1093 }
1094 }
1095 }
1096}
1097
1098#[allow(clippy::too_many_arguments)]
1099fn build_insert_many_transactionally_method(
1100 entity_ident: &proc_macro2::Ident,
1101 insert_params_ident: &proc_macro2::Ident,
1102 col_list: &str,
1103 num_cols: usize,
1104 insert_source_fields: &[&ParsedField],
1105 binds_loop: &[TokenStream],
1106 db_kind: DatabaseKind,
1107 table_name: &str,
1108 pk_fields: &[&ParsedField],
1109) -> TokenStream {
1110 let body = match db_kind {
1111 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1112 let col_list_str = col_list.to_string();
1113 let table_name_str = table_name.to_string();
1114
1115 let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1116 .iter()
1117 .enumerate()
1118 .map(|(i, f)| {
1119 let offset = i;
1120 match &f.column_default {
1121 Some(default_expr) => {
1122 let def = default_expr.as_str();
1123 match db_kind {
1124 DatabaseKind::Postgres => quote! {
1125 format!("COALESCE(${}, {})", base + #offset + 1, #def)
1126 },
1127 _ => quote! {
1128 format!("COALESCE(?, {})", #def)
1129 },
1130 }
1131 }
1132 None => match db_kind {
1133 DatabaseKind::Postgres => quote! {
1134 format!("${}", base + #offset + 1)
1135 },
1136 _ => quote! {
1137 "?".to_string()
1138 },
1139 },
1140 }
1141 })
1142 .collect();
1143
1144 quote! {
1145 let mut tx = self.pool.begin().await?;
1146 let mut all_results = Vec::with_capacity(entries.len());
1147 let max_per_chunk = 65535 / #num_cols;
1148 for chunk in entries.chunks(max_per_chunk) {
1149 let mut values_parts = Vec::with_capacity(chunk.len());
1150 for (row_idx, _) in chunk.iter().enumerate() {
1151 let base = row_idx * #num_cols;
1152 let placeholders = vec![#(#row_placeholder_exprs),*];
1153 values_parts.push(format!("({})", placeholders.join(", ")));
1154 }
1155 let sql = format!(
1156 "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1157 #table_name_str,
1158 #col_list_str,
1159 values_parts.join(", ")
1160 );
1161 let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1162 for params in chunk {
1163 #(#binds_loop)*
1164 }
1165 let rows = query.fetch_all(&mut *tx).await?;
1166 all_results.extend(rows);
1167 }
1168 tx.commit().await?;
1169 Ok(all_results)
1170 }
1171 }
1172 DatabaseKind::Mysql => {
1173 let single_placeholders: String = insert_source_fields
1174 .iter()
1175 .enumerate()
1176 .map(|(i, f)| {
1177 let p = placeholder(db_kind, i + 1);
1178 match &f.column_default {
1179 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1180 None => p,
1181 }
1182 })
1183 .collect::<Vec<_>>()
1184 .join(", ");
1185
1186 let single_insert_sql = raw_sql_lit(&format!(
1187 "INSERT INTO {} ({})\nVALUES ({})",
1188 table_name, col_list, single_placeholders
1189 ));
1190
1191 let single_binds: Vec<TokenStream> = insert_source_fields
1192 .iter()
1193 .map(|f| {
1194 let name = format_ident!("{}", f.rust_name);
1195 quote! { .bind(¶ms.#name) }
1196 })
1197 .collect();
1198
1199 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1200 let select_sql = raw_sql_lit(&format!(
1201 "SELECT *\nFROM {}\nWHERE {}",
1202 table_name, pk_where
1203 ));
1204
1205 if pk_fields.len() > 1 {
1206 let pk_binds: Vec<TokenStream> = pk_fields
1207 .iter()
1208 .map(|f| {
1209 let name = format_ident!("{}", f.rust_name);
1210 quote! { .bind(¶ms.#name) }
1211 })
1212 .collect();
1213 quote! {
1214 let mut tx = self.pool.begin().await?;
1215 let mut results = Vec::with_capacity(entries.len());
1216 for params in &entries {
1217 sqlx::query(#single_insert_sql)
1218 #(#single_binds)*
1219 .execute(&mut *tx)
1220 .await?;
1221 let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1222 #(#pk_binds)*
1223 .fetch_one(&mut *tx)
1224 .await?;
1225 results.push(row);
1226 }
1227 tx.commit().await?;
1228 Ok(results)
1229 }
1230 } else {
1231 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1232 quote! {
1233 let mut tx = self.pool.begin().await?;
1234 let mut results = Vec::with_capacity(entries.len());
1235 for params in &entries {
1236 sqlx::query(#single_insert_sql)
1237 #(#single_binds)*
1238 .execute(&mut *tx)
1239 .await?;
1240 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1241 .fetch_one(&mut *tx)
1242 .await?;
1243 let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1244 .bind(id)
1245 .fetch_one(&mut *tx)
1246 .await?;
1247 results.push(row);
1248 }
1249 tx.commit().await?;
1250 Ok(results)
1251 }
1252 }
1253 }
1254 };
1255
1256 quote! {
1257 pub async fn insert_many_transactionally(
1258 &self,
1259 entries: Vec<#insert_params_ident>,
1260 ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1261 if entries.is_empty() {
1265 return Ok(Vec::new());
1266 }
1267 #body
1268 }
1269 }
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274 use super::*;
1275 use crate::cli::Methods;
1276 use crate::codegen::parse_and_format;
1277 use crate::codegen::parse_and_format_with_tab_spaces;
1278
1279 fn make_field(
1280 rust_name: &str,
1281 column_name: &str,
1282 rust_type: &str,
1283 nullable: bool,
1284 is_pk: bool,
1285 ) -> ParsedField {
1286 let inner_type = if nullable {
1287 rust_type
1289 .strip_prefix("Option<")
1290 .and_then(|s| s.strip_suffix('>'))
1291 .unwrap_or(rust_type)
1292 .to_string()
1293 } else {
1294 rust_type.to_string()
1295 };
1296 ParsedField {
1297 rust_name: rust_name.to_string(),
1298 column_name: column_name.to_string(),
1299 rust_type: rust_type.to_string(),
1300 is_nullable: nullable,
1301 inner_type,
1302 is_primary_key: is_pk,
1303 sql_type: None,
1304 is_sql_array: false,
1305 column_default: None,
1306 }
1307 }
1308
1309 fn make_field_with_default(
1310 rust_name: &str,
1311 column_name: &str,
1312 rust_type: &str,
1313 nullable: bool,
1314 is_pk: bool,
1315 default: &str,
1316 ) -> ParsedField {
1317 let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1318 f.column_default = Some(default.to_string());
1319 f
1320 }
1321
1322 fn entity_with_defaults() -> ParsedEntity {
1323 ParsedEntity {
1324 struct_name: "Tasks".to_string(),
1325 table_name: "tasks".to_string(),
1326 schema_name: None,
1327 is_view: false,
1328 fields: vec![
1329 make_field("id", "id", "i32", false, true),
1330 make_field("title", "title", "String", false, false),
1331 make_field_with_default(
1332 "status",
1333 "status",
1334 "String",
1335 false,
1336 false,
1337 "'idle'::task_status",
1338 ),
1339 make_field_with_default("priority", "priority", "i32", false, false, "0"),
1340 make_field_with_default(
1341 "created_at",
1342 "created_at",
1343 "DateTime<Utc>",
1344 false,
1345 false,
1346 "now()",
1347 ),
1348 make_field("description", "description", "Option<String>", true, false),
1349 make_field_with_default(
1350 "deleted_at",
1351 "deleted_at",
1352 "Option<DateTime<Utc>>",
1353 true,
1354 false,
1355 "NULL",
1356 ),
1357 ],
1358 imports: vec![],
1359 }
1360 }
1361
1362 fn standard_entity() -> ParsedEntity {
1363 ParsedEntity {
1364 struct_name: "Users".to_string(),
1365 table_name: "users".to_string(),
1366 schema_name: None,
1367 is_view: false,
1368 fields: vec![
1369 make_field("id", "id", "i32", false, true),
1370 make_field("name", "name", "String", false, false),
1371 make_field("email", "email", "Option<String>", true, false),
1372 ],
1373 imports: vec![],
1374 }
1375 }
1376
1377 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1378 let skip = Methods::all();
1379 let (tokens, _) = generate_crud_from_parsed(
1380 entity,
1381 db,
1382 "crate::models::users",
1383 &skip,
1384 false,
1385 PoolVisibility::Private,
1386 );
1387 parse_and_format(&tokens).unwrap()
1388 }
1389
1390 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1391 let skip = Methods::all();
1392 let (tokens, _) = generate_crud_from_parsed(
1393 entity,
1394 db,
1395 "crate::models::users",
1396 &skip,
1397 true,
1398 PoolVisibility::Private,
1399 );
1400 parse_and_format(&tokens).unwrap()
1401 }
1402
1403 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1404 let (tokens, _) = generate_crud_from_parsed(
1405 entity,
1406 db,
1407 "crate::models::users",
1408 methods,
1409 false,
1410 PoolVisibility::Private,
1411 );
1412 parse_and_format(&tokens).unwrap()
1413 }
1414
1415 fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1416 let skip = Methods::all();
1417 let (tokens, _) = generate_crud_from_parsed(
1418 entity,
1419 db,
1420 "crate::models::users",
1421 &skip,
1422 false,
1423 PoolVisibility::Private,
1424 );
1425 parse_and_format_with_tab_spaces(&tokens, tab_spaces).unwrap()
1426 }
1427
1428 #[test]
1431 fn test_repo_struct_name() {
1432 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1433 assert!(code.contains("pub struct UsersRepository"));
1434 }
1435
1436 #[test]
1437 fn test_repo_new_method() {
1438 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1439 assert!(code.contains("pub fn new("));
1440 }
1441
1442 #[test]
1443 fn test_repo_pool_field_pg() {
1444 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1445 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1446 }
1447
1448 #[test]
1449 fn test_repo_pool_field_pub() {
1450 let skip = Methods::all();
1451 let (tokens, _) = generate_crud_from_parsed(
1452 &standard_entity(),
1453 DatabaseKind::Postgres,
1454 "crate::models::users",
1455 &skip,
1456 false,
1457 PoolVisibility::Pub,
1458 );
1459 let code = parse_and_format(&tokens).unwrap();
1460 assert!(
1461 code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool")
1462 );
1463 }
1464
1465 #[test]
1466 fn test_repo_pool_field_pub_crate() {
1467 let skip = Methods::all();
1468 let (tokens, _) = generate_crud_from_parsed(
1469 &standard_entity(),
1470 DatabaseKind::Postgres,
1471 "crate::models::users",
1472 &skip,
1473 false,
1474 PoolVisibility::PubCrate,
1475 );
1476 let code = parse_and_format(&tokens).unwrap();
1477 assert!(
1478 code.contains("pub(crate) pool: sqlx::PgPool")
1479 || code.contains("pub(crate) pool: sqlx :: PgPool")
1480 );
1481 }
1482
1483 #[test]
1484 fn test_repo_pool_field_private() {
1485 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1486 assert!(!code.contains("pub pool"));
1488 assert!(!code.contains("pub(crate) pool"));
1489 }
1490
1491 #[test]
1492 fn test_repo_pool_field_mysql() {
1493 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1494 assert!(code.contains("MySqlPool") || code.contains("MySql"));
1495 }
1496
1497 #[test]
1498 fn test_repo_pool_field_sqlite() {
1499 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1500 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1501 }
1502
1503 #[test]
1506 fn test_get_all_method() {
1507 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1508 assert!(code.contains("pub async fn get_all"));
1509 }
1510
1511 #[test]
1512 fn test_get_all_returns_vec() {
1513 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1514 assert!(code.contains("Vec<Users>"));
1515 }
1516
1517 #[test]
1518 fn test_get_all_sql() {
1519 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1520 assert!(code.contains("SELECT * FROM users"));
1521 }
1522
1523 #[test]
1526 fn test_paginate_method() {
1527 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1528 assert!(code.contains("pub async fn paginate"));
1529 }
1530
1531 #[test]
1532 fn test_paginate_params_struct() {
1533 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1534 assert!(code.contains("pub struct PaginateUsersParams"));
1535 }
1536
1537 #[test]
1538 fn test_paginate_params_fields() {
1539 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1540 assert!(code.contains("pub page: i64"));
1541 assert!(code.contains("pub per_page: i64"));
1542 }
1543
1544 #[test]
1545 fn test_paginate_returns_paginated() {
1546 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1547 assert!(code.contains("PaginatedUsers"));
1548 assert!(code.contains("PaginationUsersMeta"));
1549 }
1550
1551 #[test]
1552 fn test_paginate_meta_struct() {
1553 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1554 assert!(code.contains("pub struct PaginationUsersMeta"));
1555 assert!(code.contains("pub total: i64"));
1556 assert!(code.contains("pub last_page: i64"));
1557 assert!(code.contains("pub first_page: i64"));
1558 assert!(code.contains("pub current_page: i64"));
1559 }
1560
1561 #[test]
1562 fn test_paginate_data_struct() {
1563 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1564 assert!(code.contains("pub struct PaginatedUsers"));
1565 assert!(code.contains("pub meta: PaginationUsersMeta"));
1566 assert!(code.contains("pub data: Vec<Users>"));
1567 }
1568
1569 #[test]
1570 fn test_paginate_count_sql() {
1571 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1572 assert!(code.contains("SELECT COUNT(*) FROM users"));
1573 }
1574
1575 #[test]
1576 fn test_paginate_sql_pg() {
1577 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1578 assert!(code.contains("LIMIT $1 OFFSET $2"));
1579 }
1580
1581 #[test]
1582 fn test_paginate_sql_mysql() {
1583 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1584 assert!(code.contains("LIMIT ? OFFSET ?"));
1585 }
1586
1587 #[test]
1590 fn test_get_method() {
1591 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1592 assert!(code.contains("pub async fn get"));
1593 }
1594
1595 #[test]
1596 fn test_get_returns_option() {
1597 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1598 assert!(code.contains("Option<Users>"));
1599 }
1600
1601 #[test]
1602 fn test_get_where_pk_pg() {
1603 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1604 assert!(code.contains("WHERE id = $1"));
1605 }
1606
1607 #[test]
1608 fn test_get_where_pk_mysql() {
1609 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1610 assert!(code.contains("WHERE id = ?"));
1611 }
1612
1613 #[test]
1616 fn test_insert_method() {
1617 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1618 assert!(code.contains("pub async fn insert"));
1619 }
1620
1621 #[test]
1622 fn test_insert_params_struct() {
1623 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1624 assert!(code.contains("pub struct InsertUsersParams"));
1625 }
1626
1627 #[test]
1628 fn test_insert_params_no_pk() {
1629 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1630 assert!(code.contains("pub name: String"));
1631 assert!(
1632 code.contains("pub email: Option<String>")
1633 || code.contains("pub email: Option < String >")
1634 );
1635 }
1636
1637 #[test]
1638 fn test_insert_returning_pg() {
1639 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1640 assert!(code.contains("RETURNING *"));
1641 }
1642
1643 #[test]
1644 fn test_insert_returning_sqlite() {
1645 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1646 assert!(code.contains("RETURNING *"));
1647 }
1648
1649 #[test]
1650 fn test_insert_mysql_last_insert_id() {
1651 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1652 assert!(code.contains("LAST_INSERT_ID"));
1653 }
1654
1655 #[test]
1658 fn test_insert_default_col_is_optional() {
1659 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1660 let struct_start = code
1662 .find("pub struct InsertTasksParams")
1663 .expect("InsertTasksParams not found");
1664 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1665 let struct_body = &code[struct_start..struct_end];
1666 assert!(
1667 struct_body.contains("Option") && struct_body.contains("status"),
1668 "Expected status as Option in InsertTasksParams: {}",
1669 struct_body
1670 );
1671 }
1672
1673 #[test]
1674 fn test_insert_non_default_col_required() {
1675 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1676 let struct_start = code
1678 .find("pub struct InsertTasksParams")
1679 .expect("InsertTasksParams not found");
1680 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1681 let struct_body = &code[struct_start..struct_end];
1682 assert!(
1683 struct_body.contains("title") && struct_body.contains("String"),
1684 "Expected title as String: {}",
1685 struct_body
1686 );
1687 }
1688
1689 #[test]
1690 fn test_insert_default_col_coalesce_sql() {
1691 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1692 assert!(
1693 code.contains("COALESCE($2, 'idle'::task_status)"),
1694 "Expected COALESCE for status:\n{}",
1695 code
1696 );
1697 assert!(
1698 code.contains("COALESCE($3, 0)"),
1699 "Expected COALESCE for priority:\n{}",
1700 code
1701 );
1702 assert!(
1703 code.contains("COALESCE($4, now())"),
1704 "Expected COALESCE for created_at:\n{}",
1705 code
1706 );
1707 }
1708
1709 #[test]
1710 fn test_insert_default_col_coalesce_json() {
1711 let mut entity = entity_with_defaults();
1712 entity.fields.push(make_field_with_default(
1713 "metadata",
1714 "metadata",
1715 "serde_json::Value",
1716 false,
1717 false,
1718 r#"'{"key": "value"}'::jsonb"#,
1719 ));
1720 let code = gen(&entity, DatabaseKind::Postgres);
1721 assert!(
1722 code.contains(r#"COALESCE($7, '{"key": "value"}'::jsonb)"#),
1723 "Expected COALESCE with JSON default:\n{}",
1724 code
1725 );
1726 }
1727
1728 #[test]
1729 fn test_insert_no_coalesce_for_non_default() {
1730 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1731 assert!(
1733 code.contains("VALUES ($1, COALESCE"),
1734 "Expected $1 without COALESCE for title:\n{}",
1735 code
1736 );
1737 }
1738
1739 #[test]
1740 fn test_insert_nullable_with_default_no_double_option() {
1741 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1742 assert!(
1743 !code.contains("Option < Option") && !code.contains("Option<Option"),
1744 "Should not have Option<Option>:\n{}",
1745 code
1746 );
1747 }
1748
1749 #[test]
1750 fn test_insert_derive_default() {
1751 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1752 let struct_start = code
1753 .find("pub struct InsertTasksParams")
1754 .expect("InsertTasksParams not found");
1755 let before_struct = &code[..struct_start];
1756 assert!(
1757 before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"),
1758 "Expected #[derive(Default)] on InsertTasksParams"
1759 );
1760 }
1761
1762 #[test]
1765 fn test_update_method() {
1766 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1767 assert!(code.contains("pub async fn update"));
1768 }
1769
1770 #[test]
1771 fn test_update_params_struct() {
1772 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1773 assert!(code.contains("pub struct UpdateUsersParams"));
1774 }
1775
1776 #[test]
1777 fn test_update_pk_in_fn_signature() {
1778 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1779 let update_pos = code.find("fn update").expect("fn update not found");
1781 let params_pos = code[update_pos..]
1782 .find("UpdateUsersParams")
1783 .expect("UpdateUsersParams not found in update fn");
1784 let signature = &code[update_pos..update_pos + params_pos];
1785 assert!(
1786 signature.contains("id"),
1787 "Expected 'id' PK in update fn signature: {}",
1788 signature
1789 );
1790 }
1791
1792 #[test]
1793 fn test_update_pk_not_in_struct() {
1794 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1795 let struct_start = code
1798 .find("pub struct UpdateUsersParams")
1799 .expect("UpdateUsersParams not found");
1800 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1801 let struct_body = &code[struct_start..struct_end];
1802 assert!(
1803 !struct_body.contains("pub id"),
1804 "PK 'id' should not be in UpdateUsersParams:\n{}",
1805 struct_body
1806 );
1807 }
1808
1809 #[test]
1810 fn test_update_params_non_nullable_wrapped_in_option() {
1811 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1812 assert!(
1814 code.contains("pub name: Option<String>")
1815 || code.contains("pub name : Option < String >")
1816 );
1817 }
1818
1819 #[test]
1820 fn test_update_params_already_nullable_no_double_option() {
1821 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1822 assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1824 }
1825
1826 #[test]
1827 fn test_update_set_clause_uses_coalesce_pg() {
1828 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1829 assert!(
1830 code.contains("COALESCE($1, name)"),
1831 "Expected COALESCE for name:\n{}",
1832 code
1833 );
1834 assert!(
1835 code.contains("COALESCE($2, email)"),
1836 "Expected COALESCE for email:\n{}",
1837 code
1838 );
1839 }
1840
1841 #[test]
1842 fn test_update_where_clause_pg() {
1843 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1844 assert!(code.contains("WHERE id = $3"));
1845 }
1846
1847 #[test]
1848 fn test_update_returning_pg() {
1849 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1850 assert!(code.contains("COALESCE"));
1851 assert!(code.contains("RETURNING *"));
1852 }
1853
1854 #[test]
1855 fn test_update_set_clause_mysql() {
1856 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1857 assert!(
1858 code.contains("COALESCE(?, name)"),
1859 "Expected COALESCE for MySQL:\n{}",
1860 code
1861 );
1862 assert!(
1863 code.contains("COALESCE(?, email)"),
1864 "Expected COALESCE for email in MySQL:\n{}",
1865 code
1866 );
1867 }
1868
1869 #[test]
1870 fn test_update_set_clause_sqlite() {
1871 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1872 assert!(
1873 code.contains("COALESCE(?, name)"),
1874 "Expected COALESCE for SQLite:\n{}",
1875 code
1876 );
1877 }
1878
1879 #[test]
1882 fn test_overwrite_method() {
1883 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1884 assert!(code.contains("pub async fn overwrite"));
1885 }
1886
1887 #[test]
1888 fn test_overwrite_params_struct() {
1889 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1890 assert!(code.contains("pub struct OverwriteUsersParams"));
1891 }
1892
1893 #[test]
1894 fn test_overwrite_pk_in_fn_signature() {
1895 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1896 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1897 let params_pos = code[pos..]
1898 .find("OverwriteUsersParams")
1899 .expect("OverwriteUsersParams not found");
1900 let signature = &code[pos..pos + params_pos];
1901 assert!(
1902 signature.contains("id"),
1903 "Expected PK in overwrite fn signature: {}",
1904 signature
1905 );
1906 }
1907
1908 #[test]
1909 fn test_overwrite_pk_not_in_struct() {
1910 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1911 let struct_start = code
1912 .find("pub struct OverwriteUsersParams")
1913 .expect("OverwriteUsersParams not found");
1914 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1915 let struct_body = &code[struct_start..struct_end];
1916 assert!(
1917 !struct_body.contains("pub id"),
1918 "PK should not be in OverwriteUsersParams: {}",
1919 struct_body
1920 );
1921 }
1922
1923 #[test]
1924 fn test_overwrite_no_coalesce() {
1925 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1926 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1928 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1929 assert!(
1930 !method_body.contains("COALESCE"),
1931 "Overwrite should not use COALESCE: {}",
1932 method_body
1933 );
1934 }
1935
1936 #[test]
1937 fn test_overwrite_set_clause_pg() {
1938 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1939 assert!(code.contains("name = $1,"));
1940 assert!(code.contains("email = $2"));
1941 assert!(code.contains("WHERE id = $3"));
1942 }
1943
1944 #[test]
1945 fn test_overwrite_returning_pg() {
1946 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1947 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1948 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1949 assert!(
1950 method_body.contains("RETURNING *"),
1951 "Expected RETURNING * in overwrite"
1952 );
1953 }
1954
1955 #[test]
1956 fn test_view_no_overwrite() {
1957 let mut entity = standard_entity();
1958 entity.is_view = true;
1959 let code = gen(&entity, DatabaseKind::Postgres);
1960 assert!(!code.contains("pub async fn overwrite"));
1961 }
1962
1963 #[test]
1964 fn test_without_overwrite() {
1965 let m = Methods {
1966 overwrite: false,
1967 ..Methods::all()
1968 };
1969 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1970 assert!(!code.contains("pub async fn overwrite"));
1971 assert!(!code.contains("OverwriteUsersParams"));
1972 }
1973
1974 #[test]
1975 fn test_update_and_overwrite_coexist() {
1976 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1977 assert!(
1978 code.contains("pub async fn update"),
1979 "Expected update method"
1980 );
1981 assert!(
1982 code.contains("pub async fn overwrite"),
1983 "Expected overwrite method"
1984 );
1985 assert!(
1986 code.contains("UpdateUsersParams"),
1987 "Expected UpdateUsersParams"
1988 );
1989 assert!(
1990 code.contains("OverwriteUsersParams"),
1991 "Expected OverwriteUsersParams"
1992 );
1993 }
1994
1995 #[test]
1998 fn test_delete_method() {
1999 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2000 assert!(code.contains("pub async fn delete"));
2001 }
2002
2003 #[test]
2004 fn test_delete_where_pk() {
2005 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2006 assert!(code.contains("DELETE FROM users"));
2007 assert!(code.contains("WHERE id = $1"));
2008 }
2009
2010 #[test]
2011 fn test_tab_spaces_2_sql_indent() {
2012 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
2013 assert!(
2015 code.contains(" SELECT *"),
2016 "Expected SQL at 8-space indent:\n{}",
2017 code
2018 );
2019 assert!(
2020 code.contains(" \"#"),
2021 "Expected closing tag at 6-space indent:\n{}",
2022 code
2023 );
2024 }
2025
2026 #[test]
2027 fn test_tab_spaces_4_sql_indent() {
2028 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
2029 assert!(
2031 code.contains(" SELECT *"),
2032 "Expected SQL at 12-space indent:\n{}",
2033 code
2034 );
2035 assert!(
2036 code.contains(" \"#"),
2037 "Expected closing tag at 8-space indent:\n{}",
2038 code
2039 );
2040 }
2041
2042 #[test]
2043 fn test_delete_returns_unit() {
2044 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2045 assert!(
2046 code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>")
2047 );
2048 }
2049
2050 #[test]
2053 fn test_view_no_insert() {
2054 let mut entity = standard_entity();
2055 entity.is_view = true;
2056 let code = gen(&entity, DatabaseKind::Postgres);
2057 assert!(!code.contains("pub async fn insert"));
2058 }
2059
2060 #[test]
2061 fn test_view_no_update() {
2062 let mut entity = standard_entity();
2063 entity.is_view = true;
2064 let code = gen(&entity, DatabaseKind::Postgres);
2065 assert!(!code.contains("pub async fn update"));
2066 }
2067
2068 #[test]
2069 fn test_view_no_delete() {
2070 let mut entity = standard_entity();
2071 entity.is_view = true;
2072 let code = gen(&entity, DatabaseKind::Postgres);
2073 assert!(!code.contains("pub async fn delete"));
2074 }
2075
2076 #[test]
2077 fn test_view_has_get_all() {
2078 let mut entity = standard_entity();
2079 entity.is_view = true;
2080 let code = gen(&entity, DatabaseKind::Postgres);
2081 assert!(code.contains("pub async fn get_all"));
2082 }
2083
2084 #[test]
2085 fn test_view_has_paginate() {
2086 let mut entity = standard_entity();
2087 entity.is_view = true;
2088 let code = gen(&entity, DatabaseKind::Postgres);
2089 assert!(code.contains("pub async fn paginate"));
2090 }
2091
2092 #[test]
2093 fn test_view_has_get() {
2094 let mut entity = standard_entity();
2095 entity.is_view = true;
2096 let code = gen(&entity, DatabaseKind::Postgres);
2097 assert!(code.contains("pub async fn get"));
2098 }
2099
2100 #[test]
2103 fn test_only_get_all() {
2104 let m = Methods {
2105 get_all: true,
2106 ..Default::default()
2107 };
2108 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2109 assert!(code.contains("pub async fn get_all"));
2110 assert!(!code.contains("pub async fn paginate"));
2111 assert!(!code.contains("pub async fn insert"));
2112 }
2113
2114 #[test]
2115 fn test_without_get_all() {
2116 let m = Methods {
2117 get_all: false,
2118 ..Methods::all()
2119 };
2120 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2121 assert!(!code.contains("pub async fn get_all"));
2122 }
2123
2124 #[test]
2125 fn test_without_paginate() {
2126 let m = Methods {
2127 paginate: false,
2128 ..Methods::all()
2129 };
2130 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2131 assert!(!code.contains("pub async fn paginate"));
2132 assert!(!code.contains("PaginateUsersParams"));
2133 }
2134
2135 #[test]
2136 fn test_without_get() {
2137 let m = Methods {
2138 get: false,
2139 ..Methods::all()
2140 };
2141 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2142 assert!(code.contains("pub async fn get_all"));
2143 let without_get_all = code.replace("get_all", "XXX");
2144 assert!(!without_get_all.contains("fn get("));
2145 }
2146
2147 #[test]
2148 fn test_without_insert() {
2149 let m = Methods {
2150 insert: false,
2151 insert_many: false,
2152 ..Methods::all()
2153 };
2154 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2155 assert!(!code.contains("pub async fn insert"));
2156 assert!(!code.contains("InsertUsersParams"));
2157 }
2158
2159 #[test]
2160 fn test_without_update() {
2161 let m = Methods {
2162 update: false,
2163 ..Methods::all()
2164 };
2165 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2166 assert!(!code.contains("pub async fn update"));
2167 assert!(!code.contains("UpdateUsersParams"));
2168 }
2169
2170 #[test]
2171 fn test_without_delete() {
2172 let m = Methods {
2173 delete: false,
2174 ..Methods::all()
2175 };
2176 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2177 assert!(!code.contains("pub async fn delete"));
2178 }
2179
2180 #[test]
2181 fn test_empty_methods_no_methods() {
2182 let m = Methods::default();
2183 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2184 assert!(!code.contains("pub async fn get_all"));
2185 assert!(!code.contains("pub async fn paginate"));
2186 assert!(!code.contains("pub async fn insert"));
2187 assert!(!code.contains("pub async fn update"));
2188 assert!(!code.contains("pub async fn overwrite"));
2189 assert!(!code.contains("pub async fn delete"));
2190 assert!(!code.contains("pub async fn insert_many"));
2191 }
2192
2193 #[test]
2196 fn test_no_pool_import() {
2197 let skip = Methods::all();
2198 let (_, imports) = generate_crud_from_parsed(
2199 &standard_entity(),
2200 DatabaseKind::Postgres,
2201 "crate::models::users",
2202 &skip,
2203 false,
2204 PoolVisibility::Private,
2205 );
2206 assert!(!imports.iter().any(|i| i.contains("PgPool")));
2207 }
2208
2209 #[test]
2210 fn test_imports_contain_entity() {
2211 let skip = Methods::all();
2212 let (_, imports) = generate_crud_from_parsed(
2213 &standard_entity(),
2214 DatabaseKind::Postgres,
2215 "crate::models::users",
2216 &skip,
2217 false,
2218 PoolVisibility::Private,
2219 );
2220 assert!(imports
2221 .iter()
2222 .any(|i| i.contains("crate::models::users::Users")));
2223 }
2224
2225 #[test]
2228 fn test_renamed_column_in_sql() {
2229 let entity = ParsedEntity {
2230 struct_name: "Connector".to_string(),
2231 table_name: "connector".to_string(),
2232 schema_name: None,
2233 is_view: false,
2234 fields: vec![
2235 make_field("id", "id", "i32", false, true),
2236 make_field("connector_type", "type", "String", false, false),
2237 ],
2238 imports: vec![],
2239 };
2240 let code = gen(&entity, DatabaseKind::Postgres);
2241 assert!(code.contains("type"));
2243 assert!(code.contains("pub connector_type: String"));
2245 }
2246
2247 #[test]
2250 fn test_no_pk_no_get() {
2251 let entity = ParsedEntity {
2252 struct_name: "Logs".to_string(),
2253 table_name: "logs".to_string(),
2254 schema_name: None,
2255 is_view: false,
2256 fields: vec![
2257 make_field("message", "message", "String", false, false),
2258 make_field("ts", "ts", "String", false, false),
2259 ],
2260 imports: vec![],
2261 };
2262 let code = gen(&entity, DatabaseKind::Postgres);
2263 assert!(code.contains("pub async fn get_all"));
2264 let without_get_all = code.replace("get_all", "XXX");
2265 assert!(!without_get_all.contains("fn get("));
2266 }
2267
2268 #[test]
2269 fn test_no_pk_no_delete() {
2270 let entity = ParsedEntity {
2271 struct_name: "Logs".to_string(),
2272 table_name: "logs".to_string(),
2273 schema_name: None,
2274 is_view: false,
2275 fields: vec![make_field("message", "message", "String", false, false)],
2276 imports: vec![],
2277 };
2278 let code = gen(&entity, DatabaseKind::Postgres);
2279 assert!(!code.contains("pub async fn delete"));
2280 }
2281
2282 #[test]
2285 fn test_param_structs_have_default() {
2286 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2287 assert!(code.contains("Default"));
2288 }
2289
2290 #[test]
2293 fn test_entity_imports_forwarded() {
2294 let entity = ParsedEntity {
2295 struct_name: "Users".to_string(),
2296 table_name: "users".to_string(),
2297 schema_name: None,
2298 is_view: false,
2299 fields: vec![
2300 make_field("id", "id", "Uuid", false, true),
2301 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
2302 ],
2303 imports: vec![
2304 "use chrono::{DateTime, Utc};".to_string(),
2305 "use uuid::Uuid;".to_string(),
2306 ],
2307 };
2308 let skip = Methods::all();
2309 let (_, imports) = generate_crud_from_parsed(
2310 &entity,
2311 DatabaseKind::Postgres,
2312 "crate::models::users",
2313 &skip,
2314 false,
2315 PoolVisibility::Private,
2316 );
2317 assert!(imports.iter().any(|i| i.contains("chrono")));
2318 assert!(imports.iter().any(|i| i.contains("uuid")));
2319 }
2320
2321 #[test]
2322 fn test_entity_imports_empty_when_no_imports() {
2323 let skip = Methods::all();
2324 let (_, imports) = generate_crud_from_parsed(
2325 &standard_entity(),
2326 DatabaseKind::Postgres,
2327 "crate::models::users",
2328 &skip,
2329 false,
2330 PoolVisibility::Private,
2331 );
2332 assert!(!imports.iter().any(|i| i.contains("chrono")));
2334 assert!(!imports.iter().any(|i| i.contains("uuid")));
2335 }
2336
2337 #[test]
2340 fn test_macro_get_all() {
2341 let m = Methods {
2342 get_all: true,
2343 ..Default::default()
2344 };
2345 let (tokens, _) = generate_crud_from_parsed(
2346 &standard_entity(),
2347 DatabaseKind::Postgres,
2348 "crate::models::users",
2349 &m,
2350 true,
2351 PoolVisibility::Private,
2352 );
2353 let code = parse_and_format(&tokens).unwrap();
2354 assert!(code.contains("query_as!"));
2355 assert!(!code.contains("query_as::<"));
2356 }
2357
2358 #[test]
2359 fn test_macro_paginate() {
2360 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2361 assert!(code.contains("query_as!"));
2362 assert!(code.contains("per_page, offset"));
2363 }
2364
2365 #[test]
2366 fn test_macro_get() {
2367 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2368 assert!(code.contains("query_as!(Users"));
2370 }
2371
2372 #[test]
2373 fn test_macro_insert_pg() {
2374 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2375 assert!(code.contains("query_as!(Users"));
2376 assert!(code.contains("params.name"));
2377 assert!(code.contains("params.email"));
2378 }
2379
2380 #[test]
2381 fn test_macro_insert_mysql() {
2382 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
2383 assert!(code.contains("query!"));
2385 assert!(code.contains("query_scalar!"));
2386 }
2387
2388 #[test]
2389 fn test_macro_update() {
2390 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2391 assert!(code.contains("query_as!(Users"));
2392 assert!(
2393 code.contains("COALESCE"),
2394 "Expected COALESCE in macro update:\n{}",
2395 code
2396 );
2397 assert!(code.contains("pub async fn update"));
2398 assert!(code.contains("UpdateUsersParams"));
2399 }
2400
2401 #[test]
2402 fn test_macro_delete() {
2403 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2404 assert!(code.contains("query!"));
2406 }
2407
2408 #[test]
2409 fn test_macro_no_bind_calls() {
2410 let m = Methods {
2412 insert_many: false,
2413 ..Methods::all()
2414 };
2415 let (tokens, _) = generate_crud_from_parsed(
2416 &standard_entity(),
2417 DatabaseKind::Postgres,
2418 "crate::models::users",
2419 &m,
2420 true,
2421 PoolVisibility::Private,
2422 );
2423 let code = parse_and_format(&tokens).unwrap();
2424 assert!(!code.contains(".bind("));
2425 }
2426
2427 #[test]
2428 fn test_function_style_uses_bind() {
2429 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2430 assert!(code.contains(".bind("));
2431 assert!(!code.contains("query_as!("));
2432 assert!(!code.contains("query!("));
2433 }
2434
2435 fn entity_with_sql_array() -> ParsedEntity {
2438 ParsedEntity {
2439 struct_name: "AgentConnector".to_string(),
2440 table_name: "agent.agent_connector".to_string(),
2441 schema_name: Some("agent".to_string()),
2442 is_view: false,
2443 fields: vec![
2444 ParsedField {
2445 rust_name: "connector_id".to_string(),
2446 column_name: "connector_id".to_string(),
2447 rust_type: "Uuid".to_string(),
2448 inner_type: "Uuid".to_string(),
2449 is_nullable: false,
2450 is_primary_key: true,
2451 sql_type: None,
2452 is_sql_array: false,
2453 column_default: None,
2454 },
2455 ParsedField {
2456 rust_name: "agent_id".to_string(),
2457 column_name: "agent_id".to_string(),
2458 rust_type: "Uuid".to_string(),
2459 inner_type: "Uuid".to_string(),
2460 is_nullable: false,
2461 is_primary_key: false,
2462 sql_type: None,
2463 is_sql_array: false,
2464 column_default: None,
2465 },
2466 ParsedField {
2467 rust_name: "usages".to_string(),
2468 column_name: "usages".to_string(),
2469 rust_type: "Vec<ConnectorUsages>".to_string(),
2470 inner_type: "Vec<ConnectorUsages>".to_string(),
2471 is_nullable: false,
2472 is_primary_key: false,
2473 sql_type: Some("agent.connector_usages".to_string()),
2474 is_sql_array: true,
2475 column_default: None,
2476 },
2477 ],
2478 imports: vec!["use uuid::Uuid;".to_string()],
2479 }
2480 }
2481
2482 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2483 let skip = Methods::all();
2484 let (tokens, _) = generate_crud_from_parsed(
2485 entity,
2486 db,
2487 "crate::models::agent_connector",
2488 &skip,
2489 true,
2490 PoolVisibility::Private,
2491 );
2492 parse_and_format(&tokens).unwrap()
2493 }
2494
2495 #[test]
2496 fn test_sql_array_macro_get_all_uses_runtime() {
2497 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2498 assert!(code.contains("query_as::<"));
2500 }
2501
2502 #[test]
2503 fn test_sql_array_macro_get_uses_runtime() {
2504 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2505 assert!(code.contains(".bind("));
2507 }
2508
2509 #[test]
2510 fn test_sql_array_macro_insert_uses_runtime() {
2511 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2512 assert!(
2514 code.contains("query_as::<_ , AgentConnector>")
2515 || code.contains("query_as::<_, AgentConnector>")
2516 );
2517 }
2518
2519 #[test]
2520 fn test_sql_array_macro_delete_still_uses_macro() {
2521 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2522 assert!(code.contains("query!"));
2524 }
2525
2526 #[test]
2527 fn test_sql_array_no_query_as_macro() {
2528 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2529 assert!(!code.contains("query_as!("));
2531 }
2532
2533 fn entity_with_sql_enum() -> ParsedEntity {
2536 ParsedEntity {
2537 struct_name: "Task".to_string(),
2538 table_name: "tasks".to_string(),
2539 schema_name: None,
2540 is_view: false,
2541 fields: vec![
2542 ParsedField {
2543 rust_name: "id".to_string(),
2544 column_name: "id".to_string(),
2545 rust_type: "i32".to_string(),
2546 inner_type: "i32".to_string(),
2547 is_nullable: false,
2548 is_primary_key: true,
2549 sql_type: None,
2550 is_sql_array: false,
2551 column_default: None,
2552 },
2553 ParsedField {
2554 rust_name: "status".to_string(),
2555 column_name: "status".to_string(),
2556 rust_type: "TaskStatus".to_string(),
2557 inner_type: "TaskStatus".to_string(),
2558 is_nullable: false,
2559 is_primary_key: false,
2560 sql_type: Some("task_status".to_string()),
2561 is_sql_array: false,
2562 column_default: None,
2563 },
2564 ],
2565 imports: vec![],
2566 }
2567 }
2568
2569 #[test]
2570 fn test_sql_enum_macro_uses_runtime() {
2571 let skip = Methods::all();
2572 let (tokens, _) = generate_crud_from_parsed(
2573 &entity_with_sql_enum(),
2574 DatabaseKind::Postgres,
2575 "crate::models::task",
2576 &skip,
2577 true,
2578 PoolVisibility::Private,
2579 );
2580 let code = parse_and_format(&tokens).unwrap();
2581 assert!(code.contains("query_as::<"));
2583 assert!(!code.contains("query_as!("));
2584 }
2585
2586 #[test]
2587 fn test_sql_enum_macro_delete_still_uses_macro() {
2588 let skip = Methods::all();
2589 let (tokens, _) = generate_crud_from_parsed(
2590 &entity_with_sql_enum(),
2591 DatabaseKind::Postgres,
2592 "crate::models::task",
2593 &skip,
2594 true,
2595 PoolVisibility::Private,
2596 );
2597 let code = parse_and_format(&tokens).unwrap();
2598 assert!(code.contains("query!"));
2600 }
2601
2602 fn entity_with_vec_string() -> ParsedEntity {
2605 ParsedEntity {
2606 struct_name: "PromptHistory".to_string(),
2607 table_name: "prompt_history".to_string(),
2608 schema_name: None,
2609 is_view: false,
2610 fields: vec![
2611 ParsedField {
2612 rust_name: "id".to_string(),
2613 column_name: "id".to_string(),
2614 rust_type: "Uuid".to_string(),
2615 inner_type: "Uuid".to_string(),
2616 is_nullable: false,
2617 is_primary_key: true,
2618 sql_type: None,
2619 is_sql_array: false,
2620 column_default: None,
2621 },
2622 ParsedField {
2623 rust_name: "content".to_string(),
2624 column_name: "content".to_string(),
2625 rust_type: "String".to_string(),
2626 inner_type: "String".to_string(),
2627 is_nullable: false,
2628 is_primary_key: false,
2629 sql_type: None,
2630 is_sql_array: false,
2631 column_default: None,
2632 },
2633 ParsedField {
2634 rust_name: "tags".to_string(),
2635 column_name: "tags".to_string(),
2636 rust_type: "Vec<String>".to_string(),
2637 inner_type: "Vec<String>".to_string(),
2638 is_nullable: false,
2639 is_primary_key: false,
2640 sql_type: None,
2641 is_sql_array: false,
2642 column_default: None,
2643 },
2644 ],
2645 imports: vec!["use uuid::Uuid;".to_string()],
2646 }
2647 }
2648
2649 #[test]
2650 fn test_vec_string_macro_insert_uses_as_slice() {
2651 let skip = Methods::all();
2652 let (tokens, _) = generate_crud_from_parsed(
2653 &entity_with_vec_string(),
2654 DatabaseKind::Postgres,
2655 "crate::models::prompt_history",
2656 &skip,
2657 true,
2658 PoolVisibility::Private,
2659 );
2660 let code = parse_and_format(&tokens).unwrap();
2661 assert!(code.contains("as_slice()"));
2662 }
2663
2664 #[test]
2665 fn test_vec_string_macro_update_uses_as_slice() {
2666 let skip = Methods::all();
2667 let (tokens, _) = generate_crud_from_parsed(
2668 &entity_with_vec_string(),
2669 DatabaseKind::Postgres,
2670 "crate::models::prompt_history",
2671 &skip,
2672 true,
2673 PoolVisibility::Private,
2674 );
2675 let code = parse_and_format(&tokens).unwrap();
2676 let count = code.matches("as_slice()").count();
2678 assert!(
2679 count >= 2,
2680 "expected at least 2 as_slice() calls (insert + update), found {}",
2681 count
2682 );
2683 }
2684
2685 #[test]
2686 fn test_vec_string_non_macro_no_as_slice() {
2687 let skip = Methods::all();
2688 let (tokens, _) = generate_crud_from_parsed(
2689 &entity_with_vec_string(),
2690 DatabaseKind::Postgres,
2691 "crate::models::prompt_history",
2692 &skip,
2693 false,
2694 PoolVisibility::Private,
2695 );
2696 let code = parse_and_format(&tokens).unwrap();
2697 assert!(!code.contains("as_slice()"));
2699 }
2700
2701 #[test]
2702 fn test_vec_string_parsed_from_source_uses_as_slice() {
2703 use crate::codegen::entity_parser::parse_entity_source;
2704 let source = r#"
2705 use uuid::Uuid;
2706
2707 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2708 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2709 pub struct PromptHistory {
2710 #[sqlx_gen(primary_key)]
2711 pub id: Uuid,
2712 pub content: String,
2713 pub tags: Vec<String>,
2714 }
2715 "#;
2716 let entity = parse_entity_source(source).unwrap();
2717 let skip = Methods::all();
2718 let (tokens, _) = generate_crud_from_parsed(
2719 &entity,
2720 DatabaseKind::Postgres,
2721 "crate::models::prompt_history",
2722 &skip,
2723 true,
2724 PoolVisibility::Private,
2725 );
2726 let code = parse_and_format(&tokens).unwrap();
2727 assert!(
2728 code.contains("as_slice()"),
2729 "Expected as_slice() in generated code:\n{}",
2730 code
2731 );
2732 }
2733
2734 fn junction_entity() -> ParsedEntity {
2737 ParsedEntity {
2738 struct_name: "AnalysisRecord".to_string(),
2739 table_name: "analysis__record".to_string(),
2740 schema_name: Some("analysis".to_string()),
2741 is_view: false,
2742 fields: vec![
2743 make_field("record_id", "record_id", "uuid::Uuid", false, true),
2744 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2745 ],
2746 imports: vec![],
2747 }
2748 }
2749
2750 #[test]
2751 fn test_composite_pk_only_insert_generated() {
2752 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2753 assert!(
2754 code.contains("pub struct InsertAnalysisRecordParams"),
2755 "Expected InsertAnalysisRecordParams struct:\n{}",
2756 code
2757 );
2758 assert!(
2759 code.contains("pub record_id"),
2760 "Expected record_id field in insert params:\n{}",
2761 code
2762 );
2763 assert!(
2764 code.contains("pub analysis_id"),
2765 "Expected analysis_id field in insert params:\n{}",
2766 code
2767 );
2768 assert!(
2769 code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"),
2770 "Expected quoted INSERT INTO clause:\n{}",
2771 code
2772 );
2773 assert!(
2774 code.contains("VALUES ($1, $2)"),
2775 "Expected VALUES clause:\n{}",
2776 code
2777 );
2778 assert!(
2779 code.contains("RETURNING *"),
2780 "Expected RETURNING clause:\n{}",
2781 code
2782 );
2783 assert!(
2784 code.contains("pub async fn insert"),
2785 "Expected insert method:\n{}",
2786 code
2787 );
2788 }
2789
2790 #[test]
2791 fn test_composite_pk_only_no_update() {
2792 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2793 assert!(
2794 !code.contains("UpdateAnalysisRecordParams"),
2795 "Expected no UpdateAnalysisRecordParams struct:\n{}",
2796 code
2797 );
2798 assert!(
2799 !code.contains("pub async fn update"),
2800 "Expected no update method:\n{}",
2801 code
2802 );
2803 }
2804
2805 #[test]
2806 fn test_composite_pk_only_delete_generated() {
2807 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2808 assert!(
2809 code.contains("pub async fn delete"),
2810 "Expected delete method:\n{}",
2811 code
2812 );
2813 assert!(
2814 code.contains("DELETE FROM analysis.analysis__record"),
2815 "Expected DELETE clause:\n{}",
2816 code
2817 );
2818 assert!(
2819 code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2820 "Expected WHERE clause:\n{}",
2821 code
2822 );
2823 }
2824
2825 #[test]
2826 fn test_composite_pk_only_get_generated() {
2827 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2828 assert!(
2829 code.contains("pub async fn get"),
2830 "Expected get method:\n{}",
2831 code
2832 );
2833 assert!(
2834 code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2835 "Expected WHERE clause with both PK columns:\n{}",
2836 code
2837 );
2838 }
2839
2840 fn composite_pk_with_extra() -> ParsedEntity {
2843 ParsedEntity {
2844 struct_name: "OrderItems".to_string(),
2845 table_name: "order_items".to_string(),
2846 schema_name: None,
2847 is_view: false,
2848 fields: vec![
2849 make_field("order_id", "order_id", "i32", false, true),
2850 make_field("product_id", "product_id", "i32", false, true),
2851 make_field("qty", "qty", "i32", false, false),
2852 ],
2853 imports: vec![],
2854 }
2855 }
2856
2857 #[test]
2858 fn test_mysql_composite_pk_insert_uses_select_not_last_insert_id() {
2859 let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql);
2860 assert!(
2861 !code.contains("LAST_INSERT_ID"),
2862 "composite PK insert must not use LAST_INSERT_ID(), got:\n{}",
2863 code
2864 );
2865 assert!(
2866 code.contains("SELECT *"),
2867 "must SELECT the row back after INSERT, got:\n{}",
2868 code
2869 );
2870 assert!(
2871 code.contains("WHERE order_id = ? AND product_id = ?"),
2872 "SELECT must use bound composite PK values, got:\n{}",
2873 code
2874 );
2875 }
2876
2877 #[test]
2878 fn test_mysql_composite_pk_includes_pks_in_insert_params() {
2879 let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql);
2880 assert!(
2881 code.contains("pub order_id"),
2882 "InsertParams must expose composite PK column order_id, got:\n{}",
2883 code
2884 );
2885 assert!(code.contains("pub product_id"));
2886 assert!(code.contains("pub qty"));
2887 }
2888
2889 #[test]
2890 fn test_mysql_single_pk_insert_still_uses_last_insert_id() {
2891 let code = gen(&standard_entity(), DatabaseKind::Mysql);
2892 assert!(
2893 code.contains("LAST_INSERT_ID"),
2894 "single-PK MySQL insert should still rely on LAST_INSERT_ID()"
2895 );
2896 }
2897
2898 #[test]
2901 fn test_insert_many_transactionally_method_generated() {
2902 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2903 assert!(
2904 code.contains("pub async fn insert_many_transactionally"),
2905 "Expected insert_many_transactionally method:\n{}",
2906 code
2907 );
2908 }
2909
2910 #[test]
2911 fn test_insert_many_transactionally_signature() {
2912 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2913 assert!(
2914 code.contains("entries: Vec<InsertUsersParams>"),
2915 "Expected Vec<InsertUsersParams> param:\n{}",
2916 code
2917 );
2918 assert!(
2919 code.contains("Result<Vec<Users>"),
2920 "Expected Result<Vec<Users>> return type:\n{}",
2921 code
2922 );
2923 }
2924
2925 #[test]
2926 fn test_insert_many_transactionally_no_strategy_enum() {
2927 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2928 assert!(
2929 !code.contains("TransactionStrategy"),
2930 "TransactionStrategy should not be generated:\n{}",
2931 code
2932 );
2933 assert!(
2934 !code.contains("InsertManyUsersResult"),
2935 "InsertManyUsersResult should not be generated:\n{}",
2936 code
2937 );
2938 }
2939
2940 #[test]
2941 fn test_insert_many_transactionally_uses_transaction_pg() {
2942 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2943 let method_start = code
2944 .find("fn insert_many_transactionally")
2945 .expect("insert_many_transactionally not found");
2946 let method_body = &code[method_start..];
2947 assert!(
2948 method_body.contains("self.pool.begin()"),
2949 "Expected begin():\n{}",
2950 method_body
2951 );
2952 assert!(
2953 method_body.contains("tx.commit()"),
2954 "Expected commit():\n{}",
2955 method_body
2956 );
2957 }
2958
2959 #[test]
2960 fn test_insert_many_transactionally_multi_row_pg() {
2961 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2962 let method_start = code
2963 .find("fn insert_many_transactionally")
2964 .expect("not found");
2965 let method_body = &code[method_start..];
2966 assert!(
2967 method_body.contains("RETURNING *"),
2968 "Expected RETURNING * in multi-row SQL:\n{}",
2969 method_body
2970 );
2971 assert!(
2972 method_body.contains("values_parts"),
2973 "Expected multi-row VALUES building:\n{}",
2974 method_body
2975 );
2976 assert!(
2977 method_body.contains("65535"),
2978 "Expected chunk size limit:\n{}",
2979 method_body
2980 );
2981 }
2982
2983 #[test]
2984 fn test_insert_many_transactionally_multi_row_sqlite() {
2985 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2986 let method_start = code
2987 .find("fn insert_many_transactionally")
2988 .expect("not found");
2989 let method_body = &code[method_start..];
2990 assert!(
2991 method_body.contains("values_parts"),
2992 "Expected multi-row VALUES building for SQLite:\n{}",
2993 method_body
2994 );
2995 assert!(
2996 method_body.contains("RETURNING *"),
2997 "Expected RETURNING * for SQLite:\n{}",
2998 method_body
2999 );
3000 }
3001
3002 #[test]
3003 fn test_insert_many_transactionally_mysql_individual_inserts() {
3004 let code = gen(&standard_entity(), DatabaseKind::Mysql);
3005 let method_start = code
3006 .find("fn insert_many_transactionally")
3007 .expect("not found");
3008 let method_body = &code[method_start..];
3009 assert!(
3010 method_body.contains("LAST_INSERT_ID"),
3011 "Expected LAST_INSERT_ID for MySQL:\n{}",
3012 method_body
3013 );
3014 assert!(
3015 method_body.contains("self.pool.begin()"),
3016 "Expected begin() for MySQL:\n{}",
3017 method_body
3018 );
3019 }
3020
3021 #[test]
3022 fn test_insert_many_transactionally_view_not_generated() {
3023 let mut entity = standard_entity();
3024 entity.is_view = true;
3025 let code = gen(&entity, DatabaseKind::Postgres);
3026 assert!(
3027 !code.contains("pub async fn insert_many_transactionally"),
3028 "should not be generated for views"
3029 );
3030 }
3031
3032 #[test]
3033 fn test_insert_many_transactionally_without_method_not_generated() {
3034 let m = Methods {
3035 insert_many: false,
3036 ..Methods::all()
3037 };
3038 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
3039 assert!(
3040 !code.contains("pub async fn insert_many_transactionally"),
3041 "should not be generated when disabled"
3042 );
3043 }
3044
3045 #[test]
3046 fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
3047 let m = Methods {
3048 insert: false,
3049 insert_many: true,
3050 ..Default::default()
3051 };
3052 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
3053 assert!(
3054 code.contains("pub struct InsertUsersParams"),
3055 "Expected InsertUsersParams:\n{}",
3056 code
3057 );
3058 assert!(
3059 code.contains("pub async fn insert_many_transactionally"),
3060 "Expected method:\n{}",
3061 code
3062 );
3063 assert!(
3064 !code.contains("pub async fn insert("),
3065 "insert should not be present:\n{}",
3066 code
3067 );
3068 }
3069
3070 #[test]
3071 fn test_insert_many_transactionally_with_column_defaults_coalesce() {
3072 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
3073 let method_start = code
3074 .find("fn insert_many_transactionally")
3075 .expect("not found");
3076 let method_body = &code[method_start..];
3077 assert!(
3078 method_body.contains("COALESCE"),
3079 "Expected COALESCE for fields with defaults:\n{}",
3080 method_body
3081 );
3082 }
3083
3084 #[test]
3085 fn test_insert_many_transactionally_junction_table() {
3086 let code = gen(&junction_entity(), DatabaseKind::Postgres);
3087 assert!(
3088 code.contains("pub async fn insert_many_transactionally"),
3089 "Expected method for junction table:\n{}",
3090 code
3091 );
3092 }
3093
3094 #[test]
3095 fn test_insert_many_transactionally_all_three_backends_compile() {
3096 for db in [
3097 DatabaseKind::Postgres,
3098 DatabaseKind::Mysql,
3099 DatabaseKind::Sqlite,
3100 ] {
3101 let code = gen(&standard_entity(), db);
3102 assert!(
3103 code.contains("pub async fn insert_many_transactionally"),
3104 "Expected method for {:?}",
3105 db
3106 );
3107 }
3108 }
3109}