1#[cfg(feature = "postgres")]
2use crate::backend::{DatabaseBackend, QueryResult, QueryRow};
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use serde_json::Value;
6use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
7use sqlx::{Row, Column};
8
9pub struct PostgresBackend {
10 pool: PgPool,
11 in_transaction: bool,
12}
13
14impl PostgresBackend {
15 pub async fn new(url: &str) -> Result<Self> {
17 let pool = PgPoolOptions::new()
18 .max_connections(5)
19 .connect(url)
20 .await?;
21
22 Ok(Self {
23 pool,
24 in_transaction: false,
25 })
26 }
27
28 fn convert_row(row: &PgRow) -> QueryRow {
30 let mut query_row = QueryRow::new();
31
32 for column in row.columns() {
33 let column_name = column.name();
34
35 if let Ok(value) = row.try_get::<Value, _>(column_name) {
37 query_row.insert(column_name.to_string(), value);
38 } else if let Ok(value) = row.try_get::<String, _>(column_name) {
39 query_row.insert(column_name.to_string(), Value::String(value));
40 } else if let Ok(value) = row.try_get::<i64, _>(column_name) {
41 query_row.insert(column_name.to_string(), Value::Number(value.into()));
42 } else if let Ok(value) = row.try_get::<i32, _>(column_name) {
43 query_row.insert(column_name.to_string(), Value::Number(value.into()));
44 } else if let Ok(value) = row.try_get::<f64, _>(column_name) {
45 if let Some(num) = serde_json::Number::from_f64(value) {
46 query_row.insert(column_name.to_string(), Value::Number(num));
47 }
48 } else if let Ok(value) = row.try_get::<bool, _>(column_name) {
49 query_row.insert(column_name.to_string(), Value::Bool(value));
50 } else {
51 query_row.insert(column_name.to_string(), Value::Null);
53 }
54 }
55
56 query_row
57 }
58}
59
60#[async_trait]
61impl DatabaseBackend for PostgresBackend {
62 async fn connect(url: &str) -> Result<Self> {
63 Self::new(url).await
64 }
65
66 async fn execute(&mut self, sql: &str) -> Result<u64> {
67 let result = sqlx::query(sql).execute(&self.pool).await?;
68 Ok(result.rows_affected())
69 }
70
71 async fn query(&mut self, sql: &str) -> Result<QueryResult> {
72 let rows = sqlx::query(sql).fetch_all(&self.pool).await?;
73
74 let result = rows.iter().map(Self::convert_row).collect();
75
76 Ok(result)
77 }
78
79 async fn begin_transaction(&mut self) -> Result<()> {
80 if self.in_transaction {
81 return Err(QueryError::Transaction(
82 "Already in transaction".to_string(),
83 ));
84 }
85
86 self.execute("BEGIN").await?;
88 self.in_transaction = true;
89 Ok(())
90 }
91
92 async fn commit(&mut self) -> Result<()> {
93 if !self.in_transaction {
94 return Err(QueryError::Transaction("Not in transaction".to_string()));
95 }
96
97 self.execute("COMMIT").await?;
98 self.in_transaction = false;
99 Ok(())
100 }
101
102 async fn rollback(&mut self) -> Result<()> {
103 if !self.in_transaction {
104 return Err(QueryError::Transaction("Not in transaction".to_string()));
105 }
106
107 self.execute("ROLLBACK").await?;
108 self.in_transaction = false;
109 Ok(())
110 }
111
112 fn is_connected(&self) -> bool {
113 !self.pool.is_closed()
114 }
115
116 async fn close(self) -> Result<()> {
117 self.pool.close().await;
118 Ok(())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[tokio::test]
127 #[ignore] async fn test_postgres_connection() {
129 let backend = PostgresBackend::connect("postgresql://localhost/test").await;
130 assert!(backend.is_ok());
131 }
132
133 #[tokio::test]
134 #[ignore] async fn test_postgres_query() {
136 let mut backend = PostgresBackend::connect("postgresql://localhost/test")
137 .await
138 .unwrap();
139
140 backend
142 .execute("CREATE TEMPORARY TABLE test_table (id BIGINT PRIMARY KEY, name TEXT)")
143 .await
144 .unwrap();
145
146 backend
148 .execute("INSERT INTO test_table (id, name) VALUES (1, 'Alice')")
149 .await
150 .unwrap();
151
152 let results = backend
154 .query("SELECT * FROM test_table WHERE id = 1")
155 .await
156 .unwrap();
157
158 assert_eq!(results.len(), 1);
159 assert_eq!(results[0].get_i64("id"), Some(1));
160 assert_eq!(results[0].get_string("name"), Some("Alice".to_string()));
161 }
162}