1use cqrs_es::persist::{
2 PersistedEventRepository, PersistenceError, ReplayStream, SerializedEvent, SerializedSnapshot,
3};
4use cqrs_es::Aggregate;
5use futures::TryStreamExt;
6use serde_json::Value;
7use sqlx::sqlite::SqliteRow;
8use sqlx::{Pool, Row, Sqlite, Transaction};
9
10use crate::error::SqliteAggregateError;
11use crate::sql_query::SqlQueryFactory;
12
13const DEFAULT_EVENT_TABLE: &str = "events";
14const DEFAULT_SNAPSHOT_TABLE: &str = "snapshots";
15
16const DEFAULT_STREAMING_CHANNEL_SIZE: usize = 200;
17
18pub struct SqliteEventRepository {
20 pool: Pool<Sqlite>,
21 query_factory: SqlQueryFactory,
22 stream_channel_size: usize,
23}
24
25impl PersistedEventRepository for SqliteEventRepository {
26 async fn get_events<A: Aggregate>(
27 &self,
28 aggregate_id: &str,
29 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
30 self.select_events::<A>(aggregate_id, self.query_factory.select_events())
31 .await
32 }
33
34 async fn get_last_events<A: Aggregate>(
35 &self,
36 aggregate_id: &str,
37 last_sequence: usize,
38 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
39 let query = self.query_factory.get_last_events(last_sequence);
40 self.select_events::<A>(aggregate_id, &query).await
41 }
42
43 async fn get_snapshot<A: Aggregate>(
44 &self,
45 aggregate_id: &str,
46 ) -> Result<Option<SerializedSnapshot>, PersistenceError> {
47 let row: SqliteRow = match sqlx::query(self.query_factory.select_snapshot())
48 .bind(A::TYPE)
49 .bind(aggregate_id)
50 .fetch_optional(&self.pool)
51 .await
52 .map_err(SqliteAggregateError::from)?
53 {
54 Some(row) => row,
55 None => {
56 return Ok(None);
57 }
58 };
59 Ok(Some(self.deser_snapshot(row)?))
60 }
61
62 async fn persist<A: Aggregate>(
63 &self,
64 events: &[SerializedEvent],
65 snapshot_update: Option<(String, Value, usize)>,
66 ) -> Result<(), PersistenceError> {
67 match snapshot_update {
68 None => {
69 self.insert_events::<A>(events).await?;
70 }
71 Some((aggregate_id, aggregate, current_snapshot)) => {
72 if current_snapshot == 1 {
73 self.insert::<A>(aggregate, aggregate_id, current_snapshot, events)
74 .await?;
75 } else {
76 self.update::<A>(aggregate, aggregate_id, current_snapshot, events)
77 .await?;
78 }
79 }
80 };
81 Ok(())
82 }
83
84 async fn stream_events<A: Aggregate>(
85 &self,
86 aggregate_id: &str,
87 ) -> Result<ReplayStream, PersistenceError> {
88 Ok(stream_events(
89 self.query_factory.select_events().to_string(),
90 A::TYPE.to_string(),
91 aggregate_id.to_string(),
92 self.pool.clone(),
93 self.stream_channel_size,
94 ))
95 }
96
97 async fn stream_all_events<A: Aggregate>(&self) -> Result<ReplayStream, PersistenceError> {
99 Ok(stream_events(
100 self.query_factory.all_events().to_string(),
101 A::TYPE.to_string(),
102 "".to_string(),
103 self.pool.clone(),
104 self.stream_channel_size,
105 ))
106 }
107}
108
109fn stream_events(
110 query: String,
111 aggregate_type: String,
112 aggregate_id: String,
113 pool: Pool<Sqlite>,
114 channel_size: usize,
115) -> ReplayStream {
116 let (mut feed, stream) = ReplayStream::new(channel_size);
117 tokio::spawn(async move {
118 let query = sqlx::query(&query)
119 .bind(&aggregate_type)
120 .bind(&aggregate_id);
121 let mut rows = query.fetch(&pool);
122 while let Some(row) = rows.try_next().await.unwrap() {
123 let event_result: Result<SerializedEvent, PersistenceError> =
124 SqliteEventRepository::deser_event(row).map_err(Into::into);
125 if feed.push(event_result).await.is_err() {
126 return;
128 };
129 }
130 });
131 stream
132}
133
134impl SqliteEventRepository {
135 async fn select_events<A: Aggregate>(
136 &self,
137 aggregate_id: &str,
138 query: &str,
139 ) -> Result<Vec<SerializedEvent>, PersistenceError> {
140 let mut rows = sqlx::query(query)
141 .bind(A::TYPE)
142 .bind(aggregate_id)
143 .fetch(&self.pool);
144 let mut result: Vec<SerializedEvent> = Default::default();
145 while let Some(row) = rows.try_next().await.map_err(SqliteAggregateError::from)? {
146 result.push(SqliteEventRepository::deser_event(row)?);
147 }
148 Ok(result)
149 }
150}
151
152impl SqliteEventRepository {
153 pub fn new(pool: Pool<Sqlite>) -> Self {
165 Self::use_tables(pool, DEFAULT_EVENT_TABLE, DEFAULT_SNAPSHOT_TABLE)
166 }
167
168 pub fn with_streaming_channel_size(self, stream_channel_size: usize) -> Self {
181 Self {
182 pool: self.pool,
183 query_factory: self.query_factory,
184 stream_channel_size,
185 }
186 }
187
188 pub fn with_tables(self, events_table: &str, snapshots_table: &str) -> Self {
202 Self::use_tables(self.pool, events_table, snapshots_table)
203 }
204
205 fn use_tables(pool: Pool<Sqlite>, events_table: &str, snapshots_table: &str) -> Self {
206 Self {
207 pool,
208 query_factory: SqlQueryFactory::new(events_table, snapshots_table),
209 stream_channel_size: DEFAULT_STREAMING_CHANNEL_SIZE,
210 }
211 }
212
213 pub(crate) async fn insert_events<A: Aggregate>(
214 &self,
215 events: &[SerializedEvent],
216 ) -> Result<(), SqliteAggregateError> {
217 let mut tx = self.pool.begin().await?;
218 self.persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
219 .await?;
220 tx.commit().await?;
221 Ok(())
222 }
223
224 pub(crate) async fn insert<A: Aggregate>(
225 &self,
226 aggregate_payload: Value,
227 aggregate_id: String,
228 current_snapshot: usize,
229 events: &[SerializedEvent],
230 ) -> Result<(), SqliteAggregateError> {
231 let mut tx = self.pool.begin().await?;
232 let current_sequence = self
233 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
234 .await?;
235 sqlx::query(self.query_factory.insert_snapshot())
236 .bind(A::TYPE)
237 .bind(aggregate_id.as_str())
238 .bind(current_sequence as i32)
239 .bind(current_snapshot as i32)
240 .bind(&aggregate_payload)
241 .execute(&mut *tx)
242 .await?;
243 tx.commit().await?;
244 Ok(())
245 }
246
247 pub(crate) async fn update<A: Aggregate>(
248 &self,
249 aggregate: Value,
250 aggregate_id: String,
251 current_snapshot: usize,
252 events: &[SerializedEvent],
253 ) -> Result<(), SqliteAggregateError> {
254 let mut tx = self.pool.begin().await?;
255 let current_sequence = self
256 .persist_events::<A>(self.query_factory.insert_event(), &mut tx, events)
257 .await?;
258
259 let aggregate_payload = serde_json::to_value(&aggregate)?;
260 let result = sqlx::query(self.query_factory.update_snapshot())
261 .bind(current_sequence as i32)
262 .bind(&aggregate_payload)
263 .bind(current_snapshot as i32)
264 .bind(A::TYPE)
265 .bind(aggregate_id.as_str())
266 .bind((current_snapshot - 1) as i32)
267 .execute(&mut *tx)
268 .await?;
269 tx.commit().await?;
270 match result.rows_affected() {
271 1 => Ok(()),
272 _ => Err(SqliteAggregateError::OptimisticLock),
273 }
274 }
275
276 fn deser_event(row: SqliteRow) -> Result<SerializedEvent, SqliteAggregateError> {
277 let aggregate_type: String = row.get("aggregate_type");
278 let aggregate_id: String = row.get("aggregate_id");
279 let sequence = {
280 let s: i64 = row.get("sequence");
281 s as usize
282 };
283 let event_type: String = row.get("event_type");
284 let event_version: String = row.get("event_version");
285 let payload: Value = row.get("payload");
286 let metadata: Value = row.get("metadata");
287 Ok(SerializedEvent::new(
288 aggregate_id,
289 sequence,
290 aggregate_type,
291 event_type,
292 event_version,
293 payload,
294 metadata,
295 ))
296 }
297
298 fn deser_snapshot(&self, row: SqliteRow) -> Result<SerializedSnapshot, SqliteAggregateError> {
299 let aggregate_id = row.get("aggregate_id");
300 let s: i64 = row.get("last_sequence");
301 let current_sequence = s as usize;
302 let s: i64 = row.get("current_snapshot");
303 let current_snapshot = s as usize;
304 let aggregate: Value = row.get("payload");
305 Ok(SerializedSnapshot {
306 aggregate_id,
307 aggregate,
308 current_sequence,
309 current_snapshot,
310 })
311 }
312
313 pub(crate) async fn persist_events<A: Aggregate>(
314 &self,
315 inser_event_query: &str,
316 tx: &mut Transaction<'_, Sqlite>,
317 events: &[SerializedEvent],
318 ) -> Result<usize, SqliteAggregateError> {
319 let mut current_sequence: usize = 0;
320 for event in events {
321 current_sequence = event.sequence;
322 let event_type = &event.event_type;
323 let event_version = &event.event_version;
324 let payload = serde_json::to_value(&event.payload)?;
325 let metadata = serde_json::to_value(&event.metadata)?;
326 sqlx::query(inser_event_query)
327 .bind(A::TYPE)
328 .bind(event.aggregate_id.as_str())
329 .bind(event.sequence as i32)
330 .bind(event_type)
331 .bind(event_version)
332 .bind(&payload)
333 .bind(&metadata)
334 .execute(&mut **tx)
335 .await?;
336 }
337 Ok(current_sequence)
338 }
339}
340
341#[cfg(test)]
342mod test {
343 use cqrs_es::persist::PersistedEventRepository;
344
345 use crate::error::SqliteAggregateError;
346 use crate::testing::tests::{
347 snapshot_context, test_event_envelope, Created, SomethingElse, TestAggregate, TestEvent,
348 Tested, TEST_CONNECTION_STRING,
349 };
350 use crate::{default_sqlite_pool, SqliteEventRepository};
351
352 #[tokio::test]
353 async fn event_repositories() {
354 let pool = default_sqlite_pool(TEST_CONNECTION_STRING).await;
355 sqlx::migrate!().run(&pool).await.unwrap();
356
357 let id = uuid::Uuid::new_v4().to_string();
358 let event_repo: SqliteEventRepository =
359 SqliteEventRepository::new(pool.clone()).with_streaming_channel_size(1);
360 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
361 assert!(events.is_empty());
362
363 event_repo
364 .insert_events::<TestAggregate>(&[
365 test_event_envelope(&id, 1, TestEvent::Created(Created { id: id.clone() })),
366 test_event_envelope(
367 &id,
368 2,
369 TestEvent::Tested(Tested {
370 test_name: "a test was run".to_string(),
371 }),
372 ),
373 ])
374 .await
375 .unwrap();
376 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
377 assert_eq!(2, events.len());
378 events.iter().for_each(|e| assert_eq!(&id, &e.aggregate_id));
379
380 let result = event_repo
382 .insert_events::<TestAggregate>(&[
383 test_event_envelope(
384 &id,
385 3,
386 TestEvent::SomethingElse(SomethingElse {
387 description: "this should not persist".to_string(),
388 }),
389 ),
390 test_event_envelope(
391 &id,
392 2,
393 TestEvent::SomethingElse(SomethingElse {
394 description: "bad sequence number".to_string(),
395 }),
396 ),
397 ])
398 .await
399 .unwrap_err();
400 match result {
401 SqliteAggregateError::OptimisticLock => {}
402 _ => panic!("invalid error result found during insert: {result}"),
403 };
404
405 let events = event_repo.get_events::<TestAggregate>(&id).await.unwrap();
406 assert_eq!(2, events.len());
407
408 verify_replay_stream(&id, event_repo).await;
409 }
410
411 async fn verify_replay_stream(id: &str, event_repo: SqliteEventRepository) {
412 let mut stream = event_repo.stream_events::<TestAggregate>(id).await.unwrap();
413 let mut found_in_stream = 0;
414 while (stream.next::<TestAggregate>(&[]).await).is_some() {
415 found_in_stream += 1;
416 }
417 assert_eq!(found_in_stream, 2);
418
419 let mut stream = event_repo
420 .stream_all_events::<TestAggregate>()
421 .await
422 .unwrap();
423 let mut found_in_stream = 0;
424 while (stream.next::<TestAggregate>(&[]).await).is_some() {
425 found_in_stream += 1;
426 }
427 assert!(found_in_stream >= 2);
428 }
429
430 #[tokio::test]
431 async fn snapshot_repositories() {
432 let pool = default_sqlite_pool(TEST_CONNECTION_STRING).await;
433 sqlx::migrate!().run(&pool).await.unwrap();
434
435 let id = uuid::Uuid::new_v4().to_string();
436 let event_repo: SqliteEventRepository = SqliteEventRepository::new(pool.clone());
437 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
438 assert_eq!(None, snapshot);
439
440 let test_description = "some test snapshot here".to_string();
441 let test_tests = vec!["testA".to_string(), "testB".to_string()];
442 event_repo
443 .insert::<TestAggregate>(
444 serde_json::to_value(TestAggregate {
445 id: id.clone(),
446 description: test_description.clone(),
447 tests: test_tests.clone(),
448 })
449 .unwrap(),
450 id.clone(),
451 1,
452 &[],
453 )
454 .await
455 .unwrap();
456
457 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
458 assert_eq!(
459 Some(snapshot_context(
460 id.clone(),
461 0,
462 1,
463 serde_json::to_value(TestAggregate {
464 id: id.clone(),
465 description: test_description.clone(),
466 tests: test_tests.clone(),
467 })
468 .unwrap()
469 )),
470 snapshot
471 );
472
473 event_repo
475 .update::<TestAggregate>(
476 serde_json::to_value(TestAggregate {
477 id: id.clone(),
478 description: "a test description that should be saved".to_string(),
479 tests: test_tests.clone(),
480 })
481 .unwrap(),
482 id.clone(),
483 2,
484 &[],
485 )
486 .await
487 .unwrap();
488
489 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
490 assert_eq!(
491 Some(snapshot_context(
492 id.clone(),
493 0,
494 2,
495 serde_json::to_value(TestAggregate {
496 id: id.clone(),
497 description: "a test description that should be saved".to_string(),
498 tests: test_tests.clone(),
499 })
500 .unwrap()
501 )),
502 snapshot
503 );
504
505 let result = event_repo
507 .update::<TestAggregate>(
508 serde_json::to_value(TestAggregate {
509 id: id.clone(),
510 description: "a test description that should not be saved".to_string(),
511 tests: test_tests.clone(),
512 })
513 .unwrap(),
514 id.clone(),
515 2,
516 &[],
517 )
518 .await
519 .unwrap_err();
520 match result {
521 SqliteAggregateError::OptimisticLock => {}
522 _ => panic!("invalid error result found during insert: {result}"),
523 };
524
525 let snapshot = event_repo.get_snapshot::<TestAggregate>(&id).await.unwrap();
526 assert_eq!(
527 Some(snapshot_context(
528 id.clone(),
529 0,
530 2,
531 serde_json::to_value(TestAggregate {
532 id: id.clone(),
533 description: "a test description that should be saved".to_string(),
534 tests: test_tests.clone(),
535 })
536 .unwrap()
537 )),
538 snapshot
539 );
540 }
541}