Skip to main content

systemprompt_database/services/
database.rs

1//! Top-level [`Database`] handle that owns one or two
2//! [`DatabaseProvider`] instances (read + optional write) and exposes the
3//! query and transaction surface.
4
5use super::postgres::PostgresProvider;
6use super::provider::DatabaseProvider;
7use crate::error::{DatabaseResult, RepositoryError};
8use crate::models::{DatabaseInfo, QueryResult};
9use std::sync::Arc;
10
11pub struct Database {
12    provider: Arc<dyn DatabaseProvider>,
13    write_provider: Option<Arc<dyn DatabaseProvider>>,
14}
15
16impl std::fmt::Debug for Database {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("Database")
19            .field("backend", &"PostgreSQL")
20            .finish()
21    }
22}
23
24impl Database {
25    pub async fn new_postgres(url: &str) -> DatabaseResult<Self> {
26        let provider = PostgresProvider::new(url).await?;
27        Ok(Self {
28            provider: Arc::new(provider),
29            write_provider: None,
30        })
31    }
32
33    pub async fn from_config(db_type: &str, url: &str) -> DatabaseResult<Self> {
34        match db_type.to_lowercase().as_str() {
35            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
36            other => Err(RepositoryError::invalid_argument(format!(
37                "Unsupported database type: {other}. Only PostgreSQL is supported."
38            ))),
39        }
40    }
41
42    pub async fn from_config_with_write(
43        db_type: &str,
44        read_url: &str,
45        write_url: Option<&str>,
46    ) -> DatabaseResult<Self> {
47        let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
48            "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
49            other => {
50                return Err(RepositoryError::invalid_argument(format!(
51                    "Unsupported database type: {other}. Only PostgreSQL is supported."
52                )));
53            },
54        };
55
56        let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
57            Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
58            None => None,
59        };
60
61        Ok(Self {
62            provider,
63            write_provider,
64        })
65    }
66
67    /// Builds a handle from pools the caller already holds, reusing the open
68    /// connections rather than dialing the database again. The intended caller
69    /// is an extension HTTP router that is handed an `Arc<PgPool>` and needs to
70    /// construct core data services (which require a `Database`) without a URL.
71    #[must_use]
72    pub fn from_pools(read: Arc<sqlx::PgPool>, write: Option<Arc<sqlx::PgPool>>) -> Self {
73        let write_provider = write.map(|pool| -> Arc<dyn DatabaseProvider> {
74            Arc::new(PostgresProvider::from_pool(pool))
75        });
76        Self {
77            provider: Arc::new(PostgresProvider::from_pool(read)),
78            write_provider,
79        }
80    }
81
82    fn require_postgres(pool: Option<Arc<sqlx::PgPool>>) -> DatabaseResult<Arc<sqlx::PgPool>> {
83        pool.ok_or_else(|| RepositoryError::invalid_state("Database is not PostgreSQL"))
84    }
85
86    /// Provider that serves reads. Equal to [`Self::write`] when no separate
87    /// write URL is configured (single-node deployments).
88    #[must_use]
89    pub fn read(&self) -> &dyn DatabaseProvider {
90        self.provider.as_ref()
91    }
92
93    /// Provider that serves writes and transactions. Falls back to the read
94    /// provider when no separate write URL is configured.
95    #[must_use]
96    pub fn write(&self) -> &dyn DatabaseProvider {
97        self.write_provider
98            .as_deref()
99            .unwrap_or_else(|| self.provider.as_ref())
100    }
101
102    #[must_use]
103    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
104        self.read().get_postgres_pool()
105    }
106
107    pub fn pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
108        Self::require_postgres(self.read().get_postgres_pool())
109    }
110
111    #[must_use]
112    pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
113        self.write().get_postgres_pool()
114    }
115
116    pub fn write_pool_arc(&self) -> DatabaseResult<Arc<sqlx::PgPool>> {
117        Self::require_postgres(self.write().get_postgres_pool())
118    }
119
120    #[must_use]
121    pub fn has_write_pool(&self) -> bool {
122        self.write_provider.is_some()
123    }
124
125    pub async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
126        self.write().execute_batch(sql).await
127    }
128
129    pub async fn get_info(&self) -> DatabaseResult<DatabaseInfo> {
130        self.read().get_database_info().await
131    }
132
133    pub async fn test_connection(&self) -> DatabaseResult<()> {
134        self.provider.test_connection().await?;
135        if let Some(wp) = &self.write_provider {
136            wp.test_connection().await?;
137        }
138        Ok(())
139    }
140
141    pub async fn begin(&self) -> DatabaseResult<sqlx::Transaction<'_, sqlx::Postgres>> {
142        let pool = self.write_pool_arc()?;
143        pool.begin().await.map_err(Into::into)
144    }
145}
146
147pub type DbPool = Arc<Database>;
148
149pub trait DatabaseExt {
150    fn database(&self) -> Arc<Database>;
151}
152
153impl DatabaseExt for Arc<Database> {
154    fn database(&self) -> Arc<Database> {
155        Self::clone(self)
156    }
157}
158
159#[async_trait::async_trait]
160impl DatabaseProvider for Database {
161    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
162        self.read().get_postgres_pool()
163    }
164
165    async fn execute(
166        &self,
167        query: &dyn crate::models::QuerySelector,
168        params: &[&dyn crate::models::ToDbValue],
169    ) -> DatabaseResult<u64> {
170        self.write().execute(query, params).await
171    }
172
173    async fn execute_raw(&self, sql: &str) -> DatabaseResult<()> {
174        self.write().execute_raw(sql).await
175    }
176
177    async fn fetch_all(
178        &self,
179        query: &dyn crate::models::QuerySelector,
180        params: &[&dyn crate::models::ToDbValue],
181    ) -> DatabaseResult<Vec<crate::models::JsonRow>> {
182        self.read().fetch_all(query, params).await
183    }
184
185    async fn fetch_one(
186        &self,
187        query: &dyn crate::models::QuerySelector,
188        params: &[&dyn crate::models::ToDbValue],
189    ) -> DatabaseResult<crate::models::JsonRow> {
190        self.read().fetch_one(query, params).await
191    }
192
193    async fn fetch_optional(
194        &self,
195        query: &dyn crate::models::QuerySelector,
196        params: &[&dyn crate::models::ToDbValue],
197    ) -> DatabaseResult<Option<crate::models::JsonRow>> {
198        self.read().fetch_optional(query, params).await
199    }
200
201    async fn fetch_scalar_value(
202        &self,
203        query: &dyn crate::models::QuerySelector,
204        params: &[&dyn crate::models::ToDbValue],
205    ) -> DatabaseResult<crate::models::DbValue> {
206        self.read().fetch_scalar_value(query, params).await
207    }
208
209    async fn begin_transaction(
210        &self,
211    ) -> DatabaseResult<Box<dyn crate::models::DatabaseTransaction>> {
212        self.write().begin_transaction().await
213    }
214
215    async fn get_database_info(&self) -> DatabaseResult<DatabaseInfo> {
216        self.read().get_database_info().await
217    }
218
219    async fn test_connection(&self) -> DatabaseResult<()> {
220        self.read().test_connection().await
221    }
222
223    async fn execute_batch(&self, sql: &str) -> DatabaseResult<()> {
224        self.write().execute_batch(sql).await
225    }
226
227    async fn query_raw(
228        &self,
229        query: &dyn crate::models::QuerySelector,
230    ) -> DatabaseResult<QueryResult> {
231        self.read().query_raw(query).await
232    }
233
234    async fn query_raw_with(
235        &self,
236        query: &dyn crate::models::QuerySelector,
237        params: &[&dyn crate::models::ToDbValue],
238    ) -> DatabaseResult<QueryResult> {
239        self.read().query_raw_with(query, params).await
240    }
241}