Skip to main content

systemprompt_database/services/postgres/
conversion.rs

1//! Conversions between `SQLx` `PostgreSQL` rows, [`QueryResult`] /
2//! [`crate::models::JsonRow`] values, and parameter binders.
3//!
4//! Part of the documented sqlx allowlist: the binder operates on the dynamic
5//! `sqlx::query::Query` value passed in by the trait implementation.
6
7use sqlx::{Column, Row};
8use std::collections::HashMap;
9
10use crate::models::{DbValue, QueryResult, ToDbValue};
11
12pub fn rows_to_result(rows: Vec<sqlx::postgres::PgRow>, start: std::time::Instant) -> QueryResult {
13    let mut columns = Vec::new();
14    let mut result_rows = Vec::new();
15
16    if let Some(first_row) = rows.first() {
17        columns = first_row
18            .columns()
19            .iter()
20            .map(|c| c.name().to_string())
21            .collect();
22    }
23
24    for row in rows {
25        result_rows.push(row_to_json(&row));
26    }
27
28    let row_count = result_rows.len();
29    let execution_time_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
30
31    QueryResult {
32        columns,
33        rows: result_rows,
34        row_count,
35        execution_time_ms,
36    }
37}
38
39pub fn row_to_json(row: &sqlx::postgres::PgRow) -> HashMap<String, serde_json::Value> {
40    row.columns()
41        .iter()
42        .map(|col| (col.name().to_string(), column_to_json(row, col.ordinal())))
43        .collect()
44}
45
46fn column_to_json(row: &sqlx::postgres::PgRow, ordinal: usize) -> serde_json::Value {
47    if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(ordinal) {
48        return val.map_or(serde_json::Value::Null, |v| {
49            serde_json::Value::String(v.to_rfc3339())
50        });
51    }
52    if let Ok(val) = row.try_get::<Option<uuid::Uuid>, _>(ordinal) {
53        return val.map_or(serde_json::Value::Null, |v| {
54            serde_json::Value::String(v.to_string())
55        });
56    }
57    if let Ok(val) = row.try_get::<Option<String>, _>(ordinal) {
58        return val.map_or(serde_json::Value::Null, serde_json::Value::String);
59    }
60    if let Ok(val) = row.try_get::<Option<i64>, _>(ordinal) {
61        return val.map_or(serde_json::Value::Null, |v| {
62            serde_json::Value::Number(v.into())
63        });
64    }
65    if let Ok(val) = row.try_get::<Option<i32>, _>(ordinal) {
66        return val.map_or(serde_json::Value::Null, |v| {
67            serde_json::Value::Number(i64::from(v).into())
68        });
69    }
70    if let Ok(val) = row.try_get::<Option<f64>, _>(ordinal) {
71        return val.map_or(serde_json::Value::Null, |v| serde_json::json!(v));
72    }
73    if let Ok(val) = row.try_get::<Option<rust_decimal::Decimal>, _>(ordinal) {
74        return val.map_or(serde_json::Value::Null, |v| {
75            v.to_string().parse::<f64>().map_or_else(
76                |_| serde_json::Value::String(v.to_string()),
77                |f| serde_json::json!(f),
78            )
79        });
80    }
81    if let Ok(val) = row.try_get::<Option<bool>, _>(ordinal) {
82        return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
83    }
84    if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(ordinal) {
85        return val.map_or(serde_json::Value::Null, |v| {
86            serde_json::Value::Array(v.into_iter().map(serde_json::Value::String).collect())
87        });
88    }
89    if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(ordinal) {
90        return val.unwrap_or(serde_json::Value::Null);
91    }
92    if let Ok(val) = row.try_get::<Option<Vec<u8>>, _>(ordinal) {
93        return val.map_or(serde_json::Value::Null, |bytes| {
94            use base64::Engine;
95            use base64::engine::general_purpose::STANDARD;
96            serde_json::Value::String(STANDARD.encode(&bytes))
97        });
98    }
99    serde_json::Value::Null
100}
101
102pub fn bind_params<'q>(
103    mut query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
104    params: &[&dyn ToDbValue],
105) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
106    for param in params {
107        let value = param.to_db_value();
108        query = match value {
109            DbValue::String(s) => query.bind(s),
110            DbValue::Int(i) => query.bind(i),
111            DbValue::Float(f) => query.bind(f),
112            DbValue::Bool(b) => query.bind(b),
113            DbValue::Bytes(b) => query.bind(b),
114            DbValue::Timestamp(dt) => query.bind(dt),
115            DbValue::StringArray(arr) => query.bind(arr),
116            DbValue::NullString => query.bind(None::<String>),
117            DbValue::NullInt => query.bind(None::<i64>),
118            DbValue::NullFloat => query.bind(None::<f64>),
119            DbValue::NullBool => query.bind(None::<bool>),
120            DbValue::NullBytes => query.bind(None::<Vec<u8>>),
121            DbValue::NullTimestamp => query.bind(None::<chrono::DateTime<chrono::Utc>>),
122            DbValue::NullStringArray => query.bind(None::<Vec<String>>),
123        };
124    }
125    query
126}