1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15 pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18
19 let entity_ident = format_ident!("{}", entity.struct_name);
20 let repo_name = format!("{}Repository", entity.struct_name);
21 let repo_ident = format_ident!("{}", repo_name);
22
23 let table_name = match &entity.schema_name {
24 Some(schema) => format!("{}.{}", schema, entity.table_name),
25 None => entity.table_name.clone(),
26 };
27
28 let pool_type = pool_type_tokens(db_kind);
30
31 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40 let entity_parent = entity_module_path
44 .rsplit_once("::")
45 .map(|(parent, _)| parent)
46 .unwrap_or(entity_module_path);
47 for imp in &entity.imports {
48 if let Some(rest) = imp.strip_prefix("use super::") {
49 imports.insert(format!("use {}::{}", entity_parent, rest));
50 } else {
51 imports.insert(imp.clone());
52 }
53 }
54
55 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61 let is_view = entity.is_view;
62
63 let mut method_tokens = Vec::new();
65 let mut param_structs = Vec::new();
66
67 if methods.get_all {
69 let sql = format!("SELECT * FROM {}", table_name);
70 let method = if use_macro {
71 quote! {
72 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73 sqlx::query_as!(#entity_ident, #sql)
74 .fetch_all(&self.pool)
75 .await
76 }
77 }
78 } else {
79 quote! {
80 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81 sqlx::query_as::<_, #entity_ident>(#sql)
82 .fetch_all(&self.pool)
83 .await
84 }
85 }
86 };
87 method_tokens.push(method);
88 }
89
90 if methods.paginate {
92 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96 let sql = match db_kind {
97 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99 };
100 let method = if use_macro {
101 quote! {
102 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103 let total: i64 = sqlx::query_scalar!(#count_sql)
104 .fetch_one(&self.pool)
105 .await?
106 .unwrap_or(0);
107 let per_page = params.per_page;
108 let current_page = params.page;
109 let last_page = (total + per_page - 1) / per_page;
110 let offset = (current_page - 1) * per_page;
111 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112 .fetch_all(&self.pool)
113 .await?;
114 Ok(#paginated_ident {
115 meta: #pagination_meta_ident {
116 total,
117 per_page,
118 current_page,
119 last_page,
120 first_page: 1,
121 },
122 data,
123 })
124 }
125 }
126 } else {
127 quote! {
128 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129 let total: i64 = sqlx::query_scalar(#count_sql)
130 .fetch_one(&self.pool)
131 .await?;
132 let per_page = params.per_page;
133 let current_page = params.page;
134 let last_page = (total + per_page - 1) / per_page;
135 let offset = (current_page - 1) * per_page;
136 let data = sqlx::query_as::<_, #entity_ident>(#sql)
137 .bind(per_page)
138 .bind(offset)
139 .fetch_all(&self.pool)
140 .await?;
141 Ok(#paginated_ident {
142 meta: #pagination_meta_ident {
143 total,
144 per_page,
145 current_page,
146 last_page,
147 first_page: 1,
148 },
149 data,
150 })
151 }
152 }
153 };
154 method_tokens.push(method);
155 param_structs.push(quote! {
156 #[derive(Debug, Clone, Default)]
157 pub struct #paginate_params_ident {
158 pub page: i64,
159 pub per_page: i64,
160 }
161 });
162 param_structs.push(quote! {
163 #[derive(Debug, Clone)]
164 pub struct #pagination_meta_ident {
165 pub total: i64,
166 pub per_page: i64,
167 pub current_page: i64,
168 pub last_page: i64,
169 pub first_page: i64,
170 }
171 });
172 param_structs.push(quote! {
173 #[derive(Debug, Clone)]
174 pub struct #paginated_ident {
175 pub meta: #pagination_meta_ident,
176 pub data: Vec<#entity_ident>,
177 }
178 });
179 }
180
181 if methods.get && !pk_fields.is_empty() {
183 let pk_params: Vec<TokenStream> = pk_fields
184 .iter()
185 .map(|f| {
186 let name = format_ident!("{}", f.rust_name);
187 let ty: TokenStream = f.inner_type.parse().unwrap();
188 quote! { #name: #ty }
189 })
190 .collect();
191
192 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195 let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197 let binds: Vec<TokenStream> = pk_fields
198 .iter()
199 .map(|f| {
200 let name = format_ident!("{}", f.rust_name);
201 quote! { .bind(#name) }
202 })
203 .collect();
204
205 let method = if use_macro {
206 let pk_arg_names: Vec<TokenStream> = pk_fields
207 .iter()
208 .map(|f| {
209 let name = format_ident!("{}", f.rust_name);
210 quote! { #name }
211 })
212 .collect();
213 quote! {
214 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216 .fetch_optional(&self.pool)
217 .await
218 }
219 }
220 } else {
221 quote! {
222 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223 sqlx::query_as::<_, #entity_ident>(#sql)
224 #(#binds)*
225 .fetch_optional(&self.pool)
226 .await
227 }
228 }
229 };
230 method_tokens.push(method);
231 }
232
233 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239 pk_fields.clone()
240 } else {
241 non_pk_fields.clone()
242 };
243
244 let insert_fields: Vec<TokenStream> = insert_source_fields
245 .iter()
246 .map(|f| {
247 let name = format_ident!("{}", f.rust_name);
248 let ty: TokenStream = f.rust_type.parse().unwrap();
249 quote! { pub #name: #ty, }
250 })
251 .collect();
252
253 let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
254 let col_list = col_names.join(", ");
255 let placeholders = build_placeholders(insert_source_fields.len(), db_kind, 1);
257 let placeholders_cast = build_placeholders_with_cast(&insert_source_fields, db_kind, 1, true);
258
259 let build_insert_sql = |ph: &str| match db_kind {
260 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
261 format!(
262 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
263 table_name, col_list, ph
264 )
265 }
266 DatabaseKind::Mysql => {
267 format!(
268 "INSERT INTO {} ({}) VALUES ({})",
269 table_name, col_list, ph
270 )
271 }
272 };
273 let sql = build_insert_sql(&placeholders);
274 let sql_macro = build_insert_sql(&placeholders_cast);
275
276 let binds: Vec<TokenStream> = insert_source_fields
277 .iter()
278 .map(|f| {
279 let name = format_ident!("{}", f.rust_name);
280 quote! { .bind(¶ms.#name) }
281 })
282 .collect();
283
284 let insert_method = build_insert_method_parsed(
285 &entity_ident,
286 &insert_params_ident,
287 &sql,
288 &sql_macro,
289 &binds,
290 db_kind,
291 &table_name,
292 &pk_fields,
293 &insert_source_fields,
294 use_macro,
295 );
296 method_tokens.push(insert_method);
297
298 param_structs.push(quote! {
299 #[derive(Debug, Clone, Default)]
300 pub struct #insert_params_ident {
301 #(#insert_fields)*
302 }
303 });
304 }
305
306 if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
308 let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
309
310 let overwrite_fields: Vec<TokenStream> = entity
311 .fields
312 .iter()
313 .map(|f| {
314 let name = format_ident!("{}", f.rust_name);
315 let ty: TokenStream = f.rust_type.parse().unwrap();
316 quote! { pub #name: #ty, }
317 })
318 .collect();
319
320 let set_cols: Vec<String> = non_pk_fields
321 .iter()
322 .enumerate()
323 .map(|(i, f)| {
324 let p = placeholder(db_kind, i + 1);
325 format!("{} = {}", f.column_name, p)
326 })
327 .collect();
328 let set_clause = set_cols.join(", ");
329
330 let set_cols_cast: Vec<String> = non_pk_fields
332 .iter()
333 .enumerate()
334 .map(|(i, f)| {
335 let p = placeholder_with_cast(db_kind, i + 1, f);
336 format!("{} = {}", f.column_name, p)
337 })
338 .collect();
339 let set_clause_cast = set_cols_cast.join(", ");
340
341 let pk_start = non_pk_fields.len() + 1;
342 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
343 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
344
345 let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
346 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
347 format!(
348 "UPDATE {} SET {} WHERE {} RETURNING *",
349 table_name, sc, wc
350 )
351 }
352 DatabaseKind::Mysql => {
353 format!(
354 "UPDATE {} SET {} WHERE {}",
355 table_name, sc, wc
356 )
357 }
358 };
359 let sql = build_overwrite_sql(&set_clause, &where_clause);
360 let sql_macro = build_overwrite_sql(&set_clause_cast, &where_clause_cast);
361
362 let mut all_binds: Vec<TokenStream> = non_pk_fields
364 .iter()
365 .map(|f| {
366 let name = format_ident!("{}", f.rust_name);
367 quote! { .bind(¶ms.#name) }
368 })
369 .collect();
370 for f in &pk_fields {
371 let name = format_ident!("{}", f.rust_name);
372 all_binds.push(quote! { .bind(¶ms.#name) });
373 }
374
375 let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
377 .iter()
378 .chain(pk_fields.iter())
379 .map(|f| macro_arg_for_field(f))
380 .collect();
381
382 let overwrite_method = if use_macro {
383 match db_kind {
384 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
385 quote! {
386 pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
387 sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
388 .fetch_one(&self.pool)
389 .await
390 }
391 }
392 }
393 DatabaseKind::Mysql => {
394 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
395 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
396 let pk_macro_args: Vec<TokenStream> = pk_fields
397 .iter()
398 .map(|f| {
399 let name = format_ident!("{}", f.rust_name);
400 quote! { params.#name }
401 })
402 .collect();
403 quote! {
404 pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
405 sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
406 .execute(&self.pool)
407 .await?;
408 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
409 .fetch_one(&self.pool)
410 .await
411 }
412 }
413 }
414 }
415 } else {
416 match db_kind {
417 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
418 quote! {
419 pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
420 sqlx::query_as::<_, #entity_ident>(#sql)
421 #(#all_binds)*
422 .fetch_one(&self.pool)
423 .await
424 }
425 }
426 }
427 DatabaseKind::Mysql => {
428 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
429 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
430 let pk_binds: Vec<TokenStream> = pk_fields
431 .iter()
432 .map(|f| {
433 let name = format_ident!("{}", f.rust_name);
434 quote! { .bind(¶ms.#name) }
435 })
436 .collect();
437 quote! {
438 pub async fn overwrite(&self, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
439 sqlx::query(#sql)
440 #(#all_binds)*
441 .execute(&self.pool)
442 .await?;
443 sqlx::query_as::<_, #entity_ident>(#select_sql)
444 #(#pk_binds)*
445 .fetch_one(&self.pool)
446 .await
447 }
448 }
449 }
450 }
451 };
452 method_tokens.push(overwrite_method);
453
454 param_structs.push(quote! {
455 #[derive(Debug, Clone, Default)]
456 pub struct #overwrite_params_ident {
457 #(#overwrite_fields)*
458 }
459 });
460 }
461
462 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
464 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
465
466 let update_fields: Vec<TokenStream> = entity
468 .fields
469 .iter()
470 .map(|f| {
471 let name = format_ident!("{}", f.rust_name);
472 if f.is_primary_key {
473 let ty: TokenStream = f.rust_type.parse().unwrap();
474 quote! { pub #name: #ty, }
475 } else if f.is_nullable {
476 let ty: TokenStream = f.rust_type.parse().unwrap();
478 quote! { pub #name: #ty, }
479 } else {
480 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
481 quote! { pub #name: #ty, }
482 }
483 })
484 .collect();
485
486 let set_cols: Vec<String> = non_pk_fields
488 .iter()
489 .enumerate()
490 .map(|(i, f)| {
491 let p = placeholder(db_kind, i + 1);
492 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
493 })
494 .collect();
495 let set_clause = set_cols.join(", ");
496
497 let set_cols_cast: Vec<String> = non_pk_fields
499 .iter()
500 .enumerate()
501 .map(|(i, f)| {
502 let p = placeholder_with_cast(db_kind, i + 1, f);
503 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
504 })
505 .collect();
506 let set_clause_cast = set_cols_cast.join(", ");
507
508 let pk_start = non_pk_fields.len() + 1;
509 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
510 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
511
512 let build_update_sql = |sc: &str, wc: &str| match db_kind {
513 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
514 format!(
515 "UPDATE {} SET {} WHERE {} RETURNING *",
516 table_name, sc, wc
517 )
518 }
519 DatabaseKind::Mysql => {
520 format!(
521 "UPDATE {} SET {} WHERE {}",
522 table_name, sc, wc
523 )
524 }
525 };
526 let sql = build_update_sql(&set_clause, &where_clause);
527 let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
528
529 let mut all_binds: Vec<TokenStream> = non_pk_fields
531 .iter()
532 .map(|f| {
533 let name = format_ident!("{}", f.rust_name);
534 quote! { .bind(¶ms.#name) }
535 })
536 .collect();
537 for f in &pk_fields {
538 let name = format_ident!("{}", f.rust_name);
539 all_binds.push(quote! { .bind(¶ms.#name) });
540 }
541
542 let update_macro_args: Vec<TokenStream> = non_pk_fields
545 .iter()
546 .chain(pk_fields.iter())
547 .map(|f| macro_arg_for_field(f))
548 .collect();
549
550 let update_method = if use_macro {
551 match db_kind {
552 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
553 quote! {
554 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
555 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
556 .fetch_one(&self.pool)
557 .await
558 }
559 }
560 }
561 DatabaseKind::Mysql => {
562 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
563 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
564 let pk_macro_args: Vec<TokenStream> = pk_fields
565 .iter()
566 .map(|f| {
567 let name = format_ident!("{}", f.rust_name);
568 quote! { params.#name }
569 })
570 .collect();
571 quote! {
572 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
573 sqlx::query!(#sql_macro, #(#update_macro_args),*)
574 .execute(&self.pool)
575 .await?;
576 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
577 .fetch_one(&self.pool)
578 .await
579 }
580 }
581 }
582 }
583 } else {
584 match db_kind {
585 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
586 quote! {
587 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
588 sqlx::query_as::<_, #entity_ident>(#sql)
589 #(#all_binds)*
590 .fetch_one(&self.pool)
591 .await
592 }
593 }
594 }
595 DatabaseKind::Mysql => {
596 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
597 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
598 let pk_binds: Vec<TokenStream> = pk_fields
599 .iter()
600 .map(|f| {
601 let name = format_ident!("{}", f.rust_name);
602 quote! { .bind(¶ms.#name) }
603 })
604 .collect();
605 quote! {
606 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
607 sqlx::query(#sql)
608 #(#all_binds)*
609 .execute(&self.pool)
610 .await?;
611 sqlx::query_as::<_, #entity_ident>(#select_sql)
612 #(#pk_binds)*
613 .fetch_one(&self.pool)
614 .await
615 }
616 }
617 }
618 }
619 };
620 method_tokens.push(update_method);
621
622 param_structs.push(quote! {
623 #[derive(Debug, Clone, Default)]
624 pub struct #update_params_ident {
625 #(#update_fields)*
626 }
627 });
628 }
629
630 if !is_view && methods.delete && !pk_fields.is_empty() {
632 let pk_params: Vec<TokenStream> = pk_fields
633 .iter()
634 .map(|f| {
635 let name = format_ident!("{}", f.rust_name);
636 let ty: TokenStream = f.inner_type.parse().unwrap();
637 quote! { #name: #ty }
638 })
639 .collect();
640
641 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
642 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
643 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
644 let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
645
646 let binds: Vec<TokenStream> = pk_fields
647 .iter()
648 .map(|f| {
649 let name = format_ident!("{}", f.rust_name);
650 quote! { .bind(#name) }
651 })
652 .collect();
653
654 let method = if query_macro {
655 let pk_arg_names: Vec<TokenStream> = pk_fields
656 .iter()
657 .map(|f| {
658 let name = format_ident!("{}", f.rust_name);
659 quote! { #name }
660 })
661 .collect();
662 quote! {
663 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
664 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
665 .execute(&self.pool)
666 .await?;
667 Ok(())
668 }
669 }
670 } else {
671 quote! {
672 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
673 sqlx::query(#sql)
674 #(#binds)*
675 .execute(&self.pool)
676 .await?;
677 Ok(())
678 }
679 }
680 };
681 method_tokens.push(method);
682 }
683
684 let pool_vis: TokenStream = match pool_visibility {
685 PoolVisibility::Private => quote! {},
686 PoolVisibility::Pub => quote! { pub },
687 PoolVisibility::PubCrate => quote! { pub(crate) },
688 };
689
690 let tokens = quote! {
691 #(#param_structs)*
692
693 pub struct #repo_ident {
694 #pool_vis pool: #pool_type,
695 }
696
697 impl #repo_ident {
698 pub fn new(pool: #pool_type) -> Self {
699 Self { pool }
700 }
701
702 #(#method_tokens)*
703 }
704 };
705
706 (tokens, imports)
707}
708
709fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
710 match db_kind {
711 DatabaseKind::Postgres => quote! { sqlx::PgPool },
712 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
713 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
714 }
715}
716
717fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
718 match db_kind {
719 DatabaseKind::Postgres => format!("${}", index),
720 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
721 }
722}
723
724fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
725 let base = placeholder(db_kind, index);
726 match (&field.sql_type, field.is_sql_array) {
727 (Some(t), true) => format!("{} as {}[]", base, t),
728 (Some(t), false) => format!("{} as {}", base, t),
729 (None, _) => base,
730 }
731}
732
733fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
734 (0..count)
735 .map(|i| placeholder(db_kind, start + i))
736 .collect::<Vec<_>>()
737 .join(", ")
738}
739
740fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
741 fields
742 .iter()
743 .enumerate()
744 .map(|(i, f)| {
745 if use_cast {
746 placeholder_with_cast(db_kind, start + i, f)
747 } else {
748 placeholder(db_kind, start + i)
749 }
750 })
751 .collect::<Vec<_>>()
752 .join(", ")
753}
754
755fn build_where_clause_parsed(
756 pk_fields: &[&ParsedField],
757 db_kind: DatabaseKind,
758 start_index: usize,
759) -> String {
760 pk_fields
761 .iter()
762 .enumerate()
763 .map(|(i, f)| {
764 let p = placeholder(db_kind, start_index + i);
765 format!("{} = {}", f.column_name, p)
766 })
767 .collect::<Vec<_>>()
768 .join(" AND ")
769}
770
771fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
772 let name = format_ident!("{}", field.rust_name);
773 let check_type = if field.is_nullable {
774 &field.inner_type
775 } else {
776 &field.rust_type
777 };
778 let normalized = check_type.replace(' ', "");
779 if normalized.starts_with("Vec<") {
780 quote! { params.#name.as_slice() }
781 } else {
782 quote! { params.#name }
783 }
784}
785
786fn build_where_clause_cast(
787 pk_fields: &[&ParsedField],
788 db_kind: DatabaseKind,
789 start_index: usize,
790) -> String {
791 pk_fields
792 .iter()
793 .enumerate()
794 .map(|(i, f)| {
795 let p = placeholder_with_cast(db_kind, start_index + i, f);
796 format!("{} = {}", f.column_name, p)
797 })
798 .collect::<Vec<_>>()
799 .join(" AND ")
800}
801
802#[allow(clippy::too_many_arguments)]
803fn build_insert_method_parsed(
804 entity_ident: &proc_macro2::Ident,
805 insert_params_ident: &proc_macro2::Ident,
806 sql: &str,
807 sql_macro: &str,
808 binds: &[TokenStream],
809 db_kind: DatabaseKind,
810 table_name: &str,
811 pk_fields: &[&ParsedField],
812 non_pk_fields: &[&ParsedField],
813 use_macro: bool,
814) -> TokenStream {
815 if use_macro {
816 let macro_args: Vec<TokenStream> = non_pk_fields
817 .iter()
818 .map(|f| macro_arg_for_field(f))
819 .collect();
820
821 match db_kind {
822 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
823 quote! {
824 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
825 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
826 .fetch_one(&self.pool)
827 .await
828 }
829 }
830 }
831 DatabaseKind::Mysql => {
832 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
833 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
834 quote! {
835 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
836 sqlx::query!(#sql_macro, #(#macro_args),*)
837 .execute(&self.pool)
838 .await?;
839 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
840 .fetch_one(&self.pool)
841 .await?;
842 sqlx::query_as!(#entity_ident, #select_sql, id)
843 .fetch_one(&self.pool)
844 .await
845 }
846 }
847 }
848 }
849 } else {
850 match db_kind {
851 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
852 quote! {
853 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
854 sqlx::query_as::<_, #entity_ident>(#sql)
855 #(#binds)*
856 .fetch_one(&self.pool)
857 .await
858 }
859 }
860 }
861 DatabaseKind::Mysql => {
862 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
863 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
864 quote! {
865 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
866 sqlx::query(#sql)
867 #(#binds)*
868 .execute(&self.pool)
869 .await?;
870 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
871 .fetch_one(&self.pool)
872 .await?;
873 sqlx::query_as::<_, #entity_ident>(#select_sql)
874 .bind(id)
875 .fetch_one(&self.pool)
876 .await
877 }
878 }
879 }
880 }
881 }
882}
883
884#[cfg(test)]
885mod tests {
886 use super::*;
887 use crate::codegen::parse_and_format;
888 use crate::cli::Methods;
889
890 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
891 let inner_type = if nullable {
892 rust_type
894 .strip_prefix("Option<")
895 .and_then(|s| s.strip_suffix('>'))
896 .unwrap_or(rust_type)
897 .to_string()
898 } else {
899 rust_type.to_string()
900 };
901 ParsedField {
902 rust_name: rust_name.to_string(),
903 column_name: column_name.to_string(),
904 rust_type: rust_type.to_string(),
905 is_nullable: nullable,
906 inner_type,
907 is_primary_key: is_pk,
908 sql_type: None,
909 is_sql_array: false,
910 }
911 }
912
913 fn standard_entity() -> ParsedEntity {
914 ParsedEntity {
915 struct_name: "Users".to_string(),
916 table_name: "users".to_string(),
917 schema_name: None,
918 is_view: false,
919 fields: vec![
920 make_field("id", "id", "i32", false, true),
921 make_field("name", "name", "String", false, false),
922 make_field("email", "email", "Option<String>", true, false),
923 ],
924 imports: vec![],
925 }
926 }
927
928 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
929 let skip = Methods::all();
930 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
931 parse_and_format(&tokens)
932 }
933
934 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
935 let skip = Methods::all();
936 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
937 parse_and_format(&tokens)
938 }
939
940 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
941 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
942 parse_and_format(&tokens)
943 }
944
945 #[test]
948 fn test_repo_struct_name() {
949 let code = gen(&standard_entity(), DatabaseKind::Postgres);
950 assert!(code.contains("pub struct UsersRepository"));
951 }
952
953 #[test]
954 fn test_repo_new_method() {
955 let code = gen(&standard_entity(), DatabaseKind::Postgres);
956 assert!(code.contains("pub fn new("));
957 }
958
959 #[test]
960 fn test_repo_pool_field_pg() {
961 let code = gen(&standard_entity(), DatabaseKind::Postgres);
962 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
963 }
964
965 #[test]
966 fn test_repo_pool_field_pub() {
967 let skip = Methods::all();
968 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
969 let code = parse_and_format(&tokens);
970 assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
971 }
972
973 #[test]
974 fn test_repo_pool_field_pub_crate() {
975 let skip = Methods::all();
976 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
977 let code = parse_and_format(&tokens);
978 assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
979 }
980
981 #[test]
982 fn test_repo_pool_field_private() {
983 let code = gen(&standard_entity(), DatabaseKind::Postgres);
984 assert!(!code.contains("pub pool"));
986 assert!(!code.contains("pub(crate) pool"));
987 }
988
989 #[test]
990 fn test_repo_pool_field_mysql() {
991 let code = gen(&standard_entity(), DatabaseKind::Mysql);
992 assert!(code.contains("MySqlPool") || code.contains("MySql"));
993 }
994
995 #[test]
996 fn test_repo_pool_field_sqlite() {
997 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
998 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
999 }
1000
1001 #[test]
1004 fn test_get_all_method() {
1005 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1006 assert!(code.contains("pub async fn get_all"));
1007 }
1008
1009 #[test]
1010 fn test_get_all_returns_vec() {
1011 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1012 assert!(code.contains("Vec<Users>"));
1013 }
1014
1015 #[test]
1016 fn test_get_all_sql() {
1017 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1018 assert!(code.contains("SELECT * FROM users"));
1019 }
1020
1021 #[test]
1024 fn test_paginate_method() {
1025 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1026 assert!(code.contains("pub async fn paginate"));
1027 }
1028
1029 #[test]
1030 fn test_paginate_params_struct() {
1031 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1032 assert!(code.contains("pub struct PaginateUsersParams"));
1033 }
1034
1035 #[test]
1036 fn test_paginate_params_fields() {
1037 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1038 assert!(code.contains("pub page: i64"));
1039 assert!(code.contains("pub per_page: i64"));
1040 }
1041
1042 #[test]
1043 fn test_paginate_returns_paginated() {
1044 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1045 assert!(code.contains("PaginatedUsers"));
1046 assert!(code.contains("PaginationUsersMeta"));
1047 }
1048
1049 #[test]
1050 fn test_paginate_meta_struct() {
1051 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1052 assert!(code.contains("pub struct PaginationUsersMeta"));
1053 assert!(code.contains("pub total: i64"));
1054 assert!(code.contains("pub last_page: i64"));
1055 assert!(code.contains("pub first_page: i64"));
1056 assert!(code.contains("pub current_page: i64"));
1057 }
1058
1059 #[test]
1060 fn test_paginate_data_struct() {
1061 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1062 assert!(code.contains("pub struct PaginatedUsers"));
1063 assert!(code.contains("pub meta: PaginationUsersMeta"));
1064 assert!(code.contains("pub data: Vec<Users>"));
1065 }
1066
1067 #[test]
1068 fn test_paginate_count_sql() {
1069 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1070 assert!(code.contains("SELECT COUNT(*) FROM users"));
1071 }
1072
1073 #[test]
1074 fn test_paginate_sql_pg() {
1075 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1076 assert!(code.contains("LIMIT $1 OFFSET $2"));
1077 }
1078
1079 #[test]
1080 fn test_paginate_sql_mysql() {
1081 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1082 assert!(code.contains("LIMIT ? OFFSET ?"));
1083 }
1084
1085 #[test]
1088 fn test_get_method() {
1089 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1090 assert!(code.contains("pub async fn get"));
1091 }
1092
1093 #[test]
1094 fn test_get_returns_option() {
1095 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1096 assert!(code.contains("Option<Users>"));
1097 }
1098
1099 #[test]
1100 fn test_get_where_pk_pg() {
1101 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1102 assert!(code.contains("WHERE id = $1"));
1103 }
1104
1105 #[test]
1106 fn test_get_where_pk_mysql() {
1107 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1108 assert!(code.contains("WHERE id = ?"));
1109 }
1110
1111 #[test]
1114 fn test_insert_method() {
1115 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1116 assert!(code.contains("pub async fn insert"));
1117 }
1118
1119 #[test]
1120 fn test_insert_params_struct() {
1121 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1122 assert!(code.contains("pub struct InsertUsersParams"));
1123 }
1124
1125 #[test]
1126 fn test_insert_params_no_pk() {
1127 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1128 assert!(code.contains("pub name: String"));
1129 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1130 }
1131
1132 #[test]
1133 fn test_insert_returning_pg() {
1134 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1135 assert!(code.contains("RETURNING *"));
1136 }
1137
1138 #[test]
1139 fn test_insert_returning_sqlite() {
1140 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1141 assert!(code.contains("RETURNING *"));
1142 }
1143
1144 #[test]
1145 fn test_insert_mysql_last_insert_id() {
1146 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1147 assert!(code.contains("LAST_INSERT_ID"));
1148 }
1149
1150 #[test]
1153 fn test_overwrite_method() {
1154 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1155 assert!(code.contains("pub async fn overwrite"));
1156 }
1157
1158 #[test]
1159 fn test_overwrite_params_struct() {
1160 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1161 assert!(code.contains("pub struct OverwriteUsersParams"));
1162 }
1163
1164 #[test]
1165 fn test_overwrite_params_all_cols() {
1166 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1167 assert!(code.contains("OverwriteUsersParams"));
1168 assert!(code.contains("pub id: i32"));
1169 assert!(code.contains("pub name: String"));
1170 }
1171
1172 #[test]
1173 fn test_overwrite_set_clause_pg() {
1174 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1175 assert!(code.contains("SET name = $1, email = $2 WHERE id = $3"));
1177 }
1178
1179 #[test]
1180 fn test_overwrite_returning_pg() {
1181 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1182 assert!(code.contains("UPDATE users SET"));
1183 assert!(code.contains("RETURNING *"));
1184 }
1185
1186 #[test]
1189 fn test_update_method() {
1190 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1191 assert!(code.contains("pub async fn update"));
1192 }
1193
1194 #[test]
1195 fn test_update_params_struct() {
1196 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1197 assert!(code.contains("pub struct UpdateUsersParams"));
1198 }
1199
1200 #[test]
1201 fn test_update_params_pk_keeps_original_type() {
1202 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1203 assert!(code.contains("pub id: i32") || code.contains("pub id : i32"));
1205 }
1206
1207 #[test]
1208 fn test_update_params_non_nullable_wrapped_in_option() {
1209 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1210 assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1212 }
1213
1214 #[test]
1215 fn test_update_params_already_nullable_no_double_option() {
1216 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1217 assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1219 }
1220
1221 #[test]
1222 fn test_update_set_clause_uses_coalesce_pg() {
1223 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1224 assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1225 assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1226 }
1227
1228 #[test]
1229 fn test_update_where_clause_pg() {
1230 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1231 assert!(code.contains("WHERE id = $3"));
1232 }
1233
1234 #[test]
1235 fn test_update_returning_pg() {
1236 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1237 assert!(code.contains("COALESCE"));
1238 assert!(code.contains("RETURNING *"));
1239 }
1240
1241 #[test]
1242 fn test_update_set_clause_mysql() {
1243 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1244 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1245 assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1246 }
1247
1248 #[test]
1249 fn test_update_set_clause_sqlite() {
1250 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1251 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1252 }
1253
1254 #[test]
1255 fn test_update_and_overwrite_coexist() {
1256 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1257 assert!(code.contains("pub async fn update"), "Expected update method:\n{}", code);
1258 assert!(code.contains("pub async fn overwrite"), "Expected overwrite method:\n{}", code);
1259 assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams struct:\n{}", code);
1260 assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams struct:\n{}", code);
1261 }
1262
1263 #[test]
1266 fn test_delete_method() {
1267 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1268 assert!(code.contains("pub async fn delete"));
1269 }
1270
1271 #[test]
1272 fn test_delete_where_pk() {
1273 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1274 assert!(code.contains("DELETE FROM users WHERE id = $1"));
1275 }
1276
1277 #[test]
1278 fn test_delete_returns_unit() {
1279 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1280 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1281 }
1282
1283 #[test]
1286 fn test_view_no_insert() {
1287 let mut entity = standard_entity();
1288 entity.is_view = true;
1289 let code = gen(&entity, DatabaseKind::Postgres);
1290 assert!(!code.contains("pub async fn insert"));
1291 }
1292
1293 #[test]
1294 fn test_view_no_update() {
1295 let mut entity = standard_entity();
1296 entity.is_view = true;
1297 let code = gen(&entity, DatabaseKind::Postgres);
1298 assert!(!code.contains("pub async fn update"));
1299 }
1300
1301 #[test]
1302 fn test_view_no_overwrite() {
1303 let mut entity = standard_entity();
1304 entity.is_view = true;
1305 let code = gen(&entity, DatabaseKind::Postgres);
1306 assert!(!code.contains("pub async fn overwrite"));
1307 }
1308
1309 #[test]
1310 fn test_view_no_delete() {
1311 let mut entity = standard_entity();
1312 entity.is_view = true;
1313 let code = gen(&entity, DatabaseKind::Postgres);
1314 assert!(!code.contains("pub async fn delete"));
1315 }
1316
1317 #[test]
1318 fn test_view_has_get_all() {
1319 let mut entity = standard_entity();
1320 entity.is_view = true;
1321 let code = gen(&entity, DatabaseKind::Postgres);
1322 assert!(code.contains("pub async fn get_all"));
1323 }
1324
1325 #[test]
1326 fn test_view_has_paginate() {
1327 let mut entity = standard_entity();
1328 entity.is_view = true;
1329 let code = gen(&entity, DatabaseKind::Postgres);
1330 assert!(code.contains("pub async fn paginate"));
1331 }
1332
1333 #[test]
1334 fn test_view_has_get() {
1335 let mut entity = standard_entity();
1336 entity.is_view = true;
1337 let code = gen(&entity, DatabaseKind::Postgres);
1338 assert!(code.contains("pub async fn get"));
1339 }
1340
1341 #[test]
1344 fn test_only_get_all() {
1345 let m = Methods { get_all: true, ..Default::default() };
1346 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1347 assert!(code.contains("pub async fn get_all"));
1348 assert!(!code.contains("pub async fn paginate"));
1349 assert!(!code.contains("pub async fn insert"));
1350 }
1351
1352 #[test]
1353 fn test_without_get_all() {
1354 let m = Methods { get_all: false, ..Methods::all() };
1355 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1356 assert!(!code.contains("pub async fn get_all"));
1357 }
1358
1359 #[test]
1360 fn test_without_paginate() {
1361 let m = Methods { paginate: false, ..Methods::all() };
1362 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1363 assert!(!code.contains("pub async fn paginate"));
1364 assert!(!code.contains("PaginateUsersParams"));
1365 }
1366
1367 #[test]
1368 fn test_without_get() {
1369 let m = Methods { get: false, ..Methods::all() };
1370 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1371 assert!(code.contains("pub async fn get_all"));
1372 let without_get_all = code.replace("get_all", "XXX");
1373 assert!(!without_get_all.contains("fn get("));
1374 }
1375
1376 #[test]
1377 fn test_without_insert() {
1378 let m = Methods { insert: false, ..Methods::all() };
1379 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1380 assert!(!code.contains("pub async fn insert"));
1381 assert!(!code.contains("InsertUsersParams"));
1382 }
1383
1384 #[test]
1385 fn test_without_update() {
1386 let m = Methods { update: false, ..Methods::all() };
1387 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1388 assert!(!code.contains("pub async fn update"));
1389 assert!(!code.contains("UpdateUsersParams"));
1390 }
1391
1392 #[test]
1393 fn test_without_overwrite() {
1394 let m = Methods { overwrite: false, ..Methods::all() };
1395 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1396 assert!(!code.contains("pub async fn overwrite"));
1397 assert!(!code.contains("OverwriteUsersParams"));
1398 }
1399
1400 #[test]
1401 fn test_without_delete() {
1402 let m = Methods { delete: false, ..Methods::all() };
1403 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1404 assert!(!code.contains("pub async fn delete"));
1405 }
1406
1407 #[test]
1408 fn test_empty_methods_no_methods() {
1409 let m = Methods::default();
1410 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1411 assert!(!code.contains("pub async fn get_all"));
1412 assert!(!code.contains("pub async fn paginate"));
1413 assert!(!code.contains("pub async fn insert"));
1414 assert!(!code.contains("pub async fn update"));
1415 assert!(!code.contains("pub async fn overwrite"));
1416 assert!(!code.contains("pub async fn delete"));
1417 }
1418
1419 #[test]
1422 fn test_no_pool_import() {
1423 let skip = Methods::all();
1424 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1425 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1426 }
1427
1428 #[test]
1429 fn test_imports_contain_entity() {
1430 let skip = Methods::all();
1431 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1432 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1433 }
1434
1435 #[test]
1438 fn test_renamed_column_in_sql() {
1439 let entity = ParsedEntity {
1440 struct_name: "Connector".to_string(),
1441 table_name: "connector".to_string(),
1442 schema_name: None,
1443 is_view: false,
1444 fields: vec![
1445 make_field("id", "id", "i32", false, true),
1446 make_field("connector_type", "type", "String", false, false),
1447 ],
1448 imports: vec![],
1449 };
1450 let code = gen(&entity, DatabaseKind::Postgres);
1451 assert!(code.contains("type"));
1453 assert!(code.contains("pub connector_type: String"));
1455 }
1456
1457 #[test]
1460 fn test_no_pk_no_get() {
1461 let entity = ParsedEntity {
1462 struct_name: "Logs".to_string(),
1463 table_name: "logs".to_string(),
1464 schema_name: None,
1465 is_view: false,
1466 fields: vec![
1467 make_field("message", "message", "String", false, false),
1468 make_field("ts", "ts", "String", false, false),
1469 ],
1470 imports: vec![],
1471 };
1472 let code = gen(&entity, DatabaseKind::Postgres);
1473 assert!(code.contains("pub async fn get_all"));
1474 let without_get_all = code.replace("get_all", "XXX");
1475 assert!(!without_get_all.contains("fn get("));
1476 }
1477
1478 #[test]
1479 fn test_no_pk_no_delete() {
1480 let entity = ParsedEntity {
1481 struct_name: "Logs".to_string(),
1482 table_name: "logs".to_string(),
1483 schema_name: None,
1484 is_view: false,
1485 fields: vec![
1486 make_field("message", "message", "String", false, false),
1487 ],
1488 imports: vec![],
1489 };
1490 let code = gen(&entity, DatabaseKind::Postgres);
1491 assert!(!code.contains("pub async fn delete"));
1492 }
1493
1494 #[test]
1497 fn test_param_structs_have_default() {
1498 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1499 assert!(code.contains("Default"));
1500 }
1501
1502 #[test]
1505 fn test_entity_imports_forwarded() {
1506 let entity = ParsedEntity {
1507 struct_name: "Users".to_string(),
1508 table_name: "users".to_string(),
1509 schema_name: None,
1510 is_view: false,
1511 fields: vec![
1512 make_field("id", "id", "Uuid", false, true),
1513 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1514 ],
1515 imports: vec![
1516 "use chrono::{DateTime, Utc};".to_string(),
1517 "use uuid::Uuid;".to_string(),
1518 ],
1519 };
1520 let skip = Methods::all();
1521 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1522 assert!(imports.iter().any(|i| i.contains("chrono")));
1523 assert!(imports.iter().any(|i| i.contains("uuid")));
1524 }
1525
1526 #[test]
1527 fn test_entity_imports_empty_when_no_imports() {
1528 let skip = Methods::all();
1529 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1530 assert!(!imports.iter().any(|i| i.contains("chrono")));
1532 assert!(!imports.iter().any(|i| i.contains("uuid")));
1533 }
1534
1535 #[test]
1538 fn test_macro_get_all() {
1539 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1540 assert!(code.contains("query_as!"));
1541 assert!(!code.contains("query_as::<"));
1542 }
1543
1544 #[test]
1545 fn test_macro_paginate() {
1546 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1547 assert!(code.contains("query_as!"));
1548 assert!(code.contains("per_page, offset"));
1549 }
1550
1551 #[test]
1552 fn test_macro_get() {
1553 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1554 assert!(code.contains("query_as!(Users"));
1556 }
1557
1558 #[test]
1559 fn test_macro_insert_pg() {
1560 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1561 assert!(code.contains("query_as!(Users"));
1562 assert!(code.contains("params.name"));
1563 assert!(code.contains("params.email"));
1564 }
1565
1566 #[test]
1567 fn test_macro_insert_mysql() {
1568 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1569 assert!(code.contains("query!"));
1571 assert!(code.contains("query_scalar!"));
1572 }
1573
1574 #[test]
1575 fn test_macro_overwrite() {
1576 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1577 assert!(code.contains("query_as!(Users"));
1578 assert!(code.contains("pub async fn overwrite"));
1580 assert!(code.contains("OverwriteUsersParams"));
1581 }
1582
1583 #[test]
1584 fn test_macro_update() {
1585 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1586 assert!(code.contains("query_as!(Users"));
1587 assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1588 assert!(code.contains("pub async fn update"));
1589 assert!(code.contains("UpdateUsersParams"));
1590 }
1591
1592 #[test]
1593 fn test_macro_delete() {
1594 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1595 assert!(code.contains("query!"));
1597 }
1598
1599 #[test]
1600 fn test_macro_no_bind_calls() {
1601 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1602 assert!(!code.contains(".bind("));
1603 }
1604
1605 #[test]
1606 fn test_function_style_uses_bind() {
1607 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1608 assert!(code.contains(".bind("));
1609 assert!(!code.contains("query_as!("));
1610 assert!(!code.contains("query!("));
1611 }
1612
1613 fn entity_with_sql_array() -> ParsedEntity {
1616 ParsedEntity {
1617 struct_name: "AgentConnector".to_string(),
1618 table_name: "agent.agent_connector".to_string(),
1619 schema_name: Some("agent".to_string()),
1620 is_view: false,
1621 fields: vec![
1622 ParsedField {
1623 rust_name: "connector_id".to_string(),
1624 column_name: "connector_id".to_string(),
1625 rust_type: "Uuid".to_string(),
1626 inner_type: "Uuid".to_string(),
1627 is_nullable: false,
1628 is_primary_key: true,
1629 sql_type: None,
1630 is_sql_array: false,
1631 },
1632 ParsedField {
1633 rust_name: "agent_id".to_string(),
1634 column_name: "agent_id".to_string(),
1635 rust_type: "Uuid".to_string(),
1636 inner_type: "Uuid".to_string(),
1637 is_nullable: false,
1638 is_primary_key: false,
1639 sql_type: None,
1640 is_sql_array: false,
1641 },
1642 ParsedField {
1643 rust_name: "usages".to_string(),
1644 column_name: "usages".to_string(),
1645 rust_type: "Vec<ConnectorUsages>".to_string(),
1646 inner_type: "Vec<ConnectorUsages>".to_string(),
1647 is_nullable: false,
1648 is_primary_key: false,
1649 sql_type: Some("agent.connector_usages".to_string()),
1650 is_sql_array: true,
1651 },
1652 ],
1653 imports: vec!["use uuid::Uuid;".to_string()],
1654 }
1655 }
1656
1657 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1658 let skip = Methods::all();
1659 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1660 parse_and_format(&tokens)
1661 }
1662
1663 #[test]
1664 fn test_sql_array_macro_get_all_uses_runtime() {
1665 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1666 assert!(code.contains("query_as::<"));
1668 }
1669
1670 #[test]
1671 fn test_sql_array_macro_get_uses_runtime() {
1672 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1673 assert!(code.contains(".bind("));
1675 }
1676
1677 #[test]
1678 fn test_sql_array_macro_insert_uses_runtime() {
1679 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1680 assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1682 }
1683
1684 #[test]
1685 fn test_sql_array_macro_overwrite_uses_runtime() {
1686 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1687 assert!(code.contains("query_as::<"));
1689 }
1690
1691 #[test]
1692 fn test_sql_array_macro_delete_still_uses_macro() {
1693 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1694 assert!(code.contains("query!"));
1696 }
1697
1698 #[test]
1699 fn test_sql_array_no_query_as_macro() {
1700 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1701 assert!(!code.contains("query_as!("));
1703 }
1704
1705 fn entity_with_sql_enum() -> ParsedEntity {
1708 ParsedEntity {
1709 struct_name: "Task".to_string(),
1710 table_name: "tasks".to_string(),
1711 schema_name: None,
1712 is_view: false,
1713 fields: vec![
1714 ParsedField {
1715 rust_name: "id".to_string(),
1716 column_name: "id".to_string(),
1717 rust_type: "i32".to_string(),
1718 inner_type: "i32".to_string(),
1719 is_nullable: false,
1720 is_primary_key: true,
1721 sql_type: None,
1722 is_sql_array: false,
1723 },
1724 ParsedField {
1725 rust_name: "status".to_string(),
1726 column_name: "status".to_string(),
1727 rust_type: "TaskStatus".to_string(),
1728 inner_type: "TaskStatus".to_string(),
1729 is_nullable: false,
1730 is_primary_key: false,
1731 sql_type: Some("task_status".to_string()),
1732 is_sql_array: false,
1733 },
1734 ],
1735 imports: vec![],
1736 }
1737 }
1738
1739 #[test]
1740 fn test_sql_enum_macro_uses_runtime() {
1741 let skip = Methods::all();
1742 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1743 let code = parse_and_format(&tokens);
1744 assert!(code.contains("query_as::<"));
1746 assert!(!code.contains("query_as!("));
1747 }
1748
1749 #[test]
1750 fn test_sql_enum_macro_delete_still_uses_macro() {
1751 let skip = Methods::all();
1752 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1753 let code = parse_and_format(&tokens);
1754 assert!(code.contains("query!"));
1756 }
1757
1758 fn entity_with_vec_string() -> ParsedEntity {
1761 ParsedEntity {
1762 struct_name: "PromptHistory".to_string(),
1763 table_name: "prompt_history".to_string(),
1764 schema_name: None,
1765 is_view: false,
1766 fields: vec![
1767 ParsedField {
1768 rust_name: "id".to_string(),
1769 column_name: "id".to_string(),
1770 rust_type: "Uuid".to_string(),
1771 inner_type: "Uuid".to_string(),
1772 is_nullable: false,
1773 is_primary_key: true,
1774 sql_type: None,
1775 is_sql_array: false,
1776 },
1777 ParsedField {
1778 rust_name: "content".to_string(),
1779 column_name: "content".to_string(),
1780 rust_type: "String".to_string(),
1781 inner_type: "String".to_string(),
1782 is_nullable: false,
1783 is_primary_key: false,
1784 sql_type: None,
1785 is_sql_array: false,
1786 },
1787 ParsedField {
1788 rust_name: "tags".to_string(),
1789 column_name: "tags".to_string(),
1790 rust_type: "Vec<String>".to_string(),
1791 inner_type: "Vec<String>".to_string(),
1792 is_nullable: false,
1793 is_primary_key: false,
1794 sql_type: None,
1795 is_sql_array: false,
1796 },
1797 ],
1798 imports: vec!["use uuid::Uuid;".to_string()],
1799 }
1800 }
1801
1802 #[test]
1803 fn test_vec_string_macro_insert_uses_as_slice() {
1804 let skip = Methods::all();
1805 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1806 let code = parse_and_format(&tokens);
1807 assert!(code.contains("as_slice()"));
1808 }
1809
1810 #[test]
1811 fn test_vec_string_macro_overwrite_uses_as_slice() {
1812 let skip = Methods::all();
1813 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1814 let code = parse_and_format(&tokens);
1815 let count = code.matches("as_slice()").count();
1817 assert!(count >= 3, "expected at least 3 as_slice() calls (insert + overwrite + update), found {}", count);
1818 }
1819
1820 #[test]
1821 fn test_vec_string_non_macro_no_as_slice() {
1822 let skip = Methods::all();
1823 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1824 let code = parse_and_format(&tokens);
1825 assert!(!code.contains("as_slice()"));
1827 }
1828
1829 #[test]
1830 fn test_vec_string_parsed_from_source_uses_as_slice() {
1831 use crate::codegen::entity_parser::parse_entity_source;
1832 let source = r#"
1833 use uuid::Uuid;
1834
1835 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1836 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1837 pub struct PromptHistory {
1838 #[sqlx_gen(primary_key)]
1839 pub id: Uuid,
1840 pub content: String,
1841 pub tags: Vec<String>,
1842 }
1843 "#;
1844 let entity = parse_entity_source(source).unwrap();
1845 let skip = Methods::all();
1846 let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1847 let code = parse_and_format(&tokens);
1848 assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1849 }
1850
1851 fn junction_entity() -> ParsedEntity {
1854 ParsedEntity {
1855 struct_name: "AnalysisRecord".to_string(),
1856 table_name: "analysis.analysis__record".to_string(),
1857 schema_name: None,
1858 is_view: false,
1859 fields: vec![
1860 make_field("record_id", "record_id", "uuid::Uuid", false, true),
1861 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
1862 ],
1863 imports: vec![],
1864 }
1865 }
1866
1867 #[test]
1868 fn test_composite_pk_only_insert_generated() {
1869 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1870 assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
1871 assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
1872 assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
1873 assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
1874 assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
1875 }
1876
1877 #[test]
1878 fn test_composite_pk_only_no_update() {
1879 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1880 assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
1881 assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
1882 }
1883
1884 #[test]
1885 fn test_composite_pk_only_no_overwrite() {
1886 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1887 assert!(!code.contains("OverwriteAnalysisRecordParams"), "Expected no OverwriteAnalysisRecordParams struct:\n{}", code);
1888 assert!(!code.contains("pub async fn overwrite"), "Expected no overwrite method:\n{}", code);
1889 }
1890
1891 #[test]
1892 fn test_composite_pk_only_delete_generated() {
1893 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1894 assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
1895 assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
1896 }
1897
1898 #[test]
1899 fn test_composite_pk_only_get_generated() {
1900 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1901 assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
1902 assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
1903 }
1904}