systemprompt_database/services/postgres/
mod.rs1pub mod conversion;
2mod ext;
3mod introspection;
4pub mod transaction;
5
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
9use sqlx::Executor;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use super::provider::DatabaseProvider;
14use crate::models::{
15 DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
16};
17use conversion::{bind_params, row_to_json, rows_to_result};
18use transaction::PostgresTransaction;
19
20#[derive(Debug)]
21pub struct PostgresProvider {
22 pool: Arc<PgPool>,
23}
24
25impl PostgresProvider {
26 pub async fn new(database_url: &str) -> Result<Self> {
27 let mut connect_options = PgConnectOptions::from_str(database_url)
28 .map_err(|e| anyhow!("Failed to parse database URL: {e}"))?;
29
30 let ssl_mode = if database_url.contains("sslmode=require") {
31 PgSslMode::Require
32 } else if database_url.contains("sslmode=disable") {
33 PgSslMode::Disable
34 } else {
35 PgSslMode::Prefer
36 };
37
38 connect_options = connect_options
39 .application_name("systemprompt")
40 .statement_cache_capacity(0)
41 .ssl_mode(ssl_mode)
42 .options([("client_min_messages", "warning")]);
43
44 if let Some(ca_cert_path) = Self::get_cert_path() {
45 connect_options = connect_options.ssl_root_cert(&ca_cert_path);
46 }
47
48 let pool = sqlx::postgres::PgPoolOptions::new()
49 .max_connections(50)
50 .min_connections(0)
51 .max_lifetime(std::time::Duration::from_secs(1800))
52 .acquire_timeout(std::time::Duration::from_secs(30))
53 .idle_timeout(std::time::Duration::from_secs(300))
54 .connect_with(connect_options)
55 .await
56 .map_err(|e| anyhow!("Failed to connect to PostgreSQL: {e}"))?;
57
58 Ok(Self {
59 pool: Arc::new(pool),
60 })
61 }
62
63 fn get_cert_path() -> Option<std::path::PathBuf> {
64 std::env::var("PGCA_CERT_PATH")
65 .ok()
66 .map(std::path::PathBuf::from)
67 }
68
69 #[must_use]
70 pub fn pool(&self) -> &PgPool {
71 &self.pool
72 }
73}
74
75#[async_trait]
76impl DatabaseProvider for PostgresProvider {
77 fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
78 Some(Arc::clone(&self.pool))
79 }
80
81 async fn execute(&self, query: &dyn QuerySelector, params: &[&dyn ToDbValue]) -> Result<u64> {
82 let sql = query.select_query();
83 let query_obj = sqlx::query(sql);
84 let query_obj = bind_params(query_obj, params);
85
86 let result = query_obj
87 .execute(&*self.pool)
88 .await
89 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
90
91 Ok(result.rows_affected())
92 }
93
94 async fn execute_raw(&self, sql: &str) -> Result<()> {
95 let mut conn = self
96 .pool
97 .acquire()
98 .await
99 .map_err(|e| anyhow!("Failed to acquire connection: {e}"))?;
100
101 conn.execute(sql)
102 .await
103 .map_err(|e| anyhow!("Raw query execution failed: {e}"))?;
104
105 Ok(())
106 }
107
108 async fn fetch_all(
109 &self,
110 query: &dyn QuerySelector,
111 params: &[&dyn ToDbValue],
112 ) -> Result<Vec<JsonRow>> {
113 let sql = query.select_query();
114 let query_obj = sqlx::query(sql);
115 let query_obj = bind_params(query_obj, params);
116
117 let rows = query_obj
118 .fetch_all(&*self.pool)
119 .await
120 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
121
122 Ok(rows.iter().map(row_to_json).collect())
123 }
124
125 async fn fetch_one(
126 &self,
127 query: &dyn QuerySelector,
128 params: &[&dyn ToDbValue],
129 ) -> Result<JsonRow> {
130 let sql = query.select_query();
131 let query_obj = sqlx::query(sql);
132 let query_obj = bind_params(query_obj, params);
133
134 let row = query_obj
135 .fetch_one(&*self.pool)
136 .await
137 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
138
139 Ok(row_to_json(&row))
140 }
141
142 async fn fetch_optional(
143 &self,
144 query: &dyn QuerySelector,
145 params: &[&dyn ToDbValue],
146 ) -> Result<Option<JsonRow>> {
147 let sql = query.select_query();
148 let query_obj = sqlx::query(sql);
149 let query_obj = bind_params(query_obj, params);
150
151 let row = query_obj
152 .fetch_optional(&*self.pool)
153 .await
154 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
155
156 Ok(row.map(|r| row_to_json(&r)))
157 }
158
159 async fn fetch_scalar_value(
160 &self,
161 query: &dyn QuerySelector,
162 params: &[&dyn ToDbValue],
163 ) -> Result<DbValue> {
164 let row = self.fetch_one(query, params).await?;
165
166 let first_value = row
167 .values()
168 .next()
169 .ok_or_else(|| anyhow!("No columns in result"))?;
170
171 let db_value = match first_value {
172 serde_json::Value::String(s) => DbValue::String(s.clone()),
173 serde_json::Value::Number(n) => n
174 .as_i64()
175 .map(DbValue::Int)
176 .or_else(|| n.as_f64().map(DbValue::Float))
177 .unwrap_or(DbValue::NullFloat),
178 serde_json::Value::Bool(b) => DbValue::Bool(*b),
179 serde_json::Value::Null => DbValue::NullString,
180 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
181 return Err(anyhow!("Unsupported value type"))
182 },
183 };
184
185 Ok(db_value)
186 }
187
188 async fn begin_transaction(&self) -> Result<Box<dyn DatabaseTransaction>> {
189 let tx = self
190 .pool
191 .begin()
192 .await
193 .map_err(|e| anyhow!("Failed to begin transaction: {e}"))?;
194
195 Ok(Box::new(PostgresTransaction::new(tx)))
196 }
197
198 async fn get_database_info(&self) -> Result<DatabaseInfo> {
199 introspection::get_database_info(&self.pool).await
200 }
201
202 async fn test_connection(&self) -> Result<()> {
203 sqlx::query("SELECT 1")
204 .fetch_one(&*self.pool)
205 .await
206 .map_err(|e| anyhow!("Connection test failed: {e}"))?;
207 Ok(())
208 }
209
210 async fn execute_batch(&self, sql: &str) -> Result<()> {
211 let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
212 for statement in statements {
213 sqlx::query(&statement)
214 .execute(&*self.pool)
215 .await
216 .map_err(|e| anyhow!("Batch execution failed: {e}"))?;
217 }
218 Ok(())
219 }
220
221 async fn query_raw(&self, query: &dyn QuerySelector) -> Result<QueryResult> {
222 let sql = query.select_query();
223 let start = std::time::Instant::now();
224
225 let rows = sqlx::query(sql)
226 .fetch_all(&*self.pool)
227 .await
228 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
229
230 Ok(rows_to_result(rows, start))
231 }
232
233 async fn query_raw_with(
234 &self,
235 query: &dyn QuerySelector,
236 params: Vec<serde_json::Value>,
237 ) -> Result<QueryResult> {
238 let sql = query.select_query();
239 let start = std::time::Instant::now();
240
241 let mut query_obj = sqlx::query(sql);
242 for param in params {
243 query_obj = match param {
244 serde_json::Value::String(s) => query_obj.bind(s),
245 serde_json::Value::Number(n) => {
246 if let Some(i) = n.as_i64() {
247 query_obj.bind(i)
248 } else if let Some(f) = n.as_f64() {
249 query_obj.bind(f)
250 } else {
251 query_obj.bind(None::<i64>)
252 }
253 },
254 serde_json::Value::Bool(b) => query_obj.bind(b),
255 serde_json::Value::Null => query_obj.bind(None::<String>),
256 serde_json::Value::Array(arr) => {
257 let strings: Vec<String> = arr
258 .into_iter()
259 .filter_map(|v| v.as_str().map(String::from))
260 .collect();
261 query_obj.bind(strings)
262 },
263 serde_json::Value::Object(_) => {
264 let json_str = serde_json::to_string(¶m)
265 .map_err(|e| anyhow!("Failed to serialize JSON object: {e}"))?;
266 query_obj.bind(Some(json_str))
267 },
268 };
269 }
270
271 let rows = query_obj
272 .fetch_all(&*self.pool)
273 .await
274 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
275
276 Ok(rows_to_result(rows, start))
277 }
278}