1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15 pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18
19 let entity_ident = format_ident!("{}", entity.struct_name);
20 let repo_name = format!("{}Repository", entity.struct_name);
21 let repo_ident = format_ident!("{}", repo_name);
22
23 let table_name = match &entity.schema_name {
24 Some(schema) => format!("{}.{}", schema, entity.table_name),
25 None => entity.table_name.clone(),
26 };
27
28 let pool_type = pool_type_tokens(db_kind);
30
31 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40 let entity_parent = entity_module_path
44 .rsplit_once("::")
45 .map(|(parent, _)| parent)
46 .unwrap_or(entity_module_path);
47 for imp in &entity.imports {
48 if let Some(rest) = imp.strip_prefix("use super::") {
49 imports.insert(format!("use {}::{}", entity_parent, rest));
50 } else {
51 imports.insert(imp.clone());
52 }
53 }
54
55 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61 let is_view = entity.is_view;
62
63 let mut method_tokens = Vec::new();
65 let mut param_structs = Vec::new();
66
67 if methods.get_all {
69 let sql = raw_sql_lit(&format!("SELECT * FROM {}", table_name));
70 let method = if use_macro {
71 quote! {
72 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73 sqlx::query_as!(#entity_ident, #sql)
74 .fetch_all(&self.pool)
75 .await
76 }
77 }
78 } else {
79 quote! {
80 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81 sqlx::query_as::<_, #entity_ident>(#sql)
82 .fetch_all(&self.pool)
83 .await
84 }
85 }
86 };
87 method_tokens.push(method);
88 }
89
90 if methods.paginate {
92 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95 let count_sql = raw_sql_lit(&format!("SELECT COUNT(*) FROM {}", table_name));
96 let sql = raw_sql_lit(&match db_kind {
97 DatabaseKind::Postgres => format!("SELECT *\nFROM {}\nLIMIT $1 OFFSET $2", table_name),
98 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT *\nFROM {}\nLIMIT ? OFFSET ?", table_name),
99 });
100 let method = if use_macro {
101 quote! {
102 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103 let total: i64 = sqlx::query_scalar!(#count_sql)
104 .fetch_one(&self.pool)
105 .await?
106 .unwrap_or(0);
107 let per_page = params.per_page;
108 let current_page = params.page;
109 let last_page = (total + per_page - 1) / per_page;
110 let offset = (current_page - 1) * per_page;
111 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112 .fetch_all(&self.pool)
113 .await?;
114 Ok(#paginated_ident {
115 meta: #pagination_meta_ident {
116 total,
117 per_page,
118 current_page,
119 last_page,
120 first_page: 1,
121 },
122 data,
123 })
124 }
125 }
126 } else {
127 quote! {
128 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129 let total: i64 = sqlx::query_scalar(#count_sql)
130 .fetch_one(&self.pool)
131 .await?;
132 let per_page = params.per_page;
133 let current_page = params.page;
134 let last_page = (total + per_page - 1) / per_page;
135 let offset = (current_page - 1) * per_page;
136 let data = sqlx::query_as::<_, #entity_ident>(#sql)
137 .bind(per_page)
138 .bind(offset)
139 .fetch_all(&self.pool)
140 .await?;
141 Ok(#paginated_ident {
142 meta: #pagination_meta_ident {
143 total,
144 per_page,
145 current_page,
146 last_page,
147 first_page: 1,
148 },
149 data,
150 })
151 }
152 }
153 };
154 method_tokens.push(method);
155 param_structs.push(quote! {
156 #[derive(Debug, Clone, Default)]
157 pub struct #paginate_params_ident {
158 pub page: i64,
159 pub per_page: i64,
160 }
161 });
162 param_structs.push(quote! {
163 #[derive(Debug, Clone)]
164 pub struct #pagination_meta_ident {
165 pub total: i64,
166 pub per_page: i64,
167 pub current_page: i64,
168 pub last_page: i64,
169 pub first_page: i64,
170 }
171 });
172 param_structs.push(quote! {
173 #[derive(Debug, Clone)]
174 pub struct #paginated_ident {
175 pub meta: #pagination_meta_ident,
176 pub data: Vec<#entity_ident>,
177 }
178 });
179 }
180
181 if methods.get && !pk_fields.is_empty() {
183 let pk_params: Vec<TokenStream> = pk_fields
184 .iter()
185 .map(|f| {
186 let name = format_ident!("{}", f.rust_name);
187 let ty: TokenStream = f.inner_type.parse().unwrap();
188 quote! { #name: #ty }
189 })
190 .collect();
191
192 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194 let sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, where_clause));
195 let sql_macro = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, where_clause_cast));
196
197 let binds: Vec<TokenStream> = pk_fields
198 .iter()
199 .map(|f| {
200 let name = format_ident!("{}", f.rust_name);
201 quote! { .bind(#name) }
202 })
203 .collect();
204
205 let method = if use_macro {
206 let pk_arg_names: Vec<TokenStream> = pk_fields
207 .iter()
208 .map(|f| {
209 let name = format_ident!("{}", f.rust_name);
210 quote! { #name }
211 })
212 .collect();
213 quote! {
214 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216 .fetch_optional(&self.pool)
217 .await
218 }
219 }
220 } else {
221 quote! {
222 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223 sqlx::query_as::<_, #entity_ident>(#sql)
224 #(#binds)*
225 .fetch_optional(&self.pool)
226 .await
227 }
228 }
229 };
230 method_tokens.push(method);
231 }
232
233 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239 pk_fields.clone()
240 } else {
241 non_pk_fields.clone()
242 };
243
244 let insert_fields: Vec<TokenStream> = insert_source_fields
246 .iter()
247 .map(|f| {
248 let name = format_ident!("{}", f.rust_name);
249 if f.column_default.is_some() && !f.is_nullable {
250 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
251 quote! { pub #name: #ty, }
252 } else {
253 let ty: TokenStream = f.rust_type.parse().unwrap();
254 quote! { pub #name: #ty, }
255 }
256 })
257 .collect();
258
259 let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
260 let col_list = col_names.join(", ");
261
262 let placeholders: String = insert_source_fields
264 .iter()
265 .enumerate()
266 .map(|(i, f)| {
267 let p = placeholder(db_kind, i + 1);
268 match &f.column_default {
269 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
270 None => p,
271 }
272 })
273 .collect::<Vec<_>>()
274 .join(", ");
275
276 let placeholders_cast: String = insert_source_fields
277 .iter()
278 .enumerate()
279 .map(|(i, f)| {
280 let p = placeholder_with_cast(db_kind, i + 1, f);
281 match &f.column_default {
282 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
283 None => p,
284 }
285 })
286 .collect::<Vec<_>>()
287 .join(", ");
288
289 let build_insert_sql = |ph: &str| match db_kind {
290 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
291 format!(
292 "INSERT INTO {} ({})\nVALUES ({})\nRETURNING *",
293 table_name, col_list, ph
294 )
295 }
296 DatabaseKind::Mysql => {
297 format!(
298 "INSERT INTO {} ({})\nVALUES ({})",
299 table_name, col_list, ph
300 )
301 }
302 };
303 let sql = build_insert_sql(&placeholders);
304 let sql_macro = build_insert_sql(&placeholders_cast);
305
306 let binds: Vec<TokenStream> = insert_source_fields
307 .iter()
308 .map(|f| {
309 let name = format_ident!("{}", f.rust_name);
310 quote! { .bind(¶ms.#name) }
311 })
312 .collect();
313
314 let insert_method = build_insert_method_parsed(
315 &entity_ident,
316 &insert_params_ident,
317 &sql,
318 &sql_macro,
319 &binds,
320 db_kind,
321 &table_name,
322 &pk_fields,
323 &insert_source_fields,
324 use_macro,
325 );
326 method_tokens.push(insert_method);
327
328 param_structs.push(quote! {
329 #[derive(Debug, Clone, Default)]
330 pub struct #insert_params_ident {
331 #(#insert_fields)*
332 }
333 });
334 }
335
336 if !is_view && methods.insert_many && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
338 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
339
340 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
341 pk_fields.clone()
342 } else {
343 non_pk_fields.clone()
344 };
345
346 let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
347 let col_list = col_names.join(", ");
348 let num_cols = insert_source_fields.len();
349
350 let binds_loop: Vec<TokenStream> = insert_source_fields
351 .iter()
352 .map(|f| {
353 let name = format_ident!("{}", f.rust_name);
354 quote! { query = query.bind(¶ms.#name); }
355 })
356 .collect();
357
358 let insert_many_method = build_insert_many_transactionally_method(
359 &entity_ident,
360 &insert_params_ident,
361 &col_list,
362 num_cols,
363 &insert_source_fields,
364 &binds_loop,
365 db_kind,
366 &table_name,
367 &pk_fields,
368 );
369 method_tokens.push(insert_many_method);
370
371 if !methods.insert {
373 let insert_fields: Vec<TokenStream> = insert_source_fields
374 .iter()
375 .map(|f| {
376 let name = format_ident!("{}", f.rust_name);
377 if f.column_default.is_some() && !f.is_nullable {
378 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
379 quote! { pub #name: #ty, }
380 } else {
381 let ty: TokenStream = f.rust_type.parse().unwrap();
382 quote! { pub #name: #ty, }
383 }
384 })
385 .collect();
386
387 param_structs.push(quote! {
388 #[derive(Debug, Clone, Default)]
389 pub struct #insert_params_ident {
390 #(#insert_fields)*
391 }
392 });
393 }
394 }
395
396 if !is_view && methods.overwrite && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
398 let overwrite_params_ident = format_ident!("Overwrite{}Params", entity.struct_name);
399
400 let pk_fn_params: Vec<TokenStream> = pk_fields
402 .iter()
403 .map(|f| {
404 let name = format_ident!("{}", f.rust_name);
405 let ty: TokenStream = f.inner_type.parse().unwrap();
406 quote! { #name: #ty }
407 })
408 .collect();
409
410 let overwrite_fields: Vec<TokenStream> = non_pk_fields
412 .iter()
413 .map(|f| {
414 let name = format_ident!("{}", f.rust_name);
415 let ty: TokenStream = f.rust_type.parse().unwrap();
416 quote! { pub #name: #ty, }
417 })
418 .collect();
419
420 let set_cols: Vec<String> = non_pk_fields
421 .iter()
422 .enumerate()
423 .map(|(i, f)| {
424 let p = placeholder(db_kind, i + 1);
425 format!("{} = {}", f.column_name, p)
426 })
427 .collect();
428 let set_clause = set_cols.join(",\n ");
429
430 let set_cols_cast: Vec<String> = non_pk_fields
431 .iter()
432 .enumerate()
433 .map(|(i, f)| {
434 let p = placeholder_with_cast(db_kind, i + 1, f);
435 format!("{} = {}", f.column_name, p)
436 })
437 .collect();
438 let set_clause_cast = set_cols_cast.join(",\n ");
439
440 let pk_start = non_pk_fields.len() + 1;
441 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
442 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
443
444 let build_overwrite_sql = |sc: &str, wc: &str| match db_kind {
445 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
446 format!("UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *", table_name, sc, wc)
447 }
448 DatabaseKind::Mysql => {
449 format!("UPDATE {}\nSET\n {}\nWHERE {}", table_name, sc, wc)
450 }
451 };
452 let sql = raw_sql_lit(&build_overwrite_sql(&set_clause, &where_clause));
453 let sql_macro = raw_sql_lit(&build_overwrite_sql(&set_clause_cast, &where_clause_cast));
454
455 let mut all_binds: Vec<TokenStream> = non_pk_fields
457 .iter()
458 .map(|f| {
459 let name = format_ident!("{}", f.rust_name);
460 quote! { .bind(¶ms.#name) }
461 })
462 .collect();
463 for f in &pk_fields {
464 let name = format_ident!("{}", f.rust_name);
465 all_binds.push(quote! { .bind(#name) });
466 }
467
468 let overwrite_macro_args: Vec<TokenStream> = non_pk_fields
470 .iter()
471 .map(|f| macro_arg_for_field(f))
472 .chain(pk_fields.iter().map(|f| {
473 let name = format_ident!("{}", f.rust_name);
474 quote! { #name }
475 }))
476 .collect();
477
478 let overwrite_method = if use_macro {
479 match db_kind {
480 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
481 quote! {
482 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
483 sqlx::query_as!(#entity_ident, #sql_macro, #(#overwrite_macro_args),*)
484 .fetch_one(&self.pool)
485 .await
486 }
487 }
488 }
489 DatabaseKind::Mysql => {
490 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
491 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
492 let pk_macro_args: Vec<TokenStream> = pk_fields
493 .iter()
494 .map(|f| {
495 let name = format_ident!("{}", f.rust_name);
496 quote! { #name }
497 })
498 .collect();
499 quote! {
500 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
501 sqlx::query!(#sql_macro, #(#overwrite_macro_args),*)
502 .execute(&self.pool)
503 .await?;
504 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
505 .fetch_one(&self.pool)
506 .await
507 }
508 }
509 }
510 }
511 } else {
512 match db_kind {
513 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
514 quote! {
515 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
516 sqlx::query_as::<_, #entity_ident>(#sql)
517 #(#all_binds)*
518 .fetch_one(&self.pool)
519 .await
520 }
521 }
522 }
523 DatabaseKind::Mysql => {
524 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
525 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
526 let pk_binds: Vec<TokenStream> = pk_fields
527 .iter()
528 .map(|f| {
529 let name = format_ident!("{}", f.rust_name);
530 quote! { .bind(#name) }
531 })
532 .collect();
533 quote! {
534 pub async fn overwrite(&self, #(#pk_fn_params),*, params: &#overwrite_params_ident) -> Result<#entity_ident, sqlx::Error> {
535 sqlx::query(#sql)
536 #(#all_binds)*
537 .execute(&self.pool)
538 .await?;
539 sqlx::query_as::<_, #entity_ident>(#select_sql)
540 #(#pk_binds)*
541 .fetch_one(&self.pool)
542 .await
543 }
544 }
545 }
546 }
547 };
548 method_tokens.push(overwrite_method);
549
550 param_structs.push(quote! {
551 #[derive(Debug, Clone, Default)]
552 pub struct #overwrite_params_ident {
553 #(#overwrite_fields)*
554 }
555 });
556 }
557
558 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
560 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
561
562 let pk_fn_params: Vec<TokenStream> = pk_fields
564 .iter()
565 .map(|f| {
566 let name = format_ident!("{}", f.rust_name);
567 let ty: TokenStream = f.inner_type.parse().unwrap();
568 quote! { #name: #ty }
569 })
570 .collect();
571
572 let update_fields: Vec<TokenStream> = non_pk_fields
574 .iter()
575 .map(|f| {
576 let name = format_ident!("{}", f.rust_name);
577 if f.is_nullable {
578 let ty: TokenStream = f.rust_type.parse().unwrap();
580 quote! { pub #name: #ty, }
581 } else {
582 let ty: TokenStream = format!("Option<{}>", f.rust_type).parse().unwrap();
583 quote! { pub #name: #ty, }
584 }
585 })
586 .collect();
587
588 let set_cols: Vec<String> = non_pk_fields
590 .iter()
591 .enumerate()
592 .map(|(i, f)| {
593 let p = placeholder(db_kind, i + 1);
594 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
595 })
596 .collect();
597 let set_clause = set_cols.join(",\n ");
598
599 let set_cols_cast: Vec<String> = non_pk_fields
601 .iter()
602 .enumerate()
603 .map(|(i, f)| {
604 let p = placeholder_with_cast(db_kind, i + 1, f);
605 format!("{col} = COALESCE({p}, {col})", col = f.column_name, p = p)
606 })
607 .collect();
608 let set_clause_cast = set_cols_cast.join(",\n ");
609
610 let pk_start = non_pk_fields.len() + 1;
611 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
612 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
613
614 let build_update_sql = |sc: &str, wc: &str| match db_kind {
615 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
616 format!(
617 "UPDATE {}\nSET\n {}\nWHERE {}\nRETURNING *",
618 table_name, sc, wc
619 )
620 }
621 DatabaseKind::Mysql => {
622 format!(
623 "UPDATE {}\nSET\n {}\nWHERE {}",
624 table_name, sc, wc
625 )
626 }
627 };
628 let sql = raw_sql_lit(&build_update_sql(&set_clause, &where_clause));
629 let sql_macro = raw_sql_lit(&build_update_sql(&set_clause_cast, &where_clause_cast));
630
631 let mut all_binds: Vec<TokenStream> = non_pk_fields
633 .iter()
634 .map(|f| {
635 let name = format_ident!("{}", f.rust_name);
636 quote! { .bind(¶ms.#name) }
637 })
638 .collect();
639 for f in &pk_fields {
640 let name = format_ident!("{}", f.rust_name);
641 all_binds.push(quote! { .bind(#name) });
642 }
643
644 let update_macro_args: Vec<TokenStream> = non_pk_fields
646 .iter()
647 .map(|f| macro_arg_for_field(f))
648 .chain(pk_fields.iter().map(|f| {
649 let name = format_ident!("{}", f.rust_name);
650 quote! { #name }
651 }))
652 .collect();
653
654 let update_method = if use_macro {
655 match db_kind {
656 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
657 quote! {
658 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
659 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
660 .fetch_one(&self.pool)
661 .await
662 }
663 }
664 }
665 DatabaseKind::Mysql => {
666 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
667 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
668 let pk_macro_args: Vec<TokenStream> = pk_fields
669 .iter()
670 .map(|f| {
671 let name = format_ident!("{}", f.rust_name);
672 quote! { #name }
673 })
674 .collect();
675 quote! {
676 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
677 sqlx::query!(#sql_macro, #(#update_macro_args),*)
678 .execute(&self.pool)
679 .await?;
680 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
681 .fetch_one(&self.pool)
682 .await
683 }
684 }
685 }
686 }
687 } else {
688 match db_kind {
689 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
690 quote! {
691 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
692 sqlx::query_as::<_, #entity_ident>(#sql)
693 #(#all_binds)*
694 .fetch_one(&self.pool)
695 .await
696 }
697 }
698 }
699 DatabaseKind::Mysql => {
700 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
701 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where_select));
702 let pk_binds: Vec<TokenStream> = pk_fields
703 .iter()
704 .map(|f| {
705 let name = format_ident!("{}", f.rust_name);
706 quote! { .bind(#name) }
707 })
708 .collect();
709 quote! {
710 pub async fn update(&self, #(#pk_fn_params),*, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
711 sqlx::query(#sql)
712 #(#all_binds)*
713 .execute(&self.pool)
714 .await?;
715 sqlx::query_as::<_, #entity_ident>(#select_sql)
716 #(#pk_binds)*
717 .fetch_one(&self.pool)
718 .await
719 }
720 }
721 }
722 }
723 };
724 method_tokens.push(update_method);
725
726 param_structs.push(quote! {
727 #[derive(Debug, Clone, Default)]
728 pub struct #update_params_ident {
729 #(#update_fields)*
730 }
731 });
732 }
733
734 if !is_view && methods.delete && !pk_fields.is_empty() {
736 let pk_params: Vec<TokenStream> = pk_fields
737 .iter()
738 .map(|f| {
739 let name = format_ident!("{}", f.rust_name);
740 let ty: TokenStream = f.inner_type.parse().unwrap();
741 quote! { #name: #ty }
742 })
743 .collect();
744
745 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
746 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
747 let sql = raw_sql_lit(&format!("DELETE FROM {}\nWHERE {}", table_name, where_clause));
748 let sql_macro = raw_sql_lit(&format!("DELETE FROM {}\nWHERE {}", table_name, where_clause_cast));
749
750 let binds: Vec<TokenStream> = pk_fields
751 .iter()
752 .map(|f| {
753 let name = format_ident!("{}", f.rust_name);
754 quote! { .bind(#name) }
755 })
756 .collect();
757
758 let method = if query_macro {
759 let pk_arg_names: Vec<TokenStream> = pk_fields
760 .iter()
761 .map(|f| {
762 let name = format_ident!("{}", f.rust_name);
763 quote! { #name }
764 })
765 .collect();
766 quote! {
767 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
768 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
769 .execute(&self.pool)
770 .await?;
771 Ok(())
772 }
773 }
774 } else {
775 quote! {
776 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
777 sqlx::query(#sql)
778 #(#binds)*
779 .execute(&self.pool)
780 .await?;
781 Ok(())
782 }
783 }
784 };
785 method_tokens.push(method);
786 }
787
788 let pool_vis: TokenStream = match pool_visibility {
789 PoolVisibility::Private => quote! {},
790 PoolVisibility::Pub => quote! { pub },
791 PoolVisibility::PubCrate => quote! { pub(crate) },
792 };
793
794 let tokens = quote! {
795 #(#param_structs)*
796
797 pub struct #repo_ident {
798 #pool_vis pool: #pool_type,
799 }
800
801 impl #repo_ident {
802 pub fn new(pool: #pool_type) -> Self {
803 Self { pool }
804 }
805
806 #(#method_tokens)*
807 }
808 };
809
810 (tokens, imports)
811}
812
813fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
814 match db_kind {
815 DatabaseKind::Postgres => quote! { sqlx::PgPool },
816 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
817 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
818 }
819}
820
821fn raw_sql_lit(s: &str) -> TokenStream {
824 if s.contains('\n') {
825 format!("r#\"\n{}\n\"#", s).parse().unwrap()
826 } else {
827 format!("r#\"{}\"#", s).parse().unwrap()
828 }
829}
830
831fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
832 match db_kind {
833 DatabaseKind::Postgres => format!("${}", index),
834 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
835 }
836}
837
838fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
839 let base = placeholder(db_kind, index);
840 match (&field.sql_type, field.is_sql_array) {
841 (Some(t), true) => format!("{} as {}[]", base, t),
842 (Some(t), false) => format!("{} as {}", base, t),
843 (None, _) => base,
844 }
845}
846
847fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
848 (0..count)
849 .map(|i| placeholder(db_kind, start + i))
850 .collect::<Vec<_>>()
851 .join(", ")
852}
853
854fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
855 fields
856 .iter()
857 .enumerate()
858 .map(|(i, f)| {
859 if use_cast {
860 placeholder_with_cast(db_kind, start + i, f)
861 } else {
862 placeholder(db_kind, start + i)
863 }
864 })
865 .collect::<Vec<_>>()
866 .join(", ")
867}
868
869fn build_where_clause_parsed(
870 pk_fields: &[&ParsedField],
871 db_kind: DatabaseKind,
872 start_index: usize,
873) -> String {
874 pk_fields
875 .iter()
876 .enumerate()
877 .map(|(i, f)| {
878 let p = placeholder(db_kind, start_index + i);
879 format!("{} = {}", f.column_name, p)
880 })
881 .collect::<Vec<_>>()
882 .join(" AND ")
883}
884
885fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
886 let name = format_ident!("{}", field.rust_name);
887 let check_type = if field.is_nullable {
888 &field.inner_type
889 } else {
890 &field.rust_type
891 };
892 let normalized = check_type.replace(' ', "");
893 if normalized.starts_with("Vec<") {
894 quote! { params.#name.as_slice() }
895 } else {
896 quote! { params.#name }
897 }
898}
899
900fn build_where_clause_cast(
901 pk_fields: &[&ParsedField],
902 db_kind: DatabaseKind,
903 start_index: usize,
904) -> String {
905 pk_fields
906 .iter()
907 .enumerate()
908 .map(|(i, f)| {
909 let p = placeholder_with_cast(db_kind, start_index + i, f);
910 format!("{} = {}", f.column_name, p)
911 })
912 .collect::<Vec<_>>()
913 .join(" AND ")
914}
915
916#[allow(clippy::too_many_arguments)]
917fn build_insert_method_parsed(
918 entity_ident: &proc_macro2::Ident,
919 insert_params_ident: &proc_macro2::Ident,
920 sql: &str,
921 sql_macro: &str,
922 binds: &[TokenStream],
923 db_kind: DatabaseKind,
924 table_name: &str,
925 pk_fields: &[&ParsedField],
926 non_pk_fields: &[&ParsedField],
927 use_macro: bool,
928) -> TokenStream {
929 let sql = raw_sql_lit(sql);
930 let sql_macro = raw_sql_lit(sql_macro);
931
932 if use_macro {
933 let macro_args: Vec<TokenStream> = non_pk_fields
934 .iter()
935 .map(|f| macro_arg_for_field(f))
936 .collect();
937
938 match db_kind {
939 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
940 quote! {
941 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
942 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
943 .fetch_one(&self.pool)
944 .await
945 }
946 }
947 }
948 DatabaseKind::Mysql => {
949 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
950 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
951 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID() as id");
952 quote! {
953 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
954 sqlx::query!(#sql_macro, #(#macro_args),*)
955 .execute(&self.pool)
956 .await?;
957 let id = sqlx::query_scalar!(#last_insert_id_sql)
958 .fetch_one(&self.pool)
959 .await?;
960 sqlx::query_as!(#entity_ident, #select_sql, id)
961 .fetch_one(&self.pool)
962 .await
963 }
964 }
965 }
966 }
967 } else {
968 match db_kind {
969 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
970 quote! {
971 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
972 sqlx::query_as::<_, #entity_ident>(#sql)
973 #(#binds)*
974 .fetch_one(&self.pool)
975 .await
976 }
977 }
978 }
979 DatabaseKind::Mysql => {
980 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
981 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
982 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
983 quote! {
984 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
985 sqlx::query(#sql)
986 #(#binds)*
987 .execute(&self.pool)
988 .await?;
989 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
990 .fetch_one(&self.pool)
991 .await?;
992 sqlx::query_as::<_, #entity_ident>(#select_sql)
993 .bind(id)
994 .fetch_one(&self.pool)
995 .await
996 }
997 }
998 }
999 }
1000 }
1001}
1002
1003#[allow(clippy::too_many_arguments)]
1004fn build_insert_many_transactionally_method(
1005 entity_ident: &proc_macro2::Ident,
1006 insert_params_ident: &proc_macro2::Ident,
1007 col_list: &str,
1008 num_cols: usize,
1009 insert_source_fields: &[&ParsedField],
1010 binds_loop: &[TokenStream],
1011 db_kind: DatabaseKind,
1012 table_name: &str,
1013 pk_fields: &[&ParsedField],
1014) -> TokenStream {
1015 let body = match db_kind {
1016 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
1017 let col_list_str = col_list.to_string();
1018 let table_name_str = table_name.to_string();
1019
1020 let row_placeholder_exprs: Vec<TokenStream> = insert_source_fields
1021 .iter()
1022 .enumerate()
1023 .map(|(i, f)| {
1024 let offset = i;
1025 match &f.column_default {
1026 Some(default_expr) => {
1027 let def = default_expr.as_str();
1028 match db_kind {
1029 DatabaseKind::Postgres => quote! {
1030 format!("COALESCE(${}, {})", base + #offset + 1, #def)
1031 },
1032 _ => quote! {
1033 format!("COALESCE(?, {})", #def)
1034 },
1035 }
1036 }
1037 None => {
1038 match db_kind {
1039 DatabaseKind::Postgres => quote! {
1040 format!("${}", base + #offset + 1)
1041 },
1042 _ => quote! {
1043 "?".to_string()
1044 },
1045 }
1046 }
1047 }
1048 })
1049 .collect();
1050
1051 quote! {
1052 let mut tx = self.pool.begin().await?;
1053 let mut all_results = Vec::with_capacity(entries.len());
1054 let max_per_chunk = 65535 / #num_cols;
1055 for chunk in entries.chunks(max_per_chunk) {
1056 let mut values_parts = Vec::with_capacity(chunk.len());
1057 for (row_idx, _) in chunk.iter().enumerate() {
1058 let base = row_idx * #num_cols;
1059 let placeholders = vec![#(#row_placeholder_exprs),*];
1060 values_parts.push(format!("({})", placeholders.join(", ")));
1061 }
1062 let sql = format!(
1063 "INSERT INTO {} ({})\nVALUES {}\nRETURNING *",
1064 #table_name_str,
1065 #col_list_str,
1066 values_parts.join(", ")
1067 );
1068 let mut query = sqlx::query_as::<_, #entity_ident>(&sql);
1069 for params in chunk {
1070 #(#binds_loop)*
1071 }
1072 let rows = query.fetch_all(&mut *tx).await?;
1073 all_results.extend(rows);
1074 }
1075 tx.commit().await?;
1076 Ok(all_results)
1077 }
1078 }
1079 DatabaseKind::Mysql => {
1080 let single_placeholders: String = insert_source_fields
1081 .iter()
1082 .enumerate()
1083 .map(|(i, f)| {
1084 let p = placeholder(db_kind, i + 1);
1085 match &f.column_default {
1086 Some(default_expr) => format!("COALESCE({}, {})", p, default_expr),
1087 None => p,
1088 }
1089 })
1090 .collect::<Vec<_>>()
1091 .join(", ");
1092
1093 let single_insert_sql = raw_sql_lit(&format!(
1094 "INSERT INTO {} ({})\nVALUES ({})",
1095 table_name, col_list, single_placeholders
1096 ));
1097
1098 let single_binds: Vec<TokenStream> = insert_source_fields
1099 .iter()
1100 .map(|f| {
1101 let name = format_ident!("{}", f.rust_name);
1102 quote! { .bind(¶ms.#name) }
1103 })
1104 .collect();
1105
1106 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
1107 let select_sql = raw_sql_lit(&format!("SELECT *\nFROM {}\nWHERE {}", table_name, pk_where));
1108 let last_insert_id_sql = raw_sql_lit("SELECT LAST_INSERT_ID()");
1109
1110 quote! {
1111 let mut tx = self.pool.begin().await?;
1112 let mut results = Vec::with_capacity(entries.len());
1113 for params in &entries {
1114 sqlx::query(#single_insert_sql)
1115 #(#single_binds)*
1116 .execute(&mut *tx)
1117 .await?;
1118 let id = sqlx::query_scalar::<_, i64>(#last_insert_id_sql)
1119 .fetch_one(&mut *tx)
1120 .await?;
1121 let row = sqlx::query_as::<_, #entity_ident>(#select_sql)
1122 .bind(id)
1123 .fetch_one(&mut *tx)
1124 .await?;
1125 results.push(row);
1126 }
1127 tx.commit().await?;
1128 Ok(results)
1129 }
1130 }
1131 };
1132
1133 quote! {
1134 pub async fn insert_many_transactionally(
1135 &self,
1136 entries: Vec<#insert_params_ident>,
1137 ) -> Result<Vec<#entity_ident>, sqlx::Error> {
1138 #body
1139 }
1140 }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145 use super::*;
1146 use crate::codegen::parse_and_format_with_tab_spaces;
1147 use crate::codegen::parse_and_format;
1148 use crate::cli::Methods;
1149
1150 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
1151 let inner_type = if nullable {
1152 rust_type
1154 .strip_prefix("Option<")
1155 .and_then(|s| s.strip_suffix('>'))
1156 .unwrap_or(rust_type)
1157 .to_string()
1158 } else {
1159 rust_type.to_string()
1160 };
1161 ParsedField {
1162 rust_name: rust_name.to_string(),
1163 column_name: column_name.to_string(),
1164 rust_type: rust_type.to_string(),
1165 is_nullable: nullable,
1166 inner_type,
1167 is_primary_key: is_pk,
1168 sql_type: None,
1169 is_sql_array: false,
1170 column_default: None,
1171 }
1172 }
1173
1174 fn make_field_with_default(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool, default: &str) -> ParsedField {
1175 let mut f = make_field(rust_name, column_name, rust_type, nullable, is_pk);
1176 f.column_default = Some(default.to_string());
1177 f
1178 }
1179
1180 fn entity_with_defaults() -> ParsedEntity {
1181 ParsedEntity {
1182 struct_name: "Tasks".to_string(),
1183 table_name: "tasks".to_string(),
1184 schema_name: None,
1185 is_view: false,
1186 fields: vec![
1187 make_field("id", "id", "i32", false, true),
1188 make_field("title", "title", "String", false, false),
1189 make_field_with_default("status", "status", "String", false, false, "'idle'::task_status"),
1190 make_field_with_default("priority", "priority", "i32", false, false, "0"),
1191 make_field_with_default("created_at", "created_at", "DateTime<Utc>", false, false, "now()"),
1192 make_field("description", "description", "Option<String>", true, false),
1193 make_field_with_default("deleted_at", "deleted_at", "Option<DateTime<Utc>>", true, false, "NULL"),
1194 ],
1195 imports: vec![],
1196 }
1197 }
1198
1199 fn standard_entity() -> ParsedEntity {
1200 ParsedEntity {
1201 struct_name: "Users".to_string(),
1202 table_name: "users".to_string(),
1203 schema_name: None,
1204 is_view: false,
1205 fields: vec![
1206 make_field("id", "id", "i32", false, true),
1207 make_field("name", "name", "String", false, false),
1208 make_field("email", "email", "Option<String>", true, false),
1209 ],
1210 imports: vec![],
1211 }
1212 }
1213
1214 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
1215 let skip = Methods::all();
1216 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1217 parse_and_format(&tokens)
1218 }
1219
1220 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
1221 let skip = Methods::all();
1222 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
1223 parse_and_format(&tokens)
1224 }
1225
1226 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
1227 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
1228 parse_and_format(&tokens)
1229 }
1230
1231 fn gen_with_tab_spaces(entity: &ParsedEntity, db: DatabaseKind, tab_spaces: usize) -> String {
1232 let skip = Methods::all();
1233 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
1234 parse_and_format_with_tab_spaces(&tokens, tab_spaces)
1235 }
1236
1237 #[test]
1240 fn test_repo_struct_name() {
1241 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1242 assert!(code.contains("pub struct UsersRepository"));
1243 }
1244
1245 #[test]
1246 fn test_repo_new_method() {
1247 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1248 assert!(code.contains("pub fn new("));
1249 }
1250
1251 #[test]
1252 fn test_repo_pool_field_pg() {
1253 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1254 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
1255 }
1256
1257 #[test]
1258 fn test_repo_pool_field_pub() {
1259 let skip = Methods::all();
1260 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
1261 let code = parse_and_format(&tokens);
1262 assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
1263 }
1264
1265 #[test]
1266 fn test_repo_pool_field_pub_crate() {
1267 let skip = Methods::all();
1268 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
1269 let code = parse_and_format(&tokens);
1270 assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
1271 }
1272
1273 #[test]
1274 fn test_repo_pool_field_private() {
1275 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1276 assert!(!code.contains("pub pool"));
1278 assert!(!code.contains("pub(crate) pool"));
1279 }
1280
1281 #[test]
1282 fn test_repo_pool_field_mysql() {
1283 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1284 assert!(code.contains("MySqlPool") || code.contains("MySql"));
1285 }
1286
1287 #[test]
1288 fn test_repo_pool_field_sqlite() {
1289 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1290 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
1291 }
1292
1293 #[test]
1296 fn test_get_all_method() {
1297 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1298 assert!(code.contains("pub async fn get_all"));
1299 }
1300
1301 #[test]
1302 fn test_get_all_returns_vec() {
1303 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1304 assert!(code.contains("Vec<Users>"));
1305 }
1306
1307 #[test]
1308 fn test_get_all_sql() {
1309 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1310 assert!(code.contains("SELECT * FROM users"));
1311 }
1312
1313 #[test]
1316 fn test_paginate_method() {
1317 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1318 assert!(code.contains("pub async fn paginate"));
1319 }
1320
1321 #[test]
1322 fn test_paginate_params_struct() {
1323 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1324 assert!(code.contains("pub struct PaginateUsersParams"));
1325 }
1326
1327 #[test]
1328 fn test_paginate_params_fields() {
1329 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1330 assert!(code.contains("pub page: i64"));
1331 assert!(code.contains("pub per_page: i64"));
1332 }
1333
1334 #[test]
1335 fn test_paginate_returns_paginated() {
1336 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1337 assert!(code.contains("PaginatedUsers"));
1338 assert!(code.contains("PaginationUsersMeta"));
1339 }
1340
1341 #[test]
1342 fn test_paginate_meta_struct() {
1343 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1344 assert!(code.contains("pub struct PaginationUsersMeta"));
1345 assert!(code.contains("pub total: i64"));
1346 assert!(code.contains("pub last_page: i64"));
1347 assert!(code.contains("pub first_page: i64"));
1348 assert!(code.contains("pub current_page: i64"));
1349 }
1350
1351 #[test]
1352 fn test_paginate_data_struct() {
1353 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1354 assert!(code.contains("pub struct PaginatedUsers"));
1355 assert!(code.contains("pub meta: PaginationUsersMeta"));
1356 assert!(code.contains("pub data: Vec<Users>"));
1357 }
1358
1359 #[test]
1360 fn test_paginate_count_sql() {
1361 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1362 assert!(code.contains("SELECT COUNT(*) FROM users"));
1363 }
1364
1365 #[test]
1366 fn test_paginate_sql_pg() {
1367 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1368 assert!(code.contains("LIMIT $1 OFFSET $2"));
1369 }
1370
1371 #[test]
1372 fn test_paginate_sql_mysql() {
1373 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1374 assert!(code.contains("LIMIT ? OFFSET ?"));
1375 }
1376
1377 #[test]
1380 fn test_get_method() {
1381 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1382 assert!(code.contains("pub async fn get"));
1383 }
1384
1385 #[test]
1386 fn test_get_returns_option() {
1387 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1388 assert!(code.contains("Option<Users>"));
1389 }
1390
1391 #[test]
1392 fn test_get_where_pk_pg() {
1393 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1394 assert!(code.contains("WHERE id = $1"));
1395 }
1396
1397 #[test]
1398 fn test_get_where_pk_mysql() {
1399 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1400 assert!(code.contains("WHERE id = ?"));
1401 }
1402
1403 #[test]
1406 fn test_insert_method() {
1407 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1408 assert!(code.contains("pub async fn insert"));
1409 }
1410
1411 #[test]
1412 fn test_insert_params_struct() {
1413 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1414 assert!(code.contains("pub struct InsertUsersParams"));
1415 }
1416
1417 #[test]
1418 fn test_insert_params_no_pk() {
1419 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1420 assert!(code.contains("pub name: String"));
1421 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
1422 }
1423
1424 #[test]
1425 fn test_insert_returning_pg() {
1426 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1427 assert!(code.contains("RETURNING *"));
1428 }
1429
1430 #[test]
1431 fn test_insert_returning_sqlite() {
1432 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1433 assert!(code.contains("RETURNING *"));
1434 }
1435
1436 #[test]
1437 fn test_insert_mysql_last_insert_id() {
1438 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1439 assert!(code.contains("LAST_INSERT_ID"));
1440 }
1441
1442 #[test]
1445 fn test_insert_default_col_is_optional() {
1446 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1447 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1449 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1450 let struct_body = &code[struct_start..struct_end];
1451 assert!(struct_body.contains("Option") && struct_body.contains("status"), "Expected status as Option in InsertTasksParams: {}", struct_body);
1452 }
1453
1454 #[test]
1455 fn test_insert_non_default_col_required() {
1456 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1457 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1459 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1460 let struct_body = &code[struct_start..struct_end];
1461 assert!(struct_body.contains("title") && struct_body.contains("String"), "Expected title as String: {}", struct_body);
1462 }
1463
1464 #[test]
1465 fn test_insert_default_col_coalesce_sql() {
1466 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1467 assert!(code.contains("COALESCE($2, 'idle'::task_status)"), "Expected COALESCE for status:\n{}", code);
1468 assert!(code.contains("COALESCE($3, 0)"), "Expected COALESCE for priority:\n{}", code);
1469 assert!(code.contains("COALESCE($4, now())"), "Expected COALESCE for created_at:\n{}", code);
1470 }
1471
1472 #[test]
1473 fn test_insert_no_coalesce_for_non_default() {
1474 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1475 assert!(code.contains("VALUES ($1, COALESCE"), "Expected $1 without COALESCE for title:\n{}", code);
1477 }
1478
1479 #[test]
1480 fn test_insert_nullable_with_default_no_double_option() {
1481 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1482 assert!(!code.contains("Option < Option") && !code.contains("Option<Option"), "Should not have Option<Option>:\n{}", code);
1483 }
1484
1485 #[test]
1486 fn test_insert_derive_default() {
1487 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
1488 let struct_start = code.find("pub struct InsertTasksParams").expect("InsertTasksParams not found");
1489 let before_struct = &code[..struct_start];
1490 assert!(before_struct.ends_with("Default)]\n") || before_struct.contains("Default)]"), "Expected #[derive(Default)] on InsertTasksParams");
1491 }
1492
1493 #[test]
1496 fn test_update_method() {
1497 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1498 assert!(code.contains("pub async fn update"));
1499 }
1500
1501 #[test]
1502 fn test_update_params_struct() {
1503 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1504 assert!(code.contains("pub struct UpdateUsersParams"));
1505 }
1506
1507 #[test]
1508 fn test_update_pk_in_fn_signature() {
1509 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1510 let update_pos = code.find("fn update").expect("fn update not found");
1512 let params_pos = code[update_pos..].find("UpdateUsersParams").expect("UpdateUsersParams not found in update fn");
1513 let signature = &code[update_pos..update_pos + params_pos];
1514 assert!(signature.contains("id"), "Expected 'id' PK in update fn signature: {}", signature);
1515 }
1516
1517 #[test]
1518 fn test_update_pk_not_in_struct() {
1519 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1520 let struct_start = code.find("pub struct UpdateUsersParams").expect("UpdateUsersParams not found");
1523 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1524 let struct_body = &code[struct_start..struct_end];
1525 assert!(!struct_body.contains("pub id"), "PK 'id' should not be in UpdateUsersParams:\n{}", struct_body);
1526 }
1527
1528 #[test]
1529 fn test_update_params_non_nullable_wrapped_in_option() {
1530 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1531 assert!(code.contains("pub name: Option<String>") || code.contains("pub name : Option < String >"));
1533 }
1534
1535 #[test]
1536 fn test_update_params_already_nullable_no_double_option() {
1537 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1538 assert!(!code.contains("Option<Option") && !code.contains("Option < Option"));
1540 }
1541
1542 #[test]
1543 fn test_update_set_clause_uses_coalesce_pg() {
1544 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1545 assert!(code.contains("COALESCE($1, name)"), "Expected COALESCE for name:\n{}", code);
1546 assert!(code.contains("COALESCE($2, email)"), "Expected COALESCE for email:\n{}", code);
1547 }
1548
1549 #[test]
1550 fn test_update_where_clause_pg() {
1551 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1552 assert!(code.contains("WHERE id = $3"));
1553 }
1554
1555 #[test]
1556 fn test_update_returning_pg() {
1557 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1558 assert!(code.contains("COALESCE"));
1559 assert!(code.contains("RETURNING *"));
1560 }
1561
1562 #[test]
1563 fn test_update_set_clause_mysql() {
1564 let code = gen(&standard_entity(), DatabaseKind::Mysql);
1565 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for MySQL:\n{}", code);
1566 assert!(code.contains("COALESCE(?, email)"), "Expected COALESCE for email in MySQL:\n{}", code);
1567 }
1568
1569 #[test]
1570 fn test_update_set_clause_sqlite() {
1571 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
1572 assert!(code.contains("COALESCE(?, name)"), "Expected COALESCE for SQLite:\n{}", code);
1573 }
1574
1575 #[test]
1578 fn test_overwrite_method() {
1579 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1580 assert!(code.contains("pub async fn overwrite"));
1581 }
1582
1583 #[test]
1584 fn test_overwrite_params_struct() {
1585 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1586 assert!(code.contains("pub struct OverwriteUsersParams"));
1587 }
1588
1589 #[test]
1590 fn test_overwrite_pk_in_fn_signature() {
1591 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1592 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1593 let params_pos = code[pos..].find("OverwriteUsersParams").expect("OverwriteUsersParams not found");
1594 let signature = &code[pos..pos + params_pos];
1595 assert!(signature.contains("id"), "Expected PK in overwrite fn signature: {}", signature);
1596 }
1597
1598 #[test]
1599 fn test_overwrite_pk_not_in_struct() {
1600 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1601 let struct_start = code.find("pub struct OverwriteUsersParams").expect("OverwriteUsersParams not found");
1602 let struct_end = code[struct_start..].find('}').unwrap() + struct_start;
1603 let struct_body = &code[struct_start..struct_end];
1604 assert!(!struct_body.contains("pub id"), "PK should not be in OverwriteUsersParams: {}", struct_body);
1605 }
1606
1607 #[test]
1608 fn test_overwrite_no_coalesce() {
1609 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1610 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1612 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1613 assert!(!method_body.contains("COALESCE"), "Overwrite should not use COALESCE: {}", method_body);
1614 }
1615
1616 #[test]
1617 fn test_overwrite_set_clause_pg() {
1618 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1619 assert!(code.contains("name = $1,"));
1620 assert!(code.contains("email = $2"));
1621 assert!(code.contains("WHERE id = $3"));
1622 }
1623
1624 #[test]
1625 fn test_overwrite_returning_pg() {
1626 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1627 let pos = code.find("fn overwrite").expect("fn overwrite not found");
1628 let method_body = &code[pos..pos + 500.min(code.len() - pos)];
1629 assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in overwrite");
1630 }
1631
1632 #[test]
1633 fn test_view_no_overwrite() {
1634 let mut entity = standard_entity();
1635 entity.is_view = true;
1636 let code = gen(&entity, DatabaseKind::Postgres);
1637 assert!(!code.contains("pub async fn overwrite"));
1638 }
1639
1640 #[test]
1641 fn test_without_overwrite() {
1642 let m = Methods { overwrite: false, ..Methods::all() };
1643 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1644 assert!(!code.contains("pub async fn overwrite"));
1645 assert!(!code.contains("OverwriteUsersParams"));
1646 }
1647
1648 #[test]
1649 fn test_update_and_overwrite_coexist() {
1650 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1651 assert!(code.contains("pub async fn update"), "Expected update method");
1652 assert!(code.contains("pub async fn overwrite"), "Expected overwrite method");
1653 assert!(code.contains("UpdateUsersParams"), "Expected UpdateUsersParams");
1654 assert!(code.contains("OverwriteUsersParams"), "Expected OverwriteUsersParams");
1655 }
1656
1657 #[test]
1660 fn test_delete_method() {
1661 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1662 assert!(code.contains("pub async fn delete"));
1663 }
1664
1665 #[test]
1666 fn test_delete_where_pk() {
1667 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1668 assert!(code.contains("DELETE FROM users"));
1669 assert!(code.contains("WHERE id = $1"));
1670 }
1671
1672 #[test]
1673 fn test_tab_spaces_2_sql_indent() {
1674 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 2);
1675 assert!(code.contains(" SELECT *"), "Expected SQL at 8-space indent:\n{}", code);
1677 assert!(code.contains(" \"#"), "Expected closing tag at 6-space indent:\n{}", code);
1678 }
1679
1680 #[test]
1681 fn test_tab_spaces_4_sql_indent() {
1682 let code = gen_with_tab_spaces(&standard_entity(), DatabaseKind::Postgres, 4);
1683 assert!(code.contains(" SELECT *"), "Expected SQL at 12-space indent:\n{}", code);
1685 assert!(code.contains(" \"#"), "Expected closing tag at 8-space indent:\n{}", code);
1686 }
1687
1688 #[test]
1689 fn test_delete_returns_unit() {
1690 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1691 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1692 }
1693
1694 #[test]
1697 fn test_view_no_insert() {
1698 let mut entity = standard_entity();
1699 entity.is_view = true;
1700 let code = gen(&entity, DatabaseKind::Postgres);
1701 assert!(!code.contains("pub async fn insert"));
1702 }
1703
1704 #[test]
1705 fn test_view_no_update() {
1706 let mut entity = standard_entity();
1707 entity.is_view = true;
1708 let code = gen(&entity, DatabaseKind::Postgres);
1709 assert!(!code.contains("pub async fn update"));
1710 }
1711
1712 #[test]
1713 fn test_view_no_delete() {
1714 let mut entity = standard_entity();
1715 entity.is_view = true;
1716 let code = gen(&entity, DatabaseKind::Postgres);
1717 assert!(!code.contains("pub async fn delete"));
1718 }
1719
1720 #[test]
1721 fn test_view_has_get_all() {
1722 let mut entity = standard_entity();
1723 entity.is_view = true;
1724 let code = gen(&entity, DatabaseKind::Postgres);
1725 assert!(code.contains("pub async fn get_all"));
1726 }
1727
1728 #[test]
1729 fn test_view_has_paginate() {
1730 let mut entity = standard_entity();
1731 entity.is_view = true;
1732 let code = gen(&entity, DatabaseKind::Postgres);
1733 assert!(code.contains("pub async fn paginate"));
1734 }
1735
1736 #[test]
1737 fn test_view_has_get() {
1738 let mut entity = standard_entity();
1739 entity.is_view = true;
1740 let code = gen(&entity, DatabaseKind::Postgres);
1741 assert!(code.contains("pub async fn get"));
1742 }
1743
1744 #[test]
1747 fn test_only_get_all() {
1748 let m = Methods { get_all: true, ..Default::default() };
1749 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1750 assert!(code.contains("pub async fn get_all"));
1751 assert!(!code.contains("pub async fn paginate"));
1752 assert!(!code.contains("pub async fn insert"));
1753 }
1754
1755 #[test]
1756 fn test_without_get_all() {
1757 let m = Methods { get_all: false, ..Methods::all() };
1758 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1759 assert!(!code.contains("pub async fn get_all"));
1760 }
1761
1762 #[test]
1763 fn test_without_paginate() {
1764 let m = Methods { paginate: false, ..Methods::all() };
1765 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1766 assert!(!code.contains("pub async fn paginate"));
1767 assert!(!code.contains("PaginateUsersParams"));
1768 }
1769
1770 #[test]
1771 fn test_without_get() {
1772 let m = Methods { get: false, ..Methods::all() };
1773 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1774 assert!(code.contains("pub async fn get_all"));
1775 let without_get_all = code.replace("get_all", "XXX");
1776 assert!(!without_get_all.contains("fn get("));
1777 }
1778
1779 #[test]
1780 fn test_without_insert() {
1781 let m = Methods { insert: false, insert_many: false, ..Methods::all() };
1782 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1783 assert!(!code.contains("pub async fn insert"));
1784 assert!(!code.contains("InsertUsersParams"));
1785 }
1786
1787 #[test]
1788 fn test_without_update() {
1789 let m = Methods { update: false, ..Methods::all() };
1790 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1791 assert!(!code.contains("pub async fn update"));
1792 assert!(!code.contains("UpdateUsersParams"));
1793 }
1794
1795 #[test]
1796 fn test_without_delete() {
1797 let m = Methods { delete: false, ..Methods::all() };
1798 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1799 assert!(!code.contains("pub async fn delete"));
1800 }
1801
1802 #[test]
1803 fn test_empty_methods_no_methods() {
1804 let m = Methods::default();
1805 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1806 assert!(!code.contains("pub async fn get_all"));
1807 assert!(!code.contains("pub async fn paginate"));
1808 assert!(!code.contains("pub async fn insert"));
1809 assert!(!code.contains("pub async fn update"));
1810 assert!(!code.contains("pub async fn overwrite"));
1811 assert!(!code.contains("pub async fn delete"));
1812 assert!(!code.contains("pub async fn insert_many"));
1813 }
1814
1815 #[test]
1818 fn test_no_pool_import() {
1819 let skip = Methods::all();
1820 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1821 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1822 }
1823
1824 #[test]
1825 fn test_imports_contain_entity() {
1826 let skip = Methods::all();
1827 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1828 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1829 }
1830
1831 #[test]
1834 fn test_renamed_column_in_sql() {
1835 let entity = ParsedEntity {
1836 struct_name: "Connector".to_string(),
1837 table_name: "connector".to_string(),
1838 schema_name: None,
1839 is_view: false,
1840 fields: vec![
1841 make_field("id", "id", "i32", false, true),
1842 make_field("connector_type", "type", "String", false, false),
1843 ],
1844 imports: vec![],
1845 };
1846 let code = gen(&entity, DatabaseKind::Postgres);
1847 assert!(code.contains("type"));
1849 assert!(code.contains("pub connector_type: String"));
1851 }
1852
1853 #[test]
1856 fn test_no_pk_no_get() {
1857 let entity = ParsedEntity {
1858 struct_name: "Logs".to_string(),
1859 table_name: "logs".to_string(),
1860 schema_name: None,
1861 is_view: false,
1862 fields: vec![
1863 make_field("message", "message", "String", false, false),
1864 make_field("ts", "ts", "String", false, false),
1865 ],
1866 imports: vec![],
1867 };
1868 let code = gen(&entity, DatabaseKind::Postgres);
1869 assert!(code.contains("pub async fn get_all"));
1870 let without_get_all = code.replace("get_all", "XXX");
1871 assert!(!without_get_all.contains("fn get("));
1872 }
1873
1874 #[test]
1875 fn test_no_pk_no_delete() {
1876 let entity = ParsedEntity {
1877 struct_name: "Logs".to_string(),
1878 table_name: "logs".to_string(),
1879 schema_name: None,
1880 is_view: false,
1881 fields: vec![
1882 make_field("message", "message", "String", false, false),
1883 ],
1884 imports: vec![],
1885 };
1886 let code = gen(&entity, DatabaseKind::Postgres);
1887 assert!(!code.contains("pub async fn delete"));
1888 }
1889
1890 #[test]
1893 fn test_param_structs_have_default() {
1894 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1895 assert!(code.contains("Default"));
1896 }
1897
1898 #[test]
1901 fn test_entity_imports_forwarded() {
1902 let entity = ParsedEntity {
1903 struct_name: "Users".to_string(),
1904 table_name: "users".to_string(),
1905 schema_name: None,
1906 is_view: false,
1907 fields: vec![
1908 make_field("id", "id", "Uuid", false, true),
1909 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1910 ],
1911 imports: vec![
1912 "use chrono::{DateTime, Utc};".to_string(),
1913 "use uuid::Uuid;".to_string(),
1914 ],
1915 };
1916 let skip = Methods::all();
1917 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1918 assert!(imports.iter().any(|i| i.contains("chrono")));
1919 assert!(imports.iter().any(|i| i.contains("uuid")));
1920 }
1921
1922 #[test]
1923 fn test_entity_imports_empty_when_no_imports() {
1924 let skip = Methods::all();
1925 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1926 assert!(!imports.iter().any(|i| i.contains("chrono")));
1928 assert!(!imports.iter().any(|i| i.contains("uuid")));
1929 }
1930
1931 #[test]
1934 fn test_macro_get_all() {
1935 let m = Methods { get_all: true, ..Default::default() };
1936 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &m, true, PoolVisibility::Private);
1937 let code = parse_and_format(&tokens);
1938 assert!(code.contains("query_as!"));
1939 assert!(!code.contains("query_as::<"));
1940 }
1941
1942 #[test]
1943 fn test_macro_paginate() {
1944 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1945 assert!(code.contains("query_as!"));
1946 assert!(code.contains("per_page, offset"));
1947 }
1948
1949 #[test]
1950 fn test_macro_get() {
1951 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1952 assert!(code.contains("query_as!(Users"));
1954 }
1955
1956 #[test]
1957 fn test_macro_insert_pg() {
1958 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1959 assert!(code.contains("query_as!(Users"));
1960 assert!(code.contains("params.name"));
1961 assert!(code.contains("params.email"));
1962 }
1963
1964 #[test]
1965 fn test_macro_insert_mysql() {
1966 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1967 assert!(code.contains("query!"));
1969 assert!(code.contains("query_scalar!"));
1970 }
1971
1972 #[test]
1973 fn test_macro_update() {
1974 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1975 assert!(code.contains("query_as!(Users"));
1976 assert!(code.contains("COALESCE"), "Expected COALESCE in macro update:\n{}", code);
1977 assert!(code.contains("pub async fn update"));
1978 assert!(code.contains("UpdateUsersParams"));
1979 }
1980
1981 #[test]
1982 fn test_macro_delete() {
1983 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1984 assert!(code.contains("query!"));
1986 }
1987
1988 #[test]
1989 fn test_macro_no_bind_calls() {
1990 let m = Methods { insert_many: false, ..Methods::all() };
1992 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &m, true, PoolVisibility::Private);
1993 let code = parse_and_format(&tokens);
1994 assert!(!code.contains(".bind("));
1995 }
1996
1997 #[test]
1998 fn test_function_style_uses_bind() {
1999 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2000 assert!(code.contains(".bind("));
2001 assert!(!code.contains("query_as!("));
2002 assert!(!code.contains("query!("));
2003 }
2004
2005 fn entity_with_sql_array() -> ParsedEntity {
2008 ParsedEntity {
2009 struct_name: "AgentConnector".to_string(),
2010 table_name: "agent.agent_connector".to_string(),
2011 schema_name: Some("agent".to_string()),
2012 is_view: false,
2013 fields: vec![
2014 ParsedField {
2015 rust_name: "connector_id".to_string(),
2016 column_name: "connector_id".to_string(),
2017 rust_type: "Uuid".to_string(),
2018 inner_type: "Uuid".to_string(),
2019 is_nullable: false,
2020 is_primary_key: true,
2021 sql_type: None,
2022 is_sql_array: false,
2023 column_default: None,
2024 },
2025 ParsedField {
2026 rust_name: "agent_id".to_string(),
2027 column_name: "agent_id".to_string(),
2028 rust_type: "Uuid".to_string(),
2029 inner_type: "Uuid".to_string(),
2030 is_nullable: false,
2031 is_primary_key: false,
2032 sql_type: None,
2033 is_sql_array: false,
2034 column_default: None,
2035 },
2036 ParsedField {
2037 rust_name: "usages".to_string(),
2038 column_name: "usages".to_string(),
2039 rust_type: "Vec<ConnectorUsages>".to_string(),
2040 inner_type: "Vec<ConnectorUsages>".to_string(),
2041 is_nullable: false,
2042 is_primary_key: false,
2043 sql_type: Some("agent.connector_usages".to_string()),
2044 is_sql_array: true,
2045 column_default: None,
2046 },
2047 ],
2048 imports: vec!["use uuid::Uuid;".to_string()],
2049 }
2050 }
2051
2052 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
2053 let skip = Methods::all();
2054 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
2055 parse_and_format(&tokens)
2056 }
2057
2058 #[test]
2059 fn test_sql_array_macro_get_all_uses_runtime() {
2060 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2061 assert!(code.contains("query_as::<"));
2063 }
2064
2065 #[test]
2066 fn test_sql_array_macro_get_uses_runtime() {
2067 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2068 assert!(code.contains(".bind("));
2070 }
2071
2072 #[test]
2073 fn test_sql_array_macro_insert_uses_runtime() {
2074 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2075 assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
2077 }
2078
2079
2080 #[test]
2081 fn test_sql_array_macro_delete_still_uses_macro() {
2082 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2083 assert!(code.contains("query!"));
2085 }
2086
2087 #[test]
2088 fn test_sql_array_no_query_as_macro() {
2089 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
2090 assert!(!code.contains("query_as!("));
2092 }
2093
2094 fn entity_with_sql_enum() -> ParsedEntity {
2097 ParsedEntity {
2098 struct_name: "Task".to_string(),
2099 table_name: "tasks".to_string(),
2100 schema_name: None,
2101 is_view: false,
2102 fields: vec![
2103 ParsedField {
2104 rust_name: "id".to_string(),
2105 column_name: "id".to_string(),
2106 rust_type: "i32".to_string(),
2107 inner_type: "i32".to_string(),
2108 is_nullable: false,
2109 is_primary_key: true,
2110 sql_type: None,
2111 is_sql_array: false,
2112 column_default: None,
2113 },
2114 ParsedField {
2115 rust_name: "status".to_string(),
2116 column_name: "status".to_string(),
2117 rust_type: "TaskStatus".to_string(),
2118 inner_type: "TaskStatus".to_string(),
2119 is_nullable: false,
2120 is_primary_key: false,
2121 sql_type: Some("task_status".to_string()),
2122 is_sql_array: false,
2123 column_default: None,
2124 },
2125 ],
2126 imports: vec![],
2127 }
2128 }
2129
2130 #[test]
2131 fn test_sql_enum_macro_uses_runtime() {
2132 let skip = Methods::all();
2133 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
2134 let code = parse_and_format(&tokens);
2135 assert!(code.contains("query_as::<"));
2137 assert!(!code.contains("query_as!("));
2138 }
2139
2140 #[test]
2141 fn test_sql_enum_macro_delete_still_uses_macro() {
2142 let skip = Methods::all();
2143 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
2144 let code = parse_and_format(&tokens);
2145 assert!(code.contains("query!"));
2147 }
2148
2149 fn entity_with_vec_string() -> ParsedEntity {
2152 ParsedEntity {
2153 struct_name: "PromptHistory".to_string(),
2154 table_name: "prompt_history".to_string(),
2155 schema_name: None,
2156 is_view: false,
2157 fields: vec![
2158 ParsedField {
2159 rust_name: "id".to_string(),
2160 column_name: "id".to_string(),
2161 rust_type: "Uuid".to_string(),
2162 inner_type: "Uuid".to_string(),
2163 is_nullable: false,
2164 is_primary_key: true,
2165 sql_type: None,
2166 is_sql_array: false,
2167 column_default: None,
2168 },
2169 ParsedField {
2170 rust_name: "content".to_string(),
2171 column_name: "content".to_string(),
2172 rust_type: "String".to_string(),
2173 inner_type: "String".to_string(),
2174 is_nullable: false,
2175 is_primary_key: false,
2176 sql_type: None,
2177 is_sql_array: false,
2178 column_default: None,
2179 },
2180 ParsedField {
2181 rust_name: "tags".to_string(),
2182 column_name: "tags".to_string(),
2183 rust_type: "Vec<String>".to_string(),
2184 inner_type: "Vec<String>".to_string(),
2185 is_nullable: false,
2186 is_primary_key: false,
2187 sql_type: None,
2188 is_sql_array: false,
2189 column_default: None,
2190 },
2191 ],
2192 imports: vec!["use uuid::Uuid;".to_string()],
2193 }
2194 }
2195
2196 #[test]
2197 fn test_vec_string_macro_insert_uses_as_slice() {
2198 let skip = Methods::all();
2199 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2200 let code = parse_and_format(&tokens);
2201 assert!(code.contains("as_slice()"));
2202 }
2203
2204 #[test]
2205 fn test_vec_string_macro_update_uses_as_slice() {
2206 let skip = Methods::all();
2207 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2208 let code = parse_and_format(&tokens);
2209 let count = code.matches("as_slice()").count();
2211 assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
2212 }
2213
2214 #[test]
2215 fn test_vec_string_non_macro_no_as_slice() {
2216 let skip = Methods::all();
2217 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
2218 let code = parse_and_format(&tokens);
2219 assert!(!code.contains("as_slice()"));
2221 }
2222
2223 #[test]
2224 fn test_vec_string_parsed_from_source_uses_as_slice() {
2225 use crate::codegen::entity_parser::parse_entity_source;
2226 let source = r#"
2227 use uuid::Uuid;
2228
2229 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
2230 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
2231 pub struct PromptHistory {
2232 #[sqlx_gen(primary_key)]
2233 pub id: Uuid,
2234 pub content: String,
2235 pub tags: Vec<String>,
2236 }
2237 "#;
2238 let entity = parse_entity_source(source).unwrap();
2239 let skip = Methods::all();
2240 let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
2241 let code = parse_and_format(&tokens);
2242 assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
2243 }
2244
2245 fn junction_entity() -> ParsedEntity {
2248 ParsedEntity {
2249 struct_name: "AnalysisRecord".to_string(),
2250 table_name: "analysis.analysis__record".to_string(),
2251 schema_name: None,
2252 is_view: false,
2253 fields: vec![
2254 make_field("record_id", "record_id", "uuid::Uuid", false, true),
2255 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
2256 ],
2257 imports: vec![],
2258 }
2259 }
2260
2261 #[test]
2262 fn test_composite_pk_only_insert_generated() {
2263 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2264 assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
2265 assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
2266 assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
2267 assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id)"), "Expected INSERT INTO clause:\n{}", code);
2268 assert!(code.contains("VALUES ($1, $2)"), "Expected VALUES clause:\n{}", code);
2269 assert!(code.contains("RETURNING *"), "Expected RETURNING clause:\n{}", code);
2270 assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
2271 }
2272
2273 #[test]
2274 fn test_composite_pk_only_no_update() {
2275 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2276 assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
2277 assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
2278 }
2279
2280
2281 #[test]
2282 fn test_composite_pk_only_delete_generated() {
2283 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2284 assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
2285 assert!(code.contains("DELETE FROM analysis.analysis__record"), "Expected DELETE clause:\n{}", code);
2286 assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause:\n{}", code);
2287 }
2288
2289 #[test]
2290 fn test_composite_pk_only_get_generated() {
2291 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2292 assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
2293 assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
2294 }
2295
2296 #[test]
2299 fn test_insert_many_transactionally_method_generated() {
2300 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2301 assert!(code.contains("pub async fn insert_many_transactionally"), "Expected insert_many_transactionally method:\n{}", code);
2302 }
2303
2304 #[test]
2305 fn test_insert_many_transactionally_signature() {
2306 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2307 assert!(code.contains("entries: Vec<InsertUsersParams>"), "Expected Vec<InsertUsersParams> param:\n{}", code);
2308 assert!(code.contains("Result<Vec<Users>"), "Expected Result<Vec<Users>> return type:\n{}", code);
2309 }
2310
2311 #[test]
2312 fn test_insert_many_transactionally_no_strategy_enum() {
2313 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2314 assert!(!code.contains("TransactionStrategy"), "TransactionStrategy should not be generated:\n{}", code);
2315 assert!(!code.contains("InsertManyUsersResult"), "InsertManyUsersResult should not be generated:\n{}", code);
2316 }
2317
2318 #[test]
2319 fn test_insert_many_transactionally_uses_transaction_pg() {
2320 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2321 let method_start = code.find("fn insert_many_transactionally").expect("insert_many_transactionally not found");
2322 let method_body = &code[method_start..];
2323 assert!(method_body.contains("self.pool.begin()"), "Expected begin():\n{}", method_body);
2324 assert!(method_body.contains("tx.commit()"), "Expected commit():\n{}", method_body);
2325 }
2326
2327 #[test]
2328 fn test_insert_many_transactionally_multi_row_pg() {
2329 let code = gen(&standard_entity(), DatabaseKind::Postgres);
2330 let method_start = code.find("fn insert_many_transactionally").expect("not found");
2331 let method_body = &code[method_start..];
2332 assert!(method_body.contains("RETURNING *"), "Expected RETURNING * in multi-row SQL:\n{}", method_body);
2333 assert!(method_body.contains("values_parts"), "Expected multi-row VALUES building:\n{}", method_body);
2334 assert!(method_body.contains("65535"), "Expected chunk size limit:\n{}", method_body);
2335 }
2336
2337 #[test]
2338 fn test_insert_many_transactionally_multi_row_sqlite() {
2339 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
2340 let method_start = code.find("fn insert_many_transactionally").expect("not found");
2341 let method_body = &code[method_start..];
2342 assert!(method_body.contains("values_parts"), "Expected multi-row VALUES building for SQLite:\n{}", method_body);
2343 assert!(method_body.contains("RETURNING *"), "Expected RETURNING * for SQLite:\n{}", method_body);
2344 }
2345
2346 #[test]
2347 fn test_insert_many_transactionally_mysql_individual_inserts() {
2348 let code = gen(&standard_entity(), DatabaseKind::Mysql);
2349 let method_start = code.find("fn insert_many_transactionally").expect("not found");
2350 let method_body = &code[method_start..];
2351 assert!(method_body.contains("LAST_INSERT_ID"), "Expected LAST_INSERT_ID for MySQL:\n{}", method_body);
2352 assert!(method_body.contains("self.pool.begin()"), "Expected begin() for MySQL:\n{}", method_body);
2353 }
2354
2355 #[test]
2356 fn test_insert_many_transactionally_view_not_generated() {
2357 let mut entity = standard_entity();
2358 entity.is_view = true;
2359 let code = gen(&entity, DatabaseKind::Postgres);
2360 assert!(!code.contains("pub async fn insert_many_transactionally"), "should not be generated for views");
2361 }
2362
2363 #[test]
2364 fn test_insert_many_transactionally_without_method_not_generated() {
2365 let m = Methods { insert_many: false, ..Methods::all() };
2366 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2367 assert!(!code.contains("pub async fn insert_many_transactionally"), "should not be generated when disabled");
2368 }
2369
2370 #[test]
2371 fn test_insert_many_transactionally_generates_params_when_insert_disabled() {
2372 let m = Methods { insert: false, insert_many: true, ..Default::default() };
2373 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
2374 assert!(code.contains("pub struct InsertUsersParams"), "Expected InsertUsersParams:\n{}", code);
2375 assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method:\n{}", code);
2376 assert!(!code.contains("pub async fn insert("), "insert should not be present:\n{}", code);
2377 }
2378
2379 #[test]
2380 fn test_insert_many_transactionally_with_column_defaults_coalesce() {
2381 let code = gen(&entity_with_defaults(), DatabaseKind::Postgres);
2382 let method_start = code.find("fn insert_many_transactionally").expect("not found");
2383 let method_body = &code[method_start..];
2384 assert!(method_body.contains("COALESCE"), "Expected COALESCE for fields with defaults:\n{}", method_body);
2385 }
2386
2387 #[test]
2388 fn test_insert_many_transactionally_junction_table() {
2389 let code = gen(&junction_entity(), DatabaseKind::Postgres);
2390 assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method for junction table:\n{}", code);
2391 }
2392
2393 #[test]
2394 fn test_insert_many_transactionally_all_three_backends_compile() {
2395 for db in [DatabaseKind::Postgres, DatabaseKind::Mysql, DatabaseKind::Sqlite] {
2396 let code = gen(&standard_entity(), db);
2397 assert!(code.contains("pub async fn insert_many_transactionally"), "Expected method for {:?}", db);
2398 }
2399 }
2400}