Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40    // Forward type imports from the entity file (chrono, uuid, etc.)
41    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
42    // since the repository lives in a different module where `super` has a different meaning.
43    let entity_parent = entity_module_path
44        .rsplit_once("::")
45        .map(|(parent, _)| parent)
46        .unwrap_or(entity_module_path);
47    for imp in &entity.imports {
48        if let Some(rest) = imp.strip_prefix("use super::") {
49            imports.insert(format!("use {}::{}", entity_parent, rest));
50        } else {
51            imports.insert(imp.clone());
52        }
53    }
54
55    // Primary key fields
56    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58    // Non-PK fields (for insert)
59    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61    let is_view = entity.is_view;
62
63    // Build method tokens
64    let mut method_tokens = Vec::new();
65    let mut param_structs = Vec::new();
66
67    // --- get_all ---
68    if methods.get_all {
69        let sql = format!("SELECT * FROM {}", table_name);
70        let method = if use_macro {
71            quote! {
72                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73                    sqlx::query_as!(#entity_ident, #sql)
74                        .fetch_all(&self.pool)
75                        .await
76                }
77            }
78        } else {
79            quote! {
80                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81                    sqlx::query_as::<_, #entity_ident>(#sql)
82                        .fetch_all(&self.pool)
83                        .await
84                }
85            }
86        };
87        method_tokens.push(method);
88    }
89
90    // --- paginate ---
91    if methods.paginate {
92        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95        let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96        let sql = match db_kind {
97            DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99        };
100        let method = if use_macro {
101            quote! {
102                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103                    let total: i64 = sqlx::query_scalar!(#count_sql)
104                        .fetch_one(&self.pool)
105                        .await?
106                        .unwrap_or(0);
107                    let per_page = params.per_page;
108                    let current_page = params.page;
109                    let last_page = (total + per_page - 1) / per_page;
110                    let offset = (current_page - 1) * per_page;
111                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112                        .fetch_all(&self.pool)
113                        .await?;
114                    Ok(#paginated_ident {
115                        meta: #pagination_meta_ident {
116                            total,
117                            per_page,
118                            current_page,
119                            last_page,
120                            first_page: 1,
121                        },
122                        data,
123                    })
124                }
125            }
126        } else {
127            quote! {
128                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129                    let total: i64 = sqlx::query_scalar(#count_sql)
130                        .fetch_one(&self.pool)
131                        .await?;
132                    let per_page = params.per_page;
133                    let current_page = params.page;
134                    let last_page = (total + per_page - 1) / per_page;
135                    let offset = (current_page - 1) * per_page;
136                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
137                        .bind(per_page)
138                        .bind(offset)
139                        .fetch_all(&self.pool)
140                        .await?;
141                    Ok(#paginated_ident {
142                        meta: #pagination_meta_ident {
143                            total,
144                            per_page,
145                            current_page,
146                            last_page,
147                            first_page: 1,
148                        },
149                        data,
150                    })
151                }
152            }
153        };
154        method_tokens.push(method);
155        param_structs.push(quote! {
156            #[derive(Debug, Clone, Default)]
157            pub struct #paginate_params_ident {
158                pub page: i64,
159                pub per_page: i64,
160            }
161        });
162        param_structs.push(quote! {
163            #[derive(Debug, Clone)]
164            pub struct #pagination_meta_ident {
165                pub total: i64,
166                pub per_page: i64,
167                pub current_page: i64,
168                pub last_page: i64,
169                pub first_page: i64,
170            }
171        });
172        param_structs.push(quote! {
173            #[derive(Debug, Clone)]
174            pub struct #paginated_ident {
175                pub meta: #pagination_meta_ident,
176                pub data: Vec<#entity_ident>,
177            }
178        });
179    }
180
181    // --- get (by PK) ---
182    if methods.get && !pk_fields.is_empty() {
183        let pk_params: Vec<TokenStream> = pk_fields
184            .iter()
185            .map(|f| {
186                let name = format_ident!("{}", f.rust_name);
187                let ty: TokenStream = f.inner_type.parse().unwrap();
188                quote! { #name: #ty }
189            })
190            .collect();
191
192        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194        let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195        let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197        let binds: Vec<TokenStream> = pk_fields
198            .iter()
199            .map(|f| {
200                let name = format_ident!("{}", f.rust_name);
201                quote! { .bind(#name) }
202            })
203            .collect();
204
205        let method = if use_macro {
206            let pk_arg_names: Vec<TokenStream> = pk_fields
207                .iter()
208                .map(|f| {
209                    let name = format_ident!("{}", f.rust_name);
210                    quote! { #name }
211                })
212                .collect();
213            quote! {
214                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216                        .fetch_optional(&self.pool)
217                        .await
218                }
219            }
220        } else {
221            quote! {
222                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223                    sqlx::query_as::<_, #entity_ident>(#sql)
224                        #(#binds)*
225                        .fetch_optional(&self.pool)
226                        .await
227                }
228            }
229        };
230        method_tokens.push(method);
231    }
232
233    // --- insert (skip for views) ---
234    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237        // When all columns are PKs (e.g. junction tables), use pk_fields for insert
238        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239            pk_fields.clone()
240        } else {
241            non_pk_fields.clone()
242        };
243
244        // Fields with column_default and not already nullable → Option<T>
245        let insert_fields: Vec<TokenStream> = insert_source_fields
246            .iter()
247            .map(|f| {
248                let name = format_ident!("{}", f.rust_name);
249                if f.column_default.is_some() && !f.is_nullable {
250                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
251                    quote! { pub #name: #ty, }
252                } else {
253                    let ty: TokenStream = f.rust_type.parse().unwrap();
254                    quote! { pub #name: #ty, }
255                }
256            })
257            .collect();
258
259        let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
260        let col_list = col_names.join(", ");
261
262        // Build placeholders with COALESCE for fields that have a column_default
263        let placeholders: String = insert_source_fields
264            .iter()
265            .enumerate()
266            .map(|(i, f)| {
267                let p = placeholder(db_kind, i + 1);
268                match &f.column_default {
269                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
270                    None => p,
271                }
272            })
273            .collect::<Vec<_>>()
274            .join(", ");
275
276        let placeholders_cast: String = insert_source_fields
277            .iter()
278            .enumerate()
279            .map(|(i, f)| {
280                let p = placeholder_with_cast(db_kind, i + 1, f);
281                match &f.column_default {
282                    Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
283                    None => p,
284                }
285            })
286            .collect::<Vec<_>>()
287            .join(", ");
288
289        let build_insert_sql = |ph: &str| match db_kind {
290            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
291                format!(
292                    "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
293                    table_name, col_list, ph
294                )
295            }
296            DatabaseKind::Mysql => {
297                format!(
298                    "INSERT INTO {} ({}) VALUES ({})",
299                    table_name, col_list, ph
300                )
301            }
302        };
303        let sql = build_insert_sql(&placeholders);
304        let sql_macro = build_insert_sql(&placeholders_cast);
305
306        let binds: Vec<TokenStream> = insert_source_fields
307            .iter()
308            .map(|f| {
309                let name = format_ident!("{}", f.rust_name);
310                quote! { .bind(&params.#name) }
311            })
312            .collect();
313
314        let insert_method = build_insert_method_parsed(
315            &entity_ident,
316            &insert_params_ident,
317            &sql,
318            &sql_macro,
319            &binds,
320            db_kind,
321            &table_name,
322            &pk_fields,
323            &insert_source_fields,
324            use_macro,
325        );
326        method_tokens.push(insert_method);
327
328        param_structs.push(quote! {
329            #[derive(Debug, Clone, Default)]
330            pub struct #insert_params_ident {
331                #(#insert_fields)*
332            }
333        });
334    }
335
336    // --- overwrite (full replacement — skip for views, skip when all columns are PKs) ---
337    if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
338        let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
339
340        // PK as function parameters (like get/delete)
341        let pk_fn_params: Vec<TokenStream> = pk_fields
342            .iter()
343            .map(|f| {
344                let name = format_ident!("{}", f.rust_name);
345                let ty: TokenStream = f.inner_type.parse().unwrap();
346                quote! { #name: #ty }
347            })
348            .collect();
349
350        // Non-PK fields keep original types (required)
351        let overwrite_fields: Vec<TokenStream> = non_pk_fields
352            .iter()
353            .map(|f| {
354                let name = format_ident!("{}", f.rust_name);
355                let ty: TokenStream = f.rust_type.parse().unwrap();
356                quote! { pub #name: #ty, }
357            })
358            .collect();
359
360        let set_cols: Vec<String> = non_pk_fields
361            .iter()
362            .enumerate()
363            .map(|(i, f)| {
364                let p = placeholder(db_kind, i + 1);
365                format!("{} = {}", f.column_name, p)
366            })
367            .collect();
368        let set_clause = set_cols.join(", ");
369
370        let set_cols_cast: Vec<String> = non_pk_fields
371            .iter()
372            .enumerate()
373            .map(|(i, f)| {
374                let p = placeholder_with_cast(db_kind, i + 1, f);
375                format!("{} = {}", f.column_name, p)
376            })
377            .collect();
378        let set_clause_cast = set_cols_cast.join(", ");
379
380        let pk_start = non_pk_fields.len() + 1;
381        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
382        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
383
384        let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
385            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
386                format!("UPDATE {} SET {} WHERE {} RETURNING *", table_name, sc, wc)
387            }
388            DatabaseKind::Mysql => {
389                format!("UPDATE {} SET {} WHERE {}", table_name, sc, wc)
390            }
391        };
392        let sql = build_overwrite_sql(&set_clause, &where_clause);
393        let sql_macro = build_overwrite_sql(&set_clause_cast, &where_clause_cast);
394
395        // Bind non-PK first (from params), then PK (from function args)
396        let mut all_binds: Vec<TokenStream> = non_pk_fields
397            .iter()
398            .map(|f| {
399                let name = format_ident!("{}", f.rust_name);
400                quote! { .bind(&params.#name) }
401            })
402            .collect();
403        for f in &pk_fields {
404            let name = format_ident!("{}", f.rust_name);
405            all_binds.push(quote! { .bind(#name) });
406        }
407
408        // Macro args: non-PK from params, then PK from function args
409        let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
410            .iter()
411            .map(|f| macro_arg_for_field(f))
412            .chain(pk_fields.iter().map(|f| {
413                let name = format_ident!("{}", f.rust_name);
414                quote! { #name }
415            }))
416            .collect();
417
418        let overwrite_method = if use_macro {
419            match db_kind {
420                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
421                    quote! {
422                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
423                            sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
424                                .fetch_one(&self.pool)
425                                .await
426                        }
427                    }
428                }
429                DatabaseKind::Mysql => {
430                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
431                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
432                    let pk_macro_args: Vec<TokenStream> = pk_fields
433                        .iter()
434                        .map(|f| {
435                            let name = format_ident!("{}", f.rust_name);
436                            quote! { #name }
437                        })
438                        .collect();
439                    quote! {
440                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
441                            sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
442                                .execute(&self.pool)
443                                .await?;
444                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
445                                .fetch_one(&self.pool)
446                                .await
447                        }
448                    }
449                }
450            }
451        } else {
452            match db_kind {
453                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
454                    quote! {
455                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
456                            sqlx::query_as::<_, #entity_ident>(#sql)
457                                #(#all_binds)*
458                                .fetch_one(&self.pool)
459                                .await
460                        }
461                    }
462                }
463                DatabaseKind::Mysql => {
464                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
465                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
466                    let pk_binds: Vec<TokenStream> = pk_fields
467                        .iter()
468                        .map(|f| {
469                            let name = format_ident!("{}", f.rust_name);
470                            quote! { .bind(#name) }
471                        })
472                        .collect();
473                    quote! {
474                        pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
475                            sqlx::query(#sql)
476                                #(#all_binds)*
477                                .execute(&self.pool)
478                                .await?;
479                            sqlx::query_as::<_, #entity_ident>(#select_sql)
480                                #(#pk_binds)*
481                                .fetch_one(&self.pool)
482                                .await
483                        }
484                    }
485                }
486            }
487        };
488        method_tokens.push(overwrite_method);
489
490        param_structs.push(quote! {
491            #[derive(Debug, Clone, Default)]
492            pub struct #overwrite_params_ident {
493                #(#overwrite_fields)*
494            }
495        });
496    }
497
498    // --- update / patch (COALESCE — skip for views, skip when all columns are PKs) ---
499    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
500        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
501
502        // PK as function parameters (like get/delete)
503        let pk_fn_params: Vec<TokenStream> = pk_fields
504            .iter()
505            .map(|f| {
506                let name = format_ident!("{}", f.rust_name);
507                let ty: TokenStream = f.inner_type.parse().unwrap();
508                quote! { #name: #ty }
509            })
510            .collect();
511
512        // Non-PK fields become Option<T> (no double Option for already nullable)
513        let update_fields: Vec<TokenStream> = non_pk_fields
514            .iter()
515            .map(|f| {
516                let name = format_ident!("{}", f.rust_name);
517                if f.is_nullable {
518                    // Already Option<T> — keep as-is to avoid Option<Option<T>>
519                    let ty: TokenStream = f.rust_type.parse().unwrap();
520                    quote! { pub #name: #ty, }
521                } else {
522                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
523                    quote! { pub #name: #ty, }
524                }
525            })
526            .collect();
527
528        // SET clause with COALESCE for runtime mode
529        let set_cols: Vec<String> = non_pk_fields
530            .iter()
531            .enumerate()
532            .map(|(i, f)| {
533                let p = placeholder(db_kind, i + 1);
534                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
535            })
536            .collect();
537        let set_clause = set_cols.join(", ");
538
539        // SET clause with COALESCE and casts for macro mode
540        let set_cols_cast: Vec<String> = non_pk_fields
541            .iter()
542            .enumerate()
543            .map(|(i, f)| {
544                let p = placeholder_with_cast(db_kind, i + 1, f);
545                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
546            })
547            .collect();
548        let set_clause_cast = set_cols_cast.join(", ");
549
550        let pk_start = non_pk_fields.len() + 1;
551        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
552        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
553
554        let build_update_sql = |sc: &str, wc: &str| match db_kind {
555            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
556                format!(
557                    "UPDATE {} SET {} WHERE {} RETURNING *",
558                    table_name, sc, wc
559                )
560            }
561            DatabaseKind::Mysql => {
562                format!(
563                    "UPDATE {} SET {} WHERE {}",
564                    table_name, sc, wc
565                )
566            }
567        };
568        let sql = build_update_sql(&set_clause, &where_clause);
569        let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
570
571        // Bind non-PK first (from params), then PK (from function args)
572        let mut all_binds: Vec<TokenStream> = non_pk_fields
573            .iter()
574            .map(|f| {
575                let name = format_ident!("{}", f.rust_name);
576                quote! { .bind(&params.#name) }
577            })
578            .collect();
579        for f in &pk_fields {
580            let name = format_ident!("{}", f.rust_name);
581            all_binds.push(quote! { .bind(#name) });
582        }
583
584        // Macro args: non-PK from params, then PK from function args
585        let update_macro_args: Vec<TokenStream> = non_pk_fields
586            .iter()
587            .map(|f| macro_arg_for_field(f))
588            .chain(pk_fields.iter().map(|f| {
589                let name = format_ident!("{}", f.rust_name);
590                quote! { #name }
591            }))
592            .collect();
593
594        let update_method = if use_macro {
595            match db_kind {
596                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
597                    quote! {
598                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
599                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
600                                .fetch_one(&self.pool)
601                                .await
602                        }
603                    }
604                }
605                DatabaseKind::Mysql => {
606                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
607                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
608                    let pk_macro_args: Vec<TokenStream> = pk_fields
609                        .iter()
610                        .map(|f| {
611                            let name = format_ident!("{}", f.rust_name);
612                            quote! { #name }
613                        })
614                        .collect();
615                    quote! {
616                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
617                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
618                                .execute(&self.pool)
619                                .await?;
620                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
621                                .fetch_one(&self.pool)
622                                .await
623                        }
624                    }
625                }
626            }
627        } else {
628            match db_kind {
629                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
630                    quote! {
631                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
632                            sqlx::query_as::<_, #entity_ident>(#sql)
633                                #(#all_binds)*
634                                .fetch_one(&self.pool)
635                                .await
636                        }
637                    }
638                }
639                DatabaseKind::Mysql => {
640                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
641                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
642                    let pk_binds: Vec<TokenStream> = pk_fields
643                        .iter()
644                        .map(|f| {
645                            let name = format_ident!("{}", f.rust_name);
646                            quote! { .bind(#name) }
647                        })
648                        .collect();
649                    quote! {
650                        pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
651                            sqlx::query(#sql)
652                                #(#all_binds)*
653                                .execute(&self.pool)
654                                .await?;
655                            sqlx::query_as::<_, #entity_ident>(#select_sql)
656                                #(#pk_binds)*
657                                .fetch_one(&self.pool)
658                                .await
659                        }
660                    }
661                }
662            }
663        };
664        method_tokens.push(update_method);
665
666        param_structs.push(quote! {
667            #[derive(Debug, Clone, Default)]
668            pub struct #update_params_ident {
669                #(#update_fields)*
670            }
671        });
672    }
673
674    // --- delete (skip for views) ---
675    if !is_view && methods.delete && !pk_fields.is_empty() {
676        let pk_params: Vec<TokenStream> = pk_fields
677            .iter()
678            .map(|f| {
679                let name = format_ident!("{}", f.rust_name);
680                let ty: TokenStream = f.inner_type.parse().unwrap();
681                quote! { #name: #ty }
682            })
683            .collect();
684
685        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
686        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
687        let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
688        let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
689
690        let binds: Vec<TokenStream> = pk_fields
691            .iter()
692            .map(|f| {
693                let name = format_ident!("{}", f.rust_name);
694                quote! { .bind(#name) }
695            })
696            .collect();
697
698        let method = if query_macro {
699            let pk_arg_names: Vec<TokenStream> = pk_fields
700                .iter()
701                .map(|f| {
702                    let name = format_ident!("{}", f.rust_name);
703                    quote! { #name }
704                })
705                .collect();
706            quote! {
707                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
708                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
709                        .execute(&self.pool)
710                        .await?;
711                    Ok(())
712                }
713            }
714        } else {
715            quote! {
716                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
717                    sqlx::query(#sql)
718                        #(#binds)*
719                        .execute(&self.pool)
720                        .await?;
721                    Ok(())
722                }
723            }
724        };
725        method_tokens.push(method);
726    }
727
728    let pool_vis: TokenStream = match pool_visibility {
729        PoolVisibility::Private => quote! {},
730        PoolVisibility::Pub => quote! { pub },
731        PoolVisibility::PubCrate => quote! { pub(crate) },
732    };
733
734    let tokens = quote! {
735        #(#param_structs)*
736
737        pub struct #repo_ident {
738            #pool_vis pool: #pool_type,
739        }
740
741        impl #repo_ident {
742            pub fn new(pool: #pool_type) -> Self {
743                Self { pool }
744            }
745
746            #(#method_tokens)*
747        }
748    };
749
750    (tokens, imports)
751}
752
753fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
754    match db_kind {
755        DatabaseKind::Postgres => quote! { sqlx::PgPool },
756        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
757        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
758    }
759}
760
761fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
762    match db_kind {
763        DatabaseKind::Postgres => format!("${}", index),
764        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
765    }
766}
767
768fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
769    let base = placeholder(db_kind, index);
770    match (&field.sql_type, field.is_sql_array) {
771        (Some(t), true) => format!("{} as {}[]", base, t),
772        (Some(t), false) => format!("{} as {}", base, t),
773        (None, _) => base,
774    }
775}
776
777fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
778    (0..count)
779        .map(|i| placeholder(db_kind, start + i))
780        .collect::<Vec<_>>()
781        .join(", ")
782}
783
784fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
785    fields
786        .iter()
787        .enumerate()
788        .map(|(i, f)| {
789            if use_cast {
790                placeholder_with_cast(db_kind, start + i, f)
791            } else {
792                placeholder(db_kind, start + i)
793            }
794        })
795        .collect::<Vec<_>>()
796        .join(", ")
797}
798
799fn build_where_clause_parsed(
800    pk_fields: &[&ParsedField],
801    db_kind: DatabaseKind,
802    start_index: usize,
803) -> String {
804    pk_fields
805        .iter()
806        .enumerate()
807        .map(|(i, f)| {
808            let p = placeholder(db_kind, start_index + i);
809            format!("{} = {}", f.column_name, p)
810        })
811        .collect::<Vec<_>>()
812        .join(" AND ")
813}
814
815fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
816    let name = format_ident!("{}", field.rust_name);
817    let check_type = if field.is_nullable {
818        &field.inner_type
819    } else {
820        &field.rust_type
821    };
822    let normalized = check_type.replace(' ', "");
823    if normalized.starts_with("Vec<") {
824        quote! { params.#name.as_slice() }
825    } else {
826        quote! { params.#name }
827    }
828}
829
830fn build_where_clause_cast(
831    pk_fields: &[&ParsedField],
832    db_kind: DatabaseKind,
833    start_index: usize,
834) -> String {
835    pk_fields
836        .iter()
837        .enumerate()
838        .map(|(i, f)| {
839            let p = placeholder_with_cast(db_kind, start_index + i, f);
840            format!("{} = {}", f.column_name, p)
841        })
842        .collect::<Vec<_>>()
843        .join(" AND ")
844}
845
846#[allow(clippy::too_many_arguments)]
847fn build_insert_method_parsed(
848    entity_ident: &proc_macro2::Ident,
849    insert_params_ident: &proc_macro2::Ident,
850    sql: &str,
851    sql_macro: &str,
852    binds: &[TokenStream],
853    db_kind: DatabaseKind,
854    table_name: &str,
855    pk_fields: &[&ParsedField],
856    non_pk_fields: &[&ParsedField],
857    use_macro: bool,
858) -> TokenStream {
859    if use_macro {
860        let macro_args: Vec<TokenStream> = non_pk_fields
861            .iter()
862            .map(|f| macro_arg_for_field(f))
863            .collect();
864
865        match db_kind {
866            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
867                quote! {
868                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
869                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
870                            .fetch_one(&self.pool)
871                            .await
872                    }
873                }
874            }
875            DatabaseKind::Mysql => {
876                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
877                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
878                quote! {
879                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
880                        sqlx::query!(#sql_macro, #(#macro_args),*)
881                            .execute(&self.pool)
882                            .await?;
883                        let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
884                            .fetch_one(&self.pool)
885                            .await?;
886                        sqlx::query_as!(#entity_ident, #select_sql, id)
887                            .fetch_one(&self.pool)
888                            .await
889                    }
890                }
891            }
892        }
893    } else {
894        match db_kind {
895            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
896                quote! {
897                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
898                        sqlx::query_as::<_, #entity_ident>(#sql)
899                            #(#binds)*
900                            .fetch_one(&self.pool)
901                            .await
902                    }
903                }
904            }
905            DatabaseKind::Mysql => {
906                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
907                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
908                quote! {
909                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
910                        sqlx::query(#sql)
911                            #(#binds)*
912                            .execute(&self.pool)
913                            .await?;
914                        let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
915                            .fetch_one(&self.pool)
916                            .await?;
917                        sqlx::query_as::<_, #entity_ident>(#select_sql)
918                            .bind(id)
919                            .fetch_one(&self.pool)
920                            .await
921                    }
922                }
923            }
924        }
925    }
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use crate::codegen::parse_and_format;
932    use crate::cli::Methods;
933
934    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
935        let inner_type = if nullable {
936            // Strip "Option<" prefix and ">" suffix
937            rust_type
938                .strip_prefix("Option<")
939                .and_then(|s| s.strip_suffix('>'))
940                .unwrap_or(rust_type)
941                .to_string()
942        } else {
943            rust_type.to_string()
944        };
945        ParsedField {
946            rust_name: rust_name.to_string(),
947            column_name: column_name.to_string(),
948            rust_type: rust_type.to_string(),
949            is_nullable: nullable,
950            inner_type,
951            is_primary_key: is_pk,
952            sql_type: None,
953            is_sql_array: false,
954            column_default: None,
955        }
956    }
957
958    fn make_field_with_default(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool, default: &str) -> ParsedField {
959        let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
960        f.column_default = Some(default.to_string());
961        f
962    }
963
964    fn entity_with_defaults() -> ParsedEntity {
965        ParsedEntity {
966            struct_name: "Tasks".to_string(),
967            table_name: "tasks".to_string(),
968            schema_name: None,
969            is_view: false,
970            fields: vec![
971                make_field("id", "id", "i32", false, true),
972                make_field("title", "title", "String", false, false),
973                make_field_with_default("status", "status", "String", false, false, "'idle'::task_status"),
974                make_field_with_default("priority", "priority", "i32", false, false, "0"),
975                make_field_with_default("created_at", "created_at", "DateTime<Utc>", false, false, "now()"),
976                make_field("description", "description", "Option<String>", true, false),
977                make_field_with_default("deleted_at", "deleted_at", "Option<DateTime<Utc>>", true, false, "NULL"),
978            ],
979            imports: vec![],
980        }
981    }
982
983    fn standard_entity() -> ParsedEntity {
984        ParsedEntity {
985            struct_name: "Users".to_string(),
986            table_name: "users".to_string(),
987            schema_name: None,
988            is_view: false,
989            fields: vec![
990                make_field("id", "id", "i32", false, true),
991                make_field("name", "name", "String", false, false),
992                make_field("email", "email", "Option<String>", true, false),
993            ],
994            imports: vec![],
995        }
996    }
997
998    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
999        let skip = Methods::all();
1000        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1001        parse_and_format(&tokens)
1002    }
1003
1004    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1005        let skip = Methods::all();
1006        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
1007        parse_and_format(&tokens)
1008    }
1009
1010    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1011        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
1012        parse_and_format(&tokens)
1013    }
1014
1015    // --- basic structure ---
1016
1017    #[test]
1018    fn test_repo_struct_name() {
1019        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1020        assert!(code.contains("pub struct UsersRepository"));
1021    }
1022
1023    #[test]
1024    fn test_repo_new_method() {
1025        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1026        assert!(code.contains("pub fn new("));
1027    }
1028
1029    #[test]
1030    fn test_repo_pool_field_pg() {
1031        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1032        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1033    }
1034
1035    #[test]
1036    fn test_repo_pool_field_pub() {
1037        let skip = Methods::all();
1038        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
1039        let code = parse_and_format(&tokens);
1040        assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
1041    }
1042
1043    #[test]
1044    fn test_repo_pool_field_pub_crate() {
1045        let skip = Methods::all();
1046        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
1047        let code = parse_and_format(&tokens);
1048        assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
1049    }
1050
1051    #[test]
1052    fn test_repo_pool_field_private() {
1053        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1054        // Should NOT have `pub pool` or `pub(crate) pool`
1055        assert!(!code.contains("pub pool"));
1056        assert!(!code.contains("pub(crate) pool"));
1057    }
1058
1059    #[test]
1060    fn test_repo_pool_field_mysql() {
1061        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1062        assert!(code.contains("MySqlPool") || code.contains("MySql"));
1063    }
1064
1065    #[test]
1066    fn test_repo_pool_field_sqlite() {
1067        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1068        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1069    }
1070
1071    // --- get_all ---
1072
1073    #[test]
1074    fn test_get_all_method() {
1075        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1076        assert!(code.contains("pub async fn get_all"));
1077    }
1078
1079    #[test]
1080    fn test_get_all_returns_vec() {
1081        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1082        assert!(code.contains("Vec<Users>"));
1083    }
1084
1085    #[test]
1086    fn test_get_all_sql() {
1087        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1088        assert!(code.contains("SELECT * FROM users"));
1089    }
1090
1091    // --- paginate ---
1092
1093    #[test]
1094    fn test_paginate_method() {
1095        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1096        assert!(code.contains("pub async fn paginate"));
1097    }
1098
1099    #[test]
1100    fn test_paginate_params_struct() {
1101        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1102        assert!(code.contains("pub struct PaginateUsersParams"));
1103    }
1104
1105    #[test]
1106    fn test_paginate_params_fields() {
1107        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1108        assert!(code.contains("pub page: i64"));
1109        assert!(code.contains("pub per_page: i64"));
1110    }
1111
1112    #[test]
1113    fn test_paginate_returns_paginated() {
1114        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1115        assert!(code.contains("PaginatedUsers"));
1116        assert!(code.contains("PaginationUsersMeta"));
1117    }
1118
1119    #[test]
1120    fn test_paginate_meta_struct() {
1121        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1122        assert!(code.contains("pub struct PaginationUsersMeta"));
1123        assert!(code.contains("pub total: i64"));
1124        assert!(code.contains("pub last_page: i64"));
1125        assert!(code.contains("pub first_page: i64"));
1126        assert!(code.contains("pub current_page: i64"));
1127    }
1128
1129    #[test]
1130    fn test_paginate_data_struct() {
1131        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1132        assert!(code.contains("pub struct PaginatedUsers"));
1133        assert!(code.contains("pub meta: PaginationUsersMeta"));
1134        assert!(code.contains("pub data: Vec<Users>"));
1135    }
1136
1137    #[test]
1138    fn test_paginate_count_sql() {
1139        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1140        assert!(code.contains("SELECT COUNT(*) FROM users"));
1141    }
1142
1143    #[test]
1144    fn test_paginate_sql_pg() {
1145        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1146        assert!(code.contains("LIMIT $1 OFFSET $2"));
1147    }
1148
1149    #[test]
1150    fn test_paginate_sql_mysql() {
1151        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1152        assert!(code.contains("LIMIT ? OFFSET ?"));
1153    }
1154
1155    // --- get ---
1156
1157    #[test]
1158    fn test_get_method() {
1159        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1160        assert!(code.contains("pub async fn get"));
1161    }
1162
1163    #[test]
1164    fn test_get_returns_option() {
1165        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1166        assert!(code.contains("Option<Users>"));
1167    }
1168
1169    #[test]
1170    fn test_get_where_pk_pg() {
1171        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1172        assert!(code.contains("WHERE id = $1"));
1173    }
1174
1175    #[test]
1176    fn test_get_where_pk_mysql() {
1177        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1178        assert!(code.contains("WHERE id = ?"));
1179    }
1180
1181    // --- insert ---
1182
1183    #[test]
1184    fn test_insert_method() {
1185        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1186        assert!(code.contains("pub async fn insert"));
1187    }
1188
1189    #[test]
1190    fn test_insert_params_struct() {
1191        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1192        assert!(code.contains("pub struct InsertUsersParams"));
1193    }
1194
1195    #[test]
1196    fn test_insert_params_no_pk() {
1197        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1198        assert!(code.contains("pub name: String"));
1199        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1200    }
1201
1202    #[test]
1203    fn test_insert_returning_pg() {
1204        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1205        assert!(code.contains("RETURNING *"));
1206    }
1207
1208    #[test]
1209    fn test_insert_returning_sqlite() {
1210        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1211        assert!(code.contains("RETURNING *"));
1212    }
1213
1214    #[test]
1215    fn test_insert_mysql_last_insert_id() {
1216        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1217        assert!(code.contains("LAST_INSERT_ID"));
1218    }
1219
1220    // --- insert with column_default ---
1221
1222    #[test]
1223    fn test_insert_default_col_is_optional() {
1224        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1225        // Fields with column_default and not nullable → Option<T>
1226        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1227        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1228        let struct_body = &code[struct_start..struct_end];
1229        assert!(struct_body.contains("Option") && struct_body.contains("status"), "Expected status as Option in InsertTasksParams: {}", struct_body);
1230    }
1231
1232    #[test]
1233    fn test_insert_non_default_col_required() {
1234        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1235        // 'title' has no default → required type (String)
1236        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1237        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1238        let struct_body = &code[struct_start..struct_end];
1239        assert!(struct_body.contains("title") && struct_body.contains("String"), "Expected title as String: {}", struct_body);
1240    }
1241
1242    #[test]
1243    fn test_insert_default_col_coalesce_sql() {
1244        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1245        assert!(code.contains("COALESCE($2, 'idle'::task_status)"), "Expected COALESCE for status:\n{}", code);
1246        assert!(code.contains("COALESCE($3, 0)"), "Expected COALESCE for priority:\n{}", code);
1247        assert!(code.contains("COALESCE($4, now())"), "Expected COALESCE for created_at:\n{}", code);
1248    }
1249
1250    #[test]
1251    fn test_insert_no_coalesce_for_non_default() {
1252        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1253        // title has no default, so its placeholder should be plain $1 not COALESCE
1254        assert!(code.contains("VALUES ($1, COALESCE"), "Expected $1 without COALESCE for title:\n{}", code);
1255    }
1256
1257    #[test]
1258    fn test_insert_nullable_with_default_no_double_option() {
1259        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1260        assert!(!code.contains("Option < Option") && !code.contains("Option<Option"), "Should not have Option<Option>:\n{}", code);
1261    }
1262
1263    #[test]
1264    fn test_insert_derive_default() {
1265        let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1266        let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1267        let before_struct = &code[..struct_start];
1268        assert!(before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"), "Expected #[derive(Default)] on InsertTasksParams");
1269    }
1270
1271    // --- update (patch with COALESCE) ---
1272
1273    #[test]
1274    fn test_update_method() {
1275        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1276        assert!(code.contains("pub async fn update"));
1277    }
1278
1279    #[test]
1280    fn test_update_params_struct() {
1281        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1282        assert!(code.contains("pub struct UpdateUsersParams"));
1283    }
1284
1285    #[test]
1286    fn test_update_pk_in_fn_signature() {
1287        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1288        // PK 'id: i32' should appear between "fn update" and "UpdateUsersParams"
1289        let update_pos = code.find("fn update").expect("fn update not found");
1290        let params_pos = code[update_pos..].find("UpdateUsersParams").expect("UpdateUsersParams not found in update fn");
1291        let signature = &code[update_pos..update_pos + params_pos];
1292        assert!(signature.contains("id"), "Expected 'id' PK in update fn signature: {}", signature);
1293    }
1294
1295    #[test]
1296    fn test_update_pk_not_in_struct() {
1297        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1298        // UpdateUsersParams should NOT contain id field
1299        // Extract the struct definition and check it doesn't have id
1300        let struct_start = code.find("pub struct UpdateUsersParams").expect("UpdateUsersParams not found");
1301        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1302        let struct_body = &code[struct_start..struct_end];
1303        assert!(!struct_body.contains("pub id"), "PK 'id' should not be in UpdateUsersParams:\n{}", struct_body);
1304    }
1305
1306    #[test]
1307    fn test_update_params_non_nullable_wrapped_in_option() {
1308        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1309        // `name: String` becomes `name: Option<String>` in patch params
1310        assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1311    }
1312
1313    #[test]
1314    fn test_update_params_already_nullable_no_double_option() {
1315        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1316        // `email: Option<String>` stays `Option<String>`, NOT `Option<Option<String>>`
1317        assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1318    }
1319
1320    #[test]
1321    fn test_update_set_clause_uses_coalesce_pg() {
1322        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1323        assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1324        assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1325    }
1326
1327    #[test]
1328    fn test_update_where_clause_pg() {
1329        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1330        assert!(code.contains("WHERE id = $3"));
1331    }
1332
1333    #[test]
1334    fn test_update_returning_pg() {
1335        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1336        assert!(code.contains("COALESCE"));
1337        assert!(code.contains("RETURNING *"));
1338    }
1339
1340    #[test]
1341    fn test_update_set_clause_mysql() {
1342        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1343        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1344        assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1345    }
1346
1347    #[test]
1348    fn test_update_set_clause_sqlite() {
1349        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1350        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1351    }
1352
1353    // --- overwrite (full replacement, PK as fn param) ---
1354
1355    #[test]
1356    fn test_overwrite_method() {
1357        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1358        assert!(code.contains("pub async fn overwrite"));
1359    }
1360
1361    #[test]
1362    fn test_overwrite_params_struct() {
1363        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1364        assert!(code.contains("pub struct OverwriteUsersParams"));
1365    }
1366
1367    #[test]
1368    fn test_overwrite_pk_in_fn_signature() {
1369        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1370        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1371        let params_pos = code[pos..].find("OverwriteUsersParams").expect("OverwriteUsersParams not found");
1372        let signature = &code[pos..pos + params_pos];
1373        assert!(signature.contains("id"), "Expected PK in overwrite fn signature: {}", signature);
1374    }
1375
1376    #[test]
1377    fn test_overwrite_pk_not_in_struct() {
1378        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1379        let struct_start = code.find("pub struct OverwriteUsersParams").expect("OverwriteUsersParams not found");
1380        let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1381        let struct_body = &code[struct_start..struct_end];
1382        assert!(!struct_body.contains("pub id"), "PK should not be in OverwriteUsersParams: {}", struct_body);
1383    }
1384
1385    #[test]
1386    fn test_overwrite_no_coalesce() {
1387        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1388        // Find the overwrite SQL — should have direct SET, no COALESCE
1389        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1390        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1391        assert!(!method_body.contains("COALESCE"), "Overwrite should not use COALESCE: {}", method_body);
1392    }
1393
1394    #[test]
1395    fn test_overwrite_set_clause_pg() {
1396        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1397        assert!(code.contains("SET name = $1, email = $2 WHERE id = $3"));
1398    }
1399
1400    #[test]
1401    fn test_overwrite_returning_pg() {
1402        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1403        let pos = code.find("fn overwrite").expect("fn overwrite not found");
1404        let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1405        assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in overwrite");
1406    }
1407
1408    #[test]
1409    fn test_view_no_overwrite() {
1410        let mut entity = standard_entity();
1411        entity.is_view = true;
1412        let code = gen(&entity, DatabaseKind::Postgres);
1413        assert!(!code.contains("pub async fn overwrite"));
1414    }
1415
1416    #[test]
1417    fn test_without_overwrite() {
1418        let m = Methods { overwrite: false, ..Methods::all() };
1419        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1420        assert!(!code.contains("pub async fn overwrite"));
1421        assert!(!code.contains("OverwriteUsersParams"));
1422    }
1423
1424    #[test]
1425    fn test_update_and_overwrite_coexist() {
1426        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1427        assert!(code.contains("pub async fn update"), "Expected update method");
1428        assert!(code.contains("pub async fn overwrite"), "Expected overwrite method");
1429        assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams");
1430        assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams");
1431    }
1432
1433    // --- delete ---
1434
1435    #[test]
1436    fn test_delete_method() {
1437        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1438        assert!(code.contains("pub async fn delete"));
1439    }
1440
1441    #[test]
1442    fn test_delete_where_pk() {
1443        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1444        assert!(code.contains("DELETE FROM users WHERE id = $1"));
1445    }
1446
1447    #[test]
1448    fn test_delete_returns_unit() {
1449        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1450        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1451    }
1452
1453    // --- views (read-only) ---
1454
1455    #[test]
1456    fn test_view_no_insert() {
1457        let mut entity = standard_entity();
1458        entity.is_view = true;
1459        let code = gen(&entity, DatabaseKind::Postgres);
1460        assert!(!code.contains("pub async fn insert"));
1461    }
1462
1463    #[test]
1464    fn test_view_no_update() {
1465        let mut entity = standard_entity();
1466        entity.is_view = true;
1467        let code = gen(&entity, DatabaseKind::Postgres);
1468        assert!(!code.contains("pub async fn update"));
1469    }
1470
1471    #[test]
1472    fn test_view_no_delete() {
1473        let mut entity = standard_entity();
1474        entity.is_view = true;
1475        let code = gen(&entity, DatabaseKind::Postgres);
1476        assert!(!code.contains("pub async fn delete"));
1477    }
1478
1479    #[test]
1480    fn test_view_has_get_all() {
1481        let mut entity = standard_entity();
1482        entity.is_view = true;
1483        let code = gen(&entity, DatabaseKind::Postgres);
1484        assert!(code.contains("pub async fn get_all"));
1485    }
1486
1487    #[test]
1488    fn test_view_has_paginate() {
1489        let mut entity = standard_entity();
1490        entity.is_view = true;
1491        let code = gen(&entity, DatabaseKind::Postgres);
1492        assert!(code.contains("pub async fn paginate"));
1493    }
1494
1495    #[test]
1496    fn test_view_has_get() {
1497        let mut entity = standard_entity();
1498        entity.is_view = true;
1499        let code = gen(&entity, DatabaseKind::Postgres);
1500        assert!(code.contains("pub async fn get"));
1501    }
1502
1503    // --- selective methods ---
1504
1505    #[test]
1506    fn test_only_get_all() {
1507        let m = Methods { get_all: true, ..Default::default() };
1508        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1509        assert!(code.contains("pub async fn get_all"));
1510        assert!(!code.contains("pub async fn paginate"));
1511        assert!(!code.contains("pub async fn insert"));
1512    }
1513
1514    #[test]
1515    fn test_without_get_all() {
1516        let m = Methods { get_all: false, ..Methods::all() };
1517        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1518        assert!(!code.contains("pub async fn get_all"));
1519    }
1520
1521    #[test]
1522    fn test_without_paginate() {
1523        let m = Methods { paginate: false, ..Methods::all() };
1524        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1525        assert!(!code.contains("pub async fn paginate"));
1526        assert!(!code.contains("PaginateUsersParams"));
1527    }
1528
1529    #[test]
1530    fn test_without_get() {
1531        let m = Methods { get: false, ..Methods::all() };
1532        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1533        assert!(code.contains("pub async fn get_all"));
1534        let without_get_all = code.replace("get_all", "XXX");
1535        assert!(!without_get_all.contains("fn get("));
1536    }
1537
1538    #[test]
1539    fn test_without_insert() {
1540        let m = Methods { insert: false, ..Methods::all() };
1541        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1542        assert!(!code.contains("pub async fn insert"));
1543        assert!(!code.contains("InsertUsersParams"));
1544    }
1545
1546    #[test]
1547    fn test_without_update() {
1548        let m = Methods { update: false, ..Methods::all() };
1549        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1550        assert!(!code.contains("pub async fn update"));
1551        assert!(!code.contains("UpdateUsersParams"));
1552    }
1553
1554    #[test]
1555    fn test_without_delete() {
1556        let m = Methods { delete: false, ..Methods::all() };
1557        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1558        assert!(!code.contains("pub async fn delete"));
1559    }
1560
1561    #[test]
1562    fn test_empty_methods_no_methods() {
1563        let m = Methods::default();
1564        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1565        assert!(!code.contains("pub async fn get_all"));
1566        assert!(!code.contains("pub async fn paginate"));
1567        assert!(!code.contains("pub async fn insert"));
1568        assert!(!code.contains("pub async fn update"));
1569        assert!(!code.contains("pub async fn overwrite"));
1570        assert!(!code.contains("pub async fn delete"));
1571    }
1572
1573    // --- imports ---
1574
1575    #[test]
1576    fn test_no_pool_import() {
1577        let skip = Methods::all();
1578        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1579        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1580    }
1581
1582    #[test]
1583    fn test_imports_contain_entity() {
1584        let skip = Methods::all();
1585        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1586        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1587    }
1588
1589    // --- renamed columns ---
1590
1591    #[test]
1592    fn test_renamed_column_in_sql() {
1593        let entity = ParsedEntity {
1594            struct_name: "Connector".to_string(),
1595            table_name: "connector".to_string(),
1596            schema_name: None,
1597            is_view: false,
1598            fields: vec![
1599                make_field("id", "id", "i32", false, true),
1600                make_field("connector_type", "type", "String", false, false),
1601            ],
1602            imports: vec![],
1603        };
1604        let code = gen(&entity, DatabaseKind::Postgres);
1605        // INSERT should use the DB column name "type", not "connector_type"
1606        assert!(code.contains("type"));
1607        // The Rust param field should be connector_type
1608        assert!(code.contains("pub connector_type: String"));
1609    }
1610
1611    // --- no PK edge cases ---
1612
1613    #[test]
1614    fn test_no_pk_no_get() {
1615        let entity = ParsedEntity {
1616            struct_name: "Logs".to_string(),
1617            table_name: "logs".to_string(),
1618            schema_name: None,
1619            is_view: false,
1620            fields: vec![
1621                make_field("message", "message", "String", false, false),
1622                make_field("ts", "ts", "String", false, false),
1623            ],
1624            imports: vec![],
1625        };
1626        let code = gen(&entity, DatabaseKind::Postgres);
1627        assert!(code.contains("pub async fn get_all"));
1628        let without_get_all = code.replace("get_all", "XXX");
1629        assert!(!without_get_all.contains("fn get("));
1630    }
1631
1632    #[test]
1633    fn test_no_pk_no_delete() {
1634        let entity = ParsedEntity {
1635            struct_name: "Logs".to_string(),
1636            table_name: "logs".to_string(),
1637            schema_name: None,
1638            is_view: false,
1639            fields: vec![
1640                make_field("message", "message", "String", false, false),
1641            ],
1642            imports: vec![],
1643        };
1644        let code = gen(&entity, DatabaseKind::Postgres);
1645        assert!(!code.contains("pub async fn delete"));
1646    }
1647
1648    // --- Default derive on param structs ---
1649
1650    #[test]
1651    fn test_param_structs_have_default() {
1652        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1653        assert!(code.contains("Default"));
1654    }
1655
1656    // --- entity imports forwarded ---
1657
1658    #[test]
1659    fn test_entity_imports_forwarded() {
1660        let entity = ParsedEntity {
1661            struct_name: "Users".to_string(),
1662            table_name: "users".to_string(),
1663            schema_name: None,
1664            is_view: false,
1665            fields: vec![
1666                make_field("id", "id", "Uuid", false, true),
1667                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1668            ],
1669            imports: vec![
1670                "use chrono::{DateTime, Utc};".to_string(),
1671                "use uuid::Uuid;".to_string(),
1672            ],
1673        };
1674        let skip = Methods::all();
1675        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1676        assert!(imports.iter().any(|i| i.contains("chrono")));
1677        assert!(imports.iter().any(|i| i.contains("uuid")));
1678    }
1679
1680    #[test]
1681    fn test_entity_imports_empty_when_no_imports() {
1682        let skip = Methods::all();
1683        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1684        // Should only have pool + entity imports, no chrono/uuid
1685        assert!(!imports.iter().any(|i| i.contains("chrono")));
1686        assert!(!imports.iter().any(|i| i.contains("uuid")));
1687    }
1688
1689    // --- query_macro mode ---
1690
1691    #[test]
1692    fn test_macro_get_all() {
1693        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1694        assert!(code.contains("query_as!"));
1695        assert!(!code.contains("query_as::<"));
1696    }
1697
1698    #[test]
1699    fn test_macro_paginate() {
1700        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1701        assert!(code.contains("query_as!"));
1702        assert!(code.contains("per_page, offset"));
1703    }
1704
1705    #[test]
1706    fn test_macro_get() {
1707        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1708        // The get method should use query_as! with the PK as arg
1709        assert!(code.contains("query_as!(Users"));
1710    }
1711
1712    #[test]
1713    fn test_macro_insert_pg() {
1714        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1715        assert!(code.contains("query_as!(Users"));
1716        assert!(code.contains("params.name"));
1717        assert!(code.contains("params.email"));
1718    }
1719
1720    #[test]
1721    fn test_macro_insert_mysql() {
1722        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1723        // MySQL insert uses query! (not query_as!) for the INSERT
1724        assert!(code.contains("query!"));
1725        assert!(code.contains("query_scalar!"));
1726    }
1727
1728    #[test]
1729    fn test_macro_update() {
1730        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1731        assert!(code.contains("query_as!(Users"));
1732        assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1733        assert!(code.contains("pub async fn update"));
1734        assert!(code.contains("UpdateUsersParams"));
1735    }
1736
1737    #[test]
1738    fn test_macro_delete() {
1739        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1740        // delete uses query! (no return type)
1741        assert!(code.contains("query!"));
1742    }
1743
1744    #[test]
1745    fn test_macro_no_bind_calls() {
1746        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1747        assert!(!code.contains(".bind("));
1748    }
1749
1750    #[test]
1751    fn test_function_style_uses_bind() {
1752        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1753        assert!(code.contains(".bind("));
1754        assert!(!code.contains("query_as!("));
1755        assert!(!code.contains("query!("));
1756    }
1757
1758    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
1759
1760    fn entity_with_sql_array() -> ParsedEntity {
1761        ParsedEntity {
1762            struct_name: "AgentConnector".to_string(),
1763            table_name: "agent.agent_connector".to_string(),
1764            schema_name: Some("agent".to_string()),
1765            is_view: false,
1766            fields: vec![
1767                ParsedField {
1768                    rust_name: "connector_id".to_string(),
1769                    column_name: "connector_id".to_string(),
1770                    rust_type: "Uuid".to_string(),
1771                    inner_type: "Uuid".to_string(),
1772                    is_nullable: false,
1773                    is_primary_key: true,
1774                    sql_type: None,
1775                    is_sql_array: false,
1776                    column_default: None,
1777                },
1778                ParsedField {
1779                    rust_name: "agent_id".to_string(),
1780                    column_name: "agent_id".to_string(),
1781                    rust_type: "Uuid".to_string(),
1782                    inner_type: "Uuid".to_string(),
1783                    is_nullable: false,
1784                    is_primary_key: false,
1785                    sql_type: None,
1786                    is_sql_array: false,
1787                    column_default: None,
1788                },
1789                ParsedField {
1790                    rust_name: "usages".to_string(),
1791                    column_name: "usages".to_string(),
1792                    rust_type: "Vec<ConnectorUsages>".to_string(),
1793                    inner_type: "Vec<ConnectorUsages>".to_string(),
1794                    is_nullable: false,
1795                    is_primary_key: false,
1796                    sql_type: Some("agent.connector_usages".to_string()),
1797                    is_sql_array: true,
1798                    column_default: None,
1799                },
1800            ],
1801            imports: vec!["use uuid::Uuid;".to_string()],
1802        }
1803    }
1804
1805    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1806        let skip = Methods::all();
1807        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1808        parse_and_format(&tokens)
1809    }
1810
1811    #[test]
1812    fn test_sql_array_macro_get_all_uses_runtime() {
1813        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1814        // get_all should use runtime query_as, not macro
1815        assert!(code.contains("query_as::<"));
1816    }
1817
1818    #[test]
1819    fn test_sql_array_macro_get_uses_runtime() {
1820        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1821        // get should use .bind( since it's runtime
1822        assert!(code.contains(".bind("));
1823    }
1824
1825    #[test]
1826    fn test_sql_array_macro_insert_uses_runtime() {
1827        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1828        // insert RETURNING should use runtime query_as
1829        assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1830    }
1831
1832
1833    #[test]
1834    fn test_sql_array_macro_delete_still_uses_macro() {
1835        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1836        // delete uses query! macro (no rows returned, no array issue)
1837        assert!(code.contains("query!"));
1838    }
1839
1840    #[test]
1841    fn test_sql_array_no_query_as_macro() {
1842        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1843        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
1844        assert!(!code.contains("query_as!("));
1845    }
1846
1847    // --- custom enum (non-array) also triggers runtime fallback ---
1848
1849    fn entity_with_sql_enum() -> ParsedEntity {
1850        ParsedEntity {
1851            struct_name: "Task".to_string(),
1852            table_name: "tasks".to_string(),
1853            schema_name: None,
1854            is_view: false,
1855            fields: vec![
1856                ParsedField {
1857                    rust_name: "id".to_string(),
1858                    column_name: "id".to_string(),
1859                    rust_type: "i32".to_string(),
1860                    inner_type: "i32".to_string(),
1861                    is_nullable: false,
1862                    is_primary_key: true,
1863                    sql_type: None,
1864                    is_sql_array: false,
1865                    column_default: None,
1866                },
1867                ParsedField {
1868                    rust_name: "status".to_string(),
1869                    column_name: "status".to_string(),
1870                    rust_type: "TaskStatus".to_string(),
1871                    inner_type: "TaskStatus".to_string(),
1872                    is_nullable: false,
1873                    is_primary_key: false,
1874                    sql_type: Some("task_status".to_string()),
1875                    is_sql_array: false,
1876                    column_default: None,
1877                },
1878            ],
1879            imports: vec![],
1880        }
1881    }
1882
1883    #[test]
1884    fn test_sql_enum_macro_uses_runtime() {
1885        let skip = Methods::all();
1886        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1887        let code = parse_and_format(&tokens);
1888        // SELECT queries should use runtime query_as, not macro
1889        assert!(code.contains("query_as::<"));
1890        assert!(!code.contains("query_as!("));
1891    }
1892
1893    #[test]
1894    fn test_sql_enum_macro_delete_still_uses_macro() {
1895        let skip = Methods::all();
1896        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1897        let code = parse_and_format(&tokens);
1898        // DELETE still uses query! macro
1899        assert!(code.contains("query!"));
1900    }
1901
1902    // --- Vec<String> native array uses .as_slice() in macro mode ---
1903
1904    fn entity_with_vec_string() -> ParsedEntity {
1905        ParsedEntity {
1906            struct_name: "PromptHistory".to_string(),
1907            table_name: "prompt_history".to_string(),
1908            schema_name: None,
1909            is_view: false,
1910            fields: vec![
1911                ParsedField {
1912                    rust_name: "id".to_string(),
1913                    column_name: "id".to_string(),
1914                    rust_type: "Uuid".to_string(),
1915                    inner_type: "Uuid".to_string(),
1916                    is_nullable: false,
1917                    is_primary_key: true,
1918                    sql_type: None,
1919                    is_sql_array: false,
1920                    column_default: None,
1921                },
1922                ParsedField {
1923                    rust_name: "content".to_string(),
1924                    column_name: "content".to_string(),
1925                    rust_type: "String".to_string(),
1926                    inner_type: "String".to_string(),
1927                    is_nullable: false,
1928                    is_primary_key: false,
1929                    sql_type: None,
1930                    is_sql_array: false,
1931                    column_default: None,
1932                },
1933                ParsedField {
1934                    rust_name: "tags".to_string(),
1935                    column_name: "tags".to_string(),
1936                    rust_type: "Vec<String>".to_string(),
1937                    inner_type: "Vec<String>".to_string(),
1938                    is_nullable: false,
1939                    is_primary_key: false,
1940                    sql_type: None,
1941                    is_sql_array: false,
1942                    column_default: None,
1943                },
1944            ],
1945            imports: vec!["use uuid::Uuid;".to_string()],
1946        }
1947    }
1948
1949    #[test]
1950    fn test_vec_string_macro_insert_uses_as_slice() {
1951        let skip = Methods::all();
1952        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1953        let code = parse_and_format(&tokens);
1954        assert!(code.contains("as_slice()"));
1955    }
1956
1957    #[test]
1958    fn test_vec_string_macro_update_uses_as_slice() {
1959        let skip = Methods::all();
1960        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1961        let code = parse_and_format(&tokens);
1962        // Should have as_slice() for insert and update
1963        let count = code.matches("as_slice()").count();
1964        assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1965    }
1966
1967    #[test]
1968    fn test_vec_string_non_macro_no_as_slice() {
1969        let skip = Methods::all();
1970        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1971        let code = parse_and_format(&tokens);
1972        // Runtime mode uses .bind() so no as_slice needed
1973        assert!(!code.contains("as_slice()"));
1974    }
1975
1976    #[test]
1977    fn test_vec_string_parsed_from_source_uses_as_slice() {
1978        use crate::codegen::entity_parser::parse_entity_source;
1979        let source = r#"
1980            use uuid::Uuid;
1981
1982            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1983            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1984            pub struct PromptHistory {
1985                #[sqlx_gen(primary_key)]
1986                pub id: Uuid,
1987                pub content: String,
1988                pub tags: Vec<String>,
1989            }
1990        "#;
1991        let entity = parse_entity_source(source).unwrap();
1992        let skip = Methods::all();
1993        let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1994        let code = parse_and_format(&tokens);
1995        assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1996    }
1997
1998    // --- composite PK only (junction table) ---
1999
2000    fn junction_entity() -> ParsedEntity {
2001        ParsedEntity {
2002            struct_name: "AnalysisRecord".to_string(),
2003            table_name: "analysis.analysis__record".to_string(),
2004            schema_name: None,
2005            is_view: false,
2006            fields: vec![
2007                make_field("record_id", "record_id", "uuid::Uuid", false, true),
2008                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2009            ],
2010            imports: vec![],
2011        }
2012    }
2013
2014    #[test]
2015    fn test_composite_pk_only_insert_generated() {
2016        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2017        assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
2018        assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
2019        assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
2020        assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
2021        assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
2022    }
2023
2024    #[test]
2025    fn test_composite_pk_only_no_update() {
2026        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2027        assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
2028        assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
2029    }
2030
2031
2032    #[test]
2033    fn test_composite_pk_only_delete_generated() {
2034        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2035        assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
2036        assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
2037    }
2038
2039    #[test]
2040    fn test_composite_pk_only_get_generated() {
2041        let code = gen(&junction_entity(), DatabaseKind::Postgres);
2042        assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
2043        assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
2044    }
2045}