1pub mod snapshot;
13
14use std::marker::PhantomData;
15
16use nonempty::NonEmpty;
17use serde::{Serialize, de::DeserializeOwned};
18use sourcery_core::{
19 concurrency::ConcurrencyConflict,
20 event::DomainEvent,
21 store::{
22 CommitError, Committed, EventFilter, EventStore, GloballyOrderedStore, LoadEventsResult,
23 OptimisticCommitError, StoredEvent,
24 },
25};
26use sqlx::{PgPool, Postgres, QueryBuilder, Row};
27
28#[derive(Debug, thiserror::Error)]
29pub enum Error {
30 #[error("database error: {0}")]
31 Database(#[from] sqlx::Error),
32 #[error("invalid position value from database: {0}")]
33 InvalidPosition(i64),
34 #[error("database did not return an inserted position")]
35 MissingReturnedPosition,
36 #[error("serialization error: {0}")]
37 Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
38 #[error("deserialization error: {0}")]
39 Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
40}
41
42#[derive(Clone)]
49pub struct Store<M> {
50 pool: PgPool,
51 _phantom: PhantomData<M>,
52}
53
54impl<M> Store<M> {
55 #[must_use]
56 pub const fn new(pool: PgPool) -> Self {
57 Self {
58 pool,
59 _phantom: PhantomData,
60 }
61 }
62}
63
64impl<M> Store<M>
65where
66 M: Sync,
67{
68 #[tracing::instrument(skip(self))]
77 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
78 sqlx::query(
80 r"
81 CREATE TABLE IF NOT EXISTS es_streams (
82 aggregate_kind TEXT NOT NULL,
83 aggregate_id UUID NOT NULL,
84 last_position BIGINT NULL,
85 PRIMARY KEY (aggregate_kind, aggregate_id)
86 )
87 ",
88 )
89 .execute(&self.pool)
90 .await?;
91
92 sqlx::query(
93 r"
94 CREATE TABLE IF NOT EXISTS es_events (
95 position BIGSERIAL PRIMARY KEY,
96 aggregate_kind TEXT NOT NULL,
97 aggregate_id UUID NOT NULL,
98 event_kind TEXT NOT NULL,
99 data JSONB NOT NULL,
100 metadata JSONB NOT NULL,
101 created_at TIMESTAMPTZ NOT NULL DEFAULT now()
102 )
103 ",
104 )
105 .execute(&self.pool)
106 .await?;
107
108 sqlx::query(
109 r"CREATE INDEX IF NOT EXISTS es_events_by_kind_and_position ON es_events(event_kind, position)",
110 )
111 .execute(&self.pool)
112 .await?;
113
114 sqlx::query(
115 r"CREATE INDEX IF NOT EXISTS es_events_by_stream_and_position ON es_events(aggregate_kind, aggregate_id, position)",
116 )
117 .execute(&self.pool)
118 .await?;
119
120 Ok(())
121 }
122}
123
124impl<M> EventStore for Store<M>
125where
126 M: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
127{
128 type Data = serde_json::Value;
129 type Error = Error;
130 type Id = uuid::Uuid;
131 type Metadata = M;
132 type Position = i64;
133
134 fn decode_event<E>(
135 &self,
136 stored: &StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
137 ) -> Result<E, Self::Error>
138 where
139 E: DomainEvent + serde::de::DeserializeOwned,
140 {
141 serde_json::from_value(stored.data.clone()).map_err(|e| Error::Deserialization(Box::new(e)))
142 }
143
144 async fn stream_version<'a>(
145 &'a self,
146 aggregate_kind: &'a str,
147 aggregate_id: &'a Self::Id,
148 ) -> Result<Option<Self::Position>, Self::Error> {
149 let result: Option<i64> = sqlx::query_scalar(
150 r"SELECT last_position FROM es_streams WHERE aggregate_kind = $1 AND aggregate_id = $2",
151 )
152 .bind(aggregate_kind)
153 .bind(aggregate_id)
154 .fetch_optional(&self.pool)
155 .await?
156 .flatten();
157
158 Ok(result)
159 }
160
161 #[tracing::instrument(
162 skip(self, events, metadata),
163 fields(
164 aggregate_kind,
165 aggregate_id = %aggregate_id,
166 events_len = events.len()
167 )
168 )]
169 async fn commit_events<'a, E>(
170 &'a self,
171 aggregate_kind: &'a str,
172 aggregate_id: &'a Self::Id,
173 events: NonEmpty<E>,
174 metadata: &'a Self::Metadata,
175 ) -> Result<Committed<i64>, CommitError<Self::Error>>
176 where
177 E: sourcery_core::event::EventKind + serde::Serialize + Send + Sync + 'a,
178 Self::Metadata: Clone,
179 {
180 let mut prepared: Vec<(String, serde_json::Value)> = Vec::with_capacity(events.len());
182 for (index, event) in events.iter().enumerate() {
183 let data = serde_json::to_value(event).map_err(|e| CommitError::Serialization {
184 index,
185 source: Error::Serialization(Box::new(e)),
186 })?;
187 prepared.push((event.kind().to_string(), data));
188 }
189
190 let mut tx = self
191 .pool
192 .begin()
193 .await
194 .map_err(|e| CommitError::Store(Error::Database(e)))?;
195
196 sqlx::query(
197 r"
198 INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
199 VALUES ($1, $2, NULL)
200 ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
201 ",
202 )
203 .bind(aggregate_kind)
204 .bind(aggregate_id)
205 .execute(&mut *tx)
206 .await
207 .map_err(|e| CommitError::Store(Error::Database(e)))?;
208
209 let mut qb = QueryBuilder::<Postgres>::new(
210 "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
211 );
212 qb.push_values(prepared, |mut b, (kind, data)| {
213 b.push_bind(aggregate_kind);
214 b.push_bind(aggregate_id);
215 b.push_bind(kind);
216 b.push_bind(sqlx::types::Json(data));
217 b.push_bind(sqlx::types::Json(metadata.clone()));
218 });
219 qb.push(" RETURNING position");
220
221 let rows: Vec<i64> = qb
222 .build_query_scalar()
223 .fetch_all(&mut *tx)
224 .await
225 .map_err(|e| CommitError::Store(Error::Database(e)))?;
226
227 let last_position = rows
228 .last()
229 .ok_or_else(|| CommitError::Store(Error::MissingReturnedPosition))?;
230
231 sqlx::query(
232 r"
233 UPDATE es_streams
234 SET last_position = $1
235 WHERE aggregate_kind = $2 AND aggregate_id = $3
236 ",
237 )
238 .bind(last_position)
239 .bind(aggregate_kind)
240 .bind(aggregate_id)
241 .execute(&mut *tx)
242 .await
243 .map_err(|e| CommitError::Store(Error::Database(e)))?;
244
245 tx.commit()
246 .await
247 .map_err(|e| CommitError::Store(Error::Database(e)))?;
248
249 Ok(Committed {
250 last_position: *last_position,
251 })
252 }
253
254 #[tracing::instrument(
255 skip(self, events, metadata),
256 fields(
257 aggregate_kind,
258 aggregate_id = %aggregate_id,
259 expected_version,
260 events_len = events.len()
261 )
262 )]
263 async fn commit_events_optimistic<'a, E>(
264 &'a self,
265 aggregate_kind: &'a str,
266 aggregate_id: &'a Self::Id,
267 expected_version: Option<Self::Position>,
268 events: NonEmpty<E>,
269 metadata: &'a Self::Metadata,
270 ) -> Result<Committed<i64>, OptimisticCommitError<i64, Self::Error>>
271 where
272 E: sourcery_core::event::EventKind + serde::Serialize + Send + Sync + 'a,
273 Self::Metadata: Clone,
274 {
275 let mut prepared: Vec<(String, serde_json::Value)> = Vec::with_capacity(events.len());
277 for (index, event) in events.iter().enumerate() {
278 let data =
279 serde_json::to_value(event).map_err(|e| OptimisticCommitError::Serialization {
280 index,
281 source: Error::Serialization(Box::new(e)),
282 })?;
283 prepared.push((event.kind().to_string(), data));
284 }
285
286 let mut tx = self
287 .pool
288 .begin()
289 .await
290 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
291
292 sqlx::query(
293 r"
294 INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
295 VALUES ($1, $2, NULL)
296 ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
297 ",
298 )
299 .bind(aggregate_kind)
300 .bind(aggregate_id)
301 .execute(&mut *tx)
302 .await
303 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
304
305 let current: Option<i64> = sqlx::query_scalar::<_, Option<i64>>(
306 r"
307 SELECT last_position
308 FROM es_streams
309 WHERE aggregate_kind = $1 AND aggregate_id = $2
310 FOR UPDATE
311 ",
312 )
313 .bind(aggregate_kind)
314 .bind(aggregate_id)
315 .fetch_one(&mut *tx)
316 .await
317 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
318
319 match expected_version {
321 Some(expected) => {
322 if current != Some(expected) {
323 return Err(OptimisticCommitError::Conflict(ConcurrencyConflict {
324 expected: Some(expected),
325 actual: current,
326 }));
327 }
328 }
329 None => {
330 if let Some(actual) = current {
332 return Err(OptimisticCommitError::Conflict(ConcurrencyConflict {
333 expected: None,
334 actual: Some(actual),
335 }));
336 }
337 }
338 }
339
340 let mut qb = QueryBuilder::<Postgres>::new(
341 "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
342 );
343 qb.push_values(prepared, |mut b, (kind, data)| {
344 b.push_bind(aggregate_kind);
345 b.push_bind(aggregate_id);
346 b.push_bind(kind);
347 b.push_bind(sqlx::types::Json(data));
348 b.push_bind(sqlx::types::Json(metadata.clone()));
349 });
350 qb.push(" RETURNING position");
351
352 let rows: Vec<i64> = qb
353 .build_query_scalar()
354 .fetch_all(&mut *tx)
355 .await
356 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
357
358 let last_position = rows
359 .last()
360 .ok_or_else(|| OptimisticCommitError::Store(Error::MissingReturnedPosition))?;
361
362 sqlx::query(
363 r"
364 UPDATE es_streams
365 SET last_position = $1
366 WHERE aggregate_kind = $2 AND aggregate_id = $3
367 ",
368 )
369 .bind(last_position)
370 .bind(aggregate_kind)
371 .bind(aggregate_id)
372 .execute(&mut *tx)
373 .await
374 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
375
376 tx.commit()
377 .await
378 .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
379
380 Ok(Committed {
381 last_position: *last_position,
382 })
383 }
384
385 #[allow(clippy::type_complexity)]
386 #[tracing::instrument(skip(self, filters), fields(filters_len = filters.len()))]
387 async fn load_events<'a>(
388 &'a self,
389 filters: &'a [EventFilter<Self::Id, Self::Position>],
390 ) -> LoadEventsResult<Self::Id, Self::Position, Self::Data, Self::Metadata, Self::Error> {
391 if filters.is_empty() {
392 return Ok(Vec::new());
393 }
394
395 let mut qb = QueryBuilder::<Postgres>::new(
396 "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM (",
397 );
398
399 for (i, filter) in filters.iter().enumerate() {
400 if i > 0 {
401 qb.push(" UNION ALL ");
402 }
403
404 qb.push(
405 "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM \
406 es_events WHERE event_kind = ",
407 )
408 .push_bind(&filter.event_kind);
409
410 if let Some(kind) = &filter.aggregate_kind {
411 qb.push(" AND aggregate_kind = ").push_bind(kind);
412 }
413
414 if let Some(id) = &filter.aggregate_id {
415 qb.push(" AND aggregate_id = ").push_bind(id);
416 }
417
418 if let Some(after) = filter.after_position {
419 if after < 0 {
420 return Err(Error::InvalidPosition(after));
421 }
422 qb.push(" AND position > ").push_bind(after);
423 }
424 }
425
426 qb.push(") t ORDER BY position ASC");
427
428 let rows = qb.build().fetch_all(&self.pool).await?;
429
430 let mut out = Vec::with_capacity(rows.len());
431 for row in rows {
432 let aggregate_kind: String = row.try_get("aggregate_kind")?;
433 let aggregate_id: uuid::Uuid = row.try_get("aggregate_id")?;
434 let event_kind: String = row.try_get("event_kind")?;
435 let position: i64 = row.try_get("position")?;
436 let data: sqlx::types::Json<serde_json::Value> = row.try_get("data")?;
437 let metadata: sqlx::types::Json<M> = row.try_get("metadata")?;
438
439 out.push(StoredEvent {
440 aggregate_kind,
441 aggregate_id,
442 kind: event_kind,
443 position,
444 data: data.0,
445 metadata: metadata.0,
446 });
447 }
448
449 Ok(out)
450 }
451}
452
453impl<M> GloballyOrderedStore for Store<M> where
454 M: Serialize + DeserializeOwned + Clone + Send + Sync + 'static
455{
456}