Skip to main content

shaperail_runtime/db/
query.rs

1use serde::{Deserialize, Serialize};
2use shaperail_core::{FieldSchema, FieldType, ResourceDefinition, ShaperailError};
3use sqlx::postgres::PgRow;
4use sqlx::{PgPool, Row};
5
6use super::filter::FilterSet;
7use super::pagination::{decode_cursor, encode_cursor, PageRequest};
8use super::search::SearchParam;
9use super::sort::SortParam;
10
11/// A single row returned from a resource query, represented as a JSON object.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ResourceRow(pub serde_json::Value);
14
15/// Dynamic query executor for a resource defined by a `ResourceDefinition`.
16///
17/// Generates and executes parameterized SQL queries against a PgPool,
18/// returning results as `ResourceRow` (JSON objects).
19///
20/// # SQL injection safety
21///
22/// All user-controllable input (filter values, search term, cursor, pagination
23/// offset/limit, insert/update body) is passed only as bound parameters via
24/// `BindValue` and `query.bind()`. Table and column names in the SQL string
25/// come solely from `ResourceDefinition` (trusted schema). Filter/sort field
26/// names are allow-listed (FilterSet::from_query_params, SortParam::parse).
27/// See db_integration tests `test_sql_injection_*` for verification.
28pub struct ResourceQuery<'a> {
29    pub resource: &'a ResourceDefinition,
30    pub pool: &'a PgPool,
31}
32
33impl<'a> ResourceQuery<'a> {
34    pub fn new(resource: &'a ResourceDefinition, pool: &'a PgPool) -> Self {
35        Self { resource, pool }
36    }
37
38    /// Returns the table name (same as resource name).
39    fn table(&self) -> &str {
40        &self.resource.resource
41    }
42
43    /// Returns all column names from the schema.
44    fn columns(&self) -> Vec<&str> {
45        self.resource.schema.keys().map(|k| k.as_str()).collect()
46    }
47
48    /// Returns the primary key field name.
49    fn primary_key(&self) -> &str {
50        self.resource
51            .schema
52            .iter()
53            .find(|(_, fs)| fs.primary)
54            .map(|(name, _)| name.as_str())
55            .unwrap_or("id")
56    }
57
58    /// Builds a SELECT column list.
59    fn select_columns(&self) -> String {
60        self.columns()
61            .iter()
62            .map(|c| format!("\"{c}\""))
63            .collect::<Vec<_>>()
64            .join(", ")
65    }
66
67    /// Converts a `PgRow` to a `serde_json::Value` object based on schema field types.
68    fn row_to_json(&self, row: &PgRow) -> Result<serde_json::Value, ShaperailError> {
69        let mut obj = serde_json::Map::new();
70        for (name, field) in &self.resource.schema {
71            let value = extract_column_value(row, name, field)?;
72            obj.insert(name.clone(), value);
73        }
74        Ok(serde_json::Value::Object(obj))
75    }
76
77    /// Returns `true` if any endpoint on this resource has `soft_delete: true`.
78    fn has_soft_delete(&self) -> bool {
79        self.resource
80            .endpoints
81            .as_ref()
82            .map(|eps| eps.values().any(|ep| ep.soft_delete))
83            .unwrap_or(false)
84    }
85
86    // -- Query methods --
87
88    /// Find a single record by its primary key.
89    pub async fn find_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, ShaperailError> {
90        let pk = self.primary_key();
91        let soft_delete_clause = if self.has_soft_delete() {
92            " AND \"deleted_at\" IS NULL"
93        } else {
94            ""
95        };
96        let sql = format!(
97            "SELECT {} FROM \"{}\" WHERE \"{}\" = $1{soft_delete_clause}",
98            self.select_columns(),
99            self.table(),
100            pk,
101        );
102
103        let row = sqlx::query(&sql)
104            .bind(id)
105            .fetch_optional(self.pool)
106            .await?
107            .ok_or(ShaperailError::NotFound)?;
108
109        let json = self.row_to_json(&row)?;
110        Ok(ResourceRow(json))
111    }
112
113    /// Find all records with filtering, searching, sorting, and pagination.
114    ///
115    /// Returns `(rows, cursor_page)` for cursor pagination or `(rows, offset_page)` for offset.
116    pub async fn find_all(
117        &self,
118        filters: &FilterSet,
119        search: Option<&SearchParam>,
120        sort: &SortParam,
121        page: &PageRequest,
122    ) -> Result<(Vec<ResourceRow>, serde_json::Value), ShaperailError> {
123        let mut sql = format!("SELECT {} FROM \"{}\"", self.select_columns(), self.table());
124        let mut has_where = false;
125        let mut param_offset: usize = 1;
126        let mut bind_values: Vec<BindValue> = Vec::new();
127
128        // Exclude soft-deleted rows
129        if self.has_soft_delete() {
130            sql.push_str(" WHERE \"deleted_at\" IS NULL");
131            has_where = true;
132        }
133
134        // Apply filters
135        if !filters.is_empty() {
136            param_offset = filters.apply_to_sql(&mut sql, has_where, param_offset);
137            has_where = true;
138            for f in &filters.filters {
139                bind_values.push(self.coerce_filter_value(&f.field, &f.value));
140            }
141        }
142
143        // Apply search
144        if let Some(sp) = search {
145            param_offset = sp.apply_to_sql(&mut sql, has_where, param_offset);
146            has_where = true;
147            bind_values.push(BindValue::Text(sp.term.clone()));
148        }
149
150        // Apply cursor/offset pagination
151        match page {
152            PageRequest::Cursor { after, limit } => {
153                let decoded_cursor = if let Some(cursor_str) = after {
154                    let id_str = decode_cursor(cursor_str)?;
155                    let id = uuid::Uuid::parse_str(&id_str).map_err(|_| {
156                        ShaperailError::Validation(vec![shaperail_core::FieldError {
157                            field: "cursor".to_string(),
158                            message: "Invalid cursor value".to_string(),
159                            code: "invalid_cursor".to_string(),
160                        }])
161                    })?;
162                    Some(id)
163                } else {
164                    None
165                };
166
167                if decoded_cursor.is_some() {
168                    if has_where {
169                        sql.push_str(" AND ");
170                    } else {
171                        sql.push_str(" WHERE ");
172                    }
173                    sql.push_str(&format!("\"id\" > ${param_offset}"));
174                    bind_values.push(BindValue::Uuid(decoded_cursor.unwrap_or_default()));
175                }
176
177                // Apply sort or default to id ASC for cursor pagination
178                if sort.is_empty() {
179                    sql.push_str(" ORDER BY \"id\" ASC");
180                } else {
181                    sort.apply_to_sql(&mut sql);
182                }
183                sql.push_str(&format!(" LIMIT {}", limit + 1));
184
185                let rows = self.execute_query(&sql, &bind_values).await?;
186
187                let has_more = rows.len() as i64 > *limit;
188                let result_rows: Vec<ResourceRow> =
189                    rows.into_iter().take(*limit as usize).collect();
190
191                let cursor = if has_more {
192                    result_rows
193                        .last()
194                        .and_then(|r| r.0.get("id"))
195                        .and_then(|v| v.as_str())
196                        .map(encode_cursor)
197                } else {
198                    None
199                };
200
201                let meta = serde_json::json!({
202                    "cursor": cursor,
203                    "has_more": has_more,
204                });
205                Ok((result_rows, meta))
206            }
207            PageRequest::Offset { offset, limit } => {
208                // For offset pagination, get total count
209                let mut count_sql = format!("SELECT COUNT(*) FROM \"{}\"", self.table());
210                let mut count_has_where = false;
211                let mut count_offset: usize = 1;
212                let mut count_binds: Vec<BindValue> = Vec::new();
213
214                // Exclude soft-deleted rows
215                if self.has_soft_delete() {
216                    count_sql.push_str(" WHERE \"deleted_at\" IS NULL");
217                    count_has_where = true;
218                }
219
220                if !filters.is_empty() {
221                    count_offset =
222                        filters.apply_to_sql(&mut count_sql, count_has_where, count_offset);
223                    count_has_where = true;
224                    for f in &filters.filters {
225                        count_binds.push(self.coerce_filter_value(&f.field, &f.value));
226                    }
227                }
228                if let Some(sp) = search {
229                    sp.apply_to_sql(&mut count_sql, count_has_where, count_offset);
230                    count_binds.push(BindValue::Text(sp.term.clone()));
231                }
232
233                let total = self.execute_count(&count_sql, &count_binds).await?;
234
235                // Apply sort
236                if !sort.is_empty() {
237                    sort.apply_to_sql(&mut sql);
238                }
239                sql.push_str(&format!(" LIMIT {limit} OFFSET {offset}"));
240
241                let rows = self.execute_query(&sql, &bind_values).await?;
242
243                let meta = serde_json::json!({
244                    "offset": offset,
245                    "limit": limit,
246                    "total": total,
247                });
248                Ok((rows, meta))
249            }
250        }
251    }
252
253    /// Insert a new record. Returns the inserted row.
254    pub async fn insert(
255        &self,
256        data: &serde_json::Map<String, serde_json::Value>,
257    ) -> Result<ResourceRow, ShaperailError> {
258        let mut columns = Vec::new();
259        let mut placeholders = Vec::new();
260        let mut bind_values = Vec::new();
261        let mut idx = 1usize;
262
263        // Add generated fields
264        for (name, field) in &self.resource.schema {
265            if field.generated {
266                match field.field_type {
267                    FieldType::Uuid => {
268                        columns.push(format!("\"{name}\""));
269                        placeholders.push(format!("${idx}"));
270                        bind_values.push(BindValue::Uuid(uuid::Uuid::new_v4()));
271                        idx += 1;
272                    }
273                    FieldType::Timestamp => {
274                        columns.push(format!("\"{name}\""));
275                        placeholders.push(format!("${idx}"));
276                        bind_values.push(BindValue::Timestamp(chrono::Utc::now()));
277                        idx += 1;
278                    }
279                    _ => {}
280                }
281                continue;
282            }
283
284            if let Some(value) = data.get(name) {
285                columns.push(format!("\"{name}\""));
286                placeholders.push(format!("${idx}"));
287                bind_values.push(json_to_bind(value, field));
288                idx += 1;
289            } else if let Some(default) = &field.default {
290                columns.push(format!("\"{name}\""));
291                placeholders.push(format!("${idx}"));
292                bind_values.push(json_to_bind(default, field));
293                idx += 1;
294            }
295        }
296
297        let sql = format!(
298            "INSERT INTO \"{}\" ({}) VALUES ({}) RETURNING {}",
299            self.table(),
300            columns.join(", "),
301            placeholders.join(", "),
302            self.select_columns(),
303        );
304
305        let rows = self.execute_query(&sql, &bind_values).await?;
306        rows.into_iter()
307            .next()
308            .ok_or_else(|| ShaperailError::Internal("Insert returned no rows".to_string()))
309    }
310
311    /// Update a record by primary key. Returns the updated row.
312    pub async fn update_by_id(
313        &self,
314        id: &uuid::Uuid,
315        data: &serde_json::Map<String, serde_json::Value>,
316    ) -> Result<ResourceRow, ShaperailError> {
317        let mut set_clauses = Vec::new();
318        let mut bind_values = Vec::new();
319        let mut idx = 1usize;
320
321        for (name, value) in data {
322            if let Some(field) = self.resource.schema.get(name) {
323                if field.primary || field.generated {
324                    continue;
325                }
326                set_clauses.push(format!("\"{name}\" = ${idx}"));
327                bind_values.push(json_to_bind(value, field));
328                idx += 1;
329            }
330        }
331
332        // Auto-update updated_at if it exists and is generated
333        if let Some(field) = self.resource.schema.get("updated_at") {
334            if field.generated && field.field_type == FieldType::Timestamp {
335                set_clauses.push(format!("\"updated_at\" = ${idx}"));
336                bind_values.push(BindValue::Timestamp(chrono::Utc::now()));
337                idx += 1;
338            }
339        }
340
341        if set_clauses.is_empty() {
342            return Err(ShaperailError::Validation(vec![
343                shaperail_core::FieldError {
344                    field: "body".to_string(),
345                    message: "No valid fields to update".to_string(),
346                    code: "empty_update".to_string(),
347                },
348            ]));
349        }
350
351        let pk = self.primary_key();
352        let soft_delete_clause = if self.has_soft_delete() {
353            " AND \"deleted_at\" IS NULL"
354        } else {
355            ""
356        };
357        let sql = format!(
358            "UPDATE \"{}\" SET {} WHERE \"{}\" = ${}{soft_delete_clause} RETURNING {}",
359            self.table(),
360            set_clauses.join(", "),
361            pk,
362            idx,
363            self.select_columns(),
364        );
365        bind_values.push(BindValue::Uuid(*id));
366
367        let rows = self.execute_query(&sql, &bind_values).await?;
368        rows.into_iter().next().ok_or(ShaperailError::NotFound)
369    }
370
371    /// Soft-delete a record by setting `deleted_at` to now.
372    pub async fn soft_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, ShaperailError> {
373        let pk = self.primary_key();
374        let sql = format!(
375            "UPDATE \"{}\" SET \"deleted_at\" = $1 WHERE \"{}\" = $2 AND \"deleted_at\" IS NULL RETURNING {}",
376            self.table(),
377            pk,
378            self.select_columns(),
379        );
380
381        let bind_values = vec![
382            BindValue::Timestamp(chrono::Utc::now()),
383            BindValue::Uuid(*id),
384        ];
385
386        let rows = self.execute_query(&sql, &bind_values).await?;
387        rows.into_iter().next().ok_or(ShaperailError::NotFound)
388    }
389
390    /// Hard-delete a record permanently.
391    pub async fn hard_delete_by_id(&self, id: &uuid::Uuid) -> Result<ResourceRow, ShaperailError> {
392        let pk = self.primary_key();
393        let sql = format!(
394            "DELETE FROM \"{}\" WHERE \"{}\" = $1 RETURNING {}",
395            self.table(),
396            pk,
397            self.select_columns(),
398        );
399
400        let bind_values = vec![BindValue::Uuid(*id)];
401        let rows = self.execute_query(&sql, &bind_values).await?;
402        rows.into_iter().next().ok_or(ShaperailError::NotFound)
403    }
404
405    // -- Internal helpers --
406
407    /// Coerces a filter string value to the correct `BindValue` based on the field's schema type.
408    fn coerce_filter_value(&self, field_name: &str, value: &str) -> BindValue {
409        if let Some(field) = self.resource.schema.get(field_name) {
410            match field.field_type {
411                FieldType::Uuid => {
412                    if let Ok(u) = uuid::Uuid::parse_str(value) {
413                        return BindValue::Uuid(u);
414                    }
415                }
416                FieldType::Integer => {
417                    if let Ok(n) = value.parse::<i32>() {
418                        return BindValue::Int(n);
419                    }
420                }
421                FieldType::Bigint => {
422                    if let Ok(n) = value.parse::<i64>() {
423                        return BindValue::Bigint(n);
424                    }
425                }
426                FieldType::Number => {
427                    if let Ok(n) = value.parse::<f64>() {
428                        return BindValue::Float(n);
429                    }
430                }
431                FieldType::Boolean => {
432                    if let Ok(b) = value.parse::<bool>() {
433                        return BindValue::Bool(b);
434                    }
435                }
436                _ => {}
437            }
438        }
439        BindValue::Text(value.to_string())
440    }
441
442    async fn execute_query(
443        &self,
444        sql: &str,
445        binds: &[BindValue],
446    ) -> Result<Vec<ResourceRow>, ShaperailError> {
447        let span = crate::observability::telemetry::db_span("query", self.table(), sql);
448        let _enter = span.enter();
449        let start = std::time::Instant::now();
450
451        // Dynamic queries use sqlx::query() with bind params (not string interpolation).
452        // The query_as! macro requires compile-time SQL; generated code (M04+) will use it.
453        let mut query = sqlx::query(sql);
454        for bind in binds {
455            query = match bind {
456                BindValue::Text(v) => query.bind(v),
457                BindValue::Int(v) => query.bind(v),
458                BindValue::Bigint(v) => query.bind(v),
459                BindValue::Float(v) => query.bind(v),
460                BindValue::Bool(v) => query.bind(v),
461                BindValue::Uuid(v) => query.bind(v),
462                BindValue::Timestamp(v) => query.bind(v),
463                BindValue::Date(v) => query.bind(v),
464                BindValue::Json(v) => query.bind(v),
465                BindValue::Null => query.bind(None::<String>),
466            };
467        }
468
469        let pg_rows = query.fetch_all(self.pool).await?;
470        let duration_ms = start.elapsed().as_millis() as u64;
471        log_slow_query(sql, duration_ms);
472
473        let mut results = Vec::with_capacity(pg_rows.len());
474        for row in &pg_rows {
475            results.push(ResourceRow(self.row_to_json(row)?));
476        }
477        Ok(results)
478    }
479
480    async fn execute_count(&self, sql: &str, binds: &[BindValue]) -> Result<i64, ShaperailError> {
481        let span = crate::observability::telemetry::db_span("count", self.table(), sql);
482        let _enter = span.enter();
483        let start = std::time::Instant::now();
484
485        let mut query = sqlx::query_scalar::<_, i64>(sql);
486        for bind in binds {
487            query = match bind {
488                BindValue::Text(v) => query.bind(v),
489                BindValue::Int(v) => query.bind(v),
490                BindValue::Bigint(v) => query.bind(v),
491                BindValue::Float(v) => query.bind(v),
492                BindValue::Bool(v) => query.bind(v),
493                BindValue::Uuid(v) => query.bind(v),
494                BindValue::Timestamp(v) => query.bind(v),
495                BindValue::Date(v) => query.bind(v),
496                BindValue::Json(v) => query.bind(v),
497                BindValue::Null => query.bind(None::<String>),
498            };
499        }
500        let count = query.fetch_one(self.pool).await?;
501        let duration_ms = start.elapsed().as_millis() as u64;
502        log_slow_query(sql, duration_ms);
503
504        Ok(count)
505    }
506}
507
508/// Internal enum for type-safe query parameter binding.
509#[derive(Debug, Clone)]
510enum BindValue {
511    Text(String),
512    Int(i32),
513    Bigint(i64),
514    Float(f64),
515    Bool(bool),
516    Uuid(uuid::Uuid),
517    Timestamp(chrono::DateTime<chrono::Utc>),
518    Date(chrono::NaiveDate),
519    Json(serde_json::Value),
520    Null,
521}
522
523/// Converts a JSON value to the appropriate `BindValue` based on the field schema.
524fn json_to_bind(value: &serde_json::Value, field: &FieldSchema) -> BindValue {
525    if value.is_null() {
526        return BindValue::Null;
527    }
528    match field.field_type {
529        FieldType::Uuid => {
530            if let Some(s) = value.as_str() {
531                if let Ok(u) = uuid::Uuid::parse_str(s) {
532                    return BindValue::Uuid(u);
533                }
534            }
535            BindValue::Text(value.to_string().trim_matches('"').to_string())
536        }
537        FieldType::String | FieldType::Enum | FieldType::File => {
538            BindValue::Text(value.as_str().unwrap_or(&value.to_string()).to_string())
539        }
540        FieldType::Integer => BindValue::Int(value.as_i64().unwrap_or(0) as i32),
541        FieldType::Bigint => BindValue::Bigint(value.as_i64().unwrap_or(0)),
542        FieldType::Number => BindValue::Float(value.as_f64().unwrap_or(0.0)),
543        FieldType::Boolean => BindValue::Bool(value.as_bool().unwrap_or(false)),
544        FieldType::Timestamp => {
545            if let Some(s) = value.as_str() {
546                if let Ok(dt) = s.parse::<chrono::DateTime<chrono::Utc>>() {
547                    return BindValue::Timestamp(dt);
548                }
549            }
550            BindValue::Timestamp(chrono::Utc::now())
551        }
552        FieldType::Date => {
553            if let Some(s) = value.as_str() {
554                if let Ok(d) = s.parse::<chrono::NaiveDate>() {
555                    return BindValue::Date(d);
556                }
557            }
558            BindValue::Date(chrono::Utc::now().date_naive())
559        }
560        FieldType::Json | FieldType::Array => BindValue::Json(value.clone()),
561    }
562}
563
564/// Extracts a column value from a `PgRow` as a `serde_json::Value`.
565fn extract_column_value(
566    row: &PgRow,
567    name: &str,
568    field: &FieldSchema,
569) -> Result<serde_json::Value, ShaperailError> {
570    // Try to get the column; if it doesn't exist, return null
571    let map_err = |e: sqlx::Error| ShaperailError::Internal(format!("Column '{name}' error: {e}"));
572
573    match field.field_type {
574        FieldType::Uuid => {
575            let v: Option<uuid::Uuid> = row.try_get(name).map_err(map_err)?;
576            Ok(v.map(|u| serde_json::Value::String(u.to_string()))
577                .unwrap_or(serde_json::Value::Null))
578        }
579        FieldType::String | FieldType::Enum | FieldType::File => {
580            let v: Option<String> = row.try_get(name).map_err(map_err)?;
581            Ok(v.map(serde_json::Value::String)
582                .unwrap_or(serde_json::Value::Null))
583        }
584        FieldType::Integer => {
585            let v: Option<i32> = row.try_get(name).map_err(map_err)?;
586            Ok(v.map(|n| serde_json::Value::Number(n.into()))
587                .unwrap_or(serde_json::Value::Null))
588        }
589        FieldType::Bigint => {
590            let v: Option<i64> = row.try_get(name).map_err(map_err)?;
591            Ok(v.map(|n| serde_json::Value::Number(n.into()))
592                .unwrap_or(serde_json::Value::Null))
593        }
594        FieldType::Number => {
595            let v: Option<f64> = row.try_get(name).map_err(map_err)?;
596            Ok(
597                v.and_then(|n| serde_json::Number::from_f64(n).map(serde_json::Value::Number))
598                    .unwrap_or(serde_json::Value::Null),
599            )
600        }
601        FieldType::Boolean => {
602            let v: Option<bool> = row.try_get(name).map_err(map_err)?;
603            Ok(v.map(serde_json::Value::Bool)
604                .unwrap_or(serde_json::Value::Null))
605        }
606        FieldType::Timestamp => {
607            let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(name).map_err(map_err)?;
608            Ok(v.map(|dt| serde_json::Value::String(dt.to_rfc3339()))
609                .unwrap_or(serde_json::Value::Null))
610        }
611        FieldType::Date => {
612            let v: Option<chrono::NaiveDate> = row.try_get(name).map_err(map_err)?;
613            Ok(v.map(|d| serde_json::Value::String(d.to_string()))
614                .unwrap_or(serde_json::Value::Null))
615        }
616        FieldType::Json | FieldType::Array => {
617            let v: Option<serde_json::Value> = row.try_get(name).map_err(map_err)?;
618            Ok(v.unwrap_or(serde_json::Value::Null))
619        }
620    }
621}
622
623/// Logs a warning if a query exceeds the slow query threshold.
624///
625/// The threshold is read from `SHAPERAIL_SLOW_QUERY_MS` env var (default: 100ms).
626fn log_slow_query(sql: &str, duration_ms: u64) {
627    let threshold: u64 = std::env::var("SHAPERAIL_SLOW_QUERY_MS")
628        .ok()
629        .and_then(|v| v.parse().ok())
630        .unwrap_or(100);
631
632    if duration_ms >= threshold {
633        tracing::warn!(
634            duration_ms = duration_ms,
635            sql = %sql,
636            threshold_ms = threshold,
637            "Slow query detected"
638        );
639    }
640}
641
642/// Builds a SQL `CREATE TABLE` statement from a `ResourceDefinition`.
643///
644/// Used by the migration generator to produce initial table creation SQL.
645pub fn build_create_table_sql(resource: &ResourceDefinition) -> String {
646    let mut columns = Vec::new();
647    let mut constraints = Vec::new();
648    let has_soft_delete = resource
649        .endpoints
650        .as_ref()
651        .map(|eps| eps.values().any(|ep| ep.soft_delete))
652        .unwrap_or(false);
653
654    for (name, field) in &resource.schema {
655        let mut col = format!(
656            "\"{}\" {}",
657            name,
658            field_type_to_sql(&field.field_type, field)
659        );
660
661        if field.primary {
662            col.push_str(" PRIMARY KEY");
663        }
664        if field.required && !field.primary && !field.nullable {
665            col.push_str(" NOT NULL");
666        }
667        if field.unique && !field.primary {
668            col.push_str(" UNIQUE");
669        }
670        if let Some(default) = &field.default {
671            col.push_str(&format!(" DEFAULT {}", sql_default_value(default, field)));
672        }
673        if field.field_type == FieldType::Uuid && field.generated {
674            col.push_str(" DEFAULT gen_random_uuid()");
675        }
676        if field.field_type == FieldType::Timestamp && field.generated {
677            col.push_str(" DEFAULT NOW()");
678        }
679        if field.field_type == FieldType::Date && field.generated {
680            col.push_str(" DEFAULT CURRENT_DATE");
681        }
682
683        // Enum CHECK constraint
684        if field.field_type == FieldType::Enum {
685            if let Some(values) = &field.values {
686                let vals = values
687                    .iter()
688                    .map(|v| format!("'{v}'"))
689                    .collect::<Vec<_>>()
690                    .join(", ");
691                constraints.push(format!(
692                    "CONSTRAINT \"chk_{table}_{name}\" CHECK (\"{name}\" IN ({vals}))",
693                    table = resource.resource,
694                ));
695            }
696        }
697
698        // Foreign key constraint
699        if let Some(reference) = &field.reference {
700            if let Some((ref_table, ref_col)) = reference.split_once('.') {
701                constraints.push(format!(
702                    "CONSTRAINT \"fk_{table}_{name}\" FOREIGN KEY (\"{name}\") REFERENCES \"{ref_table}\"(\"{ref_col}\")",
703                    table = resource.resource,
704                ));
705            }
706        }
707
708        columns.push(col);
709    }
710
711    if has_soft_delete && !resource.schema.contains_key("deleted_at") {
712        columns.push("\"deleted_at\" TIMESTAMPTZ".to_string());
713    }
714
715    let mut sql = format!(
716        "CREATE TABLE IF NOT EXISTS \"{}\" (\n  {}",
717        resource.resource,
718        columns.join(",\n  ")
719    );
720
721    if !constraints.is_empty() {
722        sql.push_str(",\n  ");
723        sql.push_str(&constraints.join(",\n  "));
724    }
725    sql.push_str("\n)");
726
727    // Add indexes
728    if let Some(indexes) = &resource.indexes {
729        for (i, idx) in indexes.iter().enumerate() {
730            let idx_cols = idx
731                .fields
732                .iter()
733                .map(|f| format!("\"{f}\""))
734                .collect::<Vec<_>>()
735                .join(", ");
736            let unique = if idx.unique { "UNIQUE " } else { "" };
737            let order = idx
738                .order
739                .as_deref()
740                .map(|o| format!(" {}", o.to_uppercase()))
741                .unwrap_or_default();
742            sql.push_str(&format!(
743                ";\nCREATE {unique}INDEX IF NOT EXISTS \"idx_{}_{i}\" ON \"{}\" ({idx_cols}{order})",
744                resource.resource, resource.resource,
745            ));
746        }
747    }
748
749    sql
750}
751
752/// Maps a `FieldType` to its PostgreSQL SQL type string.
753fn field_type_to_sql(ft: &FieldType, field: &FieldSchema) -> String {
754    match ft {
755        FieldType::Uuid => "UUID".to_string(),
756        FieldType::String => {
757            if let Some(max) = &field.max {
758                if let Some(n) = max.as_u64() {
759                    return format!("VARCHAR({n})");
760                }
761            }
762            "TEXT".to_string()
763        }
764        FieldType::Integer => "INTEGER".to_string(),
765        FieldType::Bigint => "BIGINT".to_string(),
766        FieldType::Number => "NUMERIC".to_string(),
767        FieldType::Boolean => "BOOLEAN".to_string(),
768        FieldType::Timestamp => "TIMESTAMPTZ".to_string(),
769        FieldType::Date => "DATE".to_string(),
770        FieldType::Enum => "TEXT".to_string(),
771        FieldType::Json => "JSONB".to_string(),
772        FieldType::Array => {
773            if let Some(items) = &field.items {
774                let item_sql = match items.as_str() {
775                    "string" => "TEXT",
776                    "integer" => "INTEGER",
777                    "uuid" => "UUID",
778                    _ => "TEXT",
779                };
780                format!("{item_sql}[]")
781            } else {
782                "TEXT[]".to_string()
783            }
784        }
785        FieldType::File => "TEXT".to_string(),
786    }
787}
788
789/// Converts a JSON default value to a SQL literal.
790fn sql_default_value(value: &serde_json::Value, _field: &FieldSchema) -> String {
791    match value {
792        serde_json::Value::String(s) => format!("'{s}'"),
793        serde_json::Value::Number(n) => n.to_string(),
794        serde_json::Value::Bool(b) => b.to_string().to_uppercase(),
795        serde_json::Value::Null => "NULL".to_string(),
796        other => format!("'{}'", other),
797    }
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use indexmap::IndexMap;
804    use shaperail_core::IndexSpec;
805
806    fn test_resource() -> ResourceDefinition {
807        let mut schema = IndexMap::new();
808        schema.insert(
809            "id".to_string(),
810            FieldSchema {
811                field_type: FieldType::Uuid,
812                primary: true,
813                generated: true,
814                required: false,
815                unique: false,
816                nullable: false,
817                reference: None,
818                min: None,
819                max: None,
820                format: None,
821                values: None,
822                default: None,
823                sensitive: false,
824                search: false,
825                items: None,
826            },
827        );
828        schema.insert(
829            "email".to_string(),
830            FieldSchema {
831                field_type: FieldType::String,
832                primary: false,
833                generated: false,
834                required: true,
835                unique: true,
836                nullable: false,
837                reference: None,
838                min: None,
839                max: Some(serde_json::json!(255)),
840                format: Some("email".to_string()),
841                values: None,
842                default: None,
843                sensitive: false,
844                search: true,
845                items: None,
846            },
847        );
848        schema.insert(
849            "name".to_string(),
850            FieldSchema {
851                field_type: FieldType::String,
852                primary: false,
853                generated: false,
854                required: true,
855                unique: false,
856                nullable: false,
857                reference: None,
858                min: Some(serde_json::json!(1)),
859                max: Some(serde_json::json!(200)),
860                format: None,
861                values: None,
862                default: None,
863                sensitive: false,
864                search: true,
865                items: None,
866            },
867        );
868        schema.insert(
869            "role".to_string(),
870            FieldSchema {
871                field_type: FieldType::Enum,
872                primary: false,
873                generated: false,
874                required: true,
875                unique: false,
876                nullable: false,
877                reference: None,
878                min: None,
879                max: None,
880                format: None,
881                values: Some(vec![
882                    "admin".to_string(),
883                    "member".to_string(),
884                    "viewer".to_string(),
885                ]),
886                default: Some(serde_json::json!("member")),
887                sensitive: false,
888                search: false,
889                items: None,
890            },
891        );
892        schema.insert(
893            "org_id".to_string(),
894            FieldSchema {
895                field_type: FieldType::Uuid,
896                primary: false,
897                generated: false,
898                required: true,
899                unique: false,
900                nullable: false,
901                reference: Some("organizations.id".to_string()),
902                min: None,
903                max: None,
904                format: None,
905                values: None,
906                default: None,
907                sensitive: false,
908                search: false,
909                items: None,
910            },
911        );
912        schema.insert(
913            "created_at".to_string(),
914            FieldSchema {
915                field_type: FieldType::Timestamp,
916                primary: false,
917                generated: true,
918                required: false,
919                unique: false,
920                nullable: false,
921                reference: None,
922                min: None,
923                max: None,
924                format: None,
925                values: None,
926                default: None,
927                sensitive: false,
928                search: false,
929                items: None,
930            },
931        );
932        schema.insert(
933            "updated_at".to_string(),
934            FieldSchema {
935                field_type: FieldType::Timestamp,
936                primary: false,
937                generated: true,
938                required: false,
939                unique: false,
940                nullable: false,
941                reference: None,
942                min: None,
943                max: None,
944                format: None,
945                values: None,
946                default: None,
947                sensitive: false,
948                search: false,
949                items: None,
950            },
951        );
952
953        ResourceDefinition {
954            resource: "users".to_string(),
955            version: 1,
956            schema,
957            endpoints: None,
958            relations: None,
959            indexes: Some(vec![
960                IndexSpec {
961                    fields: vec!["org_id".to_string(), "role".to_string()],
962                    unique: false,
963                    order: None,
964                },
965                IndexSpec {
966                    fields: vec!["created_at".to_string()],
967                    unique: false,
968                    order: Some("desc".to_string()),
969                },
970            ]),
971        }
972    }
973
974    #[test]
975    fn create_table_sql_basic() {
976        let resource = test_resource();
977        let sql = build_create_table_sql(&resource);
978
979        assert!(sql.contains("CREATE TABLE IF NOT EXISTS \"users\""));
980        assert!(sql.contains("\"id\" UUID PRIMARY KEY DEFAULT gen_random_uuid()"));
981        assert!(sql.contains("\"email\" VARCHAR(255) NOT NULL UNIQUE"));
982        assert!(sql.contains("\"name\" VARCHAR(200) NOT NULL"));
983        assert!(sql.contains("\"role\" TEXT NOT NULL DEFAULT 'member'"));
984        assert!(sql.contains("\"org_id\" UUID NOT NULL"));
985        assert!(sql.contains("\"created_at\" TIMESTAMPTZ DEFAULT NOW()"));
986        assert!(sql.contains("\"updated_at\" TIMESTAMPTZ DEFAULT NOW()"));
987    }
988
989    #[test]
990    fn create_table_sql_constraints() {
991        let resource = test_resource();
992        let sql = build_create_table_sql(&resource);
993
994        assert!(sql.contains("CONSTRAINT \"chk_users_role\" CHECK"));
995        assert!(sql.contains("'admin', 'member', 'viewer'"));
996        assert!(sql.contains("CONSTRAINT \"fk_users_org_id\" FOREIGN KEY"));
997        assert!(sql.contains("REFERENCES \"organizations\"(\"id\")"));
998    }
999
1000    #[test]
1001    fn create_table_sql_indexes() {
1002        let resource = test_resource();
1003        let sql = build_create_table_sql(&resource);
1004
1005        assert!(sql.contains(
1006            "CREATE INDEX IF NOT EXISTS \"idx_users_0\" ON \"users\" (\"org_id\", \"role\")"
1007        ));
1008        assert!(sql.contains(
1009            "CREATE INDEX IF NOT EXISTS \"idx_users_1\" ON \"users\" (\"created_at\" DESC)"
1010        ));
1011    }
1012
1013    #[test]
1014    fn create_table_sql_adds_deleted_at_for_soft_delete() {
1015        let mut resource = test_resource();
1016        resource.endpoints = Some(indexmap::IndexMap::from([(
1017            "delete".to_string(),
1018            shaperail_core::EndpointSpec {
1019                method: shaperail_core::HttpMethod::Delete,
1020                path: "/users/:id".to_string(),
1021                auth: None,
1022                input: None,
1023                filters: None,
1024                search: None,
1025                pagination: None,
1026                sort: None,
1027                cache: None,
1028                controller: None,
1029                events: None,
1030                jobs: None,
1031                upload: None,
1032                soft_delete: true,
1033            },
1034        )]));
1035
1036        let sql = build_create_table_sql(&resource);
1037        assert!(sql.contains("\"deleted_at\" TIMESTAMPTZ"));
1038    }
1039
1040    #[test]
1041    fn field_type_to_sql_mapping() {
1042        let default_field = FieldSchema {
1043            field_type: FieldType::String,
1044            primary: false,
1045            generated: false,
1046            required: false,
1047            unique: false,
1048            nullable: false,
1049            reference: None,
1050            min: None,
1051            max: None,
1052            format: None,
1053            values: None,
1054            default: None,
1055            sensitive: false,
1056            search: false,
1057            items: None,
1058        };
1059
1060        assert_eq!(field_type_to_sql(&FieldType::Uuid, &default_field), "UUID");
1061        assert_eq!(
1062            field_type_to_sql(&FieldType::String, &default_field),
1063            "TEXT"
1064        );
1065        assert_eq!(
1066            field_type_to_sql(&FieldType::Integer, &default_field),
1067            "INTEGER"
1068        );
1069        assert_eq!(
1070            field_type_to_sql(&FieldType::Bigint, &default_field),
1071            "BIGINT"
1072        );
1073        assert_eq!(
1074            field_type_to_sql(&FieldType::Number, &default_field),
1075            "NUMERIC"
1076        );
1077        assert_eq!(
1078            field_type_to_sql(&FieldType::Boolean, &default_field),
1079            "BOOLEAN"
1080        );
1081        assert_eq!(
1082            field_type_to_sql(&FieldType::Timestamp, &default_field),
1083            "TIMESTAMPTZ"
1084        );
1085        assert_eq!(field_type_to_sql(&FieldType::Date, &default_field), "DATE");
1086        assert_eq!(field_type_to_sql(&FieldType::Enum, &default_field), "TEXT");
1087        assert_eq!(field_type_to_sql(&FieldType::Json, &default_field), "JSONB");
1088        assert_eq!(field_type_to_sql(&FieldType::File, &default_field), "TEXT");
1089    }
1090
1091    #[test]
1092    fn field_type_to_sql_varchar() {
1093        let field = FieldSchema {
1094            field_type: FieldType::String,
1095            primary: false,
1096            generated: false,
1097            required: false,
1098            unique: false,
1099            nullable: false,
1100            reference: None,
1101            min: None,
1102            max: Some(serde_json::json!(100)),
1103            format: None,
1104            values: None,
1105            default: None,
1106            sensitive: false,
1107            search: false,
1108            items: None,
1109        };
1110        assert_eq!(
1111            field_type_to_sql(&FieldType::String, &field),
1112            "VARCHAR(100)"
1113        );
1114    }
1115
1116    #[test]
1117    fn field_type_to_sql_array() {
1118        let field = FieldSchema {
1119            field_type: FieldType::Array,
1120            primary: false,
1121            generated: false,
1122            required: false,
1123            unique: false,
1124            nullable: false,
1125            reference: None,
1126            min: None,
1127            max: None,
1128            format: None,
1129            values: None,
1130            default: None,
1131            sensitive: false,
1132            search: false,
1133            items: Some("string".to_string()),
1134        };
1135        assert_eq!(field_type_to_sql(&FieldType::Array, &field), "TEXT[]");
1136    }
1137
1138    #[test]
1139    fn json_to_bind_types() {
1140        let str_field = FieldSchema {
1141            field_type: FieldType::String,
1142            primary: false,
1143            generated: false,
1144            required: false,
1145            unique: false,
1146            nullable: false,
1147            reference: None,
1148            min: None,
1149            max: None,
1150            format: None,
1151            values: None,
1152            default: None,
1153            sensitive: false,
1154            search: false,
1155            items: None,
1156        };
1157
1158        let bind = json_to_bind(&serde_json::json!("hello"), &str_field);
1159        assert!(matches!(bind, BindValue::Text(s) if s == "hello"));
1160
1161        let bind = json_to_bind(&serde_json::Value::Null, &str_field);
1162        assert!(matches!(bind, BindValue::Null));
1163    }
1164}