sourcery_postgres/
lib.rs

1//! Postgres-backed event store implementation.
2//!
3//! This crate provides [`PostgresEventStore`], an implementation of
4//! [`sourcery_core::store::EventStore`] for `PostgreSQL`.
5
6use std::marker::PhantomData;
7
8use serde::{Serialize, de::DeserializeOwned};
9use sourcery_core::{
10    codec::Codec,
11    concurrency::{ConcurrencyConflict, ConcurrencyStrategy},
12    store::{
13        AppendError, EventFilter, EventStore, LoadEventsResult, PersistableEvent, StoredEvent,
14        Transaction,
15    },
16};
17use sqlx::{PgPool, Postgres, QueryBuilder, Row};
18
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21    #[error("database error: {0}")]
22    Database(#[from] sqlx::Error),
23    #[error("invalid position value from database: {0}")]
24    InvalidPosition(i64),
25    #[error("database did not return an inserted position")]
26    MissingReturnedPosition,
27}
28
29/// A PostgreSQL-backed [`EventStore`].
30///
31/// Defaults are intentionally conservative:
32/// - Positions are global and monotonic (`i64`, backed by `BIGSERIAL`).
33/// - Metadata is stored as `jsonb` (`M: Serialize + DeserializeOwned`).
34#[derive(Clone)]
35pub struct Store<C, M> {
36    pool: PgPool,
37    codec: C,
38    _phantom: PhantomData<M>,
39}
40
41impl<C, M> Store<C, M>
42where
43    C: Sync,
44    M: Sync,
45{
46    #[must_use]
47    pub const fn new(pool: PgPool, codec: C) -> Self {
48        Self {
49            pool,
50            codec,
51            _phantom: PhantomData,
52        }
53    }
54
55    /// Apply the initial schema (idempotent).
56    ///
57    /// This uses `CREATE TABLE IF NOT EXISTS` style DDL so it can be run on
58    /// startup.
59    ///
60    /// # Errors
61    ///
62    /// Returns a `sqlx::Error` if any of the schema creation queries fail.
63    #[tracing::instrument(skip(self))]
64    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
65        // Streams track per-aggregate last position for optimistic concurrency.
66        sqlx::query(
67            r"
68            CREATE TABLE IF NOT EXISTS es_streams (
69                aggregate_kind TEXT NOT NULL,
70                aggregate_id   UUID NOT NULL,
71                last_position  BIGINT NULL,
72                PRIMARY KEY (aggregate_kind, aggregate_id)
73            )
74            ",
75        )
76        .execute(&self.pool)
77        .await?;
78
79        sqlx::query(
80            r"
81            CREATE TABLE IF NOT EXISTS es_events (
82                position       BIGSERIAL PRIMARY KEY,
83                aggregate_kind TEXT NOT NULL,
84                aggregate_id   UUID NOT NULL,
85                event_kind     TEXT NOT NULL,
86                data           BYTEA NOT NULL,
87                metadata       JSONB NOT NULL,
88                created_at     TIMESTAMPTZ NOT NULL DEFAULT now()
89            )
90            ",
91        )
92        .execute(&self.pool)
93        .await?;
94
95        sqlx::query(
96            r"CREATE INDEX IF NOT EXISTS es_events_by_kind_and_position ON es_events(event_kind, position)",
97        )
98        .execute(&self.pool)
99        .await?;
100
101        sqlx::query(
102            r"CREATE INDEX IF NOT EXISTS es_events_by_stream_and_position ON es_events(aggregate_kind, aggregate_id, position)",
103        )
104        .execute(&self.pool)
105        .await?;
106
107        Ok(())
108    }
109}
110
111impl<C, M> EventStore for Store<C, M>
112where
113    C: Codec + Clone + Send + Sync + 'static,
114    M: Serialize + DeserializeOwned + Send + Sync + 'static,
115{
116    type Codec = C;
117    type Error = Error;
118    type Id = uuid::Uuid;
119    type Metadata = M;
120    type Position = i64;
121
122    fn codec(&self) -> &Self::Codec {
123        &self.codec
124    }
125
126    async fn stream_version<'a>(
127        &'a self,
128        aggregate_kind: &'a str,
129        aggregate_id: &'a Self::Id,
130    ) -> Result<Option<Self::Position>, Self::Error> {
131        let result: Option<i64> = sqlx::query_scalar(
132            r"SELECT last_position FROM es_streams WHERE aggregate_kind = $1 AND aggregate_id = $2",
133        )
134        .bind(aggregate_kind)
135        .bind(aggregate_id)
136        .fetch_optional(&self.pool)
137        .await?
138        .flatten();
139
140        Ok(result)
141    }
142
143    fn begin<Conc: ConcurrencyStrategy>(
144        &mut self,
145        aggregate_kind: &str,
146        aggregate_id: Self::Id,
147        expected_version: Option<Self::Position>,
148    ) -> Transaction<'_, Self, Conc> {
149        Transaction::new(
150            self,
151            aggregate_kind.to_string(),
152            aggregate_id,
153            expected_version,
154        )
155    }
156
157    #[tracing::instrument(
158        skip(self, events),
159        fields(
160            aggregate_kind,
161            aggregate_id = %aggregate_id,
162            expected_version,
163            events_len = events.len()
164        )
165    )]
166    async fn append<'a>(
167        &'a mut self,
168        aggregate_kind: &'a str,
169        aggregate_id: &'a Self::Id,
170        expected_version: Option<Self::Position>,
171        events: Vec<PersistableEvent<Self::Metadata>>,
172    ) -> Result<(), AppendError<Self::Position, Self::Error>> {
173        if events.is_empty() {
174            return Ok(());
175        }
176
177        let mut tx = self
178            .pool
179            .begin()
180            .await
181            .map_err(|e| AppendError::store(Error::Database(e)))?;
182
183        sqlx::query(
184            r"
185                INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
186                VALUES ($1, $2, NULL)
187                ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
188                ",
189        )
190        .bind(aggregate_kind)
191        .bind(aggregate_id)
192        .execute(&mut *tx)
193        .await
194        .map_err(|e| AppendError::store(Error::Database(e)))?;
195
196        let current: Option<i64> = sqlx::query_scalar(
197            r"
198                SELECT last_position
199                FROM es_streams
200                WHERE aggregate_kind = $1 AND aggregate_id = $2
201                FOR UPDATE
202                ",
203        )
204        .bind(aggregate_kind)
205        .bind(aggregate_id)
206        .fetch_one(&mut *tx)
207        .await
208        .map_err(|e| AppendError::store(Error::Database(e)))?;
209
210        if let Some(expected) = expected_version
211            && current != Some(expected)
212        {
213            return Err(AppendError::Conflict(ConcurrencyConflict {
214                expected: Some(expected),
215                actual: current,
216            }));
217        }
218
219        let mut qb = QueryBuilder::<Postgres>::new(
220            "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
221        );
222        qb.push_values(events, |mut b, event| {
223            b.push_bind(aggregate_kind);
224            b.push_bind(aggregate_id);
225            b.push_bind(event.kind);
226            b.push_bind(event.data);
227            b.push_bind(sqlx::types::Json(event.metadata));
228        });
229        qb.push(" RETURNING position");
230
231        let rows: Vec<i64> = qb
232            .build_query_scalar()
233            .fetch_all(&mut *tx)
234            .await
235            .map_err(|e| AppendError::store(Error::Database(e)))?;
236
237        let last_position = rows
238            .last()
239            .ok_or_else(|| AppendError::store(Error::MissingReturnedPosition))?;
240
241        sqlx::query(
242            r"
243                UPDATE es_streams
244                SET last_position = $1
245                WHERE aggregate_kind = $2 AND aggregate_id = $3
246                ",
247        )
248        .bind(last_position)
249        .bind(aggregate_kind)
250        .bind(aggregate_id)
251        .execute(&mut *tx)
252        .await
253        .map_err(|e| AppendError::store(Error::Database(e)))?;
254
255        tx.commit()
256            .await
257            .map_err(|e| AppendError::store(Error::Database(e)))?;
258
259        Ok(())
260    }
261
262    #[tracing::instrument(skip(self, filters), fields(filters_len = filters.len()))]
263    async fn load_events<'a>(
264        &'a self,
265        filters: &'a [EventFilter<Self::Id, Self::Position>],
266    ) -> LoadEventsResult<Self::Id, Self::Position, Self::Metadata, Self::Error> {
267        if filters.is_empty() {
268            return Ok(Vec::new());
269        }
270
271        let mut qb = QueryBuilder::<Postgres>::new(
272            "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM (",
273        );
274
275        for (i, filter) in filters.iter().enumerate() {
276            if i > 0 {
277                qb.push(" UNION ALL ");
278            }
279
280            qb.push(
281                "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM \
282                 es_events WHERE event_kind = ",
283            )
284            .push_bind(&filter.event_kind);
285
286            if let Some(kind) = &filter.aggregate_kind {
287                qb.push(" AND aggregate_kind = ").push_bind(kind);
288            }
289
290            if let Some(id) = &filter.aggregate_id {
291                qb.push(" AND aggregate_id = ").push_bind(id);
292            }
293
294            if let Some(after) = filter.after_position {
295                if after < 0 {
296                    return Err(Error::InvalidPosition(after));
297                }
298                qb.push(" AND position > ").push_bind(after);
299            }
300        }
301
302        qb.push(") t ORDER BY position ASC");
303
304        let rows = qb.build().fetch_all(&self.pool).await?;
305
306        let mut out = Vec::with_capacity(rows.len());
307        for row in rows {
308            let aggregate_kind: String = row.try_get("aggregate_kind")?;
309            let aggregate_id: uuid::Uuid = row.try_get("aggregate_id")?;
310            let event_kind: String = row.try_get("event_kind")?;
311            let position: i64 = row.try_get("position")?;
312            let data: Vec<u8> = row.try_get("data")?;
313            let metadata: sqlx::types::Json<M> = row.try_get("metadata")?;
314
315            out.push(StoredEvent {
316                aggregate_kind,
317                aggregate_id,
318                kind: event_kind,
319                position,
320                data,
321                metadata: metadata.0,
322            });
323        }
324
325        Ok(out)
326    }
327
328    #[tracing::instrument(
329        skip(self, events),
330        fields(aggregate_kind, aggregate_id = %aggregate_id, events_len = events.len())
331    )]
332    async fn append_expecting_new<'a>(
333        &'a mut self,
334        aggregate_kind: &'a str,
335        aggregate_id: &'a Self::Id,
336        events: Vec<PersistableEvent<Self::Metadata>>,
337    ) -> Result<(), AppendError<Self::Position, Self::Error>> {
338        if events.is_empty() {
339            return Ok(());
340        }
341
342        let mut tx = self
343            .pool
344            .begin()
345            .await
346            .map_err(|e| AppendError::store(Error::Database(e)))?;
347
348        sqlx::query(
349            r"
350                INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
351                VALUES ($1, $2, NULL)
352                ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
353                ",
354        )
355        .bind(aggregate_kind)
356        .bind(aggregate_id)
357        .execute(&mut *tx)
358        .await
359        .map_err(|e| AppendError::store(Error::Database(e)))?;
360
361        let current: Option<i64> = sqlx::query_scalar(
362            r"
363                SELECT last_position
364                FROM es_streams
365                WHERE aggregate_kind = $1 AND aggregate_id = $2
366                FOR UPDATE
367                ",
368        )
369        .bind(aggregate_kind)
370        .bind(aggregate_id)
371        .fetch_optional(&mut *tx)
372        .await
373        .map_err(|e| AppendError::store(Error::Database(e)))?;
374
375        if let Some(actual) = current {
376            return Err(AppendError::Conflict(ConcurrencyConflict {
377                expected: None,
378                actual: Some(actual),
379            }));
380        }
381
382        let mut qb = QueryBuilder::<Postgres>::new(
383            "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
384        );
385        qb.push_values(events, |mut b, event| {
386            b.push_bind(aggregate_kind);
387            b.push_bind(aggregate_id);
388            b.push_bind(event.kind);
389            b.push_bind(event.data);
390            b.push_bind(sqlx::types::Json(event.metadata));
391        });
392        qb.push(" RETURNING position");
393
394        let rows: Vec<i64> = qb
395            .build_query_scalar()
396            .fetch_all(&mut *tx)
397            .await
398            .map_err(|e| AppendError::store(Error::Database(e)))?;
399
400        let last_position = rows
401            .last()
402            .ok_or_else(|| AppendError::store(Error::MissingReturnedPosition))?;
403
404        sqlx::query(
405            r"
406                UPDATE es_streams
407                SET last_position = $1
408                WHERE aggregate_kind = $2 AND aggregate_id = $3
409                ",
410        )
411        .bind(last_position)
412        .bind(aggregate_kind)
413        .bind(aggregate_id)
414        .execute(&mut *tx)
415        .await
416        .map_err(|e| AppendError::store(Error::Database(e)))?;
417
418        tx.commit()
419            .await
420            .map_err(|e| AppendError::store(Error::Database(e)))?;
421
422        Ok(())
423    }
424}