1#![doc=include_str!("../schema.sql")]
14use anyhow::Context;
32use bb8::PooledConnection;
33use bb8_postgres::PostgresConnectionManager;
34use chrono::{TimeZone, Utc};
35use postgres_native_tls::MakeTlsConnector;
36use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
37use uuid::Uuid;
38
39#[cfg(test)]
40mod testing;
41
42#[derive(Debug, Clone, Copy)]
44pub struct LogErrorSink;
45
46impl LogErrorSink {
47 fn new() -> Box<Self> {
48 Box::new(Self)
49 }
50}
51
52impl bb8::ErrorSink<tokio_postgres::Error> for LogErrorSink {
53 fn sink(&self, e: tokio_postgres::Error) {
54 log::error!("Postgres connection error: {e}");
55 }
56
57 fn boxed_clone(&self) -> Box<dyn bb8::ErrorSink<tokio_postgres::Error>> {
58 Self::new()
59 }
60}
61
62pub struct PostgresStorage {
64 pool: bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>,
65}
66
67impl PostgresStorage {
68 pub async fn new(connection_string: impl ToString) -> anyhow::Result<Self> {
69 let connector = native_tls::TlsConnector::new()?;
70 let connector = postgres_native_tls::MakeTlsConnector::new(connector);
71 let manager = PostgresConnectionManager::new_from_stringlike(connection_string, connector)?;
72 let pool = bb8::Pool::builder()
73 .error_sink(LogErrorSink::new())
74 .build(manager)
75 .await?;
76 Ok(Self { pool })
77 }
78}
79
80#[async_trait::async_trait]
81impl Storage for PostgresStorage {
82 async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
83 let db_client = self.pool.get_owned().await?;
84
85 db_client
86 .execute("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE", &[])
87 .await?;
88
89 Ok(Box::new(Txn {
90 client_id,
91 db_client: Some(db_client),
92 }))
93 }
94}
95
96struct Txn {
97 client_id: Uuid,
98 db_client: Option<PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>>,
101}
102
103impl Txn {
104 fn db_client(&self) -> &tokio_postgres::Client {
106 let Some(db_client) = &self.db_client else {
107 panic!("Cannot use a postgres Txn after commit");
108 };
109 db_client
110 }
111
112 async fn get_version_impl(
114 &mut self,
115 query: &'static str,
116 client_id: Uuid,
117 version_id_arg: Uuid,
118 ) -> anyhow::Result<Option<Version>> {
119 Ok(self
120 .db_client()
121 .query_opt(query, &[&version_id_arg, &client_id])
122 .await
123 .context("error getting version")?
124 .map(|r| Version {
125 version_id: r.get(0),
126 parent_version_id: r.get(1),
127 history_segment: r.get("history_segment"),
128 }))
129 }
130}
131
132#[async_trait::async_trait(?Send)]
133impl StorageTxn for Txn {
134 async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
135 Ok(self
136 .db_client()
137 .query_opt(
138 "SELECT
139 latest_version_id,
140 snapshot_timestamp,
141 versions_since_snapshot,
142 snapshot_version_id
143 FROM clients
144 WHERE client_id = $1
145 LIMIT 1",
146 &[&self.client_id],
147 )
148 .await
149 .context("error getting client")?
150 .map(|r| {
151 let latest_version_id: Uuid = r.get(0);
152 let snapshot_timestamp: Option<i64> = r.get(1);
153 let versions_since_snapshot: Option<i32> = r.get(2);
154 let snapshot_version_id: Option<Uuid> = r.get(3);
155
156 let snapshot = match (
158 snapshot_timestamp,
159 versions_since_snapshot,
160 snapshot_version_id,
161 ) {
162 (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
163 version_id: v,
164 timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
165 versions_since: vs as u32,
166 }),
167 _ => None,
168 };
169 Client {
170 latest_version_id,
171 snapshot,
172 }
173 }))
174 }
175
176 async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
177 self.db_client()
178 .execute(
179 "INSERT INTO clients (client_id, latest_version_id) VALUES ($1, $2)",
180 &[&self.client_id, &latest_version_id],
181 )
182 .await
183 .context("error creating/updating client")?;
184 Ok(())
185 }
186
187 async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
188 let timestamp = snapshot.timestamp.timestamp();
189 self.db_client()
190 .execute(
191 "UPDATE clients
192 SET snapshot_version_id = $1,
193 versions_since_snapshot = $2,
194 snapshot_timestamp = $3,
195 snapshot = $4
196 WHERE client_id = $5",
197 &[
198 &snapshot.version_id,
199 &(snapshot.versions_since as i32),
200 ×tamp,
201 &data,
202 &self.client_id,
203 ],
204 )
205 .await
206 .context("error setting snapshot")?;
207 Ok(())
208 }
209
210 async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
211 Ok(self
212 .db_client()
213 .query_opt(
214 "SELECT snapshot
215 FROM clients
216 WHERE client_id = $1 and snapshot_version_id = $2
217 LIMIT 1",
218 &[&self.client_id, &version_id],
219 )
220 .await
221 .context("error getting snapshot data")?
222 .map(|r| r.get(0)))
223 }
224
225 async fn get_version_by_parent(
226 &mut self,
227 parent_version_id: Uuid,
228 ) -> anyhow::Result<Option<Version>> {
229 self.get_version_impl(
230 "SELECT version_id, parent_version_id, history_segment
231 FROM versions
232 WHERE parent_version_id = $1 AND client_id = $2",
233 self.client_id,
234 parent_version_id,
235 )
236 .await
237 }
238
239 async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
240 self.get_version_impl(
241 "SELECT version_id, parent_version_id, history_segment
242 FROM versions
243 WHERE version_id = $1 AND client_id = $2",
244 self.client_id,
245 version_id,
246 )
247 .await
248 }
249
250 async fn add_version(
251 &mut self,
252 version_id: Uuid,
253 parent_version_id: Uuid,
254 history_segment: Vec<u8>,
255 ) -> anyhow::Result<()> {
256 self.db_client()
257 .execute(
258 "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment)
259 VALUES ($1, $2, $3, $4)",
260 &[
261 &version_id,
262 &self.client_id,
263 &parent_version_id,
264 &history_segment,
265 ],
266 )
267 .await
268 .context("error inserting new version")?;
269 let rows_modified = self
270 .db_client()
271 .execute(
272 "UPDATE clients
273 SET latest_version_id = $1,
274 versions_since_snapshot = versions_since_snapshot + 1
275 WHERE client_id = $2 and (latest_version_id = $3 or latest_version_id = $4)",
276 &[
277 &version_id,
278 &self.client_id,
279 &parent_version_id,
280 &Uuid::nil(),
281 ],
282 )
283 .await
284 .context("error updating latest_version_id")?;
285
286 if rows_modified == 0 {
288 anyhow::bail!("clients.latest_version_id does not match parent_version_id");
289 }
290 Ok(())
291 }
292
293 async fn commit(&mut self) -> anyhow::Result<()> {
294 self.db_client().execute("COMMIT", &[]).await?;
295 self.db_client = None;
296 Ok(())
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use super::*;
303 use crate::testing::with_db;
304
305 async fn make_client(db_client: &tokio_postgres::Client) -> anyhow::Result<Uuid> {
306 let client_id = Uuid::new_v4();
307 db_client
308 .execute("insert into clients (client_id) values ($1)", &[&client_id])
309 .await?;
310 Ok(client_id)
311 }
312
313 async fn make_version(
314 db_client: &tokio_postgres::Client,
315 client_id: Uuid,
316 parent_version_id: Uuid,
317 history_segment: &[u8],
318 ) -> anyhow::Result<Uuid> {
319 let version_id = Uuid::new_v4();
320 db_client
321 .execute(
322 "insert into versions
323 (version_id, client_id, parent_version_id, history_segment)
324 values ($1, $2, $3, $4)",
325 &[
326 &version_id,
327 &client_id,
328 &parent_version_id,
329 &history_segment,
330 ],
331 )
332 .await?;
333 Ok(version_id)
334 }
335
336 async fn set_client_latest_version_id(
337 db_client: &tokio_postgres::Client,
338 client_id: Uuid,
339 latest_version_id: Uuid,
340 ) -> anyhow::Result<()> {
341 db_client
342 .execute(
343 "update clients set latest_version_id = $1 where client_id = $2",
344 &[&latest_version_id, &client_id],
345 )
346 .await?;
347 Ok(())
348 }
349
350 async fn set_client_snapshot(
351 db_client: &tokio_postgres::Client,
352 client_id: Uuid,
353 snapshot_version_id: Uuid,
354 versions_since_snapshot: u32,
355 snapshot_timestamp: i64,
356 snapshot: &[u8],
357 ) -> anyhow::Result<()> {
358 db_client
359 .execute(
360 "
361 update clients
362 set snapshot_version_id = $1,
363 versions_since_snapshot = $2,
364 snapshot_timestamp = $3,
365 snapshot = $4
366 where client_id = $5",
367 &[
368 &snapshot_version_id,
369 &(versions_since_snapshot as i32),
370 &snapshot_timestamp,
371 &snapshot,
372 &client_id,
373 ],
374 )
375 .await?;
376 Ok(())
377 }
378
379 #[tokio::test]
380 async fn test_get_client_none() -> anyhow::Result<()> {
381 with_db(async |connection_string, _db_client| {
382 let storage = PostgresStorage::new(connection_string).await?;
383 let client_id = Uuid::new_v4();
384 let mut txn = storage.txn(client_id).await?;
385 assert_eq!(txn.get_client().await?, None);
386 Ok(())
387 })
388 .await
389 }
390
391 #[tokio::test]
392 async fn test_get_client_exists_empty() -> anyhow::Result<()> {
393 with_db(async |connection_string, db_client| {
394 let storage = PostgresStorage::new(connection_string).await?;
395 let client_id = make_client(&db_client).await?;
396 let mut txn = storage.txn(client_id).await?;
397 assert_eq!(
398 txn.get_client().await?,
399 Some(Client {
400 latest_version_id: Uuid::nil(),
401 snapshot: None
402 })
403 );
404 Ok(())
405 })
406 .await
407 }
408
409 #[tokio::test]
410 async fn test_get_client_exists_latest() -> anyhow::Result<()> {
411 with_db(async |connection_string, db_client| {
412 let storage = PostgresStorage::new(connection_string).await?;
413 let client_id = make_client(&db_client).await?;
414 let latest_version_id = Uuid::new_v4();
415 set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
416 let mut txn = storage.txn(client_id).await?;
417 assert_eq!(
418 txn.get_client().await?,
419 Some(Client {
420 latest_version_id,
421 snapshot: None
422 })
423 );
424 Ok(())
425 })
426 .await
427 }
428
429 #[tokio::test]
430 async fn test_get_client_exists_with_snapshot() -> anyhow::Result<()> {
431 with_db(async |connection_string, db_client| {
432 let storage = PostgresStorage::new(connection_string).await?;
433 let client_id = make_client(&db_client).await?;
434 let snapshot_version_id = Uuid::new_v4();
435 let versions_since_snapshot = 10;
436 let snapshot_timestamp = 10000000;
437 let snapshot = b"abcd";
438 set_client_snapshot(
439 &db_client,
440 client_id,
441 snapshot_version_id,
442 versions_since_snapshot,
443 snapshot_timestamp,
444 snapshot,
445 )
446 .await?;
447 let mut txn = storage.txn(client_id).await?;
448 assert_eq!(
449 txn.get_client().await?,
450 Some(Client {
451 latest_version_id: Uuid::nil(),
452 snapshot: Some(Snapshot {
453 version_id: snapshot_version_id,
454 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
455 versions_since: versions_since_snapshot,
456 })
457 })
458 );
459 Ok(())
460 })
461 .await
462 }
463
464 #[tokio::test]
465 async fn test_new_client() -> anyhow::Result<()> {
466 with_db(async |connection_string, _db_client| {
467 let storage = PostgresStorage::new(connection_string).await?;
468 let client_id = Uuid::new_v4();
469 let latest_version_id = Uuid::new_v4();
470
471 let mut txn1 = storage.txn(client_id).await?;
472 txn1.new_client(latest_version_id).await?;
473
474 let mut txn2 = storage.txn(client_id).await?;
476 assert_eq!(txn2.get_client().await?, None);
477
478 txn1.commit().await?;
479
480 let mut txn2 = storage.txn(client_id).await?;
482 assert_eq!(
483 txn2.get_client().await?,
484 Some(Client {
485 latest_version_id,
486 snapshot: None
487 })
488 );
489
490 Ok(())
491 })
492 .await
493 }
494
495 #[tokio::test]
496 async fn test_set_snapshot() -> anyhow::Result<()> {
497 with_db(async |connection_string, db_client| {
498 let storage = PostgresStorage::new(connection_string).await?;
499 let client_id = make_client(&db_client).await?;
500 let mut txn = storage.txn(client_id).await?;
501 let snapshot_version_id = Uuid::new_v4();
502 let versions_since_snapshot = 10;
503 let snapshot_timestamp = 10000000;
504 let snapshot = b"abcd";
505
506 txn.set_snapshot(
507 Snapshot {
508 version_id: snapshot_version_id,
509 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
510 versions_since: versions_since_snapshot,
511 },
512 snapshot.to_vec(),
513 )
514 .await?;
515 txn.commit().await?;
516
517 txn = storage.txn(client_id).await?;
518 assert_eq!(
519 txn.get_client().await?,
520 Some(Client {
521 latest_version_id: Uuid::nil(),
522 snapshot: Some(Snapshot {
523 version_id: snapshot_version_id,
524 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
525 versions_since: versions_since_snapshot,
526 })
527 })
528 );
529
530 let row = db_client
531 .query_one(
532 "select snapshot from clients where client_id = $1",
533 &[&client_id],
534 )
535 .await?;
536 assert_eq!(row.get::<_, &[u8]>(0), b"abcd");
537
538 Ok(())
539 })
540 .await
541 }
542
543 #[tokio::test]
544 async fn test_get_snapshot_none() -> anyhow::Result<()> {
545 with_db(async |connection_string, db_client| {
546 let storage = PostgresStorage::new(connection_string).await?;
547 let client_id = make_client(&db_client).await?;
548 let mut txn = storage.txn(client_id).await?;
549 assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
550
551 Ok(())
552 })
553 .await
554 }
555
556 #[tokio::test]
557 async fn test_get_snapshot_mismatched_version() -> anyhow::Result<()> {
558 with_db(async |connection_string, db_client| {
559 let storage = PostgresStorage::new(connection_string).await?;
560 let client_id = make_client(&db_client).await?;
561 let mut txn = storage.txn(client_id).await?;
562
563 let snapshot_version_id = Uuid::new_v4();
564 let versions_since_snapshot = 10;
565 let snapshot_timestamp = 10000000;
566 let snapshot = b"abcd";
567 txn.set_snapshot(
568 Snapshot {
569 version_id: snapshot_version_id,
570 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
571 versions_since: versions_since_snapshot,
572 },
573 snapshot.to_vec(),
574 )
575 .await?;
576
577 assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
578
579 Ok(())
580 })
581 .await
582 }
583
584 #[tokio::test]
585 async fn test_get_version() -> anyhow::Result<()> {
586 with_db(async |connection_string, db_client| {
587 let storage = PostgresStorage::new(connection_string).await?;
588 let client_id = make_client(&db_client).await?;
589 let parent_version_id = Uuid::new_v4();
590 let version_id = make_version(&db_client, client_id, parent_version_id, b"v1").await?;
591
592 let mut txn = storage.txn(client_id).await?;
593
594 assert_eq!(txn.get_version_by_parent(Uuid::new_v4()).await?, None);
596
597 assert_eq!(txn.get_version(Uuid::new_v4()).await?, None);
599
600 let version = Version {
601 version_id,
602 parent_version_id,
603 history_segment: b"v1".to_vec(),
604 };
605
606 assert_eq!(
608 txn.get_version_by_parent(parent_version_id).await?,
609 Some(version.clone())
610 );
611
612 assert_eq!(txn.get_version(version_id).await?, Some(version));
614
615 Ok(())
616 })
617 .await
618 }
619
620 #[tokio::test]
621 async fn test_add_version() -> anyhow::Result<()> {
622 with_db(async |connection_string, db_client| {
623 let storage = PostgresStorage::new(connection_string).await?;
624 let client_id = make_client(&db_client).await?;
625 let mut txn = storage.txn(client_id).await?;
626 let version_id = Uuid::new_v4();
627 txn.add_version(version_id, Uuid::nil(), b"v1".to_vec())
628 .await?;
629 assert_eq!(
630 txn.get_version(version_id).await?,
631 Some(Version {
632 version_id,
633 parent_version_id: Uuid::nil(),
634 history_segment: b"v1".to_vec()
635 })
636 );
637 Ok(())
638 })
639 .await
640 }
641
642 #[tokio::test]
643 async fn test_add_version_mismatch() -> anyhow::Result<()> {
647 with_db(async |connection_string, db_client| {
648 let storage = PostgresStorage::new(connection_string).await?;
649 let client_id = make_client(&db_client).await?;
650 let latest_version_id = Uuid::new_v4();
651 set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
652
653 let mut txn = storage.txn(client_id).await?;
654 let version_id = Uuid::new_v4();
655 let parent_version_id = Uuid::new_v4(); let res = txn
657 .add_version(version_id, parent_version_id, b"v1".to_vec())
658 .await;
659 assert!(res.is_err());
660 Ok(())
661 })
662 .await
663 }
664
665 #[tokio::test]
666 async fn test_add_version_no_conflict_different_clients() -> anyhow::Result<()> {
668 with_db(async |connection_string, db_client| {
669 let storage = PostgresStorage::new(connection_string).await?;
670
671 let client_id1 = make_client(&db_client).await?;
679 let mut txn1 = storage.txn(client_id1).await?;
680 let version_id1 = Uuid::new_v4();
681 txn1.add_version(version_id1, Uuid::nil(), b"v1".to_vec())
682 .await?;
683
684 let client_id2 = make_client(&db_client).await?;
685 let mut txn2 = storage.txn(client_id2).await?;
686 let version_id2 = Uuid::new_v4();
687 txn2.add_version(version_id2, Uuid::nil(), b"v2".to_vec())
688 .await?;
689
690 txn1.commit().await?;
691 txn2.commit().await?;
692
693 Ok(())
694 })
695 .await
696 }
697
698 #[tokio::test]
699 async fn test_add_version_no_history() -> anyhow::Result<()> {
702 with_db(async |connection_string, db_client| {
703 let storage = PostgresStorage::new(connection_string).await?;
704 let client_id = make_client(&db_client).await?;
705
706 let mut txn = storage.txn(client_id).await?;
707 let version_id = Uuid::new_v4();
708 let parent_version_id = Uuid::new_v4();
709 txn.add_version(version_id, parent_version_id, b"v1".to_vec())
710 .await?;
711 Ok(())
712 })
713 .await
714 }
715}