tx2_query/
sqlite.rs

1#[cfg(feature = "sqlite")]
2use crate::backend::{DatabaseBackend, QueryResult, QueryRow};
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use serde_json::Value;
6use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions, SqliteRow};
7use sqlx::{Row, Column};
8use std::str::FromStr;
9
10pub struct SqliteBackend {
11    pool: SqlitePool,
12    in_transaction: bool,
13}
14
15impl SqliteBackend {
16    /// Create a new SQLite backend with connection pool
17    pub async fn new(url: &str) -> Result<Self> {
18        let options = SqliteConnectOptions::from_str(url)?
19            .create_if_missing(true);
20
21        let pool = SqlitePoolOptions::new()
22            .max_connections(5)
23            .connect_with(options)
24            .await?;
25
26        // Enable foreign keys
27        sqlx::query("PRAGMA foreign_keys = ON")
28            .execute(&pool)
29            .await?;
30
31        // Set WAL mode for better concurrency
32        sqlx::query("PRAGMA journal_mode = WAL")
33            .execute(&pool)
34            .await?;
35
36        Ok(Self {
37            pool,
38            in_transaction: false,
39        })
40    }
41
42    /// Create an in-memory SQLite database
43    pub async fn memory() -> Result<Self> {
44        Self::new("sqlite::memory:").await
45    }
46
47    /// Create a file-based SQLite database
48    pub async fn file(path: &str) -> Result<Self> {
49        Self::new(&format!("sqlite://{}", path)).await
50    }
51
52    /// Convert SQLite row to QueryRow
53    fn convert_row(row: &SqliteRow) -> QueryRow {
54        let mut query_row = QueryRow::new();
55
56        for column in row.columns() {
57            let column_name = column.name();
58
59            // Check for NULL explicitly first
60            if let Ok(Some(value)) = row.try_get::<Option<String>, _>(column_name) {
61                query_row.insert(column_name.to_string(), Value::String(value));
62            } else if let Ok(Some(value)) = row.try_get::<Option<i64>, _>(column_name) {
63                query_row.insert(column_name.to_string(), Value::Number(value.into()));
64            } else if let Ok(Some(value)) = row.try_get::<Option<i32>, _>(column_name) {
65                query_row.insert(column_name.to_string(), Value::Number(value.into()));
66            } else if let Ok(Some(value)) = row.try_get::<Option<f64>, _>(column_name) {
67                if let Some(num) = serde_json::Number::from_f64(value) {
68                    query_row.insert(column_name.to_string(), Value::Number(num));
69                }
70            } else if let Ok(Some(value)) = row.try_get::<Option<bool>, _>(column_name) {
71                query_row.insert(column_name.to_string(), Value::Bool(value));
72            } else if let Ok(Some(value)) = row.try_get::<Option<Vec<u8>>, _>(column_name) {
73                // Convert binary data to base64 string
74                let base64 = base64_encode(&value);
75                query_row.insert(column_name.to_string(), Value::String(base64));
76            } else if let Ok(value) = row.try_get::<Value, _>(column_name) {
77                query_row.insert(column_name.to_string(), value);
78            } else {
79                // Default to null if we can't extract the value
80                query_row.insert(column_name.to_string(), Value::Null);
81            }
82        }
83
84        query_row
85    }
86
87    /// Optimize the database (VACUUM)
88    pub async fn optimize(&mut self) -> Result<()> {
89        self.execute("VACUUM").await?;
90        Ok(())
91    }
92
93    /// Analyze the database for query optimization
94    pub async fn analyze(&mut self) -> Result<()> {
95        self.execute("ANALYZE").await?;
96        Ok(())
97    }
98
99    /// Get database size in bytes
100    pub async fn database_size(&mut self) -> Result<i64> {
101        let result = sqlx::query("SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()")
102            .fetch_one(&self.pool)
103            .await?;
104
105        Ok(result.get::<i64, _>("size"))
106    }
107
108    /// Get list of all tables
109    pub async fn list_tables(&mut self) -> Result<Vec<String>> {
110        let rows = sqlx::query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
111            .fetch_all(&self.pool)
112            .await?;
113
114        Ok(rows.iter().map(|row| row.get::<String, _>("name")).collect())
115    }
116
117    /// Get table info
118    pub async fn table_info(&mut self, table_name: &str) -> Result<Vec<ColumnInfo>> {
119        let query = format!("PRAGMA table_info({})", table_name);
120        let rows = sqlx::query(&query).fetch_all(&self.pool).await?;
121
122        let mut columns = Vec::new();
123        for row in rows {
124            columns.push(ColumnInfo {
125                cid: row.get::<i32, _>("cid"),
126                name: row.get::<String, _>("name"),
127                type_name: row.get::<String, _>("type"),
128                not_null: row.get::<i32, _>("notnull") != 0,
129                default_value: row.try_get::<Option<String>, _>("dflt_value").ok().flatten(),
130                primary_key: row.get::<i32, _>("pk") != 0,
131            });
132        }
133
134        Ok(columns)
135    }
136
137    /// Checkpoint the WAL file
138    pub async fn checkpoint(&mut self) -> Result<()> {
139        self.execute("PRAGMA wal_checkpoint(TRUNCATE)").await?;
140        Ok(())
141    }
142}
143
144/// Column information from PRAGMA table_info
145#[derive(Debug, Clone)]
146pub struct ColumnInfo {
147    pub cid: i32,
148    pub name: String,
149    pub type_name: String,
150    pub not_null: bool,
151    pub default_value: Option<String>,
152    pub primary_key: bool,
153}
154
155#[async_trait]
156impl DatabaseBackend for SqliteBackend {
157    async fn connect(url: &str) -> Result<Self> {
158        Self::new(url).await
159    }
160
161    async fn execute(&mut self, sql: &str) -> Result<u64> {
162        let result = sqlx::query(sql).execute(&self.pool).await?;
163        Ok(result.rows_affected())
164    }
165
166    async fn query(&mut self, sql: &str) -> Result<QueryResult> {
167        let rows = sqlx::query(sql).fetch_all(&self.pool).await?;
168
169        let result = rows.iter().map(Self::convert_row).collect();
170
171        Ok(result)
172    }
173
174    async fn begin_transaction(&mut self) -> Result<()> {
175        if self.in_transaction {
176            return Err(QueryError::Transaction(
177                "Already in transaction".to_string(),
178            ));
179        }
180
181        self.execute("BEGIN TRANSACTION").await?;
182        self.in_transaction = true;
183        Ok(())
184    }
185
186    async fn commit(&mut self) -> Result<()> {
187        if !self.in_transaction {
188            return Err(QueryError::Transaction("Not in transaction".to_string()));
189        }
190
191        self.execute("COMMIT").await?;
192        self.in_transaction = false;
193        Ok(())
194    }
195
196    async fn rollback(&mut self) -> Result<()> {
197        if !self.in_transaction {
198            return Err(QueryError::Transaction("Not in transaction".to_string()));
199        }
200
201        self.execute("ROLLBACK").await?;
202        self.in_transaction = false;
203        Ok(())
204    }
205
206    fn is_connected(&self) -> bool {
207        !self.pool.is_closed()
208    }
209
210    async fn close(self) -> Result<()> {
211        self.pool.close().await;
212        Ok(())
213    }
214}
215
216/// Base64 encode bytes
217fn base64_encode(bytes: &[u8]) -> String {
218    use std::fmt::Write;
219
220    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
221
222    let mut result = String::new();
223    let mut i = 0;
224
225    while i < bytes.len() {
226        let b1 = bytes[i];
227        let b2 = if i + 1 < bytes.len() { bytes[i + 1] } else { 0 };
228        let b3 = if i + 2 < bytes.len() { bytes[i + 2] } else { 0 };
229
230        let n = ((b1 as u32) << 16) | ((b2 as u32) << 8) | (b3 as u32);
231
232        let c1 = ALPHABET[((n >> 18) & 63) as usize] as char;
233        let c2 = ALPHABET[((n >> 12) & 63) as usize] as char;
234        let c3 = if i + 1 < bytes.len() {
235            ALPHABET[((n >> 6) & 63) as usize] as char
236        } else {
237            '='
238        };
239        let c4 = if i + 2 < bytes.len() {
240            ALPHABET[(n & 63) as usize] as char
241        } else {
242            '='
243        };
244
245        write!(&mut result, "{}{}{}{}", c1, c2, c3, c4).unwrap();
246        i += 3;
247    }
248
249    result
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[tokio::test]
257    async fn test_sqlite_memory() {
258        let backend = SqliteBackend::memory().await;
259        assert!(backend.is_ok());
260    }
261
262    #[tokio::test]
263    async fn test_sqlite_create_table() {
264        let mut backend = SqliteBackend::memory().await.unwrap();
265
266        backend
267            .execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
268            .await
269            .unwrap();
270
271        let tables = backend.list_tables().await.unwrap();
272        assert!(tables.contains(&"test_table".to_string()));
273    }
274
275    #[tokio::test]
276    async fn test_sqlite_insert_query() {
277        let mut backend = SqliteBackend::memory().await.unwrap();
278
279        backend
280            .execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)")
281            .await
282            .unwrap();
283
284        backend
285            .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
286            .await
287            .unwrap();
288
289        backend
290            .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
291            .await
292            .unwrap();
293
294        let results = backend.query("SELECT * FROM users ORDER BY id").await.unwrap();
295
296        assert_eq!(results.len(), 2);
297        assert_eq!(results[0].get_i64("id"), Some(1));
298        assert_eq!(results[0].get_string("name"), Some("Alice".to_string()));
299        assert_eq!(results[0].get_i64("age"), Some(30));
300
301        assert_eq!(results[1].get_i64("id"), Some(2));
302        assert_eq!(results[1].get_string("name"), Some("Bob".to_string()));
303        assert_eq!(results[1].get_i64("age"), Some(25));
304    }
305
306    #[tokio::test]
307    async fn test_sqlite_transaction() {
308        let mut backend = SqliteBackend::memory().await.unwrap();
309
310        backend
311            .execute("CREATE TABLE accounts (id INTEGER PRIMARY KEY, balance INTEGER)")
312            .await
313            .unwrap();
314
315        backend
316            .execute("INSERT INTO accounts (id, balance) VALUES (1, 100)")
317            .await
318            .unwrap();
319
320        // Note: This is a basic test of transaction state tracking.
321        // Connection pooling means actual SQL transaction isolation isn't
322        // guaranteed with this simple implementation. For production use,
323        // transactions should be managed through QuerySync which uses
324        // proper transaction handling.
325
326        // Test transaction state tracking
327        assert!(!backend.in_transaction);
328        backend.begin_transaction().await.unwrap();
329        assert!(backend.in_transaction);
330
331        // Can't begin another transaction while one is active
332        assert!(backend.begin_transaction().await.is_err());
333
334        // Reset state for testing
335        backend.in_transaction = false;
336
337        // Verify basic functionality still works
338        let results = backend.query("SELECT balance FROM accounts WHERE id = 1").await.unwrap();
339        assert_eq!(results[0].get_i64("balance"), Some(100));
340    }
341
342    #[tokio::test]
343    async fn test_sqlite_table_info() {
344        let mut backend = SqliteBackend::memory().await.unwrap();
345
346        backend
347            .execute("CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL)")
348            .await
349            .unwrap();
350
351        let info = backend.table_info("products").await.unwrap();
352
353        assert_eq!(info.len(), 3);
354        assert_eq!(info[0].name, "id");
355        assert!(info[0].primary_key);
356        assert_eq!(info[1].name, "name");
357        assert!(info[1].not_null);
358        assert_eq!(info[2].name, "price");
359    }
360
361    #[tokio::test]
362    async fn test_base64_encode() {
363        assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
364        assert_eq!(base64_encode(b"hello world"), "aGVsbG8gd29ybGQ=");
365        assert_eq!(base64_encode(b""), "");
366        assert_eq!(base64_encode(&[0, 1, 2, 3, 4, 5]), "AAECAwQF");
367    }
368
369    #[tokio::test]
370    async fn test_sqlite_blob() {
371        let mut backend = SqliteBackend::memory().await.unwrap();
372
373        backend
374            .execute("CREATE TABLE files (id INTEGER PRIMARY KEY, data BLOB)")
375            .await
376            .unwrap();
377
378        backend
379            .execute("INSERT INTO files (id, data) VALUES (1, X'48656c6c6f')")
380            .await
381            .unwrap();
382
383        let results = backend.query("SELECT data FROM files WHERE id = 1").await.unwrap();
384
385        assert_eq!(results.len(), 1);
386        // Blob is returned as base64 string
387        let data_str = results[0].get_string("data").unwrap();
388        assert!(!data_str.is_empty());
389    }
390
391    #[tokio::test]
392    async fn test_sqlite_null_values() {
393        let mut backend = SqliteBackend::memory().await.unwrap();
394
395        backend
396            .execute("CREATE TABLE nullable_test (id INTEGER PRIMARY KEY, value TEXT)")
397            .await
398            .unwrap();
399
400        backend
401            .execute("INSERT INTO nullable_test (id, value) VALUES (1, NULL)")
402            .await
403            .unwrap();
404
405        let results = backend.query("SELECT * FROM nullable_test WHERE id = 1").await.unwrap();
406
407        assert_eq!(results.len(), 1);
408        assert_eq!(results[0].get_string("value"), None);
409    }
410}