Skip to main content

systemprompt_database/admin/
query_executor.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use sqlx::postgres::PgPool;
5use sqlx::{Column, Row};
6use thiserror::Error;
7
8use crate::admin::admin_sql::{AdminSql, AdminSqlError, DEFAULT_READONLY_ROW_LIMIT};
9use crate::models::QueryResult;
10
11#[derive(Error, Debug)]
12pub enum QueryExecutorError {
13    #[error(
14        "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, SHOW, TABLE, and \
15         VALUES queries are permitted"
16    )]
17    WriteQueryNotAllowed,
18
19    #[error("Invalid admin SQL: {0}")]
20    InvalidSql(#[from] AdminSqlError),
21
22    #[error("Query execution failed: {0}")]
23    ExecutionFailed(#[from] sqlx::Error),
24}
25
26#[derive(Debug)]
27pub struct QueryExecutor {
28    pool: Arc<PgPool>,
29}
30
31impl QueryExecutor {
32    pub const fn new(pool: Arc<PgPool>) -> Self {
33        Self { pool }
34    }
35
36    pub async fn execute_readonly(
37        &self,
38        raw_sql: &str,
39        row_limit: Option<usize>,
40    ) -> Result<QueryResult, QueryExecutorError> {
41        let sql = AdminSql::parse_readonly(raw_sql)?;
42        self.execute(sql, row_limit.unwrap_or(DEFAULT_READONLY_ROW_LIMIT))
43            .await
44    }
45
46    pub async fn execute_write(&self, raw_sql: &str) -> Result<QueryResult, QueryExecutorError> {
47        let sql = AdminSql::parse_unrestricted(raw_sql)?;
48        self.execute(sql, usize::MAX).await
49    }
50
51    async fn execute(
52        &self,
53        sql: AdminSql,
54        row_limit: usize,
55    ) -> Result<QueryResult, QueryExecutorError> {
56        let start = std::time::Instant::now();
57
58        let rows = sqlx::query(sql.as_str()).fetch_all(&*self.pool).await?;
59        let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
60
61        let columns = rows.first().map_or_else(Vec::new, |first_row| {
62            first_row
63                .columns()
64                .iter()
65                .map(|c| c.name().to_string())
66                .collect()
67        });
68
69        let total_rows = rows.len();
70        let capped_rows = rows.iter().take(row_limit);
71        let mut result_rows = Vec::with_capacity(total_rows.min(row_limit));
72
73        for row in capped_rows {
74            let mut row_map = HashMap::new();
75            for (i, column) in row.columns().iter().enumerate() {
76                row_map.insert(column.name().to_string(), extract_value(row, i));
77            }
78            result_rows.push(row_map);
79        }
80
81        Ok(QueryResult {
82            columns,
83            rows: result_rows,
84            row_count: total_rows,
85            execution_time_ms: execution_time,
86        })
87    }
88}
89
90fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
91    if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
92        return val.map_or(serde_json::Value::Null, |dt| {
93            serde_json::Value::String(dt.to_rfc3339())
94        });
95    }
96    if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
97        return val.map_or(serde_json::Value::Null, serde_json::Value::String);
98    }
99    if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
100        return val.map_or(serde_json::Value::Null, |i| {
101            serde_json::Value::Number(i.into())
102        });
103    }
104    if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
105        return val.map_or(serde_json::Value::Null, |i| {
106            serde_json::Value::Number(i.into())
107        });
108    }
109    if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
110        return val.map_or(serde_json::Value::Null, |f| {
111            serde_json::Number::from_f64(f)
112                .map_or(serde_json::Value::Null, serde_json::Value::Number)
113        });
114    }
115    if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
116        return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
117    }
118    if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
119        return val.map_or(serde_json::Value::Null, |arr| {
120            serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
121        });
122    }
123    if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
124        return val.unwrap_or(serde_json::Value::Null);
125    }
126    serde_json::Value::Null
127}