systemprompt_database/admin/
query_executor.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use sqlx::postgres::PgPool;
5use sqlx::{Column, Row};
6use thiserror::Error;
7
8use crate::admin::admin_sql::{AdminSql, AdminSqlError, DEFAULT_READONLY_ROW_LIMIT};
9use crate::models::QueryResult;
10
11#[derive(Error, Debug)]
12pub enum QueryExecutorError {
13 #[error(
14 "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, SHOW, TABLE, and \
15 VALUES queries are permitted"
16 )]
17 WriteQueryNotAllowed,
18
19 #[error("Invalid admin SQL: {0}")]
20 InvalidSql(#[from] AdminSqlError),
21
22 #[error("Query execution failed: {0}")]
23 ExecutionFailed(#[from] sqlx::Error),
24}
25
26#[derive(Debug)]
27pub struct QueryExecutor {
28 pool: Arc<PgPool>,
29}
30
31impl QueryExecutor {
32 pub const fn new(pool: Arc<PgPool>) -> Self {
33 Self { pool }
34 }
35
36 pub async fn execute_readonly(
37 &self,
38 raw_sql: &str,
39 row_limit: Option<usize>,
40 ) -> Result<QueryResult, QueryExecutorError> {
41 let sql = AdminSql::parse_readonly(raw_sql)?;
42 self.execute(sql, row_limit.unwrap_or(DEFAULT_READONLY_ROW_LIMIT))
43 .await
44 }
45
46 pub async fn execute_write(&self, raw_sql: &str) -> Result<QueryResult, QueryExecutorError> {
47 let sql = AdminSql::parse_unrestricted(raw_sql)?;
48 self.execute(sql, usize::MAX).await
49 }
50
51 async fn execute(
52 &self,
53 sql: AdminSql,
54 row_limit: usize,
55 ) -> Result<QueryResult, QueryExecutorError> {
56 let start = std::time::Instant::now();
57
58 let rows = sqlx::query(sql.as_str()).fetch_all(&*self.pool).await?;
59 let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
60
61 let columns = rows.first().map_or_else(Vec::new, |first_row| {
62 first_row
63 .columns()
64 .iter()
65 .map(|c| c.name().to_string())
66 .collect()
67 });
68
69 let total_rows = rows.len();
70 let capped_rows = rows.iter().take(row_limit);
71 let mut result_rows = Vec::with_capacity(total_rows.min(row_limit));
72
73 for row in capped_rows {
74 let mut row_map = HashMap::new();
75 for (i, column) in row.columns().iter().enumerate() {
76 row_map.insert(column.name().to_string(), extract_value(row, i));
77 }
78 result_rows.push(row_map);
79 }
80
81 Ok(QueryResult {
82 columns,
83 rows: result_rows,
84 row_count: total_rows,
85 execution_time_ms: execution_time,
86 })
87 }
88}
89
90fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
91 if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
92 return val.map_or(serde_json::Value::Null, |dt| {
93 serde_json::Value::String(dt.to_rfc3339())
94 });
95 }
96 if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
97 return val.map_or(serde_json::Value::Null, serde_json::Value::String);
98 }
99 if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
100 return val.map_or(serde_json::Value::Null, |i| {
101 serde_json::Value::Number(i.into())
102 });
103 }
104 if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
105 return val.map_or(serde_json::Value::Null, |i| {
106 serde_json::Value::Number(i.into())
107 });
108 }
109 if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
110 return val.map_or(serde_json::Value::Null, |f| {
111 serde_json::Number::from_f64(f)
112 .map_or(serde_json::Value::Null, serde_json::Value::Number)
113 });
114 }
115 if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
116 return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
117 }
118 if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
119 return val.map_or(serde_json::Value::Null, |arr| {
120 serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
121 });
122 }
123 if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
124 return val.unwrap_or(serde_json::Value::Null);
125 }
126 serde_json::Value::Null
127}