taskchampion_sync_server_storage_sqlite/
lib.rs

1//! This crate implements a SQLite storage backend for the TaskChampion sync server.
2//!
3//! Use the [`SqliteStorage`] type as an implementation of the [`Storage`] trait.
4//!
5//! This crate is intended for small deployments of a sync server, supporting one or a small number
6//! of users. The schema for the database is considered an implementation detail. For more robust
7//! database support, consider `taskchampion-sync-server-storage-postgres`.
8
9use anyhow::Context;
10use chrono::{TimeZone, Utc};
11use rusqlite::types::{FromSql, ToSql};
12use rusqlite::{params, Connection, OptionalExtension};
13use std::path::Path;
14use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
15use uuid::Uuid;
16
17/// Newtype to allow implementing `FromSql` for foreign `uuid::Uuid`
18struct StoredUuid(Uuid);
19
20/// Conversion from Uuid stored as a string (rusqlite's uuid feature stores as binary blob)
21impl FromSql for StoredUuid {
22    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
23        let u = Uuid::parse_str(value.as_str()?)
24            .map_err(|_| rusqlite::types::FromSqlError::InvalidType)?;
25        Ok(StoredUuid(u))
26    }
27}
28
29/// Store Uuid as string in database
30impl ToSql for StoredUuid {
31    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
32        let s = self.0.to_string();
33        Ok(s.into())
34    }
35}
36
37/// An on-disk storage backend which uses SQLite.
38///
39/// A new connection is opened for each transaction, and only one transaction may be active at a
40/// time; a second call to `txn` will block until the first transaction is dropped.
41pub struct SqliteStorage {
42    db_file: std::path::PathBuf,
43}
44
45impl SqliteStorage {
46    fn new_connection(&self) -> anyhow::Result<Connection> {
47        Ok(Connection::open(&self.db_file)?)
48    }
49
50    /// Create a new instance using a database at the given directory.
51    ///
52    /// The database will be stored in a file named `taskchampion-sync-server.sqlite3` in the given
53    /// directory. The database will be created if it does not exist.
54    pub fn new<P: AsRef<Path>>(directory: P) -> anyhow::Result<SqliteStorage> {
55        std::fs::create_dir_all(&directory)
56            .with_context(|| format!("Failed to create `{}`.", directory.as_ref().display()))?;
57        let db_file = directory.as_ref().join("taskchampion-sync-server.sqlite3");
58
59        let o = SqliteStorage { db_file };
60
61        let con = o.new_connection()?;
62
63        // Use the modern WAL mode.
64        con.query_row("PRAGMA journal_mode=WAL", [], |_row| Ok(()))
65            .context("Setting journal_mode=WAL")?;
66
67        let queries = vec![
68                "CREATE TABLE IF NOT EXISTS clients (
69                    client_id STRING PRIMARY KEY,
70                    latest_version_id STRING,
71                    snapshot_version_id STRING,
72                    versions_since_snapshot INTEGER,
73                    snapshot_timestamp INTEGER,
74                    snapshot BLOB);",
75                "CREATE TABLE IF NOT EXISTS versions (version_id STRING PRIMARY KEY, client_id STRING, parent_version_id STRING, history_segment BLOB);",
76                "CREATE INDEX IF NOT EXISTS versions_by_parent ON versions (parent_version_id);",
77            ];
78        for q in queries {
79            con.execute(q, [])
80                .context("Error while creating SQLite tables")?;
81        }
82
83        Ok(o)
84    }
85}
86
87#[async_trait::async_trait]
88impl Storage for SqliteStorage {
89    async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
90        let con = self.new_connection()?;
91        // Begin the transaction on this new connection. An IMMEDIATE connection is in
92        // write (exclusive) mode from the start.
93        con.execute("BEGIN IMMEDIATE", [])?;
94        let txn = Txn { con, client_id };
95        Ok(Box::new(txn))
96    }
97}
98
99struct Txn {
100    // SQLite only allows one concurrent transaction per connection, and rusqlite emulates
101    // transactions by running `BEGIN ...` and `COMMIT` at appropriate times. So we will do
102    // the same.
103    con: Connection,
104    client_id: Uuid,
105}
106
107impl Txn {
108    /// Implementation for queries from the versions table
109    fn get_version_impl(
110        &mut self,
111        query: &'static str,
112        client_id: Uuid,
113        version_id_arg: Uuid,
114    ) -> anyhow::Result<Option<Version>> {
115        let r = self
116            .con
117            .query_row(
118                query,
119                params![&StoredUuid(version_id_arg), &StoredUuid(client_id)],
120                |r| {
121                    let version_id: StoredUuid = r.get("version_id")?;
122                    let parent_version_id: StoredUuid = r.get("parent_version_id")?;
123
124                    Ok(Version {
125                        version_id: version_id.0,
126                        parent_version_id: parent_version_id.0,
127                        history_segment: r.get("history_segment")?,
128                    })
129                },
130            )
131            .optional()
132            .context("Error getting version")?;
133        Ok(r)
134    }
135}
136
137#[async_trait::async_trait(?Send)]
138impl StorageTxn for Txn {
139    async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
140        let result: Option<Client> = self
141            .con
142            .query_row(
143                "SELECT
144                    latest_version_id,
145                    snapshot_timestamp,
146                    versions_since_snapshot,
147                    snapshot_version_id
148                 FROM clients
149                 WHERE client_id = ?
150                 LIMIT 1",
151                [&StoredUuid(self.client_id)],
152                |r| {
153                    let latest_version_id: StoredUuid = r.get(0)?;
154                    let snapshot_timestamp: Option<i64> = r.get(1)?;
155                    let versions_since_snapshot: Option<u32> = r.get(2)?;
156                    let snapshot_version_id: Option<StoredUuid> = r.get(3)?;
157
158                    // if all of the relevant fields are non-NULL, return a snapshot
159                    let snapshot = match (
160                        snapshot_timestamp,
161                        versions_since_snapshot,
162                        snapshot_version_id,
163                    ) {
164                        (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
165                            version_id: v.0,
166                            timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
167                            versions_since: vs,
168                        }),
169                        _ => None,
170                    };
171                    Ok(Client {
172                        latest_version_id: latest_version_id.0,
173                        snapshot,
174                    })
175                },
176            )
177            .optional()
178            .context("Error getting client")?;
179
180        Ok(result)
181    }
182
183    async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
184        self.con
185            .execute(
186                "INSERT INTO clients (client_id, latest_version_id) VALUES (?, ?)",
187                params![&StoredUuid(self.client_id), &StoredUuid(latest_version_id)],
188            )
189            .context("Error creating/updating client")?;
190        Ok(())
191    }
192
193    async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
194        self.con
195            .execute(
196                "UPDATE clients
197             SET
198               snapshot_version_id = ?,
199               snapshot_timestamp = ?,
200               versions_since_snapshot = ?,
201               snapshot = ?
202             WHERE client_id = ?",
203                params![
204                    &StoredUuid(snapshot.version_id),
205                    snapshot.timestamp.timestamp(),
206                    snapshot.versions_since,
207                    data,
208                    &StoredUuid(self.client_id),
209                ],
210            )
211            .context("Error creating/updating snapshot")?;
212        Ok(())
213    }
214
215    async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
216        let r = self
217            .con
218            .query_row(
219                "SELECT snapshot, snapshot_version_id FROM clients WHERE client_id = ?",
220                params![&StoredUuid(self.client_id)],
221                |r| {
222                    let v: StoredUuid = r.get("snapshot_version_id")?;
223                    let d: Vec<u8> = r.get("snapshot")?;
224                    Ok((v.0, d))
225                },
226            )
227            .optional()
228            .context("Error getting snapshot")?;
229        r.map(|(v, d)| {
230            if v != version_id {
231                return Err(anyhow::anyhow!("unexpected snapshot_version_id"));
232            }
233
234            Ok(d)
235        })
236        .transpose()
237    }
238
239    async fn get_version_by_parent(
240        &mut self,
241        parent_version_id: Uuid,
242    ) -> anyhow::Result<Option<Version>> {
243        self.get_version_impl(
244            "SELECT version_id, parent_version_id, history_segment FROM versions WHERE parent_version_id = ? AND client_id = ?",
245            self.client_id,
246            parent_version_id)
247    }
248
249    async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
250        self.get_version_impl(
251            "SELECT version_id, parent_version_id, history_segment FROM versions WHERE version_id = ? AND client_id = ?",
252            self.client_id,
253            version_id)
254    }
255
256    async fn add_version(
257        &mut self,
258        version_id: Uuid,
259        parent_version_id: Uuid,
260        history_segment: Vec<u8>,
261    ) -> anyhow::Result<()> {
262        self.con.execute(
263            "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment) VALUES(?, ?, ?, ?)",
264            params![
265                StoredUuid(version_id),
266                StoredUuid(self.client_id),
267                StoredUuid(parent_version_id),
268                history_segment
269            ]
270        )
271        .context("Error adding version")?;
272        let rows_changed = self
273            .con
274            .execute(
275                "UPDATE clients
276             SET
277               latest_version_id = ?,
278               versions_since_snapshot = versions_since_snapshot + 1
279             WHERE client_id = ? and latest_version_id = ?",
280                params![
281                    StoredUuid(version_id),
282                    StoredUuid(self.client_id),
283                    StoredUuid(parent_version_id)
284                ],
285            )
286            .context("Error updating client for new version")?;
287
288        if rows_changed == 0 {
289            anyhow::bail!("clients.latest_version_id does not match parent_version_id");
290        }
291
292        Ok(())
293    }
294
295    async fn commit(&mut self) -> anyhow::Result<()> {
296        self.con.execute("COMMIT", [])?;
297        Ok(())
298    }
299}
300
301#[cfg(test)]
302mod test {
303    use super::*;
304    use chrono::DateTime;
305    use pretty_assertions::assert_eq;
306    use tempfile::TempDir;
307
308    #[tokio::test]
309    async fn test_emtpy_dir() -> anyhow::Result<()> {
310        let tmp_dir = TempDir::new()?;
311        let non_existant = tmp_dir.path().join("subdir");
312        let storage = SqliteStorage::new(non_existant)?;
313        let client_id = Uuid::new_v4();
314        let mut txn = storage.txn(client_id).await?;
315        let maybe_client = txn.get_client().await?;
316        assert!(maybe_client.is_none());
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn test_get_client_empty() -> anyhow::Result<()> {
322        let tmp_dir = TempDir::new()?;
323        let storage = SqliteStorage::new(tmp_dir.path())?;
324        let client_id = Uuid::new_v4();
325        let mut txn = storage.txn(client_id).await?;
326        let maybe_client = txn.get_client().await?;
327        assert!(maybe_client.is_none());
328        Ok(())
329    }
330
331    #[tokio::test]
332    async fn test_client_storage() -> anyhow::Result<()> {
333        let tmp_dir = TempDir::new()?;
334        let storage = SqliteStorage::new(tmp_dir.path())?;
335        let client_id = Uuid::new_v4();
336        let mut txn = storage.txn(client_id).await?;
337
338        let latest_version_id = Uuid::new_v4();
339        txn.new_client(latest_version_id).await?;
340
341        let client = txn.get_client().await?.unwrap();
342        assert_eq!(client.latest_version_id, latest_version_id);
343        assert!(client.snapshot.is_none());
344
345        let new_version_id = Uuid::new_v4();
346        txn.add_version(new_version_id, latest_version_id, vec![1, 1])
347            .await?;
348
349        let client = txn.get_client().await?.unwrap();
350        assert_eq!(client.latest_version_id, new_version_id);
351        assert!(client.snapshot.is_none());
352
353        let snap = Snapshot {
354            version_id: Uuid::new_v4(),
355            timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
356            versions_since: 4,
357        };
358        txn.set_snapshot(snap.clone(), vec![1, 2, 3]).await?;
359
360        let client = txn.get_client().await?.unwrap();
361        assert_eq!(client.latest_version_id, new_version_id);
362        assert_eq!(client.snapshot.unwrap(), snap);
363
364        Ok(())
365    }
366
367    #[tokio::test]
368    async fn test_gvbp_empty() -> anyhow::Result<()> {
369        let tmp_dir = TempDir::new()?;
370        let storage = SqliteStorage::new(tmp_dir.path())?;
371        let client_id = Uuid::new_v4();
372        let mut txn = storage.txn(client_id).await?;
373        let maybe_version = txn.get_version_by_parent(Uuid::new_v4()).await?;
374        assert!(maybe_version.is_none());
375        Ok(())
376    }
377
378    #[tokio::test]
379    async fn test_add_version_and_get_version() -> anyhow::Result<()> {
380        let tmp_dir = TempDir::new()?;
381        let storage = SqliteStorage::new(tmp_dir.path())?;
382        let client_id = Uuid::new_v4();
383        let mut txn = storage.txn(client_id).await?;
384
385        let parent_version_id = Uuid::new_v4();
386        txn.new_client(parent_version_id).await?;
387
388        let version_id = Uuid::new_v4();
389        let history_segment = b"abc".to_vec();
390        txn.add_version(version_id, parent_version_id, history_segment.clone())
391            .await?;
392
393        let expected = Version {
394            version_id,
395            parent_version_id,
396            history_segment,
397        };
398
399        let version = txn.get_version_by_parent(parent_version_id).await?.unwrap();
400        assert_eq!(version, expected);
401
402        let version = txn.get_version(version_id).await?.unwrap();
403        assert_eq!(version, expected);
404
405        Ok(())
406    }
407
408    #[tokio::test]
409    async fn test_add_version_exists() -> anyhow::Result<()> {
410        let tmp_dir = TempDir::new()?;
411        let storage = SqliteStorage::new(tmp_dir.path())?;
412        let client_id = Uuid::new_v4();
413        let mut txn = storage.txn(client_id).await?;
414
415        let parent_version_id = Uuid::new_v4();
416        txn.new_client(parent_version_id).await?;
417
418        let version_id = Uuid::new_v4();
419        let history_segment = b"abc".to_vec();
420        txn.add_version(version_id, parent_version_id, history_segment.clone())
421            .await?;
422        // Fails because the version already exists.
423        assert!(txn
424            .add_version(version_id, parent_version_id, history_segment.clone())
425            .await
426            .is_err());
427        Ok(())
428    }
429
430    #[tokio::test]
431    async fn test_add_version_mismatch() -> anyhow::Result<()> {
432        let tmp_dir = TempDir::new()?;
433        let storage = SqliteStorage::new(tmp_dir.path())?;
434        let client_id = Uuid::new_v4();
435        let mut txn = storage.txn(client_id).await?;
436
437        let latest_version_id = Uuid::new_v4();
438        txn.new_client(latest_version_id).await?;
439
440        let version_id = Uuid::new_v4();
441        let parent_version_id = Uuid::new_v4(); // != latest_version_id
442        let history_segment = b"abc".to_vec();
443        // Fails because the latest_version_id is not parent_version_id.
444        assert!(txn
445            .add_version(version_id, parent_version_id, history_segment.clone())
446            .await
447            .is_err());
448        Ok(())
449    }
450
451    #[tokio::test]
452    async fn test_snapshots() -> anyhow::Result<()> {
453        let tmp_dir = TempDir::new()?;
454        let storage = SqliteStorage::new(tmp_dir.path())?;
455        let client_id = Uuid::new_v4();
456        let mut txn = storage.txn(client_id).await?;
457
458        txn.new_client(Uuid::new_v4()).await?;
459        assert!(txn.get_client().await?.unwrap().snapshot.is_none());
460
461        let snap = Snapshot {
462            version_id: Uuid::new_v4(),
463            timestamp: "2013-10-08T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
464            versions_since: 3,
465        };
466        txn.set_snapshot(snap.clone(), vec![9, 8, 9]).await?;
467
468        assert_eq!(
469            txn.get_snapshot_data(snap.version_id).await?.unwrap(),
470            vec![9, 8, 9]
471        );
472        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap));
473
474        let snap2 = Snapshot {
475            version_id: Uuid::new_v4(),
476            timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
477            versions_since: 10,
478        };
479        txn.set_snapshot(snap2.clone(), vec![0, 2, 4, 6]).await?;
480
481        assert_eq!(
482            txn.get_snapshot_data(snap2.version_id).await?.unwrap(),
483            vec![0, 2, 4, 6]
484        );
485        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap2));
486
487        // check that mismatched version is detected
488        assert!(txn.get_snapshot_data(Uuid::new_v4()).await.is_err());
489
490        Ok(())
491    }
492}