Skip to main content

sqlx_gen/codegen/
crud_gen.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10    entity: &ParsedEntity,
11    db_kind: DatabaseKind,
12    entity_module_path: &str,
13    methods: &Methods,
14    query_macro: bool,
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17
18    let entity_ident = format_ident!("{}", entity.struct_name);
19    let repo_name = format!("{}Repository", entity.struct_name);
20    let repo_ident = format_ident!("{}", repo_name);
21
22    let table_name = &entity.table_name;
23
24    // Pool type (used via full path sqlx::PgPool etc., no import needed)
25    let pool_type = pool_type_tokens(db_kind);
26
27    // Entity import
28    imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
29
30    // Forward type imports from the entity file (chrono, uuid, etc.)
31    for imp in &entity.imports {
32        imports.insert(imp.clone());
33    }
34
35    // Primary key fields
36    let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
37
38    // Non-PK fields (for insert)
39    let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
40
41    let is_view = entity.is_view;
42
43    // Build method tokens
44    let mut method_tokens = Vec::new();
45    let mut param_structs = Vec::new();
46
47    // --- get_all ---
48    if methods.get_all {
49        let sql = format!("SELECT * FROM {}", table_name);
50        let method = if query_macro {
51            quote! {
52                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
53                    sqlx::query_as!(#entity_ident, #sql)
54                        .fetch_all(&self.pool)
55                        .await
56                }
57            }
58        } else {
59            quote! {
60                pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
61                    sqlx::query_as::<_, #entity_ident>(#sql)
62                        .fetch_all(&self.pool)
63                        .await
64                }
65            }
66        };
67        method_tokens.push(method);
68    }
69
70    // --- paginate ---
71    if methods.paginate {
72        let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
73        let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
74        let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
75        let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
76        let sql = match db_kind {
77            DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
78            DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
79        };
80        let method = if query_macro {
81            quote! {
82                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
83                    let total: i64 = sqlx::query_scalar!(#count_sql)
84                        .fetch_one(&self.pool)
85                        .await?
86                        .unwrap_or(0);
87                    let per_page = params.per_page;
88                    let current_page = params.page;
89                    let last_page = (total + per_page - 1) / per_page;
90                    let offset = (current_page - 1) * per_page;
91                    let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
92                        .fetch_all(&self.pool)
93                        .await?;
94                    Ok(#paginated_ident {
95                        meta: #pagination_meta_ident {
96                            total,
97                            per_page,
98                            current_page,
99                            last_page,
100                            first_page: 1,
101                        },
102                        data,
103                    })
104                }
105            }
106        } else {
107            quote! {
108                pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
109                    let total: i64 = sqlx::query_scalar(#count_sql)
110                        .fetch_one(&self.pool)
111                        .await?;
112                    let per_page = params.per_page;
113                    let current_page = params.page;
114                    let last_page = (total + per_page - 1) / per_page;
115                    let offset = (current_page - 1) * per_page;
116                    let data = sqlx::query_as::<_, #entity_ident>(#sql)
117                        .bind(per_page)
118                        .bind(offset)
119                        .fetch_all(&self.pool)
120                        .await?;
121                    Ok(#paginated_ident {
122                        meta: #pagination_meta_ident {
123                            total,
124                            per_page,
125                            current_page,
126                            last_page,
127                            first_page: 1,
128                        },
129                        data,
130                    })
131                }
132            }
133        };
134        method_tokens.push(method);
135        param_structs.push(quote! {
136            #[derive(Debug, Clone, Default)]
137            pub struct #paginate_params_ident {
138                pub page: i64,
139                pub per_page: i64,
140            }
141        });
142        param_structs.push(quote! {
143            #[derive(Debug, Clone)]
144            pub struct #pagination_meta_ident {
145                pub total: i64,
146                pub per_page: i64,
147                pub current_page: i64,
148                pub last_page: i64,
149                pub first_page: i64,
150            }
151        });
152        param_structs.push(quote! {
153            #[derive(Debug, Clone)]
154            pub struct #paginated_ident {
155                pub meta: #pagination_meta_ident,
156                pub data: Vec<#entity_ident>,
157            }
158        });
159    }
160
161    // --- get (by PK) ---
162    if methods.get && !pk_fields.is_empty() {
163        let pk_params: Vec<TokenStream> = pk_fields
164            .iter()
165            .map(|f| {
166                let name = format_ident!("{}", f.rust_name);
167                let ty: TokenStream = f.inner_type.parse().unwrap();
168                quote! { #name: &#ty }
169            })
170            .collect();
171
172        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
173        let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
174
175        let binds: Vec<TokenStream> = pk_fields
176            .iter()
177            .map(|f| {
178                let name = format_ident!("{}", f.rust_name);
179                quote! { .bind(#name) }
180            })
181            .collect();
182
183        let method = if query_macro {
184            let pk_arg_names: Vec<TokenStream> = pk_fields
185                .iter()
186                .map(|f| {
187                    let name = format_ident!("{}", f.rust_name);
188                    quote! { #name }
189                })
190                .collect();
191            quote! {
192                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
193                    sqlx::query_as!(#entity_ident, #sql, #(#pk_arg_names),*)
194                        .fetch_optional(&self.pool)
195                        .await
196                }
197            }
198        } else {
199            quote! {
200                pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
201                    sqlx::query_as::<_, #entity_ident>(#sql)
202                        #(#binds)*
203                        .fetch_optional(&self.pool)
204                        .await
205                }
206            }
207        };
208        method_tokens.push(method);
209    }
210
211    // --- insert (skip for views) ---
212    if !is_view && methods.insert && !non_pk_fields.is_empty() {
213        let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
214
215        let insert_fields: Vec<TokenStream> = non_pk_fields
216            .iter()
217            .map(|f| {
218                let name = format_ident!("{}", f.rust_name);
219                let ty: TokenStream = f.rust_type.parse().unwrap();
220                quote! { pub #name: #ty, }
221            })
222            .collect();
223
224        let col_names: Vec<&str> = non_pk_fields.iter().map(|f| f.column_name.as_str()).collect();
225        let col_list = col_names.join(", ");
226        let placeholders = build_placeholders(non_pk_fields.len(), db_kind, 1);
227
228        let sql = match db_kind {
229            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
230                format!(
231                    "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
232                    table_name, col_list, placeholders
233                )
234            }
235            DatabaseKind::Mysql => {
236                format!(
237                    "INSERT INTO {} ({}) VALUES ({})",
238                    table_name, col_list, placeholders
239                )
240            }
241        };
242
243        let binds: Vec<TokenStream> = non_pk_fields
244            .iter()
245            .map(|f| {
246                let name = format_ident!("{}", f.rust_name);
247                quote! { .bind(&params.#name) }
248            })
249            .collect();
250
251        let insert_method = build_insert_method_parsed(
252            &entity_ident,
253            &insert_params_ident,
254            &sql,
255            &binds,
256            db_kind,
257            table_name,
258            &pk_fields,
259            &non_pk_fields,
260            query_macro,
261        );
262        method_tokens.push(insert_method);
263
264        param_structs.push(quote! {
265            #[derive(Debug, Clone, Default)]
266            pub struct #insert_params_ident {
267                #(#insert_fields)*
268            }
269        });
270    }
271
272    // --- update (skip for views) ---
273    if !is_view && methods.update && !pk_fields.is_empty() {
274        let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
275
276        let update_fields: Vec<TokenStream> = entity
277            .fields
278            .iter()
279            .map(|f| {
280                let name = format_ident!("{}", f.rust_name);
281                let ty: TokenStream = f.rust_type.parse().unwrap();
282                quote! { pub #name: #ty, }
283            })
284            .collect();
285
286        let set_cols: Vec<String> = non_pk_fields
287            .iter()
288            .enumerate()
289            .map(|(i, f)| {
290                let p = placeholder(db_kind, i + 1);
291                format!("{} = {}", f.column_name, p)
292            })
293            .collect();
294        let set_clause = set_cols.join(", ");
295
296        let pk_start = non_pk_fields.len() + 1;
297        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
298
299        let sql = match db_kind {
300            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
301                format!(
302                    "UPDATE {} SET {} WHERE {} RETURNING *",
303                    table_name, set_clause, where_clause
304                )
305            }
306            DatabaseKind::Mysql => {
307                format!(
308                    "UPDATE {} SET {} WHERE {}",
309                    table_name, set_clause, where_clause
310                )
311            }
312        };
313
314        // Bind non-PK first, then PK
315        let mut all_binds: Vec<TokenStream> = non_pk_fields
316            .iter()
317            .map(|f| {
318                let name = format_ident!("{}", f.rust_name);
319                quote! { .bind(&params.#name) }
320            })
321            .collect();
322        for f in &pk_fields {
323            let name = format_ident!("{}", f.rust_name);
324            all_binds.push(quote! { .bind(&params.#name) });
325        }
326
327        // Macro args: non-PK fields first, then PK fields
328        let update_macro_args: Vec<TokenStream> = non_pk_fields
329            .iter()
330            .chain(pk_fields.iter())
331            .map(|f| {
332                let name = format_ident!("{}", f.rust_name);
333                quote! { params.#name }
334            })
335            .collect();
336
337        let update_method = if query_macro {
338            match db_kind {
339                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
340                    quote! {
341                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
342                            sqlx::query_as!(#entity_ident, #sql, #(#update_macro_args),*)
343                                .fetch_one(&self.pool)
344                                .await
345                        }
346                    }
347                }
348                DatabaseKind::Mysql => {
349                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
350                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
351                    let pk_macro_args: Vec<TokenStream> = pk_fields
352                        .iter()
353                        .map(|f| {
354                            let name = format_ident!("{}", f.rust_name);
355                            quote! { params.#name }
356                        })
357                        .collect();
358                    quote! {
359                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
360                            sqlx::query!(#sql, #(#update_macro_args),*)
361                                .execute(&self.pool)
362                                .await?;
363                            sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
364                                .fetch_one(&self.pool)
365                                .await
366                        }
367                    }
368                }
369            }
370        } else {
371            match db_kind {
372                DatabaseKind::Postgres | DatabaseKind::Sqlite => {
373                    quote! {
374                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
375                            sqlx::query_as::<_, #entity_ident>(#sql)
376                                #(#all_binds)*
377                                .fetch_one(&self.pool)
378                                .await
379                        }
380                    }
381                }
382                DatabaseKind::Mysql => {
383                    let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
384                    let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
385                    let pk_binds: Vec<TokenStream> = pk_fields
386                        .iter()
387                        .map(|f| {
388                            let name = format_ident!("{}", f.rust_name);
389                            quote! { .bind(&params.#name) }
390                        })
391                        .collect();
392                    quote! {
393                        pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
394                            sqlx::query(#sql)
395                                #(#all_binds)*
396                                .execute(&self.pool)
397                                .await?;
398                            sqlx::query_as::<_, #entity_ident>(#select_sql)
399                                #(#pk_binds)*
400                                .fetch_one(&self.pool)
401                                .await
402                        }
403                    }
404                }
405            }
406        };
407        method_tokens.push(update_method);
408
409        param_structs.push(quote! {
410            #[derive(Debug, Clone, Default)]
411            pub struct #update_params_ident {
412                #(#update_fields)*
413            }
414        });
415    }
416
417    // --- delete (skip for views) ---
418    if !is_view && methods.delete && !pk_fields.is_empty() {
419        let pk_params: Vec<TokenStream> = pk_fields
420            .iter()
421            .map(|f| {
422                let name = format_ident!("{}", f.rust_name);
423                let ty: TokenStream = f.inner_type.parse().unwrap();
424                quote! { #name: &#ty }
425            })
426            .collect();
427
428        let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
429        let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
430
431        let binds: Vec<TokenStream> = pk_fields
432            .iter()
433            .map(|f| {
434                let name = format_ident!("{}", f.rust_name);
435                quote! { .bind(#name) }
436            })
437            .collect();
438
439        let method = if query_macro {
440            let pk_arg_names: Vec<TokenStream> = pk_fields
441                .iter()
442                .map(|f| {
443                    let name = format_ident!("{}", f.rust_name);
444                    quote! { #name }
445                })
446                .collect();
447            quote! {
448                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
449                    sqlx::query!(#sql, #(#pk_arg_names),*)
450                        .execute(&self.pool)
451                        .await?;
452                    Ok(())
453                }
454            }
455        } else {
456            quote! {
457                pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
458                    sqlx::query(#sql)
459                        #(#binds)*
460                        .execute(&self.pool)
461                        .await?;
462                    Ok(())
463                }
464            }
465        };
466        method_tokens.push(method);
467    }
468
469    let tokens = quote! {
470        #(#param_structs)*
471
472        pub struct #repo_ident {
473            pool: #pool_type,
474        }
475
476        impl #repo_ident {
477            pub fn new(pool: #pool_type) -> Self {
478                Self { pool }
479            }
480
481            #(#method_tokens)*
482        }
483    };
484
485    (tokens, imports)
486}
487
488fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
489    match db_kind {
490        DatabaseKind::Postgres => quote! { sqlx::PgPool },
491        DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
492        DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
493    }
494}
495
496fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
497    match db_kind {
498        DatabaseKind::Postgres => format!("${}", index),
499        DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
500    }
501}
502
503fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
504    (0..count)
505        .map(|i| placeholder(db_kind, start + i))
506        .collect::<Vec<_>>()
507        .join(", ")
508}
509
510fn build_where_clause_parsed(
511    pk_fields: &[&ParsedField],
512    db_kind: DatabaseKind,
513    start_index: usize,
514) -> String {
515    pk_fields
516        .iter()
517        .enumerate()
518        .map(|(i, f)| {
519            let p = placeholder(db_kind, start_index + i);
520            format!("{} = {}", f.column_name, p)
521        })
522        .collect::<Vec<_>>()
523        .join(" AND ")
524}
525
526#[allow(clippy::too_many_arguments)]
527fn build_insert_method_parsed(
528    entity_ident: &proc_macro2::Ident,
529    insert_params_ident: &proc_macro2::Ident,
530    sql: &str,
531    binds: &[TokenStream],
532    db_kind: DatabaseKind,
533    table_name: &str,
534    pk_fields: &[&ParsedField],
535    non_pk_fields: &[&ParsedField],
536    query_macro: bool,
537) -> TokenStream {
538    if query_macro {
539        let macro_args: Vec<TokenStream> = non_pk_fields
540            .iter()
541            .map(|f| {
542                let name = format_ident!("{}", f.rust_name);
543                quote! { params.#name }
544            })
545            .collect();
546
547        match db_kind {
548            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
549                quote! {
550                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
551                        sqlx::query_as!(#entity_ident, #sql, #(#macro_args),*)
552                            .fetch_one(&self.pool)
553                            .await
554                    }
555                }
556            }
557            DatabaseKind::Mysql => {
558                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
559                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
560                quote! {
561                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
562                        sqlx::query!(#sql, #(#macro_args),*)
563                            .execute(&self.pool)
564                            .await?;
565                        let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
566                            .fetch_one(&self.pool)
567                            .await?;
568                        sqlx::query_as!(#entity_ident, #select_sql, id)
569                            .fetch_one(&self.pool)
570                            .await
571                    }
572                }
573            }
574        }
575    } else {
576        match db_kind {
577            DatabaseKind::Postgres | DatabaseKind::Sqlite => {
578                quote! {
579                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
580                        sqlx::query_as::<_, #entity_ident>(#sql)
581                            #(#binds)*
582                            .fetch_one(&self.pool)
583                            .await
584                    }
585                }
586            }
587            DatabaseKind::Mysql => {
588                let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
589                let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
590                quote! {
591                    pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
592                        sqlx::query(#sql)
593                            #(#binds)*
594                            .execute(&self.pool)
595                            .await?;
596                        let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
597                            .fetch_one(&self.pool)
598                            .await?;
599                        sqlx::query_as::<_, #entity_ident>(#select_sql)
600                            .bind(id)
601                            .fetch_one(&self.pool)
602                            .await
603                    }
604                }
605            }
606        }
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use crate::codegen::parse_and_format;
614    use crate::cli::Methods;
615
616    fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
617        let inner_type = if nullable {
618            // Strip "Option<" prefix and ">" suffix
619            rust_type
620                .strip_prefix("Option<")
621                .and_then(|s| s.strip_suffix('>'))
622                .unwrap_or(rust_type)
623                .to_string()
624        } else {
625            rust_type.to_string()
626        };
627        ParsedField {
628            rust_name: rust_name.to_string(),
629            column_name: column_name.to_string(),
630            rust_type: rust_type.to_string(),
631            is_nullable: nullable,
632            inner_type,
633            is_primary_key: is_pk,
634        }
635    }
636
637    fn standard_entity() -> ParsedEntity {
638        ParsedEntity {
639            struct_name: "Users".to_string(),
640            table_name: "users".to_string(),
641            schema_name: None,
642            is_view: false,
643            fields: vec![
644                make_field("id", "id", "i32", false, true),
645                make_field("name", "name", "String", false, false),
646                make_field("email", "email", "Option<String>", true, false),
647            ],
648            imports: vec![],
649        }
650    }
651
652    fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
653        let skip = Methods::all();
654        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false);
655        parse_and_format(&tokens)
656    }
657
658    fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
659        let skip = Methods::all();
660        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true);
661        parse_and_format(&tokens)
662    }
663
664    fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
665        let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false);
666        parse_and_format(&tokens)
667    }
668
669    // --- basic structure ---
670
671    #[test]
672    fn test_repo_struct_name() {
673        let code = gen(&standard_entity(), DatabaseKind::Postgres);
674        assert!(code.contains("pub struct UsersRepository"));
675    }
676
677    #[test]
678    fn test_repo_new_method() {
679        let code = gen(&standard_entity(), DatabaseKind::Postgres);
680        assert!(code.contains("pub fn new("));
681    }
682
683    #[test]
684    fn test_repo_pool_field_pg() {
685        let code = gen(&standard_entity(), DatabaseKind::Postgres);
686        assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
687    }
688
689    #[test]
690    fn test_repo_pool_field_mysql() {
691        let code = gen(&standard_entity(), DatabaseKind::Mysql);
692        assert!(code.contains("MySqlPool") || code.contains("MySql"));
693    }
694
695    #[test]
696    fn test_repo_pool_field_sqlite() {
697        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
698        assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
699    }
700
701    // --- get_all ---
702
703    #[test]
704    fn test_get_all_method() {
705        let code = gen(&standard_entity(), DatabaseKind::Postgres);
706        assert!(code.contains("pub async fn get_all"));
707    }
708
709    #[test]
710    fn test_get_all_returns_vec() {
711        let code = gen(&standard_entity(), DatabaseKind::Postgres);
712        assert!(code.contains("Vec<Users>"));
713    }
714
715    #[test]
716    fn test_get_all_sql() {
717        let code = gen(&standard_entity(), DatabaseKind::Postgres);
718        assert!(code.contains("SELECT * FROM users"));
719    }
720
721    // --- paginate ---
722
723    #[test]
724    fn test_paginate_method() {
725        let code = gen(&standard_entity(), DatabaseKind::Postgres);
726        assert!(code.contains("pub async fn paginate"));
727    }
728
729    #[test]
730    fn test_paginate_params_struct() {
731        let code = gen(&standard_entity(), DatabaseKind::Postgres);
732        assert!(code.contains("pub struct PaginateUsersParams"));
733    }
734
735    #[test]
736    fn test_paginate_params_fields() {
737        let code = gen(&standard_entity(), DatabaseKind::Postgres);
738        assert!(code.contains("pub page: i64"));
739        assert!(code.contains("pub per_page: i64"));
740    }
741
742    #[test]
743    fn test_paginate_returns_paginated() {
744        let code = gen(&standard_entity(), DatabaseKind::Postgres);
745        assert!(code.contains("PaginatedUsers"));
746        assert!(code.contains("PaginationUsersMeta"));
747    }
748
749    #[test]
750    fn test_paginate_meta_struct() {
751        let code = gen(&standard_entity(), DatabaseKind::Postgres);
752        assert!(code.contains("pub struct PaginationUsersMeta"));
753        assert!(code.contains("pub total: i64"));
754        assert!(code.contains("pub last_page: i64"));
755        assert!(code.contains("pub first_page: i64"));
756        assert!(code.contains("pub current_page: i64"));
757    }
758
759    #[test]
760    fn test_paginate_data_struct() {
761        let code = gen(&standard_entity(), DatabaseKind::Postgres);
762        assert!(code.contains("pub struct PaginatedUsers"));
763        assert!(code.contains("pub meta: PaginationUsersMeta"));
764        assert!(code.contains("pub data: Vec<Users>"));
765    }
766
767    #[test]
768    fn test_paginate_count_sql() {
769        let code = gen(&standard_entity(), DatabaseKind::Postgres);
770        assert!(code.contains("SELECT COUNT(*) FROM users"));
771    }
772
773    #[test]
774    fn test_paginate_sql_pg() {
775        let code = gen(&standard_entity(), DatabaseKind::Postgres);
776        assert!(code.contains("LIMIT $1 OFFSET $2"));
777    }
778
779    #[test]
780    fn test_paginate_sql_mysql() {
781        let code = gen(&standard_entity(), DatabaseKind::Mysql);
782        assert!(code.contains("LIMIT ? OFFSET ?"));
783    }
784
785    // --- get ---
786
787    #[test]
788    fn test_get_method() {
789        let code = gen(&standard_entity(), DatabaseKind::Postgres);
790        assert!(code.contains("pub async fn get"));
791    }
792
793    #[test]
794    fn test_get_returns_option() {
795        let code = gen(&standard_entity(), DatabaseKind::Postgres);
796        assert!(code.contains("Option<Users>"));
797    }
798
799    #[test]
800    fn test_get_where_pk_pg() {
801        let code = gen(&standard_entity(), DatabaseKind::Postgres);
802        assert!(code.contains("WHERE id = $1"));
803    }
804
805    #[test]
806    fn test_get_where_pk_mysql() {
807        let code = gen(&standard_entity(), DatabaseKind::Mysql);
808        assert!(code.contains("WHERE id = ?"));
809    }
810
811    // --- insert ---
812
813    #[test]
814    fn test_insert_method() {
815        let code = gen(&standard_entity(), DatabaseKind::Postgres);
816        assert!(code.contains("pub async fn insert"));
817    }
818
819    #[test]
820    fn test_insert_params_struct() {
821        let code = gen(&standard_entity(), DatabaseKind::Postgres);
822        assert!(code.contains("pub struct InsertUsersParams"));
823    }
824
825    #[test]
826    fn test_insert_params_no_pk() {
827        let code = gen(&standard_entity(), DatabaseKind::Postgres);
828        assert!(code.contains("pub name: String"));
829        assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
830    }
831
832    #[test]
833    fn test_insert_returning_pg() {
834        let code = gen(&standard_entity(), DatabaseKind::Postgres);
835        assert!(code.contains("RETURNING *"));
836    }
837
838    #[test]
839    fn test_insert_returning_sqlite() {
840        let code = gen(&standard_entity(), DatabaseKind::Sqlite);
841        assert!(code.contains("RETURNING *"));
842    }
843
844    #[test]
845    fn test_insert_mysql_last_insert_id() {
846        let code = gen(&standard_entity(), DatabaseKind::Mysql);
847        assert!(code.contains("LAST_INSERT_ID"));
848    }
849
850    // --- update ---
851
852    #[test]
853    fn test_update_method() {
854        let code = gen(&standard_entity(), DatabaseKind::Postgres);
855        assert!(code.contains("pub async fn update"));
856    }
857
858    #[test]
859    fn test_update_params_struct() {
860        let code = gen(&standard_entity(), DatabaseKind::Postgres);
861        assert!(code.contains("pub struct UpdateUsersParams"));
862    }
863
864    #[test]
865    fn test_update_params_all_cols() {
866        let code = gen(&standard_entity(), DatabaseKind::Postgres);
867        assert!(code.contains("pub id: i32"));
868        assert!(code.contains("pub name: String"));
869    }
870
871    #[test]
872    fn test_update_set_clause_pg() {
873        let code = gen(&standard_entity(), DatabaseKind::Postgres);
874        assert!(code.contains("SET name = $1"));
875        assert!(code.contains("WHERE id = $3"));
876    }
877
878    #[test]
879    fn test_update_returning_pg() {
880        let code = gen(&standard_entity(), DatabaseKind::Postgres);
881        assert!(code.contains("UPDATE users SET"));
882        assert!(code.contains("RETURNING *"));
883    }
884
885    // --- delete ---
886
887    #[test]
888    fn test_delete_method() {
889        let code = gen(&standard_entity(), DatabaseKind::Postgres);
890        assert!(code.contains("pub async fn delete"));
891    }
892
893    #[test]
894    fn test_delete_where_pk() {
895        let code = gen(&standard_entity(), DatabaseKind::Postgres);
896        assert!(code.contains("DELETE FROM users WHERE id = $1"));
897    }
898
899    #[test]
900    fn test_delete_returns_unit() {
901        let code = gen(&standard_entity(), DatabaseKind::Postgres);
902        assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
903    }
904
905    // --- views (read-only) ---
906
907    #[test]
908    fn test_view_no_insert() {
909        let mut entity = standard_entity();
910        entity.is_view = true;
911        let code = gen(&entity, DatabaseKind::Postgres);
912        assert!(!code.contains("pub async fn insert"));
913    }
914
915    #[test]
916    fn test_view_no_update() {
917        let mut entity = standard_entity();
918        entity.is_view = true;
919        let code = gen(&entity, DatabaseKind::Postgres);
920        assert!(!code.contains("pub async fn update"));
921    }
922
923    #[test]
924    fn test_view_no_delete() {
925        let mut entity = standard_entity();
926        entity.is_view = true;
927        let code = gen(&entity, DatabaseKind::Postgres);
928        assert!(!code.contains("pub async fn delete"));
929    }
930
931    #[test]
932    fn test_view_has_get_all() {
933        let mut entity = standard_entity();
934        entity.is_view = true;
935        let code = gen(&entity, DatabaseKind::Postgres);
936        assert!(code.contains("pub async fn get_all"));
937    }
938
939    #[test]
940    fn test_view_has_paginate() {
941        let mut entity = standard_entity();
942        entity.is_view = true;
943        let code = gen(&entity, DatabaseKind::Postgres);
944        assert!(code.contains("pub async fn paginate"));
945    }
946
947    #[test]
948    fn test_view_has_get() {
949        let mut entity = standard_entity();
950        entity.is_view = true;
951        let code = gen(&entity, DatabaseKind::Postgres);
952        assert!(code.contains("pub async fn get"));
953    }
954
955    // --- selective methods ---
956
957    #[test]
958    fn test_only_get_all() {
959        let m = Methods { get_all: true, ..Default::default() };
960        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
961        assert!(code.contains("pub async fn get_all"));
962        assert!(!code.contains("pub async fn paginate"));
963        assert!(!code.contains("pub async fn insert"));
964    }
965
966    #[test]
967    fn test_without_get_all() {
968        let m = Methods { get_all: false, ..Methods::all() };
969        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
970        assert!(!code.contains("pub async fn get_all"));
971    }
972
973    #[test]
974    fn test_without_paginate() {
975        let m = Methods { paginate: false, ..Methods::all() };
976        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
977        assert!(!code.contains("pub async fn paginate"));
978        assert!(!code.contains("PaginateUsersParams"));
979    }
980
981    #[test]
982    fn test_without_get() {
983        let m = Methods { get: false, ..Methods::all() };
984        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
985        assert!(code.contains("pub async fn get_all"));
986        let without_get_all = code.replace("get_all", "XXX");
987        assert!(!without_get_all.contains("fn get("));
988    }
989
990    #[test]
991    fn test_without_insert() {
992        let m = Methods { insert: false, ..Methods::all() };
993        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
994        assert!(!code.contains("pub async fn insert"));
995        assert!(!code.contains("InsertUsersParams"));
996    }
997
998    #[test]
999    fn test_without_update() {
1000        let m = Methods { update: false, ..Methods::all() };
1001        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1002        assert!(!code.contains("pub async fn update"));
1003        assert!(!code.contains("UpdateUsersParams"));
1004    }
1005
1006    #[test]
1007    fn test_without_delete() {
1008        let m = Methods { delete: false, ..Methods::all() };
1009        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1010        assert!(!code.contains("pub async fn delete"));
1011    }
1012
1013    #[test]
1014    fn test_empty_methods_no_methods() {
1015        let m = Methods::default();
1016        let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1017        assert!(!code.contains("pub async fn get_all"));
1018        assert!(!code.contains("pub async fn paginate"));
1019        assert!(!code.contains("pub async fn insert"));
1020        assert!(!code.contains("pub async fn update"));
1021        assert!(!code.contains("pub async fn delete"));
1022    }
1023
1024    // --- imports ---
1025
1026    #[test]
1027    fn test_no_pool_import() {
1028        let skip = Methods::all();
1029        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1030        assert!(!imports.iter().any(|i| i.contains("PgPool")));
1031    }
1032
1033    #[test]
1034    fn test_imports_contain_entity() {
1035        let skip = Methods::all();
1036        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1037        assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1038    }
1039
1040    // --- renamed columns ---
1041
1042    #[test]
1043    fn test_renamed_column_in_sql() {
1044        let entity = ParsedEntity {
1045            struct_name: "Connector".to_string(),
1046            table_name: "connector".to_string(),
1047            schema_name: None,
1048            is_view: false,
1049            fields: vec![
1050                make_field("id", "id", "i32", false, true),
1051                make_field("connector_type", "type", "String", false, false),
1052            ],
1053            imports: vec![],
1054        };
1055        let code = gen(&entity, DatabaseKind::Postgres);
1056        // INSERT should use the DB column name "type", not "connector_type"
1057        assert!(code.contains("type"));
1058        // The Rust param field should be connector_type
1059        assert!(code.contains("pub connector_type: String"));
1060    }
1061
1062    // --- no PK edge cases ---
1063
1064    #[test]
1065    fn test_no_pk_no_get() {
1066        let entity = ParsedEntity {
1067            struct_name: "Logs".to_string(),
1068            table_name: "logs".to_string(),
1069            schema_name: None,
1070            is_view: false,
1071            fields: vec![
1072                make_field("message", "message", "String", false, false),
1073                make_field("ts", "ts", "String", false, false),
1074            ],
1075            imports: vec![],
1076        };
1077        let code = gen(&entity, DatabaseKind::Postgres);
1078        assert!(code.contains("pub async fn get_all"));
1079        let without_get_all = code.replace("get_all", "XXX");
1080        assert!(!without_get_all.contains("fn get("));
1081    }
1082
1083    #[test]
1084    fn test_no_pk_no_delete() {
1085        let entity = ParsedEntity {
1086            struct_name: "Logs".to_string(),
1087            table_name: "logs".to_string(),
1088            schema_name: None,
1089            is_view: false,
1090            fields: vec![
1091                make_field("message", "message", "String", false, false),
1092            ],
1093            imports: vec![],
1094        };
1095        let code = gen(&entity, DatabaseKind::Postgres);
1096        assert!(!code.contains("pub async fn delete"));
1097    }
1098
1099    // --- Default derive on param structs ---
1100
1101    #[test]
1102    fn test_param_structs_have_default() {
1103        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1104        assert!(code.contains("Default"));
1105    }
1106
1107    // --- entity imports forwarded ---
1108
1109    #[test]
1110    fn test_entity_imports_forwarded() {
1111        let entity = ParsedEntity {
1112            struct_name: "Users".to_string(),
1113            table_name: "users".to_string(),
1114            schema_name: None,
1115            is_view: false,
1116            fields: vec![
1117                make_field("id", "id", "Uuid", false, true),
1118                make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1119            ],
1120            imports: vec![
1121                "use chrono::{DateTime, Utc};".to_string(),
1122                "use uuid::Uuid;".to_string(),
1123            ],
1124        };
1125        let skip = Methods::all();
1126        let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false);
1127        assert!(imports.iter().any(|i| i.contains("chrono")));
1128        assert!(imports.iter().any(|i| i.contains("uuid")));
1129    }
1130
1131    #[test]
1132    fn test_entity_imports_empty_when_no_imports() {
1133        let skip = Methods::all();
1134        let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1135        // Should only have pool + entity imports, no chrono/uuid
1136        assert!(!imports.iter().any(|i| i.contains("chrono")));
1137        assert!(!imports.iter().any(|i| i.contains("uuid")));
1138    }
1139
1140    // --- query_macro mode ---
1141
1142    #[test]
1143    fn test_macro_get_all() {
1144        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1145        assert!(code.contains("query_as!"));
1146        assert!(!code.contains("query_as::<"));
1147    }
1148
1149    #[test]
1150    fn test_macro_paginate() {
1151        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1152        assert!(code.contains("query_as!"));
1153        assert!(code.contains("per_page, offset"));
1154    }
1155
1156    #[test]
1157    fn test_macro_get() {
1158        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1159        // The get method should use query_as! with the PK as arg
1160        assert!(code.contains("query_as!(Users"));
1161    }
1162
1163    #[test]
1164    fn test_macro_insert_pg() {
1165        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1166        assert!(code.contains("query_as!(Users"));
1167        assert!(code.contains("params.name"));
1168        assert!(code.contains("params.email"));
1169    }
1170
1171    #[test]
1172    fn test_macro_insert_mysql() {
1173        let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1174        // MySQL insert uses query! (not query_as!) for the INSERT
1175        assert!(code.contains("query!"));
1176        assert!(code.contains("query_scalar!"));
1177    }
1178
1179    #[test]
1180    fn test_macro_update() {
1181        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1182        assert!(code.contains("query_as!(Users"));
1183        // Should contain params.name, params.email, params.id as args
1184        assert!(code.contains("params.name"));
1185        assert!(code.contains("params.id"));
1186    }
1187
1188    #[test]
1189    fn test_macro_delete() {
1190        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1191        // delete uses query! (no return type)
1192        assert!(code.contains("query!"));
1193    }
1194
1195    #[test]
1196    fn test_macro_no_bind_calls() {
1197        let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1198        assert!(!code.contains(".bind("));
1199    }
1200
1201    #[test]
1202    fn test_function_style_uses_bind() {
1203        let code = gen(&standard_entity(), DatabaseKind::Postgres);
1204        assert!(code.contains(".bind("));
1205        assert!(!code.contains("query_as!("));
1206        assert!(!code.contains("query!("));
1207    }
1208}