1#[cfg(feature = "sqlite")]
2use crate::backend::{DatabaseBackend, QueryResult, QueryRow};
3use crate::error::{QueryError, Result};
4use async_trait::async_trait;
5use serde_json::Value;
6use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions, SqliteRow};
7use sqlx::{Row, Column};
8use std::str::FromStr;
9
10pub struct SqliteBackend {
11 pool: SqlitePool,
12 in_transaction: bool,
13}
14
15impl SqliteBackend {
16 pub async fn new(url: &str) -> Result<Self> {
18 let options = SqliteConnectOptions::from_str(url)?
19 .create_if_missing(true);
20
21 let pool = SqlitePoolOptions::new()
22 .max_connections(5)
23 .connect_with(options)
24 .await?;
25
26 sqlx::query("PRAGMA foreign_keys = ON")
28 .execute(&pool)
29 .await?;
30
31 sqlx::query("PRAGMA journal_mode = WAL")
33 .execute(&pool)
34 .await?;
35
36 Ok(Self {
37 pool,
38 in_transaction: false,
39 })
40 }
41
42 pub async fn memory() -> Result<Self> {
44 Self::new("sqlite::memory:").await
45 }
46
47 pub async fn file(path: &str) -> Result<Self> {
49 Self::new(&format!("sqlite://{}", path)).await
50 }
51
52 fn convert_row(row: &SqliteRow) -> QueryRow {
54 let mut query_row = QueryRow::new();
55
56 for column in row.columns() {
57 let column_name = column.name();
58
59 if let Ok(Some(value)) = row.try_get::<Option<String>, _>(column_name) {
61 query_row.insert(column_name.to_string(), Value::String(value));
62 } else if let Ok(Some(value)) = row.try_get::<Option<i64>, _>(column_name) {
63 query_row.insert(column_name.to_string(), Value::Number(value.into()));
64 } else if let Ok(Some(value)) = row.try_get::<Option<i32>, _>(column_name) {
65 query_row.insert(column_name.to_string(), Value::Number(value.into()));
66 } else if let Ok(Some(value)) = row.try_get::<Option<f64>, _>(column_name) {
67 if let Some(num) = serde_json::Number::from_f64(value) {
68 query_row.insert(column_name.to_string(), Value::Number(num));
69 }
70 } else if let Ok(Some(value)) = row.try_get::<Option<bool>, _>(column_name) {
71 query_row.insert(column_name.to_string(), Value::Bool(value));
72 } else if let Ok(Some(value)) = row.try_get::<Option<Vec<u8>>, _>(column_name) {
73 let base64 = base64_encode(&value);
75 query_row.insert(column_name.to_string(), Value::String(base64));
76 } else if let Ok(value) = row.try_get::<Value, _>(column_name) {
77 query_row.insert(column_name.to_string(), value);
78 } else {
79 query_row.insert(column_name.to_string(), Value::Null);
81 }
82 }
83
84 query_row
85 }
86
87 pub async fn optimize(&mut self) -> Result<()> {
89 self.execute("VACUUM").await?;
90 Ok(())
91 }
92
93 pub async fn analyze(&mut self) -> Result<()> {
95 self.execute("ANALYZE").await?;
96 Ok(())
97 }
98
99 pub async fn database_size(&mut self) -> Result<i64> {
101 let result = sqlx::query("SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()")
102 .fetch_one(&self.pool)
103 .await?;
104
105 Ok(result.get::<i64, _>("size"))
106 }
107
108 pub async fn list_tables(&mut self) -> Result<Vec<String>> {
110 let rows = sqlx::query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
111 .fetch_all(&self.pool)
112 .await?;
113
114 Ok(rows.iter().map(|row| row.get::<String, _>("name")).collect())
115 }
116
117 pub async fn table_info(&mut self, table_name: &str) -> Result<Vec<ColumnInfo>> {
119 let query = format!("PRAGMA table_info({})", table_name);
120 let rows = sqlx::query(&query).fetch_all(&self.pool).await?;
121
122 let mut columns = Vec::new();
123 for row in rows {
124 columns.push(ColumnInfo {
125 cid: row.get::<i32, _>("cid"),
126 name: row.get::<String, _>("name"),
127 type_name: row.get::<String, _>("type"),
128 not_null: row.get::<i32, _>("notnull") != 0,
129 default_value: row.try_get::<Option<String>, _>("dflt_value").ok().flatten(),
130 primary_key: row.get::<i32, _>("pk") != 0,
131 });
132 }
133
134 Ok(columns)
135 }
136
137 pub async fn checkpoint(&mut self) -> Result<()> {
139 self.execute("PRAGMA wal_checkpoint(TRUNCATE)").await?;
140 Ok(())
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct ColumnInfo {
147 pub cid: i32,
148 pub name: String,
149 pub type_name: String,
150 pub not_null: bool,
151 pub default_value: Option<String>,
152 pub primary_key: bool,
153}
154
155#[async_trait]
156impl DatabaseBackend for SqliteBackend {
157 async fn connect(url: &str) -> Result<Self> {
158 Self::new(url).await
159 }
160
161 async fn execute(&mut self, sql: &str) -> Result<u64> {
162 let result = sqlx::query(sql).execute(&self.pool).await?;
163 Ok(result.rows_affected())
164 }
165
166 async fn query(&mut self, sql: &str) -> Result<QueryResult> {
167 let rows = sqlx::query(sql).fetch_all(&self.pool).await?;
168
169 let result = rows.iter().map(Self::convert_row).collect();
170
171 Ok(result)
172 }
173
174 async fn begin_transaction(&mut self) -> Result<()> {
175 if self.in_transaction {
176 return Err(QueryError::Transaction(
177 "Already in transaction".to_string(),
178 ));
179 }
180
181 self.execute("BEGIN TRANSACTION").await?;
182 self.in_transaction = true;
183 Ok(())
184 }
185
186 async fn commit(&mut self) -> Result<()> {
187 if !self.in_transaction {
188 return Err(QueryError::Transaction("Not in transaction".to_string()));
189 }
190
191 self.execute("COMMIT").await?;
192 self.in_transaction = false;
193 Ok(())
194 }
195
196 async fn rollback(&mut self) -> Result<()> {
197 if !self.in_transaction {
198 return Err(QueryError::Transaction("Not in transaction".to_string()));
199 }
200
201 self.execute("ROLLBACK").await?;
202 self.in_transaction = false;
203 Ok(())
204 }
205
206 fn is_connected(&self) -> bool {
207 !self.pool.is_closed()
208 }
209
210 async fn close(self) -> Result<()> {
211 self.pool.close().await;
212 Ok(())
213 }
214}
215
216fn base64_encode(bytes: &[u8]) -> String {
218 use std::fmt::Write;
219
220 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
221
222 let mut result = String::new();
223 let mut i = 0;
224
225 while i < bytes.len() {
226 let b1 = bytes[i];
227 let b2 = if i + 1 < bytes.len() { bytes[i + 1] } else { 0 };
228 let b3 = if i + 2 < bytes.len() { bytes[i + 2] } else { 0 };
229
230 let n = ((b1 as u32) << 16) | ((b2 as u32) << 8) | (b3 as u32);
231
232 let c1 = ALPHABET[((n >> 18) & 63) as usize] as char;
233 let c2 = ALPHABET[((n >> 12) & 63) as usize] as char;
234 let c3 = if i + 1 < bytes.len() {
235 ALPHABET[((n >> 6) & 63) as usize] as char
236 } else {
237 '='
238 };
239 let c4 = if i + 2 < bytes.len() {
240 ALPHABET[(n & 63) as usize] as char
241 } else {
242 '='
243 };
244
245 write!(&mut result, "{}{}{}{}", c1, c2, c3, c4).unwrap();
246 i += 3;
247 }
248
249 result
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[tokio::test]
257 async fn test_sqlite_memory() {
258 let backend = SqliteBackend::memory().await;
259 assert!(backend.is_ok());
260 }
261
262 #[tokio::test]
263 async fn test_sqlite_create_table() {
264 let mut backend = SqliteBackend::memory().await.unwrap();
265
266 backend
267 .execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
268 .await
269 .unwrap();
270
271 let tables = backend.list_tables().await.unwrap();
272 assert!(tables.contains(&"test_table".to_string()));
273 }
274
275 #[tokio::test]
276 async fn test_sqlite_insert_query() {
277 let mut backend = SqliteBackend::memory().await.unwrap();
278
279 backend
280 .execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)")
281 .await
282 .unwrap();
283
284 backend
285 .execute("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30)")
286 .await
287 .unwrap();
288
289 backend
290 .execute("INSERT INTO users (id, name, age) VALUES (2, 'Bob', 25)")
291 .await
292 .unwrap();
293
294 let results = backend.query("SELECT * FROM users ORDER BY id").await.unwrap();
295
296 assert_eq!(results.len(), 2);
297 assert_eq!(results[0].get_i64("id"), Some(1));
298 assert_eq!(results[0].get_string("name"), Some("Alice".to_string()));
299 assert_eq!(results[0].get_i64("age"), Some(30));
300
301 assert_eq!(results[1].get_i64("id"), Some(2));
302 assert_eq!(results[1].get_string("name"), Some("Bob".to_string()));
303 assert_eq!(results[1].get_i64("age"), Some(25));
304 }
305
306 #[tokio::test]
307 async fn test_sqlite_transaction() {
308 let mut backend = SqliteBackend::memory().await.unwrap();
309
310 backend
311 .execute("CREATE TABLE accounts (id INTEGER PRIMARY KEY, balance INTEGER)")
312 .await
313 .unwrap();
314
315 backend
316 .execute("INSERT INTO accounts (id, balance) VALUES (1, 100)")
317 .await
318 .unwrap();
319
320 assert!(!backend.in_transaction);
328 backend.begin_transaction().await.unwrap();
329 assert!(backend.in_transaction);
330
331 assert!(backend.begin_transaction().await.is_err());
333
334 backend.in_transaction = false;
336
337 let results = backend.query("SELECT balance FROM accounts WHERE id = 1").await.unwrap();
339 assert_eq!(results[0].get_i64("balance"), Some(100));
340 }
341
342 #[tokio::test]
343 async fn test_sqlite_table_info() {
344 let mut backend = SqliteBackend::memory().await.unwrap();
345
346 backend
347 .execute("CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL)")
348 .await
349 .unwrap();
350
351 let info = backend.table_info("products").await.unwrap();
352
353 assert_eq!(info.len(), 3);
354 assert_eq!(info[0].name, "id");
355 assert!(info[0].primary_key);
356 assert_eq!(info[1].name, "name");
357 assert!(info[1].not_null);
358 assert_eq!(info[2].name, "price");
359 }
360
361 #[tokio::test]
362 async fn test_base64_encode() {
363 assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
364 assert_eq!(base64_encode(b"hello world"), "aGVsbG8gd29ybGQ=");
365 assert_eq!(base64_encode(b""), "");
366 assert_eq!(base64_encode(&[0, 1, 2, 3, 4, 5]), "AAECAwQF");
367 }
368
369 #[tokio::test]
370 async fn test_sqlite_blob() {
371 let mut backend = SqliteBackend::memory().await.unwrap();
372
373 backend
374 .execute("CREATE TABLE files (id INTEGER PRIMARY KEY, data BLOB)")
375 .await
376 .unwrap();
377
378 backend
379 .execute("INSERT INTO files (id, data) VALUES (1, X'48656c6c6f')")
380 .await
381 .unwrap();
382
383 let results = backend.query("SELECT data FROM files WHERE id = 1").await.unwrap();
384
385 assert_eq!(results.len(), 1);
386 let data_str = results[0].get_string("data").unwrap();
388 assert!(!data_str.is_empty());
389 }
390
391 #[tokio::test]
392 async fn test_sqlite_null_values() {
393 let mut backend = SqliteBackend::memory().await.unwrap();
394
395 backend
396 .execute("CREATE TABLE nullable_test (id INTEGER PRIMARY KEY, value TEXT)")
397 .await
398 .unwrap();
399
400 backend
401 .execute("INSERT INTO nullable_test (id, value) VALUES (1, NULL)")
402 .await
403 .unwrap();
404
405 let results = backend.query("SELECT * FROM nullable_test WHERE id = 1").await.unwrap();
406
407 assert_eq!(results.len(), 1);
408 assert_eq!(results[0].get_string("value"), None);
409 }
410}