1use chrono::{DateTime, Utc};
3use futures::StreamExt;
4use serde::{de::Error as _, Deserialize, Serialize};
5use sqlx::{postgres::PgListener, PgPool};
6use std::time::Duration;
7use tokio::{
8 sync::{
9 broadcast::{self, error::RecvError},
10 RwLock,
11 },
12 task,
13};
14use tracing::instrument;
15#[cfg(feature = "otel")]
16use tracing_opentelemetry::OpenTelemetrySpanExt;
17
18use std::{
19 collections::HashMap,
20 sync::{
21 atomic::{AtomicBool, Ordering},
22 Arc,
23 },
24};
25
26use crate::{
27 balance::BalanceDetails, transaction::Transaction, AccountId, JournalId, SqlxLedgerError,
28};
29
30pub struct EventSubscriberOpts {
32 pub close_on_lag: bool,
33 pub buffer: usize,
34 pub after_id: Option<SqlxLedgerEventId>,
35}
36impl Default for EventSubscriberOpts {
37 fn default() -> Self {
38 Self {
39 close_on_lag: false,
40 buffer: 100,
41 after_id: None,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct EventSubscriber {
49 buffer: usize,
50 closed: Arc<AtomicBool>,
51 #[allow(clippy::type_complexity)]
52 balance_receivers:
53 Arc<RwLock<HashMap<(JournalId, AccountId), broadcast::Sender<SqlxLedgerEvent>>>>,
54 journal_receivers: Arc<RwLock<HashMap<JournalId, broadcast::Sender<SqlxLedgerEvent>>>>,
55 all: Arc<broadcast::Receiver<SqlxLedgerEvent>>,
56}
57
58impl EventSubscriber {
59 pub(crate) async fn connect(
60 pool: &PgPool,
61 EventSubscriberOpts {
62 close_on_lag,
63 buffer,
64 after_id: start_id,
65 }: EventSubscriberOpts,
66 ) -> Result<Self, SqlxLedgerError> {
67 let closed = Arc::new(AtomicBool::new(false));
68 let mut incoming = subscribe(pool.clone(), Arc::clone(&closed), buffer, start_id).await?;
69 let all = Arc::new(incoming.resubscribe());
70 let balance_receivers = Arc::new(RwLock::new(HashMap::<
71 (JournalId, AccountId),
72 broadcast::Sender<SqlxLedgerEvent>,
73 >::new()));
74 let journal_receivers = Arc::new(RwLock::new(HashMap::<
75 JournalId,
76 broadcast::Sender<SqlxLedgerEvent>,
77 >::new()));
78 let inner_balance_receivers = Arc::clone(&balance_receivers);
79 let inner_journal_receivers = Arc::clone(&journal_receivers);
80 let inner_closed = Arc::clone(&closed);
81 tokio::spawn(async move {
82 loop {
83 match incoming.recv().await {
84 Ok(event) => {
85 let journal_id = event.journal_id();
86 if let Some(journal_receivers) =
87 inner_journal_receivers.read().await.get(&journal_id)
88 {
89 let _ = journal_receivers.send(event.clone());
90 }
91 if let Some(account_id) = event.account_id() {
92 let receivers = inner_balance_receivers.read().await;
93 if let Some(receiver) = receivers.get(&(journal_id, account_id)) {
94 let _ = receiver.send(event);
95 }
96 }
97 }
98 Err(RecvError::Lagged(_)) => {
99 if close_on_lag {
100 inner_closed.store(true, Ordering::SeqCst);
101 inner_balance_receivers.write().await.clear();
102 inner_journal_receivers.write().await.clear();
103 }
104 }
105 Err(RecvError::Closed) => {
106 tracing::warn!("Event subscriber closed");
107 inner_closed.store(true, Ordering::SeqCst);
108 inner_balance_receivers.write().await.clear();
109 inner_journal_receivers.write().await.clear();
110 break;
111 }
112 }
113 }
114 });
115 Ok(Self {
116 buffer,
117 closed,
118 balance_receivers,
119 journal_receivers,
120 all,
121 })
122 }
123
124 pub fn all(&self) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
125 let recv = self.all.resubscribe();
126 if self.closed.load(Ordering::SeqCst) {
127 return Err(SqlxLedgerError::EventSubscriberClosed);
128 }
129 Ok(recv)
130 }
131
132 pub async fn account_balance(
133 &self,
134 journal_id: JournalId,
135 account_id: AccountId,
136 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
137 let mut listeners = self.balance_receivers.write().await;
138 let mut ret = None;
139 let sender = listeners
140 .entry((journal_id, account_id))
141 .or_insert_with(|| {
142 let (sender, recv) = broadcast::channel(self.buffer);
143 ret = Some(recv);
144 sender
145 });
146 let ret = ret.unwrap_or_else(|| sender.subscribe());
147 if self.closed.load(Ordering::SeqCst) {
148 listeners.remove(&(journal_id, account_id));
149 return Err(SqlxLedgerError::EventSubscriberClosed);
150 }
151 Ok(ret)
152 }
153
154 pub async fn journal(
155 &self,
156 journal_id: JournalId,
157 ) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
158 let mut listeners = self.journal_receivers.write().await;
159 let mut ret = None;
160 let sender = listeners.entry(journal_id).or_insert_with(|| {
161 let (sender, recv) = broadcast::channel(self.buffer);
162 ret = Some(recv);
163 sender
164 });
165 let ret = ret.unwrap_or_else(|| sender.subscribe());
166 if self.closed.load(Ordering::SeqCst) {
167 listeners.remove(&journal_id);
168 return Err(SqlxLedgerError::EventSubscriberClosed);
169 }
170 Ok(ret)
171 }
172}
173
174#[derive(
175 sqlx::Type, Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Hash, Copy,
176)]
177#[serde(transparent)]
178#[sqlx(transparent)]
179pub struct SqlxLedgerEventId(i64);
180impl SqlxLedgerEventId {
181 pub const BEGIN: Self = Self(0);
182}
183
184#[derive(Debug, Clone, Deserialize)]
186#[serde(try_from = "EventRaw")]
187pub struct SqlxLedgerEvent {
188 pub id: SqlxLedgerEventId,
189 pub data: SqlxLedgerEventData,
190 pub r#type: SqlxLedgerEventType,
191 pub recorded_at: DateTime<Utc>,
192 #[cfg(feature = "otel")]
193 pub otel_context: opentelemetry::Context,
194}
195
196impl SqlxLedgerEvent {
197 #[cfg(feature = "otel")]
198 fn record_otel_context(&mut self) {
199 self.otel_context = tracing::Span::current().context();
200 }
201
202 #[cfg(not(feature = "otel"))]
203 fn record_otel_context(&mut self) {}
204}
205
206impl SqlxLedgerEvent {
207 pub fn journal_id(&self) -> JournalId {
208 match &self.data {
209 SqlxLedgerEventData::BalanceUpdated(b) => b.journal_id,
210 SqlxLedgerEventData::TransactionCreated(t) => t.journal_id,
211 SqlxLedgerEventData::TransactionUpdated(t) => t.journal_id,
212 }
213 }
214
215 pub fn account_id(&self) -> Option<AccountId> {
216 match &self.data {
217 SqlxLedgerEventData::BalanceUpdated(b) => Some(b.account_id),
218 _ => None,
219 }
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225#[allow(clippy::large_enum_variant)]
226pub enum SqlxLedgerEventData {
227 BalanceUpdated(BalanceDetails),
228 TransactionCreated(Transaction),
229 TransactionUpdated(Transaction),
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
234pub enum SqlxLedgerEventType {
235 BalanceUpdated,
236 TransactionCreated,
237 TransactionUpdated,
238}
239
240pub(crate) async fn subscribe(
241 pool: PgPool,
242 closed: Arc<AtomicBool>,
243 buffer: usize,
244 after_id: Option<SqlxLedgerEventId>,
245) -> Result<broadcast::Receiver<SqlxLedgerEvent>, SqlxLedgerError> {
246 let mut listener = PgListener::connect_with(&pool).await?;
247 listener.listen("sqlx_ledger_events").await?;
248 let (snd, recv) = broadcast::channel(buffer);
249 let mut reload = after_id.is_some();
250
251 task::spawn(async move {
252 let mut last_id = after_id.unwrap_or(SqlxLedgerEventId(0));
253 let mut consecutive_errors = 0;
254 const MAX_RETRY_DELAY: u64 = 60;
255 const MAX_CONSECUTIVE_ERRORS: u32 = 5;
256
257 let subscriber_loop = async {
258 loop {
259 if reload {
260 let mut stream = sqlx::query!(
261 r#"SELECT json_build_object(
262 'id', id,
263 'type', type,
264 'data', data,
265 'recorded_at', recorded_at
266 ) AS "payload!" FROM sqlx_ledger_events WHERE id > $1 ORDER BY id"#,
267 last_id.0
268 )
269 .fetch(&pool);
270
271 let mut stream_failed = false;
272 while let Some(result) = stream.next().await {
273 match result {
274 Ok(row) => {
275 consecutive_errors = 0;
276 let event: Result<SqlxLedgerEvent, _> =
277 serde_json::from_value(row.payload);
278 if sqlx_ledger_notification_received(
279 event,
280 &snd,
281 &mut last_id,
282 true,
283 )
284 .is_err()
285 {
286 return Err("channel closed");
287 }
288 }
289 Err(e) => {
290 consecutive_errors += 1;
291 tracing::error!(
292 "Error fetching events after id {}: {}",
293 last_id.0,
294 e
295 );
296
297 if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
298 tracing::error!("Max retries exceeded");
299 return Err("max retries");
300 }
301
302 let delay =
303 (1_u64 << consecutive_errors.min(6)).min(MAX_RETRY_DELAY);
304 tokio::time::sleep(Duration::from_secs(delay)).await;
305 stream_failed = true;
306 break;
307 }
308 }
309 }
310
311 if stream_failed {
312 continue;
313 }
314
315 reload = false;
316 }
317
318 if closed.load(Ordering::Relaxed) {
319 return Ok(());
320 }
321
322 tokio::select! {
323 notification = listener.recv() => {
324 match notification {
325 Ok(n) => {
326 let event: Result<SqlxLedgerEvent, _> =
327 serde_json::from_str(n.payload());
328
329 if let Err(e) = &event {
330 tracing::warn!("Failed to parse: {}", e);
331 if e.to_string().contains("data field missing") {
332 reload = true;
333 continue;
334 }
335 }
336
337 match sqlx_ledger_notification_received(event, &snd, &mut last_id, false) {
338 Ok(false) => return Err("channel closed"),
339 Ok(_) => consecutive_errors = 0,
340 Err(_) => return Err("channel closed"),
341 }
342 }
343 Err(e) => {
344 tracing::error!("Listener error: {}", e);
345 consecutive_errors += 1;
346 if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
347 return Err("max retries");
348 }
349 tokio::time::sleep(Duration::from_secs(1)).await;
350 }
351 }
352 }
353 _ = tokio::time::sleep(Duration::from_secs(30)) => {
354 if closed.load(Ordering::Relaxed) {
356 return Ok(());
357 }
358 }
359 }
360 }
361 };
362
363 let result = subscriber_loop.await;
364 match result {
365 Ok(()) => tracing::info!("Subscriber shutting down gracefully"),
366 Err(reason) => tracing::warn!("Subscriber shutting down: {}", reason),
367 }
368
369 if let Err(e) = listener.unlisten("sqlx_ledger_events").await {
370 tracing::warn!("Failed to unlisten from sqlx_ledger_events: {}", e);
371 }
372 });
373
374 Ok(recv)
375}
376
377#[instrument(name = "sqlx_ledger.notification_received", skip(sender), err)]
378fn sqlx_ledger_notification_received(
379 event: Result<SqlxLedgerEvent, serde_json::Error>,
380 sender: &broadcast::Sender<SqlxLedgerEvent>,
381 last_id: &mut SqlxLedgerEventId,
382 ignore_gap: bool,
383) -> Result<bool, SqlxLedgerError> {
384 let mut event = event?;
385 event.record_otel_context();
386 let id = event.id;
387 if id <= *last_id {
388 return Ok(true);
389 }
390 if !ignore_gap && last_id.0 + 1 != id.0 {
391 return Ok(false);
392 }
393 sender.send(event)?;
394 *last_id = id;
395 Ok(true)
396}
397
398#[derive(Deserialize)]
399struct EventRaw {
400 id: SqlxLedgerEventId,
401 #[serde(default)]
402 data: Option<serde_json::Value>,
403 r#type: SqlxLedgerEventType,
404 recorded_at: DateTime<Utc>,
405}
406
407impl TryFrom<EventRaw> for SqlxLedgerEvent {
408 type Error = serde_json::Error;
409
410 fn try_from(value: EventRaw) -> Result<Self, Self::Error> {
411 let data_value = value
412 .data
413 .ok_or_else(|| serde_json::Error::custom("data field missing"))?;
414
415 let data = match value.r#type {
416 SqlxLedgerEventType::BalanceUpdated => {
417 SqlxLedgerEventData::BalanceUpdated(serde_json::from_value(data_value)?)
418 }
419 SqlxLedgerEventType::TransactionCreated => {
420 SqlxLedgerEventData::TransactionCreated(serde_json::from_value(data_value)?)
421 }
422 SqlxLedgerEventType::TransactionUpdated => {
423 SqlxLedgerEventData::TransactionUpdated(serde_json::from_value(data_value)?)
424 }
425 };
426
427 Ok(SqlxLedgerEvent {
428 id: value.id,
429 data,
430 r#type: value.r#type,
431 recorded_at: value.recorded_at,
432 #[cfg(feature = "otel")]
433 otel_context: tracing::Span::current().context(),
434 })
435 }
436}