1use std::{
15 collections::HashMap,
16 pin::Pin,
17 sync::{Arc, RwLock},
18};
19
20use nonempty::NonEmpty;
21use tokio::sync::broadcast;
22
23use crate::{
24 concurrency::ConcurrencyConflict,
25 store::{
26 CommitError, Committed, EventFilter, EventStore, GloballyOrderedStore, LoadEventsResult,
27 OptimisticCommitError, StoredEvent, StreamKey,
28 },
29 subscription::SubscribableStore,
30};
31
32type InMemoryStream<Id, M> = Vec<StoredEvent<Id, u64, serde_json::Value, M>>;
34
35#[derive(Clone)]
47pub struct Store<Id, M> {
48 inner: Arc<RwLock<Inner<Id, M>>>,
49}
50
51struct Inner<Id, M> {
52 streams: HashMap<StreamKey<Id>, InMemoryStream<Id, M>>,
53 next_position: u64,
54 notify_tx: broadcast::Sender<u64>,
57}
58
59impl<Id, M> Store<Id, M> {
60 #[must_use]
61 pub fn new() -> Self {
62 let (notify_tx, _) = broadcast::channel(1024);
63 Self {
64 inner: Arc::new(RwLock::new(Inner {
65 streams: HashMap::new(),
66 next_position: 0,
67 notify_tx,
68 })),
69 }
70 }
71}
72
73impl<Id, M> Default for Store<Id, M> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79#[derive(Debug, thiserror::Error)]
81pub enum InMemoryError {
82 #[error("serialization error: {0}")]
83 Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
84 #[error("deserialization error: {0}")]
85 Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
86}
87
88impl<Id, M> EventStore for Store<Id, M>
89where
90 Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
91 M: Clone + Send + Sync + 'static,
92{
93 type Data = serde_json::Value;
94 type Error = InMemoryError;
95 type Id = Id;
96 type Metadata = M;
97 type Position = u64;
98
99 fn decode_event<E>(
100 &self,
101 stored: &StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
102 ) -> Result<E, Self::Error>
103 where
104 E: crate::event::DomainEvent + serde::de::DeserializeOwned,
105 {
106 serde_json::from_value(stored.data.clone())
107 .map_err(|e| InMemoryError::Deserialization(Box::new(e)))
108 }
109
110 #[tracing::instrument(skip(self, aggregate_id))]
111 fn stream_version<'a>(
112 &'a self,
113 aggregate_kind: &'a str,
114 aggregate_id: &'a Self::Id,
115 ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + 'a {
116 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
117 let version = {
118 let inner = self.inner.read().expect("in-memory store lock poisoned");
119 inner
120 .streams
121 .get(&stream_key)
122 .and_then(|s| s.last().map(|e| e.position))
123 };
124 tracing::trace!(?version, "retrieved stream version");
125 std::future::ready(Ok(version))
126 }
127
128 #[tracing::instrument(skip(self, aggregate_id, events, metadata), fields(event_count = events.len()))]
129 fn commit_events<'a, E>(
130 &'a self,
131 aggregate_kind: &'a str,
132 aggregate_id: &'a Self::Id,
133 events: NonEmpty<E>,
134 metadata: &'a Self::Metadata,
135 ) -> impl Future<Output = Result<Committed<u64>, CommitError<Self::Error>>> + Send + 'a
136 where
137 E: crate::event::EventKind + serde::Serialize + Send + Sync + 'a,
138 Self::Metadata: Clone,
139 {
140 let result = (|| {
141 let mut staged = Vec::with_capacity(events.len());
143 for (index, event) in events.iter().enumerate() {
144 let data = serde_json::to_value(event).map_err(|e| CommitError::Serialization {
145 index,
146 source: InMemoryError::Serialization(Box::new(e)),
147 })?;
148 staged.push((event.kind().to_string(), data));
149 }
150
151 let mut inner = self.inner.write().expect("in-memory store lock poisoned");
152 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
153 let mut last_position = 0;
154 let mut stored = Vec::with_capacity(staged.len());
155
156 for (kind, data) in staged {
157 let position = inner.next_position;
158 inner.next_position += 1;
159 last_position = position;
160 stored.push(StoredEvent {
161 aggregate_kind: aggregate_kind.to_string(),
162 aggregate_id: aggregate_id.clone(),
163 kind,
164 position,
165 data,
166 metadata: metadata.clone(),
167 });
168 }
169
170 inner.streams.entry(stream_key).or_default().extend(stored);
171 let notify_tx = inner.notify_tx.clone();
172 drop(inner);
173 let _ = notify_tx.send(last_position);
176 tracing::debug!(events_appended = events.len(), "events committed to stream");
177 Ok(Committed { last_position })
178 })();
179
180 std::future::ready(result)
181 }
182
183 #[tracing::instrument(skip(self, aggregate_id, events, metadata), fields(event_count = events.len()))]
184 fn commit_events_optimistic<'a, E>(
185 &'a self,
186 aggregate_kind: &'a str,
187 aggregate_id: &'a Self::Id,
188 expected_version: Option<Self::Position>,
189 events: NonEmpty<E>,
190 metadata: &'a Self::Metadata,
191 ) -> impl Future<Output = Result<Committed<u64>, OptimisticCommitError<u64, Self::Error>>> + Send + 'a
192 where
193 E: crate::event::EventKind + serde::Serialize + Send + Sync + 'a,
194 Self::Metadata: Clone,
195 {
196 let result = (|| {
197 let mut staged = Vec::with_capacity(events.len());
199 for (index, event) in events.iter().enumerate() {
200 let data = serde_json::to_value(event).map_err(|e| {
201 OptimisticCommitError::Serialization {
202 index,
203 source: InMemoryError::Serialization(Box::new(e)),
204 }
205 })?;
206 staged.push((event.kind().to_string(), data));
207 }
208
209 let mut inner = self.inner.write().expect("in-memory store lock poisoned");
210 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
211
212 let current = inner
214 .streams
215 .get(&stream_key)
216 .and_then(|s| s.last().map(|e| e.position));
217
218 match expected_version {
219 Some(expected) => {
220 if current != Some(expected) {
222 tracing::debug!(?expected, ?current, "version mismatch, rejecting commit");
223 return Err(ConcurrencyConflict {
224 expected: Some(expected),
225 actual: current,
226 }
227 .into());
228 }
229 }
230 None => {
231 if let Some(actual) = current {
233 tracing::debug!(
234 ?actual,
235 "stream already exists, rejecting new aggregate commit"
236 );
237 return Err(ConcurrencyConflict {
238 expected: None,
239 actual: Some(actual),
240 }
241 .into());
242 }
243 }
244 }
245
246 let mut last_position = 0;
247 let mut stored = Vec::with_capacity(staged.len());
248
249 for (kind, data) in staged {
250 let position = inner.next_position;
251 inner.next_position += 1;
252 last_position = position;
253 stored.push(StoredEvent {
254 aggregate_kind: aggregate_kind.to_string(),
255 aggregate_id: aggregate_id.clone(),
256 kind,
257 position,
258 data,
259 metadata: metadata.clone(),
260 });
261 }
262
263 inner.streams.entry(stream_key).or_default().extend(stored);
264 let notify_tx = inner.notify_tx.clone();
265 drop(inner);
266 let _ = notify_tx.send(last_position);
267 tracing::debug!(
268 events_appended = events.len(),
269 "events committed to stream (optimistic)"
270 );
271 Ok(Committed { last_position })
272 })();
273
274 std::future::ready(result)
275 }
276
277 #[tracing::instrument(skip(self, filters), fields(filter_count = filters.len()))]
278 fn load_events<'a>(
279 &'a self,
280 filters: &'a [EventFilter<Self::Id, Self::Position>],
281 ) -> impl Future<
282 Output = LoadEventsResult<
283 Self::Id,
284 Self::Position,
285 Self::Data,
286 Self::Metadata,
287 Self::Error,
288 >,
289 > + Send
290 + 'a {
291 let result = load_matching_events(
292 &self.inner.read().expect("in-memory store lock poisoned"),
293 filters,
294 );
295
296 tracing::debug!(events_loaded = result.len(), "loaded events from store");
297 std::future::ready(Ok(result))
298 }
299}
300
301impl<Id, M> GloballyOrderedStore for Store<Id, M>
302where
303 Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
304 M: Clone + Send + Sync + 'static,
305{
306}
307
308impl<Id, M> SubscribableStore for Store<Id, M>
309where
310 Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
311 M: Clone + Send + Sync + 'static,
312{
313 fn subscribe(
314 &self,
315 filters: &[EventFilter<Self::Id, Self::Position>],
316 from_position: Option<Self::Position>,
317 ) -> Pin<
318 Box<
319 dyn futures_core::Stream<
320 Item = Result<
321 StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
322 Self::Error,
323 >,
324 > + Send
325 + '_,
326 >,
327 >
328 where
329 Self::Position: Ord,
330 {
331 let filters = filters.to_vec();
332 let inner = self.inner.clone();
333
334 let mut rx = {
337 let guard = inner.read().expect("in-memory store lock poisoned");
338 guard.notify_tx.subscribe()
339 };
340
341 Box::pin(async_stream::stream! {
342 let historical = {
344 let guard = inner.read().expect("in-memory store lock poisoned");
345 load_matching_events(&guard, &filters)
346 };
347
348 let mut last_position = from_position;
349
350 for event in historical {
351 if let Some(ref lp) = last_position
353 && event.position <= *lp
354 {
355 continue;
356 }
357 last_position = Some(event.position);
358 yield Ok(event);
359 }
360
361 loop {
363 match rx.recv().await {
364 Ok(notified_position) => {
365 if let Some(ref lp) = last_position
367 && notified_position <= *lp
368 {
369 continue;
370 }
371
372 let events = {
374 let guard = inner.read().expect("in-memory store lock poisoned");
375 load_matching_events(&guard, &filters)
376 };
377
378 for event in events {
379 if let Some(ref lp) = last_position
380 && event.position <= *lp
381 {
382 continue;
383 }
384 last_position = Some(event.position);
385 yield Ok(event);
386 }
387 }
388 Err(broadcast::error::RecvError::Lagged(_)) => {
389 let events = {
392 let guard = inner.read().expect("in-memory store lock poisoned");
393 load_matching_events(&guard, &filters)
394 };
395
396 for event in events {
397 if let Some(ref lp) = last_position
398 && event.position <= *lp
399 {
400 continue;
401 }
402 last_position = Some(event.position);
403 yield Ok(event);
404 }
405 }
406 Err(broadcast::error::RecvError::Closed) => {
407 break;
409 }
410 }
411 }
412 })
413 }
414}
415
416fn load_matching_events<Id, M>(
418 inner: &Inner<Id, M>,
419 filters: &[EventFilter<Id, u64>],
420) -> Vec<StoredEvent<Id, u64, serde_json::Value, M>>
421where
422 Id: Clone + Eq + std::hash::Hash,
423 M: Clone,
424{
425 use std::collections::HashSet;
426
427 let mut result = Vec::new();
428 let mut seen: HashSet<(StreamKey<Id>, String)> = HashSet::new();
429
430 let mut all_kinds: HashMap<String, Option<u64>> = HashMap::new();
431 let mut by_aggregate: HashMap<StreamKey<Id>, HashMap<String, Option<u64>>> = HashMap::new();
432
433 for filter in filters {
434 if let (Some(kind), Some(id)) = (&filter.aggregate_kind, &filter.aggregate_id) {
435 by_aggregate
436 .entry(StreamKey::new(kind.clone(), id.clone()))
437 .or_default()
438 .insert(filter.event_kind.clone(), filter.after_position);
439 } else {
440 all_kinds.insert(filter.event_kind.clone(), filter.after_position);
441 }
442 }
443
444 let passes_position_filter =
445 |event: &StoredEvent<Id, u64, serde_json::Value, M>, after_position: Option<u64>| -> bool {
446 after_position.is_none_or(|after| event.position > after)
447 };
448
449 for (stream_key, kinds) in &by_aggregate {
450 if let Some(stream) = inner.streams.get(stream_key) {
451 for event in stream {
452 if let Some(&after_pos) = kinds.get(&event.kind)
453 && passes_position_filter(event, after_pos)
454 {
455 seen.insert((
456 StreamKey::new(event.aggregate_kind.clone(), event.aggregate_id.clone()),
457 event.kind.clone(),
458 ));
459 result.push(event.clone());
460 }
461 }
462 }
463 }
464
465 if !all_kinds.is_empty() {
466 for stream in inner.streams.values() {
467 for event in stream {
468 if let Some(&after_pos) = all_kinds.get(&event.kind)
469 && passes_position_filter(event, after_pos)
470 {
471 let key = (
472 StreamKey::new(event.aggregate_kind.clone(), event.aggregate_id.clone()),
473 event.kind.clone(),
474 );
475 if !seen.contains(&key) {
476 result.push(event.clone());
477 }
478 }
479 }
480 }
481 }
482
483 result.sort_by_key(|event| event.position);
484 result
485}
486
487#[cfg(test)]
488mod tests {
489 use serde::{Deserialize, Serialize};
490
491 use super::*;
492 use crate::event::DomainEvent;
493
494 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
495 struct TestEvent {
496 value: i32,
497 }
498
499 impl DomainEvent for TestEvent {
500 const KIND: &'static str = "test-event";
501 }
502
503 #[test]
504 fn new_has_no_streams() {
505 let store = Store::<String, ()>::new();
506 let inner = store.inner.read().unwrap();
507 assert!(inner.streams.is_empty());
508 assert_eq!(inner.next_position, 0);
509 drop(inner);
510 }
511
512 #[test]
513 fn decode_event_deserializes() {
514 let store = Store::<String, ()>::new();
515 let event = TestEvent { value: 42 };
516 let data = serde_json::to_value(&event).unwrap();
517
518 let stored = StoredEvent {
520 aggregate_kind: "test-agg".to_string(),
521 aggregate_id: "id".to_string(),
522 kind: "test-event".to_string(),
523 position: 0,
524 data,
525 metadata: (),
526 };
527
528 let decoded: TestEvent = store.decode_event(&stored).unwrap();
529 assert_eq!(decoded, event);
530 }
531
532 #[test]
533 fn error_display_serialization() {
534 let err = InMemoryError::Serialization(Box::new(std::io::Error::other("test")));
535 assert!(err.to_string().contains("serialization error"));
536 }
537
538 #[test]
539 fn error_display_deserialization() {
540 let err = InMemoryError::Deserialization(Box::new(std::io::Error::other("test")));
541 assert!(err.to_string().contains("deserialization error"));
542 }
543
544 #[tokio::test]
545 async fn version_returns_none_for_new_stream() {
546 let store = Store::<String, ()>::new();
547 let version = store
548 .stream_version("test-agg", &"id".to_string())
549 .await
550 .unwrap();
551 assert!(version.is_none());
552 }
553
554 #[tokio::test]
555 async fn version_returns_position_after_commit() {
556 let store = Store::<String, ()>::new();
557 let id = "id".to_string();
558 let events = NonEmpty::singleton(TestEvent { value: 1 });
559
560 store
561 .commit_events("test-agg", &id, events, &())
562 .await
563 .unwrap();
564
565 let version = store.stream_version("test-agg", &id).await.unwrap();
566 assert_eq!(version, Some(0));
567 }
568
569 #[tokio::test]
570 async fn commit_with_wrong_version_returns_conflict() {
571 let store = Store::<String, ()>::new();
572 let id = "id".to_string();
573 let events1 = NonEmpty::singleton(TestEvent { value: 1 });
574
575 store
577 .commit_events("test-agg", &id, events1, &())
578 .await
579 .unwrap();
580
581 let events2 = NonEmpty::singleton(TestEvent { value: 2 });
583 let result = store
584 .commit_events_optimistic("test-agg", &id, Some(99), events2, &())
585 .await;
586
587 assert!(matches!(result, Err(OptimisticCommitError::Conflict(_))));
588 }
589
590 #[tokio::test]
591 async fn commit_new_stream_fails_if_stream_exists() {
592 let store = Store::<String, ()>::new();
593 let id = "id".to_string();
594 let events = NonEmpty::singleton(TestEvent { value: 1 });
595
596 store
598 .commit_events("test-agg", &id, events, &())
599 .await
600 .unwrap();
601
602 let events2 = NonEmpty::singleton(TestEvent { value: 2 });
604 let result = store
605 .commit_events_optimistic("test-agg", &id, None, events2, &())
606 .await;
607
608 assert!(matches!(result, Err(OptimisticCommitError::Conflict(_))));
609 }
610}