Skip to main content

systemprompt_database/services/postgres/
mod.rs

1//! `PostgreSQL` implementation of [`crate::services::DatabaseProvider`].
2//!
3//! This module is part of the documented sqlx allowlist: every `sqlx::query(_)`
4//! call here either binds a [`crate::models::QuerySelector`] string supplied
5//! at runtime (extension-defined SQL, dynamic admin queries) or executes
6//! `SELECT 1` for connection probing. Static SQL goes through the verified
7//! macros elsewhere.
8
9pub mod conversion;
10mod ext;
11mod introspection;
12pub mod transaction;
13
14use async_trait::async_trait;
15use sqlx::Executor;
16use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
17use std::str::FromStr;
18use std::sync::Arc;
19
20use super::provider::DatabaseProvider;
21use crate::error::{DatabaseResult, RepositoryError};
22use crate::models::{
23    DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
24};
25use conversion::{bind_params, row_to_json, rows_to_result};
26use transaction::PostgresTransaction;
27
28#[derive(Debug)]
29pub struct PostgresProvider {
30    pool: Arc<PgPool>,
31}
32
33impl PostgresProvider {
34    pub async fn new(database_url: &str) -> DatabaseResult<Self> {
35        let mut connect_options = PgConnectOptions::from_str(database_url)?;
36
37        let ssl_mode = if database_url.contains("sslmode=require") {
38            PgSslMode::Require
39        } else if database_url.contains("sslmode=disable") {
40            PgSslMode::Disable
41        } else {
42            PgSslMode::Prefer
43        };
44
45        connect_options = connect_options
46            .application_name("systemprompt")
47            .statement_cache_capacity(0)
48            .ssl_mode(ssl_mode)
49            .options([("client_min_messages", "warning")]);
50
51        if let Some(ca_cert_path) = Self::get_cert_path() {
52            connect_options = connect_options.ssl_root_cert(&ca_cert_path);
53        }
54
55        let pool = sqlx::postgres::PgPoolOptions::new()
56            .max_connections(50)
57            .min_connections(0)
58            .max_lifetime(std::time::Duration::from_secs(1800))
59            .acquire_timeout(std::time::Duration::from_secs(30))
60            .idle_timeout(std::time::Duration::from_secs(300))
61            .connect_with(connect_options)
62            .await?;
63
64        Ok(Self {
65            pool: Arc::new(pool),
66        })
67    }
68
69    fn get_cert_path() -> Option<std::path::PathBuf> {
70        std::env::var("PGCA_CERT_PATH")
71            .ok()
72            .map(std::path::PathBuf::from)
73    }
74
75    #[must_use]
76    pub fn pool(&self) -> &PgPool {
77        &self.pool
78    }
79}
80
81#[async_trait]
82impl DatabaseProvider for PostgresProvider {
83    fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
84        Some(Arc::clone(&self.pool))
85    }
86
87    async fn execute(
88        &self,
89        query: &dyn QuerySelector,
90        params: &[&dyn ToDbValue],
91    ) -> DatabaseResult<u64> {
92        let sql = query.select_query();
93        let query_obj = sqlx::query(sql);
94        let query_obj = bind_params(query_obj, params);
95
96        let result = query_obj.execute(&*self.pool).await?;
97
98        Ok(result.rows_affected())
99    }
100
101    async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
102        let mut conn = self.pool.acquire().await?;
103
104        conn.execute(sql).await?;
105
106        Ok(())
107    }
108
109    async fn fetch_all(
110        &self,
111        query: &dyn QuerySelector,
112        params: &[&dyn ToDbValue],
113    ) -> DatabaseResult<Vec<JsonRow>> {
114        let sql = query.select_query();
115        let query_obj = sqlx::query(sql);
116        let query_obj = bind_params(query_obj, params);
117
118        let rows = query_obj.fetch_all(&*self.pool).await?;
119
120        Ok(rows.iter().map(row_to_json).collect())
121    }
122
123    async fn fetch_one(
124        &self,
125        query: &dyn QuerySelector,
126        params: &[&dyn ToDbValue],
127    ) -> DatabaseResult<JsonRow> {
128        let sql = query.select_query();
129        let query_obj = sqlx::query(sql);
130        let query_obj = bind_params(query_obj, params);
131
132        let row = query_obj.fetch_one(&*self.pool).await?;
133
134        Ok(row_to_json(&row))
135    }
136
137    async fn fetch_optional(
138        &self,
139        query: &dyn QuerySelector,
140        params: &[&dyn ToDbValue],
141    ) -> DatabaseResult<Option<JsonRow>> {
142        let sql = query.select_query();
143        let query_obj = sqlx::query(sql);
144        let query_obj = bind_params(query_obj, params);
145
146        let row = query_obj.fetch_optional(&*self.pool).await?;
147
148        Ok(row.map(|r| row_to_json(&r)))
149    }
150
151    async fn fetch_scalar_value(
152        &self,
153        query: &dyn QuerySelector,
154        params: &[&dyn ToDbValue],
155    ) -> DatabaseResult<DbValue> {
156        let row = self.fetch_one(query, params).await?;
157
158        let first_value = row
159            .values()
160            .next()
161            .ok_or_else(|| RepositoryError::invalid_state("No columns in result"))?;
162
163        let db_value = match first_value {
164            serde_json::Value::String(s) => DbValue::String(s.clone()),
165            serde_json::Value::Number(n) => n
166                .as_i64()
167                .map(DbValue::Int)
168                .or_else(|| n.as_f64().map(DbValue::Float))
169                .unwrap_or(DbValue::NullFloat),
170            serde_json::Value::Bool(b) => DbValue::Bool(*b),
171            serde_json::Value::Null => DbValue::NullString,
172            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
173                return Err(RepositoryError::invalid_state("Unsupported value type"));
174            },
175        };
176
177        Ok(db_value)
178    }
179
180    async fn begin_transaction(&self) -> DatabaseResult<Box<dyn DatabaseTransaction>> {
181        let tx = self.pool.begin().await?;
182
183        Ok(Box::new(PostgresTransaction::new(tx)))
184    }
185
186    async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
187        introspection::get_database_info(&self.pool).await
188    }
189
190    async fn test_connection(&self) -> DatabaseResult<()> {
191        sqlx::query("SELECT 1").fetch_one(&*self.pool).await?;
192        Ok(())
193    }
194
195    async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
196        let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
197        for statement in statements {
198            sqlx::query(&statement).execute(&*self.pool).await?;
199        }
200        Ok(())
201    }
202
203    async fn query_raw(&self, query: &dyn QuerySelector) -> DatabaseResult<QueryResult> {
204        let sql = query.select_query();
205        let start = std::time::Instant::now();
206
207        let rows = sqlx::query(sql).fetch_all(&*self.pool).await?;
208
209        Ok(rows_to_result(rows, start))
210    }
211
212    async fn query_raw_with(
213        &self,
214        query: &dyn QuerySelector,
215        params: Vec<serde_json::Value>,
216    ) -> DatabaseResult<QueryResult> {
217        let sql = query.select_query();
218        let start = std::time::Instant::now();
219
220        let mut query_obj = sqlx::query(sql);
221        for param in params {
222            query_obj = match param {
223                serde_json::Value::String(s) => query_obj.bind(s),
224                serde_json::Value::Number(n) => {
225                    if let Some(i) = n.as_i64() {
226                        query_obj.bind(i)
227                    } else if let Some(f) = n.as_f64() {
228                        query_obj.bind(f)
229                    } else {
230                        query_obj.bind(None::<i64>)
231                    }
232                },
233                serde_json::Value::Bool(b) => query_obj.bind(b),
234                serde_json::Value::Null => query_obj.bind(None::<String>),
235                serde_json::Value::Array(arr) => {
236                    let strings: Vec<String> = arr
237                        .into_iter()
238                        .filter_map(|v| v.as_str().map(String::from))
239                        .collect();
240                    query_obj.bind(strings)
241                },
242                serde_json::Value::Object(_) => {
243                    let json_str = serde_json::to_string(&param)?;
244                    query_obj.bind(Some(json_str))
245                },
246            };
247        }
248
249        let rows = query_obj.fetch_all(&*self.pool).await?;
250
251        Ok(rows_to_result(rows, start))
252    }
253}