sourcery_core/store/
inmemory.rs

1use std::{collections::HashMap, convert::Infallible};
2
3use crate::{
4    codec::Codec,
5    concurrency::{ConcurrencyConflict, ConcurrencyStrategy},
6    store::{
7        AppendError, EventFilter, EventStore, PersistableEvent, StoredEvent, StreamKey, Transaction,
8    },
9};
10
11/// In-memory event store that keeps streams in a hash map.
12///
13/// Uses a global sequence counter (`Position = u64`) to maintain chronological
14/// ordering across streams, enabling cross-aggregate projections that need to
15/// interleave events by time rather than by stream name.
16///
17/// Generic over:
18/// - `Id`: Aggregate identifier type (must be hashable/equatable for map keys)
19/// - `C`: Serialization codec
20/// - `M`: Metadata type (use `()` when not needed)
21pub struct Store<Id, C, M>
22where
23    C: Codec,
24{
25    codec: C,
26    streams: HashMap<StreamKey<Id>, Vec<StoredEvent<Id, u64, M>>>,
27    next_position: u64,
28}
29
30impl<Id, C, M> Store<Id, C, M>
31where
32    C: Codec,
33{
34    #[must_use]
35    pub fn new(codec: C) -> Self {
36        Self {
37            codec,
38            streams: HashMap::new(),
39            next_position: 0,
40        }
41    }
42}
43
44/// Infallible error type that implements `std::error::Error`.
45///
46/// Used by [`InMemoryEventStore`] which cannot fail.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
48#[error("infallible")]
49pub enum InMemoryError {}
50
51impl From<Infallible> for InMemoryError {
52    fn from(x: Infallible) -> Self {
53        match x {}
54    }
55}
56
57impl<Id, C, M> EventStore for Store<Id, C, M>
58where
59    Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
60    C: Codec + Clone + Send + Sync + 'static,
61    M: Clone + Send + Sync + 'static,
62{
63    type Codec = C;
64    // Global sequence for chronological ordering
65    type Error = InMemoryError;
66    type Id = Id;
67    type Metadata = M;
68    type Position = u64;
69
70    fn codec(&self) -> &Self::Codec {
71        &self.codec
72    }
73
74    #[tracing::instrument(skip(self, aggregate_id))]
75    fn stream_version<'a>(
76        &'a self,
77        aggregate_kind: &'a str,
78        aggregate_id: &'a Self::Id,
79    ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + 'a {
80        let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
81        let version = self
82            .streams
83            .get(&stream_key)
84            .and_then(|s| s.last().map(|e| e.position));
85        tracing::trace!(?version, "retrieved stream version");
86        std::future::ready(Ok(version))
87    }
88
89    fn begin<Conc: ConcurrencyStrategy>(
90        &mut self,
91        aggregate_kind: &str,
92        aggregate_id: Self::Id,
93        expected_version: Option<Self::Position>,
94    ) -> Transaction<'_, Self, Conc> {
95        Transaction::new(
96            self,
97            aggregate_kind.to_string(),
98            aggregate_id,
99            expected_version,
100        )
101    }
102
103    #[tracing::instrument(skip(self, aggregate_id, events), fields(event_count = events.len()))]
104    fn append<'a>(
105        &'a mut self,
106        aggregate_kind: &'a str,
107        aggregate_id: &'a Self::Id,
108        expected_version: Option<u64>,
109        events: Vec<PersistableEvent<Self::Metadata>>,
110    ) -> impl Future<Output = Result<(), AppendError<u64, Self::Error>>> + Send + 'a {
111        let event_count = events.len();
112
113        let result = (|| {
114            // Check version if provided
115            if let Some(expected) = expected_version {
116                let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
117                let current = self
118                    .streams
119                    .get(&stream_key)
120                    .and_then(|s| s.last().map(|e| e.position));
121                if current != Some(expected) {
122                    tracing::debug!(?expected, ?current, "version mismatch, rejecting append");
123                    return Err(ConcurrencyConflict {
124                        expected: Some(expected),
125                        actual: current,
126                    }
127                    .into());
128                }
129            }
130
131            let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
132            let stored: Vec<StoredEvent<Id, u64, M>> = events
133                .into_iter()
134                .map(|e| {
135                    let position = self.next_position;
136                    self.next_position += 1;
137                    StoredEvent {
138                        aggregate_kind: aggregate_kind.to_string(),
139                        aggregate_id: aggregate_id.clone(),
140                        kind: e.kind,
141                        position,
142                        data: e.data,
143                        metadata: e.metadata,
144                    }
145                })
146                .collect();
147
148            self.streams.entry(stream_key).or_default().extend(stored);
149            tracing::debug!(events_appended = event_count, "events appended to stream");
150            Ok(())
151        })();
152
153        std::future::ready(result)
154    }
155
156    #[tracing::instrument(skip(self, aggregate_id, events), fields(event_count = events.len()))]
157    fn append_expecting_new<'a>(
158        &'a mut self,
159        aggregate_kind: &'a str,
160        aggregate_id: &'a Self::Id,
161        events: Vec<PersistableEvent<Self::Metadata>>,
162    ) -> impl Future<Output = Result<(), AppendError<u64, Self::Error>>> + Send + 'a {
163        let event_count = events.len();
164
165        let result = (|| {
166            // Check that stream is empty (new aggregate)
167            let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
168            let current = self
169                .streams
170                .get(&stream_key)
171                .and_then(|s| s.last().map(|e| e.position));
172
173            if let Some(actual) = current {
174                // Stream already has events - conflict!
175                tracing::debug!(
176                    ?actual,
177                    "stream already exists, rejecting new aggregate append"
178                );
179                return Err(ConcurrencyConflict {
180                    expected: None, // "expected new stream"
181                    actual: Some(actual),
182                }
183                .into());
184            }
185
186            // Stream is empty, proceed with append (no further version check needed)
187            let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
188            let stored: Vec<StoredEvent<Id, u64, M>> = events
189                .into_iter()
190                .map(|e| {
191                    let position = self.next_position;
192                    self.next_position += 1;
193                    StoredEvent {
194                        aggregate_kind: aggregate_kind.to_string(),
195                        aggregate_id: aggregate_id.clone(),
196                        kind: e.kind,
197                        position,
198                        data: e.data,
199                        metadata: e.metadata,
200                    }
201                })
202                .collect();
203
204            self.streams.entry(stream_key).or_default().extend(stored);
205            tracing::debug!(
206                events_appended = event_count,
207                "new stream created with events"
208            );
209            Ok(())
210        })();
211
212        std::future::ready(result)
213    }
214
215    #[tracing::instrument(skip(self, filters), fields(filter_count = filters.len()))]
216    fn load_events<'a>(
217        &'a self,
218        filters: &'a [EventFilter<Self::Id, Self::Position>],
219    ) -> impl Future<Output = Result<Vec<StoredEvent<Id, u64, M>>, Self::Error>> + Send + 'a {
220        use std::collections::HashSet;
221
222        let mut result = Vec::new();
223        let mut seen: HashSet<(StreamKey<Id>, String)> = HashSet::new(); // (stream key, event kind)
224
225        // Group filters by aggregate ID, tracking each filter's individual position
226        // constraint Maps event_kind -> after_position for that specific filter
227        let mut all_kinds: HashMap<String, Option<u64>> = HashMap::new(); // Filters with no aggregate restriction
228        let mut by_aggregate: HashMap<StreamKey<Id>, HashMap<String, Option<u64>>> = HashMap::new(); // Filters targeting a specific aggregate
229
230        for filter in filters {
231            if let (Some(kind), Some(id)) = (&filter.aggregate_kind, &filter.aggregate_id) {
232                by_aggregate
233                    .entry(StreamKey::new(kind.clone(), id.clone()))
234                    .or_default()
235                    .insert(filter.event_kind.clone(), filter.after_position);
236            } else {
237                all_kinds.insert(filter.event_kind.clone(), filter.after_position);
238            }
239        }
240
241        // Helper to check position filter for a specific after_position constraint
242        let passes_position_filter =
243            |event: &StoredEvent<Id, u64, M>, after_position: Option<u64>| -> bool {
244                after_position.is_none_or(|after| event.position > after)
245            };
246
247        // Load events for specific aggregates
248        for (stream_key, kinds) in &by_aggregate {
249            if let Some(stream) = self.streams.get(stream_key) {
250                for event in stream {
251                    // Check if this event kind is requested AND passes its specific position filter
252                    if let Some(&after_pos) = kinds.get(&event.kind)
253                        && passes_position_filter(event, after_pos)
254                    {
255                        // Track that we've seen this (aggregate_kind, aggregate_id, kind) triple
256                        seen.insert((
257                            StreamKey::new(
258                                event.aggregate_kind.clone(),
259                                event.aggregate_id.clone(),
260                            ),
261                            event.kind.clone(),
262                        ));
263                        result.push(event.clone());
264                    }
265                }
266            }
267        }
268
269        // Load events from all aggregates for unfiltered kinds
270        // Skip events we've already loaded for specific aggregates
271        if !all_kinds.is_empty() {
272            for stream in self.streams.values() {
273                for event in stream {
274                    // Check if this event kind is requested AND passes its specific position filter
275                    if let Some(&after_pos) = all_kinds.get(&event.kind)
276                        && passes_position_filter(event, after_pos)
277                    {
278                        let key = (
279                            StreamKey::new(
280                                event.aggregate_kind.clone(),
281                                event.aggregate_id.clone(),
282                            ),
283                            event.kind.clone(),
284                        );
285                        if !seen.contains(&key) {
286                            result.push(event.clone());
287                        }
288                    }
289                }
290            }
291        }
292
293        // Sort by position for chronological ordering across streams
294        result.sort_by_key(|event| event.position);
295
296        tracing::debug!(events_loaded = result.len(), "loaded events from store");
297        std::future::ready(Ok(result))
298    }
299}