Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40    // Forward type imports from the entity file (chrono, uuid, etc.)
41    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
42    // since the repository lives in a different module where `super` has a different meaning.
43    let entity_parent = entity_module_path
44        .rsplit_once("::")
45        .map(|(parent, _)| parent)
46        .unwrap_or(entity_module_path);
47    for imp in &entity.imports {
48        if let Some(rest) = imp.strip_prefix("use super::") {
49            imports.insert(format!("use {}::{}", entity_parent, rest));
50        } else {
51            imports.insert(imp.clone());
52        }
53    }
54
55    // Primary key fields
56    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58    // Non-PK fields (for insert)
59    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61    let is_view = entity.is_view;
62
63    // Build method tokens
64    let mut method_tokens = Vec::new();
65    let mut param_structs = Vec::new();
66
67    // --- get_all ---
68    if methods.get_all {
69        let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
70        let method = if use_macro {
71            quote! {
72                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73                    sqlx::query_as!(#entity_ident, #sql)
74                        .fetch_all(&self.pool)
75                        .await
76                }
77            }
78        } else {
79            quote! {
80                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81                    sqlx::query_as::<_, #entity_ident>(#sql)
82                        .fetch_all(&self.pool)
83                        .await
84                }
85            }
86        };
87        method_tokens.push(method);
88    }
89
90    // --- paginate ---
91    if methods.paginate {
92        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95        let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
96        let sql = raw_sql_lit(&match db_kind {
97            DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
98            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name),
99        });
100        let method = if use_macro {
101            quote! {
102                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103                    let total: i64 = sqlx::query_scalar!(#count_sql)
104                        .fetch_one(&self.pool)
105                        .await?
106                        .unwrap_or(0);
107                    let per_page = params.per_page;
108                    let current_page = params.page;
109                    let last_page = (total + per_page - 1) / per_page;
110                    let offset = (current_page - 1) * per_page;
111                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112                        .fetch_all(&self.pool)
113                        .await?;
114                    Ok(#paginated_ident {
115                        meta: #pagination_meta_ident {
116                            total,
117                            per_page,
118                            current_page,
119                            last_page,
120                            first_page: 1,
121                        },
122                        data,
123                    })
124                }
125            }
126        } else {
127            quote! {
128                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129                    let total: i64 = sqlx::query_scalar(#count_sql)
130                        .fetch_one(&self.pool)
131                        .await?;
132                    let per_page = params.per_page;
133                    let current_page = params.page;
134                    let last_page = (total + per_page - 1) / per_page;
135                    let offset = (current_page - 1) * per_page;
136                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
137                        .bind(per_page)
138                        .bind(offset)
139                        .fetch_all(&self.pool)
140                        .await?;
141                    Ok(#paginated_ident {
142                        meta: #pagination_meta_ident {
143                            total,
144                            per_page,
145                            current_page,
146                            last_page,
147                            first_page: 1,
148                        },
149                        data,
150                    })
151                }
152            }
153        };
154        method_tokens.push(method);
155        param_structs.push(quote! {
156            #[derive(Debug, Clone, Default)]
157            pub struct #paginate_params_ident {
158                pub page: i64,
159                pub per_page: i64,
160            }
161        });
162        param_structs.push(quote! {
163            #[derive(Debug, Clone)]
164            pub struct #pagination_meta_ident {
165                pub total: i64,
166                pub per_page: i64,
167                pub current_page: i64,
168                pub last_page: i64,
169                pub first_page: i64,
170            }
171        });
172        param_structs.push(quote! {
173            #[derive(Debug, Clone)]
174            pub struct #paginated_ident {
175                pub meta: #pagination_meta_ident,
176                pub data: Vec<#entity_ident>,
177            }
178        });
179    }
180
181    // --- get (by PK) ---
182    if methods.get && !pk_fields.is_empty() {
183        let pk_params: Vec<TokenStream> = pk_fields
184            .iter()
185            .map(|f| {
186                let name = format_ident!("{}", f.rust_name);
187                let ty: TokenStream = f.inner_type.parse().unwrap();
188                quote! { #name: #ty }
189            })
190            .collect();
191
192        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194        let sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, where_clause));
195        let sql_macro = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, where_clause_cast));
196
197        let binds: Vec<TokenStream> = pk_fields
198            .iter()
199            .map(|f| {
200                let name = format_ident!("{}", f.rust_name);
201                quote! { .bind(#name) }
202            })
203            .collect();
204
205        let method = if use_macro {
206            let pk_arg_names: Vec<TokenStream> = pk_fields
207                .iter()
208                .map(|f| {
209                    let name = format_ident!("{}", f.rust_name);
210                    quote! { #name }
211                })
212                .collect();
213            quote! {
214                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216                        .fetch_optional(&self.pool)
217                        .await
218                }
219            }
220        } else {
221            quote! {
222                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223                    sqlx::query_as::<_, #entity_ident>(#sql)
224                        #(#binds)*
225                        .fetch_optional(&self.pool)
226                        .await
227                }
228            }
229        };
230        method_tokens.push(method);
231    }
232
233    // --- insert (skip for views) ---
234    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237        // When all columns are PKs (e.g. junction tables), use pk_fields for insert
238        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239            pk_fields.clone()
240        } else {
241            non_pk_fields.clone()
242        };
243
244        // Fields with column_default and not already nullable → Option<T>
245        let insert_fields: Vec<TokenStream> = insert_source_fields
246            .iter()
247            .map(|f| {
248                let name = format_ident!("{}", f.rust_name);
249                if f.column_default.is_some() && !f.is_nullable {
250                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
251                    quote! { pub #name: #ty, }
252                } else {
253                    let ty: TokenStream = f.rust_type.parse().unwrap();
254                    quote! { pub #name: #ty, }
255                }
256            })
257            .collect();
258
259        let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
260        let col_list = col_names.join(", ");
261
262        // Build placeholders with COALESCE for fields that have a column_default
263        let placeholders: String = insert_source_fields
264            .iter()
265            .enumerate()
266            .map(|(i, f)| {
267                let p = placeholder(db_kind, i + 1);
268                match &f.column_default {
269                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
270                    None => p,
271                }
272            })
273            .collect::<Vec<_>>()
274            .join(", ");
275
276        let placeholders_cast: String = insert_source_fields
277            .iter()
278            .enumerate()
279            .map(|(i, f)| {
280                let p = placeholder_with_cast(db_kind, i + 1, f);
281                match &f.column_default {
282                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
283                    None => p,
284                }
285            })
286            .collect::<Vec<_>>()
287            .join(", ");
288
289        let build_insert_sql = |ph: &str| match db_kind {
290            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
291                format!(
292                    "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
293                    table_name, col_list, ph
294                )
295            }
296            DatabaseKind::Mysql => {
297                format!(
298                    "INSERT INTO {} ({})\nVALUES ({})",
299                    table_name, col_list, ph
300                )
301            }
302        };
303        let sql = build_insert_sql(&placeholders);
304        let sql_macro = build_insert_sql(&placeholders_cast);
305
306        let binds: Vec<TokenStream> = insert_source_fields
307            .iter()
308            .map(|f| {
309                let name = format_ident!("{}", f.rust_name);
310                quote! { .bind(&params.#name) }
311            })
312            .collect();
313
314        let insert_method = build_insert_method_parsed(
315            &entity_ident,
316            &insert_params_ident,
317            &sql,
318            &sql_macro,
319            &binds,
320            db_kind,
321            &table_name,
322            &pk_fields,
323            &insert_source_fields,
324            use_macro,
325        );
326        method_tokens.push(insert_method);
327
328        param_structs.push(quote! {
329            #[derive(Debug, Clone, Default)]
330            pub struct #insert_params_ident {
331                #(#insert_fields)*
332            }
333        });
334    }
335
336    // --- insert_many_transactionally (skip for views) ---
337    if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
338        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
339
340        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
341            pk_fields.clone()
342        } else {
343            non_pk_fields.clone()
344        };
345
346        let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
347        let col_list = col_names.join(", ");
348        let num_cols = insert_source_fields.len();
349
350        let binds_loop: Vec<TokenStream> = insert_source_fields
351            .iter()
352            .map(|f| {
353                let name = format_ident!("{}", f.rust_name);
354                quote! { query = query.bind(&params.#name); }
355            })
356            .collect();
357
358        let insert_many_method = build_insert_many_transactionally_method(
359            &entity_ident,
360            &insert_params_ident,
361            &col_list,
362            num_cols,
363            &insert_source_fields,
364            &binds_loop,
365            db_kind,
366            &table_name,
367            &pk_fields,
368        );
369        method_tokens.push(insert_many_method);
370
371        // Only generate InsertParams if we haven't generated it from the insert method
372        if !methods.insert {
373            let insert_fields: Vec<TokenStream> = insert_source_fields
374                .iter()
375                .map(|f| {
376                    let name = format_ident!("{}", f.rust_name);
377                    if f.column_default.is_some() && !f.is_nullable {
378                        let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
379                        quote! { pub #name: #ty, }
380                    } else {
381                        let ty: TokenStream = f.rust_type.parse().unwrap();
382                        quote! { pub #name: #ty, }
383                    }
384                })
385                .collect();
386
387            param_structs.push(quote! {
388                #[derive(Debug, Clone, Default)]
389                pub struct #insert_params_ident {
390                    #(#insert_fields)*
391                }
392            });
393        }
394    }
395
396    // --- overwrite (full replacement — skip for views, skip when all columns are PKs) ---
397    if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
398        let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
399
400        // PK as function parameters (like get/delete)
401        let pk_fn_params: Vec<TokenStream> = pk_fields
402            .iter()
403            .map(|f| {
404                let name = format_ident!("{}", f.rust_name);
405                let ty: TokenStream = f.inner_type.parse().unwrap();
406                quote! { #name: #ty }
407            })
408            .collect();
409
410        // Non-PK fields keep original types (required)
411        let overwrite_fields: Vec<TokenStream> = non_pk_fields
412            .iter()
413            .map(|f| {
414                let name = format_ident!("{}", f.rust_name);
415                let ty: TokenStream = f.rust_type.parse().unwrap();
416                quote! { pub #name: #ty, }
417            })
418            .collect();
419
420        let set_cols: Vec<String> = non_pk_fields
421            .iter()
422            .enumerate()
423            .map(|(i, f)| {
424                let p = placeholder(db_kind, i + 1);
425                format!("{} = {}", f.column_name, p)
426            })
427            .collect();
428        let set_clause = set_cols.join(",\n  ");
429
430        let set_cols_cast: Vec<String> = non_pk_fields
431            .iter()
432            .enumerate()
433            .map(|(i, f)| {
434                let p = placeholder_with_cast(db_kind, i + 1, f);
435                format!("{} = {}", f.column_name, p)
436            })
437            .collect();
438        let set_clause_cast = set_cols_cast.join(",\n  ");
439
440        let pk_start = non_pk_fields.len() + 1;
441        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
442        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
443
444        let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
445            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
446                format!("UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *", table_name, sc, wc)
447            }
448            DatabaseKind::Mysql => {
449                format!("UPDATE {}\nSET\n  {}\nWHERE {}", table_name, sc, wc)
450            }
451        };
452        let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
453        let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
454
455        // Bind non-PK first (from params), then PK (from function args)
456        let mut all_binds: Vec<TokenStream> = non_pk_fields
457            .iter()
458            .map(|f| {
459                let name = format_ident!("{}", f.rust_name);
460                quote! { .bind(&params.#name) }
461            })
462            .collect();
463        for f in &pk_fields {
464            let name = format_ident!("{}", f.rust_name);
465            all_binds.push(quote! { .bind(#name) });
466        }
467
468        // Macro args: non-PK from params, then PK from function args
469        let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
470            .iter()
471            .map(|f| macro_arg_for_field(f))
472            .chain(pk_fields.iter().map(|f| {
473                let name = format_ident!("{}", f.rust_name);
474                quote! { #name }
475            }))
476            .collect();
477
478        let overwrite_method = if use_macro {
479            match db_kind {
480                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
481                    quote! {
482                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
483                            sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
484                                .fetch_one(&self.pool)
485                                .await
486                        }
487                    }
488                }
489                DatabaseKind::Mysql => {
490                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
491                    let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
492                    let pk_macro_args: Vec<TokenStream> = pk_fields
493                        .iter()
494                        .map(|f| {
495                            let name = format_ident!("{}", f.rust_name);
496                            quote! { #name }
497                        })
498                        .collect();
499                    quote! {
500                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
501                            sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
502                                .execute(&self.pool)
503                                .await?;
504                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
505                                .fetch_one(&self.pool)
506                                .await
507                        }
508                    }
509                }
510            }
511        } else {
512            match db_kind {
513                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
514                    quote! {
515                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
516                            sqlx::query_as::<_, #entity_ident>(#sql)
517                                #(#all_binds)*
518                                .fetch_one(&self.pool)
519                                .await
520                        }
521                    }
522                }
523                DatabaseKind::Mysql => {
524                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
525                    let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
526                    let pk_binds: Vec<TokenStream> = pk_fields
527                        .iter()
528                        .map(|f| {
529                            let name = format_ident!("{}", f.rust_name);
530                            quote! { .bind(#name) }
531                        })
532                        .collect();
533                    quote! {
534                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
535                            sqlx::query(#sql)
536                                #(#all_binds)*
537                                .execute(&self.pool)
538                                .await?;
539                            sqlx::query_as::<_, #entity_ident>(#select_sql)
540                                #(#pk_binds)*
541                                .fetch_one(&self.pool)
542                                .await
543                        }
544                    }
545                }
546            }
547        };
548        method_tokens.push(overwrite_method);
549
550        param_structs.push(quote! {
551            #[derive(Debug, Clone, Default)]
552            pub struct #overwrite_params_ident {
553                #(#overwrite_fields)*
554            }
555        });
556    }
557
558    // --- update / patch (COALESCE — skip for views, skip when all columns are PKs) ---
559    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
560        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
561
562        // PK as function parameters (like get/delete)
563        let pk_fn_params: Vec<TokenStream> = pk_fields
564            .iter()
565            .map(|f| {
566                let name = format_ident!("{}", f.rust_name);
567                let ty: TokenStream = f.inner_type.parse().unwrap();
568                quote! { #name: #ty }
569            })
570            .collect();
571
572        // Non-PK fields become Option<T> (no double Option for already nullable)
573        let update_fields: Vec<TokenStream> = non_pk_fields
574            .iter()
575            .map(|f| {
576                let name = format_ident!("{}", f.rust_name);
577                if f.is_nullable {
578                    // Already Option<T> — keep as-is to avoid Option<Option<T>>
579                    let ty: TokenStream = f.rust_type.parse().unwrap();
580                    quote! { pub #name: #ty, }
581                } else {
582                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
583                    quote! { pub #name: #ty, }
584                }
585            })
586            .collect();
587
588        // SET clause with COALESCE for runtime mode
589        let set_cols: Vec<String> = non_pk_fields
590            .iter()
591            .enumerate()
592            .map(|(i, f)| {
593                let p = placeholder(db_kind, i + 1);
594                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
595            })
596            .collect();
597        let set_clause = set_cols.join(",\n  ");
598
599        // SET clause with COALESCE and casts for macro mode
600        let set_cols_cast: Vec<String> = non_pk_fields
601            .iter()
602            .enumerate()
603            .map(|(i, f)| {
604                let p = placeholder_with_cast(db_kind, i + 1, f);
605                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
606            })
607            .collect();
608        let set_clause_cast = set_cols_cast.join(",\n  ");
609
610        let pk_start = non_pk_fields.len() + 1;
611        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
612        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
613
614        let build_update_sql = |sc: &str, wc: &str| match db_kind {
615            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
616                format!(
617                    "UPDATE {}\nSET\n  {}\nWHERE {}\nRETURNING *",
618                    table_name, sc, wc
619                )
620            }
621            DatabaseKind::Mysql => {
622                format!(
623                    "UPDATE {}\nSET\n  {}\nWHERE {}",
624                    table_name, sc, wc
625                )
626            }
627        };
628        let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
629        let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
630
631        // Bind non-PK first (from params), then PK (from function args)
632        let mut all_binds: Vec<TokenStream> = non_pk_fields
633            .iter()
634            .map(|f| {
635                let name = format_ident!("{}", f.rust_name);
636                quote! { .bind(&params.#name) }
637            })
638            .collect();
639        for f in &pk_fields {
640            let name = format_ident!("{}", f.rust_name);
641            all_binds.push(quote! { .bind(#name) });
642        }
643
644        // Macro args: non-PK from params, then PK from function args
645        let update_macro_args: Vec<TokenStream> = non_pk_fields
646            .iter()
647            .map(|f| macro_arg_for_field(f))
648            .chain(pk_fields.iter().map(|f| {
649                let name = format_ident!("{}", f.rust_name);
650                quote! { #name }
651            }))
652            .collect();
653
654        let update_method = if use_macro {
655            match db_kind {
656                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
657                    quote! {
658                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
659                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
660                                .fetch_one(&self.pool)
661                                .await
662                        }
663                    }
664                }
665                DatabaseKind::Mysql => {
666                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
667                    let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
668                    let pk_macro_args: Vec<TokenStream> = pk_fields
669                        .iter()
670                        .map(|f| {
671                            let name = format_ident!("{}", f.rust_name);
672                            quote! { #name }
673                        })
674                        .collect();
675                    quote! {
676                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
677                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
678                                .execute(&self.pool)
679                                .await?;
680                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
681                                .fetch_one(&self.pool)
682                                .await
683                        }
684                    }
685                }
686            }
687        } else {
688            match db_kind {
689                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
690                    quote! {
691                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
692                            sqlx::query_as::<_, #entity_ident>(#sql)
693                                #(#all_binds)*
694                                .fetch_one(&self.pool)
695                                .await
696                        }
697                    }
698                }
699                DatabaseKind::Mysql => {
700                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
701                    let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
702                    let pk_binds: Vec<TokenStream> = pk_fields
703                        .iter()
704                        .map(|f| {
705                            let name = format_ident!("{}", f.rust_name);
706                            quote! { .bind(#name) }
707                        })
708                        .collect();
709                    quote! {
710                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
711                            sqlx::query(#sql)
712                                #(#all_binds)*
713                                .execute(&self.pool)
714                                .await?;
715                            sqlx::query_as::<_, #entity_ident>(#select_sql)
716                                #(#pk_binds)*
717                                .fetch_one(&self.pool)
718                                .await
719                        }
720                    }
721                }
722            }
723        };
724        method_tokens.push(update_method);
725
726        param_structs.push(quote! {
727            #[derive(Debug, Clone, Default)]
728            pub struct #update_params_ident {
729                #(#update_fields)*
730            }
731        });
732    }
733
734    // --- delete (skip for views) ---
735    if !is_view && methods.delete && !pk_fields.is_empty() {
736        let pk_params: Vec<TokenStream> = pk_fields
737            .iter()
738            .map(|f| {
739                let name = format_ident!("{}", f.rust_name);
740                let ty: TokenStream = f.inner_type.parse().unwrap();
741                quote! { #name: #ty }
742            })
743            .collect();
744
745        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
746        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
747        let sql = raw_sql_lit(&format!("DELETE FROM {}\nWHERE {}", table_name, where_clause));
748        let sql_macro = raw_sql_lit(&format!("DELETE FROM {}\nWHERE {}", table_name, where_clause_cast));
749
750        let binds: Vec<TokenStream> = pk_fields
751            .iter()
752            .map(|f| {
753                let name = format_ident!("{}", f.rust_name);
754                quote! { .bind(#name) }
755            })
756            .collect();
757
758        let method = if query_macro {
759            let pk_arg_names: Vec<TokenStream> = pk_fields
760                .iter()
761                .map(|f| {
762                    let name = format_ident!("{}", f.rust_name);
763                    quote! { #name }
764                })
765                .collect();
766            quote! {
767                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
768                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
769                        .execute(&self.pool)
770                        .await?;
771                    Ok(())
772                }
773            }
774        } else {
775            quote! {
776                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
777                    sqlx::query(#sql)
778                        #(#binds)*
779                        .execute(&self.pool)
780                        .await?;
781                    Ok(())
782                }
783            }
784        };
785        method_tokens.push(method);
786    }
787
788    let pool_vis: TokenStream = match pool_visibility {
789        PoolVisibility::Private => quote! {},
790        PoolVisibility::Pub => quote! { pub },
791        PoolVisibility::PubCrate => quote! { pub(crate) },
792    };
793
794    let tokens = quote! {
795        #(#param_structs)*
796
797        pub struct #repo_ident {
798            #pool_vis pool: #pool_type,
799        }
800
801        impl #repo_ident {
802            pub fn new(pool: #pool_type) -> Self {
803                Self { pool }
804            }
805
806            #(#method_tokens)*
807        }
808    };
809
810    (tokens, imports)
811}
812
813fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
814    match db_kind {
815        DatabaseKind::Postgres => quote! { sqlx::PgPool },
816        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
817        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
818    }
819}
820
821/// Wraps a SQL string as a raw string literal `r#"..."#` in the generated code.
822/// Multi-line SQL gets a leading newline so each clause starts on its own line.
823fn raw_sql_lit(s: &str) -> TokenStream {
824    if s.contains('\n') {
825        format!("r#\"\n{}\n\"#", s).parse().unwrap()
826    } else {
827        format!("r#\"{}\"#", s).parse().unwrap()
828    }
829}
830
831fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
832    match db_kind {
833        DatabaseKind::Postgres => format!("${}", index),
834        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
835    }
836}
837
838fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
839    let base = placeholder(db_kind, index);
840    match (&field.sql_type, field.is_sql_array) {
841        (Some(t), true) => format!("{} as {}[]", base, t),
842        (Some(t), false) => format!("{} as {}", base, t),
843        (None, _) => base,
844    }
845}
846
847fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
848    (0..count)
849        .map(|i| placeholder(db_kind, start + i))
850        .collect::<Vec<_>>()
851        .join(", ")
852}
853
854fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
855    fields
856        .iter()
857        .enumerate()
858        .map(|(i, f)| {
859            if use_cast {
860                placeholder_with_cast(db_kind, start + i, f)
861            } else {
862                placeholder(db_kind, start + i)
863            }
864        })
865        .collect::<Vec<_>>()
866        .join(", ")
867}
868
869fn build_where_clause_parsed(
870    pk_fields: &[&ParsedField],
871    db_kind: DatabaseKind,
872    start_index: usize,
873) -> String {
874    pk_fields
875        .iter()
876        .enumerate()
877        .map(|(i, f)| {
878            let p = placeholder(db_kind, start_index + i);
879            format!("{} = {}", f.column_name, p)
880        })
881        .collect::<Vec<_>>()
882        .join(" AND ")
883}
884
885fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
886    let name = format_ident!("{}", field.rust_name);
887    let check_type = if field.is_nullable {
888        &field.inner_type
889    } else {
890        &field.rust_type
891    };
892    let normalized = check_type.replace(' ', "");
893    if normalized.starts_with("Vec<") {
894        quote! { params.#name.as_slice() }
895    } else {
896        quote! { params.#name }
897    }
898}
899
900fn build_where_clause_cast(
901    pk_fields: &[&ParsedField],
902    db_kind: DatabaseKind,
903    start_index: usize,
904) -> String {
905    pk_fields
906        .iter()
907        .enumerate()
908        .map(|(i, f)| {
909            let p = placeholder_with_cast(db_kind, start_index + i, f);
910            format!("{} = {}", f.column_name, p)
911        })
912        .collect::<Vec<_>>()
913        .join(" AND ")
914}
915
916#[allow(clippy::too_many_arguments)]
917fn build_insert_method_parsed(
918    entity_ident: &proc_macro2::Ident,
919    insert_params_ident: &proc_macro2::Ident,
920    sql: &str,
921    sql_macro: &str,
922    binds: &[TokenStream],
923    db_kind: DatabaseKind,
924    table_name: &str,
925    pk_fields: &[&ParsedField],
926    non_pk_fields: &[&ParsedField],
927    use_macro: bool,
928) -> TokenStream {
929    let sql = raw_sql_lit(sql);
930    let sql_macro = raw_sql_lit(sql_macro);
931
932    if use_macro {
933        let macro_args: Vec<TokenStream> = non_pk_fields
934            .iter()
935            .map(|f| macro_arg_for_field(f))
936            .collect();
937
938        match db_kind {
939            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
940                quote! {
941                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
942                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
943                            .fetch_one(&self.pool)
944                            .await
945                    }
946                }
947            }
948            DatabaseKind::Mysql => {
949                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
950                let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
951                let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
952                quote! {
953                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
954                        sqlx::query!(#sql_macro, #(#macro_args),*)
955                            .execute(&self.pool)
956                            .await?;
957                        let id = sqlx::query_scalar!(#last_insert_id_sql)
958                            .fetch_one(&self.pool)
959                            .await?;
960                        sqlx::query_as!(#entity_ident, #select_sql, id)
961                            .fetch_one(&self.pool)
962                            .await
963                    }
964                }
965            }
966        }
967    } else {
968        match db_kind {
969            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
970                quote! {
971                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
972                        sqlx::query_as::<_, #entity_ident>(#sql)
973                            #(#binds)*
974                            .fetch_one(&self.pool)
975                            .await
976                    }
977                }
978            }
979            DatabaseKind::Mysql => {
980                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
981                let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
982                let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
983                quote! {
984                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
985                        sqlx::query(#sql)
986                            #(#binds)*
987                            .execute(&self.pool)
988                            .await?;
989                        let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
990                            .fetch_one(&self.pool)
991                            .await?;
992                        sqlx::query_as::<_, #entity_ident>(#select_sql)
993                            .bind(id)
994                            .fetch_one(&self.pool)
995                            .await
996                    }
997                }
998            }
999        }
1000    }
1001}
1002
1003#[allow(clippy::too_many_arguments)]
1004fn build_insert_many_transactionally_method(
1005    entity_ident: &proc_macro2::Ident,
1006    insert_params_ident: &proc_macro2::Ident,
1007    col_list: &str,
1008    num_cols: usize,
1009    insert_source_fields: &[&ParsedField],
1010    binds_loop: &[TokenStream],
1011    db_kind: DatabaseKind,
1012    table_name: &str,
1013    pk_fields: &[&ParsedField],
1014) -> TokenStream {
1015    let body = match db_kind {
1016        DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1017            let col_list_str = col_list.to_string();
1018            let table_name_str = table_name.to_string();
1019
1020            let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1021                .iter()
1022                .enumerate()
1023                .map(|(i, f)| {
1024                    let offset = i;
1025                    match &f.column_default {
1026                        Some(default_expr) => {
1027                            let def = default_expr.as_str();
1028                            match db_kind {
1029                                DatabaseKind::Postgres => quote! {
1030                                    format!("COALESCE(${}, {})", base + #offset + 1, #def)
1031                                },
1032                                _ => quote! {
1033                                    format!("COALESCE(?, {})", #def)
1034                                },
1035                            }
1036                        }
1037                        None => {
1038                            match db_kind {
1039                                DatabaseKind::Postgres => quote! {
1040                                    format!("${}", base + #offset + 1)
1041                                },
1042                                _ => quote! {
1043                                    "?".to_string()
1044                                },
1045                            }
1046                        }
1047                    }
1048                })
1049                .collect();
1050
1051            quote! {
1052                let mut tx = self.pool.begin().await?;
1053                let mut all_results = Vec::with_capacity(entries.len());
1054                let max_per_chunk = 65535 / #num_cols;
1055                for chunk in entries.chunks(max_per_chunk) {
1056                    let mut values_parts = Vec::with_capacity(chunk.len());
1057                    for (row_idx, _) in chunk.iter().enumerate() {
1058                        let base = row_idx * #num_cols;
1059                        let placeholders = vec![#(#row_placeholder_exprs),*];
1060                        values_parts.push(format!("({})", placeholders.join(", ")));
1061                    }
1062                    let sql = format!(
1063                        "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1064                        #table_name_str,
1065                        #col_list_str,
1066                        values_parts.join(", ")
1067                    );
1068                    let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1069                    for params in chunk {
1070                        #(#binds_loop)*
1071                    }
1072                    let rows = query.fetch_all(&mut *tx).await?;
1073                    all_results.extend(rows);
1074                }
1075                tx.commit().await?;
1076                Ok(all_results)
1077            }
1078        }
1079        DatabaseKind::Mysql => {
1080            let single_placeholders: String = insert_source_fields
1081                .iter()
1082                .enumerate()
1083                .map(|(i, f)| {
1084                    let p = placeholder(db_kind, i + 1);
1085                    match &f.column_default {
1086                        Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1087                        None => p,
1088                    }
1089                })
1090                .collect::<Vec<_>>()
1091                .join(", ");
1092
1093            let single_insert_sql = raw_sql_lit(&format!(
1094                "INSERT INTO {} ({})\nVALUES ({})",
1095                table_name, col_list, single_placeholders
1096            ));
1097
1098            let single_binds: Vec<TokenStream> = insert_source_fields
1099                .iter()
1100                .map(|f| {
1101                    let name = format_ident!("{}", f.rust_name);
1102                    quote! { .bind(&params.#name) }
1103                })
1104                .collect();
1105
1106            let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1107            let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
1108            let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1109
1110            quote! {
1111                let mut tx = self.pool.begin().await?;
1112                let mut results = Vec::with_capacity(entries.len());
1113                for params in &entries {
1114                    sqlx::query(#single_insert_sql)
1115                        #(#single_binds)*
1116                        .execute(&mut *tx)
1117                        .await?;
1118                    let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1119                        .fetch_one(&mut *tx)
1120                        .await?;
1121                    let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1122                        .bind(id)
1123                        .fetch_one(&mut *tx)
1124                        .await?;
1125                    results.push(row);
1126                }
1127                tx.commit().await?;
1128                Ok(results)
1129            }
1130        }
1131    };
1132
1133    quote! {
1134        pub async fn insert_many_transactionally(
1135            &self,
1136            entries: Vec<#insert_params_ident>,
1137        ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1138            #body
1139        }
1140    }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146    use crate::codegen::parse_and_format_with_tab_spaces;
1147    use crate::codegen::parse_and_format;
1148    use crate::cli::Methods;
1149
1150    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
1151        let inner_type = if nullable {
1152            // Strip "Option<" prefix and ">" suffix
1153            rust_type
1154                .strip_prefix("Option<")
1155                .and_then(|s| s.strip_suffix('>'))
1156                .unwrap_or(rust_type)
1157                .to_string()
1158        } else {
1159            rust_type.to_string()
1160        };
1161        ParsedField {
1162            rust_name: rust_name.to_string(),
1163            column_name: column_name.to_string(),
1164            rust_type: rust_type.to_string(),
1165            is_nullable: nullable,
1166            inner_type,
1167            is_primary_key: is_pk,
1168            sql_type: None,
1169            is_sql_array: false,
1170            column_default: None,
1171        }
1172    }
1173
1174    fn make_field_with_default(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool, default: &str) -> ParsedField {
1175        let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1176        f.column_default = Some(default.to_string());
1177        f
1178    }
1179
1180    fn entity_with_defaults() -> ParsedEntity {
1181        ParsedEntity {
1182            struct_name: "Tasks".to_string(),
1183            table_name: "tasks".to_string(),
1184            schema_name: None,
1185            is_view: false,
1186            fields: vec![
1187                make_field("id", "id", "i32", false, true),
1188                make_field("title", "title", "String", false, false),
1189                make_field_with_default("status", "status", "String", false, false, "'idle'::task_status"),
1190                make_field_with_default("priority", "priority", "i32", false, false, "0"),
1191                make_field_with_default("created_at", "created_at", "DateTime<Utc>", false, false, "now()"),
1192                make_field("description", "description", "Option<String>", true, false),
1193                make_field_with_default("deleted_at", "deleted_at", "Option<DateTime<Utc>>", true, false, "NULL"),
1194            ],
1195            imports: vec![],
1196        }
1197    }
1198
1199    fn standard_entity() -> ParsedEntity {
1200        ParsedEntity {
1201            struct_name: "Users".to_string(),
1202            table_name: "users".to_string(),
1203            schema_name: None,
1204            is_view: false,
1205            fields: vec![
1206                make_field("id", "id", "i32", false, true),
1207                make_field("name", "name", "String", false, false),
1208                make_field("email", "email", "Option<String>", true, false),
1209            ],
1210            imports: vec![],
1211        }
1212    }
1213
1214    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1215        let skip = Methods::all();
1216        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1217        parse_and_format(&tokens)
1218    }
1219
1220    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1221        let skip = Methods::all();
1222        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
1223        parse_and_format(&tokens)
1224    }
1225
1226    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1227        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
1228        parse_and_format(&tokens)
1229    }
1230
1231    fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1232        let skip = Methods::all();
1233        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1234        parse_and_format_with_tab_spaces(&tokens, tab_spaces)
1235    }
1236
1237    // --- basic structure ---
1238
1239    #[test]
1240    fn test_repo_struct_name() {
1241        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1242        assert!(code.contains("pub struct UsersRepository"));
1243    }
1244
1245    #[test]
1246    fn test_repo_new_method() {
1247        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1248        assert!(code.contains("pub fn new("));
1249    }
1250
1251    #[test]
1252    fn test_repo_pool_field_pg() {
1253        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1254        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1255    }
1256
1257    #[test]
1258    fn test_repo_pool_field_pub() {
1259        let skip = Methods::all();
1260        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
1261        let code = parse_and_format(&tokens);
1262        assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
1263    }
1264
1265    #[test]
1266    fn test_repo_pool_field_pub_crate() {
1267        let skip = Methods::all();
1268        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
1269        let code = parse_and_format(&tokens);
1270        assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
1271    }
1272
1273    #[test]
1274    fn test_repo_pool_field_private() {
1275        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1276        // Should NOT have `pub pool` or `pub(crate) pool`
1277        assert!(!code.contains("pub pool"));
1278        assert!(!code.contains("pub(crate) pool"));
1279    }
1280
1281    #[test]
1282    fn test_repo_pool_field_mysql() {
1283        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1284        assert!(code.contains("MySqlPool") || code.contains("MySql"));
1285    }
1286
1287    #[test]
1288    fn test_repo_pool_field_sqlite() {
1289        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1290        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1291    }
1292
1293    // --- get_all ---
1294
1295    #[test]
1296    fn test_get_all_method() {
1297        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1298        assert!(code.contains("pub async fn get_all"));
1299    }
1300
1301    #[test]
1302    fn test_get_all_returns_vec() {
1303        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1304        assert!(code.contains("Vec<Users>"));
1305    }
1306
1307    #[test]
1308    fn test_get_all_sql() {
1309        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1310        assert!(code.contains("SELECT * FROM users"));
1311    }
1312
1313    // --- paginate ---
1314
1315    #[test]
1316    fn test_paginate_method() {
1317        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1318        assert!(code.contains("pub async fn paginate"));
1319    }
1320
1321    #[test]
1322    fn test_paginate_params_struct() {
1323        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1324        assert!(code.contains("pub struct PaginateUsersParams"));
1325    }
1326
1327    #[test]
1328    fn test_paginate_params_fields() {
1329        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1330        assert!(code.contains("pub page: i64"));
1331        assert!(code.contains("pub per_page: i64"));
1332    }
1333
1334    #[test]
1335    fn test_paginate_returns_paginated() {
1336        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1337        assert!(code.contains("PaginatedUsers"));
1338        assert!(code.contains("PaginationUsersMeta"));
1339    }
1340
1341    #[test]
1342    fn test_paginate_meta_struct() {
1343        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1344        assert!(code.contains("pub struct PaginationUsersMeta"));
1345        assert!(code.contains("pub total: i64"));
1346        assert!(code.contains("pub last_page: i64"));
1347        assert!(code.contains("pub first_page: i64"));
1348        assert!(code.contains("pub current_page: i64"));
1349    }
1350
1351    #[test]
1352    fn test_paginate_data_struct() {
1353        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1354        assert!(code.contains("pub struct PaginatedUsers"));
1355        assert!(code.contains("pub meta: PaginationUsersMeta"));
1356        assert!(code.contains("pub data: Vec<Users>"));
1357    }
1358
1359    #[test]
1360    fn test_paginate_count_sql() {
1361        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1362        assert!(code.contains("SELECT COUNT(*) FROM users"));
1363    }
1364
1365    #[test]
1366    fn test_paginate_sql_pg() {
1367        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1368        assert!(code.contains("LIMIT $1 OFFSET $2"));
1369    }
1370
1371    #[test]
1372    fn test_paginate_sql_mysql() {
1373        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1374        assert!(code.contains("LIMIT ? OFFSET ?"));
1375    }
1376
1377    // --- get ---
1378
1379    #[test]
1380    fn test_get_method() {
1381        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1382        assert!(code.contains("pub async fn get"));
1383    }
1384
1385    #[test]
1386    fn test_get_returns_option() {
1387        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1388        assert!(code.contains("Option<Users>"));
1389    }
1390
1391    #[test]
1392    fn test_get_where_pk_pg() {
1393        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1394        assert!(code.contains("WHERE id = $1"));
1395    }
1396
1397    #[test]
1398    fn test_get_where_pk_mysql() {
1399        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1400        assert!(code.contains("WHERE id = ?"));
1401    }
1402
1403    // --- insert ---
1404
1405    #[test]
1406    fn test_insert_method() {
1407        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1408        assert!(code.contains("pub async fn insert"));
1409    }
1410
1411    #[test]
1412    fn test_insert_params_struct() {
1413        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1414        assert!(code.contains("pub struct InsertUsersParams"));
1415    }
1416
1417    #[test]
1418    fn test_insert_params_no_pk() {
1419        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1420        assert!(code.contains("pub name: String"));
1421        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1422    }
1423
1424    #[test]
1425    fn test_insert_returning_pg() {
1426        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1427        assert!(code.contains("RETURNING *"));
1428    }
1429
1430    #[test]
1431    fn test_insert_returning_sqlite() {
1432        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1433        assert!(code.contains("RETURNING *"));
1434    }
1435
1436    #[test]
1437    fn test_insert_mysql_last_insert_id() {
1438        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1439        assert!(code.contains("LAST_INSERT_ID"));
1440    }
1441
1442    // --- insert with column_default ---
1443
1444    #[test]
1445    fn test_insert_default_col_is_optional() {
1446        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1447        // Fields with column_default and not nullable → Option<T>
1448        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1449        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1450        let struct_body = &code[struct_start..struct_end];
1451        assert!(struct_body.contains("Option") && struct_body.contains("status"), "Expected status as Option in InsertTasksParams: {}", struct_body);
1452    }
1453
1454    #[test]
1455    fn test_insert_non_default_col_required() {
1456        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1457        // 'title' has no default → required type (String)
1458        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1459        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1460        let struct_body = &code[struct_start..struct_end];
1461        assert!(struct_body.contains("title") && struct_body.contains("String"), "Expected title as String: {}", struct_body);
1462    }
1463
1464    #[test]
1465    fn test_insert_default_col_coalesce_sql() {
1466        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1467        assert!(code.contains("COALESCE($2, 'idle'::task_status)"), "Expected COALESCE for status:\n{}", code);
1468        assert!(code.contains("COALESCE($3, 0)"), "Expected COALESCE for priority:\n{}", code);
1469        assert!(code.contains("COALESCE($4, now())"), "Expected COALESCE for created_at:\n{}", code);
1470    }
1471
1472    #[test]
1473    fn test_insert_no_coalesce_for_non_default() {
1474        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1475        // title has no default, so its placeholder should be plain $1 not COALESCE
1476        assert!(code.contains("VALUES ($1, COALESCE"), "Expected $1 without COALESCE for title:\n{}", code);
1477    }
1478
1479    #[test]
1480    fn test_insert_nullable_with_default_no_double_option() {
1481        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1482        assert!(!code.contains("Option < Option") && !code.contains("Option<Option"), "Should not have Option<Option>:\n{}", code);
1483    }
1484
1485    #[test]
1486    fn test_insert_derive_default() {
1487        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1488        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1489        let before_struct = &code[..struct_start];
1490        assert!(before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"), "Expected #[derive(Default)] on InsertTasksParams");
1491    }
1492
1493    // --- update (patch with COALESCE) ---
1494
1495    #[test]
1496    fn test_update_method() {
1497        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1498        assert!(code.contains("pub async fn update"));
1499    }
1500
1501    #[test]
1502    fn test_update_params_struct() {
1503        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1504        assert!(code.contains("pub struct UpdateUsersParams"));
1505    }
1506
1507    #[test]
1508    fn test_update_pk_in_fn_signature() {
1509        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1510        // PK 'id: i32' should appear between "fn update" and "UpdateUsersParams"
1511        let update_pos = code.find("fn update").expect("fn update not found");
1512        let params_pos = code[update_pos..].find("UpdateUsersParams").expect("UpdateUsersParams not found in update fn");
1513        let signature = &code[update_pos..update_pos + params_pos];
1514        assert!(signature.contains("id"), "Expected 'id' PK in update fn signature: {}", signature);
1515    }
1516
1517    #[test]
1518    fn test_update_pk_not_in_struct() {
1519        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1520        // UpdateUsersParams should NOT contain id field
1521        // Extract the struct definition and check it doesn't have id
1522        let struct_start = code.find("pub struct UpdateUsersParams").expect("UpdateUsersParams not found");
1523        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1524        let struct_body = &code[struct_start..struct_end];
1525        assert!(!struct_body.contains("pub id"), "PK 'id' should not be in UpdateUsersParams:\n{}", struct_body);
1526    }
1527
1528    #[test]
1529    fn test_update_params_non_nullable_wrapped_in_option() {
1530        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1531        // `name: String` becomes `name: Option<String>` in patch params
1532        assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1533    }
1534
1535    #[test]
1536    fn test_update_params_already_nullable_no_double_option() {
1537        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1538        // `email: Option<String>` stays `Option<String>`, NOT `Option<Option<String>>`
1539        assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1540    }
1541
1542    #[test]
1543    fn test_update_set_clause_uses_coalesce_pg() {
1544        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1545        assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1546        assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1547    }
1548
1549    #[test]
1550    fn test_update_where_clause_pg() {
1551        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1552        assert!(code.contains("WHERE id = $3"));
1553    }
1554
1555    #[test]
1556    fn test_update_returning_pg() {
1557        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1558        assert!(code.contains("COALESCE"));
1559        assert!(code.contains("RETURNING *"));
1560    }
1561
1562    #[test]
1563    fn test_update_set_clause_mysql() {
1564        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1565        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1566        assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1567    }
1568
1569    #[test]
1570    fn test_update_set_clause_sqlite() {
1571        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1572        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1573    }
1574
1575    // --- overwrite (full replacement, PK as fn param) ---
1576
1577    #[test]
1578    fn test_overwrite_method() {
1579        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1580        assert!(code.contains("pub async fn overwrite"));
1581    }
1582
1583    #[test]
1584    fn test_overwrite_params_struct() {
1585        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1586        assert!(code.contains("pub struct OverwriteUsersParams"));
1587    }
1588
1589    #[test]
1590    fn test_overwrite_pk_in_fn_signature() {
1591        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1592        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1593        let params_pos = code[pos..].find("OverwriteUsersParams").expect("OverwriteUsersParams not found");
1594        let signature = &code[pos..pos + params_pos];
1595        assert!(signature.contains("id"), "Expected PK in overwrite fn signature: {}", signature);
1596    }
1597
1598    #[test]
1599    fn test_overwrite_pk_not_in_struct() {
1600        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1601        let struct_start = code.find("pub struct OverwriteUsersParams").expect("OverwriteUsersParams not found");
1602        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1603        let struct_body = &code[struct_start..struct_end];
1604        assert!(!struct_body.contains("pub id"), "PK should not be in OverwriteUsersParams: {}", struct_body);
1605    }
1606
1607    #[test]
1608    fn test_overwrite_no_coalesce() {
1609        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1610        // Find the overwrite SQL — should have direct SET, no COALESCE
1611        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1612        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1613        assert!(!method_body.contains("COALESCE"), "Overwrite should not use COALESCE: {}", method_body);
1614    }
1615
1616    #[test]
1617    fn test_overwrite_set_clause_pg() {
1618        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1619        assert!(code.contains("name = $1,"));
1620        assert!(code.contains("email = $2"));
1621        assert!(code.contains("WHERE id = $3"));
1622    }
1623
1624    #[test]
1625    fn test_overwrite_returning_pg() {
1626        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1627        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1628        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1629        assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in overwrite");
1630    }
1631
1632    #[test]
1633    fn test_view_no_overwrite() {
1634        let mut entity = standard_entity();
1635        entity.is_view = true;
1636        let code = gen(&entity, DatabaseKind::Postgres);
1637        assert!(!code.contains("pub async fn overwrite"));
1638    }
1639
1640    #[test]
1641    fn test_without_overwrite() {
1642        let m = Methods { overwrite: false, ..Methods::all() };
1643        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1644        assert!(!code.contains("pub async fn overwrite"));
1645        assert!(!code.contains("OverwriteUsersParams"));
1646    }
1647
1648    #[test]
1649    fn test_update_and_overwrite_coexist() {
1650        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1651        assert!(code.contains("pub async fn update"), "Expected update method");
1652        assert!(code.contains("pub async fn overwrite"), "Expected overwrite method");
1653        assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams");
1654        assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams");
1655    }
1656
1657    // --- delete ---
1658
1659    #[test]
1660    fn test_delete_method() {
1661        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1662        assert!(code.contains("pub async fn delete"));
1663    }
1664
1665    #[test]
1666    fn test_delete_where_pk() {
1667        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1668        assert!(code.contains("DELETE FROM users"));
1669        assert!(code.contains("WHERE id = $1"));
1670    }
1671
1672    #[test]
1673    fn test_tab_spaces_2_sql_indent() {
1674        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
1675        // SQL content at 8 spaces (4 + 2×2), "# at 6 spaces (4 + 2)
1676        assert!(code.contains("        SELECT *"), "Expected SQL at 8-space indent:\n{}", code);
1677        assert!(code.contains("      \"#"), "Expected closing tag at 6-space indent:\n{}", code);
1678    }
1679
1680    #[test]
1681    fn test_tab_spaces_4_sql_indent() {
1682        let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
1683        // SQL content at 12 spaces (4 + 2×4), "# at 8 spaces (4 + 4)
1684        assert!(code.contains("            SELECT *"), "Expected SQL at 12-space indent:\n{}", code);
1685        assert!(code.contains("        \"#"), "Expected closing tag at 8-space indent:\n{}", code);
1686    }
1687
1688    #[test]
1689    fn test_delete_returns_unit() {
1690        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1691        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1692    }
1693
1694    // --- views (read-only) ---
1695
1696    #[test]
1697    fn test_view_no_insert() {
1698        let mut entity = standard_entity();
1699        entity.is_view = true;
1700        let code = gen(&entity, DatabaseKind::Postgres);
1701        assert!(!code.contains("pub async fn insert"));
1702    }
1703
1704    #[test]
1705    fn test_view_no_update() {
1706        let mut entity = standard_entity();
1707        entity.is_view = true;
1708        let code = gen(&entity, DatabaseKind::Postgres);
1709        assert!(!code.contains("pub async fn update"));
1710    }
1711
1712    #[test]
1713    fn test_view_no_delete() {
1714        let mut entity = standard_entity();
1715        entity.is_view = true;
1716        let code = gen(&entity, DatabaseKind::Postgres);
1717        assert!(!code.contains("pub async fn delete"));
1718    }
1719
1720    #[test]
1721    fn test_view_has_get_all() {
1722        let mut entity = standard_entity();
1723        entity.is_view = true;
1724        let code = gen(&entity, DatabaseKind::Postgres);
1725        assert!(code.contains("pub async fn get_all"));
1726    }
1727
1728    #[test]
1729    fn test_view_has_paginate() {
1730        let mut entity = standard_entity();
1731        entity.is_view = true;
1732        let code = gen(&entity, DatabaseKind::Postgres);
1733        assert!(code.contains("pub async fn paginate"));
1734    }
1735
1736    #[test]
1737    fn test_view_has_get() {
1738        let mut entity = standard_entity();
1739        entity.is_view = true;
1740        let code = gen(&entity, DatabaseKind::Postgres);
1741        assert!(code.contains("pub async fn get"));
1742    }
1743
1744    // --- selective methods ---
1745
1746    #[test]
1747    fn test_only_get_all() {
1748        let m = Methods { get_all: true, ..Default::default() };
1749        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1750        assert!(code.contains("pub async fn get_all"));
1751        assert!(!code.contains("pub async fn paginate"));
1752        assert!(!code.contains("pub async fn insert"));
1753    }
1754
1755    #[test]
1756    fn test_without_get_all() {
1757        let m = Methods { get_all: false, ..Methods::all() };
1758        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1759        assert!(!code.contains("pub async fn get_all"));
1760    }
1761
1762    #[test]
1763    fn test_without_paginate() {
1764        let m = Methods { paginate: false, ..Methods::all() };
1765        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1766        assert!(!code.contains("pub async fn paginate"));
1767        assert!(!code.contains("PaginateUsersParams"));
1768    }
1769
1770    #[test]
1771    fn test_without_get() {
1772        let m = Methods { get: false, ..Methods::all() };
1773        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1774        assert!(code.contains("pub async fn get_all"));
1775        let without_get_all = code.replace("get_all", "XXX");
1776        assert!(!without_get_all.contains("fn get("));
1777    }
1778
1779    #[test]
1780    fn test_without_insert() {
1781        let m = Methods { insert: false, insert_many: false, ..Methods::all() };
1782        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1783        assert!(!code.contains("pub async fn insert"));
1784        assert!(!code.contains("InsertUsersParams"));
1785    }
1786
1787    #[test]
1788    fn test_without_update() {
1789        let m = Methods { update: false, ..Methods::all() };
1790        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1791        assert!(!code.contains("pub async fn update"));
1792        assert!(!code.contains("UpdateUsersParams"));
1793    }
1794
1795    #[test]
1796    fn test_without_delete() {
1797        let m = Methods { delete: false, ..Methods::all() };
1798        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1799        assert!(!code.contains("pub async fn delete"));
1800    }
1801
1802    #[test]
1803    fn test_empty_methods_no_methods() {
1804        let m = Methods::default();
1805        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1806        assert!(!code.contains("pub async fn get_all"));
1807        assert!(!code.contains("pub async fn paginate"));
1808        assert!(!code.contains("pub async fn insert"));
1809        assert!(!code.contains("pub async fn update"));
1810        assert!(!code.contains("pub async fn overwrite"));
1811        assert!(!code.contains("pub async fn delete"));
1812        assert!(!code.contains("pub async fn insert_many"));
1813    }
1814
1815    // --- imports ---
1816
1817    #[test]
1818    fn test_no_pool_import() {
1819        let skip = Methods::all();
1820        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1821        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1822    }
1823
1824    #[test]
1825    fn test_imports_contain_entity() {
1826        let skip = Methods::all();
1827        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1828        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1829    }
1830
1831    // --- renamed columns ---
1832
1833    #[test]
1834    fn test_renamed_column_in_sql() {
1835        let entity = ParsedEntity {
1836            struct_name: "Connector".to_string(),
1837            table_name: "connector".to_string(),
1838            schema_name: None,
1839            is_view: false,
1840            fields: vec![
1841                make_field("id", "id", "i32", false, true),
1842                make_field("connector_type", "type", "String", false, false),
1843            ],
1844            imports: vec![],
1845        };
1846        let code = gen(&entity, DatabaseKind::Postgres);
1847        // INSERT should use the DB column name "type", not "connector_type"
1848        assert!(code.contains("type"));
1849        // The Rust param field should be connector_type
1850        assert!(code.contains("pub connector_type: String"));
1851    }
1852
1853    // --- no PK edge cases ---
1854
1855    #[test]
1856    fn test_no_pk_no_get() {
1857        let entity = ParsedEntity {
1858            struct_name: "Logs".to_string(),
1859            table_name: "logs".to_string(),
1860            schema_name: None,
1861            is_view: false,
1862            fields: vec![
1863                make_field("message", "message", "String", false, false),
1864                make_field("ts", "ts", "String", false, false),
1865            ],
1866            imports: vec![],
1867        };
1868        let code = gen(&entity, DatabaseKind::Postgres);
1869        assert!(code.contains("pub async fn get_all"));
1870        let without_get_all = code.replace("get_all", "XXX");
1871        assert!(!without_get_all.contains("fn get("));
1872    }
1873
1874    #[test]
1875    fn test_no_pk_no_delete() {
1876        let entity = ParsedEntity {
1877            struct_name: "Logs".to_string(),
1878            table_name: "logs".to_string(),
1879            schema_name: None,
1880            is_view: false,
1881            fields: vec![
1882                make_field("message", "message", "String", false, false),
1883            ],
1884            imports: vec![],
1885        };
1886        let code = gen(&entity, DatabaseKind::Postgres);
1887        assert!(!code.contains("pub async fn delete"));
1888    }
1889
1890    // --- Default derive on param structs ---
1891
1892    #[test]
1893    fn test_param_structs_have_default() {
1894        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1895        assert!(code.contains("Default"));
1896    }
1897
1898    // --- entity imports forwarded ---
1899
1900    #[test]
1901    fn test_entity_imports_forwarded() {
1902        let entity = ParsedEntity {
1903            struct_name: "Users".to_string(),
1904            table_name: "users".to_string(),
1905            schema_name: None,
1906            is_view: false,
1907            fields: vec![
1908                make_field("id", "id", "Uuid", false, true),
1909                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1910            ],
1911            imports: vec![
1912                "use chrono::{DateTime, Utc};".to_string(),
1913                "use uuid::Uuid;".to_string(),
1914            ],
1915        };
1916        let skip = Methods::all();
1917        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1918        assert!(imports.iter().any(|i| i.contains("chrono")));
1919        assert!(imports.iter().any(|i| i.contains("uuid")));
1920    }
1921
1922    #[test]
1923    fn test_entity_imports_empty_when_no_imports() {
1924        let skip = Methods::all();
1925        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1926        // Should only have pool + entity imports, no chrono/uuid
1927        assert!(!imports.iter().any(|i| i.contains("chrono")));
1928        assert!(!imports.iter().any(|i| i.contains("uuid")));
1929    }
1930
1931    // --- query_macro mode ---
1932
1933    #[test]
1934    fn test_macro_get_all() {
1935        let m = Methods { get_all: true, ..Default::default() };
1936        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &m, true, PoolVisibility::Private);
1937        let code = parse_and_format(&tokens);
1938        assert!(code.contains("query_as!"));
1939        assert!(!code.contains("query_as::<"));
1940    }
1941
1942    #[test]
1943    fn test_macro_paginate() {
1944        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1945        assert!(code.contains("query_as!"));
1946        assert!(code.contains("per_page, offset"));
1947    }
1948
1949    #[test]
1950    fn test_macro_get() {
1951        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1952        // The get method should use query_as! with the PK as arg
1953        assert!(code.contains("query_as!(Users"));
1954    }
1955
1956    #[test]
1957    fn test_macro_insert_pg() {
1958        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1959        assert!(code.contains("query_as!(Users"));
1960        assert!(code.contains("params.name"));
1961        assert!(code.contains("params.email"));
1962    }
1963
1964    #[test]
1965    fn test_macro_insert_mysql() {
1966        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1967        // MySQL insert uses query! (not query_as!) for the INSERT
1968        assert!(code.contains("query!"));
1969        assert!(code.contains("query_scalar!"));
1970    }
1971
1972    #[test]
1973    fn test_macro_update() {
1974        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1975        assert!(code.contains("query_as!(Users"));
1976        assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1977        assert!(code.contains("pub async fn update"));
1978        assert!(code.contains("UpdateUsersParams"));
1979    }
1980
1981    #[test]
1982    fn test_macro_delete() {
1983        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1984        // delete uses query! (no return type)
1985        assert!(code.contains("query!"));
1986    }
1987
1988    #[test]
1989    fn test_macro_no_bind_calls() {
1990        // insert_many always uses runtime mode, so exclude it for this test
1991        let m = Methods { insert_many: false, ..Methods::all() };
1992        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &m, true, PoolVisibility::Private);
1993        let code = parse_and_format(&tokens);
1994        assert!(!code.contains(".bind("));
1995    }
1996
1997    #[test]
1998    fn test_function_style_uses_bind() {
1999        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2000        assert!(code.contains(".bind("));
2001        assert!(!code.contains("query_as!("));
2002        assert!(!code.contains("query!("));
2003    }
2004
2005    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
2006
2007    fn entity_with_sql_array() -> ParsedEntity {
2008        ParsedEntity {
2009            struct_name: "AgentConnector".to_string(),
2010            table_name: "agent.agent_connector".to_string(),
2011            schema_name: Some("agent".to_string()),
2012            is_view: false,
2013            fields: vec![
2014                ParsedField {
2015                    rust_name: "connector_id".to_string(),
2016                    column_name: "connector_id".to_string(),
2017                    rust_type: "Uuid".to_string(),
2018                    inner_type: "Uuid".to_string(),
2019                    is_nullable: false,
2020                    is_primary_key: true,
2021                    sql_type: None,
2022                    is_sql_array: false,
2023                    column_default: None,
2024                },
2025                ParsedField {
2026                    rust_name: "agent_id".to_string(),
2027                    column_name: "agent_id".to_string(),
2028                    rust_type: "Uuid".to_string(),
2029                    inner_type: "Uuid".to_string(),
2030                    is_nullable: false,
2031                    is_primary_key: false,
2032                    sql_type: None,
2033                    is_sql_array: false,
2034                    column_default: None,
2035                },
2036                ParsedField {
2037                    rust_name: "usages".to_string(),
2038                    column_name: "usages".to_string(),
2039                    rust_type: "Vec<ConnectorUsages>".to_string(),
2040                    inner_type: "Vec<ConnectorUsages>".to_string(),
2041                    is_nullable: false,
2042                    is_primary_key: false,
2043                    sql_type: Some("agent.connector_usages".to_string()),
2044                    is_sql_array: true,
2045                    column_default: None,
2046                },
2047            ],
2048            imports: vec!["use uuid::Uuid;".to_string()],
2049        }
2050    }
2051
2052    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2053        let skip = Methods::all();
2054        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
2055        parse_and_format(&tokens)
2056    }
2057
2058    #[test]
2059    fn test_sql_array_macro_get_all_uses_runtime() {
2060        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2061        // get_all should use runtime query_as, not macro
2062        assert!(code.contains("query_as::<"));
2063    }
2064
2065    #[test]
2066    fn test_sql_array_macro_get_uses_runtime() {
2067        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2068        // get should use .bind( since it's runtime
2069        assert!(code.contains(".bind("));
2070    }
2071
2072    #[test]
2073    fn test_sql_array_macro_insert_uses_runtime() {
2074        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2075        // insert RETURNING should use runtime query_as
2076        assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
2077    }
2078
2079
2080    #[test]
2081    fn test_sql_array_macro_delete_still_uses_macro() {
2082        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2083        // delete uses query! macro (no rows returned, no array issue)
2084        assert!(code.contains("query!"));
2085    }
2086
2087    #[test]
2088    fn test_sql_array_no_query_as_macro() {
2089        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2090        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
2091        assert!(!code.contains("query_as!("));
2092    }
2093
2094    // --- custom enum (non-array) also triggers runtime fallback ---
2095
2096    fn entity_with_sql_enum() -> ParsedEntity {
2097        ParsedEntity {
2098            struct_name: "Task".to_string(),
2099            table_name: "tasks".to_string(),
2100            schema_name: None,
2101            is_view: false,
2102            fields: vec![
2103                ParsedField {
2104                    rust_name: "id".to_string(),
2105                    column_name: "id".to_string(),
2106                    rust_type: "i32".to_string(),
2107                    inner_type: "i32".to_string(),
2108                    is_nullable: false,
2109                    is_primary_key: true,
2110                    sql_type: None,
2111                    is_sql_array: false,
2112                    column_default: None,
2113                },
2114                ParsedField {
2115                    rust_name: "status".to_string(),
2116                    column_name: "status".to_string(),
2117                    rust_type: "TaskStatus".to_string(),
2118                    inner_type: "TaskStatus".to_string(),
2119                    is_nullable: false,
2120                    is_primary_key: false,
2121                    sql_type: Some("task_status".to_string()),
2122                    is_sql_array: false,
2123                    column_default: None,
2124                },
2125            ],
2126            imports: vec![],
2127        }
2128    }
2129
2130    #[test]
2131    fn test_sql_enum_macro_uses_runtime() {
2132        let skip = Methods::all();
2133        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
2134        let code = parse_and_format(&tokens);
2135        // SELECT queries should use runtime query_as, not macro
2136        assert!(code.contains("query_as::<"));
2137        assert!(!code.contains("query_as!("));
2138    }
2139
2140    #[test]
2141    fn test_sql_enum_macro_delete_still_uses_macro() {
2142        let skip = Methods::all();
2143        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
2144        let code = parse_and_format(&tokens);
2145        // DELETE still uses query! macro
2146        assert!(code.contains("query!"));
2147    }
2148
2149    // --- Vec<String> native array uses .as_slice() in macro mode ---
2150
2151    fn entity_with_vec_string() -> ParsedEntity {
2152        ParsedEntity {
2153            struct_name: "PromptHistory".to_string(),
2154            table_name: "prompt_history".to_string(),
2155            schema_name: None,
2156            is_view: false,
2157            fields: vec![
2158                ParsedField {
2159                    rust_name: "id".to_string(),
2160                    column_name: "id".to_string(),
2161                    rust_type: "Uuid".to_string(),
2162                    inner_type: "Uuid".to_string(),
2163                    is_nullable: false,
2164                    is_primary_key: true,
2165                    sql_type: None,
2166                    is_sql_array: false,
2167                    column_default: None,
2168                },
2169                ParsedField {
2170                    rust_name: "content".to_string(),
2171                    column_name: "content".to_string(),
2172                    rust_type: "String".to_string(),
2173                    inner_type: "String".to_string(),
2174                    is_nullable: false,
2175                    is_primary_key: false,
2176                    sql_type: None,
2177                    is_sql_array: false,
2178                    column_default: None,
2179                },
2180                ParsedField {
2181                    rust_name: "tags".to_string(),
2182                    column_name: "tags".to_string(),
2183                    rust_type: "Vec<String>".to_string(),
2184                    inner_type: "Vec<String>".to_string(),
2185                    is_nullable: false,
2186                    is_primary_key: false,
2187                    sql_type: None,
2188                    is_sql_array: false,
2189                    column_default: None,
2190                },
2191            ],
2192            imports: vec!["use uuid::Uuid;".to_string()],
2193        }
2194    }
2195
2196    #[test]
2197    fn test_vec_string_macro_insert_uses_as_slice() {
2198        let skip = Methods::all();
2199        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2200        let code = parse_and_format(&tokens);
2201        assert!(code.contains("as_slice()"));
2202    }
2203
2204    #[test]
2205    fn test_vec_string_macro_update_uses_as_slice() {
2206        let skip = Methods::all();
2207        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2208        let code = parse_and_format(&tokens);
2209        // Should have as_slice() for insert and update
2210        let count = code.matches("as_slice()").count();
2211        assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
2212    }
2213
2214    #[test]
2215    fn test_vec_string_non_macro_no_as_slice() {
2216        let skip = Methods::all();
2217        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
2218        let code = parse_and_format(&tokens);
2219        // Runtime mode uses .bind() so no as_slice needed
2220        assert!(!code.contains("as_slice()"));
2221    }
2222
2223    #[test]
2224    fn test_vec_string_parsed_from_source_uses_as_slice() {
2225        use crate::codegen::entity_parser::parse_entity_source;
2226        let source = r#"
2227            use uuid::Uuid;
2228
2229            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2230            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2231            pub struct PromptHistory {
2232                #[sqlx_gen(primary_key)]
2233                pub id: Uuid,
2234                pub content: String,
2235                pub tags: Vec<String>,
2236            }
2237        "#;
2238        let entity = parse_entity_source(source).unwrap();
2239        let skip = Methods::all();
2240        let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2241        let code = parse_and_format(&tokens);
2242        assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
2243    }
2244
2245    // --- composite PK only (junction table) ---
2246
2247    fn junction_entity() -> ParsedEntity {
2248        ParsedEntity {
2249            struct_name: "AnalysisRecord".to_string(),
2250            table_name: "analysis.analysis__record".to_string(),
2251            schema_name: None,
2252            is_view: false,
2253            fields: vec![
2254                make_field("record_id", "record_id", "uuid::Uuid", false, true),
2255                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2256            ],
2257            imports: vec![],
2258        }
2259    }
2260
2261    #[test]
2262    fn test_composite_pk_only_insert_generated() {
2263        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2264        assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
2265        assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
2266        assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
2267        assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"), "Expected INSERT INTO clause:\n{}", code);
2268        assert!(code.contains("VALUES ($1, $2)"), "Expected VALUES clause:\n{}", code);
2269        assert!(code.contains("RETURNING *"), "Expected RETURNING clause:\n{}", code);
2270        assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
2271    }
2272
2273    #[test]
2274    fn test_composite_pk_only_no_update() {
2275        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2276        assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
2277        assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
2278    }
2279
2280
2281    #[test]
2282    fn test_composite_pk_only_delete_generated() {
2283        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2284        assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
2285        assert!(code.contains("DELETE FROM analysis.analysis__record"), "Expected DELETE clause:\n{}", code);
2286        assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause:\n{}", code);
2287    }
2288
2289    #[test]
2290    fn test_composite_pk_only_get_generated() {
2291        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2292        assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
2293        assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
2294    }
2295
2296    // --- insert_many_transactionally ---
2297
2298    #[test]
2299    fn test_insert_many_transactionally_method_generated() {
2300        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2301        assert!(code.contains("pub async fn insert_many_transactionally"), "Expected insert_many_transactionally method:\n{}", code);
2302    }
2303
2304    #[test]
2305    fn test_insert_many_transactionally_signature() {
2306        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2307        assert!(code.contains("entries: Vec<InsertUsersParams>"), "Expected Vec<InsertUsersParams> param:\n{}", code);
2308        assert!(code.contains("Result<Vec<Users>"), "Expected Result<Vec<Users>> return type:\n{}", code);
2309    }
2310
2311    #[test]
2312    fn test_insert_many_transactionally_no_strategy_enum() {
2313        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2314        assert!(!code.contains("TransactionStrategy"), "TransactionStrategy should not be generated:\n{}", code);
2315        assert!(!code.contains("InsertManyUsersResult"), "InsertManyUsersResult should not be generated:\n{}", code);
2316    }
2317
2318    #[test]
2319    fn test_insert_many_transactionally_uses_transaction_pg() {
2320        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2321        let method_start = code.find("fn insert_many_transactionally").expect("insert_many_transactionally not found");
2322        let method_body = &code[method_start..];
2323        assert!(method_body.contains("self.pool.begin()"), "Expected begin():\n{}", method_body);
2324        assert!(method_body.contains("tx.commit()"), "Expected commit():\n{}", method_body);
2325    }
2326
2327    #[test]
2328    fn test_insert_many_transactionally_multi_row_pg() {
2329        let code = gen(&standard_entity(), DatabaseKind::Postgres);
2330        let method_start = code.find("fn insert_many_transactionally").expect("not found");
2331        let method_body = &code[method_start..];
2332        assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in multi-row SQL:\n{}", method_body);
2333        assert!(method_body.contains("values_parts"), "Expected multi-row VALUES building:\n{}", method_body);
2334        assert!(method_body.contains("65535"), "Expected chunk size limit:\n{}", method_body);
2335    }
2336
2337    #[test]
2338    fn test_insert_many_transactionally_multi_row_sqlite() {
2339        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2340        let method_start = code.find("fn insert_many_transactionally").expect("not found");
2341        let method_body = &code[method_start..];
2342        assert!(method_body.contains("values_parts"), "Expected multi-row VALUES building for SQLite:\n{}", method_body);
2343        assert!(method_body.contains("RETURNING *"), "Expected RETURNING * for SQLite:\n{}", method_body);
2344    }
2345
2346    #[test]
2347    fn test_insert_many_transactionally_mysql_individual_inserts() {
2348        let code = gen(&standard_entity(), DatabaseKind::Mysql);
2349        let method_start = code.find("fn insert_many_transactionally").expect("not found");
2350        let method_body = &code[method_start..];
2351        assert!(method_body.contains("LAST_INSERT_ID"), "Expected LAST_INSERT_ID for MySQL:\n{}", method_body);
2352        assert!(method_body.contains("self.pool.begin()"), "Expected begin() for MySQL:\n{}", method_body);
2353    }
2354
2355    #[test]
2356    fn test_insert_many_transactionally_view_not_generated() {
2357        let mut entity = standard_entity();
2358        entity.is_view = true;
2359        let code = gen(&entity, DatabaseKind::Postgres);
2360        assert!(!code.contains("pub async fn insert_many_transactionally"), "should not be generated for views");
2361    }
2362
2363    #[test]
2364    fn test_insert_many_transactionally_without_method_not_generated() {
2365        let m = Methods { insert_many: false, ..Methods::all() };
2366        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2367        assert!(!code.contains("pub async fn insert_many_transactionally"), "should not be generated when disabled");
2368    }
2369
2370    #[test]
2371    fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
2372        let m = Methods { insert: false, insert_many: true, ..Default::default() };
2373        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2374        assert!(code.contains("pub struct InsertUsersParams"), "Expected InsertUsersParams:\n{}", code);
2375        assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method:\n{}", code);
2376        assert!(!code.contains("pub async fn insert("), "insert should not be present:\n{}", code);
2377    }
2378
2379    #[test]
2380    fn test_insert_many_transactionally_with_column_defaults_coalesce() {
2381        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
2382        let method_start = code.find("fn insert_many_transactionally").expect("not found");
2383        let method_body = &code[method_start..];
2384        assert!(method_body.contains("COALESCE"), "Expected COALESCE for fields with defaults:\n{}", method_body);
2385    }
2386
2387    #[test]
2388    fn test_insert_many_transactionally_junction_table() {
2389        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2390        assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method for junction table:\n{}", code);
2391    }
2392
2393    #[test]
2394    fn test_insert_many_transactionally_all_three_backends_compile() {
2395        for db in [DatabaseKind::Postgres, DatabaseKind::Mysql, DatabaseKind::Sqlite] {
2396            let code = gen(&standard_entity(), db);
2397            assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method for {:?}", db);
2398        }
2399    }
2400}