1use shaperail_core::{EndpointSpec, FieldSchema, FieldType, HttpMethod, ResourceDefinition};
2
3pub struct GeneratedRustModule {
4 pub file_name: String,
5 pub contents: String,
6}
7
8pub struct GeneratedRustProject {
9 pub modules: Vec<GeneratedRustModule>,
10 pub mod_rs: String,
11}
12
13pub fn generate_project(resources: &[ResourceDefinition]) -> Result<GeneratedRustProject, String> {
14 let mut modules = Vec::with_capacity(resources.len());
15 for resource in resources {
16 modules.push(GeneratedRustModule {
17 file_name: format!("{}.rs", resource.resource),
18 contents: generate_resource_module(resource)?,
19 });
20 }
21
22 Ok(GeneratedRustProject {
23 modules,
24 mod_rs: generate_registry_module(resources),
25 })
26}
27
28pub fn generate_resource_module(resource: &ResourceDefinition) -> Result<String, String> {
29 let context = ResourceContext::new(resource)?;
30
31 let model_fields = resource
32 .schema
33 .iter()
34 .map(|(name, field)| format!(" pub {name}: {},", model_field_type(field)))
35 .collect::<Vec<_>>()
36 .join("\n");
37
38 let list_helpers = context
39 .collection_endpoints
40 .iter()
41 .map(|endpoint| generate_list_helper(&context, endpoint))
42 .collect::<Result<Vec<_>, _>>()?
43 .join("\n\n");
44
45 let list_dispatch = if context.collection_endpoints.is_empty() {
46 " let _ = (endpoint, filters, search, sort, page);\n Err(shaperail_core::ShaperailError::Internal(\"No collection endpoints are available for generated list queries\".to_string()))".to_string()
47 } else {
48 let arms = context
49 .collection_endpoints
50 .iter()
51 .map(|endpoint| {
52 format!(
53 " {path:?} => self.{helper}(filters, search, sort, page).await,",
54 path = endpoint.spec.path(),
55 helper = endpoint.helper_name
56 )
57 })
58 .collect::<Vec<_>>()
59 .join("\n");
60
61 format!(
62 " match endpoint.path() {{\n{arms}\n _ => Err(shaperail_core::ShaperailError::Internal(format!(\"No generated list query for {{}}\", endpoint.path()))),\n }}"
63 )
64 };
65
66 Ok(format!(
67 r###"//! Generated query module for the `{resource_name}` resource.
68//! DO NOT EDIT — this file is auto-generated by `shaperail generate`.
69
70use serde::{{Deserialize, Serialize}};
71use serde_json::{{Map, Value}};
72use shaperail_core::EndpointSpec;
73#[allow(unused_imports)]
74use shaperail_runtime::db::{{
75 async_trait, parse_embedded_json, parse_filter, parse_optional_json, require_field,
76 row_from_model, sort_direction_at, sort_field_at, FilterSet, PageRequest, ResourceRow,
77 ResourceStore, SearchParam, SortParam,
78}};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct {record_name} {{
82{model_fields}
83}}
84
85pub struct {store_name} {{
86 pool: sqlx::PgPool,
87}}
88
89impl {store_name} {{
90 pub fn new(pool: sqlx::PgPool) -> Self {{
91 Self {{ pool }}
92 }}
93
94{list_helpers}
95}}
96
97#[async_trait]
98impl ResourceStore for {store_name} {{
99 fn resource_name(&self) -> &str {{
100 "{resource_name}"
101 }}
102
103 async fn find_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
104 let row = sqlx::query_as!(
105 {record_name},
106 r#"
107 SELECT
108 {select_columns}
109 FROM "{table_name}"
110 WHERE "{primary_key}" = $1{soft_delete_where}
111 "#,
112 id
113 )
114 .fetch_optional(&self.pool)
115 .await?
116 .ok_or(shaperail_core::ShaperailError::NotFound)?;
117
118 row_from_model(&row)
119 }}
120
121 async fn find_all(
122 &self,
123 endpoint: &EndpointSpec,
124 filters: &FilterSet,
125 search: Option<&SearchParam>,
126 sort: &SortParam,
127 page: &PageRequest,
128 ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
129{list_dispatch}
130 }}
131
132 async fn insert(&self, data: &Map<String, Value>) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
133{insert_body}
134 }}
135
136 async fn update_by_id(
137 &self,
138 id: &uuid::Uuid,
139 data: &Map<String, Value>,
140 ) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
141{update_body}
142 }}
143
144 async fn soft_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
145{soft_delete_body}
146 }}
147
148 async fn hard_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
149{hard_delete_body}
150 }}
151}}
152"###,
153 resource_name = resource.resource,
154 record_name = context.record_name,
155 store_name = context.store_name,
156 model_fields = model_fields,
157 list_helpers = list_helpers,
158 select_columns = context.select_columns,
159 table_name = resource.resource,
160 primary_key = context.primary_key,
161 soft_delete_where = context.soft_delete_where,
162 list_dispatch = list_dispatch,
163 insert_body = generate_insert_body(&context)?,
164 update_body = generate_update_body(&context)?,
165 soft_delete_body = generate_soft_delete_body(&context),
166 hard_delete_body = generate_hard_delete_body(&context),
167 ))
168}
169
170fn generate_registry_module(resources: &[ResourceDefinition]) -> String {
171 let module_lines = resources
172 .iter()
173 .map(|resource| format!("pub mod {};", resource.resource))
174 .collect::<Vec<_>>()
175 .join("\n");
176
177 let registry_lines = resources
178 .iter()
179 .map(|resource| {
180 let store_name = format!("{}Store", to_pascal_case(&resource.resource));
181 format!(
182 " stores.insert({name:?}.to_string(), std::sync::Arc::new({module}::{store_name}::new(pool.clone())));",
183 name = resource.resource,
184 module = resource.resource
185 )
186 })
187 .collect::<Vec<_>>()
188 .join("\n");
189
190 format!(
191 r#"{module_lines}
192
193pub fn build_store_registry(pool: sqlx::PgPool) -> shaperail_runtime::db::StoreRegistry {{
194 let mut stores: std::collections::HashMap<
195 String,
196 std::sync::Arc<dyn shaperail_runtime::db::ResourceStore>,
197 > = std::collections::HashMap::new();
198{registry_lines}
199 std::sync::Arc::new(stores)
200}}
201
202/// Returns an empty controller map. Register custom controller functions here
203/// or populate from `resources/<name>.controller.rs` files.
204pub fn build_controller_map() -> shaperail_runtime::handlers::controller::ControllerMap {{
205 shaperail_runtime::handlers::controller::ControllerMap::new()
206}}
207
208{controller_traits}
209"#,
210 controller_traits = generate_controller_traits(resources)
211 )
212}
213
214fn generate_controller_traits(resources: &[ResourceDefinition]) -> String {
222 let mut output = String::new();
223
224 for resource in resources {
225 let endpoints_with_controllers: Vec<_> = resource
226 .endpoints
227 .as_ref()
228 .map(|endpoints| {
229 endpoints
230 .iter()
231 .filter(|(_, ep)| ep.controller.is_some())
232 .collect::<Vec<_>>()
233 })
234 .unwrap_or_default();
235
236 if endpoints_with_controllers.is_empty() {
237 continue;
238 }
239
240 let pascal = to_pascal_case(&resource.resource);
241 let mut trait_methods = Vec::new();
242
243 for (action, ep) in &endpoints_with_controllers {
244 let controller = ep.controller.as_ref().unwrap();
245 let action_pascal = to_pascal_case(action);
246
247 let input_fields: Vec<_> = ep
249 .input
250 .as_ref()
251 .map(|fields| {
252 fields
253 .iter()
254 .filter_map(|name| {
255 resource.schema.get(name).map(|field| {
256 format!(" pub {name}: {},", model_field_type(field))
257 })
258 })
259 .collect()
260 })
261 .unwrap_or_default();
262
263 if !input_fields.is_empty() {
264 output.push_str(&format!(
265 r#"
266/// Input fields for the {resource_name} {action} endpoint.
267/// Auto-generated from the resource schema — do not edit.
268#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
269pub struct {pascal}{action_pascal}Input {{
270{fields}
271}}
272"#,
273 resource_name = resource.resource,
274 fields = input_fields.join("\n"),
275 ));
276 }
277
278 if let Some(before) = &controller.before {
280 if !before.starts_with(shaperail_core::WASM_HOOK_PREFIX) {
281 let input_type = if input_fields.is_empty() {
282 "serde_json::Value".to_string()
283 } else {
284 format!("{pascal}{action_pascal}Input")
285 };
286 trait_methods.push(format!(
287 " /// Before-hook for the {action} endpoint. Called before the DB operation.\n async fn {before}(ctx: &shaperail_runtime::handlers::controller::ControllerContext, input: &{input_type}) -> Result<(), shaperail_core::ShaperailError>;"
288 ));
289 }
290 }
291
292 if let Some(after) = &controller.after {
294 if !after.starts_with(shaperail_core::WASM_HOOK_PREFIX) {
295 trait_methods.push(format!(
296 " /// After-hook for the {action} endpoint. Called after the DB operation.\n async fn {after}(ctx: &shaperail_runtime::handlers::controller::ControllerContext, result: &serde_json::Value) -> Result<serde_json::Value, shaperail_core::ShaperailError>;"
297 ));
298 }
299 }
300 }
301
302 if !trait_methods.is_empty() {
303 output.push_str(&format!(
304 r#"
305/// Controller trait for the {resource_name} resource.
306/// Implement this trait in `controllers/{resource_name}.controller.rs`.
307/// The compiler will enforce correct signatures — no guessing needed.
308#[async_trait::async_trait]
309pub trait {pascal}Controller {{
310{methods}
311}}
312"#,
313 resource_name = resource.resource,
314 methods = trait_methods.join("\n\n"),
315 ));
316 }
317 }
318
319 output
320}
321
322#[derive(Clone)]
323struct CollectionEndpoint<'a> {
324 spec: &'a EndpointSpec,
325 helper_name: String,
326}
327
328struct ResourceContext<'a> {
329 resource: &'a ResourceDefinition,
330 record_name: String,
331 store_name: String,
332 primary_key: String,
333 select_columns: String,
334 soft_delete_where: String,
335 collection_endpoints: Vec<CollectionEndpoint<'a>>,
336}
337
338impl<'a> ResourceContext<'a> {
339 fn new(resource: &'a ResourceDefinition) -> Result<Self, String> {
340 let primary_key = resource
341 .schema
342 .iter()
343 .find(|(_, field)| field.primary)
344 .map(|(name, _)| name.clone())
345 .unwrap_or_else(|| "id".to_string());
346
347 let select_columns = resource
348 .schema
349 .iter()
350 .map(|(name, field)| select_column_sql(name, field))
351 .collect::<Vec<_>>()
352 .join(",\n ");
353
354 let collection_endpoints = resource
355 .endpoints
356 .as_ref()
357 .map(|endpoints| {
358 endpoints
359 .iter()
360 .filter(|(_, endpoint)| {
361 *endpoint.method() == HttpMethod::Get && !endpoint.path().contains(":id")
362 })
363 .map(|(name, endpoint)| CollectionEndpoint {
364 spec: endpoint,
365 helper_name: format!("find_all_{}", sanitize_identifier(name)),
366 })
367 .collect::<Vec<_>>()
368 })
369 .unwrap_or_default();
370
371 Ok(Self {
372 resource,
373 record_name: format!("{}Record", to_pascal_case(&resource.resource)),
374 store_name: format!("{}Store", to_pascal_case(&resource.resource)),
375 primary_key,
376 select_columns,
377 soft_delete_where: if has_soft_delete(resource) {
378 " AND \"deleted_at\" IS NULL".to_string()
379 } else {
380 String::new()
381 },
382 collection_endpoints,
383 })
384 }
385}
386
387fn generate_list_helper(
388 context: &ResourceContext<'_>,
389 endpoint: &CollectionEndpoint<'_>,
390) -> Result<String, String> {
391 let filters = endpoint.spec.filters.clone().unwrap_or_default();
392 let search_fields = endpoint.spec.search.clone().unwrap_or_default();
393 let sort_fields = endpoint.spec.sort.clone().unwrap_or_default();
394
395 let filter_decls = filters
396 .iter()
397 .map(|field_name| {
398 let field = context.resource.schema.get(field_name).ok_or_else(|| {
399 format!(
400 "Unknown filter field '{field_name}' on resource '{}'",
401 context.resource.resource
402 )
403 })?;
404 Ok(generate_filter_declaration(field_name, field))
405 })
406 .collect::<Result<Vec<_>, String>>()?
407 .join("\n");
408
409 let filter_args = filters
410 .iter()
411 .map(|field_name| {
412 parameter_expression(
413 field_name,
414 context
415 .resource
416 .schema
417 .get(field_name)
418 .expect("filter field validated"),
419 )
420 })
421 .collect::<Vec<_>>();
422
423 let search_decl = if search_fields.is_empty() {
424 String::new()
425 } else {
426 " let search_term = search.map(|value| value.term.clone());".to_string()
427 };
428
429 let search_predicate = if search_fields.is_empty() {
430 String::new()
431 } else {
432 search_expression(&search_fields)
433 };
434
435 let sort_decls = (0..sort_fields.len())
436 .map(|index| {
437 format!(
438 " let sort_field_{index} = sort_field_at(sort, {index});\n let sort_direction_{index} = sort_direction_at(sort, {index});"
439 )
440 })
441 .collect::<Vec<_>>()
442 .join("\n");
443
444 let filter_positions = filters
445 .iter()
446 .enumerate()
447 .map(|(index, field_name)| (field_name.clone(), index + 1))
448 .collect::<Vec<_>>();
449
450 let search_position = if search_fields.is_empty() {
451 None
452 } else {
453 Some(filter_positions.len() + 1)
454 };
455
456 let cursor_position = filter_positions.len() + usize::from(search_position.is_some()) + 1;
457 let cursor_sort_positions = (0..sort_fields.len())
458 .map(|index| {
459 let base = cursor_position + 1 + (index * 2);
460 (base, base + 1)
461 })
462 .collect::<Vec<_>>();
463 let offset_sort_positions = (0..sort_fields.len())
464 .map(|index| {
465 let base =
466 filter_positions.len() + usize::from(search_position.is_some()) + 1 + (index * 2);
467 (base, base + 1)
468 })
469 .collect::<Vec<_>>();
470
471 let filter_predicates = generate_filter_predicates(context, &filter_positions)?;
472 let filter_clause = filter_predicates.join("\n");
473
474 let cursor_order_by = generate_order_by(context, &sort_fields, &cursor_sort_positions)?;
475 let offset_order_by = generate_order_by(context, &sort_fields, &offset_sort_positions)?;
476
477 let mut cursor_args = filter_args.clone();
478 if search_position.is_some() {
479 cursor_args.push("search_term.as_deref()".to_string());
480 }
481 cursor_args.push("cursor".to_string());
482 for index in 0..sort_fields.len() {
483 cursor_args.push(format!("sort_field_{index}.as_deref()"));
484 cursor_args.push(format!("sort_direction_{index}"));
485 }
486 cursor_args.push("*limit + 1".to_string());
487
488 let mut count_args = filter_args.clone();
489 if search_position.is_some() {
490 count_args.push("search_term.as_deref()".to_string());
491 }
492
493 let mut row_args = count_args.clone();
494 for index in 0..sort_fields.len() {
495 row_args.push(format!("sort_field_{index}.as_deref()"));
496 row_args.push(format!("sort_direction_{index}"));
497 }
498 row_args.push("*limit".to_string());
499 row_args.push("*offset".to_string());
500
501 let cursor_query = generate_cursor_query(
502 context,
503 &filter_clause,
504 search_position.map(|position| (position, search_predicate.as_str())),
505 cursor_position,
506 &cursor_order_by,
507 cursor_args.len(),
508 &cursor_args,
509 );
510 let offset_query = generate_offset_query(
511 context,
512 &filter_clause,
513 search_position.map(|position| (position, search_predicate.as_str())),
514 &offset_order_by,
515 row_args.len() - 1,
516 row_args.len(),
517 &count_args,
518 &row_args,
519 );
520
521 Ok(format!(
522 r###" async fn {helper_name}(
523 &self,
524 filters: &FilterSet,
525 search: Option<&SearchParam>,
526 sort: &SortParam,
527 page: &PageRequest,
528 ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
529{filter_decls}
530{search_decl}
531{sort_decls}
532
533 match page {{
534 PageRequest::Cursor {{ after, limit }} => {{
535 let cursor = match after {{
536 Some(cursor_value) => Some(uuid::Uuid::parse_str(
537 &shaperail_runtime::db::decode_cursor(cursor_value)?
538 ).map_err(|_| shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {{
539 field: "cursor".to_string(),
540 message: "Invalid cursor value".to_string(),
541 code: "invalid_cursor".to_string(),
542 }}]))?),
543 None => None,
544 }};
545{cursor_query}
546 }}
547 PageRequest::Offset {{ offset, limit }} => {{
548{offset_query}
549 }}
550 }}
551 }}"###,
552 helper_name = endpoint.helper_name,
553 filter_decls = indent_block(&filter_decls, 2),
554 search_decl = indent_block(&search_decl, 2),
555 sort_decls = indent_block(&sort_decls, 2),
556 cursor_query = indent_block(&cursor_query, 4),
557 offset_query = indent_block(&offset_query, 4),
558 ))
559}
560
561fn generate_cursor_query(
562 context: &ResourceContext<'_>,
563 filter_clause: &str,
564 search_position: Option<(usize, &str)>,
565 cursor_position: usize,
566 order_by: &str,
567 limit_position: usize,
568 args: &[String],
569) -> String {
570 format!(
571 r###" let rows = sqlx::query_as!(
572 {record_name},
573 r#"
574 SELECT
575 {select_columns}
576 FROM "{table_name}"
577 WHERE TRUE
578{soft_delete_clause}
579{filter_clause}{search_clause}
580 AND (${cursor_position}::uuid IS NULL OR "{primary_key}" > ${cursor_position})
581 ORDER BY
582{order_by}
583 LIMIT ${limit_position}
584 "#,
585 {args}
586 )
587 .fetch_all(&self.pool)
588 .await?;
589
590 let has_more = rows.len() as i64 > *limit;
591 let mut result_rows = rows;
592 if has_more {{
593 result_rows.truncate(*limit as usize);
594 }}
595
596 let data = result_rows
597 .iter()
598 .map(row_from_model)
599 .collect::<Result<Vec<_>, _>>()?;
600 let cursor = if has_more {{
601 result_rows
602 .last()
603 .map(|row| shaperail_runtime::db::encode_cursor(&row.{primary_key}.to_string()))
604 }} else {{
605 None
606 }};
607
608 Ok((
609 data,
610 serde_json::json!({{
611 "cursor": cursor,
612 "has_more": has_more
613 }})
614 ))"###,
615 record_name = context.record_name,
616 select_columns = context.select_columns,
617 table_name = context.resource.resource,
618 soft_delete_clause = if has_soft_delete(context.resource) {
619 " AND \"deleted_at\" IS NULL\n"
620 } else {
621 ""
622 },
623 filter_clause = if filter_clause.is_empty() {
624 String::new()
625 } else {
626 format!("{filter_clause}\n")
627 },
628 search_clause = search_position
629 .map(|(position, expression)| {
630 format!(
631 "\n AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
632 )
633 })
634 .unwrap_or_default(),
635 cursor_position = cursor_position,
636 primary_key = context.primary_key,
637 order_by = order_by,
638 limit_position = limit_position,
639 args = args.join(",\n "),
640 )
641}
642
643#[allow(clippy::too_many_arguments)]
644fn generate_offset_query(
645 context: &ResourceContext<'_>,
646 filter_clause: &str,
647 search_position: Option<(usize, &str)>,
648 order_by: &str,
649 limit_position: usize,
650 offset_position: usize,
651 count_args: &[String],
652 row_args: &[String],
653) -> String {
654 let count_macro_args = if count_args.is_empty() {
655 String::new()
656 } else {
657 format!(
658 ",\n {}",
659 count_args.join(",\n ")
660 )
661 };
662
663 format!(
664 r###" let total = sqlx::query_scalar!(
665 r#"
666 SELECT COUNT(*) as "count!"
667 FROM "{table_name}"
668 WHERE TRUE
669{soft_delete_clause}
670{filter_clause}{search_clause}
671 "#{count_macro_args}
672 )
673 .fetch_one(&self.pool)
674 .await?;
675
676 let rows = sqlx::query_as!(
677 {record_name},
678 r#"
679 SELECT
680 {select_columns}
681 FROM "{table_name}"
682 WHERE TRUE
683{soft_delete_clause}
684{filter_clause}{search_clause}
685 ORDER BY
686{order_by}
687 LIMIT ${limit_param}
688 OFFSET ${offset_param}
689 "#,
690 {row_args}
691 )
692 .fetch_all(&self.pool)
693 .await?;
694
695 let data = rows
696 .iter()
697 .map(row_from_model)
698 .collect::<Result<Vec<_>, _>>()?;
699
700 Ok((
701 data,
702 serde_json::json!({{
703 "offset": offset,
704 "limit": limit,
705 "total": total
706 }})
707 ))"###,
708 table_name = context.resource.resource,
709 soft_delete_clause = if has_soft_delete(context.resource) {
710 " AND \"deleted_at\" IS NULL\n"
711 } else {
712 ""
713 },
714 filter_clause = if filter_clause.is_empty() {
715 String::new()
716 } else {
717 format!("{filter_clause}\n")
718 },
719 search_clause = search_position
720 .map(|(position, expression)| {
721 format!(
722 "\n AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
723 )
724 })
725 .unwrap_or_default(),
726 count_macro_args = count_macro_args,
727 record_name = context.record_name,
728 select_columns = context.select_columns,
729 order_by = order_by,
730 limit_param = limit_position,
731 offset_param = offset_position,
732 row_args = row_args.join(",\n "),
733 )
734}
735
736fn generate_filter_predicates(
737 context: &ResourceContext<'_>,
738 positions: &[(String, usize)],
739) -> Result<Vec<String>, String> {
740 positions
741 .iter()
742 .map(|(field_name, position)| {
743 let field = context.resource.schema.get(field_name).ok_or_else(|| {
744 format!(
745 "Unknown filter field '{field_name}' on resource '{}'",
746 context.resource.resource
747 )
748 })?;
749 Ok(format!(
750 " AND (${position}::{cast} IS NULL OR \"{field_name}\" = ${position})",
751 cast = sql_cast_type(field)
752 ))
753 })
754 .collect()
755}
756
757fn generate_order_by(
758 context: &ResourceContext<'_>,
759 sort_fields: &[String],
760 positions: &[(usize, usize)],
761) -> Result<String, String> {
762 if sort_fields.is_empty() {
763 return Ok(format!("\"{}\" ASC", context.primary_key));
764 }
765
766 let mut clauses = Vec::new();
767 for ((field_param, direction_param), field_name) in positions.iter().zip(sort_fields) {
768 for candidate in sort_fields {
769 let field = context.resource.schema.get(candidate).ok_or_else(|| {
770 format!(
771 "Unknown sort field '{candidate}' on resource '{}'",
772 context.resource.resource
773 )
774 })?;
775 let sort_expr = sortable_expression(candidate, field);
776 clauses.push(format!(
777 " CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'asc' THEN {sort_expr} END ASC"
778 ));
779 clauses.push(format!(
780 " CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'desc' THEN {sort_expr} END DESC"
781 ));
782 }
783 let _ = field_name;
784 }
785 clauses.push(format!(" \"{}\" ASC", context.primary_key));
786 Ok(clauses.join(",\n"))
787}
788
789fn generate_insert_body(context: &ResourceContext<'_>) -> Result<String, String> {
790 let mut declarations = Vec::new();
791 let mut columns = Vec::new();
792 let mut values = Vec::new();
793 let mut args = Vec::new();
794
795 for (index, (field_name, field)) in context.resource.schema.iter().enumerate() {
796 let variable_name = sanitize_identifier(field_name);
797 declarations.push(generate_insert_declaration(
798 field_name,
799 field,
800 &variable_name,
801 )?);
802 columns.push(format!("\"{field_name}\""));
803 values.push(format!("${}", index + 1));
804 args.push(variable_name);
805 }
806
807 Ok(format!(
808 r###"{declarations}
809 let row = sqlx::query_as!(
810 {record_name},
811 r#"
812 INSERT INTO "{table_name}" ({columns})
813 VALUES ({values})
814 RETURNING
815 {select_columns}
816 "#,
817 {args}
818 )
819 .fetch_one(&self.pool)
820 .await?;
821
822 row_from_model(&row)"###,
823 declarations = declarations.join("\n"),
824 record_name = context.record_name,
825 table_name = context.resource.resource,
826 columns = columns.join(", "),
827 values = values.join(", "),
828 select_columns = context.select_columns,
829 args = args.join(",\n "),
830 ))
831}
832
833fn generate_update_body(context: &ResourceContext<'_>) -> Result<String, String> {
834 let mut declarations = Vec::new();
835 let mut set_clauses = Vec::new();
836 let mut args = vec!["id".to_string()];
837 let mut has_mutable_fields = Vec::new();
838 let mut index = 2usize;
839
840 for (field_name, field) in &context.resource.schema {
841 if field.primary || field.generated {
842 continue;
843 }
844
845 let present_name = format!("{}_present", sanitize_identifier(field_name));
846 let value_name = sanitize_identifier(field_name);
847 declarations.push(generate_update_declaration(
848 field_name,
849 field,
850 &present_name,
851 &value_name,
852 ));
853 has_mutable_fields.push(present_name.clone());
854 set_clauses.push(format!(
855 "\"{field_name}\" = CASE WHEN ${present_param} THEN ${value_param} ELSE \"{field_name}\" END",
856 present_param = index,
857 value_param = index + 1
858 ));
859 args.push(present_name);
860 args.push(value_name);
861 index += 2;
862 }
863
864 if let Some(updated_at) = context.resource.schema.get("updated_at") {
865 if updated_at.generated && updated_at.field_type == FieldType::Timestamp {
866 declarations.push(" let updated_at = chrono::Utc::now();".to_string());
867 set_clauses.push(format!("\"updated_at\" = ${index}"));
868 args.push("updated_at".to_string());
869 }
870 }
871
872 let guard = if has_mutable_fields.is_empty() {
873 String::new()
874 } else {
875 format!(
876 " if !({}) {}",
877 has_mutable_fields.join(" || "),
878 r#"{
879 return Err(shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
880 field: "body".to_string(),
881 message: "No valid fields to update".to_string(),
882 code: "empty_update".to_string(),
883 }]));
884 }"#
885 )
886 };
887
888 Ok(format!(
889 r###"{declarations}
890{guard}
891 let row = sqlx::query_as!(
892 {record_name},
893 r#"
894 UPDATE "{table_name}"
895 SET {set_clauses}
896 WHERE "{primary_key}" = $1{soft_delete_where}
897 RETURNING
898 {select_columns}
899 "#,
900 {args}
901 )
902 .fetch_optional(&self.pool)
903 .await?
904 .ok_or(shaperail_core::ShaperailError::NotFound)?;
905
906 row_from_model(&row)"###,
907 declarations = declarations.join("\n"),
908 guard = guard,
909 record_name = context.record_name,
910 table_name = context.resource.resource,
911 set_clauses = set_clauses.join(", "),
912 primary_key = context.primary_key,
913 soft_delete_where = context.soft_delete_where,
914 select_columns = context.select_columns,
915 args = args.join(",\n "),
916 ))
917}
918
919fn generate_soft_delete_body(context: &ResourceContext<'_>) -> String {
920 format!(
921 r###" let deleted_at = chrono::Utc::now();
922 let row = sqlx::query_as!(
923 {record_name},
924 r#"
925 UPDATE "{table_name}"
926 SET "deleted_at" = $2
927 WHERE "{primary_key}" = $1 AND "deleted_at" IS NULL
928 RETURNING
929 {select_columns}
930 "#,
931 id,
932 deleted_at
933 )
934 .fetch_optional(&self.pool)
935 .await?
936 .ok_or(shaperail_core::ShaperailError::NotFound)?;
937
938 row_from_model(&row)"###,
939 record_name = context.record_name,
940 table_name = context.resource.resource,
941 primary_key = context.primary_key,
942 select_columns = context.select_columns,
943 )
944}
945
946fn generate_hard_delete_body(context: &ResourceContext<'_>) -> String {
947 format!(
948 r###" let row = sqlx::query_as!(
949 {record_name},
950 r#"
951 DELETE FROM "{table_name}"
952 WHERE "{primary_key}" = $1
953 RETURNING
954 {select_columns}
955 "#,
956 id
957 )
958 .fetch_optional(&self.pool)
959 .await?
960 .ok_or(shaperail_core::ShaperailError::NotFound)?;
961
962 row_from_model(&row)"###,
963 record_name = context.record_name,
964 table_name = context.resource.resource,
965 primary_key = context.primary_key,
966 select_columns = context.select_columns,
967 )
968}
969
970fn generate_insert_declaration(
971 field_name: &str,
972 field: &FieldSchema,
973 variable_name: &str,
974) -> Result<String, String> {
975 if field.generated {
976 return Ok(format!(
977 " let {variable_name} = {};",
978 generated_value_expression(field)
979 ));
980 }
981
982 let parse_type = parse_type(field);
983 let parsed = format!(
984 "shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?"
985 );
986
987 let expression = match (field_is_required(field), field.default.as_ref()) {
988 (true, Some(default)) => format!(
989 "match {parsed} {{ Some(value) => value, None => {} }}",
990 default_expression(field_name, field, default)?
991 ),
992 (true, None) => format!("shaperail_runtime::db::require_field({parsed}, {field_name:?})?"),
993 (false, Some(default)) if model_field_is_optional(field) => format!(
994 "match {parsed} {{ Some(value) => Some(value), None => Some({}) }}",
995 default_expression(field_name, field, default)?
996 ),
997 (false, Some(default)) => format!(
998 "match {parsed} {{ Some(value) => value, None => {} }}",
999 default_expression(field_name, field, default)?
1000 ),
1001 (false, None) => parsed,
1002 };
1003
1004 Ok(format!(" let {variable_name} = {expression};"))
1005}
1006
1007fn generate_update_declaration(
1008 field_name: &str,
1009 field: &FieldSchema,
1010 present_name: &str,
1011 value_name: &str,
1012) -> String {
1013 format!(
1014 " let {present_name} = data.contains_key({field_name:?});\n let {value_name} = shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?;",
1015 parse_type = parse_type(field)
1016 )
1017}
1018
1019fn generate_filter_declaration(field_name: &str, field: &FieldSchema) -> String {
1020 let parser = match field.field_type {
1021 FieldType::Uuid => "uuid::Uuid::parse_str(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid uuid filter\".to_string()))",
1022 FieldType::String | FieldType::Enum | FieldType::File => "Ok(text.to_string())",
1023 FieldType::Integer => "text.parse::<i32>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid integer filter\".to_string()))",
1024 FieldType::Bigint => "text.parse::<i64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid bigint filter\".to_string()))",
1025 FieldType::Number => "text.parse::<f64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid number filter\".to_string()))",
1026 FieldType::Boolean => "text.parse::<bool>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid boolean filter\".to_string()))",
1027 FieldType::Timestamp => "chrono::DateTime::parse_from_rfc3339(text).map(|value| value.with_timezone(&chrono::Utc)).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid timestamp filter\".to_string()))",
1028 FieldType::Date => "chrono::NaiveDate::parse_from_str(text, \"%Y-%m-%d\").map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid date filter\".to_string()))",
1029 FieldType::Json => "serde_json::from_str::<serde_json::Value>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid json filter\".to_string()))",
1030 FieldType::Array => "serde_json::from_str::<Vec<serde_json::Value>>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid array filter\".to_string()))",
1031 };
1032
1033 format!(
1034 " let {var} = parse_filter(filters, {field_name:?}, \"invalid_filter\", |text| {parser})?;",
1035 var = field_parameter_name(field_name)
1036 )
1037}
1038
1039fn field_parameter_name(field_name: &str) -> String {
1040 format!("filter_{}", sanitize_identifier(field_name))
1041}
1042
1043fn parameter_expression(field_name: &str, field: &FieldSchema) -> String {
1044 let var = field_parameter_name(field_name);
1045 match field.field_type {
1046 FieldType::String | FieldType::Enum | FieldType::File => format!("{var}.as_deref()"),
1047 _ => var,
1048 }
1049}
1050
1051fn select_column_sql(field_name: &str, field: &FieldSchema) -> String {
1052 let nullability = if model_field_is_optional(field) {
1053 "?"
1054 } else {
1055 "!"
1056 };
1057 let expression = match field.field_type {
1058 FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
1059 _ => format!("\"{field_name}\""),
1060 };
1061 format!(
1062 "{expression} as \"{field_name}{nullability}: {type_name}\"",
1063 type_name = query_type(field)
1064 )
1065}
1066
1067fn sortable_expression(field_name: &str, field: &FieldSchema) -> String {
1068 match field.field_type {
1069 FieldType::Json | FieldType::Array | FieldType::Uuid => format!("\"{field_name}\"::text"),
1070 FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
1071 _ => format!("\"{field_name}\""),
1072 }
1073}
1074
1075fn search_expression(fields: &[String]) -> String {
1076 fields
1077 .iter()
1078 .map(|field| format!("COALESCE(\"{field}\"::text, '')"))
1079 .collect::<Vec<_>>()
1080 .join(" || ' ' || ")
1081}
1082
1083fn sql_cast_type(field: &FieldSchema) -> String {
1084 match field.field_type {
1085 FieldType::Uuid => "uuid".to_string(),
1086 FieldType::String | FieldType::Enum | FieldType::File => "text".to_string(),
1087 FieldType::Integer => "integer".to_string(),
1088 FieldType::Bigint => "bigint".to_string(),
1089 FieldType::Number => "double precision".to_string(),
1090 FieldType::Boolean => "boolean".to_string(),
1091 FieldType::Timestamp => "timestamptz".to_string(),
1092 FieldType::Date => "date".to_string(),
1093 FieldType::Json => "jsonb".to_string(),
1094 FieldType::Array => match field.items.as_deref() {
1095 Some("uuid") => "uuid[]".to_string(),
1096 Some("integer") => "integer[]".to_string(),
1097 Some("bigint") => "bigint[]".to_string(),
1098 Some("number") => "double precision[]".to_string(),
1099 Some("boolean") => "boolean[]".to_string(),
1100 _ => "text[]".to_string(),
1101 },
1102 }
1103}
1104
1105fn query_type(field: &FieldSchema) -> String {
1106 match field.field_type {
1107 FieldType::Uuid => "uuid::Uuid".to_string(),
1108 FieldType::String | FieldType::Enum | FieldType::File => "String".to_string(),
1109 FieldType::Integer => "i32".to_string(),
1110 FieldType::Bigint => "i64".to_string(),
1111 FieldType::Number => "f64".to_string(),
1112 FieldType::Boolean => "bool".to_string(),
1113 FieldType::Timestamp => "chrono::DateTime<chrono::Utc>".to_string(),
1114 FieldType::Date => "chrono::NaiveDate".to_string(),
1115 FieldType::Json => "serde_json::Value".to_string(),
1116 FieldType::Array => match field.items.as_deref() {
1117 Some("uuid") => "Vec<uuid::Uuid>".to_string(),
1118 Some("integer") => "Vec<i32>".to_string(),
1119 Some("bigint") => "Vec<i64>".to_string(),
1120 Some("number") => "Vec<f64>".to_string(),
1121 Some("boolean") => "Vec<bool>".to_string(),
1122 Some("timestamp") => "Vec<chrono::DateTime<chrono::Utc>>".to_string(),
1123 Some("date") => "Vec<chrono::NaiveDate>".to_string(),
1124 _ => "Vec<String>".to_string(),
1125 },
1126 }
1127}
1128
1129fn parse_type(field: &FieldSchema) -> String {
1130 query_type(field)
1131}
1132
1133fn model_field_type(field: &FieldSchema) -> String {
1134 let base = query_type(field);
1135 if model_field_is_optional(field) {
1136 format!("Option<{base}>")
1137 } else {
1138 base
1139 }
1140}
1141
1142fn model_field_is_optional(field: &FieldSchema) -> bool {
1143 !(field.primary || (field.required && !field.nullable))
1144}
1145
1146fn field_is_required(field: &FieldSchema) -> bool {
1147 field.primary || (field.required && !field.nullable)
1148}
1149
1150fn generated_value_expression(field: &FieldSchema) -> String {
1151 match field.field_type {
1152 FieldType::Uuid => "uuid::Uuid::new_v4()".to_string(),
1153 FieldType::Timestamp => {
1154 if model_field_is_optional(field) {
1155 "Some(chrono::Utc::now())".to_string()
1156 } else {
1157 "chrono::Utc::now()".to_string()
1158 }
1159 }
1160 FieldType::Date => {
1161 if model_field_is_optional(field) {
1162 "Some(chrono::Utc::now().date_naive())".to_string()
1163 } else {
1164 "chrono::Utc::now().date_naive()".to_string()
1165 }
1166 }
1167 _ => "Default::default()".to_string(),
1168 }
1169}
1170
1171fn default_expression(
1172 field_name: &str,
1173 field: &FieldSchema,
1174 default: &serde_json::Value,
1175) -> Result<String, String> {
1176 Ok(match field.field_type {
1177 FieldType::Uuid => format!(
1178 "parse_embedded_json::<uuid::Uuid>({field_name:?}, serde_json::json!({default}))?"
1179 ),
1180 FieldType::String | FieldType::Enum | FieldType::File => {
1181 let value = default
1182 .as_str()
1183 .ok_or_else(|| format!("Default for '{field_name}' must be a string"))?;
1184 format!("{value:?}.to_string()")
1185 }
1186 FieldType::Integer => format!(
1187 "parse_embedded_json::<i32>({field_name:?}, serde_json::json!({default}))?"
1188 ),
1189 FieldType::Bigint => format!(
1190 "parse_embedded_json::<i64>({field_name:?}, serde_json::json!({default}))?"
1191 ),
1192 FieldType::Number => format!(
1193 "parse_embedded_json::<f64>({field_name:?}, serde_json::json!({default}))?"
1194 ),
1195 FieldType::Boolean => default
1196 .as_bool()
1197 .ok_or_else(|| format!("Default for '{field_name}' must be a boolean"))?
1198 .to_string(),
1199 FieldType::Timestamp => format!(
1200 "parse_embedded_json::<chrono::DateTime<chrono::Utc>>({field_name:?}, serde_json::json!({default}))?"
1201 ),
1202 FieldType::Date => format!(
1203 "parse_embedded_json::<chrono::NaiveDate>({field_name:?}, serde_json::json!({default}))?"
1204 ),
1205 FieldType::Json => format!("serde_json::json!({default})"),
1206 FieldType::Array => format!(
1207 "parse_embedded_json::<{}>({field_name:?}, serde_json::json!({default}))?",
1208 query_type(field)
1209 ),
1210 })
1211}
1212
1213fn has_soft_delete(resource: &ResourceDefinition) -> bool {
1214 resource
1215 .endpoints
1216 .as_ref()
1217 .map(|endpoints| endpoints.values().any(|endpoint| endpoint.soft_delete))
1218 .unwrap_or(false)
1219}
1220
1221fn sanitize_identifier(value: &str) -> String {
1222 let mut output = String::new();
1223 for ch in value.chars() {
1224 if ch.is_ascii_alphanumeric() {
1225 output.push(ch.to_ascii_lowercase());
1226 } else {
1227 output.push('_');
1228 }
1229 }
1230
1231 if output.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1232 output.insert(0, '_');
1233 }
1234
1235 output
1236}
1237
1238fn to_pascal_case(value: &str) -> String {
1239 value
1240 .split('_')
1241 .filter(|part| !part.is_empty())
1242 .map(|part| {
1243 let mut chars = part.chars();
1244 match chars.next() {
1245 Some(first) => {
1246 let mut segment = String::new();
1247 segment.extend(first.to_uppercase());
1248 segment.push_str(chars.as_str());
1249 segment
1250 }
1251 None => String::new(),
1252 }
1253 })
1254 .collect::<String>()
1255}
1256
1257fn indent_block(block: &str, indent: usize) -> String {
1258 if block.trim().is_empty() {
1259 return String::new();
1260 }
1261
1262 let prefix = " ".repeat(indent);
1263 block
1264 .lines()
1265 .map(|line| {
1266 if line.is_empty() {
1267 String::new()
1268 } else {
1269 format!("{prefix}{line}")
1270 }
1271 })
1272 .collect::<Vec<_>>()
1273 .join("\n")
1274}
1275
1276#[cfg(test)]
1277mod tests {
1278 use super::*;
1279 use indexmap::IndexMap;
1280 use shaperail_core::{
1281 AuthRule, EndpointSpec, FieldSchema, HttpMethod, PaginationStyle, ResourceDefinition,
1282 };
1283
1284 fn sample_resource() -> ResourceDefinition {
1285 let mut schema = IndexMap::new();
1286 schema.insert(
1287 "id".to_string(),
1288 FieldSchema {
1289 field_type: FieldType::Uuid,
1290 primary: true,
1291 generated: true,
1292 required: false,
1293 unique: false,
1294 nullable: false,
1295 reference: None,
1296 min: None,
1297 max: None,
1298 format: None,
1299 values: None,
1300 default: None,
1301 sensitive: false,
1302 search: false,
1303 items: None,
1304 },
1305 );
1306 schema.insert(
1307 "email".to_string(),
1308 FieldSchema {
1309 field_type: FieldType::String,
1310 primary: false,
1311 generated: false,
1312 required: true,
1313 unique: true,
1314 nullable: false,
1315 reference: None,
1316 min: None,
1317 max: None,
1318 format: None,
1319 values: None,
1320 default: None,
1321 sensitive: false,
1322 search: true,
1323 items: None,
1324 },
1325 );
1326 schema.insert(
1327 "created_at".to_string(),
1328 FieldSchema {
1329 field_type: FieldType::Timestamp,
1330 primary: false,
1331 generated: true,
1332 required: false,
1333 unique: false,
1334 nullable: false,
1335 reference: None,
1336 min: None,
1337 max: None,
1338 format: None,
1339 values: None,
1340 default: None,
1341 sensitive: false,
1342 search: false,
1343 items: None,
1344 },
1345 );
1346
1347 let mut endpoints = indexmap::IndexMap::new();
1348 endpoints.insert(
1349 "list".to_string(),
1350 EndpointSpec {
1351 method: Some(HttpMethod::Get),
1352 path: Some("/users".to_string()),
1353 auth: Some(AuthRule::Public),
1354 input: None,
1355 filters: Some(vec!["email".to_string()]),
1356 search: Some(vec!["email".to_string()]),
1357 pagination: Some(PaginationStyle::Cursor),
1358 sort: Some(vec!["created_at".to_string()]),
1359 cache: None,
1360 controller: None,
1361 events: None,
1362 jobs: None,
1363 upload: None,
1364 soft_delete: false,
1365 },
1366 );
1367
1368 ResourceDefinition {
1369 resource: "users".to_string(),
1370 version: 1,
1371 db: None,
1372 tenant_key: None,
1373 schema,
1374 endpoints: Some(endpoints),
1375 relations: None,
1376 indexes: None,
1377 }
1378 }
1379
1380 #[test]
1381 fn generates_query_as_store_module() {
1382 let resource = sample_resource();
1383 let code = generate_resource_module(&resource).unwrap();
1384
1385 assert!(code.contains("impl ResourceStore for UsersStore"));
1386 assert!(code.contains("sqlx::query_as!"));
1387 assert!(code.contains("find_all_list"));
1388 }
1389
1390 #[test]
1391 fn generates_registry_module() {
1392 let resource = sample_resource();
1393 let project = generate_project(&[resource]).unwrap();
1394
1395 assert!(project.mod_rs.contains("pub mod users;"));
1396 assert!(project.mod_rs.contains("build_store_registry"));
1397 }
1398}