1use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use sqlx::{postgres::PgListener, PgPool};
5use tokio::{
6 sync::{
7 broadcast::{self, error::RecvError},
8 RwLock,
9 },
10 task,
11};
12use tracing::instrument;
13#[cfg(feature = "otel")]
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16use std::{
17 collections::HashMap,
18 sync::{
19 atomic::{AtomicBool, Ordering},
20 Arc,
21 },
22};
23
24use crate::{
25 balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
26};
27
28pub struct EventSubscriberOpts {
30 pub close_on_lag: bool,
31 pub buffer: usize,
32 pub after_id: Option<SqlxLedgerEventId>,
33}
34impl Default for EventSubscriberOpts {
35 fn default() -> Self {
36 Self {
37 close_on_lag: false,
38 buffer: 100,
39 after_id: None,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct EventSubscriber {
47 buffer: usize,
48 closed: Arc<AtomicBool>,
49 #[allow(clippy::type_complexity)]
50 balance_receivers:
51 Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
52 journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
53 all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
54}
55
56impl EventSubscriber {
57 pub(crate) async fn connect(
58 pool: &PgPool,
59 EventSubscriberOpts {
60 close_on_lag,
61 buffer,
62 after_id: start_id,
63 }: EventSubscriberOpts,
64 ) -> Result<Self, SqlxLedgerError> {
65 let closed = Arc::new(AtomicBool::new(false));
66 let mut incoming = subscribe(pool.clone(), Arc::clone(&closed), buffer, start_id).await?;
67 let all = Arc::new(incoming.resubscribe());
68 let balance_receivers = Arc::new(RwLock::new(HashMap::<
69 (JournalId, AccountId),
70 broadcast::Sender<SqlxLedgerEvent>,
71 >::new()));
72 let journal_receivers = Arc::new(RwLock::new(HashMap::<
73 JournalId,
74 broadcast::Sender<SqlxLedgerEvent>,
75 >::new()));
76 let inner_balance_receivers = Arc::clone(&balance_receivers);
77 let inner_journal_receivers = Arc::clone(&journal_receivers);
78 let inner_closed = Arc::clone(&closed);
79 tokio::spawn(async move {
80 loop {
81 match incoming.recv().await {
82 Ok(event) => {
83 let journal_id = event.journal_id();
84 if let Some(journal_receivers) =
85 inner_journal_receivers.read().await.get(&journal_id)
86 {
87 let _ = journal_receivers.send(event.clone());
88 }
89 if let Some(account_id) = event.account_id() {
90 let receivers = inner_balance_receivers.read().await;
91 if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
92 let _ = receiver.send(event);
93 }
94 }
95 }
96 Err(RecvError::Lagged(_)) => {
97 if close_on_lag {
98 inner_closed.store(true, Ordering::SeqCst);
99 inner_balance_receivers.write().await.clear();
100 inner_journal_receivers.write().await.clear();
101 }
102 }
103 Err(RecvError::Closed) => {
104 tracing::warn!("Event subscriber closed");
105 inner_closed.store(true, Ordering::SeqCst);
106 inner_balance_receivers.write().await.clear();
107 inner_journal_receivers.write().await.clear();
108 break;
109 }
110 }
111 }
112 });
113 Ok(Self {
114 buffer,
115 closed,
116 balance_receivers,
117 journal_receivers,
118 all,
119 })
120 }
121
122 pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
123 let recv = self.all.resubscribe();
124 if self.closed.load(Ordering::SeqCst) {
125 return Err(SqlxLedgerError::EventSubscriberClosed);
126 }
127 Ok(recv)
128 }
129
130 pub async fn account_balance(
131 &self,
132 journal_id: JournalId,
133 account_id: AccountId,
134 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
135 let mut listeners = self.balance_receivers.write().await;
136 let mut ret = None;
137 let sender = listeners
138 .entry((journal_id, account_id))
139 .or_insert_with(|| {
140 let (sender, recv) = broadcast::channel(self.buffer);
141 ret = Some(recv);
142 sender
143 });
144 let ret = ret.unwrap_or_else(|| sender.subscribe());
145 if self.closed.load(Ordering::SeqCst) {
146 listeners.remove(&(journal_id, account_id));
147 return Err(SqlxLedgerError::EventSubscriberClosed);
148 }
149 Ok(ret)
150 }
151
152 pub async fn journal(
153 &self,
154 journal_id: JournalId,
155 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
156 let mut listeners = self.journal_receivers.write().await;
157 let mut ret = None;
158 let sender = listeners.entry(journal_id).or_insert_with(|| {
159 let (sender, recv) = broadcast::channel(self.buffer);
160 ret = Some(recv);
161 sender
162 });
163 let ret = ret.unwrap_or_else(|| sender.subscribe());
164 if self.closed.load(Ordering::SeqCst) {
165 listeners.remove(&journal_id);
166 return Err(SqlxLedgerError::EventSubscriberClosed);
167 }
168 Ok(ret)
169 }
170}
171
172#[derive(
173 sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
174)]
175#[serde(transparent)]
176#[sqlx(transparent)]
177pub struct SqlxLedgerEventId(i64);
178impl SqlxLedgerEventId {
179 pub const BEGIN: Self = Self(0);
180}
181
182#[derive(Debug, Clone, Deserialize)]
184#[serde(try_from = "EventRaw")]
185pub struct SqlxLedgerEvent {
186 pub id: SqlxLedgerEventId,
187 pub data: SqlxLedgerEventData,
188 pub r#type: SqlxLedgerEventType,
189 pub recorded_at: DateTime<Utc>,
190 #[cfg(feature = "otel")]
191 pub otel_context: opentelemetry::Context,
192}
193
194impl SqlxLedgerEvent {
195 #[cfg(feature = "otel")]
196 fn record_otel_context(&mut self) {
197 self.otel_context = tracing::Span::current().context();
198 }
199
200 #[cfg(not(feature = "otel"))]
201 fn record_otel_context(&mut self) {}
202}
203
204impl SqlxLedgerEvent {
205 pub fn journal_id(&self) -> JournalId {
206 match &self.data {
207 SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
208 SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
209 SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
210 }
211 }
212
213 pub fn account_id(&self) -> Option<AccountId> {
214 match &self.data {
215 SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
216 _ => None,
217 }
218 }
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223#[allow(clippy::large_enum_variant)]
224pub enum SqlxLedgerEventData {
225 BalanceUpdated(BalanceDetails),
226 TransactionCreated(Transaction),
227 TransactionUpdated(Transaction),
228}
229
230#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub enum SqlxLedgerEventType {
233 BalanceUpdated,
234 TransactionCreated,
235 TransactionUpdated,
236}
237
238pub(crate) async fn subscribe(
239 pool: PgPool,
240 closed: Arc<AtomicBool>,
241 buffer: usize,
242 after_id: Option<SqlxLedgerEventId>,
243) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
244 let mut listener = PgListener::connect_with(&pool).await?;
245 listener.listen("sqlx_ledger_events").await?;
246 let (snd, recv) = broadcast::channel(buffer);
247 let mut reload = after_id.is_some();
248 task::spawn(async move {
249 let mut num_errors = 0;
250 let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
251 loop {
252 if reload {
253 match sqlx::query!(
254 r#"SELECT json_build_object(
255 'id', id,
256 'type', type,
257 'data', data,
258 'recorded_at', recorded_at
259 ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id"#,
260 last_id.0
261 )
262 .fetch_all(&pool)
263 .await
264 {
265 Ok(rows) => {
266 num_errors = 0;
267 for row in rows {
268 let event: Result<SqlxLedgerEvent, _> =
269 serde_json::from_value(row.payload);
270 if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
271 .is_err()
272 {
273 closed.store(true, Ordering::SeqCst);
274 break;
275 }
276 }
277 }
278 Err(e) if num_errors == 0 => {
279 tracing::error!("Error fetching events: {}", e);
280 tokio::time::sleep(std::time::Duration::from_secs(1 << num_errors)).await;
281 num_errors += 1;
282 continue;
283 }
284 _ => {
285 num_errors = 0;
286 continue;
287 }
288 }
289 }
290 if closed.load(Ordering::Relaxed) {
291 break;
292 }
293 while let Ok(notification) = listener.recv().await {
294 let event: Result<SqlxLedgerEvent, _> =
295 serde_json::from_str(notification.payload());
296 match sqlx_ledger_notification_received(event, &snd, &mut last_id, !reload) {
297 Ok(false) => break,
298 Ok(_) => num_errors = 0,
299 Err(_) => {
300 closed.store(true, Ordering::SeqCst);
301 }
302 }
303 reload = true;
304 }
305 }
306 });
307 Ok(recv)
308}
309
310#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
311fn sqlx_ledger_notification_received(
312 event: Result<SqlxLedgerEvent, serde_json::Error>,
313 sender: &broadcast::Sender<SqlxLedgerEvent>,
314 last_id: &mut SqlxLedgerEventId,
315 ignore_gap: bool,
316) -> Result<bool, SqlxLedgerError> {
317 let mut event = event?;
318 event.record_otel_context();
319 let id = event.id;
320 if id <= *last_id {
321 return Ok(true);
322 }
323 if !ignore_gap && last_id.0 + 1 != id.0 {
324 return Ok(false);
325 }
326 sender.send(event)?;
327 *last_id = id;
328 Ok(true)
329}
330
331#[derive(Deserialize)]
332struct EventRaw {
333 id: SqlxLedgerEventId,
334 data: serde_json::Value,
335 r#type: SqlxLedgerEventType,
336 recorded_at: DateTime<Utc>,
337}
338
339impl TryFrom<EventRaw> for SqlxLedgerEvent {
340 type Error = serde_json::Error;
341
342 fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
343 let data = match value.r#type {
344 SqlxLedgerEventType::BalanceUpdated => {
345 SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(value.data)?)
346 }
347 SqlxLedgerEventType::TransactionCreated => {
348 SqlxLedgerEventData::TransactionCreated(serde_json::from_value(value.data)?)
349 }
350 SqlxLedgerEventType::TransactionUpdated => {
351 SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(value.data)?)
352 }
353 };
354
355 Ok(SqlxLedgerEvent {
356 id: value.id,
357 data,
358 r#type: value.r#type,
359 recorded_at: value.recorded_at,
360 #[cfg(feature = "otel")]
361 otel_context: tracing::Span::current().context(),
362 })
363 }
364}