Skip to main content

sqlx_ledger/
event.rs

1//! Use [ledger.events()](crate::SqlxLedger::events()) to subscribe to events triggered by changes to the ledger.
2use chrono::{DateTime, Utc};
3use serde::{de::Error as _, Deserialize, Serialize};
4use sqlx::{postgres::PgListener, PgPool};
5use tokio::{
6    sync::{
7        broadcast::{self, error::RecvError},
8        RwLock,
9    },
10    task,
11};
12use tracing::instrument;
13#[cfg(feature = "otel")]
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16use std::{
17    collections::HashMap,
18    sync::{
19        atomic::{AtomicBool, Ordering},
20        Arc,
21    },
22};
23
24use crate::{
25    balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
26};
27
28/// Options when initializing the EventSubscriber
29pub struct EventSubscriberOpts {
30    pub close_on_lag: bool,
31    pub buffer: usize,
32    pub after_id: Option<SqlxLedgerEventId>,
33    pub batch_size: i64,
34}
35impl Default for EventSubscriberOpts {
36    fn default() -> Self {
37        Self {
38            close_on_lag: false,
39            buffer: 100,
40            after_id: None,
41            batch_size: 1000,
42        }
43    }
44}
45
46/// Contains fields to store & manage various ledger-related `SqlxLedgerEvent` event receivers.
47#[derive(Debug, Clone)]
48pub struct EventSubscriber {
49    buffer: usize,
50    closed: Arc<AtomicBool>,
51    #[allow(clippy::type_complexity)]
52    balance_receivers:
53        Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
54    journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
55    all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
56}
57
58impl EventSubscriber {
59    pub(crate) async fn connect(
60        pool: &PgPool,
61        EventSubscriberOpts {
62            close_on_lag,
63            buffer,
64            after_id: start_id,
65            batch_size,
66        }: EventSubscriberOpts,
67    ) -> Result<Self, SqlxLedgerError> {
68        let closed = Arc::new(AtomicBool::new(false));
69        let mut incoming = subscribe(
70            pool.clone(),
71            Arc::clone(&closed),
72            buffer,
73            start_id,
74            batch_size,
75        )
76        .await?;
77        let all = Arc::new(incoming.resubscribe());
78        let balance_receivers = Arc::new(RwLock::new(HashMap::<
79            (JournalId, AccountId),
80            broadcast::Sender<SqlxLedgerEvent>,
81        >::new()));
82        let journal_receivers = Arc::new(RwLock::new(HashMap::<
83            JournalId,
84            broadcast::Sender<SqlxLedgerEvent>,
85        >::new()));
86        let inner_balance_receivers = Arc::clone(&balance_receivers);
87        let inner_journal_receivers = Arc::clone(&journal_receivers);
88        let inner_closed = Arc::clone(&closed);
89        tokio::spawn(async move {
90            loop {
91                match incoming.recv().await {
92                    Ok(event) => {
93                        let journal_id = event.journal_id();
94                        if let Some(journal_receivers) =
95                            inner_journal_receivers.read().await.get(&journal_id)
96                        {
97                            let _ = journal_receivers.send(event.clone());
98                        }
99                        if let Some(account_id) = event.account_id() {
100                            let receivers = inner_balance_receivers.read().await;
101                            if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
102                                let _ = receiver.send(event);
103                            }
104                        }
105                    }
106                    Err(RecvError::Lagged(_)) => {
107                        if close_on_lag {
108                            inner_closed.store(true, Ordering::SeqCst);
109                            inner_balance_receivers.write().await.clear();
110                            inner_journal_receivers.write().await.clear();
111                        }
112                    }
113                    Err(RecvError::Closed) => {
114                        tracing::warn!("Event subscriber closed");
115                        inner_closed.store(true, Ordering::SeqCst);
116                        inner_balance_receivers.write().await.clear();
117                        inner_journal_receivers.write().await.clear();
118                        break;
119                    }
120                }
121            }
122        });
123        Ok(Self {
124            buffer,
125            closed,
126            balance_receivers,
127            journal_receivers,
128            all,
129        })
130    }
131
132    pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
133        let recv = self.all.resubscribe();
134        if self.closed.load(Ordering::SeqCst) {
135            return Err(SqlxLedgerError::EventSubscriberClosed);
136        }
137        Ok(recv)
138    }
139
140    pub async fn account_balance(
141        &self,
142        journal_id: JournalId,
143        account_id: AccountId,
144    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
145        let mut listeners = self.balance_receivers.write().await;
146        let mut ret = None;
147        let sender = listeners
148            .entry((journal_id, account_id))
149            .or_insert_with(|| {
150                let (sender, recv) = broadcast::channel(self.buffer);
151                ret = Some(recv);
152                sender
153            });
154        let ret = ret.unwrap_or_else(|| sender.subscribe());
155        if self.closed.load(Ordering::SeqCst) {
156            listeners.remove(&(journal_id, account_id));
157            return Err(SqlxLedgerError::EventSubscriberClosed);
158        }
159        Ok(ret)
160    }
161
162    pub async fn journal(
163        &self,
164        journal_id: JournalId,
165    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
166        let mut listeners = self.journal_receivers.write().await;
167        let mut ret = None;
168        let sender = listeners.entry(journal_id).or_insert_with(|| {
169            let (sender, recv) = broadcast::channel(self.buffer);
170            ret = Some(recv);
171            sender
172        });
173        let ret = ret.unwrap_or_else(|| sender.subscribe());
174        if self.closed.load(Ordering::SeqCst) {
175            listeners.remove(&journal_id);
176            return Err(SqlxLedgerError::EventSubscriberClosed);
177        }
178        Ok(ret)
179    }
180}
181
182#[derive(
183    sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
184)]
185#[serde(transparent)]
186#[sqlx(transparent)]
187pub struct SqlxLedgerEventId(i64);
188impl SqlxLedgerEventId {
189    pub const BEGIN: Self = Self(0);
190}
191
192impl From<i64> for SqlxLedgerEventId {
193    fn from(value: i64) -> Self {
194        Self(value)
195    }
196}
197
198/// Representation of a ledger event.
199#[derive(Debug, Clone, Deserialize)]
200#[serde(try_from = "EventRaw")]
201pub struct SqlxLedgerEvent {
202    pub id: SqlxLedgerEventId,
203    pub data: SqlxLedgerEventData,
204    pub r#type: SqlxLedgerEventType,
205    pub recorded_at: DateTime<Utc>,
206    #[cfg(feature = "otel")]
207    pub otel_context: opentelemetry::Context,
208}
209
210impl SqlxLedgerEvent {
211    #[cfg(feature = "otel")]
212    fn record_otel_context(&mut self) {
213        self.otel_context = tracing::Span::current().context();
214    }
215
216    #[cfg(not(feature = "otel"))]
217    fn record_otel_context(&mut self) {}
218}
219
220impl SqlxLedgerEvent {
221    pub fn journal_id(&self) -> JournalId {
222        match &self.data {
223            SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
224            SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
225            SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
226        }
227    }
228
229    pub fn account_id(&self) -> Option<AccountId> {
230        match &self.data {
231            SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
232            _ => None,
233        }
234    }
235}
236
237/// Represents the different kinds of data that can be included in an `SqlxLedgerEvent` event.
238#[derive(Debug, Clone, Serialize, Deserialize)]
239#[allow(clippy::large_enum_variant)]
240pub enum SqlxLedgerEventData {
241    BalanceUpdated(BalanceDetails),
242    TransactionCreated(Transaction),
243    TransactionUpdated(Transaction),
244}
245
246/// Defines possible event types for `SqlxLedgerEvent`.
247#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248pub enum SqlxLedgerEventType {
249    BalanceUpdated,
250    TransactionCreated,
251    TransactionUpdated,
252}
253
254pub(crate) async fn subscribe(
255    pool: PgPool,
256    closed: Arc<AtomicBool>,
257    buffer: usize,
258    after_id: Option<SqlxLedgerEventId>,
259    batch_size: i64,
260) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
261    let mut listener = PgListener::connect_with(&pool).await?;
262    listener.listen("sqlx_ledger_events").await?;
263    let (snd, recv) = broadcast::channel(buffer);
264    let mut reload = after_id.is_some();
265    task::spawn(async move {
266        let mut num_errors: u32 = 0;
267        let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
268        loop {
269            if reload {
270                // Fetch missed events in paginated batches to avoid loading
271                // the entire event history into memory at once.
272                let batch_result = async {
273                    loop {
274                        let rows = sqlx::query!(
275                            r#"SELECT json_build_object(
276                              'id', id,
277                              'type', type,
278                              'data', data,
279                              'recorded_at', recorded_at
280                            ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id LIMIT $2"#,
281                            last_id.0,
282                            batch_size
283                        )
284                        .fetch_all(&pool)
285                        .await?;
286
287                        let is_last_batch = (rows.len() as i64) < batch_size;
288
289                        for row in rows {
290                            let event: Result<SqlxLedgerEvent, _> =
291                                serde_json::from_value(row.payload);
292                            if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
293                                .is_err()
294                            {
295                                return Err::<(), SqlxLedgerError>(
296                                    SqlxLedgerError::EventSubscriberClosed,
297                                );
298                            }
299                        }
300
301                        if is_last_batch {
302                            break;
303                        }
304                    }
305                    Ok(())
306                }
307                .await;
308
309                match batch_result {
310                    Ok(()) => {
311                        num_errors = 0;
312                        reload = false;
313                    }
314                    Err(SqlxLedgerError::EventSubscriberClosed) => {
315                        closed.store(true, Ordering::SeqCst);
316                        break;
317                    }
318                    Err(e) => {
319                        num_errors += 1;
320                        let delay = backoff_delay(num_errors);
321                        tracing::error!(
322                            "Error fetching events (attempt {}): {}. Retrying in {:?}",
323                            num_errors,
324                            e,
325                            delay
326                        );
327                        tokio::time::sleep(delay).await;
328                        continue;
329                    }
330                }
331            }
332            if closed.load(Ordering::Relaxed) {
333                break;
334            }
335            loop {
336                match listener.recv().await {
337                    Ok(notification) => {
338                        let event: Result<SqlxLedgerEvent, _> =
339                            serde_json::from_str(notification.payload());
340                        if let Err(e) = &event {
341                            if e.to_string().contains("data field missing") {
342                                reload = true;
343                                break;
344                            }
345                        }
346                        match sqlx_ledger_notification_received(event, &snd, &mut last_id, reload) {
347                            Ok(false) => {
348                                reload = true;
349                                break;
350                            }
351                            Ok(_) => num_errors = 0,
352                            Err(_) => {
353                                closed.store(true, Ordering::SeqCst);
354                                break;
355                            }
356                        }
357                    }
358                    Err(e) => {
359                        // PgListener::recv() auto-reconnects on connection loss,
360                        // but still surfaces errors. Log and backoff to avoid a
361                        // tight spin loop, then trigger a reload to catch up on
362                        // any events missed during the reconnection.
363                        num_errors += 1;
364                        let delay = backoff_delay(num_errors);
365                        tracing::warn!(
366                            "PgListener recv error (attempt {}): {}. Retrying in {:?}",
367                            num_errors,
368                            e,
369                            delay
370                        );
371                        tokio::time::sleep(delay).await;
372                        reload = true;
373                        break;
374                    }
375                }
376            }
377        }
378        let _ = listener.unlisten("sqlx_ledger_events").await;
379    });
380    Ok(recv)
381}
382
383fn backoff_delay(num_errors: u32) -> std::time::Duration {
384    std::time::Duration::from_secs(1u64 << num_errors.min(5))
385}
386
387#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
388fn sqlx_ledger_notification_received(
389    event: Result<SqlxLedgerEvent, serde_json::Error>,
390    sender: &broadcast::Sender<SqlxLedgerEvent>,
391    last_id: &mut SqlxLedgerEventId,
392    ignore_gap: bool,
393) -> Result<bool, SqlxLedgerError> {
394    let mut event = event?;
395    event.record_otel_context();
396    let id = event.id;
397    if id <= *last_id {
398        return Ok(true);
399    }
400    if !ignore_gap && last_id.0 + 1 != id.0 {
401        return Ok(false);
402    }
403    sender.send(event)?;
404    *last_id = id;
405    Ok(true)
406}
407
408#[derive(Deserialize)]
409struct EventRaw {
410    id: SqlxLedgerEventId,
411    #[serde(default)]
412    data: Option<serde_json::Value>,
413    r#type: SqlxLedgerEventType,
414    recorded_at: DateTime<Utc>,
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use chrono::Utc;
421    use rust_decimal::Decimal;
422    use tokio::sync::broadcast;
423
424    use crate::balance::BalanceDetails;
425    use crate::{CorrelationId, Currency, EntryId, TransactionId, TxTemplateId};
426
427    fn make_balance_event(id: i64) -> SqlxLedgerEvent {
428        let now = Utc::now();
429        let journal_id = JournalId::new();
430        let account_id = AccountId::new();
431        let entry_id = EntryId::new();
432        let currency: Currency = "USD".parse().unwrap();
433
434        SqlxLedgerEvent {
435            id: SqlxLedgerEventId(id),
436            data: SqlxLedgerEventData::BalanceUpdated(BalanceDetails {
437                journal_id,
438                account_id,
439                entry_id,
440                currency,
441                settled_dr_balance: Decimal::ZERO,
442                settled_cr_balance: Decimal::ZERO,
443                settled_entry_id: entry_id,
444                settled_modified_at: now,
445                pending_dr_balance: Decimal::ZERO,
446                pending_cr_balance: Decimal::ZERO,
447                pending_entry_id: entry_id,
448                pending_modified_at: now,
449                encumbered_dr_balance: Decimal::ZERO,
450                encumbered_cr_balance: Decimal::ZERO,
451                encumbered_entry_id: entry_id,
452                encumbered_modified_at: now,
453                version: 1,
454                modified_at: now,
455                created_at: now,
456            }),
457            r#type: SqlxLedgerEventType::BalanceUpdated,
458            recorded_at: now,
459        }
460    }
461
462    fn make_transaction_event(id: i64) -> SqlxLedgerEvent {
463        let now = Utc::now();
464        SqlxLedgerEvent {
465            id: SqlxLedgerEventId(id),
466            data: SqlxLedgerEventData::TransactionCreated(Transaction {
467                id: TransactionId::new(),
468                version: 1,
469                journal_id: JournalId::new(),
470                tx_template_id: TxTemplateId::new(),
471                effective: now.date_naive(),
472                correlation_id: CorrelationId::new(),
473                external_id: "test-ext".to_string(),
474                description: None,
475                metadata_json: None,
476                created_at: now,
477                modified_at: now,
478            }),
479            r#type: SqlxLedgerEventType::TransactionCreated,
480            recorded_at: now,
481        }
482    }
483
484    #[test]
485    fn notification_received_sends_event_and_updates_last_id() {
486        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
487        let mut last_id = SqlxLedgerEventId(0);
488        let event = make_balance_event(1);
489
490        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
491        assert!(result.is_ok());
492        assert_eq!(result.unwrap(), true);
493        assert_eq!(last_id, SqlxLedgerEventId(1));
494
495        let received = recv.try_recv().unwrap();
496        assert_eq!(received.id, SqlxLedgerEventId(1));
497    }
498
499    #[test]
500    fn notification_received_skips_duplicate_id() {
501        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
502        let mut last_id = SqlxLedgerEventId(5);
503
504        let event = make_balance_event(5);
505        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
506        assert!(result.is_ok());
507        assert_eq!(result.unwrap(), true);
508        assert_eq!(last_id, SqlxLedgerEventId(5));
509
510        assert!(recv.try_recv().is_err());
511    }
512
513    #[test]
514    fn notification_received_skips_older_id() {
515        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
516        let mut last_id = SqlxLedgerEventId(10);
517
518        let event = make_balance_event(3);
519        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
520        assert!(result.is_ok());
521        assert_eq!(result.unwrap(), true);
522        assert_eq!(last_id, SqlxLedgerEventId(10));
523
524        assert!(recv.try_recv().is_err());
525    }
526
527    #[test]
528    fn notification_received_detects_gap_when_ignore_gap_false() {
529        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
530        let mut last_id = SqlxLedgerEventId(1);
531
532        let event = make_balance_event(5);
533        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
534        assert!(result.is_ok());
535        assert_eq!(result.unwrap(), false);
536        assert_eq!(last_id, SqlxLedgerEventId(1));
537
538        assert!(recv.try_recv().is_err());
539    }
540
541    #[test]
542    fn notification_received_ignores_gap_when_flag_set() {
543        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
544        let mut last_id = SqlxLedgerEventId(1);
545
546        let event = make_balance_event(5);
547        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, true);
548        assert!(result.is_ok());
549        assert_eq!(result.unwrap(), true);
550        assert_eq!(last_id, SqlxLedgerEventId(5));
551
552        let received = recv.try_recv().unwrap();
553        assert_eq!(received.id, SqlxLedgerEventId(5));
554    }
555
556    #[test]
557    fn notification_received_propagates_deserialization_error() {
558        let (sender, _recv) = broadcast::channel::<SqlxLedgerEvent>(16);
559        let mut last_id = SqlxLedgerEventId(0);
560
561        let deser_err: Result<SqlxLedgerEvent, _> = serde_json::from_str::<SqlxLedgerEvent>("{}");
562        assert!(deser_err.is_err());
563
564        let result = sqlx_ledger_notification_received(deser_err, &sender, &mut last_id, false);
565        assert!(result.is_err());
566        assert_eq!(last_id, SqlxLedgerEventId(0));
567    }
568
569    #[test]
570    fn notification_received_errors_when_no_receivers() {
571        // Create sender but drop all receivers — send() will return Err
572        let (sender, recv) = broadcast::channel::<SqlxLedgerEvent>(16);
573        drop(recv);
574        let mut last_id = SqlxLedgerEventId(0);
575
576        let event = make_balance_event(1);
577        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
578        assert!(result.is_err());
579    }
580
581    #[test]
582    fn notification_received_sequential_ids() {
583        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
584        let mut last_id = SqlxLedgerEventId(0);
585
586        for i in 1..=5 {
587            let event = make_balance_event(i);
588            let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
589            assert!(result.is_ok());
590            assert_eq!(result.unwrap(), true);
591            assert_eq!(last_id, SqlxLedgerEventId(i));
592        }
593
594        for i in 1..=5 {
595            let received = recv.try_recv().unwrap();
596            assert_eq!(received.id, SqlxLedgerEventId(i));
597        }
598    }
599
600    #[test]
601    fn notification_received_handles_transaction_event() {
602        let (sender, mut recv) = broadcast::channel::<SqlxLedgerEvent>(16);
603        let mut last_id = SqlxLedgerEventId(0);
604
605        let event = make_transaction_event(1);
606        let result = sqlx_ledger_notification_received(Ok(event), &sender, &mut last_id, false);
607        assert!(result.is_ok());
608        assert_eq!(result.unwrap(), true);
609        assert_eq!(last_id, SqlxLedgerEventId(1));
610
611        let received = recv.try_recv().unwrap();
612        assert!(matches!(
613            received.r#type,
614            SqlxLedgerEventType::TransactionCreated
615        ));
616    }
617
618    #[test]
619    fn notification_received_data_field_missing_error_string() {
620        // Verify that the "data field missing" error string matches what
621        // the subscribe loop checks for in the notification path
622        let raw_json =
623            r#"{"id": 1, "type": "BalanceUpdated", "recorded_at": "2024-01-01T00:00:00Z"}"#;
624        let result: Result<SqlxLedgerEvent, _> = serde_json::from_str(raw_json);
625        assert!(result.is_err());
626        let err_msg = result.unwrap_err().to_string();
627        assert!(
628            err_msg.contains("data field missing"),
629            "Expected 'data field missing' in error: {err_msg}"
630        );
631    }
632
633    #[test]
634    fn backoff_delay_caps_at_32_seconds() {
635        use std::time::Duration;
636        // num_errors is incremented before calling backoff_delay, so first
637        // call uses num_errors=1
638        assert_eq!(backoff_delay(1), Duration::from_secs(2));
639        assert_eq!(backoff_delay(2), Duration::from_secs(4));
640        assert_eq!(backoff_delay(3), Duration::from_secs(8));
641        assert_eq!(backoff_delay(4), Duration::from_secs(16));
642        assert_eq!(backoff_delay(5), Duration::from_secs(32));
643        // Cap at 32s for any num_errors >= 5
644        assert_eq!(backoff_delay(6), Duration::from_secs(32));
645        assert_eq!(backoff_delay(10), Duration::from_secs(32));
646        assert_eq!(backoff_delay(100), Duration::from_secs(32));
647    }
648}
649
650impl TryFrom<EventRaw> for SqlxLedgerEvent {
651    type Error = serde_json::Error;
652
653    fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
654        let data_value = value
655            .data
656            .ok_or_else(|| serde_json::Error::custom("data field missing"))?;
657
658        let data = match value.r#type {
659            SqlxLedgerEventType::BalanceUpdated => {
660                SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
661            }
662            SqlxLedgerEventType::TransactionCreated => {
663                SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
664            }
665            SqlxLedgerEventType::TransactionUpdated => {
666                SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
667            }
668        };
669
670        Ok(SqlxLedgerEvent {
671            id: value.id,
672            data,
673            r#type: value.r#type,
674            recorded_at: value.recorded_at,
675            #[cfg(feature = "otel")]
676            otel_context: tracing::Span::current().context(),
677        })
678    }
679}