sqlx_ledger/
event.rs

1//! Use [ledger.events()](crate::SqlxLedger::events()) to subscribe to events triggered by changes to the ledger.
2use chrono::{DateTime, Utc};
3use futures::StreamExt;
4use serde::{de::Error as _, Deserialize, Serialize};
5use sqlx::{postgres::PgListener, PgPool};
6use std::time::Duration;
7use tokio::{
8    sync::{
9        broadcast::{self, error::RecvError},
10        RwLock,
11    },
12    task,
13};
14use tracing::instrument;
15#[cfg(feature = "otel")]
16use tracing_opentelemetry::OpenTelemetrySpanExt;
17
18use std::{
19    collections::HashMap,
20    sync::{
21        atomic::{AtomicBool, Ordering},
22        Arc,
23    },
24};
25
26use crate::{
27    balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
28};
29
30/// Options when initializing the EventSubscriber
31pub struct EventSubscriberOpts {
32    pub close_on_lag: bool,
33    pub buffer: usize,
34    pub after_id: Option<SqlxLedgerEventId>,
35}
36impl Default for EventSubscriberOpts {
37    fn default() -> Self {
38        Self {
39            close_on_lag: false,
40            buffer: 100,
41            after_id: None,
42        }
43    }
44}
45
46/// Contains fields to store & manage various ledger-related `SqlxLedgerEvent` event receivers.
47#[derive(Debug, Clone)]
48pub struct EventSubscriber {
49    buffer: usize,
50    closed: Arc<AtomicBool>,
51    #[allow(clippy::type_complexity)]
52    balance_receivers:
53        Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
54    journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
55    all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
56}
57
58impl EventSubscriber {
59    pub(crate) async fn connect(
60        pool: &PgPool,
61        EventSubscriberOpts {
62            close_on_lag,
63            buffer,
64            after_id: start_id,
65        }: EventSubscriberOpts,
66    ) -> Result<Self, SqlxLedgerError> {
67        let closed = Arc::new(AtomicBool::new(false));
68        let mut incoming = subscribe(pool.clone(), Arc::clone(&closed), buffer, start_id).await?;
69        let all = Arc::new(incoming.resubscribe());
70        let balance_receivers = Arc::new(RwLock::new(HashMap::<
71            (JournalId, AccountId),
72            broadcast::Sender<SqlxLedgerEvent>,
73        >::new()));
74        let journal_receivers = Arc::new(RwLock::new(HashMap::<
75            JournalId,
76            broadcast::Sender<SqlxLedgerEvent>,
77        >::new()));
78        let inner_balance_receivers = Arc::clone(&balance_receivers);
79        let inner_journal_receivers = Arc::clone(&journal_receivers);
80        let inner_closed = Arc::clone(&closed);
81        tokio::spawn(async move {
82            loop {
83                match incoming.recv().await {
84                    Ok(event) => {
85                        let journal_id = event.journal_id();
86                        if let Some(journal_receivers) =
87                            inner_journal_receivers.read().await.get(&journal_id)
88                        {
89                            let _ = journal_receivers.send(event.clone());
90                        }
91                        if let Some(account_id) = event.account_id() {
92                            let receivers = inner_balance_receivers.read().await;
93                            if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
94                                let _ = receiver.send(event);
95                            }
96                        }
97                    }
98                    Err(RecvError::Lagged(_)) => {
99                        if close_on_lag {
100                            inner_closed.store(true, Ordering::SeqCst);
101                            inner_balance_receivers.write().await.clear();
102                            inner_journal_receivers.write().await.clear();
103                        }
104                    }
105                    Err(RecvError::Closed) => {
106                        tracing::warn!("Event subscriber closed");
107                        inner_closed.store(true, Ordering::SeqCst);
108                        inner_balance_receivers.write().await.clear();
109                        inner_journal_receivers.write().await.clear();
110                        break;
111                    }
112                }
113            }
114        });
115        Ok(Self {
116            buffer,
117            closed,
118            balance_receivers,
119            journal_receivers,
120            all,
121        })
122    }
123
124    pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
125        let recv = self.all.resubscribe();
126        if self.closed.load(Ordering::SeqCst) {
127            return Err(SqlxLedgerError::EventSubscriberClosed);
128        }
129        Ok(recv)
130    }
131
132    pub async fn account_balance(
133        &self,
134        journal_id: JournalId,
135        account_id: AccountId,
136    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
137        let mut listeners = self.balance_receivers.write().await;
138        let mut ret = None;
139        let sender = listeners
140            .entry((journal_id, account_id))
141            .or_insert_with(|| {
142                let (sender, recv) = broadcast::channel(self.buffer);
143                ret = Some(recv);
144                sender
145            });
146        let ret = ret.unwrap_or_else(|| sender.subscribe());
147        if self.closed.load(Ordering::SeqCst) {
148            listeners.remove(&(journal_id, account_id));
149            return Err(SqlxLedgerError::EventSubscriberClosed);
150        }
151        Ok(ret)
152    }
153
154    pub async fn journal(
155        &self,
156        journal_id: JournalId,
157    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
158        let mut listeners = self.journal_receivers.write().await;
159        let mut ret = None;
160        let sender = listeners.entry(journal_id).or_insert_with(|| {
161            let (sender, recv) = broadcast::channel(self.buffer);
162            ret = Some(recv);
163            sender
164        });
165        let ret = ret.unwrap_or_else(|| sender.subscribe());
166        if self.closed.load(Ordering::SeqCst) {
167            listeners.remove(&journal_id);
168            return Err(SqlxLedgerError::EventSubscriberClosed);
169        }
170        Ok(ret)
171    }
172}
173
174#[derive(
175    sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
176)]
177#[serde(transparent)]
178#[sqlx(transparent)]
179pub struct SqlxLedgerEventId(i64);
180impl SqlxLedgerEventId {
181    pub const BEGIN: Self = Self(0);
182}
183
184/// Representation of a ledger event.
185#[derive(Debug, Clone, Deserialize)]
186#[serde(try_from = "EventRaw")]
187pub struct SqlxLedgerEvent {
188    pub id: SqlxLedgerEventId,
189    pub data: SqlxLedgerEventData,
190    pub r#type: SqlxLedgerEventType,
191    pub recorded_at: DateTime<Utc>,
192    #[cfg(feature = "otel")]
193    pub otel_context: opentelemetry::Context,
194}
195
196impl SqlxLedgerEvent {
197    #[cfg(feature = "otel")]
198    fn record_otel_context(&mut self) {
199        self.otel_context = tracing::Span::current().context();
200    }
201
202    #[cfg(not(feature = "otel"))]
203    fn record_otel_context(&mut self) {}
204}
205
206impl SqlxLedgerEvent {
207    pub fn journal_id(&self) -> JournalId {
208        match &self.data {
209            SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
210            SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
211            SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
212        }
213    }
214
215    pub fn account_id(&self) -> Option<AccountId> {
216        match &self.data {
217            SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
218            _ => None,
219        }
220    }
221}
222
223/// Represents the different kinds of data that can be included in an `SqlxLedgerEvent` event.
224#[derive(Debug, Clone, Serialize, Deserialize)]
225#[allow(clippy::large_enum_variant)]
226pub enum SqlxLedgerEventData {
227    BalanceUpdated(BalanceDetails),
228    TransactionCreated(Transaction),
229    TransactionUpdated(Transaction),
230}
231
232/// Defines possible event types for `SqlxLedgerEvent`.
233#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub enum SqlxLedgerEventType {
235    BalanceUpdated,
236    TransactionCreated,
237    TransactionUpdated,
238}
239
240pub(crate) async fn subscribe(
241    pool: PgPool,
242    closed: Arc<AtomicBool>,
243    buffer: usize,
244    after_id: Option<SqlxLedgerEventId>,
245) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
246    let mut listener = PgListener::connect_with(&pool).await?;
247    listener.listen("sqlx_ledger_events").await?;
248    let (snd, recv) = broadcast::channel(buffer);
249    let mut reload = after_id.is_some();
250
251    task::spawn(async move {
252        let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
253        let mut consecutive_errors = 0;
254        const MAX_RETRY_DELAY: u64 = 60;
255        const MAX_CONSECUTIVE_ERRORS: u32 = 5;
256
257        let subscriber_loop = async {
258            loop {
259                if reload {
260                    let mut stream = sqlx::query!(
261                        r#"SELECT json_build_object(
262                          'id', id,
263                          'type', type,
264                          'data', data,
265                          'recorded_at', recorded_at
266                        ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id"#,
267                        last_id.0
268                    )
269                    .fetch(&pool);
270
271                    let mut stream_failed = false;
272                    while let Some(result) = stream.next().await {
273                        match result {
274                            Ok(row) => {
275                                consecutive_errors = 0;
276                                let event: Result<SqlxLedgerEvent, _> =
277                                    serde_json::from_value(row.payload);
278                                if sqlx_ledger_notification_received(
279                                    event,
280                                    &snd,
281                                    &mut last_id,
282                                    true,
283                                )
284                                .is_err()
285                                {
286                                    return Err("channel closed");
287                                }
288                            }
289                            Err(e) => {
290                                consecutive_errors += 1;
291                                tracing::error!(
292                                    "Error fetching events after id {}: {}",
293                                    last_id.0,
294                                    e
295                                );
296
297                                if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
298                                    tracing::error!("Max retries exceeded");
299                                    return Err("max retries");
300                                }
301
302                                let delay =
303                                    (1_u64 << consecutive_errors.min(6)).min(MAX_RETRY_DELAY);
304                                tokio::time::sleep(Duration::from_secs(delay)).await;
305                                stream_failed = true;
306                                break;
307                            }
308                        }
309                    }
310
311                    if stream_failed {
312                        continue;
313                    }
314
315                    reload = false;
316                }
317
318                if closed.load(Ordering::Relaxed) {
319                    return Ok(());
320                }
321
322                tokio::select! {
323                    notification = listener.recv() => {
324                        match notification {
325                            Ok(n) => {
326                                let event: Result<SqlxLedgerEvent, _> =
327                                    serde_json::from_str(n.payload());
328
329                                if let Err(e) = &event {
330                                    tracing::warn!("Failed to parse: {}", e);
331                                    if e.to_string().contains("data field missing") {
332                                        reload = true;
333                                        continue;
334                                    }
335                                }
336
337                                match sqlx_ledger_notification_received(event, &snd, &mut last_id, false) {
338                                    Ok(false) => return Err("channel closed"),
339                                    Ok(_) => consecutive_errors = 0,
340                                    Err(_) => return Err("channel closed"),
341                                }
342                            }
343                            Err(e) => {
344                                tracing::error!("Listener error: {}", e);
345                                consecutive_errors += 1;
346                                if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
347                                    return Err("max retries");
348                                }
349                                tokio::time::sleep(Duration::from_secs(1)).await;
350                            }
351                        }
352                    }
353                    _ = tokio::time::sleep(Duration::from_secs(30)) => {
354                        // Health check: periodically verify stream is not closed
355                        if closed.load(Ordering::Relaxed) {
356                            return Ok(());
357                        }
358                    }
359                }
360            }
361        };
362
363        let result = subscriber_loop.await;
364        match result {
365            Ok(()) => tracing::info!("Subscriber shutting down gracefully"),
366            Err(reason) => tracing::warn!("Subscriber shutting down: {}", reason),
367        }
368
369        if let Err(e) = listener.unlisten("sqlx_ledger_events").await {
370            tracing::warn!("Failed to unlisten from sqlx_ledger_events: {}", e);
371        }
372    });
373
374    Ok(recv)
375}
376
377#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
378fn sqlx_ledger_notification_received(
379    event: Result<SqlxLedgerEvent, serde_json::Error>,
380    sender: &broadcast::Sender<SqlxLedgerEvent>,
381    last_id: &mut SqlxLedgerEventId,
382    ignore_gap: bool,
383) -> Result<bool, SqlxLedgerError> {
384    let mut event = event?;
385    event.record_otel_context();
386    let id = event.id;
387    if id <= *last_id {
388        return Ok(true);
389    }
390    if !ignore_gap && last_id.0 + 1 != id.0 {
391        return Ok(false);
392    }
393    sender.send(event)?;
394    *last_id = id;
395    Ok(true)
396}
397
398#[derive(Deserialize)]
399struct EventRaw {
400    id: SqlxLedgerEventId,
401    #[serde(default)]
402    data: Option<serde_json::Value>,
403    r#type: SqlxLedgerEventType,
404    recorded_at: DateTime<Utc>,
405}
406
407impl TryFrom<EventRaw> for SqlxLedgerEvent {
408    type Error = serde_json::Error;
409
410    fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
411        let data_value = value
412            .data
413            .ok_or_else(|| serde_json::Error::custom("data field missing"))?;
414
415        let data = match value.r#type {
416            SqlxLedgerEventType::BalanceUpdated => {
417                SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
418            }
419            SqlxLedgerEventType::TransactionCreated => {
420                SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
421            }
422            SqlxLedgerEventType::TransactionUpdated => {
423                SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
424            }
425        };
426
427        Ok(SqlxLedgerEvent {
428            id: value.id,
429            data,
430            r#type: value.r#type,
431            recorded_at: value.recorded_at,
432            #[cfg(feature = "otel")]
433            otel_context: tracing::Span::current().context(),
434        })
435    }
436}