1use std::collections::BTreeSet;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, Methods, PoolVisibility};
7use crate::codegen::entity_parser::{ParsedEntity, ParsedField};
8
9pub fn generate_crud_from_parsed(
10 entity: &ParsedEntity,
11 db_kind: DatabaseKind,
12 entity_module_path: &str,
13 methods: &Methods,
14 query_macro: bool,
15 pool_visibility: PoolVisibility,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18
19 let entity_ident = format_ident!("{}", entity.struct_name);
20 let repo_name = format!("{}Repository", entity.struct_name);
21 let repo_ident = format_ident!("{}", repo_name);
22
23 let table_name = match &entity.schema_name {
24 Some(schema) => format!("{}.{}", schema, entity.table_name),
25 None => entity.table_name.clone(),
26 };
27
28 let pool_type = pool_type_tokens(db_kind);
30
31 let has_custom_sql_type = entity.fields.iter().any(|f| f.sql_type.is_some());
35 let use_macro = query_macro && !has_custom_sql_type && !entity.is_view;
36
37 imports.insert(format!("use {}::{};", entity_module_path, entity.struct_name));
39
40 let entity_parent = entity_module_path
44 .rsplit_once("::")
45 .map(|(parent, _)| parent)
46 .unwrap_or(entity_module_path);
47 for imp in &entity.imports {
48 if let Some(rest) = imp.strip_prefix("use super::") {
49 imports.insert(format!("use {}::{}", entity_parent, rest));
50 } else {
51 imports.insert(imp.clone());
52 }
53 }
54
55 let pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| f.is_primary_key).collect();
57
58 let non_pk_fields: Vec<&ParsedField> = entity.fields.iter().filter(|f| !f.is_primary_key).collect();
60
61 let is_view = entity.is_view;
62
63 let mut method_tokens = Vec::new();
65 let mut param_structs = Vec::new();
66
67 if methods.get_all {
69 let sql = format!("SELECT * FROM {}", table_name);
70 let method = if use_macro {
71 quote! {
72 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
73 sqlx::query_as!(#entity_ident, #sql)
74 .fetch_all(&self.pool)
75 .await
76 }
77 }
78 } else {
79 quote! {
80 pub async fn get_all(&self) -> Result<Vec<#entity_ident>, sqlx::Error> {
81 sqlx::query_as::<_, #entity_ident>(#sql)
82 .fetch_all(&self.pool)
83 .await
84 }
85 }
86 };
87 method_tokens.push(method);
88 }
89
90 if methods.paginate {
92 let paginate_params_ident = format_ident!("Paginate{}Params", entity.struct_name);
93 let paginated_ident = format_ident!("Paginated{}", entity.struct_name);
94 let pagination_meta_ident = format_ident!("Pagination{}Meta", entity.struct_name);
95 let count_sql = format!("SELECT COUNT(*) FROM {}", table_name);
96 let sql = match db_kind {
97 DatabaseKind::Postgres => format!("SELECT * FROM {} LIMIT $1 OFFSET $2", table_name),
98 DatabaseKind::Mysql | DatabaseKind::Sqlite => format!("SELECT * FROM {} LIMIT ? OFFSET ?", table_name),
99 };
100 let method = if use_macro {
101 quote! {
102 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
103 let total: i64 = sqlx::query_scalar!(#count_sql)
104 .fetch_one(&self.pool)
105 .await?
106 .unwrap_or(0);
107 let per_page = params.per_page;
108 let current_page = params.page;
109 let last_page = (total + per_page - 1) / per_page;
110 let offset = (current_page - 1) * per_page;
111 let data = sqlx::query_as!(#entity_ident, #sql, per_page, offset)
112 .fetch_all(&self.pool)
113 .await?;
114 Ok(#paginated_ident {
115 meta: #pagination_meta_ident {
116 total,
117 per_page,
118 current_page,
119 last_page,
120 first_page: 1,
121 },
122 data,
123 })
124 }
125 }
126 } else {
127 quote! {
128 pub async fn paginate(&self, params: &#paginate_params_ident) -> Result<#paginated_ident, sqlx::Error> {
129 let total: i64 = sqlx::query_scalar(#count_sql)
130 .fetch_one(&self.pool)
131 .await?;
132 let per_page = params.per_page;
133 let current_page = params.page;
134 let last_page = (total + per_page - 1) / per_page;
135 let offset = (current_page - 1) * per_page;
136 let data = sqlx::query_as::<_, #entity_ident>(#sql)
137 .bind(per_page)
138 .bind(offset)
139 .fetch_all(&self.pool)
140 .await?;
141 Ok(#paginated_ident {
142 meta: #pagination_meta_ident {
143 total,
144 per_page,
145 current_page,
146 last_page,
147 first_page: 1,
148 },
149 data,
150 })
151 }
152 }
153 };
154 method_tokens.push(method);
155 param_structs.push(quote! {
156 #[derive(Debug, Clone, Default)]
157 pub struct #paginate_params_ident {
158 pub page: i64,
159 pub per_page: i64,
160 }
161 });
162 param_structs.push(quote! {
163 #[derive(Debug, Clone)]
164 pub struct #pagination_meta_ident {
165 pub total: i64,
166 pub per_page: i64,
167 pub current_page: i64,
168 pub last_page: i64,
169 pub first_page: i64,
170 }
171 });
172 param_structs.push(quote! {
173 #[derive(Debug, Clone)]
174 pub struct #paginated_ident {
175 pub meta: #pagination_meta_ident,
176 pub data: Vec<#entity_ident>,
177 }
178 });
179 }
180
181 if methods.get && !pk_fields.is_empty() {
183 let pk_params: Vec<TokenStream> = pk_fields
184 .iter()
185 .map(|f| {
186 let name = format_ident!("{}", f.rust_name);
187 let ty: TokenStream = f.inner_type.parse().unwrap();
188 quote! { #name: #ty }
189 })
190 .collect();
191
192 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
193 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
194 let sql = format!("SELECT * FROM {} WHERE {}", table_name, where_clause);
195 let sql_macro = format!("SELECT * FROM {} WHERE {}", table_name, where_clause_cast);
196
197 let binds: Vec<TokenStream> = pk_fields
198 .iter()
199 .map(|f| {
200 let name = format_ident!("{}", f.rust_name);
201 quote! { .bind(#name) }
202 })
203 .collect();
204
205 let method = if use_macro {
206 let pk_arg_names: Vec<TokenStream> = pk_fields
207 .iter()
208 .map(|f| {
209 let name = format_ident!("{}", f.rust_name);
210 quote! { #name }
211 })
212 .collect();
213 quote! {
214 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
215 sqlx::query_as!(#entity_ident, #sql_macro, #(#pk_arg_names),*)
216 .fetch_optional(&self.pool)
217 .await
218 }
219 }
220 } else {
221 quote! {
222 pub async fn get(&self, #(#pk_params),*) -> Result<Option<#entity_ident>, sqlx::Error> {
223 sqlx::query_as::<_, #entity_ident>(#sql)
224 #(#binds)*
225 .fetch_optional(&self.pool)
226 .await
227 }
228 }
229 };
230 method_tokens.push(method);
231 }
232
233 if !is_view && methods.insert && (!non_pk_fields.is_empty() || !pk_fields.is_empty()) {
235 let insert_params_ident = format_ident!("Insert{}Params", entity.struct_name);
236
237 let insert_source_fields: Vec<&ParsedField> = if non_pk_fields.is_empty() {
239 pk_fields.clone()
240 } else {
241 non_pk_fields.clone()
242 };
243
244 let insert_fields: Vec<TokenStream> = insert_source_fields
245 .iter()
246 .map(|f| {
247 let name = format_ident!("{}", f.rust_name);
248 let ty: TokenStream = f.rust_type.parse().unwrap();
249 quote! { pub #name: #ty, }
250 })
251 .collect();
252
253 let col_names: Vec<&str> = insert_source_fields.iter().map(|f| f.column_name.as_str()).collect();
254 let col_list = col_names.join(", ");
255 let placeholders = build_placeholders(insert_source_fields.len(), db_kind, 1);
257 let placeholders_cast = build_placeholders_with_cast(&insert_source_fields, db_kind, 1, true);
258
259 let build_insert_sql = |ph: &str| match db_kind {
260 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
261 format!(
262 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
263 table_name, col_list, ph
264 )
265 }
266 DatabaseKind::Mysql => {
267 format!(
268 "INSERT INTO {} ({}) VALUES ({})",
269 table_name, col_list, ph
270 )
271 }
272 };
273 let sql = build_insert_sql(&placeholders);
274 let sql_macro = build_insert_sql(&placeholders_cast);
275
276 let binds: Vec<TokenStream> = insert_source_fields
277 .iter()
278 .map(|f| {
279 let name = format_ident!("{}", f.rust_name);
280 quote! { .bind(¶ms.#name) }
281 })
282 .collect();
283
284 let insert_method = build_insert_method_parsed(
285 &entity_ident,
286 &insert_params_ident,
287 &sql,
288 &sql_macro,
289 &binds,
290 db_kind,
291 &table_name,
292 &pk_fields,
293 &insert_source_fields,
294 use_macro,
295 );
296 method_tokens.push(insert_method);
297
298 param_structs.push(quote! {
299 #[derive(Debug, Clone, Default)]
300 pub struct #insert_params_ident {
301 #(#insert_fields)*
302 }
303 });
304 }
305
306 if !is_view && methods.update && !pk_fields.is_empty() && !non_pk_fields.is_empty() {
308 let update_params_ident = format_ident!("Update{}Params", entity.struct_name);
309
310 let update_fields: Vec<TokenStream> = entity
311 .fields
312 .iter()
313 .map(|f| {
314 let name = format_ident!("{}", f.rust_name);
315 let ty: TokenStream = f.rust_type.parse().unwrap();
316 quote! { pub #name: #ty, }
317 })
318 .collect();
319
320 let set_cols: Vec<String> = non_pk_fields
321 .iter()
322 .enumerate()
323 .map(|(i, f)| {
324 let p = placeholder(db_kind, i + 1);
325 format!("{} = {}", f.column_name, p)
326 })
327 .collect();
328 let set_clause = set_cols.join(", ");
329
330 let set_cols_cast: Vec<String> = non_pk_fields
332 .iter()
333 .enumerate()
334 .map(|(i, f)| {
335 let p = placeholder_with_cast(db_kind, i + 1, f);
336 format!("{} = {}", f.column_name, p)
337 })
338 .collect();
339 let set_clause_cast = set_cols_cast.join(", ");
340
341 let pk_start = non_pk_fields.len() + 1;
342 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, pk_start);
343 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, pk_start);
344
345 let build_update_sql = |sc: &str, wc: &str| match db_kind {
346 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
347 format!(
348 "UPDATE {} SET {} WHERE {} RETURNING *",
349 table_name, sc, wc
350 )
351 }
352 DatabaseKind::Mysql => {
353 format!(
354 "UPDATE {} SET {} WHERE {}",
355 table_name, sc, wc
356 )
357 }
358 };
359 let sql = build_update_sql(&set_clause, &where_clause);
360 let sql_macro = build_update_sql(&set_clause_cast, &where_clause_cast);
361
362 let mut all_binds: Vec<TokenStream> = non_pk_fields
364 .iter()
365 .map(|f| {
366 let name = format_ident!("{}", f.rust_name);
367 quote! { .bind(¶ms.#name) }
368 })
369 .collect();
370 for f in &pk_fields {
371 let name = format_ident!("{}", f.rust_name);
372 all_binds.push(quote! { .bind(¶ms.#name) });
373 }
374
375 let update_macro_args: Vec<TokenStream> = non_pk_fields
377 .iter()
378 .chain(pk_fields.iter())
379 .map(|f| macro_arg_for_field(f))
380 .collect();
381
382 let update_method = if use_macro {
383 match db_kind {
384 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
385 quote! {
386 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
387 sqlx::query_as!(#entity_ident, #sql_macro, #(#update_macro_args),*)
388 .fetch_one(&self.pool)
389 .await
390 }
391 }
392 }
393 DatabaseKind::Mysql => {
394 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
395 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
396 let pk_macro_args: Vec<TokenStream> = pk_fields
397 .iter()
398 .map(|f| {
399 let name = format_ident!("{}", f.rust_name);
400 quote! { params.#name }
401 })
402 .collect();
403 quote! {
404 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
405 sqlx::query!(#sql_macro, #(#update_macro_args),*)
406 .execute(&self.pool)
407 .await?;
408 sqlx::query_as!(#entity_ident, #select_sql, #(#pk_macro_args),*)
409 .fetch_one(&self.pool)
410 .await
411 }
412 }
413 }
414 }
415 } else {
416 match db_kind {
417 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
418 quote! {
419 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
420 sqlx::query_as::<_, #entity_ident>(#sql)
421 #(#all_binds)*
422 .fetch_one(&self.pool)
423 .await
424 }
425 }
426 }
427 DatabaseKind::Mysql => {
428 let pk_where_select = build_where_clause_parsed(&pk_fields, db_kind, 1);
429 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where_select);
430 let pk_binds: Vec<TokenStream> = pk_fields
431 .iter()
432 .map(|f| {
433 let name = format_ident!("{}", f.rust_name);
434 quote! { .bind(¶ms.#name) }
435 })
436 .collect();
437 quote! {
438 pub async fn update(&self, params: &#update_params_ident) -> Result<#entity_ident, sqlx::Error> {
439 sqlx::query(#sql)
440 #(#all_binds)*
441 .execute(&self.pool)
442 .await?;
443 sqlx::query_as::<_, #entity_ident>(#select_sql)
444 #(#pk_binds)*
445 .fetch_one(&self.pool)
446 .await
447 }
448 }
449 }
450 }
451 };
452 method_tokens.push(update_method);
453
454 param_structs.push(quote! {
455 #[derive(Debug, Clone, Default)]
456 pub struct #update_params_ident {
457 #(#update_fields)*
458 }
459 });
460 }
461
462 if !is_view && methods.delete && !pk_fields.is_empty() {
464 let pk_params: Vec<TokenStream> = pk_fields
465 .iter()
466 .map(|f| {
467 let name = format_ident!("{}", f.rust_name);
468 let ty: TokenStream = f.inner_type.parse().unwrap();
469 quote! { #name: #ty }
470 })
471 .collect();
472
473 let where_clause = build_where_clause_parsed(&pk_fields, db_kind, 1);
474 let where_clause_cast = build_where_clause_cast(&pk_fields, db_kind, 1);
475 let sql = format!("DELETE FROM {} WHERE {}", table_name, where_clause);
476 let sql_macro = format!("DELETE FROM {} WHERE {}", table_name, where_clause_cast);
477
478 let binds: Vec<TokenStream> = pk_fields
479 .iter()
480 .map(|f| {
481 let name = format_ident!("{}", f.rust_name);
482 quote! { .bind(#name) }
483 })
484 .collect();
485
486 let method = if query_macro {
487 let pk_arg_names: Vec<TokenStream> = pk_fields
488 .iter()
489 .map(|f| {
490 let name = format_ident!("{}", f.rust_name);
491 quote! { #name }
492 })
493 .collect();
494 quote! {
495 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
496 sqlx::query!(#sql_macro, #(#pk_arg_names),*)
497 .execute(&self.pool)
498 .await?;
499 Ok(())
500 }
501 }
502 } else {
503 quote! {
504 pub async fn delete(&self, #(#pk_params),*) -> Result<(), sqlx::Error> {
505 sqlx::query(#sql)
506 #(#binds)*
507 .execute(&self.pool)
508 .await?;
509 Ok(())
510 }
511 }
512 };
513 method_tokens.push(method);
514 }
515
516 let pool_vis: TokenStream = match pool_visibility {
517 PoolVisibility::Private => quote! {},
518 PoolVisibility::Pub => quote! { pub },
519 PoolVisibility::PubCrate => quote! { pub(crate) },
520 };
521
522 let tokens = quote! {
523 #(#param_structs)*
524
525 pub struct #repo_ident {
526 #pool_vis pool: #pool_type,
527 }
528
529 impl #repo_ident {
530 pub fn new(pool: #pool_type) -> Self {
531 Self { pool }
532 }
533
534 #(#method_tokens)*
535 }
536 };
537
538 (tokens, imports)
539}
540
541fn pool_type_tokens(db_kind: DatabaseKind) -> TokenStream {
542 match db_kind {
543 DatabaseKind::Postgres => quote! { sqlx::PgPool },
544 DatabaseKind::Mysql => quote! { sqlx::MySqlPool },
545 DatabaseKind::Sqlite => quote! { sqlx::SqlitePool },
546 }
547}
548
549fn placeholder(db_kind: DatabaseKind, index: usize) -> String {
550 match db_kind {
551 DatabaseKind::Postgres => format!("${}", index),
552 DatabaseKind::Mysql | DatabaseKind::Sqlite => "?".to_string(),
553 }
554}
555
556fn placeholder_with_cast(db_kind: DatabaseKind, index: usize, field: &ParsedField) -> String {
557 let base = placeholder(db_kind, index);
558 match (&field.sql_type, field.is_sql_array) {
559 (Some(t), true) => format!("{} as {}[]", base, t),
560 (Some(t), false) => format!("{} as {}", base, t),
561 (None, _) => base,
562 }
563}
564
565fn build_placeholders(count: usize, db_kind: DatabaseKind, start: usize) -> String {
566 (0..count)
567 .map(|i| placeholder(db_kind, start + i))
568 .collect::<Vec<_>>()
569 .join(", ")
570}
571
572fn build_placeholders_with_cast(fields: &[&ParsedField], db_kind: DatabaseKind, start: usize, use_cast: bool) -> String {
573 fields
574 .iter()
575 .enumerate()
576 .map(|(i, f)| {
577 if use_cast {
578 placeholder_with_cast(db_kind, start + i, f)
579 } else {
580 placeholder(db_kind, start + i)
581 }
582 })
583 .collect::<Vec<_>>()
584 .join(", ")
585}
586
587fn build_where_clause_parsed(
588 pk_fields: &[&ParsedField],
589 db_kind: DatabaseKind,
590 start_index: usize,
591) -> String {
592 pk_fields
593 .iter()
594 .enumerate()
595 .map(|(i, f)| {
596 let p = placeholder(db_kind, start_index + i);
597 format!("{} = {}", f.column_name, p)
598 })
599 .collect::<Vec<_>>()
600 .join(" AND ")
601}
602
603fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
604 let name = format_ident!("{}", field.rust_name);
605 let check_type = if field.is_nullable {
606 &field.inner_type
607 } else {
608 &field.rust_type
609 };
610 let normalized = check_type.replace(' ', "");
611 if normalized.starts_with("Vec<") {
612 quote! { params.#name.as_slice() }
613 } else {
614 quote! { params.#name }
615 }
616}
617
618fn build_where_clause_cast(
619 pk_fields: &[&ParsedField],
620 db_kind: DatabaseKind,
621 start_index: usize,
622) -> String {
623 pk_fields
624 .iter()
625 .enumerate()
626 .map(|(i, f)| {
627 let p = placeholder_with_cast(db_kind, start_index + i, f);
628 format!("{} = {}", f.column_name, p)
629 })
630 .collect::<Vec<_>>()
631 .join(" AND ")
632}
633
634#[allow(clippy::too_many_arguments)]
635fn build_insert_method_parsed(
636 entity_ident: &proc_macro2::Ident,
637 insert_params_ident: &proc_macro2::Ident,
638 sql: &str,
639 sql_macro: &str,
640 binds: &[TokenStream],
641 db_kind: DatabaseKind,
642 table_name: &str,
643 pk_fields: &[&ParsedField],
644 non_pk_fields: &[&ParsedField],
645 use_macro: bool,
646) -> TokenStream {
647 if use_macro {
648 let macro_args: Vec<TokenStream> = non_pk_fields
649 .iter()
650 .map(|f| macro_arg_for_field(f))
651 .collect();
652
653 match db_kind {
654 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
655 quote! {
656 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
657 sqlx::query_as!(#entity_ident, #sql_macro, #(#macro_args),*)
658 .fetch_one(&self.pool)
659 .await
660 }
661 }
662 }
663 DatabaseKind::Mysql => {
664 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
665 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
666 quote! {
667 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
668 sqlx::query!(#sql_macro, #(#macro_args),*)
669 .execute(&self.pool)
670 .await?;
671 let id = sqlx::query_scalar!("SELECT LAST_INSERT_ID() as id")
672 .fetch_one(&self.pool)
673 .await?;
674 sqlx::query_as!(#entity_ident, #select_sql, id)
675 .fetch_one(&self.pool)
676 .await
677 }
678 }
679 }
680 }
681 } else {
682 match db_kind {
683 DatabaseKind::Postgres | DatabaseKind::Sqlite => {
684 quote! {
685 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
686 sqlx::query_as::<_, #entity_ident>(#sql)
687 #(#binds)*
688 .fetch_one(&self.pool)
689 .await
690 }
691 }
692 }
693 DatabaseKind::Mysql => {
694 let pk_where = build_where_clause_parsed(pk_fields, db_kind, 1);
695 let select_sql = format!("SELECT * FROM {} WHERE {}", table_name, pk_where);
696 quote! {
697 pub async fn insert(&self, params: &#insert_params_ident) -> Result<#entity_ident, sqlx::Error> {
698 sqlx::query(#sql)
699 #(#binds)*
700 .execute(&self.pool)
701 .await?;
702 let id = sqlx::query_scalar::<_, i64>("SELECT LAST_INSERT_ID()")
703 .fetch_one(&self.pool)
704 .await?;
705 sqlx::query_as::<_, #entity_ident>(#select_sql)
706 .bind(id)
707 .fetch_one(&self.pool)
708 .await
709 }
710 }
711 }
712 }
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719 use crate::codegen::parse_and_format;
720 use crate::cli::Methods;
721
722 fn make_field(rust_name: &str, column_name: &str, rust_type: &str, nullable: bool, is_pk: bool) -> ParsedField {
723 let inner_type = if nullable {
724 rust_type
726 .strip_prefix("Option<")
727 .and_then(|s| s.strip_suffix('>'))
728 .unwrap_or(rust_type)
729 .to_string()
730 } else {
731 rust_type.to_string()
732 };
733 ParsedField {
734 rust_name: rust_name.to_string(),
735 column_name: column_name.to_string(),
736 rust_type: rust_type.to_string(),
737 is_nullable: nullable,
738 inner_type,
739 is_primary_key: is_pk,
740 sql_type: None,
741 is_sql_array: false,
742 }
743 }
744
745 fn standard_entity() -> ParsedEntity {
746 ParsedEntity {
747 struct_name: "Users".to_string(),
748 table_name: "users".to_string(),
749 schema_name: None,
750 is_view: false,
751 fields: vec![
752 make_field("id", "id", "i32", false, true),
753 make_field("name", "name", "String", false, false),
754 make_field("email", "email", "Option<String>", true, false),
755 ],
756 imports: vec![],
757 }
758 }
759
760 fn gen(entity: &ParsedEntity, db: DatabaseKind) -> String {
761 let skip = Methods::all();
762 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, false, PoolVisibility::Private);
763 parse_and_format(&tokens)
764 }
765
766 fn gen_macro(entity: &ParsedEntity, db: DatabaseKind) -> String {
767 let skip = Methods::all();
768 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", &skip, true, PoolVisibility::Private);
769 parse_and_format(&tokens)
770 }
771
772 fn gen_with_methods(entity: &ParsedEntity, db: DatabaseKind, methods: &Methods) -> String {
773 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::users", methods, false, PoolVisibility::Private);
774 parse_and_format(&tokens)
775 }
776
777 #[test]
780 fn test_repo_struct_name() {
781 let code = gen(&standard_entity(), DatabaseKind::Postgres);
782 assert!(code.contains("pub struct UsersRepository"));
783 }
784
785 #[test]
786 fn test_repo_new_method() {
787 let code = gen(&standard_entity(), DatabaseKind::Postgres);
788 assert!(code.contains("pub fn new("));
789 }
790
791 #[test]
792 fn test_repo_pool_field_pg() {
793 let code = gen(&standard_entity(), DatabaseKind::Postgres);
794 assert!(code.contains("pool: sqlx::PgPool") || code.contains("pool: sqlx :: PgPool"));
795 }
796
797 #[test]
798 fn test_repo_pool_field_pub() {
799 let skip = Methods::all();
800 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Pub);
801 let code = parse_and_format(&tokens);
802 assert!(code.contains("pub pool: sqlx::PgPool") || code.contains("pub pool: sqlx :: PgPool"));
803 }
804
805 #[test]
806 fn test_repo_pool_field_pub_crate() {
807 let skip = Methods::all();
808 let (tokens, _) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::PubCrate);
809 let code = parse_and_format(&tokens);
810 assert!(code.contains("pub(crate) pool: sqlx::PgPool") || code.contains("pub(crate) pool: sqlx :: PgPool"));
811 }
812
813 #[test]
814 fn test_repo_pool_field_private() {
815 let code = gen(&standard_entity(), DatabaseKind::Postgres);
816 assert!(!code.contains("pub pool"));
818 assert!(!code.contains("pub(crate) pool"));
819 }
820
821 #[test]
822 fn test_repo_pool_field_mysql() {
823 let code = gen(&standard_entity(), DatabaseKind::Mysql);
824 assert!(code.contains("MySqlPool") || code.contains("MySql"));
825 }
826
827 #[test]
828 fn test_repo_pool_field_sqlite() {
829 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
830 assert!(code.contains("SqlitePool") || code.contains("Sqlite"));
831 }
832
833 #[test]
836 fn test_get_all_method() {
837 let code = gen(&standard_entity(), DatabaseKind::Postgres);
838 assert!(code.contains("pub async fn get_all"));
839 }
840
841 #[test]
842 fn test_get_all_returns_vec() {
843 let code = gen(&standard_entity(), DatabaseKind::Postgres);
844 assert!(code.contains("Vec<Users>"));
845 }
846
847 #[test]
848 fn test_get_all_sql() {
849 let code = gen(&standard_entity(), DatabaseKind::Postgres);
850 assert!(code.contains("SELECT * FROM users"));
851 }
852
853 #[test]
856 fn test_paginate_method() {
857 let code = gen(&standard_entity(), DatabaseKind::Postgres);
858 assert!(code.contains("pub async fn paginate"));
859 }
860
861 #[test]
862 fn test_paginate_params_struct() {
863 let code = gen(&standard_entity(), DatabaseKind::Postgres);
864 assert!(code.contains("pub struct PaginateUsersParams"));
865 }
866
867 #[test]
868 fn test_paginate_params_fields() {
869 let code = gen(&standard_entity(), DatabaseKind::Postgres);
870 assert!(code.contains("pub page: i64"));
871 assert!(code.contains("pub per_page: i64"));
872 }
873
874 #[test]
875 fn test_paginate_returns_paginated() {
876 let code = gen(&standard_entity(), DatabaseKind::Postgres);
877 assert!(code.contains("PaginatedUsers"));
878 assert!(code.contains("PaginationUsersMeta"));
879 }
880
881 #[test]
882 fn test_paginate_meta_struct() {
883 let code = gen(&standard_entity(), DatabaseKind::Postgres);
884 assert!(code.contains("pub struct PaginationUsersMeta"));
885 assert!(code.contains("pub total: i64"));
886 assert!(code.contains("pub last_page: i64"));
887 assert!(code.contains("pub first_page: i64"));
888 assert!(code.contains("pub current_page: i64"));
889 }
890
891 #[test]
892 fn test_paginate_data_struct() {
893 let code = gen(&standard_entity(), DatabaseKind::Postgres);
894 assert!(code.contains("pub struct PaginatedUsers"));
895 assert!(code.contains("pub meta: PaginationUsersMeta"));
896 assert!(code.contains("pub data: Vec<Users>"));
897 }
898
899 #[test]
900 fn test_paginate_count_sql() {
901 let code = gen(&standard_entity(), DatabaseKind::Postgres);
902 assert!(code.contains("SELECT COUNT(*) FROM users"));
903 }
904
905 #[test]
906 fn test_paginate_sql_pg() {
907 let code = gen(&standard_entity(), DatabaseKind::Postgres);
908 assert!(code.contains("LIMIT $1 OFFSET $2"));
909 }
910
911 #[test]
912 fn test_paginate_sql_mysql() {
913 let code = gen(&standard_entity(), DatabaseKind::Mysql);
914 assert!(code.contains("LIMIT ? OFFSET ?"));
915 }
916
917 #[test]
920 fn test_get_method() {
921 let code = gen(&standard_entity(), DatabaseKind::Postgres);
922 assert!(code.contains("pub async fn get"));
923 }
924
925 #[test]
926 fn test_get_returns_option() {
927 let code = gen(&standard_entity(), DatabaseKind::Postgres);
928 assert!(code.contains("Option<Users>"));
929 }
930
931 #[test]
932 fn test_get_where_pk_pg() {
933 let code = gen(&standard_entity(), DatabaseKind::Postgres);
934 assert!(code.contains("WHERE id = $1"));
935 }
936
937 #[test]
938 fn test_get_where_pk_mysql() {
939 let code = gen(&standard_entity(), DatabaseKind::Mysql);
940 assert!(code.contains("WHERE id = ?"));
941 }
942
943 #[test]
946 fn test_insert_method() {
947 let code = gen(&standard_entity(), DatabaseKind::Postgres);
948 assert!(code.contains("pub async fn insert"));
949 }
950
951 #[test]
952 fn test_insert_params_struct() {
953 let code = gen(&standard_entity(), DatabaseKind::Postgres);
954 assert!(code.contains("pub struct InsertUsersParams"));
955 }
956
957 #[test]
958 fn test_insert_params_no_pk() {
959 let code = gen(&standard_entity(), DatabaseKind::Postgres);
960 assert!(code.contains("pub name: String"));
961 assert!(code.contains("pub email: Option<String>") || code.contains("pub email: Option < String >"));
962 }
963
964 #[test]
965 fn test_insert_returning_pg() {
966 let code = gen(&standard_entity(), DatabaseKind::Postgres);
967 assert!(code.contains("RETURNING *"));
968 }
969
970 #[test]
971 fn test_insert_returning_sqlite() {
972 let code = gen(&standard_entity(), DatabaseKind::Sqlite);
973 assert!(code.contains("RETURNING *"));
974 }
975
976 #[test]
977 fn test_insert_mysql_last_insert_id() {
978 let code = gen(&standard_entity(), DatabaseKind::Mysql);
979 assert!(code.contains("LAST_INSERT_ID"));
980 }
981
982 #[test]
985 fn test_update_method() {
986 let code = gen(&standard_entity(), DatabaseKind::Postgres);
987 assert!(code.contains("pub async fn update"));
988 }
989
990 #[test]
991 fn test_update_params_struct() {
992 let code = gen(&standard_entity(), DatabaseKind::Postgres);
993 assert!(code.contains("pub struct UpdateUsersParams"));
994 }
995
996 #[test]
997 fn test_update_params_all_cols() {
998 let code = gen(&standard_entity(), DatabaseKind::Postgres);
999 assert!(code.contains("pub id: i32"));
1000 assert!(code.contains("pub name: String"));
1001 }
1002
1003 #[test]
1004 fn test_update_set_clause_pg() {
1005 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1006 assert!(code.contains("SET name = $1"));
1007 assert!(code.contains("WHERE id = $3"));
1008 }
1009
1010 #[test]
1011 fn test_update_returning_pg() {
1012 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1013 assert!(code.contains("UPDATE users SET"));
1014 assert!(code.contains("RETURNING *"));
1015 }
1016
1017 #[test]
1020 fn test_delete_method() {
1021 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1022 assert!(code.contains("pub async fn delete"));
1023 }
1024
1025 #[test]
1026 fn test_delete_where_pk() {
1027 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1028 assert!(code.contains("DELETE FROM users WHERE id = $1"));
1029 }
1030
1031 #[test]
1032 fn test_delete_returns_unit() {
1033 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1034 assert!(code.contains("Result<(), sqlx::Error>") || code.contains("Result<(), sqlx :: Error>"));
1035 }
1036
1037 #[test]
1040 fn test_view_no_insert() {
1041 let mut entity = standard_entity();
1042 entity.is_view = true;
1043 let code = gen(&entity, DatabaseKind::Postgres);
1044 assert!(!code.contains("pub async fn insert"));
1045 }
1046
1047 #[test]
1048 fn test_view_no_update() {
1049 let mut entity = standard_entity();
1050 entity.is_view = true;
1051 let code = gen(&entity, DatabaseKind::Postgres);
1052 assert!(!code.contains("pub async fn update"));
1053 }
1054
1055 #[test]
1056 fn test_view_no_delete() {
1057 let mut entity = standard_entity();
1058 entity.is_view = true;
1059 let code = gen(&entity, DatabaseKind::Postgres);
1060 assert!(!code.contains("pub async fn delete"));
1061 }
1062
1063 #[test]
1064 fn test_view_has_get_all() {
1065 let mut entity = standard_entity();
1066 entity.is_view = true;
1067 let code = gen(&entity, DatabaseKind::Postgres);
1068 assert!(code.contains("pub async fn get_all"));
1069 }
1070
1071 #[test]
1072 fn test_view_has_paginate() {
1073 let mut entity = standard_entity();
1074 entity.is_view = true;
1075 let code = gen(&entity, DatabaseKind::Postgres);
1076 assert!(code.contains("pub async fn paginate"));
1077 }
1078
1079 #[test]
1080 fn test_view_has_get() {
1081 let mut entity = standard_entity();
1082 entity.is_view = true;
1083 let code = gen(&entity, DatabaseKind::Postgres);
1084 assert!(code.contains("pub async fn get"));
1085 }
1086
1087 #[test]
1090 fn test_only_get_all() {
1091 let m = Methods { get_all: true, ..Default::default() };
1092 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1093 assert!(code.contains("pub async fn get_all"));
1094 assert!(!code.contains("pub async fn paginate"));
1095 assert!(!code.contains("pub async fn insert"));
1096 }
1097
1098 #[test]
1099 fn test_without_get_all() {
1100 let m = Methods { get_all: false, ..Methods::all() };
1101 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1102 assert!(!code.contains("pub async fn get_all"));
1103 }
1104
1105 #[test]
1106 fn test_without_paginate() {
1107 let m = Methods { paginate: false, ..Methods::all() };
1108 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1109 assert!(!code.contains("pub async fn paginate"));
1110 assert!(!code.contains("PaginateUsersParams"));
1111 }
1112
1113 #[test]
1114 fn test_without_get() {
1115 let m = Methods { get: false, ..Methods::all() };
1116 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1117 assert!(code.contains("pub async fn get_all"));
1118 let without_get_all = code.replace("get_all", "XXX");
1119 assert!(!without_get_all.contains("fn get("));
1120 }
1121
1122 #[test]
1123 fn test_without_insert() {
1124 let m = Methods { insert: false, ..Methods::all() };
1125 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1126 assert!(!code.contains("pub async fn insert"));
1127 assert!(!code.contains("InsertUsersParams"));
1128 }
1129
1130 #[test]
1131 fn test_without_update() {
1132 let m = Methods { update: false, ..Methods::all() };
1133 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1134 assert!(!code.contains("pub async fn update"));
1135 assert!(!code.contains("UpdateUsersParams"));
1136 }
1137
1138 #[test]
1139 fn test_without_delete() {
1140 let m = Methods { delete: false, ..Methods::all() };
1141 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1142 assert!(!code.contains("pub async fn delete"));
1143 }
1144
1145 #[test]
1146 fn test_empty_methods_no_methods() {
1147 let m = Methods::default();
1148 let code = gen_with_methods(&standard_entity(), DatabaseKind::Postgres, &m);
1149 assert!(!code.contains("pub async fn get_all"));
1150 assert!(!code.contains("pub async fn paginate"));
1151 assert!(!code.contains("pub async fn insert"));
1152 assert!(!code.contains("pub async fn update"));
1153 assert!(!code.contains("pub async fn delete"));
1154 }
1155
1156 #[test]
1159 fn test_no_pool_import() {
1160 let skip = Methods::all();
1161 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1162 assert!(!imports.iter().any(|i| i.contains("PgPool")));
1163 }
1164
1165 #[test]
1166 fn test_imports_contain_entity() {
1167 let skip = Methods::all();
1168 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1169 assert!(imports.iter().any(|i| i.contains("crate::models::users::Users")));
1170 }
1171
1172 #[test]
1175 fn test_renamed_column_in_sql() {
1176 let entity = ParsedEntity {
1177 struct_name: "Connector".to_string(),
1178 table_name: "connector".to_string(),
1179 schema_name: None,
1180 is_view: false,
1181 fields: vec![
1182 make_field("id", "id", "i32", false, true),
1183 make_field("connector_type", "type", "String", false, false),
1184 ],
1185 imports: vec![],
1186 };
1187 let code = gen(&entity, DatabaseKind::Postgres);
1188 assert!(code.contains("type"));
1190 assert!(code.contains("pub connector_type: String"));
1192 }
1193
1194 #[test]
1197 fn test_no_pk_no_get() {
1198 let entity = ParsedEntity {
1199 struct_name: "Logs".to_string(),
1200 table_name: "logs".to_string(),
1201 schema_name: None,
1202 is_view: false,
1203 fields: vec![
1204 make_field("message", "message", "String", false, false),
1205 make_field("ts", "ts", "String", false, false),
1206 ],
1207 imports: vec![],
1208 };
1209 let code = gen(&entity, DatabaseKind::Postgres);
1210 assert!(code.contains("pub async fn get_all"));
1211 let without_get_all = code.replace("get_all", "XXX");
1212 assert!(!without_get_all.contains("fn get("));
1213 }
1214
1215 #[test]
1216 fn test_no_pk_no_delete() {
1217 let entity = ParsedEntity {
1218 struct_name: "Logs".to_string(),
1219 table_name: "logs".to_string(),
1220 schema_name: None,
1221 is_view: false,
1222 fields: vec![
1223 make_field("message", "message", "String", false, false),
1224 ],
1225 imports: vec![],
1226 };
1227 let code = gen(&entity, DatabaseKind::Postgres);
1228 assert!(!code.contains("pub async fn delete"));
1229 }
1230
1231 #[test]
1234 fn test_param_structs_have_default() {
1235 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1236 assert!(code.contains("Default"));
1237 }
1238
1239 #[test]
1242 fn test_entity_imports_forwarded() {
1243 let entity = ParsedEntity {
1244 struct_name: "Users".to_string(),
1245 table_name: "users".to_string(),
1246 schema_name: None,
1247 is_view: false,
1248 fields: vec![
1249 make_field("id", "id", "Uuid", false, true),
1250 make_field("created_at", "created_at", "DateTime<Utc>", false, false),
1251 ],
1252 imports: vec![
1253 "use chrono::{DateTime, Utc};".to_string(),
1254 "use uuid::Uuid;".to_string(),
1255 ],
1256 };
1257 let skip = Methods::all();
1258 let (_, imports) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1259 assert!(imports.iter().any(|i| i.contains("chrono")));
1260 assert!(imports.iter().any(|i| i.contains("uuid")));
1261 }
1262
1263 #[test]
1264 fn test_entity_imports_empty_when_no_imports() {
1265 let skip = Methods::all();
1266 let (_, imports) = generate_crud_from_parsed(&standard_entity(), DatabaseKind::Postgres, "crate::models::users", &skip, false, PoolVisibility::Private);
1267 assert!(!imports.iter().any(|i| i.contains("chrono")));
1269 assert!(!imports.iter().any(|i| i.contains("uuid")));
1270 }
1271
1272 #[test]
1275 fn test_macro_get_all() {
1276 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1277 assert!(code.contains("query_as!"));
1278 assert!(!code.contains("query_as::<"));
1279 }
1280
1281 #[test]
1282 fn test_macro_paginate() {
1283 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1284 assert!(code.contains("query_as!"));
1285 assert!(code.contains("per_page, offset"));
1286 }
1287
1288 #[test]
1289 fn test_macro_get() {
1290 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1291 assert!(code.contains("query_as!(Users"));
1293 }
1294
1295 #[test]
1296 fn test_macro_insert_pg() {
1297 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1298 assert!(code.contains("query_as!(Users"));
1299 assert!(code.contains("params.name"));
1300 assert!(code.contains("params.email"));
1301 }
1302
1303 #[test]
1304 fn test_macro_insert_mysql() {
1305 let code = gen_macro(&standard_entity(), DatabaseKind::Mysql);
1306 assert!(code.contains("query!"));
1308 assert!(code.contains("query_scalar!"));
1309 }
1310
1311 #[test]
1312 fn test_macro_update() {
1313 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1314 assert!(code.contains("query_as!(Users"));
1315 assert!(code.contains("params.name"));
1317 assert!(code.contains("params.id"));
1318 }
1319
1320 #[test]
1321 fn test_macro_delete() {
1322 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1323 assert!(code.contains("query!"));
1325 }
1326
1327 #[test]
1328 fn test_macro_no_bind_calls() {
1329 let code = gen_macro(&standard_entity(), DatabaseKind::Postgres);
1330 assert!(!code.contains(".bind("));
1331 }
1332
1333 #[test]
1334 fn test_function_style_uses_bind() {
1335 let code = gen(&standard_entity(), DatabaseKind::Postgres);
1336 assert!(code.contains(".bind("));
1337 assert!(!code.contains("query_as!("));
1338 assert!(!code.contains("query!("));
1339 }
1340
1341 fn entity_with_sql_array() -> ParsedEntity {
1344 ParsedEntity {
1345 struct_name: "AgentConnector".to_string(),
1346 table_name: "agent.agent_connector".to_string(),
1347 schema_name: Some("agent".to_string()),
1348 is_view: false,
1349 fields: vec![
1350 ParsedField {
1351 rust_name: "connector_id".to_string(),
1352 column_name: "connector_id".to_string(),
1353 rust_type: "Uuid".to_string(),
1354 inner_type: "Uuid".to_string(),
1355 is_nullable: false,
1356 is_primary_key: true,
1357 sql_type: None,
1358 is_sql_array: false,
1359 },
1360 ParsedField {
1361 rust_name: "agent_id".to_string(),
1362 column_name: "agent_id".to_string(),
1363 rust_type: "Uuid".to_string(),
1364 inner_type: "Uuid".to_string(),
1365 is_nullable: false,
1366 is_primary_key: false,
1367 sql_type: None,
1368 is_sql_array: false,
1369 },
1370 ParsedField {
1371 rust_name: "usages".to_string(),
1372 column_name: "usages".to_string(),
1373 rust_type: "Vec<ConnectorUsages>".to_string(),
1374 inner_type: "Vec<ConnectorUsages>".to_string(),
1375 is_nullable: false,
1376 is_primary_key: false,
1377 sql_type: Some("agent.connector_usages".to_string()),
1378 is_sql_array: true,
1379 },
1380 ],
1381 imports: vec!["use uuid::Uuid;".to_string()],
1382 }
1383 }
1384
1385 fn gen_macro_array(entity: &ParsedEntity, db: DatabaseKind) -> String {
1386 let skip = Methods::all();
1387 let (tokens, _) = generate_crud_from_parsed(entity, db, "crate::models::agent_connector", &skip, true, PoolVisibility::Private);
1388 parse_and_format(&tokens)
1389 }
1390
1391 #[test]
1392 fn test_sql_array_macro_get_all_uses_runtime() {
1393 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1394 assert!(code.contains("query_as::<"));
1396 }
1397
1398 #[test]
1399 fn test_sql_array_macro_get_uses_runtime() {
1400 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1401 assert!(code.contains(".bind("));
1403 }
1404
1405 #[test]
1406 fn test_sql_array_macro_insert_uses_runtime() {
1407 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1408 assert!(code.contains("query_as::<_ , AgentConnector>") || code.contains("query_as::<_, AgentConnector>"));
1410 }
1411
1412 #[test]
1413 fn test_sql_array_macro_update_uses_runtime() {
1414 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1415 assert!(code.contains("query_as::<"));
1417 }
1418
1419 #[test]
1420 fn test_sql_array_macro_delete_still_uses_macro() {
1421 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1422 assert!(code.contains("query!"));
1424 }
1425
1426 #[test]
1427 fn test_sql_array_no_query_as_macro() {
1428 let code = gen_macro_array(&entity_with_sql_array(), DatabaseKind::Postgres);
1429 assert!(!code.contains("query_as!("));
1431 }
1432
1433 fn entity_with_sql_enum() -> ParsedEntity {
1436 ParsedEntity {
1437 struct_name: "Task".to_string(),
1438 table_name: "tasks".to_string(),
1439 schema_name: None,
1440 is_view: false,
1441 fields: vec![
1442 ParsedField {
1443 rust_name: "id".to_string(),
1444 column_name: "id".to_string(),
1445 rust_type: "i32".to_string(),
1446 inner_type: "i32".to_string(),
1447 is_nullable: false,
1448 is_primary_key: true,
1449 sql_type: None,
1450 is_sql_array: false,
1451 },
1452 ParsedField {
1453 rust_name: "status".to_string(),
1454 column_name: "status".to_string(),
1455 rust_type: "TaskStatus".to_string(),
1456 inner_type: "TaskStatus".to_string(),
1457 is_nullable: false,
1458 is_primary_key: false,
1459 sql_type: Some("task_status".to_string()),
1460 is_sql_array: false,
1461 },
1462 ],
1463 imports: vec![],
1464 }
1465 }
1466
1467 #[test]
1468 fn test_sql_enum_macro_uses_runtime() {
1469 let skip = Methods::all();
1470 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1471 let code = parse_and_format(&tokens);
1472 assert!(code.contains("query_as::<"));
1474 assert!(!code.contains("query_as!("));
1475 }
1476
1477 #[test]
1478 fn test_sql_enum_macro_delete_still_uses_macro() {
1479 let skip = Methods::all();
1480 let (tokens, _) = generate_crud_from_parsed(&entity_with_sql_enum(), DatabaseKind::Postgres, "crate::models::task", &skip, true, PoolVisibility::Private);
1481 let code = parse_and_format(&tokens);
1482 assert!(code.contains("query!"));
1484 }
1485
1486 fn entity_with_vec_string() -> ParsedEntity {
1489 ParsedEntity {
1490 struct_name: "PromptHistory".to_string(),
1491 table_name: "prompt_history".to_string(),
1492 schema_name: None,
1493 is_view: false,
1494 fields: vec![
1495 ParsedField {
1496 rust_name: "id".to_string(),
1497 column_name: "id".to_string(),
1498 rust_type: "Uuid".to_string(),
1499 inner_type: "Uuid".to_string(),
1500 is_nullable: false,
1501 is_primary_key: true,
1502 sql_type: None,
1503 is_sql_array: false,
1504 },
1505 ParsedField {
1506 rust_name: "content".to_string(),
1507 column_name: "content".to_string(),
1508 rust_type: "String".to_string(),
1509 inner_type: "String".to_string(),
1510 is_nullable: false,
1511 is_primary_key: false,
1512 sql_type: None,
1513 is_sql_array: false,
1514 },
1515 ParsedField {
1516 rust_name: "tags".to_string(),
1517 column_name: "tags".to_string(),
1518 rust_type: "Vec<String>".to_string(),
1519 inner_type: "Vec<String>".to_string(),
1520 is_nullable: false,
1521 is_primary_key: false,
1522 sql_type: None,
1523 is_sql_array: false,
1524 },
1525 ],
1526 imports: vec!["use uuid::Uuid;".to_string()],
1527 }
1528 }
1529
1530 #[test]
1531 fn test_vec_string_macro_insert_uses_as_slice() {
1532 let skip = Methods::all();
1533 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1534 let code = parse_and_format(&tokens);
1535 assert!(code.contains("as_slice()"));
1536 }
1537
1538 #[test]
1539 fn test_vec_string_macro_update_uses_as_slice() {
1540 let skip = Methods::all();
1541 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1542 let code = parse_and_format(&tokens);
1543 let count = code.matches("as_slice()").count();
1545 assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1546 }
1547
1548 #[test]
1549 fn test_vec_string_non_macro_no_as_slice() {
1550 let skip = Methods::all();
1551 let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false, PoolVisibility::Private);
1552 let code = parse_and_format(&tokens);
1553 assert!(!code.contains("as_slice()"));
1555 }
1556
1557 #[test]
1558 fn test_vec_string_parsed_from_source_uses_as_slice() {
1559 use crate::codegen::entity_parser::parse_entity_source;
1560 let source = r#"
1561 use uuid::Uuid;
1562
1563 #[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1564 #[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1565 pub struct PromptHistory {
1566 #[sqlx_gen(primary_key)]
1567 pub id: Uuid,
1568 pub content: String,
1569 pub tags: Vec<String>,
1570 }
1571 "#;
1572 let entity = parse_entity_source(source).unwrap();
1573 let skip = Methods::all();
1574 let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true, PoolVisibility::Private);
1575 let code = parse_and_format(&tokens);
1576 assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1577 }
1578
1579 fn junction_entity() -> ParsedEntity {
1582 ParsedEntity {
1583 struct_name: "AnalysisRecord".to_string(),
1584 table_name: "analysis.analysis__record".to_string(),
1585 schema_name: None,
1586 is_view: false,
1587 fields: vec![
1588 make_field("record_id", "record_id", "uuid::Uuid", false, true),
1589 make_field("analysis_id", "analysis_id", "uuid::Uuid", false, true),
1590 ],
1591 imports: vec![],
1592 }
1593 }
1594
1595 #[test]
1596 fn test_composite_pk_only_insert_generated() {
1597 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1598 assert!(code.contains("pub struct InsertAnalysisRecordParams"), "Expected InsertAnalysisRecordParams struct:\n{}", code);
1599 assert!(code.contains("pub record_id"), "Expected record_id field in insert params:\n{}", code);
1600 assert!(code.contains("pub analysis_id"), "Expected analysis_id field in insert params:\n{}", code);
1601 assert!(code.contains("INSERT INTO analysis.analysis__record (record_id, analysis_id) VALUES ($1, $2) RETURNING *"), "Expected valid INSERT SQL:\n{}", code);
1602 assert!(code.contains("pub async fn insert"), "Expected insert method:\n{}", code);
1603 }
1604
1605 #[test]
1606 fn test_composite_pk_only_no_update() {
1607 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1608 assert!(!code.contains("UpdateAnalysisRecordParams"), "Expected no UpdateAnalysisRecordParams struct:\n{}", code);
1609 assert!(!code.contains("pub async fn update"), "Expected no update method:\n{}", code);
1610 }
1611
1612 #[test]
1613 fn test_composite_pk_only_delete_generated() {
1614 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1615 assert!(code.contains("pub async fn delete"), "Expected delete method:\n{}", code);
1616 assert!(code.contains("DELETE FROM analysis.analysis__record WHERE record_id = $1 AND analysis_id = $2"), "Expected valid DELETE SQL:\n{}", code);
1617 }
1618
1619 #[test]
1620 fn test_composite_pk_only_get_generated() {
1621 let code = gen(&junction_entity(), DatabaseKind::Postgres);
1622 assert!(code.contains("pub async fn get"), "Expected get method:\n{}", code);
1623 assert!(code.contains("WHERE record_id = $1 AND analysis_id = $2"), "Expected WHERE clause with both PK columns:\n{}", code);
1624 }
1625}