1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17
18 let entity_ident = format_ident!("{}", entity.struct_name);
19 let repo_name = format!("{}Repository", entity.struct_name);
20 let repo_ident = format_ident!("{}", repo_name);
21
22 let table_name = &entity.table_name;
23
24 let pool_type = pool_type_tokens(db_kind);
26
27 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
29
30 for imp in &entity.imports {
32 imports.insert(imp.clone());
33 }
34
35 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
37
38 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
40
41 let is_view = entity.is_view;
42
43 let mut method_tokens = Vec::new();
45 let mut param_structs = Vec::new();
46
47 if methods.get_all {
49 let sql = format!("SELECT * FROM {}", table_name);
50 let method = if query_macro {
51 quote! {
52 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
53 sqlx::query_as!(#entity_ident, #sql)
54 .fetch_all(&self.pool)
55 .await
56 }
57 }
58 } else {
59 quote! {
60 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
61 sqlx::query_as::<_, #entity_ident>(#sql)
62 .fetch_all(&self.pool)
63 .await
64 }
65 }
66 };
67 method_tokens.push(method);
68 }
69
70 if methods.paginate {
72 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
73 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
74 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
75 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
76 let sql = match db_kind {
77 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
78 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
79 };
80 let method = if query_macro {
81 quote! {
82 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
83 let total: i64 = sqlx::query_scalar!(#count_sql)
84 .fetch_one(&self.pool)
85 .await?
86 .unwrap_or(0);
87 let per_page = params.per_page;
88 let current_page = params.page;
89 let last_page = (total + per_page - 1) / per_page;
90 let offset = (current_page - 1) * per_page;
91 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
92 .fetch_all(&self.pool)
93 .await?;
94 Ok(#paginated_ident {
95 meta: #pagination_meta_ident {
96 total,
97 per_page,
98 current_page,
99 last_page,
100 first_page: 1,
101 },
102 data,
103 })
104 }
105 }
106 } else {
107 quote! {
108 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
109 let total: i64 = sqlx::query_scalar(#count_sql)
110 .fetch_one(&self.pool)
111 .await?;
112 let per_page = params.per_page;
113 let current_page = params.page;
114 let last_page = (total + per_page - 1) / per_page;
115 let offset = (current_page - 1) * per_page;
116 let data = sqlx::query_as::<_, #entity_ident>(#sql)
117 .bind(per_page)
118 .bind(offset)
119 .fetch_all(&self.pool)
120 .await?;
121 Ok(#paginated_ident {
122 meta: #pagination_meta_ident {
123 total,
124 per_page,
125 current_page,
126 last_page,
127 first_page: 1,
128 },
129 data,
130 })
131 }
132 }
133 };
134 method_tokens.push(method);
135 param_structs.push(quote! {
136 #[derive(Debug, Clone, Default)]
137 pub struct #paginate_params_ident {
138 pub page: i64,
139 pub per_page: i64,
140 }
141 });
142 param_structs.push(quote! {
143 #[derive(Debug, Clone)]
144 pub struct #pagination_meta_ident {
145 pub total: i64,
146 pub per_page: i64,
147 pub current_page: i64,
148 pub last_page: i64,
149 pub first_page: i64,
150 }
151 });
152 param_structs.push(quote! {
153 #[derive(Debug, Clone)]
154 pub struct #paginated_ident {
155 pub meta: #pagination_meta_ident,
156 pub data: Vec<#entity_ident>,
157 }
158 });
159 }
160
161 if methods.get && !pk_fields.is_empty() {
163 let pk_params: Vec<TokenStream> = pk_fields
164 .iter()
165 .map(|f| {
166 let name = format_ident!("{}", f.rust_name);
167 let ty: TokenStream = f.inner_type.parse().unwrap();
168 quote! { #name: &#ty }
169 })
170 .collect();
171
172 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
173 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
174
175 let binds: Vec<TokenStream> = pk_fields
176 .iter()
177 .map(|f| {
178 let name = format_ident!("{}", f.rust_name);
179 quote! { .bind(#name) }
180 })
181 .collect();
182
183 let method = if query_macro {
184 let pk_arg_names: Vec<TokenStream> = pk_fields
185 .iter()
186 .map(|f| {
187 let name = format_ident!("{}", f.rust_name);
188 quote! { #name }
189 })
190 .collect();
191 quote! {
192 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
193 sqlx::query_as!(#entity_ident, #sql, #(#pk_arg_names),*)
194 .fetch_optional(&self.pool)
195 .await
196 }
197 }
198 } else {
199 quote! {
200 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
201 sqlx::query_as::<_, #entity_ident>(#sql)
202 #(#binds)*
203 .fetch_optional(&self.pool)
204 .await
205 }
206 }
207 };
208 method_tokens.push(method);
209 }
210
211 if !is_view && methods.insert && !non_pk_fields.is_empty() {
213 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
214
215 let insert_fields: Vec<TokenStream> = non_pk_fields
216 .iter()
217 .map(|f| {
218 let name = format_ident!("{}", f.rust_name);
219 let ty: TokenStream = f.rust_type.parse().unwrap();
220 quote! { pub #name: #ty, }
221 })
222 .collect();
223
224 let col_names: Vec<&str> = non_pk_fields.iter().map(|f| f.column_name.as_str()).collect();
225 let col_list = col_names.join(", ");
226 let placeholders = build_placeholders(non_pk_fields.len(), db_kind, 1);
227
228 let sql = match db_kind {
229 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
230 format!(
231 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
232 table_name, col_list, placeholders
233 )
234 }
235 DatabaseKind::Mysql => {
236 format!(
237 "INSERT INTO {} ({}) VALUES ({})",
238 table_name, col_list, placeholders
239 )
240 }
241 };
242
243 let binds: Vec<TokenStream> = non_pk_fields
244 .iter()
245 .map(|f| {
246 let name = format_ident!("{}", f.rust_name);
247 quote! { .bind(¶ms.#name) }
248 })
249 .collect();
250
251 let insert_method = build_insert_method_parsed(
252 &entity_ident,
253 &insert_params_ident,
254 &sql,
255 &binds,
256 db_kind,
257 table_name,
258 &pk_fields,
259 &non_pk_fields,
260 query_macro,
261 );
262 method_tokens.push(insert_method);
263
264 param_structs.push(quote! {
265 #[derive(Debug, Clone, Default)]
266 pub struct #insert_params_ident {
267 #(#insert_fields)*
268 }
269 });
270 }
271
272 if !is_view && methods.update && !pk_fields.is_empty() {
274 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
275
276 let update_fields: Vec<TokenStream> = entity
277 .fields
278 .iter()
279 .map(|f| {
280 let name = format_ident!("{}", f.rust_name);
281 let ty: TokenStream = f.rust_type.parse().unwrap();
282 quote! { pub #name: #ty, }
283 })
284 .collect();
285
286 let set_cols: Vec<String> = non_pk_fields
287 .iter()
288 .enumerate()
289 .map(|(i, f)| {
290 let p = placeholder(db_kind, i + 1);
291 format!("{} = {}", f.column_name, p)
292 })
293 .collect();
294 let set_clause = set_cols.join(", ");
295
296 let pk_start = non_pk_fields.len() + 1;
297 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
298
299 let sql = match db_kind {
300 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
301 format!(
302 "UPDATE {} SET {} WHERE {} RETURNING *",
303 table_name, set_clause, where_clause
304 )
305 }
306 DatabaseKind::Mysql => {
307 format!(
308 "UPDATE {} SET {} WHERE {}",
309 table_name, set_clause, where_clause
310 )
311 }
312 };
313
314 let mut all_binds: Vec<TokenStream> = non_pk_fields
316 .iter()
317 .map(|f| {
318 let name = format_ident!("{}", f.rust_name);
319 quote! { .bind(¶ms.#name) }
320 })
321 .collect();
322 for f in &pk_fields {
323 let name = format_ident!("{}", f.rust_name);
324 all_binds.push(quote! { .bind(¶ms.#name) });
325 }
326
327 let update_macro_args: Vec<TokenStream> = non_pk_fields
329 .iter()
330 .chain(pk_fields.iter())
331 .map(|f| {
332 let name = format_ident!("{}", f.rust_name);
333 quote! { params.#name }
334 })
335 .collect();
336
337 let update_method = if query_macro {
338 match db_kind {
339 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
340 quote! {
341 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
342 sqlx::query_as!(#entity_ident, #sql, #(#update_macro_args),*)
343 .fetch_one(&self.pool)
344 .await
345 }
346 }
347 }
348 DatabaseKind::Mysql => {
349 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
350 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
351 let pk_macro_args: Vec<TokenStream> = pk_fields
352 .iter()
353 .map(|f| {
354 let name = format_ident!("{}", f.rust_name);
355 quote! { params.#name }
356 })
357 .collect();
358 quote! {
359 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
360 sqlx::query!(#sql, #(#update_macro_args),*)
361 .execute(&self.pool)
362 .await?;
363 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
364 .fetch_one(&self.pool)
365 .await
366 }
367 }
368 }
369 }
370 } else {
371 match db_kind {
372 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
373 quote! {
374 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
375 sqlx::query_as::<_, #entity_ident>(#sql)
376 #(#all_binds)*
377 .fetch_one(&self.pool)
378 .await
379 }
380 }
381 }
382 DatabaseKind::Mysql => {
383 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
384 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
385 let pk_binds: Vec<TokenStream> = pk_fields
386 .iter()
387 .map(|f| {
388 let name = format_ident!("{}", f.rust_name);
389 quote! { .bind(¶ms.#name) }
390 })
391 .collect();
392 quote! {
393 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
394 sqlx::query(#sql)
395 #(#all_binds)*
396 .execute(&self.pool)
397 .await?;
398 sqlx::query_as::<_, #entity_ident>(#select_sql)
399 #(#pk_binds)*
400 .fetch_one(&self.pool)
401 .await
402 }
403 }
404 }
405 }
406 };
407 method_tokens.push(update_method);
408
409 param_structs.push(quote! {
410 #[derive(Debug, Clone, Default)]
411 pub struct #update_params_ident {
412 #(#update_fields)*
413 }
414 });
415 }
416
417 if !is_view && methods.delete && !pk_fields.is_empty() {
419 let pk_params: Vec<TokenStream> = pk_fields
420 .iter()
421 .map(|f| {
422 let name = format_ident!("{}", f.rust_name);
423 let ty: TokenStream = f.inner_type.parse().unwrap();
424 quote! { #name: &#ty }
425 })
426 .collect();
427
428 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
429 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
430
431 let binds: Vec<TokenStream> = pk_fields
432 .iter()
433 .map(|f| {
434 let name = format_ident!("{}", f.rust_name);
435 quote! { .bind(#name) }
436 })
437 .collect();
438
439 let method = if query_macro {
440 let pk_arg_names: Vec<TokenStream> = pk_fields
441 .iter()
442 .map(|f| {
443 let name = format_ident!("{}", f.rust_name);
444 quote! { #name }
445 })
446 .collect();
447 quote! {
448 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
449 sqlx::query!(#sql, #(#pk_arg_names),*)
450 .execute(&self.pool)
451 .await?;
452 Ok(())
453 }
454 }
455 } else {
456 quote! {
457 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
458 sqlx::query(#sql)
459 #(#binds)*
460 .execute(&self.pool)
461 .await?;
462 Ok(())
463 }
464 }
465 };
466 method_tokens.push(method);
467 }
468
469 let tokens = quote! {
470 #(#param_structs)*
471
472 pub struct #repo_ident {
473 pool: #pool_type,
474 }
475
476 impl #repo_ident {
477 pub fn new(pool: #pool_type) -> Self {
478 Self { pool }
479 }
480
481 #(#method_tokens)*
482 }
483 };
484
485 (tokens, imports)
486}
487
488fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
489 match db_kind {
490 DatabaseKind::Postgres => quote! { sqlx::PgPool },
491 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
492 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
493 }
494}
495
496fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
497 match db_kind {
498 DatabaseKind::Postgres => format!("${}", index),
499 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
500 }
501}
502
503fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
504 (0..count)
505 .map(|i| placeholder(db_kind, start + i))
506 .collect::<Vec<_>>()
507 .join(", ")
508}
509
510fn build_where_clause_parsed(
511 pk_fields: &[&ParsedField],
512 db_kind: DatabaseKind,
513 start_index: usize,
514) -> String {
515 pk_fields
516 .iter()
517 .enumerate()
518 .map(|(i, f)| {
519 let p = placeholder(db_kind, start_index + i);
520 format!("{} = {}", f.column_name, p)
521 })
522 .collect::<Vec<_>>()
523 .join(" AND ")
524}
525
526#[allow(clippy::too_many_arguments)]
527fn build_insert_method_parsed(
528 entity_ident: &proc_macro2::Ident,
529 insert_params_ident: &proc_macro2::Ident,
530 sql: &str,
531 binds: &[TokenStream],
532 db_kind: DatabaseKind,
533 table_name: &str,
534 pk_fields: &[&ParsedField],
535 non_pk_fields: &[&ParsedField],
536 query_macro: bool,
537) -> TokenStream {
538 if query_macro {
539 let macro_args: Vec<TokenStream> = non_pk_fields
540 .iter()
541 .map(|f| {
542 let name = format_ident!("{}", f.rust_name);
543 quote! { params.#name }
544 })
545 .collect();
546
547 match db_kind {
548 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
549 quote! {
550 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
551 sqlx::query_as!(#entity_ident, #sql, #(#macro_args),*)
552 .fetch_one(&self.pool)
553 .await
554 }
555 }
556 }
557 DatabaseKind::Mysql => {
558 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
559 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
560 quote! {
561 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
562 sqlx::query!(#sql, #(#macro_args),*)
563 .execute(&self.pool)
564 .await?;
565 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
566 .fetch_one(&self.pool)
567 .await?;
568 sqlx::query_as!(#entity_ident, #select_sql, id)
569 .fetch_one(&self.pool)
570 .await
571 }
572 }
573 }
574 }
575 } else {
576 match db_kind {
577 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
578 quote! {
579 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
580 sqlx::query_as::<_, #entity_ident>(#sql)
581 #(#binds)*
582 .fetch_one(&self.pool)
583 .await
584 }
585 }
586 }
587 DatabaseKind::Mysql => {
588 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
589 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
590 quote! {
591 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
592 sqlx::query(#sql)
593 #(#binds)*
594 .execute(&self.pool)
595 .await?;
596 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
597 .fetch_one(&self.pool)
598 .await?;
599 sqlx::query_as::<_, #entity_ident>(#select_sql)
600 .bind(id)
601 .fetch_one(&self.pool)
602 .await
603 }
604 }
605 }
606 }
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::codegen::parse_and_format;
614 use crate::cli::Methods;
615
616 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
617 let inner_type = if nullable {
618 rust_type
620 .strip_prefix("Option<")
621 .and_then(|s| s.strip_suffix('>'))
622 .unwrap_or(rust_type)
623 .to_string()
624 } else {
625 rust_type.to_string()
626 };
627 ParsedField {
628 rust_name: rust_name.to_string(),
629 column_name: column_name.to_string(),
630 rust_type: rust_type.to_string(),
631 is_nullable: nullable,
632 inner_type,
633 is_primary_key: is_pk,
634 }
635 }
636
637 fn standard_entity() -> ParsedEntity {
638 ParsedEntity {
639 struct_name: "Users".to_string(),
640 table_name: "users".to_string(),
641 schema_name: None,
642 is_view: false,
643 fields: vec![
644 make_field("id", "id", "i32", false, true),
645 make_field("name", "name", "String", false, false),
646 make_field("email", "email", "Option<String>", true, false),
647 ],
648 imports: vec![],
649 }
650 }
651
652 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
653 let skip = Methods::all();
654 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false);
655 parse_and_format(&tokens)
656 }
657
658 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
659 let skip = Methods::all();
660 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true);
661 parse_and_format(&tokens)
662 }
663
664 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
665 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false);
666 parse_and_format(&tokens)
667 }
668
669 #[test]
672 fn test_repo_struct_name() {
673 let code = gen(&standard_entity(), DatabaseKind::Postgres);
674 assert!(code.contains("pub struct UsersRepository"));
675 }
676
677 #[test]
678 fn test_repo_new_method() {
679 let code = gen(&standard_entity(), DatabaseKind::Postgres);
680 assert!(code.contains("pub fn new("));
681 }
682
683 #[test]
684 fn test_repo_pool_field_pg() {
685 let code = gen(&standard_entity(), DatabaseKind::Postgres);
686 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
687 }
688
689 #[test]
690 fn test_repo_pool_field_mysql() {
691 let code = gen(&standard_entity(), DatabaseKind::Mysql);
692 assert!(code.contains("MySqlPool") || code.contains("MySql"));
693 }
694
695 #[test]
696 fn test_repo_pool_field_sqlite() {
697 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
698 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
699 }
700
701 #[test]
704 fn test_get_all_method() {
705 let code = gen(&standard_entity(), DatabaseKind::Postgres);
706 assert!(code.contains("pub async fn get_all"));
707 }
708
709 #[test]
710 fn test_get_all_returns_vec() {
711 let code = gen(&standard_entity(), DatabaseKind::Postgres);
712 assert!(code.contains("Vec<Users>"));
713 }
714
715 #[test]
716 fn test_get_all_sql() {
717 let code = gen(&standard_entity(), DatabaseKind::Postgres);
718 assert!(code.contains("SELECT * FROM users"));
719 }
720
721 #[test]
724 fn test_paginate_method() {
725 let code = gen(&standard_entity(), DatabaseKind::Postgres);
726 assert!(code.contains("pub async fn paginate"));
727 }
728
729 #[test]
730 fn test_paginate_params_struct() {
731 let code = gen(&standard_entity(), DatabaseKind::Postgres);
732 assert!(code.contains("pub struct PaginateUsersParams"));
733 }
734
735 #[test]
736 fn test_paginate_params_fields() {
737 let code = gen(&standard_entity(), DatabaseKind::Postgres);
738 assert!(code.contains("pub page: i64"));
739 assert!(code.contains("pub per_page: i64"));
740 }
741
742 #[test]
743 fn test_paginate_returns_paginated() {
744 let code = gen(&standard_entity(), DatabaseKind::Postgres);
745 assert!(code.contains("PaginatedUsers"));
746 assert!(code.contains("PaginationUsersMeta"));
747 }
748
749 #[test]
750 fn test_paginate_meta_struct() {
751 let code = gen(&standard_entity(), DatabaseKind::Postgres);
752 assert!(code.contains("pub struct PaginationUsersMeta"));
753 assert!(code.contains("pub total: i64"));
754 assert!(code.contains("pub last_page: i64"));
755 assert!(code.contains("pub first_page: i64"));
756 assert!(code.contains("pub current_page: i64"));
757 }
758
759 #[test]
760 fn test_paginate_data_struct() {
761 let code = gen(&standard_entity(), DatabaseKind::Postgres);
762 assert!(code.contains("pub struct PaginatedUsers"));
763 assert!(code.contains("pub meta: PaginationUsersMeta"));
764 assert!(code.contains("pub data: Vec<Users>"));
765 }
766
767 #[test]
768 fn test_paginate_count_sql() {
769 let code = gen(&standard_entity(), DatabaseKind::Postgres);
770 assert!(code.contains("SELECT COUNT(*) FROM users"));
771 }
772
773 #[test]
774 fn test_paginate_sql_pg() {
775 let code = gen(&standard_entity(), DatabaseKind::Postgres);
776 assert!(code.contains("LIMIT $1 OFFSET $2"));
777 }
778
779 #[test]
780 fn test_paginate_sql_mysql() {
781 let code = gen(&standard_entity(), DatabaseKind::Mysql);
782 assert!(code.contains("LIMIT ? OFFSET ?"));
783 }
784
785 #[test]
788 fn test_get_method() {
789 let code = gen(&standard_entity(), DatabaseKind::Postgres);
790 assert!(code.contains("pub async fn get"));
791 }
792
793 #[test]
794 fn test_get_returns_option() {
795 let code = gen(&standard_entity(), DatabaseKind::Postgres);
796 assert!(code.contains("Option<Users>"));
797 }
798
799 #[test]
800 fn test_get_where_pk_pg() {
801 let code = gen(&standard_entity(), DatabaseKind::Postgres);
802 assert!(code.contains("WHERE id = $1"));
803 }
804
805 #[test]
806 fn test_get_where_pk_mysql() {
807 let code = gen(&standard_entity(), DatabaseKind::Mysql);
808 assert!(code.contains("WHERE id = ?"));
809 }
810
811 #[test]
814 fn test_insert_method() {
815 let code = gen(&standard_entity(), DatabaseKind::Postgres);
816 assert!(code.contains("pub async fn insert"));
817 }
818
819 #[test]
820 fn test_insert_params_struct() {
821 let code = gen(&standard_entity(), DatabaseKind::Postgres);
822 assert!(code.contains("pub struct InsertUsersParams"));
823 }
824
825 #[test]
826 fn test_insert_params_no_pk() {
827 let code = gen(&standard_entity(), DatabaseKind::Postgres);
828 assert!(code.contains("pub name: String"));
829 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
830 }
831
832 #[test]
833 fn test_insert_returning_pg() {
834 let code = gen(&standard_entity(), DatabaseKind::Postgres);
835 assert!(code.contains("RETURNING *"));
836 }
837
838 #[test]
839 fn test_insert_returning_sqlite() {
840 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
841 assert!(code.contains("RETURNING *"));
842 }
843
844 #[test]
845 fn test_insert_mysql_last_insert_id() {
846 let code = gen(&standard_entity(), DatabaseKind::Mysql);
847 assert!(code.contains("LAST_INSERT_ID"));
848 }
849
850 #[test]
853 fn test_update_method() {
854 let code = gen(&standard_entity(), DatabaseKind::Postgres);
855 assert!(code.contains("pub async fn update"));
856 }
857
858 #[test]
859 fn test_update_params_struct() {
860 let code = gen(&standard_entity(), DatabaseKind::Postgres);
861 assert!(code.contains("pub struct UpdateUsersParams"));
862 }
863
864 #[test]
865 fn test_update_params_all_cols() {
866 let code = gen(&standard_entity(), DatabaseKind::Postgres);
867 assert!(code.contains("pub id: i32"));
868 assert!(code.contains("pub name: String"));
869 }
870
871 #[test]
872 fn test_update_set_clause_pg() {
873 let code = gen(&standard_entity(), DatabaseKind::Postgres);
874 assert!(code.contains("SET name = $1"));
875 assert!(code.contains("WHERE id = $3"));
876 }
877
878 #[test]
879 fn test_update_returning_pg() {
880 let code = gen(&standard_entity(), DatabaseKind::Postgres);
881 assert!(code.contains("UPDATE users SET"));
882 assert!(code.contains("RETURNING *"));
883 }
884
885 #[test]
888 fn test_delete_method() {
889 let code = gen(&standard_entity(), DatabaseKind::Postgres);
890 assert!(code.contains("pub async fn delete"));
891 }
892
893 #[test]
894 fn test_delete_where_pk() {
895 let code = gen(&standard_entity(), DatabaseKind::Postgres);
896 assert!(code.contains("DELETE FROM users WHERE id = $1"));
897 }
898
899 #[test]
900 fn test_delete_returns_unit() {
901 let code = gen(&standard_entity(), DatabaseKind::Postgres);
902 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
903 }
904
905 #[test]
908 fn test_view_no_insert() {
909 let mut entity = standard_entity();
910 entity.is_view = true;
911 let code = gen(&entity, DatabaseKind::Postgres);
912 assert!(!code.contains("pub async fn insert"));
913 }
914
915 #[test]
916 fn test_view_no_update() {
917 let mut entity = standard_entity();
918 entity.is_view = true;
919 let code = gen(&entity, DatabaseKind::Postgres);
920 assert!(!code.contains("pub async fn update"));
921 }
922
923 #[test]
924 fn test_view_no_delete() {
925 let mut entity = standard_entity();
926 entity.is_view = true;
927 let code = gen(&entity, DatabaseKind::Postgres);
928 assert!(!code.contains("pub async fn delete"));
929 }
930
931 #[test]
932 fn test_view_has_get_all() {
933 let mut entity = standard_entity();
934 entity.is_view = true;
935 let code = gen(&entity, DatabaseKind::Postgres);
936 assert!(code.contains("pub async fn get_all"));
937 }
938
939 #[test]
940 fn test_view_has_paginate() {
941 let mut entity = standard_entity();
942 entity.is_view = true;
943 let code = gen(&entity, DatabaseKind::Postgres);
944 assert!(code.contains("pub async fn paginate"));
945 }
946
947 #[test]
948 fn test_view_has_get() {
949 let mut entity = standard_entity();
950 entity.is_view = true;
951 let code = gen(&entity, DatabaseKind::Postgres);
952 assert!(code.contains("pub async fn get"));
953 }
954
955 #[test]
958 fn test_only_get_all() {
959 let m = Methods { get_all: true, ..Default::default() };
960 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
961 assert!(code.contains("pub async fn get_all"));
962 assert!(!code.contains("pub async fn paginate"));
963 assert!(!code.contains("pub async fn insert"));
964 }
965
966 #[test]
967 fn test_without_get_all() {
968 let m = Methods { get_all: false, ..Methods::all() };
969 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
970 assert!(!code.contains("pub async fn get_all"));
971 }
972
973 #[test]
974 fn test_without_paginate() {
975 let m = Methods { paginate: false, ..Methods::all() };
976 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
977 assert!(!code.contains("pub async fn paginate"));
978 assert!(!code.contains("PaginateUsersParams"));
979 }
980
981 #[test]
982 fn test_without_get() {
983 let m = Methods { get: false, ..Methods::all() };
984 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
985 assert!(code.contains("pub async fn get_all"));
986 let without_get_all = code.replace("get_all", "XXX");
987 assert!(!without_get_all.contains("fn get("));
988 }
989
990 #[test]
991 fn test_without_insert() {
992 let m = Methods { insert: false, ..Methods::all() };
993 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
994 assert!(!code.contains("pub async fn insert"));
995 assert!(!code.contains("InsertUsersParams"));
996 }
997
998 #[test]
999 fn test_without_update() {
1000 let m = Methods { update: false, ..Methods::all() };
1001 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1002 assert!(!code.contains("pub async fn update"));
1003 assert!(!code.contains("UpdateUsersParams"));
1004 }
1005
1006 #[test]
1007 fn test_without_delete() {
1008 let m = Methods { delete: false, ..Methods::all() };
1009 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1010 assert!(!code.contains("pub async fn delete"));
1011 }
1012
1013 #[test]
1014 fn test_empty_methods_no_methods() {
1015 let m = Methods::default();
1016 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1017 assert!(!code.contains("pub async fn get_all"));
1018 assert!(!code.contains("pub async fn paginate"));
1019 assert!(!code.contains("pub async fn insert"));
1020 assert!(!code.contains("pub async fn update"));
1021 assert!(!code.contains("pub async fn delete"));
1022 }
1023
1024 #[test]
1027 fn test_no_pool_import() {
1028 let skip = Methods::all();
1029 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1030 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1031 }
1032
1033 #[test]
1034 fn test_imports_contain_entity() {
1035 let skip = Methods::all();
1036 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1037 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1038 }
1039
1040 #[test]
1043 fn test_renamed_column_in_sql() {
1044 let entity = ParsedEntity {
1045 struct_name: "Connector".to_string(),
1046 table_name: "connector".to_string(),
1047 schema_name: None,
1048 is_view: false,
1049 fields: vec![
1050 make_field("id", "id", "i32", false, true),
1051 make_field("connector_type", "type", "String", false, false),
1052 ],
1053 imports: vec![],
1054 };
1055 let code = gen(&entity, DatabaseKind::Postgres);
1056 assert!(code.contains("type"));
1058 assert!(code.contains("pub connector_type: String"));
1060 }
1061
1062 #[test]
1065 fn test_no_pk_no_get() {
1066 let entity = ParsedEntity {
1067 struct_name: "Logs".to_string(),
1068 table_name: "logs".to_string(),
1069 schema_name: None,
1070 is_view: false,
1071 fields: vec![
1072 make_field("message", "message", "String", false, false),
1073 make_field("ts", "ts", "String", false, false),
1074 ],
1075 imports: vec![],
1076 };
1077 let code = gen(&entity, DatabaseKind::Postgres);
1078 assert!(code.contains("pub async fn get_all"));
1079 let without_get_all = code.replace("get_all", "XXX");
1080 assert!(!without_get_all.contains("fn get("));
1081 }
1082
1083 #[test]
1084 fn test_no_pk_no_delete() {
1085 let entity = ParsedEntity {
1086 struct_name: "Logs".to_string(),
1087 table_name: "logs".to_string(),
1088 schema_name: None,
1089 is_view: false,
1090 fields: vec![
1091 make_field("message", "message", "String", false, false),
1092 ],
1093 imports: vec![],
1094 };
1095 let code = gen(&entity, DatabaseKind::Postgres);
1096 assert!(!code.contains("pub async fn delete"));
1097 }
1098
1099 #[test]
1102 fn test_param_structs_have_default() {
1103 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1104 assert!(code.contains("Default"));
1105 }
1106
1107 #[test]
1110 fn test_entity_imports_forwarded() {
1111 let entity = ParsedEntity {
1112 struct_name: "Users".to_string(),
1113 table_name: "users".to_string(),
1114 schema_name: None,
1115 is_view: false,
1116 fields: vec![
1117 make_field("id", "id", "Uuid", false, true),
1118 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1119 ],
1120 imports: vec![
1121 "use chrono::{DateTime, Utc};".to_string(),
1122 "use uuid::Uuid;".to_string(),
1123 ],
1124 };
1125 let skip = Methods::all();
1126 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false);
1127 assert!(imports.iter().any(|i| i.contains("chrono")));
1128 assert!(imports.iter().any(|i| i.contains("uuid")));
1129 }
1130
1131 #[test]
1132 fn test_entity_imports_empty_when_no_imports() {
1133 let skip = Methods::all();
1134 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1135 assert!(!imports.iter().any(|i| i.contains("chrono")));
1137 assert!(!imports.iter().any(|i| i.contains("uuid")));
1138 }
1139
1140 #[test]
1143 fn test_macro_get_all() {
1144 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1145 assert!(code.contains("query_as!"));
1146 assert!(!code.contains("query_as::<"));
1147 }
1148
1149 #[test]
1150 fn test_macro_paginate() {
1151 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1152 assert!(code.contains("query_as!"));
1153 assert!(code.contains("per_page, offset"));
1154 }
1155
1156 #[test]
1157 fn test_macro_get() {
1158 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1159 assert!(code.contains("query_as!(Users"));
1161 }
1162
1163 #[test]
1164 fn test_macro_insert_pg() {
1165 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1166 assert!(code.contains("query_as!(Users"));
1167 assert!(code.contains("params.name"));
1168 assert!(code.contains("params.email"));
1169 }
1170
1171 #[test]
1172 fn test_macro_insert_mysql() {
1173 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1174 assert!(code.contains("query!"));
1176 assert!(code.contains("query_scalar!"));
1177 }
1178
1179 #[test]
1180 fn test_macro_update() {
1181 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1182 assert!(code.contains("query_as!(Users"));
1183 assert!(code.contains("params.name"));
1185 assert!(code.contains("params.id"));
1186 }
1187
1188 #[test]
1189 fn test_macro_delete() {
1190 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1191 assert!(code.contains("query!"));
1193 }
1194
1195 #[test]
1196 fn test_macro_no_bind_calls() {
1197 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1198 assert!(!code.contains(".bind("));
1199 }
1200
1201 #[test]
1202 fn test_function_style_uses_bind() {
1203 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1204 assert!(code.contains(".bind("));
1205 assert!(!code.contains("query_as!("));
1206 assert!(!code.contains("query!("));
1207 }
1208}