Skip to main content

sourcery_postgres/
snapshot.rs

1//! PostgreSQL-backed snapshot store implementation.
2//!
3//! This module provides [`Store`], an implementation of
4//! [`sourcery_core::snapshot::SnapshotStore`] for `PostgreSQL`.
5
6use serde::{Serialize, de::DeserializeOwned};
7use sourcery_core::snapshot::{
8    OfferSnapshotError, Snapshot, SnapshotOffer, SnapshotStore, inmemory::SnapshotPolicy,
9};
10use sqlx::{PgPool, Row};
11
12/// Error type for `PostgreSQL` snapshot operations.
13#[derive(Debug, thiserror::Error)]
14pub enum Error {
15    /// Database error during snapshot operations.
16    #[error("database error: {0}")]
17    Database(#[from] sqlx::Error),
18    /// Serialisation error.
19    #[error("serialization error: {0}")]
20    Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
21    /// Deserialisation error.
22    #[error("deserialization error: {0}")]
23    Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
24}
25
26/// A PostgreSQL-backed snapshot store with configurable policy.
27///
28/// This implementation stores snapshots in a dedicated `PostgreSQL` table
29/// (`es_snapshots`), using the same database as the event store for
30/// consistency.
31///
32/// # Schema
33///
34/// The store uses the following table schema (created by
35/// [`migrate()`](Self::migrate)):
36///
37/// ```sql
38/// CREATE TABLE IF NOT EXISTS es_snapshots (
39///     aggregate_kind TEXT NOT NULL,
40///     aggregate_id   UUID NOT NULL,
41///     position       BIGINT NOT NULL,
42///     data           JSONB NOT NULL,
43///     created_at     TIMESTAMPTZ NOT NULL DEFAULT now(),
44///     PRIMARY KEY (aggregate_kind, aggregate_id)
45/// )
46/// ```
47///
48/// # Example
49///
50/// ```ignore
51/// use sourcery_postgres::{Store as EventStore};
52/// use sourcery_postgres::snapshot::Store as SnapshotStore;
53/// use sourcery_core::Repository;
54///
55/// let pool = PgPool::connect("postgres://...").await?;
56/// let event_store = EventStore::new(pool.clone());
57/// let snapshot_store = SnapshotStore::every(pool, 100);
58///
59/// // Run migrations
60/// event_store.migrate().await?;
61/// snapshot_store.migrate().await?;
62///
63/// let repo = Repository::new(event_store).with_snapshots(snapshot_store);
64/// ```
65#[derive(Clone)]
66pub struct Store {
67    pool: PgPool,
68    policy: SnapshotPolicy,
69}
70
71impl std::fmt::Debug for Store {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("SnapshotStore")
74            .field("policy", &self.policy)
75            .finish_non_exhaustive()
76    }
77}
78
79impl Store {
80    /// Create a snapshot store that saves after every command.
81    ///
82    /// Best for aggregates with expensive replay or many events.
83    /// See the policy guidelines in [`SnapshotPolicy`] for choosing an
84    /// appropriate cadence.
85    #[must_use]
86    pub const fn always(pool: PgPool) -> Self {
87        Self {
88            pool,
89            policy: SnapshotPolicy::Always,
90        }
91    }
92
93    /// Create a snapshot store that saves every N events.
94    ///
95    /// Recommended for most use cases. Start with `n = 50-100` and tune
96    /// based on your aggregate's replay cost.
97    #[must_use]
98    pub const fn every(pool: PgPool, n: u64) -> Self {
99        Self {
100            pool,
101            policy: SnapshotPolicy::EveryNEvents(n),
102        }
103    }
104
105    /// Create a snapshot store that never saves (load-only).
106    ///
107    /// Use for read replicas, short-lived aggregates, or when managing
108    /// snapshots externally.
109    #[must_use]
110    pub const fn never(pool: PgPool) -> Self {
111        Self {
112            pool,
113            policy: SnapshotPolicy::Never,
114        }
115    }
116
117    /// Apply the snapshot schema (idempotent).
118    ///
119    /// This uses `CREATE TABLE IF NOT EXISTS` style DDL so it can be run on
120    /// startup alongside the event store migrations.
121    ///
122    /// # Errors
123    ///
124    /// Returns a `sqlx::Error` if the schema creation query fails.
125    #[tracing::instrument(skip(self))]
126    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
127        sqlx::query(
128            r"
129            CREATE TABLE IF NOT EXISTS es_snapshots (
130                aggregate_kind TEXT NOT NULL,
131                aggregate_id   UUID NOT NULL,
132                position       BIGINT NOT NULL,
133                data           JSONB NOT NULL,
134                created_at     TIMESTAMPTZ NOT NULL DEFAULT now(),
135                PRIMARY KEY (aggregate_kind, aggregate_id)
136            )
137            ",
138        )
139        .execute(&self.pool)
140        .await?;
141
142        Ok(())
143    }
144}
145
146impl SnapshotStore<uuid::Uuid> for Store {
147    type Error = Error;
148    type Position = i64;
149
150    #[tracing::instrument(skip(self))]
151    async fn load<T>(
152        &self,
153        kind: &str,
154        id: &uuid::Uuid,
155    ) -> Result<Option<Snapshot<Self::Position, T>>, Self::Error>
156    where
157        T: DeserializeOwned,
158    {
159        let result = sqlx::query(
160            r"
161            SELECT position, data
162            FROM es_snapshots
163            WHERE aggregate_kind = $1 AND aggregate_id = $2
164            ",
165        )
166        .bind(kind)
167        .bind(id)
168        .fetch_optional(&self.pool)
169        .await?;
170
171        let snapshot = result
172            .map(|row| {
173                let position: i64 = row.get("position");
174                let data: sqlx::types::Json<serde_json::Value> = row.get("data");
175                serde_json::from_value::<T>(data.0)
176                    .map(|decoded| Snapshot {
177                        position,
178                        data: decoded,
179                    })
180                    .map_err(|e| Error::Deserialization(Box::new(e)))
181            })
182            .transpose()?;
183
184        tracing::trace!(found = snapshot.is_some(), "snapshot lookup");
185        Ok(snapshot)
186    }
187
188    #[tracing::instrument(skip(self, create_snapshot))]
189    async fn offer_snapshot<CE, T, Create>(
190        &self,
191        kind: &str,
192        id: &uuid::Uuid,
193        events_since_last_snapshot: u64,
194        create_snapshot: Create,
195    ) -> Result<SnapshotOffer, OfferSnapshotError<Self::Error, CE>>
196    where
197        CE: std::error::Error + Send + Sync + 'static,
198        T: Serialize,
199        Create: FnOnce() -> Result<Snapshot<Self::Position, T>, CE>,
200    {
201        let prepared = if self.policy.should_snapshot(events_since_last_snapshot) {
202            match create_snapshot() {
203                Ok(snapshot) => serde_json::to_value(&snapshot.data)
204                    .map(|data| Some((snapshot.position, data)))
205                    .map_err(|e| OfferSnapshotError::Snapshot(Error::Serialization(Box::new(e)))),
206                Err(e) => Err(OfferSnapshotError::Create(e)),
207            }
208        } else {
209            Ok(None)
210        }?;
211
212        let Some((position, data)) = prepared else {
213            return Ok(SnapshotOffer::Declined);
214        };
215
216        // Use ON CONFLICT to upsert, but only if the new position is greater
217        // than the existing one. This prevents race conditions where an older
218        // snapshot could overwrite a newer one.
219        let result = sqlx::query(
220            r"
221            INSERT INTO es_snapshots (aggregate_kind, aggregate_id, position, data)
222            VALUES ($1, $2, $3, $4)
223            ON CONFLICT (aggregate_kind, aggregate_id)
224            DO UPDATE SET position = EXCLUDED.position, data = EXCLUDED.data, created_at = now()
225            WHERE es_snapshots.position < EXCLUDED.position
226            ",
227        )
228        .bind(kind)
229        .bind(id)
230        .bind(position)
231        .bind(sqlx::types::Json(data))
232        .execute(&self.pool)
233        .await
234        .map_err(|e| OfferSnapshotError::Snapshot(Error::Database(e)))?;
235
236        // rows_affected() will be 1 if inserted or updated, 0 if the existing
237        // snapshot has a >= position (declined due to staleness)
238        let offer = if result.rows_affected() > 0 {
239            SnapshotOffer::Stored
240        } else {
241            SnapshotOffer::Declined
242        };
243
244        tracing::debug!(
245            events_since_last_snapshot,
246            ?offer,
247            "snapshot offer evaluated"
248        );
249        Ok(offer)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn policy_always_should_snapshot() {
259        let policy = SnapshotPolicy::Always;
260        assert!(policy.should_snapshot(0));
261        assert!(policy.should_snapshot(1));
262        assert!(policy.should_snapshot(100));
263    }
264
265    #[test]
266    fn policy_every_n_events_should_snapshot() {
267        let policy = SnapshotPolicy::EveryNEvents(50);
268        assert!(!policy.should_snapshot(0));
269        assert!(!policy.should_snapshot(49));
270        assert!(policy.should_snapshot(50));
271        assert!(policy.should_snapshot(100));
272    }
273
274    #[test]
275    fn policy_never_should_snapshot() {
276        let policy = SnapshotPolicy::Never;
277        assert!(!policy.should_snapshot(0));
278        assert!(!policy.should_snapshot(1));
279        assert!(!policy.should_snapshot(1000));
280    }
281}