systemprompt_database/services/postgres/
mod.rs1pub mod connection;
10pub mod conversion;
11mod ext;
12mod introspection;
13pub mod transaction;
14
15use async_trait::async_trait;
16use sqlx::Executor;
17use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use super::provider::DatabaseProvider;
22use crate::error::{DatabaseResult, RepositoryError};
23use crate::models::{
24 DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
25};
26use conversion::{bind_params, row_to_json, rows_to_result};
27use transaction::PostgresTransaction;
28
29#[derive(Debug)]
30pub struct PostgresProvider {
31 pool: Arc<PgPool>,
32}
33
34impl PostgresProvider {
35 pub async fn new(database_url: &str) -> DatabaseResult<Self> {
36 let mut connect_options = PgConnectOptions::from_str(database_url)?;
37
38 let ssl_mode = if database_url.contains("sslmode=require") {
39 PgSslMode::Require
40 } else if database_url.contains("sslmode=disable") {
41 PgSslMode::Disable
42 } else {
43 PgSslMode::Prefer
44 };
45
46 connect_options = connect_options
47 .application_name("systemprompt")
48 .statement_cache_capacity(0)
49 .ssl_mode(ssl_mode)
50 .options([("client_min_messages", "warning")]);
51
52 if let Some(ca_cert_path) = Self::get_cert_path() {
53 connect_options = connect_options.ssl_root_cert(&ca_cert_path);
54 }
55
56 let pool =
57 connection::connect_with_retry(connection::build_pool_options(), connect_options)
58 .await?;
59
60 Ok(Self {
61 pool: Arc::new(pool),
62 })
63 }
64
65 fn get_cert_path() -> Option<std::path::PathBuf> {
66 std::env::var("PGCA_CERT_PATH")
67 .ok()
68 .map(std::path::PathBuf::from)
69 }
70
71 #[must_use]
72 pub fn pool(&self) -> &PgPool {
73 &self.pool
74 }
75}
76
77#[async_trait]
78impl DatabaseProvider for PostgresProvider {
79 fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
80 Some(Arc::clone(&self.pool))
81 }
82
83 async fn execute(
84 &self,
85 query: &dyn QuerySelector,
86 params: &[&dyn ToDbValue],
87 ) -> DatabaseResult<u64> {
88 let sql = query.select_query();
89 let query_obj = sqlx::query(sql);
90 let query_obj = bind_params(query_obj, params);
91
92 let result = query_obj.execute(&*self.pool).await?;
93
94 Ok(result.rows_affected())
95 }
96
97 async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
98 let mut conn = self.pool.acquire().await?;
99
100 conn.execute(sql).await?;
101
102 Ok(())
103 }
104
105 async fn fetch_all(
106 &self,
107 query: &dyn QuerySelector,
108 params: &[&dyn ToDbValue],
109 ) -> DatabaseResult<Vec<JsonRow>> {
110 let sql = query.select_query();
111 let query_obj = sqlx::query(sql);
112 let query_obj = bind_params(query_obj, params);
113
114 let rows = query_obj.fetch_all(&*self.pool).await?;
115
116 Ok(rows.iter().map(row_to_json).collect())
117 }
118
119 async fn fetch_one(
120 &self,
121 query: &dyn QuerySelector,
122 params: &[&dyn ToDbValue],
123 ) -> DatabaseResult<JsonRow> {
124 let sql = query.select_query();
125 let query_obj = sqlx::query(sql);
126 let query_obj = bind_params(query_obj, params);
127
128 let row = query_obj.fetch_one(&*self.pool).await?;
129
130 Ok(row_to_json(&row))
131 }
132
133 async fn fetch_optional(
134 &self,
135 query: &dyn QuerySelector,
136 params: &[&dyn ToDbValue],
137 ) -> DatabaseResult<Option<JsonRow>> {
138 let sql = query.select_query();
139 let query_obj = sqlx::query(sql);
140 let query_obj = bind_params(query_obj, params);
141
142 let row = query_obj.fetch_optional(&*self.pool).await?;
143
144 Ok(row.map(|r| row_to_json(&r)))
145 }
146
147 async fn fetch_scalar_value(
148 &self,
149 query: &dyn QuerySelector,
150 params: &[&dyn ToDbValue],
151 ) -> DatabaseResult<DbValue> {
152 let row = self.fetch_one(query, params).await?;
153
154 let first_value = row
155 .values()
156 .next()
157 .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
158
159 let db_value = match first_value {
160 serde_json::Value::String(s) => DbValue::String(s.clone()),
161 serde_json::Value::Number(n) => n
162 .as_i64()
163 .map(DbValue::Int)
164 .or_else(|| n.as_f64().map(DbValue::Float))
165 .unwrap_or(DbValue::NullFloat),
166 serde_json::Value::Bool(b) => DbValue::Bool(*b),
167 serde_json::Value::Null => DbValue::NullString,
168 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
169 return Err(RepositoryError::invalid_state("Unsupported value type"));
170 },
171 };
172
173 Ok(db_value)
174 }
175
176 async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
177 let tx = self.pool.begin().await?;
178
179 Ok(Box::new(PostgresTransaction::new(tx)))
180 }
181
182 async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
183 introspection::get_database_info(&self.pool).await
184 }
185
186 async fn test_connection(&self) -> DatabaseResult<()> {
187 sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
188 Ok(())
189 }
190
191 async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
192 let statements = crate::services::SqlExecutor::parse_sql_statements(sql)?;
193 for statement in statements {
194 sqlx::query(&statement).execute(&*self.pool).await?;
195 }
196 Ok(())
197 }
198
199 async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
200 let sql = query.select_query();
201 let start = std::time::Instant::now();
202
203 let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
204
205 Ok(rows_to_result(rows, start))
206 }
207
208 async fn query_raw_with(
209 &self,
210 query: &dyn QuerySelector,
211 params: &[&dyn ToDbValue],
212 ) -> DatabaseResult<QueryResult> {
213 let sql = query.select_query();
214 let start = std::time::Instant::now();
215
216 let query_obj = bind_params(sqlx::query(sql), params);
217 let rows = query_obj.fetch_all(&*self.pool).await?;
218
219 Ok(rows_to_result(rows, start))
220 }
221}