1use anyhow::Context;
10use chrono::{TimeZone, Utc};
11use rusqlite::types::{FromSql, ToSql};
12use rusqlite::{params, Connection, OptionalExtension};
13use std::path::Path;
14use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
15use uuid::Uuid;
16
17struct StoredUuid(Uuid);
19
20impl FromSql for StoredUuid {
22 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
23 let u = Uuid::parse_str(value.as_str()?)
24 .map_err(|_| rusqlite::types::FromSqlError::InvalidType)?;
25 Ok(StoredUuid(u))
26 }
27}
28
29impl ToSql for StoredUuid {
31 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
32 let s = self.0.to_string();
33 Ok(s.into())
34 }
35}
36
37pub struct SqliteStorage {
42 db_file: std::path::PathBuf,
43}
44
45impl SqliteStorage {
46 fn new_connection(&self) -> anyhow::Result<Connection> {
47 Ok(Connection::open(&self.db_file)?)
48 }
49
50 pub fn new<P: AsRef<Path>>(directory: P) -> anyhow::Result<SqliteStorage> {
55 std::fs::create_dir_all(&directory)
56 .with_context(|| format!("Failed to create `{}`.", directory.as_ref().display()))?;
57 let db_file = directory.as_ref().join("taskchampion-sync-server.sqlite3");
58
59 let o = SqliteStorage { db_file };
60
61 let con = o.new_connection()?;
62
63 con.query_row("PRAGMA journal_mode=WAL", [], |_row| Ok(()))
65 .context("Setting journal_mode=WAL")?;
66
67 let queries = vec![
68 "CREATE TABLE IF NOT EXISTS clients (
69 client_id STRING PRIMARY KEY,
70 latest_version_id STRING,
71 snapshot_version_id STRING,
72 versions_since_snapshot INTEGER,
73 snapshot_timestamp INTEGER,
74 snapshot BLOB);",
75 "CREATE TABLE IF NOT EXISTS versions (version_id STRING PRIMARY KEY, client_id STRING, parent_version_id STRING, history_segment BLOB);",
76 "CREATE INDEX IF NOT EXISTS versions_by_parent ON versions (parent_version_id);",
77 ];
78 for q in queries {
79 con.execute(q, [])
80 .context("Error while creating SQLite tables")?;
81 }
82
83 Ok(o)
84 }
85}
86
87#[async_trait::async_trait]
88impl Storage for SqliteStorage {
89 async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
90 let con = self.new_connection()?;
91 con.execute("BEGIN IMMEDIATE", [])?;
94 let txn = Txn { con, client_id };
95 Ok(Box::new(txn))
96 }
97}
98
99struct Txn {
100 con: Connection,
104 client_id: Uuid,
105}
106
107impl Txn {
108 fn get_version_impl(
110 &mut self,
111 query: &'static str,
112 client_id: Uuid,
113 version_id_arg: Uuid,
114 ) -> anyhow::Result<Option<Version>> {
115 let r = self
116 .con
117 .query_row(
118 query,
119 params![&StoredUuid(version_id_arg), &StoredUuid(client_id)],
120 |r| {
121 let version_id: StoredUuid = r.get("version_id")?;
122 let parent_version_id: StoredUuid = r.get("parent_version_id")?;
123
124 Ok(Version {
125 version_id: version_id.0,
126 parent_version_id: parent_version_id.0,
127 history_segment: r.get("history_segment")?,
128 })
129 },
130 )
131 .optional()
132 .context("Error getting version")?;
133 Ok(r)
134 }
135}
136
137#[async_trait::async_trait(?Send)]
138impl StorageTxn for Txn {
139 async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
140 let result: Option<Client> = self
141 .con
142 .query_row(
143 "SELECT
144 latest_version_id,
145 snapshot_timestamp,
146 versions_since_snapshot,
147 snapshot_version_id
148 FROM clients
149 WHERE client_id = ?
150 LIMIT 1",
151 [&StoredUuid(self.client_id)],
152 |r| {
153 let latest_version_id: StoredUuid = r.get(0)?;
154 let snapshot_timestamp: Option<i64> = r.get(1)?;
155 let versions_since_snapshot: Option<u32> = r.get(2)?;
156 let snapshot_version_id: Option<StoredUuid> = r.get(3)?;
157
158 let snapshot = match (
160 snapshot_timestamp,
161 versions_since_snapshot,
162 snapshot_version_id,
163 ) {
164 (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
165 version_id: v.0,
166 timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
167 versions_since: vs,
168 }),
169 _ => None,
170 };
171 Ok(Client {
172 latest_version_id: latest_version_id.0,
173 snapshot,
174 })
175 },
176 )
177 .optional()
178 .context("Error getting client")?;
179
180 Ok(result)
181 }
182
183 async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
184 self.con
185 .execute(
186 "INSERT INTO clients (client_id, latest_version_id) VALUES (?, ?)",
187 params![&StoredUuid(self.client_id), &StoredUuid(latest_version_id)],
188 )
189 .context("Error creating/updating client")?;
190 Ok(())
191 }
192
193 async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
194 self.con
195 .execute(
196 "UPDATE clients
197 SET
198 snapshot_version_id = ?,
199 snapshot_timestamp = ?,
200 versions_since_snapshot = ?,
201 snapshot = ?
202 WHERE client_id = ?",
203 params![
204 &StoredUuid(snapshot.version_id),
205 snapshot.timestamp.timestamp(),
206 snapshot.versions_since,
207 data,
208 &StoredUuid(self.client_id),
209 ],
210 )
211 .context("Error creating/updating snapshot")?;
212 Ok(())
213 }
214
215 async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
216 let r = self
217 .con
218 .query_row(
219 "SELECT snapshot, snapshot_version_id FROM clients WHERE client_id = ?",
220 params![&StoredUuid(self.client_id)],
221 |r| {
222 let v: StoredUuid = r.get("snapshot_version_id")?;
223 let d: Vec<u8> = r.get("snapshot")?;
224 Ok((v.0, d))
225 },
226 )
227 .optional()
228 .context("Error getting snapshot")?;
229 r.map(|(v, d)| {
230 if v != version_id {
231 return Err(anyhow::anyhow!("unexpected snapshot_version_id"));
232 }
233
234 Ok(d)
235 })
236 .transpose()
237 }
238
239 async fn get_version_by_parent(
240 &mut self,
241 parent_version_id: Uuid,
242 ) -> anyhow::Result<Option<Version>> {
243 self.get_version_impl(
244 "SELECT version_id, parent_version_id, history_segment FROM versions WHERE parent_version_id = ? AND client_id = ?",
245 self.client_id,
246 parent_version_id)
247 }
248
249 async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
250 self.get_version_impl(
251 "SELECT version_id, parent_version_id, history_segment FROM versions WHERE version_id = ? AND client_id = ?",
252 self.client_id,
253 version_id)
254 }
255
256 async fn add_version(
257 &mut self,
258 version_id: Uuid,
259 parent_version_id: Uuid,
260 history_segment: Vec<u8>,
261 ) -> anyhow::Result<()> {
262 self.con.execute(
263 "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment) VALUES(?, ?, ?, ?)",
264 params![
265 StoredUuid(version_id),
266 StoredUuid(self.client_id),
267 StoredUuid(parent_version_id),
268 history_segment
269 ]
270 )
271 .context("Error adding version")?;
272 let rows_changed = self
273 .con
274 .execute(
275 "UPDATE clients
276 SET
277 latest_version_id = ?,
278 versions_since_snapshot = versions_since_snapshot + 1
279 WHERE client_id = ? and (latest_version_id = ? or latest_version_id = ?)",
280 params![
281 StoredUuid(version_id),
282 StoredUuid(self.client_id),
283 StoredUuid(parent_version_id),
284 StoredUuid(Uuid::nil())
285 ],
286 )
287 .context("Error updating client for new version")?;
288
289 if rows_changed == 0 {
290 anyhow::bail!("clients.latest_version_id does not match parent_version_id");
291 }
292
293 Ok(())
294 }
295
296 async fn commit(&mut self) -> anyhow::Result<()> {
297 self.con.execute("COMMIT", [])?;
298 Ok(())
299 }
300}
301
302#[cfg(test)]
303mod test {
304 use super::*;
305 use chrono::DateTime;
306 use pretty_assertions::assert_eq;
307 use tempfile::TempDir;
308
309 #[tokio::test]
310 async fn test_emtpy_dir() -> anyhow::Result<()> {
311 let tmp_dir = TempDir::new()?;
312 let non_existant = tmp_dir.path().join("subdir");
313 let storage = SqliteStorage::new(non_existant)?;
314 let client_id = Uuid::new_v4();
315 let mut txn = storage.txn(client_id).await?;
316 let maybe_client = txn.get_client().await?;
317 assert!(maybe_client.is_none());
318 Ok(())
319 }
320
321 #[tokio::test]
322 async fn test_get_client_empty() -> anyhow::Result<()> {
323 let tmp_dir = TempDir::new()?;
324 let storage = SqliteStorage::new(tmp_dir.path())?;
325 let client_id = Uuid::new_v4();
326 let mut txn = storage.txn(client_id).await?;
327 let maybe_client = txn.get_client().await?;
328 assert!(maybe_client.is_none());
329 Ok(())
330 }
331
332 #[tokio::test]
333 async fn test_client_storage() -> anyhow::Result<()> {
334 let tmp_dir = TempDir::new()?;
335 let storage = SqliteStorage::new(tmp_dir.path())?;
336 let client_id = Uuid::new_v4();
337 let mut txn = storage.txn(client_id).await?;
338
339 let latest_version_id = Uuid::new_v4();
340 txn.new_client(latest_version_id).await?;
341
342 let client = txn.get_client().await?.unwrap();
343 assert_eq!(client.latest_version_id, latest_version_id);
344 assert!(client.snapshot.is_none());
345
346 let new_version_id = Uuid::new_v4();
347 txn.add_version(new_version_id, latest_version_id, vec![1, 1])
348 .await?;
349
350 let client = txn.get_client().await?.unwrap();
351 assert_eq!(client.latest_version_id, new_version_id);
352 assert!(client.snapshot.is_none());
353
354 let snap = Snapshot {
355 version_id: Uuid::new_v4(),
356 timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
357 versions_since: 4,
358 };
359 txn.set_snapshot(snap.clone(), vec![1, 2, 3]).await?;
360
361 let client = txn.get_client().await?.unwrap();
362 assert_eq!(client.latest_version_id, new_version_id);
363 assert_eq!(client.snapshot.unwrap(), snap);
364
365 Ok(())
366 }
367
368 #[tokio::test]
369 async fn test_gvbp_empty() -> anyhow::Result<()> {
370 let tmp_dir = TempDir::new()?;
371 let storage = SqliteStorage::new(tmp_dir.path())?;
372 let client_id = Uuid::new_v4();
373 let mut txn = storage.txn(client_id).await?;
374 let maybe_version = txn.get_version_by_parent(Uuid::new_v4()).await?;
375 assert!(maybe_version.is_none());
376 Ok(())
377 }
378
379 #[tokio::test]
380 async fn test_add_version_and_get_version() -> anyhow::Result<()> {
381 let tmp_dir = TempDir::new()?;
382 let storage = SqliteStorage::new(tmp_dir.path())?;
383 let client_id = Uuid::new_v4();
384 let mut txn = storage.txn(client_id).await?;
385
386 let parent_version_id = Uuid::new_v4();
387 txn.new_client(parent_version_id).await?;
388
389 let version_id = Uuid::new_v4();
390 let history_segment = b"abc".to_vec();
391 txn.add_version(version_id, parent_version_id, history_segment.clone())
392 .await?;
393
394 let expected = Version {
395 version_id,
396 parent_version_id,
397 history_segment,
398 };
399
400 let version = txn.get_version_by_parent(parent_version_id).await?.unwrap();
401 assert_eq!(version, expected);
402
403 let version = txn.get_version(version_id).await?.unwrap();
404 assert_eq!(version, expected);
405
406 Ok(())
407 }
408
409 #[tokio::test]
410 async fn test_add_version_exists() -> anyhow::Result<()> {
411 let tmp_dir = TempDir::new()?;
412 let storage = SqliteStorage::new(tmp_dir.path())?;
413 let client_id = Uuid::new_v4();
414 let mut txn = storage.txn(client_id).await?;
415
416 let parent_version_id = Uuid::new_v4();
417 txn.new_client(parent_version_id).await?;
418
419 let version_id = Uuid::new_v4();
420 let history_segment = b"abc".to_vec();
421 txn.add_version(version_id, parent_version_id, history_segment.clone())
422 .await?;
423 assert!(txn
425 .add_version(version_id, parent_version_id, history_segment.clone())
426 .await
427 .is_err());
428 Ok(())
429 }
430
431 #[tokio::test]
432 async fn test_add_version_mismatch() -> anyhow::Result<()> {
433 let tmp_dir = TempDir::new()?;
434 let storage = SqliteStorage::new(tmp_dir.path())?;
435 let client_id = Uuid::new_v4();
436 let mut txn = storage.txn(client_id).await?;
437
438 let latest_version_id = Uuid::new_v4();
439 txn.new_client(latest_version_id).await?;
440
441 let version_id = Uuid::new_v4();
442 let parent_version_id = Uuid::new_v4(); let history_segment = b"abc".to_vec();
444 assert!(txn
446 .add_version(version_id, parent_version_id, history_segment.clone())
447 .await
448 .is_err());
449 Ok(())
450 }
451
452 #[tokio::test]
453 async fn test_snapshots() -> anyhow::Result<()> {
454 let tmp_dir = TempDir::new()?;
455 let storage = SqliteStorage::new(tmp_dir.path())?;
456 let client_id = Uuid::new_v4();
457 let mut txn = storage.txn(client_id).await?;
458
459 txn.new_client(Uuid::new_v4()).await?;
460 assert!(txn.get_client().await?.unwrap().snapshot.is_none());
461
462 let snap = Snapshot {
463 version_id: Uuid::new_v4(),
464 timestamp: "2013-10-08T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
465 versions_since: 3,
466 };
467 txn.set_snapshot(snap.clone(), vec![9, 8, 9]).await?;
468
469 assert_eq!(
470 txn.get_snapshot_data(snap.version_id).await?.unwrap(),
471 vec![9, 8, 9]
472 );
473 assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap));
474
475 let snap2 = Snapshot {
476 version_id: Uuid::new_v4(),
477 timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
478 versions_since: 10,
479 };
480 txn.set_snapshot(snap2.clone(), vec![0, 2, 4, 6]).await?;
481
482 assert_eq!(
483 txn.get_snapshot_data(snap2.version_id).await?.unwrap(),
484 vec![0, 2, 4, 6]
485 );
486 assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap2));
487
488 assert!(txn.get_snapshot_data(Uuid::new_v4()).await.is_err());
490
491 Ok(())
492 }
493
494 #[tokio::test]
495 async fn test_add_version_no_history() -> anyhow::Result<()> {
498 let tmp_dir = TempDir::new()?;
499 let storage = SqliteStorage::new(tmp_dir.path())?;
500 let client_id = Uuid::new_v4();
501 let mut txn = storage.txn(client_id).await?;
502 txn.new_client(Uuid::nil()).await?;
503
504 let version_id = Uuid::new_v4();
505 let parent_version_id = Uuid::new_v4();
506 txn.add_version(version_id, parent_version_id, b"v1".to_vec())
507 .await?;
508 Ok(())
509 }
510}