Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8use crate::codegen::identifiers::{quote_ident, quote_qualified};
9
10pub fn generate_crud_from_parsed(
11    entity: &ParsedEntity,
12    db_kind: DatabaseKind,
13    entity_module_path: &str,
14    methods: &Methods,
15    query_macro: bool,
16    pool_visibility: PoolVisibility,
17) -> (TokenStream, BTreeSet<String>) {
18    let mut imports = BTreeSet::new();
19
20    let entity_ident = format_ident!("{}", entity.struct_name);
21    let repo_name = format!("{}Repository", entity.struct_name);
22    let repo_ident = format_ident!("{}", repo_name);
23
24    // Skip schema qualification for the well-known default schema of each
25    // backend ("public" for Postgres, "main" for SQLite, "dbo" for SQL Server-
26    // adjacent flows). Avoids verbose "public"."users" everywhere when the
27    // user has no other schema in play.
28    let schema_for_sql = entity
29        .schema_name
30        .as_deref()
31        .filter(|s| !crate::codegen::is_default_schema(s));
32    let table_name = quote_qualified(schema_for_sql, &entity.table_name, db_kind);
33
34    // Pool type (used via full path sqlx::PgPool etc., no import needed)
35    let pool_type = pool_type_tokens(db_kind);
36
37    // When the entity has custom SQL types (enums, composites, arrays),
38    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
39    // for queries that return rows. DELETE (no rows returned) can still use macro.
40    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
41    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
42
43    // Entity import
44    imports.insert(format!(
45        "use {}::{};",
46        entity_module_path, entity.struct_name
47    ));
48
49    // Forward type imports from the entity file (chrono, uuid, etc.)
50    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
51    // since the repository lives in a different module where `super` has a different meaning.
52    let entity_parent = entity_module_path
53        .rsplit_once("::")
54        .map(|(parent, _)| parent)
55        .unwrap_or(entity_module_path);
56    for imp in &entity.imports {
57        if let Some(rest) = imp.strip_prefix("use super::") {
58            imports.insert(format!("use {}::{}", entity_parent, rest));
59        } else {
60            imports.insert(imp.clone());
61        }
62    }
63
64    // Primary key fields
65    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
66
67    // Non-PK fields (for insert)
68    let non_pk_fields: Vec<&ParsedField> =
69        entity.fields.iter().filter(|f| !f.is_primary_key).collect();
70
71    let is_view = entity.is_view;
72
73    // Build method tokens
74    let mut method_tokens = Vec::new();
75    let mut param_structs = Vec::new();
76
77    // --- get_all ---
78    if methods.get_all {
79        let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
80        let method = if use_macro {
81            quote! {
82                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
83                    sqlx::query_as!(#entity_ident, #sql)
84                        .fetch_all(&self.pool)
85                        .await
86                }
87            }
88        } else {
89            quote! {
90                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
91                    sqlx::query_as::<_, #entity_ident>(#sql)
92                        .fetch_all(&self.pool)
93                        .await
94                }
95            }
96        };
97        method_tokens.push(method);
98    }
99
100    // --- paginate ---
101    if methods.paginate {
102        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
103        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
104        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
105        let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
106        let sql = raw_sql_lit(&match db_kind {
107            DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
108            DatabaseKind::Mysql | DatabaseKind::Sqlite => {
109                format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name)
110            }
111        });
112        let method = if use_macro {
113            quote! {
114                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
115                    let total: i64 = sqlx::query_scalar!(#count_sql)
116                        .fetch_one(&self.pool)
117                        .await?
118                        .unwrap_or(0);
119                    let per_page = params.per_page;
120                    let current_page = params.page;
121                    let last_page = (total + per_page - 1) / per_page;
122                    let offset = (current_page - 1) * per_page;
123                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
124                        .fetch_all(&self.pool)
125                        .await?;
126                    Ok(#paginated_ident {
127                        meta: #pagination_meta_ident {
128                            total,
129                            per_page,
130                            current_page,
131                            last_page,
132                            first_page: 1,
133                        },
134                        data,
135                    })
136                }
137            }
138        } else {
139            quote! {
140                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
141                    let total: i64 = sqlx::query_scalar(#count_sql)
142                        .fetch_one(&self.pool)
143                        .await?;
144                    let per_page = params.per_page;
145                    let current_page = params.page;
146                    let last_page = (total + per_page - 1) / per_page;
147                    let offset = (current_page - 1) * per_page;
148                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
149                        .bind(per_page)
150                        .bind(offset)
151                        .fetch_all(&self.pool)
152                        .await?;
153                    Ok(#paginated_ident {
154                        meta: #pagination_meta_ident {
155                            total,
156                            per_page,
157                            current_page,
158                            last_page,
159                            first_page: 1,
160                        },
161                        data,
162                    })
163                }
164            }
165        };
166        method_tokens.push(method);
167        param_structs.push(quote! {
168            #[derive(Debug, Clone, Default)]
169            pub struct #paginate_params_ident {
170                pub page: i64,
171                pub per_page: i64,
172            }
173        });
174        param_structs.push(quote! {
175            #[derive(Debug, Clone)]
176            pub struct #pagination_meta_ident {
177                pub total: i64,
178                pub per_page: i64,
179                pub current_page: i64,
180                pub last_page: i64,
181                pub first_page: i64,
182            }
183        });
184        param_structs.push(quote! {
185            #[derive(Debug, Clone)]
186            pub struct #paginated_ident {
187                pub meta: #pagination_meta_ident,
188                pub data: Vec<#entity_ident>,
189            }
190        });
191    }
192
193    // --- get (by PK) ---
194    if methods.get && !pk_fields.is_empty() {
195        let pk_params: Vec<TokenStream> = pk_fields
196            .iter()
197            .map(|f| {
198                let name = format_ident!("{}", f.rust_name);
199                let ty: TokenStream = f.inner_type.parse().unwrap();
200                quote! { #name: #ty }
201            })
202            .collect();
203
204        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
205        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
206        let sql = raw_sql_lit(&format!(
207            "SELECT *\nFROM {}\nWHERE {}",
208            table_name, where_clause
209        ));
210        let sql_macro = raw_sql_lit(&format!(
211            "SELECT *\nFROM {}\nWHERE {}",
212            table_name, where_clause_cast
213        ));
214
215        let binds: Vec<TokenStream> = pk_fields
216            .iter()
217            .map(|f| {
218                let name = format_ident!("{}", f.rust_name);
219                quote! { .bind(#name) }
220            })
221            .collect();
222
223        let method = if use_macro {
224            let pk_arg_names: Vec<TokenStream> = pk_fields
225                .iter()
226                .map(|f| {
227                    let name = format_ident!("{}", f.rust_name);
228                    quote! { #name }
229                })
230                .collect();
231            quote! {
232                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
233                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
234                        .fetch_optional(&self.pool)
235                        .await
236                }
237            }
238        } else {
239            quote! {
240                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
241                    sqlx::query_as::<_, #entity_ident>(#sql)
242                        #(#binds)*
243                        .fetch_optional(&self.pool)
244                        .await
245                }
246            }
247        };
248        method_tokens.push(method);
249    }
250
251    // --- insert (skip for views) ---
252    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
253        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
254
255        // Insert source fields:
256        // - Junction table (all-PK): use the PKs themselves so something gets inserted.
257        // - Composite PK (>1) with extra columns: include the PKs alongside non-PK
258        //   columns. Composite PKs are typically NOT auto-generated, so omitting
259        //   them would produce a NOT NULL violation. The MySQL branch below also
260        //   relies on the bound PK values when LAST_INSERT_ID is not applicable.
261        // - Single PK + extras: keep the legacy behaviour (exclude the PK and
262        //   assume it's SERIAL/AUTO_INCREMENT).
263        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
264            pk_fields.clone()
265        } else if pk_fields.len() > 1 {
266            let mut combined: Vec<&ParsedField> = pk_fields.clone();
267            combined.extend(non_pk_fields.iter().copied());
268            combined
269        } else {
270            non_pk_fields.clone()
271        };
272
273        // Fields with column_default and not already nullable → Option<T>
274        let insert_fields: Vec<TokenStream> = insert_source_fields
275            .iter()
276            .map(|f| {
277                let name = format_ident!("{}", f.rust_name);
278                if f.column_default.is_some() && !f.is_nullable {
279                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
280                    quote! { pub #name: #ty, }
281                } else {
282                    let ty: TokenStream = f.rust_type.parse().unwrap();
283                    quote! { pub #name: #ty, }
284                }
285            })
286            .collect();
287
288        let col_names: Vec<String> = insert_source_fields
289            .iter()
290            .map(|f| quote_ident(&f.column_name, db_kind))
291            .collect();
292        let col_list = col_names.join(", ");
293
294        // Build placeholders with COALESCE for fields that have a column_default
295        let placeholders: String = insert_source_fields
296            .iter()
297            .enumerate()
298            .map(|(i, f)| {
299                let p = placeholder(db_kind, i + 1);
300                match &f.column_default {
301                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
302                    None => p,
303                }
304            })
305            .collect::<Vec<_>>()
306            .join(", ");
307
308        let placeholders_cast: String = insert_source_fields
309            .iter()
310            .enumerate()
311            .map(|(i, f)| {
312                let p = placeholder_with_cast(db_kind, i + 1, f);
313                match &f.column_default {
314                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
315                    None => p,
316                }
317            })
318            .collect::<Vec<_>>()
319            .join(", ");
320
321        let build_insert_sql = |ph: &str| match db_kind {
322            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
323                format!(
324                    "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
325                    table_name, col_list, ph
326                )
327            }
328            DatabaseKind::Mysql => {
329                format!("INSERT INTO {} ({})\nVALUES ({})", table_name, col_list, ph)
330            }
331        };
332        let sql = build_insert_sql(&placeholders);
333        let sql_macro = build_insert_sql(&placeholders_cast);
334
335        let binds: Vec<TokenStream> = insert_source_fields
336            .iter()
337            .map(|f| {
338                let name = format_ident!("{}", f.rust_name);
339                quote! { .bind(&params.#name) }
340            })
341            .collect();
342
343        let insert_method = build_insert_method_parsed(
344            &entity_ident,
345            &insert_params_ident,
346            &sql,
347            &sql_macro,
348            &binds,
349            db_kind,
350            &table_name,
351            &pk_fields,
352            &insert_source_fields,
353            use_macro,
354        );
355        method_tokens.push(insert_method);
356
357        param_structs.push(quote! {
358            #[derive(Debug, Clone, Default)]
359            pub struct #insert_params_ident {
360                #(#insert_fields)*
361            }
362        });
363    }
364
365    // --- insert_many_transactionally (skip for views) ---
366    if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
367        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
368
369        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
370            pk_fields.clone()
371        } else {
372            non_pk_fields.clone()
373        };
374
375        let col_names: Vec<String> = insert_source_fields
376            .iter()
377            .map(|f| quote_ident(&f.column_name, db_kind))
378            .collect();
379        let col_list = col_names.join(", ");
380        let num_cols = insert_source_fields.len();
381
382        let binds_loop: Vec<TokenStream> = insert_source_fields
383            .iter()
384            .map(|f| {
385                let name = format_ident!("{}", f.rust_name);
386                quote! { query = query.bind(&params.#name); }
387            })
388            .collect();
389
390        let insert_many_method = build_insert_many_transactionally_method(
391            &entity_ident,
392            &insert_params_ident,
393            &col_list,
394            num_cols,
395            &insert_source_fields,
396            &binds_loop,
397            db_kind,
398            &table_name,
399            &pk_fields,
400        );
401        method_tokens.push(insert_many_method);
402
403        // Only generate InsertParams if we haven't generated it from the insert method
404        if !methods.insert {
405            let insert_fields: Vec<TokenStream> = insert_source_fields
406                .iter()
407                .map(|f| {
408                    let name = format_ident!("{}", f.rust_name);
409                    if f.column_default.is_some() && !f.is_nullable {
410                        let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
411                        quote! { pub #name: #ty, }
412                    } else {
413                        let ty: TokenStream = f.rust_type.parse().unwrap();
414                        quote! { pub #name: #ty, }
415                    }
416                })
417                .collect();
418
419            param_structs.push(quote! {
420                #[derive(Debug, Clone, Default)]
421                pub struct #insert_params_ident {
422                    #(#insert_fields)*
423                }
424            });
425        }
426    }
427
428    // --- overwrite (full replacement — skip for views, skip when all columns are PKs) ---
429    if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
430        let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
431
432        // PK as function parameters (like get/delete)
433        let pk_fn_params: Vec<TokenStream> = pk_fields
434            .iter()
435            .map(|f| {
436                let name = format_ident!("{}", f.rust_name);
437                let ty: TokenStream = f.inner_type.parse().unwrap();
438                quote! { #name: #ty }
439            })
440            .collect();
441
442        // Non-PK fields keep original types (required)
443        let overwrite_fields: Vec<TokenStream> = non_pk_fields
444            .iter()
445            .map(|f| {
446                let name = format_ident!("{}", f.rust_name);
447                let ty: TokenStream = f.rust_type.parse().unwrap();
448                quote! { pub #name: #ty, }
449            })
450            .collect();
451
452        let set_cols: Vec<String> = non_pk_fields
453            .iter()
454            .enumerate()
455            .map(|(i, f)| {
456                let p = placeholder(db_kind, i + 1);
457                format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
458            })
459            .collect();
460        let set_clause = set_cols.join(",\n  ");
461
462        let set_cols_cast: Vec<String> = non_pk_fields
463            .iter()
464            .enumerate()
465            .map(|(i, f)| {
466                let p = placeholder_with_cast(db_kind, i + 1, f);
467                format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
468            })
469            .collect();
470        let set_clause_cast = set_cols_cast.join(",\n  ");
471
472        let pk_start = non_pk_fields.len() + 1;
473        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
474        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
475
476        let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
477            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
478                format!(
479                    "UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *",
480                    table_name, sc, wc
481                )
482            }
483            DatabaseKind::Mysql => {
484                format!("UPDATE {}\nSET\n  {}\nWHERE {}", table_name, sc, wc)
485            }
486        };
487
488        let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
489        let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
490
491        // Bind non-PK first (from params), then PK (from function args)
492        let mut all_binds: Vec<TokenStream> = non_pk_fields
493            .iter()
494            .map(|f| {
495                let name = format_ident!("{}", f.rust_name);
496                quote! { .bind(&params.#name) }
497            })
498            .collect();
499        for f in &pk_fields {
500            let name = format_ident!("{}", f.rust_name);
501            all_binds.push(quote! { .bind(#name) });
502        }
503
504        // Macro args: non-PK from params, then PK from function args
505        let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
506            .iter()
507            .map(|f| macro_arg_for_field(f))
508            .chain(pk_fields.iter().map(|f| {
509                let name = format_ident!("{}", f.rust_name);
510                quote! { #name }
511            }))
512            .collect();
513
514        let overwrite_method = if use_macro {
515            match db_kind {
516                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
517                    quote! {
518                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
519                            sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
520                                .fetch_one(&self.pool)
521                                .await
522                        }
523                    }
524                }
525                DatabaseKind::Mysql => {
526                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
527                    let select_sql = raw_sql_lit(&format!(
528                        "SELECT *\nFROM {}\nWHERE {}",
529                        table_name, pk_where_select
530                    ));
531                    let pk_macro_args: Vec<TokenStream> = pk_fields
532                        .iter()
533                        .map(|f| {
534                            let name = format_ident!("{}", f.rust_name);
535                            quote! { #name }
536                        })
537                        .collect();
538                    quote! {
539                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
540                            sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
541                                .execute(&self.pool)
542                                .await?;
543                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
544                                .fetch_one(&self.pool)
545                                .await
546                        }
547                    }
548                }
549            }
550        } else {
551            match db_kind {
552                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
553                    quote! {
554                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
555                            sqlx::query_as::<_, #entity_ident>(#sql)
556                                #(#all_binds)*
557                                .fetch_one(&self.pool)
558                                .await
559                        }
560                    }
561                }
562                DatabaseKind::Mysql => {
563                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
564                    let select_sql = raw_sql_lit(&format!(
565                        "SELECT *\nFROM {}\nWHERE {}",
566                        table_name, pk_where_select
567                    ));
568                    let pk_binds: Vec<TokenStream> = pk_fields
569                        .iter()
570                        .map(|f| {
571                            let name = format_ident!("{}", f.rust_name);
572                            quote! { .bind(#name) }
573                        })
574                        .collect();
575                    quote! {
576                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
577                            sqlx::query(#sql)
578                                #(#all_binds)*
579                                .execute(&self.pool)
580                                .await?;
581                            sqlx::query_as::<_, #entity_ident>(#select_sql)
582                                #(#pk_binds)*
583                                .fetch_one(&self.pool)
584                                .await
585                        }
586                    }
587                }
588            }
589        };
590        method_tokens.push(overwrite_method);
591
592        param_structs.push(quote! {
593            #[derive(Debug, Clone, Default)]
594            pub struct #overwrite_params_ident {
595                #(#overwrite_fields)*
596            }
597        });
598    }
599
600    // --- update / patch (COALESCE — skip for views, skip when all columns are PKs) ---
601    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
602        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
603
604        // PK as function parameters (like get/delete)
605        let pk_fn_params: Vec<TokenStream> = pk_fields
606            .iter()
607            .map(|f| {
608                let name = format_ident!("{}", f.rust_name);
609                let ty: TokenStream = f.inner_type.parse().unwrap();
610                quote! { #name: #ty }
611            })
612            .collect();
613
614        // Non-PK fields become Option<T> (no double Option for already nullable)
615        let update_fields: Vec<TokenStream> = non_pk_fields
616            .iter()
617            .map(|f| {
618                let name = format_ident!("{}", f.rust_name);
619                if f.is_nullable {
620                    // Already Option<T> — keep as-is to avoid Option<Option<T>>
621                    let ty: TokenStream = f.rust_type.parse().unwrap();
622                    quote! { pub #name: #ty, }
623                } else {
624                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
625                    quote! { pub #name: #ty, }
626                }
627            })
628            .collect();
629
630        // SET clause with COALESCE for runtime mode
631        let set_cols: Vec<String> = non_pk_fields
632            .iter()
633            .enumerate()
634            .map(|(i, f)| {
635                let p = placeholder(db_kind, i + 1);
636                let col = quote_ident(&f.column_name, db_kind);
637                format!("{col} = COALESCE({p}, {col})", col = col, p = p)
638            })
639            .collect();
640        let set_clause = set_cols.join(",\n  ");
641
642        // SET clause with COALESCE and casts for macro mode
643        let set_cols_cast: Vec<String> = non_pk_fields
644            .iter()
645            .enumerate()
646            .map(|(i, f)| {
647                let p = placeholder_with_cast(db_kind, i + 1, f);
648                let col = quote_ident(&f.column_name, db_kind);
649                format!("{col} = COALESCE({p}, {col})", col = col, p = p)
650            })
651            .collect();
652        let set_clause_cast = set_cols_cast.join(",\n  ");
653
654        let pk_start = non_pk_fields.len() + 1;
655        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
656        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
657
658        let build_update_sql = |sc: &str, wc: &str| match db_kind {
659            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
660                format!(
661                    "UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *",
662                    table_name, sc, wc
663                )
664            }
665            DatabaseKind::Mysql => {
666                format!("UPDATE {}\nSET\n  {}\nWHERE {}", table_name, sc, wc)
667            }
668        };
669        let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
670        let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
671
672        // Bind non-PK first (from params), then PK (from function args)
673        let mut all_binds: Vec<TokenStream> = non_pk_fields
674            .iter()
675            .map(|f| {
676                let name = format_ident!("{}", f.rust_name);
677                quote! { .bind(&params.#name) }
678            })
679            .collect();
680        for f in &pk_fields {
681            let name = format_ident!("{}", f.rust_name);
682            all_binds.push(quote! { .bind(#name) });
683        }
684
685        // Macro args: non-PK from params, then PK from function args
686        let update_macro_args: Vec<TokenStream> = non_pk_fields
687            .iter()
688            .map(|f| macro_arg_for_field(f))
689            .chain(pk_fields.iter().map(|f| {
690                let name = format_ident!("{}", f.rust_name);
691                quote! { #name }
692            }))
693            .collect();
694
695        let update_method = if use_macro {
696            match db_kind {
697                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
698                    quote! {
699                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
700                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
701                                .fetch_one(&self.pool)
702                                .await
703                        }
704                    }
705                }
706                DatabaseKind::Mysql => {
707                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
708                    let select_sql = raw_sql_lit(&format!(
709                        "SELECT *\nFROM {}\nWHERE {}",
710                        table_name, pk_where_select
711                    ));
712                    let pk_macro_args: Vec<TokenStream> = pk_fields
713                        .iter()
714                        .map(|f| {
715                            let name = format_ident!("{}", f.rust_name);
716                            quote! { #name }
717                        })
718                        .collect();
719                    quote! {
720                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
721                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
722                                .execute(&self.pool)
723                                .await?;
724                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
725                                .fetch_one(&self.pool)
726                                .await
727                        }
728                    }
729                }
730            }
731        } else {
732            match db_kind {
733                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
734                    quote! {
735                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
736                            sqlx::query_as::<_, #entity_ident>(#sql)
737                                #(#all_binds)*
738                                .fetch_one(&self.pool)
739                                .await
740                        }
741                    }
742                }
743                DatabaseKind::Mysql => {
744                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
745                    let select_sql = raw_sql_lit(&format!(
746                        "SELECT *\nFROM {}\nWHERE {}",
747                        table_name, pk_where_select
748                    ));
749                    let pk_binds: Vec<TokenStream> = pk_fields
750                        .iter()
751                        .map(|f| {
752                            let name = format_ident!("{}", f.rust_name);
753                            quote! { .bind(#name) }
754                        })
755                        .collect();
756                    quote! {
757                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
758                            sqlx::query(#sql)
759                                #(#all_binds)*
760                                .execute(&self.pool)
761                                .await?;
762                            sqlx::query_as::<_, #entity_ident>(#select_sql)
763                                #(#pk_binds)*
764                                .fetch_one(&self.pool)
765                                .await
766                        }
767                    }
768                }
769            }
770        };
771        method_tokens.push(update_method);
772
773        param_structs.push(quote! {
774            #[derive(Debug, Clone, Default)]
775            pub struct #update_params_ident {
776                #(#update_fields)*
777            }
778        });
779    }
780
781    // --- delete (skip for views) ---
782    if !is_view && methods.delete && !pk_fields.is_empty() {
783        let pk_params: Vec<TokenStream> = pk_fields
784            .iter()
785            .map(|f| {
786                let name = format_ident!("{}", f.rust_name);
787                let ty: TokenStream = f.inner_type.parse().unwrap();
788                quote! { #name: #ty }
789            })
790            .collect();
791
792        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
793        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
794        let sql = raw_sql_lit(&format!(
795            "DELETE FROM {}\nWHERE {}",
796            table_name, where_clause
797        ));
798        let sql_macro = raw_sql_lit(&format!(
799            "DELETE FROM {}\nWHERE {}",
800            table_name, where_clause_cast
801        ));
802
803        let binds: Vec<TokenStream> = pk_fields
804            .iter()
805            .map(|f| {
806                let name = format_ident!("{}", f.rust_name);
807                quote! { .bind(#name) }
808            })
809            .collect();
810
811        let method = if query_macro {
812            let pk_arg_names: Vec<TokenStream> = pk_fields
813                .iter()
814                .map(|f| {
815                    let name = format_ident!("{}", f.rust_name);
816                    quote! { #name }
817                })
818                .collect();
819            quote! {
820                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
821                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
822                        .execute(&self.pool)
823                        .await?;
824                    Ok(())
825                }
826            }
827        } else {
828            quote! {
829                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
830                    sqlx::query(#sql)
831                        #(#binds)*
832                        .execute(&self.pool)
833                        .await?;
834                    Ok(())
835                }
836            }
837        };
838        method_tokens.push(method);
839    }
840
841    let pool_vis: TokenStream = match pool_visibility {
842        PoolVisibility::Private => quote! {},
843        PoolVisibility::Pub => quote! { pub },
844        PoolVisibility::PubCrate => quote! { pub(crate) },
845    };
846
847    let tokens = quote! {
848        #(#param_structs)*
849
850        pub struct #repo_ident {
851            #pool_vis pool: #pool_type,
852        }
853
854        impl #repo_ident {
855            pub fn new(pool: #pool_type) -> Self {
856                Self { pool }
857            }
858
859            #(#method_tokens)*
860        }
861    };
862
863    (tokens, imports)
864}
865
866fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
867    match db_kind {
868        DatabaseKind::Postgres => quote! { sqlx::PgPool },
869        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
870        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
871    }
872}
873
874/// Wraps a SQL string as a raw string literal `r#"..."#` (or `r##"..."##`
875/// when the body contains `"#`) in the generated code. Multi-line SQL gets
876/// a leading newline so each clause starts on its own line.
877fn raw_sql_lit(s: &str) -> TokenStream {
878    // Pick the smallest number of `#` characters whose fence isn't present in
879    // the body. Quoted SQL identifiers (`"users"`) followed by `#` are rare in
880    // practice but a malicious or quirky DB name could trip the default fence.
881    let mut hashes = 1usize;
882    while s.contains(&format!("\"{}", "#".repeat(hashes))) {
883        hashes += 1;
884    }
885    let fence = "#".repeat(hashes);
886    let body = if s.contains('\n') {
887        format!("\n{}\n", s)
888    } else {
889        s.to_string()
890    };
891    format!("r{fence}\"{body}\"{fence}", fence = fence, body = body)
892        .parse()
893        .expect("raw_sql_lit must produce a valid Rust raw string literal")
894}
895
896fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
897    match db_kind {
898        DatabaseKind::Postgres => format!("${}", index),
899        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
900    }
901}
902
903fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
904    let base = placeholder(db_kind, index);
905    match (&field.sql_type, field.is_sql_array) {
906        (Some(t), true) => format!("{} as {}[]", base, t),
907        (Some(t), false) => format!("{} as {}", base, t),
908        (None, _) => base,
909    }
910}
911
912fn build_where_clause_parsed(
913    pk_fields: &[&ParsedField],
914    db_kind: DatabaseKind,
915    start_index: usize,
916) -> String {
917    pk_fields
918        .iter()
919        .enumerate()
920        .map(|(i, f)| {
921            let p = placeholder(db_kind, start_index + i);
922            format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
923        })
924        .collect::<Vec<_>>()
925        .join(" AND ")
926}
927
928fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
929    let name = format_ident!("{}", field.rust_name);
930    let check_type = if field.is_nullable {
931        &field.inner_type
932    } else {
933        &field.rust_type
934    };
935    let normalized = check_type.replace(' ', "");
936    if normalized.starts_with("Vec<") {
937        quote! { params.#name.as_slice() }
938    } else {
939        quote! { params.#name }
940    }
941}
942
943fn build_where_clause_cast(
944    pk_fields: &[&ParsedField],
945    db_kind: DatabaseKind,
946    start_index: usize,
947) -> String {
948    pk_fields
949        .iter()
950        .enumerate()
951        .map(|(i, f)| {
952            let p = placeholder_with_cast(db_kind, start_index + i, f);
953            format!("{} = {}", quote_ident(&f.column_name, db_kind), p)
954        })
955        .collect::<Vec<_>>()
956        .join(" AND ")
957}
958
959#[allow(clippy::too_many_arguments)]
960fn build_insert_method_parsed(
961    entity_ident: &proc_macro2::Ident,
962    insert_params_ident: &proc_macro2::Ident,
963    sql: &str,
964    sql_macro: &str,
965    binds: &[TokenStream],
966    db_kind: DatabaseKind,
967    table_name: &str,
968    pk_fields: &[&ParsedField],
969    non_pk_fields: &[&ParsedField],
970    use_macro: bool,
971) -> TokenStream {
972    let sql = raw_sql_lit(sql);
973    let sql_macro = raw_sql_lit(sql_macro);
974
975    if use_macro {
976        let macro_args: Vec<TokenStream> = non_pk_fields
977            .iter()
978            .map(|f| macro_arg_for_field(f))
979            .collect();
980
981        match db_kind {
982            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
983                quote! {
984                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
985                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
986                            .fetch_one(&self.pool)
987                            .await
988                    }
989                }
990            }
991            DatabaseKind::Mysql => {
992                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
993                let select_sql = raw_sql_lit(&format!(
994                    "SELECT *\nFROM {}\nWHERE {}",
995                    table_name, pk_where
996                ));
997                if pk_fields.len() > 1 {
998                    // Composite PK → the user supplied every PK column in
999                    // params, so SELECT by them directly. LAST_INSERT_ID is
1000                    // meaningless for composite keys (only the first
1001                    // auto-increment column populates it, if any).
1002                    let pk_macro_args: Vec<TokenStream> = pk_fields
1003                        .iter()
1004                        .map(|f| {
1005                            let name = format_ident!("{}", f.rust_name);
1006                            quote! { params.#name }
1007                        })
1008                        .collect();
1009                    quote! {
1010                        pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1011                            sqlx::query!(#sql_macro, #(#macro_args),*)
1012                                .execute(&self.pool)
1013                                .await?;
1014                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
1015                                .fetch_one(&self.pool)
1016                                .await
1017                        }
1018                    }
1019                } else {
1020                    let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
1021                    quote! {
1022                        pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1023                            sqlx::query!(#sql_macro, #(#macro_args),*)
1024                                .execute(&self.pool)
1025                                .await?;
1026                            let id = sqlx::query_scalar!(#last_insert_id_sql)
1027                                .fetch_one(&self.pool)
1028                                .await?;
1029                            sqlx::query_as!(#entity_ident, #select_sql, id)
1030                                .fetch_one(&self.pool)
1031                                .await
1032                        }
1033                    }
1034                }
1035            }
1036        }
1037    } else {
1038        match db_kind {
1039            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1040                quote! {
1041                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1042                        sqlx::query_as::<_, #entity_ident>(#sql)
1043                            #(#binds)*
1044                            .fetch_one(&self.pool)
1045                            .await
1046                    }
1047                }
1048            }
1049            DatabaseKind::Mysql => {
1050                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1051                let select_sql = raw_sql_lit(&format!(
1052                    "SELECT *\nFROM {}\nWHERE {}",
1053                    table_name, pk_where
1054                ));
1055                if pk_fields.len() > 1 {
1056                    let pk_binds: Vec<TokenStream> = pk_fields
1057                        .iter()
1058                        .map(|f| {
1059                            let name = format_ident!("{}", f.rust_name);
1060                            quote! { .bind(&params.#name) }
1061                        })
1062                        .collect();
1063                    quote! {
1064                        pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1065                            sqlx::query(#sql)
1066                                #(#binds)*
1067                                .execute(&self.pool)
1068                                .await?;
1069                            sqlx::query_as::<_, #entity_ident>(#select_sql)
1070                                #(#pk_binds)*
1071                                .fetch_one(&self.pool)
1072                                .await
1073                        }
1074                    }
1075                } else {
1076                    let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1077                    quote! {
1078                        pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
1079                            sqlx::query(#sql)
1080                                #(#binds)*
1081                                .execute(&self.pool)
1082                                .await?;
1083                            let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1084                                .fetch_one(&self.pool)
1085                                .await?;
1086                            sqlx::query_as::<_, #entity_ident>(#select_sql)
1087                                .bind(id)
1088                                .fetch_one(&self.pool)
1089                                .await
1090                        }
1091                    }
1092                }
1093            }
1094        }
1095    }
1096}
1097
1098#[allow(clippy::too_many_arguments)]
1099fn build_insert_many_transactionally_method(
1100    entity_ident: &proc_macro2::Ident,
1101    insert_params_ident: &proc_macro2::Ident,
1102    col_list: &str,
1103    num_cols: usize,
1104    insert_source_fields: &[&ParsedField],
1105    binds_loop: &[TokenStream],
1106    db_kind: DatabaseKind,
1107    table_name: &str,
1108    pk_fields: &[&ParsedField],
1109) -> TokenStream {
1110    let body = match db_kind {
1111        DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1112            let col_list_str = col_list.to_string();
1113            let table_name_str = table_name.to_string();
1114
1115            let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1116                .iter()
1117                .enumerate()
1118                .map(|(i, f)| {
1119                    let offset = i;
1120                    match &f.column_default {
1121                        Some(default_expr) => {
1122                            let def = default_expr.as_str();
1123                            match db_kind {
1124                                DatabaseKind::Postgres => quote! {
1125                                    format!("COALESCE(${}, {})", base + #offset + 1, #def)
1126                                },
1127                                _ => quote! {
1128                                    format!("COALESCE(?, {})", #def)
1129                                },
1130                            }
1131                        }
1132                        None => match db_kind {
1133                            DatabaseKind::Postgres => quote! {
1134                                format!("${}", base + #offset + 1)
1135                            },
1136                            _ => quote! {
1137                                "?".to_string()
1138                            },
1139                        },
1140                    }
1141                })
1142                .collect();
1143
1144            quote! {
1145                let mut tx = self.pool.begin().await?;
1146                let mut all_results = Vec::with_capacity(entries.len());
1147                let max_per_chunk = 65535 / #num_cols;
1148                for chunk in entries.chunks(max_per_chunk) {
1149                    let mut values_parts = Vec::with_capacity(chunk.len());
1150                    for (row_idx, _) in chunk.iter().enumerate() {
1151                        let base = row_idx * #num_cols;
1152                        let placeholders = vec![#(#row_placeholder_exprs),*];
1153                        values_parts.push(format!("({})", placeholders.join(", ")));
1154                    }
1155                    let sql = format!(
1156                        "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1157                        #table_name_str,
1158                        #col_list_str,
1159                        values_parts.join(", ")
1160                    );
1161                    let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1162                    for params in chunk {
1163                        #(#binds_loop)*
1164                    }
1165                    let rows = query.fetch_all(&mut *tx).await?;
1166                    all_results.extend(rows);
1167                }
1168                tx.commit().await?;
1169                Ok(all_results)
1170            }
1171        }
1172        DatabaseKind::Mysql => {
1173            let single_placeholders: String = insert_source_fields
1174                .iter()
1175                .enumerate()
1176                .map(|(i, f)| {
1177                    let p = placeholder(db_kind, i + 1);
1178                    match &f.column_default {
1179                        Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1180                        None => p,
1181                    }
1182                })
1183                .collect::<Vec<_>>()
1184                .join(", ");
1185
1186            let single_insert_sql = raw_sql_lit(&format!(
1187                "INSERT INTO {} ({})\nVALUES ({})",
1188                table_name, col_list, single_placeholders
1189            ));
1190
1191            let single_binds: Vec<TokenStream> = insert_source_fields
1192                .iter()
1193                .map(|f| {
1194                    let name = format_ident!("{}", f.rust_name);
1195                    quote! { .bind(&params.#name) }
1196                })
1197                .collect();
1198
1199            let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1200            let select_sql = raw_sql_lit(&format!(
1201                "SELECT *\nFROM {}\nWHERE {}",
1202                table_name, pk_where
1203            ));
1204
1205            if pk_fields.len() > 1 {
1206                let pk_binds: Vec<TokenStream> = pk_fields
1207                    .iter()
1208                    .map(|f| {
1209                        let name = format_ident!("{}", f.rust_name);
1210                        quote! { .bind(&params.#name) }
1211                    })
1212                    .collect();
1213                quote! {
1214                    let mut tx = self.pool.begin().await?;
1215                    let mut results = Vec::with_capacity(entries.len());
1216                    for params in &entries {
1217                        sqlx::query(#single_insert_sql)
1218                            #(#single_binds)*
1219                            .execute(&mut *tx)
1220                            .await?;
1221                        let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1222                            #(#pk_binds)*
1223                            .fetch_one(&mut *tx)
1224                            .await?;
1225                        results.push(row);
1226                    }
1227                    tx.commit().await?;
1228                    Ok(results)
1229                }
1230            } else {
1231                let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1232                quote! {
1233                    let mut tx = self.pool.begin().await?;
1234                    let mut results = Vec::with_capacity(entries.len());
1235                    for params in &entries {
1236                        sqlx::query(#single_insert_sql)
1237                            #(#single_binds)*
1238                            .execute(&mut *tx)
1239                            .await?;
1240                        let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1241                            .fetch_one(&mut *tx)
1242                            .await?;
1243                        let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1244                            .bind(id)
1245                            .fetch_one(&mut *tx)
1246                            .await?;
1247                        results.push(row);
1248                    }
1249                    tx.commit().await?;
1250                    Ok(results)
1251                }
1252            }
1253        }
1254    };
1255
1256    quote! {
1257        pub async fn insert_many_transactionally(
1258            &self,
1259            entries: Vec<#insert_params_ident>,
1260        ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1261            // Short-circuit empty batches: avoids opening a transaction and
1262            // sending a zero-row INSERT (or a "VALUES " with no tuples, which
1263            // Postgres rejects as a syntax error).
1264            if entries.is_empty() {
1265                return Ok(Vec::new());
1266            }
1267            #body
1268        }
1269    }
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274    use super::*;
1275    use crate::cli::Methods;
1276    use crate::codegen::parse_and_format;
1277    use crate::codegen::parse_and_format_with_tab_spaces;
1278
1279    fn make_field(
1280        rust_name: &str,
1281        column_name: &str,
1282        rust_type: &str,
1283        nullable: bool,
1284        is_pk: bool,
1285    ) -> ParsedField {
1286        let inner_type = if nullable {
1287            // Strip "Option<" prefix and ">" suffix
1288            rust_type
1289                .strip_prefix("Option<")
1290                .and_then(|s| s.strip_suffix('>'))
1291                .unwrap_or(rust_type)
1292                .to_string()
1293        } else {
1294            rust_type.to_string()
1295        };
1296        ParsedField {
1297            rust_name: rust_name.to_string(),
1298            column_name: column_name.to_string(),
1299            rust_type: rust_type.to_string(),
1300            is_nullable: nullable,
1301            inner_type,
1302            is_primary_key: is_pk,
1303            sql_type: None,
1304            is_sql_array: false,
1305            column_default: None,
1306        }
1307    }
1308
1309    fn make_field_with_default(
1310        rust_name: &str,
1311        column_name: &str,
1312        rust_type: &str,
1313        nullable: bool,
1314        is_pk: bool,
1315        default: &str,
1316    ) -> ParsedField {
1317        let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1318        f.column_default = Some(default.to_string());
1319        f
1320    }
1321
1322    fn entity_with_defaults() -> ParsedEntity {
1323        ParsedEntity {
1324            struct_name: "Tasks".to_string(),
1325            table_name: "tasks".to_string(),
1326            schema_name: None,
1327            is_view: false,
1328            fields: vec![
1329                make_field("id", "id", "i32", false, true),
1330                make_field("title", "title", "String", false, false),
1331                make_field_with_default(
1332                    "status",
1333                    "status",
1334                    "String",
1335                    false,
1336                    false,
1337                    "'idle'::task_status",
1338                ),
1339                make_field_with_default("priority", "priority", "i32", false, false, "0"),
1340                make_field_with_default(
1341                    "created_at",
1342                    "created_at",
1343                    "DateTime<Utc>",
1344                    false,
1345                    false,
1346                    "now()",
1347                ),
1348                make_field("description", "description", "Option<String>", true, false),
1349                make_field_with_default(
1350                    "deleted_at",
1351                    "deleted_at",
1352                    "Option<DateTime<Utc>>",
1353                    true,
1354                    false,
1355                    "NULL",
1356                ),
1357            ],
1358            imports: vec![],
1359        }
1360    }
1361
1362    fn standard_entity() -> ParsedEntity {
1363        ParsedEntity {
1364            struct_name: "Users".to_string(),
1365            table_name: "users".to_string(),
1366            schema_name: None,
1367            is_view: false,
1368            fields: vec![
1369                make_field("id", "id", "i32", false, true),
1370                make_field("name", "name", "String", false, false),
1371                make_field("email", "email", "Option<String>", true, false),
1372            ],
1373            imports: vec![],
1374        }
1375    }
1376
1377    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1378        let skip = Methods::all();
1379        let (tokens, _) = generate_crud_from_parsed(
1380            entity,
1381            db,
1382            "crate::models::users",
1383            &skip,
1384            false,
1385            PoolVisibility::Private,
1386        );
1387        parse_and_format(&tokens).unwrap()
1388    }
1389
1390    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1391        let skip = Methods::all();
1392        let (tokens, _) = generate_crud_from_parsed(
1393            entity,
1394            db,
1395            "crate::models::users",
1396            &skip,
1397            true,
1398            PoolVisibility::Private,
1399        );
1400        parse_and_format(&tokens).unwrap()
1401    }
1402
1403    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1404        let (tokens, _) = generate_crud_from_parsed(
1405            entity,
1406            db,
1407            "crate::models::users",
1408            methods,
1409            false,
1410            PoolVisibility::Private,
1411        );
1412        parse_and_format(&tokens).unwrap()
1413    }
1414
1415    fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1416        let skip = Methods::all();
1417        let (tokens, _) = generate_crud_from_parsed(
1418            entity,
1419            db,
1420            "crate::models::users",
1421            &skip,
1422            false,
1423            PoolVisibility::Private,
1424        );
1425        parse_and_format_with_tab_spaces(&tokens, tab_spaces).unwrap()
1426    }
1427
1428    // --- basic structure ---
1429
1430    #[test]
1431    fn test_repo_struct_name() {
1432        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1433        assert!(code.contains("pub struct UsersRepository"));
1434    }
1435
1436    #[test]
1437    fn test_repo_new_method() {
1438        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1439        assert!(code.contains("pub fn new("));
1440    }
1441
1442    #[test]
1443    fn test_repo_pool_field_pg() {
1444        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1445        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1446    }
1447
1448    #[test]
1449    fn test_repo_pool_field_pub() {
1450        let skip = Methods::all();
1451        let (tokens, _) = generate_crud_from_parsed(
1452            &standard_entity(),
1453            DatabaseKind::Postgres,
1454            "crate::models::users",
1455            &skip,
1456            false,
1457            PoolVisibility::Pub,
1458        );
1459        let code = parse_and_format(&tokens).unwrap();
1460        assert!(
1461            code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool")
1462        );
1463    }
1464
1465    #[test]
1466    fn test_repo_pool_field_pub_crate() {
1467        let skip = Methods::all();
1468        let (tokens, _) = generate_crud_from_parsed(
1469            &standard_entity(),
1470            DatabaseKind::Postgres,
1471            "crate::models::users",
1472            &skip,
1473            false,
1474            PoolVisibility::PubCrate,
1475        );
1476        let code = parse_and_format(&tokens).unwrap();
1477        assert!(
1478            code.contains("pub(crate) pool: sqlx::PgPool")
1479                || code.contains("pub(crate) pool: sqlx :: PgPool")
1480        );
1481    }
1482
1483    #[test]
1484    fn test_repo_pool_field_private() {
1485        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1486        // Should NOT have `pub pool` or `pub(crate) pool`
1487        assert!(!code.contains("pub pool"));
1488        assert!(!code.contains("pub(crate) pool"));
1489    }
1490
1491    #[test]
1492    fn test_repo_pool_field_mysql() {
1493        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1494        assert!(code.contains("MySqlPool") || code.contains("MySql"));
1495    }
1496
1497    #[test]
1498    fn test_repo_pool_field_sqlite() {
1499        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1500        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1501    }
1502
1503    // --- get_all ---
1504
1505    #[test]
1506    fn test_get_all_method() {
1507        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1508        assert!(code.contains("pub async fn get_all"));
1509    }
1510
1511    #[test]
1512    fn test_get_all_returns_vec() {
1513        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1514        assert!(code.contains("Vec<Users>"));
1515    }
1516
1517    #[test]
1518    fn test_get_all_sql() {
1519        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1520        assert!(code.contains("SELECT * FROM users"));
1521    }
1522
1523    // --- paginate ---
1524
1525    #[test]
1526    fn test_paginate_method() {
1527        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1528        assert!(code.contains("pub async fn paginate"));
1529    }
1530
1531    #[test]
1532    fn test_paginate_params_struct() {
1533        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1534        assert!(code.contains("pub struct PaginateUsersParams"));
1535    }
1536
1537    #[test]
1538    fn test_paginate_params_fields() {
1539        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1540        assert!(code.contains("pub page: i64"));
1541        assert!(code.contains("pub per_page: i64"));
1542    }
1543
1544    #[test]
1545    fn test_paginate_returns_paginated() {
1546        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1547        assert!(code.contains("PaginatedUsers"));
1548        assert!(code.contains("PaginationUsersMeta"));
1549    }
1550
1551    #[test]
1552    fn test_paginate_meta_struct() {
1553        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1554        assert!(code.contains("pub struct PaginationUsersMeta"));
1555        assert!(code.contains("pub total: i64"));
1556        assert!(code.contains("pub last_page: i64"));
1557        assert!(code.contains("pub first_page: i64"));
1558        assert!(code.contains("pub current_page: i64"));
1559    }
1560
1561    #[test]
1562    fn test_paginate_data_struct() {
1563        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1564        assert!(code.contains("pub struct PaginatedUsers"));
1565        assert!(code.contains("pub meta: PaginationUsersMeta"));
1566        assert!(code.contains("pub data: Vec<Users>"));
1567    }
1568
1569    #[test]
1570    fn test_paginate_count_sql() {
1571        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1572        assert!(code.contains("SELECT COUNT(*) FROM users"));
1573    }
1574
1575    #[test]
1576    fn test_paginate_sql_pg() {
1577        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1578        assert!(code.contains("LIMIT $1 OFFSET $2"));
1579    }
1580
1581    #[test]
1582    fn test_paginate_sql_mysql() {
1583        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1584        assert!(code.contains("LIMIT ? OFFSET ?"));
1585    }
1586
1587    // --- get ---
1588
1589    #[test]
1590    fn test_get_method() {
1591        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1592        assert!(code.contains("pub async fn get"));
1593    }
1594
1595    #[test]
1596    fn test_get_returns_option() {
1597        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1598        assert!(code.contains("Option<Users>"));
1599    }
1600
1601    #[test]
1602    fn test_get_where_pk_pg() {
1603        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1604        assert!(code.contains("WHERE id = $1"));
1605    }
1606
1607    #[test]
1608    fn test_get_where_pk_mysql() {
1609        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1610        assert!(code.contains("WHERE id = ?"));
1611    }
1612
1613    // --- insert ---
1614
1615    #[test]
1616    fn test_insert_method() {
1617        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1618        assert!(code.contains("pub async fn insert"));
1619    }
1620
1621    #[test]
1622    fn test_insert_params_struct() {
1623        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1624        assert!(code.contains("pub struct InsertUsersParams"));
1625    }
1626
1627    #[test]
1628    fn test_insert_params_no_pk() {
1629        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1630        assert!(code.contains("pub name: String"));
1631        assert!(
1632            code.contains("pub email: Option<String>")
1633                || code.contains("pub email: Option < String >")
1634        );
1635    }
1636
1637    #[test]
1638    fn test_insert_returning_pg() {
1639        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1640        assert!(code.contains("RETURNING *"));
1641    }
1642
1643    #[test]
1644    fn test_insert_returning_sqlite() {
1645        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1646        assert!(code.contains("RETURNING *"));
1647    }
1648
1649    #[test]
1650    fn test_insert_mysql_last_insert_id() {
1651        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1652        assert!(code.contains("LAST_INSERT_ID"));
1653    }
1654
1655    // --- insert with column_default ---
1656
1657    #[test]
1658    fn test_insert_default_col_is_optional() {
1659        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1660        // Fields with column_default and not nullable → Option<T>
1661        let struct_start = code
1662            .find("pub struct InsertTasksParams")
1663            .expect("InsertTasksParams not found");
1664        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1665        let struct_body = &code[struct_start..struct_end];
1666        assert!(
1667            struct_body.contains("Option") && struct_body.contains("status"),
1668            "Expected status as Option in InsertTasksParams: {}",
1669            struct_body
1670        );
1671    }
1672
1673    #[test]
1674    fn test_insert_non_default_col_required() {
1675        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1676        // 'title' has no default → required type (String)
1677        let struct_start = code
1678            .find("pub struct InsertTasksParams")
1679            .expect("InsertTasksParams not found");
1680        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1681        let struct_body = &code[struct_start..struct_end];
1682        assert!(
1683            struct_body.contains("title") && struct_body.contains("String"),
1684            "Expected title as String: {}",
1685            struct_body
1686        );
1687    }
1688
1689    #[test]
1690    fn test_insert_default_col_coalesce_sql() {
1691        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1692        assert!(
1693            code.contains("COALESCE($2, 'idle'::task_status)"),
1694            "Expected COALESCE for status:\n{}",
1695            code
1696        );
1697        assert!(
1698            code.contains("COALESCE($3, 0)"),
1699            "Expected COALESCE for priority:\n{}",
1700            code
1701        );
1702        assert!(
1703            code.contains("COALESCE($4, now())"),
1704            "Expected COALESCE for created_at:\n{}",
1705            code
1706        );
1707    }
1708
1709    #[test]
1710    fn test_insert_default_col_coalesce_json() {
1711        let mut entity = entity_with_defaults();
1712        entity.fields.push(make_field_with_default(
1713            "metadata",
1714            "metadata",
1715            "serde_json::Value",
1716            false,
1717            false,
1718            r#"'{"key": "value"}'::jsonb"#,
1719        ));
1720        let code = gen(&entity, DatabaseKind::Postgres);
1721        assert!(
1722            code.contains(r#"COALESCE($7, '{"key": "value"}'::jsonb)"#),
1723            "Expected COALESCE with JSON default:\n{}",
1724            code
1725        );
1726    }
1727
1728    #[test]
1729    fn test_insert_no_coalesce_for_non_default() {
1730        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1731        // title has no default, so its placeholder should be plain $1 not COALESCE
1732        assert!(
1733            code.contains("VALUES ($1, COALESCE"),
1734            "Expected $1 without COALESCE for title:\n{}",
1735            code
1736        );
1737    }
1738
1739    #[test]
1740    fn test_insert_nullable_with_default_no_double_option() {
1741        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1742        assert!(
1743            !code.contains("Option < Option") && !code.contains("Option<Option"),
1744            "Should not have Option<Option>:\n{}",
1745            code
1746        );
1747    }
1748
1749    #[test]
1750    fn test_insert_derive_default() {
1751        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1752        let struct_start = code
1753            .find("pub struct InsertTasksParams")
1754            .expect("InsertTasksParams not found");
1755        let before_struct = &code[..struct_start];
1756        assert!(
1757            before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"),
1758            "Expected #[derive(Default)] on InsertTasksParams"
1759        );
1760    }
1761
1762    // --- update (patch with COALESCE) ---
1763
1764    #[test]
1765    fn test_update_method() {
1766        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1767        assert!(code.contains("pub async fn update"));
1768    }
1769
1770    #[test]
1771    fn test_update_params_struct() {
1772        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1773        assert!(code.contains("pub struct UpdateUsersParams"));
1774    }
1775
1776    #[test]
1777    fn test_update_pk_in_fn_signature() {
1778        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1779        // PK 'id: i32' should appear between "fn update" and "UpdateUsersParams"
1780        let update_pos = code.find("fn update").expect("fn update not found");
1781        let params_pos = code[update_pos..]
1782            .find("UpdateUsersParams")
1783            .expect("UpdateUsersParams not found in update fn");
1784        let signature = &code[update_pos..update_pos + params_pos];
1785        assert!(
1786            signature.contains("id"),
1787            "Expected 'id' PK in update fn signature: {}",
1788            signature
1789        );
1790    }
1791
1792    #[test]
1793    fn test_update_pk_not_in_struct() {
1794        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1795        // UpdateUsersParams should NOT contain id field
1796        // Extract the struct definition and check it doesn't have id
1797        let struct_start = code
1798            .find("pub struct UpdateUsersParams")
1799            .expect("UpdateUsersParams not found");
1800        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1801        let struct_body = &code[struct_start..struct_end];
1802        assert!(
1803            !struct_body.contains("pub id"),
1804            "PK 'id' should not be in UpdateUsersParams:\n{}",
1805            struct_body
1806        );
1807    }
1808
1809    #[test]
1810    fn test_update_params_non_nullable_wrapped_in_option() {
1811        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1812        // `name: String` becomes `name: Option<String>` in patch params
1813        assert!(
1814            code.contains("pub name: Option<String>")
1815                || code.contains("pub name : Option < String >")
1816        );
1817    }
1818
1819    #[test]
1820    fn test_update_params_already_nullable_no_double_option() {
1821        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1822        // `email: Option<String>` stays `Option<String>`, NOT `Option<Option<String>>`
1823        assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1824    }
1825
1826    #[test]
1827    fn test_update_set_clause_uses_coalesce_pg() {
1828        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1829        assert!(
1830            code.contains("COALESCE($1, name)"),
1831            "Expected COALESCE for name:\n{}",
1832            code
1833        );
1834        assert!(
1835            code.contains("COALESCE($2, email)"),
1836            "Expected COALESCE for email:\n{}",
1837            code
1838        );
1839    }
1840
1841    #[test]
1842    fn test_update_where_clause_pg() {
1843        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1844        assert!(code.contains("WHERE id = $3"));
1845    }
1846
1847    #[test]
1848    fn test_update_returning_pg() {
1849        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1850        assert!(code.contains("COALESCE"));
1851        assert!(code.contains("RETURNING *"));
1852    }
1853
1854    #[test]
1855    fn test_update_set_clause_mysql() {
1856        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1857        assert!(
1858            code.contains("COALESCE(?, name)"),
1859            "Expected COALESCE for MySQL:\n{}",
1860            code
1861        );
1862        assert!(
1863            code.contains("COALESCE(?, email)"),
1864            "Expected COALESCE for email in MySQL:\n{}",
1865            code
1866        );
1867    }
1868
1869    #[test]
1870    fn test_update_set_clause_sqlite() {
1871        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1872        assert!(
1873            code.contains("COALESCE(?, name)"),
1874            "Expected COALESCE for SQLite:\n{}",
1875            code
1876        );
1877    }
1878
1879    // --- overwrite (full replacement, PK as fn param) ---
1880
1881    #[test]
1882    fn test_overwrite_method() {
1883        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1884        assert!(code.contains("pub async fn overwrite"));
1885    }
1886
1887    #[test]
1888    fn test_overwrite_params_struct() {
1889        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1890        assert!(code.contains("pub struct OverwriteUsersParams"));
1891    }
1892
1893    #[test]
1894    fn test_overwrite_pk_in_fn_signature() {
1895        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1896        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1897        let params_pos = code[pos..]
1898            .find("OverwriteUsersParams")
1899            .expect("OverwriteUsersParams not found");
1900        let signature = &code[pos..pos + params_pos];
1901        assert!(
1902            signature.contains("id"),
1903            "Expected PK in overwrite fn signature: {}",
1904            signature
1905        );
1906    }
1907
1908    #[test]
1909    fn test_overwrite_pk_not_in_struct() {
1910        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1911        let struct_start = code
1912            .find("pub struct OverwriteUsersParams")
1913            .expect("OverwriteUsersParams not found");
1914        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1915        let struct_body = &code[struct_start..struct_end];
1916        assert!(
1917            !struct_body.contains("pub id"),
1918            "PK should not be in OverwriteUsersParams: {}",
1919            struct_body
1920        );
1921    }
1922
1923    #[test]
1924    fn test_overwrite_no_coalesce() {
1925        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1926        // Find the overwrite SQL — should have direct SET, no COALESCE
1927        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1928        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1929        assert!(
1930            !method_body.contains("COALESCE"),
1931            "Overwrite should not use COALESCE: {}",
1932            method_body
1933        );
1934    }
1935
1936    #[test]
1937    fn test_overwrite_set_clause_pg() {
1938        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1939        assert!(code.contains("name = $1,"));
1940        assert!(code.contains("email = $2"));
1941        assert!(code.contains("WHERE id = $3"));
1942    }
1943
1944    #[test]
1945    fn test_overwrite_returning_pg() {
1946        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1947        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1948        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1949        assert!(
1950            method_body.contains("RETURNING *"),
1951            "Expected RETURNING * in overwrite"
1952        );
1953    }
1954
1955    #[test]
1956    fn test_view_no_overwrite() {
1957        let mut entity = standard_entity();
1958        entity.is_view = true;
1959        let code = gen(&entity, DatabaseKind::Postgres);
1960        assert!(!code.contains("pub async fn overwrite"));
1961    }
1962
1963    #[test]
1964    fn test_without_overwrite() {
1965        let m = Methods {
1966            overwrite: false,
1967            ..Methods::all()
1968        };
1969        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1970        assert!(!code.contains("pub async fn overwrite"));
1971        assert!(!code.contains("OverwriteUsersParams"));
1972    }
1973
1974    #[test]
1975    fn test_update_and_overwrite_coexist() {
1976        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1977        assert!(
1978            code.contains("pub async fn update"),
1979            "Expected update method"
1980        );
1981        assert!(
1982            code.contains("pub async fn overwrite"),
1983            "Expected overwrite method"
1984        );
1985        assert!(
1986            code.contains("UpdateUsersParams"),
1987            "Expected UpdateUsersParams"
1988        );
1989        assert!(
1990            code.contains("OverwriteUsersParams"),
1991            "Expected OverwriteUsersParams"
1992        );
1993    }
1994
1995    // --- delete ---
1996
1997    #[test]
1998    fn test_delete_method() {
1999        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2000        assert!(code.contains("pub async fn delete"));
2001    }
2002
2003    #[test]
2004    fn test_delete_where_pk() {
2005        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2006        assert!(code.contains("DELETE FROM users"));
2007        assert!(code.contains("WHERE id = $1"));
2008    }
2009
2010    #[test]
2011    fn test_tab_spaces_2_sql_indent() {
2012        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
2013        // SQL content at 8 spaces (4 + 2×2), "# at 6 spaces (4 + 2)
2014        assert!(
2015            code.contains("        SELECT *"),
2016            "Expected SQL at 8-space indent:\n{}",
2017            code
2018        );
2019        assert!(
2020            code.contains("      \"#"),
2021            "Expected closing tag at 6-space indent:\n{}",
2022            code
2023        );
2024    }
2025
2026    #[test]
2027    fn test_tab_spaces_4_sql_indent() {
2028        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
2029        // SQL content at 12 spaces (4 + 2×4), "# at 8 spaces (4 + 4)
2030        assert!(
2031            code.contains("            SELECT *"),
2032            "Expected SQL at 12-space indent:\n{}",
2033            code
2034        );
2035        assert!(
2036            code.contains("        \"#"),
2037            "Expected closing tag at 8-space indent:\n{}",
2038            code
2039        );
2040    }
2041
2042    #[test]
2043    fn test_delete_returns_unit() {
2044        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2045        assert!(
2046            code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>")
2047        );
2048    }
2049
2050    // --- views (read-only) ---
2051
2052    #[test]
2053    fn test_view_no_insert() {
2054        let mut entity = standard_entity();
2055        entity.is_view = true;
2056        let code = gen(&entity, DatabaseKind::Postgres);
2057        assert!(!code.contains("pub async fn insert"));
2058    }
2059
2060    #[test]
2061    fn test_view_no_update() {
2062        let mut entity = standard_entity();
2063        entity.is_view = true;
2064        let code = gen(&entity, DatabaseKind::Postgres);
2065        assert!(!code.contains("pub async fn update"));
2066    }
2067
2068    #[test]
2069    fn test_view_no_delete() {
2070        let mut entity = standard_entity();
2071        entity.is_view = true;
2072        let code = gen(&entity, DatabaseKind::Postgres);
2073        assert!(!code.contains("pub async fn delete"));
2074    }
2075
2076    #[test]
2077    fn test_view_has_get_all() {
2078        let mut entity = standard_entity();
2079        entity.is_view = true;
2080        let code = gen(&entity, DatabaseKind::Postgres);
2081        assert!(code.contains("pub async fn get_all"));
2082    }
2083
2084    #[test]
2085    fn test_view_has_paginate() {
2086        let mut entity = standard_entity();
2087        entity.is_view = true;
2088        let code = gen(&entity, DatabaseKind::Postgres);
2089        assert!(code.contains("pub async fn paginate"));
2090    }
2091
2092    #[test]
2093    fn test_view_has_get() {
2094        let mut entity = standard_entity();
2095        entity.is_view = true;
2096        let code = gen(&entity, DatabaseKind::Postgres);
2097        assert!(code.contains("pub async fn get"));
2098    }
2099
2100    // --- selective methods ---
2101
2102    #[test]
2103    fn test_only_get_all() {
2104        let m = Methods {
2105            get_all: true,
2106            ..Default::default()
2107        };
2108        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2109        assert!(code.contains("pub async fn get_all"));
2110        assert!(!code.contains("pub async fn paginate"));
2111        assert!(!code.contains("pub async fn insert"));
2112    }
2113
2114    #[test]
2115    fn test_without_get_all() {
2116        let m = Methods {
2117            get_all: false,
2118            ..Methods::all()
2119        };
2120        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2121        assert!(!code.contains("pub async fn get_all"));
2122    }
2123
2124    #[test]
2125    fn test_without_paginate() {
2126        let m = Methods {
2127            paginate: false,
2128            ..Methods::all()
2129        };
2130        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2131        assert!(!code.contains("pub async fn paginate"));
2132        assert!(!code.contains("PaginateUsersParams"));
2133    }
2134
2135    #[test]
2136    fn test_without_get() {
2137        let m = Methods {
2138            get: false,
2139            ..Methods::all()
2140        };
2141        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2142        assert!(code.contains("pub async fn get_all"));
2143        let without_get_all = code.replace("get_all", "XXX");
2144        assert!(!without_get_all.contains("fn get("));
2145    }
2146
2147    #[test]
2148    fn test_without_insert() {
2149        let m = Methods {
2150            insert: false,
2151            insert_many: false,
2152            ..Methods::all()
2153        };
2154        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2155        assert!(!code.contains("pub async fn insert"));
2156        assert!(!code.contains("InsertUsersParams"));
2157    }
2158
2159    #[test]
2160    fn test_without_update() {
2161        let m = Methods {
2162            update: false,
2163            ..Methods::all()
2164        };
2165        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2166        assert!(!code.contains("pub async fn update"));
2167        assert!(!code.contains("UpdateUsersParams"));
2168    }
2169
2170    #[test]
2171    fn test_without_delete() {
2172        let m = Methods {
2173            delete: false,
2174            ..Methods::all()
2175        };
2176        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2177        assert!(!code.contains("pub async fn delete"));
2178    }
2179
2180    #[test]
2181    fn test_empty_methods_no_methods() {
2182        let m = Methods::default();
2183        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2184        assert!(!code.contains("pub async fn get_all"));
2185        assert!(!code.contains("pub async fn paginate"));
2186        assert!(!code.contains("pub async fn insert"));
2187        assert!(!code.contains("pub async fn update"));
2188        assert!(!code.contains("pub async fn overwrite"));
2189        assert!(!code.contains("pub async fn delete"));
2190        assert!(!code.contains("pub async fn insert_many"));
2191    }
2192
2193    // --- imports ---
2194
2195    #[test]
2196    fn test_no_pool_import() {
2197        let skip = Methods::all();
2198        let (_, imports) = generate_crud_from_parsed(
2199            &standard_entity(),
2200            DatabaseKind::Postgres,
2201            "crate::models::users",
2202            &skip,
2203            false,
2204            PoolVisibility::Private,
2205        );
2206        assert!(!imports.iter().any(|i| i.contains("PgPool")));
2207    }
2208
2209    #[test]
2210    fn test_imports_contain_entity() {
2211        let skip = Methods::all();
2212        let (_, imports) = generate_crud_from_parsed(
2213            &standard_entity(),
2214            DatabaseKind::Postgres,
2215            "crate::models::users",
2216            &skip,
2217            false,
2218            PoolVisibility::Private,
2219        );
2220        assert!(imports
2221            .iter()
2222            .any(|i| i.contains("crate::models::users::Users")));
2223    }
2224
2225    // --- renamed columns ---
2226
2227    #[test]
2228    fn test_renamed_column_in_sql() {
2229        let entity = ParsedEntity {
2230            struct_name: "Connector".to_string(),
2231            table_name: "connector".to_string(),
2232            schema_name: None,
2233            is_view: false,
2234            fields: vec![
2235                make_field("id", "id", "i32", false, true),
2236                make_field("connector_type", "type", "String", false, false),
2237            ],
2238            imports: vec![],
2239        };
2240        let code = gen(&entity, DatabaseKind::Postgres);
2241        // INSERT should use the DB column name "type", not "connector_type"
2242        assert!(code.contains("type"));
2243        // The Rust param field should be connector_type
2244        assert!(code.contains("pub connector_type: String"));
2245    }
2246
2247    // --- no PK edge cases ---
2248
2249    #[test]
2250    fn test_no_pk_no_get() {
2251        let entity = ParsedEntity {
2252            struct_name: "Logs".to_string(),
2253            table_name: "logs".to_string(),
2254            schema_name: None,
2255            is_view: false,
2256            fields: vec![
2257                make_field("message", "message", "String", false, false),
2258                make_field("ts", "ts", "String", false, false),
2259            ],
2260            imports: vec![],
2261        };
2262        let code = gen(&entity, DatabaseKind::Postgres);
2263        assert!(code.contains("pub async fn get_all"));
2264        let without_get_all = code.replace("get_all", "XXX");
2265        assert!(!without_get_all.contains("fn get("));
2266    }
2267
2268    #[test]
2269    fn test_no_pk_no_delete() {
2270        let entity = ParsedEntity {
2271            struct_name: "Logs".to_string(),
2272            table_name: "logs".to_string(),
2273            schema_name: None,
2274            is_view: false,
2275            fields: vec![make_field("message", "message", "String", false, false)],
2276            imports: vec![],
2277        };
2278        let code = gen(&entity, DatabaseKind::Postgres);
2279        assert!(!code.contains("pub async fn delete"));
2280    }
2281
2282    // --- Default derive on param structs ---
2283
2284    #[test]
2285    fn test_param_structs_have_default() {
2286        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2287        assert!(code.contains("Default"));
2288    }
2289
2290    // --- entity imports forwarded ---
2291
2292    #[test]
2293    fn test_entity_imports_forwarded() {
2294        let entity = ParsedEntity {
2295            struct_name: "Users".to_string(),
2296            table_name: "users".to_string(),
2297            schema_name: None,
2298            is_view: false,
2299            fields: vec![
2300                make_field("id", "id", "Uuid", false, true),
2301                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
2302            ],
2303            imports: vec![
2304                "use chrono::{DateTime, Utc};".to_string(),
2305                "use uuid::Uuid;".to_string(),
2306            ],
2307        };
2308        let skip = Methods::all();
2309        let (_, imports) = generate_crud_from_parsed(
2310            &entity,
2311            DatabaseKind::Postgres,
2312            "crate::models::users",
2313            &skip,
2314            false,
2315            PoolVisibility::Private,
2316        );
2317        assert!(imports.iter().any(|i| i.contains("chrono")));
2318        assert!(imports.iter().any(|i| i.contains("uuid")));
2319    }
2320
2321    #[test]
2322    fn test_entity_imports_empty_when_no_imports() {
2323        let skip = Methods::all();
2324        let (_, imports) = generate_crud_from_parsed(
2325            &standard_entity(),
2326            DatabaseKind::Postgres,
2327            "crate::models::users",
2328            &skip,
2329            false,
2330            PoolVisibility::Private,
2331        );
2332        // Should only have pool + entity imports, no chrono/uuid
2333        assert!(!imports.iter().any(|i| i.contains("chrono")));
2334        assert!(!imports.iter().any(|i| i.contains("uuid")));
2335    }
2336
2337    // --- query_macro mode ---
2338
2339    #[test]
2340    fn test_macro_get_all() {
2341        let m = Methods {
2342            get_all: true,
2343            ..Default::default()
2344        };
2345        let (tokens, _) = generate_crud_from_parsed(
2346            &standard_entity(),
2347            DatabaseKind::Postgres,
2348            "crate::models::users",
2349            &m,
2350            true,
2351            PoolVisibility::Private,
2352        );
2353        let code = parse_and_format(&tokens).unwrap();
2354        assert!(code.contains("query_as!"));
2355        assert!(!code.contains("query_as::<"));
2356    }
2357
2358    #[test]
2359    fn test_macro_paginate() {
2360        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2361        assert!(code.contains("query_as!"));
2362        assert!(code.contains("per_page, offset"));
2363    }
2364
2365    #[test]
2366    fn test_macro_get() {
2367        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2368        // The get method should use query_as! with the PK as arg
2369        assert!(code.contains("query_as!(Users"));
2370    }
2371
2372    #[test]
2373    fn test_macro_insert_pg() {
2374        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2375        assert!(code.contains("query_as!(Users"));
2376        assert!(code.contains("params.name"));
2377        assert!(code.contains("params.email"));
2378    }
2379
2380    #[test]
2381    fn test_macro_insert_mysql() {
2382        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
2383        // MySQL insert uses query! (not query_as!) for the INSERT
2384        assert!(code.contains("query!"));
2385        assert!(code.contains("query_scalar!"));
2386    }
2387
2388    #[test]
2389    fn test_macro_update() {
2390        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2391        assert!(code.contains("query_as!(Users"));
2392        assert!(
2393            code.contains("COALESCE"),
2394            "Expected COALESCE in macro update:\n{}",
2395            code
2396        );
2397        assert!(code.contains("pub async fn update"));
2398        assert!(code.contains("UpdateUsersParams"));
2399    }
2400
2401    #[test]
2402    fn test_macro_delete() {
2403        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
2404        // delete uses query! (no return type)
2405        assert!(code.contains("query!"));
2406    }
2407
2408    #[test]
2409    fn test_macro_no_bind_calls() {
2410        // insert_many always uses runtime mode, so exclude it for this test
2411        let m = Methods {
2412            insert_many: false,
2413            ..Methods::all()
2414        };
2415        let (tokens, _) = generate_crud_from_parsed(
2416            &standard_entity(),
2417            DatabaseKind::Postgres,
2418            "crate::models::users",
2419            &m,
2420            true,
2421            PoolVisibility::Private,
2422        );
2423        let code = parse_and_format(&tokens).unwrap();
2424        assert!(!code.contains(".bind("));
2425    }
2426
2427    #[test]
2428    fn test_function_style_uses_bind() {
2429        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2430        assert!(code.contains(".bind("));
2431        assert!(!code.contains("query_as!("));
2432        assert!(!code.contains("query!("));
2433    }
2434
2435    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
2436
2437    fn entity_with_sql_array() -> ParsedEntity {
2438        ParsedEntity {
2439            struct_name: "AgentConnector".to_string(),
2440            table_name: "agent.agent_connector".to_string(),
2441            schema_name: Some("agent".to_string()),
2442            is_view: false,
2443            fields: vec![
2444                ParsedField {
2445                    rust_name: "connector_id".to_string(),
2446                    column_name: "connector_id".to_string(),
2447                    rust_type: "Uuid".to_string(),
2448                    inner_type: "Uuid".to_string(),
2449                    is_nullable: false,
2450                    is_primary_key: true,
2451                    sql_type: None,
2452                    is_sql_array: false,
2453                    column_default: None,
2454                },
2455                ParsedField {
2456                    rust_name: "agent_id".to_string(),
2457                    column_name: "agent_id".to_string(),
2458                    rust_type: "Uuid".to_string(),
2459                    inner_type: "Uuid".to_string(),
2460                    is_nullable: false,
2461                    is_primary_key: false,
2462                    sql_type: None,
2463                    is_sql_array: false,
2464                    column_default: None,
2465                },
2466                ParsedField {
2467                    rust_name: "usages".to_string(),
2468                    column_name: "usages".to_string(),
2469                    rust_type: "Vec<ConnectorUsages>".to_string(),
2470                    inner_type: "Vec<ConnectorUsages>".to_string(),
2471                    is_nullable: false,
2472                    is_primary_key: false,
2473                    sql_type: Some("agent.connector_usages".to_string()),
2474                    is_sql_array: true,
2475                    column_default: None,
2476                },
2477            ],
2478            imports: vec!["use uuid::Uuid;".to_string()],
2479        }
2480    }
2481
2482    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2483        let skip = Methods::all();
2484        let (tokens, _) = generate_crud_from_parsed(
2485            entity,
2486            db,
2487            "crate::models::agent_connector",
2488            &skip,
2489            true,
2490            PoolVisibility::Private,
2491        );
2492        parse_and_format(&tokens).unwrap()
2493    }
2494
2495    #[test]
2496    fn test_sql_array_macro_get_all_uses_runtime() {
2497        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2498        // get_all should use runtime query_as, not macro
2499        assert!(code.contains("query_as::<"));
2500    }
2501
2502    #[test]
2503    fn test_sql_array_macro_get_uses_runtime() {
2504        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2505        // get should use .bind( since it's runtime
2506        assert!(code.contains(".bind("));
2507    }
2508
2509    #[test]
2510    fn test_sql_array_macro_insert_uses_runtime() {
2511        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2512        // insert RETURNING should use runtime query_as
2513        assert!(
2514            code.contains("query_as::<_ , AgentConnector>")
2515                || code.contains("query_as::<_, AgentConnector>")
2516        );
2517    }
2518
2519    #[test]
2520    fn test_sql_array_macro_delete_still_uses_macro() {
2521        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2522        // delete uses query! macro (no rows returned, no array issue)
2523        assert!(code.contains("query!"));
2524    }
2525
2526    #[test]
2527    fn test_sql_array_no_query_as_macro() {
2528        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2529        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
2530        assert!(!code.contains("query_as!("));
2531    }
2532
2533    // --- custom enum (non-array) also triggers runtime fallback ---
2534
2535    fn entity_with_sql_enum() -> ParsedEntity {
2536        ParsedEntity {
2537            struct_name: "Task".to_string(),
2538            table_name: "tasks".to_string(),
2539            schema_name: None,
2540            is_view: false,
2541            fields: vec![
2542                ParsedField {
2543                    rust_name: "id".to_string(),
2544                    column_name: "id".to_string(),
2545                    rust_type: "i32".to_string(),
2546                    inner_type: "i32".to_string(),
2547                    is_nullable: false,
2548                    is_primary_key: true,
2549                    sql_type: None,
2550                    is_sql_array: false,
2551                    column_default: None,
2552                },
2553                ParsedField {
2554                    rust_name: "status".to_string(),
2555                    column_name: "status".to_string(),
2556                    rust_type: "TaskStatus".to_string(),
2557                    inner_type: "TaskStatus".to_string(),
2558                    is_nullable: false,
2559                    is_primary_key: false,
2560                    sql_type: Some("task_status".to_string()),
2561                    is_sql_array: false,
2562                    column_default: None,
2563                },
2564            ],
2565            imports: vec![],
2566        }
2567    }
2568
2569    #[test]
2570    fn test_sql_enum_macro_uses_runtime() {
2571        let skip = Methods::all();
2572        let (tokens, _) = generate_crud_from_parsed(
2573            &entity_with_sql_enum(),
2574            DatabaseKind::Postgres,
2575            "crate::models::task",
2576            &skip,
2577            true,
2578            PoolVisibility::Private,
2579        );
2580        let code = parse_and_format(&tokens).unwrap();
2581        // SELECT queries should use runtime query_as, not macro
2582        assert!(code.contains("query_as::<"));
2583        assert!(!code.contains("query_as!("));
2584    }
2585
2586    #[test]
2587    fn test_sql_enum_macro_delete_still_uses_macro() {
2588        let skip = Methods::all();
2589        let (tokens, _) = generate_crud_from_parsed(
2590            &entity_with_sql_enum(),
2591            DatabaseKind::Postgres,
2592            "crate::models::task",
2593            &skip,
2594            true,
2595            PoolVisibility::Private,
2596        );
2597        let code = parse_and_format(&tokens).unwrap();
2598        // DELETE still uses query! macro
2599        assert!(code.contains("query!"));
2600    }
2601
2602    // --- Vec<String> native array uses .as_slice() in macro mode ---
2603
2604    fn entity_with_vec_string() -> ParsedEntity {
2605        ParsedEntity {
2606            struct_name: "PromptHistory".to_string(),
2607            table_name: "prompt_history".to_string(),
2608            schema_name: None,
2609            is_view: false,
2610            fields: vec![
2611                ParsedField {
2612                    rust_name: "id".to_string(),
2613                    column_name: "id".to_string(),
2614                    rust_type: "Uuid".to_string(),
2615                    inner_type: "Uuid".to_string(),
2616                    is_nullable: false,
2617                    is_primary_key: true,
2618                    sql_type: None,
2619                    is_sql_array: false,
2620                    column_default: None,
2621                },
2622                ParsedField {
2623                    rust_name: "content".to_string(),
2624                    column_name: "content".to_string(),
2625                    rust_type: "String".to_string(),
2626                    inner_type: "String".to_string(),
2627                    is_nullable: false,
2628                    is_primary_key: false,
2629                    sql_type: None,
2630                    is_sql_array: false,
2631                    column_default: None,
2632                },
2633                ParsedField {
2634                    rust_name: "tags".to_string(),
2635                    column_name: "tags".to_string(),
2636                    rust_type: "Vec<String>".to_string(),
2637                    inner_type: "Vec<String>".to_string(),
2638                    is_nullable: false,
2639                    is_primary_key: false,
2640                    sql_type: None,
2641                    is_sql_array: false,
2642                    column_default: None,
2643                },
2644            ],
2645            imports: vec!["use uuid::Uuid;".to_string()],
2646        }
2647    }
2648
2649    #[test]
2650    fn test_vec_string_macro_insert_uses_as_slice() {
2651        let skip = Methods::all();
2652        let (tokens, _) = generate_crud_from_parsed(
2653            &entity_with_vec_string(),
2654            DatabaseKind::Postgres,
2655            "crate::models::prompt_history",
2656            &skip,
2657            true,
2658            PoolVisibility::Private,
2659        );
2660        let code = parse_and_format(&tokens).unwrap();
2661        assert!(code.contains("as_slice()"));
2662    }
2663
2664    #[test]
2665    fn test_vec_string_macro_update_uses_as_slice() {
2666        let skip = Methods::all();
2667        let (tokens, _) = generate_crud_from_parsed(
2668            &entity_with_vec_string(),
2669            DatabaseKind::Postgres,
2670            "crate::models::prompt_history",
2671            &skip,
2672            true,
2673            PoolVisibility::Private,
2674        );
2675        let code = parse_and_format(&tokens).unwrap();
2676        // Should have as_slice() for insert and update
2677        let count = code.matches("as_slice()").count();
2678        assert!(
2679            count >= 2,
2680            "expected at least 2 as_slice() calls (insert + update), found {}",
2681            count
2682        );
2683    }
2684
2685    #[test]
2686    fn test_vec_string_non_macro_no_as_slice() {
2687        let skip = Methods::all();
2688        let (tokens, _) = generate_crud_from_parsed(
2689            &entity_with_vec_string(),
2690            DatabaseKind::Postgres,
2691            "crate::models::prompt_history",
2692            &skip,
2693            false,
2694            PoolVisibility::Private,
2695        );
2696        let code = parse_and_format(&tokens).unwrap();
2697        // Runtime mode uses .bind() so no as_slice needed
2698        assert!(!code.contains("as_slice()"));
2699    }
2700
2701    #[test]
2702    fn test_vec_string_parsed_from_source_uses_as_slice() {
2703        use crate::codegen::entity_parser::parse_entity_source;
2704        let source = r#"
2705            use uuid::Uuid;
2706
2707            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2708            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2709            pub struct PromptHistory {
2710                #[sqlx_gen(primary_key)]
2711                pub id: Uuid,
2712                pub content: String,
2713                pub tags: Vec<String>,
2714            }
2715        "#;
2716        let entity = parse_entity_source(source).unwrap();
2717        let skip = Methods::all();
2718        let (tokens, _) = generate_crud_from_parsed(
2719            &entity,
2720            DatabaseKind::Postgres,
2721            "crate::models::prompt_history",
2722            &skip,
2723            true,
2724            PoolVisibility::Private,
2725        );
2726        let code = parse_and_format(&tokens).unwrap();
2727        assert!(
2728            code.contains("as_slice()"),
2729            "Expected as_slice() in generated code:\n{}",
2730            code
2731        );
2732    }
2733
2734    // --- composite PK only (junction table) ---
2735
2736    fn junction_entity() -> ParsedEntity {
2737        ParsedEntity {
2738            struct_name: "AnalysisRecord".to_string(),
2739            table_name: "analysis__record".to_string(),
2740            schema_name: Some("analysis".to_string()),
2741            is_view: false,
2742            fields: vec![
2743                make_field("record_id", "record_id", "uuid::Uuid", false, true),
2744                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2745            ],
2746            imports: vec![],
2747        }
2748    }
2749
2750    #[test]
2751    fn test_composite_pk_only_insert_generated() {
2752        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2753        assert!(
2754            code.contains("pub struct InsertAnalysisRecordParams"),
2755            "Expected InsertAnalysisRecordParams struct:\n{}",
2756            code
2757        );
2758        assert!(
2759            code.contains("pub record_id"),
2760            "Expected record_id field in insert params:\n{}",
2761            code
2762        );
2763        assert!(
2764            code.contains("pub analysis_id"),
2765            "Expected analysis_id field in insert params:\n{}",
2766            code
2767        );
2768        assert!(
2769            code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"),
2770            "Expected quoted INSERT INTO clause:\n{}",
2771            code
2772        );
2773        assert!(
2774            code.contains("VALUES ($1, $2)"),
2775            "Expected VALUES clause:\n{}",
2776            code
2777        );
2778        assert!(
2779            code.contains("RETURNING *"),
2780            "Expected RETURNING clause:\n{}",
2781            code
2782        );
2783        assert!(
2784            code.contains("pub async fn insert"),
2785            "Expected insert method:\n{}",
2786            code
2787        );
2788    }
2789
2790    #[test]
2791    fn test_composite_pk_only_no_update() {
2792        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2793        assert!(
2794            !code.contains("UpdateAnalysisRecordParams"),
2795            "Expected no UpdateAnalysisRecordParams struct:\n{}",
2796            code
2797        );
2798        assert!(
2799            !code.contains("pub async fn update"),
2800            "Expected no update method:\n{}",
2801            code
2802        );
2803    }
2804
2805    #[test]
2806    fn test_composite_pk_only_delete_generated() {
2807        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2808        assert!(
2809            code.contains("pub async fn delete"),
2810            "Expected delete method:\n{}",
2811            code
2812        );
2813        assert!(
2814            code.contains("DELETE FROM analysis.analysis__record"),
2815            "Expected DELETE clause:\n{}",
2816            code
2817        );
2818        assert!(
2819            code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2820            "Expected WHERE clause:\n{}",
2821            code
2822        );
2823    }
2824
2825    #[test]
2826    fn test_composite_pk_only_get_generated() {
2827        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2828        assert!(
2829            code.contains("pub async fn get"),
2830            "Expected get method:\n{}",
2831            code
2832        );
2833        assert!(
2834            code.contains("WHERE record_id = $1 AND analysis_id = $2"),
2835            "Expected WHERE clause with both PK columns:\n{}",
2836            code
2837        );
2838    }
2839
2840    // --- composite PK + non-PK columns (MySQL) ---
2841
2842    fn composite_pk_with_extra() -> ParsedEntity {
2843        ParsedEntity {
2844            struct_name: "OrderItems".to_string(),
2845            table_name: "order_items".to_string(),
2846            schema_name: None,
2847            is_view: false,
2848            fields: vec![
2849                make_field("order_id", "order_id", "i32", false, true),
2850                make_field("product_id", "product_id", "i32", false, true),
2851                make_field("qty", "qty", "i32", false, false),
2852            ],
2853            imports: vec![],
2854        }
2855    }
2856
2857    #[test]
2858    fn test_mysql_composite_pk_insert_uses_select_not_last_insert_id() {
2859        let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql);
2860        assert!(
2861            !code.contains("LAST_INSERT_ID"),
2862            "composite PK insert must not use LAST_INSERT_ID(), got:\n{}",
2863            code
2864        );
2865        assert!(
2866            code.contains("SELECT *"),
2867            "must SELECT the row back after INSERT, got:\n{}",
2868            code
2869        );
2870        assert!(
2871            code.contains("WHERE order_id = ? AND product_id = ?"),
2872            "SELECT must use bound composite PK values, got:\n{}",
2873            code
2874        );
2875    }
2876
2877    #[test]
2878    fn test_mysql_composite_pk_includes_pks_in_insert_params() {
2879        let code = gen(&composite_pk_with_extra(), DatabaseKind::Mysql);
2880        assert!(
2881            code.contains("pub order_id"),
2882            "InsertParams must expose composite PK column order_id, got:\n{}",
2883            code
2884        );
2885        assert!(code.contains("pub product_id"));
2886        assert!(code.contains("pub qty"));
2887    }
2888
2889    #[test]
2890    fn test_mysql_single_pk_insert_still_uses_last_insert_id() {
2891        let code = gen(&standard_entity(), DatabaseKind::Mysql);
2892        assert!(
2893            code.contains("LAST_INSERT_ID"),
2894            "single-PK MySQL insert should still rely on LAST_INSERT_ID()"
2895        );
2896    }
2897
2898    // --- insert_many_transactionally ---
2899
2900    #[test]
2901    fn test_insert_many_transactionally_method_generated() {
2902        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2903        assert!(
2904            code.contains("pub async fn insert_many_transactionally"),
2905            "Expected insert_many_transactionally method:\n{}",
2906            code
2907        );
2908    }
2909
2910    #[test]
2911    fn test_insert_many_transactionally_signature() {
2912        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2913        assert!(
2914            code.contains("entries: Vec<InsertUsersParams>"),
2915            "Expected Vec<InsertUsersParams> param:\n{}",
2916            code
2917        );
2918        assert!(
2919            code.contains("Result<Vec<Users>"),
2920            "Expected Result<Vec<Users>> return type:\n{}",
2921            code
2922        );
2923    }
2924
2925    #[test]
2926    fn test_insert_many_transactionally_no_strategy_enum() {
2927        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2928        assert!(
2929            !code.contains("TransactionStrategy"),
2930            "TransactionStrategy should not be generated:\n{}",
2931            code
2932        );
2933        assert!(
2934            !code.contains("InsertManyUsersResult"),
2935            "InsertManyUsersResult should not be generated:\n{}",
2936            code
2937        );
2938    }
2939
2940    #[test]
2941    fn test_insert_many_transactionally_uses_transaction_pg() {
2942        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2943        let method_start = code
2944            .find("fn insert_many_transactionally")
2945            .expect("insert_many_transactionally not found");
2946        let method_body = &code[method_start..];
2947        assert!(
2948            method_body.contains("self.pool.begin()"),
2949            "Expected begin():\n{}",
2950            method_body
2951        );
2952        assert!(
2953            method_body.contains("tx.commit()"),
2954            "Expected commit():\n{}",
2955            method_body
2956        );
2957    }
2958
2959    #[test]
2960    fn test_insert_many_transactionally_multi_row_pg() {
2961        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2962        let method_start = code
2963            .find("fn insert_many_transactionally")
2964            .expect("not found");
2965        let method_body = &code[method_start..];
2966        assert!(
2967            method_body.contains("RETURNING *"),
2968            "Expected RETURNING * in multi-row SQL:\n{}",
2969            method_body
2970        );
2971        assert!(
2972            method_body.contains("values_parts"),
2973            "Expected multi-row VALUES building:\n{}",
2974            method_body
2975        );
2976        assert!(
2977            method_body.contains("65535"),
2978            "Expected chunk size limit:\n{}",
2979            method_body
2980        );
2981    }
2982
2983    #[test]
2984    fn test_insert_many_transactionally_multi_row_sqlite() {
2985        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2986        let method_start = code
2987            .find("fn insert_many_transactionally")
2988            .expect("not found");
2989        let method_body = &code[method_start..];
2990        assert!(
2991            method_body.contains("values_parts"),
2992            "Expected multi-row VALUES building for SQLite:\n{}",
2993            method_body
2994        );
2995        assert!(
2996            method_body.contains("RETURNING *"),
2997            "Expected RETURNING * for SQLite:\n{}",
2998            method_body
2999        );
3000    }
3001
3002    #[test]
3003    fn test_insert_many_transactionally_mysql_individual_inserts() {
3004        let code = gen(&standard_entity(), DatabaseKind::Mysql);
3005        let method_start = code
3006            .find("fn insert_many_transactionally")
3007            .expect("not found");
3008        let method_body = &code[method_start..];
3009        assert!(
3010            method_body.contains("LAST_INSERT_ID"),
3011            "Expected LAST_INSERT_ID for MySQL:\n{}",
3012            method_body
3013        );
3014        assert!(
3015            method_body.contains("self.pool.begin()"),
3016            "Expected begin() for MySQL:\n{}",
3017            method_body
3018        );
3019    }
3020
3021    #[test]
3022    fn test_insert_many_transactionally_view_not_generated() {
3023        let mut entity = standard_entity();
3024        entity.is_view = true;
3025        let code = gen(&entity, DatabaseKind::Postgres);
3026        assert!(
3027            !code.contains("pub async fn insert_many_transactionally"),
3028            "should not be generated for views"
3029        );
3030    }
3031
3032    #[test]
3033    fn test_insert_many_transactionally_without_method_not_generated() {
3034        let m = Methods {
3035            insert_many: false,
3036            ..Methods::all()
3037        };
3038        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
3039        assert!(
3040            !code.contains("pub async fn insert_many_transactionally"),
3041            "should not be generated when disabled"
3042        );
3043    }
3044
3045    #[test]
3046    fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
3047        let m = Methods {
3048            insert: false,
3049            insert_many: true,
3050            ..Default::default()
3051        };
3052        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
3053        assert!(
3054            code.contains("pub struct InsertUsersParams"),
3055            "Expected InsertUsersParams:\n{}",
3056            code
3057        );
3058        assert!(
3059            code.contains("pub async fn insert_many_transactionally"),
3060            "Expected method:\n{}",
3061            code
3062        );
3063        assert!(
3064            !code.contains("pub async fn insert("),
3065            "insert should not be present:\n{}",
3066            code
3067        );
3068    }
3069
3070    #[test]
3071    fn test_insert_many_transactionally_with_column_defaults_coalesce() {
3072        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
3073        let method_start = code
3074            .find("fn insert_many_transactionally")
3075            .expect("not found");
3076        let method_body = &code[method_start..];
3077        assert!(
3078            method_body.contains("COALESCE"),
3079            "Expected COALESCE for fields with defaults:\n{}",
3080            method_body
3081        );
3082    }
3083
3084    #[test]
3085    fn test_insert_many_transactionally_junction_table() {
3086        let code = gen(&junction_entity(), DatabaseKind::Postgres);
3087        assert!(
3088            code.contains("pub async fn insert_many_transactionally"),
3089            "Expected method for junction table:\n{}",
3090            code
3091        );
3092    }
3093
3094    #[test]
3095    fn test_insert_many_transactionally_all_three_backends_compile() {
3096        for db in [
3097            DatabaseKind::Postgres,
3098            DatabaseKind::Mysql,
3099            DatabaseKind::Sqlite,
3100        ] {
3101            let code = gen(&standard_entity(), db);
3102            assert!(
3103                code.contains("pub async fn insert_many_transactionally"),
3104                "Expected method for {:?}",
3105                db
3106            );
3107        }
3108    }
3109}