1use std::{collections::HashMap, pin::Pin};
31
32use futures_core::Stream;
33use serde::{Serialize, de::DeserializeOwned};
34use thiserror::Error;
35use tokio::task::JoinHandle;
36use tokio_stream::StreamExt as _;
37
38use crate::{
39 event::EventDecodeError,
40 projection::{HandlerError, Projection, ProjectionFilters},
41 snapshot::{Snapshot, SnapshotStore},
42 store::{EventFilter, EventStore, GloballyOrderedStore, StoredEvent},
43};
44
45pub type EventStream<'a, S> = Pin<
48 Box<
49 dyn Stream<
50 Item = Result<
51 StoredEvent<
52 <S as EventStore>::Id,
53 <S as EventStore>::Position,
54 <S as EventStore>::Data,
55 <S as EventStore>::Metadata,
56 >,
57 <S as EventStore>::Error,
58 >,
59 > + Send
60 + 'a,
61 >,
62>;
63
64pub trait SubscribableStore: EventStore + GloballyOrderedStore {
75 fn subscribe(
88 &self,
89 filters: &[EventFilter<Self::Id, Self::Position>],
90 from_position: Option<Self::Position>,
91 ) -> EventStream<'_, Self>
92 where
93 Self::Position: Ord;
94}
95
96#[derive(Debug, Error)]
98pub enum SubscriptionError<StoreError>
99where
100 StoreError: std::error::Error + 'static,
101{
102 #[error("store error: {0}")]
104 Store(#[source] StoreError),
105 #[error("failed to decode event: {0}")]
107 EventDecode(#[source] EventDecodeError<StoreError>),
108 #[error("subscription ended before catch-up completed")]
110 CatchupInterrupted,
111 #[error("subscription task panicked")]
113 TaskPanicked,
114}
115
116pub struct SubscriptionHandle<StoreError>
123where
124 StoreError: std::error::Error + 'static,
125{
126 stop_tx: Option<tokio::sync::oneshot::Sender<()>>,
127 task: Option<JoinHandle<Result<(), SubscriptionError<StoreError>>>>,
128}
129
130impl<StoreError> SubscriptionHandle<StoreError>
131where
132 StoreError: std::error::Error + 'static,
133{
134 #[allow(clippy::missing_panics_doc)]
140 pub async fn stop(mut self) -> Result<(), SubscriptionError<StoreError>> {
141 if let Some(tx) = self.stop_tx.take() {
143 let _ = tx.send(());
144 }
145
146 if let Some(task) = self.task.take() {
148 return task.await.map_err(|_| SubscriptionError::TaskPanicked)?;
149 }
150
151 Ok(())
152 }
153
154 #[must_use]
156 pub fn is_running(&self) -> bool {
157 self.task.as_ref().is_some_and(|task| !task.is_finished())
158 }
159}
160
161impl<StoreError> Drop for SubscriptionHandle<StoreError>
162where
163 StoreError: std::error::Error + 'static,
164{
165 fn drop(&mut self) {
166 if self.is_running() {
167 tracing::warn!(
168 "subscription handle dropped without stop(); signaling background task to stop"
169 );
170 if let Some(tx) = self.stop_tx.take() {
171 let _ = tx.send(());
172 }
173 }
174 }
175}
176
177type UpdateCallback<P> = Box<dyn Fn(&P) + Send + Sync + 'static>;
179
180pub struct SubscriptionBuilder<S, P, SS>
189where
190 S: EventStore,
191 P: ProjectionFilters,
192{
193 store: S,
194 snapshots: SS,
195 instance_id: P::InstanceId,
196 on_update: Option<UpdateCallback<P>>,
197}
198
199impl<S, P, SS> SubscriptionBuilder<S, P, SS>
200where
201 S: SubscribableStore + Clone + Send + Sync + 'static,
202 S::Position: Ord + Send + Sync,
203 S::Data: Send,
204 S::Metadata: Send + Sync,
205 P: Projection
206 + ProjectionFilters<Id = S::Id, Metadata = S::Metadata>
207 + Serialize
208 + DeserializeOwned
209 + Send
210 + Sync
211 + 'static,
212 P::InstanceId: Clone + Send + Sync + 'static,
213 P::Metadata: Send,
214 SS: SnapshotStore<P::InstanceId, Position = S::Position> + Send + Sync + 'static,
215{
216 pub(crate) fn new(store: S, snapshots: SS, instance_id: P::InstanceId) -> Self {
217 Self {
218 store,
219 snapshots,
220 instance_id,
221 on_update: None,
222 }
223 }
224
225 #[must_use]
231 pub fn on_update<F>(mut self, callback: F) -> Self
232 where
233 F: Fn(&P) + Send + Sync + 'static,
234 {
235 self.on_update = Some(Box::new(callback));
236 self
237 }
238
239 #[allow(clippy::too_many_lines)]
254 pub async fn start(self) -> Result<SubscriptionHandle<S::Error>, SubscriptionError<S::Error>> {
255 let Self {
256 store,
257 snapshots,
258 instance_id,
259 on_update,
260 } = self;
261
262 let (mut projection, snapshot_position) =
263 load_snapshot::<P, SS>(&snapshots, &instance_id).await;
264
265 let filters = P::filters::<S>(&instance_id);
267 let (event_filters, handlers) = filters.into_event_filters(snapshot_position.as_ref());
268
269 let current_events = store
270 .load_events(&event_filters)
271 .await
272 .map_err(SubscriptionError::Store)?;
273
274 let catchup_target_position = current_events.last().map(|e| e.position.clone());
275
276 let mut last_position = snapshot_position;
278 let mut events_since_snapshot: u64 = 0;
279
280 for stored in ¤t_events {
281 process_subscription_event(
282 &mut projection,
283 stored,
284 &handlers,
285 &store,
286 on_update.as_ref(),
287 &mut last_position,
288 &mut events_since_snapshot,
289 )?;
290 }
291
292 if events_since_snapshot > 0
294 && let Some(ref pos) = last_position
295 && offer_projection_snapshot(
296 &snapshots,
297 &instance_id,
298 events_since_snapshot,
299 pos,
300 &projection,
301 )
302 .await
303 {
304 events_since_snapshot = 0;
305 }
306
307 let (stop_tx, mut stop_rx) = tokio::sync::oneshot::channel();
309 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel();
310
311 let task = tokio::spawn(async move {
312 let mut ready_tx = Some(ready_tx);
313
314 let signal_ready = |ready_tx: &mut Option<tokio::sync::oneshot::Sender<()>>| {
315 if let Some(tx) = ready_tx.take() {
316 let _ = tx.send(());
317 }
318 };
319
320 let filters = P::filters::<S>(&instance_id);
322 let (live_filters, handlers) = filters.into_event_filters(last_position.as_ref());
323
324 let mut stream = store.subscribe(&live_filters, last_position.clone());
325
326 let catchup_target = store
332 .load_events(&live_filters)
333 .await
334 .map_err(SubscriptionError::Store)?
335 .last()
336 .map(|e| e.position.clone())
337 .or(catchup_target_position);
338
339 if catchup_target.is_none() || last_position >= catchup_target {
341 signal_ready(&mut ready_tx);
342 }
343
344 loop {
345 tokio::select! {
346 biased;
347 _ = &mut stop_rx => {
348 tracing::debug!("subscription stopped");
349 break;
350 }
351 event = stream.next() => {
352 let Some(result) = event else {
353 tracing::debug!("subscription stream ended");
354 break;
355 };
356
357 let stored = result.map_err(SubscriptionError::Store)?;
358
359 if let Some(ref lp) = last_position
361 && stored.position <= *lp
362 {
363 continue;
364 }
365
366 process_subscription_event(
367 &mut projection,
368 &stored,
369 &handlers,
370 &store,
371 on_update.as_ref(),
372 &mut last_position,
373 &mut events_since_snapshot,
374 )?;
375
376 if catchup_target.is_none() || last_position >= catchup_target {
379 signal_ready(&mut ready_tx);
380 }
381
382 if events_since_snapshot.is_multiple_of(100)
384 && let Some(ref pos) = last_position
385 && offer_projection_snapshot(
386 &snapshots,
387 &instance_id,
388 events_since_snapshot,
389 pos,
390 &projection,
391 )
392 .await
393 {
394 events_since_snapshot = 0;
395 }
396 }
397 }
398 }
399
400 if events_since_snapshot > 0
402 && let Some(ref pos) = last_position
403 {
404 let _ = offer_projection_snapshot(
405 &snapshots,
406 &instance_id,
407 events_since_snapshot,
408 pos,
409 &projection,
410 )
411 .await;
412 }
413
414 Ok(())
415 });
416
417 match ready_rx.await {
418 Ok(()) => Ok(SubscriptionHandle {
419 stop_tx: Some(stop_tx),
420 task: Some(task),
421 }),
422 Err(_) => match task.await {
423 Ok(Ok(())) => Err(SubscriptionError::CatchupInterrupted),
424 Ok(Err(error)) => Err(error),
425 Err(_) => Err(SubscriptionError::TaskPanicked),
426 },
427 }
428 }
429}
430
431async fn load_snapshot<P, SS>(
432 snapshots: &SS,
433 instance_id: &P::InstanceId,
434) -> (P, Option<SS::Position>)
435where
436 P: Projection + ProjectionFilters + DeserializeOwned,
437 P::InstanceId: Sync,
438 SS: SnapshotStore<P::InstanceId>,
439{
440 let snapshot_result = snapshots
441 .load::<P>(P::KIND, instance_id)
442 .await
443 .inspect_err(|e| {
444 tracing::error!(error = %e, "failed to load subscription snapshot");
445 })
446 .ok()
447 .flatten();
448
449 if let Some(snapshot) = snapshot_result {
450 (snapshot.data, Some(snapshot.position))
451 } else {
452 (P::init(instance_id), None)
453 }
454}
455
456fn apply_handler<P, S>(
457 handler: &crate::projection::EventHandler<P, S>,
458 projection: &mut P,
459 stored: &StoredEvent<S::Id, S::Position, S::Data, S::Metadata>,
460 store: &S,
461) -> Result<(), SubscriptionError<S::Error>>
462where
463 P: ProjectionFilters<Id = S::Id>,
464 S: EventStore,
465{
466 (handler)(
467 projection,
468 stored.aggregate_id(),
469 stored,
470 stored.metadata(),
471 store,
472 )
473 .map_err(|error| match error {
474 HandlerError::EventDecode(error) => SubscriptionError::EventDecode(error),
475 HandlerError::Store(error) => {
476 SubscriptionError::EventDecode(EventDecodeError::Store(error))
477 }
478 })
479}
480
481fn process_subscription_event<P, S>(
482 projection: &mut P,
483 stored: &StoredEvent<S::Id, S::Position, S::Data, S::Metadata>,
484 handlers: &HashMap<&'static str, crate::projection::EventHandler<P, S>>,
485 store: &S,
486 on_update: Option<&UpdateCallback<P>>,
487 last_position: &mut Option<S::Position>,
488 events_since_snapshot: &mut u64,
489) -> Result<(), SubscriptionError<S::Error>>
490where
491 P: ProjectionFilters<Id = S::Id>,
492 S: EventStore,
493 S::Position: Clone,
494{
495 if let Some(handler) = handlers.get(stored.kind()) {
496 apply_handler(handler, projection, stored, store)?;
497 }
498
499 *last_position = Some(stored.position());
500 *events_since_snapshot += 1;
501
502 if let Some(callback) = on_update {
503 callback(projection);
504 }
505
506 Ok(())
507}
508
509async fn offer_projection_snapshot<P, SS>(
510 snapshots: &SS,
511 instance_id: &P::InstanceId,
512 events_since_snapshot: u64,
513 position: &SS::Position,
514 projection: &P,
515) -> bool
516where
517 P: Projection + ProjectionFilters + Serialize + Sync,
518 P::InstanceId: Sync,
519 SS: SnapshotStore<P::InstanceId>,
520 SS::Position: Clone,
521{
522 let pos = position.clone();
523 let result = snapshots
524 .offer_snapshot(
525 P::KIND,
526 instance_id,
527 events_since_snapshot,
528 move || -> Result<Snapshot<SS::Position, &P>, std::convert::Infallible> {
529 Ok(Snapshot {
530 position: pos,
531 data: projection,
532 })
533 },
534 )
535 .await;
536
537 match result {
538 Ok(crate::snapshot::SnapshotOffer::Stored) => true,
539 Ok(crate::snapshot::SnapshotOffer::Declined) => false,
540 Err(e) => {
541 tracing::warn!(error = %e, "failed to store subscription snapshot");
542 false
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use std::{error::Error, io};
550
551 use super::*;
552
553 #[test]
554 fn subscription_error_store_displays() {
555 let err: SubscriptionError<io::Error> = SubscriptionError::Store(io::Error::other("test"));
556 assert!(err.to_string().contains("store error"));
557 assert!(err.source().is_some());
558 }
559
560 #[test]
561 fn subscription_error_task_panicked_displays() {
562 let err: SubscriptionError<io::Error> = SubscriptionError::TaskPanicked;
563 assert!(err.to_string().contains("panicked"));
564 }
565
566 #[test]
567 fn subscription_not_alive_after_stop_consumes_task_handle() {
568 let handle: SubscriptionHandle<io::Error> = SubscriptionHandle {
569 stop_tx: None,
570 task: None,
571 };
572 assert!(!handle.is_running());
573 }
574}