Skip to main content

sourcery_core/snapshot/
inmemory.rs

1//! In-memory snapshot store implementation.
2
3use std::{
4    collections::HashMap,
5    hash::Hash,
6    sync::{Arc, RwLock},
7};
8
9use serde::{Serialize, de::DeserializeOwned};
10
11use super::{OfferSnapshotError, Snapshot, SnapshotOffer, SnapshotStore};
12
13/// Snapshot store policy for when to accept snapshot offers.
14///
15/// This policy is used by snapshot store implementations to determine when to
16/// persist snapshots:
17///
18/// - [`SnapshotPolicy::Always`]: Snapshot after every command (high storage
19///   cost, minimal replay)
20/// - [`SnapshotPolicy::EveryNEvents`]: Snapshot every N events (balanced
21///   approach)
22/// - [`SnapshotPolicy::Never`]: Don't persist snapshots (load-only mode)
23///
24/// # Guidelines
25///
26/// Choose based on your aggregate's characteristics:
27///
28/// ## Use `Always` when:
29/// - Event replay is expensive (complex business logic per event)
30/// - Aggregates accumulate many events (hundreds or thousands)
31/// - Storage cost is less important than read performance
32///
33/// ## Use `EveryNEvents(n)` when:
34/// - Balancing storage cost vs. replay cost
35/// - Aggregates have moderate event counts
36/// - Start with n=50-100 and tune based on profiling
37///
38/// ## Use `Never` when:
39/// - Running a read replica that consumes snapshots created elsewhere
40/// - Aggregates are short-lived (few events per instance)
41/// - Managing snapshots through an external process
42/// - Testing without snapshot overhead
43///
44/// # Example
45///
46/// ```ignore
47/// use sourcery::snapshot::inmemory::SnapshotPolicy;
48///
49/// // Use in custom snapshot store implementations
50/// impl MySnapshotStore {
51///     pub fn new(policy: SnapshotPolicy) -> Self {
52///         Self { policy }
53///     }
54/// }
55/// ```
56#[derive(Clone, Debug)]
57pub enum SnapshotPolicy {
58    /// Create a snapshot after every command.
59    Always,
60    /// Create a snapshot every N events.
61    EveryNEvents(u64),
62    /// Never create snapshots (load-only mode).
63    Never,
64}
65
66impl SnapshotPolicy {
67    /// Check if a snapshot should be created based on events since last
68    /// snapshot.
69    #[must_use]
70    pub const fn should_snapshot(&self, events_since: u64) -> bool {
71        match self {
72            Self::Always => true,
73            Self::EveryNEvents(threshold) => events_since >= *threshold,
74            Self::Never => false,
75        }
76    }
77}
78
79/// Error type for in-memory snapshot store.
80#[derive(Debug, thiserror::Error)]
81pub enum Error {
82    #[error("serialization error: {0}")]
83    Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
84    #[error("deserialization error: {0}")]
85    Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
86}
87
88impl Error {
89    fn serialization(err: impl std::error::Error + Send + Sync + 'static) -> Self {
90        Self::Serialization(Box::new(err))
91    }
92
93    fn deserialization(err: impl std::error::Error + Send + Sync + 'static) -> Self {
94        Self::Deserialization(Box::new(err))
95    }
96}
97
98/// In-memory snapshot store with configurable policy.
99///
100/// This is a reference implementation suitable for testing and development.
101/// Production systems should implement [`SnapshotStore`] with durable storage.
102///
103/// Keys are derived from `(kind, id)` directly, so `Id: Eq + Hash + Clone` is
104/// required for this store.
105///
106/// Generic over `Pos` to match the `EventStore` position type.
107///
108/// # Example
109///
110/// ```ignore
111/// use sourcery::{Repository, store::inmemory, snapshot::inmemory};
112///
113/// let repo = Repository::new(store::inmemory::Store::new())
114///     .with_snapshots(snapshot::inmemory::Store::every(100));
115/// ```
116type SnapshotMap<Id, Pos> = HashMap<SnapshotKey<Id>, Snapshot<Pos, serde_json::Value>>;
117type SharedSnapshots<Id, Pos> = Arc<RwLock<SnapshotMap<Id, Pos>>>;
118
119#[derive(Clone, Debug)]
120pub struct Store<Id, Pos> {
121    snapshots: SharedSnapshots<Id, Pos>,
122    policy: SnapshotPolicy,
123}
124
125impl<Id, Pos> Store<Id, Pos> {
126    /// Create a snapshot store that saves after every command.
127    ///
128    /// Best for aggregates with expensive replay or many events.
129    /// See the policy guidelines above for choosing an appropriate cadence.
130    #[must_use]
131    pub fn always() -> Self {
132        Self {
133            snapshots: Arc::new(RwLock::new(HashMap::new())),
134            policy: SnapshotPolicy::Always,
135        }
136    }
137
138    /// Create a snapshot store that saves every N events.
139    ///
140    /// Recommended for most use cases. Start with `n = 50-100` and tune
141    /// based on your aggregate's replay cost.
142    /// See the policy guidelines above for choosing a policy.
143    #[must_use]
144    pub fn every(n: u64) -> Self {
145        Self {
146            snapshots: Arc::new(RwLock::new(HashMap::new())),
147            policy: SnapshotPolicy::EveryNEvents(n),
148        }
149    }
150
151    /// Create a snapshot store that never saves (load-only).
152    ///
153    /// Use for read replicas, short-lived aggregates, or when managing
154    /// snapshots externally. See the policy guidelines above for when this
155    /// fits.
156    #[must_use]
157    pub fn never() -> Self {
158        Self {
159            snapshots: Arc::new(RwLock::new(HashMap::new())),
160            policy: SnapshotPolicy::Never,
161        }
162    }
163}
164
165impl<Id, Pos> Default for Store<Id, Pos> {
166    fn default() -> Self {
167        Self::always()
168    }
169}
170
171impl<Id, Pos> SnapshotStore<Id> for Store<Id, Pos>
172where
173    Id: Clone + Eq + Hash + Send + Sync,
174    Pos: Clone + Ord + Send + Sync,
175{
176    type Error = Error;
177    type Position = Pos;
178
179    #[tracing::instrument(skip(self, id))]
180    async fn load<T>(&self, kind: &str, id: &Id) -> Result<Option<Snapshot<Pos, T>>, Self::Error>
181    where
182        T: DeserializeOwned,
183    {
184        let key = SnapshotKey::new(kind, id.clone());
185        let stored = {
186            let snapshots = self.snapshots.read().expect("snapshot store lock poisoned");
187            snapshots.get(&key).cloned()
188        };
189        let snapshot = match stored {
190            Some(snapshot) => {
191                let data = serde_json::from_value(snapshot.data.clone())
192                    .map_err(Error::deserialization)?;
193                Some(Snapshot {
194                    position: snapshot.position,
195                    data,
196                })
197            }
198            None => None,
199        };
200        tracing::trace!(found = snapshot.is_some(), "snapshot lookup");
201        Ok(snapshot)
202    }
203
204    #[tracing::instrument(skip(self, id, create_snapshot))]
205    async fn offer_snapshot<CE, T, Create>(
206        &self,
207        kind: &str,
208        id: &Id,
209        events_since_last_snapshot: u64,
210        create_snapshot: Create,
211    ) -> Result<SnapshotOffer, OfferSnapshotError<Self::Error, CE>>
212    where
213        CE: std::error::Error + Send + Sync + 'static,
214        T: Serialize,
215        Create: FnOnce() -> Result<Snapshot<Pos, T>, CE>,
216    {
217        if !self.policy.should_snapshot(events_since_last_snapshot) {
218            return Ok(SnapshotOffer::Declined);
219        }
220
221        let snapshot = match create_snapshot() {
222            Ok(snapshot) => snapshot,
223            Err(e) => return Err(OfferSnapshotError::Create(e)),
224        };
225        let data = serde_json::to_value(&snapshot.data)
226            .map_err(|e| OfferSnapshotError::Snapshot(Error::serialization(e)))?;
227        let key = SnapshotKey::new(kind, id.clone());
228        let stored = Snapshot {
229            position: snapshot.position,
230            data,
231        };
232
233        let offer = {
234            let mut snapshots = self
235                .snapshots
236                .write()
237                .expect("snapshot store lock poisoned");
238            match snapshots.get(&key) {
239                Some(existing) if existing.position >= stored.position => SnapshotOffer::Declined,
240                _ => {
241                    snapshots.insert(key, stored);
242                    SnapshotOffer::Stored
243                }
244            }
245        };
246
247        tracing::debug!(
248            events_since_last_snapshot,
249            ?offer,
250            "snapshot offer evaluated"
251        );
252        Ok(offer)
253    }
254}
255
256#[derive(Clone, Debug, Eq, PartialEq, Hash)]
257struct SnapshotKey<Id> {
258    kind: String,
259    id: Id,
260}
261
262impl<Id> SnapshotKey<Id> {
263    fn new(kind: &str, id: Id) -> Self {
264        Self {
265            kind: kind.to_string(),
266            id,
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::convert::Infallible;
274
275    use super::*;
276
277    #[test]
278    fn always_should_snapshot() {
279        let policy = SnapshotPolicy::Always;
280        assert!(policy.should_snapshot(0));
281        assert!(policy.should_snapshot(1));
282        assert!(policy.should_snapshot(100));
283    }
284
285    #[test]
286    fn every_n_at_threshold() {
287        let policy = SnapshotPolicy::EveryNEvents(3);
288        assert!(policy.should_snapshot(3));
289        assert!(policy.should_snapshot(4));
290        assert!(policy.should_snapshot(100));
291    }
292
293    #[test]
294    fn every_n_below_threshold() {
295        let policy = SnapshotPolicy::EveryNEvents(3);
296        assert!(!policy.should_snapshot(0));
297        assert!(!policy.should_snapshot(1));
298        assert!(!policy.should_snapshot(2));
299    }
300
301    #[test]
302    fn never_should_snapshot() {
303        let policy = SnapshotPolicy::Never;
304        assert!(!policy.should_snapshot(0));
305        assert!(!policy.should_snapshot(1));
306        assert!(!policy.should_snapshot(100));
307    }
308
309    #[tokio::test]
310    async fn load_returns_none_for_missing() {
311        let store = Store::<String, u64>::always();
312        let result: Option<Snapshot<u64, String>> =
313            store.load("test", &"id".to_string()).await.unwrap();
314        assert!(result.is_none());
315    }
316
317    #[tokio::test]
318    async fn load_returns_stored_snapshot() {
319        let store = Store::<String, u64>::always();
320        let id = "test-id".to_string();
321
322        store
323            .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
324                Ok(Snapshot {
325                    position: 5,
326                    data: "snapshot-data".to_string(),
327                })
328            })
329            .await
330            .unwrap();
331
332        let loaded: Snapshot<u64, String> = store.load("test", &id).await.unwrap().unwrap();
333        assert_eq!(loaded.position, 5);
334        assert_eq!(loaded.data, "snapshot-data");
335    }
336
337    #[tokio::test]
338    async fn offer_declines_older_position() {
339        let store = Store::<String, u64>::always();
340        let id = "test-id".to_string();
341
342        // Store initial snapshot at position 10
343        store
344            .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
345                Ok(Snapshot {
346                    position: 10,
347                    data: "first",
348                })
349            })
350            .await
351            .unwrap();
352
353        // Try to store older snapshot at position 5 - should be declined
354        let result = store
355            .offer_snapshot::<Infallible, _, _>("test", &id, 1, || {
356                Ok(Snapshot {
357                    position: 5,
358                    data: "older",
359                })
360            })
361            .await
362            .unwrap();
363
364        assert_eq!(result, SnapshotOffer::Declined);
365
366        // Verify original snapshot is still there
367        let loaded: Snapshot<u64, String> = store.load("test", &id).await.unwrap().unwrap();
368        assert_eq!(loaded.position, 10);
369        assert_eq!(loaded.data, "first");
370    }
371
372    #[test]
373    fn default_is_always() {
374        let store = Store::<String, u64>::default();
375        assert!(matches!(store.policy, SnapshotPolicy::Always));
376    }
377}