1#![doc=include_str!("../schema.sql")]
14use anyhow::Context;
32use bb8::PooledConnection;
33use bb8_postgres::PostgresConnectionManager;
34use chrono::{TimeZone, Utc};
35use postgres_native_tls::MakeTlsConnector;
36use taskchampion_sync_server_core::{Client, Snapshot, Storage, StorageTxn, Version};
37use uuid::Uuid;
38
39#[cfg(test)]
40mod testing;
41
42pub struct PostgresStorage {
44 pool: bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>,
45}
46
47impl PostgresStorage {
48 pub async fn new(connection_string: impl ToString) -> anyhow::Result<Self> {
49 let connector = native_tls::TlsConnector::new()?;
50 let connector = postgres_native_tls::MakeTlsConnector::new(connector);
51 let manager = PostgresConnectionManager::new_from_stringlike(connection_string, connector)?;
52 let pool = bb8::Pool::builder().build(manager).await?;
53 Ok(Self { pool })
54 }
55}
56
57#[async_trait::async_trait]
58impl Storage for PostgresStorage {
59 async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
60 let db_client = self.pool.get_owned().await?;
61
62 db_client
63 .execute("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE", &[])
64 .await?;
65
66 Ok(Box::new(Txn {
67 client_id,
68 db_client: Some(db_client),
69 }))
70 }
71}
72
73struct Txn {
74 client_id: Uuid,
75 db_client: Option<PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>>,
78}
79
80impl Txn {
81 fn db_client(&self) -> &tokio_postgres::Client {
83 let Some(db_client) = &self.db_client else {
84 panic!("Cannot use a postgres Txn after commit");
85 };
86 db_client
87 }
88
89 async fn get_version_impl(
91 &mut self,
92 query: &'static str,
93 client_id: Uuid,
94 version_id_arg: Uuid,
95 ) -> anyhow::Result<Option<Version>> {
96 Ok(self
97 .db_client()
98 .query_opt(query, &[&version_id_arg, &client_id])
99 .await
100 .context("error getting version")?
101 .map(|r| Version {
102 version_id: r.get(0),
103 parent_version_id: r.get(1),
104 history_segment: r.get("history_segment"),
105 }))
106 }
107}
108
109#[async_trait::async_trait(?Send)]
110impl StorageTxn for Txn {
111 async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
112 Ok(self
113 .db_client()
114 .query_opt(
115 "SELECT
116 latest_version_id,
117 snapshot_timestamp,
118 versions_since_snapshot,
119 snapshot_version_id
120 FROM clients
121 WHERE client_id = $1
122 LIMIT 1",
123 &[&self.client_id],
124 )
125 .await
126 .context("error getting client")?
127 .map(|r| {
128 let latest_version_id: Uuid = r.get(0);
129 let snapshot_timestamp: Option<i64> = r.get(1);
130 let versions_since_snapshot: Option<i32> = r.get(2);
131 let snapshot_version_id: Option<Uuid> = r.get(3);
132
133 let snapshot = match (
135 snapshot_timestamp,
136 versions_since_snapshot,
137 snapshot_version_id,
138 ) {
139 (Some(ts), Some(vs), Some(v)) => Some(Snapshot {
140 version_id: v,
141 timestamp: Utc.timestamp_opt(ts, 0).unwrap(),
142 versions_since: vs as u32,
143 }),
144 _ => None,
145 };
146 Client {
147 latest_version_id,
148 snapshot,
149 }
150 }))
151 }
152
153 async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
154 self.db_client()
155 .execute(
156 "INSERT INTO clients (client_id, latest_version_id) VALUES ($1, $2)",
157 &[&self.client_id, &latest_version_id],
158 )
159 .await
160 .context("error creating/updating client")?;
161 Ok(())
162 }
163
164 async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
165 let timestamp = snapshot.timestamp.timestamp();
166 self.db_client()
167 .execute(
168 "UPDATE clients
169 SET snapshot_version_id = $1,
170 versions_since_snapshot = $2,
171 snapshot_timestamp = $3,
172 snapshot = $4
173 WHERE client_id = $5",
174 &[
175 &snapshot.version_id,
176 &(snapshot.versions_since as i32),
177 ×tamp,
178 &data,
179 &self.client_id,
180 ],
181 )
182 .await
183 .context("error setting snapshot")?;
184 Ok(())
185 }
186
187 async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
188 Ok(self
189 .db_client()
190 .query_opt(
191 "SELECT snapshot
192 FROM clients
193 WHERE client_id = $1 and snapshot_version_id = $2
194 LIMIT 1",
195 &[&self.client_id, &version_id],
196 )
197 .await
198 .context("error getting snapshot data")?
199 .map(|r| r.get(0)))
200 }
201
202 async fn get_version_by_parent(
203 &mut self,
204 parent_version_id: Uuid,
205 ) -> anyhow::Result<Option<Version>> {
206 self.get_version_impl(
207 "SELECT version_id, parent_version_id, history_segment
208 FROM versions
209 WHERE parent_version_id = $1 AND client_id = $2",
210 self.client_id,
211 parent_version_id,
212 )
213 .await
214 }
215
216 async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
217 self.get_version_impl(
218 "SELECT version_id, parent_version_id, history_segment
219 FROM versions
220 WHERE version_id = $1 AND client_id = $2",
221 self.client_id,
222 version_id,
223 )
224 .await
225 }
226
227 async fn add_version(
228 &mut self,
229 version_id: Uuid,
230 parent_version_id: Uuid,
231 history_segment: Vec<u8>,
232 ) -> anyhow::Result<()> {
233 self.db_client()
234 .execute(
235 "INSERT INTO versions (version_id, client_id, parent_version_id, history_segment)
236 VALUES ($1, $2, $3, $4)",
237 &[
238 &version_id,
239 &self.client_id,
240 &parent_version_id,
241 &history_segment,
242 ],
243 )
244 .await
245 .context("error inserting new version")?;
246 let rows_modified = self
247 .db_client()
248 .execute(
249 "UPDATE clients
250 SET latest_version_id = $1,
251 versions_since_snapshot = versions_since_snapshot + 1
252 WHERE client_id = $2 and latest_version_id = $3",
253 &[&version_id, &self.client_id, &parent_version_id],
254 )
255 .await
256 .context("error updating latest_version_id")?;
257
258 if rows_modified == 0 {
260 anyhow::bail!("clients.latest_version_id does not match parent_version_id");
261 }
262 Ok(())
263 }
264
265 async fn commit(&mut self) -> anyhow::Result<()> {
266 self.db_client().execute("COMMIT", &[]).await?;
267 self.db_client = None;
268 Ok(())
269 }
270}
271
272#[cfg(test)]
273mod test {
274 use super::*;
275 use crate::testing::with_db;
276
277 async fn make_client(db_client: &tokio_postgres::Client) -> anyhow::Result<Uuid> {
278 let client_id = Uuid::new_v4();
279 db_client
280 .execute("insert into clients (client_id) values ($1)", &[&client_id])
281 .await?;
282 Ok(client_id)
283 }
284
285 async fn make_version(
286 db_client: &tokio_postgres::Client,
287 client_id: Uuid,
288 parent_version_id: Uuid,
289 history_segment: &[u8],
290 ) -> anyhow::Result<Uuid> {
291 let version_id = Uuid::new_v4();
292 db_client
293 .execute(
294 "insert into versions
295 (version_id, client_id, parent_version_id, history_segment)
296 values ($1, $2, $3, $4)",
297 &[
298 &version_id,
299 &client_id,
300 &parent_version_id,
301 &history_segment,
302 ],
303 )
304 .await?;
305 Ok(version_id)
306 }
307
308 async fn set_client_latest_version_id(
309 db_client: &tokio_postgres::Client,
310 client_id: Uuid,
311 latest_version_id: Uuid,
312 ) -> anyhow::Result<()> {
313 db_client
314 .execute(
315 "update clients set latest_version_id = $1 where client_id = $2",
316 &[&latest_version_id, &client_id],
317 )
318 .await?;
319 Ok(())
320 }
321
322 async fn set_client_snapshot(
323 db_client: &tokio_postgres::Client,
324 client_id: Uuid,
325 snapshot_version_id: Uuid,
326 versions_since_snapshot: u32,
327 snapshot_timestamp: i64,
328 snapshot: &[u8],
329 ) -> anyhow::Result<()> {
330 db_client
331 .execute(
332 "
333 update clients
334 set snapshot_version_id = $1,
335 versions_since_snapshot = $2,
336 snapshot_timestamp = $3,
337 snapshot = $4
338 where client_id = $5",
339 &[
340 &snapshot_version_id,
341 &(versions_since_snapshot as i32),
342 &snapshot_timestamp,
343 &snapshot,
344 &client_id,
345 ],
346 )
347 .await?;
348 Ok(())
349 }
350
351 #[tokio::test]
352 async fn test_get_client_none() -> anyhow::Result<()> {
353 with_db(async |connection_string, _db_client| {
354 let storage = PostgresStorage::new(connection_string).await?;
355 let client_id = Uuid::new_v4();
356 let mut txn = storage.txn(client_id).await?;
357 assert_eq!(txn.get_client().await?, None);
358 Ok(())
359 })
360 .await
361 }
362
363 #[tokio::test]
364 async fn test_get_client_exists_empty() -> anyhow::Result<()> {
365 with_db(async |connection_string, db_client| {
366 let storage = PostgresStorage::new(connection_string).await?;
367 let client_id = make_client(&db_client).await?;
368 let mut txn = storage.txn(client_id).await?;
369 assert_eq!(
370 txn.get_client().await?,
371 Some(Client {
372 latest_version_id: Uuid::nil(),
373 snapshot: None
374 })
375 );
376 Ok(())
377 })
378 .await
379 }
380
381 #[tokio::test]
382 async fn test_get_client_exists_latest() -> anyhow::Result<()> {
383 with_db(async |connection_string, db_client| {
384 let storage = PostgresStorage::new(connection_string).await?;
385 let client_id = make_client(&db_client).await?;
386 let latest_version_id = Uuid::new_v4();
387 set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
388 let mut txn = storage.txn(client_id).await?;
389 assert_eq!(
390 txn.get_client().await?,
391 Some(Client {
392 latest_version_id,
393 snapshot: None
394 })
395 );
396 Ok(())
397 })
398 .await
399 }
400
401 #[tokio::test]
402 async fn test_get_client_exists_with_snapshot() -> anyhow::Result<()> {
403 with_db(async |connection_string, db_client| {
404 let storage = PostgresStorage::new(connection_string).await?;
405 let client_id = make_client(&db_client).await?;
406 let snapshot_version_id = Uuid::new_v4();
407 let versions_since_snapshot = 10;
408 let snapshot_timestamp = 10000000;
409 let snapshot = b"abcd";
410 set_client_snapshot(
411 &db_client,
412 client_id,
413 snapshot_version_id,
414 versions_since_snapshot,
415 snapshot_timestamp,
416 snapshot,
417 )
418 .await?;
419 let mut txn = storage.txn(client_id).await?;
420 assert_eq!(
421 txn.get_client().await?,
422 Some(Client {
423 latest_version_id: Uuid::nil(),
424 snapshot: Some(Snapshot {
425 version_id: snapshot_version_id,
426 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
427 versions_since: versions_since_snapshot,
428 })
429 })
430 );
431 Ok(())
432 })
433 .await
434 }
435
436 #[tokio::test]
437 async fn test_new_client() -> anyhow::Result<()> {
438 with_db(async |connection_string, _db_client| {
439 let storage = PostgresStorage::new(connection_string).await?;
440 let client_id = Uuid::new_v4();
441 let latest_version_id = Uuid::new_v4();
442
443 let mut txn1 = storage.txn(client_id).await?;
444 txn1.new_client(latest_version_id).await?;
445
446 let mut txn2 = storage.txn(client_id).await?;
448 assert_eq!(txn2.get_client().await?, None);
449
450 txn1.commit().await?;
451
452 let mut txn2 = storage.txn(client_id).await?;
454 assert_eq!(
455 txn2.get_client().await?,
456 Some(Client {
457 latest_version_id,
458 snapshot: None
459 })
460 );
461
462 Ok(())
463 })
464 .await
465 }
466
467 #[tokio::test]
468 async fn test_set_snapshot() -> anyhow::Result<()> {
469 with_db(async |connection_string, db_client| {
470 let storage = PostgresStorage::new(connection_string).await?;
471 let client_id = make_client(&db_client).await?;
472 let mut txn = storage.txn(client_id).await?;
473 let snapshot_version_id = Uuid::new_v4();
474 let versions_since_snapshot = 10;
475 let snapshot_timestamp = 10000000;
476 let snapshot = b"abcd";
477
478 txn.set_snapshot(
479 Snapshot {
480 version_id: snapshot_version_id,
481 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
482 versions_since: versions_since_snapshot,
483 },
484 snapshot.to_vec(),
485 )
486 .await?;
487 txn.commit().await?;
488
489 txn = storage.txn(client_id).await?;
490 assert_eq!(
491 txn.get_client().await?,
492 Some(Client {
493 latest_version_id: Uuid::nil(),
494 snapshot: Some(Snapshot {
495 version_id: snapshot_version_id,
496 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
497 versions_since: versions_since_snapshot,
498 })
499 })
500 );
501
502 let row = db_client
503 .query_one(
504 "select snapshot from clients where client_id = $1",
505 &[&client_id],
506 )
507 .await?;
508 assert_eq!(row.get::<_, &[u8]>(0), b"abcd");
509
510 Ok(())
511 })
512 .await
513 }
514
515 #[tokio::test]
516 async fn test_get_snapshot_none() -> anyhow::Result<()> {
517 with_db(async |connection_string, db_client| {
518 let storage = PostgresStorage::new(connection_string).await?;
519 let client_id = make_client(&db_client).await?;
520 let mut txn = storage.txn(client_id).await?;
521 assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
522
523 Ok(())
524 })
525 .await
526 }
527
528 #[tokio::test]
529 async fn test_get_snapshot_mismatched_version() -> anyhow::Result<()> {
530 with_db(async |connection_string, db_client| {
531 let storage = PostgresStorage::new(connection_string).await?;
532 let client_id = make_client(&db_client).await?;
533 let mut txn = storage.txn(client_id).await?;
534
535 let snapshot_version_id = Uuid::new_v4();
536 let versions_since_snapshot = 10;
537 let snapshot_timestamp = 10000000;
538 let snapshot = b"abcd";
539 txn.set_snapshot(
540 Snapshot {
541 version_id: snapshot_version_id,
542 timestamp: Utc.timestamp_opt(snapshot_timestamp, 0).unwrap(),
543 versions_since: versions_since_snapshot,
544 },
545 snapshot.to_vec(),
546 )
547 .await?;
548
549 assert_eq!(txn.get_snapshot_data(Uuid::new_v4()).await?, None);
550
551 Ok(())
552 })
553 .await
554 }
555
556 #[tokio::test]
557 async fn test_get_version() -> anyhow::Result<()> {
558 with_db(async |connection_string, db_client| {
559 let storage = PostgresStorage::new(connection_string).await?;
560 let client_id = make_client(&db_client).await?;
561 let parent_version_id = Uuid::new_v4();
562 let version_id = make_version(&db_client, client_id, parent_version_id, b"v1").await?;
563
564 let mut txn = storage.txn(client_id).await?;
565
566 assert_eq!(txn.get_version_by_parent(Uuid::new_v4()).await?, None);
568
569 assert_eq!(txn.get_version(Uuid::new_v4()).await?, None);
571
572 let version = Version {
573 version_id,
574 parent_version_id,
575 history_segment: b"v1".to_vec(),
576 };
577
578 assert_eq!(
580 txn.get_version_by_parent(parent_version_id).await?,
581 Some(version.clone())
582 );
583
584 assert_eq!(txn.get_version(version_id).await?, Some(version));
586
587 Ok(())
588 })
589 .await
590 }
591
592 #[tokio::test]
593 async fn test_add_version() -> anyhow::Result<()> {
594 with_db(async |connection_string, db_client| {
595 let storage = PostgresStorage::new(connection_string).await?;
596 let client_id = make_client(&db_client).await?;
597 let mut txn = storage.txn(client_id).await?;
598 let version_id = Uuid::new_v4();
599 txn.add_version(version_id, Uuid::nil(), b"v1".to_vec())
600 .await?;
601 assert_eq!(
602 txn.get_version(version_id).await?,
603 Some(Version {
604 version_id,
605 parent_version_id: Uuid::nil(),
606 history_segment: b"v1".to_vec()
607 })
608 );
609 Ok(())
610 })
611 .await
612 }
613
614 #[tokio::test]
615 async fn test_add_version_mismatch() -> anyhow::Result<()> {
619 with_db(async |connection_string, db_client| {
620 let storage = PostgresStorage::new(connection_string).await?;
621 let client_id = make_client(&db_client).await?;
622 let latest_version_id = Uuid::new_v4();
623 set_client_latest_version_id(&db_client, client_id, latest_version_id).await?;
624
625 let mut txn = storage.txn(client_id).await?;
626 let version_id = Uuid::new_v4();
627 let parent_version_id = Uuid::new_v4(); let res = txn
629 .add_version(version_id, parent_version_id, b"v1".to_vec())
630 .await;
631 assert!(res.is_err());
632 Ok(())
633 })
634 .await
635 }
636
637 #[tokio::test]
638 async fn test_add_version_no_conflict_different_clients() -> anyhow::Result<()> {
640 with_db(async |connection_string, db_client| {
641 let storage = PostgresStorage::new(connection_string).await?;
642
643 let client_id1 = make_client(&db_client).await?;
651 let mut txn1 = storage.txn(client_id1).await?;
652 let version_id1 = Uuid::new_v4();
653 txn1.add_version(version_id1, Uuid::nil(), b"v1".to_vec())
654 .await?;
655
656 let client_id2 = make_client(&db_client).await?;
657 let mut txn2 = storage.txn(client_id2).await?;
658 let version_id2 = Uuid::new_v4();
659 txn2.add_version(version_id2, Uuid::nil(), b"v2".to_vec())
660 .await?;
661
662 txn1.commit().await?;
663 txn2.commit().await?;
664
665 Ok(())
666 })
667 .await
668 }
669}