1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15 pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18
19 let entity_ident = format_ident!("{}", entity.struct_name);
20 let repo_name = format!("{}Repository", entity.struct_name);
21 let repo_ident = format_ident!("{}", repo_name);
22
23 let table_name = match &entity.schema_name {
24 Some(schema) => format!("{}.{}", schema, entity.table_name),
25 None => entity.table_name.clone(),
26 };
27
28 let pool_type = pool_type_tokens(db_kind);
30
31 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40 let entity_parent = entity_module_path
44 .rsplit_once("::")
45 .map(|(parent, _)| parent)
46 .unwrap_or(entity_module_path);
47 for imp in &entity.imports {
48 if let Some(rest) = imp.strip_prefix("use super::") {
49 imports.insert(format!("use {}::{}", entity_parent, rest));
50 } else {
51 imports.insert(imp.clone());
52 }
53 }
54
55 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61 let is_view = entity.is_view;
62
63 let mut method_tokens = Vec::new();
65 let mut param_structs = Vec::new();
66
67 if methods.get_all {
69 let sql = format!("SELECT * FROM {}", table_name);
70 let method = if use_macro {
71 quote! {
72 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73 sqlx::query_as!(#entity_ident, #sql)
74 .fetch_all(&self.pool)
75 .await
76 }
77 }
78 } else {
79 quote! {
80 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81 sqlx::query_as::<_, #entity_ident>(#sql)
82 .fetch_all(&self.pool)
83 .await
84 }
85 }
86 };
87 method_tokens.push(method);
88 }
89
90 if methods.paginate {
92 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96 let sql = match db_kind {
97 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99 };
100 let method = if use_macro {
101 quote! {
102 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103 let total: i64 = sqlx::query_scalar!(#count_sql)
104 .fetch_one(&self.pool)
105 .await?
106 .unwrap_or(0);
107 let per_page = params.per_page;
108 let current_page = params.page;
109 let last_page = (total + per_page - 1) / per_page;
110 let offset = (current_page - 1) * per_page;
111 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112 .fetch_all(&self.pool)
113 .await?;
114 Ok(#paginated_ident {
115 meta: #pagination_meta_ident {
116 total,
117 per_page,
118 current_page,
119 last_page,
120 first_page: 1,
121 },
122 data,
123 })
124 }
125 }
126 } else {
127 quote! {
128 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129 let total: i64 = sqlx::query_scalar(#count_sql)
130 .fetch_one(&self.pool)
131 .await?;
132 let per_page = params.per_page;
133 let current_page = params.page;
134 let last_page = (total + per_page - 1) / per_page;
135 let offset = (current_page - 1) * per_page;
136 let data = sqlx::query_as::<_, #entity_ident>(#sql)
137 .bind(per_page)
138 .bind(offset)
139 .fetch_all(&self.pool)
140 .await?;
141 Ok(#paginated_ident {
142 meta: #pagination_meta_ident {
143 total,
144 per_page,
145 current_page,
146 last_page,
147 first_page: 1,
148 },
149 data,
150 })
151 }
152 }
153 };
154 method_tokens.push(method);
155 param_structs.push(quote! {
156 #[derive(Debug, Clone, Default)]
157 pub struct #paginate_params_ident {
158 pub page: i64,
159 pub per_page: i64,
160 }
161 });
162 param_structs.push(quote! {
163 #[derive(Debug, Clone)]
164 pub struct #pagination_meta_ident {
165 pub total: i64,
166 pub per_page: i64,
167 pub current_page: i64,
168 pub last_page: i64,
169 pub first_page: i64,
170 }
171 });
172 param_structs.push(quote! {
173 #[derive(Debug, Clone)]
174 pub struct #paginated_ident {
175 pub meta: #pagination_meta_ident,
176 pub data: Vec<#entity_ident>,
177 }
178 });
179 }
180
181 if methods.get && !pk_fields.is_empty() {
183 let pk_params: Vec<TokenStream> = pk_fields
184 .iter()
185 .map(|f| {
186 let name = format_ident!("{}", f.rust_name);
187 let ty: TokenStream = f.inner_type.parse().unwrap();
188 quote! { #name: #ty }
189 })
190 .collect();
191
192 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195 let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197 let binds: Vec<TokenStream> = pk_fields
198 .iter()
199 .map(|f| {
200 let name = format_ident!("{}", f.rust_name);
201 quote! { .bind(#name) }
202 })
203 .collect();
204
205 let method = if use_macro {
206 let pk_arg_names: Vec<TokenStream> = pk_fields
207 .iter()
208 .map(|f| {
209 let name = format_ident!("{}", f.rust_name);
210 quote! { #name }
211 })
212 .collect();
213 quote! {
214 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216 .fetch_optional(&self.pool)
217 .await
218 }
219 }
220 } else {
221 quote! {
222 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223 sqlx::query_as::<_, #entity_ident>(#sql)
224 #(#binds)*
225 .fetch_optional(&self.pool)
226 .await
227 }
228 }
229 };
230 method_tokens.push(method);
231 }
232
233 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239 pk_fields.clone()
240 } else {
241 non_pk_fields.clone()
242 };
243
244 let insert_fields: Vec<TokenStream> = insert_source_fields
246 .iter()
247 .map(|f| {
248 let name = format_ident!("{}", f.rust_name);
249 if f.column_default.is_some() && !f.is_nullable {
250 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
251 quote! { pub #name: #ty, }
252 } else {
253 let ty: TokenStream = f.rust_type.parse().unwrap();
254 quote! { pub #name: #ty, }
255 }
256 })
257 .collect();
258
259 let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
260 let col_list = col_names.join(", ");
261
262 let placeholders: String = insert_source_fields
264 .iter()
265 .enumerate()
266 .map(|(i, f)| {
267 let p = placeholder(db_kind, i + 1);
268 match &f.column_default {
269 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
270 None => p,
271 }
272 })
273 .collect::<Vec<_>>()
274 .join(", ");
275
276 let placeholders_cast: String = insert_source_fields
277 .iter()
278 .enumerate()
279 .map(|(i, f)| {
280 let p = placeholder_with_cast(db_kind, i + 1, f);
281 match &f.column_default {
282 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
283 None => p,
284 }
285 })
286 .collect::<Vec<_>>()
287 .join(", ");
288
289 let build_insert_sql = |ph: &str| match db_kind {
290 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
291 format!(
292 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
293 table_name, col_list, ph
294 )
295 }
296 DatabaseKind::Mysql => {
297 format!(
298 "INSERT INTO {} ({}) VALUES ({})",
299 table_name, col_list, ph
300 )
301 }
302 };
303 let sql = build_insert_sql(&placeholders);
304 let sql_macro = build_insert_sql(&placeholders_cast);
305
306 let binds: Vec<TokenStream> = insert_source_fields
307 .iter()
308 .map(|f| {
309 let name = format_ident!("{}", f.rust_name);
310 quote! { .bind(¶ms.#name) }
311 })
312 .collect();
313
314 let insert_method = build_insert_method_parsed(
315 &entity_ident,
316 &insert_params_ident,
317 &sql,
318 &sql_macro,
319 &binds,
320 db_kind,
321 &table_name,
322 &pk_fields,
323 &insert_source_fields,
324 use_macro,
325 );
326 method_tokens.push(insert_method);
327
328 param_structs.push(quote! {
329 #[derive(Debug, Clone, Default)]
330 pub struct #insert_params_ident {
331 #(#insert_fields)*
332 }
333 });
334 }
335
336 if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
338 let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
339
340 let pk_fn_params: Vec<TokenStream> = pk_fields
342 .iter()
343 .map(|f| {
344 let name = format_ident!("{}", f.rust_name);
345 let ty: TokenStream = f.inner_type.parse().unwrap();
346 quote! { #name: #ty }
347 })
348 .collect();
349
350 let overwrite_fields: Vec<TokenStream> = non_pk_fields
352 .iter()
353 .map(|f| {
354 let name = format_ident!("{}", f.rust_name);
355 let ty: TokenStream = f.rust_type.parse().unwrap();
356 quote! { pub #name: #ty, }
357 })
358 .collect();
359
360 let set_cols: Vec<String> = non_pk_fields
361 .iter()
362 .enumerate()
363 .map(|(i, f)| {
364 let p = placeholder(db_kind, i + 1);
365 format!("{} = {}", f.column_name, p)
366 })
367 .collect();
368 let set_clause = set_cols.join(", ");
369
370 let set_cols_cast: Vec<String> = non_pk_fields
371 .iter()
372 .enumerate()
373 .map(|(i, f)| {
374 let p = placeholder_with_cast(db_kind, i + 1, f);
375 format!("{} = {}", f.column_name, p)
376 })
377 .collect();
378 let set_clause_cast = set_cols_cast.join(", ");
379
380 let pk_start = non_pk_fields.len() + 1;
381 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
382 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
383
384 let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
385 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
386 format!("UPDATE {} SET {} WHERE {} RETURNING *", table_name, sc, wc)
387 }
388 DatabaseKind::Mysql => {
389 format!("UPDATE {} SET {} WHERE {}", table_name, sc, wc)
390 }
391 };
392 let sql = build_overwrite_sql(&set_clause, &where_clause);
393 let sql_macro = build_overwrite_sql(&set_clause_cast, &where_clause_cast);
394
395 let mut all_binds: Vec<TokenStream> = non_pk_fields
397 .iter()
398 .map(|f| {
399 let name = format_ident!("{}", f.rust_name);
400 quote! { .bind(¶ms.#name) }
401 })
402 .collect();
403 for f in &pk_fields {
404 let name = format_ident!("{}", f.rust_name);
405 all_binds.push(quote! { .bind(#name) });
406 }
407
408 let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
410 .iter()
411 .map(|f| macro_arg_for_field(f))
412 .chain(pk_fields.iter().map(|f| {
413 let name = format_ident!("{}", f.rust_name);
414 quote! { #name }
415 }))
416 .collect();
417
418 let overwrite_method = if use_macro {
419 match db_kind {
420 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
421 quote! {
422 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
423 sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
424 .fetch_one(&self.pool)
425 .await
426 }
427 }
428 }
429 DatabaseKind::Mysql => {
430 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
431 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
432 let pk_macro_args: Vec<TokenStream> = pk_fields
433 .iter()
434 .map(|f| {
435 let name = format_ident!("{}", f.rust_name);
436 quote! { #name }
437 })
438 .collect();
439 quote! {
440 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
441 sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
442 .execute(&self.pool)
443 .await?;
444 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
445 .fetch_one(&self.pool)
446 .await
447 }
448 }
449 }
450 }
451 } else {
452 match db_kind {
453 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
454 quote! {
455 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
456 sqlx::query_as::<_, #entity_ident>(#sql)
457 #(#all_binds)*
458 .fetch_one(&self.pool)
459 .await
460 }
461 }
462 }
463 DatabaseKind::Mysql => {
464 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
465 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
466 let pk_binds: Vec<TokenStream> = pk_fields
467 .iter()
468 .map(|f| {
469 let name = format_ident!("{}", f.rust_name);
470 quote! { .bind(#name) }
471 })
472 .collect();
473 quote! {
474 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
475 sqlx::query(#sql)
476 #(#all_binds)*
477 .execute(&self.pool)
478 .await?;
479 sqlx::query_as::<_, #entity_ident>(#select_sql)
480 #(#pk_binds)*
481 .fetch_one(&self.pool)
482 .await
483 }
484 }
485 }
486 }
487 };
488 method_tokens.push(overwrite_method);
489
490 param_structs.push(quote! {
491 #[derive(Debug, Clone, Default)]
492 pub struct #overwrite_params_ident {
493 #(#overwrite_fields)*
494 }
495 });
496 }
497
498 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
500 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
501
502 let pk_fn_params: Vec<TokenStream> = pk_fields
504 .iter()
505 .map(|f| {
506 let name = format_ident!("{}", f.rust_name);
507 let ty: TokenStream = f.inner_type.parse().unwrap();
508 quote! { #name: #ty }
509 })
510 .collect();
511
512 let update_fields: Vec<TokenStream> = non_pk_fields
514 .iter()
515 .map(|f| {
516 let name = format_ident!("{}", f.rust_name);
517 if f.is_nullable {
518 let ty: TokenStream = f.rust_type.parse().unwrap();
520 quote! { pub #name: #ty, }
521 } else {
522 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
523 quote! { pub #name: #ty, }
524 }
525 })
526 .collect();
527
528 let set_cols: Vec<String> = non_pk_fields
530 .iter()
531 .enumerate()
532 .map(|(i, f)| {
533 let p = placeholder(db_kind, i + 1);
534 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
535 })
536 .collect();
537 let set_clause = set_cols.join(", ");
538
539 let set_cols_cast: Vec<String> = non_pk_fields
541 .iter()
542 .enumerate()
543 .map(|(i, f)| {
544 let p = placeholder_with_cast(db_kind, i + 1, f);
545 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
546 })
547 .collect();
548 let set_clause_cast = set_cols_cast.join(", ");
549
550 let pk_start = non_pk_fields.len() + 1;
551 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
552 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
553
554 let build_update_sql = |sc: &str, wc: &str| match db_kind {
555 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
556 format!(
557 "UPDATE {} SET {} WHERE {} RETURNING *",
558 table_name, sc, wc
559 )
560 }
561 DatabaseKind::Mysql => {
562 format!(
563 "UPDATE {} SET {} WHERE {}",
564 table_name, sc, wc
565 )
566 }
567 };
568 let sql = build_update_sql(&set_clause, &where_clause);
569 let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
570
571 let mut all_binds: Vec<TokenStream> = non_pk_fields
573 .iter()
574 .map(|f| {
575 let name = format_ident!("{}", f.rust_name);
576 quote! { .bind(¶ms.#name) }
577 })
578 .collect();
579 for f in &pk_fields {
580 let name = format_ident!("{}", f.rust_name);
581 all_binds.push(quote! { .bind(#name) });
582 }
583
584 let update_macro_args: Vec<TokenStream> = non_pk_fields
586 .iter()
587 .map(|f| macro_arg_for_field(f))
588 .chain(pk_fields.iter().map(|f| {
589 let name = format_ident!("{}", f.rust_name);
590 quote! { #name }
591 }))
592 .collect();
593
594 let update_method = if use_macro {
595 match db_kind {
596 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
597 quote! {
598 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
599 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
600 .fetch_one(&self.pool)
601 .await
602 }
603 }
604 }
605 DatabaseKind::Mysql => {
606 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
607 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
608 let pk_macro_args: Vec<TokenStream> = pk_fields
609 .iter()
610 .map(|f| {
611 let name = format_ident!("{}", f.rust_name);
612 quote! { #name }
613 })
614 .collect();
615 quote! {
616 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
617 sqlx::query!(#sql_macro, #(#update_macro_args),*)
618 .execute(&self.pool)
619 .await?;
620 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
621 .fetch_one(&self.pool)
622 .await
623 }
624 }
625 }
626 }
627 } else {
628 match db_kind {
629 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
630 quote! {
631 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
632 sqlx::query_as::<_, #entity_ident>(#sql)
633 #(#all_binds)*
634 .fetch_one(&self.pool)
635 .await
636 }
637 }
638 }
639 DatabaseKind::Mysql => {
640 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
641 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
642 let pk_binds: Vec<TokenStream> = pk_fields
643 .iter()
644 .map(|f| {
645 let name = format_ident!("{}", f.rust_name);
646 quote! { .bind(#name) }
647 })
648 .collect();
649 quote! {
650 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
651 sqlx::query(#sql)
652 #(#all_binds)*
653 .execute(&self.pool)
654 .await?;
655 sqlx::query_as::<_, #entity_ident>(#select_sql)
656 #(#pk_binds)*
657 .fetch_one(&self.pool)
658 .await
659 }
660 }
661 }
662 }
663 };
664 method_tokens.push(update_method);
665
666 param_structs.push(quote! {
667 #[derive(Debug, Clone, Default)]
668 pub struct #update_params_ident {
669 #(#update_fields)*
670 }
671 });
672 }
673
674 if !is_view && methods.delete && !pk_fields.is_empty() {
676 let pk_params: Vec<TokenStream> = pk_fields
677 .iter()
678 .map(|f| {
679 let name = format_ident!("{}", f.rust_name);
680 let ty: TokenStream = f.inner_type.parse().unwrap();
681 quote! { #name: #ty }
682 })
683 .collect();
684
685 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
686 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
687 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
688 let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
689
690 let binds: Vec<TokenStream> = pk_fields
691 .iter()
692 .map(|f| {
693 let name = format_ident!("{}", f.rust_name);
694 quote! { .bind(#name) }
695 })
696 .collect();
697
698 let method = if query_macro {
699 let pk_arg_names: Vec<TokenStream> = pk_fields
700 .iter()
701 .map(|f| {
702 let name = format_ident!("{}", f.rust_name);
703 quote! { #name }
704 })
705 .collect();
706 quote! {
707 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
708 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
709 .execute(&self.pool)
710 .await?;
711 Ok(())
712 }
713 }
714 } else {
715 quote! {
716 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
717 sqlx::query(#sql)
718 #(#binds)*
719 .execute(&self.pool)
720 .await?;
721 Ok(())
722 }
723 }
724 };
725 method_tokens.push(method);
726 }
727
728 let pool_vis: TokenStream = match pool_visibility {
729 PoolVisibility::Private => quote! {},
730 PoolVisibility::Pub => quote! { pub },
731 PoolVisibility::PubCrate => quote! { pub(crate) },
732 };
733
734 let tokens = quote! {
735 #(#param_structs)*
736
737 pub struct #repo_ident {
738 #pool_vis pool: #pool_type,
739 }
740
741 impl #repo_ident {
742 pub fn new(pool: #pool_type) -> Self {
743 Self { pool }
744 }
745
746 #(#method_tokens)*
747 }
748 };
749
750 (tokens, imports)
751}
752
753fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
754 match db_kind {
755 DatabaseKind::Postgres => quote! { sqlx::PgPool },
756 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
757 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
758 }
759}
760
761fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
762 match db_kind {
763 DatabaseKind::Postgres => format!("${}", index),
764 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
765 }
766}
767
768fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
769 let base = placeholder(db_kind, index);
770 match (&field.sql_type, field.is_sql_array) {
771 (Some(t), true) => format!("{} as {}[]", base, t),
772 (Some(t), false) => format!("{} as {}", base, t),
773 (None, _) => base,
774 }
775}
776
777fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
778 (0..count)
779 .map(|i| placeholder(db_kind, start + i))
780 .collect::<Vec<_>>()
781 .join(", ")
782}
783
784fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
785 fields
786 .iter()
787 .enumerate()
788 .map(|(i, f)| {
789 if use_cast {
790 placeholder_with_cast(db_kind, start + i, f)
791 } else {
792 placeholder(db_kind, start + i)
793 }
794 })
795 .collect::<Vec<_>>()
796 .join(", ")
797}
798
799fn build_where_clause_parsed(
800 pk_fields: &[&ParsedField],
801 db_kind: DatabaseKind,
802 start_index: usize,
803) -> String {
804 pk_fields
805 .iter()
806 .enumerate()
807 .map(|(i, f)| {
808 let p = placeholder(db_kind, start_index + i);
809 format!("{} = {}", f.column_name, p)
810 })
811 .collect::<Vec<_>>()
812 .join(" AND ")
813}
814
815fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
816 let name = format_ident!("{}", field.rust_name);
817 let check_type = if field.is_nullable {
818 &field.inner_type
819 } else {
820 &field.rust_type
821 };
822 let normalized = check_type.replace(' ', "");
823 if normalized.starts_with("Vec<") {
824 quote! { params.#name.as_slice() }
825 } else {
826 quote! { params.#name }
827 }
828}
829
830fn build_where_clause_cast(
831 pk_fields: &[&ParsedField],
832 db_kind: DatabaseKind,
833 start_index: usize,
834) -> String {
835 pk_fields
836 .iter()
837 .enumerate()
838 .map(|(i, f)| {
839 let p = placeholder_with_cast(db_kind, start_index + i, f);
840 format!("{} = {}", f.column_name, p)
841 })
842 .collect::<Vec<_>>()
843 .join(" AND ")
844}
845
846#[allow(clippy::too_many_arguments)]
847fn build_insert_method_parsed(
848 entity_ident: &proc_macro2::Ident,
849 insert_params_ident: &proc_macro2::Ident,
850 sql: &str,
851 sql_macro: &str,
852 binds: &[TokenStream],
853 db_kind: DatabaseKind,
854 table_name: &str,
855 pk_fields: &[&ParsedField],
856 non_pk_fields: &[&ParsedField],
857 use_macro: bool,
858) -> TokenStream {
859 if use_macro {
860 let macro_args: Vec<TokenStream> = non_pk_fields
861 .iter()
862 .map(|f| macro_arg_for_field(f))
863 .collect();
864
865 match db_kind {
866 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
867 quote! {
868 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
869 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
870 .fetch_one(&self.pool)
871 .await
872 }
873 }
874 }
875 DatabaseKind::Mysql => {
876 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
877 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
878 quote! {
879 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
880 sqlx::query!(#sql_macro, #(#macro_args),*)
881 .execute(&self.pool)
882 .await?;
883 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
884 .fetch_one(&self.pool)
885 .await?;
886 sqlx::query_as!(#entity_ident, #select_sql, id)
887 .fetch_one(&self.pool)
888 .await
889 }
890 }
891 }
892 }
893 } else {
894 match db_kind {
895 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
896 quote! {
897 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
898 sqlx::query_as::<_, #entity_ident>(#sql)
899 #(#binds)*
900 .fetch_one(&self.pool)
901 .await
902 }
903 }
904 }
905 DatabaseKind::Mysql => {
906 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
907 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
908 quote! {
909 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
910 sqlx::query(#sql)
911 #(#binds)*
912 .execute(&self.pool)
913 .await?;
914 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
915 .fetch_one(&self.pool)
916 .await?;
917 sqlx::query_as::<_, #entity_ident>(#select_sql)
918 .bind(id)
919 .fetch_one(&self.pool)
920 .await
921 }
922 }
923 }
924 }
925 }
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931 use crate::codegen::parse_and_format;
932 use crate::cli::Methods;
933
934 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
935 let inner_type = if nullable {
936 rust_type
938 .strip_prefix("Option<")
939 .and_then(|s| s.strip_suffix('>'))
940 .unwrap_or(rust_type)
941 .to_string()
942 } else {
943 rust_type.to_string()
944 };
945 ParsedField {
946 rust_name: rust_name.to_string(),
947 column_name: column_name.to_string(),
948 rust_type: rust_type.to_string(),
949 is_nullable: nullable,
950 inner_type,
951 is_primary_key: is_pk,
952 sql_type: None,
953 is_sql_array: false,
954 column_default: None,
955 }
956 }
957
958 fn make_field_with_default(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool, default: &str) -> ParsedField {
959 let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
960 f.column_default = Some(default.to_string());
961 f
962 }
963
964 fn entity_with_defaults() -> ParsedEntity {
965 ParsedEntity {
966 struct_name: "Tasks".to_string(),
967 table_name: "tasks".to_string(),
968 schema_name: None,
969 is_view: false,
970 fields: vec![
971 make_field("id", "id", "i32", false, true),
972 make_field("title", "title", "String", false, false),
973 make_field_with_default("status", "status", "String", false, false, "'idle'::task_status"),
974 make_field_with_default("priority", "priority", "i32", false, false, "0"),
975 make_field_with_default("created_at", "created_at", "DateTime<Utc>", false, false, "now()"),
976 make_field("description", "description", "Option<String>", true, false),
977 make_field_with_default("deleted_at", "deleted_at", "Option<DateTime<Utc>>", true, false, "NULL"),
978 ],
979 imports: vec![],
980 }
981 }
982
983 fn standard_entity() -> ParsedEntity {
984 ParsedEntity {
985 struct_name: "Users".to_string(),
986 table_name: "users".to_string(),
987 schema_name: None,
988 is_view: false,
989 fields: vec![
990 make_field("id", "id", "i32", false, true),
991 make_field("name", "name", "String", false, false),
992 make_field("email", "email", "Option<String>", true, false),
993 ],
994 imports: vec![],
995 }
996 }
997
998 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
999 let skip = Methods::all();
1000 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1001 parse_and_format(&tokens)
1002 }
1003
1004 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1005 let skip = Methods::all();
1006 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
1007 parse_and_format(&tokens)
1008 }
1009
1010 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1011 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
1012 parse_and_format(&tokens)
1013 }
1014
1015 #[test]
1018 fn test_repo_struct_name() {
1019 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1020 assert!(code.contains("pub struct UsersRepository"));
1021 }
1022
1023 #[test]
1024 fn test_repo_new_method() {
1025 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1026 assert!(code.contains("pub fn new("));
1027 }
1028
1029 #[test]
1030 fn test_repo_pool_field_pg() {
1031 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1032 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1033 }
1034
1035 #[test]
1036 fn test_repo_pool_field_pub() {
1037 let skip = Methods::all();
1038 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
1039 let code = parse_and_format(&tokens);
1040 assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
1041 }
1042
1043 #[test]
1044 fn test_repo_pool_field_pub_crate() {
1045 let skip = Methods::all();
1046 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
1047 let code = parse_and_format(&tokens);
1048 assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
1049 }
1050
1051 #[test]
1052 fn test_repo_pool_field_private() {
1053 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1054 assert!(!code.contains("pub pool"));
1056 assert!(!code.contains("pub(crate) pool"));
1057 }
1058
1059 #[test]
1060 fn test_repo_pool_field_mysql() {
1061 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1062 assert!(code.contains("MySqlPool") || code.contains("MySql"));
1063 }
1064
1065 #[test]
1066 fn test_repo_pool_field_sqlite() {
1067 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1068 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1069 }
1070
1071 #[test]
1074 fn test_get_all_method() {
1075 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1076 assert!(code.contains("pub async fn get_all"));
1077 }
1078
1079 #[test]
1080 fn test_get_all_returns_vec() {
1081 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1082 assert!(code.contains("Vec<Users>"));
1083 }
1084
1085 #[test]
1086 fn test_get_all_sql() {
1087 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1088 assert!(code.contains("SELECT * FROM users"));
1089 }
1090
1091 #[test]
1094 fn test_paginate_method() {
1095 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1096 assert!(code.contains("pub async fn paginate"));
1097 }
1098
1099 #[test]
1100 fn test_paginate_params_struct() {
1101 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1102 assert!(code.contains("pub struct PaginateUsersParams"));
1103 }
1104
1105 #[test]
1106 fn test_paginate_params_fields() {
1107 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1108 assert!(code.contains("pub page: i64"));
1109 assert!(code.contains("pub per_page: i64"));
1110 }
1111
1112 #[test]
1113 fn test_paginate_returns_paginated() {
1114 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1115 assert!(code.contains("PaginatedUsers"));
1116 assert!(code.contains("PaginationUsersMeta"));
1117 }
1118
1119 #[test]
1120 fn test_paginate_meta_struct() {
1121 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1122 assert!(code.contains("pub struct PaginationUsersMeta"));
1123 assert!(code.contains("pub total: i64"));
1124 assert!(code.contains("pub last_page: i64"));
1125 assert!(code.contains("pub first_page: i64"));
1126 assert!(code.contains("pub current_page: i64"));
1127 }
1128
1129 #[test]
1130 fn test_paginate_data_struct() {
1131 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1132 assert!(code.contains("pub struct PaginatedUsers"));
1133 assert!(code.contains("pub meta: PaginationUsersMeta"));
1134 assert!(code.contains("pub data: Vec<Users>"));
1135 }
1136
1137 #[test]
1138 fn test_paginate_count_sql() {
1139 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1140 assert!(code.contains("SELECT COUNT(*) FROM users"));
1141 }
1142
1143 #[test]
1144 fn test_paginate_sql_pg() {
1145 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1146 assert!(code.contains("LIMIT $1 OFFSET $2"));
1147 }
1148
1149 #[test]
1150 fn test_paginate_sql_mysql() {
1151 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1152 assert!(code.contains("LIMIT ? OFFSET ?"));
1153 }
1154
1155 #[test]
1158 fn test_get_method() {
1159 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1160 assert!(code.contains("pub async fn get"));
1161 }
1162
1163 #[test]
1164 fn test_get_returns_option() {
1165 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1166 assert!(code.contains("Option<Users>"));
1167 }
1168
1169 #[test]
1170 fn test_get_where_pk_pg() {
1171 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1172 assert!(code.contains("WHERE id = $1"));
1173 }
1174
1175 #[test]
1176 fn test_get_where_pk_mysql() {
1177 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1178 assert!(code.contains("WHERE id = ?"));
1179 }
1180
1181 #[test]
1184 fn test_insert_method() {
1185 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1186 assert!(code.contains("pub async fn insert"));
1187 }
1188
1189 #[test]
1190 fn test_insert_params_struct() {
1191 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1192 assert!(code.contains("pub struct InsertUsersParams"));
1193 }
1194
1195 #[test]
1196 fn test_insert_params_no_pk() {
1197 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1198 assert!(code.contains("pub name: String"));
1199 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1200 }
1201
1202 #[test]
1203 fn test_insert_returning_pg() {
1204 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1205 assert!(code.contains("RETURNING *"));
1206 }
1207
1208 #[test]
1209 fn test_insert_returning_sqlite() {
1210 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1211 assert!(code.contains("RETURNING *"));
1212 }
1213
1214 #[test]
1215 fn test_insert_mysql_last_insert_id() {
1216 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1217 assert!(code.contains("LAST_INSERT_ID"));
1218 }
1219
1220 #[test]
1223 fn test_insert_default_col_is_optional() {
1224 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1225 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1227 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1228 let struct_body = &code[struct_start..struct_end];
1229 assert!(struct_body.contains("Option") && struct_body.contains("status"), "Expected status as Option in InsertTasksParams: {}", struct_body);
1230 }
1231
1232 #[test]
1233 fn test_insert_non_default_col_required() {
1234 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1235 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1237 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1238 let struct_body = &code[struct_start..struct_end];
1239 assert!(struct_body.contains("title") && struct_body.contains("String"), "Expected title as String: {}", struct_body);
1240 }
1241
1242 #[test]
1243 fn test_insert_default_col_coalesce_sql() {
1244 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1245 assert!(code.contains("COALESCE($2, 'idle'::task_status)"), "Expected COALESCE for status:\n{}", code);
1246 assert!(code.contains("COALESCE($3, 0)"), "Expected COALESCE for priority:\n{}", code);
1247 assert!(code.contains("COALESCE($4, now())"), "Expected COALESCE for created_at:\n{}", code);
1248 }
1249
1250 #[test]
1251 fn test_insert_no_coalesce_for_non_default() {
1252 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1253 assert!(code.contains("VALUES ($1, COALESCE"), "Expected $1 without COALESCE for title:\n{}", code);
1255 }
1256
1257 #[test]
1258 fn test_insert_nullable_with_default_no_double_option() {
1259 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1260 assert!(!code.contains("Option < Option") && !code.contains("Option<Option"), "Should not have Option<Option>:\n{}", code);
1261 }
1262
1263 #[test]
1264 fn test_insert_derive_default() {
1265 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1266 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1267 let before_struct = &code[..struct_start];
1268 assert!(before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"), "Expected #[derive(Default)] on InsertTasksParams");
1269 }
1270
1271 #[test]
1274 fn test_update_method() {
1275 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1276 assert!(code.contains("pub async fn update"));
1277 }
1278
1279 #[test]
1280 fn test_update_params_struct() {
1281 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1282 assert!(code.contains("pub struct UpdateUsersParams"));
1283 }
1284
1285 #[test]
1286 fn test_update_pk_in_fn_signature() {
1287 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1288 let update_pos = code.find("fn update").expect("fn update not found");
1290 let params_pos = code[update_pos..].find("UpdateUsersParams").expect("UpdateUsersParams not found in update fn");
1291 let signature = &code[update_pos..update_pos + params_pos];
1292 assert!(signature.contains("id"), "Expected 'id' PK in update fn signature: {}", signature);
1293 }
1294
1295 #[test]
1296 fn test_update_pk_not_in_struct() {
1297 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1298 let struct_start = code.find("pub struct UpdateUsersParams").expect("UpdateUsersParams not found");
1301 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1302 let struct_body = &code[struct_start..struct_end];
1303 assert!(!struct_body.contains("pub id"), "PK 'id' should not be in UpdateUsersParams:\n{}", struct_body);
1304 }
1305
1306 #[test]
1307 fn test_update_params_non_nullable_wrapped_in_option() {
1308 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1309 assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1311 }
1312
1313 #[test]
1314 fn test_update_params_already_nullable_no_double_option() {
1315 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1316 assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1318 }
1319
1320 #[test]
1321 fn test_update_set_clause_uses_coalesce_pg() {
1322 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1323 assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1324 assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1325 }
1326
1327 #[test]
1328 fn test_update_where_clause_pg() {
1329 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1330 assert!(code.contains("WHERE id = $3"));
1331 }
1332
1333 #[test]
1334 fn test_update_returning_pg() {
1335 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1336 assert!(code.contains("COALESCE"));
1337 assert!(code.contains("RETURNING *"));
1338 }
1339
1340 #[test]
1341 fn test_update_set_clause_mysql() {
1342 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1343 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1344 assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1345 }
1346
1347 #[test]
1348 fn test_update_set_clause_sqlite() {
1349 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1350 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1351 }
1352
1353 #[test]
1356 fn test_overwrite_method() {
1357 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1358 assert!(code.contains("pub async fn overwrite"));
1359 }
1360
1361 #[test]
1362 fn test_overwrite_params_struct() {
1363 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1364 assert!(code.contains("pub struct OverwriteUsersParams"));
1365 }
1366
1367 #[test]
1368 fn test_overwrite_pk_in_fn_signature() {
1369 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1370 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1371 let params_pos = code[pos..].find("OverwriteUsersParams").expect("OverwriteUsersParams not found");
1372 let signature = &code[pos..pos + params_pos];
1373 assert!(signature.contains("id"), "Expected PK in overwrite fn signature: {}", signature);
1374 }
1375
1376 #[test]
1377 fn test_overwrite_pk_not_in_struct() {
1378 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1379 let struct_start = code.find("pub struct OverwriteUsersParams").expect("OverwriteUsersParams not found");
1380 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1381 let struct_body = &code[struct_start..struct_end];
1382 assert!(!struct_body.contains("pub id"), "PK should not be in OverwriteUsersParams: {}", struct_body);
1383 }
1384
1385 #[test]
1386 fn test_overwrite_no_coalesce() {
1387 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1388 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1390 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1391 assert!(!method_body.contains("COALESCE"), "Overwrite should not use COALESCE: {}", method_body);
1392 }
1393
1394 #[test]
1395 fn test_overwrite_set_clause_pg() {
1396 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1397 assert!(code.contains("SET name = $1, email = $2 WHERE id = $3"));
1398 }
1399
1400 #[test]
1401 fn test_overwrite_returning_pg() {
1402 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1403 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1404 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1405 assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in overwrite");
1406 }
1407
1408 #[test]
1409 fn test_view_no_overwrite() {
1410 let mut entity = standard_entity();
1411 entity.is_view = true;
1412 let code = gen(&entity, DatabaseKind::Postgres);
1413 assert!(!code.contains("pub async fn overwrite"));
1414 }
1415
1416 #[test]
1417 fn test_without_overwrite() {
1418 let m = Methods { overwrite: false, ..Methods::all() };
1419 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1420 assert!(!code.contains("pub async fn overwrite"));
1421 assert!(!code.contains("OverwriteUsersParams"));
1422 }
1423
1424 #[test]
1425 fn test_update_and_overwrite_coexist() {
1426 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1427 assert!(code.contains("pub async fn update"), "Expected update method");
1428 assert!(code.contains("pub async fn overwrite"), "Expected overwrite method");
1429 assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams");
1430 assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams");
1431 }
1432
1433 #[test]
1436 fn test_delete_method() {
1437 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1438 assert!(code.contains("pub async fn delete"));
1439 }
1440
1441 #[test]
1442 fn test_delete_where_pk() {
1443 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1444 assert!(code.contains("DELETE FROM users WHERE id = $1"));
1445 }
1446
1447 #[test]
1448 fn test_delete_returns_unit() {
1449 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1450 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1451 }
1452
1453 #[test]
1456 fn test_view_no_insert() {
1457 let mut entity = standard_entity();
1458 entity.is_view = true;
1459 let code = gen(&entity, DatabaseKind::Postgres);
1460 assert!(!code.contains("pub async fn insert"));
1461 }
1462
1463 #[test]
1464 fn test_view_no_update() {
1465 let mut entity = standard_entity();
1466 entity.is_view = true;
1467 let code = gen(&entity, DatabaseKind::Postgres);
1468 assert!(!code.contains("pub async fn update"));
1469 }
1470
1471 #[test]
1472 fn test_view_no_delete() {
1473 let mut entity = standard_entity();
1474 entity.is_view = true;
1475 let code = gen(&entity, DatabaseKind::Postgres);
1476 assert!(!code.contains("pub async fn delete"));
1477 }
1478
1479 #[test]
1480 fn test_view_has_get_all() {
1481 let mut entity = standard_entity();
1482 entity.is_view = true;
1483 let code = gen(&entity, DatabaseKind::Postgres);
1484 assert!(code.contains("pub async fn get_all"));
1485 }
1486
1487 #[test]
1488 fn test_view_has_paginate() {
1489 let mut entity = standard_entity();
1490 entity.is_view = true;
1491 let code = gen(&entity, DatabaseKind::Postgres);
1492 assert!(code.contains("pub async fn paginate"));
1493 }
1494
1495 #[test]
1496 fn test_view_has_get() {
1497 let mut entity = standard_entity();
1498 entity.is_view = true;
1499 let code = gen(&entity, DatabaseKind::Postgres);
1500 assert!(code.contains("pub async fn get"));
1501 }
1502
1503 #[test]
1506 fn test_only_get_all() {
1507 let m = Methods { get_all: true, ..Default::default() };
1508 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1509 assert!(code.contains("pub async fn get_all"));
1510 assert!(!code.contains("pub async fn paginate"));
1511 assert!(!code.contains("pub async fn insert"));
1512 }
1513
1514 #[test]
1515 fn test_without_get_all() {
1516 let m = Methods { get_all: false, ..Methods::all() };
1517 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1518 assert!(!code.contains("pub async fn get_all"));
1519 }
1520
1521 #[test]
1522 fn test_without_paginate() {
1523 let m = Methods { paginate: false, ..Methods::all() };
1524 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1525 assert!(!code.contains("pub async fn paginate"));
1526 assert!(!code.contains("PaginateUsersParams"));
1527 }
1528
1529 #[test]
1530 fn test_without_get() {
1531 let m = Methods { get: false, ..Methods::all() };
1532 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1533 assert!(code.contains("pub async fn get_all"));
1534 let without_get_all = code.replace("get_all", "XXX");
1535 assert!(!without_get_all.contains("fn get("));
1536 }
1537
1538 #[test]
1539 fn test_without_insert() {
1540 let m = Methods { insert: false, ..Methods::all() };
1541 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1542 assert!(!code.contains("pub async fn insert"));
1543 assert!(!code.contains("InsertUsersParams"));
1544 }
1545
1546 #[test]
1547 fn test_without_update() {
1548 let m = Methods { update: false, ..Methods::all() };
1549 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1550 assert!(!code.contains("pub async fn update"));
1551 assert!(!code.contains("UpdateUsersParams"));
1552 }
1553
1554 #[test]
1555 fn test_without_delete() {
1556 let m = Methods { delete: false, ..Methods::all() };
1557 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1558 assert!(!code.contains("pub async fn delete"));
1559 }
1560
1561 #[test]
1562 fn test_empty_methods_no_methods() {
1563 let m = Methods::default();
1564 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1565 assert!(!code.contains("pub async fn get_all"));
1566 assert!(!code.contains("pub async fn paginate"));
1567 assert!(!code.contains("pub async fn insert"));
1568 assert!(!code.contains("pub async fn update"));
1569 assert!(!code.contains("pub async fn overwrite"));
1570 assert!(!code.contains("pub async fn delete"));
1571 }
1572
1573 #[test]
1576 fn test_no_pool_import() {
1577 let skip = Methods::all();
1578 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1579 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1580 }
1581
1582 #[test]
1583 fn test_imports_contain_entity() {
1584 let skip = Methods::all();
1585 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1586 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1587 }
1588
1589 #[test]
1592 fn test_renamed_column_in_sql() {
1593 let entity = ParsedEntity {
1594 struct_name: "Connector".to_string(),
1595 table_name: "connector".to_string(),
1596 schema_name: None,
1597 is_view: false,
1598 fields: vec![
1599 make_field("id", "id", "i32", false, true),
1600 make_field("connector_type", "type", "String", false, false),
1601 ],
1602 imports: vec![],
1603 };
1604 let code = gen(&entity, DatabaseKind::Postgres);
1605 assert!(code.contains("type"));
1607 assert!(code.contains("pub connector_type: String"));
1609 }
1610
1611 #[test]
1614 fn test_no_pk_no_get() {
1615 let entity = ParsedEntity {
1616 struct_name: "Logs".to_string(),
1617 table_name: "logs".to_string(),
1618 schema_name: None,
1619 is_view: false,
1620 fields: vec![
1621 make_field("message", "message", "String", false, false),
1622 make_field("ts", "ts", "String", false, false),
1623 ],
1624 imports: vec![],
1625 };
1626 let code = gen(&entity, DatabaseKind::Postgres);
1627 assert!(code.contains("pub async fn get_all"));
1628 let without_get_all = code.replace("get_all", "XXX");
1629 assert!(!without_get_all.contains("fn get("));
1630 }
1631
1632 #[test]
1633 fn test_no_pk_no_delete() {
1634 let entity = ParsedEntity {
1635 struct_name: "Logs".to_string(),
1636 table_name: "logs".to_string(),
1637 schema_name: None,
1638 is_view: false,
1639 fields: vec![
1640 make_field("message", "message", "String", false, false),
1641 ],
1642 imports: vec![],
1643 };
1644 let code = gen(&entity, DatabaseKind::Postgres);
1645 assert!(!code.contains("pub async fn delete"));
1646 }
1647
1648 #[test]
1651 fn test_param_structs_have_default() {
1652 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1653 assert!(code.contains("Default"));
1654 }
1655
1656 #[test]
1659 fn test_entity_imports_forwarded() {
1660 let entity = ParsedEntity {
1661 struct_name: "Users".to_string(),
1662 table_name: "users".to_string(),
1663 schema_name: None,
1664 is_view: false,
1665 fields: vec![
1666 make_field("id", "id", "Uuid", false, true),
1667 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1668 ],
1669 imports: vec![
1670 "use chrono::{DateTime, Utc};".to_string(),
1671 "use uuid::Uuid;".to_string(),
1672 ],
1673 };
1674 let skip = Methods::all();
1675 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1676 assert!(imports.iter().any(|i| i.contains("chrono")));
1677 assert!(imports.iter().any(|i| i.contains("uuid")));
1678 }
1679
1680 #[test]
1681 fn test_entity_imports_empty_when_no_imports() {
1682 let skip = Methods::all();
1683 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1684 assert!(!imports.iter().any(|i| i.contains("chrono")));
1686 assert!(!imports.iter().any(|i| i.contains("uuid")));
1687 }
1688
1689 #[test]
1692 fn test_macro_get_all() {
1693 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1694 assert!(code.contains("query_as!"));
1695 assert!(!code.contains("query_as::<"));
1696 }
1697
1698 #[test]
1699 fn test_macro_paginate() {
1700 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1701 assert!(code.contains("query_as!"));
1702 assert!(code.contains("per_page, offset"));
1703 }
1704
1705 #[test]
1706 fn test_macro_get() {
1707 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1708 assert!(code.contains("query_as!(Users"));
1710 }
1711
1712 #[test]
1713 fn test_macro_insert_pg() {
1714 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1715 assert!(code.contains("query_as!(Users"));
1716 assert!(code.contains("params.name"));
1717 assert!(code.contains("params.email"));
1718 }
1719
1720 #[test]
1721 fn test_macro_insert_mysql() {
1722 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1723 assert!(code.contains("query!"));
1725 assert!(code.contains("query_scalar!"));
1726 }
1727
1728 #[test]
1729 fn test_macro_update() {
1730 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1731 assert!(code.contains("query_as!(Users"));
1732 assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1733 assert!(code.contains("pub async fn update"));
1734 assert!(code.contains("UpdateUsersParams"));
1735 }
1736
1737 #[test]
1738 fn test_macro_delete() {
1739 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1740 assert!(code.contains("query!"));
1742 }
1743
1744 #[test]
1745 fn test_macro_no_bind_calls() {
1746 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1747 assert!(!code.contains(".bind("));
1748 }
1749
1750 #[test]
1751 fn test_function_style_uses_bind() {
1752 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1753 assert!(code.contains(".bind("));
1754 assert!(!code.contains("query_as!("));
1755 assert!(!code.contains("query!("));
1756 }
1757
1758 fn entity_with_sql_array() -> ParsedEntity {
1761 ParsedEntity {
1762 struct_name: "AgentConnector".to_string(),
1763 table_name: "agent.agent_connector".to_string(),
1764 schema_name: Some("agent".to_string()),
1765 is_view: false,
1766 fields: vec![
1767 ParsedField {
1768 rust_name: "connector_id".to_string(),
1769 column_name: "connector_id".to_string(),
1770 rust_type: "Uuid".to_string(),
1771 inner_type: "Uuid".to_string(),
1772 is_nullable: false,
1773 is_primary_key: true,
1774 sql_type: None,
1775 is_sql_array: false,
1776 column_default: None,
1777 },
1778 ParsedField {
1779 rust_name: "agent_id".to_string(),
1780 column_name: "agent_id".to_string(),
1781 rust_type: "Uuid".to_string(),
1782 inner_type: "Uuid".to_string(),
1783 is_nullable: false,
1784 is_primary_key: false,
1785 sql_type: None,
1786 is_sql_array: false,
1787 column_default: None,
1788 },
1789 ParsedField {
1790 rust_name: "usages".to_string(),
1791 column_name: "usages".to_string(),
1792 rust_type: "Vec<ConnectorUsages>".to_string(),
1793 inner_type: "Vec<ConnectorUsages>".to_string(),
1794 is_nullable: false,
1795 is_primary_key: false,
1796 sql_type: Some("agent.connector_usages".to_string()),
1797 is_sql_array: true,
1798 column_default: None,
1799 },
1800 ],
1801 imports: vec!["use uuid::Uuid;".to_string()],
1802 }
1803 }
1804
1805 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1806 let skip = Methods::all();
1807 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1808 parse_and_format(&tokens)
1809 }
1810
1811 #[test]
1812 fn test_sql_array_macro_get_all_uses_runtime() {
1813 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1814 assert!(code.contains("query_as::<"));
1816 }
1817
1818 #[test]
1819 fn test_sql_array_macro_get_uses_runtime() {
1820 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1821 assert!(code.contains(".bind("));
1823 }
1824
1825 #[test]
1826 fn test_sql_array_macro_insert_uses_runtime() {
1827 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1828 assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1830 }
1831
1832
1833 #[test]
1834 fn test_sql_array_macro_delete_still_uses_macro() {
1835 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1836 assert!(code.contains("query!"));
1838 }
1839
1840 #[test]
1841 fn test_sql_array_no_query_as_macro() {
1842 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1843 assert!(!code.contains("query_as!("));
1845 }
1846
1847 fn entity_with_sql_enum() -> ParsedEntity {
1850 ParsedEntity {
1851 struct_name: "Task".to_string(),
1852 table_name: "tasks".to_string(),
1853 schema_name: None,
1854 is_view: false,
1855 fields: vec![
1856 ParsedField {
1857 rust_name: "id".to_string(),
1858 column_name: "id".to_string(),
1859 rust_type: "i32".to_string(),
1860 inner_type: "i32".to_string(),
1861 is_nullable: false,
1862 is_primary_key: true,
1863 sql_type: None,
1864 is_sql_array: false,
1865 column_default: None,
1866 },
1867 ParsedField {
1868 rust_name: "status".to_string(),
1869 column_name: "status".to_string(),
1870 rust_type: "TaskStatus".to_string(),
1871 inner_type: "TaskStatus".to_string(),
1872 is_nullable: false,
1873 is_primary_key: false,
1874 sql_type: Some("task_status".to_string()),
1875 is_sql_array: false,
1876 column_default: None,
1877 },
1878 ],
1879 imports: vec![],
1880 }
1881 }
1882
1883 #[test]
1884 fn test_sql_enum_macro_uses_runtime() {
1885 let skip = Methods::all();
1886 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1887 let code = parse_and_format(&tokens);
1888 assert!(code.contains("query_as::<"));
1890 assert!(!code.contains("query_as!("));
1891 }
1892
1893 #[test]
1894 fn test_sql_enum_macro_delete_still_uses_macro() {
1895 let skip = Methods::all();
1896 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1897 let code = parse_and_format(&tokens);
1898 assert!(code.contains("query!"));
1900 }
1901
1902 fn entity_with_vec_string() -> ParsedEntity {
1905 ParsedEntity {
1906 struct_name: "PromptHistory".to_string(),
1907 table_name: "prompt_history".to_string(),
1908 schema_name: None,
1909 is_view: false,
1910 fields: vec![
1911 ParsedField {
1912 rust_name: "id".to_string(),
1913 column_name: "id".to_string(),
1914 rust_type: "Uuid".to_string(),
1915 inner_type: "Uuid".to_string(),
1916 is_nullable: false,
1917 is_primary_key: true,
1918 sql_type: None,
1919 is_sql_array: false,
1920 column_default: None,
1921 },
1922 ParsedField {
1923 rust_name: "content".to_string(),
1924 column_name: "content".to_string(),
1925 rust_type: "String".to_string(),
1926 inner_type: "String".to_string(),
1927 is_nullable: false,
1928 is_primary_key: false,
1929 sql_type: None,
1930 is_sql_array: false,
1931 column_default: None,
1932 },
1933 ParsedField {
1934 rust_name: "tags".to_string(),
1935 column_name: "tags".to_string(),
1936 rust_type: "Vec<String>".to_string(),
1937 inner_type: "Vec<String>".to_string(),
1938 is_nullable: false,
1939 is_primary_key: false,
1940 sql_type: None,
1941 is_sql_array: false,
1942 column_default: None,
1943 },
1944 ],
1945 imports: vec!["use uuid::Uuid;".to_string()],
1946 }
1947 }
1948
1949 #[test]
1950 fn test_vec_string_macro_insert_uses_as_slice() {
1951 let skip = Methods::all();
1952 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1953 let code = parse_and_format(&tokens);
1954 assert!(code.contains("as_slice()"));
1955 }
1956
1957 #[test]
1958 fn test_vec_string_macro_update_uses_as_slice() {
1959 let skip = Methods::all();
1960 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1961 let code = parse_and_format(&tokens);
1962 let count = code.matches("as_slice()").count();
1964 assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1965 }
1966
1967 #[test]
1968 fn test_vec_string_non_macro_no_as_slice() {
1969 let skip = Methods::all();
1970 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1971 let code = parse_and_format(&tokens);
1972 assert!(!code.contains("as_slice()"));
1974 }
1975
1976 #[test]
1977 fn test_vec_string_parsed_from_source_uses_as_slice() {
1978 use crate::codegen::entity_parser::parse_entity_source;
1979 let source = r#"
1980 use uuid::Uuid;
1981
1982 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1983 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1984 pub struct PromptHistory {
1985 #[sqlx_gen(primary_key)]
1986 pub id: Uuid,
1987 pub content: String,
1988 pub tags: Vec<String>,
1989 }
1990 "#;
1991 let entity = parse_entity_source(source).unwrap();
1992 let skip = Methods::all();
1993 let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1994 let code = parse_and_format(&tokens);
1995 assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1996 }
1997
1998 fn junction_entity() -> ParsedEntity {
2001 ParsedEntity {
2002 struct_name: "AnalysisRecord".to_string(),
2003 table_name: "analysis.analysis__record".to_string(),
2004 schema_name: None,
2005 is_view: false,
2006 fields: vec![
2007 make_field("record_id", "record_id", "uuid::Uuid", false, true),
2008 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2009 ],
2010 imports: vec![],
2011 }
2012 }
2013
2014 #[test]
2015 fn test_composite_pk_only_insert_generated() {
2016 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2017 assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
2018 assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
2019 assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
2020 assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
2021 assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
2022 }
2023
2024 #[test]
2025 fn test_composite_pk_only_no_update() {
2026 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2027 assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
2028 assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
2029 }
2030
2031
2032 #[test]
2033 fn test_composite_pk_only_delete_generated() {
2034 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2035 assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
2036 assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
2037 }
2038
2039 #[test]
2040 fn test_composite_pk_only_get_generated() {
2041 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2042 assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
2043 assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
2044 }
2045}