Skip to main content

systemprompt_database/admin/
query_executor.rs

1//! Query executor used by the CLI's `infra db query` and `db exec` commands.
2//!
3//! Part of the documented sqlx allowlist: SQL is supplied dynamically by
4//! the operator and validated through [`AdminSql`].
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use sqlx::postgres::PgPool;
10use sqlx::{Column, Row};
11use thiserror::Error;
12
13use crate::admin::admin_sql::{AdminSql, AdminSqlError, DEFAULT_READONLY_ROW_LIMIT};
14use crate::models::QueryResult;
15
16#[derive(Error, Debug)]
17pub enum QueryExecutorError {
18    #[error(
19        "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, SHOW, TABLE, and \
20         VALUES queries are permitted"
21    )]
22    WriteQueryNotAllowed,
23
24    #[error("Invalid admin SQL: {0}")]
25    InvalidSql(#[from] AdminSqlError),
26
27    #[error("Query execution failed: {0}")]
28    ExecutionFailed(#[from] sqlx::Error),
29}
30
31#[derive(Debug)]
32pub struct QueryExecutor {
33    pool: Arc<PgPool>,
34}
35
36impl QueryExecutor {
37    pub const fn new(pool: Arc<PgPool>) -> Self {
38        Self { pool }
39    }
40
41    pub async fn execute_readonly(
42        &self,
43        raw_sql: &str,
44        row_limit: Option<usize>,
45    ) -> Result<QueryResult, QueryExecutorError> {
46        let sql = AdminSql::parse_readonly(raw_sql)?;
47        self.execute(sql, row_limit.unwrap_or(DEFAULT_READONLY_ROW_LIMIT))
48            .await
49    }
50
51    pub async fn execute_write(&self, raw_sql: &str) -> Result<QueryResult, QueryExecutorError> {
52        let sql = AdminSql::parse_unrestricted(raw_sql)?;
53        self.execute(sql, usize::MAX).await
54    }
55
56    async fn execute(
57        &self,
58        sql: AdminSql,
59        row_limit: usize,
60    ) -> Result<QueryResult, QueryExecutorError> {
61        let start = std::time::Instant::now();
62
63        let rows = sqlx::query(sql.as_str()).fetch_all(&*self.pool).await?;
64        let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
65
66        let columns = rows.first().map_or_else(Vec::new, |first_row| {
67            first_row
68                .columns()
69                .iter()
70                .map(|c| c.name().to_string())
71                .collect()
72        });
73
74        let total_rows = rows.len();
75        let capped_rows = rows.iter().take(row_limit);
76        let mut result_rows = Vec::with_capacity(total_rows.min(row_limit));
77
78        for row in capped_rows {
79            let mut row_map = HashMap::new();
80            for (i, column) in row.columns().iter().enumerate() {
81                row_map.insert(column.name().to_string(), extract_value(row, i));
82            }
83            result_rows.push(row_map);
84        }
85
86        Ok(QueryResult {
87            columns,
88            rows: result_rows,
89            row_count: total_rows,
90            execution_time_ms: execution_time,
91        })
92    }
93}
94
95fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
96    if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
97        return val.map_or(serde_json::Value::Null, |dt| {
98            serde_json::Value::String(dt.to_rfc3339())
99        });
100    }
101    if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
102        return val.map_or(serde_json::Value::Null, serde_json::Value::String);
103    }
104    if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
105        return val.map_or(serde_json::Value::Null, |i| {
106            serde_json::Value::Number(i.into())
107        });
108    }
109    if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
110        return val.map_or(serde_json::Value::Null, |i| {
111            serde_json::Value::Number(i.into())
112        });
113    }
114    if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
115        return val.map_or(serde_json::Value::Null, |f| {
116            serde_json::Number::from_f64(f)
117                .map_or(serde_json::Value::Null, serde_json::Value::Number)
118        });
119    }
120    if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
121        return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
122    }
123    if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
124        return val.map_or(serde_json::Value::Null, |arr| {
125            serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
126        });
127    }
128    if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
129        return val.unwrap_or(serde_json::Value::Null);
130    }
131    serde_json::Value::Null
132}