taskchampion_sync_server_storage_sqlite/
lib.rs

1//! This crate implements a SQLite storage backend for the TaskChampion sync server.
2//!
3//! Use the [`SqliteStorage`] type as an implementation of the [`Storage`] trait.
4//!
5//! This crate is intended for small deployments of a sync server, supporting one or a small number
6//! of users. The schema for the database is considered an implementation detail. For more robust
7//! database support, consider `taskchampion-sync-server-storage-postgres`.
8
9use anyhow::Context;
10use chrono::{TimeZone, Utc};
11use rusqlite::types::{FromSql, ToSql};
12use rusqlite::{params, Connection, OptionalExtension};
13use std::path::Path;
14use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
15use uuid::Uuid;
16
17/// Newtype to allow implementing `FromSql` for foreign `uuid::Uuid`
18struct StoredUuid(Uuid);
19
20/// Conversion from Uuid stored as a string (rusqlite's uuid feature stores as binary blob)
21impl FromSql for StoredUuid {
22    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
23        let u = Uuid::parse_str(value.as_str()?)
24            .map_err(|_| rusqlite::types::FromSqlError::InvalidType)?;
25        Ok(StoredUuid(u))
26    }
27}
28
29/// Store Uuid as string in database
30impl ToSql for StoredUuid {
31    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
32        let s = self.0.to_string();
33        Ok(s.into())
34    }
35}
36
37/// An on-disk storage backend which uses SQLite.
38///
39/// A new connection is opened for each transaction, and only one transaction may be active at a
40/// time; a second call to `txn` will block until the first transaction is dropped.
41pub struct SqliteStorage {
42    db_file: std::path::PathBuf,
43}
44
45impl SqliteStorage {
46    fn new_connection(&self) -> anyhow::Result<Connection> {
47        Ok(Connection::open(&self.db_file)?)
48    }
49
50    /// Create a new instance using a database at the given directory.
51    ///
52    /// The database will be stored in a file named `taskchampion-sync-server.sqlite3` in the given
53    /// directory. The database will be created if it does not exist.
54    pub fn new<P: AsRef<Path>>(directory: P) -> anyhow::Result<SqliteStorage> {
55        std::fs::create_dir_all(&directory)
56            .with_context(|| format!("Failed to create `{}`.", directory.as_ref().display()))?;
57        let db_file = directory.as_ref().join("taskchampion-sync-server.sqlite3");
58
59        let o = SqliteStorage { db_file };
60
61        let con = o.new_connection()?;
62
63        // Use the modern WAL mode.
64        con.query_row("PRAGMA journal_mode=WAL", [], |_row| Ok(()))
65            .context("Setting journal_mode=WAL")?;
66
67        let queries = vec![
68                "CREATE TABLE IF NOT EXISTS clients (
69                    client_id STRING PRIMARY KEY,
70                    latest_version_id STRING,
71                    snapshot_version_id STRING,
72                    versions_since_snapshot INTEGER,
73                    snapshot_timestamp INTEGER,
74                    snapshot BLOB);",
75                "CREATE TABLE IF NOT EXISTS versions (version_id STRING PRIMARY KEY, client_id STRING, parent_version_id STRING, history_segment BLOB);",
76                "CREATE INDEX IF NOT EXISTS versions_by_parent ON versions (parent_version_id);",
77            ];
78        for q in queries {
79            con.execute(q, [])
80                .context("Error while creating SQLite tables")?;
81        }
82
83        Ok(o)
84    }
85}
86
87#[async_trait::async_trait]
88impl Storage for SqliteStorage {
89    async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
90        let con = self.new_connection()?;
91        // Begin the transaction on this new connection. An IMMEDIATE connection is in
92        // write (exclusive) mode from the start.
93        con.execute("BEGIN IMMEDIATE", [])?;
94        let txn = Txn { con, client_id };
95        Ok(Box::new(txn))
96    }
97}
98
99struct Txn {
100    // SQLite only allows one concurrent transaction per connection, and rusqlite emulates
101    // transactions by running `BEGIN ...` and `COMMIT` at appropriate times. So we will do
102    // the same.
103    con: Connection,
104    client_id: Uuid,
105}
106
107impl Txn {
108    /// Implementation for queries from the versions table
109    fn get_version_impl(
110        &mut self,
111        query: &'static str,
112        client_id: Uuid,
113        version_id_arg: Uuid,
114    ) -> anyhow::Result<Option<Version>> {
115        let r = self
116            .con
117            .query_row(
118                query,
119                params![&StoredUuid(version_id_arg), &StoredUuid(client_id)],
120                |r| {
121                    let version_id: StoredUuid = r.get("version_id")?;
122                    let parent_version_id: StoredUuid = r.get("parent_version_id")?;
123
124                    Ok(Version {
125                        version_id: version_id.0,
126                        parent_version_id: parent_version_id.0,
127                        history_segment: r.get("history_segment")?,
128                    })
129                },
130            )
131            .optional()
132            .context("Error getting version")?;
133        Ok(r)
134    }
135}
136
137#[async_trait::async_trait(?Send)]
138impl StorageTxn for Txn {
139    async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
140        let result: Option<Client> = self
141            .con
142            .query_row(
143                "SELECT
144                    latest_version_id,
145                    snapshot_timestamp,
146                    versions_since_snapshot,
147                    snapshot_version_id
148                 FROM clients
149                 WHERE client_id = ?
150                 LIMIT 1",
151                [&StoredUuid(self.client_id)],
152                |r| {
153                    let latest_version_id: StoredUuid = r.get(0)?;
154                    let snapshot_timestamp: Option<i64> = r.get(1)?;
155                    let versions_since_snapshot: Option<u32> = r.get(2)?;
156                    let snapshot_version_id: Option<StoredUuid> = r.get(3)?;
157
158                    // if all of the relevant fields are non-NULL, return a snapshot
159                    let snapshot = match (
160                        snapshot_timestamp,
161                        versions_since_snapshot,
162                        snapshot_version_id,
163                    ) {
164                        (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
165                            version_id: v.0,
166                            timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
167                            versions_since: vs,
168                        }),
169                        _ => None,
170                    };
171                    Ok(Client {
172                        latest_version_id: latest_version_id.0,
173                        snapshot,
174                    })
175                },
176            )
177            .optional()
178            .context("Error getting client")?;
179
180        Ok(result)
181    }
182
183    async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
184        self.con
185            .execute(
186                "INSERT INTO clients (client_id, latest_version_id) VALUES (?, ?)",
187                params![&StoredUuid(self.client_id), &StoredUuid(latest_version_id)],
188            )
189            .context("Error creating/updating client")?;
190        Ok(())
191    }
192
193    async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
194        self.con
195            .execute(
196                "UPDATE clients
197             SET
198               snapshot_version_id = ?,
199               snapshot_timestamp = ?,
200               versions_since_snapshot = ?,
201               snapshot = ?
202             WHERE client_id = ?",
203                params![
204                    &StoredUuid(snapshot.version_id),
205                    snapshot.timestamp.timestamp(),
206                    snapshot.versions_since,
207                    data,
208                    &StoredUuid(self.client_id),
209                ],
210            )
211            .context("Error creating/updating snapshot")?;
212        Ok(())
213    }
214
215    async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
216        let r = self
217            .con
218            .query_row(
219                "SELECT snapshot, snapshot_version_id FROM clients WHERE client_id = ?",
220                params![&StoredUuid(self.client_id)],
221                |r| {
222                    let v: StoredUuid = r.get("snapshot_version_id")?;
223                    let d: Vec<u8> = r.get("snapshot")?;
224                    Ok((v.0, d))
225                },
226            )
227            .optional()
228            .context("Error getting snapshot")?;
229        r.map(|(v, d)| {
230            if v != version_id {
231                return Err(anyhow::anyhow!("unexpected snapshot_version_id"));
232            }
233
234            Ok(d)
235        })
236        .transpose()
237    }
238
239    async fn get_version_by_parent(
240        &mut self,
241        parent_version_id: Uuid,
242    ) -> anyhow::Result<Option<Version>> {
243        self.get_version_impl(
244            "SELECT version_id, parent_version_id, history_segment FROM versions WHERE parent_version_id = ? AND client_id = ?",
245            self.client_id,
246            parent_version_id)
247    }
248
249    async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
250        self.get_version_impl(
251            "SELECT version_id, parent_version_id, history_segment FROM versions WHERE version_id = ? AND client_id = ?",
252            self.client_id,
253            version_id)
254    }
255
256    async fn add_version(
257        &mut self,
258        version_id: Uuid,
259        parent_version_id: Uuid,
260        history_segment: Vec<u8>,
261    ) -> anyhow::Result<()> {
262        self.con.execute(
263            "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment) VALUES(?, ?, ?, ?)",
264            params![
265                StoredUuid(version_id),
266                StoredUuid(self.client_id),
267                StoredUuid(parent_version_id),
268                history_segment
269            ]
270        )
271        .context("Error adding version")?;
272        let rows_changed = self
273            .con
274            .execute(
275                "UPDATE clients
276             SET
277               latest_version_id = ?,
278               versions_since_snapshot = versions_since_snapshot + 1
279             WHERE client_id = ? and (latest_version_id = ? or latest_version_id = ?)",
280                params![
281                    StoredUuid(version_id),
282                    StoredUuid(self.client_id),
283                    StoredUuid(parent_version_id),
284                    StoredUuid(Uuid::nil())
285                ],
286            )
287            .context("Error updating client for new version")?;
288
289        if rows_changed == 0 {
290            anyhow::bail!("clients.latest_version_id does not match parent_version_id");
291        }
292
293        Ok(())
294    }
295
296    async fn commit(&mut self) -> anyhow::Result<()> {
297        self.con.execute("COMMIT", [])?;
298        Ok(())
299    }
300}
301
302#[cfg(test)]
303mod test {
304    use super::*;
305    use chrono::DateTime;
306    use pretty_assertions::assert_eq;
307    use tempfile::TempDir;
308
309    #[tokio::test]
310    async fn test_emtpy_dir() -> anyhow::Result<()> {
311        let tmp_dir = TempDir::new()?;
312        let non_existant = tmp_dir.path().join("subdir");
313        let storage = SqliteStorage::new(non_existant)?;
314        let client_id = Uuid::new_v4();
315        let mut txn = storage.txn(client_id).await?;
316        let maybe_client = txn.get_client().await?;
317        assert!(maybe_client.is_none());
318        Ok(())
319    }
320
321    #[tokio::test]
322    async fn test_get_client_empty() -> anyhow::Result<()> {
323        let tmp_dir = TempDir::new()?;
324        let storage = SqliteStorage::new(tmp_dir.path())?;
325        let client_id = Uuid::new_v4();
326        let mut txn = storage.txn(client_id).await?;
327        let maybe_client = txn.get_client().await?;
328        assert!(maybe_client.is_none());
329        Ok(())
330    }
331
332    #[tokio::test]
333    async fn test_client_storage() -> anyhow::Result<()> {
334        let tmp_dir = TempDir::new()?;
335        let storage = SqliteStorage::new(tmp_dir.path())?;
336        let client_id = Uuid::new_v4();
337        let mut txn = storage.txn(client_id).await?;
338
339        let latest_version_id = Uuid::new_v4();
340        txn.new_client(latest_version_id).await?;
341
342        let client = txn.get_client().await?.unwrap();
343        assert_eq!(client.latest_version_id, latest_version_id);
344        assert!(client.snapshot.is_none());
345
346        let new_version_id = Uuid::new_v4();
347        txn.add_version(new_version_id, latest_version_id, vec![1, 1])
348            .await?;
349
350        let client = txn.get_client().await?.unwrap();
351        assert_eq!(client.latest_version_id, new_version_id);
352        assert!(client.snapshot.is_none());
353
354        let snap = Snapshot {
355            version_id: Uuid::new_v4(),
356            timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
357            versions_since: 4,
358        };
359        txn.set_snapshot(snap.clone(), vec![1, 2, 3]).await?;
360
361        let client = txn.get_client().await?.unwrap();
362        assert_eq!(client.latest_version_id, new_version_id);
363        assert_eq!(client.snapshot.unwrap(), snap);
364
365        Ok(())
366    }
367
368    #[tokio::test]
369    async fn test_gvbp_empty() -> anyhow::Result<()> {
370        let tmp_dir = TempDir::new()?;
371        let storage = SqliteStorage::new(tmp_dir.path())?;
372        let client_id = Uuid::new_v4();
373        let mut txn = storage.txn(client_id).await?;
374        let maybe_version = txn.get_version_by_parent(Uuid::new_v4()).await?;
375        assert!(maybe_version.is_none());
376        Ok(())
377    }
378
379    #[tokio::test]
380    async fn test_add_version_and_get_version() -> anyhow::Result<()> {
381        let tmp_dir = TempDir::new()?;
382        let storage = SqliteStorage::new(tmp_dir.path())?;
383        let client_id = Uuid::new_v4();
384        let mut txn = storage.txn(client_id).await?;
385
386        let parent_version_id = Uuid::new_v4();
387        txn.new_client(parent_version_id).await?;
388
389        let version_id = Uuid::new_v4();
390        let history_segment = b"abc".to_vec();
391        txn.add_version(version_id, parent_version_id, history_segment.clone())
392            .await?;
393
394        let expected = Version {
395            version_id,
396            parent_version_id,
397            history_segment,
398        };
399
400        let version = txn.get_version_by_parent(parent_version_id).await?.unwrap();
401        assert_eq!(version, expected);
402
403        let version = txn.get_version(version_id).await?.unwrap();
404        assert_eq!(version, expected);
405
406        Ok(())
407    }
408
409    #[tokio::test]
410    async fn test_add_version_exists() -> anyhow::Result<()> {
411        let tmp_dir = TempDir::new()?;
412        let storage = SqliteStorage::new(tmp_dir.path())?;
413        let client_id = Uuid::new_v4();
414        let mut txn = storage.txn(client_id).await?;
415
416        let parent_version_id = Uuid::new_v4();
417        txn.new_client(parent_version_id).await?;
418
419        let version_id = Uuid::new_v4();
420        let history_segment = b"abc".to_vec();
421        txn.add_version(version_id, parent_version_id, history_segment.clone())
422            .await?;
423        // Fails because the version already exists.
424        assert!(txn
425            .add_version(version_id, parent_version_id, history_segment.clone())
426            .await
427            .is_err());
428        Ok(())
429    }
430
431    #[tokio::test]
432    async fn test_add_version_mismatch() -> anyhow::Result<()> {
433        let tmp_dir = TempDir::new()?;
434        let storage = SqliteStorage::new(tmp_dir.path())?;
435        let client_id = Uuid::new_v4();
436        let mut txn = storage.txn(client_id).await?;
437
438        let latest_version_id = Uuid::new_v4();
439        txn.new_client(latest_version_id).await?;
440
441        let version_id = Uuid::new_v4();
442        let parent_version_id = Uuid::new_v4(); // != latest_version_id
443        let history_segment = b"abc".to_vec();
444        // Fails because the latest_version_id is not parent_version_id.
445        assert!(txn
446            .add_version(version_id, parent_version_id, history_segment.clone())
447            .await
448            .is_err());
449        Ok(())
450    }
451
452    #[tokio::test]
453    async fn test_snapshots() -> anyhow::Result<()> {
454        let tmp_dir = TempDir::new()?;
455        let storage = SqliteStorage::new(tmp_dir.path())?;
456        let client_id = Uuid::new_v4();
457        let mut txn = storage.txn(client_id).await?;
458
459        txn.new_client(Uuid::new_v4()).await?;
460        assert!(txn.get_client().await?.unwrap().snapshot.is_none());
461
462        let snap = Snapshot {
463            version_id: Uuid::new_v4(),
464            timestamp: "2013-10-08T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
465            versions_since: 3,
466        };
467        txn.set_snapshot(snap.clone(), vec![9, 8, 9]).await?;
468
469        assert_eq!(
470            txn.get_snapshot_data(snap.version_id).await?.unwrap(),
471            vec![9, 8, 9]
472        );
473        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap));
474
475        let snap2 = Snapshot {
476            version_id: Uuid::new_v4(),
477            timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
478            versions_since: 10,
479        };
480        txn.set_snapshot(snap2.clone(), vec![0, 2, 4, 6]).await?;
481
482        assert_eq!(
483            txn.get_snapshot_data(snap2.version_id).await?.unwrap(),
484            vec![0, 2, 4, 6]
485        );
486        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap2));
487
488        // check that mismatched version is detected
489        assert!(txn.get_snapshot_data(Uuid::new_v4()).await.is_err());
490
491        Ok(())
492    }
493
494    #[tokio::test]
495    /// When an add_version call specifies a `parent_version_id` that does not exist in the
496    /// DB, but no other versions exist, the call succeeds.
497    async fn test_add_version_no_history() -> anyhow::Result<()> {
498        let tmp_dir = TempDir::new()?;
499        let storage = SqliteStorage::new(tmp_dir.path())?;
500        let client_id = Uuid::new_v4();
501        let mut txn = storage.txn(client_id).await?;
502        txn.new_client(Uuid::nil()).await?;
503
504        let version_id = Uuid::new_v4();
505        let parent_version_id = Uuid::new_v4();
506        txn.add_version(version_id, parent_version_id, b"v1".to_vec())
507            .await?;
508        Ok(())
509    }
510}