Skip to main content

systemprompt_database/services/postgres/
mod.rs

1//! `PostgreSQL` implementation of [`crate::services::DatabaseProvider`].
2//!
3//! This module is part of the documented sqlx allowlist: every `sqlx::query(_)`
4//! call here either binds a [`crate::models::QuerySelector`] string supplied
5//! at runtime (extension-defined SQL, dynamic admin queries) or executes
6//! `SELECT 1` for connection probing. Static SQL goes through the verified
7//! macros elsewhere.
8
9pub mod connection;
10pub mod conversion;
11mod ext;
12mod introspection;
13pub mod transaction;
14
15use async_trait::async_trait;
16use sqlx::Executor;
17use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
18use std::str::FromStr;
19use std::sync::Arc;
20
21use super::provider::DatabaseProvider;
22use crate::error::{DatabaseResult, RepositoryError};
23use crate::models::{
24    DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
25};
26use conversion::{bind_params, row_to_json, rows_to_result};
27use transaction::PostgresTransaction;
28
29#[derive(Debug)]
30pub struct PostgresProvider {
31    pool: Arc<PgPool>,
32}
33
34impl PostgresProvider {
35    pub async fn new(database_url: &str) -> DatabaseResult<Self> {
36        let mut connect_options = PgConnectOptions::from_str(database_url)?;
37
38        let ssl_mode = if database_url.contains("sslmode=require") {
39            PgSslMode::Require
40        } else if database_url.contains("sslmode=disable") {
41            PgSslMode::Disable
42        } else {
43            PgSslMode::Prefer
44        };
45
46        connect_options = connect_options
47            .application_name("systemprompt")
48            .statement_cache_capacity(0)
49            .ssl_mode(ssl_mode)
50            .options([("client_min_messages", "warning")]);
51
52        if let Some(ca_cert_path) = Self::get_cert_path() {
53            connect_options = connect_options.ssl_root_cert(&ca_cert_path);
54        }
55
56        let pool =
57            connection::connect_with_retry(connection::build_pool_options(), connect_options)
58                .await?;
59
60        Ok(Self {
61            pool: Arc::new(pool),
62        })
63    }
64
65    fn get_cert_path() -> Option<std::path::PathBuf> {
66        std::env::var("PGCA_CERT_PATH")
67            .ok()
68            .map(std::path::PathBuf::from)
69    }
70
71    #[must_use]
72    pub fn pool(&self) -> &PgPool {
73        &self.pool
74    }
75}
76
77#[async_trait]
78impl DatabaseProvider for PostgresProvider {
79    fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
80        Some(Arc::clone(&self.pool))
81    }
82
83    async fn execute(
84        &self,
85        query: &dyn QuerySelector,
86        params: &[&dyn ToDbValue],
87    ) -> DatabaseResult<u64> {
88        let sql = query.select_query();
89        let query_obj = sqlx::query(sql);
90        let query_obj = bind_params(query_obj, params);
91
92        let result = query_obj.execute(&*self.pool).await?;
93
94        Ok(result.rows_affected())
95    }
96
97    async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
98        let mut conn = self.pool.acquire().await?;
99
100        conn.execute(sql).await?;
101
102        Ok(())
103    }
104
105    async fn fetch_all(
106        &self,
107        query: &dyn QuerySelector,
108        params: &[&dyn ToDbValue],
109    ) -> DatabaseResult<Vec<JsonRow>> {
110        let sql = query.select_query();
111        let query_obj = sqlx::query(sql);
112        let query_obj = bind_params(query_obj, params);
113
114        let rows = query_obj.fetch_all(&*self.pool).await?;
115
116        Ok(rows.iter().map(row_to_json).collect())
117    }
118
119    async fn fetch_one(
120        &self,
121        query: &dyn QuerySelector,
122        params: &[&dyn ToDbValue],
123    ) -> DatabaseResult<JsonRow> {
124        let sql = query.select_query();
125        let query_obj = sqlx::query(sql);
126        let query_obj = bind_params(query_obj, params);
127
128        let row = query_obj.fetch_one(&*self.pool).await?;
129
130        Ok(row_to_json(&row))
131    }
132
133    async fn fetch_optional(
134        &self,
135        query: &dyn QuerySelector,
136        params: &[&dyn ToDbValue],
137    ) -> DatabaseResult<Option<JsonRow>> {
138        let sql = query.select_query();
139        let query_obj = sqlx::query(sql);
140        let query_obj = bind_params(query_obj, params);
141
142        let row = query_obj.fetch_optional(&*self.pool).await?;
143
144        Ok(row.map(|r| row_to_json(&r)))
145    }
146
147    async fn fetch_scalar_value(
148        &self,
149        query: &dyn QuerySelector,
150        params: &[&dyn ToDbValue],
151    ) -> DatabaseResult<DbValue> {
152        let row = self.fetch_one(query, params).await?;
153
154        let first_value = row
155            .values()
156            .next()
157            .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
158
159        let db_value = match first_value {
160            serde_json::Value::String(s) => DbValue::String(s.clone()),
161            serde_json::Value::Number(n) => n
162                .as_i64()
163                .map(DbValue::Int)
164                .or_else(|| n.as_f64().map(DbValue::Float))
165                .unwrap_or(DbValue::NullFloat),
166            serde_json::Value::Bool(b) => DbValue::Bool(*b),
167            serde_json::Value::Null => DbValue::NullString,
168            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
169                return Err(RepositoryError::invalid_state("Unsupported value type"));
170            },
171        };
172
173        Ok(db_value)
174    }
175
176    async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
177        let tx = self.pool.begin().await?;
178
179        Ok(Box::new(PostgresTransaction::new(tx)))
180    }
181
182    async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
183        introspection::get_database_info(&self.pool).await
184    }
185
186    async fn test_connection(&self) -> DatabaseResult<()> {
187        sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
188        Ok(())
189    }
190
191    async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
192        let statements = crate::services::SqlExecutor::parse_sql_statements(sql)?;
193        for statement in statements {
194            sqlx::query(&statement).execute(&*self.pool).await?;
195        }
196        Ok(())
197    }
198
199    async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
200        let sql = query.select_query();
201        let start = std::time::Instant::now();
202
203        let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
204
205        Ok(rows_to_result(rows, start))
206    }
207
208    async fn query_raw_with(
209        &self,
210        query: &dyn QuerySelector,
211        params: &[&dyn ToDbValue],
212    ) -> DatabaseResult<QueryResult> {
213        let sql = query.select_query();
214        let start = std::time::Instant::now();
215
216        let query_obj = bind_params(sqlx::query(sql), params);
217        let rows = query_obj.fetch_all(&*self.pool).await?;
218
219        Ok(rows_to_result(rows, start))
220    }
221}