Skip to main content

systemprompt_database/services/postgres/
mod.rs

1//! `PostgreSQL` implementation of [`crate::services::DatabaseProvider`].
2//!
3//! This module is part of the documented sqlx allowlist: every `sqlx::query(_)`
4//! call here either binds a [`crate::models::QuerySelector`] string supplied
5//! at runtime (extension-defined SQL, dynamic admin queries) or executes
6//! `SELECT 1` for connection probing. Static SQL goes through the verified
7//! macros elsewhere.
8
9pub mod connection;
10pub mod conversion;
11mod ext;
12mod introspection;
13pub mod transaction;
14
15use async_trait::async_trait;
16use sqlx::Executor;
17use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use super::provider::DatabaseProvider;
22use crate::error::{DatabaseResult, RepositoryError};
23use crate::models::{
24    DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
25};
26use conversion::{bind_params, row_to_json, rows_to_result};
27use transaction::PostgresTransaction;
28
29#[derive(Debug)]
30pub struct PostgresProvider {
31    pool: Arc<PgPool>,
32}
33
34impl PostgresProvider {
35    pub async fn new(database_url: &str) -> DatabaseResult<Self> {
36        let mut connect_options = PgConnectOptions::from_str(database_url)?;
37
38        let ssl_mode = if database_url.contains("sslmode=require") {
39            PgSslMode::Require
40        } else if database_url.contains("sslmode=disable") {
41            PgSslMode::Disable
42        } else {
43            PgSslMode::Prefer
44        };
45
46        connect_options = connect_options
47            .application_name("systemprompt")
48            .statement_cache_capacity(0)
49            .ssl_mode(ssl_mode)
50            .options([("client_min_messages", "warning")]);
51
52        if let Some(ca_cert_path) = Self::get_cert_path() {
53            connect_options = connect_options.ssl_root_cert(&ca_cert_path);
54        }
55
56        let pool =
57            connection::connect_with_retry(connection::build_pool_options(), connect_options)
58                .await?;
59
60        Ok(Self {
61            pool: Arc::new(pool),
62        })
63    }
64
65    #[must_use]
66    pub const fn from_pool(pool: Arc<PgPool>) -> Self {
67        Self { pool }
68    }
69
70    fn get_cert_path() -> Option<std::path::PathBuf> {
71        std::env::var("PGCA_CERT_PATH")
72            .ok()
73            .map(std::path::PathBuf::from)
74    }
75
76    #[must_use]
77    pub fn pool(&self) -> &PgPool {
78        &self.pool
79    }
80}
81
82#[async_trait]
83impl DatabaseProvider for PostgresProvider {
84    fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
85        Some(Arc::clone(&self.pool))
86    }
87
88    async fn execute(
89        &self,
90        query: &dyn QuerySelector,
91        params: &[&dyn ToDbValue],
92    ) -> DatabaseResult<u64> {
93        let sql = query.select_query();
94        let query_obj = sqlx::query(sql);
95        let query_obj = bind_params(query_obj, params);
96
97        let result = query_obj.execute(&*self.pool).await?;
98
99        Ok(result.rows_affected())
100    }
101
102    async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
103        let mut conn = self.pool.acquire().await?;
104
105        conn.execute(sql).await?;
106
107        Ok(())
108    }
109
110    async fn fetch_all(
111        &self,
112        query: &dyn QuerySelector,
113        params: &[&dyn ToDbValue],
114    ) -> DatabaseResult<Vec<JsonRow>> {
115        let sql = query.select_query();
116        let query_obj = sqlx::query(sql);
117        let query_obj = bind_params(query_obj, params);
118
119        let rows = query_obj.fetch_all(&*self.pool).await?;
120
121        Ok(rows.iter().map(row_to_json).collect())
122    }
123
124    async fn fetch_one(
125        &self,
126        query: &dyn QuerySelector,
127        params: &[&dyn ToDbValue],
128    ) -> DatabaseResult<JsonRow> {
129        let sql = query.select_query();
130        let query_obj = sqlx::query(sql);
131        let query_obj = bind_params(query_obj, params);
132
133        let row = query_obj.fetch_one(&*self.pool).await?;
134
135        Ok(row_to_json(&row))
136    }
137
138    async fn fetch_optional(
139        &self,
140        query: &dyn QuerySelector,
141        params: &[&dyn ToDbValue],
142    ) -> DatabaseResult<Option<JsonRow>> {
143        let sql = query.select_query();
144        let query_obj = sqlx::query(sql);
145        let query_obj = bind_params(query_obj, params);
146
147        let row = query_obj.fetch_optional(&*self.pool).await?;
148
149        Ok(row.map(|r| row_to_json(&r)))
150    }
151
152    async fn fetch_scalar_value(
153        &self,
154        query: &dyn QuerySelector,
155        params: &[&dyn ToDbValue],
156    ) -> DatabaseResult<DbValue> {
157        let row = self.fetch_one(query, params).await?;
158
159        let first_value = row
160            .values()
161            .next()
162            .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
163
164        let db_value = match first_value {
165            serde_json::Value::String(s) => DbValue::String(s.clone()),
166            serde_json::Value::Number(n) => n
167                .as_i64()
168                .map(DbValue::Int)
169                .or_else(|| n.as_f64().map(DbValue::Float))
170                .unwrap_or(DbValue::NullFloat),
171            serde_json::Value::Bool(b) => DbValue::Bool(*b),
172            serde_json::Value::Null => DbValue::NullString,
173            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
174                return Err(RepositoryError::invalid_state("Unsupported value type"));
175            },
176        };
177
178        Ok(db_value)
179    }
180
181    async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
182        let tx = self.pool.begin().await?;
183
184        Ok(Box::new(PostgresTransaction::new(tx)))
185    }
186
187    async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
188        introspection::get_database_info(&self.pool).await
189    }
190
191    async fn test_connection(&self) -> DatabaseResult<()> {
192        sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
193        Ok(())
194    }
195
196    async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
197        let statements = crate::services::SqlExecutor::parse_sql_statements(sql)?;
198        for statement in statements {
199            sqlx::query(&statement).execute(&*self.pool).await?;
200        }
201        Ok(())
202    }
203
204    async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
205        let sql = query.select_query();
206        let start = std::time::Instant::now();
207
208        let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
209
210        Ok(rows_to_result(rows, start))
211    }
212
213    async fn query_raw_with(
214        &self,
215        query: &dyn QuerySelector,
216        params: &[&dyn ToDbValue],
217    ) -> DatabaseResult<QueryResult> {
218        let sql = query.select_query();
219        let start = std::time::Instant::now();
220
221        let query_obj = bind_params(sqlx::query(sql), params);
222        let rows = query_obj.fetch_all(&*self.pool).await?;
223
224        Ok(rows_to_result(rows, start))
225    }
226}