Skip to main content

systemprompt_database/services/
database.rs

1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8    provider: Arc<dyn DatabaseProvider>,
9    write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.debug_struct("Database")
15            .field("backend", &"PostgreSQL")
16            .finish()
17    }
18}
19
20impl Database {
21    pub async fn new_postgres(url: &str) -> Result<Self> {
22        let provider = PostgresProvider::new(url).await?;
23        Ok(Self {
24            provider: Arc::new(provider),
25            write_provider: None,
26        })
27    }
28
29    pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30        match db_type.to_lowercase().as_str() {
31            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32            other => Err(anyhow::anyhow!(
33                "Unsupported database type: {other}. Only PostgreSQL is supported."
34            )),
35        }
36    }
37
38    pub async fn from_config_with_write(
39        db_type: &str,
40        read_url: &str,
41        write_url: Option<&str>,
42    ) -> Result<Self> {
43        let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44            "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45            other => {
46                return Err(anyhow::anyhow!(
47                    "Unsupported database type: {other}. Only PostgreSQL is supported."
48                ));
49            },
50        };
51
52        let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53            Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54            None => None,
55        };
56
57        Ok(Self {
58            provider,
59            write_provider,
60        })
61    }
62
63    pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64        self.pool_arc()
65    }
66
67    pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
68        self.write_provider.as_ref().map_or_else(
69            || self.get_postgres_pool_arc(),
70            |wp| {
71                wp.get_postgres_pool()
72                    .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
73            },
74        )
75    }
76
77    #[must_use]
78    pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
79        self.write_provider
80            .as_ref()
81            .and_then(|wp| wp.get_postgres_pool())
82            .or_else(|| self.provider.get_postgres_pool())
83    }
84
85    #[must_use]
86    pub fn has_write_pool(&self) -> bool {
87        self.write_provider.is_some()
88    }
89
90    #[must_use]
91    pub fn write_provider(&self) -> &dyn DatabaseProvider {
92        self.write_provider
93            .as_deref()
94            .unwrap_or_else(|| self.provider.as_ref())
95    }
96
97    pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
98        self.provider.query_raw(sql).await
99    }
100
101    // JSON: dynamic query params — type-erased for heterogeneous admin queries
102    pub async fn query_with(
103        &self,
104        sql: &dyn crate::models::QuerySelector,
105        params: Vec<serde_json::Value>,
106    ) -> Result<QueryResult> {
107        self.provider.query_raw_with(sql, params).await
108    }
109
110    pub async fn execute_batch(&self, sql: &str) -> Result<()> {
111        self.provider.execute_batch(sql).await
112    }
113
114    pub async fn get_info(&self) -> Result<DatabaseInfo> {
115        self.provider.get_database_info().await
116    }
117
118    pub async fn test_connection(&self) -> Result<()> {
119        self.provider.test_connection().await?;
120        if let Some(wp) = &self.write_provider {
121            wp.test_connection().await?;
122        }
123        Ok(())
124    }
125
126    #[must_use]
127    pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
128        self.write_provider
129            .as_ref()
130            .and_then(|wp| wp.get_postgres_pool())
131            .or_else(|| self.provider.get_postgres_pool())
132    }
133
134    pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
135        self.get_postgres_pool()
136            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
137    }
138
139    #[must_use]
140    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
141        self.get_postgres_pool()
142    }
143
144    #[must_use]
145    pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
146        self.provider.get_postgres_pool()
147    }
148
149    pub fn read_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
150        self.provider
151            .get_postgres_pool()
152            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
153    }
154
155    pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
156        let pool = self.write_pool_arc()?;
157        pool.begin().await.map_err(Into::into)
158    }
159}
160
161pub type DbPool = Arc<Database>;
162
163pub trait DatabaseExt {
164    fn database(&self) -> Arc<Database>;
165}
166
167impl DatabaseExt for Arc<Database> {
168    fn database(&self) -> Arc<Database> {
169        Self::clone(self)
170    }
171}
172
173#[async_trait::async_trait]
174impl DatabaseProvider for Database {
175    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
176        self.write_provider
177            .as_ref()
178            .and_then(|wp| wp.get_postgres_pool())
179            .or_else(|| self.provider.get_postgres_pool())
180    }
181
182    async fn execute(
183        &self,
184        query: &dyn crate::models::QuerySelector,
185        params: &[&dyn crate::models::ToDbValue],
186    ) -> Result<u64> {
187        self.write_provider().execute(query, params).await
188    }
189
190    async fn execute_raw(&self, sql: &str) -> Result<()> {
191        self.write_provider().execute_raw(sql).await
192    }
193
194    async fn fetch_all(
195        &self,
196        query: &dyn crate::models::QuerySelector,
197        params: &[&dyn crate::models::ToDbValue],
198    ) -> Result<Vec<crate::models::JsonRow>> {
199        self.provider.fetch_all(query, params).await
200    }
201
202    async fn fetch_one(
203        &self,
204        query: &dyn crate::models::QuerySelector,
205        params: &[&dyn crate::models::ToDbValue],
206    ) -> Result<crate::models::JsonRow> {
207        self.provider.fetch_one(query, params).await
208    }
209
210    async fn fetch_optional(
211        &self,
212        query: &dyn crate::models::QuerySelector,
213        params: &[&dyn crate::models::ToDbValue],
214    ) -> Result<Option<crate::models::JsonRow>> {
215        self.provider.fetch_optional(query, params).await
216    }
217
218    async fn fetch_scalar_value(
219        &self,
220        query: &dyn crate::models::QuerySelector,
221        params: &[&dyn crate::models::ToDbValue],
222    ) -> Result<crate::models::DbValue> {
223        self.provider.fetch_scalar_value(query, params).await
224    }
225
226    async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
227        self.write_provider().begin_transaction().await
228    }
229
230    async fn get_database_info(&self) -> Result<DatabaseInfo> {
231        self.provider.get_database_info().await
232    }
233
234    async fn test_connection(&self) -> Result<()> {
235        self.provider.test_connection().await
236    }
237
238    async fn execute_batch(&self, sql: &str) -> Result<()> {
239        self.write_provider().execute_batch(sql).await
240    }
241
242    async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
243        self.provider.query_raw(query).await
244    }
245
246    async fn query_raw_with(
247        &self,
248        query: &dyn crate::models::QuerySelector,
249        params: Vec<serde_json::Value>,
250    ) -> Result<QueryResult> {
251        self.provider.query_raw_with(query, params).await
252    }
253}