Skip to main content

victauri_plugin/
database.rs

1#[cfg(feature = "sqlite")]
2use std::path::{Path, PathBuf};
3
4#[cfg(feature = "sqlite")]
5const MAX_ROWS_DEFAULT: usize = 100;
6#[cfg(feature = "sqlite")]
7const MAX_ROWS_LIMIT: usize = 10_000;
8
9#[cfg(feature = "sqlite")]
10static READ_ONLY_PREFIXES: &[&str] = &["select", "pragma", "explain", "with"];
11
12#[cfg(feature = "sqlite")]
13fn strip_sql_comments(sql: &str) -> String {
14    let mut result = String::with_capacity(sql.len());
15    let bytes = sql.as_bytes();
16    let len = bytes.len();
17    let mut i = 0;
18    while i < len {
19        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
20            while i < len && bytes[i] != b'\n' {
21                i += 1;
22            }
23        } else if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
24            i += 2;
25            while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
26                i += 1;
27            }
28            if i + 1 < len {
29                i += 2;
30            }
31            result.push(' ');
32        } else {
33            result.push(bytes[i] as char);
34            i += 1;
35        }
36    }
37    result
38}
39
40#[cfg(feature = "sqlite")]
41fn is_read_only(sql: &str) -> bool {
42    let cleaned = strip_sql_comments(sql);
43    let trimmed = cleaned.trim_start().to_lowercase();
44    if trimmed.is_empty() {
45        return false;
46    }
47    READ_ONLY_PREFIXES
48        .iter()
49        .any(|prefix| trimmed.starts_with(prefix))
50}
51
52/// Returns true if `sql` is the write form of a PRAGMA (`PRAGMA name = value`).
53///
54/// The read forms (`PRAGMA name`, `PRAGMA name(arg)`) are not flagged. An `=`
55/// is only significant when it appears outside of any quoted string.
56#[cfg(feature = "sqlite")]
57fn is_pragma_write(sql: &str) -> bool {
58    let cleaned = strip_sql_comments(sql);
59    let trimmed = cleaned.trim_start();
60    if !trimmed.to_lowercase().starts_with("pragma") {
61        return false;
62    }
63    let bytes = trimmed.as_bytes();
64    let mut in_single = false;
65    let mut in_double = false;
66    for &b in bytes {
67        match b {
68            b'\'' if !in_double => in_single = !in_single,
69            b'"' if !in_single => in_double = !in_double,
70            b'=' if !in_single && !in_double => return true,
71            _ => {}
72        }
73    }
74    false
75}
76
77/// Discover `SQLite` database files in a directory (non-recursive, max depth 2).
78#[cfg(feature = "sqlite")]
79#[must_use]
80pub fn discover_databases(dir: &Path) -> Vec<PathBuf> {
81    let mut results = Vec::new();
82    discover_recursive(dir, 0, 2, &mut results);
83    results
84}
85
86#[cfg(feature = "sqlite")]
87fn discover_recursive(dir: &Path, depth: u32, max_depth: u32, results: &mut Vec<PathBuf>) {
88    let Ok(entries) = std::fs::read_dir(dir) else {
89        return;
90    };
91    for entry in entries.flatten() {
92        let path = entry.path();
93        if path.is_symlink() {
94            continue;
95        }
96        if path.is_file() {
97            if let Some(ext) = path.extension().and_then(|e| e.to_str())
98                && matches!(ext, "sqlite" | "sqlite3" | "db" | "sdb")
99            {
100                results.push(path);
101            }
102        } else if path.is_dir() && depth < max_depth {
103            discover_recursive(&path, depth + 1, max_depth, results);
104        }
105    }
106}
107
108/// Execute a read-only SQL query against a `SQLite` database.
109///
110/// # Errors
111///
112/// Returns an error if the query is not read-only, the database cannot be opened,
113/// or the query fails.
114#[cfg(feature = "sqlite")]
115pub fn query(
116    db_path: &Path,
117    sql: &str,
118    params: &[serde_json::Value],
119    max_rows: Option<usize>,
120) -> Result<serde_json::Value, String> {
121    if !is_read_only(sql) {
122        return Err(
123            "only SELECT, PRAGMA, EXPLAIN, and WITH queries are allowed (read-only access)"
124                .to_string(),
125        );
126    }
127
128    // Defence in depth: the connection is opened READ_ONLY (SQLite rejects
129    // actual writes), but explicitly reject the write form of PRAGMA
130    // (`PRAGMA name = value`) so the read-only contract is self-evident and
131    // not solely reliant on the open flags. The read forms `PRAGMA name` and
132    // `PRAGMA name(arg)` remain allowed.
133    if is_pragma_write(sql) {
134        return Err(
135            "PRAGMA writes (PRAGMA name = value) are not allowed (read-only access)".to_string(),
136        );
137    }
138
139    let cleaned = strip_sql_comments(sql);
140    if cleaned.contains(';') {
141        let parts: Vec<&str> = cleaned
142            .split(';')
143            .filter(|s| !s.trim().is_empty())
144            .collect();
145        if parts.len() > 1 {
146            return Err(
147                "stacked queries (multiple statements separated by ;) are not allowed".to_string(),
148            );
149        }
150    }
151
152    let max_rows = max_rows.unwrap_or(MAX_ROWS_DEFAULT).min(MAX_ROWS_LIMIT);
153
154    let conn = rusqlite::Connection::open_with_flags(
155        db_path,
156        rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
157    )
158    .map_err(|e| format!("failed to open database: {e}"))?;
159
160    // 5 second query timeout
161    conn.busy_timeout(std::time::Duration::from_secs(5))
162        .map_err(|e| format!("failed to set timeout: {e}"))?;
163
164    let mut stmt = conn
165        .prepare(sql)
166        .map_err(|e| format!("failed to prepare query: {e}"))?;
167
168    let column_names: Vec<String> = stmt
169        .column_names()
170        .iter()
171        .map(|s| (*s).to_string())
172        .collect();
173    let column_count = column_names.len();
174
175    let sqlite_params: Vec<Box<dyn rusqlite::types::ToSql>> =
176        params.iter().map(json_to_sql).collect();
177    let param_refs: Vec<&dyn rusqlite::types::ToSql> = sqlite_params.iter().map(|b| &**b).collect();
178
179    let mut rows_out: Vec<serde_json::Value> = Vec::new();
180    let mut rows = stmt
181        .query(param_refs.as_slice())
182        .map_err(|e| format!("query execution failed: {e}"))?;
183
184    while let Some(row) = rows.next().map_err(|e| format!("row read failed: {e}"))? {
185        if rows_out.len() >= max_rows {
186            break;
187        }
188        let mut obj = serde_json::Map::new();
189        for (i, col_name) in column_names.iter().enumerate().take(column_count) {
190            let value = row_value_to_json(row, i);
191            obj.insert(col_name.clone(), value);
192        }
193        rows_out.push(serde_json::Value::Object(obj));
194    }
195
196    let truncated = rows_out.len() == max_rows;
197
198    Ok(serde_json::json!({
199        "columns": column_names,
200        "rows": rows_out,
201        "row_count": rows_out.len(),
202        "truncated": truncated,
203        "max_rows": max_rows,
204    }))
205}
206
207#[cfg(feature = "sqlite")]
208fn json_to_sql(val: &serde_json::Value) -> Box<dyn rusqlite::types::ToSql> {
209    match val {
210        serde_json::Value::Null => Box::new(rusqlite::types::Null),
211        serde_json::Value::Bool(b) => Box::new(*b),
212        serde_json::Value::Number(n) => {
213            if let Some(i) = n.as_i64() {
214                Box::new(i)
215            } else if let Some(f) = n.as_f64() {
216                Box::new(f)
217            } else {
218                Box::new(n.to_string())
219            }
220        }
221        serde_json::Value::String(s) => Box::new(s.clone()),
222        other => Box::new(other.to_string()),
223    }
224}
225
226#[cfg(feature = "sqlite")]
227fn row_value_to_json(row: &rusqlite::Row, idx: usize) -> serde_json::Value {
228    use rusqlite::types::ValueRef;
229    match row.get_ref(idx) {
230        Ok(ValueRef::Null) => serde_json::Value::Null,
231        Ok(ValueRef::Integer(i)) => serde_json::json!(i),
232        Ok(ValueRef::Real(f)) => serde_json::json!(f),
233        Ok(ValueRef::Text(t)) => {
234            let s = String::from_utf8_lossy(t);
235            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&s)
236                && (parsed.is_object() || parsed.is_array())
237            {
238                return parsed;
239            }
240            serde_json::Value::String(s.into_owned())
241        }
242        Ok(ValueRef::Blob(b)) => {
243            use base64::Engine;
244            serde_json::json!({
245                "__blob": true,
246                "size": b.len(),
247                "base64": base64::engine::general_purpose::STANDARD.encode(b),
248            })
249        }
250        Err(_) => serde_json::Value::Null,
251    }
252}
253
254#[cfg(all(test, feature = "sqlite"))]
255mod tests {
256    use super::*;
257
258    fn create_test_db() -> (tempfile::NamedTempFile, PathBuf) {
259        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
260        let path = file.path().to_path_buf();
261        let conn = rusqlite::Connection::open(&path).unwrap();
262        conn.execute_batch(
263            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, score REAL);
264             INSERT INTO users VALUES (1, 'Alice', 95.5);
265             INSERT INTO users VALUES (2, 'Bob', 87.0);
266             INSERT INTO users VALUES (3, 'Charlie', 92.3);",
267        )
268        .unwrap();
269        (file, path)
270    }
271
272    #[test]
273    fn select_all_rows() {
274        let (_f, path) = create_test_db();
275        let result = query(&path, "SELECT * FROM users", &[], None).unwrap();
276        assert_eq!(result["row_count"], 3);
277        assert_eq!(
278            result["columns"],
279            serde_json::json!(["id", "name", "score"])
280        );
281        assert_eq!(result["rows"][0]["name"], "Alice");
282        assert_eq!(result["rows"][1]["name"], "Bob");
283    }
284
285    #[test]
286    fn select_with_params() {
287        let (_f, path) = create_test_db();
288        let result = query(
289            &path,
290            "SELECT name FROM users WHERE score > ?",
291            &[serde_json::json!(90.0)],
292            None,
293        )
294        .unwrap();
295        assert_eq!(result["row_count"], 2);
296    }
297
298    #[test]
299    fn max_rows_truncation() {
300        let (_f, path) = create_test_db();
301        let result = query(&path, "SELECT * FROM users", &[], Some(2)).unwrap();
302        assert_eq!(result["row_count"], 2);
303        assert_eq!(result["truncated"], true);
304    }
305
306    #[test]
307    fn rejects_insert() {
308        let (_f, path) = create_test_db();
309        let err = query(
310            &path,
311            "INSERT INTO users VALUES (4, 'Eve', 99.0)",
312            &[],
313            None,
314        )
315        .unwrap_err();
316        assert!(err.contains("read-only"));
317    }
318
319    #[test]
320    fn rejects_delete() {
321        let (_f, path) = create_test_db();
322        let err = query(&path, "DELETE FROM users", &[], None).unwrap_err();
323        assert!(err.contains("read-only"));
324    }
325
326    #[test]
327    fn rejects_drop() {
328        let (_f, path) = create_test_db();
329        let err = query(&path, "DROP TABLE users", &[], None).unwrap_err();
330        assert!(err.contains("read-only"));
331    }
332
333    #[test]
334    fn rejects_update() {
335        let (_f, path) = create_test_db();
336        let err = query(&path, "UPDATE users SET name = 'X'", &[], None).unwrap_err();
337        assert!(err.contains("read-only"));
338    }
339
340    #[test]
341    fn pragma_works() {
342        let (_f, path) = create_test_db();
343        let result = query(&path, "PRAGMA table_info(users)", &[], None).unwrap();
344        assert!(result["row_count"].as_u64().unwrap() >= 3);
345    }
346
347    #[test]
348    fn pragma_read_allowed() {
349        let (_f, path) = create_test_db();
350        assert!(query(&path, "PRAGMA journal_mode", &[], None).is_ok());
351        assert!(query(&path, "PRAGMA user_version", &[], None).is_ok());
352    }
353
354    #[test]
355    fn rejects_pragma_write_form() {
356        let (_f, path) = create_test_db();
357        for sql in [
358            "PRAGMA journal_mode=DELETE",
359            "PRAGMA journal_mode = WAL",
360            "PRAGMA user_version=12345",
361            "  pragma  synchronous = 0 ",
362        ] {
363            let err = query(&path, sql, &[], None).unwrap_err();
364            assert!(err.contains("PRAGMA writes"), "expected block for: {sql}");
365        }
366    }
367
368    #[test]
369    fn is_pragma_write_ignores_equals_in_strings() {
370        // A read-form PRAGMA whose argument contains '=' inside quotes is not a write.
371        assert!(!is_pragma_write("PRAGMA table_info('a=b')"));
372        assert!(is_pragma_write("PRAGMA foo = 'a=b'"));
373        assert!(!is_pragma_write("SELECT 1 = 1"));
374    }
375
376    #[test]
377    fn with_cte_works() {
378        let (_f, path) = create_test_db();
379        let result = query(
380            &path,
381            "WITH top AS (SELECT * FROM users WHERE score > 90) SELECT name FROM top",
382            &[],
383            None,
384        )
385        .unwrap();
386        assert_eq!(result["row_count"], 2);
387    }
388
389    #[test]
390    fn nonexistent_db_fails() {
391        let err = query(Path::new("/nonexistent/db.sqlite"), "SELECT 1", &[], None).unwrap_err();
392        assert!(err.contains("failed to open"));
393    }
394
395    #[test]
396    fn json_column_parsed() {
397        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
398        let path = file.path().to_path_buf();
399        let conn = rusqlite::Connection::open(&path).unwrap();
400        conn.execute_batch(
401            r#"CREATE TABLE config (key TEXT, value TEXT);
402               INSERT INTO config VALUES ('settings', '{"theme":"dark","lang":"en"}');"#,
403        )
404        .unwrap();
405        let result = query(&path, "SELECT * FROM config", &[], None).unwrap();
406        assert!(result["rows"][0]["value"].is_object());
407        assert_eq!(result["rows"][0]["value"]["theme"], "dark");
408    }
409
410    #[test]
411    fn discover_finds_sqlite_files() {
412        let dir = tempfile::tempdir().unwrap();
413        std::fs::File::create(dir.path().join("app.sqlite")).unwrap();
414        std::fs::File::create(dir.path().join("cache.db")).unwrap();
415        std::fs::File::create(dir.path().join("readme.txt")).unwrap();
416        let sub = dir.path().join("subdir");
417        std::fs::create_dir(&sub).unwrap();
418        std::fs::File::create(sub.join("deep.sqlite3")).unwrap();
419
420        let dbs = discover_databases(dir.path());
421        assert_eq!(dbs.len(), 3);
422    }
423
424    #[test]
425    fn rejects_comment_bypass_block() {
426        let (_f, path) = create_test_db();
427        let err = query(&path, "/* sneaky */DELETE FROM users", &[], None).unwrap_err();
428        assert!(err.contains("read-only"));
429    }
430
431    #[test]
432    fn rejects_line_comment_bypass() {
433        let (_f, path) = create_test_db();
434        let err = query(&path, "-- comment\nDELETE FROM users", &[], None).unwrap_err();
435        assert!(err.contains("read-only"));
436    }
437
438    #[test]
439    fn rejects_stacked_queries() {
440        let (_f, path) = create_test_db();
441        let err = query(&path, "SELECT 1; DROP TABLE users", &[], None).unwrap_err();
442        assert!(err.contains("stacked queries"));
443    }
444
445    #[test]
446    fn allows_trailing_semicolon() {
447        let (_f, path) = create_test_db();
448        let result = query(&path, "SELECT * FROM users;", &[], None).unwrap();
449        assert_eq!(result["row_count"], 3);
450    }
451
452    #[test]
453    fn allows_select_with_block_comment() {
454        let (_f, path) = create_test_db();
455        let result = query(
456            &path,
457            "/* filter */ SELECT name FROM users WHERE id = 1",
458            &[],
459            None,
460        )
461        .unwrap();
462        assert_eq!(result["row_count"], 1);
463        assert_eq!(result["rows"][0]["name"], "Alice");
464    }
465
466    #[test]
467    fn rejects_empty_query() {
468        let (_f, path) = create_test_db();
469        let err = query(&path, "", &[], None).unwrap_err();
470        assert!(err.contains("read-only"));
471    }
472
473    #[test]
474    fn rejects_comment_only_query() {
475        let (_f, path) = create_test_db();
476        let err = query(&path, "/* just a comment */", &[], None).unwrap_err();
477        assert!(err.contains("read-only"));
478    }
479
480    #[test]
481    fn rejects_nested_comment_bypass() {
482        let (_f, path) = create_test_db();
483        let err = query(
484            &path,
485            "/* outer /* inner */ still comment */ DROP TABLE users",
486            &[],
487            None,
488        )
489        .unwrap_err();
490        assert!(err.contains("read-only"));
491    }
492
493    #[test]
494    fn blob_column_base64() {
495        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
496        let path = file.path().to_path_buf();
497        let conn = rusqlite::Connection::open(&path).unwrap();
498        conn.execute_batch("CREATE TABLE blobs (id INTEGER, data BLOB)")
499            .unwrap();
500        conn.execute("INSERT INTO blobs VALUES (1, X'DEADBEEF')", [])
501            .unwrap();
502        let result = query(&path, "SELECT * FROM blobs", &[], None).unwrap();
503        assert!(result["rows"][0]["data"]["__blob"].as_bool().unwrap());
504        assert_eq!(result["rows"][0]["data"]["size"], 4);
505    }
506}