sqlx_ledger/
event.rs

1//! Use [ledger.events()](crate::SqlxLedger::events()) to subscribe to events triggered by changes to the ledger.
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use sqlx::{postgres::PgListener, PgPool};
5use tokio::{
6    sync::{
7        broadcast::{self, error::RecvError},
8        RwLock,
9    },
10    task,
11};
12use tracing::instrument;
13#[cfg(feature = "otel")]
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16use std::{
17    collections::HashMap,
18    sync::{
19        atomic::{AtomicBool, Ordering},
20        Arc,
21    },
22};
23
24use crate::{
25    balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
26};
27
28/// Options when initializing the EventSubscriber
29pub struct EventSubscriberOpts {
30    pub close_on_lag: bool,
31    pub buffer: usize,
32    pub after_id: Option<SqlxLedgerEventId>,
33}
34impl Default for EventSubscriberOpts {
35    fn default() -> Self {
36        Self {
37            close_on_lag: false,
38            buffer: 100,
39            after_id: None,
40        }
41    }
42}
43
44/// Contains fields to store & manage various ledger-related `SqlxLedgerEvent` event receivers.
45#[derive(Debug, Clone)]
46pub struct EventSubscriber {
47    buffer: usize,
48    closed: Arc<AtomicBool>,
49    #[allow(clippy::type_complexity)]
50    balance_receivers:
51        Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
52    journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
53    all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
54}
55
56impl EventSubscriber {
57    pub(crate) async fn connect(
58        pool: &PgPool,
59        EventSubscriberOpts {
60            close_on_lag,
61            buffer,
62            after_id: start_id,
63        }: EventSubscriberOpts,
64    ) -> Result<Self, SqlxLedgerError> {
65        let closed = Arc::new(AtomicBool::new(false));
66        let mut incoming = subscribe(pool.clone(), Arc::clone(&closed), buffer, start_id).await?;
67        let all = Arc::new(incoming.resubscribe());
68        let balance_receivers = Arc::new(RwLock::new(HashMap::<
69            (JournalId, AccountId),
70            broadcast::Sender<SqlxLedgerEvent>,
71        >::new()));
72        let journal_receivers = Arc::new(RwLock::new(HashMap::<
73            JournalId,
74            broadcast::Sender<SqlxLedgerEvent>,
75        >::new()));
76        let inner_balance_receivers = Arc::clone(&balance_receivers);
77        let inner_journal_receivers = Arc::clone(&journal_receivers);
78        let inner_closed = Arc::clone(&closed);
79        tokio::spawn(async move {
80            loop {
81                match incoming.recv().await {
82                    Ok(event) => {
83                        let journal_id = event.journal_id();
84                        if let Some(journal_receivers) =
85                            inner_journal_receivers.read().await.get(&journal_id)
86                        {
87                            let _ = journal_receivers.send(event.clone());
88                        }
89                        if let Some(account_id) = event.account_id() {
90                            let receivers = inner_balance_receivers.read().await;
91                            if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
92                                let _ = receiver.send(event);
93                            }
94                        }
95                    }
96                    Err(RecvError::Lagged(_)) => {
97                        if close_on_lag {
98                            inner_closed.store(true, Ordering::SeqCst);
99                            inner_balance_receivers.write().await.clear();
100                            inner_journal_receivers.write().await.clear();
101                        }
102                    }
103                    Err(RecvError::Closed) => {
104                        tracing::warn!("Event subscriber closed");
105                        inner_closed.store(true, Ordering::SeqCst);
106                        inner_balance_receivers.write().await.clear();
107                        inner_journal_receivers.write().await.clear();
108                        break;
109                    }
110                }
111            }
112        });
113        Ok(Self {
114            buffer,
115            closed,
116            balance_receivers,
117            journal_receivers,
118            all,
119        })
120    }
121
122    pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
123        let recv = self.all.resubscribe();
124        if self.closed.load(Ordering::SeqCst) {
125            return Err(SqlxLedgerError::EventSubscriberClosed);
126        }
127        Ok(recv)
128    }
129
130    pub async fn account_balance(
131        &self,
132        journal_id: JournalId,
133        account_id: AccountId,
134    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
135        let mut listeners = self.balance_receivers.write().await;
136        let mut ret = None;
137        let sender = listeners
138            .entry((journal_id, account_id))
139            .or_insert_with(|| {
140                let (sender, recv) = broadcast::channel(self.buffer);
141                ret = Some(recv);
142                sender
143            });
144        let ret = ret.unwrap_or_else(|| sender.subscribe());
145        if self.closed.load(Ordering::SeqCst) {
146            listeners.remove(&(journal_id, account_id));
147            return Err(SqlxLedgerError::EventSubscriberClosed);
148        }
149        Ok(ret)
150    }
151
152    pub async fn journal(
153        &self,
154        journal_id: JournalId,
155    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
156        let mut listeners = self.journal_receivers.write().await;
157        let mut ret = None;
158        let sender = listeners.entry(journal_id).or_insert_with(|| {
159            let (sender, recv) = broadcast::channel(self.buffer);
160            ret = Some(recv);
161            sender
162        });
163        let ret = ret.unwrap_or_else(|| sender.subscribe());
164        if self.closed.load(Ordering::SeqCst) {
165            listeners.remove(&journal_id);
166            return Err(SqlxLedgerError::EventSubscriberClosed);
167        }
168        Ok(ret)
169    }
170}
171
172#[derive(
173    sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
174)]
175#[serde(transparent)]
176#[sqlx(transparent)]
177pub struct SqlxLedgerEventId(i64);
178impl SqlxLedgerEventId {
179    pub const BEGIN: Self = Self(0);
180}
181
182/// Representation of a ledger event.
183#[derive(Debug, Clone, Deserialize)]
184#[serde(try_from = "EventRaw")]
185pub struct SqlxLedgerEvent {
186    pub id: SqlxLedgerEventId,
187    pub data: SqlxLedgerEventData,
188    pub r#type: SqlxLedgerEventType,
189    pub recorded_at: DateTime<Utc>,
190    #[cfg(feature = "otel")]
191    pub otel_context: opentelemetry::Context,
192}
193
194impl SqlxLedgerEvent {
195    #[cfg(feature = "otel")]
196    fn record_otel_context(&mut self) {
197        self.otel_context = tracing::Span::current().context();
198    }
199
200    #[cfg(not(feature = "otel"))]
201    fn record_otel_context(&mut self) {}
202}
203
204impl SqlxLedgerEvent {
205    pub fn journal_id(&self) -> JournalId {
206        match &self.data {
207            SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
208            SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
209            SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
210        }
211    }
212
213    pub fn account_id(&self) -> Option<AccountId> {
214        match &self.data {
215            SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
216            _ => None,
217        }
218    }
219}
220
221/// Represents the different kinds of data that can be included in an `SqlxLedgerEvent` event.
222#[derive(Debug, Clone, Serialize, Deserialize)]
223#[allow(clippy::large_enum_variant)]
224pub enum SqlxLedgerEventData {
225    BalanceUpdated(BalanceDetails),
226    TransactionCreated(Transaction),
227    TransactionUpdated(Transaction),
228}
229
230/// Defines possible event types for `SqlxLedgerEvent`.
231#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub enum SqlxLedgerEventType {
233    BalanceUpdated,
234    TransactionCreated,
235    TransactionUpdated,
236}
237
238pub(crate) async fn subscribe(
239    pool: PgPool,
240    closed: Arc<AtomicBool>,
241    buffer: usize,
242    after_id: Option<SqlxLedgerEventId>,
243) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
244    let mut listener = PgListener::connect_with(&pool).await?;
245    listener.listen("sqlx_ledger_events").await?;
246    let (snd, recv) = broadcast::channel(buffer);
247    let mut reload = after_id.is_some();
248    task::spawn(async move {
249        let mut num_errors = 0;
250        let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
251        loop {
252            if reload {
253                match sqlx::query!(
254                    r#"SELECT json_build_object(
255                      'id', id,
256                      'type', type,
257                      'data', data,
258                      'recorded_at', recorded_at
259                    ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id"#,
260                    last_id.0
261                )
262                .fetch_all(&pool)
263                .await
264                {
265                    Ok(rows) => {
266                        num_errors = 0;
267                        for row in rows {
268                            let event: Result<SqlxLedgerEvent, _> =
269                                serde_json::from_value(row.payload);
270                            if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
271                                .is_err()
272                            {
273                                closed.store(true, Ordering::SeqCst);
274                                break;
275                            }
276                        }
277                    }
278                    Err(e) if num_errors == 0 => {
279                        tracing::error!("Error fetching events: {}", e);
280                        tokio::time::sleep(std::time::Duration::from_secs(1 << num_errors)).await;
281                        num_errors += 1;
282                        continue;
283                    }
284                    _ => {
285                        num_errors = 0;
286                        continue;
287                    }
288                }
289            }
290            if closed.load(Ordering::Relaxed) {
291                break;
292            }
293            while let Ok(notification) = listener.recv().await {
294                let event: Result<SqlxLedgerEvent, _> =
295                    serde_json::from_str(notification.payload());
296                match sqlx_ledger_notification_received(event, &snd, &mut last_id, !reload) {
297                    Ok(false) => break,
298                    Ok(_) => num_errors = 0,
299                    Err(_) => {
300                        closed.store(true, Ordering::SeqCst);
301                    }
302                }
303                reload = true;
304            }
305        }
306    });
307    Ok(recv)
308}
309
310#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
311fn sqlx_ledger_notification_received(
312    event: Result<SqlxLedgerEvent, serde_json::Error>,
313    sender: &broadcast::Sender<SqlxLedgerEvent>,
314    last_id: &mut SqlxLedgerEventId,
315    ignore_gap: bool,
316) -> Result<bool, SqlxLedgerError> {
317    let mut event = event?;
318    event.record_otel_context();
319    let id = event.id;
320    if id <= *last_id {
321        return Ok(true);
322    }
323    if !ignore_gap && last_id.0 + 1 != id.0 {
324        return Ok(false);
325    }
326    sender.send(event)?;
327    *last_id = id;
328    Ok(true)
329}
330
331#[derive(Deserialize)]
332struct EventRaw {
333    id: SqlxLedgerEventId,
334    data: serde_json::Value,
335    r#type: SqlxLedgerEventType,
336    recorded_at: DateTime<Utc>,
337}
338
339impl TryFrom<EventRaw> for SqlxLedgerEvent {
340    type Error = serde_json::Error;
341
342    fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
343        let data = match value.r#type {
344            SqlxLedgerEventType::BalanceUpdated => {
345                SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(value.data)?)
346            }
347            SqlxLedgerEventType::TransactionCreated => {
348                SqlxLedgerEventData::TransactionCreated(serde_json::from_value(value.data)?)
349            }
350            SqlxLedgerEventType::TransactionUpdated => {
351                SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(value.data)?)
352            }
353        };
354
355        Ok(SqlxLedgerEvent {
356            id: value.id,
357            data,
358            r#type: value.r#type,
359            recorded_at: value.recorded_at,
360            #[cfg(feature = "otel")]
361            otel_context: tracing::Span::current().context(),
362        })
363    }
364}