systemprompt_database/services/postgres/
mod.rs1pub mod conversion;
2mod ext;
3mod introspection;
4pub mod transaction;
5
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
9use sqlx::Executor;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use super::provider::DatabaseProvider;
14use crate::models::{
15 DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
16};
17use conversion::{bind_params, row_to_json, rows_to_result};
18use transaction::PostgresTransaction;
19
20#[derive(Debug)]
21pub struct PostgresProvider {
22 pool: Arc<PgPool>,
23}
24
25impl PostgresProvider {
26 pub async fn new(database_url: &str) -> Result<Self> {
27 let mut connect_options = PgConnectOptions::from_str(database_url)
28 .map_err(|e| anyhow!("Failed to parse database URL: {e}"))?;
29
30 let ssl_mode = if database_url.contains("sslmode=require") {
31 PgSslMode::Require
32 } else if database_url.contains("sslmode=disable") {
33 PgSslMode::Disable
34 } else {
35 PgSslMode::Prefer
36 };
37
38 connect_options = connect_options
39 .application_name("systemprompt")
40 .statement_cache_capacity(0)
41 .ssl_mode(ssl_mode);
42
43 if let Some(ca_cert_path) = Self::get_cert_path() {
44 connect_options = connect_options.ssl_root_cert(&ca_cert_path);
45 }
46
47 let pool = sqlx::postgres::PgPoolOptions::new()
48 .max_connections(50)
49 .min_connections(0)
50 .max_lifetime(std::time::Duration::from_secs(1800))
51 .acquire_timeout(std::time::Duration::from_secs(30))
52 .idle_timeout(std::time::Duration::from_secs(300))
53 .connect_with(connect_options)
54 .await
55 .map_err(|e| anyhow!("Failed to connect to PostgreSQL: {e}"))?;
56
57 Ok(Self {
58 pool: Arc::new(pool),
59 })
60 }
61
62 fn get_cert_path() -> Option<std::path::PathBuf> {
63 std::env::var("PGCA_CERT_PATH")
64 .ok()
65 .map(std::path::PathBuf::from)
66 }
67
68 #[must_use]
69 pub fn pool(&self) -> &PgPool {
70 &self.pool
71 }
72}
73
74#[async_trait]
75impl DatabaseProvider for PostgresProvider {
76 fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
77 Some(Arc::clone(&self.pool))
78 }
79
80 async fn execute(&self, query: &dyn QuerySelector, params: &[&dyn ToDbValue]) -> Result<u64> {
81 let sql = query.select_query();
82 let query_obj = sqlx::query(sql);
83 let query_obj = bind_params(query_obj, params);
84
85 let result = query_obj
86 .execute(&*self.pool)
87 .await
88 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
89
90 Ok(result.rows_affected())
91 }
92
93 async fn execute_raw(&self, sql: &str) -> Result<()> {
94 let mut conn = self
95 .pool
96 .acquire()
97 .await
98 .map_err(|e| anyhow!("Failed to acquire connection: {e}"))?;
99
100 conn.execute(sql)
101 .await
102 .map_err(|e| anyhow!("Raw query execution failed: {e}"))?;
103
104 Ok(())
105 }
106
107 async fn fetch_all(
108 &self,
109 query: &dyn QuerySelector,
110 params: &[&dyn ToDbValue],
111 ) -> Result<Vec<JsonRow>> {
112 let sql = query.select_query();
113 let query_obj = sqlx::query(sql);
114 let query_obj = bind_params(query_obj, params);
115
116 let rows = query_obj
117 .fetch_all(&*self.pool)
118 .await
119 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
120
121 Ok(rows.iter().map(row_to_json).collect())
122 }
123
124 async fn fetch_one(
125 &self,
126 query: &dyn QuerySelector,
127 params: &[&dyn ToDbValue],
128 ) -> Result<JsonRow> {
129 let sql = query.select_query();
130 let query_obj = sqlx::query(sql);
131 let query_obj = bind_params(query_obj, params);
132
133 let row = query_obj
134 .fetch_one(&*self.pool)
135 .await
136 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
137
138 Ok(row_to_json(&row))
139 }
140
141 async fn fetch_optional(
142 &self,
143 query: &dyn QuerySelector,
144 params: &[&dyn ToDbValue],
145 ) -> Result<Option<JsonRow>> {
146 let sql = query.select_query();
147 let query_obj = sqlx::query(sql);
148 let query_obj = bind_params(query_obj, params);
149
150 let row = query_obj
151 .fetch_optional(&*self.pool)
152 .await
153 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
154
155 Ok(row.map(|r| row_to_json(&r)))
156 }
157
158 async fn fetch_scalar_value(
159 &self,
160 query: &dyn QuerySelector,
161 params: &[&dyn ToDbValue],
162 ) -> Result<DbValue> {
163 let row = self.fetch_one(query, params).await?;
164
165 let first_value = row
166 .values()
167 .next()
168 .ok_or_else(|| anyhow!("No columns in result"))?;
169
170 let db_value = match first_value {
171 serde_json::Value::String(s) => DbValue::String(s.clone()),
172 serde_json::Value::Number(n) => n
173 .as_i64()
174 .map(DbValue::Int)
175 .or_else(|| n.as_f64().map(DbValue::Float))
176 .unwrap_or(DbValue::NullFloat),
177 serde_json::Value::Bool(b) => DbValue::Bool(*b),
178 serde_json::Value::Null => DbValue::NullString,
179 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
180 return Err(anyhow!("Unsupported value type"))
181 },
182 };
183
184 Ok(db_value)
185 }
186
187 async fn begin_transaction(&self) -> Result<Box<dyn DatabaseTransaction>> {
188 let tx = self
189 .pool
190 .begin()
191 .await
192 .map_err(|e| anyhow!("Failed to begin transaction: {e}"))?;
193
194 Ok(Box::new(PostgresTransaction::new(tx)))
195 }
196
197 async fn get_database_info(&self) -> Result<DatabaseInfo> {
198 introspection::get_database_info(&self.pool).await
199 }
200
201 async fn test_connection(&self) -> Result<()> {
202 sqlx::query("SELECT 1")
203 .fetch_one(&*self.pool)
204 .await
205 .map_err(|e| anyhow!("Connection test failed: {e}"))?;
206 Ok(())
207 }
208
209 async fn execute_batch(&self, sql: &str) -> Result<()> {
210 let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
211 for statement in statements {
212 sqlx::query(&statement)
213 .execute(&*self.pool)
214 .await
215 .map_err(|e| anyhow!("Batch execution failed: {e}"))?;
216 }
217 Ok(())
218 }
219
220 async fn query_raw(&self, query: &dyn QuerySelector) -> Result<QueryResult> {
221 let sql = query.select_query();
222 let start = std::time::Instant::now();
223
224 let rows = sqlx::query(sql)
225 .fetch_all(&*self.pool)
226 .await
227 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
228
229 Ok(rows_to_result(rows, start))
230 }
231
232 async fn query_raw_with(
233 &self,
234 query: &dyn QuerySelector,
235 params: Vec<serde_json::Value>,
236 ) -> Result<QueryResult> {
237 let sql = query.select_query();
238 let start = std::time::Instant::now();
239
240 let mut query_obj = sqlx::query(sql);
241 for param in params {
242 query_obj = match param {
243 serde_json::Value::String(s) => query_obj.bind(s),
244 serde_json::Value::Number(n) => {
245 if let Some(i) = n.as_i64() {
246 query_obj.bind(i)
247 } else if let Some(f) = n.as_f64() {
248 query_obj.bind(f)
249 } else {
250 query_obj.bind(None::<i64>)
251 }
252 },
253 serde_json::Value::Bool(b) => query_obj.bind(b),
254 serde_json::Value::Null => query_obj.bind(None::<String>),
255 serde_json::Value::Array(arr) => {
256 let strings: Vec<String> = arr
257 .into_iter()
258 .filter_map(|v| v.as_str().map(String::from))
259 .collect();
260 query_obj.bind(strings)
261 },
262 serde_json::Value::Object(_) => {
263 let json_str = serde_json::to_string(¶m)
264 .map_err(|e| anyhow!("Failed to serialize JSON object: {e}"))?;
265 query_obj.bind(Some(json_str))
266 },
267 };
268 }
269
270 let rows = query_obj
271 .fetch_all(&*self.pool)
272 .await
273 .map_err(|e| anyhow!("Query execution failed: {e}"))?;
274
275 Ok(rows_to_result(rows, start))
276 }
277}