1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17
18 let entity_ident = format_ident!("{}", entity.struct_name);
19 let repo_name = format!("{}Repository", entity.struct_name);
20 let repo_ident = format_ident!("{}", repo_name);
21
22 let table_name = &entity.table_name;
23
24 let pool_type = pool_type_tokens(db_kind);
26
27 let has_sql_array = entity.fields.iter().any(|f| f.is_sql_array);
31 let use_macro = query_macro && !has_sql_array;
32
33 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
35
36 let entity_parent = entity_module_path
40 .rsplit_once("::")
41 .map(|(parent, _)| parent)
42 .unwrap_or(entity_module_path);
43 for imp in &entity.imports {
44 if let Some(rest) = imp.strip_prefix("use super::") {
45 imports.insert(format!("use {}::{}", entity_parent, rest));
46 } else {
47 imports.insert(imp.clone());
48 }
49 }
50
51 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
53
54 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
56
57 let is_view = entity.is_view;
58
59 let mut method_tokens = Vec::new();
61 let mut param_structs = Vec::new();
62
63 if methods.get_all {
65 let sql = format!("SELECT * FROM {}", table_name);
66 let method = if use_macro {
67 quote! {
68 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
69 sqlx::query_as!(#entity_ident, #sql)
70 .fetch_all(&self.pool)
71 .await
72 }
73 }
74 } else {
75 quote! {
76 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
77 sqlx::query_as::<_, #entity_ident>(#sql)
78 .fetch_all(&self.pool)
79 .await
80 }
81 }
82 };
83 method_tokens.push(method);
84 }
85
86 if methods.paginate {
88 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
89 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
90 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
91 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
92 let sql = match db_kind {
93 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
94 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
95 };
96 let method = if use_macro {
97 quote! {
98 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
99 let total: i64 = sqlx::query_scalar!(#count_sql)
100 .fetch_one(&self.pool)
101 .await?
102 .unwrap_or(0);
103 let per_page = params.per_page;
104 let current_page = params.page;
105 let last_page = (total + per_page - 1) / per_page;
106 let offset = (current_page - 1) * per_page;
107 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
108 .fetch_all(&self.pool)
109 .await?;
110 Ok(#paginated_ident {
111 meta: #pagination_meta_ident {
112 total,
113 per_page,
114 current_page,
115 last_page,
116 first_page: 1,
117 },
118 data,
119 })
120 }
121 }
122 } else {
123 quote! {
124 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
125 let total: i64 = sqlx::query_scalar(#count_sql)
126 .fetch_one(&self.pool)
127 .await?;
128 let per_page = params.per_page;
129 let current_page = params.page;
130 let last_page = (total + per_page - 1) / per_page;
131 let offset = (current_page - 1) * per_page;
132 let data = sqlx::query_as::<_, #entity_ident>(#sql)
133 .bind(per_page)
134 .bind(offset)
135 .fetch_all(&self.pool)
136 .await?;
137 Ok(#paginated_ident {
138 meta: #pagination_meta_ident {
139 total,
140 per_page,
141 current_page,
142 last_page,
143 first_page: 1,
144 },
145 data,
146 })
147 }
148 }
149 };
150 method_tokens.push(method);
151 param_structs.push(quote! {
152 #[derive(Debug, Clone, Default)]
153 pub struct #paginate_params_ident {
154 pub page: i64,
155 pub per_page: i64,
156 }
157 });
158 param_structs.push(quote! {
159 #[derive(Debug, Clone)]
160 pub struct #pagination_meta_ident {
161 pub total: i64,
162 pub per_page: i64,
163 pub current_page: i64,
164 pub last_page: i64,
165 pub first_page: i64,
166 }
167 });
168 param_structs.push(quote! {
169 #[derive(Debug, Clone)]
170 pub struct #paginated_ident {
171 pub meta: #pagination_meta_ident,
172 pub data: Vec<#entity_ident>,
173 }
174 });
175 }
176
177 if methods.get && !pk_fields.is_empty() {
179 let pk_params: Vec<TokenStream> = pk_fields
180 .iter()
181 .map(|f| {
182 let name = format_ident!("{}", f.rust_name);
183 let ty: TokenStream = f.inner_type.parse().unwrap();
184 quote! { #name: &#ty }
185 })
186 .collect();
187
188 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
189 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
190 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
191 let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
192
193 let binds: Vec<TokenStream> = pk_fields
194 .iter()
195 .map(|f| {
196 let name = format_ident!("{}", f.rust_name);
197 quote! { .bind(#name) }
198 })
199 .collect();
200
201 let method = if use_macro {
202 let pk_arg_names: Vec<TokenStream> = pk_fields
203 .iter()
204 .map(|f| {
205 let name = format_ident!("{}", f.rust_name);
206 quote! { #name }
207 })
208 .collect();
209 quote! {
210 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
211 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
212 .fetch_optional(&self.pool)
213 .await
214 }
215 }
216 } else {
217 quote! {
218 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
219 sqlx::query_as::<_, #entity_ident>(#sql)
220 #(#binds)*
221 .fetch_optional(&self.pool)
222 .await
223 }
224 }
225 };
226 method_tokens.push(method);
227 }
228
229 if !is_view && methods.insert && !non_pk_fields.is_empty() {
231 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
232
233 let insert_fields: Vec<TokenStream> = non_pk_fields
234 .iter()
235 .map(|f| {
236 let name = format_ident!("{}", f.rust_name);
237 let ty: TokenStream = f.rust_type.parse().unwrap();
238 quote! { pub #name: #ty, }
239 })
240 .collect();
241
242 let col_names: Vec<&str> = non_pk_fields.iter().map(|f| f.column_name.as_str()).collect();
243 let col_list = col_names.join(", ");
244 let placeholders = build_placeholders(non_pk_fields.len(), db_kind, 1);
246 let placeholders_cast = build_placeholders_with_cast(&non_pk_fields, db_kind, 1, true);
247
248 let build_insert_sql = |ph: &str| match db_kind {
249 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
250 format!(
251 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
252 table_name, col_list, ph
253 )
254 }
255 DatabaseKind::Mysql => {
256 format!(
257 "INSERT INTO {} ({}) VALUES ({})",
258 table_name, col_list, ph
259 )
260 }
261 };
262 let sql = build_insert_sql(&placeholders);
263 let sql_macro = build_insert_sql(&placeholders_cast);
264
265 let binds: Vec<TokenStream> = non_pk_fields
266 .iter()
267 .map(|f| {
268 let name = format_ident!("{}", f.rust_name);
269 quote! { .bind(¶ms.#name) }
270 })
271 .collect();
272
273 let insert_method = build_insert_method_parsed(
274 &entity_ident,
275 &insert_params_ident,
276 &sql,
277 &sql_macro,
278 &binds,
279 db_kind,
280 table_name,
281 &pk_fields,
282 &non_pk_fields,
283 use_macro,
284 );
285 method_tokens.push(insert_method);
286
287 param_structs.push(quote! {
288 #[derive(Debug, Clone, Default)]
289 pub struct #insert_params_ident {
290 #(#insert_fields)*
291 }
292 });
293 }
294
295 if !is_view && methods.update && !pk_fields.is_empty() {
297 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
298
299 let update_fields: Vec<TokenStream> = entity
300 .fields
301 .iter()
302 .map(|f| {
303 let name = format_ident!("{}", f.rust_name);
304 let ty: TokenStream = f.rust_type.parse().unwrap();
305 quote! { pub #name: #ty, }
306 })
307 .collect();
308
309 let set_cols: Vec<String> = non_pk_fields
310 .iter()
311 .enumerate()
312 .map(|(i, f)| {
313 let p = placeholder(db_kind, i + 1);
314 format!("{} = {}", f.column_name, p)
315 })
316 .collect();
317 let set_clause = set_cols.join(", ");
318
319 let set_cols_cast: Vec<String> = non_pk_fields
321 .iter()
322 .enumerate()
323 .map(|(i, f)| {
324 let p = placeholder_with_cast(db_kind, i + 1, f);
325 format!("{} = {}", f.column_name, p)
326 })
327 .collect();
328 let set_clause_cast = set_cols_cast.join(", ");
329
330 let pk_start = non_pk_fields.len() + 1;
331 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
332 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
333
334 let build_update_sql = |sc: &str, wc: &str| match db_kind {
335 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
336 format!(
337 "UPDATE {} SET {} WHERE {} RETURNING *",
338 table_name, sc, wc
339 )
340 }
341 DatabaseKind::Mysql => {
342 format!(
343 "UPDATE {} SET {} WHERE {}",
344 table_name, sc, wc
345 )
346 }
347 };
348 let sql = build_update_sql(&set_clause, &where_clause);
349 let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
350
351 let mut all_binds: Vec<TokenStream> = non_pk_fields
353 .iter()
354 .map(|f| {
355 let name = format_ident!("{}", f.rust_name);
356 quote! { .bind(¶ms.#name) }
357 })
358 .collect();
359 for f in &pk_fields {
360 let name = format_ident!("{}", f.rust_name);
361 all_binds.push(quote! { .bind(¶ms.#name) });
362 }
363
364 let update_macro_args: Vec<TokenStream> = non_pk_fields
366 .iter()
367 .chain(pk_fields.iter())
368 .map(|f| {
369 let name = format_ident!("{}", f.rust_name);
370 quote! { params.#name }
371 })
372 .collect();
373
374 let update_method = if use_macro {
375 match db_kind {
376 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
377 quote! {
378 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
379 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
380 .fetch_one(&self.pool)
381 .await
382 }
383 }
384 }
385 DatabaseKind::Mysql => {
386 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
387 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
388 let pk_macro_args: Vec<TokenStream> = pk_fields
389 .iter()
390 .map(|f| {
391 let name = format_ident!("{}", f.rust_name);
392 quote! { params.#name }
393 })
394 .collect();
395 quote! {
396 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
397 sqlx::query!(#sql_macro, #(#update_macro_args),*)
398 .execute(&self.pool)
399 .await?;
400 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
401 .fetch_one(&self.pool)
402 .await
403 }
404 }
405 }
406 }
407 } else {
408 match db_kind {
409 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
410 quote! {
411 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
412 sqlx::query_as::<_, #entity_ident>(#sql)
413 #(#all_binds)*
414 .fetch_one(&self.pool)
415 .await
416 }
417 }
418 }
419 DatabaseKind::Mysql => {
420 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
421 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
422 let pk_binds: Vec<TokenStream> = pk_fields
423 .iter()
424 .map(|f| {
425 let name = format_ident!("{}", f.rust_name);
426 quote! { .bind(¶ms.#name) }
427 })
428 .collect();
429 quote! {
430 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
431 sqlx::query(#sql)
432 #(#all_binds)*
433 .execute(&self.pool)
434 .await?;
435 sqlx::query_as::<_, #entity_ident>(#select_sql)
436 #(#pk_binds)*
437 .fetch_one(&self.pool)
438 .await
439 }
440 }
441 }
442 }
443 };
444 method_tokens.push(update_method);
445
446 param_structs.push(quote! {
447 #[derive(Debug, Clone, Default)]
448 pub struct #update_params_ident {
449 #(#update_fields)*
450 }
451 });
452 }
453
454 if !is_view && methods.delete && !pk_fields.is_empty() {
456 let pk_params: Vec<TokenStream> = pk_fields
457 .iter()
458 .map(|f| {
459 let name = format_ident!("{}", f.rust_name);
460 let ty: TokenStream = f.inner_type.parse().unwrap();
461 quote! { #name: &#ty }
462 })
463 .collect();
464
465 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
466 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
467 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
468 let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
469
470 let binds: Vec<TokenStream> = pk_fields
471 .iter()
472 .map(|f| {
473 let name = format_ident!("{}", f.rust_name);
474 quote! { .bind(#name) }
475 })
476 .collect();
477
478 let method = if query_macro {
479 let pk_arg_names: Vec<TokenStream> = pk_fields
480 .iter()
481 .map(|f| {
482 let name = format_ident!("{}", f.rust_name);
483 quote! { #name }
484 })
485 .collect();
486 quote! {
487 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
488 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
489 .execute(&self.pool)
490 .await?;
491 Ok(())
492 }
493 }
494 } else {
495 quote! {
496 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
497 sqlx::query(#sql)
498 #(#binds)*
499 .execute(&self.pool)
500 .await?;
501 Ok(())
502 }
503 }
504 };
505 method_tokens.push(method);
506 }
507
508 let tokens = quote! {
509 #(#param_structs)*
510
511 pub struct #repo_ident {
512 pool: #pool_type,
513 }
514
515 impl #repo_ident {
516 pub fn new(pool: #pool_type) -> Self {
517 Self { pool }
518 }
519
520 #(#method_tokens)*
521 }
522 };
523
524 (tokens, imports)
525}
526
527fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
528 match db_kind {
529 DatabaseKind::Postgres => quote! { sqlx::PgPool },
530 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
531 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
532 }
533}
534
535fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
536 match db_kind {
537 DatabaseKind::Postgres => format!("${}", index),
538 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
539 }
540}
541
542fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
543 let base = placeholder(db_kind, index);
544 match (&field.sql_type, field.is_sql_array) {
545 (Some(t), true) => format!("{} as {}[]", base, t),
546 (Some(t), false) => format!("{} as {}", base, t),
547 (None, _) => base,
548 }
549}
550
551fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
552 (0..count)
553 .map(|i| placeholder(db_kind, start + i))
554 .collect::<Vec<_>>()
555 .join(", ")
556}
557
558fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
559 fields
560 .iter()
561 .enumerate()
562 .map(|(i, f)| {
563 if use_cast {
564 placeholder_with_cast(db_kind, start + i, f)
565 } else {
566 placeholder(db_kind, start + i)
567 }
568 })
569 .collect::<Vec<_>>()
570 .join(", ")
571}
572
573fn build_where_clause_parsed(
574 pk_fields: &[&ParsedField],
575 db_kind: DatabaseKind,
576 start_index: usize,
577) -> String {
578 pk_fields
579 .iter()
580 .enumerate()
581 .map(|(i, f)| {
582 let p = placeholder(db_kind, start_index + i);
583 format!("{} = {}", f.column_name, p)
584 })
585 .collect::<Vec<_>>()
586 .join(" AND ")
587}
588
589fn build_where_clause_cast(
590 pk_fields: &[&ParsedField],
591 db_kind: DatabaseKind,
592 start_index: usize,
593) -> String {
594 pk_fields
595 .iter()
596 .enumerate()
597 .map(|(i, f)| {
598 let p = placeholder_with_cast(db_kind, start_index + i, f);
599 format!("{} = {}", f.column_name, p)
600 })
601 .collect::<Vec<_>>()
602 .join(" AND ")
603}
604
605#[allow(clippy::too_many_arguments)]
606fn build_insert_method_parsed(
607 entity_ident: &proc_macro2::Ident,
608 insert_params_ident: &proc_macro2::Ident,
609 sql: &str,
610 sql_macro: &str,
611 binds: &[TokenStream],
612 db_kind: DatabaseKind,
613 table_name: &str,
614 pk_fields: &[&ParsedField],
615 non_pk_fields: &[&ParsedField],
616 use_macro: bool,
617) -> TokenStream {
618 if use_macro {
619 let macro_args: Vec<TokenStream> = non_pk_fields
620 .iter()
621 .map(|f| {
622 let name = format_ident!("{}", f.rust_name);
623 quote! { params.#name }
624 })
625 .collect();
626
627 match db_kind {
628 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
629 quote! {
630 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
631 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
632 .fetch_one(&self.pool)
633 .await
634 }
635 }
636 }
637 DatabaseKind::Mysql => {
638 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
639 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
640 quote! {
641 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
642 sqlx::query!(#sql_macro, #(#macro_args),*)
643 .execute(&self.pool)
644 .await?;
645 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
646 .fetch_one(&self.pool)
647 .await?;
648 sqlx::query_as!(#entity_ident, #select_sql, id)
649 .fetch_one(&self.pool)
650 .await
651 }
652 }
653 }
654 }
655 } else {
656 match db_kind {
657 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
658 quote! {
659 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
660 sqlx::query_as::<_, #entity_ident>(#sql)
661 #(#binds)*
662 .fetch_one(&self.pool)
663 .await
664 }
665 }
666 }
667 DatabaseKind::Mysql => {
668 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
669 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
670 quote! {
671 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
672 sqlx::query(#sql)
673 #(#binds)*
674 .execute(&self.pool)
675 .await?;
676 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
677 .fetch_one(&self.pool)
678 .await?;
679 sqlx::query_as::<_, #entity_ident>(#select_sql)
680 .bind(id)
681 .fetch_one(&self.pool)
682 .await
683 }
684 }
685 }
686 }
687 }
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693 use crate::codegen::parse_and_format;
694 use crate::cli::Methods;
695
696 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
697 let inner_type = if nullable {
698 rust_type
700 .strip_prefix("Option<")
701 .and_then(|s| s.strip_suffix('>'))
702 .unwrap_or(rust_type)
703 .to_string()
704 } else {
705 rust_type.to_string()
706 };
707 ParsedField {
708 rust_name: rust_name.to_string(),
709 column_name: column_name.to_string(),
710 rust_type: rust_type.to_string(),
711 is_nullable: nullable,
712 inner_type,
713 is_primary_key: is_pk,
714 sql_type: None,
715 is_sql_array: false,
716 }
717 }
718
719 fn standard_entity() -> ParsedEntity {
720 ParsedEntity {
721 struct_name: "Users".to_string(),
722 table_name: "users".to_string(),
723 schema_name: None,
724 is_view: false,
725 fields: vec![
726 make_field("id", "id", "i32", false, true),
727 make_field("name", "name", "String", false, false),
728 make_field("email", "email", "Option<String>", true, false),
729 ],
730 imports: vec![],
731 }
732 }
733
734 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
735 let skip = Methods::all();
736 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false);
737 parse_and_format(&tokens)
738 }
739
740 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
741 let skip = Methods::all();
742 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true);
743 parse_and_format(&tokens)
744 }
745
746 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
747 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false);
748 parse_and_format(&tokens)
749 }
750
751 #[test]
754 fn test_repo_struct_name() {
755 let code = gen(&standard_entity(), DatabaseKind::Postgres);
756 assert!(code.contains("pub struct UsersRepository"));
757 }
758
759 #[test]
760 fn test_repo_new_method() {
761 let code = gen(&standard_entity(), DatabaseKind::Postgres);
762 assert!(code.contains("pub fn new("));
763 }
764
765 #[test]
766 fn test_repo_pool_field_pg() {
767 let code = gen(&standard_entity(), DatabaseKind::Postgres);
768 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
769 }
770
771 #[test]
772 fn test_repo_pool_field_mysql() {
773 let code = gen(&standard_entity(), DatabaseKind::Mysql);
774 assert!(code.contains("MySqlPool") || code.contains("MySql"));
775 }
776
777 #[test]
778 fn test_repo_pool_field_sqlite() {
779 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
780 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
781 }
782
783 #[test]
786 fn test_get_all_method() {
787 let code = gen(&standard_entity(), DatabaseKind::Postgres);
788 assert!(code.contains("pub async fn get_all"));
789 }
790
791 #[test]
792 fn test_get_all_returns_vec() {
793 let code = gen(&standard_entity(), DatabaseKind::Postgres);
794 assert!(code.contains("Vec<Users>"));
795 }
796
797 #[test]
798 fn test_get_all_sql() {
799 let code = gen(&standard_entity(), DatabaseKind::Postgres);
800 assert!(code.contains("SELECT * FROM users"));
801 }
802
803 #[test]
806 fn test_paginate_method() {
807 let code = gen(&standard_entity(), DatabaseKind::Postgres);
808 assert!(code.contains("pub async fn paginate"));
809 }
810
811 #[test]
812 fn test_paginate_params_struct() {
813 let code = gen(&standard_entity(), DatabaseKind::Postgres);
814 assert!(code.contains("pub struct PaginateUsersParams"));
815 }
816
817 #[test]
818 fn test_paginate_params_fields() {
819 let code = gen(&standard_entity(), DatabaseKind::Postgres);
820 assert!(code.contains("pub page: i64"));
821 assert!(code.contains("pub per_page: i64"));
822 }
823
824 #[test]
825 fn test_paginate_returns_paginated() {
826 let code = gen(&standard_entity(), DatabaseKind::Postgres);
827 assert!(code.contains("PaginatedUsers"));
828 assert!(code.contains("PaginationUsersMeta"));
829 }
830
831 #[test]
832 fn test_paginate_meta_struct() {
833 let code = gen(&standard_entity(), DatabaseKind::Postgres);
834 assert!(code.contains("pub struct PaginationUsersMeta"));
835 assert!(code.contains("pub total: i64"));
836 assert!(code.contains("pub last_page: i64"));
837 assert!(code.contains("pub first_page: i64"));
838 assert!(code.contains("pub current_page: i64"));
839 }
840
841 #[test]
842 fn test_paginate_data_struct() {
843 let code = gen(&standard_entity(), DatabaseKind::Postgres);
844 assert!(code.contains("pub struct PaginatedUsers"));
845 assert!(code.contains("pub meta: PaginationUsersMeta"));
846 assert!(code.contains("pub data: Vec<Users>"));
847 }
848
849 #[test]
850 fn test_paginate_count_sql() {
851 let code = gen(&standard_entity(), DatabaseKind::Postgres);
852 assert!(code.contains("SELECT COUNT(*) FROM users"));
853 }
854
855 #[test]
856 fn test_paginate_sql_pg() {
857 let code = gen(&standard_entity(), DatabaseKind::Postgres);
858 assert!(code.contains("LIMIT $1 OFFSET $2"));
859 }
860
861 #[test]
862 fn test_paginate_sql_mysql() {
863 let code = gen(&standard_entity(), DatabaseKind::Mysql);
864 assert!(code.contains("LIMIT ? OFFSET ?"));
865 }
866
867 #[test]
870 fn test_get_method() {
871 let code = gen(&standard_entity(), DatabaseKind::Postgres);
872 assert!(code.contains("pub async fn get"));
873 }
874
875 #[test]
876 fn test_get_returns_option() {
877 let code = gen(&standard_entity(), DatabaseKind::Postgres);
878 assert!(code.contains("Option<Users>"));
879 }
880
881 #[test]
882 fn test_get_where_pk_pg() {
883 let code = gen(&standard_entity(), DatabaseKind::Postgres);
884 assert!(code.contains("WHERE id = $1"));
885 }
886
887 #[test]
888 fn test_get_where_pk_mysql() {
889 let code = gen(&standard_entity(), DatabaseKind::Mysql);
890 assert!(code.contains("WHERE id = ?"));
891 }
892
893 #[test]
896 fn test_insert_method() {
897 let code = gen(&standard_entity(), DatabaseKind::Postgres);
898 assert!(code.contains("pub async fn insert"));
899 }
900
901 #[test]
902 fn test_insert_params_struct() {
903 let code = gen(&standard_entity(), DatabaseKind::Postgres);
904 assert!(code.contains("pub struct InsertUsersParams"));
905 }
906
907 #[test]
908 fn test_insert_params_no_pk() {
909 let code = gen(&standard_entity(), DatabaseKind::Postgres);
910 assert!(code.contains("pub name: String"));
911 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
912 }
913
914 #[test]
915 fn test_insert_returning_pg() {
916 let code = gen(&standard_entity(), DatabaseKind::Postgres);
917 assert!(code.contains("RETURNING *"));
918 }
919
920 #[test]
921 fn test_insert_returning_sqlite() {
922 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
923 assert!(code.contains("RETURNING *"));
924 }
925
926 #[test]
927 fn test_insert_mysql_last_insert_id() {
928 let code = gen(&standard_entity(), DatabaseKind::Mysql);
929 assert!(code.contains("LAST_INSERT_ID"));
930 }
931
932 #[test]
935 fn test_update_method() {
936 let code = gen(&standard_entity(), DatabaseKind::Postgres);
937 assert!(code.contains("pub async fn update"));
938 }
939
940 #[test]
941 fn test_update_params_struct() {
942 let code = gen(&standard_entity(), DatabaseKind::Postgres);
943 assert!(code.contains("pub struct UpdateUsersParams"));
944 }
945
946 #[test]
947 fn test_update_params_all_cols() {
948 let code = gen(&standard_entity(), DatabaseKind::Postgres);
949 assert!(code.contains("pub id: i32"));
950 assert!(code.contains("pub name: String"));
951 }
952
953 #[test]
954 fn test_update_set_clause_pg() {
955 let code = gen(&standard_entity(), DatabaseKind::Postgres);
956 assert!(code.contains("SET name = $1"));
957 assert!(code.contains("WHERE id = $3"));
958 }
959
960 #[test]
961 fn test_update_returning_pg() {
962 let code = gen(&standard_entity(), DatabaseKind::Postgres);
963 assert!(code.contains("UPDATE users SET"));
964 assert!(code.contains("RETURNING *"));
965 }
966
967 #[test]
970 fn test_delete_method() {
971 let code = gen(&standard_entity(), DatabaseKind::Postgres);
972 assert!(code.contains("pub async fn delete"));
973 }
974
975 #[test]
976 fn test_delete_where_pk() {
977 let code = gen(&standard_entity(), DatabaseKind::Postgres);
978 assert!(code.contains("DELETE FROM users WHERE id = $1"));
979 }
980
981 #[test]
982 fn test_delete_returns_unit() {
983 let code = gen(&standard_entity(), DatabaseKind::Postgres);
984 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
985 }
986
987 #[test]
990 fn test_view_no_insert() {
991 let mut entity = standard_entity();
992 entity.is_view = true;
993 let code = gen(&entity, DatabaseKind::Postgres);
994 assert!(!code.contains("pub async fn insert"));
995 }
996
997 #[test]
998 fn test_view_no_update() {
999 let mut entity = standard_entity();
1000 entity.is_view = true;
1001 let code = gen(&entity, DatabaseKind::Postgres);
1002 assert!(!code.contains("pub async fn update"));
1003 }
1004
1005 #[test]
1006 fn test_view_no_delete() {
1007 let mut entity = standard_entity();
1008 entity.is_view = true;
1009 let code = gen(&entity, DatabaseKind::Postgres);
1010 assert!(!code.contains("pub async fn delete"));
1011 }
1012
1013 #[test]
1014 fn test_view_has_get_all() {
1015 let mut entity = standard_entity();
1016 entity.is_view = true;
1017 let code = gen(&entity, DatabaseKind::Postgres);
1018 assert!(code.contains("pub async fn get_all"));
1019 }
1020
1021 #[test]
1022 fn test_view_has_paginate() {
1023 let mut entity = standard_entity();
1024 entity.is_view = true;
1025 let code = gen(&entity, DatabaseKind::Postgres);
1026 assert!(code.contains("pub async fn paginate"));
1027 }
1028
1029 #[test]
1030 fn test_view_has_get() {
1031 let mut entity = standard_entity();
1032 entity.is_view = true;
1033 let code = gen(&entity, DatabaseKind::Postgres);
1034 assert!(code.contains("pub async fn get"));
1035 }
1036
1037 #[test]
1040 fn test_only_get_all() {
1041 let m = Methods { get_all: true, ..Default::default() };
1042 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1043 assert!(code.contains("pub async fn get_all"));
1044 assert!(!code.contains("pub async fn paginate"));
1045 assert!(!code.contains("pub async fn insert"));
1046 }
1047
1048 #[test]
1049 fn test_without_get_all() {
1050 let m = Methods { get_all: false, ..Methods::all() };
1051 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1052 assert!(!code.contains("pub async fn get_all"));
1053 }
1054
1055 #[test]
1056 fn test_without_paginate() {
1057 let m = Methods { paginate: false, ..Methods::all() };
1058 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1059 assert!(!code.contains("pub async fn paginate"));
1060 assert!(!code.contains("PaginateUsersParams"));
1061 }
1062
1063 #[test]
1064 fn test_without_get() {
1065 let m = Methods { get: false, ..Methods::all() };
1066 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1067 assert!(code.contains("pub async fn get_all"));
1068 let without_get_all = code.replace("get_all", "XXX");
1069 assert!(!without_get_all.contains("fn get("));
1070 }
1071
1072 #[test]
1073 fn test_without_insert() {
1074 let m = Methods { insert: false, ..Methods::all() };
1075 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1076 assert!(!code.contains("pub async fn insert"));
1077 assert!(!code.contains("InsertUsersParams"));
1078 }
1079
1080 #[test]
1081 fn test_without_update() {
1082 let m = Methods { update: false, ..Methods::all() };
1083 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1084 assert!(!code.contains("pub async fn update"));
1085 assert!(!code.contains("UpdateUsersParams"));
1086 }
1087
1088 #[test]
1089 fn test_without_delete() {
1090 let m = Methods { delete: false, ..Methods::all() };
1091 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1092 assert!(!code.contains("pub async fn delete"));
1093 }
1094
1095 #[test]
1096 fn test_empty_methods_no_methods() {
1097 let m = Methods::default();
1098 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1099 assert!(!code.contains("pub async fn get_all"));
1100 assert!(!code.contains("pub async fn paginate"));
1101 assert!(!code.contains("pub async fn insert"));
1102 assert!(!code.contains("pub async fn update"));
1103 assert!(!code.contains("pub async fn delete"));
1104 }
1105
1106 #[test]
1109 fn test_no_pool_import() {
1110 let skip = Methods::all();
1111 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1112 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1113 }
1114
1115 #[test]
1116 fn test_imports_contain_entity() {
1117 let skip = Methods::all();
1118 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1119 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1120 }
1121
1122 #[test]
1125 fn test_renamed_column_in_sql() {
1126 let entity = ParsedEntity {
1127 struct_name: "Connector".to_string(),
1128 table_name: "connector".to_string(),
1129 schema_name: None,
1130 is_view: false,
1131 fields: vec![
1132 make_field("id", "id", "i32", false, true),
1133 make_field("connector_type", "type", "String", false, false),
1134 ],
1135 imports: vec![],
1136 };
1137 let code = gen(&entity, DatabaseKind::Postgres);
1138 assert!(code.contains("type"));
1140 assert!(code.contains("pub connector_type: String"));
1142 }
1143
1144 #[test]
1147 fn test_no_pk_no_get() {
1148 let entity = ParsedEntity {
1149 struct_name: "Logs".to_string(),
1150 table_name: "logs".to_string(),
1151 schema_name: None,
1152 is_view: false,
1153 fields: vec![
1154 make_field("message", "message", "String", false, false),
1155 make_field("ts", "ts", "String", false, false),
1156 ],
1157 imports: vec![],
1158 };
1159 let code = gen(&entity, DatabaseKind::Postgres);
1160 assert!(code.contains("pub async fn get_all"));
1161 let without_get_all = code.replace("get_all", "XXX");
1162 assert!(!without_get_all.contains("fn get("));
1163 }
1164
1165 #[test]
1166 fn test_no_pk_no_delete() {
1167 let entity = ParsedEntity {
1168 struct_name: "Logs".to_string(),
1169 table_name: "logs".to_string(),
1170 schema_name: None,
1171 is_view: false,
1172 fields: vec![
1173 make_field("message", "message", "String", false, false),
1174 ],
1175 imports: vec![],
1176 };
1177 let code = gen(&entity, DatabaseKind::Postgres);
1178 assert!(!code.contains("pub async fn delete"));
1179 }
1180
1181 #[test]
1184 fn test_param_structs_have_default() {
1185 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1186 assert!(code.contains("Default"));
1187 }
1188
1189 #[test]
1192 fn test_entity_imports_forwarded() {
1193 let entity = ParsedEntity {
1194 struct_name: "Users".to_string(),
1195 table_name: "users".to_string(),
1196 schema_name: None,
1197 is_view: false,
1198 fields: vec![
1199 make_field("id", "id", "Uuid", false, true),
1200 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1201 ],
1202 imports: vec![
1203 "use chrono::{DateTime, Utc};".to_string(),
1204 "use uuid::Uuid;".to_string(),
1205 ],
1206 };
1207 let skip = Methods::all();
1208 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false);
1209 assert!(imports.iter().any(|i| i.contains("chrono")));
1210 assert!(imports.iter().any(|i| i.contains("uuid")));
1211 }
1212
1213 #[test]
1214 fn test_entity_imports_empty_when_no_imports() {
1215 let skip = Methods::all();
1216 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false);
1217 assert!(!imports.iter().any(|i| i.contains("chrono")));
1219 assert!(!imports.iter().any(|i| i.contains("uuid")));
1220 }
1221
1222 #[test]
1225 fn test_macro_get_all() {
1226 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1227 assert!(code.contains("query_as!"));
1228 assert!(!code.contains("query_as::<"));
1229 }
1230
1231 #[test]
1232 fn test_macro_paginate() {
1233 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1234 assert!(code.contains("query_as!"));
1235 assert!(code.contains("per_page, offset"));
1236 }
1237
1238 #[test]
1239 fn test_macro_get() {
1240 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1241 assert!(code.contains("query_as!(Users"));
1243 }
1244
1245 #[test]
1246 fn test_macro_insert_pg() {
1247 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1248 assert!(code.contains("query_as!(Users"));
1249 assert!(code.contains("params.name"));
1250 assert!(code.contains("params.email"));
1251 }
1252
1253 #[test]
1254 fn test_macro_insert_mysql() {
1255 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1256 assert!(code.contains("query!"));
1258 assert!(code.contains("query_scalar!"));
1259 }
1260
1261 #[test]
1262 fn test_macro_update() {
1263 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1264 assert!(code.contains("query_as!(Users"));
1265 assert!(code.contains("params.name"));
1267 assert!(code.contains("params.id"));
1268 }
1269
1270 #[test]
1271 fn test_macro_delete() {
1272 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1273 assert!(code.contains("query!"));
1275 }
1276
1277 #[test]
1278 fn test_macro_no_bind_calls() {
1279 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1280 assert!(!code.contains(".bind("));
1281 }
1282
1283 #[test]
1284 fn test_function_style_uses_bind() {
1285 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1286 assert!(code.contains(".bind("));
1287 assert!(!code.contains("query_as!("));
1288 assert!(!code.contains("query!("));
1289 }
1290
1291 fn entity_with_sql_array() -> ParsedEntity {
1294 ParsedEntity {
1295 struct_name: "AgentConnector".to_string(),
1296 table_name: "agent.agent_connector".to_string(),
1297 schema_name: Some("agent".to_string()),
1298 is_view: false,
1299 fields: vec![
1300 ParsedField {
1301 rust_name: "connector_id".to_string(),
1302 column_name: "connector_id".to_string(),
1303 rust_type: "Uuid".to_string(),
1304 inner_type: "Uuid".to_string(),
1305 is_nullable: false,
1306 is_primary_key: true,
1307 sql_type: None,
1308 is_sql_array: false,
1309 },
1310 ParsedField {
1311 rust_name: "agent_id".to_string(),
1312 column_name: "agent_id".to_string(),
1313 rust_type: "Uuid".to_string(),
1314 inner_type: "Uuid".to_string(),
1315 is_nullable: false,
1316 is_primary_key: false,
1317 sql_type: None,
1318 is_sql_array: false,
1319 },
1320 ParsedField {
1321 rust_name: "usages".to_string(),
1322 column_name: "usages".to_string(),
1323 rust_type: "Vec<ConnectorUsages>".to_string(),
1324 inner_type: "Vec<ConnectorUsages>".to_string(),
1325 is_nullable: false,
1326 is_primary_key: false,
1327 sql_type: Some("agent.connector_usages".to_string()),
1328 is_sql_array: true,
1329 },
1330 ],
1331 imports: vec!["use uuid::Uuid;".to_string()],
1332 }
1333 }
1334
1335 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1336 let skip = Methods::all();
1337 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true);
1338 parse_and_format(&tokens)
1339 }
1340
1341 #[test]
1342 fn test_sql_array_macro_get_all_uses_runtime() {
1343 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1344 assert!(code.contains("query_as::<"));
1346 }
1347
1348 #[test]
1349 fn test_sql_array_macro_get_uses_runtime() {
1350 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1351 assert!(code.contains(".bind("));
1353 }
1354
1355 #[test]
1356 fn test_sql_array_macro_insert_uses_runtime() {
1357 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1358 assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1360 }
1361
1362 #[test]
1363 fn test_sql_array_macro_update_uses_runtime() {
1364 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1365 assert!(code.contains("query_as::<"));
1367 }
1368
1369 #[test]
1370 fn test_sql_array_macro_delete_still_uses_macro() {
1371 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1372 assert!(code.contains("query!"));
1374 }
1375
1376 #[test]
1377 fn test_sql_array_no_query_as_macro() {
1378 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1379 assert!(!code.contains("query_as!("));
1381 }
1382}