Skip to main content

systemprompt_database/services/postgres/
conversion.rs

1use sqlx::{Column, Row};
2use std::collections::HashMap;
3
4use crate::models::{DbValue, QueryResult, ToDbValue};
5
6pub fn rows_to_result(rows: Vec<sqlx::postgres::PgRow>, start: std::time::Instant) -> QueryResult {
7    let mut columns = Vec::new();
8    let mut result_rows = Vec::new();
9
10    if let Some(first_row) = rows.first() {
11        columns = first_row
12            .columns()
13            .iter()
14            .map(|c| c.name().to_string())
15            .collect();
16    }
17
18    for row in rows {
19        result_rows.push(row_to_json(&row));
20    }
21
22    let row_count = result_rows.len();
23    let execution_time_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
24
25    QueryResult {
26        columns,
27        rows: result_rows,
28        row_count,
29        execution_time_ms,
30    }
31}
32
33pub fn row_to_json(row: &sqlx::postgres::PgRow) -> HashMap<String, serde_json::Value> {
34    row.columns()
35        .iter()
36        .map(|col| (col.name().to_string(), column_to_json(row, col.ordinal())))
37        .collect()
38}
39
40fn column_to_json(row: &sqlx::postgres::PgRow, ordinal: usize) -> serde_json::Value {
41    if let Ok(val) = row.try_get::<Option<chrono::NaiveDateTime>, _>(ordinal) {
42        return val.map_or(serde_json::Value::Null, |v| {
43            serde_json::Value::String(v.and_utc().to_rfc3339())
44        });
45    }
46    if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(ordinal) {
47        return val.map_or(serde_json::Value::Null, |v| {
48            serde_json::Value::String(v.to_rfc3339())
49        });
50    }
51    if let Ok(val) = row.try_get::<Option<uuid::Uuid>, _>(ordinal) {
52        return val.map_or(serde_json::Value::Null, |v| {
53            serde_json::Value::String(v.to_string())
54        });
55    }
56    if let Ok(val) = row.try_get::<Option<String>, _>(ordinal) {
57        return val.map_or(serde_json::Value::Null, serde_json::Value::String);
58    }
59    if let Ok(val) = row.try_get::<Option<i64>, _>(ordinal) {
60        return val.map_or(serde_json::Value::Null, |v| {
61            serde_json::Value::Number(v.into())
62        });
63    }
64    if let Ok(val) = row.try_get::<Option<i32>, _>(ordinal) {
65        return val.map_or(serde_json::Value::Null, |v| {
66            serde_json::Value::Number(i64::from(v).into())
67        });
68    }
69    if let Ok(val) = row.try_get::<Option<f64>, _>(ordinal) {
70        return val.map_or(serde_json::Value::Null, |v| serde_json::json!(v));
71    }
72    if let Ok(val) = row.try_get::<Option<rust_decimal::Decimal>, _>(ordinal) {
73        return val.map_or(serde_json::Value::Null, |v| {
74            v.to_string().parse::<f64>().map_or_else(
75                |_| serde_json::Value::String(v.to_string()),
76                |f| serde_json::json!(f),
77            )
78        });
79    }
80    if let Ok(val) = row.try_get::<Option<bool>, _>(ordinal) {
81        return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
82    }
83    if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(ordinal) {
84        return val.map_or(serde_json::Value::Null, |v| {
85            serde_json::Value::Array(v.into_iter().map(serde_json::Value::String).collect())
86        });
87    }
88    if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(ordinal) {
89        return val.unwrap_or(serde_json::Value::Null);
90    }
91    if let Ok(val) = row.try_get::<Option<Vec<u8>>, _>(ordinal) {
92        return val.map_or(serde_json::Value::Null, |bytes| {
93            use base64::engine::general_purpose::STANDARD;
94            use base64::Engine;
95            serde_json::Value::String(STANDARD.encode(&bytes))
96        });
97    }
98    serde_json::Value::Null
99}
100
101pub fn bind_params<'q>(
102    mut query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
103    params: &[&dyn ToDbValue],
104) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
105    for param in params {
106        let value = param.to_db_value();
107        query = match value {
108            DbValue::String(s) => query.bind(s),
109            DbValue::Int(i) => query.bind(i),
110            DbValue::Float(f) => query.bind(f),
111            DbValue::Bool(b) => query.bind(b),
112            DbValue::Bytes(b) => query.bind(b),
113            DbValue::Timestamp(dt) => query.bind(dt),
114            DbValue::StringArray(arr) => query.bind(arr),
115            DbValue::NullString => query.bind(None::<String>),
116            DbValue::NullInt => query.bind(None::<i64>),
117            DbValue::NullFloat => query.bind(None::<f64>),
118            DbValue::NullBool => query.bind(None::<bool>),
119            DbValue::NullBytes => query.bind(None::<Vec<u8>>),
120            DbValue::NullTimestamp => query.bind(None::<chrono::DateTime<chrono::Utc>>),
121            DbValue::NullStringArray => query.bind(None::<Vec<String>>),
122        };
123    }
124    query
125}