Skip to main content

veta_sqlite/
lib.rs

1//! SQLite implementation of the Veta database trait.
2
3use regex::Regex;
4use rusqlite::{params, Connection, OptionalExtension};
5use std::path::Path;
6use std::sync::Mutex;
7use veta_core::{
8    get_pending_migrations, CreateNote, Database, Error, Note, NoteQuery, TagCount, UpdateNote,
9    SCHEMA_VERSION,
10};
11
12/// SQLite-backed database implementation.
13pub struct SqliteDatabase {
14    conn: Mutex<Connection>,
15}
16
17impl SqliteDatabase {
18    /// Open a database at the given path and run any pending migrations.
19    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
20        let conn = Connection::open(path).map_err(|e| Error::Database(e.to_string()))?;
21        conn.execute_batch("PRAGMA foreign_keys = ON;")
22            .map_err(|e| Error::Database(e.to_string()))?;
23        let db = Self {
24            conn: Mutex::new(conn),
25        };
26        db.run_migrations()?;
27        Ok(db)
28    }
29
30    /// Open an in-memory database and run migrations.
31    pub fn open_in_memory() -> Result<Self, Error> {
32        let conn = Connection::open_in_memory().map_err(|e| Error::Database(e.to_string()))?;
33        conn.execute_batch("PRAGMA foreign_keys = ON;")
34            .map_err(|e| Error::Database(e.to_string()))?;
35        let db = Self {
36            conn: Mutex::new(conn),
37        };
38        db.run_migrations()?;
39        Ok(db)
40    }
41
42    /// Run any pending database migrations.
43    fn run_migrations(&self) -> Result<(), Error> {
44        let conn = self.conn.lock().unwrap();
45
46        // Ensure _veta_meta table exists
47        conn.execute(
48            "CREATE TABLE IF NOT EXISTS _veta_meta (
49                key TEXT PRIMARY KEY,
50                value TEXT NOT NULL
51            )",
52            [],
53        )
54        .map_err(|e| Error::Database(e.to_string()))?;
55
56        // Get current schema version
57        let current_version: i64 = conn
58            .query_row(
59                "SELECT value FROM _veta_meta WHERE key = 'schema_version'",
60                [],
61                |row| {
62                    let val: String = row.get(0)?;
63                    Ok(val.parse().unwrap_or(0))
64                },
65            )
66            .unwrap_or(0);
67
68        // Already up to date
69        if current_version >= SCHEMA_VERSION {
70            return Ok(());
71        }
72
73        // Run pending migrations
74        for migration in get_pending_migrations(current_version) {
75            for statement in migration.statements {
76                // Skip _veta_meta creation (already done above)
77                if statement.contains("_veta_meta") {
78                    continue;
79                }
80                // ALTER TABLE doesn't support IF NOT EXISTS, so ignore errors for those
81                if statement.starts_with("ALTER TABLE") {
82                    let _ = conn.execute(statement, []);
83                } else {
84                    conn.execute(statement, []).map_err(|e| {
85                        Error::Database(format!("Migration {} failed: {}", migration.name, e))
86                    })?;
87                }
88            }
89        }
90
91        // Update schema version
92        conn.execute(
93            "INSERT OR REPLACE INTO _veta_meta (key, value) VALUES ('schema_version', ?1)",
94            params![SCHEMA_VERSION.to_string()],
95        )
96        .map_err(|e| Error::Database(e.to_string()))?;
97
98        Ok(())
99    }
100
101    fn parse_tags(tags_str: Option<String>) -> Vec<String> {
102        let mut tags: Vec<String> = tags_str
103            .map(|s| {
104                s.split(',')
105                    .map(String::from)
106                    .filter(|s| !s.is_empty())
107                    .collect()
108            })
109            .unwrap_or_default();
110        tags.sort();
111        tags
112    }
113
114    fn parse_references(refs_str: Option<String>) -> Vec<String> {
115        refs_str
116            .and_then(|s| serde_json::from_str(&s).ok())
117            .unwrap_or_default()
118    }
119
120    fn serialize_references(refs: &[String]) -> String {
121        serde_json::to_string(refs).unwrap_or_else(|_| "[]".to_string())
122    }
123}
124
125#[async_trait::async_trait(?Send)]
126impl Database for SqliteDatabase {
127    async fn add_note(&self, note: CreateNote) -> Result<i64, Error> {
128        let conn = self.conn.lock().unwrap();
129
130        let refs_json = Self::serialize_references(&note.references);
131
132        // Insert the note
133        conn.execute(
134            "INSERT INTO notes (title, body, \"references\") VALUES (?1, ?2, ?3)",
135            params![note.title, note.body, refs_json],
136        )
137        .map_err(|e| Error::Database(e.to_string()))?;
138
139        let note_id = conn.last_insert_rowid();
140
141        // Insert tags
142        for tag in &note.tags {
143            conn.execute(
144                "INSERT INTO tags (name) VALUES (?1) ON CONFLICT (name) DO NOTHING",
145                params![tag],
146            )
147            .map_err(|e| Error::Database(e.to_string()))?;
148
149            conn.execute(
150                "INSERT INTO note_tags (note_id, tag_id) SELECT ?1, id FROM tags WHERE name = ?2",
151                params![note_id, tag],
152            )
153            .map_err(|e| Error::Database(e.to_string()))?;
154        }
155
156        Ok(note_id)
157    }
158
159    async fn get_note(&self, id: i64) -> Result<Option<Note>, Error> {
160        let conn = self.conn.lock().unwrap();
161
162        let note = conn
163            .query_row(
164                "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
165                 FROM notes n
166                 LEFT JOIN note_tags nt ON n.id = nt.note_id
167                 LEFT JOIN tags t ON nt.tag_id = t.id
168                 WHERE n.id = ?1
169                 GROUP BY n.id",
170                params![id],
171                |row| {
172                    Ok(Note {
173                        id: row.get(0)?,
174                        title: row.get(1)?,
175                        body: row.get(2)?,
176                        updated_at: row.get(3)?,
177                        references: Self::parse_references(row.get(4)?),
178                        tags: Self::parse_tags(row.get(5)?),
179                    })
180                },
181            )
182            .optional()
183            .map_err(|e| Error::Database(e.to_string()))?;
184
185        Ok(note)
186    }
187
188    async fn list_notes(&self, query: NoteQuery) -> Result<Vec<Note>, Error> {
189        let conn = self.conn.lock().unwrap();
190
191        let mut sql = String::from(
192            "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
193             FROM notes n
194             LEFT JOIN note_tags nt ON n.id = nt.note_id
195             LEFT JOIN tags t ON nt.tag_id = t.id",
196        );
197
198        let mut conditions = Vec::new();
199        let mut params_vec: Vec<String> = Vec::new();
200
201        if let Some(ref tags) = query.tags {
202            if !tags.is_empty() {
203                let placeholders: Vec<_> = (0..tags.len()).map(|i| format!("?{}", i + 1)).collect();
204                conditions.push(format!(
205                    "n.id IN (SELECT note_id FROM note_tags nt2 
206                              JOIN tags t2 ON nt2.tag_id = t2.id 
207                              WHERE t2.name IN ({}))",
208                    placeholders.join(",")
209                ));
210                params_vec.extend(tags.clone());
211            }
212        }
213
214        if let Some(ref from) = query.from {
215            conditions.push(format!("n.updated_at >= ?{}", params_vec.len() + 1));
216            params_vec.push(from.clone());
217        }
218
219        if let Some(ref to) = query.to {
220            conditions.push(format!("n.updated_at <= ?{}", params_vec.len() + 1));
221            params_vec.push(to.clone());
222        }
223
224        if !conditions.is_empty() {
225            sql.push_str(" WHERE ");
226            sql.push_str(&conditions.join(" AND "));
227        }
228
229        sql.push_str(" GROUP BY n.id ORDER BY n.updated_at DESC, n.id DESC");
230
231        if let Some(limit) = query.limit {
232            sql.push_str(&format!(" LIMIT {}", limit));
233        }
234
235        let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
236            .iter()
237            .map(|p| p as &dyn rusqlite::ToSql)
238            .collect();
239
240        let mut stmt = conn
241            .prepare(&sql)
242            .map_err(|e| Error::Database(e.to_string()))?;
243
244        let notes = stmt
245            .query_map(params_refs.as_slice(), |row| {
246                Ok(Note {
247                    id: row.get(0)?,
248                    title: row.get(1)?,
249                    body: row.get(2)?,
250                    updated_at: row.get(3)?,
251                    references: Self::parse_references(row.get(4)?),
252                    tags: Self::parse_tags(row.get(5)?),
253                })
254            })
255            .map_err(|e| Error::Database(e.to_string()))?
256            .collect::<Result<Vec<_>, _>>()
257            .map_err(|e| Error::Database(e.to_string()))?;
258
259        Ok(notes)
260    }
261
262    async fn count_notes(&self, query: NoteQuery) -> Result<i64, Error> {
263        let conn = self.conn.lock().unwrap();
264
265        let mut sql = String::from("SELECT COUNT(DISTINCT n.id) FROM notes n");
266
267        let mut conditions = Vec::new();
268        let mut params_vec: Vec<String> = Vec::new();
269
270        if let Some(ref tags) = query.tags {
271            if !tags.is_empty() {
272                let placeholders: Vec<_> = (0..tags.len()).map(|i| format!("?{}", i + 1)).collect();
273                conditions.push(format!(
274                    "n.id IN (SELECT note_id FROM note_tags nt2 
275                              JOIN tags t2 ON nt2.tag_id = t2.id 
276                              WHERE t2.name IN ({}))",
277                    placeholders.join(",")
278                ));
279                params_vec.extend(tags.clone());
280            }
281        }
282
283        if let Some(ref from) = query.from {
284            conditions.push(format!("n.updated_at >= ?{}", params_vec.len() + 1));
285            params_vec.push(from.clone());
286        }
287
288        if let Some(ref to) = query.to {
289            conditions.push(format!("n.updated_at <= ?{}", params_vec.len() + 1));
290            params_vec.push(to.clone());
291        }
292
293        if !conditions.is_empty() {
294            sql.push_str(" WHERE ");
295            sql.push_str(&conditions.join(" AND "));
296        }
297
298        let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
299            .iter()
300            .map(|p| p as &dyn rusqlite::ToSql)
301            .collect();
302
303        let count: i64 = conn
304            .query_row(&sql, params_refs.as_slice(), |row| row.get(0))
305            .map_err(|e| Error::Database(e.to_string()))?;
306
307        Ok(count)
308    }
309
310    async fn update_note(&self, id: i64, update: UpdateNote) -> Result<bool, Error> {
311        let conn = self.conn.lock().unwrap();
312
313        // Check if note exists
314        let exists: bool = conn
315            .query_row("SELECT 1 FROM notes WHERE id = ?1", params![id], |_| {
316                Ok(true)
317            })
318            .optional()
319            .map_err(|e| Error::Database(e.to_string()))?
320            .unwrap_or(false);
321
322        if !exists {
323            return Ok(false);
324        }
325
326        // Update title if provided
327        if let Some(ref title) = update.title {
328            conn.execute(
329                "UPDATE notes SET title = ?1, updated_at = datetime('now') WHERE id = ?2",
330                params![title, id],
331            )
332            .map_err(|e| Error::Database(e.to_string()))?;
333        }
334
335        // Update body if provided
336        if let Some(ref body) = update.body {
337            conn.execute(
338                "UPDATE notes SET body = ?1, updated_at = datetime('now') WHERE id = ?2",
339                params![body, id],
340            )
341            .map_err(|e| Error::Database(e.to_string()))?;
342        }
343
344        // Update tags if provided
345        if let Some(ref tags) = update.tags {
346            // Delete existing tags
347            conn.execute("DELETE FROM note_tags WHERE note_id = ?1", params![id])
348                .map_err(|e| Error::Database(e.to_string()))?;
349
350            // Insert new tags
351            for tag in tags {
352                conn.execute(
353                    "INSERT INTO tags (name) VALUES (?1) ON CONFLICT (name) DO NOTHING",
354                    params![tag],
355                )
356                .map_err(|e| Error::Database(e.to_string()))?;
357
358                conn.execute(
359                    "INSERT INTO note_tags (note_id, tag_id) SELECT ?1, id FROM tags WHERE name = ?2",
360                    params![id, tag],
361                )
362                .map_err(|e| Error::Database(e.to_string()))?;
363            }
364
365            // Update timestamp
366            conn.execute(
367                "UPDATE notes SET updated_at = datetime('now') WHERE id = ?1",
368                params![id],
369            )
370            .map_err(|e| Error::Database(e.to_string()))?;
371        }
372
373        // Update references if provided
374        if let Some(ref references) = update.references {
375            let refs_json = Self::serialize_references(references);
376            conn.execute(
377                "UPDATE notes SET \"references\" = ?1, updated_at = datetime('now') WHERE id = ?2",
378                params![refs_json, id],
379            )
380            .map_err(|e| Error::Database(e.to_string()))?;
381        }
382
383        Ok(true)
384    }
385
386    async fn delete_note(&self, id: i64) -> Result<bool, Error> {
387        let conn = self.conn.lock().unwrap();
388
389        let rows = conn
390            .execute("DELETE FROM notes WHERE id = ?1", params![id])
391            .map_err(|e| Error::Database(e.to_string()))?;
392
393        Ok(rows > 0)
394    }
395
396    async fn list_tags(&self) -> Result<Vec<TagCount>, Error> {
397        let conn = self.conn.lock().unwrap();
398
399        let mut stmt = conn
400            .prepare(
401                "SELECT t.name, COUNT(nt.note_id) as count
402                 FROM tags t
403                 LEFT JOIN note_tags nt ON t.id = nt.tag_id
404                 GROUP BY t.id
405                 HAVING count > 0
406                 ORDER BY count DESC, t.name",
407            )
408            .map_err(|e| Error::Database(e.to_string()))?;
409
410        let tags = stmt
411            .query_map([], |row| {
412                Ok(TagCount {
413                    name: row.get(0)?,
414                    count: row.get(1)?,
415                })
416            })
417            .map_err(|e| Error::Database(e.to_string()))?
418            .collect::<Result<Vec<_>, _>>()
419            .map_err(|e| Error::Database(e.to_string()))?;
420
421        Ok(tags)
422    }
423
424    async fn grep(
425        &self,
426        pattern: &str,
427        tags: Option<&[String]>,
428        case_sensitive: bool,
429    ) -> Result<Vec<Note>, Error> {
430        let conn = self.conn.lock().unwrap();
431
432        // Build regex
433        let regex = if case_sensitive {
434            Regex::new(pattern).map_err(|e| Error::Validation(format!("invalid regex: {}", e)))?
435        } else {
436            Regex::new(&format!("(?i){}", pattern))
437                .map_err(|e| Error::Validation(format!("invalid regex: {}", e)))?
438        };
439
440        // Query all notes (with tag filter if provided)
441        let mut sql = String::from(
442            "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
443             FROM notes n
444             LEFT JOIN note_tags nt ON n.id = nt.note_id
445             LEFT JOIN tags t ON nt.tag_id = t.id",
446        );
447
448        let mut params_vec: Vec<String> = Vec::new();
449
450        if let Some(tag_list) = tags {
451            if !tag_list.is_empty() {
452                let placeholders: Vec<_> =
453                    (0..tag_list.len()).map(|i| format!("?{}", i + 1)).collect();
454                sql.push_str(&format!(
455                    " WHERE n.id IN (SELECT note_id FROM note_tags nt2 
456                                     JOIN tags t2 ON nt2.tag_id = t2.id 
457                                     WHERE t2.name IN ({}))",
458                    placeholders.join(",")
459                ));
460                params_vec.extend(tag_list.iter().cloned());
461            }
462        }
463
464        sql.push_str(" GROUP BY n.id ORDER BY n.updated_at DESC, n.id DESC");
465
466        let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
467            .iter()
468            .map(|p| p as &dyn rusqlite::ToSql)
469            .collect();
470
471        let mut stmt = conn
472            .prepare(&sql)
473            .map_err(|e| Error::Database(e.to_string()))?;
474
475        let all_notes: Vec<Note> = stmt
476            .query_map(params_refs.as_slice(), |row| {
477                Ok(Note {
478                    id: row.get(0)?,
479                    title: row.get(1)?,
480                    body: row.get(2)?,
481                    updated_at: row.get(3)?,
482                    references: Self::parse_references(row.get(4)?),
483                    tags: Self::parse_tags(row.get(5)?),
484                })
485            })
486            .map_err(|e| Error::Database(e.to_string()))?
487            .collect::<Result<Vec<_>, _>>()
488            .map_err(|e| Error::Database(e.to_string()))?;
489
490        // Filter by regex
491        let matching: Vec<Note> = all_notes
492            .into_iter()
493            .filter(|note| regex.is_match(&note.title) || regex.is_match(&note.body))
494            .collect();
495
496        Ok(matching)
497    }
498}