Skip to main content

shaperail_codegen/
rust.rs

1use shaperail_core::{EndpointSpec, FieldSchema, FieldType, HttpMethod, ResourceDefinition};
2
3pub struct GeneratedRustModule {
4    pub file_name: String,
5    pub contents: String,
6}
7
8pub struct GeneratedRustProject {
9    pub modules: Vec<GeneratedRustModule>,
10    pub mod_rs: String,
11}
12
13pub fn generate_project(resources: &[ResourceDefinition]) -> Result<GeneratedRustProject, String> {
14    let mut modules = Vec::with_capacity(resources.len());
15    for resource in resources {
16        modules.push(GeneratedRustModule {
17            file_name: format!("{}.rs", resource.resource),
18            contents: generate_resource_module(resource)?,
19        });
20    }
21
22    Ok(GeneratedRustProject {
23        modules,
24        mod_rs: generate_registry_module(resources),
25    })
26}
27
28pub fn generate_resource_module(resource: &ResourceDefinition) -> Result<String, String> {
29    let context = ResourceContext::new(resource)?;
30
31    let model_fields = resource
32        .schema
33        .iter()
34        .map(|(name, field)| format!("    pub {name}: {},", model_field_type(field)))
35        .collect::<Vec<_>>()
36        .join("\n");
37
38    let list_helpers = context
39        .collection_endpoints
40        .iter()
41        .map(|endpoint| generate_list_helper(&context, endpoint))
42        .collect::<Result<Vec<_>, _>>()?
43        .join("\n\n");
44
45    let list_dispatch = if context.collection_endpoints.is_empty() {
46        "        let _ = (endpoint, filters, search, sort, page);\n        Err(shaperail_core::ShaperailError::Internal(\"No collection endpoints are available for generated list queries\".to_string()))".to_string()
47    } else {
48        let arms = context
49            .collection_endpoints
50            .iter()
51            .map(|endpoint| {
52                format!(
53                    "            {path:?} => self.{helper}(filters, search, sort, page).await,",
54                    path = endpoint.spec.path(),
55                    helper = endpoint.helper_name
56                )
57            })
58            .collect::<Vec<_>>()
59            .join("\n");
60
61        format!(
62            "        match endpoint.path() {{\n{arms}\n            _ => Err(shaperail_core::ShaperailError::Internal(format!(\"No generated list query for {{}}\", endpoint.path()))),\n        }}"
63        )
64    };
65
66    Ok(format!(
67        r###"//! Generated query module for the `{resource_name}` resource.
68//! DO NOT EDIT — this file is auto-generated by `shaperail generate`.
69
70use serde::{{Deserialize, Serialize}};
71use serde_json::{{Map, Value}};
72use shaperail_core::EndpointSpec;
73#[allow(unused_imports)]
74use shaperail_runtime::db::{{
75    async_trait, parse_embedded_json, parse_filter, parse_optional_json, require_field,
76    row_from_model, sort_direction_at, sort_field_at, FilterSet, PageRequest, ResourceRow,
77    ResourceStore, SearchParam, SortParam,
78}};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct {record_name} {{
82{model_fields}
83}}
84
85pub struct {store_name} {{
86    pool: sqlx::PgPool,
87}}
88
89impl {store_name} {{
90    pub fn new(pool: sqlx::PgPool) -> Self {{
91        Self {{ pool }}
92    }}
93
94{list_helpers}
95}}
96
97#[async_trait]
98impl ResourceStore for {store_name} {{
99    fn resource_name(&self) -> &str {{
100        "{resource_name}"
101    }}
102
103    async fn find_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
104        let row = sqlx::query_as!(
105            {record_name},
106            r#"
107            SELECT
108                {select_columns}
109            FROM "{table_name}"
110            WHERE "{primary_key}" = $1{soft_delete_where}
111            "#,
112            id
113        )
114        .fetch_optional(&self.pool)
115        .await?
116        .ok_or(shaperail_core::ShaperailError::NotFound)?;
117
118        row_from_model(&row)
119    }}
120
121    async fn find_all(
122        &self,
123        endpoint: &EndpointSpec,
124        filters: &FilterSet,
125        search: Option<&SearchParam>,
126        sort: &SortParam,
127        page: &PageRequest,
128    ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
129{list_dispatch}
130    }}
131
132    async fn insert(&self, data: &Map<String, Value>) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
133{insert_body}
134    }}
135
136    async fn update_by_id(
137        &self,
138        id: &uuid::Uuid,
139        data: &Map<String, Value>,
140    ) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
141{update_body}
142    }}
143
144    async fn soft_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
145{soft_delete_body}
146    }}
147
148    async fn hard_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
149{hard_delete_body}
150    }}
151}}
152"###,
153        resource_name = resource.resource,
154        record_name = context.record_name,
155        store_name = context.store_name,
156        model_fields = model_fields,
157        list_helpers = list_helpers,
158        select_columns = context.select_columns,
159        table_name = resource.resource,
160        primary_key = context.primary_key,
161        soft_delete_where = context.soft_delete_where,
162        list_dispatch = list_dispatch,
163        insert_body = generate_insert_body(&context)?,
164        update_body = generate_update_body(&context)?,
165        soft_delete_body = generate_soft_delete_body(&context),
166        hard_delete_body = generate_hard_delete_body(&context),
167    ))
168}
169
170fn generate_registry_module(resources: &[ResourceDefinition]) -> String {
171    let module_lines = resources
172        .iter()
173        .map(|resource| format!("pub mod {};", resource.resource))
174        .collect::<Vec<_>>()
175        .join("\n");
176
177    let registry_lines = resources
178        .iter()
179        .map(|resource| {
180            let store_name = format!("{}Store", to_pascal_case(&resource.resource));
181            format!(
182                "    stores.insert({name:?}.to_string(), std::sync::Arc::new({module}::{store_name}::new(pool.clone())));",
183                name = resource.resource,
184                module = resource.resource
185            )
186        })
187        .collect::<Vec<_>>()
188        .join("\n");
189
190    format!(
191        r#"{module_lines}
192
193pub fn build_store_registry(pool: sqlx::PgPool) -> shaperail_runtime::db::StoreRegistry {{
194    let mut stores: std::collections::HashMap<
195        String,
196        std::sync::Arc<dyn shaperail_runtime::db::ResourceStore>,
197    > = std::collections::HashMap::new();
198{registry_lines}
199    std::sync::Arc::new(stores)
200}}
201
202/// Returns an empty controller map. Register custom controller functions here
203/// or populate from `resources/<name>.controller.rs` files.
204pub fn build_controller_map() -> shaperail_runtime::handlers::controller::ControllerMap {{
205    shaperail_runtime::handlers::controller::ControllerMap::new()
206}}
207
208{controller_traits}
209"#,
210        controller_traits = generate_controller_traits(resources)
211    )
212}
213
214/// Generate typed controller trait stubs for all resources that declare controllers.
215///
216/// For each resource with controller declarations, generates:
217/// - An input struct for each endpoint action that has a controller
218/// - A trait with the exact function signatures the controller must implement
219///
220/// This eliminates the #1 source of LLM errors: guessing controller function signatures.
221fn generate_controller_traits(resources: &[ResourceDefinition]) -> String {
222    let mut output = String::new();
223
224    for resource in resources {
225        let endpoints_with_controllers: Vec<_> = resource
226            .endpoints
227            .as_ref()
228            .map(|endpoints| {
229                endpoints
230                    .iter()
231                    .filter(|(_, ep)| ep.controller.is_some())
232                    .collect::<Vec<_>>()
233            })
234            .unwrap_or_default();
235
236        if endpoints_with_controllers.is_empty() {
237            continue;
238        }
239
240        let pascal = to_pascal_case(&resource.resource);
241        let mut trait_methods = Vec::new();
242
243        for (action, ep) in &endpoints_with_controllers {
244            let controller = ep.controller.as_ref().unwrap();
245            let action_pascal = to_pascal_case(action);
246
247            // Determine input fields for this endpoint
248            let input_fields: Vec<_> = ep
249                .input
250                .as_ref()
251                .map(|fields| {
252                    fields
253                        .iter()
254                        .filter_map(|name| {
255                            resource.schema.get(name).map(|field| {
256                                format!("    pub {name}: {},", model_field_type(field))
257                            })
258                        })
259                        .collect()
260                })
261                .unwrap_or_default();
262
263            if !input_fields.is_empty() {
264                output.push_str(&format!(
265                    r#"
266/// Input fields for the {resource_name} {action} endpoint.
267/// Auto-generated from the resource schema — do not edit.
268#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
269pub struct {pascal}{action_pascal}Input {{
270{fields}
271}}
272"#,
273                    resource_name = resource.resource,
274                    fields = input_fields.join("\n"),
275                ));
276            }
277
278            // Generate trait method for before hook
279            if let Some(before) = &controller.before {
280                if !before.starts_with(shaperail_core::WASM_HOOK_PREFIX) {
281                    let input_type = if input_fields.is_empty() {
282                        "serde_json::Value".to_string()
283                    } else {
284                        format!("{pascal}{action_pascal}Input")
285                    };
286                    trait_methods.push(format!(
287                        "    /// Before-hook for the {action} endpoint. Called before the DB operation.\n    async fn {before}(ctx: &shaperail_runtime::handlers::controller::ControllerContext, input: &{input_type}) -> Result<(), shaperail_core::ShaperailError>;"
288                    ));
289                }
290            }
291
292            // Generate trait method for after hook
293            if let Some(after) = &controller.after {
294                if !after.starts_with(shaperail_core::WASM_HOOK_PREFIX) {
295                    trait_methods.push(format!(
296                        "    /// After-hook for the {action} endpoint. Called after the DB operation.\n    async fn {after}(ctx: &shaperail_runtime::handlers::controller::ControllerContext, result: &serde_json::Value) -> Result<serde_json::Value, shaperail_core::ShaperailError>;"
297                    ));
298                }
299            }
300        }
301
302        if !trait_methods.is_empty() {
303            output.push_str(&format!(
304                r#"
305/// Controller trait for the {resource_name} resource.
306/// Implement this trait in `controllers/{resource_name}.controller.rs`.
307/// The compiler will enforce correct signatures — no guessing needed.
308#[async_trait::async_trait]
309pub trait {pascal}Controller {{
310{methods}
311}}
312"#,
313                resource_name = resource.resource,
314                methods = trait_methods.join("\n\n"),
315            ));
316        }
317    }
318
319    output
320}
321
322#[derive(Clone)]
323struct CollectionEndpoint<'a> {
324    spec: &'a EndpointSpec,
325    helper_name: String,
326}
327
328struct ResourceContext<'a> {
329    resource: &'a ResourceDefinition,
330    record_name: String,
331    store_name: String,
332    primary_key: String,
333    select_columns: String,
334    soft_delete_where: String,
335    collection_endpoints: Vec<CollectionEndpoint<'a>>,
336}
337
338impl<'a> ResourceContext<'a> {
339    fn new(resource: &'a ResourceDefinition) -> Result<Self, String> {
340        let primary_key = resource
341            .schema
342            .iter()
343            .find(|(_, field)| field.primary)
344            .map(|(name, _)| name.clone())
345            .unwrap_or_else(|| "id".to_string());
346
347        let select_columns = resource
348            .schema
349            .iter()
350            .map(|(name, field)| select_column_sql(name, field))
351            .collect::<Vec<_>>()
352            .join(",\n                ");
353
354        let collection_endpoints = resource
355            .endpoints
356            .as_ref()
357            .map(|endpoints| {
358                endpoints
359                    .iter()
360                    .filter(|(_, endpoint)| {
361                        *endpoint.method() == HttpMethod::Get && !endpoint.path().contains(":id")
362                    })
363                    .map(|(name, endpoint)| CollectionEndpoint {
364                        spec: endpoint,
365                        helper_name: format!("find_all_{}", sanitize_identifier(name)),
366                    })
367                    .collect::<Vec<_>>()
368            })
369            .unwrap_or_default();
370
371        Ok(Self {
372            resource,
373            record_name: format!("{}Record", to_pascal_case(&resource.resource)),
374            store_name: format!("{}Store", to_pascal_case(&resource.resource)),
375            primary_key,
376            select_columns,
377            soft_delete_where: if has_soft_delete(resource) {
378                " AND \"deleted_at\" IS NULL".to_string()
379            } else {
380                String::new()
381            },
382            collection_endpoints,
383        })
384    }
385}
386
387fn generate_list_helper(
388    context: &ResourceContext<'_>,
389    endpoint: &CollectionEndpoint<'_>,
390) -> Result<String, String> {
391    let filters = endpoint.spec.filters.clone().unwrap_or_default();
392    let search_fields = endpoint.spec.search.clone().unwrap_or_default();
393    let sort_fields = endpoint.spec.sort.clone().unwrap_or_default();
394
395    let filter_decls = filters
396        .iter()
397        .map(|field_name| {
398            let field = context.resource.schema.get(field_name).ok_or_else(|| {
399                format!(
400                    "Unknown filter field '{field_name}' on resource '{}'",
401                    context.resource.resource
402                )
403            })?;
404            Ok(generate_filter_declaration(field_name, field))
405        })
406        .collect::<Result<Vec<_>, String>>()?
407        .join("\n");
408
409    let filter_args = filters
410        .iter()
411        .map(|field_name| {
412            parameter_expression(
413                field_name,
414                context
415                    .resource
416                    .schema
417                    .get(field_name)
418                    .expect("filter field validated"),
419            )
420        })
421        .collect::<Vec<_>>();
422
423    let search_decl = if search_fields.is_empty() {
424        String::new()
425    } else {
426        "        let search_term = search.map(|value| value.term.clone());".to_string()
427    };
428
429    let search_predicate = if search_fields.is_empty() {
430        String::new()
431    } else {
432        search_expression(&search_fields)
433    };
434
435    let sort_decls = (0..sort_fields.len())
436        .map(|index| {
437            format!(
438                "        let sort_field_{index} = sort_field_at(sort, {index});\n        let sort_direction_{index} = sort_direction_at(sort, {index});"
439            )
440        })
441        .collect::<Vec<_>>()
442        .join("\n");
443
444    let filter_positions = filters
445        .iter()
446        .enumerate()
447        .map(|(index, field_name)| (field_name.clone(), index + 1))
448        .collect::<Vec<_>>();
449
450    let search_position = if search_fields.is_empty() {
451        None
452    } else {
453        Some(filter_positions.len() + 1)
454    };
455
456    let cursor_position = filter_positions.len() + usize::from(search_position.is_some()) + 1;
457    let cursor_sort_positions = (0..sort_fields.len())
458        .map(|index| {
459            let base = cursor_position + 1 + (index * 2);
460            (base, base + 1)
461        })
462        .collect::<Vec<_>>();
463    let offset_sort_positions = (0..sort_fields.len())
464        .map(|index| {
465            let base =
466                filter_positions.len() + usize::from(search_position.is_some()) + 1 + (index * 2);
467            (base, base + 1)
468        })
469        .collect::<Vec<_>>();
470
471    let filter_predicates = generate_filter_predicates(context, &filter_positions)?;
472    let filter_clause = filter_predicates.join("\n");
473
474    let cursor_order_by = generate_order_by(context, &sort_fields, &cursor_sort_positions)?;
475    let offset_order_by = generate_order_by(context, &sort_fields, &offset_sort_positions)?;
476
477    let mut cursor_args = filter_args.clone();
478    if search_position.is_some() {
479        cursor_args.push("search_term.as_deref()".to_string());
480    }
481    cursor_args.push("cursor".to_string());
482    for index in 0..sort_fields.len() {
483        cursor_args.push(format!("sort_field_{index}.as_deref()"));
484        cursor_args.push(format!("sort_direction_{index}"));
485    }
486    cursor_args.push("*limit + 1".to_string());
487
488    let mut count_args = filter_args.clone();
489    if search_position.is_some() {
490        count_args.push("search_term.as_deref()".to_string());
491    }
492
493    let mut row_args = count_args.clone();
494    for index in 0..sort_fields.len() {
495        row_args.push(format!("sort_field_{index}.as_deref()"));
496        row_args.push(format!("sort_direction_{index}"));
497    }
498    row_args.push("*limit".to_string());
499    row_args.push("*offset".to_string());
500
501    let cursor_query = generate_cursor_query(
502        context,
503        &filter_clause,
504        search_position.map(|position| (position, search_predicate.as_str())),
505        cursor_position,
506        &cursor_order_by,
507        cursor_args.len(),
508        &cursor_args,
509    );
510    let offset_query = generate_offset_query(
511        context,
512        &filter_clause,
513        search_position.map(|position| (position, search_predicate.as_str())),
514        &offset_order_by,
515        row_args.len() - 1,
516        row_args.len(),
517        &count_args,
518        &row_args,
519    );
520
521    Ok(format!(
522        r###"    async fn {helper_name}(
523        &self,
524        filters: &FilterSet,
525        search: Option<&SearchParam>,
526        sort: &SortParam,
527        page: &PageRequest,
528    ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
529{filter_decls}
530{search_decl}
531{sort_decls}
532
533        match page {{
534            PageRequest::Cursor {{ after, limit }} => {{
535                let cursor = match after {{
536                    Some(cursor_value) => Some(uuid::Uuid::parse_str(
537                        &shaperail_runtime::db::decode_cursor(cursor_value)?
538                    ).map_err(|_| shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {{
539                        field: "cursor".to_string(),
540                        message: "Invalid cursor value".to_string(),
541                        code: "invalid_cursor".to_string(),
542                    }}]))?),
543                    None => None,
544                }};
545{cursor_query}
546            }}
547            PageRequest::Offset {{ offset, limit }} => {{
548{offset_query}
549            }}
550        }}
551    }}"###,
552        helper_name = endpoint.helper_name,
553        filter_decls = indent_block(&filter_decls, 2),
554        search_decl = indent_block(&search_decl, 2),
555        sort_decls = indent_block(&sort_decls, 2),
556        cursor_query = indent_block(&cursor_query, 4),
557        offset_query = indent_block(&offset_query, 4),
558    ))
559}
560
561fn generate_cursor_query(
562    context: &ResourceContext<'_>,
563    filter_clause: &str,
564    search_position: Option<(usize, &str)>,
565    cursor_position: usize,
566    order_by: &str,
567    limit_position: usize,
568    args: &[String],
569) -> String {
570    format!(
571        r###"                let rows = sqlx::query_as!(
572                    {record_name},
573                    r#"
574                    SELECT
575                        {select_columns}
576                    FROM "{table_name}"
577                    WHERE TRUE
578{soft_delete_clause}
579{filter_clause}{search_clause}
580                        AND (${cursor_position}::uuid IS NULL OR "{primary_key}" > ${cursor_position})
581                    ORDER BY
582{order_by}
583                    LIMIT ${limit_position}
584                    "#,
585                    {args}
586                )
587                .fetch_all(&self.pool)
588                .await?;
589
590                let has_more = rows.len() as i64 > *limit;
591                let mut result_rows = rows;
592                if has_more {{
593                    result_rows.truncate(*limit as usize);
594                }}
595
596                let data = result_rows
597                    .iter()
598                    .map(row_from_model)
599                    .collect::<Result<Vec<_>, _>>()?;
600                let cursor = if has_more {{
601                    result_rows
602                        .last()
603                        .map(|row| shaperail_runtime::db::encode_cursor(&row.{primary_key}.to_string()))
604                }} else {{
605                    None
606                }};
607
608                Ok((
609                    data,
610                    serde_json::json!({{
611                        "cursor": cursor,
612                        "has_more": has_more
613                    }})
614                ))"###,
615        record_name = context.record_name,
616        select_columns = context.select_columns,
617        table_name = context.resource.resource,
618        soft_delete_clause = if has_soft_delete(context.resource) {
619            "                        AND \"deleted_at\" IS NULL\n"
620        } else {
621            ""
622        },
623        filter_clause = if filter_clause.is_empty() {
624            String::new()
625        } else {
626            format!("{filter_clause}\n")
627        },
628        search_clause = search_position
629            .map(|(position, expression)| {
630                format!(
631                    "\n                        AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
632                )
633            })
634            .unwrap_or_default(),
635        cursor_position = cursor_position,
636        primary_key = context.primary_key,
637        order_by = order_by,
638        limit_position = limit_position,
639        args = args.join(",\n                    "),
640    )
641}
642
643#[allow(clippy::too_many_arguments)]
644fn generate_offset_query(
645    context: &ResourceContext<'_>,
646    filter_clause: &str,
647    search_position: Option<(usize, &str)>,
648    order_by: &str,
649    limit_position: usize,
650    offset_position: usize,
651    count_args: &[String],
652    row_args: &[String],
653) -> String {
654    let count_macro_args = if count_args.is_empty() {
655        String::new()
656    } else {
657        format!(
658            ",\n                    {}",
659            count_args.join(",\n                    ")
660        )
661    };
662
663    format!(
664        r###"                let total = sqlx::query_scalar!(
665                    r#"
666                    SELECT COUNT(*) as "count!"
667                    FROM "{table_name}"
668                    WHERE TRUE
669{soft_delete_clause}
670{filter_clause}{search_clause}
671                    "#{count_macro_args}
672                )
673                .fetch_one(&self.pool)
674                .await?;
675
676                let rows = sqlx::query_as!(
677                    {record_name},
678                    r#"
679                    SELECT
680                        {select_columns}
681                    FROM "{table_name}"
682                    WHERE TRUE
683{soft_delete_clause}
684{filter_clause}{search_clause}
685                    ORDER BY
686{order_by}
687                    LIMIT ${limit_param}
688                    OFFSET ${offset_param}
689                    "#,
690                    {row_args}
691                )
692                .fetch_all(&self.pool)
693                .await?;
694
695                let data = rows
696                    .iter()
697                    .map(row_from_model)
698                    .collect::<Result<Vec<_>, _>>()?;
699
700                Ok((
701                    data,
702                    serde_json::json!({{
703                        "offset": offset,
704                        "limit": limit,
705                        "total": total
706                    }})
707                ))"###,
708        table_name = context.resource.resource,
709        soft_delete_clause = if has_soft_delete(context.resource) {
710            "                        AND \"deleted_at\" IS NULL\n"
711        } else {
712            ""
713        },
714        filter_clause = if filter_clause.is_empty() {
715            String::new()
716        } else {
717            format!("{filter_clause}\n")
718        },
719        search_clause = search_position
720            .map(|(position, expression)| {
721                format!(
722                    "\n                        AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
723                )
724            })
725            .unwrap_or_default(),
726        count_macro_args = count_macro_args,
727        record_name = context.record_name,
728        select_columns = context.select_columns,
729        order_by = order_by,
730        limit_param = limit_position,
731        offset_param = offset_position,
732        row_args = row_args.join(",\n                    "),
733    )
734}
735
736fn generate_filter_predicates(
737    context: &ResourceContext<'_>,
738    positions: &[(String, usize)],
739) -> Result<Vec<String>, String> {
740    positions
741        .iter()
742        .map(|(field_name, position)| {
743            let field = context.resource.schema.get(field_name).ok_or_else(|| {
744                format!(
745                    "Unknown filter field '{field_name}' on resource '{}'",
746                    context.resource.resource
747                )
748            })?;
749            Ok(format!(
750                "                AND (${position}::{cast} IS NULL OR \"{field_name}\" = ${position})",
751                cast = sql_cast_type(field)
752            ))
753        })
754        .collect()
755}
756
757fn generate_order_by(
758    context: &ResourceContext<'_>,
759    sort_fields: &[String],
760    positions: &[(usize, usize)],
761) -> Result<String, String> {
762    if sort_fields.is_empty() {
763        return Ok(format!("\"{}\" ASC", context.primary_key));
764    }
765
766    let mut clauses = Vec::new();
767    for ((field_param, direction_param), field_name) in positions.iter().zip(sort_fields) {
768        for candidate in sort_fields {
769            let field = context.resource.schema.get(candidate).ok_or_else(|| {
770                format!(
771                    "Unknown sort field '{candidate}' on resource '{}'",
772                    context.resource.resource
773                )
774            })?;
775            let sort_expr = sortable_expression(candidate, field);
776            clauses.push(format!(
777                "                CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'asc' THEN {sort_expr} END ASC"
778            ));
779            clauses.push(format!(
780                "                CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'desc' THEN {sort_expr} END DESC"
781            ));
782        }
783        let _ = field_name;
784    }
785    clauses.push(format!("                \"{}\" ASC", context.primary_key));
786    Ok(clauses.join(",\n"))
787}
788
789fn generate_insert_body(context: &ResourceContext<'_>) -> Result<String, String> {
790    let mut declarations = Vec::new();
791    let mut columns = Vec::new();
792    let mut values = Vec::new();
793    let mut args = Vec::new();
794
795    for (index, (field_name, field)) in context.resource.schema.iter().enumerate() {
796        let variable_name = sanitize_identifier(field_name);
797        declarations.push(generate_insert_declaration(
798            field_name,
799            field,
800            &variable_name,
801        )?);
802        columns.push(format!("\"{field_name}\""));
803        values.push(format!("${}", index + 1));
804        args.push(variable_name);
805    }
806
807    Ok(format!(
808        r###"{declarations}
809        let row = sqlx::query_as!(
810            {record_name},
811            r#"
812            INSERT INTO "{table_name}" ({columns})
813            VALUES ({values})
814            RETURNING
815                {select_columns}
816            "#,
817            {args}
818        )
819        .fetch_one(&self.pool)
820        .await?;
821
822        row_from_model(&row)"###,
823        declarations = declarations.join("\n"),
824        record_name = context.record_name,
825        table_name = context.resource.resource,
826        columns = columns.join(", "),
827        values = values.join(", "),
828        select_columns = context.select_columns,
829        args = args.join(",\n            "),
830    ))
831}
832
833fn generate_update_body(context: &ResourceContext<'_>) -> Result<String, String> {
834    let mut declarations = Vec::new();
835    let mut set_clauses = Vec::new();
836    let mut args = vec!["id".to_string()];
837    let mut has_mutable_fields = Vec::new();
838    let mut index = 2usize;
839
840    for (field_name, field) in &context.resource.schema {
841        if field.primary || field.generated {
842            continue;
843        }
844
845        let present_name = format!("{}_present", sanitize_identifier(field_name));
846        let value_name = sanitize_identifier(field_name);
847        declarations.push(generate_update_declaration(
848            field_name,
849            field,
850            &present_name,
851            &value_name,
852        ));
853        has_mutable_fields.push(present_name.clone());
854        set_clauses.push(format!(
855            "\"{field_name}\" = CASE WHEN ${present_param} THEN ${value_param} ELSE \"{field_name}\" END",
856            present_param = index,
857            value_param = index + 1
858        ));
859        args.push(present_name);
860        args.push(value_name);
861        index += 2;
862    }
863
864    if let Some(updated_at) = context.resource.schema.get("updated_at") {
865        if updated_at.generated && updated_at.field_type == FieldType::Timestamp {
866            declarations.push("        let updated_at = chrono::Utc::now();".to_string());
867            set_clauses.push(format!("\"updated_at\" = ${index}"));
868            args.push("updated_at".to_string());
869        }
870    }
871
872    let guard = if has_mutable_fields.is_empty() {
873        String::new()
874    } else {
875        format!(
876            "        if !({}) {}",
877            has_mutable_fields.join(" || "),
878            r#"{
879            return Err(shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
880                field: "body".to_string(),
881                message: "No valid fields to update".to_string(),
882                code: "empty_update".to_string(),
883            }]));
884        }"#
885        )
886    };
887
888    Ok(format!(
889        r###"{declarations}
890{guard}
891        let row = sqlx::query_as!(
892            {record_name},
893            r#"
894            UPDATE "{table_name}"
895            SET {set_clauses}
896            WHERE "{primary_key}" = $1{soft_delete_where}
897            RETURNING
898                {select_columns}
899            "#,
900            {args}
901        )
902        .fetch_optional(&self.pool)
903        .await?
904        .ok_or(shaperail_core::ShaperailError::NotFound)?;
905
906        row_from_model(&row)"###,
907        declarations = declarations.join("\n"),
908        guard = guard,
909        record_name = context.record_name,
910        table_name = context.resource.resource,
911        set_clauses = set_clauses.join(", "),
912        primary_key = context.primary_key,
913        soft_delete_where = context.soft_delete_where,
914        select_columns = context.select_columns,
915        args = args.join(",\n            "),
916    ))
917}
918
919fn generate_soft_delete_body(context: &ResourceContext<'_>) -> String {
920    format!(
921        r###"        let deleted_at = chrono::Utc::now();
922        let row = sqlx::query_as!(
923            {record_name},
924            r#"
925            UPDATE "{table_name}"
926            SET "deleted_at" = $2
927            WHERE "{primary_key}" = $1 AND "deleted_at" IS NULL
928            RETURNING
929                {select_columns}
930            "#,
931            id,
932            deleted_at
933        )
934        .fetch_optional(&self.pool)
935        .await?
936        .ok_or(shaperail_core::ShaperailError::NotFound)?;
937
938        row_from_model(&row)"###,
939        record_name = context.record_name,
940        table_name = context.resource.resource,
941        primary_key = context.primary_key,
942        select_columns = context.select_columns,
943    )
944}
945
946fn generate_hard_delete_body(context: &ResourceContext<'_>) -> String {
947    format!(
948        r###"        let row = sqlx::query_as!(
949            {record_name},
950            r#"
951            DELETE FROM "{table_name}"
952            WHERE "{primary_key}" = $1
953            RETURNING
954                {select_columns}
955            "#,
956            id
957        )
958        .fetch_optional(&self.pool)
959        .await?
960        .ok_or(shaperail_core::ShaperailError::NotFound)?;
961
962        row_from_model(&row)"###,
963        record_name = context.record_name,
964        table_name = context.resource.resource,
965        primary_key = context.primary_key,
966        select_columns = context.select_columns,
967    )
968}
969
970fn generate_insert_declaration(
971    field_name: &str,
972    field: &FieldSchema,
973    variable_name: &str,
974) -> Result<String, String> {
975    if field.generated {
976        return Ok(format!(
977            "        let {variable_name} = {};",
978            generated_value_expression(field)
979        ));
980    }
981
982    let parse_type = parse_type(field);
983    let parsed = format!(
984        "shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?"
985    );
986
987    let expression = match (field_is_required(field), field.default.as_ref()) {
988        (true, Some(default)) => format!(
989            "match {parsed} {{ Some(value) => value, None => {} }}",
990            default_expression(field_name, field, default)?
991        ),
992        (true, None) => format!("shaperail_runtime::db::require_field({parsed}, {field_name:?})?"),
993        (false, Some(default)) if model_field_is_optional(field) => format!(
994            "match {parsed} {{ Some(value) => Some(value), None => Some({}) }}",
995            default_expression(field_name, field, default)?
996        ),
997        (false, Some(default)) => format!(
998            "match {parsed} {{ Some(value) => value, None => {} }}",
999            default_expression(field_name, field, default)?
1000        ),
1001        (false, None) => parsed,
1002    };
1003
1004    Ok(format!("        let {variable_name} = {expression};"))
1005}
1006
1007fn generate_update_declaration(
1008    field_name: &str,
1009    field: &FieldSchema,
1010    present_name: &str,
1011    value_name: &str,
1012) -> String {
1013    format!(
1014        "        let {present_name} = data.contains_key({field_name:?});\n        let {value_name} = shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?;",
1015        parse_type = parse_type(field)
1016    )
1017}
1018
1019fn generate_filter_declaration(field_name: &str, field: &FieldSchema) -> String {
1020    let parser = match field.field_type {
1021        FieldType::Uuid => "uuid::Uuid::parse_str(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid uuid filter\".to_string()))",
1022        FieldType::String | FieldType::Enum | FieldType::File => "Ok(text.to_string())",
1023        FieldType::Integer => "text.parse::<i32>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid integer filter\".to_string()))",
1024        FieldType::Bigint => "text.parse::<i64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid bigint filter\".to_string()))",
1025        FieldType::Number => "text.parse::<f64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid number filter\".to_string()))",
1026        FieldType::Boolean => "text.parse::<bool>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid boolean filter\".to_string()))",
1027        FieldType::Timestamp => "chrono::DateTime::parse_from_rfc3339(text).map(|value| value.with_timezone(&chrono::Utc)).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid timestamp filter\".to_string()))",
1028        FieldType::Date => "chrono::NaiveDate::parse_from_str(text, \"%Y-%m-%d\").map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid date filter\".to_string()))",
1029        FieldType::Json => "serde_json::from_str::<serde_json::Value>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid json filter\".to_string()))",
1030        FieldType::Array => "serde_json::from_str::<Vec<serde_json::Value>>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid array filter\".to_string()))",
1031    };
1032
1033    format!(
1034        "        let {var} = parse_filter(filters, {field_name:?}, \"invalid_filter\", |text| {parser})?;",
1035        var = field_parameter_name(field_name)
1036    )
1037}
1038
1039fn field_parameter_name(field_name: &str) -> String {
1040    format!("filter_{}", sanitize_identifier(field_name))
1041}
1042
1043fn parameter_expression(field_name: &str, field: &FieldSchema) -> String {
1044    let var = field_parameter_name(field_name);
1045    match field.field_type {
1046        FieldType::String | FieldType::Enum | FieldType::File => format!("{var}.as_deref()"),
1047        _ => var,
1048    }
1049}
1050
1051fn select_column_sql(field_name: &str, field: &FieldSchema) -> String {
1052    let nullability = if model_field_is_optional(field) {
1053        "?"
1054    } else {
1055        "!"
1056    };
1057    let expression = match field.field_type {
1058        FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
1059        _ => format!("\"{field_name}\""),
1060    };
1061    format!(
1062        "{expression} as \"{field_name}{nullability}: {type_name}\"",
1063        type_name = query_type(field)
1064    )
1065}
1066
1067fn sortable_expression(field_name: &str, field: &FieldSchema) -> String {
1068    match field.field_type {
1069        FieldType::Json | FieldType::Array | FieldType::Uuid => format!("\"{field_name}\"::text"),
1070        FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
1071        _ => format!("\"{field_name}\""),
1072    }
1073}
1074
1075fn search_expression(fields: &[String]) -> String {
1076    fields
1077        .iter()
1078        .map(|field| format!("COALESCE(\"{field}\"::text, '')"))
1079        .collect::<Vec<_>>()
1080        .join(" || ' ' || ")
1081}
1082
1083fn sql_cast_type(field: &FieldSchema) -> String {
1084    match field.field_type {
1085        FieldType::Uuid => "uuid".to_string(),
1086        FieldType::String | FieldType::Enum | FieldType::File => "text".to_string(),
1087        FieldType::Integer => "integer".to_string(),
1088        FieldType::Bigint => "bigint".to_string(),
1089        FieldType::Number => "double precision".to_string(),
1090        FieldType::Boolean => "boolean".to_string(),
1091        FieldType::Timestamp => "timestamptz".to_string(),
1092        FieldType::Date => "date".to_string(),
1093        FieldType::Json => "jsonb".to_string(),
1094        FieldType::Array => match field.items.as_deref() {
1095            Some("uuid") => "uuid[]".to_string(),
1096            Some("integer") => "integer[]".to_string(),
1097            Some("bigint") => "bigint[]".to_string(),
1098            Some("number") => "double precision[]".to_string(),
1099            Some("boolean") => "boolean[]".to_string(),
1100            _ => "text[]".to_string(),
1101        },
1102    }
1103}
1104
1105fn query_type(field: &FieldSchema) -> String {
1106    match field.field_type {
1107        FieldType::Uuid => "uuid::Uuid".to_string(),
1108        FieldType::String | FieldType::Enum | FieldType::File => "String".to_string(),
1109        FieldType::Integer => "i32".to_string(),
1110        FieldType::Bigint => "i64".to_string(),
1111        FieldType::Number => "f64".to_string(),
1112        FieldType::Boolean => "bool".to_string(),
1113        FieldType::Timestamp => "chrono::DateTime<chrono::Utc>".to_string(),
1114        FieldType::Date => "chrono::NaiveDate".to_string(),
1115        FieldType::Json => "serde_json::Value".to_string(),
1116        FieldType::Array => match field.items.as_deref() {
1117            Some("uuid") => "Vec<uuid::Uuid>".to_string(),
1118            Some("integer") => "Vec<i32>".to_string(),
1119            Some("bigint") => "Vec<i64>".to_string(),
1120            Some("number") => "Vec<f64>".to_string(),
1121            Some("boolean") => "Vec<bool>".to_string(),
1122            Some("timestamp") => "Vec<chrono::DateTime<chrono::Utc>>".to_string(),
1123            Some("date") => "Vec<chrono::NaiveDate>".to_string(),
1124            _ => "Vec<String>".to_string(),
1125        },
1126    }
1127}
1128
1129fn parse_type(field: &FieldSchema) -> String {
1130    query_type(field)
1131}
1132
1133fn model_field_type(field: &FieldSchema) -> String {
1134    let base = query_type(field);
1135    if model_field_is_optional(field) {
1136        format!("Option<{base}>")
1137    } else {
1138        base
1139    }
1140}
1141
1142fn model_field_is_optional(field: &FieldSchema) -> bool {
1143    !(field.primary || (field.required && !field.nullable))
1144}
1145
1146fn field_is_required(field: &FieldSchema) -> bool {
1147    field.primary || (field.required && !field.nullable)
1148}
1149
1150fn generated_value_expression(field: &FieldSchema) -> String {
1151    match field.field_type {
1152        FieldType::Uuid => "uuid::Uuid::new_v4()".to_string(),
1153        FieldType::Timestamp => {
1154            if model_field_is_optional(field) {
1155                "Some(chrono::Utc::now())".to_string()
1156            } else {
1157                "chrono::Utc::now()".to_string()
1158            }
1159        }
1160        FieldType::Date => {
1161            if model_field_is_optional(field) {
1162                "Some(chrono::Utc::now().date_naive())".to_string()
1163            } else {
1164                "chrono::Utc::now().date_naive()".to_string()
1165            }
1166        }
1167        _ => "Default::default()".to_string(),
1168    }
1169}
1170
1171fn default_expression(
1172    field_name: &str,
1173    field: &FieldSchema,
1174    default: &serde_json::Value,
1175) -> Result<String, String> {
1176    Ok(match field.field_type {
1177        FieldType::Uuid => format!(
1178            "parse_embedded_json::<uuid::Uuid>({field_name:?}, serde_json::json!({default}))?"
1179        ),
1180        FieldType::String | FieldType::Enum | FieldType::File => {
1181            let value = default
1182                .as_str()
1183                .ok_or_else(|| format!("Default for '{field_name}' must be a string"))?;
1184            format!("{value:?}.to_string()")
1185        }
1186        FieldType::Integer => format!(
1187            "parse_embedded_json::<i32>({field_name:?}, serde_json::json!({default}))?"
1188        ),
1189        FieldType::Bigint => format!(
1190            "parse_embedded_json::<i64>({field_name:?}, serde_json::json!({default}))?"
1191        ),
1192        FieldType::Number => format!(
1193            "parse_embedded_json::<f64>({field_name:?}, serde_json::json!({default}))?"
1194        ),
1195        FieldType::Boolean => default
1196            .as_bool()
1197            .ok_or_else(|| format!("Default for '{field_name}' must be a boolean"))?
1198            .to_string(),
1199        FieldType::Timestamp => format!(
1200            "parse_embedded_json::<chrono::DateTime<chrono::Utc>>({field_name:?}, serde_json::json!({default}))?"
1201        ),
1202        FieldType::Date => format!(
1203            "parse_embedded_json::<chrono::NaiveDate>({field_name:?}, serde_json::json!({default}))?"
1204        ),
1205        FieldType::Json => format!("serde_json::json!({default})"),
1206        FieldType::Array => format!(
1207            "parse_embedded_json::<{}>({field_name:?}, serde_json::json!({default}))?",
1208            query_type(field)
1209        ),
1210    })
1211}
1212
1213fn has_soft_delete(resource: &ResourceDefinition) -> bool {
1214    resource
1215        .endpoints
1216        .as_ref()
1217        .map(|endpoints| endpoints.values().any(|endpoint| endpoint.soft_delete))
1218        .unwrap_or(false)
1219}
1220
1221fn sanitize_identifier(value: &str) -> String {
1222    let mut output = String::new();
1223    for ch in value.chars() {
1224        if ch.is_ascii_alphanumeric() {
1225            output.push(ch.to_ascii_lowercase());
1226        } else {
1227            output.push('_');
1228        }
1229    }
1230
1231    if output.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1232        output.insert(0, '_');
1233    }
1234
1235    output
1236}
1237
1238fn to_pascal_case(value: &str) -> String {
1239    value
1240        .split('_')
1241        .filter(|part| !part.is_empty())
1242        .map(|part| {
1243            let mut chars = part.chars();
1244            match chars.next() {
1245                Some(first) => {
1246                    let mut segment = String::new();
1247                    segment.extend(first.to_uppercase());
1248                    segment.push_str(chars.as_str());
1249                    segment
1250                }
1251                None => String::new(),
1252            }
1253        })
1254        .collect::<String>()
1255}
1256
1257fn indent_block(block: &str, indent: usize) -> String {
1258    if block.trim().is_empty() {
1259        return String::new();
1260    }
1261
1262    let prefix = "    ".repeat(indent);
1263    block
1264        .lines()
1265        .map(|line| {
1266            if line.is_empty() {
1267                String::new()
1268            } else {
1269                format!("{prefix}{line}")
1270            }
1271        })
1272        .collect::<Vec<_>>()
1273        .join("\n")
1274}
1275
1276#[cfg(test)]
1277mod tests {
1278    use super::*;
1279    use indexmap::IndexMap;
1280    use shaperail_core::{
1281        AuthRule, EndpointSpec, FieldSchema, HttpMethod, PaginationStyle, ResourceDefinition,
1282    };
1283
1284    fn sample_resource() -> ResourceDefinition {
1285        let mut schema = IndexMap::new();
1286        schema.insert(
1287            "id".to_string(),
1288            FieldSchema {
1289                field_type: FieldType::Uuid,
1290                primary: true,
1291                generated: true,
1292                required: false,
1293                unique: false,
1294                nullable: false,
1295                reference: None,
1296                min: None,
1297                max: None,
1298                format: None,
1299                values: None,
1300                default: None,
1301                sensitive: false,
1302                search: false,
1303                items: None,
1304            },
1305        );
1306        schema.insert(
1307            "email".to_string(),
1308            FieldSchema {
1309                field_type: FieldType::String,
1310                primary: false,
1311                generated: false,
1312                required: true,
1313                unique: true,
1314                nullable: false,
1315                reference: None,
1316                min: None,
1317                max: None,
1318                format: None,
1319                values: None,
1320                default: None,
1321                sensitive: false,
1322                search: true,
1323                items: None,
1324            },
1325        );
1326        schema.insert(
1327            "created_at".to_string(),
1328            FieldSchema {
1329                field_type: FieldType::Timestamp,
1330                primary: false,
1331                generated: true,
1332                required: false,
1333                unique: false,
1334                nullable: false,
1335                reference: None,
1336                min: None,
1337                max: None,
1338                format: None,
1339                values: None,
1340                default: None,
1341                sensitive: false,
1342                search: false,
1343                items: None,
1344            },
1345        );
1346
1347        let mut endpoints = indexmap::IndexMap::new();
1348        endpoints.insert(
1349            "list".to_string(),
1350            EndpointSpec {
1351                method: Some(HttpMethod::Get),
1352                path: Some("/users".to_string()),
1353                auth: Some(AuthRule::Public),
1354                input: None,
1355                filters: Some(vec!["email".to_string()]),
1356                search: Some(vec!["email".to_string()]),
1357                pagination: Some(PaginationStyle::Cursor),
1358                sort: Some(vec!["created_at".to_string()]),
1359                cache: None,
1360                controller: None,
1361                events: None,
1362                jobs: None,
1363                upload: None,
1364                soft_delete: false,
1365            },
1366        );
1367
1368        ResourceDefinition {
1369            resource: "users".to_string(),
1370            version: 1,
1371            db: None,
1372            tenant_key: None,
1373            schema,
1374            endpoints: Some(endpoints),
1375            relations: None,
1376            indexes: None,
1377        }
1378    }
1379
1380    #[test]
1381    fn generates_query_as_store_module() {
1382        let resource = sample_resource();
1383        let code = generate_resource_module(&resource).unwrap();
1384
1385        assert!(code.contains("impl ResourceStore for UsersStore"));
1386        assert!(code.contains("sqlx::query_as!"));
1387        assert!(code.contains("find_all_list"));
1388    }
1389
1390    #[test]
1391    fn generates_registry_module() {
1392        let resource = sample_resource();
1393        let project = generate_project(&[resource]).unwrap();
1394
1395        assert!(project.mod_rs.contains("pub mod users;"));
1396        assert!(project.mod_rs.contains("build_store_registry"));
1397    }
1398}