Skip to main content

sourcery_core/store/
inmemory.rs

1//! In-memory event store implementation for testing.
2//!
3//! This module provides [`Store`], a thread-safe in-memory implementation of
4//! [`EventStore`](super::EventStore) suitable for unit tests and examples.
5//!
6//! # Example
7//!
8//! ```
9//! use sourcery_core::store::inmemory;
10//!
11//! let store: inmemory::Store<String, ()> = inmemory::Store::new();
12//! ```
13
14use std::{
15    collections::HashMap,
16    pin::Pin,
17    sync::{Arc, RwLock},
18};
19
20use nonempty::NonEmpty;
21use tokio::sync::broadcast;
22
23use crate::{
24    concurrency::ConcurrencyConflict,
25    store::{
26        CommitError, Committed, EventFilter, EventStore, GloballyOrderedStore, LoadEventsResult,
27        OptimisticCommitError, StoredEvent, StreamKey,
28    },
29    subscription::SubscribableStore,
30};
31
32/// Event stream stored in memory with fixed position and data types.
33type InMemoryStream<Id, M> = Vec<StoredEvent<Id, u64, serde_json::Value, M>>;
34
35/// In-memory event store that keeps streams in a hash map.
36///
37/// Uses a global sequence counter (`Position = u64`) to maintain chronological
38/// ordering across streams, enabling cross-aggregate projections that need to
39/// interleave events by time rather than by stream name.
40///
41/// Generic over:
42/// - `Id`: Aggregate identifier type (must be hashable/equatable for map keys)
43/// - `M`: Metadata type (use `()` when not needed)
44///
45/// This store uses `serde_json::Value` as the internal data representation.
46#[derive(Clone)]
47pub struct Store<Id, M> {
48    inner: Arc<RwLock<Inner<Id, M>>>,
49}
50
51struct Inner<Id, M> {
52    streams: HashMap<StreamKey<Id>, InMemoryStream<Id, M>>,
53    next_position: u64,
54    /// Broadcasts the position of newly committed events. Subscribers use this
55    /// to detect live writes without polling.
56    notify_tx: broadcast::Sender<u64>,
57}
58
59impl<Id, M> Store<Id, M> {
60    #[must_use]
61    pub fn new() -> Self {
62        let (notify_tx, _) = broadcast::channel(1024);
63        Self {
64            inner: Arc::new(RwLock::new(Inner {
65                streams: HashMap::new(),
66                next_position: 0,
67                notify_tx,
68            })),
69        }
70    }
71}
72
73impl<Id, M> Default for Store<Id, M> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79/// Error type for in-memory store.
80#[derive(Debug, thiserror::Error)]
81pub enum InMemoryError {
82    #[error("serialization error: {0}")]
83    Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
84    #[error("deserialization error: {0}")]
85    Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
86}
87
88impl<Id, M> EventStore for Store<Id, M>
89where
90    Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
91    M: Clone + Send + Sync + 'static,
92{
93    type Data = serde_json::Value;
94    type Error = InMemoryError;
95    type Id = Id;
96    type Metadata = M;
97    type Position = u64;
98
99    fn decode_event<E>(
100        &self,
101        stored: &StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
102    ) -> Result<E, Self::Error>
103    where
104        E: crate::event::DomainEvent + serde::de::DeserializeOwned,
105    {
106        serde_json::from_value(stored.data.clone())
107            .map_err(|e| InMemoryError::Deserialization(Box::new(e)))
108    }
109
110    #[tracing::instrument(skip(self, aggregate_id))]
111    fn stream_version<'a>(
112        &'a self,
113        aggregate_kind: &'a str,
114        aggregate_id: &'a Self::Id,
115    ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + 'a {
116        let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
117        let version = {
118            let inner = self.inner.read().expect("in-memory store lock poisoned");
119            inner
120                .streams
121                .get(&stream_key)
122                .and_then(|s| s.last().map(|e| e.position))
123        };
124        tracing::trace!(?version, "retrieved stream version");
125        std::future::ready(Ok(version))
126    }
127
128    #[tracing::instrument(skip(self, aggregate_id, events, metadata), fields(event_count = events.len()))]
129    fn commit_events<'a, E>(
130        &'a self,
131        aggregate_kind: &'a str,
132        aggregate_id: &'a Self::Id,
133        events: NonEmpty<E>,
134        metadata: &'a Self::Metadata,
135    ) -> impl Future<Output = Result<Committed<u64>, CommitError<Self::Error>>> + Send + 'a
136    where
137        E: crate::event::EventKind + serde::Serialize + Send + Sync + 'a,
138        Self::Metadata: Clone,
139    {
140        let result = (|| {
141            // Serialize all events first
142            let mut staged = Vec::with_capacity(events.len());
143            for (index, event) in events.iter().enumerate() {
144                let data = serde_json::to_value(event).map_err(|e| CommitError::Serialization {
145                    index,
146                    source: InMemoryError::Serialization(Box::new(e)),
147                })?;
148                staged.push((event.kind().to_string(), data));
149            }
150
151            let mut inner = self.inner.write().expect("in-memory store lock poisoned");
152            let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
153            let mut last_position = 0;
154            let mut stored = Vec::with_capacity(staged.len());
155
156            for (kind, data) in staged {
157                let position = inner.next_position;
158                inner.next_position += 1;
159                last_position = position;
160                stored.push(StoredEvent {
161                    aggregate_kind: aggregate_kind.to_string(),
162                    aggregate_id: aggregate_id.clone(),
163                    kind,
164                    position,
165                    data,
166                    metadata: metadata.clone(),
167                });
168            }
169
170            inner.streams.entry(stream_key).or_default().extend(stored);
171            let notify_tx = inner.notify_tx.clone();
172            drop(inner);
173            // Notify subscribers of the new position (ignore send errors -- no
174            // receivers is fine)
175            let _ = notify_tx.send(last_position);
176            tracing::debug!(events_appended = events.len(), "events committed to stream");
177            Ok(Committed { last_position })
178        })();
179
180        std::future::ready(result)
181    }
182
183    #[tracing::instrument(skip(self, aggregate_id, events, metadata), fields(event_count = events.len()))]
184    fn commit_events_optimistic<'a, E>(
185        &'a self,
186        aggregate_kind: &'a str,
187        aggregate_id: &'a Self::Id,
188        expected_version: Option<Self::Position>,
189        events: NonEmpty<E>,
190        metadata: &'a Self::Metadata,
191    ) -> impl Future<Output = Result<Committed<u64>, OptimisticCommitError<u64, Self::Error>>> + Send + 'a
192    where
193        E: crate::event::EventKind + serde::Serialize + Send + Sync + 'a,
194        Self::Metadata: Clone,
195    {
196        let result = (|| {
197            // Serialize all events first
198            let mut staged = Vec::with_capacity(events.len());
199            for (index, event) in events.iter().enumerate() {
200                let data = serde_json::to_value(event).map_err(|e| {
201                    OptimisticCommitError::Serialization {
202                        index,
203                        source: InMemoryError::Serialization(Box::new(e)),
204                    }
205                })?;
206                staged.push((event.kind().to_string(), data));
207            }
208
209            let mut inner = self.inner.write().expect("in-memory store lock poisoned");
210            let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
211
212            // Check version
213            let current = inner
214                .streams
215                .get(&stream_key)
216                .and_then(|s| s.last().map(|e| e.position));
217
218            match expected_version {
219                Some(expected) => {
220                    // Expected specific version
221                    if current != Some(expected) {
222                        tracing::debug!(?expected, ?current, "version mismatch, rejecting commit");
223                        return Err(ConcurrencyConflict {
224                            expected: Some(expected),
225                            actual: current,
226                        }
227                        .into());
228                    }
229                }
230                None => {
231                    // Expected new stream (no events)
232                    if let Some(actual) = current {
233                        tracing::debug!(
234                            ?actual,
235                            "stream already exists, rejecting new aggregate commit"
236                        );
237                        return Err(ConcurrencyConflict {
238                            expected: None,
239                            actual: Some(actual),
240                        }
241                        .into());
242                    }
243                }
244            }
245
246            let mut last_position = 0;
247            let mut stored = Vec::with_capacity(staged.len());
248
249            for (kind, data) in staged {
250                let position = inner.next_position;
251                inner.next_position += 1;
252                last_position = position;
253                stored.push(StoredEvent {
254                    aggregate_kind: aggregate_kind.to_string(),
255                    aggregate_id: aggregate_id.clone(),
256                    kind,
257                    position,
258                    data,
259                    metadata: metadata.clone(),
260                });
261            }
262
263            inner.streams.entry(stream_key).or_default().extend(stored);
264            let notify_tx = inner.notify_tx.clone();
265            drop(inner);
266            let _ = notify_tx.send(last_position);
267            tracing::debug!(
268                events_appended = events.len(),
269                "events committed to stream (optimistic)"
270            );
271            Ok(Committed { last_position })
272        })();
273
274        std::future::ready(result)
275    }
276
277    #[tracing::instrument(skip(self, filters), fields(filter_count = filters.len()))]
278    fn load_events<'a>(
279        &'a self,
280        filters: &'a [EventFilter<Self::Id, Self::Position>],
281    ) -> impl Future<
282        Output = LoadEventsResult<
283            Self::Id,
284            Self::Position,
285            Self::Data,
286            Self::Metadata,
287            Self::Error,
288        >,
289    > + Send
290    + 'a {
291        let result = load_matching_events(
292            &self.inner.read().expect("in-memory store lock poisoned"),
293            filters,
294        );
295
296        tracing::debug!(events_loaded = result.len(), "loaded events from store");
297        std::future::ready(Ok(result))
298    }
299}
300
301impl<Id, M> GloballyOrderedStore for Store<Id, M>
302where
303    Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
304    M: Clone + Send + Sync + 'static,
305{
306}
307
308impl<Id, M> SubscribableStore for Store<Id, M>
309where
310    Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
311    M: Clone + Send + Sync + 'static,
312{
313    fn subscribe(
314        &self,
315        filters: &[EventFilter<Self::Id, Self::Position>],
316        from_position: Option<Self::Position>,
317    ) -> Pin<
318        Box<
319            dyn futures_core::Stream<
320                    Item = Result<
321                        StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
322                        Self::Error,
323                    >,
324                > + Send
325                + '_,
326        >,
327    >
328    where
329        Self::Position: Ord,
330    {
331        let filters = filters.to_vec();
332        let inner = self.inner.clone();
333
334        // Subscribe to broadcast FIRST to avoid missing events committed
335        // between the historical load and live listening.
336        let mut rx = {
337            let guard = inner.read().expect("in-memory store lock poisoned");
338            guard.notify_tx.subscribe()
339        };
340
341        Box::pin(async_stream::stream! {
342            // 1. Load and yield historical events
343            let historical = {
344                let guard = inner.read().expect("in-memory store lock poisoned");
345                load_matching_events(&guard, &filters)
346            };
347
348            let mut last_position = from_position;
349
350            for event in historical {
351                // Skip events at or before our starting position
352                if let Some(ref lp) = last_position
353                    && event.position <= *lp
354                {
355                    continue;
356                }
357                last_position = Some(event.position);
358                yield Ok(event);
359            }
360
361            // 2. Process live events from broadcast
362            loop {
363                match rx.recv().await {
364                    Ok(notified_position) => {
365                        // Skip positions we've already seen
366                        if let Some(ref lp) = last_position
367                            && notified_position <= *lp
368                        {
369                            continue;
370                        }
371
372                        // Load the event at this position and check filters
373                        let events = {
374                            let guard = inner.read().expect("in-memory store lock poisoned");
375                            load_matching_events(&guard, &filters)
376                        };
377
378                        for event in events {
379                            if let Some(ref lp) = last_position
380                                && event.position <= *lp
381                            {
382                                continue;
383                            }
384                            last_position = Some(event.position);
385                            yield Ok(event);
386                        }
387                    }
388                    Err(broadcast::error::RecvError::Lagged(_)) => {
389                        // We missed some notifications. Re-load from our last
390                        // known position to catch up.
391                        let events = {
392                            let guard = inner.read().expect("in-memory store lock poisoned");
393                            load_matching_events(&guard, &filters)
394                        };
395
396                        for event in events {
397                            if let Some(ref lp) = last_position
398                                && event.position <= *lp
399                            {
400                                continue;
401                            }
402                            last_position = Some(event.position);
403                            yield Ok(event);
404                        }
405                    }
406                    Err(broadcast::error::RecvError::Closed) => {
407                        // Store dropped, end stream
408                        break;
409                    }
410                }
411            }
412        })
413    }
414}
415
416/// Load all events matching the given filters, sorted by position.
417fn load_matching_events<Id, M>(
418    inner: &Inner<Id, M>,
419    filters: &[EventFilter<Id, u64>],
420) -> Vec<StoredEvent<Id, u64, serde_json::Value, M>>
421where
422    Id: Clone + Eq + std::hash::Hash,
423    M: Clone,
424{
425    use std::collections::HashSet;
426
427    let mut result = Vec::new();
428    let mut seen: HashSet<(StreamKey<Id>, String)> = HashSet::new();
429
430    let mut all_kinds: HashMap<String, Option<u64>> = HashMap::new();
431    let mut by_aggregate: HashMap<StreamKey<Id>, HashMap<String, Option<u64>>> = HashMap::new();
432
433    for filter in filters {
434        if let (Some(kind), Some(id)) = (&filter.aggregate_kind, &filter.aggregate_id) {
435            by_aggregate
436                .entry(StreamKey::new(kind.clone(), id.clone()))
437                .or_default()
438                .insert(filter.event_kind.clone(), filter.after_position);
439        } else {
440            all_kinds.insert(filter.event_kind.clone(), filter.after_position);
441        }
442    }
443
444    let passes_position_filter =
445        |event: &StoredEvent<Id, u64, serde_json::Value, M>, after_position: Option<u64>| -> bool {
446            after_position.is_none_or(|after| event.position > after)
447        };
448
449    for (stream_key, kinds) in &by_aggregate {
450        if let Some(stream) = inner.streams.get(stream_key) {
451            for event in stream {
452                if let Some(&after_pos) = kinds.get(&event.kind)
453                    && passes_position_filter(event, after_pos)
454                {
455                    seen.insert((
456                        StreamKey::new(event.aggregate_kind.clone(), event.aggregate_id.clone()),
457                        event.kind.clone(),
458                    ));
459                    result.push(event.clone());
460                }
461            }
462        }
463    }
464
465    if !all_kinds.is_empty() {
466        for stream in inner.streams.values() {
467            for event in stream {
468                if let Some(&after_pos) = all_kinds.get(&event.kind)
469                    && passes_position_filter(event, after_pos)
470                {
471                    let key = (
472                        StreamKey::new(event.aggregate_kind.clone(), event.aggregate_id.clone()),
473                        event.kind.clone(),
474                    );
475                    if !seen.contains(&key) {
476                        result.push(event.clone());
477                    }
478                }
479            }
480        }
481    }
482
483    result.sort_by_key(|event| event.position);
484    result
485}
486
487#[cfg(test)]
488mod tests {
489    use serde::{Deserialize, Serialize};
490
491    use super::*;
492    use crate::event::DomainEvent;
493
494    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
495    struct TestEvent {
496        value: i32,
497    }
498
499    impl DomainEvent for TestEvent {
500        const KIND: &'static str = "test-event";
501    }
502
503    #[test]
504    fn new_has_no_streams() {
505        let store = Store::<String, ()>::new();
506        let inner = store.inner.read().unwrap();
507        assert!(inner.streams.is_empty());
508        assert_eq!(inner.next_position, 0);
509        drop(inner);
510    }
511
512    #[test]
513    fn decode_event_deserializes() {
514        let store = Store::<String, ()>::new();
515        let event = TestEvent { value: 42 };
516        let data = serde_json::to_value(&event).unwrap();
517
518        // Create a stored event
519        let stored = StoredEvent {
520            aggregate_kind: "test-agg".to_string(),
521            aggregate_id: "id".to_string(),
522            kind: "test-event".to_string(),
523            position: 0,
524            data,
525            metadata: (),
526        };
527
528        let decoded: TestEvent = store.decode_event(&stored).unwrap();
529        assert_eq!(decoded, event);
530    }
531
532    #[test]
533    fn error_display_serialization() {
534        let err = InMemoryError::Serialization(Box::new(std::io::Error::other("test")));
535        assert!(err.to_string().contains("serialization error"));
536    }
537
538    #[test]
539    fn error_display_deserialization() {
540        let err = InMemoryError::Deserialization(Box::new(std::io::Error::other("test")));
541        assert!(err.to_string().contains("deserialization error"));
542    }
543
544    #[tokio::test]
545    async fn version_returns_none_for_new_stream() {
546        let store = Store::<String, ()>::new();
547        let version = store
548            .stream_version("test-agg", &"id".to_string())
549            .await
550            .unwrap();
551        assert!(version.is_none());
552    }
553
554    #[tokio::test]
555    async fn version_returns_position_after_commit() {
556        let store = Store::<String, ()>::new();
557        let id = "id".to_string();
558        let events = NonEmpty::singleton(TestEvent { value: 1 });
559
560        store
561            .commit_events("test-agg", &id, events, &())
562            .await
563            .unwrap();
564
565        let version = store.stream_version("test-agg", &id).await.unwrap();
566        assert_eq!(version, Some(0));
567    }
568
569    #[tokio::test]
570    async fn commit_with_wrong_version_returns_conflict() {
571        let store = Store::<String, ()>::new();
572        let id = "id".to_string();
573        let events1 = NonEmpty::singleton(TestEvent { value: 1 });
574
575        // First, create the stream
576        store
577            .commit_events("test-agg", &id, events1, &())
578            .await
579            .unwrap();
580
581        // Try to commit with wrong expected version
582        let events2 = NonEmpty::singleton(TestEvent { value: 2 });
583        let result = store
584            .commit_events_optimistic("test-agg", &id, Some(99), events2, &())
585            .await;
586
587        assert!(matches!(result, Err(OptimisticCommitError::Conflict(_))));
588    }
589
590    #[tokio::test]
591    async fn commit_new_stream_fails_if_stream_exists() {
592        let store = Store::<String, ()>::new();
593        let id = "id".to_string();
594        let events = NonEmpty::singleton(TestEvent { value: 1 });
595
596        // First, create the stream
597        store
598            .commit_events("test-agg", &id, events, &())
599            .await
600            .unwrap();
601
602        // Try to commit expecting new stream
603        let events2 = NonEmpty::singleton(TestEvent { value: 2 });
604        let result = store
605            .commit_events_optimistic("test-agg", &id, None, events2, &())
606            .await;
607
608        assert!(matches!(result, Err(OptimisticCommitError::Conflict(_))));
609    }
610}