sqlx_ledger/
event.rs

1//! Use [ledger.events()](crate::SqlxLedger::events()) to subscribe to events triggered by changes to the ledger.
2use chrono::{DateTime, Utc};
3use serde::{de::Error as _, Deserialize, Serialize};
4use sqlx::{postgres::PgListener, PgPool};
5use tokio::{
6    sync::{
7        broadcast::{self, error::RecvError},
8        RwLock,
9    },
10    task,
11};
12use tracing::instrument;
13#[cfg(feature = "otel")]
14use tracing_opentelemetry::OpenTelemetrySpanExt;
15
16use std::{
17    collections::HashMap,
18    sync::{
19        atomic::{AtomicBool, Ordering},
20        Arc,
21    },
22};
23
24use crate::{
25    balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
26};
27
28/// Options when initializing the EventSubscriber
29pub struct EventSubscriberOpts {
30    pub close_on_lag: bool,
31    pub buffer: usize,
32    pub after_id: Option<SqlxLedgerEventId>,
33}
34impl Default for EventSubscriberOpts {
35    fn default() -> Self {
36        Self {
37            close_on_lag: false,
38            buffer: 100,
39            after_id: None,
40        }
41    }
42}
43
44/// Contains fields to store & manage various ledger-related `SqlxLedgerEvent` event receivers.
45#[derive(Debug, Clone)]
46pub struct EventSubscriber {
47    buffer: usize,
48    closed: Arc<AtomicBool>,
49    #[allow(clippy::type_complexity)]
50    balance_receivers:
51        Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
52    journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
53    all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
54}
55
56impl EventSubscriber {
57    pub(crate) async fn connect(
58        pool: &PgPool,
59        EventSubscriberOpts {
60            close_on_lag,
61            buffer,
62            after_id: start_id,
63        }: EventSubscriberOpts,
64    ) -> Result<Self, SqlxLedgerError> {
65        let closed = Arc::new(AtomicBool::new(false));
66        let mut incoming = subscribe(pool.clone(), Arc::clone(&closed), buffer, start_id).await?;
67        let all = Arc::new(incoming.resubscribe());
68        let balance_receivers = Arc::new(RwLock::new(HashMap::<
69            (JournalId, AccountId),
70            broadcast::Sender<SqlxLedgerEvent>,
71        >::new()));
72        let journal_receivers = Arc::new(RwLock::new(HashMap::<
73            JournalId,
74            broadcast::Sender<SqlxLedgerEvent>,
75        >::new()));
76        let inner_balance_receivers = Arc::clone(&balance_receivers);
77        let inner_journal_receivers = Arc::clone(&journal_receivers);
78        let inner_closed = Arc::clone(&closed);
79        tokio::spawn(async move {
80            loop {
81                match incoming.recv().await {
82                    Ok(event) => {
83                        let journal_id = event.journal_id();
84                        if let Some(journal_receivers) =
85                            inner_journal_receivers.read().await.get(&journal_id)
86                        {
87                            let _ = journal_receivers.send(event.clone());
88                        }
89                        if let Some(account_id) = event.account_id() {
90                            let receivers = inner_balance_receivers.read().await;
91                            if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
92                                let _ = receiver.send(event);
93                            }
94                        }
95                    }
96                    Err(RecvError::Lagged(_)) => {
97                        if close_on_lag {
98                            inner_closed.store(true, Ordering::SeqCst);
99                            inner_balance_receivers.write().await.clear();
100                            inner_journal_receivers.write().await.clear();
101                        }
102                    }
103                    Err(RecvError::Closed) => {
104                        tracing::warn!("Event subscriber closed");
105                        inner_closed.store(true, Ordering::SeqCst);
106                        inner_balance_receivers.write().await.clear();
107                        inner_journal_receivers.write().await.clear();
108                        break;
109                    }
110                }
111            }
112        });
113        Ok(Self {
114            buffer,
115            closed,
116            balance_receivers,
117            journal_receivers,
118            all,
119        })
120    }
121
122    pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
123        let recv = self.all.resubscribe();
124        if self.closed.load(Ordering::SeqCst) {
125            return Err(SqlxLedgerError::EventSubscriberClosed);
126        }
127        Ok(recv)
128    }
129
130    pub async fn account_balance(
131        &self,
132        journal_id: JournalId,
133        account_id: AccountId,
134    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
135        let mut listeners = self.balance_receivers.write().await;
136        let mut ret = None;
137        let sender = listeners
138            .entry((journal_id, account_id))
139            .or_insert_with(|| {
140                let (sender, recv) = broadcast::channel(self.buffer);
141                ret = Some(recv);
142                sender
143            });
144        let ret = ret.unwrap_or_else(|| sender.subscribe());
145        if self.closed.load(Ordering::SeqCst) {
146            listeners.remove(&(journal_id, account_id));
147            return Err(SqlxLedgerError::EventSubscriberClosed);
148        }
149        Ok(ret)
150    }
151
152    pub async fn journal(
153        &self,
154        journal_id: JournalId,
155    ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
156        let mut listeners = self.journal_receivers.write().await;
157        let mut ret = None;
158        let sender = listeners.entry(journal_id).or_insert_with(|| {
159            let (sender, recv) = broadcast::channel(self.buffer);
160            ret = Some(recv);
161            sender
162        });
163        let ret = ret.unwrap_or_else(|| sender.subscribe());
164        if self.closed.load(Ordering::SeqCst) {
165            listeners.remove(&journal_id);
166            return Err(SqlxLedgerError::EventSubscriberClosed);
167        }
168        Ok(ret)
169    }
170}
171
172#[derive(
173    sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
174)]
175#[serde(transparent)]
176#[sqlx(transparent)]
177pub struct SqlxLedgerEventId(i64);
178impl SqlxLedgerEventId {
179    pub const BEGIN: Self = Self(0);
180}
181
182/// Representation of a ledger event.
183#[derive(Debug, Clone, Deserialize)]
184#[serde(try_from = "EventRaw")]
185pub struct SqlxLedgerEvent {
186    pub id: SqlxLedgerEventId,
187    pub data: SqlxLedgerEventData,
188    pub r#type: SqlxLedgerEventType,
189    pub recorded_at: DateTime<Utc>,
190    #[cfg(feature = "otel")]
191    pub otel_context: opentelemetry::Context,
192}
193
194impl SqlxLedgerEvent {
195    #[cfg(feature = "otel")]
196    fn record_otel_context(&mut self) {
197        self.otel_context = tracing::Span::current().context();
198    }
199
200    #[cfg(not(feature = "otel"))]
201    fn record_otel_context(&mut self) {}
202}
203
204impl SqlxLedgerEvent {
205    pub fn journal_id(&self) -> JournalId {
206        match &self.data {
207            SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
208            SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
209            SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
210        }
211    }
212
213    pub fn account_id(&self) -> Option<AccountId> {
214        match &self.data {
215            SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
216            _ => None,
217        }
218    }
219}
220
221/// Represents the different kinds of data that can be included in an `SqlxLedgerEvent` event.
222#[derive(Debug, Clone, Serialize, Deserialize)]
223#[allow(clippy::large_enum_variant)]
224pub enum SqlxLedgerEventData {
225    BalanceUpdated(BalanceDetails),
226    TransactionCreated(Transaction),
227    TransactionUpdated(Transaction),
228}
229
230/// Defines possible event types for `SqlxLedgerEvent`.
231#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
232pub enum SqlxLedgerEventType {
233    BalanceUpdated,
234    TransactionCreated,
235    TransactionUpdated,
236}
237
238pub(crate) async fn subscribe(
239    pool: PgPool,
240    closed: Arc<AtomicBool>,
241    buffer: usize,
242    after_id: Option<SqlxLedgerEventId>,
243) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
244    let mut listener = PgListener::connect_with(&pool).await?;
245    listener.listen("sqlx_ledger_events").await?;
246    let (snd, recv) = broadcast::channel(buffer);
247    let mut reload = after_id.is_some();
248    task::spawn(async move {
249        let mut num_errors = 0;
250        let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
251        loop {
252            if reload {
253                match sqlx::query!(
254                    r#"SELECT json_build_object(
255                      'id', id,
256                      'type', type,
257                      'data', data,
258                      'recorded_at', recorded_at
259                    ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id"#,
260                    last_id.0
261                )
262                .fetch_all(&pool)
263                .await
264                {
265                    Ok(rows) => {
266                        num_errors = 0;
267                        for row in rows {
268                            let event: Result<SqlxLedgerEvent, _> =
269                                serde_json::from_value(row.payload);
270                            if sqlx_ledger_notification_received(event, &snd, &mut last_id, true)
271                                .is_err()
272                            {
273                                closed.store(true, Ordering::SeqCst);
274                                break;
275                            }
276                        }
277                    }
278                    Err(e) if num_errors == 0 => {
279                        tracing::error!("Error fetching events: {}", e);
280                        tokio::time::sleep(std::time::Duration::from_secs(1 << num_errors)).await;
281                        num_errors += 1;
282                        continue;
283                    }
284                    _ => {
285                        num_errors = 0;
286                        continue;
287                    }
288                }
289            }
290            if closed.load(Ordering::Relaxed) {
291                break;
292            }
293            while let Ok(notification) = listener.recv().await {
294                let event: Result<SqlxLedgerEvent, _> =
295                    serde_json::from_str(notification.payload());
296                if let Err(e) = &event {
297                    if e.to_string().contains("data field missing") {
298                        reload = true;
299                        break;
300                    }
301                }
302                match sqlx_ledger_notification_received(event, &snd, &mut last_id, !reload) {
303                    Ok(false) => break,
304                    Ok(_) => num_errors = 0,
305                    Err(_) => {
306                        closed.store(true, Ordering::SeqCst);
307                    }
308                }
309                reload = true;
310            }
311        }
312        let _ = listener.unlisten("sqlx_ledger_events").await;
313    });
314    Ok(recv)
315}
316
317#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
318fn sqlx_ledger_notification_received(
319    event: Result<SqlxLedgerEvent, serde_json::Error>,
320    sender: &broadcast::Sender<SqlxLedgerEvent>,
321    last_id: &mut SqlxLedgerEventId,
322    ignore_gap: bool,
323) -> Result<bool, SqlxLedgerError> {
324    let mut event = event?;
325    event.record_otel_context();
326    let id = event.id;
327    if id <= *last_id {
328        return Ok(true);
329    }
330    if !ignore_gap && last_id.0 + 1 != id.0 {
331        return Ok(false);
332    }
333    sender.send(event)?;
334    *last_id = id;
335    Ok(true)
336}
337
338#[derive(Deserialize)]
339struct EventRaw {
340    id: SqlxLedgerEventId,
341    #[serde(default)]
342    data: Option<serde_json::Value>,
343    r#type: SqlxLedgerEventType,
344    recorded_at: DateTime<Utc>,
345}
346
347impl TryFrom<EventRaw> for SqlxLedgerEvent {
348    type Error = serde_json::Error;
349
350    fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
351        let data_value = value
352            .data
353            .ok_or_else(|| serde_json::Error::custom("data field missing"))?;
354
355        let data = match value.r#type {
356            SqlxLedgerEventType::BalanceUpdated => {
357                SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
358            }
359            SqlxLedgerEventType::TransactionCreated => {
360                SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
361            }
362            SqlxLedgerEventType::TransactionUpdated => {
363                SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
364            }
365        };
366
367        Ok(SqlxLedgerEvent {
368            id: value.id,
369            data,
370            r#type: value.r#type,
371            recorded_at: value.recorded_at,
372            #[cfg(feature = "otel")]
373            otel_context: tracing::Span::current().context(),
374        })
375    }
376}