Skip to main content

systemprompt_database/services/postgres/
mod.rs

1pub mod conversion;
2mod ext;
3mod introspection;
4pub mod transaction;
5
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
9use sqlx::Executor;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use super::provider::DatabaseProvider;
14use crate::models::{
15    DatabaseInfo, DatabaseTransaction, DbValue, JsonRow, QueryResult, QuerySelector, ToDbValue,
16};
17use conversion::{bind_params, row_to_json, rows_to_result};
18use transaction::PostgresTransaction;
19
20#[derive(Debug)]
21pub struct PostgresProvider {
22    pool: Arc<PgPool>,
23}
24
25impl PostgresProvider {
26    pub async fn new(database_url: &str) -> Result<Self> {
27        let mut connect_options = PgConnectOptions::from_str(database_url)
28            .map_err(|e| anyhow!("Failed to parse database URL: {e}"))?;
29
30        let ssl_mode = if database_url.contains("sslmode=require") {
31            PgSslMode::Require
32        } else if database_url.contains("sslmode=disable") {
33            PgSslMode::Disable
34        } else {
35            PgSslMode::Prefer
36        };
37
38        connect_options = connect_options
39            .application_name("systemprompt")
40            .statement_cache_capacity(0)
41            .ssl_mode(ssl_mode)
42            .options([("client_min_messages", "warning")]);
43
44        if let Some(ca_cert_path) = Self::get_cert_path() {
45            connect_options = connect_options.ssl_root_cert(&ca_cert_path);
46        }
47
48        let pool = sqlx::postgres::PgPoolOptions::new()
49            .max_connections(50)
50            .min_connections(0)
51            .max_lifetime(std::time::Duration::from_secs(1800))
52            .acquire_timeout(std::time::Duration::from_secs(30))
53            .idle_timeout(std::time::Duration::from_secs(300))
54            .connect_with(connect_options)
55            .await
56            .map_err(|e| anyhow!("Failed to connect to PostgreSQL: {e}"))?;
57
58        Ok(Self {
59            pool: Arc::new(pool),
60        })
61    }
62
63    fn get_cert_path() -> Option<std::path::PathBuf> {
64        std::env::var("PGCA_CERT_PATH")
65            .ok()
66            .map(std::path::PathBuf::from)
67    }
68
69    #[must_use]
70    pub fn pool(&self) -> &PgPool {
71        &self.pool
72    }
73}
74
75#[async_trait]
76impl DatabaseProvider for PostgresProvider {
77    fn get_postgres_pool(&self) -> Option<Arc<PgPool>> {
78        Some(Arc::clone(&self.pool))
79    }
80
81    async fn execute(&self, query: &dyn QuerySelector, params: &[&dyn ToDbValue]) -> Result<u64> {
82        let sql = query.select_query();
83        let query_obj = sqlx::query(sql);
84        let query_obj = bind_params(query_obj, params);
85
86        let result = query_obj
87            .execute(&*self.pool)
88            .await
89            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
90
91        Ok(result.rows_affected())
92    }
93
94    async fn execute_raw(&self, sql: &str) -> Result<()> {
95        let mut conn = self
96            .pool
97            .acquire()
98            .await
99            .map_err(|e| anyhow!("Failed to acquire connection: {e}"))?;
100
101        conn.execute(sql)
102            .await
103            .map_err(|e| anyhow!("Raw query execution failed: {e}"))?;
104
105        Ok(())
106    }
107
108    async fn fetch_all(
109        &self,
110        query: &dyn QuerySelector,
111        params: &[&dyn ToDbValue],
112    ) -> Result<Vec<JsonRow>> {
113        let sql = query.select_query();
114        let query_obj = sqlx::query(sql);
115        let query_obj = bind_params(query_obj, params);
116
117        let rows = query_obj
118            .fetch_all(&*self.pool)
119            .await
120            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
121
122        Ok(rows.iter().map(row_to_json).collect())
123    }
124
125    async fn fetch_one(
126        &self,
127        query: &dyn QuerySelector,
128        params: &[&dyn ToDbValue],
129    ) -> Result<JsonRow> {
130        let sql = query.select_query();
131        let query_obj = sqlx::query(sql);
132        let query_obj = bind_params(query_obj, params);
133
134        let row = query_obj
135            .fetch_one(&*self.pool)
136            .await
137            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
138
139        Ok(row_to_json(&row))
140    }
141
142    async fn fetch_optional(
143        &self,
144        query: &dyn QuerySelector,
145        params: &[&dyn ToDbValue],
146    ) -> Result<Option<JsonRow>> {
147        let sql = query.select_query();
148        let query_obj = sqlx::query(sql);
149        let query_obj = bind_params(query_obj, params);
150
151        let row = query_obj
152            .fetch_optional(&*self.pool)
153            .await
154            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
155
156        Ok(row.map(|r| row_to_json(&r)))
157    }
158
159    async fn fetch_scalar_value(
160        &self,
161        query: &dyn QuerySelector,
162        params: &[&dyn ToDbValue],
163    ) -> Result<DbValue> {
164        let row = self.fetch_one(query, params).await?;
165
166        let first_value = row
167            .values()
168            .next()
169            .ok_or_else(|| anyhow!("No columns in result"))?;
170
171        let db_value = match first_value {
172            serde_json::Value::String(s) => DbValue::String(s.clone()),
173            serde_json::Value::Number(n) => n
174                .as_i64()
175                .map(DbValue::Int)
176                .or_else(|| n.as_f64().map(DbValue::Float))
177                .unwrap_or(DbValue::NullFloat),
178            serde_json::Value::Bool(b) => DbValue::Bool(*b),
179            serde_json::Value::Null => DbValue::NullString,
180            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
181                return Err(anyhow!("Unsupported value type"))
182            },
183        };
184
185        Ok(db_value)
186    }
187
188    async fn begin_transaction(&self) -> Result<Box<dyn DatabaseTransaction>> {
189        let tx = self
190            .pool
191            .begin()
192            .await
193            .map_err(|e| anyhow!("Failed to begin transaction: {e}"))?;
194
195        Ok(Box::new(PostgresTransaction::new(tx)))
196    }
197
198    async fn get_database_info(&self) -> Result<DatabaseInfo> {
199        introspection::get_database_info(&self.pool).await
200    }
201
202    async fn test_connection(&self) -> Result<()> {
203        sqlx::query("SELECT 1")
204            .fetch_one(&*self.pool)
205            .await
206            .map_err(|e| anyhow!("Connection test failed: {e}"))?;
207        Ok(())
208    }
209
210    async fn execute_batch(&self, sql: &str) -> Result<()> {
211        let statements = crate::services::SqlExecutor::parse_sql_statements(sql);
212        for statement in statements {
213            sqlx::query(&statement)
214                .execute(&*self.pool)
215                .await
216                .map_err(|e| anyhow!("Batch execution failed: {e}"))?;
217        }
218        Ok(())
219    }
220
221    async fn query_raw(&self, query: &dyn QuerySelector) -> Result<QueryResult> {
222        let sql = query.select_query();
223        let start = std::time::Instant::now();
224
225        let rows = sqlx::query(sql)
226            .fetch_all(&*self.pool)
227            .await
228            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
229
230        Ok(rows_to_result(rows, start))
231    }
232
233    async fn query_raw_with(
234        &self,
235        query: &dyn QuerySelector,
236        params: Vec<serde_json::Value>,
237    ) -> Result<QueryResult> {
238        let sql = query.select_query();
239        let start = std::time::Instant::now();
240
241        let mut query_obj = sqlx::query(sql);
242        for param in params {
243            query_obj = match param {
244                serde_json::Value::String(s) => query_obj.bind(s),
245                serde_json::Value::Number(n) => {
246                    if let Some(i) = n.as_i64() {
247                        query_obj.bind(i)
248                    } else if let Some(f) = n.as_f64() {
249                        query_obj.bind(f)
250                    } else {
251                        query_obj.bind(None::<i64>)
252                    }
253                },
254                serde_json::Value::Bool(b) => query_obj.bind(b),
255                serde_json::Value::Null => query_obj.bind(None::<String>),
256                serde_json::Value::Array(arr) => {
257                    let strings: Vec<String> = arr
258                        .into_iter()
259                        .filter_map(|v| v.as_str().map(String::from))
260                        .collect();
261                    query_obj.bind(strings)
262                },
263                serde_json::Value::Object(_) => {
264                    let json_str = serde_json::to_string(&param)
265                        .map_err(|e| anyhow!("Failed to serialize JSON object: {e}"))?;
266                    query_obj.bind(Some(json_str))
267                },
268            };
269        }
270
271        let rows = query_obj
272            .fetch_all(&*self.pool)
273            .await
274            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
275
276        Ok(rows_to_result(rows, start))
277    }
278}