sourcery_postgres/
snapshot.rs1use serde::{Serialize, de::DeserializeOwned};
7use sourcery_core::snapshot::{
8 OfferSnapshotError, Snapshot, SnapshotOffer, SnapshotStore, inmemory::SnapshotPolicy,
9};
10use sqlx::{PgPool, Row};
11
12#[derive(Debug, thiserror::Error)]
14pub enum Error {
15 #[error("database error: {0}")]
17 Database(#[from] sqlx::Error),
18 #[error("serialization error: {0}")]
20 Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
21 #[error("deserialization error: {0}")]
23 Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
24}
25
26#[derive(Clone)]
66pub struct Store {
67 pool: PgPool,
68 policy: SnapshotPolicy,
69}
70
71impl std::fmt::Debug for Store {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("SnapshotStore")
74 .field("policy", &self.policy)
75 .finish_non_exhaustive()
76 }
77}
78
79impl Store {
80 #[must_use]
86 pub const fn always(pool: PgPool) -> Self {
87 Self {
88 pool,
89 policy: SnapshotPolicy::Always,
90 }
91 }
92
93 #[must_use]
98 pub const fn every(pool: PgPool, n: u64) -> Self {
99 Self {
100 pool,
101 policy: SnapshotPolicy::EveryNEvents(n),
102 }
103 }
104
105 #[must_use]
110 pub const fn never(pool: PgPool) -> Self {
111 Self {
112 pool,
113 policy: SnapshotPolicy::Never,
114 }
115 }
116
117 #[tracing::instrument(skip(self))]
126 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
127 sqlx::query(
128 r"
129 CREATE TABLE IF NOT EXISTS es_snapshots (
130 aggregate_kind TEXT NOT NULL,
131 aggregate_id UUID NOT NULL,
132 position BIGINT NOT NULL,
133 data JSONB NOT NULL,
134 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
135 PRIMARY KEY (aggregate_kind, aggregate_id)
136 )
137 ",
138 )
139 .execute(&self.pool)
140 .await?;
141
142 Ok(())
143 }
144}
145
146impl SnapshotStore<uuid::Uuid> for Store {
147 type Error = Error;
148 type Position = i64;
149
150 #[tracing::instrument(skip(self))]
151 async fn load<T>(
152 &self,
153 kind: &str,
154 id: &uuid::Uuid,
155 ) -> Result<Option<Snapshot<Self::Position, T>>, Self::Error>
156 where
157 T: DeserializeOwned,
158 {
159 let result = sqlx::query(
160 r"
161 SELECT position, data
162 FROM es_snapshots
163 WHERE aggregate_kind = $1 AND aggregate_id = $2
164 ",
165 )
166 .bind(kind)
167 .bind(id)
168 .fetch_optional(&self.pool)
169 .await?;
170
171 let snapshot = result
172 .map(|row| {
173 let position: i64 = row.get("position");
174 let data: sqlx::types::Json<serde_json::Value> = row.get("data");
175 serde_json::from_value::<T>(data.0)
176 .map(|decoded| Snapshot {
177 position,
178 data: decoded,
179 })
180 .map_err(|e| Error::Deserialization(Box::new(e)))
181 })
182 .transpose()?;
183
184 tracing::trace!(found = snapshot.is_some(), "snapshot lookup");
185 Ok(snapshot)
186 }
187
188 #[tracing::instrument(skip(self, create_snapshot))]
189 async fn offer_snapshot<CE, T, Create>(
190 &self,
191 kind: &str,
192 id: &uuid::Uuid,
193 events_since_last_snapshot: u64,
194 create_snapshot: Create,
195 ) -> Result<SnapshotOffer, OfferSnapshotError<Self::Error, CE>>
196 where
197 CE: std::error::Error + Send + Sync + 'static,
198 T: Serialize,
199 Create: FnOnce() -> Result<Snapshot<Self::Position, T>, CE>,
200 {
201 let prepared = if self.policy.should_snapshot(events_since_last_snapshot) {
202 match create_snapshot() {
203 Ok(snapshot) => serde_json::to_value(&snapshot.data)
204 .map(|data| Some((snapshot.position, data)))
205 .map_err(|e| OfferSnapshotError::Snapshot(Error::Serialization(Box::new(e)))),
206 Err(e) => Err(OfferSnapshotError::Create(e)),
207 }
208 } else {
209 Ok(None)
210 }?;
211
212 let Some((position, data)) = prepared else {
213 return Ok(SnapshotOffer::Declined);
214 };
215
216 let result = sqlx::query(
220 r"
221 INSERT INTO es_snapshots (aggregate_kind, aggregate_id, position, data)
222 VALUES ($1, $2, $3, $4)
223 ON CONFLICT (aggregate_kind, aggregate_id)
224 DO UPDATE SET position = EXCLUDED.position, data = EXCLUDED.data, created_at = now()
225 WHERE es_snapshots.position < EXCLUDED.position
226 ",
227 )
228 .bind(kind)
229 .bind(id)
230 .bind(position)
231 .bind(sqlx::types::Json(data))
232 .execute(&self.pool)
233 .await
234 .map_err(|e| OfferSnapshotError::Snapshot(Error::Database(e)))?;
235
236 let offer = if result.rows_affected() > 0 {
239 SnapshotOffer::Stored
240 } else {
241 SnapshotOffer::Declined
242 };
243
244 tracing::debug!(
245 events_since_last_snapshot,
246 ?offer,
247 "snapshot offer evaluated"
248 );
249 Ok(offer)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn policy_always_should_snapshot() {
259 let policy = SnapshotPolicy::Always;
260 assert!(policy.should_snapshot(0));
261 assert!(policy.should_snapshot(1));
262 assert!(policy.should_snapshot(100));
263 }
264
265 #[test]
266 fn policy_every_n_events_should_snapshot() {
267 let policy = SnapshotPolicy::EveryNEvents(50);
268 assert!(!policy.should_snapshot(0));
269 assert!(!policy.should_snapshot(49));
270 assert!(policy.should_snapshot(50));
271 assert!(policy.should_snapshot(100));
272 }
273
274 #[test]
275 fn policy_never_should_snapshot() {
276 let policy = SnapshotPolicy::Never;
277 assert!(!policy.should_snapshot(0));
278 assert!(!policy.should_snapshot(1));
279 assert!(!policy.should_snapshot(1000));
280 }
281}