Skip to main content

systemprompt_database/services/
database.rs

1//! Top-level [`Database`] handle that owns one or two
2//! [`DatabaseProvider`] instances (read + optional write) and exposes the
3//! query and transaction surface.
4
5use super::postgres::PostgresProvider;
6use super::provider::DatabaseProvider;
7use crate::error::{DatabaseResult, RepositoryError};
8use crate::models::{DatabaseInfo, QueryResult};
9use std::sync::Arc;
10
11pub struct Database {
12    provider: Arc<dyn DatabaseProvider>,
13    write_provider: Option<Arc<dyn DatabaseProvider>>,
14}
15
16impl std::fmt::Debug for Database {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("Database")
19            .field("backend", &"PostgreSQL")
20            .finish()
21    }
22}
23
24impl Database {
25    pub async fn new_postgres(url: &str) -> DatabaseResult<Self> {
26        let provider = PostgresProvider::new(url).await?;
27        Ok(Self {
28            provider: Arc::new(provider),
29            write_provider: None,
30        })
31    }
32
33    pub async fn from_config(db_type: &str, url: &str) -> DatabaseResult<Self> {
34        match db_type.to_lowercase().as_str() {
35            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
36            other => Err(RepositoryError::invalid_argument(format!(
37                "Unsupported database type: {other}. Only PostgreSQL is supported."
38            ))),
39        }
40    }
41
42    pub async fn from_config_with_write(
43        db_type: &str,
44        read_url: &str,
45        write_url: Option<&str>,
46    ) -> DatabaseResult<Self> {
47        let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
48            "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
49            other => {
50                return Err(RepositoryError::invalid_argument(format!(
51                    "Unsupported database type: {other}. Only PostgreSQL is supported."
52                )));
53            },
54        };
55
56        let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
57            Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
58            None => None,
59        };
60
61        Ok(Self {
62            provider,
63            write_provider,
64        })
65    }
66
67    pub fn get_postgres_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
68        self.pool_arc()
69    }
70
71    pub fn write_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
72        self.write_provider.as_ref().map_or_else(
73            || self.get_postgres_pool_arc(),
74            |wp| {
75                wp.get_postgres_pool().ok_or_else(|| {
76                    RepositoryError::invalid_state("Write database is not PostgreSQL")
77                })
78            },
79        )
80    }
81
82    #[must_use]
83    pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
84        self.write_provider
85            .as_ref()
86            .and_then(|wp| wp.get_postgres_pool())
87            .or_else(|| self.provider.get_postgres_pool())
88    }
89
90    #[must_use]
91    pub fn has_write_pool(&self) -> bool {
92        self.write_provider.is_some()
93    }
94
95    #[must_use]
96    pub fn write_provider(&self) -> &dyn DatabaseProvider {
97        self.write_provider
98            .as_deref()
99            .unwrap_or_else(|| self.provider.as_ref())
100    }
101
102    pub async fn query(
103        &self,
104        sql: &dyn crate::models::QuerySelector,
105    ) -> DatabaseResult<QueryResult> {
106        self.provider.query_raw(sql).await
107    }
108
109    pub async fn query_with(
110        &self,
111        sql: &dyn crate::models::QuerySelector,
112        params: Vec<serde_json::Value>,
113    ) -> DatabaseResult<QueryResult> {
114        self.provider.query_raw_with(sql, params).await
115    }
116
117    pub async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
118        self.provider.execute_batch(sql).await
119    }
120
121    pub async fn get_info(&self) -> DatabaseResult<DatabaseInfo> {
122        self.provider.get_database_info().await
123    }
124
125    pub async fn test_connection(&self) -> DatabaseResult<()> {
126        self.provider.test_connection().await?;
127        if let Some(wp) = &self.write_provider {
128            wp.test_connection().await?;
129        }
130        Ok(())
131    }
132
133    #[must_use]
134    pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
135        self.write_provider
136            .as_ref()
137            .and_then(|wp| wp.get_postgres_pool())
138            .or_else(|| self.provider.get_postgres_pool())
139    }
140
141    pub fn pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
142        self.get_postgres_pool()
143            .ok_or_else(|| RepositoryError::invalid_state("Database is not PostgreSQL"))
144    }
145
146    #[must_use]
147    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
148        self.get_postgres_pool()
149    }
150
151    #[must_use]
152    pub fn read_pool(&self) -> Option<Arc<sqlx::PgPool>> {
153        self.provider.get_postgres_pool()
154    }
155
156    pub fn read_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
157        self.provider
158            .get_postgres_pool()
159            .ok_or_else(|| RepositoryError::invalid_state("Database is not PostgreSQL"))
160    }
161
162    pub async fn begin(&self) -> DatabaseResult<sqlx::Transaction<'_, sqlx::Postgres>> {
163        let pool = self.write_pool_arc()?;
164        pool.begin().await.map_err(Into::into)
165    }
166}
167
168pub type DbPool = Arc<Database>;
169
170pub trait DatabaseExt {
171    fn database(&self) -> Arc<Database>;
172}
173
174impl DatabaseExt for Arc<Database> {
175    fn database(&self) -> Arc<Database> {
176        Self::clone(self)
177    }
178}
179
180#[async_trait::async_trait]
181impl DatabaseProvider for Database {
182    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
183        self.write_provider
184            .as_ref()
185            .and_then(|wp| wp.get_postgres_pool())
186            .or_else(|| self.provider.get_postgres_pool())
187    }
188
189    async fn execute(
190        &self,
191        query: &dyn crate::models::QuerySelector,
192        params: &[&dyn crate::models::ToDbValue],
193    ) -> DatabaseResult<u64> {
194        self.write_provider().execute(query, params).await
195    }
196
197    async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
198        self.write_provider().execute_raw(sql).await
199    }
200
201    async fn fetch_all(
202        &self,
203        query: &dyn crate::models::QuerySelector,
204        params: &[&dyn crate::models::ToDbValue],
205    ) -> DatabaseResult<Vec<crate::models::JsonRow>> {
206        self.provider.fetch_all(query, params).await
207    }
208
209    async fn fetch_one(
210        &self,
211        query: &dyn crate::models::QuerySelector,
212        params: &[&dyn crate::models::ToDbValue],
213    ) -> DatabaseResult<crate::models::JsonRow> {
214        self.provider.fetch_one(query, params).await
215    }
216
217    async fn fetch_optional(
218        &self,
219        query: &dyn crate::models::QuerySelector,
220        params: &[&dyn crate::models::ToDbValue],
221    ) -> DatabaseResult<Option<crate::models::JsonRow>> {
222        self.provider.fetch_optional(query, params).await
223    }
224
225    async fn fetch_scalar_value(
226        &self,
227        query: &dyn crate::models::QuerySelector,
228        params: &[&dyn crate::models::ToDbValue],
229    ) -> DatabaseResult<crate::models::DbValue> {
230        self.provider.fetch_scalar_value(query, params).await
231    }
232
233    async fn begin_transaction(
234        &self,
235    ) -> DatabaseResult<Box<dyn crate::models::DatabaseTransaction>> {
236        self.write_provider().begin_transaction().await
237    }
238
239    async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
240        self.provider.get_database_info().await
241    }
242
243    async fn test_connection(&self) -> DatabaseResult<()> {
244        self.provider.test_connection().await
245    }
246
247    async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
248        self.write_provider().execute_batch(sql).await
249    }
250
251    async fn query_raw(
252        &self,
253        query: &dyn crate::models::QuerySelector,
254    ) -> DatabaseResult<QueryResult> {
255        self.provider.query_raw(query).await
256    }
257
258    async fn query_raw_with(
259        &self,
260        query: &dyn crate::models::QuerySelector,
261        params: Vec<serde_json::Value>,
262    ) -> DatabaseResult<QueryResult> {
263        self.provider.query_raw_with(query, params).await
264    }
265}