Skip to main content

systemprompt_database/admin/
query_executor.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use sqlx::postgres::PgPool;
5use sqlx::{Column, Row};
6use thiserror::Error;
7
8use crate::models::QueryResult;
9
10#[derive(Error, Debug)]
11pub enum QueryExecutorError {
12    #[error(
13        "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, and PRAGMA \
14         queries are permitted"
15    )]
16    WriteQueryNotAllowed,
17
18    #[error("Query execution failed: {0}")]
19    ExecutionFailed(#[from] sqlx::Error),
20}
21
22#[derive(Debug)]
23pub struct QueryExecutor {
24    pool: Arc<PgPool>,
25}
26
27impl QueryExecutor {
28    pub const fn new(pool: Arc<PgPool>) -> Self {
29        Self { pool }
30    }
31
32    pub async fn execute_query(
33        &self,
34        query: &str,
35        read_only: bool,
36    ) -> Result<QueryResult, QueryExecutorError> {
37        let start = std::time::Instant::now();
38
39        if read_only && !Self::is_safe_query(query) {
40            return Err(QueryExecutorError::WriteQueryNotAllowed);
41        }
42
43        let rows = sqlx::query(query).fetch_all(&*self.pool).await?;
44        let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
45
46        let mut columns = Vec::new();
47        let mut result_rows = Vec::new();
48
49        if let Some(first_row) = rows.first() {
50            columns = first_row
51                .columns()
52                .iter()
53                .map(|c| c.name().to_string())
54                .collect();
55        }
56
57        for row in &rows {
58            let mut row_map = HashMap::new();
59            for (i, column) in row.columns().iter().enumerate() {
60                row_map.insert(column.name().to_string(), Self::extract_value(row, i));
61            }
62            result_rows.push(row_map);
63        }
64
65        Ok(QueryResult {
66            columns,
67            rows: result_rows,
68            row_count: rows.len(),
69            execution_time_ms: execution_time,
70        })
71    }
72
73    fn is_safe_query(query: &str) -> bool {
74        let trimmed = query.trim().to_lowercase();
75        let safe_starts = ["select", "with", "explain", "pragma"];
76        let unsafe_ops = [
77            " drop ", " delete ", " insert ", " update ", " alter ", " create ",
78        ];
79
80        safe_starts.iter().any(|s| trimmed.starts_with(s))
81            && !unsafe_ops.iter().any(|op| trimmed.contains(op))
82    }
83
84    fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
85        if let Ok(val) = row.try_get::<Option<chrono::NaiveDateTime>, _>(column_index) {
86            return val.map_or(serde_json::Value::Null, |dt| {
87                serde_json::Value::String(dt.and_utc().to_rfc3339())
88            });
89        }
90        if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
91            return val.map_or(serde_json::Value::Null, |dt| {
92                serde_json::Value::String(dt.to_rfc3339())
93            });
94        }
95        if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
96            return val.map_or(serde_json::Value::Null, serde_json::Value::String);
97        }
98        if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
99            return val.map_or(serde_json::Value::Null, |i| {
100                serde_json::Value::Number(i.into())
101            });
102        }
103        if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
104            return val.map_or(serde_json::Value::Null, |i| {
105                serde_json::Value::Number(i.into())
106            });
107        }
108        if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
109            return val.map_or(serde_json::Value::Null, |f| {
110                serde_json::Number::from_f64(f)
111                    .map_or(serde_json::Value::Null, serde_json::Value::Number)
112            });
113        }
114        if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
115            return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
116        }
117        if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
118            return val.map_or(serde_json::Value::Null, |arr| {
119                serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
120            });
121        }
122        if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
123            return val.unwrap_or(serde_json::Value::Null);
124        }
125        serde_json::Value::Null
126    }
127}