1use std::{
4 collections::HashMap,
5 hash::Hash,
6 sync::{Arc, RwLock},
7};
8
9use serde::{Serialize, de::DeserializeOwned};
10
11use super::{OfferSnapshotError, Snapshot, SnapshotOffer, SnapshotStore};
12
13#[derive(Clone, Debug)]
57pub enum SnapshotPolicy {
58 Always,
60 EveryNEvents(u64),
62 Never,
64}
65
66impl SnapshotPolicy {
67 #[must_use]
70 pub const fn should_snapshot(&self, events_since: u64) -> bool {
71 match self {
72 Self::Always => true,
73 Self::EveryNEvents(threshold) => events_since >= *threshold,
74 Self::Never => false,
75 }
76 }
77}
78
79#[derive(Debug, thiserror::Error)]
81pub enum Error {
82 #[error("serialization error: {0}")]
83 Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
84 #[error("deserialization error: {0}")]
85 Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
86}
87
88impl Error {
89 fn serialization(err: impl std::error::Error + Send + Sync + 'static) -> Self {
90 Self::Serialization(Box::new(err))
91 }
92
93 fn deserialization(err: impl std::error::Error + Send + Sync + 'static) -> Self {
94 Self::Deserialization(Box::new(err))
95 }
96}
97
98type SnapshotMap<Id, Pos> = HashMap<SnapshotKey<Id>, Snapshot<Pos, serde_json::Value>>;
117type SharedSnapshots<Id, Pos> = Arc<RwLock<SnapshotMap<Id, Pos>>>;
118
119#[derive(Clone, Debug)]
120pub struct Store<Id, Pos> {
121 snapshots: SharedSnapshots<Id, Pos>,
122 policy: SnapshotPolicy,
123}
124
125impl<Id, Pos> Store<Id, Pos> {
126 #[must_use]
131 pub fn always() -> Self {
132 Self {
133 snapshots: Arc::new(RwLock::new(HashMap::new())),
134 policy: SnapshotPolicy::Always,
135 }
136 }
137
138 #[must_use]
144 pub fn every(n: u64) -> Self {
145 Self {
146 snapshots: Arc::new(RwLock::new(HashMap::new())),
147 policy: SnapshotPolicy::EveryNEvents(n),
148 }
149 }
150
151 #[must_use]
157 pub fn never() -> Self {
158 Self {
159 snapshots: Arc::new(RwLock::new(HashMap::new())),
160 policy: SnapshotPolicy::Never,
161 }
162 }
163}
164
165impl<Id, Pos> Default for Store<Id, Pos> {
166 fn default() -> Self {
167 Self::always()
168 }
169}
170
171impl<Id, Pos> SnapshotStore<Id> for Store<Id, Pos>
172where
173 Id: Clone + Eq + Hash + Send + Sync,
174 Pos: Clone + Ord + Send + Sync,
175{
176 type Error = Error;
177 type Position = Pos;
178
179 #[tracing::instrument(skip(self, id))]
180 async fn load<T>(&self, kind: &str, id: &Id) -> Result<Option<Snapshot<Pos, T>>, Self::Error>
181 where
182 T: DeserializeOwned,
183 {
184 let key = SnapshotKey::new(kind, id.clone());
185 let stored = {
186 let snapshots = self.snapshots.read().expect("snapshot store lock poisoned");
187 snapshots.get(&key).cloned()
188 };
189 let snapshot = match stored {
190 Some(snapshot) => {
191 let data = serde_json::from_value(snapshot.data.clone())
192 .map_err(Error::deserialization)?;
193 Some(Snapshot {
194 position: snapshot.position,
195 data,
196 })
197 }
198 None => None,
199 };
200 tracing::trace!(found = snapshot.is_some(), "snapshot lookup");
201 Ok(snapshot)
202 }
203
204 #[tracing::instrument(skip(self, id, create_snapshot))]
205 async fn offer_snapshot<CE, T, Create>(
206 &self,
207 kind: &str,
208 id: &Id,
209 events_since_last_snapshot: u64,
210 create_snapshot: Create,
211 ) -> Result<SnapshotOffer, OfferSnapshotError<Self::Error, CE>>
212 where
213 CE: std::error::Error + Send + Sync + 'static,
214 T: Serialize,
215 Create: FnOnce() -> Result<Snapshot<Pos, T>, CE>,
216 {
217 if !self.policy.should_snapshot(events_since_last_snapshot) {
218 return Ok(SnapshotOffer::Declined);
219 }
220
221 let snapshot = match create_snapshot() {
222 Ok(snapshot) => snapshot,
223 Err(e) => return Err(OfferSnapshotError::Create(e)),
224 };
225 let data = serde_json::to_value(&snapshot.data)
226 .map_err(|e| OfferSnapshotError::Snapshot(Error::serialization(e)))?;
227 let key = SnapshotKey::new(kind, id.clone());
228 let stored = Snapshot {
229 position: snapshot.position,
230 data,
231 };
232
233 let offer = {
234 let mut snapshots = self
235 .snapshots
236 .write()
237 .expect("snapshot store lock poisoned");
238 match snapshots.get(&key) {
239 Some(existing) if existing.position >= stored.position => SnapshotOffer::Declined,
240 _ => {
241 snapshots.insert(key, stored);
242 SnapshotOffer::Stored
243 }
244 }
245 };
246
247 tracing::debug!(
248 events_since_last_snapshot,
249 ?offer,
250 "snapshot offer evaluated"
251 );
252 Ok(offer)
253 }
254}
255
256#[derive(Clone, Debug, Eq, PartialEq, Hash)]
257struct SnapshotKey<Id> {
258 kind: String,
259 id: Id,
260}
261
262impl<Id> SnapshotKey<Id> {
263 fn new(kind: &str, id: Id) -> Self {
264 Self {
265 kind: kind.to_string(),
266 id,
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use std::convert::Infallible;
274
275 use super::*;
276
277 #[test]
278 fn always_should_snapshot() {
279 let policy = SnapshotPolicy::Always;
280 assert!(policy.should_snapshot(0));
281 assert!(policy.should_snapshot(1));
282 assert!(policy.should_snapshot(100));
283 }
284
285 #[test]
286 fn every_n_at_threshold() {
287 let policy = SnapshotPolicy::EveryNEvents(3);
288 assert!(policy.should_snapshot(3));
289 assert!(policy.should_snapshot(4));
290 assert!(policy.should_snapshot(100));
291 }
292
293 #[test]
294 fn every_n_below_threshold() {
295 let policy = SnapshotPolicy::EveryNEvents(3);
296 assert!(!policy.should_snapshot(0));
297 assert!(!policy.should_snapshot(1));
298 assert!(!policy.should_snapshot(2));
299 }
300
301 #[test]
302 fn never_should_snapshot() {
303 let policy = SnapshotPolicy::Never;
304 assert!(!policy.should_snapshot(0));
305 assert!(!policy.should_snapshot(1));
306 assert!(!policy.should_snapshot(100));
307 }
308
309 #[tokio::test]
310 async fn load_returns_none_for_missing() {
311 let store = Store::<String, u64>::always();
312 let result: Option<Snapshot<u64, String>> =
313 store.load("test", &"id".to_string()).await.unwrap();
314 assert!(result.is_none());
315 }
316
317 #[tokio::test]
318 async fn load_returns_stored_snapshot() {
319 let store = Store::<String, u64>::always();
320 let id = "test-id".to_string();
321
322 store
323 .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
324 Ok(Snapshot {
325 position: 5,
326 data: "snapshot-data".to_string(),
327 })
328 })
329 .await
330 .unwrap();
331
332 let loaded: Snapshot<u64, String> = store.load("test", &id).await.unwrap().unwrap();
333 assert_eq!(loaded.position, 5);
334 assert_eq!(loaded.data, "snapshot-data");
335 }
336
337 #[tokio::test]
338 async fn offer_declines_older_position() {
339 let store = Store::<String, u64>::always();
340 let id = "test-id".to_string();
341
342 store
344 .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
345 Ok(Snapshot {
346 position: 10,
347 data: "first",
348 })
349 })
350 .await
351 .unwrap();
352
353 let result = store
355 .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
356 Ok(Snapshot {
357 position: 5,
358 data: "older",
359 })
360 })
361 .await
362 .unwrap();
363
364 assert_eq!(result, SnapshotOffer::Declined);
365
366 let loaded: Snapshot<u64, String> = store.load("test", &id).await.unwrap().unwrap();
368 assert_eq!(loaded.position, 10);
369 assert_eq!(loaded.data, "first");
370 }
371
372 #[test]
373 fn default_is_always() {
374 let store = Store::<String, u64>::default();
375 assert!(matches!(store.policy, SnapshotPolicy::Always));
376 }
377}