taskchampion_sync_server_storage_postgres/
lib.rs

1//! This crate implements a Postgres storage backend for the TaskChampion sync server.
2//!
3//! Use the [`PostgresStorage`] type as an implementation of the [`Storage`] trait.
4//!
5//! This implementation is tested with Postgres version 17 but should work with any recent version.
6//!
7//! ## Schema Setup
8//!
9//! The database identified by the connection string must already exist and be set up with the
10//! following schema (also available in `postgres/schema.sql` in the repository):
11//!
12//! ```sql
13#![doc=include_str!("../schema.sql")]
14//! ```
15//!
16//! ## Integration with External Applications
17//!
18//! The schema is stable, and any changes to the schema will be made in a major version with
19//! migration instructions provided.
20//!
21//! An external application may:
22//!  - Add additional tables to the database
23//!  - Add additional columns to the `clients` table. If those columns do not have default
24//!    values, calls to `Txn::new_client` will fail. It is possible to configure
25//!    `taskchampion-sync-server` to never call this method.
26//!  - Insert rows into the `clients` table, using default values for all columns except
27//!    `client_id` and application-specific columns.
28//!  - Delete rows from the `clients` table, noting that any associated task data
29//!    is also deleted.
30
31use anyhow::Context;
32use bb8::PooledConnection;
33use bb8_postgres::PostgresConnectionManager;
34use chrono::{TimeZone, Utc};
35use postgres_native_tls::MakeTlsConnector;
36use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
37use uuid::Uuid;
38
39#[cfg(test)]
40mod testing;
41
42/// An `ErrorSink` implementation that logs errors to the Rust log.
43#[derive(Debug, Clone, Copy)]
44pub struct LogErrorSink;
45
46impl LogErrorSink {
47    fn new() -> Box<Self> {
48        Box::new(Self)
49    }
50}
51
52impl bb8::ErrorSink<tokio_postgres::Error> for LogErrorSink {
53    fn sink(&self, e: tokio_postgres::Error) {
54        log::error!("Postgres connection error: {e}");
55    }
56
57    fn boxed_clone(&self) -> Box<dyn bb8::ErrorSink<tokio_postgres::Error>> {
58        Self::new()
59    }
60}
61
62/// A storage backend which uses Postgres.
63pub struct PostgresStorage {
64    pool: bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>,
65}
66
67impl PostgresStorage {
68    pub async fn new(connection_string: impl ToString) -> anyhow::Result<Self> {
69        let connector = native_tls::TlsConnector::new()?;
70        let connector = postgres_native_tls::MakeTlsConnector::new(connector);
71        let manager = PostgresConnectionManager::new_from_stringlike(connection_string, connector)?;
72        let pool = bb8::Pool::builder()
73            .error_sink(LogErrorSink::new())
74            .build(manager)
75            .await?;
76        Ok(Self { pool })
77    }
78}
79
80#[async_trait::async_trait]
81impl Storage for PostgresStorage {
82    async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
83        let db_client = self.pool.get_owned().await?;
84
85        db_client
86            .execute("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE", &[])
87            .await?;
88
89        Ok(Box::new(Txn {
90            client_id,
91            db_client: Some(db_client),
92        }))
93    }
94}
95
96struct Txn {
97    client_id: Uuid,
98    /// The DB client or, if `commit` has been called, None. This ensures queries aren't executed
99    /// after commit, and also frees connections back to the pool as quickly as possible.
100    db_client: Option<PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>>,
101}
102
103impl Txn {
104    /// Get the db_client, or panic if it is gone (after commit).
105    fn db_client(&self) -> &tokio_postgres::Client {
106        let Some(db_client) = &self.db_client else {
107            panic!("Cannot use a postgres Txn after commit");
108        };
109        db_client
110    }
111
112    /// Implementation for queries from the versions table
113    async fn get_version_impl(
114        &mut self,
115        query: &'static str,
116        client_id: Uuid,
117        version_id_arg: Uuid,
118    ) -> anyhow::Result<Option<Version>> {
119        Ok(self
120            .db_client()
121            .query_opt(query, &[&version_id_arg, &client_id])
122            .await
123            .context("error getting version")?
124            .map(|r| Version {
125                version_id: r.get(0),
126                parent_version_id: r.get(1),
127                history_segment: r.get("history_segment"),
128            }))
129    }
130}
131
132#[async_trait::async_trait(?Send)]
133impl StorageTxn for Txn {
134    async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
135        Ok(self
136            .db_client()
137            .query_opt(
138                "SELECT
139                    latest_version_id,
140                    snapshot_timestamp,
141                    versions_since_snapshot,
142                    snapshot_version_id
143                 FROM clients
144                 WHERE client_id = $1
145                 LIMIT 1",
146                &[&self.client_id],
147            )
148            .await
149            .context("error getting client")?
150            .map(|r| {
151                let latest_version_id: Uuid = r.get(0);
152                let snapshot_timestamp: Option<i64> = r.get(1);
153                let versions_since_snapshot: Option<i32> = r.get(2);
154                let snapshot_version_id: Option<Uuid> = r.get(3);
155
156                // if all of the relevant fields are non-NULL, return a snapshot
157                let snapshot = match (
158                    snapshot_timestamp,
159                    versions_since_snapshot,
160                    snapshot_version_id,
161                ) {
162                    (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
163                        version_id: v,
164                        timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
165                        versions_since: vs as u32,
166                    }),
167                    _ => None,
168                };
169                Client {
170                    latest_version_id,
171                    snapshot,
172                }
173            }))
174    }
175
176    async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
177        self.db_client()
178            .execute(
179                "INSERT INTO clients (client_id, latest_version_id) VALUES ($1, $2)",
180                &[&self.client_id, &latest_version_id],
181            )
182            .await
183            .context("error creating/updating client")?;
184        Ok(())
185    }
186
187    async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
188        let timestamp = snapshot.timestamp.timestamp();
189        self.db_client()
190            .execute(
191                "UPDATE clients
192                    SET snapshot_version_id = $1,
193                        versions_since_snapshot = $2,
194                        snapshot_timestamp = $3,
195                        snapshot = $4
196                    WHERE client_id = $5",
197                &[
198                    &snapshot.version_id,
199                    &(snapshot.versions_since as i32),
200                    &timestamp,
201                    &data,
202                    &self.client_id,
203                ],
204            )
205            .await
206            .context("error setting snapshot")?;
207        Ok(())
208    }
209
210    async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
211        Ok(self
212            .db_client()
213            .query_opt(
214                "SELECT snapshot
215                 FROM clients
216                 WHERE client_id = $1 and snapshot_version_id = $2
217                 LIMIT 1",
218                &[&self.client_id, &version_id],
219            )
220            .await
221            .context("error getting snapshot data")?
222            .map(|r| r.get(0)))
223    }
224
225    async fn get_version_by_parent(
226        &mut self,
227        parent_version_id: Uuid,
228    ) -> anyhow::Result<Option<Version>> {
229        self.get_version_impl(
230            "SELECT version_id, parent_version_id, history_segment
231                FROM versions
232                WHERE parent_version_id = $1 AND client_id = $2",
233            self.client_id,
234            parent_version_id,
235        )
236        .await
237    }
238
239    async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
240        self.get_version_impl(
241            "SELECT version_id, parent_version_id, history_segment
242                FROM versions
243                WHERE version_id = $1 AND client_id = $2",
244            self.client_id,
245            version_id,
246        )
247        .await
248    }
249
250    async fn add_version(
251        &mut self,
252        version_id: Uuid,
253        parent_version_id: Uuid,
254        history_segment: Vec<u8>,
255    ) -> anyhow::Result<()> {
256        self.db_client()
257            .execute(
258                "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment)
259                VALUES ($1, $2, $3, $4)",
260                &[
261                    &version_id,
262                    &self.client_id,
263                    &parent_version_id,
264                    &history_segment,
265                ],
266            )
267            .await
268            .context("error inserting new version")?;
269        let rows_modified = self
270            .db_client()
271            .execute(
272                "UPDATE clients
273                    SET latest_version_id = $1,
274                        versions_since_snapshot = versions_since_snapshot + 1
275                    WHERE client_id = $2 and (latest_version_id = $3 or latest_version_id = $4)",
276                &[
277                    &version_id,
278                    &self.client_id,
279                    &parent_version_id,
280                    &Uuid::nil(),
281                ],
282            )
283            .await
284            .context("error updating latest_version_id")?;
285
286        // If no rows were modified, this operation failed.
287        if rows_modified == 0 {
288            anyhow::bail!("clients.latest_version_id does not match parent_version_id");
289        }
290        Ok(())
291    }
292
293    async fn commit(&mut self) -> anyhow::Result<()> {
294        self.db_client().execute("COMMIT", &[]).await?;
295        self.db_client = None;
296        Ok(())
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use super::*;
303    use crate::testing::with_db;
304
305    async fn make_client(db_client: &tokio_postgres::Client) -> anyhow::Result<Uuid> {
306        let client_id = Uuid::new_v4();
307        db_client
308            .execute("insert into clients (client_id) values ($1)", &[&client_id])
309            .await?;
310        Ok(client_id)
311    }
312
313    async fn make_version(
314        db_client: &tokio_postgres::Client,
315        client_id: Uuid,
316        parent_version_id: Uuid,
317        history_segment: &[u8],
318    ) -> anyhow::Result<Uuid> {
319        let version_id = Uuid::new_v4();
320        db_client
321            .execute(
322                "insert into versions
323                    (version_id, client_id, parent_version_id, history_segment)
324                    values ($1, $2, $3, $4)",
325                &[
326                    &version_id,
327                    &client_id,
328                    &parent_version_id,
329                    &history_segment,
330                ],
331            )
332            .await?;
333        Ok(version_id)
334    }
335
336    async fn set_client_latest_version_id(
337        db_client: &tokio_postgres::Client,
338        client_id: Uuid,
339        latest_version_id: Uuid,
340    ) -> anyhow::Result<()> {
341        db_client
342            .execute(
343                "update clients set latest_version_id = $1 where client_id = $2",
344                &[&latest_version_id, &client_id],
345            )
346            .await?;
347        Ok(())
348    }
349
350    async fn set_client_snapshot(
351        db_client: &tokio_postgres::Client,
352        client_id: Uuid,
353        snapshot_version_id: Uuid,
354        versions_since_snapshot: u32,
355        snapshot_timestamp: i64,
356        snapshot: &[u8],
357    ) -> anyhow::Result<()> {
358        db_client
359            .execute(
360                "
361                update clients
362                    set snapshot_version_id = $1,
363                        versions_since_snapshot = $2,
364                        snapshot_timestamp = $3,
365                        snapshot = $4
366                    where client_id = $5",
367                &[
368                    &snapshot_version_id,
369                    &(versions_since_snapshot as i32),
370                    &snapshot_timestamp,
371                    &snapshot,
372                    &client_id,
373                ],
374            )
375            .await?;
376        Ok(())
377    }
378
379    #[tokio::test]
380    async fn test_get_client_none() -> anyhow::Result<()> {
381        with_db(async |connection_string, _db_client| {
382            let storage = PostgresStorage::new(connection_string).await?;
383            let client_id = Uuid::new_v4();
384            let mut txn = storage.txn(client_id).await?;
385            assert_eq!(txn.get_client().await?, None);
386            Ok(())
387        })
388        .await
389    }
390
391    #[tokio::test]
392    async fn test_get_client_exists_empty() -> anyhow::Result<()> {
393        with_db(async |connection_string, db_client| {
394            let storage = PostgresStorage::new(connection_string).await?;
395            let client_id = make_client(&db_client).await?;
396            let mut txn = storage.txn(client_id).await?;
397            assert_eq!(
398                txn.get_client().await?,
399                Some(Client {
400                    latest_version_id: Uuid::nil(),
401                    snapshot: None
402                })
403            );
404            Ok(())
405        })
406        .await
407    }
408
409    #[tokio::test]
410    async fn test_get_client_exists_latest() -> anyhow::Result<()> {
411        with_db(async |connection_string, db_client| {
412            let storage = PostgresStorage::new(connection_string).await?;
413            let client_id = make_client(&db_client).await?;
414            let latest_version_id = Uuid::new_v4();
415            set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
416            let mut txn = storage.txn(client_id).await?;
417            assert_eq!(
418                txn.get_client().await?,
419                Some(Client {
420                    latest_version_id,
421                    snapshot: None
422                })
423            );
424            Ok(())
425        })
426        .await
427    }
428
429    #[tokio::test]
430    async fn test_get_client_exists_with_snapshot() -> anyhow::Result<()> {
431        with_db(async |connection_string, db_client| {
432            let storage = PostgresStorage::new(connection_string).await?;
433            let client_id = make_client(&db_client).await?;
434            let snapshot_version_id = Uuid::new_v4();
435            let versions_since_snapshot = 10;
436            let snapshot_timestamp = 10000000;
437            let snapshot = b"abcd";
438            set_client_snapshot(
439                &db_client,
440                client_id,
441                snapshot_version_id,
442                versions_since_snapshot,
443                snapshot_timestamp,
444                snapshot,
445            )
446            .await?;
447            let mut txn = storage.txn(client_id).await?;
448            assert_eq!(
449                txn.get_client().await?,
450                Some(Client {
451                    latest_version_id: Uuid::nil(),
452                    snapshot: Some(Snapshot {
453                        version_id: snapshot_version_id,
454                        timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
455                        versions_since: versions_since_snapshot,
456                    })
457                })
458            );
459            Ok(())
460        })
461        .await
462    }
463
464    #[tokio::test]
465    async fn test_new_client() -> anyhow::Result<()> {
466        with_db(async |connection_string, _db_client| {
467            let storage = PostgresStorage::new(connection_string).await?;
468            let client_id = Uuid::new_v4();
469            let latest_version_id = Uuid::new_v4();
470
471            let mut txn1 = storage.txn(client_id).await?;
472            txn1.new_client(latest_version_id).await?;
473
474            // Client is not visible yet as txn1 is not committed.
475            let mut txn2 = storage.txn(client_id).await?;
476            assert_eq!(txn2.get_client().await?, None);
477
478            txn1.commit().await?;
479
480            // Client is now visible.
481            let mut txn2 = storage.txn(client_id).await?;
482            assert_eq!(
483                txn2.get_client().await?,
484                Some(Client {
485                    latest_version_id,
486                    snapshot: None
487                })
488            );
489
490            Ok(())
491        })
492        .await
493    }
494
495    #[tokio::test]
496    async fn test_set_snapshot() -> anyhow::Result<()> {
497        with_db(async |connection_string, db_client| {
498            let storage = PostgresStorage::new(connection_string).await?;
499            let client_id = make_client(&db_client).await?;
500            let mut txn = storage.txn(client_id).await?;
501            let snapshot_version_id = Uuid::new_v4();
502            let versions_since_snapshot = 10;
503            let snapshot_timestamp = 10000000;
504            let snapshot = b"abcd";
505
506            txn.set_snapshot(
507                Snapshot {
508                    version_id: snapshot_version_id,
509                    timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
510                    versions_since: versions_since_snapshot,
511                },
512                snapshot.to_vec(),
513            )
514            .await?;
515            txn.commit().await?;
516
517            txn = storage.txn(client_id).await?;
518            assert_eq!(
519                txn.get_client().await?,
520                Some(Client {
521                    latest_version_id: Uuid::nil(),
522                    snapshot: Some(Snapshot {
523                        version_id: snapshot_version_id,
524                        timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
525                        versions_since: versions_since_snapshot,
526                    })
527                })
528            );
529
530            let row = db_client
531                .query_one(
532                    "select snapshot from clients where client_id = $1",
533                    &[&client_id],
534                )
535                .await?;
536            assert_eq!(row.get::<_, &[u8]>(0), b"abcd");
537
538            Ok(())
539        })
540        .await
541    }
542
543    #[tokio::test]
544    async fn test_get_snapshot_none() -> anyhow::Result<()> {
545        with_db(async |connection_string, db_client| {
546            let storage = PostgresStorage::new(connection_string).await?;
547            let client_id = make_client(&db_client).await?;
548            let mut txn = storage.txn(client_id).await?;
549            assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
550
551            Ok(())
552        })
553        .await
554    }
555
556    #[tokio::test]
557    async fn test_get_snapshot_mismatched_version() -> anyhow::Result<()> {
558        with_db(async |connection_string, db_client| {
559            let storage = PostgresStorage::new(connection_string).await?;
560            let client_id = make_client(&db_client).await?;
561            let mut txn = storage.txn(client_id).await?;
562
563            let snapshot_version_id = Uuid::new_v4();
564            let versions_since_snapshot = 10;
565            let snapshot_timestamp = 10000000;
566            let snapshot = b"abcd";
567            txn.set_snapshot(
568                Snapshot {
569                    version_id: snapshot_version_id,
570                    timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
571                    versions_since: versions_since_snapshot,
572                },
573                snapshot.to_vec(),
574            )
575            .await?;
576
577            assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
578
579            Ok(())
580        })
581        .await
582    }
583
584    #[tokio::test]
585    async fn test_get_version() -> anyhow::Result<()> {
586        with_db(async |connection_string, db_client| {
587            let storage = PostgresStorage::new(connection_string).await?;
588            let client_id = make_client(&db_client).await?;
589            let parent_version_id = Uuid::new_v4();
590            let version_id = make_version(&db_client, client_id, parent_version_id, b"v1").await?;
591
592            let mut txn = storage.txn(client_id).await?;
593
594            // Different parent doesn't exist.
595            assert_eq!(txn.get_version_by_parent(Uuid::new_v4()).await?, None);
596
597            // Different version doesn't exist.
598            assert_eq!(txn.get_version(Uuid::new_v4()).await?, None);
599
600            let version = Version {
601                version_id,
602                parent_version_id,
603                history_segment: b"v1".to_vec(),
604            };
605
606            // Version found by parent.
607            assert_eq!(
608                txn.get_version_by_parent(parent_version_id).await?,
609                Some(version.clone())
610            );
611
612            // Version found by ID.
613            assert_eq!(txn.get_version(version_id).await?, Some(version));
614
615            Ok(())
616        })
617        .await
618    }
619
620    #[tokio::test]
621    async fn test_add_version() -> anyhow::Result<()> {
622        with_db(async |connection_string, db_client| {
623            let storage = PostgresStorage::new(connection_string).await?;
624            let client_id = make_client(&db_client).await?;
625            let mut txn = storage.txn(client_id).await?;
626            let version_id = Uuid::new_v4();
627            txn.add_version(version_id, Uuid::nil(), b"v1".to_vec())
628                .await?;
629            assert_eq!(
630                txn.get_version(version_id).await?,
631                Some(Version {
632                    version_id,
633                    parent_version_id: Uuid::nil(),
634                    history_segment: b"v1".to_vec()
635                })
636            );
637            Ok(())
638        })
639        .await
640    }
641
642    #[tokio::test]
643    /// When an add_version call specifies an incorrect `parent_version_id, it fails. This is
644    /// typically avoided by calling `get_client` beforehand, which (due to repeatable reads)
645    /// allows the caller to check the `latest_version_id` before calling `add_version`.
646    async fn test_add_version_mismatch() -> anyhow::Result<()> {
647        with_db(async |connection_string, db_client| {
648            let storage = PostgresStorage::new(connection_string).await?;
649            let client_id = make_client(&db_client).await?;
650            let latest_version_id = Uuid::new_v4();
651            set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
652
653            let mut txn = storage.txn(client_id).await?;
654            let version_id = Uuid::new_v4();
655            let parent_version_id = Uuid::new_v4(); // != latest_version_id
656            let res = txn
657                .add_version(version_id, parent_version_id, b"v1".to_vec())
658                .await;
659            assert!(res.is_err());
660            Ok(())
661        })
662        .await
663    }
664
665    #[tokio::test]
666    /// Adding versions to two different clients can proceed concurrently.
667    async fn test_add_version_no_conflict_different_clients() -> anyhow::Result<()> {
668        with_db(async |connection_string, db_client| {
669            let storage = PostgresStorage::new(connection_string).await?;
670
671            // Clients 1 and 2 do not interfere with each other; if these are the same client, then
672            // this will deadlock as one transaction waits for the other. If the postgres storage
673            // implementation serialized _all_ transactions across clients, that would limit its
674            // scalability.
675            //
676            // So the asertion here is "does not deadlock".
677
678            let client_id1 = make_client(&db_client).await?;
679            let mut txn1 = storage.txn(client_id1).await?;
680            let version_id1 = Uuid::new_v4();
681            txn1.add_version(version_id1, Uuid::nil(), b"v1".to_vec())
682                .await?;
683
684            let client_id2 = make_client(&db_client).await?;
685            let mut txn2 = storage.txn(client_id2).await?;
686            let version_id2 = Uuid::new_v4();
687            txn2.add_version(version_id2, Uuid::nil(), b"v2".to_vec())
688                .await?;
689
690            txn1.commit().await?;
691            txn2.commit().await?;
692
693            Ok(())
694        })
695        .await
696    }
697
698    #[tokio::test]
699    /// When an add_version call specifies a `parent_version_id` that does not exist in the
700    /// DB, but no other versions exist, the call succeeds.
701    async fn test_add_version_no_history() -> anyhow::Result<()> {
702        with_db(async |connection_string, db_client| {
703            let storage = PostgresStorage::new(connection_string).await?;
704            let client_id = make_client(&db_client).await?;
705
706            let mut txn = storage.txn(client_id).await?;
707            let version_id = Uuid::new_v4();
708            let parent_version_id = Uuid::new_v4();
709            txn.add_version(version_id, parent_version_id, b"v1".to_vec())
710                .await?;
711            Ok(())
712        })
713        .await
714    }
715}