tx2_query/
postgres.rs

1#[cfg(feature = "postgres")]
2use crate::backend::{DatabaseBackend, QueryResult, QueryRow};
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use serde_json::Value;
6use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
7use sqlx::{Row, Column};
8
9pub struct PostgresBackend {
10    pool: PgPool,
11    in_transaction: bool,
12}
13
14impl PostgresBackend {
15    /// Create a new PostgreSQL backend with connection pool
16    pub async fn new(url: &str) -> Result<Self> {
17        let pool = PgPoolOptions::new()
18            .max_connections(5)
19            .connect(url)
20            .await?;
21
22        Ok(Self {
23            pool,
24            in_transaction: false,
25        })
26    }
27
28    /// Convert PostgreSQL row to QueryRow
29    fn convert_row(row: &PgRow) -> QueryRow {
30        let mut query_row = QueryRow::new();
31
32        for column in row.columns() {
33            let column_name = column.name();
34
35            // Try to extract value as JSON
36            if let Ok(value) = row.try_get::<Value, _>(column_name) {
37                query_row.insert(column_name.to_string(), value);
38            } else if let Ok(value) = row.try_get::<String, _>(column_name) {
39                query_row.insert(column_name.to_string(), Value::String(value));
40            } else if let Ok(value) = row.try_get::<i64, _>(column_name) {
41                query_row.insert(column_name.to_string(), Value::Number(value.into()));
42            } else if let Ok(value) = row.try_get::<i32, _>(column_name) {
43                query_row.insert(column_name.to_string(), Value::Number(value.into()));
44            } else if let Ok(value) = row.try_get::<f64, _>(column_name) {
45                if let Some(num) = serde_json::Number::from_f64(value) {
46                    query_row.insert(column_name.to_string(), Value::Number(num));
47                }
48            } else if let Ok(value) = row.try_get::<bool, _>(column_name) {
49                query_row.insert(column_name.to_string(), Value::Bool(value));
50            } else {
51                // Default to null if we can't extract the value
52                query_row.insert(column_name.to_string(), Value::Null);
53            }
54        }
55
56        query_row
57    }
58}
59
60#[async_trait]
61impl DatabaseBackend for PostgresBackend {
62    async fn connect(url: &str) -> Result<Self> {
63        Self::new(url).await
64    }
65
66    async fn execute(&mut self, sql: &str) -> Result<u64> {
67        let result = sqlx::query(sql).execute(&self.pool).await?;
68        Ok(result.rows_affected())
69    }
70
71    async fn query(&mut self, sql: &str) -> Result<QueryResult> {
72        let rows = sqlx::query(sql).fetch_all(&self.pool).await?;
73
74        let result = rows.iter().map(Self::convert_row).collect();
75
76        Ok(result)
77    }
78
79    async fn begin_transaction(&mut self) -> Result<()> {
80        if self.in_transaction {
81            return Err(QueryError::Transaction(
82                "Already in transaction".to_string(),
83            ));
84        }
85
86        // Start transaction using SAVEPOINT approach
87        self.execute("BEGIN").await?;
88        self.in_transaction = true;
89        Ok(())
90    }
91
92    async fn commit(&mut self) -> Result<()> {
93        if !self.in_transaction {
94            return Err(QueryError::Transaction("Not in transaction".to_string()));
95        }
96
97        self.execute("COMMIT").await?;
98        self.in_transaction = false;
99        Ok(())
100    }
101
102    async fn rollback(&mut self) -> Result<()> {
103        if !self.in_transaction {
104            return Err(QueryError::Transaction("Not in transaction".to_string()));
105        }
106
107        self.execute("ROLLBACK").await?;
108        self.in_transaction = false;
109        Ok(())
110    }
111
112    fn is_connected(&self) -> bool {
113        !self.pool.is_closed()
114    }
115
116    async fn close(self) -> Result<()> {
117        self.pool.close().await;
118        Ok(())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[tokio::test]
127    #[ignore] // Requires PostgreSQL running
128    async fn test_postgres_connection() {
129        let backend = PostgresBackend::connect("postgresql://localhost/test").await;
130        assert!(backend.is_ok());
131    }
132
133    #[tokio::test]
134    #[ignore] // Requires PostgreSQL running
135    async fn test_postgres_query() {
136        let mut backend = PostgresBackend::connect("postgresql://localhost/test")
137            .await
138            .unwrap();
139
140        // Create test table
141        backend
142            .execute("CREATE TEMPORARY TABLE test_table (id BIGINT PRIMARY KEY, name TEXT)")
143            .await
144            .unwrap();
145
146        // Insert data
147        backend
148            .execute("INSERT INTO test_table (id, name) VALUES (1, 'Alice')")
149            .await
150            .unwrap();
151
152        // Query data
153        let results = backend
154            .query("SELECT * FROM test_table WHERE id = 1")
155            .await
156            .unwrap();
157
158        assert_eq!(results.len(), 1);
159        assert_eq!(results[0].get_i64("id"), Some(1));
160        assert_eq!(results[0].get_string("name"), Some("Alice".to_string()));
161    }
162}