1use std::fmt::Write;
2
3use async_trait::async_trait;
4use bb8_postgres::{
5 bb8::Pool,
6 tokio_postgres::{
7 tls::{MakeTlsConnect, TlsConnect},
8 types::ToSql,
9 IsolationLevel, Socket,
10 },
11 PostgresConnectionManager,
12};
13use serde::{de::DeserializeOwned, Serialize};
14use thalo::{
15 aggregate::{Aggregate, TypeId},
16 event::{AggregateEventEnvelope, EventType},
17 event_store::EventStore,
18};
19
20use crate::error::Error;
21
22const INSERT_OUTBOX_EVENTS_QUERY: &str = include_str!("queries/insert_outbox_events.sql");
23const LOAD_AGGREGATE_SEQUENCE_QUERY: &str = include_str!("queries/load_aggregate_sequence.sql");
24const LOAD_EVENTS_QUERY: &str = include_str!("queries/load_events.sql");
25const LOAD_EVENTS_BY_ID_QUERY: &str = include_str!("queries/load_events_by_id.sql");
26const SAVE_EVENTS_QUERY: &str = include_str!("queries/save_events.sql");
27
28#[derive(Clone)]
30pub struct PgEventStore<Tls>
31where
32 Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
33 <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
34 <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
35 <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
36{
37 pool: Pool<PostgresConnectionManager<Tls>>,
38}
39
40impl<Tls> PgEventStore<Tls>
41where
42 Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
43 <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
44 <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
45 <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
46{
47 pub async fn connect(
49 uri: impl ToString,
50 tls: Tls,
51 ) -> Result<Self, bb8_postgres::tokio_postgres::Error> {
52 let manager = PostgresConnectionManager::new_from_stringlike(uri, tls)?;
53 let pool = Pool::builder().build(manager).await?;
54
55 Ok(Self { pool })
56 }
57}
58
59#[async_trait]
60impl<Tls> EventStore for PgEventStore<Tls>
61where
62 Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
63 <Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
64 <Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
65 <<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
66{
67 type Error = Error;
68
69 async fn load_events<A>(
70 &self,
71 id: Option<&<A as Aggregate>::ID>,
72 ) -> Result<Vec<AggregateEventEnvelope<A>>, Self::Error>
73 where
74 A: Aggregate,
75 <A as Aggregate>::Event: DeserializeOwned,
76 {
77 let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
78
79 let rows = conn
80 .query(
81 LOAD_EVENTS_QUERY,
82 &[&<A as TypeId>::type_id(), &id.map(|id| id.to_string())],
83 )
84 .await?;
85
86 Ok(rows
87 .into_iter()
88 .map(|row| {
89 let event_id = row.get::<_, i64>(0) as usize;
90
91 let event_json = row.get(5);
92 let event = serde_json::from_value(event_json)
93 .map_err(|err| Error::DeserializeDbEvent(event_id, err))?;
94
95 Result::<_, Self::Error>::Ok(AggregateEventEnvelope::<A> {
96 id: event_id,
97 created_at: row.get(1),
98 aggregate_type: row.get(2),
99 aggregate_id: row.get(3),
100 sequence: row.get::<_, i64>(4) as usize,
101 event,
102 })
103 })
104 .collect::<Result<Vec<_>, _>>()?)
105 }
106
107 async fn load_events_by_id<A>(
108 &self,
109 ids: &[usize],
110 ) -> Result<Vec<AggregateEventEnvelope<A>>, Self::Error>
111 where
112 A: Aggregate,
113 <A as Aggregate>::Event: DeserializeOwned,
114 {
115 let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
116
117 let rows = conn
118 .query(
119 LOAD_EVENTS_BY_ID_QUERY,
120 &[&ids
121 .iter()
122 .map(|id| id.to_string())
123 .collect::<Vec<_>>()
124 .join(",")],
125 )
126 .await?;
127
128 Ok(rows
129 .into_iter()
130 .map(|row| {
131 let event_id = row.get::<_, i64>(0) as usize;
132
133 let event_json = row.get(5);
134 let event = serde_json::from_value(event_json)
135 .map_err(|err| Error::DeserializeDbEvent(event_id, err))?;
136
137 Result::<_, Self::Error>::Ok(AggregateEventEnvelope::<A> {
138 id: event_id,
139 created_at: row.get(1),
140 aggregate_type: row.get(2),
141 aggregate_id: row.get(3),
142 sequence: row.get::<_, i64>(4) as usize,
143 event,
144 })
145 })
146 .collect::<Result<Vec<_>, _>>()?)
147 }
148
149 async fn load_aggregate_sequence<A>(
150 &self,
151 id: &<A as Aggregate>::ID,
152 ) -> Result<Option<usize>, Self::Error>
153 where
154 A: Aggregate,
155 {
156 let conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
157
158 let row = conn
159 .query_one(
160 LOAD_AGGREGATE_SEQUENCE_QUERY,
161 &[&<A as TypeId>::type_id(), &id.to_string()],
162 )
163 .await?;
164
165 Ok(row
166 .get::<_, Option<i64>>(0)
167 .map(|sequence| sequence as usize))
168 }
169
170 async fn save_events<A>(
171 &self,
172 id: &<A as Aggregate>::ID,
173 events: &[<A as Aggregate>::Event],
174 ) -> Result<Vec<usize>, Self::Error>
175 where
176 A: Aggregate,
177 <A as Aggregate>::Event: Serialize,
178 {
179 if events.is_empty() {
180 return Ok(vec![]);
181 }
182
183 let sequence = self.load_aggregate_sequence::<A>(id).await?;
184
185 let (query, values) = create_insert_events_query::<A>(id, sequence, events)?;
186
187 let mut conn = self.pool.get().await.map_err(Error::GetDbPoolConnection)?;
188
189 let transaction = conn
190 .build_transaction()
191 .isolation_level(IsolationLevel::ReadCommitted)
192 .start()
193 .await?;
194
195 let rows = transaction
196 .query(
197 &query,
198 &values
199 .iter()
200 .map(|value| value.as_ref() as &(dyn ToSql + Sync))
201 .collect::<Vec<_>>(),
202 )
203 .await?;
204
205 let event_ids: Vec<_> = rows
206 .into_iter()
207 .map(|row| row.get::<_, i64>(0) as usize)
208 .collect();
209 let query = create_insert_outbox_events_query(&event_ids);
210
211 transaction
212 .execute(
213 &query,
214 &event_ids
215 .iter()
216 .map(|event_id| *event_id as i64)
217 .collect::<Vec<_>>()
218 .iter()
219 .map(|event_id| event_id as &(dyn ToSql + Sync))
220 .collect::<Vec<_>>(),
221 )
222 .await?;
223
224 transaction.commit().await?;
225
226 Ok(event_ids)
227 }
228}
229
230fn create_insert_events_query<A>(
231 id: &<A as Aggregate>::ID,
232 sequence: Option<usize>,
233 events: &[<A as Aggregate>::Event],
234) -> Result<(String, Vec<Box<dyn ToSql + Send + Sync>>), Error>
235where
236 A: Aggregate,
237 <A as Aggregate>::Event: Serialize,
238{
239 let mut query = SAVE_EVENTS_QUERY.to_string();
240 let mut values: Vec<Box<dyn ToSql + Send + Sync>> = Vec::with_capacity(events.len() * 5);
241 let event_values = events
242 .iter()
243 .enumerate()
244 .map(|(index, event)| {
245 Result::<_, Error>::Ok((
246 Box::new(<A as TypeId>::type_id()),
247 Box::new(id.to_string()),
248 Box::new(sequence.map(|sequence| sequence + index + 1).unwrap_or(0) as i64),
249 Box::new(event.event_type()),
250 Box::new(serde_json::to_value(event).map_err(Error::SerializeEvent)?),
251 ))
252 })
253 .collect::<Result<Vec<_>, _>>()?;
254 let values_len = event_values.len();
255 for (index, (aggregate_type, aggregate_id, sequence, event_type, event_data)) in
256 event_values.into_iter().enumerate()
257 {
258 write!(
259 query,
260 "(${}, ${}, ${}, ${}, ${})",
261 values.len() + 1,
262 values.len() + 2,
263 values.len() + 3,
264 values.len() + 4,
265 values.len() + 5
266 )
267 .unwrap();
268 if index < values_len - 1 {
269 write!(query, ", ").unwrap();
270 }
271
272 values.extend([
273 aggregate_type,
274 aggregate_id,
275 sequence,
276 event_type,
277 event_data,
278 ] as [Box<dyn ToSql + Send + Sync>; 5]);
279 }
280 write!(query, r#" RETURNING "id""#).unwrap();
281
282 Ok((query, values))
283}
284
285fn create_insert_outbox_events_query(event_ids: &[usize]) -> String {
286 INSERT_OUTBOX_EVENTS_QUERY.to_string()
287 + &(1..event_ids.len() + 1)
288 .into_iter()
289 .map(|index| format!("(${})", index))
290 .collect::<Vec<_>>()
291 .join(", ")
292}
293
294#[cfg(test)]
295mod tests {
296 use thalo::tests_cfg::bank_account::{
297 BankAccount, BankAccountEvent, DepositedFundsEvent, OpenedAccountEvent,
298 };
299
300 #[test]
301 fn insert_events_query() -> Result<(), super::Error> {
302 let id = "abc123".to_string();
303
304 let (query, _) = super::create_insert_events_query::<BankAccount>(
305 &id,
306 None,
307 &[
308 BankAccountEvent::OpenedAccount(OpenedAccountEvent { balance: 0.0 }),
309 BankAccountEvent::DepositedFunds(DepositedFundsEvent { amount: 25.0 }),
310 ],
311 )?;
312
313 assert_eq!(
314 query,
315 r#"INSERT INTO "event" (
316 "aggregate_type",
317 "aggregate_id",
318 "sequence",
319 "event_type",
320 "event_data"
321) VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) RETURNING "id""#
322 );
323
324 Ok(())
325 }
326
327 #[test]
328 fn insert_outbox_events_query() {
329 let query = super::create_insert_outbox_events_query(&[0, 1, 2]);
330
331 assert_eq!(
332 query,
333 r#"INSERT INTO "outbox" ("id") VALUES ($1), ($2), ($3)"#
334 );
335 }
336}