1use regex::Regex;
4use rusqlite::{params, Connection, OptionalExtension};
5use std::path::Path;
6use std::sync::Mutex;
7use veta_core::{
8 get_pending_migrations, CreateNote, Database, Error, Note, NoteQuery, TagCount, UpdateNote,
9 SCHEMA_VERSION,
10};
11
12pub struct SqliteDatabase {
14 conn: Mutex<Connection>,
15}
16
17impl SqliteDatabase {
18 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
20 let conn = Connection::open(path).map_err(|e| Error::Database(e.to_string()))?;
21 conn.execute_batch("PRAGMA foreign_keys = ON;")
22 .map_err(|e| Error::Database(e.to_string()))?;
23 let db = Self {
24 conn: Mutex::new(conn),
25 };
26 db.run_migrations()?;
27 Ok(db)
28 }
29
30 pub fn open_in_memory() -> Result<Self, Error> {
32 let conn = Connection::open_in_memory().map_err(|e| Error::Database(e.to_string()))?;
33 conn.execute_batch("PRAGMA foreign_keys = ON;")
34 .map_err(|e| Error::Database(e.to_string()))?;
35 let db = Self {
36 conn: Mutex::new(conn),
37 };
38 db.run_migrations()?;
39 Ok(db)
40 }
41
42 fn run_migrations(&self) -> Result<(), Error> {
44 let conn = self.conn.lock().unwrap();
45
46 conn.execute(
48 "CREATE TABLE IF NOT EXISTS _veta_meta (
49 key TEXT PRIMARY KEY,
50 value TEXT NOT NULL
51 )",
52 [],
53 )
54 .map_err(|e| Error::Database(e.to_string()))?;
55
56 let current_version: i64 = conn
58 .query_row(
59 "SELECT value FROM _veta_meta WHERE key = 'schema_version'",
60 [],
61 |row| {
62 let val: String = row.get(0)?;
63 Ok(val.parse().unwrap_or(0))
64 },
65 )
66 .unwrap_or(0);
67
68 if current_version >= SCHEMA_VERSION {
70 return Ok(());
71 }
72
73 for migration in get_pending_migrations(current_version) {
75 for statement in migration.statements {
76 if statement.contains("_veta_meta") {
78 continue;
79 }
80 if statement.starts_with("ALTER TABLE") {
82 let _ = conn.execute(statement, []);
83 } else {
84 conn.execute(statement, []).map_err(|e| {
85 Error::Database(format!("Migration {} failed: {}", migration.name, e))
86 })?;
87 }
88 }
89 }
90
91 conn.execute(
93 "INSERT OR REPLACE INTO _veta_meta (key, value) VALUES ('schema_version', ?1)",
94 params![SCHEMA_VERSION.to_string()],
95 )
96 .map_err(|e| Error::Database(e.to_string()))?;
97
98 Ok(())
99 }
100
101 fn parse_tags(tags_str: Option<String>) -> Vec<String> {
102 let mut tags: Vec<String> = tags_str
103 .map(|s| {
104 s.split(',')
105 .map(String::from)
106 .filter(|s| !s.is_empty())
107 .collect()
108 })
109 .unwrap_or_default();
110 tags.sort();
111 tags
112 }
113
114 fn parse_references(refs_str: Option<String>) -> Vec<String> {
115 refs_str
116 .and_then(|s| serde_json::from_str(&s).ok())
117 .unwrap_or_default()
118 }
119
120 fn serialize_references(refs: &[String]) -> String {
121 serde_json::to_string(refs).unwrap_or_else(|_| "[]".to_string())
122 }
123}
124
125#[async_trait::async_trait(?Send)]
126impl Database for SqliteDatabase {
127 async fn add_note(&self, note: CreateNote) -> Result<i64, Error> {
128 let conn = self.conn.lock().unwrap();
129
130 let refs_json = Self::serialize_references(¬e.references);
131
132 conn.execute(
134 "INSERT INTO notes (title, body, \"references\") VALUES (?1, ?2, ?3)",
135 params![note.title, note.body, refs_json],
136 )
137 .map_err(|e| Error::Database(e.to_string()))?;
138
139 let note_id = conn.last_insert_rowid();
140
141 for tag in ¬e.tags {
143 conn.execute(
144 "INSERT INTO tags (name) VALUES (?1) ON CONFLICT (name) DO NOTHING",
145 params![tag],
146 )
147 .map_err(|e| Error::Database(e.to_string()))?;
148
149 conn.execute(
150 "INSERT INTO note_tags (note_id, tag_id) SELECT ?1, id FROM tags WHERE name = ?2",
151 params![note_id, tag],
152 )
153 .map_err(|e| Error::Database(e.to_string()))?;
154 }
155
156 Ok(note_id)
157 }
158
159 async fn get_note(&self, id: i64) -> Result<Option<Note>, Error> {
160 let conn = self.conn.lock().unwrap();
161
162 let note = conn
163 .query_row(
164 "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
165 FROM notes n
166 LEFT JOIN note_tags nt ON n.id = nt.note_id
167 LEFT JOIN tags t ON nt.tag_id = t.id
168 WHERE n.id = ?1
169 GROUP BY n.id",
170 params![id],
171 |row| {
172 Ok(Note {
173 id: row.get(0)?,
174 title: row.get(1)?,
175 body: row.get(2)?,
176 updated_at: row.get(3)?,
177 references: Self::parse_references(row.get(4)?),
178 tags: Self::parse_tags(row.get(5)?),
179 })
180 },
181 )
182 .optional()
183 .map_err(|e| Error::Database(e.to_string()))?;
184
185 Ok(note)
186 }
187
188 async fn list_notes(&self, query: NoteQuery) -> Result<Vec<Note>, Error> {
189 let conn = self.conn.lock().unwrap();
190
191 let mut sql = String::from(
192 "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
193 FROM notes n
194 LEFT JOIN note_tags nt ON n.id = nt.note_id
195 LEFT JOIN tags t ON nt.tag_id = t.id",
196 );
197
198 let mut conditions = Vec::new();
199 let mut params_vec: Vec<String> = Vec::new();
200
201 if let Some(ref tags) = query.tags {
202 if !tags.is_empty() {
203 let placeholders: Vec<_> = (0..tags.len()).map(|i| format!("?{}", i + 1)).collect();
204 conditions.push(format!(
205 "n.id IN (SELECT note_id FROM note_tags nt2
206 JOIN tags t2 ON nt2.tag_id = t2.id
207 WHERE t2.name IN ({}))",
208 placeholders.join(",")
209 ));
210 params_vec.extend(tags.clone());
211 }
212 }
213
214 if let Some(ref from) = query.from {
215 conditions.push(format!("n.updated_at >= ?{}", params_vec.len() + 1));
216 params_vec.push(from.clone());
217 }
218
219 if let Some(ref to) = query.to {
220 conditions.push(format!("n.updated_at <= ?{}", params_vec.len() + 1));
221 params_vec.push(to.clone());
222 }
223
224 if !conditions.is_empty() {
225 sql.push_str(" WHERE ");
226 sql.push_str(&conditions.join(" AND "));
227 }
228
229 sql.push_str(" GROUP BY n.id ORDER BY n.updated_at DESC, n.id DESC");
230
231 if let Some(limit) = query.limit {
232 sql.push_str(&format!(" LIMIT {}", limit));
233 }
234
235 let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
236 .iter()
237 .map(|p| p as &dyn rusqlite::ToSql)
238 .collect();
239
240 let mut stmt = conn
241 .prepare(&sql)
242 .map_err(|e| Error::Database(e.to_string()))?;
243
244 let notes = stmt
245 .query_map(params_refs.as_slice(), |row| {
246 Ok(Note {
247 id: row.get(0)?,
248 title: row.get(1)?,
249 body: row.get(2)?,
250 updated_at: row.get(3)?,
251 references: Self::parse_references(row.get(4)?),
252 tags: Self::parse_tags(row.get(5)?),
253 })
254 })
255 .map_err(|e| Error::Database(e.to_string()))?
256 .collect::<Result<Vec<_>, _>>()
257 .map_err(|e| Error::Database(e.to_string()))?;
258
259 Ok(notes)
260 }
261
262 async fn count_notes(&self, query: NoteQuery) -> Result<i64, Error> {
263 let conn = self.conn.lock().unwrap();
264
265 let mut sql = String::from("SELECT COUNT(DISTINCT n.id) FROM notes n");
266
267 let mut conditions = Vec::new();
268 let mut params_vec: Vec<String> = Vec::new();
269
270 if let Some(ref tags) = query.tags {
271 if !tags.is_empty() {
272 let placeholders: Vec<_> = (0..tags.len()).map(|i| format!("?{}", i + 1)).collect();
273 conditions.push(format!(
274 "n.id IN (SELECT note_id FROM note_tags nt2
275 JOIN tags t2 ON nt2.tag_id = t2.id
276 WHERE t2.name IN ({}))",
277 placeholders.join(",")
278 ));
279 params_vec.extend(tags.clone());
280 }
281 }
282
283 if let Some(ref from) = query.from {
284 conditions.push(format!("n.updated_at >= ?{}", params_vec.len() + 1));
285 params_vec.push(from.clone());
286 }
287
288 if let Some(ref to) = query.to {
289 conditions.push(format!("n.updated_at <= ?{}", params_vec.len() + 1));
290 params_vec.push(to.clone());
291 }
292
293 if !conditions.is_empty() {
294 sql.push_str(" WHERE ");
295 sql.push_str(&conditions.join(" AND "));
296 }
297
298 let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
299 .iter()
300 .map(|p| p as &dyn rusqlite::ToSql)
301 .collect();
302
303 let count: i64 = conn
304 .query_row(&sql, params_refs.as_slice(), |row| row.get(0))
305 .map_err(|e| Error::Database(e.to_string()))?;
306
307 Ok(count)
308 }
309
310 async fn update_note(&self, id: i64, update: UpdateNote) -> Result<bool, Error> {
311 let conn = self.conn.lock().unwrap();
312
313 let exists: bool = conn
315 .query_row("SELECT 1 FROM notes WHERE id = ?1", params![id], |_| {
316 Ok(true)
317 })
318 .optional()
319 .map_err(|e| Error::Database(e.to_string()))?
320 .unwrap_or(false);
321
322 if !exists {
323 return Ok(false);
324 }
325
326 if let Some(ref title) = update.title {
328 conn.execute(
329 "UPDATE notes SET title = ?1, updated_at = datetime('now') WHERE id = ?2",
330 params![title, id],
331 )
332 .map_err(|e| Error::Database(e.to_string()))?;
333 }
334
335 if let Some(ref body) = update.body {
337 conn.execute(
338 "UPDATE notes SET body = ?1, updated_at = datetime('now') WHERE id = ?2",
339 params![body, id],
340 )
341 .map_err(|e| Error::Database(e.to_string()))?;
342 }
343
344 if let Some(ref tags) = update.tags {
346 conn.execute("DELETE FROM note_tags WHERE note_id = ?1", params![id])
348 .map_err(|e| Error::Database(e.to_string()))?;
349
350 for tag in tags {
352 conn.execute(
353 "INSERT INTO tags (name) VALUES (?1) ON CONFLICT (name) DO NOTHING",
354 params![tag],
355 )
356 .map_err(|e| Error::Database(e.to_string()))?;
357
358 conn.execute(
359 "INSERT INTO note_tags (note_id, tag_id) SELECT ?1, id FROM tags WHERE name = ?2",
360 params![id, tag],
361 )
362 .map_err(|e| Error::Database(e.to_string()))?;
363 }
364
365 conn.execute(
367 "UPDATE notes SET updated_at = datetime('now') WHERE id = ?1",
368 params![id],
369 )
370 .map_err(|e| Error::Database(e.to_string()))?;
371 }
372
373 if let Some(ref references) = update.references {
375 let refs_json = Self::serialize_references(references);
376 conn.execute(
377 "UPDATE notes SET \"references\" = ?1, updated_at = datetime('now') WHERE id = ?2",
378 params![refs_json, id],
379 )
380 .map_err(|e| Error::Database(e.to_string()))?;
381 }
382
383 Ok(true)
384 }
385
386 async fn delete_note(&self, id: i64) -> Result<bool, Error> {
387 let conn = self.conn.lock().unwrap();
388
389 let rows = conn
390 .execute("DELETE FROM notes WHERE id = ?1", params![id])
391 .map_err(|e| Error::Database(e.to_string()))?;
392
393 Ok(rows > 0)
394 }
395
396 async fn list_tags(&self) -> Result<Vec<TagCount>, Error> {
397 let conn = self.conn.lock().unwrap();
398
399 let mut stmt = conn
400 .prepare(
401 "SELECT t.name, COUNT(nt.note_id) as count
402 FROM tags t
403 LEFT JOIN note_tags nt ON t.id = nt.tag_id
404 GROUP BY t.id
405 HAVING count > 0
406 ORDER BY count DESC, t.name",
407 )
408 .map_err(|e| Error::Database(e.to_string()))?;
409
410 let tags = stmt
411 .query_map([], |row| {
412 Ok(TagCount {
413 name: row.get(0)?,
414 count: row.get(1)?,
415 })
416 })
417 .map_err(|e| Error::Database(e.to_string()))?
418 .collect::<Result<Vec<_>, _>>()
419 .map_err(|e| Error::Database(e.to_string()))?;
420
421 Ok(tags)
422 }
423
424 async fn grep(
425 &self,
426 pattern: &str,
427 tags: Option<&[String]>,
428 case_sensitive: bool,
429 ) -> Result<Vec<Note>, Error> {
430 let conn = self.conn.lock().unwrap();
431
432 let regex = if case_sensitive {
434 Regex::new(pattern).map_err(|e| Error::Validation(format!("invalid regex: {}", e)))?
435 } else {
436 Regex::new(&format!("(?i){}", pattern))
437 .map_err(|e| Error::Validation(format!("invalid regex: {}", e)))?
438 };
439
440 let mut sql = String::from(
442 "SELECT n.id, n.title, n.body, n.updated_at, n.\"references\", GROUP_CONCAT(t.name) as tags
443 FROM notes n
444 LEFT JOIN note_tags nt ON n.id = nt.note_id
445 LEFT JOIN tags t ON nt.tag_id = t.id",
446 );
447
448 let mut params_vec: Vec<String> = Vec::new();
449
450 if let Some(tag_list) = tags {
451 if !tag_list.is_empty() {
452 let placeholders: Vec<_> =
453 (0..tag_list.len()).map(|i| format!("?{}", i + 1)).collect();
454 sql.push_str(&format!(
455 " WHERE n.id IN (SELECT note_id FROM note_tags nt2
456 JOIN tags t2 ON nt2.tag_id = t2.id
457 WHERE t2.name IN ({}))",
458 placeholders.join(",")
459 ));
460 params_vec.extend(tag_list.iter().cloned());
461 }
462 }
463
464 sql.push_str(" GROUP BY n.id ORDER BY n.updated_at DESC, n.id DESC");
465
466 let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec
467 .iter()
468 .map(|p| p as &dyn rusqlite::ToSql)
469 .collect();
470
471 let mut stmt = conn
472 .prepare(&sql)
473 .map_err(|e| Error::Database(e.to_string()))?;
474
475 let all_notes: Vec<Note> = stmt
476 .query_map(params_refs.as_slice(), |row| {
477 Ok(Note {
478 id: row.get(0)?,
479 title: row.get(1)?,
480 body: row.get(2)?,
481 updated_at: row.get(3)?,
482 references: Self::parse_references(row.get(4)?),
483 tags: Self::parse_tags(row.get(5)?),
484 })
485 })
486 .map_err(|e| Error::Database(e.to_string()))?
487 .collect::<Result<Vec<_>, _>>()
488 .map_err(|e| Error::Database(e.to_string()))?;
489
490 let matching: Vec<Note> = all_notes
492 .into_iter()
493 .filter(|note| regex.is_match(¬e.title) || regex.is_match(¬e.body))
494 .collect();
495
496 Ok(matching)
497 }
498}