systemprompt_database/admin/
query_executor.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use sqlx::postgres::PgPool;
10use sqlx::{Column, Row};
11use thiserror::Error;
12
13use crate::admin::admin_sql::{AdminSql, AdminSqlError, DEFAULT_READONLY_ROW_LIMIT};
14use crate::models::QueryResult;
15
16#[derive(Error, Debug)]
17pub enum QueryExecutorError {
18 #[error(
19 "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, SHOW, TABLE, and \
20 VALUES queries are permitted"
21 )]
22 WriteQueryNotAllowed,
23
24 #[error("Invalid admin SQL: {0}")]
25 InvalidSql(#[from] AdminSqlError),
26
27 #[error("Query execution failed: {0}")]
28 ExecutionFailed(#[from] sqlx::Error),
29}
30
31#[derive(Debug)]
32pub struct QueryExecutor {
33 pool: Arc<PgPool>,
34}
35
36impl QueryExecutor {
37 pub const fn new(pool: Arc<PgPool>) -> Self {
38 Self { pool }
39 }
40
41 pub async fn execute_readonly(
42 &self,
43 raw_sql: &str,
44 row_limit: Option<usize>,
45 ) -> Result<QueryResult, QueryExecutorError> {
46 let sql = AdminSql::parse_readonly(raw_sql)?;
47 self.execute(sql, row_limit.unwrap_or(DEFAULT_READONLY_ROW_LIMIT))
48 .await
49 }
50
51 pub async fn execute_write(&self, raw_sql: &str) -> Result<QueryResult, QueryExecutorError> {
52 let sql = AdminSql::parse_unrestricted(raw_sql)?;
53 self.execute(sql, usize::MAX).await
54 }
55
56 async fn execute(
57 &self,
58 sql: AdminSql,
59 row_limit: usize,
60 ) -> Result<QueryResult, QueryExecutorError> {
61 let start = std::time::Instant::now();
62
63 let rows = sqlx::query(sql.as_str()).fetch_all(&*self.pool).await?;
64 let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
65
66 let columns = rows.first().map_or_else(Vec::new, |first_row| {
67 first_row
68 .columns()
69 .iter()
70 .map(|c| c.name().to_string())
71 .collect()
72 });
73
74 let total_rows = rows.len();
75 let capped_rows = rows.iter().take(row_limit);
76 let mut result_rows = Vec::with_capacity(total_rows.min(row_limit));
77
78 for row in capped_rows {
79 let mut row_map = HashMap::new();
80 for (i, column) in row.columns().iter().enumerate() {
81 row_map.insert(column.name().to_string(), extract_value(row, i));
82 }
83 result_rows.push(row_map);
84 }
85
86 Ok(QueryResult {
87 columns,
88 rows: result_rows,
89 row_count: total_rows,
90 execution_time_ms: execution_time,
91 })
92 }
93}
94
95fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
96 if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
97 return val.map_or(serde_json::Value::Null, |dt| {
98 serde_json::Value::String(dt.to_rfc3339())
99 });
100 }
101 if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
102 return val.map_or(serde_json::Value::Null, serde_json::Value::String);
103 }
104 if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
105 return val.map_or(serde_json::Value::Null, |i| {
106 serde_json::Value::Number(i.into())
107 });
108 }
109 if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
110 return val.map_or(serde_json::Value::Null, |i| {
111 serde_json::Value::Number(i.into())
112 });
113 }
114 if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
115 return val.map_or(serde_json::Value::Null, |f| {
116 serde_json::Number::from_f64(f)
117 .map_or(serde_json::Value::Null, serde_json::Value::Number)
118 });
119 }
120 if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
121 return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
122 }
123 if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
124 return val.map_or(serde_json::Value::Null, |arr| {
125 serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
126 });
127 }
128 if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
129 return val.unwrap_or(serde_json::Value::Null);
130 }
131 serde_json::Value::Null
132}