tx2_query/
duckdb.rs

1#[cfg(feature = "duckdb")]
2use crate::backend::{DatabaseBackend, QueryResult, QueryRow};
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use duckdb::{params, Connection, Row};
6use serde_json::Value;
7use std::sync::{Arc, Mutex};
8
9/// DuckDB backend for OLAP workloads and analytics
10pub struct DuckDBBackend {
11    conn: Arc<Mutex<Connection>>,
12    in_transaction: bool,
13}
14
15impl DuckDBBackend {
16    /// Create a new DuckDB backend with an in-memory database
17    pub async fn memory() -> Result<Self> {
18        let conn = Connection::open_in_memory()
19            .map_err(|e| QueryError::Connection(e.to_string()))?;
20
21        Ok(Self {
22            conn: Arc::new(Mutex::new(conn)),
23            in_transaction: false,
24        })
25    }
26
27    /// Create a new DuckDB backend with a file-based database
28    pub async fn file(path: &str) -> Result<Self> {
29        let conn = Connection::open(path)
30            .map_err(|e| QueryError::Connection(e.to_string()))?;
31
32        Ok(Self {
33            conn: Arc::new(Mutex::new(conn)),
34            in_transaction: false,
35        })
36    }
37
38    /// Convert DuckDB row to QueryRow
39    fn convert_row(row: &Row) -> Result<QueryRow> {
40        let mut query_row = QueryRow::new();
41        let column_count = row.as_ref().column_count();
42
43        for i in 0..column_count {
44            let column_name = row.as_ref().column_name(i)
45                .map_err(|e| QueryError::Database(e.to_string()))?;
46
47            // Try to extract value with proper type handling
48            if let Ok(value) = row.get::<_, Option<String>>(i) {
49                if let Some(s) = value {
50                    query_row.insert(column_name.to_string(), Value::String(s));
51                } else {
52                    query_row.insert(column_name.to_string(), Value::Null);
53                }
54            } else if let Ok(value) = row.get::<_, Option<i64>>(i) {
55                if let Some(n) = value {
56                    query_row.insert(column_name.to_string(), Value::Number(n.into()));
57                } else {
58                    query_row.insert(column_name.to_string(), Value::Null);
59                }
60            } else if let Ok(value) = row.get::<_, Option<i32>>(i) {
61                if let Some(n) = value {
62                    query_row.insert(column_name.to_string(), Value::Number(n.into()));
63                } else {
64                    query_row.insert(column_name.to_string(), Value::Null);
65                }
66            } else if let Ok(value) = row.get::<_, Option<f64>>(i) {
67                if let Some(n) = value {
68                    if let Some(num) = serde_json::Number::from_f64(n) {
69                        query_row.insert(column_name.to_string(), Value::Number(num));
70                    }
71                } else {
72                    query_row.insert(column_name.to_string(), Value::Null);
73                }
74            } else if let Ok(value) = row.get::<_, Option<bool>>(i) {
75                if let Some(b) = value {
76                    query_row.insert(column_name.to_string(), Value::Bool(b));
77                } else {
78                    query_row.insert(column_name.to_string(), Value::Null);
79                }
80            } else {
81                // Default to null if we can't extract the value
82                query_row.insert(column_name.to_string(), Value::Null);
83            }
84        }
85
86        Ok(query_row)
87    }
88
89    /// Export database to Parquet file
90    pub async fn export_parquet(&mut self, table: &str, path: &str) -> Result<()> {
91        let sql = format!("COPY {} TO '{}' (FORMAT PARQUET)", table, path);
92        self.execute(&sql).await?;
93        Ok(())
94    }
95
96    /// Import data from Parquet file
97    pub async fn import_parquet(&mut self, table: &str, path: &str) -> Result<()> {
98        let sql = format!("COPY {} FROM '{}' (FORMAT PARQUET)", table, path);
99        self.execute(&sql).await?;
100        Ok(())
101    }
102
103    /// Export database to CSV file
104    pub async fn export_csv(&mut self, table: &str, path: &str) -> Result<()> {
105        let sql = format!("COPY {} TO '{}' (FORMAT CSV, HEADER)", table, path);
106        self.execute(&sql).await?;
107        Ok(())
108    }
109
110    /// Get table statistics
111    pub async fn analyze_table(&mut self, table: &str) -> Result<()> {
112        let sql = format!("ANALYZE {}", table);
113        self.execute(&sql).await?;
114        Ok(())
115    }
116
117    /// Create an index on a table
118    pub async fn create_index(&mut self, table: &str, column: &str, index_name: &str) -> Result<()> {
119        let sql = format!("CREATE INDEX IF NOT EXISTS {} ON {} ({})", index_name, table, column);
120        self.execute(&sql).await?;
121        Ok(())
122    }
123}
124
125#[async_trait]
126impl DatabaseBackend for DuckDBBackend {
127    async fn connect(url: &str) -> Result<Self> {
128        if url == ":memory:" || url == "memory" {
129            Self::memory().await
130        } else {
131            Self::file(url).await
132        }
133    }
134
135    async fn execute(&mut self, sql: &str) -> Result<u64> {
136        // Run in blocking task to avoid blocking async runtime
137        let sql = sql.to_string();
138        let conn = self.conn.clone();
139
140        tokio::task::spawn_blocking(move || {
141            let conn = conn.lock()
142                .map_err(|e| QueryError::Database(format!("Lock error: {}", e)))?;
143
144            let affected = conn.execute(&sql, params![])
145                .map_err(|e| QueryError::Database(e.to_string()))?;
146
147            Ok(affected as u64)
148        })
149        .await
150        .map_err(|e| QueryError::Database(format!("Join error: {}", e)))?
151    }
152
153    async fn query(&mut self, sql: &str) -> Result<QueryResult> {
154        // Run in blocking task to avoid blocking async runtime
155        let sql = sql.to_string();
156        let conn = self.conn.clone();
157
158        tokio::task::spawn_blocking(move || {
159            let conn = conn.lock()
160                .map_err(|e| QueryError::Database(format!("Lock error: {}", e)))?;
161
162            let mut stmt = conn.prepare(&sql)
163                .map_err(|e| QueryError::Database(e.to_string()))?;
164
165            let rows = stmt.query_map(params![], |row| {
166                Ok(DuckDBBackend::convert_row(row).map_err(|e| duckdb::Error::ToSqlConversionFailure(Box::new(e)))?)
167            })
168            .map_err(|e| QueryError::Database(e.to_string()))?;
169
170            let mut result = Vec::new();
171            for row_result in rows {
172                let query_row = row_result.map_err(|e| QueryError::Database(e.to_string()))?;
173                result.push(query_row);
174            }
175
176            Ok(result)
177        })
178        .await
179        .map_err(|e| QueryError::Database(format!("Join error: {}", e)))?
180    }
181
182    async fn begin_transaction(&mut self) -> Result<()> {
183        if self.in_transaction {
184            return Err(QueryError::Transaction("Already in transaction".to_string()));
185        }
186
187        self.execute("BEGIN TRANSACTION").await?;
188        self.in_transaction = true;
189        Ok(())
190    }
191
192    async fn commit(&mut self) -> Result<()> {
193        if !self.in_transaction {
194            return Err(QueryError::Transaction("Not in transaction".to_string()));
195        }
196
197        self.execute("COMMIT").await?;
198        self.in_transaction = false;
199        Ok(())
200    }
201
202    async fn rollback(&mut self) -> Result<()> {
203        if !self.in_transaction {
204            return Err(QueryError::Transaction("Not in transaction".to_string()));
205        }
206
207        self.execute("ROLLBACK").await?;
208        self.in_transaction = false;
209        Ok(())
210    }
211
212    fn is_connected(&self) -> bool {
213        // DuckDB connections are always "connected" once created
214        true
215    }
216
217    async fn close(self) -> Result<()> {
218        // DuckDB connection will be closed when dropped
219        Ok(())
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[tokio::test]
228    async fn test_duckdb_memory() {
229        let backend = DuckDBBackend::memory().await;
230        assert!(backend.is_ok());
231    }
232
233    #[tokio::test]
234    async fn test_duckdb_create_table() {
235        let mut backend = DuckDBBackend::memory().await.unwrap();
236
237        backend
238            .execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name VARCHAR)")
239            .await
240            .unwrap();
241
242        let results = backend
243            .query("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table'")
244            .await
245            .unwrap();
246
247        assert!(results.len() > 0 || backend.is_connected());
248    }
249
250    #[tokio::test]
251    async fn test_duckdb_insert_query() {
252        let mut backend = DuckDBBackend::memory().await.unwrap();
253
254        backend
255            .execute("CREATE TABLE users (id INTEGER, name VARCHAR, age INTEGER)")
256            .await
257            .unwrap();
258
259        backend
260            .execute("INSERT INTO users VALUES (1, 'Alice', 30)")
261            .await
262            .unwrap();
263
264        backend
265            .execute("INSERT INTO users VALUES (2, 'Bob', 25)")
266            .await
267            .unwrap();
268
269        let results = backend
270            .query("SELECT * FROM users ORDER BY id")
271            .await
272            .unwrap();
273
274        assert_eq!(results.len(), 2);
275        assert_eq!(results[0].get_string("name"), Some("Alice".to_string()));
276        assert_eq!(results[0].get_i64("age"), Some(30));
277
278        assert_eq!(results[1].get_string("name"), Some("Bob".to_string()));
279        assert_eq!(results[1].get_i64("age"), Some(25));
280    }
281
282    #[tokio::test]
283    async fn test_duckdb_aggregation() {
284        let mut backend = DuckDBBackend::memory().await.unwrap();
285
286        backend
287            .execute("CREATE TABLE sales (product VARCHAR, amount INTEGER)")
288            .await
289            .unwrap();
290
291        backend
292            .execute("INSERT INTO sales VALUES ('A', 100), ('B', 200), ('A', 150)")
293            .await
294            .unwrap();
295
296        let results = backend
297            .query("SELECT product, SUM(amount) as total FROM sales GROUP BY product ORDER BY product")
298            .await
299            .unwrap();
300
301        assert_eq!(results.len(), 2);
302        assert_eq!(results[0].get_string("product"), Some("A".to_string()));
303        assert_eq!(results[0].get_i64("total"), Some(250));
304    }
305
306    #[tokio::test]
307    async fn test_duckdb_analytical_query() {
308        let mut backend = DuckDBBackend::memory().await.unwrap();
309
310        backend
311            .execute("CREATE TABLE events (user_id INTEGER, event_type VARCHAR, timestamp BIGINT)")
312            .await
313            .unwrap();
314
315        backend
316            .execute("INSERT INTO events VALUES (1, 'login', 1000), (1, 'click', 2000), (2, 'login', 1500)")
317            .await
318            .unwrap();
319
320        // Window function - DuckDB excels at these
321        let results = backend
322            .query("SELECT user_id, event_type, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY timestamp) as event_seq FROM events")
323            .await
324            .unwrap();
325
326        assert_eq!(results.len(), 3);
327    }
328
329    #[tokio::test]
330    async fn test_duckdb_null_handling() {
331        let mut backend = DuckDBBackend::memory().await.unwrap();
332
333        backend
334            .execute("CREATE TABLE nullable_test (id INTEGER, value VARCHAR)")
335            .await
336            .unwrap();
337
338        backend
339            .execute("INSERT INTO nullable_test VALUES (1, NULL)")
340            .await
341            .unwrap();
342
343        let results = backend
344            .query("SELECT * FROM nullable_test WHERE id = 1")
345            .await
346            .unwrap();
347
348        assert_eq!(results.len(), 1);
349        assert_eq!(results[0].get_string("value"), None);
350    }
351
352    #[tokio::test]
353    async fn test_duckdb_transaction() {
354        let mut backend = DuckDBBackend::memory().await.unwrap();
355
356        backend
357            .execute("CREATE TABLE accounts (id INTEGER, balance INTEGER)")
358            .await
359            .unwrap();
360
361        backend
362            .execute("INSERT INTO accounts VALUES (1, 100)")
363            .await
364            .unwrap();
365
366        // Test transaction state tracking
367        assert!(!backend.in_transaction);
368        backend.begin_transaction().await.unwrap();
369        assert!(backend.in_transaction);
370
371        // Can't begin another transaction while one is active
372        assert!(backend.begin_transaction().await.is_err());
373
374        // Reset state
375        backend.in_transaction = false;
376    }
377}