Skip to main content

systemprompt_database/services/
database.rs

1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8    provider: Arc<dyn DatabaseProvider>,
9}
10
11impl std::fmt::Debug for Database {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        f.debug_struct("Database")
14            .field("backend", &"PostgreSQL")
15            .finish()
16    }
17}
18
19impl Database {
20    pub async fn new_postgres(url: &str) -> Result<Self> {
21        let provider = PostgresProvider::new(url).await?;
22        Ok(Self {
23            provider: Arc::new(provider),
24        })
25    }
26
27    pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
28        match db_type.to_lowercase().as_str() {
29            "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
30            other => Err(anyhow::anyhow!(
31                "Unsupported database type: {other}. Only PostgreSQL is supported."
32            )),
33        }
34    }
35
36    pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
37        self.provider
38            .get_postgres_pool()
39            .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
40    }
41
42    pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
43        self.provider.query_raw(sql).await
44    }
45
46    pub async fn query_with(
47        &self,
48        sql: &dyn crate::models::QuerySelector,
49        params: Vec<serde_json::Value>,
50    ) -> Result<QueryResult> {
51        self.provider.query_raw_with(sql, params).await
52    }
53
54    pub async fn execute_batch(&self, sql: &str) -> Result<()> {
55        self.provider.execute_batch(sql).await
56    }
57
58    pub async fn get_info(&self) -> Result<DatabaseInfo> {
59        self.provider.get_database_info().await
60    }
61
62    pub async fn test_connection(&self) -> Result<()> {
63        self.provider.test_connection().await
64    }
65
66    #[must_use]
67    pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
68        self.provider.get_postgres_pool()
69    }
70
71    pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
72        self.get_postgres_pool_arc()
73    }
74
75    #[must_use]
76    pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
77        self.get_postgres_pool()
78    }
79
80    pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
81        let pool = self.pool_arc()?;
82        pool.begin().await.map_err(Into::into)
83    }
84}
85
86pub type DbPool = Arc<Database>;
87
88pub trait DatabaseExt {
89    fn database(&self) -> Arc<Database>;
90}
91
92impl DatabaseExt for Arc<Database> {
93    fn database(&self) -> Arc<Database> {
94        Self::clone(self)
95    }
96}
97
98#[async_trait::async_trait]
99impl DatabaseProvider for Database {
100    fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
101        self.provider.get_postgres_pool()
102    }
103
104    async fn execute(
105        &self,
106        query: &dyn crate::models::QuerySelector,
107        params: &[&dyn crate::models::ToDbValue],
108    ) -> Result<u64> {
109        self.provider.execute(query, params).await
110    }
111
112    async fn execute_raw(&self, sql: &str) -> Result<()> {
113        self.provider.execute_raw(sql).await
114    }
115
116    async fn fetch_all(
117        &self,
118        query: &dyn crate::models::QuerySelector,
119        params: &[&dyn crate::models::ToDbValue],
120    ) -> Result<Vec<crate::models::JsonRow>> {
121        self.provider.fetch_all(query, params).await
122    }
123
124    async fn fetch_one(
125        &self,
126        query: &dyn crate::models::QuerySelector,
127        params: &[&dyn crate::models::ToDbValue],
128    ) -> Result<crate::models::JsonRow> {
129        self.provider.fetch_one(query, params).await
130    }
131
132    async fn fetch_optional(
133        &self,
134        query: &dyn crate::models::QuerySelector,
135        params: &[&dyn crate::models::ToDbValue],
136    ) -> Result<Option<crate::models::JsonRow>> {
137        self.provider.fetch_optional(query, params).await
138    }
139
140    async fn fetch_scalar_value(
141        &self,
142        query: &dyn crate::models::QuerySelector,
143        params: &[&dyn crate::models::ToDbValue],
144    ) -> Result<crate::models::DbValue> {
145        self.provider.fetch_scalar_value(query, params).await
146    }
147
148    async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
149        self.provider.begin_transaction().await
150    }
151
152    async fn get_database_info(&self) -> Result<DatabaseInfo> {
153        self.provider.get_database_info().await
154    }
155
156    async fn test_connection(&self) -> Result<()> {
157        self.provider.test_connection().await
158    }
159
160    async fn execute_batch(&self, sql: &str) -> Result<()> {
161        self.provider.execute_batch(sql).await
162    }
163
164    async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
165        self.provider.query_raw(query).await
166    }
167
168    async fn query_raw_with(
169        &self,
170        query: &dyn crate::models::QuerySelector,
171        params: Vec<serde_json::Value>,
172    ) -> Result<QueryResult> {
173        self.provider.query_raw_with(query, params).await
174    }
175}