1use std::marker::PhantomData;
7
8use serde::{Serialize, de::DeserializeOwned};
9use sourcery_core::{
10 codec::Codec,
11 concurrency::{ConcurrencyConflict, ConcurrencyStrategy},
12 store::{
13 AppendError, EventFilter, EventStore, LoadEventsResult, PersistableEvent, StoredEvent,
14 Transaction,
15 },
16};
17use sqlx::{PgPool, Postgres, QueryBuilder, Row};
18
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21 #[error("database error: {0}")]
22 Database(#[from] sqlx::Error),
23 #[error("invalid position value from database: {0}")]
24 InvalidPosition(i64),
25 #[error("database did not return an inserted position")]
26 MissingReturnedPosition,
27}
28
29#[derive(Clone)]
35pub struct Store<C, M> {
36 pool: PgPool,
37 codec: C,
38 _phantom: PhantomData<M>,
39}
40
41impl<C, M> Store<C, M>
42where
43 C: Sync,
44 M: Sync,
45{
46 #[must_use]
47 pub const fn new(pool: PgPool, codec: C) -> Self {
48 Self {
49 pool,
50 codec,
51 _phantom: PhantomData,
52 }
53 }
54
55 #[tracing::instrument(skip(self))]
64 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
65 sqlx::query(
67 r"
68 CREATE TABLE IF NOT EXISTS es_streams (
69 aggregate_kind TEXT NOT NULL,
70 aggregate_id UUID NOT NULL,
71 last_position BIGINT NULL,
72 PRIMARY KEY (aggregate_kind, aggregate_id)
73 )
74 ",
75 )
76 .execute(&self.pool)
77 .await?;
78
79 sqlx::query(
80 r"
81 CREATE TABLE IF NOT EXISTS es_events (
82 position BIGSERIAL PRIMARY KEY,
83 aggregate_kind TEXT NOT NULL,
84 aggregate_id UUID NOT NULL,
85 event_kind TEXT NOT NULL,
86 data BYTEA NOT NULL,
87 metadata JSONB NOT NULL,
88 created_at TIMESTAMPTZ NOT NULL DEFAULT now()
89 )
90 ",
91 )
92 .execute(&self.pool)
93 .await?;
94
95 sqlx::query(
96 r"CREATE INDEX IF NOT EXISTS es_events_by_kind_and_position ON es_events(event_kind, position)",
97 )
98 .execute(&self.pool)
99 .await?;
100
101 sqlx::query(
102 r"CREATE INDEX IF NOT EXISTS es_events_by_stream_and_position ON es_events(aggregate_kind, aggregate_id, position)",
103 )
104 .execute(&self.pool)
105 .await?;
106
107 Ok(())
108 }
109}
110
111impl<C, M> EventStore for Store<C, M>
112where
113 C: Codec + Clone + Send + Sync + 'static,
114 M: Serialize + DeserializeOwned + Send + Sync + 'static,
115{
116 type Codec = C;
117 type Error = Error;
118 type Id = uuid::Uuid;
119 type Metadata = M;
120 type Position = i64;
121
122 fn codec(&self) -> &Self::Codec {
123 &self.codec
124 }
125
126 async fn stream_version<'a>(
127 &'a self,
128 aggregate_kind: &'a str,
129 aggregate_id: &'a Self::Id,
130 ) -> Result<Option<Self::Position>, Self::Error> {
131 let result: Option<i64> = sqlx::query_scalar(
132 r"SELECT last_position FROM es_streams WHERE aggregate_kind = $1 AND aggregate_id = $2",
133 )
134 .bind(aggregate_kind)
135 .bind(aggregate_id)
136 .fetch_optional(&self.pool)
137 .await?
138 .flatten();
139
140 Ok(result)
141 }
142
143 fn begin<Conc: ConcurrencyStrategy>(
144 &mut self,
145 aggregate_kind: &str,
146 aggregate_id: Self::Id,
147 expected_version: Option<Self::Position>,
148 ) -> Transaction<'_, Self, Conc> {
149 Transaction::new(
150 self,
151 aggregate_kind.to_string(),
152 aggregate_id,
153 expected_version,
154 )
155 }
156
157 #[tracing::instrument(
158 skip(self, events),
159 fields(
160 aggregate_kind,
161 aggregate_id = %aggregate_id,
162 expected_version,
163 events_len = events.len()
164 )
165 )]
166 async fn append<'a>(
167 &'a mut self,
168 aggregate_kind: &'a str,
169 aggregate_id: &'a Self::Id,
170 expected_version: Option<Self::Position>,
171 events: Vec<PersistableEvent<Self::Metadata>>,
172 ) -> Result<(), AppendError<Self::Position, Self::Error>> {
173 if events.is_empty() {
174 return Ok(());
175 }
176
177 let mut tx = self
178 .pool
179 .begin()
180 .await
181 .map_err(|e| AppendError::store(Error::Database(e)))?;
182
183 sqlx::query(
184 r"
185 INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
186 VALUES ($1, $2, NULL)
187 ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
188 ",
189 )
190 .bind(aggregate_kind)
191 .bind(aggregate_id)
192 .execute(&mut *tx)
193 .await
194 .map_err(|e| AppendError::store(Error::Database(e)))?;
195
196 let current: Option<i64> = sqlx::query_scalar(
197 r"
198 SELECT last_position
199 FROM es_streams
200 WHERE aggregate_kind = $1 AND aggregate_id = $2
201 FOR UPDATE
202 ",
203 )
204 .bind(aggregate_kind)
205 .bind(aggregate_id)
206 .fetch_one(&mut *tx)
207 .await
208 .map_err(|e| AppendError::store(Error::Database(e)))?;
209
210 if let Some(expected) = expected_version
211 && current != Some(expected)
212 {
213 return Err(AppendError::Conflict(ConcurrencyConflict {
214 expected: Some(expected),
215 actual: current,
216 }));
217 }
218
219 let mut qb = QueryBuilder::<Postgres>::new(
220 "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
221 );
222 qb.push_values(events, |mut b, event| {
223 b.push_bind(aggregate_kind);
224 b.push_bind(aggregate_id);
225 b.push_bind(event.kind);
226 b.push_bind(event.data);
227 b.push_bind(sqlx::types::Json(event.metadata));
228 });
229 qb.push(" RETURNING position");
230
231 let rows: Vec<i64> = qb
232 .build_query_scalar()
233 .fetch_all(&mut *tx)
234 .await
235 .map_err(|e| AppendError::store(Error::Database(e)))?;
236
237 let last_position = rows
238 .last()
239 .ok_or_else(|| AppendError::store(Error::MissingReturnedPosition))?;
240
241 sqlx::query(
242 r"
243 UPDATE es_streams
244 SET last_position = $1
245 WHERE aggregate_kind = $2 AND aggregate_id = $3
246 ",
247 )
248 .bind(last_position)
249 .bind(aggregate_kind)
250 .bind(aggregate_id)
251 .execute(&mut *tx)
252 .await
253 .map_err(|e| AppendError::store(Error::Database(e)))?;
254
255 tx.commit()
256 .await
257 .map_err(|e| AppendError::store(Error::Database(e)))?;
258
259 Ok(())
260 }
261
262 #[tracing::instrument(skip(self, filters), fields(filters_len = filters.len()))]
263 async fn load_events<'a>(
264 &'a self,
265 filters: &'a [EventFilter<Self::Id, Self::Position>],
266 ) -> LoadEventsResult<Self::Id, Self::Position, Self::Metadata, Self::Error> {
267 if filters.is_empty() {
268 return Ok(Vec::new());
269 }
270
271 let mut qb = QueryBuilder::<Postgres>::new(
272 "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM (",
273 );
274
275 for (i, filter) in filters.iter().enumerate() {
276 if i > 0 {
277 qb.push(" UNION ALL ");
278 }
279
280 qb.push(
281 "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM \
282 es_events WHERE event_kind = ",
283 )
284 .push_bind(&filter.event_kind);
285
286 if let Some(kind) = &filter.aggregate_kind {
287 qb.push(" AND aggregate_kind = ").push_bind(kind);
288 }
289
290 if let Some(id) = &filter.aggregate_id {
291 qb.push(" AND aggregate_id = ").push_bind(id);
292 }
293
294 if let Some(after) = filter.after_position {
295 if after < 0 {
296 return Err(Error::InvalidPosition(after));
297 }
298 qb.push(" AND position > ").push_bind(after);
299 }
300 }
301
302 qb.push(") t ORDER BY position ASC");
303
304 let rows = qb.build().fetch_all(&self.pool).await?;
305
306 let mut out = Vec::with_capacity(rows.len());
307 for row in rows {
308 let aggregate_kind: String = row.try_get("aggregate_kind")?;
309 let aggregate_id: uuid::Uuid = row.try_get("aggregate_id")?;
310 let event_kind: String = row.try_get("event_kind")?;
311 let position: i64 = row.try_get("position")?;
312 let data: Vec<u8> = row.try_get("data")?;
313 let metadata: sqlx::types::Json<M> = row.try_get("metadata")?;
314
315 out.push(StoredEvent {
316 aggregate_kind,
317 aggregate_id,
318 kind: event_kind,
319 position,
320 data,
321 metadata: metadata.0,
322 });
323 }
324
325 Ok(out)
326 }
327
328 #[tracing::instrument(
329 skip(self, events),
330 fields(aggregate_kind, aggregate_id = %aggregate_id, events_len = events.len())
331 )]
332 async fn append_expecting_new<'a>(
333 &'a mut self,
334 aggregate_kind: &'a str,
335 aggregate_id: &'a Self::Id,
336 events: Vec<PersistableEvent<Self::Metadata>>,
337 ) -> Result<(), AppendError<Self::Position, Self::Error>> {
338 if events.is_empty() {
339 return Ok(());
340 }
341
342 let mut tx = self
343 .pool
344 .begin()
345 .await
346 .map_err(|e| AppendError::store(Error::Database(e)))?;
347
348 sqlx::query(
349 r"
350 INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
351 VALUES ($1, $2, NULL)
352 ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
353 ",
354 )
355 .bind(aggregate_kind)
356 .bind(aggregate_id)
357 .execute(&mut *tx)
358 .await
359 .map_err(|e| AppendError::store(Error::Database(e)))?;
360
361 let current: Option<i64> = sqlx::query_scalar(
362 r"
363 SELECT last_position
364 FROM es_streams
365 WHERE aggregate_kind = $1 AND aggregate_id = $2
366 FOR UPDATE
367 ",
368 )
369 .bind(aggregate_kind)
370 .bind(aggregate_id)
371 .fetch_optional(&mut *tx)
372 .await
373 .map_err(|e| AppendError::store(Error::Database(e)))?;
374
375 if let Some(actual) = current {
376 return Err(AppendError::Conflict(ConcurrencyConflict {
377 expected: None,
378 actual: Some(actual),
379 }));
380 }
381
382 let mut qb = QueryBuilder::<Postgres>::new(
383 "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
384 );
385 qb.push_values(events, |mut b, event| {
386 b.push_bind(aggregate_kind);
387 b.push_bind(aggregate_id);
388 b.push_bind(event.kind);
389 b.push_bind(event.data);
390 b.push_bind(sqlx::types::Json(event.metadata));
391 });
392 qb.push(" RETURNING position");
393
394 let rows: Vec<i64> = qb
395 .build_query_scalar()
396 .fetch_all(&mut *tx)
397 .await
398 .map_err(|e| AppendError::store(Error::Database(e)))?;
399
400 let last_position = rows
401 .last()
402 .ok_or_else(|| AppendError::store(Error::MissingReturnedPosition))?;
403
404 sqlx::query(
405 r"
406 UPDATE es_streams
407 SET last_position = $1
408 WHERE aggregate_kind = $2 AND aggregate_id = $3
409 ",
410 )
411 .bind(last_position)
412 .bind(aggregate_kind)
413 .bind(aggregate_id)
414 .execute(&mut *tx)
415 .await
416 .map_err(|e| AppendError::store(Error::Database(e)))?;
417
418 tx.commit()
419 .await
420 .map_err(|e| AppendError::store(Error::Database(e)))?;
421
422 Ok(())
423 }
424}