systemprompt_database/services/postgres/
mod.rs1pub mod conversion;
10mod ext;
11mod introspection;
12pub mod transaction;
13
14use async_trait::async_trait;
15use sqlx::Executor;
16use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
17use std::str::FromStr;
18use std::sync::Arc;
19
20use super::provider::DatabaseProvider;
21use crate::error::{DatabaseResult, RepositoryError};
22use crate::models::{
23 DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
24};
25use conversion::{bind_params, row_to_json, rows_to_result};
26use transaction::PostgresTransaction;
27
28#[derive(Debug)]
29pub struct PostgresProvider {
30 pool: Arc<PgPool>,
31}
32
33impl PostgresProvider {
34 pub async fn new(database_url: &str) -> DatabaseResult<Self> {
35 let mut connect_options = PgConnectOptions::from_str(database_url)?;
36
37 let ssl_mode = if database_url.contains("sslmode=require") {
38 PgSslMode::Require
39 } else if database_url.contains("sslmode=disable") {
40 PgSslMode::Disable
41 } else {
42 PgSslMode::Prefer
43 };
44
45 connect_options = connect_options
46 .application_name("systemprompt")
47 .statement_cache_capacity(0)
48 .ssl_mode(ssl_mode)
49 .options([("client_min_messages", "warning")]);
50
51 if let Some(ca_cert_path) = Self::get_cert_path() {
52 connect_options = connect_options.ssl_root_cert(&ca_cert_path);
53 }
54
55 let pool = sqlx::postgres::PgPoolOptions::new()
56 .max_connections(50)
57 .min_connections(0)
58 .max_lifetime(std::time::Duration::from_secs(1800))
59 .acquire_timeout(std::time::Duration::from_secs(30))
60 .idle_timeout(std::time::Duration::from_secs(300))
61 .connect_with(connect_options)
62 .await?;
63
64 Ok(Self {
65 pool: Arc::new(pool),
66 })
67 }
68
69 fn get_cert_path() -> Option<std::path::PathBuf> {
70 std::env::var("PGCA_CERT_PATH")
71 .ok()
72 .map(std::path::PathBuf::from)
73 }
74
75 #[must_use]
76 pub fn pool(&self) -> &PgPool {
77 &self.pool
78 }
79}
80
81#[async_trait]
82impl DatabaseProvider for PostgresProvider {
83 fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
84 Some(Arc::clone(&self.pool))
85 }
86
87 async fn execute(
88 &self,
89 query: &dyn QuerySelector,
90 params: &[&dyn ToDbValue],
91 ) -> DatabaseResult<u64> {
92 let sql = query.select_query();
93 let query_obj = sqlx::query(sql);
94 let query_obj = bind_params(query_obj, params);
95
96 let result = query_obj.execute(&*self.pool).await?;
97
98 Ok(result.rows_affected())
99 }
100
101 async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
102 let mut conn = self.pool.acquire().await?;
103
104 conn.execute(sql).await?;
105
106 Ok(())
107 }
108
109 async fn fetch_all(
110 &self,
111 query: &dyn QuerySelector,
112 params: &[&dyn ToDbValue],
113 ) -> DatabaseResult<Vec<JsonRow>> {
114 let sql = query.select_query();
115 let query_obj = sqlx::query(sql);
116 let query_obj = bind_params(query_obj, params);
117
118 let rows = query_obj.fetch_all(&*self.pool).await?;
119
120 Ok(rows.iter().map(row_to_json).collect())
121 }
122
123 async fn fetch_one(
124 &self,
125 query: &dyn QuerySelector,
126 params: &[&dyn ToDbValue],
127 ) -> DatabaseResult<JsonRow> {
128 let sql = query.select_query();
129 let query_obj = sqlx::query(sql);
130 let query_obj = bind_params(query_obj, params);
131
132 let row = query_obj.fetch_one(&*self.pool).await?;
133
134 Ok(row_to_json(&row))
135 }
136
137 async fn fetch_optional(
138 &self,
139 query: &dyn QuerySelector,
140 params: &[&dyn ToDbValue],
141 ) -> DatabaseResult<Option<JsonRow>> {
142 let sql = query.select_query();
143 let query_obj = sqlx::query(sql);
144 let query_obj = bind_params(query_obj, params);
145
146 let row = query_obj.fetch_optional(&*self.pool).await?;
147
148 Ok(row.map(|r| row_to_json(&r)))
149 }
150
151 async fn fetch_scalar_value(
152 &self,
153 query: &dyn QuerySelector,
154 params: &[&dyn ToDbValue],
155 ) -> DatabaseResult<DbValue> {
156 let row = self.fetch_one(query, params).await?;
157
158 let first_value = row
159 .values()
160 .next()
161 .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
162
163 let db_value = match first_value {
164 serde_json::Value::String(s) => DbValue::String(s.clone()),
165 serde_json::Value::Number(n) => n
166 .as_i64()
167 .map(DbValue::Int)
168 .or_else(|| n.as_f64().map(DbValue::Float))
169 .unwrap_or(DbValue::NullFloat),
170 serde_json::Value::Bool(b) => DbValue::Bool(*b),
171 serde_json::Value::Null => DbValue::NullString,
172 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
173 return Err(RepositoryError::invalid_state("Unsupported value type"));
174 },
175 };
176
177 Ok(db_value)
178 }
179
180 async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
181 let tx = self.pool.begin().await?;
182
183 Ok(Box::new(PostgresTransaction::new(tx)))
184 }
185
186 async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
187 introspection::get_database_info(&self.pool).await
188 }
189
190 async fn test_connection(&self) -> DatabaseResult<()> {
191 sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
192 Ok(())
193 }
194
195 async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
196 let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
197 for statement in statements {
198 sqlx::query(&statement).execute(&*self.pool).await?;
199 }
200 Ok(())
201 }
202
203 async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
204 let sql = query.select_query();
205 let start = std::time::Instant::now();
206
207 let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
208
209 Ok(rows_to_result(rows, start))
210 }
211
212 async fn query_raw_with(
213 &self,
214 query: &dyn QuerySelector,
215 params: Vec<serde_json::Value>,
216 ) -> DatabaseResult<QueryResult> {
217 let sql = query.select_query();
218 let start = std::time::Instant::now();
219
220 let mut query_obj = sqlx::query(sql);
221 for param in params {
222 query_obj = match param {
223 serde_json::Value::String(s) => query_obj.bind(s),
224 serde_json::Value::Number(n) => {
225 if let Some(i) = n.as_i64() {
226 query_obj.bind(i)
227 } else if let Some(f) = n.as_f64() {
228 query_obj.bind(f)
229 } else {
230 query_obj.bind(None::<i64>)
231 }
232 },
233 serde_json::Value::Bool(b) => query_obj.bind(b),
234 serde_json::Value::Null => query_obj.bind(None::<String>),
235 serde_json::Value::Array(arr) => {
236 let strings: Vec<String> = arr
237 .into_iter()
238 .filter_map(|v| v.as_str().map(String::from))
239 .collect();
240 query_obj.bind(strings)
241 },
242 serde_json::Value::Object(_) => {
243 let json_str = serde_json::to_string(¶m)?;
244 query_obj.bind(Some(json_str))
245 },
246 };
247 }
248
249 let rows = query_obj.fetch_all(&*self.pool).await?;
250
251 Ok(rows_to_result(rows, start))
252 }
253}