Skip to main content

tokio_process_tools/output_stream/backend/broadcast/
subscription.rs

1use super::state::{BestEffortLiveQueue, IndexedEvent, Shared, SubscriberId};
2use crate::output_stream::Subscription;
3use crate::output_stream::event::StreamEvent;
4use crate::output_stream::policy::{Delivery, LossyWithoutBackpressure, NoReplay, Replay};
5use std::collections::VecDeque;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use tokio::sync::broadcast;
10use tokio::sync::broadcast::error::RecvError;
11use tokio::sync::mpsc;
12
13#[derive(Debug)]
14pub(super) struct FastSubscription {
15    pub(super) receiver: broadcast::Receiver<StreamEvent>,
16    pub(super) emit_terminal_event: Option<StreamEvent>,
17}
18
19impl FastSubscription {
20    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
21        if let Some(event) = self.emit_terminal_event.take() {
22            return Some(event);
23        }
24
25        match self.receiver.recv().await {
26            Ok(event) => Some(event),
27            Err(RecvError::Closed) => None,
28            Err(RecvError::Lagged(lagged)) => {
29                tracing::warn!(lagged, "Broadcast subscriber is lagging behind");
30                Some(StreamEvent::Gap)
31            }
32        }
33    }
34}
35
36#[derive(Debug)]
37pub(super) enum LiveReceiver {
38    Reliable(mpsc::Receiver<IndexedEvent>),
39    BestEffort(Arc<BestEffortLiveQueue>),
40    Closed,
41}
42
43impl LiveReceiver {
44    async fn recv(&mut self) -> Option<IndexedEvent> {
45        match self {
46            Self::Reliable(receiver) => receiver.recv().await,
47            Self::BestEffort(queue) => queue.recv().await,
48            Self::Closed => None,
49        }
50    }
51}
52
53#[derive(Debug)]
54pub(super) struct SharedSubscription<D = LossyWithoutBackpressure, R = NoReplay>
55where
56    D: Delivery,
57    R: Replay,
58{
59    pub(super) shared: Arc<Shared>,
60    pub(super) id: Option<SubscriberId>,
61    pub(super) replay: VecDeque<IndexedEvent>,
62    pub(super) live_start_seq: u64,
63    pub(super) live_receiver: LiveReceiver,
64    pub(super) _marker: PhantomData<fn() -> (D, R)>,
65    pub(super) done: bool,
66}
67
68impl<D, R> Drop for SharedSubscription<D, R>
69where
70    D: Delivery,
71    R: Replay,
72{
73    fn drop(&mut self) {
74        if !self.done
75            && let Some(id) = self.id.take()
76        {
77            let mut state = self.shared.state.lock().expect("broadcast state poisoned");
78            state.remove_subscriber(id);
79        }
80    }
81}
82
83impl<D, R> SharedSubscription<D, R>
84where
85    D: Delivery,
86    R: Replay,
87{
88    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
89        if let Some(event) = self.replay.pop_front() {
90            if matches!(event.event, StreamEvent::Eof | StreamEvent::ReadError(_)) {
91                self.detach();
92            }
93            return Some(event.event);
94        }
95
96        loop {
97            let event = self.live_receiver.recv().await?;
98            if event.seq < self.live_start_seq {
99                continue;
100            }
101            if matches!(event.event, StreamEvent::Eof | StreamEvent::ReadError(_)) {
102                self.detach();
103            }
104            return Some(event.event);
105        }
106    }
107
108    fn detach(&mut self) {
109        if let Some(id) = self.id.take() {
110            let mut state = self.shared.state.lock().expect("broadcast state poisoned");
111            state.remove_subscriber(id);
112        }
113        self.done = true;
114    }
115}
116
117/// Subscription handle returned by
118/// [`BroadcastOutputStream::try_subscribe`](crate::BroadcastOutputStream).
119/// Treat this as an opaque value: pass it to a built-in consumer or your own
120/// [`Subscription`]-driven loop. The internal representation is not part of the public API.
121#[derive(Debug)]
122pub struct BroadcastSubscription<D = LossyWithoutBackpressure, R = NoReplay>
123where
124    D: Delivery,
125    R: Replay,
126{
127    inner: BroadcastSubscriptionInner<D, R>,
128}
129
130#[derive(Debug)]
131enum BroadcastSubscriptionInner<D, R>
132where
133    D: Delivery,
134    R: Replay,
135{
136    Fast(FastSubscription),
137    Shared(SharedSubscription<D, R>),
138}
139
140impl<D, R> BroadcastSubscription<D, R>
141where
142    D: Delivery,
143    R: Replay,
144{
145    pub(super) fn fast(subscription: FastSubscription) -> Self {
146        Self {
147            inner: BroadcastSubscriptionInner::Fast(subscription),
148        }
149    }
150
151    pub(super) fn shared(subscription: SharedSubscription<D, R>) -> Self {
152        Self {
153            inner: BroadcastSubscriptionInner::Shared(subscription),
154        }
155    }
156
157    pub(super) async fn recv(&mut self) -> Option<StreamEvent> {
158        match &mut self.inner {
159            BroadcastSubscriptionInner::Fast(subscription) => subscription.recv().await,
160            BroadcastSubscriptionInner::Shared(subscription) => subscription.recv().await,
161        }
162    }
163}
164
165impl<D, R> Subscription for BroadcastSubscription<D, R>
166where
167    D: Delivery,
168    R: Replay,
169{
170    #[allow(
171        clippy::manual_async_fn,
172        reason = "the trait method must expose a Send future for tokio::spawn"
173    )]
174    fn next_event(&mut self) -> impl Future<Output = Option<StreamEvent>> + Send + '_ {
175        async move { self.recv().await }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::super::state::{SubscriberSender, append_event};
182    use super::*;
183    use crate::StreamReadError;
184    use crate::{
185        NumBytesExt, ReliableWithBackpressure, ReplayEnabled, ReplayRetention, StreamConfig,
186    };
187    use assertr::prelude::*;
188    use std::io;
189
190    fn best_effort_options(
191        retention: ReplayRetention,
192    ) -> StreamConfig<LossyWithoutBackpressure, ReplayEnabled> {
193        let builder = StreamConfig::builder().lossy_without_backpressure();
194        match retention {
195            ReplayRetention::LastChunks(chunks) => builder.replay_last_chunks(chunks),
196            ReplayRetention::LastBytes(bytes) => builder.replay_last_bytes(bytes),
197            ReplayRetention::All => builder.replay_all(),
198        }
199        .read_chunk_size(3.bytes())
200        .max_buffered_chunks(1)
201        .build()
202    }
203
204    fn reliable_no_replay_options() -> StreamConfig<ReliableWithBackpressure, NoReplay> {
205        StreamConfig::builder()
206            .reliable_with_backpressure()
207            .no_replay()
208            .read_chunk_size(1.bytes())
209            .max_buffered_chunks(1)
210            .build()
211    }
212
213    fn reliable_replay_options(
214        retention: ReplayRetention,
215    ) -> StreamConfig<ReliableWithBackpressure, ReplayEnabled> {
216        let builder = StreamConfig::builder().reliable_with_backpressure();
217        match retention {
218            ReplayRetention::LastChunks(chunks) => builder.replay_last_chunks(chunks),
219            ReplayRetention::LastBytes(bytes) => builder.replay_last_bytes(bytes),
220            ReplayRetention::All => builder.replay_all(),
221        }
222        .read_chunk_size(1.bytes())
223        .max_buffered_chunks(4)
224        .build()
225    }
226
227    fn subscribe<D, R>(
228        shared: &Arc<Shared>,
229        options: StreamConfig<D, R>,
230    ) -> SharedSubscription<D, R>
231    where
232        D: Delivery,
233        R: Replay,
234    {
235        let (sender, live_receiver) = match options.delivery_guarantee() {
236            crate::DeliveryGuarantee::ReliableWithBackpressure => {
237                let (sender, receiver) = mpsc::channel(options.max_buffered_chunks);
238                (
239                    SubscriberSender::Reliable(sender),
240                    LiveReceiver::Reliable(receiver),
241                )
242            }
243            crate::DeliveryGuarantee::LossyWithoutBackpressure => {
244                let queue = Arc::new(BestEffortLiveQueue::new(options.max_buffered_chunks));
245                (
246                    SubscriberSender::BestEffort(Arc::clone(&queue)),
247                    LiveReceiver::BestEffort(queue),
248                )
249            }
250        };
251
252        let mut state = shared.state.lock().expect("broadcast state poisoned");
253        let (replay, live_start_seq) = state.replay_snapshot(options);
254        let id = if state.closed || state.terminal.is_some() {
255            None
256        } else {
257            Some(state.add_subscriber(sender))
258        };
259        drop(state);
260
261        SharedSubscription {
262            shared: Arc::clone(shared),
263            id,
264            replay,
265            live_start_seq,
266            live_receiver,
267            _marker: PhantomData,
268            done: false,
269        }
270    }
271
272    async fn assert_next_chunk<D, R>(
273        subscription: &mut SharedSubscription<D, R>,
274        expected: &'static [u8],
275    ) where
276        D: Delivery,
277        R: Replay,
278    {
279        match subscription.recv().await {
280            Some(StreamEvent::Chunk(chunk)) => {
281                assert_that!(chunk.as_ref()).is_equal_to(expected);
282            }
283            other => {
284                assert_that!(&other).fail(format_args!("expected chunk, got {other:?}"));
285            }
286        }
287    }
288
289    #[tokio::test]
290    async fn slow_best_effort_subscriber_observes_gap_then_newer_tail() {
291        let options = best_effort_options(ReplayRetention::LastChunks(1));
292        let shared = Arc::new(Shared::new());
293        let mut subscription = subscribe(&shared, options);
294
295        append_event(&shared, options, StreamEvent::chunk(b"old")).await;
296        append_event(&shared, options, StreamEvent::chunk(b"new")).await;
297        append_event(&shared, options, StreamEvent::Eof).await;
298
299        assert_that!(subscription.recv().await)
300            .is_some()
301            .is_equal_to(StreamEvent::Gap);
302        assert_that!(subscription.recv().await)
303            .is_some()
304            .is_equal_to(StreamEvent::Eof);
305    }
306
307    #[tokio::test]
308    async fn eof_is_replayed_to_late_subscribers_before_seal() {
309        let options = reliable_replay_options(ReplayRetention::All);
310        let shared = Arc::new(Shared::new());
311
312        append_event(&shared, options, StreamEvent::chunk(b"tail")).await;
313        append_event(&shared, options, StreamEvent::Eof).await;
314
315        let mut subscription = subscribe(&shared, options);
316        assert_next_chunk(&mut subscription, b"tail").await;
317        assert_that!(subscription.recv().await)
318            .is_some()
319            .is_equal_to(StreamEvent::Eof);
320    }
321
322    #[tokio::test]
323    async fn no_replay_late_subscriber_observes_terminal_read_error() {
324        let options = reliable_no_replay_options();
325        let shared = Arc::new(Shared::new());
326
327        append_event(&shared, options, StreamEvent::chunk(b"booting\n")).await;
328        append_event(
329            &shared,
330            options,
331            StreamEvent::ReadError(StreamReadError::new(
332                "custom",
333                io::Error::from(io::ErrorKind::BrokenPipe),
334            )),
335        )
336        .await;
337
338        let mut subscription = subscribe(&shared, options);
339        match subscription.recv().await {
340            Some(StreamEvent::ReadError(err)) => {
341                assert_that!(err.stream_name()).is_equal_to("custom");
342                assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
343            }
344            other => {
345                assert_that!(&other).fail(format_args!("expected read error, got {other:?}"));
346            }
347        }
348    }
349
350    #[tokio::test]
351    async fn replay_late_subscriber_observes_retained_output_then_read_error() {
352        let options = reliable_replay_options(ReplayRetention::All);
353        let shared = Arc::new(Shared::new());
354
355        append_event(&shared, options, StreamEvent::chunk(b"booting\npartial")).await;
356        append_event(
357            &shared,
358            options,
359            StreamEvent::ReadError(StreamReadError::new(
360                "custom",
361                io::Error::from(io::ErrorKind::BrokenPipe),
362            )),
363        )
364        .await;
365
366        let mut subscription = subscribe(&shared, options);
367        assert_next_chunk(&mut subscription, b"booting\npartial").await;
368        match subscription.recv().await {
369            Some(StreamEvent::ReadError(err)) => {
370                assert_that!(err.stream_name()).is_equal_to("custom");
371                assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
372            }
373            other => {
374                assert_that!(&other).fail(format_args!("expected read error, got {other:?}"));
375            }
376        }
377    }
378
379    #[tokio::test]
380    async fn active_subscription_does_not_duplicate_live_handoff() {
381        let options = reliable_replay_options(ReplayRetention::All);
382        let shared = Arc::new(Shared::new());
383
384        append_event(&shared, options, StreamEvent::chunk(b"replay")).await;
385        let mut subscription = subscribe(&shared, options);
386        append_event(&shared, options, StreamEvent::chunk(b"live")).await;
387        append_event(&shared, options, StreamEvent::Eof).await;
388
389        assert_next_chunk(&mut subscription, b"replay").await;
390        assert_next_chunk(&mut subscription, b"live").await;
391        assert_that!(subscription.recv().await)
392            .is_some()
393            .is_equal_to(StreamEvent::Eof);
394    }
395}