Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15    pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18
19    let entity_ident = format_ident!("{}", entity.struct_name);
20    let repo_name = format!("{}Repository", entity.struct_name);
21    let repo_ident = format_ident!("{}", repo_name);
22
23    let table_name = match &entity.schema_name {
24        Some(schema) => format!("{}.{}", schema, entity.table_name),
25        None => entity.table_name.clone(),
26    };
27
28    // Pool type (used via full path sqlx::PgPool etc., no import needed)
29    let pool_type = pool_type_tokens(db_kind);
30
31    // When the entity has custom SQL types (enums, composites, arrays),
32    // query_as! macro can't resolve the column type at compile time. Fall back to runtime query_as::<_, T>()
33    // for queries that return rows. DELETE (no rows returned) can still use macro.
34    let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35    let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37    // Entity import
38    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40    // Forward type imports from the entity file (chrono, uuid, etc.)
41    // Rewrite `use super::X` imports to absolute paths based on entity_module_path,
42    // since the repository lives in a different module where `super` has a different meaning.
43    let entity_parent = entity_module_path
44        .rsplit_once("::")
45        .map(|(parent, _)| parent)
46        .unwrap_or(entity_module_path);
47    for imp in &entity.imports {
48        if let Some(rest) = imp.strip_prefix("use super::") {
49            imports.insert(format!("use {}::{}", entity_parent, rest));
50        } else {
51            imports.insert(imp.clone());
52        }
53    }
54
55    // Primary key fields
56    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58    // Non-PK fields (for insert)
59    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61    let is_view = entity.is_view;
62
63    // Build method tokens
64    let mut method_tokens = Vec::new();
65    let mut param_structs = Vec::new();
66
67    // --- get_all ---
68    if methods.get_all {
69        let sql = format!("SELECT * FROM {}", table_name);
70        let method = if use_macro {
71            quote! {
72                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73                    sqlx::query_as!(#entity_ident, #sql)
74                        .fetch_all(&self.pool)
75                        .await
76                }
77            }
78        } else {
79            quote! {
80                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81                    sqlx::query_as::<_, #entity_ident>(#sql)
82                        .fetch_all(&self.pool)
83                        .await
84                }
85            }
86        };
87        method_tokens.push(method);
88    }
89
90    // --- paginate ---
91    if methods.paginate {
92        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95        let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96        let sql = match db_kind {
97            DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99        };
100        let method = if use_macro {
101            quote! {
102                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103                    let total: i64 = sqlx::query_scalar!(#count_sql)
104                        .fetch_one(&self.pool)
105                        .await?
106                        .unwrap_or(0);
107                    let per_page = params.per_page;
108                    let current_page = params.page;
109                    let last_page = (total + per_page - 1) / per_page;
110                    let offset = (current_page - 1) * per_page;
111                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112                        .fetch_all(&self.pool)
113                        .await?;
114                    Ok(#paginated_ident {
115                        meta: #pagination_meta_ident {
116                            total,
117                            per_page,
118                            current_page,
119                            last_page,
120                            first_page: 1,
121                        },
122                        data,
123                    })
124                }
125            }
126        } else {
127            quote! {
128                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129                    let total: i64 = sqlx::query_scalar(#count_sql)
130                        .fetch_one(&self.pool)
131                        .await?;
132                    let per_page = params.per_page;
133                    let current_page = params.page;
134                    let last_page = (total + per_page - 1) / per_page;
135                    let offset = (current_page - 1) * per_page;
136                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
137                        .bind(per_page)
138                        .bind(offset)
139                        .fetch_all(&self.pool)
140                        .await?;
141                    Ok(#paginated_ident {
142                        meta: #pagination_meta_ident {
143                            total,
144                            per_page,
145                            current_page,
146                            last_page,
147                            first_page: 1,
148                        },
149                        data,
150                    })
151                }
152            }
153        };
154        method_tokens.push(method);
155        param_structs.push(quote! {
156            #[derive(Debug, Clone, Default)]
157            pub struct #paginate_params_ident {
158                pub page: i64,
159                pub per_page: i64,
160            }
161        });
162        param_structs.push(quote! {
163            #[derive(Debug, Clone)]
164            pub struct #pagination_meta_ident {
165                pub total: i64,
166                pub per_page: i64,
167                pub current_page: i64,
168                pub last_page: i64,
169                pub first_page: i64,
170            }
171        });
172        param_structs.push(quote! {
173            #[derive(Debug, Clone)]
174            pub struct #paginated_ident {
175                pub meta: #pagination_meta_ident,
176                pub data: Vec<#entity_ident>,
177            }
178        });
179    }
180
181    // --- get (by PK) ---
182    if methods.get && !pk_fields.is_empty() {
183        let pk_params: Vec<TokenStream> = pk_fields
184            .iter()
185            .map(|f| {
186                let name = format_ident!("{}", f.rust_name);
187                let ty: TokenStream = f.inner_type.parse().unwrap();
188                quote! { #name: #ty }
189            })
190            .collect();
191
192        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194        let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195        let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197        let binds: Vec<TokenStream> = pk_fields
198            .iter()
199            .map(|f| {
200                let name = format_ident!("{}", f.rust_name);
201                quote! { .bind(#name) }
202            })
203            .collect();
204
205        let method = if use_macro {
206            let pk_arg_names: Vec<TokenStream> = pk_fields
207                .iter()
208                .map(|f| {
209                    let name = format_ident!("{}", f.rust_name);
210                    quote! { #name }
211                })
212                .collect();
213            quote! {
214                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215                    sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216                        .fetch_optional(&self.pool)
217                        .await
218                }
219            }
220        } else {
221            quote! {
222                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223                    sqlx::query_as::<_, #entity_ident>(#sql)
224                        #(#binds)*
225                        .fetch_optional(&self.pool)
226                        .await
227                }
228            }
229        };
230        method_tokens.push(method);
231    }
232
233    // --- insert (skip for views) ---
234    if !is_view && methods.insert && !non_pk_fields.is_empty() {
235        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237        let insert_fields: Vec<TokenStream> = non_pk_fields
238            .iter()
239            .map(|f| {
240                let name = format_ident!("{}", f.rust_name);
241                let ty: TokenStream = f.rust_type.parse().unwrap();
242                quote! { pub #name: #ty, }
243            })
244            .collect();
245
246        let col_names: Vec<&str> = non_pk_fields.iter().map(|f| f.column_name.as_str()).collect();
247        let col_list = col_names.join(", ");
248        // Use casted placeholders for macro mode, plain for runtime
249        let placeholders = build_placeholders(non_pk_fields.len(), db_kind, 1);
250        let placeholders_cast = build_placeholders_with_cast(&non_pk_fields, db_kind, 1, true);
251
252        let build_insert_sql = |ph: &str| match db_kind {
253            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
254                format!(
255                    "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
256                    table_name, col_list, ph
257                )
258            }
259            DatabaseKind::Mysql => {
260                format!(
261                    "INSERT INTO {} ({}) VALUES ({})",
262                    table_name, col_list, ph
263                )
264            }
265        };
266        let sql = build_insert_sql(&placeholders);
267        let sql_macro = build_insert_sql(&placeholders_cast);
268
269        let binds: Vec<TokenStream> = non_pk_fields
270            .iter()
271            .map(|f| {
272                let name = format_ident!("{}", f.rust_name);
273                quote! { .bind(&params.#name) }
274            })
275            .collect();
276
277        let insert_method = build_insert_method_parsed(
278            &entity_ident,
279            &insert_params_ident,
280            &sql,
281            &sql_macro,
282            &binds,
283            db_kind,
284            &table_name,
285            &pk_fields,
286            &non_pk_fields,
287            use_macro,
288        );
289        method_tokens.push(insert_method);
290
291        param_structs.push(quote! {
292            #[derive(Debug, Clone, Default)]
293            pub struct #insert_params_ident {
294                #(#insert_fields)*
295            }
296        });
297    }
298
299    // --- update (skip for views) ---
300    if !is_view && methods.update && !pk_fields.is_empty() {
301        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
302
303        let update_fields: Vec<TokenStream> = entity
304            .fields
305            .iter()
306            .map(|f| {
307                let name = format_ident!("{}", f.rust_name);
308                let ty: TokenStream = f.rust_type.parse().unwrap();
309                quote! { pub #name: #ty, }
310            })
311            .collect();
312
313        let set_cols: Vec<String> = non_pk_fields
314            .iter()
315            .enumerate()
316            .map(|(i, f)| {
317                let p = placeholder(db_kind, i + 1);
318                format!("{} = {}", f.column_name, p)
319            })
320            .collect();
321        let set_clause = set_cols.join(", ");
322
323        // SET clause with casts for macro mode
324        let set_cols_cast: Vec<String> = non_pk_fields
325            .iter()
326            .enumerate()
327            .map(|(i, f)| {
328                let p = placeholder_with_cast(db_kind, i + 1, f);
329                format!("{} = {}", f.column_name, p)
330            })
331            .collect();
332        let set_clause_cast = set_cols_cast.join(", ");
333
334        let pk_start = non_pk_fields.len() + 1;
335        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
336        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
337
338        let build_update_sql = |sc: &str, wc: &str| match db_kind {
339            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
340                format!(
341                    "UPDATE {} SET {} WHERE {} RETURNING *",
342                    table_name, sc, wc
343                )
344            }
345            DatabaseKind::Mysql => {
346                format!(
347                    "UPDATE {} SET {} WHERE {}",
348                    table_name, sc, wc
349                )
350            }
351        };
352        let sql = build_update_sql(&set_clause, &where_clause);
353        let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
354
355        // Bind non-PK first, then PK
356        let mut all_binds: Vec<TokenStream> = non_pk_fields
357            .iter()
358            .map(|f| {
359                let name = format_ident!("{}", f.rust_name);
360                quote! { .bind(&params.#name) }
361            })
362            .collect();
363        for f in &pk_fields {
364            let name = format_ident!("{}", f.rust_name);
365            all_binds.push(quote! { .bind(&params.#name) });
366        }
367
368        // Macro args: non-PK fields first, then PK fields
369        let update_macro_args: Vec<TokenStream> = non_pk_fields
370            .iter()
371            .chain(pk_fields.iter())
372            .map(|f| macro_arg_for_field(f))
373            .collect();
374
375        let update_method = if use_macro {
376            match db_kind {
377                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
378                    quote! {
379                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
380                            sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
381                                .fetch_one(&self.pool)
382                                .await
383                        }
384                    }
385                }
386                DatabaseKind::Mysql => {
387                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
388                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
389                    let pk_macro_args: Vec<TokenStream> = pk_fields
390                        .iter()
391                        .map(|f| {
392                            let name = format_ident!("{}", f.rust_name);
393                            quote! { params.#name }
394                        })
395                        .collect();
396                    quote! {
397                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
398                            sqlx::query!(#sql_macro, #(#update_macro_args),*)
399                                .execute(&self.pool)
400                                .await?;
401                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
402                                .fetch_one(&self.pool)
403                                .await
404                        }
405                    }
406                }
407            }
408        } else {
409            match db_kind {
410                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
411                    quote! {
412                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
413                            sqlx::query_as::<_, #entity_ident>(#sql)
414                                #(#all_binds)*
415                                .fetch_one(&self.pool)
416                                .await
417                        }
418                    }
419                }
420                DatabaseKind::Mysql => {
421                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
422                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
423                    let pk_binds: Vec<TokenStream> = pk_fields
424                        .iter()
425                        .map(|f| {
426                            let name = format_ident!("{}", f.rust_name);
427                            quote! { .bind(&params.#name) }
428                        })
429                        .collect();
430                    quote! {
431                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
432                            sqlx::query(#sql)
433                                #(#all_binds)*
434                                .execute(&self.pool)
435                                .await?;
436                            sqlx::query_as::<_, #entity_ident>(#select_sql)
437                                #(#pk_binds)*
438                                .fetch_one(&self.pool)
439                                .await
440                        }
441                    }
442                }
443            }
444        };
445        method_tokens.push(update_method);
446
447        param_structs.push(quote! {
448            #[derive(Debug, Clone, Default)]
449            pub struct #update_params_ident {
450                #(#update_fields)*
451            }
452        });
453    }
454
455    // --- delete (skip for views) ---
456    if !is_view && methods.delete && !pk_fields.is_empty() {
457        let pk_params: Vec<TokenStream> = pk_fields
458            .iter()
459            .map(|f| {
460                let name = format_ident!("{}", f.rust_name);
461                let ty: TokenStream = f.inner_type.parse().unwrap();
462                quote! { #name: #ty }
463            })
464            .collect();
465
466        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
467        let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
468        let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
469        let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
470
471        let binds: Vec<TokenStream> = pk_fields
472            .iter()
473            .map(|f| {
474                let name = format_ident!("{}", f.rust_name);
475                quote! { .bind(#name) }
476            })
477            .collect();
478
479        let method = if query_macro {
480            let pk_arg_names: Vec<TokenStream> = pk_fields
481                .iter()
482                .map(|f| {
483                    let name = format_ident!("{}", f.rust_name);
484                    quote! { #name }
485                })
486                .collect();
487            quote! {
488                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
489                    sqlx::query!(#sql_macro, #(#pk_arg_names),*)
490                        .execute(&self.pool)
491                        .await?;
492                    Ok(())
493                }
494            }
495        } else {
496            quote! {
497                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
498                    sqlx::query(#sql)
499                        #(#binds)*
500                        .execute(&self.pool)
501                        .await?;
502                    Ok(())
503                }
504            }
505        };
506        method_tokens.push(method);
507    }
508
509    let pool_vis: TokenStream = match pool_visibility {
510        PoolVisibility::Private => quote! {},
511        PoolVisibility::Pub => quote! { pub },
512        PoolVisibility::PubCrate => quote! { pub(crate) },
513    };
514
515    let tokens = quote! {
516        #(#param_structs)*
517
518        pub struct #repo_ident {
519            #pool_vis pool: #pool_type,
520        }
521
522        impl #repo_ident {
523            pub fn new(pool: #pool_type) -> Self {
524                Self { pool }
525            }
526
527            #(#method_tokens)*
528        }
529    };
530
531    (tokens, imports)
532}
533
534fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
535    match db_kind {
536        DatabaseKind::Postgres => quote! { sqlx::PgPool },
537        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
538        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
539    }
540}
541
542fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
543    match db_kind {
544        DatabaseKind::Postgres => format!("${}", index),
545        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
546    }
547}
548
549fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
550    let base = placeholder(db_kind, index);
551    match (&field.sql_type, field.is_sql_array) {
552        (Some(t), true) => format!("{} as {}[]", base, t),
553        (Some(t), false) => format!("{} as {}", base, t),
554        (None, _) => base,
555    }
556}
557
558fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
559    (0..count)
560        .map(|i| placeholder(db_kind, start + i))
561        .collect::<Vec<_>>()
562        .join(", ")
563}
564
565fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
566    fields
567        .iter()
568        .enumerate()
569        .map(|(i, f)| {
570            if use_cast {
571                placeholder_with_cast(db_kind, start + i, f)
572            } else {
573                placeholder(db_kind, start + i)
574            }
575        })
576        .collect::<Vec<_>>()
577        .join(", ")
578}
579
580fn build_where_clause_parsed(
581    pk_fields: &[&ParsedField],
582    db_kind: DatabaseKind,
583    start_index: usize,
584) -> String {
585    pk_fields
586        .iter()
587        .enumerate()
588        .map(|(i, f)| {
589            let p = placeholder(db_kind, start_index + i);
590            format!("{} = {}", f.column_name, p)
591        })
592        .collect::<Vec<_>>()
593        .join(" AND ")
594}
595
596fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
597    let name = format_ident!("{}", field.rust_name);
598    let check_type = if field.is_nullable {
599        &field.inner_type
600    } else {
601        &field.rust_type
602    };
603    let normalized = check_type.replace(' ', "");
604    if normalized.starts_with("Vec<") {
605        quote! { params.#name.as_slice() }
606    } else {
607        quote! { params.#name }
608    }
609}
610
611fn build_where_clause_cast(
612    pk_fields: &[&ParsedField],
613    db_kind: DatabaseKind,
614    start_index: usize,
615) -> String {
616    pk_fields
617        .iter()
618        .enumerate()
619        .map(|(i, f)| {
620            let p = placeholder_with_cast(db_kind, start_index + i, f);
621            format!("{} = {}", f.column_name, p)
622        })
623        .collect::<Vec<_>>()
624        .join(" AND ")
625}
626
627#[allow(clippy::too_many_arguments)]
628fn build_insert_method_parsed(
629    entity_ident: &proc_macro2::Ident,
630    insert_params_ident: &proc_macro2::Ident,
631    sql: &str,
632    sql_macro: &str,
633    binds: &[TokenStream],
634    db_kind: DatabaseKind,
635    table_name: &str,
636    pk_fields: &[&ParsedField],
637    non_pk_fields: &[&ParsedField],
638    use_macro: bool,
639) -> TokenStream {
640    if use_macro {
641        let macro_args: Vec<TokenStream> = non_pk_fields
642            .iter()
643            .map(|f| macro_arg_for_field(f))
644            .collect();
645
646        match db_kind {
647            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
648                quote! {
649                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
650                        sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
651                            .fetch_one(&self.pool)
652                            .await
653                    }
654                }
655            }
656            DatabaseKind::Mysql => {
657                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
658                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
659                quote! {
660                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
661                        sqlx::query!(#sql_macro, #(#macro_args),*)
662                            .execute(&self.pool)
663                            .await?;
664                        let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
665                            .fetch_one(&self.pool)
666                            .await?;
667                        sqlx::query_as!(#entity_ident, #select_sql, id)
668                            .fetch_one(&self.pool)
669                            .await
670                    }
671                }
672            }
673        }
674    } else {
675        match db_kind {
676            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
677                quote! {
678                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
679                        sqlx::query_as::<_, #entity_ident>(#sql)
680                            #(#binds)*
681                            .fetch_one(&self.pool)
682                            .await
683                    }
684                }
685            }
686            DatabaseKind::Mysql => {
687                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
688                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
689                quote! {
690                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
691                        sqlx::query(#sql)
692                            #(#binds)*
693                            .execute(&self.pool)
694                            .await?;
695                        let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
696                            .fetch_one(&self.pool)
697                            .await?;
698                        sqlx::query_as::<_, #entity_ident>(#select_sql)
699                            .bind(id)
700                            .fetch_one(&self.pool)
701                            .await
702                    }
703                }
704            }
705        }
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use crate::codegen::parse_and_format;
713    use crate::cli::Methods;
714
715    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
716        let inner_type = if nullable {
717            // Strip "Option<" prefix and ">" suffix
718            rust_type
719                .strip_prefix("Option<")
720                .and_then(|s| s.strip_suffix('>'))
721                .unwrap_or(rust_type)
722                .to_string()
723        } else {
724            rust_type.to_string()
725        };
726        ParsedField {
727            rust_name: rust_name.to_string(),
728            column_name: column_name.to_string(),
729            rust_type: rust_type.to_string(),
730            is_nullable: nullable,
731            inner_type,
732            is_primary_key: is_pk,
733            sql_type: None,
734            is_sql_array: false,
735        }
736    }
737
738    fn standard_entity() -> ParsedEntity {
739        ParsedEntity {
740            struct_name: "Users".to_string(),
741            table_name: "users".to_string(),
742            schema_name: None,
743            is_view: false,
744            fields: vec![
745                make_field("id", "id", "i32", false, true),
746                make_field("name", "name", "String", false, false),
747                make_field("email", "email", "Option<String>", true, false),
748            ],
749            imports: vec![],
750        }
751    }
752
753    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
754        let skip = Methods::all();
755        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
756        parse_and_format(&tokens)
757    }
758
759    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
760        let skip = Methods::all();
761        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
762        parse_and_format(&tokens)
763    }
764
765    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
766        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
767        parse_and_format(&tokens)
768    }
769
770    // --- basic structure ---
771
772    #[test]
773    fn test_repo_struct_name() {
774        let code = gen(&standard_entity(), DatabaseKind::Postgres);
775        assert!(code.contains("pub struct UsersRepository"));
776    }
777
778    #[test]
779    fn test_repo_new_method() {
780        let code = gen(&standard_entity(), DatabaseKind::Postgres);
781        assert!(code.contains("pub fn new("));
782    }
783
784    #[test]
785    fn test_repo_pool_field_pg() {
786        let code = gen(&standard_entity(), DatabaseKind::Postgres);
787        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
788    }
789
790    #[test]
791    fn test_repo_pool_field_pub() {
792        let skip = Methods::all();
793        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
794        let code = parse_and_format(&tokens);
795        assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
796    }
797
798    #[test]
799    fn test_repo_pool_field_pub_crate() {
800        let skip = Methods::all();
801        let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
802        let code = parse_and_format(&tokens);
803        assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
804    }
805
806    #[test]
807    fn test_repo_pool_field_private() {
808        let code = gen(&standard_entity(), DatabaseKind::Postgres);
809        // Should NOT have `pub pool` or `pub(crate) pool`
810        assert!(!code.contains("pub pool"));
811        assert!(!code.contains("pub(crate) pool"));
812    }
813
814    #[test]
815    fn test_repo_pool_field_mysql() {
816        let code = gen(&standard_entity(), DatabaseKind::Mysql);
817        assert!(code.contains("MySqlPool") || code.contains("MySql"));
818    }
819
820    #[test]
821    fn test_repo_pool_field_sqlite() {
822        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
823        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
824    }
825
826    // --- get_all ---
827
828    #[test]
829    fn test_get_all_method() {
830        let code = gen(&standard_entity(), DatabaseKind::Postgres);
831        assert!(code.contains("pub async fn get_all"));
832    }
833
834    #[test]
835    fn test_get_all_returns_vec() {
836        let code = gen(&standard_entity(), DatabaseKind::Postgres);
837        assert!(code.contains("Vec<Users>"));
838    }
839
840    #[test]
841    fn test_get_all_sql() {
842        let code = gen(&standard_entity(), DatabaseKind::Postgres);
843        assert!(code.contains("SELECT * FROM users"));
844    }
845
846    // --- paginate ---
847
848    #[test]
849    fn test_paginate_method() {
850        let code = gen(&standard_entity(), DatabaseKind::Postgres);
851        assert!(code.contains("pub async fn paginate"));
852    }
853
854    #[test]
855    fn test_paginate_params_struct() {
856        let code = gen(&standard_entity(), DatabaseKind::Postgres);
857        assert!(code.contains("pub struct PaginateUsersParams"));
858    }
859
860    #[test]
861    fn test_paginate_params_fields() {
862        let code = gen(&standard_entity(), DatabaseKind::Postgres);
863        assert!(code.contains("pub page: i64"));
864        assert!(code.contains("pub per_page: i64"));
865    }
866
867    #[test]
868    fn test_paginate_returns_paginated() {
869        let code = gen(&standard_entity(), DatabaseKind::Postgres);
870        assert!(code.contains("PaginatedUsers"));
871        assert!(code.contains("PaginationUsersMeta"));
872    }
873
874    #[test]
875    fn test_paginate_meta_struct() {
876        let code = gen(&standard_entity(), DatabaseKind::Postgres);
877        assert!(code.contains("pub struct PaginationUsersMeta"));
878        assert!(code.contains("pub total: i64"));
879        assert!(code.contains("pub last_page: i64"));
880        assert!(code.contains("pub first_page: i64"));
881        assert!(code.contains("pub current_page: i64"));
882    }
883
884    #[test]
885    fn test_paginate_data_struct() {
886        let code = gen(&standard_entity(), DatabaseKind::Postgres);
887        assert!(code.contains("pub struct PaginatedUsers"));
888        assert!(code.contains("pub meta: PaginationUsersMeta"));
889        assert!(code.contains("pub data: Vec<Users>"));
890    }
891
892    #[test]
893    fn test_paginate_count_sql() {
894        let code = gen(&standard_entity(), DatabaseKind::Postgres);
895        assert!(code.contains("SELECT COUNT(*) FROM users"));
896    }
897
898    #[test]
899    fn test_paginate_sql_pg() {
900        let code = gen(&standard_entity(), DatabaseKind::Postgres);
901        assert!(code.contains("LIMIT $1 OFFSET $2"));
902    }
903
904    #[test]
905    fn test_paginate_sql_mysql() {
906        let code = gen(&standard_entity(), DatabaseKind::Mysql);
907        assert!(code.contains("LIMIT ? OFFSET ?"));
908    }
909
910    // --- get ---
911
912    #[test]
913    fn test_get_method() {
914        let code = gen(&standard_entity(), DatabaseKind::Postgres);
915        assert!(code.contains("pub async fn get"));
916    }
917
918    #[test]
919    fn test_get_returns_option() {
920        let code = gen(&standard_entity(), DatabaseKind::Postgres);
921        assert!(code.contains("Option<Users>"));
922    }
923
924    #[test]
925    fn test_get_where_pk_pg() {
926        let code = gen(&standard_entity(), DatabaseKind::Postgres);
927        assert!(code.contains("WHERE id = $1"));
928    }
929
930    #[test]
931    fn test_get_where_pk_mysql() {
932        let code = gen(&standard_entity(), DatabaseKind::Mysql);
933        assert!(code.contains("WHERE id = ?"));
934    }
935
936    // --- insert ---
937
938    #[test]
939    fn test_insert_method() {
940        let code = gen(&standard_entity(), DatabaseKind::Postgres);
941        assert!(code.contains("pub async fn insert"));
942    }
943
944    #[test]
945    fn test_insert_params_struct() {
946        let code = gen(&standard_entity(), DatabaseKind::Postgres);
947        assert!(code.contains("pub struct InsertUsersParams"));
948    }
949
950    #[test]
951    fn test_insert_params_no_pk() {
952        let code = gen(&standard_entity(), DatabaseKind::Postgres);
953        assert!(code.contains("pub name: String"));
954        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
955    }
956
957    #[test]
958    fn test_insert_returning_pg() {
959        let code = gen(&standard_entity(), DatabaseKind::Postgres);
960        assert!(code.contains("RETURNING *"));
961    }
962
963    #[test]
964    fn test_insert_returning_sqlite() {
965        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
966        assert!(code.contains("RETURNING *"));
967    }
968
969    #[test]
970    fn test_insert_mysql_last_insert_id() {
971        let code = gen(&standard_entity(), DatabaseKind::Mysql);
972        assert!(code.contains("LAST_INSERT_ID"));
973    }
974
975    // --- update ---
976
977    #[test]
978    fn test_update_method() {
979        let code = gen(&standard_entity(), DatabaseKind::Postgres);
980        assert!(code.contains("pub async fn update"));
981    }
982
983    #[test]
984    fn test_update_params_struct() {
985        let code = gen(&standard_entity(), DatabaseKind::Postgres);
986        assert!(code.contains("pub struct UpdateUsersParams"));
987    }
988
989    #[test]
990    fn test_update_params_all_cols() {
991        let code = gen(&standard_entity(), DatabaseKind::Postgres);
992        assert!(code.contains("pub id: i32"));
993        assert!(code.contains("pub name: String"));
994    }
995
996    #[test]
997    fn test_update_set_clause_pg() {
998        let code = gen(&standard_entity(), DatabaseKind::Postgres);
999        assert!(code.contains("SET name = $1"));
1000        assert!(code.contains("WHERE id = $3"));
1001    }
1002
1003    #[test]
1004    fn test_update_returning_pg() {
1005        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1006        assert!(code.contains("UPDATE users SET"));
1007        assert!(code.contains("RETURNING *"));
1008    }
1009
1010    // --- delete ---
1011
1012    #[test]
1013    fn test_delete_method() {
1014        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1015        assert!(code.contains("pub async fn delete"));
1016    }
1017
1018    #[test]
1019    fn test_delete_where_pk() {
1020        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1021        assert!(code.contains("DELETE FROM users WHERE id = $1"));
1022    }
1023
1024    #[test]
1025    fn test_delete_returns_unit() {
1026        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1027        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1028    }
1029
1030    // --- views (read-only) ---
1031
1032    #[test]
1033    fn test_view_no_insert() {
1034        let mut entity = standard_entity();
1035        entity.is_view = true;
1036        let code = gen(&entity, DatabaseKind::Postgres);
1037        assert!(!code.contains("pub async fn insert"));
1038    }
1039
1040    #[test]
1041    fn test_view_no_update() {
1042        let mut entity = standard_entity();
1043        entity.is_view = true;
1044        let code = gen(&entity, DatabaseKind::Postgres);
1045        assert!(!code.contains("pub async fn update"));
1046    }
1047
1048    #[test]
1049    fn test_view_no_delete() {
1050        let mut entity = standard_entity();
1051        entity.is_view = true;
1052        let code = gen(&entity, DatabaseKind::Postgres);
1053        assert!(!code.contains("pub async fn delete"));
1054    }
1055
1056    #[test]
1057    fn test_view_has_get_all() {
1058        let mut entity = standard_entity();
1059        entity.is_view = true;
1060        let code = gen(&entity, DatabaseKind::Postgres);
1061        assert!(code.contains("pub async fn get_all"));
1062    }
1063
1064    #[test]
1065    fn test_view_has_paginate() {
1066        let mut entity = standard_entity();
1067        entity.is_view = true;
1068        let code = gen(&entity, DatabaseKind::Postgres);
1069        assert!(code.contains("pub async fn paginate"));
1070    }
1071
1072    #[test]
1073    fn test_view_has_get() {
1074        let mut entity = standard_entity();
1075        entity.is_view = true;
1076        let code = gen(&entity, DatabaseKind::Postgres);
1077        assert!(code.contains("pub async fn get"));
1078    }
1079
1080    // --- selective methods ---
1081
1082    #[test]
1083    fn test_only_get_all() {
1084        let m = Methods { get_all: true, ..Default::default() };
1085        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1086        assert!(code.contains("pub async fn get_all"));
1087        assert!(!code.contains("pub async fn paginate"));
1088        assert!(!code.contains("pub async fn insert"));
1089    }
1090
1091    #[test]
1092    fn test_without_get_all() {
1093        let m = Methods { get_all: false, ..Methods::all() };
1094        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1095        assert!(!code.contains("pub async fn get_all"));
1096    }
1097
1098    #[test]
1099    fn test_without_paginate() {
1100        let m = Methods { paginate: false, ..Methods::all() };
1101        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1102        assert!(!code.contains("pub async fn paginate"));
1103        assert!(!code.contains("PaginateUsersParams"));
1104    }
1105
1106    #[test]
1107    fn test_without_get() {
1108        let m = Methods { get: false, ..Methods::all() };
1109        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1110        assert!(code.contains("pub async fn get_all"));
1111        let without_get_all = code.replace("get_all", "XXX");
1112        assert!(!without_get_all.contains("fn get("));
1113    }
1114
1115    #[test]
1116    fn test_without_insert() {
1117        let m = Methods { insert: false, ..Methods::all() };
1118        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1119        assert!(!code.contains("pub async fn insert"));
1120        assert!(!code.contains("InsertUsersParams"));
1121    }
1122
1123    #[test]
1124    fn test_without_update() {
1125        let m = Methods { update: false, ..Methods::all() };
1126        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1127        assert!(!code.contains("pub async fn update"));
1128        assert!(!code.contains("UpdateUsersParams"));
1129    }
1130
1131    #[test]
1132    fn test_without_delete() {
1133        let m = Methods { delete: false, ..Methods::all() };
1134        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1135        assert!(!code.contains("pub async fn delete"));
1136    }
1137
1138    #[test]
1139    fn test_empty_methods_no_methods() {
1140        let m = Methods::default();
1141        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1142        assert!(!code.contains("pub async fn get_all"));
1143        assert!(!code.contains("pub async fn paginate"));
1144        assert!(!code.contains("pub async fn insert"));
1145        assert!(!code.contains("pub async fn update"));
1146        assert!(!code.contains("pub async fn delete"));
1147    }
1148
1149    // --- imports ---
1150
1151    #[test]
1152    fn test_no_pool_import() {
1153        let skip = Methods::all();
1154        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1155        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1156    }
1157
1158    #[test]
1159    fn test_imports_contain_entity() {
1160        let skip = Methods::all();
1161        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1162        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1163    }
1164
1165    // --- renamed columns ---
1166
1167    #[test]
1168    fn test_renamed_column_in_sql() {
1169        let entity = ParsedEntity {
1170            struct_name: "Connector".to_string(),
1171            table_name: "connector".to_string(),
1172            schema_name: None,
1173            is_view: false,
1174            fields: vec![
1175                make_field("id", "id", "i32", false, true),
1176                make_field("connector_type", "type", "String", false, false),
1177            ],
1178            imports: vec![],
1179        };
1180        let code = gen(&entity, DatabaseKind::Postgres);
1181        // INSERT should use the DB column name "type", not "connector_type"
1182        assert!(code.contains("type"));
1183        // The Rust param field should be connector_type
1184        assert!(code.contains("pub connector_type: String"));
1185    }
1186
1187    // --- no PK edge cases ---
1188
1189    #[test]
1190    fn test_no_pk_no_get() {
1191        let entity = ParsedEntity {
1192            struct_name: "Logs".to_string(),
1193            table_name: "logs".to_string(),
1194            schema_name: None,
1195            is_view: false,
1196            fields: vec![
1197                make_field("message", "message", "String", false, false),
1198                make_field("ts", "ts", "String", false, false),
1199            ],
1200            imports: vec![],
1201        };
1202        let code = gen(&entity, DatabaseKind::Postgres);
1203        assert!(code.contains("pub async fn get_all"));
1204        let without_get_all = code.replace("get_all", "XXX");
1205        assert!(!without_get_all.contains("fn get("));
1206    }
1207
1208    #[test]
1209    fn test_no_pk_no_delete() {
1210        let entity = ParsedEntity {
1211            struct_name: "Logs".to_string(),
1212            table_name: "logs".to_string(),
1213            schema_name: None,
1214            is_view: false,
1215            fields: vec![
1216                make_field("message", "message", "String", false, false),
1217            ],
1218            imports: vec![],
1219        };
1220        let code = gen(&entity, DatabaseKind::Postgres);
1221        assert!(!code.contains("pub async fn delete"));
1222    }
1223
1224    // --- Default derive on param structs ---
1225
1226    #[test]
1227    fn test_param_structs_have_default() {
1228        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1229        assert!(code.contains("Default"));
1230    }
1231
1232    // --- entity imports forwarded ---
1233
1234    #[test]
1235    fn test_entity_imports_forwarded() {
1236        let entity = ParsedEntity {
1237            struct_name: "Users".to_string(),
1238            table_name: "users".to_string(),
1239            schema_name: None,
1240            is_view: false,
1241            fields: vec![
1242                make_field("id", "id", "Uuid", false, true),
1243                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1244            ],
1245            imports: vec![
1246                "use chrono::{DateTime, Utc};".to_string(),
1247                "use uuid::Uuid;".to_string(),
1248            ],
1249        };
1250        let skip = Methods::all();
1251        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1252        assert!(imports.iter().any(|i| i.contains("chrono")));
1253        assert!(imports.iter().any(|i| i.contains("uuid")));
1254    }
1255
1256    #[test]
1257    fn test_entity_imports_empty_when_no_imports() {
1258        let skip = Methods::all();
1259        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1260        // Should only have pool + entity imports, no chrono/uuid
1261        assert!(!imports.iter().any(|i| i.contains("chrono")));
1262        assert!(!imports.iter().any(|i| i.contains("uuid")));
1263    }
1264
1265    // --- query_macro mode ---
1266
1267    #[test]
1268    fn test_macro_get_all() {
1269        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1270        assert!(code.contains("query_as!"));
1271        assert!(!code.contains("query_as::<"));
1272    }
1273
1274    #[test]
1275    fn test_macro_paginate() {
1276        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1277        assert!(code.contains("query_as!"));
1278        assert!(code.contains("per_page, offset"));
1279    }
1280
1281    #[test]
1282    fn test_macro_get() {
1283        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1284        // The get method should use query_as! with the PK as arg
1285        assert!(code.contains("query_as!(Users"));
1286    }
1287
1288    #[test]
1289    fn test_macro_insert_pg() {
1290        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1291        assert!(code.contains("query_as!(Users"));
1292        assert!(code.contains("params.name"));
1293        assert!(code.contains("params.email"));
1294    }
1295
1296    #[test]
1297    fn test_macro_insert_mysql() {
1298        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1299        // MySQL insert uses query! (not query_as!) for the INSERT
1300        assert!(code.contains("query!"));
1301        assert!(code.contains("query_scalar!"));
1302    }
1303
1304    #[test]
1305    fn test_macro_update() {
1306        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1307        assert!(code.contains("query_as!(Users"));
1308        // Should contain params.name, params.email, params.id as args
1309        assert!(code.contains("params.name"));
1310        assert!(code.contains("params.id"));
1311    }
1312
1313    #[test]
1314    fn test_macro_delete() {
1315        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1316        // delete uses query! (no return type)
1317        assert!(code.contains("query!"));
1318    }
1319
1320    #[test]
1321    fn test_macro_no_bind_calls() {
1322        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1323        assert!(!code.contains(".bind("));
1324    }
1325
1326    #[test]
1327    fn test_function_style_uses_bind() {
1328        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1329        assert!(code.contains(".bind("));
1330        assert!(!code.contains("query_as!("));
1331        assert!(!code.contains("query!("));
1332    }
1333
1334    // --- custom sql_type fallback: macro mode + custom type → runtime for SELECT, macro for DELETE ---
1335
1336    fn entity_with_sql_array() -> ParsedEntity {
1337        ParsedEntity {
1338            struct_name: "AgentConnector".to_string(),
1339            table_name: "agent.agent_connector".to_string(),
1340            schema_name: Some("agent".to_string()),
1341            is_view: false,
1342            fields: vec![
1343                ParsedField {
1344                    rust_name: "connector_id".to_string(),
1345                    column_name: "connector_id".to_string(),
1346                    rust_type: "Uuid".to_string(),
1347                    inner_type: "Uuid".to_string(),
1348                    is_nullable: false,
1349                    is_primary_key: true,
1350                    sql_type: None,
1351                    is_sql_array: false,
1352                },
1353                ParsedField {
1354                    rust_name: "agent_id".to_string(),
1355                    column_name: "agent_id".to_string(),
1356                    rust_type: "Uuid".to_string(),
1357                    inner_type: "Uuid".to_string(),
1358                    is_nullable: false,
1359                    is_primary_key: false,
1360                    sql_type: None,
1361                    is_sql_array: false,
1362                },
1363                ParsedField {
1364                    rust_name: "usages".to_string(),
1365                    column_name: "usages".to_string(),
1366                    rust_type: "Vec<ConnectorUsages>".to_string(),
1367                    inner_type: "Vec<ConnectorUsages>".to_string(),
1368                    is_nullable: false,
1369                    is_primary_key: false,
1370                    sql_type: Some("agent.connector_usages".to_string()),
1371                    is_sql_array: true,
1372                },
1373            ],
1374            imports: vec!["use uuid::Uuid;".to_string()],
1375        }
1376    }
1377
1378    fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1379        let skip = Methods::all();
1380        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1381        parse_and_format(&tokens)
1382    }
1383
1384    #[test]
1385    fn test_sql_array_macro_get_all_uses_runtime() {
1386        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1387        // get_all should use runtime query_as, not macro
1388        assert!(code.contains("query_as::<"));
1389    }
1390
1391    #[test]
1392    fn test_sql_array_macro_get_uses_runtime() {
1393        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1394        // get should use .bind( since it's runtime
1395        assert!(code.contains(".bind("));
1396    }
1397
1398    #[test]
1399    fn test_sql_array_macro_insert_uses_runtime() {
1400        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1401        // insert RETURNING should use runtime query_as
1402        assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1403    }
1404
1405    #[test]
1406    fn test_sql_array_macro_update_uses_runtime() {
1407        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1408        // update RETURNING should use runtime query_as
1409        assert!(code.contains("query_as::<"));
1410    }
1411
1412    #[test]
1413    fn test_sql_array_macro_delete_still_uses_macro() {
1414        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1415        // delete uses query! macro (no rows returned, no array issue)
1416        assert!(code.contains("query!"));
1417    }
1418
1419    #[test]
1420    fn test_sql_array_no_query_as_macro() {
1421        let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1422        // Should NOT contain query_as! macro (only query_as::<_ for runtime)
1423        assert!(!code.contains("query_as!("));
1424    }
1425
1426    // --- custom enum (non-array) also triggers runtime fallback ---
1427
1428    fn entity_with_sql_enum() -> ParsedEntity {
1429        ParsedEntity {
1430            struct_name: "Task".to_string(),
1431            table_name: "tasks".to_string(),
1432            schema_name: None,
1433            is_view: false,
1434            fields: vec![
1435                ParsedField {
1436                    rust_name: "id".to_string(),
1437                    column_name: "id".to_string(),
1438                    rust_type: "i32".to_string(),
1439                    inner_type: "i32".to_string(),
1440                    is_nullable: false,
1441                    is_primary_key: true,
1442                    sql_type: None,
1443                    is_sql_array: false,
1444                },
1445                ParsedField {
1446                    rust_name: "status".to_string(),
1447                    column_name: "status".to_string(),
1448                    rust_type: "TaskStatus".to_string(),
1449                    inner_type: "TaskStatus".to_string(),
1450                    is_nullable: false,
1451                    is_primary_key: false,
1452                    sql_type: Some("task_status".to_string()),
1453                    is_sql_array: false,
1454                },
1455            ],
1456            imports: vec![],
1457        }
1458    }
1459
1460    #[test]
1461    fn test_sql_enum_macro_uses_runtime() {
1462        let skip = Methods::all();
1463        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1464        let code = parse_and_format(&tokens);
1465        // SELECT queries should use runtime query_as, not macro
1466        assert!(code.contains("query_as::<"));
1467        assert!(!code.contains("query_as!("));
1468    }
1469
1470    #[test]
1471    fn test_sql_enum_macro_delete_still_uses_macro() {
1472        let skip = Methods::all();
1473        let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1474        let code = parse_and_format(&tokens);
1475        // DELETE still uses query! macro
1476        assert!(code.contains("query!"));
1477    }
1478
1479    // --- Vec<String> native array uses .as_slice() in macro mode ---
1480
1481    fn entity_with_vec_string() -> ParsedEntity {
1482        ParsedEntity {
1483            struct_name: "PromptHistory".to_string(),
1484            table_name: "prompt_history".to_string(),
1485            schema_name: None,
1486            is_view: false,
1487            fields: vec![
1488                ParsedField {
1489                    rust_name: "id".to_string(),
1490                    column_name: "id".to_string(),
1491                    rust_type: "Uuid".to_string(),
1492                    inner_type: "Uuid".to_string(),
1493                    is_nullable: false,
1494                    is_primary_key: true,
1495                    sql_type: None,
1496                    is_sql_array: false,
1497                },
1498                ParsedField {
1499                    rust_name: "content".to_string(),
1500                    column_name: "content".to_string(),
1501                    rust_type: "String".to_string(),
1502                    inner_type: "String".to_string(),
1503                    is_nullable: false,
1504                    is_primary_key: false,
1505                    sql_type: None,
1506                    is_sql_array: false,
1507                },
1508                ParsedField {
1509                    rust_name: "tags".to_string(),
1510                    column_name: "tags".to_string(),
1511                    rust_type: "Vec<String>".to_string(),
1512                    inner_type: "Vec<String>".to_string(),
1513                    is_nullable: false,
1514                    is_primary_key: false,
1515                    sql_type: None,
1516                    is_sql_array: false,
1517                },
1518            ],
1519            imports: vec!["use uuid::Uuid;".to_string()],
1520        }
1521    }
1522
1523    #[test]
1524    fn test_vec_string_macro_insert_uses_as_slice() {
1525        let skip = Methods::all();
1526        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1527        let code = parse_and_format(&tokens);
1528        assert!(code.contains("as_slice()"));
1529    }
1530
1531    #[test]
1532    fn test_vec_string_macro_update_uses_as_slice() {
1533        let skip = Methods::all();
1534        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1535        let code = parse_and_format(&tokens);
1536        // Should have as_slice() for both insert and update
1537        let count = code.matches("as_slice()").count();
1538        assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1539    }
1540
1541    #[test]
1542    fn test_vec_string_non_macro_no_as_slice() {
1543        let skip = Methods::all();
1544        let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1545        let code = parse_and_format(&tokens);
1546        // Runtime mode uses .bind() so no as_slice needed
1547        assert!(!code.contains("as_slice()"));
1548    }
1549
1550    #[test]
1551    fn test_vec_string_parsed_from_source_uses_as_slice() {
1552        use crate::codegen::entity_parser::parse_entity_source;
1553        let source = r#"
1554            use uuid::Uuid;
1555
1556            #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1557            #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1558            pub struct PromptHistory {
1559                #[sqlx_gen(primary_key)]
1560                pub id: Uuid,
1561                pub content: String,
1562                pub tags: Vec<String>,
1563            }
1564        "#;
1565        let entity = parse_entity_source(source).unwrap();
1566        let skip = Methods::all();
1567        let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1568        let code = parse_and_format(&tokens);
1569        assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1570    }
1571}