systemprompt_database/services/
database.rs1use super::postgres::PostgresProvider;
2use super::provider::DatabaseProvider;
3use crate::models::{DatabaseInfo, QueryResult};
4use anyhow::Result;
5use std::sync::Arc;
6
7pub struct Database {
8 provider: Arc<dyn DatabaseProvider>,
9 write_provider: Option<Arc<dyn DatabaseProvider>>,
10}
11
12impl std::fmt::Debug for Database {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.debug_struct("Database")
15 .field("backend", &"PostgreSQL")
16 .finish()
17 }
18}
19
20impl Database {
21 pub async fn new_postgres(url: &str) -> Result<Self> {
22 let provider = PostgresProvider::new(url).await?;
23 Ok(Self {
24 provider: Arc::new(provider),
25 write_provider: None,
26 })
27 }
28
29 pub async fn from_config(db_type: &str, url: &str) -> Result<Self> {
30 match db_type.to_lowercase().as_str() {
31 "postgres" | "postgresql" | "" => Self::new_postgres(url).await,
32 other => Err(anyhow::anyhow!(
33 "Unsupported database type: {other}. Only PostgreSQL is supported."
34 )),
35 }
36 }
37
38 pub async fn from_config_with_write(
39 db_type: &str,
40 read_url: &str,
41 write_url: Option<&str>,
42 ) -> Result<Self> {
43 let provider: Arc<dyn DatabaseProvider> = match db_type.to_lowercase().as_str() {
44 "postgres" | "postgresql" | "" => Arc::new(PostgresProvider::new(read_url).await?),
45 other => {
46 return Err(anyhow::anyhow!(
47 "Unsupported database type: {other}. Only PostgreSQL is supported."
48 ))
49 },
50 };
51
52 let write_provider: Option<Arc<dyn DatabaseProvider>> = match write_url {
53 Some(url) => Some(Arc::new(PostgresProvider::new(url).await?)),
54 None => None,
55 };
56
57 Ok(Self {
58 provider,
59 write_provider,
60 })
61 }
62
63 pub fn get_postgres_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
64 self.provider
65 .get_postgres_pool()
66 .ok_or_else(|| anyhow::anyhow!("Database is not PostgreSQL"))
67 }
68
69 pub fn write_pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
70 self.write_provider.as_ref().map_or_else(
71 || self.get_postgres_pool_arc(),
72 |wp| {
73 wp.get_postgres_pool()
74 .ok_or_else(|| anyhow::anyhow!("Write database is not PostgreSQL"))
75 },
76 )
77 }
78
79 #[must_use]
80 pub fn write_pool(&self) -> Option<Arc<sqlx::PgPool>> {
81 self.write_provider
82 .as_ref()
83 .and_then(|wp| wp.get_postgres_pool())
84 .or_else(|| self.provider.get_postgres_pool())
85 }
86
87 #[must_use]
88 pub fn has_write_pool(&self) -> bool {
89 self.write_provider.is_some()
90 }
91
92 #[must_use]
93 pub fn write_provider(&self) -> &dyn DatabaseProvider {
94 self.write_provider
95 .as_deref()
96 .unwrap_or_else(|| self.provider.as_ref())
97 }
98
99 pub async fn query(&self, sql: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
100 self.provider.query_raw(sql).await
101 }
102
103 pub async fn query_with(
104 &self,
105 sql: &dyn crate::models::QuerySelector,
106 params: Vec<serde_json::Value>,
107 ) -> Result<QueryResult> {
108 self.provider.query_raw_with(sql, params).await
109 }
110
111 pub async fn execute_batch(&self, sql: &str) -> Result<()> {
112 self.provider.execute_batch(sql).await
113 }
114
115 pub async fn get_info(&self) -> Result<DatabaseInfo> {
116 self.provider.get_database_info().await
117 }
118
119 pub async fn test_connection(&self) -> Result<()> {
120 self.provider.test_connection().await?;
121 if let Some(wp) = &self.write_provider {
122 wp.test_connection().await?;
123 }
124 Ok(())
125 }
126
127 #[must_use]
128 pub fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
129 self.provider.get_postgres_pool()
130 }
131
132 pub fn pool_arc(&self) -> Result<Arc<sqlx::PgPool>> {
133 self.get_postgres_pool_arc()
134 }
135
136 #[must_use]
137 pub fn pool(&self) -> Option<Arc<sqlx::PgPool>> {
138 self.get_postgres_pool()
139 }
140
141 pub async fn begin(&self) -> Result<sqlx::Transaction<'_, sqlx::Postgres>> {
142 let pool = self.write_pool_arc()?;
143 pool.begin().await.map_err(Into::into)
144 }
145}
146
147pub type DbPool = Arc<Database>;
148
149pub trait DatabaseExt {
150 fn database(&self) -> Arc<Database>;
151}
152
153impl DatabaseExt for Arc<Database> {
154 fn database(&self) -> Arc<Database> {
155 Self::clone(self)
156 }
157}
158
159#[async_trait::async_trait]
160impl DatabaseProvider for Database {
161 fn get_postgres_pool(&self) -> Option<Arc<sqlx::PgPool>> {
162 self.provider.get_postgres_pool()
163 }
164
165 async fn execute(
166 &self,
167 query: &dyn crate::models::QuerySelector,
168 params: &[&dyn crate::models::ToDbValue],
169 ) -> Result<u64> {
170 self.provider.execute(query, params).await
171 }
172
173 async fn execute_raw(&self, sql: &str) -> Result<()> {
174 self.provider.execute_raw(sql).await
175 }
176
177 async fn fetch_all(
178 &self,
179 query: &dyn crate::models::QuerySelector,
180 params: &[&dyn crate::models::ToDbValue],
181 ) -> Result<Vec<crate::models::JsonRow>> {
182 self.provider.fetch_all(query, params).await
183 }
184
185 async fn fetch_one(
186 &self,
187 query: &dyn crate::models::QuerySelector,
188 params: &[&dyn crate::models::ToDbValue],
189 ) -> Result<crate::models::JsonRow> {
190 self.provider.fetch_one(query, params).await
191 }
192
193 async fn fetch_optional(
194 &self,
195 query: &dyn crate::models::QuerySelector,
196 params: &[&dyn crate::models::ToDbValue],
197 ) -> Result<Option<crate::models::JsonRow>> {
198 self.provider.fetch_optional(query, params).await
199 }
200
201 async fn fetch_scalar_value(
202 &self,
203 query: &dyn crate::models::QuerySelector,
204 params: &[&dyn crate::models::ToDbValue],
205 ) -> Result<crate::models::DbValue> {
206 self.provider.fetch_scalar_value(query, params).await
207 }
208
209 async fn begin_transaction(&self) -> Result<Box<dyn crate::models::DatabaseTransaction>> {
210 self.provider.begin_transaction().await
211 }
212
213 async fn get_database_info(&self) -> Result<DatabaseInfo> {
214 self.provider.get_database_info().await
215 }
216
217 async fn test_connection(&self) -> Result<()> {
218 self.provider.test_connection().await
219 }
220
221 async fn execute_batch(&self, sql: &str) -> Result<()> {
222 self.provider.execute_batch(sql).await
223 }
224
225 async fn query_raw(&self, query: &dyn crate::models::QuerySelector) -> Result<QueryResult> {
226 self.provider.query_raw(query).await
227 }
228
229 async fn query_raw_with(
230 &self,
231 query: &dyn crate::models::QuerySelector,
232 params: Vec<serde_json::Value>,
233 ) -> Result<QueryResult> {
234 self.provider.query_raw_with(query, params).await
235 }
236}