Skip to main content

sqlite_es/
event_repository.rs

1use cqrs_es::persist::{
2    PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
3};
4use cqrs_es::Aggregate;
5use futures::TryStreamExt;
6use serde_json::Value;
7use sqlx::sqlite::SqliteRow;
8use sqlx::{Pool, Row, Sqlite, Transaction};
9
10use crate::error::SqliteAggregateError;
11use crate::sql_query::SqlQueryFactory;
12
13const DEFAULT_EVENT_TABLE: &str = "events";
14const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
15
16const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
17
18/// An event repository relying on a SQLite database for persistence.
19pub struct SqliteEventRepository {
20    pool: Pool<Sqlite>,
21    query_factory: SqlQueryFactory,
22    stream_channel_size: usize,
23}
24
25impl PersistedEventRepository for SqliteEventRepository {
26    async fn get_events<A: Aggregate>(
27        &self,
28        aggregate_id: &str,
29    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
30        self.select_events::<A>(aggregate_id, self.query_factory.select_events())
31            .await
32    }
33
34    async fn get_last_events<A: Aggregate>(
35        &self,
36        aggregate_id: &str,
37        last_sequence: usize,
38    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
39        let query = self.query_factory.get_last_events(last_sequence);
40        self.select_events::<A>(aggregate_id, &query).await
41    }
42
43    async fn get_snapshot<A: Aggregate>(
44        &self,
45        aggregate_id: &str,
46    ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
47        let row: SqliteRow = match sqlx::query(self.query_factory.select_snapshot())
48            .bind(A::TYPE)
49            .bind(aggregate_id)
50            .fetch_optional(&self.pool)
51            .await
52            .map_err(SqliteAggregateError::from)?
53        {
54            Some(row) => row,
55            None => {
56                return Ok(None);
57            }
58        };
59        Ok(Some(self.deser_snapshot(row)?))
60    }
61
62    async fn persist<A: Aggregate>(
63        &self,
64        events: &[SerializedEvent],
65        snapshot_update: Option<(String, Value, usize)>,
66    ) -> Result<(), PersistenceError> {
67        match snapshot_update {
68            None => {
69                self.insert_events::<A>(events).await?;
70            }
71            Some((aggregate_id, aggregate, current_snapshot)) => {
72                if current_snapshot == 1 {
73                    self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
74                        .await?;
75                } else {
76                    self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
77                        .await?;
78                }
79            }
80        };
81        Ok(())
82    }
83
84    async fn stream_events<A: Aggregate>(
85        &self,
86        aggregate_id: &str,
87    ) -> Result<ReplayStream, PersistenceError> {
88        Ok(stream_events(
89            self.query_factory.select_events().to_string(),
90            A::TYPE.to_string(),
91            aggregate_id.to_string(),
92            self.pool.clone(),
93            self.stream_channel_size,
94        ))
95    }
96
97    // TODO: aggregate id is unused here, `stream_events` function needs to be broken up
98    async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
99        Ok(stream_events(
100            self.query_factory.all_events().to_string(),
101            A::TYPE.to_string(),
102            "".to_string(),
103            self.pool.clone(),
104            self.stream_channel_size,
105        ))
106    }
107}
108
109fn stream_events(
110    query: String,
111    aggregate_type: String,
112    aggregate_id: String,
113    pool: Pool<Sqlite>,
114    channel_size: usize,
115) -> ReplayStream {
116    let (mut feed, stream) = ReplayStream::new(channel_size);
117    tokio::spawn(async move {
118        let query = sqlx::query(&query)
119            .bind(&aggregate_type)
120            .bind(&aggregate_id);
121        let mut rows = query.fetch(&pool);
122        while let Some(row) = rows.try_next().await.unwrap() {
123            let event_result: Result<SerializedEvent, PersistenceError> =
124                SqliteEventRepository::deser_event(row).map_err(Into::into);
125            if feed.push(event_result).await.is_err() {
126                // TODO: in the unlikely event of a broken channel this error should be reported.
127                return;
128            };
129        }
130    });
131    stream
132}
133
134impl SqliteEventRepository {
135    async fn select_events<A: Aggregate>(
136        &self,
137        aggregate_id: &str,
138        query: &str,
139    ) -> Result<Vec<SerializedEvent>, PersistenceError> {
140        let mut rows = sqlx::query(query)
141            .bind(A::TYPE)
142            .bind(aggregate_id)
143            .fetch(&self.pool);
144        let mut result: Vec<SerializedEvent> = Default::default();
145        while let Some(row) = rows.try_next().await.map_err(SqliteAggregateError::from)? {
146            result.push(SqliteEventRepository::deser_event(row)?);
147        }
148        Ok(result)
149    }
150}
151
152impl SqliteEventRepository {
153    /// Creates a new `SqliteEventRepository` from the provided database connection.
154    /// This uses the default tables 'events' and 'snapshots'.
155    ///
156    /// ```
157    /// use sqlx::{Pool, Sqlite};
158    /// use sqlite_es::SqliteEventRepository;
159    ///
160    /// fn configure_repo(pool: Pool<Sqlite>) -> SqliteEventRepository {
161    ///     SqliteEventRepository::new(pool)
162    /// }
163    /// ```
164    pub fn new(pool: Pool<Sqlite>) -> Self {
165        Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
166    }
167
168    /// Configures a `SqliteEventRepository` to use a streaming queue of the provided size.
169    ///
170    /// _Example: configure the repository to stream with a 1000 event buffer._
171    /// ```
172    /// use sqlx::{Pool, Sqlite};
173    /// use sqlite_es::SqliteEventRepository;
174    ///
175    /// fn configure_repo(pool: Pool<Sqlite>) -> SqliteEventRepository {
176    ///     let store = SqliteEventRepository::new(pool);
177    ///     store.with_streaming_channel_size(1000)
178    /// }
179    /// ```
180    pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
181        Self {
182            pool: self.pool,
183            query_factory: self.query_factory,
184            stream_channel_size,
185        }
186    }
187
188    /// Configures a `SqliteEventRepository` to use the provided table names.
189    ///
190    /// _Example: configure the repository to use "my_event_table" and "my_snapshot_table"
191    /// for the event and snapshot table names._
192    /// ```
193    /// use sqlx::{Pool, Sqlite};
194    /// use sqlite_es::SqliteEventRepository;
195    ///
196    /// fn configure_repo(pool: Pool<Sqlite>) -> SqliteEventRepository {
197    ///     let store = SqliteEventRepository::new(pool);
198    ///     store.with_tables("my_event_table", "my_snapshot_table")
199    /// }
200    /// ```
201    pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
202        Self::use_tables(self.pool, events_table, snapshots_table)
203    }
204
205    fn use_tables(pool: Pool<Sqlite>, events_table: &str, snapshots_table: &str) -> Self {
206        Self {
207            pool,
208            query_factory: SqlQueryFactory::new(events_table, snapshots_table),
209            stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
210        }
211    }
212
213    pub(crate) async fn insert_events<A: Aggregate>(
214        &self,
215        events: &[SerializedEvent],
216    ) -> Result<(), SqliteAggregateError> {
217        let mut tx = self.pool.begin().await?;
218        self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
219            .await?;
220        tx.commit().await?;
221        Ok(())
222    }
223
224    pub(crate) async fn insert<A: Aggregate>(
225        &self,
226        aggregate_payload: Value,
227        aggregate_id: String,
228        current_snapshot: usize,
229        events: &[SerializedEvent],
230    ) -> Result<(), SqliteAggregateError> {
231        let mut tx = self.pool.begin().await?;
232        let current_sequence = self
233            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
234            .await?;
235        sqlx::query(self.query_factory.insert_snapshot())
236            .bind(A::TYPE)
237            .bind(aggregate_id.as_str())
238            .bind(current_sequence as i32)
239            .bind(current_snapshot as i32)
240            .bind(&aggregate_payload)
241            .execute(&mut *tx)
242            .await?;
243        tx.commit().await?;
244        Ok(())
245    }
246
247    pub(crate) async fn update<A: Aggregate>(
248        &self,
249        aggregate: Value,
250        aggregate_id: String,
251        current_snapshot: usize,
252        events: &[SerializedEvent],
253    ) -> Result<(), SqliteAggregateError> {
254        let mut tx = self.pool.begin().await?;
255        let current_sequence = self
256            .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
257            .await?;
258
259        let aggregate_payload = serde_json::to_value(&aggregate)?;
260        let result = sqlx::query(self.query_factory.update_snapshot())
261            .bind(current_sequence as i32)
262            .bind(&aggregate_payload)
263            .bind(current_snapshot as i32)
264            .bind(A::TYPE)
265            .bind(aggregate_id.as_str())
266            .bind((current_snapshot - 1) as i32)
267            .execute(&mut *tx)
268            .await?;
269        tx.commit().await?;
270        match result.rows_affected() {
271            1 => Ok(()),
272            _ => Err(SqliteAggregateError::OptimisticLock),
273        }
274    }
275
276    fn deser_event(row: SqliteRow) -> Result<SerializedEvent, SqliteAggregateError> {
277        let aggregate_type: String = row.get("aggregate_type");
278        let aggregate_id: String = row.get("aggregate_id");
279        let sequence = {
280            let s: i64 = row.get("sequence");
281            s as usize
282        };
283        let event_type: String = row.get("event_type");
284        let event_version: String = row.get("event_version");
285        let payload: Value = row.get("payload");
286        let metadata: Value = row.get("metadata");
287        Ok(SerializedEvent::new(
288            aggregate_id,
289            sequence,
290            aggregate_type,
291            event_type,
292            event_version,
293            payload,
294            metadata,
295        ))
296    }
297
298    fn deser_snapshot(&self, row: SqliteRow) -> Result<SerializedSnapshot, SqliteAggregateError> {
299        let aggregate_id = row.get("aggregate_id");
300        let s: i64 = row.get("last_sequence");
301        let current_sequence = s as usize;
302        let s: i64 = row.get("current_snapshot");
303        let current_snapshot = s as usize;
304        let aggregate: Value = row.get("payload");
305        Ok(SerializedSnapshot {
306            aggregate_id,
307            aggregate,
308            current_sequence,
309            current_snapshot,
310        })
311    }
312
313    pub(crate) async fn persist_events<A: Aggregate>(
314        &self,
315        inser_event_query: &str,
316        tx: &mut Transaction<'_, Sqlite>,
317        events: &[SerializedEvent],
318    ) -> Result<usize, SqliteAggregateError> {
319        let mut current_sequence: usize = 0;
320        for event in events {
321            current_sequence = event.sequence;
322            let event_type = &event.event_type;
323            let event_version = &event.event_version;
324            let payload = serde_json::to_value(&event.payload)?;
325            let metadata = serde_json::to_value(&event.metadata)?;
326            sqlx::query(inser_event_query)
327                .bind(A::TYPE)
328                .bind(event.aggregate_id.as_str())
329                .bind(event.sequence as i32)
330                .bind(event_type)
331                .bind(event_version)
332                .bind(&payload)
333                .bind(&metadata)
334                .execute(&mut **tx)
335                .await?;
336        }
337        Ok(current_sequence)
338    }
339}
340
341#[cfg(test)]
342mod test {
343    use cqrs_es::persist::PersistedEventRepository;
344
345    use crate::error::SqliteAggregateError;
346    use crate::testing::tests::{
347        snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
348        Tested, TEST_CONNECTION_STRING,
349    };
350    use crate::{default_sqlite_pool, SqliteEventRepository};
351
352    #[tokio::test]
353    async fn event_repositories() {
354        let pool = default_sqlite_pool(TEST_CONNECTION_STRING).await;
355        sqlx::migrate!().run(&pool).await.unwrap();
356
357        let id = uuid::Uuid::new_v4().to_string();
358        let event_repo: SqliteEventRepository =
359            SqliteEventRepository::new(pool.clone()).with_streaming_channel_size(1);
360        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
361        assert!(events.is_empty());
362
363        event_repo
364            .insert_events::<TestAggregate>(&[
365                test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
366                test_event_envelope(
367                    &id,
368                    2,
369                    TestEvent::Tested(Tested {
370                        test_name: "a test was run".to_string(),
371                    }),
372                ),
373            ])
374            .await
375            .unwrap();
376        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
377        assert_eq!(2, events.len());
378        events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
379
380        // Optimistic lock error
381        let result = event_repo
382            .insert_events::<TestAggregate>(&[
383                test_event_envelope(
384                    &id,
385                    3,
386                    TestEvent::SomethingElse(SomethingElse {
387                        description: "this should not persist".to_string(),
388                    }),
389                ),
390                test_event_envelope(
391                    &id,
392                    2,
393                    TestEvent::SomethingElse(SomethingElse {
394                        description: "bad sequence number".to_string(),
395                    }),
396                ),
397            ])
398            .await
399            .unwrap_err();
400        match result {
401            SqliteAggregateError::OptimisticLock => {}
402            _ => panic!("invalid error result found during insert: {result}"),
403        };
404
405        let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
406        assert_eq!(2, events.len());
407
408        verify_replay_stream(&id, event_repo).await;
409    }
410
411    async fn verify_replay_stream(id: &str, event_repo: SqliteEventRepository) {
412        let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
413        let mut found_in_stream = 0;
414        while (stream.next::<TestAggregate>(&[]).await).is_some() {
415            found_in_stream += 1;
416        }
417        assert_eq!(found_in_stream, 2);
418
419        let mut stream = event_repo
420            .stream_all_events::<TestAggregate>()
421            .await
422            .unwrap();
423        let mut found_in_stream = 0;
424        while (stream.next::<TestAggregate>(&[]).await).is_some() {
425            found_in_stream += 1;
426        }
427        assert!(found_in_stream >= 2);
428    }
429
430    #[tokio::test]
431    async fn snapshot_repositories() {
432        let pool = default_sqlite_pool(TEST_CONNECTION_STRING).await;
433        sqlx::migrate!().run(&pool).await.unwrap();
434
435        let id = uuid::Uuid::new_v4().to_string();
436        let event_repo: SqliteEventRepository = SqliteEventRepository::new(pool.clone());
437        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
438        assert_eq!(None, snapshot);
439
440        let test_description = "some test snapshot here".to_string();
441        let test_tests = vec!["testA".to_string(), "testB".to_string()];
442        event_repo
443            .insert::<TestAggregate>(
444                serde_json::to_value(TestAggregate {
445                    id: id.clone(),
446                    description: test_description.clone(),
447                    tests: test_tests.clone(),
448                })
449                .unwrap(),
450                id.clone(),
451                1,
452                &[],
453            )
454            .await
455            .unwrap();
456
457        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
458        assert_eq!(
459            Some(snapshot_context(
460                id.clone(),
461                0,
462                1,
463                serde_json::to_value(TestAggregate {
464                    id: id.clone(),
465                    description: test_description.clone(),
466                    tests: test_tests.clone(),
467                })
468                .unwrap()
469            )),
470            snapshot
471        );
472
473        // sequence iterated, does update
474        event_repo
475            .update::<TestAggregate>(
476                serde_json::to_value(TestAggregate {
477                    id: id.clone(),
478                    description: "a test description that should be saved".to_string(),
479                    tests: test_tests.clone(),
480                })
481                .unwrap(),
482                id.clone(),
483                2,
484                &[],
485            )
486            .await
487            .unwrap();
488
489        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
490        assert_eq!(
491            Some(snapshot_context(
492                id.clone(),
493                0,
494                2,
495                serde_json::to_value(TestAggregate {
496                    id: id.clone(),
497                    description: "a test description that should be saved".to_string(),
498                    tests: test_tests.clone(),
499                })
500                .unwrap()
501            )),
502            snapshot
503        );
504
505        // sequence out of order or not iterated, does not update
506        let result = event_repo
507            .update::<TestAggregate>(
508                serde_json::to_value(TestAggregate {
509                    id: id.clone(),
510                    description: "a test description that should not be saved".to_string(),
511                    tests: test_tests.clone(),
512                })
513                .unwrap(),
514                id.clone(),
515                2,
516                &[],
517            )
518            .await
519            .unwrap_err();
520        match result {
521            SqliteAggregateError::OptimisticLock => {}
522            _ => panic!("invalid error result found during insert: {result}"),
523        };
524
525        let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
526        assert_eq!(
527            Some(snapshot_context(
528                id.clone(),
529                0,
530                2,
531                serde_json::to_value(TestAggregate {
532                    id: id.clone(),
533                    description: "a test description that should be saved".to_string(),
534                    tests: test_tests.clone(),
535                })
536                .unwrap()
537            )),
538            snapshot
539        );
540    }
541}