Skip to main content

shaperail_codegen/
rust.rs

1use shaperail_core::{EndpointSpec, FieldSchema, FieldType, HttpMethod, ResourceDefinition};
2
3pub struct GeneratedRustModule {
4    pub file_name: String,
5    pub contents: String,
6}
7
8pub struct GeneratedRustProject {
9    pub modules: Vec<GeneratedRustModule>,
10    pub mod_rs: String,
11}
12
13pub fn generate_project(resources: &[ResourceDefinition]) -> Result<GeneratedRustProject, String> {
14    let mut modules = Vec::with_capacity(resources.len());
15    for resource in resources {
16        modules.push(GeneratedRustModule {
17            file_name: format!("{}.rs", resource.resource),
18            contents: generate_resource_module(resource)?,
19        });
20    }
21
22    Ok(GeneratedRustProject {
23        modules,
24        mod_rs: generate_registry_module(resources),
25    })
26}
27
28pub fn generate_resource_module(resource: &ResourceDefinition) -> Result<String, String> {
29    let context = ResourceContext::new(resource)?;
30
31    let model_fields = resource
32        .schema
33        .iter()
34        .map(|(name, field)| format!("    pub {name}: {},", model_field_type(field)))
35        .collect::<Vec<_>>()
36        .join("\n");
37
38    let list_helpers = context
39        .collection_endpoints
40        .iter()
41        .map(|endpoint| generate_list_helper(&context, endpoint))
42        .collect::<Result<Vec<_>, _>>()?
43        .join("\n\n");
44
45    let list_dispatch = if context.collection_endpoints.is_empty() {
46        "        let _ = (endpoint, filters, search, sort, page);\n        Err(shaperail_core::ShaperailError::Internal(\"No collection endpoints are available for generated list queries\".to_string()))".to_string()
47    } else {
48        let arms = context
49            .collection_endpoints
50            .iter()
51            .map(|endpoint| {
52                format!(
53                    "            {path:?} => self.{helper}(filters, search, sort, page).await,",
54                    path = endpoint.spec.path,
55                    helper = endpoint.helper_name
56                )
57            })
58            .collect::<Vec<_>>()
59            .join("\n");
60
61        format!(
62            "        match endpoint.path.as_str() {{\n{arms}\n            _ => Err(shaperail_core::ShaperailError::Internal(format!(\"No generated list query for {{}}\", endpoint.path))),\n        }}"
63        )
64    };
65
66    Ok(format!(
67        r###"//! Generated query module for the `{resource_name}` resource.
68//! DO NOT EDIT — this file is auto-generated by `shaperail generate`.
69
70use serde::{{Deserialize, Serialize}};
71use serde_json::{{Map, Value}};
72use shaperail_core::EndpointSpec;
73#[allow(unused_imports)]
74use shaperail_runtime::db::{{
75    async_trait, parse_embedded_json, parse_filter, parse_optional_json, require_field,
76    row_from_model, sort_direction_at, sort_field_at, FilterSet, PageRequest, ResourceRow,
77    ResourceStore, SearchParam, SortParam,
78}};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct {record_name} {{
82{model_fields}
83}}
84
85pub struct {store_name} {{
86    pool: sqlx::PgPool,
87}}
88
89impl {store_name} {{
90    pub fn new(pool: sqlx::PgPool) -> Self {{
91        Self {{ pool }}
92    }}
93
94{list_helpers}
95}}
96
97#[async_trait]
98impl ResourceStore for {store_name} {{
99    fn resource_name(&self) -> &str {{
100        "{resource_name}"
101    }}
102
103    async fn find_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
104        let row = sqlx::query_as!(
105            {record_name},
106            r#"
107            SELECT
108                {select_columns}
109            FROM "{table_name}"
110            WHERE "{primary_key}" = $1{soft_delete_where}
111            "#,
112            id
113        )
114        .fetch_optional(&self.pool)
115        .await?
116        .ok_or(shaperail_core::ShaperailError::NotFound)?;
117
118        row_from_model(&row)
119    }}
120
121    async fn find_all(
122        &self,
123        endpoint: &EndpointSpec,
124        filters: &FilterSet,
125        search: Option<&SearchParam>,
126        sort: &SortParam,
127        page: &PageRequest,
128    ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
129{list_dispatch}
130    }}
131
132    async fn insert(&self, data: &Map<String, Value>) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
133{insert_body}
134    }}
135
136    async fn update_by_id(
137        &self,
138        id: &uuid::Uuid,
139        data: &Map<String, Value>,
140    ) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
141{update_body}
142    }}
143
144    async fn soft_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
145{soft_delete_body}
146    }}
147
148    async fn hard_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, shaperail_core::ShaperailError> {{
149{hard_delete_body}
150    }}
151}}
152"###,
153        resource_name = resource.resource,
154        record_name = context.record_name,
155        store_name = context.store_name,
156        model_fields = model_fields,
157        list_helpers = list_helpers,
158        select_columns = context.select_columns,
159        table_name = resource.resource,
160        primary_key = context.primary_key,
161        soft_delete_where = context.soft_delete_where,
162        list_dispatch = list_dispatch,
163        insert_body = generate_insert_body(&context)?,
164        update_body = generate_update_body(&context)?,
165        soft_delete_body = generate_soft_delete_body(&context),
166        hard_delete_body = generate_hard_delete_body(&context),
167    ))
168}
169
170fn generate_registry_module(resources: &[ResourceDefinition]) -> String {
171    let module_lines = resources
172        .iter()
173        .map(|resource| format!("pub mod {};", resource.resource))
174        .collect::<Vec<_>>()
175        .join("\n");
176
177    let registry_lines = resources
178        .iter()
179        .map(|resource| {
180            let store_name = format!("{}Store", to_pascal_case(&resource.resource));
181            format!(
182                "    stores.insert({name:?}.to_string(), std::sync::Arc::new({module}::{store_name}::new(pool.clone())));",
183                name = resource.resource,
184                module = resource.resource
185            )
186        })
187        .collect::<Vec<_>>()
188        .join("\n");
189
190    format!(
191        r#"{module_lines}
192
193pub fn build_store_registry(pool: sqlx::PgPool) -> shaperail_runtime::db::StoreRegistry {{
194    let mut stores: std::collections::HashMap<
195        String,
196        std::sync::Arc<dyn shaperail_runtime::db::ResourceStore>,
197    > = std::collections::HashMap::new();
198{registry_lines}
199    std::sync::Arc::new(stores)
200}}
201"#
202    )
203}
204
205#[derive(Clone)]
206struct CollectionEndpoint<'a> {
207    spec: &'a EndpointSpec,
208    helper_name: String,
209}
210
211struct ResourceContext<'a> {
212    resource: &'a ResourceDefinition,
213    record_name: String,
214    store_name: String,
215    primary_key: String,
216    select_columns: String,
217    soft_delete_where: String,
218    collection_endpoints: Vec<CollectionEndpoint<'a>>,
219}
220
221impl<'a> ResourceContext<'a> {
222    fn new(resource: &'a ResourceDefinition) -> Result<Self, String> {
223        let primary_key = resource
224            .schema
225            .iter()
226            .find(|(_, field)| field.primary)
227            .map(|(name, _)| name.clone())
228            .unwrap_or_else(|| "id".to_string());
229
230        let select_columns = resource
231            .schema
232            .iter()
233            .map(|(name, field)| select_column_sql(name, field))
234            .collect::<Vec<_>>()
235            .join(",\n                ");
236
237        let collection_endpoints = resource
238            .endpoints
239            .as_ref()
240            .map(|endpoints| {
241                endpoints
242                    .iter()
243                    .filter(|(_, endpoint)| {
244                        endpoint.method == HttpMethod::Get && !endpoint.path.contains(":id")
245                    })
246                    .map(|(name, endpoint)| CollectionEndpoint {
247                        spec: endpoint,
248                        helper_name: format!("find_all_{}", sanitize_identifier(name)),
249                    })
250                    .collect::<Vec<_>>()
251            })
252            .unwrap_or_default();
253
254        Ok(Self {
255            resource,
256            record_name: format!("{}Record", to_pascal_case(&resource.resource)),
257            store_name: format!("{}Store", to_pascal_case(&resource.resource)),
258            primary_key,
259            select_columns,
260            soft_delete_where: if has_soft_delete(resource) {
261                " AND \"deleted_at\" IS NULL".to_string()
262            } else {
263                String::new()
264            },
265            collection_endpoints,
266        })
267    }
268}
269
270fn generate_list_helper(
271    context: &ResourceContext<'_>,
272    endpoint: &CollectionEndpoint<'_>,
273) -> Result<String, String> {
274    let filters = endpoint.spec.filters.clone().unwrap_or_default();
275    let search_fields = endpoint.spec.search.clone().unwrap_or_default();
276    let sort_fields = endpoint.spec.sort.clone().unwrap_or_default();
277
278    let filter_decls = filters
279        .iter()
280        .map(|field_name| {
281            let field = context.resource.schema.get(field_name).ok_or_else(|| {
282                format!(
283                    "Unknown filter field '{field_name}' on resource '{}'",
284                    context.resource.resource
285                )
286            })?;
287            Ok(generate_filter_declaration(field_name, field))
288        })
289        .collect::<Result<Vec<_>, String>>()?
290        .join("\n");
291
292    let filter_args = filters
293        .iter()
294        .map(|field_name| {
295            parameter_expression(
296                field_name,
297                context
298                    .resource
299                    .schema
300                    .get(field_name)
301                    .expect("filter field validated"),
302            )
303        })
304        .collect::<Vec<_>>();
305
306    let search_decl = if search_fields.is_empty() {
307        String::new()
308    } else {
309        "        let search_term = search.map(|value| value.term.clone());".to_string()
310    };
311
312    let search_predicate = if search_fields.is_empty() {
313        String::new()
314    } else {
315        search_expression(&search_fields)
316    };
317
318    let sort_decls = (0..sort_fields.len())
319        .map(|index| {
320            format!(
321                "        let sort_field_{index} = sort_field_at(sort, {index});\n        let sort_direction_{index} = sort_direction_at(sort, {index});"
322            )
323        })
324        .collect::<Vec<_>>()
325        .join("\n");
326
327    let filter_positions = filters
328        .iter()
329        .enumerate()
330        .map(|(index, field_name)| (field_name.clone(), index + 1))
331        .collect::<Vec<_>>();
332
333    let search_position = if search_fields.is_empty() {
334        None
335    } else {
336        Some(filter_positions.len() + 1)
337    };
338
339    let cursor_position = filter_positions.len() + usize::from(search_position.is_some()) + 1;
340    let cursor_sort_positions = (0..sort_fields.len())
341        .map(|index| {
342            let base = cursor_position + 1 + (index * 2);
343            (base, base + 1)
344        })
345        .collect::<Vec<_>>();
346    let offset_sort_positions = (0..sort_fields.len())
347        .map(|index| {
348            let base =
349                filter_positions.len() + usize::from(search_position.is_some()) + 1 + (index * 2);
350            (base, base + 1)
351        })
352        .collect::<Vec<_>>();
353
354    let filter_predicates = generate_filter_predicates(context, &filter_positions)?;
355    let filter_clause = filter_predicates.join("\n");
356
357    let cursor_order_by = generate_order_by(context, &sort_fields, &cursor_sort_positions)?;
358    let offset_order_by = generate_order_by(context, &sort_fields, &offset_sort_positions)?;
359
360    let mut cursor_args = filter_args.clone();
361    if search_position.is_some() {
362        cursor_args.push("search_term.as_deref()".to_string());
363    }
364    cursor_args.push("cursor".to_string());
365    for index in 0..sort_fields.len() {
366        cursor_args.push(format!("sort_field_{index}.as_deref()"));
367        cursor_args.push(format!("sort_direction_{index}"));
368    }
369    cursor_args.push("*limit + 1".to_string());
370
371    let mut count_args = filter_args.clone();
372    if search_position.is_some() {
373        count_args.push("search_term.as_deref()".to_string());
374    }
375
376    let mut row_args = count_args.clone();
377    for index in 0..sort_fields.len() {
378        row_args.push(format!("sort_field_{index}.as_deref()"));
379        row_args.push(format!("sort_direction_{index}"));
380    }
381    row_args.push("*limit".to_string());
382    row_args.push("*offset".to_string());
383
384    let cursor_query = generate_cursor_query(
385        context,
386        &filter_clause,
387        search_position.map(|position| (position, search_predicate.as_str())),
388        cursor_position,
389        &cursor_order_by,
390        cursor_args.len(),
391        &cursor_args,
392    );
393    let offset_query = generate_offset_query(
394        context,
395        &filter_clause,
396        search_position.map(|position| (position, search_predicate.as_str())),
397        &offset_order_by,
398        row_args.len() - 1,
399        row_args.len(),
400        &count_args,
401        &row_args,
402    );
403
404    Ok(format!(
405        r###"    async fn {helper_name}(
406        &self,
407        filters: &FilterSet,
408        search: Option<&SearchParam>,
409        sort: &SortParam,
410        page: &PageRequest,
411    ) -> Result<(Vec<ResourceRow>, Value), shaperail_core::ShaperailError> {{
412{filter_decls}
413{search_decl}
414{sort_decls}
415
416        match page {{
417            PageRequest::Cursor {{ after, limit }} => {{
418                let cursor = match after {{
419                    Some(cursor_value) => Some(uuid::Uuid::parse_str(
420                        &shaperail_runtime::db::decode_cursor(cursor_value)?
421                    ).map_err(|_| shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {{
422                        field: "cursor".to_string(),
423                        message: "Invalid cursor value".to_string(),
424                        code: "invalid_cursor".to_string(),
425                    }}]))?),
426                    None => None,
427                }};
428{cursor_query}
429            }}
430            PageRequest::Offset {{ offset, limit }} => {{
431{offset_query}
432            }}
433        }}
434    }}"###,
435        helper_name = endpoint.helper_name,
436        filter_decls = indent_block(&filter_decls, 2),
437        search_decl = indent_block(&search_decl, 2),
438        sort_decls = indent_block(&sort_decls, 2),
439        cursor_query = indent_block(&cursor_query, 4),
440        offset_query = indent_block(&offset_query, 4),
441    ))
442}
443
444fn generate_cursor_query(
445    context: &ResourceContext<'_>,
446    filter_clause: &str,
447    search_position: Option<(usize, &str)>,
448    cursor_position: usize,
449    order_by: &str,
450    limit_position: usize,
451    args: &[String],
452) -> String {
453    format!(
454        r###"                let rows = sqlx::query_as!(
455                    {record_name},
456                    r#"
457                    SELECT
458                        {select_columns}
459                    FROM "{table_name}"
460                    WHERE TRUE
461{soft_delete_clause}
462{filter_clause}{search_clause}
463                        AND (${cursor_position}::uuid IS NULL OR "{primary_key}" > ${cursor_position})
464                    ORDER BY
465{order_by}
466                    LIMIT ${limit_position}
467                    "#,
468                    {args}
469                )
470                .fetch_all(&self.pool)
471                .await?;
472
473                let has_more = rows.len() as i64 > *limit;
474                let mut result_rows = rows;
475                if has_more {{
476                    result_rows.truncate(*limit as usize);
477                }}
478
479                let data = result_rows
480                    .iter()
481                    .map(row_from_model)
482                    .collect::<Result<Vec<_>, _>>()?;
483                let cursor = if has_more {{
484                    result_rows
485                        .last()
486                        .map(|row| shaperail_runtime::db::encode_cursor(&row.{primary_key}.to_string()))
487                }} else {{
488                    None
489                }};
490
491                Ok((
492                    data,
493                    serde_json::json!({{
494                        "cursor": cursor,
495                        "has_more": has_more
496                    }})
497                ))"###,
498        record_name = context.record_name,
499        select_columns = context.select_columns,
500        table_name = context.resource.resource,
501        soft_delete_clause = if has_soft_delete(context.resource) {
502            "                        AND \"deleted_at\" IS NULL\n"
503        } else {
504            ""
505        },
506        filter_clause = if filter_clause.is_empty() {
507            String::new()
508        } else {
509            format!("{filter_clause}\n")
510        },
511        search_clause = search_position
512            .map(|(position, expression)| {
513                format!(
514                    "\n                        AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
515                )
516            })
517            .unwrap_or_default(),
518        cursor_position = cursor_position,
519        primary_key = context.primary_key,
520        order_by = order_by,
521        limit_position = limit_position,
522        args = args.join(",\n                    "),
523    )
524}
525
526#[allow(clippy::too_many_arguments)]
527fn generate_offset_query(
528    context: &ResourceContext<'_>,
529    filter_clause: &str,
530    search_position: Option<(usize, &str)>,
531    order_by: &str,
532    limit_position: usize,
533    offset_position: usize,
534    count_args: &[String],
535    row_args: &[String],
536) -> String {
537    let count_macro_args = if count_args.is_empty() {
538        String::new()
539    } else {
540        format!(
541            ",\n                    {}",
542            count_args.join(",\n                    ")
543        )
544    };
545
546    format!(
547        r###"                let total = sqlx::query_scalar!(
548                    r#"
549                    SELECT COUNT(*) as "count!"
550                    FROM "{table_name}"
551                    WHERE TRUE
552{soft_delete_clause}
553{filter_clause}{search_clause}
554                    "#{count_macro_args}
555                )
556                .fetch_one(&self.pool)
557                .await?;
558
559                let rows = sqlx::query_as!(
560                    {record_name},
561                    r#"
562                    SELECT
563                        {select_columns}
564                    FROM "{table_name}"
565                    WHERE TRUE
566{soft_delete_clause}
567{filter_clause}{search_clause}
568                    ORDER BY
569{order_by}
570                    LIMIT ${limit_param}
571                    OFFSET ${offset_param}
572                    "#,
573                    {row_args}
574                )
575                .fetch_all(&self.pool)
576                .await?;
577
578                let data = rows
579                    .iter()
580                    .map(row_from_model)
581                    .collect::<Result<Vec<_>, _>>()?;
582
583                Ok((
584                    data,
585                    serde_json::json!({{
586                        "offset": offset,
587                        "limit": limit,
588                        "total": total
589                    }})
590                ))"###,
591        table_name = context.resource.resource,
592        soft_delete_clause = if has_soft_delete(context.resource) {
593            "                        AND \"deleted_at\" IS NULL\n"
594        } else {
595            ""
596        },
597        filter_clause = if filter_clause.is_empty() {
598            String::new()
599        } else {
600            format!("{filter_clause}\n")
601        },
602        search_clause = search_position
603            .map(|(position, expression)| {
604                format!(
605                    "\n                        AND (${position}::text IS NULL OR to_tsvector('english', {expression}) @@ plainto_tsquery('english', ${position}))"
606                )
607            })
608            .unwrap_or_default(),
609        count_macro_args = count_macro_args,
610        record_name = context.record_name,
611        select_columns = context.select_columns,
612        order_by = order_by,
613        limit_param = limit_position,
614        offset_param = offset_position,
615        row_args = row_args.join(",\n                    "),
616    )
617}
618
619fn generate_filter_predicates(
620    context: &ResourceContext<'_>,
621    positions: &[(String, usize)],
622) -> Result<Vec<String>, String> {
623    positions
624        .iter()
625        .map(|(field_name, position)| {
626            let field = context.resource.schema.get(field_name).ok_or_else(|| {
627                format!(
628                    "Unknown filter field '{field_name}' on resource '{}'",
629                    context.resource.resource
630                )
631            })?;
632            Ok(format!(
633                "                AND (${position}::{cast} IS NULL OR \"{field_name}\" = ${position})",
634                cast = sql_cast_type(field)
635            ))
636        })
637        .collect()
638}
639
640fn generate_order_by(
641    context: &ResourceContext<'_>,
642    sort_fields: &[String],
643    positions: &[(usize, usize)],
644) -> Result<String, String> {
645    if sort_fields.is_empty() {
646        return Ok(format!("\"{}\" ASC", context.primary_key));
647    }
648
649    let mut clauses = Vec::new();
650    for ((field_param, direction_param), field_name) in positions.iter().zip(sort_fields) {
651        for candidate in sort_fields {
652            let field = context.resource.schema.get(candidate).ok_or_else(|| {
653                format!(
654                    "Unknown sort field '{candidate}' on resource '{}'",
655                    context.resource.resource
656                )
657            })?;
658            let sort_expr = sortable_expression(candidate, field);
659            clauses.push(format!(
660                "                CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'asc' THEN {sort_expr} END ASC"
661            ));
662            clauses.push(format!(
663                "                CASE WHEN ${field_param}::text = '{candidate}' AND ${direction_param}::text = 'desc' THEN {sort_expr} END DESC"
664            ));
665        }
666        let _ = field_name;
667    }
668    clauses.push(format!("                \"{}\" ASC", context.primary_key));
669    Ok(clauses.join(",\n"))
670}
671
672fn generate_insert_body(context: &ResourceContext<'_>) -> Result<String, String> {
673    let mut declarations = Vec::new();
674    let mut columns = Vec::new();
675    let mut values = Vec::new();
676    let mut args = Vec::new();
677
678    for (index, (field_name, field)) in context.resource.schema.iter().enumerate() {
679        let variable_name = sanitize_identifier(field_name);
680        declarations.push(generate_insert_declaration(
681            field_name,
682            field,
683            &variable_name,
684        )?);
685        columns.push(format!("\"{field_name}\""));
686        values.push(format!("${}", index + 1));
687        args.push(variable_name);
688    }
689
690    Ok(format!(
691        r###"{declarations}
692        let row = sqlx::query_as!(
693            {record_name},
694            r#"
695            INSERT INTO "{table_name}" ({columns})
696            VALUES ({values})
697            RETURNING
698                {select_columns}
699            "#,
700            {args}
701        )
702        .fetch_one(&self.pool)
703        .await?;
704
705        row_from_model(&row)"###,
706        declarations = declarations.join("\n"),
707        record_name = context.record_name,
708        table_name = context.resource.resource,
709        columns = columns.join(", "),
710        values = values.join(", "),
711        select_columns = context.select_columns,
712        args = args.join(",\n            "),
713    ))
714}
715
716fn generate_update_body(context: &ResourceContext<'_>) -> Result<String, String> {
717    let mut declarations = Vec::new();
718    let mut set_clauses = Vec::new();
719    let mut args = vec!["id".to_string()];
720    let mut has_mutable_fields = Vec::new();
721    let mut index = 2usize;
722
723    for (field_name, field) in &context.resource.schema {
724        if field.primary || field.generated {
725            continue;
726        }
727
728        let present_name = format!("{}_present", sanitize_identifier(field_name));
729        let value_name = sanitize_identifier(field_name);
730        declarations.push(generate_update_declaration(
731            field_name,
732            field,
733            &present_name,
734            &value_name,
735        ));
736        has_mutable_fields.push(present_name.clone());
737        set_clauses.push(format!(
738            "\"{field_name}\" = CASE WHEN ${present_param} THEN ${value_param} ELSE \"{field_name}\" END",
739            present_param = index,
740            value_param = index + 1
741        ));
742        args.push(present_name);
743        args.push(value_name);
744        index += 2;
745    }
746
747    if let Some(updated_at) = context.resource.schema.get("updated_at") {
748        if updated_at.generated && updated_at.field_type == FieldType::Timestamp {
749            declarations.push("        let updated_at = chrono::Utc::now();".to_string());
750            set_clauses.push(format!("\"updated_at\" = ${index}"));
751            args.push("updated_at".to_string());
752        }
753    }
754
755    let guard = if has_mutable_fields.is_empty() {
756        String::new()
757    } else {
758        format!(
759            "        if !({}) {}",
760            has_mutable_fields.join(" || "),
761            r#"{
762            return Err(shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
763                field: "body".to_string(),
764                message: "No valid fields to update".to_string(),
765                code: "empty_update".to_string(),
766            }]));
767        }"#
768        )
769    };
770
771    Ok(format!(
772        r###"{declarations}
773{guard}
774        let row = sqlx::query_as!(
775            {record_name},
776            r#"
777            UPDATE "{table_name}"
778            SET {set_clauses}
779            WHERE "{primary_key}" = $1{soft_delete_where}
780            RETURNING
781                {select_columns}
782            "#,
783            {args}
784        )
785        .fetch_optional(&self.pool)
786        .await?
787        .ok_or(shaperail_core::ShaperailError::NotFound)?;
788
789        row_from_model(&row)"###,
790        declarations = declarations.join("\n"),
791        guard = guard,
792        record_name = context.record_name,
793        table_name = context.resource.resource,
794        set_clauses = set_clauses.join(", "),
795        primary_key = context.primary_key,
796        soft_delete_where = context.soft_delete_where,
797        select_columns = context.select_columns,
798        args = args.join(",\n            "),
799    ))
800}
801
802fn generate_soft_delete_body(context: &ResourceContext<'_>) -> String {
803    format!(
804        r###"        let deleted_at = chrono::Utc::now();
805        let row = sqlx::query_as!(
806            {record_name},
807            r#"
808            UPDATE "{table_name}"
809            SET "deleted_at" = $2
810            WHERE "{primary_key}" = $1 AND "deleted_at" IS NULL
811            RETURNING
812                {select_columns}
813            "#,
814            id,
815            deleted_at
816        )
817        .fetch_optional(&self.pool)
818        .await?
819        .ok_or(shaperail_core::ShaperailError::NotFound)?;
820
821        row_from_model(&row)"###,
822        record_name = context.record_name,
823        table_name = context.resource.resource,
824        primary_key = context.primary_key,
825        select_columns = context.select_columns,
826    )
827}
828
829fn generate_hard_delete_body(context: &ResourceContext<'_>) -> String {
830    format!(
831        r###"        let row = sqlx::query_as!(
832            {record_name},
833            r#"
834            DELETE FROM "{table_name}"
835            WHERE "{primary_key}" = $1
836            RETURNING
837                {select_columns}
838            "#,
839            id
840        )
841        .fetch_optional(&self.pool)
842        .await?
843        .ok_or(shaperail_core::ShaperailError::NotFound)?;
844
845        row_from_model(&row)"###,
846        record_name = context.record_name,
847        table_name = context.resource.resource,
848        primary_key = context.primary_key,
849        select_columns = context.select_columns,
850    )
851}
852
853fn generate_insert_declaration(
854    field_name: &str,
855    field: &FieldSchema,
856    variable_name: &str,
857) -> Result<String, String> {
858    if field.generated {
859        return Ok(format!(
860            "        let {variable_name} = {};",
861            generated_value_expression(field)
862        ));
863    }
864
865    let parse_type = parse_type(field);
866    let parsed = format!(
867        "shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?"
868    );
869
870    let expression = match (field_is_required(field), field.default.as_ref()) {
871        (true, Some(default)) => format!(
872            "match {parsed} {{ Some(value) => value, None => {} }}",
873            default_expression(field_name, field, default)?
874        ),
875        (true, None) => format!("shaperail_runtime::db::require_field({parsed}, {field_name:?})?"),
876        (false, Some(default)) if model_field_is_optional(field) => format!(
877            "match {parsed} {{ Some(value) => Some(value), None => Some({}) }}",
878            default_expression(field_name, field, default)?
879        ),
880        (false, Some(default)) => format!(
881            "match {parsed} {{ Some(value) => value, None => {} }}",
882            default_expression(field_name, field, default)?
883        ),
884        (false, None) => parsed,
885    };
886
887    Ok(format!("        let {variable_name} = {expression};"))
888}
889
890fn generate_update_declaration(
891    field_name: &str,
892    field: &FieldSchema,
893    present_name: &str,
894    value_name: &str,
895) -> String {
896    format!(
897        "        let {present_name} = data.contains_key({field_name:?});\n        let {value_name} = shaperail_runtime::db::parse_optional_json::<{parse_type}>(data, {field_name:?})?;",
898        parse_type = parse_type(field)
899    )
900}
901
902fn generate_filter_declaration(field_name: &str, field: &FieldSchema) -> String {
903    let parser = match field.field_type {
904        FieldType::Uuid => "uuid::Uuid::parse_str(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid uuid filter\".to_string()))",
905        FieldType::String | FieldType::Enum | FieldType::File => "Ok(text.to_string())",
906        FieldType::Integer => "text.parse::<i32>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid integer filter\".to_string()))",
907        FieldType::Bigint => "text.parse::<i64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid bigint filter\".to_string()))",
908        FieldType::Number => "text.parse::<f64>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid number filter\".to_string()))",
909        FieldType::Boolean => "text.parse::<bool>().map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid boolean filter\".to_string()))",
910        FieldType::Timestamp => "chrono::DateTime::parse_from_rfc3339(text).map(|value| value.with_timezone(&chrono::Utc)).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid timestamp filter\".to_string()))",
911        FieldType::Date => "chrono::NaiveDate::parse_from_str(text, \"%Y-%m-%d\").map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid date filter\".to_string()))",
912        FieldType::Json => "serde_json::from_str::<serde_json::Value>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid json filter\".to_string()))",
913        FieldType::Array => "serde_json::from_str::<Vec<serde_json::Value>>(text).map_err(|_| shaperail_core::ShaperailError::Internal(\"invalid array filter\".to_string()))",
914    };
915
916    format!(
917        "        let {var} = parse_filter(filters, {field_name:?}, \"invalid_filter\", |text| {parser})?;",
918        var = field_parameter_name(field_name)
919    )
920}
921
922fn field_parameter_name(field_name: &str) -> String {
923    format!("filter_{}", sanitize_identifier(field_name))
924}
925
926fn parameter_expression(field_name: &str, field: &FieldSchema) -> String {
927    let var = field_parameter_name(field_name);
928    match field.field_type {
929        FieldType::String | FieldType::Enum | FieldType::File => format!("{var}.as_deref()"),
930        _ => var,
931    }
932}
933
934fn select_column_sql(field_name: &str, field: &FieldSchema) -> String {
935    let nullability = if model_field_is_optional(field) {
936        "?"
937    } else {
938        "!"
939    };
940    let expression = match field.field_type {
941        FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
942        _ => format!("\"{field_name}\""),
943    };
944    format!(
945        "{expression} as \"{field_name}{nullability}: {type_name}\"",
946        type_name = query_type(field)
947    )
948}
949
950fn sortable_expression(field_name: &str, field: &FieldSchema) -> String {
951    match field.field_type {
952        FieldType::Json | FieldType::Array | FieldType::Uuid => format!("\"{field_name}\"::text"),
953        FieldType::Number => format!("\"{field_name}\"::DOUBLE PRECISION"),
954        _ => format!("\"{field_name}\""),
955    }
956}
957
958fn search_expression(fields: &[String]) -> String {
959    fields
960        .iter()
961        .map(|field| format!("COALESCE(\"{field}\"::text, '')"))
962        .collect::<Vec<_>>()
963        .join(" || ' ' || ")
964}
965
966fn sql_cast_type(field: &FieldSchema) -> String {
967    match field.field_type {
968        FieldType::Uuid => "uuid".to_string(),
969        FieldType::String | FieldType::Enum | FieldType::File => "text".to_string(),
970        FieldType::Integer => "integer".to_string(),
971        FieldType::Bigint => "bigint".to_string(),
972        FieldType::Number => "double precision".to_string(),
973        FieldType::Boolean => "boolean".to_string(),
974        FieldType::Timestamp => "timestamptz".to_string(),
975        FieldType::Date => "date".to_string(),
976        FieldType::Json => "jsonb".to_string(),
977        FieldType::Array => match field.items.as_deref() {
978            Some("uuid") => "uuid[]".to_string(),
979            Some("integer") => "integer[]".to_string(),
980            Some("bigint") => "bigint[]".to_string(),
981            Some("number") => "double precision[]".to_string(),
982            Some("boolean") => "boolean[]".to_string(),
983            _ => "text[]".to_string(),
984        },
985    }
986}
987
988fn query_type(field: &FieldSchema) -> String {
989    match field.field_type {
990        FieldType::Uuid => "uuid::Uuid".to_string(),
991        FieldType::String | FieldType::Enum | FieldType::File => "String".to_string(),
992        FieldType::Integer => "i32".to_string(),
993        FieldType::Bigint => "i64".to_string(),
994        FieldType::Number => "f64".to_string(),
995        FieldType::Boolean => "bool".to_string(),
996        FieldType::Timestamp => "chrono::DateTime<chrono::Utc>".to_string(),
997        FieldType::Date => "chrono::NaiveDate".to_string(),
998        FieldType::Json => "serde_json::Value".to_string(),
999        FieldType::Array => match field.items.as_deref() {
1000            Some("uuid") => "Vec<uuid::Uuid>".to_string(),
1001            Some("integer") => "Vec<i32>".to_string(),
1002            Some("bigint") => "Vec<i64>".to_string(),
1003            Some("number") => "Vec<f64>".to_string(),
1004            Some("boolean") => "Vec<bool>".to_string(),
1005            Some("timestamp") => "Vec<chrono::DateTime<chrono::Utc>>".to_string(),
1006            Some("date") => "Vec<chrono::NaiveDate>".to_string(),
1007            _ => "Vec<String>".to_string(),
1008        },
1009    }
1010}
1011
1012fn parse_type(field: &FieldSchema) -> String {
1013    query_type(field)
1014}
1015
1016fn model_field_type(field: &FieldSchema) -> String {
1017    let base = query_type(field);
1018    if model_field_is_optional(field) {
1019        format!("Option<{base}>")
1020    } else {
1021        base
1022    }
1023}
1024
1025fn model_field_is_optional(field: &FieldSchema) -> bool {
1026    !(field.primary || (field.required && !field.nullable))
1027}
1028
1029fn field_is_required(field: &FieldSchema) -> bool {
1030    field.primary || (field.required && !field.nullable)
1031}
1032
1033fn generated_value_expression(field: &FieldSchema) -> String {
1034    match field.field_type {
1035        FieldType::Uuid => "uuid::Uuid::new_v4()".to_string(),
1036        FieldType::Timestamp => {
1037            if model_field_is_optional(field) {
1038                "Some(chrono::Utc::now())".to_string()
1039            } else {
1040                "chrono::Utc::now()".to_string()
1041            }
1042        }
1043        FieldType::Date => {
1044            if model_field_is_optional(field) {
1045                "Some(chrono::Utc::now().date_naive())".to_string()
1046            } else {
1047                "chrono::Utc::now().date_naive()".to_string()
1048            }
1049        }
1050        _ => "Default::default()".to_string(),
1051    }
1052}
1053
1054fn default_expression(
1055    field_name: &str,
1056    field: &FieldSchema,
1057    default: &serde_json::Value,
1058) -> Result<String, String> {
1059    Ok(match field.field_type {
1060        FieldType::Uuid => format!(
1061            "parse_embedded_json::<uuid::Uuid>({field_name:?}, serde_json::json!({default}))?"
1062        ),
1063        FieldType::String | FieldType::Enum | FieldType::File => {
1064            let value = default
1065                .as_str()
1066                .ok_or_else(|| format!("Default for '{field_name}' must be a string"))?;
1067            format!("{value:?}.to_string()")
1068        }
1069        FieldType::Integer => format!(
1070            "parse_embedded_json::<i32>({field_name:?}, serde_json::json!({default}))?"
1071        ),
1072        FieldType::Bigint => format!(
1073            "parse_embedded_json::<i64>({field_name:?}, serde_json::json!({default}))?"
1074        ),
1075        FieldType::Number => format!(
1076            "parse_embedded_json::<f64>({field_name:?}, serde_json::json!({default}))?"
1077        ),
1078        FieldType::Boolean => default
1079            .as_bool()
1080            .ok_or_else(|| format!("Default for '{field_name}' must be a boolean"))?
1081            .to_string(),
1082        FieldType::Timestamp => format!(
1083            "parse_embedded_json::<chrono::DateTime<chrono::Utc>>({field_name:?}, serde_json::json!({default}))?"
1084        ),
1085        FieldType::Date => format!(
1086            "parse_embedded_json::<chrono::NaiveDate>({field_name:?}, serde_json::json!({default}))?"
1087        ),
1088        FieldType::Json => format!("serde_json::json!({default})"),
1089        FieldType::Array => format!(
1090            "parse_embedded_json::<{}>({field_name:?}, serde_json::json!({default}))?",
1091            query_type(field)
1092        ),
1093    })
1094}
1095
1096fn has_soft_delete(resource: &ResourceDefinition) -> bool {
1097    resource
1098        .endpoints
1099        .as_ref()
1100        .map(|endpoints| endpoints.values().any(|endpoint| endpoint.soft_delete))
1101        .unwrap_or(false)
1102}
1103
1104fn sanitize_identifier(value: &str) -> String {
1105    let mut output = String::new();
1106    for ch in value.chars() {
1107        if ch.is_ascii_alphanumeric() {
1108            output.push(ch.to_ascii_lowercase());
1109        } else {
1110            output.push('_');
1111        }
1112    }
1113
1114    if output.chars().next().is_some_and(|ch| ch.is_ascii_digit()) {
1115        output.insert(0, '_');
1116    }
1117
1118    output
1119}
1120
1121fn to_pascal_case(value: &str) -> String {
1122    value
1123        .split('_')
1124        .filter(|part| !part.is_empty())
1125        .map(|part| {
1126            let mut chars = part.chars();
1127            match chars.next() {
1128                Some(first) => {
1129                    let mut segment = String::new();
1130                    segment.extend(first.to_uppercase());
1131                    segment.push_str(chars.as_str());
1132                    segment
1133                }
1134                None => String::new(),
1135            }
1136        })
1137        .collect::<String>()
1138}
1139
1140fn indent_block(block: &str, indent: usize) -> String {
1141    if block.trim().is_empty() {
1142        return String::new();
1143    }
1144
1145    let prefix = "    ".repeat(indent);
1146    block
1147        .lines()
1148        .map(|line| {
1149            if line.is_empty() {
1150                String::new()
1151            } else {
1152                format!("{prefix}{line}")
1153            }
1154        })
1155        .collect::<Vec<_>>()
1156        .join("\n")
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162    use indexmap::IndexMap;
1163    use shaperail_core::{
1164        AuthRule, EndpointSpec, FieldSchema, HttpMethod, PaginationStyle, ResourceDefinition,
1165    };
1166
1167    fn sample_resource() -> ResourceDefinition {
1168        let mut schema = IndexMap::new();
1169        schema.insert(
1170            "id".to_string(),
1171            FieldSchema {
1172                field_type: FieldType::Uuid,
1173                primary: true,
1174                generated: true,
1175                required: false,
1176                unique: false,
1177                nullable: false,
1178                reference: None,
1179                min: None,
1180                max: None,
1181                format: None,
1182                values: None,
1183                default: None,
1184                sensitive: false,
1185                search: false,
1186                items: None,
1187            },
1188        );
1189        schema.insert(
1190            "email".to_string(),
1191            FieldSchema {
1192                field_type: FieldType::String,
1193                primary: false,
1194                generated: false,
1195                required: true,
1196                unique: true,
1197                nullable: false,
1198                reference: None,
1199                min: None,
1200                max: None,
1201                format: None,
1202                values: None,
1203                default: None,
1204                sensitive: false,
1205                search: true,
1206                items: None,
1207            },
1208        );
1209        schema.insert(
1210            "created_at".to_string(),
1211            FieldSchema {
1212                field_type: FieldType::Timestamp,
1213                primary: false,
1214                generated: true,
1215                required: false,
1216                unique: false,
1217                nullable: false,
1218                reference: None,
1219                min: None,
1220                max: None,
1221                format: None,
1222                values: None,
1223                default: None,
1224                sensitive: false,
1225                search: false,
1226                items: None,
1227            },
1228        );
1229
1230        let mut endpoints = indexmap::IndexMap::new();
1231        endpoints.insert(
1232            "list".to_string(),
1233            EndpointSpec {
1234                method: HttpMethod::Get,
1235                path: "/users".to_string(),
1236                auth: Some(AuthRule::Public),
1237                input: None,
1238                filters: Some(vec!["email".to_string()]),
1239                search: Some(vec!["email".to_string()]),
1240                pagination: Some(PaginationStyle::Cursor),
1241                sort: Some(vec!["created_at".to_string()]),
1242                cache: None,
1243                controller: None,
1244                events: None,
1245                jobs: None,
1246                upload: None,
1247                soft_delete: false,
1248            },
1249        );
1250
1251        ResourceDefinition {
1252            resource: "users".to_string(),
1253            version: 1,
1254            db: None,
1255            schema,
1256            endpoints: Some(endpoints),
1257            relations: None,
1258            indexes: None,
1259        }
1260    }
1261
1262    #[test]
1263    fn generates_query_as_store_module() {
1264        let resource = sample_resource();
1265        let code = generate_resource_module(&resource).unwrap();
1266
1267        assert!(code.contains("impl ResourceStore for UsersStore"));
1268        assert!(code.contains("sqlx::query_as!"));
1269        assert!(code.contains("find_all_list"));
1270    }
1271
1272    #[test]
1273    fn generates_registry_module() {
1274        let resource = sample_resource();
1275        let project = generate_project(&[resource]).unwrap();
1276
1277        assert!(project.mod_rs.contains("pub mod users;"));
1278        assert!(project.mod_rs.contains("build_store_registry"));
1279    }
1280}