thalo_postgres/
event_store.rs

1use std::fmt::Write;
2
3use async_trait::async_trait;
4use bb8_postgres::{
5    bb8::Pool,
6    tokio_postgres::{
7        tls::{MakeTlsConnect, TlsConnect},
8        types::ToSql,
9        IsolationLevel, Socket,
10    },
11    PostgresConnectionManager,
12};
13use serde::{de::DeserializeOwned, Serialize};
14use thalo::{
15    aggregate::{Aggregate, TypeId},
16    event::{AggregateEventEnvelope, EventType},
17    event_store::EventStore,
18};
19
20use crate::error::Error;
21
22const INSERT_OUTBOX_EVENTS_QUERY: &str = include_str!("queries/insert_outbox_events.sql");
23const LOAD_AGGREGATE_SEQUENCE_QUERY: &str = include_str!("queries/load_aggregate_sequence.sql");
24const LOAD_EVENTS_QUERY: &str = include_str!("queries/load_events.sql");
25const LOAD_EVENTS_BY_ID_QUERY: &str = include_str!("queries/load_events_by_id.sql");
26const SAVE_EVENTS_QUERY: &str = include_str!("queries/save_events.sql");
27
28/// Postgres event store implementation.
29#[derive(Clone)]
30pub struct PgEventStore<Tls>
31where
32    Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
33    <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
34    <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
35    <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
36{
37    pool: Pool<PostgresConnectionManager<Tls>>,
38}
39
40impl<Tls> PgEventStore<Tls>
41where
42    Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
43    <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
44    <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
45    <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
46{
47    /// Connects to an event store.
48    pub async fn connect(
49        uri: impl ToString,
50        tls: Tls,
51    ) -> Result<Self, bb8_postgres::tokio_postgres::Error> {
52        let manager = PostgresConnectionManager::new_from_stringlike(uri, tls)?;
53        let pool = Pool::builder().build(manager).await?;
54
55        Ok(Self { pool })
56    }
57}
58
59#[async_trait]
60impl<Tls> EventStore for PgEventStore<Tls>
61where
62    Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
63    <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
64    <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
65    <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
66{
67    type Error = Error;
68
69    async fn load_events<A>(
70        &self,
71        id: Option<&<A as Aggregate>::ID>,
72    ) -> Result<Vec<AggregateEventEnvelope<A>>, Self::Error>
73    where
74        A: Aggregate,
75        <A as Aggregate>::Event: DeserializeOwned,
76    {
77        let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
78
79        let rows = conn
80            .query(
81                LOAD_EVENTS_QUERY,
82                &[&<A as TypeId>::type_id(), &id.map(|id| id.to_string())],
83            )
84            .await?;
85
86        Ok(rows
87            .into_iter()
88            .map(|row| {
89                let event_id = row.get::<_, i64>(0) as usize;
90
91                let event_json = row.get(5);
92                let event = serde_json::from_value(event_json)
93                    .map_err(|err| Error::DeserializeDbEvent(event_id, err))?;
94
95                Result::<_, Self::Error>::Ok(AggregateEventEnvelope::<A> {
96                    id: event_id,
97                    created_at: row.get(1),
98                    aggregate_type: row.get(2),
99                    aggregate_id: row.get(3),
100                    sequence: row.get::<_, i64>(4) as usize,
101                    event,
102                })
103            })
104            .collect::<Result<Vec<_>, _>>()?)
105    }
106
107    async fn load_events_by_id<A>(
108        &self,
109        ids: &[usize],
110    ) -> Result<Vec<AggregateEventEnvelope<A>>, Self::Error>
111    where
112        A: Aggregate,
113        <A as Aggregate>::Event: DeserializeOwned,
114    {
115        let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
116
117        let rows = conn
118            .query(
119                LOAD_EVENTS_BY_ID_QUERY,
120                &[&ids
121                    .iter()
122                    .map(|id| id.to_string())
123                    .collect::<Vec<_>>()
124                    .join(",")],
125            )
126            .await?;
127
128        Ok(rows
129            .into_iter()
130            .map(|row| {
131                let event_id = row.get::<_, i64>(0) as usize;
132
133                let event_json = row.get(5);
134                let event = serde_json::from_value(event_json)
135                    .map_err(|err| Error::DeserializeDbEvent(event_id, err))?;
136
137                Result::<_, Self::Error>::Ok(AggregateEventEnvelope::<A> {
138                    id: event_id,
139                    created_at: row.get(1),
140                    aggregate_type: row.get(2),
141                    aggregate_id: row.get(3),
142                    sequence: row.get::<_, i64>(4) as usize,
143                    event,
144                })
145            })
146            .collect::<Result<Vec<_>, _>>()?)
147    }
148
149    async fn load_aggregate_sequence<A>(
150        &self,
151        id: &<A as Aggregate>::ID,
152    ) -> Result<Option<usize>, Self::Error>
153    where
154        A: Aggregate,
155    {
156        let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
157
158        let row = conn
159            .query_one(
160                LOAD_AGGREGATE_SEQUENCE_QUERY,
161                &[&<A as TypeId>::type_id(), &id.to_string()],
162            )
163            .await?;
164
165        Ok(row
166            .get::<_, Option<i64>>(0)
167            .map(|sequence| sequence as usize))
168    }
169
170    async fn save_events<A>(
171        &self,
172        id: &<A as Aggregate>::ID,
173        events: &[<A as Aggregate>::Event],
174    ) -> Result<Vec<usize>, Self::Error>
175    where
176        A: Aggregate,
177        <A as Aggregate>::Event: Serialize,
178    {
179        if events.is_empty() {
180            return Ok(vec![]);
181        }
182
183        let sequence = self.load_aggregate_sequence::<A>(id).await?;
184
185        let (query, values) = create_insert_events_query::<A>(id, sequence, events)?;
186
187        let mut conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
188
189        let transaction = conn
190            .build_transaction()
191            .isolation_level(IsolationLevel::ReadCommitted)
192            .start()
193            .await?;
194
195        let rows = transaction
196            .query(
197                &query,
198                &values
199                    .iter()
200                    .map(|value| value.as_ref() as &(dyn ToSql + Sync))
201                    .collect::<Vec<_>>(),
202            )
203            .await?;
204
205        let event_ids: Vec<_> = rows
206            .into_iter()
207            .map(|row| row.get::<_, i64>(0) as usize)
208            .collect();
209        let query = create_insert_outbox_events_query(&event_ids);
210
211        transaction
212            .execute(
213                &query,
214                &event_ids
215                    .iter()
216                    .map(|event_id| *event_id as i64)
217                    .collect::<Vec<_>>()
218                    .iter()
219                    .map(|event_id| event_id as &(dyn ToSql + Sync))
220                    .collect::<Vec<_>>(),
221            )
222            .await?;
223
224        transaction.commit().await?;
225
226        Ok(event_ids)
227    }
228}
229
230fn create_insert_events_query<A>(
231    id: &<A as Aggregate>::ID,
232    sequence: Option<usize>,
233    events: &[<A as Aggregate>::Event],
234) -> Result<(String, Vec<Box<dyn ToSql + Send + Sync>>), Error>
235where
236    A: Aggregate,
237    <A as Aggregate>::Event: Serialize,
238{
239    let mut query = SAVE_EVENTS_QUERY.to_string();
240    let mut values: Vec<Box<dyn ToSql + Send + Sync>> = Vec::with_capacity(events.len() * 5);
241    let event_values = events
242        .iter()
243        .enumerate()
244        .map(|(index, event)| {
245            Result::<_, Error>::Ok((
246                Box::new(<A as TypeId>::type_id()),
247                Box::new(id.to_string()),
248                Box::new(sequence.map(|sequence| sequence + index + 1).unwrap_or(0) as i64),
249                Box::new(event.event_type()),
250                Box::new(serde_json::to_value(event).map_err(Error::SerializeEvent)?),
251            ))
252        })
253        .collect::<Result<Vec<_>, _>>()?;
254    let values_len = event_values.len();
255    for (index, (aggregate_type, aggregate_id, sequence, event_type, event_data)) in
256        event_values.into_iter().enumerate()
257    {
258        write!(
259            query,
260            "(${}, ${}, ${}, ${}, ${})",
261            values.len() + 1,
262            values.len() + 2,
263            values.len() + 3,
264            values.len() + 4,
265            values.len() + 5
266        )
267        .unwrap();
268        if index < values_len - 1 {
269            write!(query, ", ").unwrap();
270        }
271
272        values.extend([
273            aggregate_type,
274            aggregate_id,
275            sequence,
276            event_type,
277            event_data,
278        ] as [Box<dyn ToSql + Send + Sync>; 5]);
279    }
280    write!(query, r#" RETURNING "id""#).unwrap();
281
282    Ok((query, values))
283}
284
285fn create_insert_outbox_events_query(event_ids: &[usize]) -> String {
286    INSERT_OUTBOX_EVENTS_QUERY.to_string()
287        + &(1..event_ids.len() + 1)
288            .into_iter()
289            .map(|index| format!("(${})", index))
290            .collect::<Vec<_>>()
291            .join(", ")
292}
293
294#[cfg(test)]
295mod tests {
296    use thalo::tests_cfg::bank_account::{
297        BankAccount, BankAccountEvent, DepositedFundsEvent, OpenedAccountEvent,
298    };
299
300    #[test]
301    fn insert_events_query() -> Result<(), super::Error> {
302        let id = "abc123".to_string();
303
304        let (query, _) = super::create_insert_events_query::<BankAccount>(
305            &id,
306            None,
307            &[
308                BankAccountEvent::OpenedAccount(OpenedAccountEvent { balance: 0.0 }),
309                BankAccountEvent::DepositedFunds(DepositedFundsEvent { amount: 25.0 }),
310            ],
311        )?;
312
313        assert_eq!(
314            query,
315            r#"INSERT INTO "event" (
316  "aggregate_type",
317  "aggregate_id",
318  "sequence",
319  "event_type",
320  "event_data"
321) VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) RETURNING "id""#
322        );
323
324        Ok(())
325    }
326
327    #[test]
328    fn insert_outbox_events_query() {
329        let query = super::create_insert_outbox_events_query(&[0, 1, 2]);
330
331        assert_eq!(
332            query,
333            r#"INSERT INTO "outbox" ("id") VALUES ($1), ($2), ($3)"#
334        );
335    }
336}