1use std::{collections::HashMap, convert::Infallible};
2
3use crate::{
4 codec::Codec,
5 concurrency::{ConcurrencyConflict, ConcurrencyStrategy},
6 store::{
7 AppendError, EventFilter, EventStore, PersistableEvent, StoredEvent, StreamKey, Transaction,
8 },
9};
10
11pub struct Store<Id, C, M>
22where
23 C: Codec,
24{
25 codec: C,
26 streams: HashMap<StreamKey<Id>, Vec<StoredEvent<Id, u64, M>>>,
27 next_position: u64,
28}
29
30impl<Id, C, M> Store<Id, C, M>
31where
32 C: Codec,
33{
34 #[must_use]
35 pub fn new(codec: C) -> Self {
36 Self {
37 codec,
38 streams: HashMap::new(),
39 next_position: 0,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
48#[error("infallible")]
49pub enum InMemoryError {}
50
51impl From<Infallible> for InMemoryError {
52 fn from(x: Infallible) -> Self {
53 match x {}
54 }
55}
56
57impl<Id, C, M> EventStore for Store<Id, C, M>
58where
59 Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
60 C: Codec + Clone + Send + Sync + 'static,
61 M: Clone + Send + Sync + 'static,
62{
63 type Codec = C;
64 type Error = InMemoryError;
66 type Id = Id;
67 type Metadata = M;
68 type Position = u64;
69
70 fn codec(&self) -> &Self::Codec {
71 &self.codec
72 }
73
74 #[tracing::instrument(skip(self, aggregate_id))]
75 fn stream_version<'a>(
76 &'a self,
77 aggregate_kind: &'a str,
78 aggregate_id: &'a Self::Id,
79 ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + 'a {
80 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
81 let version = self
82 .streams
83 .get(&stream_key)
84 .and_then(|s| s.last().map(|e| e.position));
85 tracing::trace!(?version, "retrieved stream version");
86 std::future::ready(Ok(version))
87 }
88
89 fn begin<Conc: ConcurrencyStrategy>(
90 &mut self,
91 aggregate_kind: &str,
92 aggregate_id: Self::Id,
93 expected_version: Option<Self::Position>,
94 ) -> Transaction<'_, Self, Conc> {
95 Transaction::new(
96 self,
97 aggregate_kind.to_string(),
98 aggregate_id,
99 expected_version,
100 )
101 }
102
103 #[tracing::instrument(skip(self, aggregate_id, events), fields(event_count = events.len()))]
104 fn append<'a>(
105 &'a mut self,
106 aggregate_kind: &'a str,
107 aggregate_id: &'a Self::Id,
108 expected_version: Option<u64>,
109 events: Vec<PersistableEvent<Self::Metadata>>,
110 ) -> impl Future<Output = Result<(), AppendError<u64, Self::Error>>> + Send + 'a {
111 let event_count = events.len();
112
113 let result = (|| {
114 if let Some(expected) = expected_version {
116 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
117 let current = self
118 .streams
119 .get(&stream_key)
120 .and_then(|s| s.last().map(|e| e.position));
121 if current != Some(expected) {
122 tracing::debug!(?expected, ?current, "version mismatch, rejecting append");
123 return Err(ConcurrencyConflict {
124 expected: Some(expected),
125 actual: current,
126 }
127 .into());
128 }
129 }
130
131 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
132 let stored: Vec<StoredEvent<Id, u64, M>> = events
133 .into_iter()
134 .map(|e| {
135 let position = self.next_position;
136 self.next_position += 1;
137 StoredEvent {
138 aggregate_kind: aggregate_kind.to_string(),
139 aggregate_id: aggregate_id.clone(),
140 kind: e.kind,
141 position,
142 data: e.data,
143 metadata: e.metadata,
144 }
145 })
146 .collect();
147
148 self.streams.entry(stream_key).or_default().extend(stored);
149 tracing::debug!(events_appended = event_count, "events appended to stream");
150 Ok(())
151 })();
152
153 std::future::ready(result)
154 }
155
156 #[tracing::instrument(skip(self, aggregate_id, events), fields(event_count = events.len()))]
157 fn append_expecting_new<'a>(
158 &'a mut self,
159 aggregate_kind: &'a str,
160 aggregate_id: &'a Self::Id,
161 events: Vec<PersistableEvent<Self::Metadata>>,
162 ) -> impl Future<Output = Result<(), AppendError<u64, Self::Error>>> + Send + 'a {
163 let event_count = events.len();
164
165 let result = (|| {
166 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
168 let current = self
169 .streams
170 .get(&stream_key)
171 .and_then(|s| s.last().map(|e| e.position));
172
173 if let Some(actual) = current {
174 tracing::debug!(
176 ?actual,
177 "stream already exists, rejecting new aggregate append"
178 );
179 return Err(ConcurrencyConflict {
180 expected: None, actual: Some(actual),
182 }
183 .into());
184 }
185
186 let stream_key = StreamKey::new(aggregate_kind, aggregate_id.clone());
188 let stored: Vec<StoredEvent<Id, u64, M>> = events
189 .into_iter()
190 .map(|e| {
191 let position = self.next_position;
192 self.next_position += 1;
193 StoredEvent {
194 aggregate_kind: aggregate_kind.to_string(),
195 aggregate_id: aggregate_id.clone(),
196 kind: e.kind,
197 position,
198 data: e.data,
199 metadata: e.metadata,
200 }
201 })
202 .collect();
203
204 self.streams.entry(stream_key).or_default().extend(stored);
205 tracing::debug!(
206 events_appended = event_count,
207 "new stream created with events"
208 );
209 Ok(())
210 })();
211
212 std::future::ready(result)
213 }
214
215 #[tracing::instrument(skip(self, filters), fields(filter_count = filters.len()))]
216 fn load_events<'a>(
217 &'a self,
218 filters: &'a [EventFilter<Self::Id, Self::Position>],
219 ) -> impl Future<Output = Result<Vec<StoredEvent<Id, u64, M>>, Self::Error>> + Send + 'a {
220 use std::collections::HashSet;
221
222 let mut result = Vec::new();
223 let mut seen: HashSet<(StreamKey<Id>, String)> = HashSet::new(); let mut all_kinds: HashMap<String, Option<u64>> = HashMap::new(); let mut by_aggregate: HashMap<StreamKey<Id>, HashMap<String, Option<u64>>> = HashMap::new(); for filter in filters {
231 if let (Some(kind), Some(id)) = (&filter.aggregate_kind, &filter.aggregate_id) {
232 by_aggregate
233 .entry(StreamKey::new(kind.clone(), id.clone()))
234 .or_default()
235 .insert(filter.event_kind.clone(), filter.after_position);
236 } else {
237 all_kinds.insert(filter.event_kind.clone(), filter.after_position);
238 }
239 }
240
241 let passes_position_filter =
243 |event: &StoredEvent<Id, u64, M>, after_position: Option<u64>| -> bool {
244 after_position.is_none_or(|after| event.position > after)
245 };
246
247 for (stream_key, kinds) in &by_aggregate {
249 if let Some(stream) = self.streams.get(stream_key) {
250 for event in stream {
251 if let Some(&after_pos) = kinds.get(&event.kind)
253 && passes_position_filter(event, after_pos)
254 {
255 seen.insert((
257 StreamKey::new(
258 event.aggregate_kind.clone(),
259 event.aggregate_id.clone(),
260 ),
261 event.kind.clone(),
262 ));
263 result.push(event.clone());
264 }
265 }
266 }
267 }
268
269 if !all_kinds.is_empty() {
272 for stream in self.streams.values() {
273 for event in stream {
274 if let Some(&after_pos) = all_kinds.get(&event.kind)
276 && passes_position_filter(event, after_pos)
277 {
278 let key = (
279 StreamKey::new(
280 event.aggregate_kind.clone(),
281 event.aggregate_id.clone(),
282 ),
283 event.kind.clone(),
284 );
285 if !seen.contains(&key) {
286 result.push(event.clone());
287 }
288 }
289 }
290 }
291 }
292
293 result.sort_by_key(|event| event.position);
295
296 tracing::debug!(events_loaded = result.len(), "loaded events from store");
297 std::future::ready(Ok(result))
298 }
299}