Skip to main content

systemprompt_database/services/postgres/
mod.rs

1pub mod conversion;
2mod ext;
3mod introspection;
4pub mod transaction;
5
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
9use sqlx::Executor;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use super::provider::DatabaseProvider;
14use crate::models::{
15    DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
16};
17use conversion::{bind_params, row_to_json, rows_to_result};
18use transaction::PostgresTransaction;
19
20#[derive(Debug)]
21pub struct PostgresProvider {
22    pool: Arc<PgPool>,
23}
24
25impl PostgresProvider {
26    pub async fn new(database_url: &str) -> Result<Self> {
27        let mut connect_options = PgConnectOptions::from_str(database_url)
28            .map_err(|e| anyhow!("Failed to parse database URL: {e}"))?;
29
30        let ssl_mode = if database_url.contains("sslmode=require") {
31            PgSslMode::Require
32        } else if database_url.contains("sslmode=disable") {
33            PgSslMode::Disable
34        } else {
35            PgSslMode::Prefer
36        };
37
38        connect_options = connect_options
39            .application_name("systemprompt")
40            .statement_cache_capacity(0)
41            .ssl_mode(ssl_mode);
42
43        if let Some(ca_cert_path) = Self::get_cert_path() {
44            connect_options = connect_options.ssl_root_cert(&ca_cert_path);
45        }
46
47        let pool = sqlx::postgres::PgPoolOptions::new()
48            .max_connections(50)
49            .min_connections(0)
50            .max_lifetime(std::time::Duration::from_secs(1800))
51            .acquire_timeout(std::time::Duration::from_secs(30))
52            .idle_timeout(std::time::Duration::from_secs(300))
53            .connect_with(connect_options)
54            .await
55            .map_err(|e| anyhow!("Failed to connect to PostgreSQL: {e}"))?;
56
57        Ok(Self {
58            pool: Arc::new(pool),
59        })
60    }
61
62    fn get_cert_path() -> Option<std::path::PathBuf> {
63        std::env::var("PGCA_CERT_PATH")
64            .ok()
65            .map(std::path::PathBuf::from)
66    }
67
68    #[must_use]
69    pub fn pool(&self) -> &PgPool {
70        &self.pool
71    }
72}
73
74#[async_trait]
75impl DatabaseProvider for PostgresProvider {
76    fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
77        Some(Arc::clone(&self.pool))
78    }
79
80    async fn execute(&self, query: &dyn QuerySelector, params: &[&dyn ToDbValue]) -> Result<u64> {
81        let sql = query.select_query();
82        let query_obj = sqlx::query(sql);
83        let query_obj = bind_params(query_obj, params);
84
85        let result = query_obj
86            .execute(&*self.pool)
87            .await
88            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
89
90        Ok(result.rows_affected())
91    }
92
93    async fn execute_raw(&self, sql: &str) -> Result<()> {
94        let mut conn = self
95            .pool
96            .acquire()
97            .await
98            .map_err(|e| anyhow!("Failed to acquire connection: {e}"))?;
99
100        conn.execute(sql)
101            .await
102            .map_err(|e| anyhow!("Raw query execution failed: {e}"))?;
103
104        Ok(())
105    }
106
107    async fn fetch_all(
108        &self,
109        query: &dyn QuerySelector,
110        params: &[&dyn ToDbValue],
111    ) -> Result<Vec<JsonRow>> {
112        let sql = query.select_query();
113        let query_obj = sqlx::query(sql);
114        let query_obj = bind_params(query_obj, params);
115
116        let rows = query_obj
117            .fetch_all(&*self.pool)
118            .await
119            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
120
121        Ok(rows.iter().map(row_to_json).collect())
122    }
123
124    async fn fetch_one(
125        &self,
126        query: &dyn QuerySelector,
127        params: &[&dyn ToDbValue],
128    ) -> Result<JsonRow> {
129        let sql = query.select_query();
130        let query_obj = sqlx::query(sql);
131        let query_obj = bind_params(query_obj, params);
132
133        let row = query_obj
134            .fetch_one(&*self.pool)
135            .await
136            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
137
138        Ok(row_to_json(&row))
139    }
140
141    async fn fetch_optional(
142        &self,
143        query: &dyn QuerySelector,
144        params: &[&dyn ToDbValue],
145    ) -> Result<Option<JsonRow>> {
146        let sql = query.select_query();
147        let query_obj = sqlx::query(sql);
148        let query_obj = bind_params(query_obj, params);
149
150        let row = query_obj
151            .fetch_optional(&*self.pool)
152            .await
153            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
154
155        Ok(row.map(|r| row_to_json(&r)))
156    }
157
158    async fn fetch_scalar_value(
159        &self,
160        query: &dyn QuerySelector,
161        params: &[&dyn ToDbValue],
162    ) -> Result<DbValue> {
163        let row = self.fetch_one(query, params).await?;
164
165        let first_value = row
166            .values()
167            .next()
168            .ok_or_else(|| anyhow!("No columns in result"))?;
169
170        let db_value = match first_value {
171            serde_json::Value::String(s) => DbValue::String(s.clone()),
172            serde_json::Value::Number(n) => n
173                .as_i64()
174                .map(DbValue::Int)
175                .or_else(|| n.as_f64().map(DbValue::Float))
176                .unwrap_or(DbValue::NullFloat),
177            serde_json::Value::Bool(b) => DbValue::Bool(*b),
178            serde_json::Value::Null => DbValue::NullString,
179            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
180                return Err(anyhow!("Unsupported value type"))
181            },
182        };
183
184        Ok(db_value)
185    }
186
187    async fn begin_transaction(&self) -> Result<Box<dyn DatabaseTransaction>> {
188        let tx = self
189            .pool
190            .begin()
191            .await
192            .map_err(|e| anyhow!("Failed to begin transaction: {e}"))?;
193
194        Ok(Box::new(PostgresTransaction::new(tx)))
195    }
196
197    async fn get_database_info(&self) -> Result<DatabaseInfo> {
198        introspection::get_database_info(&self.pool).await
199    }
200
201    async fn test_connection(&self) -> Result<()> {
202        sqlx::query("SELECT 1")
203            .fetch_one(&*self.pool)
204            .await
205            .map_err(|e| anyhow!("Connection test failed: {e}"))?;
206        Ok(())
207    }
208
209    async fn execute_batch(&self, sql: &str) -> Result<()> {
210        let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
211        for statement in statements {
212            sqlx::query(&statement)
213                .execute(&*self.pool)
214                .await
215                .map_err(|e| anyhow!("Batch execution failed: {e}"))?;
216        }
217        Ok(())
218    }
219
220    async fn query_raw(&self, query: &dyn QuerySelector) -> Result<QueryResult> {
221        let sql = query.select_query();
222        let start = std::time::Instant::now();
223
224        let rows = sqlx::query(sql)
225            .fetch_all(&*self.pool)
226            .await
227            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
228
229        Ok(rows_to_result(rows, start))
230    }
231
232    async fn query_raw_with(
233        &self,
234        query: &dyn QuerySelector,
235        params: Vec<serde_json::Value>,
236    ) -> Result<QueryResult> {
237        let sql = query.select_query();
238        let start = std::time::Instant::now();
239
240        let mut query_obj = sqlx::query(sql);
241        for param in params {
242            query_obj = match param {
243                serde_json::Value::String(s) => query_obj.bind(s),
244                serde_json::Value::Number(n) => {
245                    if let Some(i) = n.as_i64() {
246                        query_obj.bind(i)
247                    } else if let Some(f) = n.as_f64() {
248                        query_obj.bind(f)
249                    } else {
250                        query_obj.bind(None::<i64>)
251                    }
252                },
253                serde_json::Value::Bool(b) => query_obj.bind(b),
254                serde_json::Value::Null => query_obj.bind(None::<String>),
255                serde_json::Value::Array(arr) => {
256                    let strings: Vec<String> = arr
257                        .into_iter()
258                        .filter_map(|v| v.as_str().map(String::from))
259                        .collect();
260                    query_obj.bind(strings)
261                },
262                serde_json::Value::Object(_) => {
263                    let json_str = serde_json::to_string(&param)
264                        .map_err(|e| anyhow!("Failed to serialize JSON object: {e}"))?;
265                    query_obj.bind(Some(json_str))
266                },
267            };
268        }
269
270        let rows = query_obj
271            .fetch_all(&*self.pool)
272            .await
273            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
274
275        Ok(rows_to_result(rows, start))
276    }
277}