systemprompt_database/admin/
query_executor.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use sqlx::postgres::PgPool;
5use sqlx::{Column, Row};
6use thiserror::Error;
7
8use crate::models::QueryResult;
9
10#[derive(Error, Debug)]
11pub enum QueryExecutorError {
12 #[error(
13 "Write query not allowed in read-only mode: only SELECT, WITH, EXPLAIN, and PRAGMA \
14 queries are permitted"
15 )]
16 WriteQueryNotAllowed,
17
18 #[error("Query execution failed: {0}")]
19 ExecutionFailed(#[from] sqlx::Error),
20}
21
22#[derive(Debug)]
23pub struct QueryExecutor {
24 pool: Arc<PgPool>,
25}
26
27impl QueryExecutor {
28 pub const fn new(pool: Arc<PgPool>) -> Self {
29 Self { pool }
30 }
31
32 pub async fn execute_query(
33 &self,
34 query: &str,
35 read_only: bool,
36 ) -> Result<QueryResult, QueryExecutorError> {
37 let start = std::time::Instant::now();
38
39 if read_only && !Self::is_safe_query(query) {
40 return Err(QueryExecutorError::WriteQueryNotAllowed);
41 }
42
43 let rows = sqlx::query(query).fetch_all(&*self.pool).await?;
44 let execution_time = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
45
46 let mut columns = Vec::new();
47 let mut result_rows = Vec::new();
48
49 if let Some(first_row) = rows.first() {
50 columns = first_row
51 .columns()
52 .iter()
53 .map(|c| c.name().to_string())
54 .collect();
55 }
56
57 for row in &rows {
58 let mut row_map = HashMap::new();
59 for (i, column) in row.columns().iter().enumerate() {
60 row_map.insert(column.name().to_string(), Self::extract_value(row, i));
61 }
62 result_rows.push(row_map);
63 }
64
65 Ok(QueryResult {
66 columns,
67 rows: result_rows,
68 row_count: rows.len(),
69 execution_time_ms: execution_time,
70 })
71 }
72
73 fn is_safe_query(query: &str) -> bool {
74 let trimmed = query.trim().to_lowercase();
75 let safe_starts = ["select", "with", "explain", "pragma"];
76 let unsafe_ops = [
77 " drop ", " delete ", " insert ", " update ", " alter ", " create ",
78 ];
79
80 safe_starts.iter().any(|s| trimmed.starts_with(s))
81 && !unsafe_ops.iter().any(|op| trimmed.contains(op))
82 }
83
84 fn extract_value(row: &sqlx::postgres::PgRow, column_index: usize) -> serde_json::Value {
85 if let Ok(val) = row.try_get::<Option<chrono::NaiveDateTime>, _>(column_index) {
86 return val.map_or(serde_json::Value::Null, |dt| {
87 serde_json::Value::String(dt.and_utc().to_rfc3339())
88 });
89 }
90 if let Ok(val) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(column_index) {
91 return val.map_or(serde_json::Value::Null, |dt| {
92 serde_json::Value::String(dt.to_rfc3339())
93 });
94 }
95 if let Ok(val) = row.try_get::<Option<String>, _>(column_index) {
96 return val.map_or(serde_json::Value::Null, serde_json::Value::String);
97 }
98 if let Ok(val) = row.try_get::<Option<i64>, _>(column_index) {
99 return val.map_or(serde_json::Value::Null, |i| {
100 serde_json::Value::Number(i.into())
101 });
102 }
103 if let Ok(val) = row.try_get::<Option<i32>, _>(column_index) {
104 return val.map_or(serde_json::Value::Null, |i| {
105 serde_json::Value::Number(i.into())
106 });
107 }
108 if let Ok(val) = row.try_get::<Option<f64>, _>(column_index) {
109 return val.map_or(serde_json::Value::Null, |f| {
110 serde_json::Number::from_f64(f)
111 .map_or(serde_json::Value::Null, serde_json::Value::Number)
112 });
113 }
114 if let Ok(val) = row.try_get::<Option<bool>, _>(column_index) {
115 return val.map_or(serde_json::Value::Null, serde_json::Value::Bool);
116 }
117 if let Ok(val) = row.try_get::<Option<Vec<String>>, _>(column_index) {
118 return val.map_or(serde_json::Value::Null, |arr| {
119 serde_json::Value::Array(arr.into_iter().map(serde_json::Value::String).collect())
120 });
121 }
122 if let Ok(val) = row.try_get::<Option<serde_json::Value>, _>(column_index) {
123 return val.unwrap_or(serde_json::Value::Null);
124 }
125 serde_json::Value::Null
126 }
127}