1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17
18 let entity_ident = format_ident!("{}", entity.struct_name);
19 let repo_name = format!("{}Repository", entity.struct_name);
20 let repo_ident = format_ident!("{}", repo_name);
21
22 let table_name = &entity.table_name;
23
24 let pool_type = pool_type_tokens(db_kind);
26
27 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
29
30 for imp in &entity.imports {
32 imports.insert(imp.clone());
33 }
34
35 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
37
38 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
40
41 let is_view = entity.is_view;
42
43 let mut method_tokens = Vec::new();
45 let mut param_structs = Vec::new();
46
47 if methods.get_all {
49 let sql = format!("SELECT * FROM {}", table_name);
50 let method = if query_macro {
51 quote! {
52 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
53 sqlx::query_as!(#entity_ident, #sql)
54 .fetch_all(&self.pool)
55 .await
56 }
57 }
58 } else {
59 quote! {
60 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
61 sqlx::query_as::<_, #entity_ident>(#sql)
62 .fetch_all(&self.pool)
63 .await
64 }
65 }
66 };
67 method_tokens.push(method);
68 }
69
70 if methods.paginate {
72 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
73 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
74 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
75 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
76 let sql = match db_kind {
77 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
78 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
79 };
80 let method = if query_macro {
81 quote! {
82 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
83 let total: i64 = sqlx::query_scalar!(#count_sql)
84 .fetch_one(&self.pool)
85 .await?
86 .unwrap_or(0);
87 let per_page = params.per_page;
88 let current_page = params.page;
89 let last_page = (total + per_page - 1) / per_page;
90 let offset = (current_page - 1) * per_page;
91 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
92 .fetch_all(&self.pool)
93 .await?;
94 Ok(#paginated_ident {
95 meta: #pagination_meta_ident {
96 total,
97 per_page,
98 current_page,
99 last_page,
100 first_page: 1,
101 },
102 data,
103 })
104 }
105 }
106 } else {
107 quote! {
108 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
109 let total: i64 = sqlx::query_scalar(#count_sql)
110 .fetch_one(&self.pool)
111 .await?;
112 let per_page = params.per_page;
113 let current_page = params.page;
114 let last_page = (total + per_page - 1) / per_page;
115 let offset = (current_page - 1) * per_page;
116 let data = sqlx::query_as::<_, #entity_ident>(#sql)
117 .bind(per_page)
118 .bind(offset)
119 .fetch_all(&self.pool)
120 .await?;
121 Ok(#paginated_ident {
122 meta: #pagination_meta_ident {
123 total,
124 per_page,
125 current_page,
126 last_page,
127 first_page: 1,
128 },
129 data,
130 })
131 }
132 }
133 };
134 method_tokens.push(method);
135 param_structs.push(quote! {
136 #[derive(Debug, Clone, Default)]
137 pub struct #paginate_params_ident {
138 pub page: i64,
139 pub per_page: i64,
140 }
141 });
142 param_structs.push(quote! {
143 #[derive(Debug, Clone)]
144 pub struct #pagination_meta_ident {
145 pub total: i64,
146 pub per_page: i64,
147 pub current_page: i64,
148 pub last_page: i64,
149 pub first_page: i64,
150 }
151 });
152 param_structs.push(quote! {
153 #[derive(Debug, Clone)]
154 pub struct #paginated_ident {
155 pub meta: #pagination_meta_ident,
156 pub data: Vec<#entity_ident>,
157 }
158 });
159 }
160
161 if methods.get && !pk_fields.is_empty() {
163 let pk_params: Vec<TokenStream> = pk_fields
164 .iter()
165 .map(|f| {
166 let name = format_ident!("{}", f.rust_name);
167 let ty: TokenStream = f.inner_type.parse().unwrap();
168 quote! { #name: &#ty }
169 })
170 .collect();
171
172 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
173 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
174
175 let binds: Vec<TokenStream> = pk_fields
176 .iter()
177 .map(|f| {
178 let name = format_ident!("{}", f.rust_name);
179 quote! { .bind(#name) }
180 })
181 .collect();
182
183 let method = if query_macro {
184 let pk_arg_names: Vec<TokenStream> = pk_fields
185 .iter()
186 .map(|f| {
187 let name = format_ident!("{}", f.rust_name);
188 quote! { #name }
189 })
190 .collect();
191 quote! {
192 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
193 sqlx::query_as!(#entity_ident, #sql, #(#pk_arg_names),*)
194 .fetch_optional(&self.pool)
195 .await
196 }
197 }
198 } else {
199 quote! {
200 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
201 sqlx::query_as::<_, #entity_ident>(#sql)
202 #(#binds)*
203 .fetch_optional(&self.pool)
204 .await
205 }
206 }
207 };
208 method_tokens.push(method);
209 }
210
211 if !is_view && methods.insert && !non_pk_fields.is_empty() {
213 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
214
215 let insert_fields: Vec<TokenStream> = non_pk_fields
216 .iter()
217 .map(|f| {
218 let name = format_ident!("{}", f.rust_name);
219 let ty: TokenStream = f.rust_type.parse().unwrap();
220 quote! { pub #name: #ty, }
221 })
222 .collect();
223
224 let col_names: Vec<&str> = non_pk_fields.iter().map(|f| f.column_name.as_str()).collect();
225 let col_list = col_names.join(", ");
226 let placeholders = build_placeholders(non_pk_fields.len(), db_kind, 1);
227
228 let sql = match db_kind {
229 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
230 format!(
231 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
232 table_name, col_list, placeholders
233 )
234 }
235 DatabaseKind::Mysql => {
236 format!(
237 "INSERT INTO {} ({}) VALUES ({})",
238 table_name, col_list, placeholders
239 )
240 }
241 };
242
243 let binds: Vec<TokenStream> = non_pk_fields
244 .iter()
245 .map(|f| {
246 let name = format_ident!("{}", f.rust_name);
247 quote! { .bind(¶ms.#name) }
248 })
249 .collect();
250
251 let insert_method = build_insert_method_parsed(
252 &entity_ident,
253 &insert_params_ident,
254 &sql,
255 &binds,
256 db_kind,
257 table_name,
258 &pk_fields,
259 &non_pk_fields,
260 query_macro,
261 );
262 method_tokens.push(insert_method);
263
264 param_structs.push(quote! {
265 #[derive(Debug, Clone, Default)]
266 pub struct #insert_params_ident {
267 #(#insert_fields)*
268 }
269 });
270 }
271
272 if !is_view && methods.update && !pk_fields.is_empty() {
274 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
275
276 let update_fields: Vec<TokenStream> = entity
277 .fields
278 .iter()
279 .map(|f| {
280 let name = format_ident!("{}", f.rust_name);
281 let ty: TokenStream = f.rust_type.parse().unwrap();
282 quote! { pub #name: #ty, }
283 })
284 .collect();
285
286 let set_cols: Vec<String> = non_pk_fields
287 .iter()
288 .enumerate()
289 .map(|(i, f)| {
290 let p = placeholder(db_kind, i + 1);
291 format!("{} = {}", f.column_name, p)
292 })
293 .collect();
294 let set_clause = set_cols.join(", ");
295
296 let pk_start = non_pk_fields.len() + 1;
297 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
298
299 let sql = match db_kind {
300 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
301 format!(
302 "UPDATE {} SET {} WHERE {} RETURNING *",
303 table_name, set_clause, where_clause
304 )
305 }
306 DatabaseKind::Mysql => {
307 format!(
308 "UPDATE {} SET {} WHERE {}",
309 table_name, set_clause, where_clause
310 )
311 }
312 };
313
314 let mut all_binds: Vec<TokenStream> = non_pk_fields
316 .iter()
317 .map(|f| {
318 let name = format_ident!("{}", f.rust_name);
319 quote! { .bind(¶ms.#name) }
320 })
321 .collect();
322 for f in &pk_fields {
323 let name = format_ident!("{}", f.rust_name);
324 all_binds.push(quote! { .bind(¶ms.#name) });
325 }
326
327 let update_macro_args: Vec<TokenStream> = non_pk_fields
329 .iter()
330 .chain(pk_fields.iter())
331 .map(|f| {
332 let name = format_ident!("{}", f.rust_name);
333 quote! { params.#name }
334 })
335 .collect();
336
337 let update_method = if query_macro {
338 match db_kind {
339 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
340 quote! {
341 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
342 sqlx::query_as!(#entity_ident, #sql, #(#update_macro_args),*)
343 .fetch_one(&self.pool)
344 .await
345 }
346 }
347 }
348 DatabaseKind::Mysql => {
349 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
350 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
351 let pk_macro_args: Vec<TokenStream> = pk_fields
352 .iter()
353 .map(|f| {
354 let name = format_ident!("{}", f.rust_name);
355 quote! { params.#name }
356 })
357 .collect();
358 quote! {
359 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
360 sqlx::query!(#sql, #(#update_macro_args),*)
361 .execute(&self.pool)
362 .await?;
363 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
364 .fetch_one(&self.pool)
365 .await
366 }
367 }
368 }
369 }
370 } else {
371 match db_kind {
372 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
373 quote! {
374 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
375 sqlx::query_as::<_, #entity_ident>(#sql)
376 #(#all_binds)*
377 .fetch_one(&self.pool)
378 .await
379 }
380 }
381 }
382 DatabaseKind::Mysql => {
383 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
384 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
385 let pk_binds: Vec<TokenStream> = pk_fields
386 .iter()
387 .map(|f| {
388 let name = format_ident!("{}", f.rust_name);
389 quote! { .bind(¶ms.#name) }
390 })
391 .collect();
392 quote! {
393 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
394 sqlx::query(#sql)
395 #(#all_binds)*
396 .execute(&self.pool)
397 .await?;
398 sqlx::query_as::<_, #entity_ident>(#select_sql)
399 #(#pk_binds)*
400 .fetch_one(&self.pool)
401 .await
402 }
403 }
404 }
405 }
406 };
407 method_tokens.push(update_method);
408
409 param_structs.push(quote! {
410 #[derive(Debug, Clone, Default)]
411 pub struct #update_params_ident {
412 #(#update_fields)*
413 }
414 });
415 }
416
417 if !is_view && methods.delete && !pk_fields.is_empty() {
419 let pk_params: Vec<TokenStream> = pk_fields
420 .iter()
421 .map(|f| {
422 let name = format_ident!("{}", f.rust_name);
423 let ty: TokenStream = f.inner_type.parse().unwrap();
424 quote! { #name: &#ty }
425 })
426 .collect();
427
428 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
429 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
430
431 let binds: Vec<TokenStream> = pk_fields
432 .iter()
433 .map(|f| {
434 let name = format_ident!("{}", f.rust_name);
435 quote! { .bind(#name) }
436 })
437 .collect();
438
439 let method = if query_macro {
440 let pk_arg_names: Vec<TokenStream> = pk_fields
441 .iter()
442 .map(|f| {
443 let name = format_ident!("{}", f.rust_name);
444 quote! { #name }
445 })
446 .collect();
447 quote! {
448 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
449 sqlx::query!(#sql, #(#pk_arg_names),*)
450 .execute(&self.pool)
451 .await?;
452 Ok(())
453 }
454 }
455 } else {
456 quote! {
457 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
458 sqlx::query(#sql)
459 #(#binds)*
460 .execute(&self.pool)
461 .await?;
462 Ok(())
463 }
464 }
465 };
466 method_tokens.push(method);
467 }
468
469 let tokens = quote! {
470 #(#param_structs)*
471
472 pub struct #repo_ident {
473 pool: #pool_type,
474 }
475
476 impl #repo_ident {
477 pub fn new(pool: #pool_type) -> Self {
478 Self { pool }
479 }
480
481 #(#method_tokens)*
482 }
483 };
484
485 (tokens, imports)
486}
487
488fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
489 match db_kind {
490 DatabaseKind::Postgres => quote! { sqlx::PgPool },
491 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
492 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
493 }
494}
495
496fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
497 match db_kind {
498 DatabaseKind::Postgres => format!("${}", index),
499 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
500 }
501}
502
503fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
504 (0..count)
505 .map(|i| placeholder(db_kind, start + i))
506 .collect::<Vec<_>>()
507 .join(", ")
508}
509
510fn build_where_clause_parsed(
511 pk_fields: &[&ParsedField],
512 db_kind: DatabaseKind,
513 start_index: usize,
514) -> String {
515 pk_fields
516 .iter()
517 .enumerate()
518 .map(|(i, f)| {
519 let p = placeholder(db_kind, start_index + i);
520 format!("{} = {}", f.column_name, p)
521 })
522 .collect::<Vec<_>>()
523 .join(" AND ")
524}
525
526#[allow(clippy::too_many_arguments)]
527fn build_insert_method_parsed(
528 entity_ident: &proc_macro2::Ident,
529 insert_params_ident: &proc_macro2::Ident,
530 sql: &str,
531 binds: &[TokenStream],
532 db_kind: DatabaseKind,
533 table_name: &str,
534 pk_fields: &[&ParsedField],
535 non_pk_fields: &[&ParsedField],
536 query_macro: bool,
537) -> TokenStream {
538 if query_macro {
539 let macro_args: Vec<TokenStream> = non_pk_fields
540 .iter()
541 .map(|f| {
542 let name = format_ident!("{}", f.rust_name);
543 quote! { params.#name }
544 })
545 .collect();
546
547 match db_kind {
548 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
549 quote! {
550 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
551 sqlx::query_as!(#entity_ident, #sql, #(#macro_args),*)
552 .fetch_one(&self.pool)
553 .await
554 }
555 }
556 }
557 DatabaseKind::Mysql => {
558 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
559 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
560 quote! {
561 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
562 sqlx::query!(#sql, #(#macro_args),*)
563 .execute(&self.pool)
564 .await?;
565 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
566 .fetch_one(&self.pool)
567 .await?;
568 sqlx::query_as!(#entity_ident, #select_sql, id)
569 .fetch_one(&self.pool)
570 .await
571 }
572 }
573 }
574 }
575 } else {
576 match db_kind {
577 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
578 quote! {
579 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
580 sqlx::query_as::<_, #entity_ident>(#sql)
581 #(#binds)*
582 .fetch_one(&self.pool)
583 .await
584 }
585 }
586 }
587 DatabaseKind::Mysql => {
588 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
589 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
590 quote! {
591 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
592 sqlx::query(#sql)
593 #(#binds)*
594 .execute(&self.pool)
595 .await?;
596 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
597 .fetch_one(&self.pool)
598 .await?;
599 sqlx::query_as::<_, #entity_ident>(#select_sql)
600 .bind(id)
601 .fetch_one(&self.pool)
602 .await
603 }
604 }
605 }
606 }
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::codegen::parse_and_format;
614 use crate::cli::Methods;
615
616 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
617 let inner_type = if nullable {
618 rust_type
620 .strip_prefix("Option<")
621 .and_then(|s| s.strip_suffix('>'))
622 .unwrap_or(rust_type)
623 .to_string()
624 } else {
625 rust_type.to_string()
626 };
627 ParsedField {
628 rust_name: rust_name.to_string(),
629 column_name: column_name.to_string(),
630 rust_type: rust_type.to_string(),
631 is_nullable: nullable,
632 inner_type,
633 is_primary_key: is_pk,
634 }
635 }
636
637 fn standard_entity() -> ParsedEntity {
638 ParsedEntity {
639 struct_name: "Users".to_string(),
640 table_name: "users".to_string(),
641 is_view: false,
642 fields: vec![
643 make_field("id", "id", "i32", false, true),
644 make_field("name", "name", "String", false, false),
645 make_field("email", "email", "Option<String>", true, false),
646 ],
647 imports: vec![],
648 }
649 }
650
651 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
652 let skip = Methods::all();
653 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false);
654 parse_and_format(&tokens)
655 }
656
657 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
658 let skip = Methods::all();
659 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true);
660 parse_and_format(&tokens)
661 }
662
663 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
664 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false);
665 parse_and_format(&tokens)
666 }
667
668 #[test]
671 fn test_repo_struct_name() {
672 let code = gen(&standard_entity(), DatabaseKind::Postgres);
673 assert!(code.contains("pub struct UsersRepository"));
674 }
675
676 #[test]
677 fn test_repo_new_method() {
678 let code = gen(&standard_entity(), DatabaseKind::Postgres);
679 assert!(code.contains("pub fn new("));
680 }
681
682 #[test]
683 fn test_repo_pool_field_pg() {
684 let code = gen(&standard_entity(), DatabaseKind::Postgres);
685 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
686 }
687
688 #[test]
689 fn test_repo_pool_field_mysql() {
690 let code = gen(&standard_entity(), DatabaseKind::Mysql);
691 assert!(code.contains("MySqlPool") || code.contains("MySql"));
692 }
693
694 #[test]
695 fn test_repo_pool_field_sqlite() {
696 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
697 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
698 }
699
700 #[test]
703 fn test_get_all_method() {
704 let code = gen(&standard_entity(), DatabaseKind::Postgres);
705 assert!(code.contains("pub async fn get_all"));
706 }
707
708 #[test]
709 fn test_get_all_returns_vec() {
710 let code = gen(&standard_entity(), DatabaseKind::Postgres);
711 assert!(code.contains("Vec<Users>"));
712 }
713
714 #[test]
715 fn test_get_all_sql() {
716 let code = gen(&standard_entity(), DatabaseKind::Postgres);
717 assert!(code.contains("SELECT * FROM users"));
718 }
719
720 #[test]
723 fn test_paginate_method() {
724 let code = gen(&standard_entity(), DatabaseKind::Postgres);
725 assert!(code.contains("pub async fn paginate"));
726 }
727
728 #[test]
729 fn test_paginate_params_struct() {
730 let code = gen(&standard_entity(), DatabaseKind::Postgres);
731 assert!(code.contains("pub struct PaginateUsersParams"));
732 }
733
734 #[test]
735 fn test_paginate_params_fields() {
736 let code = gen(&standard_entity(), DatabaseKind::Postgres);
737 assert!(code.contains("pub page: i64"));
738 assert!(code.contains("pub per_page: i64"));
739 }
740
741 #[test]
742 fn test_paginate_returns_paginated() {
743 let code = gen(&standard_entity(), DatabaseKind::Postgres);
744 assert!(code.contains("PaginatedUsers"));
745 assert!(code.contains("PaginationUsersMeta"));
746 }
747
748 #[test]
749 fn test_paginate_meta_struct() {
750 let code = gen(&standard_entity(), DatabaseKind::Postgres);
751 assert!(code.contains("pub struct PaginationUsersMeta"));
752 assert!(code.contains("pub total: i64"));
753 assert!(code.contains("pub last_page: i64"));
754 assert!(code.contains("pub first_page: i64"));
755 assert!(code.contains("pub current_page: i64"));
756 }
757
758 #[test]
759 fn test_paginate_data_struct() {
760 let code = gen(&standard_entity(), DatabaseKind::Postgres);
761 assert!(code.contains("pub struct PaginatedUsers"));
762 assert!(code.contains("pub meta: PaginationUsersMeta"));
763 assert!(code.contains("pub data: Vec<Users>"));
764 }
765
766 #[test]
767 fn test_paginate_count_sql() {
768 let code = gen(&standard_entity(), DatabaseKind::Postgres);
769 assert!(code.contains("SELECT COUNT(*) FROM users"));
770 }
771
772 #[test]
773 fn test_paginate_sql_pg() {
774 let code = gen(&standard_entity(), DatabaseKind::Postgres);
775 assert!(code.contains("LIMIT $1 OFFSET $2"));
776 }
777
778 #[test]
779 fn test_paginate_sql_mysql() {
780 let code = gen(&standard_entity(), DatabaseKind::Mysql);
781 assert!(code.contains("LIMIT ? OFFSET ?"));
782 }
783
784 #[test]
787 fn test_get_method() {
788 let code = gen(&standard_entity(), DatabaseKind::Postgres);
789 assert!(code.contains("pub async fn get"));
790 }
791
792 #[test]
793 fn test_get_returns_option() {
794 let code = gen(&standard_entity(), DatabaseKind::Postgres);
795 assert!(code.contains("Option<Users>"));
796 }
797
798 #[test]
799 fn test_get_where_pk_pg() {
800 let code = gen(&standard_entity(), DatabaseKind::Postgres);
801 assert!(code.contains("WHERE id = $1"));
802 }
803
804 #[test]
805 fn test_get_where_pk_mysql() {
806 let code = gen(&standard_entity(), DatabaseKind::Mysql);
807 assert!(code.contains("WHERE id = ?"));
808 }
809
810 #[test]
813 fn test_insert_method() {
814 let code = gen(&standard_entity(), DatabaseKind::Postgres);
815 assert!(code.contains("pub async fn insert"));
816 }
817
818 #[test]
819 fn test_insert_params_struct() {
820 let code = gen(&standard_entity(), DatabaseKind::Postgres);
821 assert!(code.contains("pub struct InsertUsersParams"));
822 }
823
824 #[test]
825 fn test_insert_params_no_pk() {
826 let code = gen(&standard_entity(), DatabaseKind::Postgres);
827 assert!(code.contains("pub name: String"));
828 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
829 }
830
831 #[test]
832 fn test_insert_returning_pg() {
833 let code = gen(&standard_entity(), DatabaseKind::Postgres);
834 assert!(code.contains("RETURNING *"));
835 }
836
837 #[test]
838 fn test_insert_returning_sqlite() {
839 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
840 assert!(code.contains("RETURNING *"));
841 }
842
843 #[test]
844 fn test_insert_mysql_last_insert_id() {
845 let code = gen(&standard_entity(), DatabaseKind::Mysql);
846 assert!(code.contains("LAST_INSERT_ID"));
847 }
848
849 #[test]
852 fn test_update_method() {
853 let code = gen(&standard_entity(), DatabaseKind::Postgres);
854 assert!(code.contains("pub async fn update"));
855 }
856
857 #[test]
858 fn test_update_params_struct() {
859 let code = gen(&standard_entity(), DatabaseKind::Postgres);
860 assert!(code.contains("pub struct UpdateUsersParams"));
861 }
862
863 #[test]
864 fn test_update_params_all_cols() {
865 let code = gen(&standard_entity(), DatabaseKind::Postgres);
866 assert!(code.contains("pub id: i32"));
867 assert!(code.contains("pub name: String"));
868 }
869
870 #[test]
871 fn test_update_set_clause_pg() {
872 let code = gen(&standard_entity(), DatabaseKind::Postgres);
873 assert!(code.contains("SET name = $1"));
874 assert!(code.contains("WHERE id = $3"));
875 }
876
877 #[test]
878 fn test_update_returning_pg() {
879 let code = gen(&standard_entity(), DatabaseKind::Postgres);
880 assert!(code.contains("UPDATE users SET"));
881 assert!(code.contains("RETURNING *"));
882 }
883
884 #[test]
887 fn test_delete_method() {
888 let code = gen(&standard_entity(), DatabaseKind::Postgres);
889 assert!(code.contains("pub async fn delete"));
890 }
891
892 #[test]
893 fn test_delete_where_pk() {
894 let code = gen(&standard_entity(), DatabaseKind::Postgres);
895 assert!(code.contains("DELETE FROM users WHERE id = $1"));
896 }
897
898 #[test]
899 fn test_delete_returns_unit() {
900 let code = gen(&standard_entity(), DatabaseKind::Postgres);
901 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
902 }
903
904 #[test]
907 fn test_view_no_insert() {
908 let mut entity = standard_entity();
909 entity.is_view = true;
910 let code = gen(&entity, DatabaseKind::Postgres);
911 assert!(!code.contains("pub async fn insert"));
912 }
913
914 #[test]
915 fn test_view_no_update() {
916 let mut entity = standard_entity();
917 entity.is_view = true;
918 let code = gen(&entity, DatabaseKind::Postgres);
919 assert!(!code.contains("pub async fn update"));
920 }
921
922 #[test]
923 fn test_view_no_delete() {
924 let mut entity = standard_entity();
925 entity.is_view = true;
926 let code = gen(&entity, DatabaseKind::Postgres);
927 assert!(!code.contains("pub async fn delete"));
928 }
929
930 #[test]
931 fn test_view_has_get_all() {
932 let mut entity = standard_entity();
933 entity.is_view = true;
934 let code = gen(&entity, DatabaseKind::Postgres);
935 assert!(code.contains("pub async fn get_all"));
936 }
937
938 #[test]
939 fn test_view_has_paginate() {
940 let mut entity = standard_entity();
941 entity.is_view = true;
942 let code = gen(&entity, DatabaseKind::Postgres);
943 assert!(code.contains("pub async fn paginate"));
944 }
945
946 #[test]
947 fn test_view_has_get() {
948 let mut entity = standard_entity();
949 entity.is_view = true;
950 let code = gen(&entity, DatabaseKind::Postgres);
951 assert!(code.contains("pub async fn get"));
952 }
953
954 #[test]
957 fn test_only_get_all() {
958 let m = Methods { get_all: true, ..Default::default() };
959 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
960 assert!(code.contains("pub async fn get_all"));
961 assert!(!code.contains("pub async fn paginate"));
962 assert!(!code.contains("pub async fn insert"));
963 }
964
965 #[test]
966 fn test_without_get_all() {
967 let m = Methods { get_all: false, ..Methods::all() };
968 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
969 assert!(!code.contains("pub async fn get_all"));
970 }
971
972 #[test]
973 fn test_without_paginate() {
974 let m = Methods { paginate: false, ..Methods::all() };
975 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
976 assert!(!code.contains("pub async fn paginate"));
977 assert!(!code.contains("PaginateUsersParams"));
978 }
979
980 #[test]
981 fn test_without_get() {
982 let m = Methods { get: false, ..Methods::all() };
983 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
984 assert!(code.contains("pub async fn get_all"));
985 let without_get_all = code.replace("get_all", "XXX");
986 assert!(!without_get_all.contains("fn get("));
987 }
988
989 #[test]
990 fn test_without_insert() {
991 let m = Methods { insert: false, ..Methods::all() };
992 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
993 assert!(!code.contains("pub async fn insert"));
994 assert!(!code.contains("InsertUsersParams"));
995 }
996
997 #[test]
998 fn test_without_update() {
999 let m = Methods { update: false, ..Methods::all() };
1000 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1001 assert!(!code.contains("pub async fn update"));
1002 assert!(!code.contains("UpdateUsersParams"));
1003 }
1004
1005 #[test]
1006 fn test_without_delete() {
1007 let m = Methods { delete: false, ..Methods::all() };
1008 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1009 assert!(!code.contains("pub async fn delete"));
1010 }
1011
1012 #[test]
1013 fn test_empty_methods_no_methods() {
1014 let m = Methods::default();
1015 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1016 assert!(!code.contains("pub async fn get_all"));
1017 assert!(!code.contains("pub async fn paginate"));
1018 assert!(!code.contains("pub async fn insert"));
1019 assert!(!code.contains("pub async fn update"));
1020 assert!(!code.contains("pub async fn delete"));
1021 }
1022
1023 #[test]
1026 fn test_no_pool_import() {
1027 let skip = Methods::all();
1028 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1029 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1030 }
1031
1032 #[test]
1033 fn test_imports_contain_entity() {
1034 let skip = Methods::all();
1035 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1036 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1037 }
1038
1039 #[test]
1042 fn test_renamed_column_in_sql() {
1043 let entity = ParsedEntity {
1044 struct_name: "Connector".to_string(),
1045 table_name: "connector".to_string(),
1046 is_view: false,
1047 fields: vec![
1048 make_field("id", "id", "i32", false, true),
1049 make_field("connector_type", "type", "String", false, false),
1050 ],
1051 imports: vec![],
1052 };
1053 let code = gen(&entity, DatabaseKind::Postgres);
1054 assert!(code.contains("type"));
1056 assert!(code.contains("pub connector_type: String"));
1058 }
1059
1060 #[test]
1063 fn test_no_pk_no_get() {
1064 let entity = ParsedEntity {
1065 struct_name: "Logs".to_string(),
1066 table_name: "logs".to_string(),
1067 is_view: false,
1068 fields: vec![
1069 make_field("message", "message", "String", false, false),
1070 make_field("ts", "ts", "String", false, false),
1071 ],
1072 imports: vec![],
1073 };
1074 let code = gen(&entity, DatabaseKind::Postgres);
1075 assert!(code.contains("pub async fn get_all"));
1076 let without_get_all = code.replace("get_all", "XXX");
1077 assert!(!without_get_all.contains("fn get("));
1078 }
1079
1080 #[test]
1081 fn test_no_pk_no_delete() {
1082 let entity = ParsedEntity {
1083 struct_name: "Logs".to_string(),
1084 table_name: "logs".to_string(),
1085 is_view: false,
1086 fields: vec![
1087 make_field("message", "message", "String", false, false),
1088 ],
1089 imports: vec![],
1090 };
1091 let code = gen(&entity, DatabaseKind::Postgres);
1092 assert!(!code.contains("pub async fn delete"));
1093 }
1094
1095 #[test]
1098 fn test_param_structs_have_default() {
1099 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1100 assert!(code.contains("Default"));
1101 }
1102
1103 #[test]
1106 fn test_entity_imports_forwarded() {
1107 let entity = ParsedEntity {
1108 struct_name: "Users".to_string(),
1109 table_name: "users".to_string(),
1110 is_view: false,
1111 fields: vec![
1112 make_field("id", "id", "Uuid", false, true),
1113 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1114 ],
1115 imports: vec![
1116 "use chrono::{DateTime, Utc};".to_string(),
1117 "use uuid::Uuid;".to_string(),
1118 ],
1119 };
1120 let skip = Methods::all();
1121 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false);
1122 assert!(imports.iter().any(|i| i.contains("chrono")));
1123 assert!(imports.iter().any(|i| i.contains("uuid")));
1124 }
1125
1126 #[test]
1127 fn test_entity_imports_empty_when_no_imports() {
1128 let skip = Methods::all();
1129 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1130 assert!(!imports.iter().any(|i| i.contains("chrono")));
1132 assert!(!imports.iter().any(|i| i.contains("uuid")));
1133 }
1134
1135 #[test]
1138 fn test_macro_get_all() {
1139 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1140 assert!(code.contains("query_as!"));
1141 assert!(!code.contains("query_as::<"));
1142 }
1143
1144 #[test]
1145 fn test_macro_paginate() {
1146 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1147 assert!(code.contains("query_as!"));
1148 assert!(code.contains("per_page, offset"));
1149 }
1150
1151 #[test]
1152 fn test_macro_get() {
1153 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1154 assert!(code.contains("query_as!(Users"));
1156 }
1157
1158 #[test]
1159 fn test_macro_insert_pg() {
1160 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1161 assert!(code.contains("query_as!(Users"));
1162 assert!(code.contains("params.name"));
1163 assert!(code.contains("params.email"));
1164 }
1165
1166 #[test]
1167 fn test_macro_insert_mysql() {
1168 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1169 assert!(code.contains("query!"));
1171 assert!(code.contains("query_scalar!"));
1172 }
1173
1174 #[test]
1175 fn test_macro_update() {
1176 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1177 assert!(code.contains("query_as!(Users"));
1178 assert!(code.contains("params.name"));
1180 assert!(code.contains("params.id"));
1181 }
1182
1183 #[test]
1184 fn test_macro_delete() {
1185 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1186 assert!(code.contains("query!"));
1188 }
1189
1190 #[test]
1191 fn test_macro_no_bind_calls() {
1192 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1193 assert!(!code.contains(".bind("));
1194 }
1195
1196 #[test]
1197 fn test_function_style_uses_bind() {
1198 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1199 assert!(code.contains(".bind("));
1200 assert!(!code.contains("query_as!("));
1201 assert!(!code.contains("query!("));
1202 }
1203}