Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40    // Forward type imports from the entity file (chrono, uuid, etc.)
41    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
42    // since the repository lives in a different module where `super` has a different meaning.
43    let entity_parent = entity_module_path
44        .rsplit_once("::")
45        .map(|(parent, _)| parent)
46        .unwrap_or(entity_module_path);
47    for imp in &entity.imports {
48        if let Some(rest) = imp.strip_prefix("use super::") {
49            imports.insert(format!("use {}::{}", entity_parent, rest));
50        } else {
51            imports.insert(imp.clone());
52        }
53    }
54
55    // Primary key fields
56    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58    // Non-PK fields (for insert)
59    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61    let is_view = entity.is_view;
62
63    // Build method tokens
64    let mut method_tokens = Vec::new();
65    let mut param_structs = Vec::new();
66
67    // --- get_all ---
68    if methods.get_all {
69        let sql = format!("SELECT * FROM {}", table_name);
70        let method = if use_macro {
71            quote! {
72                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73                    sqlx::query_as!(#entity_ident, #sql)
74                        .fetch_all(&self.pool)
75                        .await
76                }
77            }
78        } else {
79            quote! {
80                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81                    sqlx::query_as::<_, #entity_ident>(#sql)
82                        .fetch_all(&self.pool)
83                        .await
84                }
85            }
86        };
87        method_tokens.push(method);
88    }
89
90    // --- paginate ---
91    if methods.paginate {
92        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95        let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96        let sql = match db_kind {
97            DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99        };
100        let method = if use_macro {
101            quote! {
102                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103                    let total: i64 = sqlx::query_scalar!(#count_sql)
104                        .fetch_one(&self.pool)
105                        .await?
106                        .unwrap_or(0);
107                    let per_page = params.per_page;
108                    let current_page = params.page;
109                    let last_page = (total + per_page - 1) / per_page;
110                    let offset = (current_page - 1) * per_page;
111                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112                        .fetch_all(&self.pool)
113                        .await?;
114                    Ok(#paginated_ident {
115                        meta: #pagination_meta_ident {
116                            total,
117                            per_page,
118                            current_page,
119                            last_page,
120                            first_page: 1,
121                        },
122                        data,
123                    })
124                }
125            }
126        } else {
127            quote! {
128                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129                    let total: i64 = sqlx::query_scalar(#count_sql)
130                        .fetch_one(&self.pool)
131                        .await?;
132                    let per_page = params.per_page;
133                    let current_page = params.page;
134                    let last_page = (total + per_page - 1) / per_page;
135                    let offset = (current_page - 1) * per_page;
136                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
137                        .bind(per_page)
138                        .bind(offset)
139                        .fetch_all(&self.pool)
140                        .await?;
141                    Ok(#paginated_ident {
142                        meta: #pagination_meta_ident {
143                            total,
144                            per_page,
145                            current_page,
146                            last_page,
147                            first_page: 1,
148                        },
149                        data,
150                    })
151                }
152            }
153        };
154        method_tokens.push(method);
155        param_structs.push(quote! {
156            #[derive(Debug, Clone, Default)]
157            pub struct #paginate_params_ident {
158                pub page: i64,
159                pub per_page: i64,
160            }
161        });
162        param_structs.push(quote! {
163            #[derive(Debug, Clone)]
164            pub struct #pagination_meta_ident {
165                pub total: i64,
166                pub per_page: i64,
167                pub current_page: i64,
168                pub last_page: i64,
169                pub first_page: i64,
170            }
171        });
172        param_structs.push(quote! {
173            #[derive(Debug, Clone)]
174            pub struct #paginated_ident {
175                pub meta: #pagination_meta_ident,
176                pub data: Vec<#entity_ident>,
177            }
178        });
179    }
180
181    // --- get (by PK) ---
182    if methods.get && !pk_fields.is_empty() {
183        let pk_params: Vec<TokenStream> = pk_fields
184            .iter()
185            .map(|f| {
186                let name = format_ident!("{}", f.rust_name);
187                let ty: TokenStream = f.inner_type.parse().unwrap();
188                quote! { #name: #ty }
189            })
190            .collect();
191
192        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194        let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195        let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197        let binds: Vec<TokenStream> = pk_fields
198            .iter()
199            .map(|f| {
200                let name = format_ident!("{}", f.rust_name);
201                quote! { .bind(#name) }
202            })
203            .collect();
204
205        let method = if use_macro {
206            let pk_arg_names: Vec<TokenStream> = pk_fields
207                .iter()
208                .map(|f| {
209                    let name = format_ident!("{}", f.rust_name);
210                    quote! { #name }
211                })
212                .collect();
213            quote! {
214                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216                        .fetch_optional(&self.pool)
217                        .await
218                }
219            }
220        } else {
221            quote! {
222                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223                    sqlx::query_as::<_, #entity_ident>(#sql)
224                        #(#binds)*
225                        .fetch_optional(&self.pool)
226                        .await
227                }
228            }
229        };
230        method_tokens.push(method);
231    }
232
233    // --- insert (skip for views) ---
234    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237        // When all columns are PKs (e.g. junction tables), use pk_fields for insert
238        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239            pk_fields.clone()
240        } else {
241            non_pk_fields.clone()
242        };
243
244        let insert_fields: Vec<TokenStream> = insert_source_fields
245            .iter()
246            .map(|f| {
247                let name = format_ident!("{}", f.rust_name);
248                let ty: TokenStream = f.rust_type.parse().unwrap();
249                quote! { pub #name: #ty, }
250            })
251            .collect();
252
253        let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
254        let col_list = col_names.join(", ");
255        // Use casted placeholders for macro mode, plain for runtime
256        let placeholders = build_placeholders(insert_source_fields.len(), db_kind, 1);
257        let placeholders_cast = build_placeholders_with_cast(&insert_source_fields, db_kind, 1, true);
258
259        let build_insert_sql = |ph: &str| match db_kind {
260            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
261                format!(
262                    "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
263                    table_name, col_list, ph
264                )
265            }
266            DatabaseKind::Mysql => {
267                format!(
268                    "INSERT INTO {} ({}) VALUES ({})",
269                    table_name, col_list, ph
270                )
271            }
272        };
273        let sql = build_insert_sql(&placeholders);
274        let sql_macro = build_insert_sql(&placeholders_cast);
275
276        let binds: Vec<TokenStream> = insert_source_fields
277            .iter()
278            .map(|f| {
279                let name = format_ident!("{}", f.rust_name);
280                quote! { .bind(&params.#name) }
281            })
282            .collect();
283
284        let insert_method = build_insert_method_parsed(
285            &entity_ident,
286            &insert_params_ident,
287            &sql,
288            &sql_macro,
289            &binds,
290            db_kind,
291            &table_name,
292            &pk_fields,
293            &insert_source_fields,
294            use_macro,
295        );
296        method_tokens.push(insert_method);
297
298        param_structs.push(quote! {
299            #[derive(Debug, Clone, Default)]
300            pub struct #insert_params_ident {
301                #(#insert_fields)*
302            }
303        });
304    }
305
306    // --- overwrite (full replacement — skip for views, skip when all columns are PKs) ---
307    if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
308        let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
309
310        let overwrite_fields: Vec<TokenStream> = entity
311            .fields
312            .iter()
313            .map(|f| {
314                let name = format_ident!("{}", f.rust_name);
315                let ty: TokenStream = f.rust_type.parse().unwrap();
316                quote! { pub #name: #ty, }
317            })
318            .collect();
319
320        let set_cols: Vec<String> = non_pk_fields
321            .iter()
322            .enumerate()
323            .map(|(i, f)| {
324                let p = placeholder(db_kind, i + 1);
325                format!("{} = {}", f.column_name, p)
326            })
327            .collect();
328        let set_clause = set_cols.join(", ");
329
330        // SET clause with casts for macro mode
331        let set_cols_cast: Vec<String> = non_pk_fields
332            .iter()
333            .enumerate()
334            .map(|(i, f)| {
335                let p = placeholder_with_cast(db_kind, i + 1, f);
336                format!("{} = {}", f.column_name, p)
337            })
338            .collect();
339        let set_clause_cast = set_cols_cast.join(", ");
340
341        let pk_start = non_pk_fields.len() + 1;
342        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
343        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
344
345        let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
346            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
347                format!(
348                    "UPDATE {} SET {} WHERE {} RETURNING *",
349                    table_name, sc, wc
350                )
351            }
352            DatabaseKind::Mysql => {
353                format!(
354                    "UPDATE {} SET {} WHERE {}",
355                    table_name, sc, wc
356                )
357            }
358        };
359        let sql = build_overwrite_sql(&set_clause, &where_clause);
360        let sql_macro = build_overwrite_sql(&set_clause_cast, &where_clause_cast);
361
362        // Bind non-PK first, then PK
363        let mut all_binds: Vec<TokenStream> = non_pk_fields
364            .iter()
365            .map(|f| {
366                let name = format_ident!("{}", f.rust_name);
367                quote! { .bind(&params.#name) }
368            })
369            .collect();
370        for f in &pk_fields {
371            let name = format_ident!("{}", f.rust_name);
372            all_binds.push(quote! { .bind(&params.#name) });
373        }
374
375        // Macro args: non-PK fields first, then PK fields
376        let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
377            .iter()
378            .chain(pk_fields.iter())
379            .map(|f| macro_arg_for_field(f))
380            .collect();
381
382        let overwrite_method = if use_macro {
383            match db_kind {
384                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
385                    quote! {
386                        pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
387                            sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
388                                .fetch_one(&self.pool)
389                                .await
390                        }
391                    }
392                }
393                DatabaseKind::Mysql => {
394                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
395                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
396                    let pk_macro_args: Vec<TokenStream> = pk_fields
397                        .iter()
398                        .map(|f| {
399                            let name = format_ident!("{}", f.rust_name);
400                            quote! { params.#name }
401                        })
402                        .collect();
403                    quote! {
404                        pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
405                            sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
406                                .execute(&self.pool)
407                                .await?;
408                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
409                                .fetch_one(&self.pool)
410                                .await
411                        }
412                    }
413                }
414            }
415        } else {
416            match db_kind {
417                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
418                    quote! {
419                        pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
420                            sqlx::query_as::<_, #entity_ident>(#sql)
421                                #(#all_binds)*
422                                .fetch_one(&self.pool)
423                                .await
424                        }
425                    }
426                }
427                DatabaseKind::Mysql => {
428                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
429                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
430                    let pk_binds: Vec<TokenStream> = pk_fields
431                        .iter()
432                        .map(|f| {
433                            let name = format_ident!("{}", f.rust_name);
434                            quote! { .bind(&params.#name) }
435                        })
436                        .collect();
437                    quote! {
438                        pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
439                            sqlx::query(#sql)
440                                #(#all_binds)*
441                                .execute(&self.pool)
442                                .await?;
443                            sqlx::query_as::<_, #entity_ident>(#select_sql)
444                                #(#pk_binds)*
445                                .fetch_one(&self.pool)
446                                .await
447                        }
448                    }
449                }
450            }
451        };
452        method_tokens.push(overwrite_method);
453
454        param_structs.push(quote! {
455            #[derive(Debug, Clone, Default)]
456            pub struct #overwrite_params_ident {
457                #(#overwrite_fields)*
458            }
459        });
460    }
461
462    // --- update / patch (COALESCE — skip for views, skip when all columns are PKs) ---
463    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
464        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
465
466        // PK fields keep original type; non-PK fields become Option<T> (no double Option)
467        let update_fields: Vec<TokenStream> = entity
468            .fields
469            .iter()
470            .map(|f| {
471                let name = format_ident!("{}", f.rust_name);
472                if f.is_primary_key {
473                    let ty: TokenStream = f.rust_type.parse().unwrap();
474                    quote! { pub #name: #ty, }
475                } else if f.is_nullable {
476                    // Already Option<T> — keep as-is to avoid Option<Option<T>>
477                    let ty: TokenStream = f.rust_type.parse().unwrap();
478                    quote! { pub #name: #ty, }
479                } else {
480                    let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
481                    quote! { pub #name: #ty, }
482                }
483            })
484            .collect();
485
486        // SET clause with COALESCE for runtime mode
487        let set_cols: Vec<String> = non_pk_fields
488            .iter()
489            .enumerate()
490            .map(|(i, f)| {
491                let p = placeholder(db_kind, i + 1);
492                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
493            })
494            .collect();
495        let set_clause = set_cols.join(", ");
496
497        // SET clause with COALESCE and casts for macro mode
498        let set_cols_cast: Vec<String> = non_pk_fields
499            .iter()
500            .enumerate()
501            .map(|(i, f)| {
502                let p = placeholder_with_cast(db_kind, i + 1, f);
503                format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
504            })
505            .collect();
506        let set_clause_cast = set_cols_cast.join(", ");
507
508        let pk_start = non_pk_fields.len() + 1;
509        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
510        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
511
512        let build_update_sql = |sc: &str, wc: &str| match db_kind {
513            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
514                format!(
515                    "UPDATE {} SET {} WHERE {} RETURNING *",
516                    table_name, sc, wc
517                )
518            }
519            DatabaseKind::Mysql => {
520                format!(
521                    "UPDATE {} SET {} WHERE {}",
522                    table_name, sc, wc
523                )
524            }
525        };
526        let sql = build_update_sql(&set_clause, &where_clause);
527        let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
528
529        // Bind non-PK first, then PK
530        let mut all_binds: Vec<TokenStream> = non_pk_fields
531            .iter()
532            .map(|f| {
533                let name = format_ident!("{}", f.rust_name);
534                quote! { .bind(&params.#name) }
535            })
536            .collect();
537        for f in &pk_fields {
538            let name = format_ident!("{}", f.rust_name);
539            all_binds.push(quote! { .bind(&params.#name) });
540        }
541
542        // Macro args: non-PK fields first, then PK fields
543        // All non-PK fields are Option<T> in the params struct, sqlx handles None → NULL
544        let update_macro_args: Vec<TokenStream> = non_pk_fields
545            .iter()
546            .chain(pk_fields.iter())
547            .map(|f| macro_arg_for_field(f))
548            .collect();
549
550        let update_method = if use_macro {
551            match db_kind {
552                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
553                    quote! {
554                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
555                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
556                                .fetch_one(&self.pool)
557                                .await
558                        }
559                    }
560                }
561                DatabaseKind::Mysql => {
562                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
563                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
564                    let pk_macro_args: Vec<TokenStream> = pk_fields
565                        .iter()
566                        .map(|f| {
567                            let name = format_ident!("{}", f.rust_name);
568                            quote! { params.#name }
569                        })
570                        .collect();
571                    quote! {
572                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
573                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
574                                .execute(&self.pool)
575                                .await?;
576                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
577                                .fetch_one(&self.pool)
578                                .await
579                        }
580                    }
581                }
582            }
583        } else {
584            match db_kind {
585                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
586                    quote! {
587                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
588                            sqlx::query_as::<_, #entity_ident>(#sql)
589                                #(#all_binds)*
590                                .fetch_one(&self.pool)
591                                .await
592                        }
593                    }
594                }
595                DatabaseKind::Mysql => {
596                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
597                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
598                    let pk_binds: Vec<TokenStream> = pk_fields
599                        .iter()
600                        .map(|f| {
601                            let name = format_ident!("{}", f.rust_name);
602                            quote! { .bind(&params.#name) }
603                        })
604                        .collect();
605                    quote! {
606                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
607                            sqlx::query(#sql)
608                                #(#all_binds)*
609                                .execute(&self.pool)
610                                .await?;
611                            sqlx::query_as::<_, #entity_ident>(#select_sql)
612                                #(#pk_binds)*
613                                .fetch_one(&self.pool)
614                                .await
615                        }
616                    }
617                }
618            }
619        };
620        method_tokens.push(update_method);
621
622        param_structs.push(quote! {
623            #[derive(Debug, Clone, Default)]
624            pub struct #update_params_ident {
625                #(#update_fields)*
626            }
627        });
628    }
629
630    // --- delete (skip for views) ---
631    if !is_view && methods.delete && !pk_fields.is_empty() {
632        let pk_params: Vec<TokenStream> = pk_fields
633            .iter()
634            .map(|f| {
635                let name = format_ident!("{}", f.rust_name);
636                let ty: TokenStream = f.inner_type.parse().unwrap();
637                quote! { #name: #ty }
638            })
639            .collect();
640
641        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
642        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
643        let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
644        let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
645
646        let binds: Vec<TokenStream> = pk_fields
647            .iter()
648            .map(|f| {
649                let name = format_ident!("{}", f.rust_name);
650                quote! { .bind(#name) }
651            })
652            .collect();
653
654        let method = if query_macro {
655            let pk_arg_names: Vec<TokenStream> = pk_fields
656                .iter()
657                .map(|f| {
658                    let name = format_ident!("{}", f.rust_name);
659                    quote! { #name }
660                })
661                .collect();
662            quote! {
663                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
664                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
665                        .execute(&self.pool)
666                        .await?;
667                    Ok(())
668                }
669            }
670        } else {
671            quote! {
672                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
673                    sqlx::query(#sql)
674                        #(#binds)*
675                        .execute(&self.pool)
676                        .await?;
677                    Ok(())
678                }
679            }
680        };
681        method_tokens.push(method);
682    }
683
684    let pool_vis: TokenStream = match pool_visibility {
685        PoolVisibility::Private => quote! {},
686        PoolVisibility::Pub => quote! { pub },
687        PoolVisibility::PubCrate => quote! { pub(crate) },
688    };
689
690    let tokens = quote! {
691        #(#param_structs)*
692
693        pub struct #repo_ident {
694            #pool_vis pool: #pool_type,
695        }
696
697        impl #repo_ident {
698            pub fn new(pool: #pool_type) -> Self {
699                Self { pool }
700            }
701
702            #(#method_tokens)*
703        }
704    };
705
706    (tokens, imports)
707}
708
709fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
710    match db_kind {
711        DatabaseKind::Postgres => quote! { sqlx::PgPool },
712        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
713        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
714    }
715}
716
717fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
718    match db_kind {
719        DatabaseKind::Postgres => format!("${}", index),
720        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
721    }
722}
723
724fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
725    let base = placeholder(db_kind, index);
726    match (&field.sql_type, field.is_sql_array) {
727        (Some(t), true) => format!("{} as {}[]", base, t),
728        (Some(t), false) => format!("{} as {}", base, t),
729        (None, _) => base,
730    }
731}
732
733fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
734    (0..count)
735        .map(|i| placeholder(db_kind, start + i))
736        .collect::<Vec<_>>()
737        .join(", ")
738}
739
740fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
741    fields
742        .iter()
743        .enumerate()
744        .map(|(i, f)| {
745            if use_cast {
746                placeholder_with_cast(db_kind, start + i, f)
747            } else {
748                placeholder(db_kind, start + i)
749            }
750        })
751        .collect::<Vec<_>>()
752        .join(", ")
753}
754
755fn build_where_clause_parsed(
756    pk_fields: &[&ParsedField],
757    db_kind: DatabaseKind,
758    start_index: usize,
759) -> String {
760    pk_fields
761        .iter()
762        .enumerate()
763        .map(|(i, f)| {
764            let p = placeholder(db_kind, start_index + i);
765            format!("{} = {}", f.column_name, p)
766        })
767        .collect::<Vec<_>>()
768        .join(" AND ")
769}
770
771fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
772    let name = format_ident!("{}", field.rust_name);
773    let check_type = if field.is_nullable {
774        &field.inner_type
775    } else {
776        &field.rust_type
777    };
778    let normalized = check_type.replace(' ', "");
779    if normalized.starts_with("Vec<") {
780        quote! { params.#name.as_slice() }
781    } else {
782        quote! { params.#name }
783    }
784}
785
786fn build_where_clause_cast(
787    pk_fields: &[&ParsedField],
788    db_kind: DatabaseKind,
789    start_index: usize,
790) -> String {
791    pk_fields
792        .iter()
793        .enumerate()
794        .map(|(i, f)| {
795            let p = placeholder_with_cast(db_kind, start_index + i, f);
796            format!("{} = {}", f.column_name, p)
797        })
798        .collect::<Vec<_>>()
799        .join(" AND ")
800}
801
802#[allow(clippy::too_many_arguments)]
803fn build_insert_method_parsed(
804    entity_ident: &proc_macro2::Ident,
805    insert_params_ident: &proc_macro2::Ident,
806    sql: &str,
807    sql_macro: &str,
808    binds: &[TokenStream],
809    db_kind: DatabaseKind,
810    table_name: &str,
811    pk_fields: &[&ParsedField],
812    non_pk_fields: &[&ParsedField],
813    use_macro: bool,
814) -> TokenStream {
815    if use_macro {
816        let macro_args: Vec<TokenStream> = non_pk_fields
817            .iter()
818            .map(|f| macro_arg_for_field(f))
819            .collect();
820
821        match db_kind {
822            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
823                quote! {
824                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
825                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
826                            .fetch_one(&self.pool)
827                            .await
828                    }
829                }
830            }
831            DatabaseKind::Mysql => {
832                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
833                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
834                quote! {
835                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
836                        sqlx::query!(#sql_macro, #(#macro_args),*)
837                            .execute(&self.pool)
838                            .await?;
839                        let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
840                            .fetch_one(&self.pool)
841                            .await?;
842                        sqlx::query_as!(#entity_ident, #select_sql, id)
843                            .fetch_one(&self.pool)
844                            .await
845                    }
846                }
847            }
848        }
849    } else {
850        match db_kind {
851            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
852                quote! {
853                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
854                        sqlx::query_as::<_, #entity_ident>(#sql)
855                            #(#binds)*
856                            .fetch_one(&self.pool)
857                            .await
858                    }
859                }
860            }
861            DatabaseKind::Mysql => {
862                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
863                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
864                quote! {
865                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
866                        sqlx::query(#sql)
867                            #(#binds)*
868                            .execute(&self.pool)
869                            .await?;
870                        let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
871                            .fetch_one(&self.pool)
872                            .await?;
873                        sqlx::query_as::<_, #entity_ident>(#select_sql)
874                            .bind(id)
875                            .fetch_one(&self.pool)
876                            .await
877                    }
878                }
879            }
880        }
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887    use crate::codegen::parse_and_format;
888    use crate::cli::Methods;
889
890    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
891        let inner_type = if nullable {
892            // Strip "Option<" prefix and ">" suffix
893            rust_type
894                .strip_prefix("Option<")
895                .and_then(|s| s.strip_suffix('>'))
896                .unwrap_or(rust_type)
897                .to_string()
898        } else {
899            rust_type.to_string()
900        };
901        ParsedField {
902            rust_name: rust_name.to_string(),
903            column_name: column_name.to_string(),
904            rust_type: rust_type.to_string(),
905            is_nullable: nullable,
906            inner_type,
907            is_primary_key: is_pk,
908            sql_type: None,
909            is_sql_array: false,
910        }
911    }
912
913    fn standard_entity() -> ParsedEntity {
914        ParsedEntity {
915            struct_name: "Users".to_string(),
916            table_name: "users".to_string(),
917            schema_name: None,
918            is_view: false,
919            fields: vec![
920                make_field("id", "id", "i32", false, true),
921                make_field("name", "name", "String", false, false),
922                make_field("email", "email", "Option<String>", true, false),
923            ],
924            imports: vec![],
925        }
926    }
927
928    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
929        let skip = Methods::all();
930        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
931        parse_and_format(&tokens)
932    }
933
934    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
935        let skip = Methods::all();
936        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
937        parse_and_format(&tokens)
938    }
939
940    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
941        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
942        parse_and_format(&tokens)
943    }
944
945    // --- basic structure ---
946
947    #[test]
948    fn test_repo_struct_name() {
949        let code = gen(&standard_entity(), DatabaseKind::Postgres);
950        assert!(code.contains("pub struct UsersRepository"));
951    }
952
953    #[test]
954    fn test_repo_new_method() {
955        let code = gen(&standard_entity(), DatabaseKind::Postgres);
956        assert!(code.contains("pub fn new("));
957    }
958
959    #[test]
960    fn test_repo_pool_field_pg() {
961        let code = gen(&standard_entity(), DatabaseKind::Postgres);
962        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
963    }
964
965    #[test]
966    fn test_repo_pool_field_pub() {
967        let skip = Methods::all();
968        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
969        let code = parse_and_format(&tokens);
970        assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
971    }
972
973    #[test]
974    fn test_repo_pool_field_pub_crate() {
975        let skip = Methods::all();
976        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
977        let code = parse_and_format(&tokens);
978        assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
979    }
980
981    #[test]
982    fn test_repo_pool_field_private() {
983        let code = gen(&standard_entity(), DatabaseKind::Postgres);
984        // Should NOT have `pub pool` or `pub(crate) pool`
985        assert!(!code.contains("pub pool"));
986        assert!(!code.contains("pub(crate) pool"));
987    }
988
989    #[test]
990    fn test_repo_pool_field_mysql() {
991        let code = gen(&standard_entity(), DatabaseKind::Mysql);
992        assert!(code.contains("MySqlPool") || code.contains("MySql"));
993    }
994
995    #[test]
996    fn test_repo_pool_field_sqlite() {
997        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
998        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
999    }
1000
1001    // --- get_all ---
1002
1003    #[test]
1004    fn test_get_all_method() {
1005        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1006        assert!(code.contains("pub async fn get_all"));
1007    }
1008
1009    #[test]
1010    fn test_get_all_returns_vec() {
1011        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1012        assert!(code.contains("Vec<Users>"));
1013    }
1014
1015    #[test]
1016    fn test_get_all_sql() {
1017        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1018        assert!(code.contains("SELECT * FROM users"));
1019    }
1020
1021    // --- paginate ---
1022
1023    #[test]
1024    fn test_paginate_method() {
1025        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1026        assert!(code.contains("pub async fn paginate"));
1027    }
1028
1029    #[test]
1030    fn test_paginate_params_struct() {
1031        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1032        assert!(code.contains("pub struct PaginateUsersParams"));
1033    }
1034
1035    #[test]
1036    fn test_paginate_params_fields() {
1037        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1038        assert!(code.contains("pub page: i64"));
1039        assert!(code.contains("pub per_page: i64"));
1040    }
1041
1042    #[test]
1043    fn test_paginate_returns_paginated() {
1044        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1045        assert!(code.contains("PaginatedUsers"));
1046        assert!(code.contains("PaginationUsersMeta"));
1047    }
1048
1049    #[test]
1050    fn test_paginate_meta_struct() {
1051        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1052        assert!(code.contains("pub struct PaginationUsersMeta"));
1053        assert!(code.contains("pub total: i64"));
1054        assert!(code.contains("pub last_page: i64"));
1055        assert!(code.contains("pub first_page: i64"));
1056        assert!(code.contains("pub current_page: i64"));
1057    }
1058
1059    #[test]
1060    fn test_paginate_data_struct() {
1061        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1062        assert!(code.contains("pub struct PaginatedUsers"));
1063        assert!(code.contains("pub meta: PaginationUsersMeta"));
1064        assert!(code.contains("pub data: Vec<Users>"));
1065    }
1066
1067    #[test]
1068    fn test_paginate_count_sql() {
1069        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1070        assert!(code.contains("SELECT COUNT(*) FROM users"));
1071    }
1072
1073    #[test]
1074    fn test_paginate_sql_pg() {
1075        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1076        assert!(code.contains("LIMIT $1 OFFSET $2"));
1077    }
1078
1079    #[test]
1080    fn test_paginate_sql_mysql() {
1081        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1082        assert!(code.contains("LIMIT ? OFFSET ?"));
1083    }
1084
1085    // --- get ---
1086
1087    #[test]
1088    fn test_get_method() {
1089        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1090        assert!(code.contains("pub async fn get"));
1091    }
1092
1093    #[test]
1094    fn test_get_returns_option() {
1095        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1096        assert!(code.contains("Option<Users>"));
1097    }
1098
1099    #[test]
1100    fn test_get_where_pk_pg() {
1101        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1102        assert!(code.contains("WHERE id = $1"));
1103    }
1104
1105    #[test]
1106    fn test_get_where_pk_mysql() {
1107        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1108        assert!(code.contains("WHERE id = ?"));
1109    }
1110
1111    // --- insert ---
1112
1113    #[test]
1114    fn test_insert_method() {
1115        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1116        assert!(code.contains("pub async fn insert"));
1117    }
1118
1119    #[test]
1120    fn test_insert_params_struct() {
1121        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1122        assert!(code.contains("pub struct InsertUsersParams"));
1123    }
1124
1125    #[test]
1126    fn test_insert_params_no_pk() {
1127        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1128        assert!(code.contains("pub name: String"));
1129        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1130    }
1131
1132    #[test]
1133    fn test_insert_returning_pg() {
1134        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1135        assert!(code.contains("RETURNING *"));
1136    }
1137
1138    #[test]
1139    fn test_insert_returning_sqlite() {
1140        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1141        assert!(code.contains("RETURNING *"));
1142    }
1143
1144    #[test]
1145    fn test_insert_mysql_last_insert_id() {
1146        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1147        assert!(code.contains("LAST_INSERT_ID"));
1148    }
1149
1150    // --- overwrite (full replacement, formerly "update") ---
1151
1152    #[test]
1153    fn test_overwrite_method() {
1154        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1155        assert!(code.contains("pub async fn overwrite"));
1156    }
1157
1158    #[test]
1159    fn test_overwrite_params_struct() {
1160        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1161        assert!(code.contains("pub struct OverwriteUsersParams"));
1162    }
1163
1164    #[test]
1165    fn test_overwrite_params_all_cols() {
1166        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1167        assert!(code.contains("OverwriteUsersParams"));
1168        assert!(code.contains("pub id: i32"));
1169        assert!(code.contains("pub name: String"));
1170    }
1171
1172    #[test]
1173    fn test_overwrite_set_clause_pg() {
1174        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1175        // Overwrite uses direct assignment, no COALESCE
1176        assert!(code.contains("SET name = $1, email = $2 WHERE id = $3"));
1177    }
1178
1179    #[test]
1180    fn test_overwrite_returning_pg() {
1181        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1182        assert!(code.contains("UPDATE users SET"));
1183        assert!(code.contains("RETURNING *"));
1184    }
1185
1186    // --- update (patch with COALESCE) ---
1187
1188    #[test]
1189    fn test_update_method() {
1190        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1191        assert!(code.contains("pub async fn update"));
1192    }
1193
1194    #[test]
1195    fn test_update_params_struct() {
1196        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1197        assert!(code.contains("pub struct UpdateUsersParams"));
1198    }
1199
1200    #[test]
1201    fn test_update_params_pk_keeps_original_type() {
1202        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1203        // PK field `id` stays i32 (not Option<i32>) in UpdateUsersParams
1204        assert!(code.contains("pub id: i32") || code.contains("pub id : i32"));
1205    }
1206
1207    #[test]
1208    fn test_update_params_non_nullable_wrapped_in_option() {
1209        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1210        // `name: String` becomes `name: Option<String>` in patch params
1211        assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1212    }
1213
1214    #[test]
1215    fn test_update_params_already_nullable_no_double_option() {
1216        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1217        // `email: Option<String>` stays `Option<String>`, NOT `Option<Option<String>>`
1218        assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1219    }
1220
1221    #[test]
1222    fn test_update_set_clause_uses_coalesce_pg() {
1223        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1224        assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1225        assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1226    }
1227
1228    #[test]
1229    fn test_update_where_clause_pg() {
1230        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1231        assert!(code.contains("WHERE id = $3"));
1232    }
1233
1234    #[test]
1235    fn test_update_returning_pg() {
1236        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1237        assert!(code.contains("COALESCE"));
1238        assert!(code.contains("RETURNING *"));
1239    }
1240
1241    #[test]
1242    fn test_update_set_clause_mysql() {
1243        let code = gen(&standard_entity(), DatabaseKind::Mysql);
1244        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1245        assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1246    }
1247
1248    #[test]
1249    fn test_update_set_clause_sqlite() {
1250        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1251        assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1252    }
1253
1254    #[test]
1255    fn test_update_and_overwrite_coexist() {
1256        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1257        assert!(code.contains("pub async fn update"), "Expected update method:\n{}", code);
1258        assert!(code.contains("pub async fn overwrite"), "Expected overwrite method:\n{}", code);
1259        assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams struct:\n{}", code);
1260        assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams struct:\n{}", code);
1261    }
1262
1263    // --- delete ---
1264
1265    #[test]
1266    fn test_delete_method() {
1267        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1268        assert!(code.contains("pub async fn delete"));
1269    }
1270
1271    #[test]
1272    fn test_delete_where_pk() {
1273        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1274        assert!(code.contains("DELETE FROM users WHERE id = $1"));
1275    }
1276
1277    #[test]
1278    fn test_delete_returns_unit() {
1279        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1280        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1281    }
1282
1283    // --- views (read-only) ---
1284
1285    #[test]
1286    fn test_view_no_insert() {
1287        let mut entity = standard_entity();
1288        entity.is_view = true;
1289        let code = gen(&entity, DatabaseKind::Postgres);
1290        assert!(!code.contains("pub async fn insert"));
1291    }
1292
1293    #[test]
1294    fn test_view_no_update() {
1295        let mut entity = standard_entity();
1296        entity.is_view = true;
1297        let code = gen(&entity, DatabaseKind::Postgres);
1298        assert!(!code.contains("pub async fn update"));
1299    }
1300
1301    #[test]
1302    fn test_view_no_overwrite() {
1303        let mut entity = standard_entity();
1304        entity.is_view = true;
1305        let code = gen(&entity, DatabaseKind::Postgres);
1306        assert!(!code.contains("pub async fn overwrite"));
1307    }
1308
1309    #[test]
1310    fn test_view_no_delete() {
1311        let mut entity = standard_entity();
1312        entity.is_view = true;
1313        let code = gen(&entity, DatabaseKind::Postgres);
1314        assert!(!code.contains("pub async fn delete"));
1315    }
1316
1317    #[test]
1318    fn test_view_has_get_all() {
1319        let mut entity = standard_entity();
1320        entity.is_view = true;
1321        let code = gen(&entity, DatabaseKind::Postgres);
1322        assert!(code.contains("pub async fn get_all"));
1323    }
1324
1325    #[test]
1326    fn test_view_has_paginate() {
1327        let mut entity = standard_entity();
1328        entity.is_view = true;
1329        let code = gen(&entity, DatabaseKind::Postgres);
1330        assert!(code.contains("pub async fn paginate"));
1331    }
1332
1333    #[test]
1334    fn test_view_has_get() {
1335        let mut entity = standard_entity();
1336        entity.is_view = true;
1337        let code = gen(&entity, DatabaseKind::Postgres);
1338        assert!(code.contains("pub async fn get"));
1339    }
1340
1341    // --- selective methods ---
1342
1343    #[test]
1344    fn test_only_get_all() {
1345        let m = Methods { get_all: true, ..Default::default() };
1346        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1347        assert!(code.contains("pub async fn get_all"));
1348        assert!(!code.contains("pub async fn paginate"));
1349        assert!(!code.contains("pub async fn insert"));
1350    }
1351
1352    #[test]
1353    fn test_without_get_all() {
1354        let m = Methods { get_all: false, ..Methods::all() };
1355        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1356        assert!(!code.contains("pub async fn get_all"));
1357    }
1358
1359    #[test]
1360    fn test_without_paginate() {
1361        let m = Methods { paginate: false, ..Methods::all() };
1362        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1363        assert!(!code.contains("pub async fn paginate"));
1364        assert!(!code.contains("PaginateUsersParams"));
1365    }
1366
1367    #[test]
1368    fn test_without_get() {
1369        let m = Methods { get: false, ..Methods::all() };
1370        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1371        assert!(code.contains("pub async fn get_all"));
1372        let without_get_all = code.replace("get_all", "XXX");
1373        assert!(!without_get_all.contains("fn get("));
1374    }
1375
1376    #[test]
1377    fn test_without_insert() {
1378        let m = Methods { insert: false, ..Methods::all() };
1379        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1380        assert!(!code.contains("pub async fn insert"));
1381        assert!(!code.contains("InsertUsersParams"));
1382    }
1383
1384    #[test]
1385    fn test_without_update() {
1386        let m = Methods { update: false, ..Methods::all() };
1387        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1388        assert!(!code.contains("pub async fn update"));
1389        assert!(!code.contains("UpdateUsersParams"));
1390    }
1391
1392    #[test]
1393    fn test_without_overwrite() {
1394        let m = Methods { overwrite: false, ..Methods::all() };
1395        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1396        assert!(!code.contains("pub async fn overwrite"));
1397        assert!(!code.contains("OverwriteUsersParams"));
1398    }
1399
1400    #[test]
1401    fn test_without_delete() {
1402        let m = Methods { delete: false, ..Methods::all() };
1403        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1404        assert!(!code.contains("pub async fn delete"));
1405    }
1406
1407    #[test]
1408    fn test_empty_methods_no_methods() {
1409        let m = Methods::default();
1410        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1411        assert!(!code.contains("pub async fn get_all"));
1412        assert!(!code.contains("pub async fn paginate"));
1413        assert!(!code.contains("pub async fn insert"));
1414        assert!(!code.contains("pub async fn update"));
1415        assert!(!code.contains("pub async fn overwrite"));
1416        assert!(!code.contains("pub async fn delete"));
1417    }
1418
1419    // --- imports ---
1420
1421    #[test]
1422    fn test_no_pool_import() {
1423        let skip = Methods::all();
1424        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1425        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1426    }
1427
1428    #[test]
1429    fn test_imports_contain_entity() {
1430        let skip = Methods::all();
1431        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1432        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1433    }
1434
1435    // --- renamed columns ---
1436
1437    #[test]
1438    fn test_renamed_column_in_sql() {
1439        let entity = ParsedEntity {
1440            struct_name: "Connector".to_string(),
1441            table_name: "connector".to_string(),
1442            schema_name: None,
1443            is_view: false,
1444            fields: vec![
1445                make_field("id", "id", "i32", false, true),
1446                make_field("connector_type", "type", "String", false, false),
1447            ],
1448            imports: vec![],
1449        };
1450        let code = gen(&entity, DatabaseKind::Postgres);
1451        // INSERT should use the DB column name "type", not "connector_type"
1452        assert!(code.contains("type"));
1453        // The Rust param field should be connector_type
1454        assert!(code.contains("pub connector_type: String"));
1455    }
1456
1457    // --- no PK edge cases ---
1458
1459    #[test]
1460    fn test_no_pk_no_get() {
1461        let entity = ParsedEntity {
1462            struct_name: "Logs".to_string(),
1463            table_name: "logs".to_string(),
1464            schema_name: None,
1465            is_view: false,
1466            fields: vec![
1467                make_field("message", "message", "String", false, false),
1468                make_field("ts", "ts", "String", false, false),
1469            ],
1470            imports: vec![],
1471        };
1472        let code = gen(&entity, DatabaseKind::Postgres);
1473        assert!(code.contains("pub async fn get_all"));
1474        let without_get_all = code.replace("get_all", "XXX");
1475        assert!(!without_get_all.contains("fn get("));
1476    }
1477
1478    #[test]
1479    fn test_no_pk_no_delete() {
1480        let entity = ParsedEntity {
1481            struct_name: "Logs".to_string(),
1482            table_name: "logs".to_string(),
1483            schema_name: None,
1484            is_view: false,
1485            fields: vec![
1486                make_field("message", "message", "String", false, false),
1487            ],
1488            imports: vec![],
1489        };
1490        let code = gen(&entity, DatabaseKind::Postgres);
1491        assert!(!code.contains("pub async fn delete"));
1492    }
1493
1494    // --- Default derive on param structs ---
1495
1496    #[test]
1497    fn test_param_structs_have_default() {
1498        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1499        assert!(code.contains("Default"));
1500    }
1501
1502    // --- entity imports forwarded ---
1503
1504    #[test]
1505    fn test_entity_imports_forwarded() {
1506        let entity = ParsedEntity {
1507            struct_name: "Users".to_string(),
1508            table_name: "users".to_string(),
1509            schema_name: None,
1510            is_view: false,
1511            fields: vec![
1512                make_field("id", "id", "Uuid", false, true),
1513                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1514            ],
1515            imports: vec![
1516                "use chrono::{DateTime, Utc};".to_string(),
1517                "use uuid::Uuid;".to_string(),
1518            ],
1519        };
1520        let skip = Methods::all();
1521        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1522        assert!(imports.iter().any(|i| i.contains("chrono")));
1523        assert!(imports.iter().any(|i| i.contains("uuid")));
1524    }
1525
1526    #[test]
1527    fn test_entity_imports_empty_when_no_imports() {
1528        let skip = Methods::all();
1529        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1530        // Should only have pool + entity imports, no chrono/uuid
1531        assert!(!imports.iter().any(|i| i.contains("chrono")));
1532        assert!(!imports.iter().any(|i| i.contains("uuid")));
1533    }
1534
1535    // --- query_macro mode ---
1536
1537    #[test]
1538    fn test_macro_get_all() {
1539        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1540        assert!(code.contains("query_as!"));
1541        assert!(!code.contains("query_as::<"));
1542    }
1543
1544    #[test]
1545    fn test_macro_paginate() {
1546        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1547        assert!(code.contains("query_as!"));
1548        assert!(code.contains("per_page, offset"));
1549    }
1550
1551    #[test]
1552    fn test_macro_get() {
1553        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1554        // The get method should use query_as! with the PK as arg
1555        assert!(code.contains("query_as!(Users"));
1556    }
1557
1558    #[test]
1559    fn test_macro_insert_pg() {
1560        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1561        assert!(code.contains("query_as!(Users"));
1562        assert!(code.contains("params.name"));
1563        assert!(code.contains("params.email"));
1564    }
1565
1566    #[test]
1567    fn test_macro_insert_mysql() {
1568        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1569        // MySQL insert uses query! (not query_as!) for the INSERT
1570        assert!(code.contains("query!"));
1571        assert!(code.contains("query_scalar!"));
1572    }
1573
1574    #[test]
1575    fn test_macro_overwrite() {
1576        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1577        assert!(code.contains("query_as!(Users"));
1578        // Should contain params.name, params.email, params.id as args
1579        assert!(code.contains("pub async fn overwrite"));
1580        assert!(code.contains("OverwriteUsersParams"));
1581    }
1582
1583    #[test]
1584    fn test_macro_update() {
1585        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1586        assert!(code.contains("query_as!(Users"));
1587        assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1588        assert!(code.contains("pub async fn update"));
1589        assert!(code.contains("UpdateUsersParams"));
1590    }
1591
1592    #[test]
1593    fn test_macro_delete() {
1594        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1595        // delete uses query! (no return type)
1596        assert!(code.contains("query!"));
1597    }
1598
1599    #[test]
1600    fn test_macro_no_bind_calls() {
1601        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1602        assert!(!code.contains(".bind("));
1603    }
1604
1605    #[test]
1606    fn test_function_style_uses_bind() {
1607        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1608        assert!(code.contains(".bind("));
1609        assert!(!code.contains("query_as!("));
1610        assert!(!code.contains("query!("));
1611    }
1612
1613    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
1614
1615    fn entity_with_sql_array() -> ParsedEntity {
1616        ParsedEntity {
1617            struct_name: "AgentConnector".to_string(),
1618            table_name: "agent.agent_connector".to_string(),
1619            schema_name: Some("agent".to_string()),
1620            is_view: false,
1621            fields: vec![
1622                ParsedField {
1623                    rust_name: "connector_id".to_string(),
1624                    column_name: "connector_id".to_string(),
1625                    rust_type: "Uuid".to_string(),
1626                    inner_type: "Uuid".to_string(),
1627                    is_nullable: false,
1628                    is_primary_key: true,
1629                    sql_type: None,
1630                    is_sql_array: false,
1631                },
1632                ParsedField {
1633                    rust_name: "agent_id".to_string(),
1634                    column_name: "agent_id".to_string(),
1635                    rust_type: "Uuid".to_string(),
1636                    inner_type: "Uuid".to_string(),
1637                    is_nullable: false,
1638                    is_primary_key: false,
1639                    sql_type: None,
1640                    is_sql_array: false,
1641                },
1642                ParsedField {
1643                    rust_name: "usages".to_string(),
1644                    column_name: "usages".to_string(),
1645                    rust_type: "Vec<ConnectorUsages>".to_string(),
1646                    inner_type: "Vec<ConnectorUsages>".to_string(),
1647                    is_nullable: false,
1648                    is_primary_key: false,
1649                    sql_type: Some("agent.connector_usages".to_string()),
1650                    is_sql_array: true,
1651                },
1652            ],
1653            imports: vec!["use uuid::Uuid;".to_string()],
1654        }
1655    }
1656
1657    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1658        let skip = Methods::all();
1659        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1660        parse_and_format(&tokens)
1661    }
1662
1663    #[test]
1664    fn test_sql_array_macro_get_all_uses_runtime() {
1665        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1666        // get_all should use runtime query_as, not macro
1667        assert!(code.contains("query_as::<"));
1668    }
1669
1670    #[test]
1671    fn test_sql_array_macro_get_uses_runtime() {
1672        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1673        // get should use .bind( since it's runtime
1674        assert!(code.contains(".bind("));
1675    }
1676
1677    #[test]
1678    fn test_sql_array_macro_insert_uses_runtime() {
1679        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1680        // insert RETURNING should use runtime query_as
1681        assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1682    }
1683
1684    #[test]
1685    fn test_sql_array_macro_overwrite_uses_runtime() {
1686        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1687        // overwrite RETURNING should use runtime query_as
1688        assert!(code.contains("query_as::<"));
1689    }
1690
1691    #[test]
1692    fn test_sql_array_macro_delete_still_uses_macro() {
1693        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1694        // delete uses query! macro (no rows returned, no array issue)
1695        assert!(code.contains("query!"));
1696    }
1697
1698    #[test]
1699    fn test_sql_array_no_query_as_macro() {
1700        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1701        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
1702        assert!(!code.contains("query_as!("));
1703    }
1704
1705    // --- custom enum (non-array) also triggers runtime fallback ---
1706
1707    fn entity_with_sql_enum() -> ParsedEntity {
1708        ParsedEntity {
1709            struct_name: "Task".to_string(),
1710            table_name: "tasks".to_string(),
1711            schema_name: None,
1712            is_view: false,
1713            fields: vec![
1714                ParsedField {
1715                    rust_name: "id".to_string(),
1716                    column_name: "id".to_string(),
1717                    rust_type: "i32".to_string(),
1718                    inner_type: "i32".to_string(),
1719                    is_nullable: false,
1720                    is_primary_key: true,
1721                    sql_type: None,
1722                    is_sql_array: false,
1723                },
1724                ParsedField {
1725                    rust_name: "status".to_string(),
1726                    column_name: "status".to_string(),
1727                    rust_type: "TaskStatus".to_string(),
1728                    inner_type: "TaskStatus".to_string(),
1729                    is_nullable: false,
1730                    is_primary_key: false,
1731                    sql_type: Some("task_status".to_string()),
1732                    is_sql_array: false,
1733                },
1734            ],
1735            imports: vec![],
1736        }
1737    }
1738
1739    #[test]
1740    fn test_sql_enum_macro_uses_runtime() {
1741        let skip = Methods::all();
1742        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1743        let code = parse_and_format(&tokens);
1744        // SELECT queries should use runtime query_as, not macro
1745        assert!(code.contains("query_as::<"));
1746        assert!(!code.contains("query_as!("));
1747    }
1748
1749    #[test]
1750    fn test_sql_enum_macro_delete_still_uses_macro() {
1751        let skip = Methods::all();
1752        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1753        let code = parse_and_format(&tokens);
1754        // DELETE still uses query! macro
1755        assert!(code.contains("query!"));
1756    }
1757
1758    // --- Vec<String> native array uses .as_slice() in macro mode ---
1759
1760    fn entity_with_vec_string() -> ParsedEntity {
1761        ParsedEntity {
1762            struct_name: "PromptHistory".to_string(),
1763            table_name: "prompt_history".to_string(),
1764            schema_name: None,
1765            is_view: false,
1766            fields: vec![
1767                ParsedField {
1768                    rust_name: "id".to_string(),
1769                    column_name: "id".to_string(),
1770                    rust_type: "Uuid".to_string(),
1771                    inner_type: "Uuid".to_string(),
1772                    is_nullable: false,
1773                    is_primary_key: true,
1774                    sql_type: None,
1775                    is_sql_array: false,
1776                },
1777                ParsedField {
1778                    rust_name: "content".to_string(),
1779                    column_name: "content".to_string(),
1780                    rust_type: "String".to_string(),
1781                    inner_type: "String".to_string(),
1782                    is_nullable: false,
1783                    is_primary_key: false,
1784                    sql_type: None,
1785                    is_sql_array: false,
1786                },
1787                ParsedField {
1788                    rust_name: "tags".to_string(),
1789                    column_name: "tags".to_string(),
1790                    rust_type: "Vec<String>".to_string(),
1791                    inner_type: "Vec<String>".to_string(),
1792                    is_nullable: false,
1793                    is_primary_key: false,
1794                    sql_type: None,
1795                    is_sql_array: false,
1796                },
1797            ],
1798            imports: vec!["use uuid::Uuid;".to_string()],
1799        }
1800    }
1801
1802    #[test]
1803    fn test_vec_string_macro_insert_uses_as_slice() {
1804        let skip = Methods::all();
1805        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1806        let code = parse_and_format(&tokens);
1807        assert!(code.contains("as_slice()"));
1808    }
1809
1810    #[test]
1811    fn test_vec_string_macro_overwrite_uses_as_slice() {
1812        let skip = Methods::all();
1813        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1814        let code = parse_and_format(&tokens);
1815        // Should have as_slice() for insert, overwrite, and update
1816        let count = code.matches("as_slice()").count();
1817        assert!(count >= 3, "expected at least 3 as_slice() calls (insert + overwrite + update), found {}", count);
1818    }
1819
1820    #[test]
1821    fn test_vec_string_non_macro_no_as_slice() {
1822        let skip = Methods::all();
1823        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1824        let code = parse_and_format(&tokens);
1825        // Runtime mode uses .bind() so no as_slice needed
1826        assert!(!code.contains("as_slice()"));
1827    }
1828
1829    #[test]
1830    fn test_vec_string_parsed_from_source_uses_as_slice() {
1831        use crate::codegen::entity_parser::parse_entity_source;
1832        let source = r#"
1833            use uuid::Uuid;
1834
1835            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1836            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1837            pub struct PromptHistory {
1838                #[sqlx_gen(primary_key)]
1839                pub id: Uuid,
1840                pub content: String,
1841                pub tags: Vec<String>,
1842            }
1843        "#;
1844        let entity = parse_entity_source(source).unwrap();
1845        let skip = Methods::all();
1846        let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1847        let code = parse_and_format(&tokens);
1848        assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1849    }
1850
1851    // --- composite PK only (junction table) ---
1852
1853    fn junction_entity() -> ParsedEntity {
1854        ParsedEntity {
1855            struct_name: "AnalysisRecord".to_string(),
1856            table_name: "analysis.analysis__record".to_string(),
1857            schema_name: None,
1858            is_view: false,
1859            fields: vec![
1860                make_field("record_id", "record_id", "uuid::Uuid", false, true),
1861                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
1862            ],
1863            imports: vec![],
1864        }
1865    }
1866
1867    #[test]
1868    fn test_composite_pk_only_insert_generated() {
1869        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1870        assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
1871        assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
1872        assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
1873        assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
1874        assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
1875    }
1876
1877    #[test]
1878    fn test_composite_pk_only_no_update() {
1879        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1880        assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
1881        assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
1882    }
1883
1884    #[test]
1885    fn test_composite_pk_only_no_overwrite() {
1886        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1887        assert!(!code.contains("OverwriteAnalysisRecordParams"), "Expected no OverwriteAnalysisRecordParams struct:\n{}", code);
1888        assert!(!code.contains("pub async fn overwrite"), "Expected no overwrite method:\n{}", code);
1889    }
1890
1891    #[test]
1892    fn test_composite_pk_only_delete_generated() {
1893        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1894        assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
1895        assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
1896    }
1897
1898    #[test]
1899    fn test_composite_pk_only_get_generated() {
1900        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1901        assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
1902        assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
1903    }
1904}