1use shaperail_core::{EndpointSpec, FieldSchema, FieldType, HttpMethod, ResourceDefinition};
2
3pub struct GeneratedRustModule {
4 pub file_name: String,
5 pub contents: String,
6}
7
8pub struct GeneratedRustProject {
9 pub modules: Vec<GeneratedRustModule>,
10 pub mod_rs: String,
11}
12
13pub fn generate_project(resources: &[ResourceDefinition]) -> Result<GeneratedRustProject, String> {
14 let mut modules = Vec::with_capacity(resources.len());
15 for resource in resources {
16 modules.push(GeneratedRustModule {
17 file_name: format!("{}.rs", resource.resource),
18 contents: generate_resource_module(resource)?,
19 });
20 }
21
22 Ok(GeneratedRustProject {
23 modules,
24 mod_rs: generate_registry_module(resources),
25 })
26}
27
28pub fn generate_resource_module(resource: &ResourceDefinition) -> Result<String, String> {
29 let context = ResourceContext::new(resource)?;
30
31 let model_fields = resource
32 .schema
33 .iter()
34 .map(|(name, field)| format!(" pub {name}: {},", model_field_type(field)))
35 .collect::<Vec<_>>()
36 .join("\n");
37
38 let list_helpers = context
39 .collection_endpoints
40 .iter()
41 .map(|endpoint| generate_list_helper(&context, endpoint))
42 .collect::<Result<Vec<_>, _>>()?
43 .join("\n\n");
44
45 let list_dispatch = if context.collection_endpoints.is_empty() {
46 " let _ = (endpoint, filters, search, sort, page);\n Err(shaperail_core::ShaperailError::Internal(\"No collection endpoints are available for generated list queries\".to_string()))".to_string()
47 } else {
48 let arms = context
49 .collection_endpoints
50 .iter()
51 .map(|endpoint| {
52 format!(
53 " {path:?} => self.{helper}(filters, search, sort, page).await,",
54 path = endpoint.spec.path,
55 helper = endpoint.helper_name
56 )
57 })
58 .collect::<Vec<_>>()
59 .join("\n");
60
61 format!(
62 " match endpoint.path.as_str() {{\n{arms}\n _ => Err(shaperail_core::ShaperailError::Internal(format!(\"No generated list query for {{}}\", endpoint.path))),\n }}"
63 )
64 };
65
66 Ok(format!(
67 r###"//! Generated query module for the `{resource_name}` resource.
68//! DO NOT EDIT — this file is auto-generated by `shaperail generate`.
69
70use serde::{{Deserialize, Serialize}};
71use serde_json::{{Map, Value}};
72use shaperail_core::EndpointSpec;
73#[allow(unused_imports)]
74use shaperail_runtime::db::{{
75 async_trait, parse_embedded_json, parse_filter, parse_optional_json, require_field,
76 row_from_model, sort_direction_at, sort_field_at, FilterSet, PageRequest, ResourceRow,
77 ResourceStore, SearchParam, SortParam,
78}};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct {record_name} {{
82{model_fields}
83}}
84
85pub struct {store_name} {{
86 pool: sqlx::PgPool,
87}}
88
89impl {store_name} {{
90 pub fn new(pool: sqlx::PgPool) -> Self {{
91 Self {{ pool }}
92 }}
93
94{list_helpers}
95}}
96
97#[async_trait]
98impl ResourceStore for {store_name} {{
99 fn resource_name(&self) -> &str {{
100 "{resource_name}"
101 }}
102
103 async fn find_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
104 let row = sqlx::query_as!(
105 {record_name},
106 r#"
107 SELECT
108 {select_columns}
109 FROM "{table_name}"
110 WHERE "{primary_key}" = $1{soft_delete_where}
111 "#,
112 id
113 )
114 .fetch_optional(&self.pool)
115 .await?
116 .ok_or(shaperail_core::ShaperailError::NotFound)?;
117
118 row_from_model(&row)
119 }}
120
121 async fn find_all(
122 &self,
123 endpoint: &EndpointSpec,
124 filters: &FilterSet,
125 search: Option<&SearchParam>,
126 sort: &SortParam,
127 page: &PageRequest,
128 ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
129{list_dispatch}
130 }}
131
132 async fn insert(&self, data: &Map<String, Value>) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
133{insert_body}
134 }}
135
136 async fn update_by_id(
137 &self,
138 id: &uuid::Uuid,
139 data: &Map<String, Value>,
140 ) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
141{update_body}
142 }}
143
144 async fn soft_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
145{soft_delete_body}
146 }}
147
148 async fn hard_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
149{hard_delete_body}
150 }}
151}}
152"###,
153 resource_name = resource.resource,
154 record_name = context.record_name,
155 store_name = context.store_name,
156 model_fields = model_fields,
157 list_helpers = list_helpers,
158 select_columns = context.select_columns,
159 table_name = resource.resource,
160 primary_key = context.primary_key,
161 soft_delete_where = context.soft_delete_where,
162 list_dispatch = list_dispatch,
163 insert_body = generate_insert_body(&context)?,
164 update_body = generate_update_body(&context)?,
165 soft_delete_body = generate_soft_delete_body(&context),
166 hard_delete_body = generate_hard_delete_body(&context),
167 ))
168}
169
170fn generate_registry_module(resources: &[ResourceDefinition]) -> String {
171 let module_lines = resources
172 .iter()
173 .map(|resource| format!("pub mod {};", resource.resource))
174 .collect::<Vec<_>>()
175 .join("\n");
176
177 let registry_lines = resources
178 .iter()
179 .map(|resource| {
180 let store_name = format!("{}Store", to_pascal_case(&resource.resource));
181 format!(
182 " stores.insert({name:?}.to_string(), std::sync::Arc::new({module}::{store_name}::new(pool.clone())));",
183 name = resource.resource,
184 module = resource.resource
185 )
186 })
187 .collect::<Vec<_>>()
188 .join("\n");
189
190 format!(
191 r#"{module_lines}
192
193pub fn build_store_registry(pool: sqlx::PgPool) -> shaperail_runtime::db::StoreRegistry {{
194 let mut stores: std::collections::HashMap<
195 String,
196 std::sync::Arc<dyn shaperail_runtime::db::ResourceStore>,
197 > = std::collections::HashMap::new();
198{registry_lines}
199 std::sync::Arc::new(stores)
200}}
201"#
202 )
203}
204
205#[derive(Clone)]
206struct CollectionEndpoint<'a> {
207 spec: &'a EndpointSpec,
208 helper_name: String,
209}
210
211struct ResourceContext<'a> {
212 resource: &'a ResourceDefinition,
213 record_name: String,
214 store_name: String,
215 primary_key: String,
216 select_columns: String,
217 soft_delete_where: String,
218 collection_endpoints: Vec<CollectionEndpoint<'a>>,
219}
220
221impl<'a> ResourceContext<'a> {
222 fn new(resource: &'a ResourceDefinition) -> Result<Self, String> {
223 let primary_key = resource
224 .schema
225 .iter()
226 .find(|(_, field)| field.primary)
227 .map(|(name, _)| name.clone())
228 .unwrap_or_else(|| "id".to_string());
229
230 let select_columns = resource
231 .schema
232 .iter()
233 .map(|(name, field)| select_column_sql(name, field))
234 .collect::<Vec<_>>()
235 .join(",\n ");
236
237 let collection_endpoints = resource
238 .endpoints
239 .as_ref()
240 .map(|endpoints| {
241 endpoints
242 .iter()
243 .filter(|(_, endpoint)| {
244 endpoint.method == HttpMethod::Get && !endpoint.path.contains(":id")
245 })
246 .map(|(name, endpoint)| CollectionEndpoint {
247 spec: endpoint,
248 helper_name: format!("find_all_{}", sanitize_identifier(name)),
249 })
250 .collect::<Vec<_>>()
251 })
252 .unwrap_or_default();
253
254 Ok(Self {
255 resource,
256 record_name: format!("{}Record", to_pascal_case(&resource.resource)),
257 store_name: format!("{}Store", to_pascal_case(&resource.resource)),
258 primary_key,
259 select_columns,
260 soft_delete_where: if has_soft_delete(resource) {
261 " AND \"deleted_at\" IS NULL".to_string()
262 } else {
263 String::new()
264 },
265 collection_endpoints,
266 })
267 }
268}
269
270fn generate_list_helper(
271 context: &ResourceContext<'_>,
272 endpoint: &CollectionEndpoint<'_>,
273) -> Result<String, String> {
274 let filters = endpoint.spec.filters.clone().unwrap_or_default();
275 let search_fields = endpoint.spec.search.clone().unwrap_or_default();
276 let sort_fields = endpoint.spec.sort.clone().unwrap_or_default();
277
278 let filter_decls = filters
279 .iter()
280 .map(|field_name| {
281 let field = context.resource.schema.get(field_name).ok_or_else(|| {
282 format!(
283 "Unknown filter field '{field_name}' on resource '{}'",
284 context.resource.resource
285 )
286 })?;
287 Ok(generate_filter_declaration(field_name, field))
288 })
289 .collect::<Result<Vec<_>, String>>()?
290 .join("\n");
291
292 let filter_args = filters
293 .iter()
294 .map(|field_name| {
295 parameter_expression(
296 field_name,
297 context
298 .resource
299 .schema
300 .get(field_name)
301 .expect("filter field validated"),
302 )
303 })
304 .collect::<Vec<_>>();
305
306 let search_decl = if search_fields.is_empty() {
307 String::new()
308 } else {
309 " let search_term = search.map(|value| value.term.clone());".to_string()
310 };
311
312 let search_predicate = if search_fields.is_empty() {
313 String::new()
314 } else {
315 search_expression(&search_fields)
316 };
317
318 let sort_decls = (0..sort_fields.len())
319 .map(|index| {
320 format!(
321 " let sort_field_{index} = sort_field_at(sort, {index});\n let sort_direction_{index} = sort_direction_at(sort, {index});"
322 )
323 })
324 .collect::<Vec<_>>()
325 .join("\n");
326
327 let filter_positions = filters
328 .iter()
329 .enumerate()
330 .map(|(index, field_name)| (field_name.clone(), index + 1))
331 .collect::<Vec<_>>();
332
333 let search_position = if search_fields.is_empty() {
334 None
335 } else {
336 Some(filter_positions.len() + 1)
337 };
338
339 let cursor_position = filter_positions.len() + usize::from(search_position.is_some()) + 1;
340 let cursor_sort_positions = (0..sort_fields.len())
341 .map(|index| {
342 let base = cursor_position + 1 + (index * 2);
343 (base, base + 1)
344 })
345 .collect::<Vec<_>>();
346 let offset_sort_positions = (0..sort_fields.len())
347 .map(|index| {
348 let base =
349 filter_positions.len() + usize::from(search_position.is_some()) + 1 + (index * 2);
350 (base, base + 1)
351 })
352 .collect::<Vec<_>>();
353
354 let filter_predicates = generate_filter_predicates(context, &filter_positions)?;
355 let filter_clause = filter_predicates.join("\n");
356
357 let cursor_order_by = generate_order_by(context, &sort_fields, &cursor_sort_positions)?;
358 let offset_order_by = generate_order_by(context, &sort_fields, &offset_sort_positions)?;
359
360 let mut cursor_args = filter_args.clone();
361 if search_position.is_some() {
362 cursor_args.push("search_term.as_deref()".to_string());
363 }
364 cursor_args.push("cursor".to_string());
365 for index in 0..sort_fields.len() {
366 cursor_args.push(format!("sort_field_{index}.as_deref()"));
367 cursor_args.push(format!("sort_direction_{index}"));
368 }
369 cursor_args.push("*limit + 1".to_string());
370
371 let mut count_args = filter_args.clone();
372 if search_position.is_some() {
373 count_args.push("search_term.as_deref()".to_string());
374 }
375
376 let mut row_args = count_args.clone();
377 for index in 0..sort_fields.len() {
378 row_args.push(format!("sort_field_{index}.as_deref()"));
379 row_args.push(format!("sort_direction_{index}"));
380 }
381 row_args.push("*limit".to_string());
382 row_args.push("*offset".to_string());
383
384 let cursor_query = generate_cursor_query(
385 context,
386 &filter_clause,
387 search_position.map(|position| (position, search_predicate.as_str())),
388 cursor_position,
389 &cursor_order_by,
390 cursor_args.len(),
391 &cursor_args,
392 );
393 let offset_query = generate_offset_query(
394 context,
395 &filter_clause,
396 search_position.map(|position| (position, search_predicate.as_str())),
397 &offset_order_by,
398 row_args.len() - 1,
399 row_args.len(),
400 &count_args,
401 &row_args,
402 );
403
404 Ok(format!(
405 r###" async fn {helper_name}(
406 &self,
407 filters: &FilterSet,
408 search: Option<&SearchParam>,
409 sort: &SortParam,
410 page: &PageRequest,
411 ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
412{filter_decls}
413{search_decl}
414{sort_decls}
415
416 match page {{
417 PageRequest::Cursor {{ after, limit }} => {{
418 let cursor = match after {{
419 Some(cursor_value) => Some(uuid::Uuid::parse_str(
420 &shaperail_runtime::db::decode_cursor(cursor_value)?
421 ).map_err(|_| shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {{
422 field: "cursor".to_string(),
423 message: "Invalid cursor value".to_string(),
424 code: "invalid_cursor".to_string(),
425 }}]))?),
426 None => None,
427 }};
428{cursor_query}
429 }}
430 PageRequest::Offset {{ offset, limit }} => {{
431{offset_query}
432 }}
433 }}
434 }}"###,
435 helper_name = endpoint.helper_name,
436 filter_decls = indent_block(&filter_decls, 2),
437 search_decl = indent_block(&search_decl, 2),
438 sort_decls = indent_block(&sort_decls, 2),
439 cursor_query = indent_block(&cursor_query, 4),
440 offset_query = indent_block(&offset_query, 4),
441 ))
442}
443
444fn generate_cursor_query(
445 context: &ResourceContext<'_>,
446 filter_clause: &str,
447 search_position: Option<(usize, &str)>,
448 cursor_position: usize,
449 order_by: &str,
450 limit_position: usize,
451 args: &[String],
452) -> String {
453 format!(
454 r###" let rows = sqlx::query_as!(
455 {record_name},
456 r#"
457 SELECT
458 {select_columns}
459 FROM "{table_name}"
460 WHERE TRUE
461{soft_delete_clause}
462{filter_clause}{search_clause}
463 AND (${cursor_position}::uuid IS NULL OR "{primary_key}" > ${cursor_position})
464 ORDER BY
465{order_by}
466 LIMIT ${limit_position}
467 "#,
468 {args}
469 )
470 .fetch_all(&self.pool)
471 .await?;
472
473 let has_more = rows.len() as i64 > *limit;
474 let mut result_rows = rows;
475 if has_more {{
476 result_rows.truncate(*limit as usize);
477 }}
478
479 let data = result_rows
480 .iter()
481 .map(row_from_model)
482 .collect::<Result<Vec<_>, _>>()?;
483 let cursor = if has_more {{
484 result_rows
485 .last()
486 .map(|row| shaperail_runtime::db::encode_cursor(&row.{primary_key}.to_string()))
487 }} else {{
488 None
489 }};
490
491 Ok((
492 data,
493 serde_json::json!({{
494 "cursor": cursor,
495 "has_more": has_more
496 }})
497 ))"###,
498 record_name = context.record_name,
499 select_columns = context.select_columns,
500 table_name = context.resource.resource,
501 soft_delete_clause = if has_soft_delete(context.resource) {
502 " AND \"deleted_at\" IS NULL\n"
503 } else {
504 ""
505 },
506 filter_clause = if filter_clause.is_empty() {
507 String::new()
508 } else {
509 format!("{filter_clause}\n")
510 },
511 search_clause = search_position
512 .map(|(position, expression)| {
513 format!(
514 "\n AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
515 )
516 })
517 .unwrap_or_default(),
518 cursor_position = cursor_position,
519 primary_key = context.primary_key,
520 order_by = order_by,
521 limit_position = limit_position,
522 args = args.join(",\n "),
523 )
524}
525
526#[allow(clippy::too_many_arguments)]
527fn generate_offset_query(
528 context: &ResourceContext<'_>,
529 filter_clause: &str,
530 search_position: Option<(usize, &str)>,
531 order_by: &str,
532 limit_position: usize,
533 offset_position: usize,
534 count_args: &[String],
535 row_args: &[String],
536) -> String {
537 let count_macro_args = if count_args.is_empty() {
538 String::new()
539 } else {
540 format!(
541 ",\n {}",
542 count_args.join(",\n ")
543 )
544 };
545
546 format!(
547 r###" let total = sqlx::query_scalar!(
548 r#"
549 SELECT COUNT(*) as "count!"
550 FROM "{table_name}"
551 WHERE TRUE
552{soft_delete_clause}
553{filter_clause}{search_clause}
554 "#{count_macro_args}
555 )
556 .fetch_one(&self.pool)
557 .await?;
558
559 let rows = sqlx::query_as!(
560 {record_name},
561 r#"
562 SELECT
563 {select_columns}
564 FROM "{table_name}"
565 WHERE TRUE
566{soft_delete_clause}
567{filter_clause}{search_clause}
568 ORDER BY
569{order_by}
570 LIMIT ${limit_param}
571 OFFSET ${offset_param}
572 "#,
573 {row_args}
574 )
575 .fetch_all(&self.pool)
576 .await?;
577
578 let data = rows
579 .iter()
580 .map(row_from_model)
581 .collect::<Result<Vec<_>, _>>()?;
582
583 Ok((
584 data,
585 serde_json::json!({{
586 "offset": offset,
587 "limit": limit,
588 "total": total
589 }})
590 ))"###,
591 table_name = context.resource.resource,
592 soft_delete_clause = if has_soft_delete(context.resource) {
593 " AND \"deleted_at\" IS NULL\n"
594 } else {
595 ""
596 },
597 filter_clause = if filter_clause.is_empty() {
598 String::new()
599 } else {
600 format!("{filter_clause}\n")
601 },
602 search_clause = search_position
603 .map(|(position, expression)| {
604 format!(
605 "\n AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
606 )
607 })
608 .unwrap_or_default(),
609 count_macro_args = count_macro_args,
610 record_name = context.record_name,
611 select_columns = context.select_columns,
612 order_by = order_by,
613 limit_param = limit_position,
614 offset_param = offset_position,
615 row_args = row_args.join(",\n "),
616 )
617}
618
619fn generate_filter_predicates(
620 context: &ResourceContext<'_>,
621 positions: &[(String, usize)],
622) -> Result<Vec<String>, String> {
623 positions
624 .iter()
625 .map(|(field_name, position)| {
626 let field = context.resource.schema.get(field_name).ok_or_else(|| {
627 format!(
628 "Unknown filter field '{field_name}' on resource '{}'",
629 context.resource.resource
630 )
631 })?;
632 Ok(format!(
633 " AND (${position}::{cast} IS NULL OR \"{field_name}\" = ${position})",
634 cast = sql_cast_type(field)
635 ))
636 })
637 .collect()
638}
639
640fn generate_order_by(
641 context: &ResourceContext<'_>,
642 sort_fields: &[String],
643 positions: &[(usize, usize)],
644) -> Result<String, String> {
645 if sort_fields.is_empty() {
646 return Ok(format!("\"{}\" ASC", context.primary_key));
647 }
648
649 let mut clauses = Vec::new();
650 for ((field_param, direction_param), field_name) in positions.iter().zip(sort_fields) {
651 for candidate in sort_fields {
652 let field = context.resource.schema.get(candidate).ok_or_else(|| {
653 format!(
654 "Unknown sort field '{candidate}' on resource '{}'",
655 context.resource.resource
656 )
657 })?;
658 let sort_expr = sortable_expression(candidate, field);
659 clauses.push(format!(
660 " CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'asc' THEN {sort_expr} END ASC"
661 ));
662 clauses.push(format!(
663 " CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'desc' THEN {sort_expr} END DESC"
664 ));
665 }
666 let _ = field_name;
667 }
668 clauses.push(format!(" \"{}\" ASC", context.primary_key));
669 Ok(clauses.join(",\n"))
670}
671
672fn generate_insert_body(context: &ResourceContext<'_>) -> Result<String, String> {
673 let mut declarations = Vec::new();
674 let mut columns = Vec::new();
675 let mut values = Vec::new();
676 let mut args = Vec::new();
677
678 for (index, (field_name, field)) in context.resource.schema.iter().enumerate() {
679 let variable_name = sanitize_identifier(field_name);
680 declarations.push(generate_insert_declaration(
681 field_name,
682 field,
683 &variable_name,
684 )?);
685 columns.push(format!("\"{field_name}\""));
686 values.push(format!("${}", index + 1));
687 args.push(variable_name);
688 }
689
690 Ok(format!(
691 r###"{declarations}
692 let row = sqlx::query_as!(
693 {record_name},
694 r#"
695 INSERT INTO "{table_name}" ({columns})
696 VALUES ({values})
697 RETURNING
698 {select_columns}
699 "#,
700 {args}
701 )
702 .fetch_one(&self.pool)
703 .await?;
704
705 row_from_model(&row)"###,
706 declarations = declarations.join("\n"),
707 record_name = context.record_name,
708 table_name = context.resource.resource,
709 columns = columns.join(", "),
710 values = values.join(", "),
711 select_columns = context.select_columns,
712 args = args.join(",\n "),
713 ))
714}
715
716fn generate_update_body(context: &ResourceContext<'_>) -> Result<String, String> {
717 let mut declarations = Vec::new();
718 let mut set_clauses = Vec::new();
719 let mut args = vec!["id".to_string()];
720 let mut has_mutable_fields = Vec::new();
721 let mut index = 2usize;
722
723 for (field_name, field) in &context.resource.schema {
724 if field.primary || field.generated {
725 continue;
726 }
727
728 let present_name = format!("{}_present", sanitize_identifier(field_name));
729 let value_name = sanitize_identifier(field_name);
730 declarations.push(generate_update_declaration(
731 field_name,
732 field,
733 &present_name,
734 &value_name,
735 ));
736 has_mutable_fields.push(present_name.clone());
737 set_clauses.push(format!(
738 "\"{field_name}\" = CASE WHEN ${present_param} THEN ${value_param} ELSE \"{field_name}\" END",
739 present_param = index,
740 value_param = index + 1
741 ));
742 args.push(present_name);
743 args.push(value_name);
744 index += 2;
745 }
746
747 if let Some(updated_at) = context.resource.schema.get("updated_at") {
748 if updated_at.generated && updated_at.field_type == FieldType::Timestamp {
749 declarations.push(" let updated_at = chrono::Utc::now();".to_string());
750 set_clauses.push(format!("\"updated_at\" = ${index}"));
751 args.push("updated_at".to_string());
752 }
753 }
754
755 let guard = if has_mutable_fields.is_empty() {
756 String::new()
757 } else {
758 format!(
759 " if !({}) {}",
760 has_mutable_fields.join(" || "),
761 r#"{
762 return Err(shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
763 field: "body".to_string(),
764 message: "No valid fields to update".to_string(),
765 code: "empty_update".to_string(),
766 }]));
767 }"#
768 )
769 };
770
771 Ok(format!(
772 r###"{declarations}
773{guard}
774 let row = sqlx::query_as!(
775 {record_name},
776 r#"
777 UPDATE "{table_name}"
778 SET {set_clauses}
779 WHERE "{primary_key}" = $1{soft_delete_where}
780 RETURNING
781 {select_columns}
782 "#,
783 {args}
784 )
785 .fetch_optional(&self.pool)
786 .await?
787 .ok_or(shaperail_core::ShaperailError::NotFound)?;
788
789 row_from_model(&row)"###,
790 declarations = declarations.join("\n"),
791 guard = guard,
792 record_name = context.record_name,
793 table_name = context.resource.resource,
794 set_clauses = set_clauses.join(", "),
795 primary_key = context.primary_key,
796 soft_delete_where = context.soft_delete_where,
797 select_columns = context.select_columns,
798 args = args.join(",\n "),
799 ))
800}
801
802fn generate_soft_delete_body(context: &ResourceContext<'_>) -> String {
803 format!(
804 r###" let deleted_at = chrono::Utc::now();
805 let row = sqlx::query_as!(
806 {record_name},
807 r#"
808 UPDATE "{table_name}"
809 SET "deleted_at" = $2
810 WHERE "{primary_key}" = $1 AND "deleted_at" IS NULL
811 RETURNING
812 {select_columns}
813 "#,
814 id,
815 deleted_at
816 )
817 .fetch_optional(&self.pool)
818 .await?
819 .ok_or(shaperail_core::ShaperailError::NotFound)?;
820
821 row_from_model(&row)"###,
822 record_name = context.record_name,
823 table_name = context.resource.resource,
824 primary_key = context.primary_key,
825 select_columns = context.select_columns,
826 )
827}
828
829fn generate_hard_delete_body(context: &ResourceContext<'_>) -> String {
830 format!(
831 r###" let row = sqlx::query_as!(
832 {record_name},
833 r#"
834 DELETE FROM "{table_name}"
835 WHERE "{primary_key}" = $1
836 RETURNING
837 {select_columns}
838 "#,
839 id
840 )
841 .fetch_optional(&self.pool)
842 .await?
843 .ok_or(shaperail_core::ShaperailError::NotFound)?;
844
845 row_from_model(&row)"###,
846 record_name = context.record_name,
847 table_name = context.resource.resource,
848 primary_key = context.primary_key,
849 select_columns = context.select_columns,
850 )
851}
852
853fn generate_insert_declaration(
854 field_name: &str,
855 field: &FieldSchema,
856 variable_name: &str,
857) -> Result<String, String> {
858 if field.generated {
859 return Ok(format!(
860 " let {variable_name} = {};",
861 generated_value_expression(field)
862 ));
863 }
864
865 let parse_type = parse_type(field);
866 let parsed = format!(
867 "shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?"
868 );
869
870 let expression = match (field_is_required(field), field.default.as_ref()) {
871 (true, Some(default)) => format!(
872 "match {parsed} {{ Some(value) => value, None => {} }}",
873 default_expression(field_name, field, default)?
874 ),
875 (true, None) => format!("shaperail_runtime::db::require_field({parsed}, {field_name:?})?"),
876 (false, Some(default)) if model_field_is_optional(field) => format!(
877 "match {parsed} {{ Some(value) => Some(value), None => Some({}) }}",
878 default_expression(field_name, field, default)?
879 ),
880 (false, Some(default)) => format!(
881 "match {parsed} {{ Some(value) => value, None => {} }}",
882 default_expression(field_name, field, default)?
883 ),
884 (false, None) => parsed,
885 };
886
887 Ok(format!(" let {variable_name} = {expression};"))
888}
889
890fn generate_update_declaration(
891 field_name: &str,
892 field: &FieldSchema,
893 present_name: &str,
894 value_name: &str,
895) -> String {
896 format!(
897 " let {present_name} = data.contains_key({field_name:?});\n let {value_name} = shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?;",
898 parse_type = parse_type(field)
899 )
900}
901
902fn generate_filter_declaration(field_name: &str, field: &FieldSchema) -> String {
903 let parser = match field.field_type {
904 FieldType::Uuid => "uuid::Uuid::parse_str(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid uuid filter\".to_string()))",
905 FieldType::String | FieldType::Enum | FieldType::File => "Ok(text.to_string())",
906 FieldType::Integer => "text.parse::<i32>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid integer filter\".to_string()))",
907 FieldType::Bigint => "text.parse::<i64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid bigint filter\".to_string()))",
908 FieldType::Number => "text.parse::<f64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid number filter\".to_string()))",
909 FieldType::Boolean => "text.parse::<bool>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid boolean filter\".to_string()))",
910 FieldType::Timestamp => "chrono::DateTime::parse_from_rfc3339(text).map(|value| value.with_timezone(&chrono::Utc)).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid timestamp filter\".to_string()))",
911 FieldType::Date => "chrono::NaiveDate::parse_from_str(text, \"%Y-%m-%d\").map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid date filter\".to_string()))",
912 FieldType::Json => "serde_json::from_str::<serde_json::Value>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid json filter\".to_string()))",
913 FieldType::Array => "serde_json::from_str::<Vec<serde_json::Value>>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid array filter\".to_string()))",
914 };
915
916 format!(
917 " let {var} = parse_filter(filters, {field_name:?}, \"invalid_filter\", |text| {parser})?;",
918 var = field_parameter_name(field_name)
919 )
920}
921
922fn field_parameter_name(field_name: &str) -> String {
923 format!("filter_{}", sanitize_identifier(field_name))
924}
925
926fn parameter_expression(field_name: &str, field: &FieldSchema) -> String {
927 let var = field_parameter_name(field_name);
928 match field.field_type {
929 FieldType::String | FieldType::Enum | FieldType::File => format!("{var}.as_deref()"),
930 _ => var,
931 }
932}
933
934fn select_column_sql(field_name: &str, field: &FieldSchema) -> String {
935 let nullability = if model_field_is_optional(field) {
936 "?"
937 } else {
938 "!"
939 };
940 let expression = match field.field_type {
941 FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
942 _ => format!("\"{field_name}\""),
943 };
944 format!(
945 "{expression} as \"{field_name}{nullability}: {type_name}\"",
946 type_name = query_type(field)
947 )
948}
949
950fn sortable_expression(field_name: &str, field: &FieldSchema) -> String {
951 match field.field_type {
952 FieldType::Json | FieldType::Array | FieldType::Uuid => format!("\"{field_name}\"::text"),
953 FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
954 _ => format!("\"{field_name}\""),
955 }
956}
957
958fn search_expression(fields: &[String]) -> String {
959 fields
960 .iter()
961 .map(|field| format!("COALESCE(\"{field}\"::text, '')"))
962 .collect::<Vec<_>>()
963 .join(" || ' ' || ")
964}
965
966fn sql_cast_type(field: &FieldSchema) -> String {
967 match field.field_type {
968 FieldType::Uuid => "uuid".to_string(),
969 FieldType::String | FieldType::Enum | FieldType::File => "text".to_string(),
970 FieldType::Integer => "integer".to_string(),
971 FieldType::Bigint => "bigint".to_string(),
972 FieldType::Number => "double precision".to_string(),
973 FieldType::Boolean => "boolean".to_string(),
974 FieldType::Timestamp => "timestamptz".to_string(),
975 FieldType::Date => "date".to_string(),
976 FieldType::Json => "jsonb".to_string(),
977 FieldType::Array => match field.items.as_deref() {
978 Some("uuid") => "uuid[]".to_string(),
979 Some("integer") => "integer[]".to_string(),
980 Some("bigint") => "bigint[]".to_string(),
981 Some("number") => "double precision[]".to_string(),
982 Some("boolean") => "boolean[]".to_string(),
983 _ => "text[]".to_string(),
984 },
985 }
986}
987
988fn query_type(field: &FieldSchema) -> String {
989 match field.field_type {
990 FieldType::Uuid => "uuid::Uuid".to_string(),
991 FieldType::String | FieldType::Enum | FieldType::File => "String".to_string(),
992 FieldType::Integer => "i32".to_string(),
993 FieldType::Bigint => "i64".to_string(),
994 FieldType::Number => "f64".to_string(),
995 FieldType::Boolean => "bool".to_string(),
996 FieldType::Timestamp => "chrono::DateTime<chrono::Utc>".to_string(),
997 FieldType::Date => "chrono::NaiveDate".to_string(),
998 FieldType::Json => "serde_json::Value".to_string(),
999 FieldType::Array => match field.items.as_deref() {
1000 Some("uuid") => "Vec<uuid::Uuid>".to_string(),
1001 Some("integer") => "Vec<i32>".to_string(),
1002 Some("bigint") => "Vec<i64>".to_string(),
1003 Some("number") => "Vec<f64>".to_string(),
1004 Some("boolean") => "Vec<bool>".to_string(),
1005 Some("timestamp") => "Vec<chrono::DateTime<chrono::Utc>>".to_string(),
1006 Some("date") => "Vec<chrono::NaiveDate>".to_string(),
1007 _ => "Vec<String>".to_string(),
1008 },
1009 }
1010}
1011
1012fn parse_type(field: &FieldSchema) -> String {
1013 query_type(field)
1014}
1015
1016fn model_field_type(field: &FieldSchema) -> String {
1017 let base = query_type(field);
1018 if model_field_is_optional(field) {
1019 format!("Option<{base}>")
1020 } else {
1021 base
1022 }
1023}
1024
1025fn model_field_is_optional(field: &FieldSchema) -> bool {
1026 !(field.primary || (field.required && !field.nullable))
1027}
1028
1029fn field_is_required(field: &FieldSchema) -> bool {
1030 field.primary || (field.required && !field.nullable)
1031}
1032
1033fn generated_value_expression(field: &FieldSchema) -> String {
1034 match field.field_type {
1035 FieldType::Uuid => "uuid::Uuid::new_v4()".to_string(),
1036 FieldType::Timestamp => {
1037 if model_field_is_optional(field) {
1038 "Some(chrono::Utc::now())".to_string()
1039 } else {
1040 "chrono::Utc::now()".to_string()
1041 }
1042 }
1043 FieldType::Date => {
1044 if model_field_is_optional(field) {
1045 "Some(chrono::Utc::now().date_naive())".to_string()
1046 } else {
1047 "chrono::Utc::now().date_naive()".to_string()
1048 }
1049 }
1050 _ => "Default::default()".to_string(),
1051 }
1052}
1053
1054fn default_expression(
1055 field_name: &str,
1056 field: &FieldSchema,
1057 default: &serde_json::Value,
1058) -> Result<String, String> {
1059 Ok(match field.field_type {
1060 FieldType::Uuid => format!(
1061 "parse_embedded_json::<uuid::Uuid>({field_name:?}, serde_json::json!({default}))?"
1062 ),
1063 FieldType::String | FieldType::Enum | FieldType::File => {
1064 let value = default
1065 .as_str()
1066 .ok_or_else(|| format!("Default for '{field_name}' must be a string"))?;
1067 format!("{value:?}.to_string()")
1068 }
1069 FieldType::Integer => format!(
1070 "parse_embedded_json::<i32>({field_name:?}, serde_json::json!({default}))?"
1071 ),
1072 FieldType::Bigint => format!(
1073 "parse_embedded_json::<i64>({field_name:?}, serde_json::json!({default}))?"
1074 ),
1075 FieldType::Number => format!(
1076 "parse_embedded_json::<f64>({field_name:?}, serde_json::json!({default}))?"
1077 ),
1078 FieldType::Boolean => default
1079 .as_bool()
1080 .ok_or_else(|| format!("Default for '{field_name}' must be a boolean"))?
1081 .to_string(),
1082 FieldType::Timestamp => format!(
1083 "parse_embedded_json::<chrono::DateTime<chrono::Utc>>({field_name:?}, serde_json::json!({default}))?"
1084 ),
1085 FieldType::Date => format!(
1086 "parse_embedded_json::<chrono::NaiveDate>({field_name:?}, serde_json::json!({default}))?"
1087 ),
1088 FieldType::Json => format!("serde_json::json!({default})"),
1089 FieldType::Array => format!(
1090 "parse_embedded_json::<{}>({field_name:?}, serde_json::json!({default}))?",
1091 query_type(field)
1092 ),
1093 })
1094}
1095
1096fn has_soft_delete(resource: &ResourceDefinition) -> bool {
1097 resource
1098 .endpoints
1099 .as_ref()
1100 .map(|endpoints| endpoints.values().any(|endpoint| endpoint.soft_delete))
1101 .unwrap_or(false)
1102}
1103
1104fn sanitize_identifier(value: &str) -> String {
1105 let mut output = String::new();
1106 for ch in value.chars() {
1107 if ch.is_ascii_alphanumeric() {
1108 output.push(ch.to_ascii_lowercase());
1109 } else {
1110 output.push('_');
1111 }
1112 }
1113
1114 if output.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1115 output.insert(0, '_');
1116 }
1117
1118 output
1119}
1120
1121fn to_pascal_case(value: &str) -> String {
1122 value
1123 .split('_')
1124 .filter(|part| !part.is_empty())
1125 .map(|part| {
1126 let mut chars = part.chars();
1127 match chars.next() {
1128 Some(first) => {
1129 let mut segment = String::new();
1130 segment.extend(first.to_uppercase());
1131 segment.push_str(chars.as_str());
1132 segment
1133 }
1134 None => String::new(),
1135 }
1136 })
1137 .collect::<String>()
1138}
1139
1140fn indent_block(block: &str, indent: usize) -> String {
1141 if block.trim().is_empty() {
1142 return String::new();
1143 }
1144
1145 let prefix = " ".repeat(indent);
1146 block
1147 .lines()
1148 .map(|line| {
1149 if line.is_empty() {
1150 String::new()
1151 } else {
1152 format!("{prefix}{line}")
1153 }
1154 })
1155 .collect::<Vec<_>>()
1156 .join("\n")
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161 use super::*;
1162 use indexmap::IndexMap;
1163 use shaperail_core::{
1164 AuthRule, EndpointSpec, FieldSchema, HttpMethod, PaginationStyle, ResourceDefinition,
1165 };
1166
1167 fn sample_resource() -> ResourceDefinition {
1168 let mut schema = IndexMap::new();
1169 schema.insert(
1170 "id".to_string(),
1171 FieldSchema {
1172 field_type: FieldType::Uuid,
1173 primary: true,
1174 generated: true,
1175 required: false,
1176 unique: false,
1177 nullable: false,
1178 reference: None,
1179 min: None,
1180 max: None,
1181 format: None,
1182 values: None,
1183 default: None,
1184 sensitive: false,
1185 search: false,
1186 items: None,
1187 },
1188 );
1189 schema.insert(
1190 "email".to_string(),
1191 FieldSchema {
1192 field_type: FieldType::String,
1193 primary: false,
1194 generated: false,
1195 required: true,
1196 unique: true,
1197 nullable: false,
1198 reference: None,
1199 min: None,
1200 max: None,
1201 format: None,
1202 values: None,
1203 default: None,
1204 sensitive: false,
1205 search: true,
1206 items: None,
1207 },
1208 );
1209 schema.insert(
1210 "created_at".to_string(),
1211 FieldSchema {
1212 field_type: FieldType::Timestamp,
1213 primary: false,
1214 generated: true,
1215 required: false,
1216 unique: false,
1217 nullable: false,
1218 reference: None,
1219 min: None,
1220 max: None,
1221 format: None,
1222 values: None,
1223 default: None,
1224 sensitive: false,
1225 search: false,
1226 items: None,
1227 },
1228 );
1229
1230 let mut endpoints = indexmap::IndexMap::new();
1231 endpoints.insert(
1232 "list".to_string(),
1233 EndpointSpec {
1234 method: HttpMethod::Get,
1235 path: "/users".to_string(),
1236 auth: Some(AuthRule::Public),
1237 input: None,
1238 filters: Some(vec!["email".to_string()]),
1239 search: Some(vec!["email".to_string()]),
1240 pagination: Some(PaginationStyle::Cursor),
1241 sort: Some(vec!["created_at".to_string()]),
1242 cache: None,
1243 controller: None,
1244 events: None,
1245 jobs: None,
1246 upload: None,
1247 soft_delete: false,
1248 },
1249 );
1250
1251 ResourceDefinition {
1252 resource: "users".to_string(),
1253 version: 1,
1254 db: None,
1255 schema,
1256 endpoints: Some(endpoints),
1257 relations: None,
1258 indexes: None,
1259 }
1260 }
1261
1262 #[test]
1263 fn generates_query_as_store_module() {
1264 let resource = sample_resource();
1265 let code = generate_resource_module(&resource).unwrap();
1266
1267 assert!(code.contains("impl ResourceStore for UsersStore"));
1268 assert!(code.contains("sqlx::query_as!"));
1269 assert!(code.contains("find_all_list"));
1270 }
1271
1272 #[test]
1273 fn generates_registry_module() {
1274 let resource = sample_resource();
1275 let project = generate_project(&[resource]).unwrap();
1276
1277 assert!(project.mod_rs.contains("pub mod users;"));
1278 assert!(project.mod_rs.contains("build_store_registry"));
1279 }
1280}