taskchampion_sync_server_storage_postgres/
lib.rs

1//! This crate implements a Postgres storage backend for the TaskChampion sync server.
2//!
3//! Use the [`PostgresStorage`] type as an implementation of the [`Storage`] trait.
4//!
5//! This implementation is tested with Postgres version 17 but should work with any recent version.
6//!
7//! ## Schema Setup
8//!
9//! The database identified by the connection string must already exist and be set up with the
10//! following schema (also available in `postgres/schema.sql` in the repository):
11//!
12//! ```sql
13#![doc=include_str!("../schema.sql")]
14//! ```
15//!
16//! ## Integration with External Applications
17//!
18//! The schema is stable, and any changes to the schema will be made in a major version with
19//! migration instructions provided.
20//!
21//! An external application may:
22//!  - Add additional tables to the database
23//!  - Add additional columns to the `clients` table. If those columns do not have default
24//!    values, calls to `Txn::new_client` will fail. It is possible to configure
25//!    `taskchampion-sync-server` to never call this method.
26//!  - Insert rows into the `clients` table, using default values for all columns except
27//!    `client_id` and application-specific columns.
28//!  - Delete rows from the `clients` table, noting that any associated task data
29//!    is also deleted.
30
31use anyhow::Context;
32use bb8::PooledConnection;
33use bb8_postgres::PostgresConnectionManager;
34use chrono::{TimeZone, Utc};
35use postgres_native_tls::MakeTlsConnector;
36use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
37use uuid::Uuid;
38
39#[cfg(test)]
40mod testing;
41
42/// A storage backend which uses Postgres.
43pub struct PostgresStorage {
44    pool: bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>,
45}
46
47impl PostgresStorage {
48    pub async fn new(connection_string: impl ToString) -> anyhow::Result<Self> {
49        let connector = native_tls::TlsConnector::new()?;
50        let connector = postgres_native_tls::MakeTlsConnector::new(connector);
51        let manager = PostgresConnectionManager::new_from_stringlike(connection_string, connector)?;
52        let pool = bb8::Pool::builder().build(manager).await?;
53        Ok(Self { pool })
54    }
55}
56
57#[async_trait::async_trait]
58impl Storage for PostgresStorage {
59    async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
60        let db_client = self.pool.get_owned().await?;
61
62        db_client
63            .execute("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE", &[])
64            .await?;
65
66        Ok(Box::new(Txn {
67            client_id,
68            db_client: Some(db_client),
69        }))
70    }
71}
72
73struct Txn {
74    client_id: Uuid,
75    /// The DB client or, if `commit` has been called, None. This ensures queries aren't executed
76    /// after commit, and also frees connections back to the pool as quickly as possible.
77    db_client: Option<PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>>,
78}
79
80impl Txn {
81    /// Get the db_client, or panic if it is gone (after commit).
82    fn db_client(&self) -> &tokio_postgres::Client {
83        let Some(db_client) = &self.db_client else {
84            panic!("Cannot use a postgres Txn after commit");
85        };
86        db_client
87    }
88
89    /// Implementation for queries from the versions table
90    async fn get_version_impl(
91        &mut self,
92        query: &'static str,
93        client_id: Uuid,
94        version_id_arg: Uuid,
95    ) -> anyhow::Result<Option<Version>> {
96        Ok(self
97            .db_client()
98            .query_opt(query, &[&version_id_arg, &client_id])
99            .await
100            .context("error getting version")?
101            .map(|r| Version {
102                version_id: r.get(0),
103                parent_version_id: r.get(1),
104                history_segment: r.get("history_segment"),
105            }))
106    }
107}
108
109#[async_trait::async_trait(?Send)]
110impl StorageTxn for Txn {
111    async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
112        Ok(self
113            .db_client()
114            .query_opt(
115                "SELECT
116                    latest_version_id,
117                    snapshot_timestamp,
118                    versions_since_snapshot,
119                    snapshot_version_id
120                 FROM clients
121                 WHERE client_id = $1
122                 LIMIT 1",
123                &[&self.client_id],
124            )
125            .await
126            .context("error getting client")?
127            .map(|r| {
128                let latest_version_id: Uuid = r.get(0);
129                let snapshot_timestamp: Option<i64> = r.get(1);
130                let versions_since_snapshot: Option<i32> = r.get(2);
131                let snapshot_version_id: Option<Uuid> = r.get(3);
132
133                // if all of the relevant fields are non-NULL, return a snapshot
134                let snapshot = match (
135                    snapshot_timestamp,
136                    versions_since_snapshot,
137                    snapshot_version_id,
138                ) {
139                    (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
140                        version_id: v,
141                        timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
142                        versions_since: vs as u32,
143                    }),
144                    _ => None,
145                };
146                Client {
147                    latest_version_id,
148                    snapshot,
149                }
150            }))
151    }
152
153    async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
154        self.db_client()
155            .execute(
156                "INSERT INTO clients (client_id, latest_version_id) VALUES ($1, $2)",
157                &[&self.client_id, &latest_version_id],
158            )
159            .await
160            .context("error creating/updating client")?;
161        Ok(())
162    }
163
164    async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
165        let timestamp = snapshot.timestamp.timestamp();
166        self.db_client()
167            .execute(
168                "UPDATE clients
169                    SET snapshot_version_id = $1,
170                        versions_since_snapshot = $2,
171                        snapshot_timestamp = $3,
172                        snapshot = $4
173                    WHERE client_id = $5",
174                &[
175                    &snapshot.version_id,
176                    &(snapshot.versions_since as i32),
177                    &timestamp,
178                    &data,
179                    &self.client_id,
180                ],
181            )
182            .await
183            .context("error setting snapshot")?;
184        Ok(())
185    }
186
187    async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
188        Ok(self
189            .db_client()
190            .query_opt(
191                "SELECT snapshot
192                 FROM clients
193                 WHERE client_id = $1 and snapshot_version_id = $2
194                 LIMIT 1",
195                &[&self.client_id, &version_id],
196            )
197            .await
198            .context("error getting snapshot data")?
199            .map(|r| r.get(0)))
200    }
201
202    async fn get_version_by_parent(
203        &mut self,
204        parent_version_id: Uuid,
205    ) -> anyhow::Result<Option<Version>> {
206        self.get_version_impl(
207            "SELECT version_id, parent_version_id, history_segment
208                FROM versions
209                WHERE parent_version_id = $1 AND client_id = $2",
210            self.client_id,
211            parent_version_id,
212        )
213        .await
214    }
215
216    async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
217        self.get_version_impl(
218            "SELECT version_id, parent_version_id, history_segment
219                FROM versions
220                WHERE version_id = $1 AND client_id = $2",
221            self.client_id,
222            version_id,
223        )
224        .await
225    }
226
227    async fn add_version(
228        &mut self,
229        version_id: Uuid,
230        parent_version_id: Uuid,
231        history_segment: Vec<u8>,
232    ) -> anyhow::Result<()> {
233        self.db_client()
234            .execute(
235                "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment)
236                VALUES ($1, $2, $3, $4)",
237                &[
238                    &version_id,
239                    &self.client_id,
240                    &parent_version_id,
241                    &history_segment,
242                ],
243            )
244            .await
245            .context("error inserting new version")?;
246        let rows_modified = self
247            .db_client()
248            .execute(
249                "UPDATE clients
250                    SET latest_version_id = $1,
251                        versions_since_snapshot = versions_since_snapshot + 1
252                    WHERE client_id = $2 and latest_version_id = $3",
253                &[&version_id, &self.client_id, &parent_version_id],
254            )
255            .await
256            .context("error updating latest_version_id")?;
257
258        // If no rows were modified, this operation failed.
259        if rows_modified == 0 {
260            anyhow::bail!("clients.latest_version_id does not match parent_version_id");
261        }
262        Ok(())
263    }
264
265    async fn commit(&mut self) -> anyhow::Result<()> {
266        self.db_client().execute("COMMIT", &[]).await?;
267        self.db_client = None;
268        Ok(())
269    }
270}
271
272#[cfg(test)]
273mod test {
274    use super::*;
275    use crate::testing::with_db;
276
277    async fn make_client(db_client: &tokio_postgres::Client) -> anyhow::Result<Uuid> {
278        let client_id = Uuid::new_v4();
279        db_client
280            .execute("insert into clients (client_id) values ($1)", &[&client_id])
281            .await?;
282        Ok(client_id)
283    }
284
285    async fn make_version(
286        db_client: &tokio_postgres::Client,
287        client_id: Uuid,
288        parent_version_id: Uuid,
289        history_segment: &[u8],
290    ) -> anyhow::Result<Uuid> {
291        let version_id = Uuid::new_v4();
292        db_client
293            .execute(
294                "insert into versions
295                    (version_id, client_id, parent_version_id, history_segment)
296                    values ($1, $2, $3, $4)",
297                &[
298                    &version_id,
299                    &client_id,
300                    &parent_version_id,
301                    &history_segment,
302                ],
303            )
304            .await?;
305        Ok(version_id)
306    }
307
308    async fn set_client_latest_version_id(
309        db_client: &tokio_postgres::Client,
310        client_id: Uuid,
311        latest_version_id: Uuid,
312    ) -> anyhow::Result<()> {
313        db_client
314            .execute(
315                "update clients set latest_version_id = $1 where client_id = $2",
316                &[&latest_version_id, &client_id],
317            )
318            .await?;
319        Ok(())
320    }
321
322    async fn set_client_snapshot(
323        db_client: &tokio_postgres::Client,
324        client_id: Uuid,
325        snapshot_version_id: Uuid,
326        versions_since_snapshot: u32,
327        snapshot_timestamp: i64,
328        snapshot: &[u8],
329    ) -> anyhow::Result<()> {
330        db_client
331            .execute(
332                "
333                update clients
334                    set snapshot_version_id = $1,
335                        versions_since_snapshot = $2,
336                        snapshot_timestamp = $3,
337                        snapshot = $4
338                    where client_id = $5",
339                &[
340                    &snapshot_version_id,
341                    &(versions_since_snapshot as i32),
342                    &snapshot_timestamp,
343                    &snapshot,
344                    &client_id,
345                ],
346            )
347            .await?;
348        Ok(())
349    }
350
351    #[tokio::test]
352    async fn test_get_client_none() -> anyhow::Result<()> {
353        with_db(async |connection_string, _db_client| {
354            let storage = PostgresStorage::new(connection_string).await?;
355            let client_id = Uuid::new_v4();
356            let mut txn = storage.txn(client_id).await?;
357            assert_eq!(txn.get_client().await?, None);
358            Ok(())
359        })
360        .await
361    }
362
363    #[tokio::test]
364    async fn test_get_client_exists_empty() -> anyhow::Result<()> {
365        with_db(async |connection_string, db_client| {
366            let storage = PostgresStorage::new(connection_string).await?;
367            let client_id = make_client(&db_client).await?;
368            let mut txn = storage.txn(client_id).await?;
369            assert_eq!(
370                txn.get_client().await?,
371                Some(Client {
372                    latest_version_id: Uuid::nil(),
373                    snapshot: None
374                })
375            );
376            Ok(())
377        })
378        .await
379    }
380
381    #[tokio::test]
382    async fn test_get_client_exists_latest() -> anyhow::Result<()> {
383        with_db(async |connection_string, db_client| {
384            let storage = PostgresStorage::new(connection_string).await?;
385            let client_id = make_client(&db_client).await?;
386            let latest_version_id = Uuid::new_v4();
387            set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
388            let mut txn = storage.txn(client_id).await?;
389            assert_eq!(
390                txn.get_client().await?,
391                Some(Client {
392                    latest_version_id,
393                    snapshot: None
394                })
395            );
396            Ok(())
397        })
398        .await
399    }
400
401    #[tokio::test]
402    async fn test_get_client_exists_with_snapshot() -> anyhow::Result<()> {
403        with_db(async |connection_string, db_client| {
404            let storage = PostgresStorage::new(connection_string).await?;
405            let client_id = make_client(&db_client).await?;
406            let snapshot_version_id = Uuid::new_v4();
407            let versions_since_snapshot = 10;
408            let snapshot_timestamp = 10000000;
409            let snapshot = b"abcd";
410            set_client_snapshot(
411                &db_client,
412                client_id,
413                snapshot_version_id,
414                versions_since_snapshot,
415                snapshot_timestamp,
416                snapshot,
417            )
418            .await?;
419            let mut txn = storage.txn(client_id).await?;
420            assert_eq!(
421                txn.get_client().await?,
422                Some(Client {
423                    latest_version_id: Uuid::nil(),
424                    snapshot: Some(Snapshot {
425                        version_id: snapshot_version_id,
426                        timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
427                        versions_since: versions_since_snapshot,
428                    })
429                })
430            );
431            Ok(())
432        })
433        .await
434    }
435
436    #[tokio::test]
437    async fn test_new_client() -> anyhow::Result<()> {
438        with_db(async |connection_string, _db_client| {
439            let storage = PostgresStorage::new(connection_string).await?;
440            let client_id = Uuid::new_v4();
441            let latest_version_id = Uuid::new_v4();
442
443            let mut txn1 = storage.txn(client_id).await?;
444            txn1.new_client(latest_version_id).await?;
445
446            // Client is not visible yet as txn1 is not committed.
447            let mut txn2 = storage.txn(client_id).await?;
448            assert_eq!(txn2.get_client().await?, None);
449
450            txn1.commit().await?;
451
452            // Client is now visible.
453            let mut txn2 = storage.txn(client_id).await?;
454            assert_eq!(
455                txn2.get_client().await?,
456                Some(Client {
457                    latest_version_id,
458                    snapshot: None
459                })
460            );
461
462            Ok(())
463        })
464        .await
465    }
466
467    #[tokio::test]
468    async fn test_set_snapshot() -> anyhow::Result<()> {
469        with_db(async |connection_string, db_client| {
470            let storage = PostgresStorage::new(connection_string).await?;
471            let client_id = make_client(&db_client).await?;
472            let mut txn = storage.txn(client_id).await?;
473            let snapshot_version_id = Uuid::new_v4();
474            let versions_since_snapshot = 10;
475            let snapshot_timestamp = 10000000;
476            let snapshot = b"abcd";
477
478            txn.set_snapshot(
479                Snapshot {
480                    version_id: snapshot_version_id,
481                    timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
482                    versions_since: versions_since_snapshot,
483                },
484                snapshot.to_vec(),
485            )
486            .await?;
487            txn.commit().await?;
488
489            txn = storage.txn(client_id).await?;
490            assert_eq!(
491                txn.get_client().await?,
492                Some(Client {
493                    latest_version_id: Uuid::nil(),
494                    snapshot: Some(Snapshot {
495                        version_id: snapshot_version_id,
496                        timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
497                        versions_since: versions_since_snapshot,
498                    })
499                })
500            );
501
502            let row = db_client
503                .query_one(
504                    "select snapshot from clients where client_id = $1",
505                    &[&client_id],
506                )
507                .await?;
508            assert_eq!(row.get::<_, &[u8]>(0), b"abcd");
509
510            Ok(())
511        })
512        .await
513    }
514
515    #[tokio::test]
516    async fn test_get_snapshot_none() -> anyhow::Result<()> {
517        with_db(async |connection_string, db_client| {
518            let storage = PostgresStorage::new(connection_string).await?;
519            let client_id = make_client(&db_client).await?;
520            let mut txn = storage.txn(client_id).await?;
521            assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
522
523            Ok(())
524        })
525        .await
526    }
527
528    #[tokio::test]
529    async fn test_get_snapshot_mismatched_version() -> anyhow::Result<()> {
530        with_db(async |connection_string, db_client| {
531            let storage = PostgresStorage::new(connection_string).await?;
532            let client_id = make_client(&db_client).await?;
533            let mut txn = storage.txn(client_id).await?;
534
535            let snapshot_version_id = Uuid::new_v4();
536            let versions_since_snapshot = 10;
537            let snapshot_timestamp = 10000000;
538            let snapshot = b"abcd";
539            txn.set_snapshot(
540                Snapshot {
541                    version_id: snapshot_version_id,
542                    timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
543                    versions_since: versions_since_snapshot,
544                },
545                snapshot.to_vec(),
546            )
547            .await?;
548
549            assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
550
551            Ok(())
552        })
553        .await
554    }
555
556    #[tokio::test]
557    async fn test_get_version() -> anyhow::Result<()> {
558        with_db(async |connection_string, db_client| {
559            let storage = PostgresStorage::new(connection_string).await?;
560            let client_id = make_client(&db_client).await?;
561            let parent_version_id = Uuid::new_v4();
562            let version_id = make_version(&db_client, client_id, parent_version_id, b"v1").await?;
563
564            let mut txn = storage.txn(client_id).await?;
565
566            // Different parent doesn't exist.
567            assert_eq!(txn.get_version_by_parent(Uuid::new_v4()).await?, None);
568
569            // Different version doesn't exist.
570            assert_eq!(txn.get_version(Uuid::new_v4()).await?, None);
571
572            let version = Version {
573                version_id,
574                parent_version_id,
575                history_segment: b"v1".to_vec(),
576            };
577
578            // Version found by parent.
579            assert_eq!(
580                txn.get_version_by_parent(parent_version_id).await?,
581                Some(version.clone())
582            );
583
584            // Version found by ID.
585            assert_eq!(txn.get_version(version_id).await?, Some(version));
586
587            Ok(())
588        })
589        .await
590    }
591
592    #[tokio::test]
593    async fn test_add_version() -> anyhow::Result<()> {
594        with_db(async |connection_string, db_client| {
595            let storage = PostgresStorage::new(connection_string).await?;
596            let client_id = make_client(&db_client).await?;
597            let mut txn = storage.txn(client_id).await?;
598            let version_id = Uuid::new_v4();
599            txn.add_version(version_id, Uuid::nil(), b"v1".to_vec())
600                .await?;
601            assert_eq!(
602                txn.get_version(version_id).await?,
603                Some(Version {
604                    version_id,
605                    parent_version_id: Uuid::nil(),
606                    history_segment: b"v1".to_vec()
607                })
608            );
609            Ok(())
610        })
611        .await
612    }
613
614    #[tokio::test]
615    /// When an add_version call specifies an incorrect `parent_version_id, it fails. This is
616    /// typically avoided by calling `get_client` beforehand, which (due to repeatable reads)
617    /// allows the caller to check the `latest_version_id` before calling `add_version`.
618    async fn test_add_version_mismatch() -> anyhow::Result<()> {
619        with_db(async |connection_string, db_client| {
620            let storage = PostgresStorage::new(connection_string).await?;
621            let client_id = make_client(&db_client).await?;
622            let latest_version_id = Uuid::new_v4();
623            set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
624
625            let mut txn = storage.txn(client_id).await?;
626            let version_id = Uuid::new_v4();
627            let parent_version_id = Uuid::new_v4(); // != latest_version_id
628            let res = txn
629                .add_version(version_id, parent_version_id, b"v1".to_vec())
630                .await;
631            assert!(res.is_err());
632            Ok(())
633        })
634        .await
635    }
636
637    #[tokio::test]
638    /// Adding versions to two different clients can proceed concurrently.
639    async fn test_add_version_no_conflict_different_clients() -> anyhow::Result<()> {
640        with_db(async |connection_string, db_client| {
641            let storage = PostgresStorage::new(connection_string).await?;
642
643            // Clients 1 and 2 do not interfere with each other; if these are the same client, then
644            // this will deadlock as one transaction waits for the other. If the postgres storage
645            // implementation serialized _all_ transactions across clients, that would limit its
646            // scalability.
647            //
648            // So the asertion here is "does not deadlock".
649
650            let client_id1 = make_client(&db_client).await?;
651            let mut txn1 = storage.txn(client_id1).await?;
652            let version_id1 = Uuid::new_v4();
653            txn1.add_version(version_id1, Uuid::nil(), b"v1".to_vec())
654                .await?;
655
656            let client_id2 = make_client(&db_client).await?;
657            let mut txn2 = storage.txn(client_id2).await?;
658            let version_id2 = Uuid::new_v4();
659            txn2.add_version(version_id2, Uuid::nil(), b"v2".to_vec())
660                .await?;
661
662            txn1.commit().await?;
663            txn2.commit().await?;
664
665            Ok(())
666        })
667        .await
668    }
669}