Skip to main content

systemprompt_database/services/
database.rs

1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8    provider: Arc<dyn DatabaseProvider>,
9    write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.debug_struct("Database")
15            .field("backend", &"PostgreSQL")
16            .finish()
17    }
18}
19
20impl Database {
21    pub async fn new_postgres(url: &str) -> Result<Self> {
22        let provider = PostgresProvider::new(url).await?;
23        Ok(Self {
24            provider: Arc::new(provider),
25            write_provider: None,
26        })
27    }
28
29    pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30        match db_type.to_lowercase().as_str() {
31            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32            other => Err(anyhow::anyhow!(
33                "Unsupported database type: {other}. Only PostgreSQL is supported."
34            )),
35        }
36    }
37
38    pub async fn from_config_with_write(
39        db_type: &str,
40        read_url: &str,
41        write_url: Option<&str>,
42    ) -> Result<Self> {
43        let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44            "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45            other => {
46                return Err(anyhow::anyhow!(
47                    "Unsupported database type: {other}. Only PostgreSQL is supported."
48                ))
49            },
50        };
51
52        let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53            Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54            None => None,
55        };
56
57        Ok(Self {
58            provider,
59            write_provider,
60        })
61    }
62
63    pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64        self.provider
65            .get_postgres_pool()
66            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
67    }
68
69    pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
70        self.write_provider.as_ref().map_or_else(
71            || self.get_postgres_pool_arc(),
72            |wp| {
73                wp.get_postgres_pool()
74                    .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
75            },
76        )
77    }
78
79    #[must_use]
80    pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
81        self.write_provider
82            .as_ref()
83            .and_then(|wp| wp.get_postgres_pool())
84            .or_else(|| self.provider.get_postgres_pool())
85    }
86
87    #[must_use]
88    pub fn has_write_pool(&self) -> bool {
89        self.write_provider.is_some()
90    }
91
92    #[must_use]
93    pub fn write_provider(&self) -> &dyn DatabaseProvider {
94        self.write_provider
95            .as_deref()
96            .unwrap_or_else(|| self.provider.as_ref())
97    }
98
99    pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
100        self.provider.query_raw(sql).await
101    }
102
103    pub async fn query_with(
104        &self,
105        sql: &dyn crate::models::QuerySelector,
106        params: Vec<serde_json::Value>,
107    ) -> Result<QueryResult> {
108        self.provider.query_raw_with(sql, params).await
109    }
110
111    pub async fn execute_batch(&self, sql: &str) -> Result<()> {
112        self.provider.execute_batch(sql).await
113    }
114
115    pub async fn get_info(&self) -> Result<DatabaseInfo> {
116        self.provider.get_database_info().await
117    }
118
119    pub async fn test_connection(&self) -> Result<()> {
120        self.provider.test_connection().await?;
121        if let Some(wp) = &self.write_provider {
122            wp.test_connection().await?;
123        }
124        Ok(())
125    }
126
127    #[must_use]
128    pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
129        self.provider.get_postgres_pool()
130    }
131
132    pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
133        self.get_postgres_pool_arc()
134    }
135
136    #[must_use]
137    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
138        self.get_postgres_pool()
139    }
140
141    pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
142        let pool = self.write_pool_arc()?;
143        pool.begin().await.map_err(Into::into)
144    }
145}
146
147pub type DbPool = Arc<Database>;
148
149pub trait DatabaseExt {
150    fn database(&self) -> Arc<Database>;
151}
152
153impl DatabaseExt for Arc<Database> {
154    fn database(&self) -> Arc<Database> {
155        Self::clone(self)
156    }
157}
158
159#[async_trait::async_trait]
160impl DatabaseProvider for Database {
161    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
162        self.provider.get_postgres_pool()
163    }
164
165    async fn execute(
166        &self,
167        query: &dyn crate::models::QuerySelector,
168        params: &[&dyn crate::models::ToDbValue],
169    ) -> Result<u64> {
170        self.provider.execute(query, params).await
171    }
172
173    async fn execute_raw(&self, sql: &str) -> Result<()> {
174        self.provider.execute_raw(sql).await
175    }
176
177    async fn fetch_all(
178        &self,
179        query: &dyn crate::models::QuerySelector,
180        params: &[&dyn crate::models::ToDbValue],
181    ) -> Result<Vec<crate::models::JsonRow>> {
182        self.provider.fetch_all(query, params).await
183    }
184
185    async fn fetch_one(
186        &self,
187        query: &dyn crate::models::QuerySelector,
188        params: &[&dyn crate::models::ToDbValue],
189    ) -> Result<crate::models::JsonRow> {
190        self.provider.fetch_one(query, params).await
191    }
192
193    async fn fetch_optional(
194        &self,
195        query: &dyn crate::models::QuerySelector,
196        params: &[&dyn crate::models::ToDbValue],
197    ) -> Result<Option<crate::models::JsonRow>> {
198        self.provider.fetch_optional(query, params).await
199    }
200
201    async fn fetch_scalar_value(
202        &self,
203        query: &dyn crate::models::QuerySelector,
204        params: &[&dyn crate::models::ToDbValue],
205    ) -> Result<crate::models::DbValue> {
206        self.provider.fetch_scalar_value(query, params).await
207    }
208
209    async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
210        self.provider.begin_transaction().await
211    }
212
213    async fn get_database_info(&self) -> Result<DatabaseInfo> {
214        self.provider.get_database_info().await
215    }
216
217    async fn test_connection(&self) -> Result<()> {
218        self.provider.test_connection().await
219    }
220
221    async fn execute_batch(&self, sql: &str) -> Result<()> {
222        self.provider.execute_batch(sql).await
223    }
224
225    async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
226        self.provider.query_raw(query).await
227    }
228
229    async fn query_raw_with(
230        &self,
231        query: &dyn crate::models::QuerySelector,
232        params: Vec<serde_json::Value>,
233    ) -> Result<QueryResult> {
234        self.provider.query_raw_with(query, params).await
235    }
236}