Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40    // Forward type imports from the entity file (chrono, uuid, etc.)
41    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
42    // since the repository lives in a different module where `super` has a different meaning.
43    let entity_parent = entity_module_path
44        .rsplit_once("::")
45        .map(|(parent, _)| parent)
46        .unwrap_or(entity_module_path);
47    for imp in &entity.imports {
48        if let Some(rest) = imp.strip_prefix("use super::") {
49            imports.insert(format!("use {}::{}", entity_parent, rest));
50        } else {
51            imports.insert(imp.clone());
52        }
53    }
54
55    // Primary key fields
56    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58    // Non-PK fields (for insert)
59    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61    let is_view = entity.is_view;
62
63    // Build method tokens
64    let mut method_tokens = Vec::new();
65    let mut param_structs = Vec::new();
66
67    // --- get_all ---
68    if methods.get_all {
69        let sql = format!("SELECT * FROM {}", table_name);
70        let method = if use_macro {
71            quote! {
72                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73                    sqlx::query_as!(#entity_ident, #sql)
74                        .fetch_all(&self.pool)
75                        .await
76                }
77            }
78        } else {
79            quote! {
80                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81                    sqlx::query_as::<_, #entity_ident>(#sql)
82                        .fetch_all(&self.pool)
83                        .await
84                }
85            }
86        };
87        method_tokens.push(method);
88    }
89
90    // --- paginate ---
91    if methods.paginate {
92        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95        let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96        let sql = match db_kind {
97            DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99        };
100        let method = if use_macro {
101            quote! {
102                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103                    let total: i64 = sqlx::query_scalar!(#count_sql)
104                        .fetch_one(&self.pool)
105                        .await?
106                        .unwrap_or(0);
107                    let per_page = params.per_page;
108                    let current_page = params.page;
109                    let last_page = (total + per_page - 1) / per_page;
110                    let offset = (current_page - 1) * per_page;
111                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112                        .fetch_all(&self.pool)
113                        .await?;
114                    Ok(#paginated_ident {
115                        meta: #pagination_meta_ident {
116                            total,
117                            per_page,
118                            current_page,
119                            last_page,
120                            first_page: 1,
121                        },
122                        data,
123                    })
124                }
125            }
126        } else {
127            quote! {
128                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129                    let total: i64 = sqlx::query_scalar(#count_sql)
130                        .fetch_one(&self.pool)
131                        .await?;
132                    let per_page = params.per_page;
133                    let current_page = params.page;
134                    let last_page = (total + per_page - 1) / per_page;
135                    let offset = (current_page - 1) * per_page;
136                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
137                        .bind(per_page)
138                        .bind(offset)
139                        .fetch_all(&self.pool)
140                        .await?;
141                    Ok(#paginated_ident {
142                        meta: #pagination_meta_ident {
143                            total,
144                            per_page,
145                            current_page,
146                            last_page,
147                            first_page: 1,
148                        },
149                        data,
150                    })
151                }
152            }
153        };
154        method_tokens.push(method);
155        param_structs.push(quote! {
156            #[derive(Debug, Clone, Default)]
157            pub struct #paginate_params_ident {
158                pub page: i64,
159                pub per_page: i64,
160            }
161        });
162        param_structs.push(quote! {
163            #[derive(Debug, Clone)]
164            pub struct #pagination_meta_ident {
165                pub total: i64,
166                pub per_page: i64,
167                pub current_page: i64,
168                pub last_page: i64,
169                pub first_page: i64,
170            }
171        });
172        param_structs.push(quote! {
173            #[derive(Debug, Clone)]
174            pub struct #paginated_ident {
175                pub meta: #pagination_meta_ident,
176                pub data: Vec<#entity_ident>,
177            }
178        });
179    }
180
181    // --- get (by PK) ---
182    if methods.get && !pk_fields.is_empty() {
183        let pk_params: Vec<TokenStream> = pk_fields
184            .iter()
185            .map(|f| {
186                let name = format_ident!("{}", f.rust_name);
187                let ty: TokenStream = f.inner_type.parse().unwrap();
188                quote! { #name: #ty }
189            })
190            .collect();
191
192        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194        let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195        let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197        let binds: Vec<TokenStream> = pk_fields
198            .iter()
199            .map(|f| {
200                let name = format_ident!("{}", f.rust_name);
201                quote! { .bind(#name) }
202            })
203            .collect();
204
205        let method = if use_macro {
206            let pk_arg_names: Vec<TokenStream> = pk_fields
207                .iter()
208                .map(|f| {
209                    let name = format_ident!("{}", f.rust_name);
210                    quote! { #name }
211                })
212                .collect();
213            quote! {
214                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216                        .fetch_optional(&self.pool)
217                        .await
218                }
219            }
220        } else {
221            quote! {
222                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223                    sqlx::query_as::<_, #entity_ident>(#sql)
224                        #(#binds)*
225                        .fetch_optional(&self.pool)
226                        .await
227                }
228            }
229        };
230        method_tokens.push(method);
231    }
232
233    // --- insert (skip for views) ---
234    if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237        // When all columns are PKs (e.g. junction tables), use pk_fields for insert
238        let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239            pk_fields.clone()
240        } else {
241            non_pk_fields.clone()
242        };
243
244        let insert_fields: Vec<TokenStream> = insert_source_fields
245            .iter()
246            .map(|f| {
247                let name = format_ident!("{}", f.rust_name);
248                let ty: TokenStream = f.rust_type.parse().unwrap();
249                quote! { pub #name: #ty, }
250            })
251            .collect();
252
253        let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
254        let col_list = col_names.join(", ");
255        // Use casted placeholders for macro mode, plain for runtime
256        let placeholders = build_placeholders(insert_source_fields.len(), db_kind, 1);
257        let placeholders_cast = build_placeholders_with_cast(&insert_source_fields, db_kind, 1, true);
258
259        let build_insert_sql = |ph: &str| match db_kind {
260            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
261                format!(
262                    "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
263                    table_name, col_list, ph
264                )
265            }
266            DatabaseKind::Mysql => {
267                format!(
268                    "INSERT INTO {} ({}) VALUES ({})",
269                    table_name, col_list, ph
270                )
271            }
272        };
273        let sql = build_insert_sql(&placeholders);
274        let sql_macro = build_insert_sql(&placeholders_cast);
275
276        let binds: Vec<TokenStream> = insert_source_fields
277            .iter()
278            .map(|f| {
279                let name = format_ident!("{}", f.rust_name);
280                quote! { .bind(&params.#name) }
281            })
282            .collect();
283
284        let insert_method = build_insert_method_parsed(
285            &entity_ident,
286            &insert_params_ident,
287            &sql,
288            &sql_macro,
289            &binds,
290            db_kind,
291            &table_name,
292            &pk_fields,
293            &insert_source_fields,
294            use_macro,
295        );
296        method_tokens.push(insert_method);
297
298        param_structs.push(quote! {
299            #[derive(Debug, Clone, Default)]
300            pub struct #insert_params_ident {
301                #(#insert_fields)*
302            }
303        });
304    }
305
306    // --- update (skip for views, skip when all columns are PKs — nothing to SET) ---
307    if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
308        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
309
310        let update_fields: Vec<TokenStream> = entity
311            .fields
312            .iter()
313            .map(|f| {
314                let name = format_ident!("{}", f.rust_name);
315                let ty: TokenStream = f.rust_type.parse().unwrap();
316                quote! { pub #name: #ty, }
317            })
318            .collect();
319
320        let set_cols: Vec<String> = non_pk_fields
321            .iter()
322            .enumerate()
323            .map(|(i, f)| {
324                let p = placeholder(db_kind, i + 1);
325                format!("{} = {}", f.column_name, p)
326            })
327            .collect();
328        let set_clause = set_cols.join(", ");
329
330        // SET clause with casts for macro mode
331        let set_cols_cast: Vec<String> = non_pk_fields
332            .iter()
333            .enumerate()
334            .map(|(i, f)| {
335                let p = placeholder_with_cast(db_kind, i + 1, f);
336                format!("{} = {}", f.column_name, p)
337            })
338            .collect();
339        let set_clause_cast = set_cols_cast.join(", ");
340
341        let pk_start = non_pk_fields.len() + 1;
342        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
343        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
344
345        let build_update_sql = |sc: &str, wc: &str| match db_kind {
346            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
347                format!(
348                    "UPDATE {} SET {} WHERE {} RETURNING *",
349                    table_name, sc, wc
350                )
351            }
352            DatabaseKind::Mysql => {
353                format!(
354                    "UPDATE {} SET {} WHERE {}",
355                    table_name, sc, wc
356                )
357            }
358        };
359        let sql = build_update_sql(&set_clause, &where_clause);
360        let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
361
362        // Bind non-PK first, then PK
363        let mut all_binds: Vec<TokenStream> = non_pk_fields
364            .iter()
365            .map(|f| {
366                let name = format_ident!("{}", f.rust_name);
367                quote! { .bind(&params.#name) }
368            })
369            .collect();
370        for f in &pk_fields {
371            let name = format_ident!("{}", f.rust_name);
372            all_binds.push(quote! { .bind(&params.#name) });
373        }
374
375        // Macro args: non-PK fields first, then PK fields
376        let update_macro_args: Vec<TokenStream> = non_pk_fields
377            .iter()
378            .chain(pk_fields.iter())
379            .map(|f| macro_arg_for_field(f))
380            .collect();
381
382        let update_method = if use_macro {
383            match db_kind {
384                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
385                    quote! {
386                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
387                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
388                                .fetch_one(&self.pool)
389                                .await
390                        }
391                    }
392                }
393                DatabaseKind::Mysql => {
394                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
395                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
396                    let pk_macro_args: Vec<TokenStream> = pk_fields
397                        .iter()
398                        .map(|f| {
399                            let name = format_ident!("{}", f.rust_name);
400                            quote! { params.#name }
401                        })
402                        .collect();
403                    quote! {
404                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
405                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
406                                .execute(&self.pool)
407                                .await?;
408                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
409                                .fetch_one(&self.pool)
410                                .await
411                        }
412                    }
413                }
414            }
415        } else {
416            match db_kind {
417                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
418                    quote! {
419                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
420                            sqlx::query_as::<_, #entity_ident>(#sql)
421                                #(#all_binds)*
422                                .fetch_one(&self.pool)
423                                .await
424                        }
425                    }
426                }
427                DatabaseKind::Mysql => {
428                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
429                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
430                    let pk_binds: Vec<TokenStream> = pk_fields
431                        .iter()
432                        .map(|f| {
433                            let name = format_ident!("{}", f.rust_name);
434                            quote! { .bind(&params.#name) }
435                        })
436                        .collect();
437                    quote! {
438                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
439                            sqlx::query(#sql)
440                                #(#all_binds)*
441                                .execute(&self.pool)
442                                .await?;
443                            sqlx::query_as::<_, #entity_ident>(#select_sql)
444                                #(#pk_binds)*
445                                .fetch_one(&self.pool)
446                                .await
447                        }
448                    }
449                }
450            }
451        };
452        method_tokens.push(update_method);
453
454        param_structs.push(quote! {
455            #[derive(Debug, Clone, Default)]
456            pub struct #update_params_ident {
457                #(#update_fields)*
458            }
459        });
460    }
461
462    // --- delete (skip for views) ---
463    if !is_view && methods.delete && !pk_fields.is_empty() {
464        let pk_params: Vec<TokenStream> = pk_fields
465            .iter()
466            .map(|f| {
467                let name = format_ident!("{}", f.rust_name);
468                let ty: TokenStream = f.inner_type.parse().unwrap();
469                quote! { #name: #ty }
470            })
471            .collect();
472
473        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
474        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
475        let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
476        let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
477
478        let binds: Vec<TokenStream> = pk_fields
479            .iter()
480            .map(|f| {
481                let name = format_ident!("{}", f.rust_name);
482                quote! { .bind(#name) }
483            })
484            .collect();
485
486        let method = if query_macro {
487            let pk_arg_names: Vec<TokenStream> = pk_fields
488                .iter()
489                .map(|f| {
490                    let name = format_ident!("{}", f.rust_name);
491                    quote! { #name }
492                })
493                .collect();
494            quote! {
495                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
496                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
497                        .execute(&self.pool)
498                        .await?;
499                    Ok(())
500                }
501            }
502        } else {
503            quote! {
504                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
505                    sqlx::query(#sql)
506                        #(#binds)*
507                        .execute(&self.pool)
508                        .await?;
509                    Ok(())
510                }
511            }
512        };
513        method_tokens.push(method);
514    }
515
516    let pool_vis: TokenStream = match pool_visibility {
517        PoolVisibility::Private => quote! {},
518        PoolVisibility::Pub => quote! { pub },
519        PoolVisibility::PubCrate => quote! { pub(crate) },
520    };
521
522    let tokens = quote! {
523        #(#param_structs)*
524
525        pub struct #repo_ident {
526            #pool_vis pool: #pool_type,
527        }
528
529        impl #repo_ident {
530            pub fn new(pool: #pool_type) -> Self {
531                Self { pool }
532            }
533
534            #(#method_tokens)*
535        }
536    };
537
538    (tokens, imports)
539}
540
541fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
542    match db_kind {
543        DatabaseKind::Postgres => quote! { sqlx::PgPool },
544        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
545        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
546    }
547}
548
549fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
550    match db_kind {
551        DatabaseKind::Postgres => format!("${}", index),
552        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
553    }
554}
555
556fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
557    let base = placeholder(db_kind, index);
558    match (&field.sql_type, field.is_sql_array) {
559        (Some(t), true) => format!("{} as {}[]", base, t),
560        (Some(t), false) => format!("{} as {}", base, t),
561        (None, _) => base,
562    }
563}
564
565fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
566    (0..count)
567        .map(|i| placeholder(db_kind, start + i))
568        .collect::<Vec<_>>()
569        .join(", ")
570}
571
572fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
573    fields
574        .iter()
575        .enumerate()
576        .map(|(i, f)| {
577            if use_cast {
578                placeholder_with_cast(db_kind, start + i, f)
579            } else {
580                placeholder(db_kind, start + i)
581            }
582        })
583        .collect::<Vec<_>>()
584        .join(", ")
585}
586
587fn build_where_clause_parsed(
588    pk_fields: &[&ParsedField],
589    db_kind: DatabaseKind,
590    start_index: usize,
591) -> String {
592    pk_fields
593        .iter()
594        .enumerate()
595        .map(|(i, f)| {
596            let p = placeholder(db_kind, start_index + i);
597            format!("{} = {}", f.column_name, p)
598        })
599        .collect::<Vec<_>>()
600        .join(" AND ")
601}
602
603fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
604    let name = format_ident!("{}", field.rust_name);
605    let check_type = if field.is_nullable {
606        &field.inner_type
607    } else {
608        &field.rust_type
609    };
610    let normalized = check_type.replace(' ', "");
611    if normalized.starts_with("Vec<") {
612        quote! { params.#name.as_slice() }
613    } else {
614        quote! { params.#name }
615    }
616}
617
618fn build_where_clause_cast(
619    pk_fields: &[&ParsedField],
620    db_kind: DatabaseKind,
621    start_index: usize,
622) -> String {
623    pk_fields
624        .iter()
625        .enumerate()
626        .map(|(i, f)| {
627            let p = placeholder_with_cast(db_kind, start_index + i, f);
628            format!("{} = {}", f.column_name, p)
629        })
630        .collect::<Vec<_>>()
631        .join(" AND ")
632}
633
634#[allow(clippy::too_many_arguments)]
635fn build_insert_method_parsed(
636    entity_ident: &proc_macro2::Ident,
637    insert_params_ident: &proc_macro2::Ident,
638    sql: &str,
639    sql_macro: &str,
640    binds: &[TokenStream],
641    db_kind: DatabaseKind,
642    table_name: &str,
643    pk_fields: &[&ParsedField],
644    non_pk_fields: &[&ParsedField],
645    use_macro: bool,
646) -> TokenStream {
647    if use_macro {
648        let macro_args: Vec<TokenStream> = non_pk_fields
649            .iter()
650            .map(|f| macro_arg_for_field(f))
651            .collect();
652
653        match db_kind {
654            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
655                quote! {
656                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
657                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
658                            .fetch_one(&self.pool)
659                            .await
660                    }
661                }
662            }
663            DatabaseKind::Mysql => {
664                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
665                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
666                quote! {
667                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
668                        sqlx::query!(#sql_macro, #(#macro_args),*)
669                            .execute(&self.pool)
670                            .await?;
671                        let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
672                            .fetch_one(&self.pool)
673                            .await?;
674                        sqlx::query_as!(#entity_ident, #select_sql, id)
675                            .fetch_one(&self.pool)
676                            .await
677                    }
678                }
679            }
680        }
681    } else {
682        match db_kind {
683            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
684                quote! {
685                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
686                        sqlx::query_as::<_, #entity_ident>(#sql)
687                            #(#binds)*
688                            .fetch_one(&self.pool)
689                            .await
690                    }
691                }
692            }
693            DatabaseKind::Mysql => {
694                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
695                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
696                quote! {
697                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
698                        sqlx::query(#sql)
699                            #(#binds)*
700                            .execute(&self.pool)
701                            .await?;
702                        let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
703                            .fetch_one(&self.pool)
704                            .await?;
705                        sqlx::query_as::<_, #entity_ident>(#select_sql)
706                            .bind(id)
707                            .fetch_one(&self.pool)
708                            .await
709                    }
710                }
711            }
712        }
713    }
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use crate::codegen::parse_and_format;
720    use crate::cli::Methods;
721
722    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
723        let inner_type = if nullable {
724            // Strip "Option<" prefix and ">" suffix
725            rust_type
726                .strip_prefix("Option<")
727                .and_then(|s| s.strip_suffix('>'))
728                .unwrap_or(rust_type)
729                .to_string()
730        } else {
731            rust_type.to_string()
732        };
733        ParsedField {
734            rust_name: rust_name.to_string(),
735            column_name: column_name.to_string(),
736            rust_type: rust_type.to_string(),
737            is_nullable: nullable,
738            inner_type,
739            is_primary_key: is_pk,
740            sql_type: None,
741            is_sql_array: false,
742        }
743    }
744
745    fn standard_entity() -> ParsedEntity {
746        ParsedEntity {
747            struct_name: "Users".to_string(),
748            table_name: "users".to_string(),
749            schema_name: None,
750            is_view: false,
751            fields: vec![
752                make_field("id", "id", "i32", false, true),
753                make_field("name", "name", "String", false, false),
754                make_field("email", "email", "Option<String>", true, false),
755            ],
756            imports: vec![],
757        }
758    }
759
760    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
761        let skip = Methods::all();
762        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
763        parse_and_format(&tokens)
764    }
765
766    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
767        let skip = Methods::all();
768        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
769        parse_and_format(&tokens)
770    }
771
772    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
773        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
774        parse_and_format(&tokens)
775    }
776
777    // --- basic structure ---
778
779    #[test]
780    fn test_repo_struct_name() {
781        let code = gen(&standard_entity(), DatabaseKind::Postgres);
782        assert!(code.contains("pub struct UsersRepository"));
783    }
784
785    #[test]
786    fn test_repo_new_method() {
787        let code = gen(&standard_entity(), DatabaseKind::Postgres);
788        assert!(code.contains("pub fn new("));
789    }
790
791    #[test]
792    fn test_repo_pool_field_pg() {
793        let code = gen(&standard_entity(), DatabaseKind::Postgres);
794        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
795    }
796
797    #[test]
798    fn test_repo_pool_field_pub() {
799        let skip = Methods::all();
800        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
801        let code = parse_and_format(&tokens);
802        assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
803    }
804
805    #[test]
806    fn test_repo_pool_field_pub_crate() {
807        let skip = Methods::all();
808        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
809        let code = parse_and_format(&tokens);
810        assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
811    }
812
813    #[test]
814    fn test_repo_pool_field_private() {
815        let code = gen(&standard_entity(), DatabaseKind::Postgres);
816        // Should NOT have `pub pool` or `pub(crate) pool`
817        assert!(!code.contains("pub pool"));
818        assert!(!code.contains("pub(crate) pool"));
819    }
820
821    #[test]
822    fn test_repo_pool_field_mysql() {
823        let code = gen(&standard_entity(), DatabaseKind::Mysql);
824        assert!(code.contains("MySqlPool") || code.contains("MySql"));
825    }
826
827    #[test]
828    fn test_repo_pool_field_sqlite() {
829        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
830        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
831    }
832
833    // --- get_all ---
834
835    #[test]
836    fn test_get_all_method() {
837        let code = gen(&standard_entity(), DatabaseKind::Postgres);
838        assert!(code.contains("pub async fn get_all"));
839    }
840
841    #[test]
842    fn test_get_all_returns_vec() {
843        let code = gen(&standard_entity(), DatabaseKind::Postgres);
844        assert!(code.contains("Vec<Users>"));
845    }
846
847    #[test]
848    fn test_get_all_sql() {
849        let code = gen(&standard_entity(), DatabaseKind::Postgres);
850        assert!(code.contains("SELECT * FROM users"));
851    }
852
853    // --- paginate ---
854
855    #[test]
856    fn test_paginate_method() {
857        let code = gen(&standard_entity(), DatabaseKind::Postgres);
858        assert!(code.contains("pub async fn paginate"));
859    }
860
861    #[test]
862    fn test_paginate_params_struct() {
863        let code = gen(&standard_entity(), DatabaseKind::Postgres);
864        assert!(code.contains("pub struct PaginateUsersParams"));
865    }
866
867    #[test]
868    fn test_paginate_params_fields() {
869        let code = gen(&standard_entity(), DatabaseKind::Postgres);
870        assert!(code.contains("pub page: i64"));
871        assert!(code.contains("pub per_page: i64"));
872    }
873
874    #[test]
875    fn test_paginate_returns_paginated() {
876        let code = gen(&standard_entity(), DatabaseKind::Postgres);
877        assert!(code.contains("PaginatedUsers"));
878        assert!(code.contains("PaginationUsersMeta"));
879    }
880
881    #[test]
882    fn test_paginate_meta_struct() {
883        let code = gen(&standard_entity(), DatabaseKind::Postgres);
884        assert!(code.contains("pub struct PaginationUsersMeta"));
885        assert!(code.contains("pub total: i64"));
886        assert!(code.contains("pub last_page: i64"));
887        assert!(code.contains("pub first_page: i64"));
888        assert!(code.contains("pub current_page: i64"));
889    }
890
891    #[test]
892    fn test_paginate_data_struct() {
893        let code = gen(&standard_entity(), DatabaseKind::Postgres);
894        assert!(code.contains("pub struct PaginatedUsers"));
895        assert!(code.contains("pub meta: PaginationUsersMeta"));
896        assert!(code.contains("pub data: Vec<Users>"));
897    }
898
899    #[test]
900    fn test_paginate_count_sql() {
901        let code = gen(&standard_entity(), DatabaseKind::Postgres);
902        assert!(code.contains("SELECT COUNT(*) FROM users"));
903    }
904
905    #[test]
906    fn test_paginate_sql_pg() {
907        let code = gen(&standard_entity(), DatabaseKind::Postgres);
908        assert!(code.contains("LIMIT $1 OFFSET $2"));
909    }
910
911    #[test]
912    fn test_paginate_sql_mysql() {
913        let code = gen(&standard_entity(), DatabaseKind::Mysql);
914        assert!(code.contains("LIMIT ? OFFSET ?"));
915    }
916
917    // --- get ---
918
919    #[test]
920    fn test_get_method() {
921        let code = gen(&standard_entity(), DatabaseKind::Postgres);
922        assert!(code.contains("pub async fn get"));
923    }
924
925    #[test]
926    fn test_get_returns_option() {
927        let code = gen(&standard_entity(), DatabaseKind::Postgres);
928        assert!(code.contains("Option<Users>"));
929    }
930
931    #[test]
932    fn test_get_where_pk_pg() {
933        let code = gen(&standard_entity(), DatabaseKind::Postgres);
934        assert!(code.contains("WHERE id = $1"));
935    }
936
937    #[test]
938    fn test_get_where_pk_mysql() {
939        let code = gen(&standard_entity(), DatabaseKind::Mysql);
940        assert!(code.contains("WHERE id = ?"));
941    }
942
943    // --- insert ---
944
945    #[test]
946    fn test_insert_method() {
947        let code = gen(&standard_entity(), DatabaseKind::Postgres);
948        assert!(code.contains("pub async fn insert"));
949    }
950
951    #[test]
952    fn test_insert_params_struct() {
953        let code = gen(&standard_entity(), DatabaseKind::Postgres);
954        assert!(code.contains("pub struct InsertUsersParams"));
955    }
956
957    #[test]
958    fn test_insert_params_no_pk() {
959        let code = gen(&standard_entity(), DatabaseKind::Postgres);
960        assert!(code.contains("pub name: String"));
961        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
962    }
963
964    #[test]
965    fn test_insert_returning_pg() {
966        let code = gen(&standard_entity(), DatabaseKind::Postgres);
967        assert!(code.contains("RETURNING *"));
968    }
969
970    #[test]
971    fn test_insert_returning_sqlite() {
972        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
973        assert!(code.contains("RETURNING *"));
974    }
975
976    #[test]
977    fn test_insert_mysql_last_insert_id() {
978        let code = gen(&standard_entity(), DatabaseKind::Mysql);
979        assert!(code.contains("LAST_INSERT_ID"));
980    }
981
982    // --- update ---
983
984    #[test]
985    fn test_update_method() {
986        let code = gen(&standard_entity(), DatabaseKind::Postgres);
987        assert!(code.contains("pub async fn update"));
988    }
989
990    #[test]
991    fn test_update_params_struct() {
992        let code = gen(&standard_entity(), DatabaseKind::Postgres);
993        assert!(code.contains("pub struct UpdateUsersParams"));
994    }
995
996    #[test]
997    fn test_update_params_all_cols() {
998        let code = gen(&standard_entity(), DatabaseKind::Postgres);
999        assert!(code.contains("pub id: i32"));
1000        assert!(code.contains("pub name: String"));
1001    }
1002
1003    #[test]
1004    fn test_update_set_clause_pg() {
1005        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1006        assert!(code.contains("SET name = $1"));
1007        assert!(code.contains("WHERE id = $3"));
1008    }
1009
1010    #[test]
1011    fn test_update_returning_pg() {
1012        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1013        assert!(code.contains("UPDATE users SET"));
1014        assert!(code.contains("RETURNING *"));
1015    }
1016
1017    // --- delete ---
1018
1019    #[test]
1020    fn test_delete_method() {
1021        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1022        assert!(code.contains("pub async fn delete"));
1023    }
1024
1025    #[test]
1026    fn test_delete_where_pk() {
1027        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1028        assert!(code.contains("DELETE FROM users WHERE id = $1"));
1029    }
1030
1031    #[test]
1032    fn test_delete_returns_unit() {
1033        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1034        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1035    }
1036
1037    // --- views (read-only) ---
1038
1039    #[test]
1040    fn test_view_no_insert() {
1041        let mut entity = standard_entity();
1042        entity.is_view = true;
1043        let code = gen(&entity, DatabaseKind::Postgres);
1044        assert!(!code.contains("pub async fn insert"));
1045    }
1046
1047    #[test]
1048    fn test_view_no_update() {
1049        let mut entity = standard_entity();
1050        entity.is_view = true;
1051        let code = gen(&entity, DatabaseKind::Postgres);
1052        assert!(!code.contains("pub async fn update"));
1053    }
1054
1055    #[test]
1056    fn test_view_no_delete() {
1057        let mut entity = standard_entity();
1058        entity.is_view = true;
1059        let code = gen(&entity, DatabaseKind::Postgres);
1060        assert!(!code.contains("pub async fn delete"));
1061    }
1062
1063    #[test]
1064    fn test_view_has_get_all() {
1065        let mut entity = standard_entity();
1066        entity.is_view = true;
1067        let code = gen(&entity, DatabaseKind::Postgres);
1068        assert!(code.contains("pub async fn get_all"));
1069    }
1070
1071    #[test]
1072    fn test_view_has_paginate() {
1073        let mut entity = standard_entity();
1074        entity.is_view = true;
1075        let code = gen(&entity, DatabaseKind::Postgres);
1076        assert!(code.contains("pub async fn paginate"));
1077    }
1078
1079    #[test]
1080    fn test_view_has_get() {
1081        let mut entity = standard_entity();
1082        entity.is_view = true;
1083        let code = gen(&entity, DatabaseKind::Postgres);
1084        assert!(code.contains("pub async fn get"));
1085    }
1086
1087    // --- selective methods ---
1088
1089    #[test]
1090    fn test_only_get_all() {
1091        let m = Methods { get_all: true, ..Default::default() };
1092        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1093        assert!(code.contains("pub async fn get_all"));
1094        assert!(!code.contains("pub async fn paginate"));
1095        assert!(!code.contains("pub async fn insert"));
1096    }
1097
1098    #[test]
1099    fn test_without_get_all() {
1100        let m = Methods { get_all: false, ..Methods::all() };
1101        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1102        assert!(!code.contains("pub async fn get_all"));
1103    }
1104
1105    #[test]
1106    fn test_without_paginate() {
1107        let m = Methods { paginate: false, ..Methods::all() };
1108        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1109        assert!(!code.contains("pub async fn paginate"));
1110        assert!(!code.contains("PaginateUsersParams"));
1111    }
1112
1113    #[test]
1114    fn test_without_get() {
1115        let m = Methods { get: false, ..Methods::all() };
1116        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1117        assert!(code.contains("pub async fn get_all"));
1118        let without_get_all = code.replace("get_all", "XXX");
1119        assert!(!without_get_all.contains("fn get("));
1120    }
1121
1122    #[test]
1123    fn test_without_insert() {
1124        let m = Methods { insert: false, ..Methods::all() };
1125        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1126        assert!(!code.contains("pub async fn insert"));
1127        assert!(!code.contains("InsertUsersParams"));
1128    }
1129
1130    #[test]
1131    fn test_without_update() {
1132        let m = Methods { update: false, ..Methods::all() };
1133        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1134        assert!(!code.contains("pub async fn update"));
1135        assert!(!code.contains("UpdateUsersParams"));
1136    }
1137
1138    #[test]
1139    fn test_without_delete() {
1140        let m = Methods { delete: false, ..Methods::all() };
1141        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1142        assert!(!code.contains("pub async fn delete"));
1143    }
1144
1145    #[test]
1146    fn test_empty_methods_no_methods() {
1147        let m = Methods::default();
1148        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1149        assert!(!code.contains("pub async fn get_all"));
1150        assert!(!code.contains("pub async fn paginate"));
1151        assert!(!code.contains("pub async fn insert"));
1152        assert!(!code.contains("pub async fn update"));
1153        assert!(!code.contains("pub async fn delete"));
1154    }
1155
1156    // --- imports ---
1157
1158    #[test]
1159    fn test_no_pool_import() {
1160        let skip = Methods::all();
1161        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1162        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1163    }
1164
1165    #[test]
1166    fn test_imports_contain_entity() {
1167        let skip = Methods::all();
1168        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1169        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1170    }
1171
1172    // --- renamed columns ---
1173
1174    #[test]
1175    fn test_renamed_column_in_sql() {
1176        let entity = ParsedEntity {
1177            struct_name: "Connector".to_string(),
1178            table_name: "connector".to_string(),
1179            schema_name: None,
1180            is_view: false,
1181            fields: vec![
1182                make_field("id", "id", "i32", false, true),
1183                make_field("connector_type", "type", "String", false, false),
1184            ],
1185            imports: vec![],
1186        };
1187        let code = gen(&entity, DatabaseKind::Postgres);
1188        // INSERT should use the DB column name "type", not "connector_type"
1189        assert!(code.contains("type"));
1190        // The Rust param field should be connector_type
1191        assert!(code.contains("pub connector_type: String"));
1192    }
1193
1194    // --- no PK edge cases ---
1195
1196    #[test]
1197    fn test_no_pk_no_get() {
1198        let entity = ParsedEntity {
1199            struct_name: "Logs".to_string(),
1200            table_name: "logs".to_string(),
1201            schema_name: None,
1202            is_view: false,
1203            fields: vec![
1204                make_field("message", "message", "String", false, false),
1205                make_field("ts", "ts", "String", false, false),
1206            ],
1207            imports: vec![],
1208        };
1209        let code = gen(&entity, DatabaseKind::Postgres);
1210        assert!(code.contains("pub async fn get_all"));
1211        let without_get_all = code.replace("get_all", "XXX");
1212        assert!(!without_get_all.contains("fn get("));
1213    }
1214
1215    #[test]
1216    fn test_no_pk_no_delete() {
1217        let entity = ParsedEntity {
1218            struct_name: "Logs".to_string(),
1219            table_name: "logs".to_string(),
1220            schema_name: None,
1221            is_view: false,
1222            fields: vec![
1223                make_field("message", "message", "String", false, false),
1224            ],
1225            imports: vec![],
1226        };
1227        let code = gen(&entity, DatabaseKind::Postgres);
1228        assert!(!code.contains("pub async fn delete"));
1229    }
1230
1231    // --- Default derive on param structs ---
1232
1233    #[test]
1234    fn test_param_structs_have_default() {
1235        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1236        assert!(code.contains("Default"));
1237    }
1238
1239    // --- entity imports forwarded ---
1240
1241    #[test]
1242    fn test_entity_imports_forwarded() {
1243        let entity = ParsedEntity {
1244            struct_name: "Users".to_string(),
1245            table_name: "users".to_string(),
1246            schema_name: None,
1247            is_view: false,
1248            fields: vec![
1249                make_field("id", "id", "Uuid", false, true),
1250                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1251            ],
1252            imports: vec![
1253                "use chrono::{DateTime, Utc};".to_string(),
1254                "use uuid::Uuid;".to_string(),
1255            ],
1256        };
1257        let skip = Methods::all();
1258        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1259        assert!(imports.iter().any(|i| i.contains("chrono")));
1260        assert!(imports.iter().any(|i| i.contains("uuid")));
1261    }
1262
1263    #[test]
1264    fn test_entity_imports_empty_when_no_imports() {
1265        let skip = Methods::all();
1266        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1267        // Should only have pool + entity imports, no chrono/uuid
1268        assert!(!imports.iter().any(|i| i.contains("chrono")));
1269        assert!(!imports.iter().any(|i| i.contains("uuid")));
1270    }
1271
1272    // --- query_macro mode ---
1273
1274    #[test]
1275    fn test_macro_get_all() {
1276        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1277        assert!(code.contains("query_as!"));
1278        assert!(!code.contains("query_as::<"));
1279    }
1280
1281    #[test]
1282    fn test_macro_paginate() {
1283        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1284        assert!(code.contains("query_as!"));
1285        assert!(code.contains("per_page, offset"));
1286    }
1287
1288    #[test]
1289    fn test_macro_get() {
1290        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1291        // The get method should use query_as! with the PK as arg
1292        assert!(code.contains("query_as!(Users"));
1293    }
1294
1295    #[test]
1296    fn test_macro_insert_pg() {
1297        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1298        assert!(code.contains("query_as!(Users"));
1299        assert!(code.contains("params.name"));
1300        assert!(code.contains("params.email"));
1301    }
1302
1303    #[test]
1304    fn test_macro_insert_mysql() {
1305        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1306        // MySQL insert uses query! (not query_as!) for the INSERT
1307        assert!(code.contains("query!"));
1308        assert!(code.contains("query_scalar!"));
1309    }
1310
1311    #[test]
1312    fn test_macro_update() {
1313        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1314        assert!(code.contains("query_as!(Users"));
1315        // Should contain params.name, params.email, params.id as args
1316        assert!(code.contains("params.name"));
1317        assert!(code.contains("params.id"));
1318    }
1319
1320    #[test]
1321    fn test_macro_delete() {
1322        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1323        // delete uses query! (no return type)
1324        assert!(code.contains("query!"));
1325    }
1326
1327    #[test]
1328    fn test_macro_no_bind_calls() {
1329        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1330        assert!(!code.contains(".bind("));
1331    }
1332
1333    #[test]
1334    fn test_function_style_uses_bind() {
1335        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1336        assert!(code.contains(".bind("));
1337        assert!(!code.contains("query_as!("));
1338        assert!(!code.contains("query!("));
1339    }
1340
1341    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
1342
1343    fn entity_with_sql_array() -> ParsedEntity {
1344        ParsedEntity {
1345            struct_name: "AgentConnector".to_string(),
1346            table_name: "agent.agent_connector".to_string(),
1347            schema_name: Some("agent".to_string()),
1348            is_view: false,
1349            fields: vec![
1350                ParsedField {
1351                    rust_name: "connector_id".to_string(),
1352                    column_name: "connector_id".to_string(),
1353                    rust_type: "Uuid".to_string(),
1354                    inner_type: "Uuid".to_string(),
1355                    is_nullable: false,
1356                    is_primary_key: true,
1357                    sql_type: None,
1358                    is_sql_array: false,
1359                },
1360                ParsedField {
1361                    rust_name: "agent_id".to_string(),
1362                    column_name: "agent_id".to_string(),
1363                    rust_type: "Uuid".to_string(),
1364                    inner_type: "Uuid".to_string(),
1365                    is_nullable: false,
1366                    is_primary_key: false,
1367                    sql_type: None,
1368                    is_sql_array: false,
1369                },
1370                ParsedField {
1371                    rust_name: "usages".to_string(),
1372                    column_name: "usages".to_string(),
1373                    rust_type: "Vec<ConnectorUsages>".to_string(),
1374                    inner_type: "Vec<ConnectorUsages>".to_string(),
1375                    is_nullable: false,
1376                    is_primary_key: false,
1377                    sql_type: Some("agent.connector_usages".to_string()),
1378                    is_sql_array: true,
1379                },
1380            ],
1381            imports: vec!["use uuid::Uuid;".to_string()],
1382        }
1383    }
1384
1385    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1386        let skip = Methods::all();
1387        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1388        parse_and_format(&tokens)
1389    }
1390
1391    #[test]
1392    fn test_sql_array_macro_get_all_uses_runtime() {
1393        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1394        // get_all should use runtime query_as, not macro
1395        assert!(code.contains("query_as::<"));
1396    }
1397
1398    #[test]
1399    fn test_sql_array_macro_get_uses_runtime() {
1400        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1401        // get should use .bind( since it's runtime
1402        assert!(code.contains(".bind("));
1403    }
1404
1405    #[test]
1406    fn test_sql_array_macro_insert_uses_runtime() {
1407        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1408        // insert RETURNING should use runtime query_as
1409        assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1410    }
1411
1412    #[test]
1413    fn test_sql_array_macro_update_uses_runtime() {
1414        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1415        // update RETURNING should use runtime query_as
1416        assert!(code.contains("query_as::<"));
1417    }
1418
1419    #[test]
1420    fn test_sql_array_macro_delete_still_uses_macro() {
1421        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1422        // delete uses query! macro (no rows returned, no array issue)
1423        assert!(code.contains("query!"));
1424    }
1425
1426    #[test]
1427    fn test_sql_array_no_query_as_macro() {
1428        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1429        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
1430        assert!(!code.contains("query_as!("));
1431    }
1432
1433    // --- custom enum (non-array) also triggers runtime fallback ---
1434
1435    fn entity_with_sql_enum() -> ParsedEntity {
1436        ParsedEntity {
1437            struct_name: "Task".to_string(),
1438            table_name: "tasks".to_string(),
1439            schema_name: None,
1440            is_view: false,
1441            fields: vec![
1442                ParsedField {
1443                    rust_name: "id".to_string(),
1444                    column_name: "id".to_string(),
1445                    rust_type: "i32".to_string(),
1446                    inner_type: "i32".to_string(),
1447                    is_nullable: false,
1448                    is_primary_key: true,
1449                    sql_type: None,
1450                    is_sql_array: false,
1451                },
1452                ParsedField {
1453                    rust_name: "status".to_string(),
1454                    column_name: "status".to_string(),
1455                    rust_type: "TaskStatus".to_string(),
1456                    inner_type: "TaskStatus".to_string(),
1457                    is_nullable: false,
1458                    is_primary_key: false,
1459                    sql_type: Some("task_status".to_string()),
1460                    is_sql_array: false,
1461                },
1462            ],
1463            imports: vec![],
1464        }
1465    }
1466
1467    #[test]
1468    fn test_sql_enum_macro_uses_runtime() {
1469        let skip = Methods::all();
1470        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1471        let code = parse_and_format(&tokens);
1472        // SELECT queries should use runtime query_as, not macro
1473        assert!(code.contains("query_as::<"));
1474        assert!(!code.contains("query_as!("));
1475    }
1476
1477    #[test]
1478    fn test_sql_enum_macro_delete_still_uses_macro() {
1479        let skip = Methods::all();
1480        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1481        let code = parse_and_format(&tokens);
1482        // DELETE still uses query! macro
1483        assert!(code.contains("query!"));
1484    }
1485
1486    // --- Vec<String> native array uses .as_slice() in macro mode ---
1487
1488    fn entity_with_vec_string() -> ParsedEntity {
1489        ParsedEntity {
1490            struct_name: "PromptHistory".to_string(),
1491            table_name: "prompt_history".to_string(),
1492            schema_name: None,
1493            is_view: false,
1494            fields: vec![
1495                ParsedField {
1496                    rust_name: "id".to_string(),
1497                    column_name: "id".to_string(),
1498                    rust_type: "Uuid".to_string(),
1499                    inner_type: "Uuid".to_string(),
1500                    is_nullable: false,
1501                    is_primary_key: true,
1502                    sql_type: None,
1503                    is_sql_array: false,
1504                },
1505                ParsedField {
1506                    rust_name: "content".to_string(),
1507                    column_name: "content".to_string(),
1508                    rust_type: "String".to_string(),
1509                    inner_type: "String".to_string(),
1510                    is_nullable: false,
1511                    is_primary_key: false,
1512                    sql_type: None,
1513                    is_sql_array: false,
1514                },
1515                ParsedField {
1516                    rust_name: "tags".to_string(),
1517                    column_name: "tags".to_string(),
1518                    rust_type: "Vec<String>".to_string(),
1519                    inner_type: "Vec<String>".to_string(),
1520                    is_nullable: false,
1521                    is_primary_key: false,
1522                    sql_type: None,
1523                    is_sql_array: false,
1524                },
1525            ],
1526            imports: vec!["use uuid::Uuid;".to_string()],
1527        }
1528    }
1529
1530    #[test]
1531    fn test_vec_string_macro_insert_uses_as_slice() {
1532        let skip = Methods::all();
1533        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1534        let code = parse_and_format(&tokens);
1535        assert!(code.contains("as_slice()"));
1536    }
1537
1538    #[test]
1539    fn test_vec_string_macro_update_uses_as_slice() {
1540        let skip = Methods::all();
1541        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1542        let code = parse_and_format(&tokens);
1543        // Should have as_slice() for both insert and update
1544        let count = code.matches("as_slice()").count();
1545        assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1546    }
1547
1548    #[test]
1549    fn test_vec_string_non_macro_no_as_slice() {
1550        let skip = Methods::all();
1551        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1552        let code = parse_and_format(&tokens);
1553        // Runtime mode uses .bind() so no as_slice needed
1554        assert!(!code.contains("as_slice()"));
1555    }
1556
1557    #[test]
1558    fn test_vec_string_parsed_from_source_uses_as_slice() {
1559        use crate::codegen::entity_parser::parse_entity_source;
1560        let source = r#"
1561            use uuid::Uuid;
1562
1563            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1564            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1565            pub struct PromptHistory {
1566                #[sqlx_gen(primary_key)]
1567                pub id: Uuid,
1568                pub content: String,
1569                pub tags: Vec<String>,
1570            }
1571        "#;
1572        let entity = parse_entity_source(source).unwrap();
1573        let skip = Methods::all();
1574        let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1575        let code = parse_and_format(&tokens);
1576        assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1577    }
1578
1579    // --- composite PK only (junction table) ---
1580
1581    fn junction_entity() -> ParsedEntity {
1582        ParsedEntity {
1583            struct_name: "AnalysisRecord".to_string(),
1584            table_name: "analysis.analysis__record".to_string(),
1585            schema_name: None,
1586            is_view: false,
1587            fields: vec![
1588                make_field("record_id", "record_id", "uuid::Uuid", false, true),
1589                make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
1590            ],
1591            imports: vec![],
1592        }
1593    }
1594
1595    #[test]
1596    fn test_composite_pk_only_insert_generated() {
1597        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1598        assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
1599        assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
1600        assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
1601        assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
1602        assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
1603    }
1604
1605    #[test]
1606    fn test_composite_pk_only_no_update() {
1607        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1608        assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
1609        assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
1610    }
1611
1612    #[test]
1613    fn test_composite_pk_only_delete_generated() {
1614        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1615        assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
1616        assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
1617    }
1618
1619    #[test]
1620    fn test_composite_pk_only_get_generated() {
1621        let code = gen(&junction_entity(), DatabaseKind::Postgres);
1622        assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
1623        assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
1624    }
1625}