systemprompt_database/services/postgres/
mod.rs1pub mod connection;
10pub mod conversion;
11mod ext;
12mod introspection;
13pub mod transaction;
14
15use async_trait::async_trait;
16use sqlx::Executor;
17use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use super::provider::DatabaseProvider;
22use crate::error::{DatabaseResult, RepositoryError};
23use crate::models::{
24 DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
25};
26use conversion::{bind_params, row_to_json, rows_to_result};
27use transaction::PostgresTransaction;
28
29#[derive(Debug)]
30pub struct PostgresProvider {
31 pool: Arc<PgPool>,
32}
33
34impl PostgresProvider {
35 pub async fn new(database_url: &str) -> DatabaseResult<Self> {
36 let mut connect_options = PgConnectOptions::from_str(database_url)?;
37
38 let ssl_mode = if database_url.contains("sslmode=require") {
39 PgSslMode::Require
40 } else if database_url.contains("sslmode=disable") {
41 PgSslMode::Disable
42 } else {
43 PgSslMode::Prefer
44 };
45
46 connect_options = connect_options
47 .application_name("systemprompt")
48 .statement_cache_capacity(0)
49 .ssl_mode(ssl_mode)
50 .options([("client_min_messages", "warning")]);
51
52 if let Some(ca_cert_path) = Self::get_cert_path() {
53 connect_options = connect_options.ssl_root_cert(&ca_cert_path);
54 }
55
56 let pool =
57 connection::connect_with_retry(connection::build_pool_options(), connect_options)
58 .await?;
59
60 Ok(Self {
61 pool: Arc::new(pool),
62 })
63 }
64
65 #[must_use]
66 pub const fn from_pool(pool: Arc<PgPool>) -> Self {
67 Self { pool }
68 }
69
70 fn get_cert_path() -> Option<std::path::PathBuf> {
71 std::env::var("PGCA_CERT_PATH")
72 .ok()
73 .map(std::path::PathBuf::from)
74 }
75
76 #[must_use]
77 pub fn pool(&self) -> &PgPool {
78 &self.pool
79 }
80}
81
82#[async_trait]
83impl DatabaseProvider for PostgresProvider {
84 fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
85 Some(Arc::clone(&self.pool))
86 }
87
88 async fn execute(
89 &self,
90 query: &dyn QuerySelector,
91 params: &[&dyn ToDbValue],
92 ) -> DatabaseResult<u64> {
93 let sql = query.select_query();
94 let query_obj = sqlx::query(sql);
95 let query_obj = bind_params(query_obj, params);
96
97 let result = query_obj.execute(&*self.pool).await?;
98
99 Ok(result.rows_affected())
100 }
101
102 async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
103 let mut conn = self.pool.acquire().await?;
104
105 conn.execute(sql).await?;
106
107 Ok(())
108 }
109
110 async fn fetch_all(
111 &self,
112 query: &dyn QuerySelector,
113 params: &[&dyn ToDbValue],
114 ) -> DatabaseResult<Vec<JsonRow>> {
115 let sql = query.select_query();
116 let query_obj = sqlx::query(sql);
117 let query_obj = bind_params(query_obj, params);
118
119 let rows = query_obj.fetch_all(&*self.pool).await?;
120
121 Ok(rows.iter().map(row_to_json).collect())
122 }
123
124 async fn fetch_one(
125 &self,
126 query: &dyn QuerySelector,
127 params: &[&dyn ToDbValue],
128 ) -> DatabaseResult<JsonRow> {
129 let sql = query.select_query();
130 let query_obj = sqlx::query(sql);
131 let query_obj = bind_params(query_obj, params);
132
133 let row = query_obj.fetch_one(&*self.pool).await?;
134
135 Ok(row_to_json(&row))
136 }
137
138 async fn fetch_optional(
139 &self,
140 query: &dyn QuerySelector,
141 params: &[&dyn ToDbValue],
142 ) -> DatabaseResult<Option<JsonRow>> {
143 let sql = query.select_query();
144 let query_obj = sqlx::query(sql);
145 let query_obj = bind_params(query_obj, params);
146
147 let row = query_obj.fetch_optional(&*self.pool).await?;
148
149 Ok(row.map(|r| row_to_json(&r)))
150 }
151
152 async fn fetch_scalar_value(
153 &self,
154 query: &dyn QuerySelector,
155 params: &[&dyn ToDbValue],
156 ) -> DatabaseResult<DbValue> {
157 let row = self.fetch_one(query, params).await?;
158
159 let first_value = row
160 .values()
161 .next()
162 .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
163
164 let db_value = match first_value {
165 serde_json::Value::String(s) => DbValue::String(s.clone()),
166 serde_json::Value::Number(n) => n
167 .as_i64()
168 .map(DbValue::Int)
169 .or_else(|| n.as_f64().map(DbValue::Float))
170 .unwrap_or(DbValue::NullFloat),
171 serde_json::Value::Bool(b) => DbValue::Bool(*b),
172 serde_json::Value::Null => DbValue::NullString,
173 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
174 return Err(RepositoryError::invalid_state("Unsupported value type"));
175 },
176 };
177
178 Ok(db_value)
179 }
180
181 async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
182 let tx = self.pool.begin().await?;
183
184 Ok(Box::new(PostgresTransaction::new(tx)))
185 }
186
187 async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
188 introspection::get_database_info(&self.pool).await
189 }
190
191 async fn test_connection(&self) -> DatabaseResult<()> {
192 sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
193 Ok(())
194 }
195
196 async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
197 let statements = crate::services::SqlExecutor::parse_sql_statements(sql)?;
198 for statement in statements {
199 sqlx::query(&statement).execute(&*self.pool).await?;
200 }
201 Ok(())
202 }
203
204 async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
205 let sql = query.select_query();
206 let start = std::time::Instant::now();
207
208 let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
209
210 Ok(rows_to_result(rows, start))
211 }
212
213 async fn query_raw_with(
214 &self,
215 query: &dyn QuerySelector,
216 params: &[&dyn ToDbValue],
217 ) -> DatabaseResult<QueryResult> {
218 let sql = query.select_query();
219 let start = std::time::Instant::now();
220
221 let query_obj = bind_params(sqlx::query(sql), params);
222 let rows = query_obj.fetch_all(&*self.pool).await?;
223
224 Ok(rows_to_result(rows, start))
225 }
226}