Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!(
39        "use {}::{};",
40        entity_module_path, entity.struct_name
41    ));
42
43    // Forward type imports from the entity file (chrono, uuid, etc.)
44    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
45    // since the repository lives in a different module where `super` has a different meaning.
46    let entity_parent = entity_module_path
47        .rsplit_once("::")
48        .map(|(parent, _)| parent)
49        .unwrap_or(entity_module_path);
50    for imp in &entity.imports {
51        if let Some(rest) = imp.strip_prefix("use super::") {
52            imports.insert(format!("use {}::{}", entity_parent, rest));
53        } else {
54            imports.insert(imp.clone());
55        }
56    }
57
58    // Primary key fields
59    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
60
61    // Non-PK fields (for insert)
62    let non_pk_fields: Vec<&ParsedField> =
63        entity.fields.iter().filter(|f| !f.is_primary_key).collect();
64
65    let is_view = entity.is_view;
66
67    // Build method tokens
68    let mut method_tokens = Vec::new();
69    let mut param_structs = Vec::new();
70
71    // --- get_all ---
72    if methods.get_all {
73        let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
74        let method = if use_macro {
75            quote! {
76                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
77                    sqlx::query_as!(#entity_ident, #sql)
78                        .fetch_all(&self.pool)
79                        .await
80                }
81            }
82        } else {
83            quote! {
84                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
85                    sqlx::query_as::<_, #entity_ident>(#sql)
86                        .fetch_all(&self.pool)
87                        .await
88                }
89            }
90        };
91        method_tokens.push(method);
92    }
93
94    // --- paginate ---
95    if methods.paginate {
96        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
97        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
98        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
99        let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
100        let sql = raw_sql_lit(&match db_kind {
101            DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
102            DatabaseKind::Mysql | DatabaseKind::Sqlite => {
103                format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name)
104            }
105        });
106        let method = if use_macro {
107            quote! {
108                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
109                    let total: i64 = sqlx::query_scalar!(#count_sql)
110                        .fetch_one(&self.pool)
111                        .await?
112                        .unwrap_or(0);
113                    let per_page = params.per_page;
114                    let current_page = params.page;
115                    let last_page = (total + per_page - 1) / per_page;
116                    let offset = (current_page - 1) * per_page;
117                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
118                        .fetch_all(&self.pool)
119                        .await?;
120                    Ok(#paginated_ident {
121                        meta: #pagination_meta_ident {
122                            total,
123                            per_page,
124                            current_page,
125                            last_page,
126                            first_page: 1,
127                        },
128                        data,
129                    })
130                }
131            }
132        } else {
133            quote! {
134                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
135                    let total: i64 = sqlx::query_scalar(#count_sql)
136                        .fetch_one(&self.pool)
137                        .await?;
138                    let per_page = params.per_page;
139                    let current_page = params.page;
140                    let last_page = (total + per_page - 1) / per_page;
141                    let offset = (current_page - 1) * per_page;
142                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
143                        .bind(per_page)
144                        .bind(offset)
145                        .fetch_all(&self.pool)
146                        .await?;
147                    Ok(#paginated_ident {
148                        meta: #pagination_meta_ident {
149                            total,
150                            per_page,
151                            current_page,
152                            last_page,
153                            first_page: 1,
154                        },
155                        data,
156                    })
157                }
158            }
159        };
160        method_tokens.push(method);
161        param_structs.push(quote! {
162            #[derive(Debug, Clone, Default)]
163            pub struct #paginate_params_ident {
164                pub page: i64,
165                pub per_page: i64,
166            }
167        });
168        param_structs.push(quote! {
169            #[derive(Debug, Clone)]
170            pub struct #pagination_meta_ident {
171                pub total: i64,
172                pub per_page: i64,
173                pub current_page: i64,
174                pub last_page: i64,
175                pub first_page: i64,
176            }
177        });
178        param_structs.push(quote! {
179            #[derive(Debug, Clone)]
180            pub struct #paginated_ident {
181                pub meta: #pagination_meta_ident,
182                pub data: Vec<#entity_ident>,
183            }
184        });
185    }
186
187    // --- get (by PK) ---
188    if methods.get && !pk_fields.is_empty() {
189        let pk_params: Vec<TokenStream> = pk_fields
190            .iter()
191            .map(|f| {
192                let name = format_ident!("{}", f.rust_name);
193                let ty: TokenStream = f.inner_type.parse().unwrap();
194                quote! { #name: #ty }
195            })
196            .collect();
197
198        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
199        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
200        let sql = raw_sql_lit(&format!(
201            "SELECT *\nFROM {}\nWHERE {}",
202            table_name, where_clause
203        ));
204        let sql_macro = raw_sql_lit(&format!(
205            "SELECT *\nFROM {}\nWHERE {}",
206            table_name, where_clause_cast
207        ));
208
209        let binds: Vec<TokenStream> = pk_fields
210            .iter()
211            .map(|f| {
212                let name = format_ident!("{}", f.rust_name);
213                quote! { .bind(#name) }
214            })
215            .collect();
216
217        let method = if use_macro {
218            let pk_arg_names: Vec<TokenStream> = pk_fields
219                .iter()
220                .map(|f| {
221                    let name = format_ident!("{}", f.rust_name);
222                    quote! { #name }
223                })
224                .collect();
225            quote! {
226                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
227                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
228                        .fetch_optional(&self.pool)
229                        .await
230                }
231            }
232        } else {
233            quote! {
234                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
235                    sqlx::query_as::<_, #entity_ident>(#sql)
236                        #(#binds)*
237                        .fetch_optional(&self.pool)
238                        .await
239                }
240            }
241        };
242        method_tokens.push(method);
243    }
244
245    // --- insert (skip for views) ---
246    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
247        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
248
249        // When all columns are PKs (e.g. junction tables), use pk_fields for insert
250        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
251            pk_fields.clone()
252        } else {
253            non_pk_fields.clone()
254        };
255
256        // Fields with column_default and not already nullable → Option<T>
257        let insert_fields: Vec<TokenStream> = insert_source_fields
258            .iter()
259            .map(|f| {
260                let name = format_ident!("{}", f.rust_name);
261                if f.column_default.is_some() && !f.is_nullable {
262                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
263                    quote! { pub #name: #ty, }
264                } else {
265                    let ty: TokenStream = f.rust_type.parse().unwrap();
266                    quote! { pub #name: #ty, }
267                }
268            })
269            .collect();
270
271        let col_names: Vec<&str> = insert_source_fields
272            .iter()
273            .map(|f| f.column_name.as_str())
274            .collect();
275        let col_list = col_names.join(", ");
276
277        // Build placeholders with COALESCE for fields that have a column_default
278        let placeholders: String = insert_source_fields
279            .iter()
280            .enumerate()
281            .map(|(i, f)| {
282                let p = placeholder(db_kind, i + 1);
283                match &f.column_default {
284                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
285                    None => p,
286                }
287            })
288            .collect::<Vec<_>>()
289            .join(", ");
290
291        let placeholders_cast: String = insert_source_fields
292            .iter()
293            .enumerate()
294            .map(|(i, f)| {
295                let p = placeholder_with_cast(db_kind, i + 1, f);
296                match &f.column_default {
297                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
298                    None => p,
299                }
300            })
301            .collect::<Vec<_>>()
302            .join(", ");
303
304        let build_insert_sql = |ph: &str| match db_kind {
305            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
306                format!(
307                    "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
308                    table_name, col_list, ph
309                )
310            }
311            DatabaseKind::Mysql => {
312                format!("INSERT INTO {} ({})\nVALUES ({})", table_name, col_list, ph)
313            }
314        };
315        let sql = build_insert_sql(&placeholders);
316        let sql_macro = build_insert_sql(&placeholders_cast);
317
318        let binds: Vec<TokenStream> = insert_source_fields
319            .iter()
320            .map(|f| {
321                let name = format_ident!("{}", f.rust_name);
322                quote! { .bind(&params.#name) }
323            })
324            .collect();
325
326        let insert_method = build_insert_method_parsed(
327            &entity_ident,
328            &insert_params_ident,
329            &sql,
330            &sql_macro,
331            &binds,
332            db_kind,
333            &table_name,
334            &pk_fields,
335            &insert_source_fields,
336            use_macro,
337        );
338        method_tokens.push(insert_method);
339
340        param_structs.push(quote! {
341            #[derive(Debug, Clone, Default)]
342            pub struct #insert_params_ident {
343                #(#insert_fields)*
344            }
345        });
346    }
347
348    // --- insert_many_transactionally (skip for views) ---
349    if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
350        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
351
352        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
353            pk_fields.clone()
354        } else {
355            non_pk_fields.clone()
356        };
357
358        let col_names: Vec<&str> = insert_source_fields
359            .iter()
360            .map(|f| f.column_name.as_str())
361            .collect();
362        let col_list = col_names.join(", ");
363        let num_cols = insert_source_fields.len();
364
365        let binds_loop: Vec<TokenStream> = insert_source_fields
366            .iter()
367            .map(|f| {
368                let name = format_ident!("{}", f.rust_name);
369                quote! { query = query.bind(&params.#name); }
370            })
371            .collect();
372
373        let insert_many_method = build_insert_many_transactionally_method(
374            &entity_ident,
375            &insert_params_ident,
376            &col_list,
377            num_cols,
378            &insert_source_fields,
379            &binds_loop,
380            db_kind,
381            &table_name,
382            &pk_fields,
383        );
384        method_tokens.push(insert_many_method);
385
386        // Only generate InsertParams if we haven't generated it from the insert method
387        if !methods.insert {
388            let insert_fields: Vec<TokenStream> = insert_source_fields
389                .iter()
390                .map(|f| {
391                    let name = format_ident!("{}", f.rust_name);
392                    if f.column_default.is_some() && !f.is_nullable {
393                        let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
394                        quote! { pub #name: #ty, }
395                    } else {
396                        let ty: TokenStream = f.rust_type.parse().unwrap();
397                        quote! { pub #name: #ty, }
398                    }
399                })
400                .collect();
401
402            param_structs.push(quote! {
403                #[derive(Debug, Clone, Default)]
404                pub struct #insert_params_ident {
405                    #(#insert_fields)*
406                }
407            });
408        }
409    }
410
411    // --- overwrite (full replacement — skip for views, skip when all columns are PKs) ---
412    if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
413        let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
414
415        // PK as function parameters (like get/delete)
416        let pk_fn_params: Vec<TokenStream> = pk_fields
417            .iter()
418            .map(|f| {
419                let name = format_ident!("{}", f.rust_name);
420                let ty: TokenStream = f.inner_type.parse().unwrap();
421                quote! { #name: #ty }
422            })
423            .collect();
424
425        // Non-PK fields keep original types (required)
426        let overwrite_fields: Vec<TokenStream> = non_pk_fields
427            .iter()
428            .map(|f| {
429                let name = format_ident!("{}", f.rust_name);
430                let ty: TokenStream = f.rust_type.parse().unwrap();
431                quote! { pub #name: #ty, }
432            })
433            .collect();
434
435        let set_cols: Vec<String> = non_pk_fields
436            .iter()
437            .enumerate()
438            .map(|(i, f)| {
439                let p = placeholder(db_kind, i + 1);
440                format!("{} = {}", f.column_name, p)
441            })
442            .collect();
443        let set_clause = set_cols.join(",\n  ");
444
445        let set_cols_cast: Vec<String> = non_pk_fields
446            .iter()
447            .enumerate()
448            .map(|(i, f)| {
449                let p = placeholder_with_cast(db_kind, i + 1, f);
450                format!("{} = {}", f.column_name, p)
451            })
452            .collect();
453        let set_clause_cast = set_cols_cast.join(",\n  ");
454
455        let pk_start = non_pk_fields.len() + 1;
456        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
457        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
458
459        let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
460            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
461                format!(
462                    "UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *",
463                    table_name, sc, wc
464                )
465            }
466            DatabaseKind::Mysql => {
467                format!("UPDATE {}\nSET\n  {}\nWHERE {}", table_name, sc, wc)
468            }
469        };
470        let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
471        let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
472
473        // Bind non-PK first (from params), then PK (from function args)
474        let mut all_binds: Vec<TokenStream> = non_pk_fields
475            .iter()
476            .map(|f| {
477                let name = format_ident!("{}", f.rust_name);
478                quote! { .bind(&params.#name) }
479            })
480            .collect();
481        for f in &pk_fields {
482            let name = format_ident!("{}", f.rust_name);
483            all_binds.push(quote! { .bind(#name) });
484        }
485
486        // Macro args: non-PK from params, then PK from function args
487        let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
488            .iter()
489            .map(|f| macro_arg_for_field(f))
490            .chain(pk_fields.iter().map(|f| {
491                let name = format_ident!("{}", f.rust_name);
492                quote! { #name }
493            }))
494            .collect();
495
496        let overwrite_method = if use_macro {
497            match db_kind {
498                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
499                    quote! {
500                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
501                            sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
502                                .fetch_one(&self.pool)
503                                .await
504                        }
505                    }
506                }
507                DatabaseKind::Mysql => {
508                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
509                    let select_sql = raw_sql_lit(&format!(
510                        "SELECT *\nFROM {}\nWHERE {}",
511                        table_name, pk_where_select
512                    ));
513                    let pk_macro_args: Vec<TokenStream> = pk_fields
514                        .iter()
515                        .map(|f| {
516                            let name = format_ident!("{}", f.rust_name);
517                            quote! { #name }
518                        })
519                        .collect();
520                    quote! {
521                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
522                            sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
523                                .execute(&self.pool)
524                                .await?;
525                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
526                                .fetch_one(&self.pool)
527                                .await
528                        }
529                    }
530                }
531            }
532        } else {
533            match db_kind {
534                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
535                    quote! {
536                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
537                            sqlx::query_as::<_, #entity_ident>(#sql)
538                                #(#all_binds)*
539                                .fetch_one(&self.pool)
540                                .await
541                        }
542                    }
543                }
544                DatabaseKind::Mysql => {
545                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
546                    let select_sql = raw_sql_lit(&format!(
547                        "SELECT *\nFROM {}\nWHERE {}",
548                        table_name, pk_where_select
549                    ));
550                    let pk_binds: Vec<TokenStream> = pk_fields
551                        .iter()
552                        .map(|f| {
553                            let name = format_ident!("{}", f.rust_name);
554                            quote! { .bind(#name) }
555                        })
556                        .collect();
557                    quote! {
558                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
559                            sqlx::query(#sql)
560                                #(#all_binds)*
561                                .execute(&self.pool)
562                                .await?;
563                            sqlx::query_as::<_, #entity_ident>(#select_sql)
564                                #(#pk_binds)*
565                                .fetch_one(&self.pool)
566                                .await
567                        }
568                    }
569                }
570            }
571        };
572        method_tokens.push(overwrite_method);
573
574        param_structs.push(quote! {
575            #[derive(Debug, Clone, Default)]
576            pub struct #overwrite_params_ident {
577                #(#overwrite_fields)*
578            }
579        });
580    }
581
582    // --- update / patch (COALESCE — skip for views, skip when all columns are PKs) ---
583    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
584        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
585
586        // PK as function parameters (like get/delete)
587        let pk_fn_params: Vec<TokenStream> = pk_fields
588            .iter()
589            .map(|f| {
590                let name = format_ident!("{}", f.rust_name);
591                let ty: TokenStream = f.inner_type.parse().unwrap();
592                quote! { #name: #ty }
593            })
594            .collect();
595
596        // Non-PK fields become Option<T> (no double Option for already nullable)
597        let update_fields: Vec<TokenStream> = non_pk_fields
598            .iter()
599            .map(|f| {
600                let name = format_ident!("{}", f.rust_name);
601                if f.is_nullable {
602                    // Already Option<T> — keep as-is to avoid Option<Option<T>>
603                    let ty: TokenStream = f.rust_type.parse().unwrap();
604                    quote! { pub #name: #ty, }
605                } else {
606                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
607                    quote! { pub #name: #ty, }
608                }
609            })
610            .collect();
611
612        // SET clause with COALESCE for runtime mode
613        let set_cols: Vec<String> = non_pk_fields
614            .iter()
615            .enumerate()
616            .map(|(i, f)| {
617                let p = placeholder(db_kind, i + 1);
618                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
619            })
620            .collect();
621        let set_clause = set_cols.join(",\n  ");
622
623        // SET clause with COALESCE and casts for macro mode
624        let set_cols_cast: Vec<String> = non_pk_fields
625            .iter()
626            .enumerate()
627            .map(|(i, f)| {
628                let p = placeholder_with_cast(db_kind, i + 1, f);
629                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
630            })
631            .collect();
632        let set_clause_cast = set_cols_cast.join(",\n  ");
633
634        let pk_start = non_pk_fields.len() + 1;
635        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
636        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
637
638        let build_update_sql = |sc: &str, wc: &str| match db_kind {
639            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
640                format!(
641                    "UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *",
642                    table_name, sc, wc
643                )
644            }
645            DatabaseKind::Mysql => {
646                format!("UPDATE {}\nSET\n  {}\nWHERE {}", table_name, sc, wc)
647            }
648        };
649        let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
650        let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
651
652        // Bind non-PK first (from params), then PK (from function args)
653        let mut all_binds: Vec<TokenStream> = non_pk_fields
654            .iter()
655            .map(|f| {
656                let name = format_ident!("{}", f.rust_name);
657                quote! { .bind(&params.#name) }
658            })
659            .collect();
660        for f in &pk_fields {
661            let name = format_ident!("{}", f.rust_name);
662            all_binds.push(quote! { .bind(#name) });
663        }
664
665        // Macro args: non-PK from params, then PK from function args
666        let update_macro_args: Vec<TokenStream> = non_pk_fields
667            .iter()
668            .map(|f| macro_arg_for_field(f))
669            .chain(pk_fields.iter().map(|f| {
670                let name = format_ident!("{}", f.rust_name);
671                quote! { #name }
672            }))
673            .collect();
674
675        let update_method = if use_macro {
676            match db_kind {
677                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
678                    quote! {
679                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
680                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
681                                .fetch_one(&self.pool)
682                                .await
683                        }
684                    }
685                }
686                DatabaseKind::Mysql => {
687                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
688                    let select_sql = raw_sql_lit(&format!(
689                        "SELECT *\nFROM {}\nWHERE {}",
690                        table_name, pk_where_select
691                    ));
692                    let pk_macro_args: Vec<TokenStream> = pk_fields
693                        .iter()
694                        .map(|f| {
695                            let name = format_ident!("{}", f.rust_name);
696                            quote! { #name }
697                        })
698                        .collect();
699                    quote! {
700                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
701                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
702                                .execute(&self.pool)
703                                .await?;
704                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
705                                .fetch_one(&self.pool)
706                                .await
707                        }
708                    }
709                }
710            }
711        } else {
712            match db_kind {
713                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
714                    quote! {
715                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
716                            sqlx::query_as::<_, #entity_ident>(#sql)
717                                #(#all_binds)*
718                                .fetch_one(&self.pool)
719                                .await
720                        }
721                    }
722                }
723                DatabaseKind::Mysql => {
724                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
725                    let select_sql = raw_sql_lit(&format!(
726                        "SELECT *\nFROM {}\nWHERE {}",
727                        table_name, pk_where_select
728                    ));
729                    let pk_binds: Vec<TokenStream> = pk_fields
730                        .iter()
731                        .map(|f| {
732                            let name = format_ident!("{}", f.rust_name);
733                            quote! { .bind(#name) }
734                        })
735                        .collect();
736                    quote! {
737                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
738                            sqlx::query(#sql)
739                                #(#all_binds)*
740                                .execute(&self.pool)
741                                .await?;
742                            sqlx::query_as::<_, #entity_ident>(#select_sql)
743                                #(#pk_binds)*
744                                .fetch_one(&self.pool)
745                                .await
746                        }
747                    }
748                }
749            }
750        };
751        method_tokens.push(update_method);
752
753        param_structs.push(quote! {
754            #[derive(Debug, Clone, Default)]
755            pub struct #update_params_ident {
756                #(#update_fields)*
757            }
758        });
759    }
760
761    // --- delete (skip for views) ---
762    if !is_view && methods.delete && !pk_fields.is_empty() {
763        let pk_params: Vec<TokenStream> = pk_fields
764            .iter()
765            .map(|f| {
766                let name = format_ident!("{}", f.rust_name);
767                let ty: TokenStream = f.inner_type.parse().unwrap();
768                quote! { #name: #ty }
769            })
770            .collect();
771
772        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
773        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
774        let sql = raw_sql_lit(&format!(
775            "DELETE FROM {}\nWHERE {}",
776            table_name, where_clause
777        ));
778        let sql_macro = raw_sql_lit(&format!(
779            "DELETE FROM {}\nWHERE {}",
780            table_name, where_clause_cast
781        ));
782
783        let binds: Vec<TokenStream> = pk_fields
784            .iter()
785            .map(|f| {
786                let name = format_ident!("{}", f.rust_name);
787                quote! { .bind(#name) }
788            })
789            .collect();
790
791        let method = if query_macro {
792            let pk_arg_names: Vec<TokenStream> = pk_fields
793                .iter()
794                .map(|f| {
795                    let name = format_ident!("{}", f.rust_name);
796                    quote! { #name }
797                })
798                .collect();
799            quote! {
800                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
801                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
802                        .execute(&self.pool)
803                        .await?;
804                    Ok(())
805                }
806            }
807        } else {
808            quote! {
809                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
810                    sqlx::query(#sql)
811                        #(#binds)*
812                        .execute(&self.pool)
813                        .await?;
814                    Ok(())
815                }
816            }
817        };
818        method_tokens.push(method);
819    }
820
821    let pool_vis: TokenStream = match pool_visibility {
822        PoolVisibility::Private => quote! {},
823        PoolVisibility::Pub => quote! { pub },
824        PoolVisibility::PubCrate => quote! { pub(crate) },
825    };
826
827    let tokens = quote! {
828        #(#param_structs)*
829
830        pub struct #repo_ident {
831            #pool_vis pool: #pool_type,
832        }
833
834        impl #repo_ident {
835            pub fn new(pool: #pool_type) -> Self {
836                Self { pool }
837            }
838
839            #(#method_tokens)*
840        }
841    };
842
843    (tokens, imports)
844}
845
846fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
847    match db_kind {
848        DatabaseKind::Postgres => quote! { sqlx::PgPool },
849        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
850        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
851    }
852}
853
854/// Wraps a SQL string as a raw string literal `r#"..."#` in the generated code.
855/// Multi-line SQL gets a leading newline so each clause starts on its own line.
856fn raw_sql_lit(s: &str) -> TokenStream {
857    if s.contains('\n') {
858        format!("r#\"\n{}\n\"#", s).parse().unwrap()
859    } else {
860        format!("r#\"{}\"#", s).parse().unwrap()
861    }
862}
863
864fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
865    match db_kind {
866        DatabaseKind::Postgres => format!("${}", index),
867        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
868    }
869}
870
871fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
872    let base = placeholder(db_kind, index);
873    match (&field.sql_type, field.is_sql_array) {
874        (Some(t), true) => format!("{} as {}[]", base, t),
875        (Some(t), false) => format!("{} as {}", base, t),
876        (None, _) => base,
877    }
878}
879
880fn build_where_clause_parsed(
881    pk_fields: &[&ParsedField],
882    db_kind: DatabaseKind,
883    start_index: usize,
884) -> String {
885    pk_fields
886        .iter()
887        .enumerate()
888        .map(|(i, f)| {
889            let p = placeholder(db_kind, start_index + i);
890            format!("{} = {}", f.column_name, p)
891        })
892        .collect::<Vec<_>>()
893        .join(" AND ")
894}
895
896fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
897    let name = format_ident!("{}", field.rust_name);
898    let check_type = if field.is_nullable {
899        &field.inner_type
900    } else {
901        &field.rust_type
902    };
903    let normalized = check_type.replace(' ', "");
904    if normalized.starts_with("Vec<") {
905        quote! { params.#name.as_slice() }
906    } else {
907        quote! { params.#name }
908    }
909}
910
911fn build_where_clause_cast(
912    pk_fields: &[&ParsedField],
913    db_kind: DatabaseKind,
914    start_index: usize,
915) -> String {
916    pk_fields
917        .iter()
918        .enumerate()
919        .map(|(i, f)| {
920            let p = placeholder_with_cast(db_kind, start_index + i, f);
921            format!("{} = {}", f.column_name, p)
922        })
923        .collect::<Vec<_>>()
924        .join(" AND ")
925}
926
927#[allow(clippy::too_many_arguments)]
928fn build_insert_method_parsed(
929    entity_ident: &proc_macro2::Ident,
930    insert_params_ident: &proc_macro2::Ident,
931    sql: &str,
932    sql_macro: &str,
933    binds: &[TokenStream],
934    db_kind: DatabaseKind,
935    table_name: &str,
936    pk_fields: &[&ParsedField],
937    non_pk_fields: &[&ParsedField],
938    use_macro: bool,
939) -> TokenStream {
940    let sql = raw_sql_lit(sql);
941    let sql_macro = raw_sql_lit(sql_macro);
942
943    if use_macro {
944        let macro_args: Vec<TokenStream> = non_pk_fields
945            .iter()
946            .map(|f| macro_arg_for_field(f))
947            .collect();
948
949        match db_kind {
950            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
951                quote! {
952                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
953                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
954                            .fetch_one(&self.pool)
955                            .await
956                    }
957                }
958            }
959            DatabaseKind::Mysql => {
960                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
961                let select_sql = raw_sql_lit(&format!(
962                    "SELECT *\nFROM {}\nWHERE {}",
963                    table_name, pk_where
964                ));
965                let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
966                quote! {
967                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
968                        sqlx::query!(#sql_macro, #(#macro_args),*)
969                            .execute(&self.pool)
970                            .await?;
971                        let id = sqlx::query_scalar!(#last_insert_id_sql)
972                            .fetch_one(&self.pool)
973                            .await?;
974                        sqlx::query_as!(#entity_ident, #select_sql, id)
975                            .fetch_one(&self.pool)
976                            .await
977                    }
978                }
979            }
980        }
981    } else {
982        match db_kind {
983            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
984                quote! {
985                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
986                        sqlx::query_as::<_, #entity_ident>(#sql)
987                            #(#binds)*
988                            .fetch_one(&self.pool)
989                            .await
990                    }
991                }
992            }
993            DatabaseKind::Mysql => {
994                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
995                let select_sql = raw_sql_lit(&format!(
996                    "SELECT *\nFROM {}\nWHERE {}",
997                    table_name, pk_where
998                ));
999                let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1000                quote! {
1001                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1002                        sqlx::query(#sql)
1003                            #(#binds)*
1004                            .execute(&self.pool)
1005                            .await?;
1006                        let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1007                            .fetch_one(&self.pool)
1008                            .await?;
1009                        sqlx::query_as::<_, #entity_ident>(#select_sql)
1010                            .bind(id)
1011                            .fetch_one(&self.pool)
1012                            .await
1013                    }
1014                }
1015            }
1016        }
1017    }
1018}
1019
1020#[allow(clippy::too_many_arguments)]
1021fn build_insert_many_transactionally_method(
1022    entity_ident: &proc_macro2::Ident,
1023    insert_params_ident: &proc_macro2::Ident,
1024    col_list: &str,
1025    num_cols: usize,
1026    insert_source_fields: &[&ParsedField],
1027    binds_loop: &[TokenStream],
1028    db_kind: DatabaseKind,
1029    table_name: &str,
1030    pk_fields: &[&ParsedField],
1031) -> TokenStream {
1032    let body = match db_kind {
1033        DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1034            let col_list_str = col_list.to_string();
1035            let table_name_str = table_name.to_string();
1036
1037            let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1038                .iter()
1039                .enumerate()
1040                .map(|(i, f)| {
1041                    let offset = i;
1042                    match &f.column_default {
1043                        Some(default_expr) => {
1044                            let def = default_expr.as_str();
1045                            match db_kind {
1046                                DatabaseKind::Postgres => quote! {
1047                                    format!("COALESCE(${}, {})", base + #offset + 1, #def)
1048                                },
1049                                _ => quote! {
1050                                    format!("COALESCE(?, {})", #def)
1051                                },
1052                            }
1053                        }
1054                        None => match db_kind {
1055                            DatabaseKind::Postgres => quote! {
1056                                format!("${}", base + #offset + 1)
1057                            },
1058                            _ => quote! {
1059                                "?".to_string()
1060                            },
1061                        },
1062                    }
1063                })
1064                .collect();
1065
1066            quote! {
1067                let mut tx = self.pool.begin().await?;
1068                let mut all_results = Vec::with_capacity(entries.len());
1069                let max_per_chunk = 65535 / #num_cols;
1070                for chunk in entries.chunks(max_per_chunk) {
1071                    let mut values_parts = Vec::with_capacity(chunk.len());
1072                    for (row_idx, _) in chunk.iter().enumerate() {
1073                        let base = row_idx * #num_cols;
1074                        let placeholders = vec![#(#row_placeholder_exprs),*];
1075                        values_parts.push(format!("({})", placeholders.join(", ")));
1076                    }
1077                    let sql = format!(
1078                        "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1079                        #table_name_str,
1080                        #col_list_str,
1081                        values_parts.join(", ")
1082                    );
1083                    let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1084                    for params in chunk {
1085                        #(#binds_loop)*
1086                    }
1087                    let rows = query.fetch_all(&mut *tx).await?;
1088                    all_results.extend(rows);
1089                }
1090                tx.commit().await?;
1091                Ok(all_results)
1092            }
1093        }
1094        DatabaseKind::Mysql => {
1095            let single_placeholders: String = insert_source_fields
1096                .iter()
1097                .enumerate()
1098                .map(|(i, f)| {
1099                    let p = placeholder(db_kind, i + 1);
1100                    match &f.column_default {
1101                        Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1102                        None => p,
1103                    }
1104                })
1105                .collect::<Vec<_>>()
1106                .join(", ");
1107
1108            let single_insert_sql = raw_sql_lit(&format!(
1109                "INSERT INTO {} ({})\nVALUES ({})",
1110                table_name, col_list, single_placeholders
1111            ));
1112
1113            let single_binds: Vec<TokenStream> = insert_source_fields
1114                .iter()
1115                .map(|f| {
1116                    let name = format_ident!("{}", f.rust_name);
1117                    quote! { .bind(&params.#name) }
1118                })
1119                .collect();
1120
1121            let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1122            let select_sql = raw_sql_lit(&format!(
1123                "SELECT *\nFROM {}\nWHERE {}",
1124                table_name, pk_where
1125            ));
1126            let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1127
1128            quote! {
1129                let mut tx = self.pool.begin().await?;
1130                let mut results = Vec::with_capacity(entries.len());
1131                for params in &entries {
1132                    sqlx::query(#single_insert_sql)
1133                        #(#single_binds)*
1134                        .execute(&mut *tx)
1135                        .await?;
1136                    let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1137                        .fetch_one(&mut *tx)
1138                        .await?;
1139                    let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1140                        .bind(id)
1141                        .fetch_one(&mut *tx)
1142                        .await?;
1143                    results.push(row);
1144                }
1145                tx.commit().await?;
1146                Ok(results)
1147            }
1148        }
1149    };
1150
1151    quote! {
1152        pub async fn insert_many_transactionally(
1153            &self,
1154            entries: Vec<#insert_params_ident>,
1155        ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1156            #body
1157        }
1158    }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164    use crate::cli::Methods;
1165    use crate::codegen::parse_and_format;
1166    use crate::codegen::parse_and_format_with_tab_spaces;
1167
1168    fn make_field(
1169        rust_name: &str,
1170        column_name: &str,
1171        rust_type: &str,
1172        nullable: bool,
1173        is_pk: bool,
1174    ) -> ParsedField {
1175        let inner_type = if nullable {
1176            // Strip "Option<" prefix and ">" suffix
1177            rust_type
1178                .strip_prefix("Option<")
1179                .and_then(|s| s.strip_suffix('>'))
1180                .unwrap_or(rust_type)
1181                .to_string()
1182        } else {
1183            rust_type.to_string()
1184        };
1185        ParsedField {
1186            rust_name: rust_name.to_string(),
1187            column_name: column_name.to_string(),
1188            rust_type: rust_type.to_string(),
1189            is_nullable: nullable,
1190            inner_type,
1191            is_primary_key: is_pk,
1192            sql_type: None,
1193            is_sql_array: false,
1194            column_default: None,
1195        }
1196    }
1197
1198    fn make_field_with_default(
1199        rust_name: &str,
1200        column_name: &str,
1201        rust_type: &str,
1202        nullable: bool,
1203        is_pk: bool,
1204        default: &str,
1205    ) -> ParsedField {
1206        let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1207        f.column_default = Some(default.to_string());
1208        f
1209    }
1210
1211    fn entity_with_defaults() -> ParsedEntity {
1212        ParsedEntity {
1213            struct_name: "Tasks".to_string(),
1214            table_name: "tasks".to_string(),
1215            schema_name: None,
1216            is_view: false,
1217            fields: vec![
1218                make_field("id", "id", "i32", false, true),
1219                make_field("title", "title", "String", false, false),
1220                make_field_with_default(
1221                    "status",
1222                    "status",
1223                    "String",
1224                    false,
1225                    false,
1226                    "'idle'::task_status",
1227                ),
1228                make_field_with_default("priority", "priority", "i32", false, false, "0"),
1229                make_field_with_default(
1230                    "created_at",
1231                    "created_at",
1232                    "DateTime<Utc>",
1233                    false,
1234                    false,
1235                    "now()",
1236                ),
1237                make_field("description", "description", "Option<String>", true, false),
1238                make_field_with_default(
1239                    "deleted_at",
1240                    "deleted_at",
1241                    "Option<DateTime<Utc>>",
1242                    true,
1243                    false,
1244                    "NULL",
1245                ),
1246            ],
1247            imports: vec![],
1248        }
1249    }
1250
1251    fn standard_entity() -> ParsedEntity {
1252        ParsedEntity {
1253            struct_name: "Users".to_string(),
1254            table_name: "users".to_string(),
1255            schema_name: None,
1256            is_view: false,
1257            fields: vec![
1258                make_field("id", "id", "i32", false, true),
1259                make_field("name", "name", "String", false, false),
1260                make_field("email", "email", "Option<String>", true, false),
1261            ],
1262            imports: vec![],
1263        }
1264    }
1265
1266    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1267        let skip = Methods::all();
1268        let (tokens, _) = generate_crud_from_parsed(
1269            entity,
1270            db,
1271            "crate::models::users",
1272            &skip,
1273            false,
1274            PoolVisibility::Private,
1275        );
1276        parse_and_format(&tokens)
1277    }
1278
1279    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1280        let skip = Methods::all();
1281        let (tokens, _) = generate_crud_from_parsed(
1282            entity,
1283            db,
1284            "crate::models::users",
1285            &skip,
1286            true,
1287            PoolVisibility::Private,
1288        );
1289        parse_and_format(&tokens)
1290    }
1291
1292    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1293        let (tokens, _) = generate_crud_from_parsed(
1294            entity,
1295            db,
1296            "crate::models::users",
1297            methods,
1298            false,
1299            PoolVisibility::Private,
1300        );
1301        parse_and_format(&tokens)
1302    }
1303
1304    fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1305        let skip = Methods::all();
1306        let (tokens, _) = generate_crud_from_parsed(
1307            entity,
1308            db,
1309            "crate::models::users",
1310            &skip,
1311            false,
1312            PoolVisibility::Private,
1313        );
1314        parse_and_format_with_tab_spaces(&tokens, tab_spaces)
1315    }
1316
1317    // --- basic structure ---
1318
1319    #[test]
1320    fn test_repo_struct_name() {
1321        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1322        assert!(code.contains("pub struct UsersRepository"));
1323    }
1324
1325    #[test]
1326    fn test_repo_new_method() {
1327        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1328        assert!(code.contains("pub fn new("));
1329    }
1330
1331    #[test]
1332    fn test_repo_pool_field_pg() {
1333        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1334        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1335    }
1336
1337    #[test]
1338    fn test_repo_pool_field_pub() {
1339        let skip = Methods::all();
1340        let (tokens, _) = generate_crud_from_parsed(
1341            &standard_entity(),
1342            DatabaseKind::Postgres,
1343            "crate::models::users",
1344            &skip,
1345            false,
1346            PoolVisibility::Pub,
1347        );
1348        let code = parse_and_format(&tokens);
1349        assert!(
1350            code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool")
1351        );
1352    }
1353
1354    #[test]
1355    fn test_repo_pool_field_pub_crate() {
1356        let skip = Methods::all();
1357        let (tokens, _) = generate_crud_from_parsed(
1358            &standard_entity(),
1359            DatabaseKind::Postgres,
1360            "crate::models::users",
1361            &skip,
1362            false,
1363            PoolVisibility::PubCrate,
1364        );
1365        let code = parse_and_format(&tokens);
1366        assert!(
1367            code.contains("pub(crate) pool: sqlx::PgPool")
1368                || code.contains("pub(crate) pool: sqlx :: PgPool")
1369        );
1370    }
1371
1372    #[test]
1373    fn test_repo_pool_field_private() {
1374        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1375        // Should NOT have `pub pool` or `pub(crate) pool`
1376        assert!(!code.contains("pub pool"));
1377        assert!(!code.contains("pub(crate) pool"));
1378    }
1379
1380    #[test]
1381    fn test_repo_pool_field_mysql() {
1382        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1383        assert!(code.contains("MySqlPool") || code.contains("MySql"));
1384    }
1385
1386    #[test]
1387    fn test_repo_pool_field_sqlite() {
1388        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1389        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1390    }
1391
1392    // --- get_all ---
1393
1394    #[test]
1395    fn test_get_all_method() {
1396        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1397        assert!(code.contains("pub async fn get_all"));
1398    }
1399
1400    #[test]
1401    fn test_get_all_returns_vec() {
1402        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1403        assert!(code.contains("Vec<Users>"));
1404    }
1405
1406    #[test]
1407    fn test_get_all_sql() {
1408        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1409        assert!(code.contains("SELECT * FROM users"));
1410    }
1411
1412    // --- paginate ---
1413
1414    #[test]
1415    fn test_paginate_method() {
1416        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1417        assert!(code.contains("pub async fn paginate"));
1418    }
1419
1420    #[test]
1421    fn test_paginate_params_struct() {
1422        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1423        assert!(code.contains("pub struct PaginateUsersParams"));
1424    }
1425
1426    #[test]
1427    fn test_paginate_params_fields() {
1428        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1429        assert!(code.contains("pub page: i64"));
1430        assert!(code.contains("pub per_page: i64"));
1431    }
1432
1433    #[test]
1434    fn test_paginate_returns_paginated() {
1435        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1436        assert!(code.contains("PaginatedUsers"));
1437        assert!(code.contains("PaginationUsersMeta"));
1438    }
1439
1440    #[test]
1441    fn test_paginate_meta_struct() {
1442        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1443        assert!(code.contains("pub struct PaginationUsersMeta"));
1444        assert!(code.contains("pub total: i64"));
1445        assert!(code.contains("pub last_page: i64"));
1446        assert!(code.contains("pub first_page: i64"));
1447        assert!(code.contains("pub current_page: i64"));
1448    }
1449
1450    #[test]
1451    fn test_paginate_data_struct() {
1452        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1453        assert!(code.contains("pub struct PaginatedUsers"));
1454        assert!(code.contains("pub meta: PaginationUsersMeta"));
1455        assert!(code.contains("pub data: Vec<Users>"));
1456    }
1457
1458    #[test]
1459    fn test_paginate_count_sql() {
1460        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1461        assert!(code.contains("SELECT COUNT(*) FROM users"));
1462    }
1463
1464    #[test]
1465    fn test_paginate_sql_pg() {
1466        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1467        assert!(code.contains("LIMIT $1 OFFSET $2"));
1468    }
1469
1470    #[test]
1471    fn test_paginate_sql_mysql() {
1472        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1473        assert!(code.contains("LIMIT ? OFFSET ?"));
1474    }
1475
1476    // --- get ---
1477
1478    #[test]
1479    fn test_get_method() {
1480        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1481        assert!(code.contains("pub async fn get"));
1482    }
1483
1484    #[test]
1485    fn test_get_returns_option() {
1486        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1487        assert!(code.contains("Option<Users>"));
1488    }
1489
1490    #[test]
1491    fn test_get_where_pk_pg() {
1492        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1493        assert!(code.contains("WHERE id = $1"));
1494    }
1495
1496    #[test]
1497    fn test_get_where_pk_mysql() {
1498        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1499        assert!(code.contains("WHERE id = ?"));
1500    }
1501
1502    // --- insert ---
1503
1504    #[test]
1505    fn test_insert_method() {
1506        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1507        assert!(code.contains("pub async fn insert"));
1508    }
1509
1510    #[test]
1511    fn test_insert_params_struct() {
1512        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1513        assert!(code.contains("pub struct InsertUsersParams"));
1514    }
1515
1516    #[test]
1517    fn test_insert_params_no_pk() {
1518        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1519        assert!(code.contains("pub name: String"));
1520        assert!(
1521            code.contains("pub email: Option<String>")
1522                || code.contains("pub email: Option < String >")
1523        );
1524    }
1525
1526    #[test]
1527    fn test_insert_returning_pg() {
1528        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1529        assert!(code.contains("RETURNING *"));
1530    }
1531
1532    #[test]
1533    fn test_insert_returning_sqlite() {
1534        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1535        assert!(code.contains("RETURNING *"));
1536    }
1537
1538    #[test]
1539    fn test_insert_mysql_last_insert_id() {
1540        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1541        assert!(code.contains("LAST_INSERT_ID"));
1542    }
1543
1544    // --- insert with column_default ---
1545
1546    #[test]
1547    fn test_insert_default_col_is_optional() {
1548        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1549        // Fields with column_default and not nullable → Option<T>
1550        let struct_start = code
1551            .find("pub struct InsertTasksParams")
1552            .expect("InsertTasksParams not found");
1553        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1554        let struct_body = &code[struct_start..struct_end];
1555        assert!(
1556            struct_body.contains("Option") && struct_body.contains("status"),
1557            "Expected status as Option in InsertTasksParams: {}",
1558            struct_body
1559        );
1560    }
1561
1562    #[test]
1563    fn test_insert_non_default_col_required() {
1564        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1565        // 'title' has no default → required type (String)
1566        let struct_start = code
1567            .find("pub struct InsertTasksParams")
1568            .expect("InsertTasksParams not found");
1569        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1570        let struct_body = &code[struct_start..struct_end];
1571        assert!(
1572            struct_body.contains("title") && struct_body.contains("String"),
1573            "Expected title as String: {}",
1574            struct_body
1575        );
1576    }
1577
1578    #[test]
1579    fn test_insert_default_col_coalesce_sql() {
1580        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1581        assert!(
1582            code.contains("COALESCE($2, 'idle'::task_status)"),
1583            "Expected COALESCE for status:\n{}",
1584            code
1585        );
1586        assert!(
1587            code.contains("COALESCE($3, 0)"),
1588            "Expected COALESCE for priority:\n{}",
1589            code
1590        );
1591        assert!(
1592            code.contains("COALESCE($4, now())"),
1593            "Expected COALESCE for created_at:\n{}",
1594            code
1595        );
1596    }
1597
1598    #[test]
1599    fn test_insert_default_col_coalesce_json() {
1600        let mut entity = entity_with_defaults();
1601        entity.fields.push(make_field_with_default(
1602            "metadata",
1603            "metadata",
1604            "serde_json::Value",
1605            false,
1606            false,
1607            r#"'{"key": "value"}'::jsonb"#,
1608        ));
1609        let code = gen(&entity, DatabaseKind::Postgres);
1610        assert!(
1611            code.contains(r#"COALESCE($7, '{"key": "value"}'::jsonb)"#),
1612            "Expected COALESCE with JSON default:\n{}",
1613            code
1614        );
1615    }
1616
1617    #[test]
1618    fn test_insert_no_coalesce_for_non_default() {
1619        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1620        // title has no default, so its placeholder should be plain $1 not COALESCE
1621        assert!(
1622            code.contains("VALUES ($1, COALESCE"),
1623            "Expected $1 without COALESCE for title:\n{}",
1624            code
1625        );
1626    }
1627
1628    #[test]
1629    fn test_insert_nullable_with_default_no_double_option() {
1630        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1631        assert!(
1632            !code.contains("Option < Option") && !code.contains("Option<Option"),
1633            "Should not have Option<Option>:\n{}",
1634            code
1635        );
1636    }
1637
1638    #[test]
1639    fn test_insert_derive_default() {
1640        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1641        let struct_start = code
1642            .find("pub struct InsertTasksParams")
1643            .expect("InsertTasksParams not found");
1644        let before_struct = &code[..struct_start];
1645        assert!(
1646            before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"),
1647            "Expected #[derive(Default)] on InsertTasksParams"
1648        );
1649    }
1650
1651    // --- update (patch with COALESCE) ---
1652
1653    #[test]
1654    fn test_update_method() {
1655        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1656        assert!(code.contains("pub async fn update"));
1657    }
1658
1659    #[test]
1660    fn test_update_params_struct() {
1661        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1662        assert!(code.contains("pub struct UpdateUsersParams"));
1663    }
1664
1665    #[test]
1666    fn test_update_pk_in_fn_signature() {
1667        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1668        // PK 'id: i32' should appear between "fn update" and "UpdateUsersParams"
1669        let update_pos = code.find("fn update").expect("fn update not found");
1670        let params_pos = code[update_pos..]
1671            .find("UpdateUsersParams")
1672            .expect("UpdateUsersParams not found in update fn");
1673        let signature = &code[update_pos..update_pos + params_pos];
1674        assert!(
1675            signature.contains("id"),
1676            "Expected 'id' PK in update fn signature: {}",
1677            signature
1678        );
1679    }
1680
1681    #[test]
1682    fn test_update_pk_not_in_struct() {
1683        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1684        // UpdateUsersParams should NOT contain id field
1685        // Extract the struct definition and check it doesn't have id
1686        let struct_start = code
1687            .find("pub struct UpdateUsersParams")
1688            .expect("UpdateUsersParams not found");
1689        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1690        let struct_body = &code[struct_start..struct_end];
1691        assert!(
1692            !struct_body.contains("pub id"),
1693            "PK 'id' should not be in UpdateUsersParams:\n{}",
1694            struct_body
1695        );
1696    }
1697
1698    #[test]
1699    fn test_update_params_non_nullable_wrapped_in_option() {
1700        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1701        // `name: String` becomes `name: Option<String>` in patch params
1702        assert!(
1703            code.contains("pub name: Option<String>")
1704                || code.contains("pub name : Option < String >")
1705        );
1706    }
1707
1708    #[test]
1709    fn test_update_params_already_nullable_no_double_option() {
1710        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1711        // `email: Option<String>` stays `Option<String>`, NOT `Option<Option<String>>`
1712        assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1713    }
1714
1715    #[test]
1716    fn test_update_set_clause_uses_coalesce_pg() {
1717        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1718        assert!(
1719            code.contains("COALESCE($1, name)"),
1720            "Expected COALESCE for name:\n{}",
1721            code
1722        );
1723        assert!(
1724            code.contains("COALESCE($2, email)"),
1725            "Expected COALESCE for email:\n{}",
1726            code
1727        );
1728    }
1729
1730    #[test]
1731    fn test_update_where_clause_pg() {
1732        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1733        assert!(code.contains("WHERE id = $3"));
1734    }
1735
1736    #[test]
1737    fn test_update_returning_pg() {
1738        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1739        assert!(code.contains("COALESCE"));
1740        assert!(code.contains("RETURNING *"));
1741    }
1742
1743    #[test]
1744    fn test_update_set_clause_mysql() {
1745        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1746        assert!(
1747            code.contains("COALESCE(?, name)"),
1748            "Expected COALESCE for MySQL:\n{}",
1749            code
1750        );
1751        assert!(
1752            code.contains("COALESCE(?, email)"),
1753            "Expected COALESCE for email in MySQL:\n{}",
1754            code
1755        );
1756    }
1757
1758    #[test]
1759    fn test_update_set_clause_sqlite() {
1760        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1761        assert!(
1762            code.contains("COALESCE(?, name)"),
1763            "Expected COALESCE for SQLite:\n{}",
1764            code
1765        );
1766    }
1767
1768    // --- overwrite (full replacement, PK as fn param) ---
1769
1770    #[test]
1771    fn test_overwrite_method() {
1772        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1773        assert!(code.contains("pub async fn overwrite"));
1774    }
1775
1776    #[test]
1777    fn test_overwrite_params_struct() {
1778        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1779        assert!(code.contains("pub struct OverwriteUsersParams"));
1780    }
1781
1782    #[test]
1783    fn test_overwrite_pk_in_fn_signature() {
1784        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1785        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1786        let params_pos = code[pos..]
1787            .find("OverwriteUsersParams")
1788            .expect("OverwriteUsersParams not found");
1789        let signature = &code[pos..pos + params_pos];
1790        assert!(
1791            signature.contains("id"),
1792            "Expected PK in overwrite fn signature: {}",
1793            signature
1794        );
1795    }
1796
1797    #[test]
1798    fn test_overwrite_pk_not_in_struct() {
1799        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1800        let struct_start = code
1801            .find("pub struct OverwriteUsersParams")
1802            .expect("OverwriteUsersParams not found");
1803        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1804        let struct_body = &code[struct_start..struct_end];
1805        assert!(
1806            !struct_body.contains("pub id"),
1807            "PK should not be in OverwriteUsersParams: {}",
1808            struct_body
1809        );
1810    }
1811
1812    #[test]
1813    fn test_overwrite_no_coalesce() {
1814        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1815        // Find the overwrite SQL — should have direct SET, no COALESCE
1816        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1817        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1818        assert!(
1819            !method_body.contains("COALESCE"),
1820            "Overwrite should not use COALESCE: {}",
1821            method_body
1822        );
1823    }
1824
1825    #[test]
1826    fn test_overwrite_set_clause_pg() {
1827        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1828        assert!(code.contains("name = $1,"));
1829        assert!(code.contains("email = $2"));
1830        assert!(code.contains("WHERE id = $3"));
1831    }
1832
1833    #[test]
1834    fn test_overwrite_returning_pg() {
1835        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1836        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1837        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1838        assert!(
1839            method_body.contains("RETURNING *"),
1840            "Expected RETURNING * in overwrite"
1841        );
1842    }
1843
1844    #[test]
1845    fn test_view_no_overwrite() {
1846        let mut entity = standard_entity();
1847        entity.is_view = true;
1848        let code = gen(&entity, DatabaseKind::Postgres);
1849        assert!(!code.contains("pub async fn overwrite"));
1850    }
1851
1852    #[test]
1853    fn test_without_overwrite() {
1854        let m = Methods {
1855            overwrite: false,
1856            ..Methods::all()
1857        };
1858        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1859        assert!(!code.contains("pub async fn overwrite"));
1860        assert!(!code.contains("OverwriteUsersParams"));
1861    }
1862
1863    #[test]
1864    fn test_update_and_overwrite_coexist() {
1865        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1866        assert!(
1867            code.contains("pub async fn update"),
1868            "Expected update method"
1869        );
1870        assert!(
1871            code.contains("pub async fn overwrite"),
1872            "Expected overwrite method"
1873        );
1874        assert!(
1875            code.contains("UpdateUsersParams"),
1876            "Expected UpdateUsersParams"
1877        );
1878        assert!(
1879            code.contains("OverwriteUsersParams"),
1880            "Expected OverwriteUsersParams"
1881        );
1882    }
1883
1884    // --- delete ---
1885
1886    #[test]
1887    fn test_delete_method() {
1888        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1889        assert!(code.contains("pub async fn delete"));
1890    }
1891
1892    #[test]
1893    fn test_delete_where_pk() {
1894        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1895        assert!(code.contains("DELETE FROM users"));
1896        assert!(code.contains("WHERE id = $1"));
1897    }
1898
1899    #[test]
1900    fn test_tab_spaces_2_sql_indent() {
1901        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
1902        // SQL content at 8 spaces (4 + 2×2), "# at 6 spaces (4 + 2)
1903        assert!(
1904            code.contains("        SELECT *"),
1905            "Expected SQL at 8-space indent:\n{}",
1906            code
1907        );
1908        assert!(
1909            code.contains("      \"#"),
1910            "Expected closing tag at 6-space indent:\n{}",
1911            code
1912        );
1913    }
1914
1915    #[test]
1916    fn test_tab_spaces_4_sql_indent() {
1917        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
1918        // SQL content at 12 spaces (4 + 2×4), "# at 8 spaces (4 + 4)
1919        assert!(
1920            code.contains("            SELECT *"),
1921            "Expected SQL at 12-space indent:\n{}",
1922            code
1923        );
1924        assert!(
1925            code.contains("        \"#"),
1926            "Expected closing tag at 8-space indent:\n{}",
1927            code
1928        );
1929    }
1930
1931    #[test]
1932    fn test_delete_returns_unit() {
1933        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1934        assert!(
1935            code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>")
1936        );
1937    }
1938
1939    // --- views (read-only) ---
1940
1941    #[test]
1942    fn test_view_no_insert() {
1943        let mut entity = standard_entity();
1944        entity.is_view = true;
1945        let code = gen(&entity, DatabaseKind::Postgres);
1946        assert!(!code.contains("pub async fn insert"));
1947    }
1948
1949    #[test]
1950    fn test_view_no_update() {
1951        let mut entity = standard_entity();
1952        entity.is_view = true;
1953        let code = gen(&entity, DatabaseKind::Postgres);
1954        assert!(!code.contains("pub async fn update"));
1955    }
1956
1957    #[test]
1958    fn test_view_no_delete() {
1959        let mut entity = standard_entity();
1960        entity.is_view = true;
1961        let code = gen(&entity, DatabaseKind::Postgres);
1962        assert!(!code.contains("pub async fn delete"));
1963    }
1964
1965    #[test]
1966    fn test_view_has_get_all() {
1967        let mut entity = standard_entity();
1968        entity.is_view = true;
1969        let code = gen(&entity, DatabaseKind::Postgres);
1970        assert!(code.contains("pub async fn get_all"));
1971    }
1972
1973    #[test]
1974    fn test_view_has_paginate() {
1975        let mut entity = standard_entity();
1976        entity.is_view = true;
1977        let code = gen(&entity, DatabaseKind::Postgres);
1978        assert!(code.contains("pub async fn paginate"));
1979    }
1980
1981    #[test]
1982    fn test_view_has_get() {
1983        let mut entity = standard_entity();
1984        entity.is_view = true;
1985        let code = gen(&entity, DatabaseKind::Postgres);
1986        assert!(code.contains("pub async fn get"));
1987    }
1988
1989    // --- selective methods ---
1990
1991    #[test]
1992    fn test_only_get_all() {
1993        let m = Methods {
1994            get_all: true,
1995            ..Default::default()
1996        };
1997        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1998        assert!(code.contains("pub async fn get_all"));
1999        assert!(!code.contains("pub async fn paginate"));
2000        assert!(!code.contains("pub async fn insert"));
2001    }
2002
2003    #[test]
2004    fn test_without_get_all() {
2005        let m = Methods {
2006            get_all: false,
2007            ..Methods::all()
2008        };
2009        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2010        assert!(!code.contains("pub async fn get_all"));
2011    }
2012
2013    #[test]
2014    fn test_without_paginate() {
2015        let m = Methods {
2016            paginate: false,
2017            ..Methods::all()
2018        };
2019        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2020        assert!(!code.contains("pub async fn paginate"));
2021        assert!(!code.contains("PaginateUsersParams"));
2022    }
2023
2024    #[test]
2025    fn test_without_get() {
2026        let m = Methods {
2027            get: false,
2028            ..Methods::all()
2029        };
2030        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2031        assert!(code.contains("pub async fn get_all"));
2032        let without_get_all = code.replace("get_all", "XXX");
2033        assert!(!without_get_all.contains("fn get("));
2034    }
2035
2036    #[test]
2037    fn test_without_insert() {
2038        let m = Methods {
2039            insert: false,
2040            insert_many: false,
2041            ..Methods::all()
2042        };
2043        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2044        assert!(!code.contains("pub async fn insert"));
2045        assert!(!code.contains("InsertUsersParams"));
2046    }
2047
2048    #[test]
2049    fn test_without_update() {
2050        let m = Methods {
2051            update: false,
2052            ..Methods::all()
2053        };
2054        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2055        assert!(!code.contains("pub async fn update"));
2056        assert!(!code.contains("UpdateUsersParams"));
2057    }
2058
2059    #[test]
2060    fn test_without_delete() {
2061        let m = Methods {
2062            delete: false,
2063            ..Methods::all()
2064        };
2065        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2066        assert!(!code.contains("pub async fn delete"));
2067    }
2068
2069    #[test]
2070    fn test_empty_methods_no_methods() {
2071        let m = Methods::default();
2072        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2073        assert!(!code.contains("pub async fn get_all"));
2074        assert!(!code.contains("pub async fn paginate"));
2075        assert!(!code.contains("pub async fn insert"));
2076        assert!(!code.contains("pub async fn update"));
2077        assert!(!code.contains("pub async fn overwrite"));
2078        assert!(!code.contains("pub async fn delete"));
2079        assert!(!code.contains("pub async fn insert_many"));
2080    }
2081
2082    // --- imports ---
2083
2084    #[test]
2085    fn test_no_pool_import() {
2086        let skip = Methods::all();
2087        let (_, imports) = generate_crud_from_parsed(
2088            &standard_entity(),
2089            DatabaseKind::Postgres,
2090            "crate::models::users",
2091            &skip,
2092            false,
2093            PoolVisibility::Private,
2094        );
2095        assert!(!imports.iter().any(|i| i.contains("PgPool")));
2096    }
2097
2098    #[test]
2099    fn test_imports_contain_entity() {
2100        let skip = Methods::all();
2101        let (_, imports) = generate_crud_from_parsed(
2102            &standard_entity(),
2103            DatabaseKind::Postgres,
2104            "crate::models::users",
2105            &skip,
2106            false,
2107            PoolVisibility::Private,
2108        );
2109        assert!(imports
2110            .iter()
2111            .any(|i| i.contains("crate::models::users::Users")));
2112    }
2113
2114    // --- renamed columns ---
2115
2116    #[test]
2117    fn test_renamed_column_in_sql() {
2118        let entity = ParsedEntity {
2119            struct_name: "Connector".to_string(),
2120            table_name: "connector".to_string(),
2121            schema_name: None,
2122            is_view: false,
2123            fields: vec![
2124                make_field("id", "id", "i32", false, true),
2125                make_field("connector_type", "type", "String", false, false),
2126            ],
2127            imports: vec![],
2128        };
2129        let code = gen(&entity, DatabaseKind::Postgres);
2130        // INSERT should use the DB column name "type", not "connector_type"
2131        assert!(code.contains("type"));
2132        // The Rust param field should be connector_type
2133        assert!(code.contains("pub connector_type: String"));
2134    }
2135
2136    // --- no PK edge cases ---
2137
2138    #[test]
2139    fn test_no_pk_no_get() {
2140        let entity = ParsedEntity {
2141            struct_name: "Logs".to_string(),
2142            table_name: "logs".to_string(),
2143            schema_name: None,
2144            is_view: false,
2145            fields: vec![
2146                make_field("message", "message", "String", false, false),
2147                make_field("ts", "ts", "String", false, false),
2148            ],
2149            imports: vec![],
2150        };
2151        let code = gen(&entity, DatabaseKind::Postgres);
2152        assert!(code.contains("pub async fn get_all"));
2153        let without_get_all = code.replace("get_all", "XXX");
2154        assert!(!without_get_all.contains("fn get("));
2155    }
2156
2157    #[test]
2158    fn test_no_pk_no_delete() {
2159        let entity = ParsedEntity {
2160            struct_name: "Logs".to_string(),
2161            table_name: "logs".to_string(),
2162            schema_name: None,
2163            is_view: false,
2164            fields: vec![make_field("message", "message", "String", false, false)],
2165            imports: vec![],
2166        };
2167        let code = gen(&entity, DatabaseKind::Postgres);
2168        assert!(!code.contains("pub async fn delete"));
2169    }
2170
2171    // --- Default derive on param structs ---
2172
2173    #[test]
2174    fn test_param_structs_have_default() {
2175        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2176        assert!(code.contains("Default"));
2177    }
2178
2179    // --- entity imports forwarded ---
2180
2181    #[test]
2182    fn test_entity_imports_forwarded() {
2183        let entity = ParsedEntity {
2184            struct_name: "Users".to_string(),
2185            table_name: "users".to_string(),
2186            schema_name: None,
2187            is_view: false,
2188            fields: vec![
2189                make_field("id", "id", "Uuid", false, true),
2190                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
2191            ],
2192            imports: vec![
2193                "use chrono::{DateTime, Utc};".to_string(),
2194                "use uuid::Uuid;".to_string(),
2195            ],
2196        };
2197        let skip = Methods::all();
2198        let (_, imports) = generate_crud_from_parsed(
2199            &entity,
2200            DatabaseKind::Postgres,
2201            "crate::models::users",
2202            &skip,
2203            false,
2204            PoolVisibility::Private,
2205        );
2206        assert!(imports.iter().any(|i| i.contains("chrono")));
2207        assert!(imports.iter().any(|i| i.contains("uuid")));
2208    }
2209
2210    #[test]
2211    fn test_entity_imports_empty_when_no_imports() {
2212        let skip = Methods::all();
2213        let (_, imports) = generate_crud_from_parsed(
2214            &standard_entity(),
2215            DatabaseKind::Postgres,
2216            "crate::models::users",
2217            &skip,
2218            false,
2219            PoolVisibility::Private,
2220        );
2221        // Should only have pool + entity imports, no chrono/uuid
2222        assert!(!imports.iter().any(|i| i.contains("chrono")));
2223        assert!(!imports.iter().any(|i| i.contains("uuid")));
2224    }
2225
2226    // --- query_macro mode ---
2227
2228    #[test]
2229    fn test_macro_get_all() {
2230        let m = Methods {
2231            get_all: true,
2232            ..Default::default()
2233        };
2234        let (tokens, _) = generate_crud_from_parsed(
2235            &standard_entity(),
2236            DatabaseKind::Postgres,
2237            "crate::models::users",
2238            &m,
2239            true,
2240            PoolVisibility::Private,
2241        );
2242        let code = parse_and_format(&tokens);
2243        assert!(code.contains("query_as!"));
2244        assert!(!code.contains("query_as::<"));
2245    }
2246
2247    #[test]
2248    fn test_macro_paginate() {
2249        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2250        assert!(code.contains("query_as!"));
2251        assert!(code.contains("per_page, offset"));
2252    }
2253
2254    #[test]
2255    fn test_macro_get() {
2256        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2257        // The get method should use query_as! with the PK as arg
2258        assert!(code.contains("query_as!(Users"));
2259    }
2260
2261    #[test]
2262    fn test_macro_insert_pg() {
2263        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2264        assert!(code.contains("query_as!(Users"));
2265        assert!(code.contains("params.name"));
2266        assert!(code.contains("params.email"));
2267    }
2268
2269    #[test]
2270    fn test_macro_insert_mysql() {
2271        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
2272        // MySQL insert uses query! (not query_as!) for the INSERT
2273        assert!(code.contains("query!"));
2274        assert!(code.contains("query_scalar!"));
2275    }
2276
2277    #[test]
2278    fn test_macro_update() {
2279        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2280        assert!(code.contains("query_as!(Users"));
2281        assert!(
2282            code.contains("COALESCE"),
2283            "Expected COALESCE in macro update:\n{}",
2284            code
2285        );
2286        assert!(code.contains("pub async fn update"));
2287        assert!(code.contains("UpdateUsersParams"));
2288    }
2289
2290    #[test]
2291    fn test_macro_delete() {
2292        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2293        // delete uses query! (no return type)
2294        assert!(code.contains("query!"));
2295    }
2296
2297    #[test]
2298    fn test_macro_no_bind_calls() {
2299        // insert_many always uses runtime mode, so exclude it for this test
2300        let m = Methods {
2301            insert_many: false,
2302            ..Methods::all()
2303        };
2304        let (tokens, _) = generate_crud_from_parsed(
2305            &standard_entity(),
2306            DatabaseKind::Postgres,
2307            "crate::models::users",
2308            &m,
2309            true,
2310            PoolVisibility::Private,
2311        );
2312        let code = parse_and_format(&tokens);
2313        assert!(!code.contains(".bind("));
2314    }
2315
2316    #[test]
2317    fn test_function_style_uses_bind() {
2318        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2319        assert!(code.contains(".bind("));
2320        assert!(!code.contains("query_as!("));
2321        assert!(!code.contains("query!("));
2322    }
2323
2324    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
2325
2326    fn entity_with_sql_array() -> ParsedEntity {
2327        ParsedEntity {
2328            struct_name: "AgentConnector".to_string(),
2329            table_name: "agent.agent_connector".to_string(),
2330            schema_name: Some("agent".to_string()),
2331            is_view: false,
2332            fields: vec![
2333                ParsedField {
2334                    rust_name: "connector_id".to_string(),
2335                    column_name: "connector_id".to_string(),
2336                    rust_type: "Uuid".to_string(),
2337                    inner_type: "Uuid".to_string(),
2338                    is_nullable: false,
2339                    is_primary_key: true,
2340                    sql_type: None,
2341                    is_sql_array: false,
2342                    column_default: None,
2343                },
2344                ParsedField {
2345                    rust_name: "agent_id".to_string(),
2346                    column_name: "agent_id".to_string(),
2347                    rust_type: "Uuid".to_string(),
2348                    inner_type: "Uuid".to_string(),
2349                    is_nullable: false,
2350                    is_primary_key: false,
2351                    sql_type: None,
2352                    is_sql_array: false,
2353                    column_default: None,
2354                },
2355                ParsedField {
2356                    rust_name: "usages".to_string(),
2357                    column_name: "usages".to_string(),
2358                    rust_type: "Vec<ConnectorUsages>".to_string(),
2359                    inner_type: "Vec<ConnectorUsages>".to_string(),
2360                    is_nullable: false,
2361                    is_primary_key: false,
2362                    sql_type: Some("agent.connector_usages".to_string()),
2363                    is_sql_array: true,
2364                    column_default: None,
2365                },
2366            ],
2367            imports: vec!["use uuid::Uuid;".to_string()],
2368        }
2369    }
2370
2371    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2372        let skip = Methods::all();
2373        let (tokens, _) = generate_crud_from_parsed(
2374            entity,
2375            db,
2376            "crate::models::agent_connector",
2377            &skip,
2378            true,
2379            PoolVisibility::Private,
2380        );
2381        parse_and_format(&tokens)
2382    }
2383
2384    #[test]
2385    fn test_sql_array_macro_get_all_uses_runtime() {
2386        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2387        // get_all should use runtime query_as, not macro
2388        assert!(code.contains("query_as::<"));
2389    }
2390
2391    #[test]
2392    fn test_sql_array_macro_get_uses_runtime() {
2393        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2394        // get should use .bind( since it's runtime
2395        assert!(code.contains(".bind("));
2396    }
2397
2398    #[test]
2399    fn test_sql_array_macro_insert_uses_runtime() {
2400        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2401        // insert RETURNING should use runtime query_as
2402        assert!(
2403            code.contains("query_as::<_ , AgentConnector>")
2404                || code.contains("query_as::<_, AgentConnector>")
2405        );
2406    }
2407
2408    #[test]
2409    fn test_sql_array_macro_delete_still_uses_macro() {
2410        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2411        // delete uses query! macro (no rows returned, no array issue)
2412        assert!(code.contains("query!"));
2413    }
2414
2415    #[test]
2416    fn test_sql_array_no_query_as_macro() {
2417        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2418        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
2419        assert!(!code.contains("query_as!("));
2420    }
2421
2422    // --- custom enum (non-array) also triggers runtime fallback ---
2423
2424    fn entity_with_sql_enum() -> ParsedEntity {
2425        ParsedEntity {
2426            struct_name: "Task".to_string(),
2427            table_name: "tasks".to_string(),
2428            schema_name: None,
2429            is_view: false,
2430            fields: vec![
2431                ParsedField {
2432                    rust_name: "id".to_string(),
2433                    column_name: "id".to_string(),
2434                    rust_type: "i32".to_string(),
2435                    inner_type: "i32".to_string(),
2436                    is_nullable: false,
2437                    is_primary_key: true,
2438                    sql_type: None,
2439                    is_sql_array: false,
2440                    column_default: None,
2441                },
2442                ParsedField {
2443                    rust_name: "status".to_string(),
2444                    column_name: "status".to_string(),
2445                    rust_type: "TaskStatus".to_string(),
2446                    inner_type: "TaskStatus".to_string(),
2447                    is_nullable: false,
2448                    is_primary_key: false,
2449                    sql_type: Some("task_status".to_string()),
2450                    is_sql_array: false,
2451                    column_default: None,
2452                },
2453            ],
2454            imports: vec![],
2455        }
2456    }
2457
2458    #[test]
2459    fn test_sql_enum_macro_uses_runtime() {
2460        let skip = Methods::all();
2461        let (tokens, _) = generate_crud_from_parsed(
2462            &entity_with_sql_enum(),
2463            DatabaseKind::Postgres,
2464            "crate::models::task",
2465            &skip,
2466            true,
2467            PoolVisibility::Private,
2468        );
2469        let code = parse_and_format(&tokens);
2470        // SELECT queries should use runtime query_as, not macro
2471        assert!(code.contains("query_as::<"));
2472        assert!(!code.contains("query_as!("));
2473    }
2474
2475    #[test]
2476    fn test_sql_enum_macro_delete_still_uses_macro() {
2477        let skip = Methods::all();
2478        let (tokens, _) = generate_crud_from_parsed(
2479            &entity_with_sql_enum(),
2480            DatabaseKind::Postgres,
2481            "crate::models::task",
2482            &skip,
2483            true,
2484            PoolVisibility::Private,
2485        );
2486        let code = parse_and_format(&tokens);
2487        // DELETE still uses query! macro
2488        assert!(code.contains("query!"));
2489    }
2490
2491    // --- Vec<String> native array uses .as_slice() in macro mode ---
2492
2493    fn entity_with_vec_string() -> ParsedEntity {
2494        ParsedEntity {
2495            struct_name: "PromptHistory".to_string(),
2496            table_name: "prompt_history".to_string(),
2497            schema_name: None,
2498            is_view: false,
2499            fields: vec![
2500                ParsedField {
2501                    rust_name: "id".to_string(),
2502                    column_name: "id".to_string(),
2503                    rust_type: "Uuid".to_string(),
2504                    inner_type: "Uuid".to_string(),
2505                    is_nullable: false,
2506                    is_primary_key: true,
2507                    sql_type: None,
2508                    is_sql_array: false,
2509                    column_default: None,
2510                },
2511                ParsedField {
2512                    rust_name: "content".to_string(),
2513                    column_name: "content".to_string(),
2514                    rust_type: "String".to_string(),
2515                    inner_type: "String".to_string(),
2516                    is_nullable: false,
2517                    is_primary_key: false,
2518                    sql_type: None,
2519                    is_sql_array: false,
2520                    column_default: None,
2521                },
2522                ParsedField {
2523                    rust_name: "tags".to_string(),
2524                    column_name: "tags".to_string(),
2525                    rust_type: "Vec<String>".to_string(),
2526                    inner_type: "Vec<String>".to_string(),
2527                    is_nullable: false,
2528                    is_primary_key: false,
2529                    sql_type: None,
2530                    is_sql_array: false,
2531                    column_default: None,
2532                },
2533            ],
2534            imports: vec!["use uuid::Uuid;".to_string()],
2535        }
2536    }
2537
2538    #[test]
2539    fn test_vec_string_macro_insert_uses_as_slice() {
2540        let skip = Methods::all();
2541        let (tokens, _) = generate_crud_from_parsed(
2542            &entity_with_vec_string(),
2543            DatabaseKind::Postgres,
2544            "crate::models::prompt_history",
2545            &skip,
2546            true,
2547            PoolVisibility::Private,
2548        );
2549        let code = parse_and_format(&tokens);
2550        assert!(code.contains("as_slice()"));
2551    }
2552
2553    #[test]
2554    fn test_vec_string_macro_update_uses_as_slice() {
2555        let skip = Methods::all();
2556        let (tokens, _) = generate_crud_from_parsed(
2557            &entity_with_vec_string(),
2558            DatabaseKind::Postgres,
2559            "crate::models::prompt_history",
2560            &skip,
2561            true,
2562            PoolVisibility::Private,
2563        );
2564        let code = parse_and_format(&tokens);
2565        // Should have as_slice() for insert and update
2566        let count = code.matches("as_slice()").count();
2567        assert!(
2568            count >= 2,
2569            "expected at least 2 as_slice() calls (insert + update), found {}",
2570            count
2571        );
2572    }
2573
2574    #[test]
2575    fn test_vec_string_non_macro_no_as_slice() {
2576        let skip = Methods::all();
2577        let (tokens, _) = generate_crud_from_parsed(
2578            &entity_with_vec_string(),
2579            DatabaseKind::Postgres,
2580            "crate::models::prompt_history",
2581            &skip,
2582            false,
2583            PoolVisibility::Private,
2584        );
2585        let code = parse_and_format(&tokens);
2586        // Runtime mode uses .bind() so no as_slice needed
2587        assert!(!code.contains("as_slice()"));
2588    }
2589
2590    #[test]
2591    fn test_vec_string_parsed_from_source_uses_as_slice() {
2592        use crate::codegen::entity_parser::parse_entity_source;
2593        let source = r#"
2594            use uuid::Uuid;
2595
2596            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2597            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2598            pub struct PromptHistory {
2599                #[sqlx_gen(primary_key)]
2600                pub id: Uuid,
2601                pub content: String,
2602                pub tags: Vec<String>,
2603            }
2604        "#;
2605        let entity = parse_entity_source(source).unwrap();
2606        let skip = Methods::all();
2607        let (tokens, _) = generate_crud_from_parsed(
2608            &entity,
2609            DatabaseKind::Postgres,
2610            "crate::models::prompt_history",
2611            &skip,
2612            true,
2613            PoolVisibility::Private,
2614        );
2615        let code = parse_and_format(&tokens);
2616        assert!(
2617            code.contains("as_slice()"),
2618            "Expected as_slice() in generated code:\n{}",
2619            code
2620        );
2621    }
2622
2623    // --- composite PK only (junction table) ---
2624
2625    fn junction_entity() -> ParsedEntity {
2626        ParsedEntity {
2627            struct_name: "AnalysisRecord".to_string(),
2628            table_name: "analysis.analysis__record".to_string(),
2629            schema_name: None,
2630            is_view: false,
2631            fields: vec![
2632                make_field("record_id", "record_id", "uuid::Uuid", false, true),
2633                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2634            ],
2635            imports: vec![],
2636        }
2637    }
2638
2639    #[test]
2640    fn test_composite_pk_only_insert_generated() {
2641        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2642        assert!(
2643            code.contains("pub struct InsertAnalysisRecordParams"),
2644            "Expected InsertAnalysisRecordParams struct:\n{}",
2645            code
2646        );
2647        assert!(
2648            code.contains("pub record_id"),
2649            "Expected record_id field in insert params:\n{}",
2650            code
2651        );
2652        assert!(
2653            code.contains("pub analysis_id"),
2654            "Expected analysis_id field in insert params:\n{}",
2655            code
2656        );
2657        assert!(
2658            code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"),
2659            "Expected INSERT INTO clause:\n{}",
2660            code
2661        );
2662        assert!(
2663            code.contains("VALUES ($1, $2)"),
2664            "Expected VALUES clause:\n{}",
2665            code
2666        );
2667        assert!(
2668            code.contains("RETURNING *"),
2669            "Expected RETURNING clause:\n{}",
2670            code
2671        );
2672        assert!(
2673            code.contains("pub async fn insert"),
2674            "Expected insert method:\n{}",
2675            code
2676        );
2677    }
2678
2679    #[test]
2680    fn test_composite_pk_only_no_update() {
2681        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2682        assert!(
2683            !code.contains("UpdateAnalysisRecordParams"),
2684            "Expected no UpdateAnalysisRecordParams struct:\n{}",
2685            code
2686        );
2687        assert!(
2688            !code.contains("pub async fn update"),
2689            "Expected no update method:\n{}",
2690            code
2691        );
2692    }
2693
2694    #[test]
2695    fn test_composite_pk_only_delete_generated() {
2696        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2697        assert!(
2698            code.contains("pub async fn delete"),
2699            "Expected delete method:\n{}",
2700            code
2701        );
2702        assert!(
2703            code.contains("DELETE FROM analysis.analysis__record"),
2704            "Expected DELETE clause:\n{}",
2705            code
2706        );
2707        assert!(
2708            code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2709            "Expected WHERE clause:\n{}",
2710            code
2711        );
2712    }
2713
2714    #[test]
2715    fn test_composite_pk_only_get_generated() {
2716        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2717        assert!(
2718            code.contains("pub async fn get"),
2719            "Expected get method:\n{}",
2720            code
2721        );
2722        assert!(
2723            code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2724            "Expected WHERE clause with both PK columns:\n{}",
2725            code
2726        );
2727    }
2728
2729    // --- insert_many_transactionally ---
2730
2731    #[test]
2732    fn test_insert_many_transactionally_method_generated() {
2733        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2734        assert!(
2735            code.contains("pub async fn insert_many_transactionally"),
2736            "Expected insert_many_transactionally method:\n{}",
2737            code
2738        );
2739    }
2740
2741    #[test]
2742    fn test_insert_many_transactionally_signature() {
2743        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2744        assert!(
2745            code.contains("entries: Vec<InsertUsersParams>"),
2746            "Expected Vec<InsertUsersParams> param:\n{}",
2747            code
2748        );
2749        assert!(
2750            code.contains("Result<Vec<Users>"),
2751            "Expected Result<Vec<Users>> return type:\n{}",
2752            code
2753        );
2754    }
2755
2756    #[test]
2757    fn test_insert_many_transactionally_no_strategy_enum() {
2758        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2759        assert!(
2760            !code.contains("TransactionStrategy"),
2761            "TransactionStrategy should not be generated:\n{}",
2762            code
2763        );
2764        assert!(
2765            !code.contains("InsertManyUsersResult"),
2766            "InsertManyUsersResult should not be generated:\n{}",
2767            code
2768        );
2769    }
2770
2771    #[test]
2772    fn test_insert_many_transactionally_uses_transaction_pg() {
2773        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2774        let method_start = code
2775            .find("fn insert_many_transactionally")
2776            .expect("insert_many_transactionally not found");
2777        let method_body = &code[method_start..];
2778        assert!(
2779            method_body.contains("self.pool.begin()"),
2780            "Expected begin():\n{}",
2781            method_body
2782        );
2783        assert!(
2784            method_body.contains("tx.commit()"),
2785            "Expected commit():\n{}",
2786            method_body
2787        );
2788    }
2789
2790    #[test]
2791    fn test_insert_many_transactionally_multi_row_pg() {
2792        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2793        let method_start = code
2794            .find("fn insert_many_transactionally")
2795            .expect("not found");
2796        let method_body = &code[method_start..];
2797        assert!(
2798            method_body.contains("RETURNING *"),
2799            "Expected RETURNING * in multi-row SQL:\n{}",
2800            method_body
2801        );
2802        assert!(
2803            method_body.contains("values_parts"),
2804            "Expected multi-row VALUES building:\n{}",
2805            method_body
2806        );
2807        assert!(
2808            method_body.contains("65535"),
2809            "Expected chunk size limit:\n{}",
2810            method_body
2811        );
2812    }
2813
2814    #[test]
2815    fn test_insert_many_transactionally_multi_row_sqlite() {
2816        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2817        let method_start = code
2818            .find("fn insert_many_transactionally")
2819            .expect("not found");
2820        let method_body = &code[method_start..];
2821        assert!(
2822            method_body.contains("values_parts"),
2823            "Expected multi-row VALUES building for SQLite:\n{}",
2824            method_body
2825        );
2826        assert!(
2827            method_body.contains("RETURNING *"),
2828            "Expected RETURNING * for SQLite:\n{}",
2829            method_body
2830        );
2831    }
2832
2833    #[test]
2834    fn test_insert_many_transactionally_mysql_individual_inserts() {
2835        let code = gen(&standard_entity(), DatabaseKind::Mysql);
2836        let method_start = code
2837            .find("fn insert_many_transactionally")
2838            .expect("not found");
2839        let method_body = &code[method_start..];
2840        assert!(
2841            method_body.contains("LAST_INSERT_ID"),
2842            "Expected LAST_INSERT_ID for MySQL:\n{}",
2843            method_body
2844        );
2845        assert!(
2846            method_body.contains("self.pool.begin()"),
2847            "Expected begin() for MySQL:\n{}",
2848            method_body
2849        );
2850    }
2851
2852    #[test]
2853    fn test_insert_many_transactionally_view_not_generated() {
2854        let mut entity = standard_entity();
2855        entity.is_view = true;
2856        let code = gen(&entity, DatabaseKind::Postgres);
2857        assert!(
2858            !code.contains("pub async fn insert_many_transactionally"),
2859            "should not be generated for views"
2860        );
2861    }
2862
2863    #[test]
2864    fn test_insert_many_transactionally_without_method_not_generated() {
2865        let m = Methods {
2866            insert_many: false,
2867            ..Methods::all()
2868        };
2869        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2870        assert!(
2871            !code.contains("pub async fn insert_many_transactionally"),
2872            "should not be generated when disabled"
2873        );
2874    }
2875
2876    #[test]
2877    fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
2878        let m = Methods {
2879            insert: false,
2880            insert_many: true,
2881            ..Default::default()
2882        };
2883        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2884        assert!(
2885            code.contains("pub struct InsertUsersParams"),
2886            "Expected InsertUsersParams:\n{}",
2887            code
2888        );
2889        assert!(
2890            code.contains("pub async fn insert_many_transactionally"),
2891            "Expected method:\n{}",
2892            code
2893        );
2894        assert!(
2895            !code.contains("pub async fn insert("),
2896            "insert should not be present:\n{}",
2897            code
2898        );
2899    }
2900
2901    #[test]
2902    fn test_insert_many_transactionally_with_column_defaults_coalesce() {
2903        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
2904        let method_start = code
2905            .find("fn insert_many_transactionally")
2906            .expect("not found");
2907        let method_body = &code[method_start..];
2908        assert!(
2909            method_body.contains("COALESCE"),
2910            "Expected COALESCE for fields with defaults:\n{}",
2911            method_body
2912        );
2913    }
2914
2915    #[test]
2916    fn test_insert_many_transactionally_junction_table() {
2917        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2918        assert!(
2919            code.contains("pub async fn insert_many_transactionally"),
2920            "Expected method for junction table:\n{}",
2921            code
2922        );
2923    }
2924
2925    #[test]
2926    fn test_insert_many_transactionally_all_three_backends_compile() {
2927        for db in [
2928            DatabaseKind::Postgres,
2929            DatabaseKind::Mysql,
2930            DatabaseKind::Sqlite,
2931        ] {
2932            let code = gen(&standard_entity(), db);
2933            assert!(
2934                code.contains("pub async fn insert_many_transactionally"),
2935                "Expected method for {:?}",
2936                db
2937            );
2938        }
2939    }
2940}