taskchampion_sync_server_storage_sqlite/
lib.rs1use anyhow::Context;
10use chrono::{TimeZone, Utc};
11use rusqlite::types::{FromSql, ToSql};
12use rusqlite::{params, Connection, OptionalExtension};
13use std::path::Path;
14use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
15use uuid::Uuid;
16
17struct StoredUuid(Uuid);
19
20impl FromSql for StoredUuid {
22 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
23 let u = Uuid::parse_str(value.as_str()?)
24 .map_err(|_| rusqlite::types::FromSqlError::InvalidType)?;
25 Ok(StoredUuid(u))
26 }
27}
28
29impl ToSql for StoredUuid {
31 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
32 let s = self.0.to_string();
33 Ok(s.into())
34 }
35}
36
37pub struct SqliteStorage {
42 db_file: std::path::PathBuf,
43}
44
45impl SqliteStorage {
46 fn new_connection(&self) -> anyhow::Result<Connection> {
47 Ok(Connection::open(&self.db_file)?)
48 }
49
50 pub fn new<P: AsRef<Path>>(directory: P) -> anyhow::Result<SqliteStorage> {
55 std::fs::create_dir_all(&directory)
56 .with_context(|| format!("Failed to create `{}`.", directory.as_ref().display()))?;
57 let db_file = directory.as_ref().join("taskchampion-sync-server.sqlite3");
58
59 let o = SqliteStorage { db_file };
60
61 let con = o.new_connection()?;
62
63 con.query_row("PRAGMA journal_mode=WAL", [], |_row| Ok(()))
65 .context("Setting journal_mode=WAL")?;
66
67 let queries = vec![
68 "CREATE TABLE IF NOT EXISTS clients (
69 client_id STRING PRIMARY KEY,
70 latest_version_id STRING,
71 snapshot_version_id STRING,
72 versions_since_snapshot INTEGER,
73 snapshot_timestamp INTEGER,
74 snapshot BLOB);",
75 "CREATE TABLE IF NOT EXISTS versions (version_id STRING PRIMARY KEY, client_id STRING, parent_version_id STRING, history_segment BLOB);",
76 "CREATE INDEX IF NOT EXISTS versions_by_parent ON versions (parent_version_id);",
77 ];
78 for q in queries {
79 con.execute(q, [])
80 .context("Error while creating SQLite tables")?;
81 }
82
83 Ok(o)
84 }
85}
86
87#[async_trait::async_trait]
88impl Storage for SqliteStorage {
89 async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
90 let con = self.new_connection()?;
91 con.execute("BEGIN IMMEDIATE", [])?;
94 let txn = Txn { con, client_id };
95 Ok(Box::new(txn))
96 }
97}
98
99struct Txn {
100 con: Connection,
104 client_id: Uuid,
105}
106
107impl Txn {
108 fn get_version_impl(
110 &mut self,
111 query: &'static str,
112 client_id: Uuid,
113 version_id_arg: Uuid,
114 ) -> anyhow::Result<Option<Version>> {
115 let r = self
116 .con
117 .query_row(
118 query,
119 params![&StoredUuid(version_id_arg), &StoredUuid(client_id)],
120 |r| {
121 let version_id: StoredUuid = r.get("version_id")?;
122 let parent_version_id: StoredUuid = r.get("parent_version_id")?;
123
124 Ok(Version {
125 version_id: version_id.0,
126 parent_version_id: parent_version_id.0,
127 history_segment: r.get("history_segment")?,
128 })
129 },
130 )
131 .optional()
132 .context("Error getting version")?;
133 Ok(r)
134 }
135}
136
137#[async_trait::async_trait(?Send)]
138impl StorageTxn for Txn {
139 async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
140 let result: Option<Client> = self
141 .con
142 .query_row(
143 "SELECT
144 latest_version_id,
145 snapshot_timestamp,
146 versions_since_snapshot,
147 snapshot_version_id
148 FROM clients
149 WHERE client_id = ?
150 LIMIT 1",
151 [&StoredUuid(self.client_id)],
152 |r| {
153 let latest_version_id: StoredUuid = r.get(0)?;
154 let snapshot_timestamp: Option<i64> = r.get(1)?;
155 let versions_since_snapshot: Option<u32> = r.get(2)?;
156 let snapshot_version_id: Option<StoredUuid> = r.get(3)?;
157
158 let snapshot = match (
160 snapshot_timestamp,
161 versions_since_snapshot,
162 snapshot_version_id,
163 ) {
164 (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
165 version_id: v.0,
166 timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
167 versions_since: vs,
168 }),
169 _ => None,
170 };
171 Ok(Client {
172 latest_version_id: latest_version_id.0,
173 snapshot,
174 })
175 },
176 )
177 .optional()
178 .context("Error getting client")?;
179
180 Ok(result)
181 }
182
183 async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
184 self.con
185 .execute(
186 "INSERT INTO clients (client_id, latest_version_id) VALUES (?, ?)",
187 params![&StoredUuid(self.client_id), &StoredUuid(latest_version_id)],
188 )
189 .context("Error creating/updating client")?;
190 Ok(())
191 }
192
193 async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
194 self.con
195 .execute(
196 "UPDATE clients
197 SET
198 snapshot_version_id = ?,
199 snapshot_timestamp = ?,
200 versions_since_snapshot = ?,
201 snapshot = ?
202 WHERE client_id = ?",
203 params![
204 &StoredUuid(snapshot.version_id),
205 snapshot.timestamp.timestamp(),
206 snapshot.versions_since,
207 data,
208 &StoredUuid(self.client_id),
209 ],
210 )
211 .context("Error creating/updating snapshot")?;
212 Ok(())
213 }
214
215 async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
216 let r = self
217 .con
218 .query_row(
219 "SELECT snapshot, snapshot_version_id FROM clients WHERE client_id = ?",
220 params![&StoredUuid(self.client_id)],
221 |r| {
222 let v: StoredUuid = r.get("snapshot_version_id")?;
223 let d: Vec<u8> = r.get("snapshot")?;
224 Ok((v.0, d))
225 },
226 )
227 .optional()
228 .context("Error getting snapshot")?;
229 r.map(|(v, d)| {
230 if v != version_id {
231 return Err(anyhow::anyhow!("unexpected snapshot_version_id"));
232 }
233
234 Ok(d)
235 })
236 .transpose()
237 }
238
239 async fn get_version_by_parent(
240 &mut self,
241 parent_version_id: Uuid,
242 ) -> anyhow::Result<Option<Version>> {
243 self.get_version_impl(
244 "SELECT version_id, parent_version_id, history_segment FROM versions WHERE parent_version_id = ? AND client_id = ?",
245 self.client_id,
246 parent_version_id)
247 }
248
249 async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
250 self.get_version_impl(
251 "SELECT version_id, parent_version_id, history_segment FROM versions WHERE version_id = ? AND client_id = ?",
252 self.client_id,
253 version_id)
254 }
255
256 async fn add_version(
257 &mut self,
258 version_id: Uuid,
259 parent_version_id: Uuid,
260 history_segment: Vec<u8>,
261 ) -> anyhow::Result<()> {
262 self.con.execute(
263 "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment) VALUES(?, ?, ?, ?)",
264 params![
265 StoredUuid(version_id),
266 StoredUuid(self.client_id),
267 StoredUuid(parent_version_id),
268 history_segment
269 ]
270 )
271 .context("Error adding version")?;
272 let rows_changed = self
273 .con
274 .execute(
275 "UPDATE clients
276 SET
277 latest_version_id = ?,
278 versions_since_snapshot = versions_since_snapshot + 1
279 WHERE client_id = ? and latest_version_id = ?",
280 params![
281 StoredUuid(version_id),
282 StoredUuid(self.client_id),
283 StoredUuid(parent_version_id)
284 ],
285 )
286 .context("Error updating client for new version")?;
287
288 if rows_changed == 0 {
289 anyhow::bail!("clients.latest_version_id does not match parent_version_id");
290 }
291
292 Ok(())
293 }
294
295 async fn commit(&mut self) -> anyhow::Result<()> {
296 self.con.execute("COMMIT", [])?;
297 Ok(())
298 }
299}
300
301#[cfg(test)]
302mod test {
303 use super::*;
304 use chrono::DateTime;
305 use pretty_assertions::assert_eq;
306 use tempfile::TempDir;
307
308 #[tokio::test]
309 async fn test_emtpy_dir() -> anyhow::Result<()> {
310 let tmp_dir = TempDir::new()?;
311 let non_existant = tmp_dir.path().join("subdir");
312 let storage = SqliteStorage::new(non_existant)?;
313 let client_id = Uuid::new_v4();
314 let mut txn = storage.txn(client_id).await?;
315 let maybe_client = txn.get_client().await?;
316 assert!(maybe_client.is_none());
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn test_get_client_empty() -> anyhow::Result<()> {
322 let tmp_dir = TempDir::new()?;
323 let storage = SqliteStorage::new(tmp_dir.path())?;
324 let client_id = Uuid::new_v4();
325 let mut txn = storage.txn(client_id).await?;
326 let maybe_client = txn.get_client().await?;
327 assert!(maybe_client.is_none());
328 Ok(())
329 }
330
331 #[tokio::test]
332 async fn test_client_storage() -> anyhow::Result<()> {
333 let tmp_dir = TempDir::new()?;
334 let storage = SqliteStorage::new(tmp_dir.path())?;
335 let client_id = Uuid::new_v4();
336 let mut txn = storage.txn(client_id).await?;
337
338 let latest_version_id = Uuid::new_v4();
339 txn.new_client(latest_version_id).await?;
340
341 let client = txn.get_client().await?.unwrap();
342 assert_eq!(client.latest_version_id, latest_version_id);
343 assert!(client.snapshot.is_none());
344
345 let new_version_id = Uuid::new_v4();
346 txn.add_version(new_version_id, latest_version_id, vec![1, 1])
347 .await?;
348
349 let client = txn.get_client().await?.unwrap();
350 assert_eq!(client.latest_version_id, new_version_id);
351 assert!(client.snapshot.is_none());
352
353 let snap = Snapshot {
354 version_id: Uuid::new_v4(),
355 timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
356 versions_since: 4,
357 };
358 txn.set_snapshot(snap.clone(), vec![1, 2, 3]).await?;
359
360 let client = txn.get_client().await?.unwrap();
361 assert_eq!(client.latest_version_id, new_version_id);
362 assert_eq!(client.snapshot.unwrap(), snap);
363
364 Ok(())
365 }
366
367 #[tokio::test]
368 async fn test_gvbp_empty() -> anyhow::Result<()> {
369 let tmp_dir = TempDir::new()?;
370 let storage = SqliteStorage::new(tmp_dir.path())?;
371 let client_id = Uuid::new_v4();
372 let mut txn = storage.txn(client_id).await?;
373 let maybe_version = txn.get_version_by_parent(Uuid::new_v4()).await?;
374 assert!(maybe_version.is_none());
375 Ok(())
376 }
377
378 #[tokio::test]
379 async fn test_add_version_and_get_version() -> anyhow::Result<()> {
380 let tmp_dir = TempDir::new()?;
381 let storage = SqliteStorage::new(tmp_dir.path())?;
382 let client_id = Uuid::new_v4();
383 let mut txn = storage.txn(client_id).await?;
384
385 let parent_version_id = Uuid::new_v4();
386 txn.new_client(parent_version_id).await?;
387
388 let version_id = Uuid::new_v4();
389 let history_segment = b"abc".to_vec();
390 txn.add_version(version_id, parent_version_id, history_segment.clone())
391 .await?;
392
393 let expected = Version {
394 version_id,
395 parent_version_id,
396 history_segment,
397 };
398
399 let version = txn.get_version_by_parent(parent_version_id).await?.unwrap();
400 assert_eq!(version, expected);
401
402 let version = txn.get_version(version_id).await?.unwrap();
403 assert_eq!(version, expected);
404
405 Ok(())
406 }
407
408 #[tokio::test]
409 async fn test_add_version_exists() -> anyhow::Result<()> {
410 let tmp_dir = TempDir::new()?;
411 let storage = SqliteStorage::new(tmp_dir.path())?;
412 let client_id = Uuid::new_v4();
413 let mut txn = storage.txn(client_id).await?;
414
415 let parent_version_id = Uuid::new_v4();
416 txn.new_client(parent_version_id).await?;
417
418 let version_id = Uuid::new_v4();
419 let history_segment = b"abc".to_vec();
420 txn.add_version(version_id, parent_version_id, history_segment.clone())
421 .await?;
422 assert!(txn
424 .add_version(version_id, parent_version_id, history_segment.clone())
425 .await
426 .is_err());
427 Ok(())
428 }
429
430 #[tokio::test]
431 async fn test_add_version_mismatch() -> anyhow::Result<()> {
432 let tmp_dir = TempDir::new()?;
433 let storage = SqliteStorage::new(tmp_dir.path())?;
434 let client_id = Uuid::new_v4();
435 let mut txn = storage.txn(client_id).await?;
436
437 let latest_version_id = Uuid::new_v4();
438 txn.new_client(latest_version_id).await?;
439
440 let version_id = Uuid::new_v4();
441 let parent_version_id = Uuid::new_v4(); let history_segment = b"abc".to_vec();
443 assert!(txn
445 .add_version(version_id, parent_version_id, history_segment.clone())
446 .await
447 .is_err());
448 Ok(())
449 }
450
451 #[tokio::test]
452 async fn test_snapshots() -> anyhow::Result<()> {
453 let tmp_dir = TempDir::new()?;
454 let storage = SqliteStorage::new(tmp_dir.path())?;
455 let client_id = Uuid::new_v4();
456 let mut txn = storage.txn(client_id).await?;
457
458 txn.new_client(Uuid::new_v4()).await?;
459 assert!(txn.get_client().await?.unwrap().snapshot.is_none());
460
461 let snap = Snapshot {
462 version_id: Uuid::new_v4(),
463 timestamp: "2013-10-08T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
464 versions_since: 3,
465 };
466 txn.set_snapshot(snap.clone(), vec![9, 8, 9]).await?;
467
468 assert_eq!(
469 txn.get_snapshot_data(snap.version_id).await?.unwrap(),
470 vec![9, 8, 9]
471 );
472 assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap));
473
474 let snap2 = Snapshot {
475 version_id: Uuid::new_v4(),
476 timestamp: "2014-11-28T12:00:09Z".parse::<DateTime<Utc>>().unwrap(),
477 versions_since: 10,
478 };
479 txn.set_snapshot(snap2.clone(), vec![0, 2, 4, 6]).await?;
480
481 assert_eq!(
482 txn.get_snapshot_data(snap2.version_id).await?.unwrap(),
483 vec![0, 2, 4, 6]
484 );
485 assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap2));
486
487 assert!(txn.get_snapshot_data(Uuid::new_v4()).await.is_err());
489
490 Ok(())
491 }
492}