Skip to main content

tokio_process_tools/output_stream/visitors/
inspect.rs

1use crate::output_stream::Next;
2use crate::output_stream::event::Chunk;
3use crate::output_stream::line::adapter::{AsyncLineVisitor, LineVisitor};
4use crate::output_stream::visitor::{AsyncStreamVisitor, StreamVisitor};
5use std::borrow::Cow;
6use std::future::Future;
7use std::marker::PhantomData;
8use typed_builder::TypedBuilder;
9
10/// Synchronous [`StreamVisitor`] that calls `f` once per observed chunk and discards the
11/// chunk afterwards. Construct via [`builder`](Self::builder) and pass to
12/// [`Consumable::consume`](crate::Consumable::consume).
13#[derive(TypedBuilder)]
14pub struct InspectChunks<F>
15where
16    F: FnMut(Chunk) -> Next + Send + 'static,
17{
18    /// Per-chunk callback. Return [`Next::Break`] to stop the consumer early.
19    pub f: F,
20}
21
22impl<F> StreamVisitor for InspectChunks<F>
23where
24    F: FnMut(Chunk) -> Next + Send + 'static,
25{
26    type Output = ();
27
28    fn on_chunk(&mut self, chunk: Chunk) -> Next {
29        (self.f)(chunk)
30    }
31
32    fn into_output(self) -> Self::Output {}
33}
34
35/// Asynchronous counterpart to [`InspectChunks`]. Construct via [`builder`](Self::builder) and
36/// pass to [`Consumable::consume_async`](crate::Consumable::consume_async).
37#[derive(TypedBuilder)]
38pub struct InspectChunksAsync<F, Fut>
39where
40    F: FnMut(Chunk) -> Fut + Send + 'static,
41    Fut: Future<Output = Next> + Send + 'static,
42{
43    /// Per-chunk async callback. The returned future yields [`Next::Break`] to stop early.
44    pub f: F,
45    /// Phantom marker so the `Fut` bound lives on the struct (and on the derived builder)
46    /// rather than only on the impl block. The closure's return type carries `Fut`, so the
47    /// builder infers it from `f`. Users never spell `Fut` out.
48    #[builder(default, setter(skip))]
49    pub _fut: PhantomData<fn() -> Fut>,
50}
51
52impl<F, Fut> AsyncStreamVisitor for InspectChunksAsync<F, Fut>
53where
54    F: FnMut(Chunk) -> Fut + Send + 'static,
55    Fut: Future<Output = Next> + Send + 'static,
56{
57    type Output = ();
58
59    fn on_chunk(&mut self, chunk: Chunk) -> impl Future<Output = Next> + Send + '_ {
60        (self.f)(chunk)
61    }
62
63    fn into_output(self) -> Self::Output {}
64}
65
66/// [`LineVisitor`] wrapping a per-line closure. Compose with
67/// [`ParseLines`](crate::output_stream::line::adapter::ParseLines) (most easily via
68/// [`ParseLines::inspect`](crate::output_stream::line::adapter::ParseLines::inspect)) to
69/// build a line-by-line inspector consumer.
70pub struct InspectLines<F> {
71    f: F,
72}
73
74impl<F> InspectLines<F>
75where
76    F: FnMut(Cow<'_, str>) -> Next + Send + 'static,
77{
78    /// Creates a new sink that calls `f` once for each parsed line.
79    pub fn new(f: F) -> Self {
80        Self { f }
81    }
82}
83
84impl<F> LineVisitor for InspectLines<F>
85where
86    F: FnMut(Cow<'_, str>) -> Next + Send + 'static,
87{
88    type Output = ();
89
90    fn on_line(&mut self, line: Cow<'_, str>) -> Next {
91        (self.f)(line)
92    }
93
94    fn into_output(self) -> Self::Output {}
95}
96
97/// [`AsyncLineVisitor`] wrapping a per-line async closure. Compose with
98/// [`ParseLines`](crate::output_stream::line::adapter::ParseLines) (most easily via
99/// [`ParseLines::inspect_async`](crate::output_stream::line::adapter::ParseLines::inspect_async))
100/// to build an async line-by-line inspector consumer. The `PhantomData<fn() -> Fut>` carries
101/// the future's type onto the struct so callers never name `Fut` explicitly.
102pub struct InspectLinesAsync<F, Fut> {
103    f: F,
104    _fut: PhantomData<fn() -> Fut>,
105}
106
107impl<F, Fut> InspectLinesAsync<F, Fut>
108where
109    F: FnMut(Cow<'_, str>) -> Fut + Send + 'static,
110    Fut: Future<Output = Next> + Send + 'static,
111{
112    /// Creates a new sink that awaits `f` once for each parsed line.
113    pub fn new(f: F) -> Self {
114        Self {
115            f,
116            _fut: PhantomData,
117        }
118    }
119}
120
121impl<F, Fut> AsyncLineVisitor for InspectLinesAsync<F, Fut>
122where
123    F: FnMut(Cow<'_, str>) -> Fut + Send + 'static,
124    Fut: Future<Output = Next> + Send + 'static,
125{
126    type Output = ();
127
128    fn on_line<'a>(&'a mut self, line: Cow<'a, str>) -> impl Future<Output = Next> + Send + 'a {
129        (self.f)(line)
130    }
131
132    fn into_output(self) -> Self::Output {}
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::output_stream::consumer::Consumer;
139    use crate::output_stream::consumer::driver::spawn_consumer_sync;
140    use crate::output_stream::event::StreamEvent;
141    use crate::output_stream::event::tests::event_receiver;
142    use crate::output_stream::line::adapter::ParseLines;
143    use crate::output_stream::line::options::LineParsingOptions;
144    use crate::{ConsumerCancelOutcome, ConsumerError, StreamReadError};
145    use assertr::prelude::*;
146    use bytes::Bytes;
147    use std::io;
148    use std::sync::{Arc, Mutex};
149    use std::time::Duration;
150    use tokio::sync::oneshot;
151
152    #[tokio::test]
153    async fn cancel_returns_cancelled_when_cooperative() {
154        let (task_termination_sender, task_termination_receiver) = oneshot::channel();
155        let inspector: Consumer<()> = Consumer {
156            stream_name: "custom",
157            task: Some(tokio::spawn(async move {
158                let _res = task_termination_receiver.await;
159                Ok(())
160            })),
161            task_termination_sender: Some(task_termination_sender),
162        };
163
164        let outcome = inspector.cancel(Duration::from_secs(1)).await.unwrap();
165
166        assert_that!(matches!(outcome, ConsumerCancelOutcome::Cancelled(()))).is_true();
167    }
168
169    mod inspect_lines {
170        use super::*;
171        use crate::NumBytesExt;
172
173        #[test]
174        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
175        fn panics_when_max_line_length_is_zero() {
176            let _visitor = ParseLines::new(
177                LineParsingOptions {
178                    max_line_length: 0.bytes(),
179                    overflow_behavior: crate::LineOverflowBehavior::default(),
180                    buffer_compaction_threshold: None,
181                },
182                InspectLines::new(|_line| Next::Continue),
183            );
184        }
185
186        #[tokio::test]
187        async fn inspectors_return_stream_read_error() {
188            let error = StreamReadError::new("custom", io::Error::from(io::ErrorKind::BrokenPipe));
189            let inspector = spawn_consumer_sync(
190                "custom",
191                event_receiver(vec![
192                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"complete\npartial"))),
193                    StreamEvent::ReadError(error),
194                ])
195                .await,
196                ParseLines::inspect(LineParsingOptions::default(), |_line| Next::Continue),
197            );
198
199            match inspector.wait().await {
200                Err(ConsumerError::StreamRead { source }) => {
201                    assert_that!(source.stream_name()).is_equal_to("custom");
202                    assert_that!(source.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
203                }
204                other => {
205                    assert_that!(&other).fail(format_args!(
206                        "expected inspector stream read error, got {other:?}"
207                    ));
208                }
209            }
210        }
211
212        #[tokio::test]
213        async fn inspectors_skip_gaps_and_visit_final_unterminated_line() {
214            let seen = Arc::new(Mutex::new(Vec::<String>::new()));
215            let seen_in_task = Arc::clone(&seen);
216            let inspector = spawn_consumer_sync(
217                "custom",
218                event_receiver(vec![
219                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"one\npar"))),
220                    StreamEvent::Gap,
221                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"\ntwo\nfinal"))),
222                    StreamEvent::Eof,
223                ])
224                .await,
225                ParseLines::inspect(LineParsingOptions::default(), move |line| {
226                    seen_in_task.lock().unwrap().push(line.into_owned());
227                    Next::Continue
228                }),
229            );
230
231            inspector.wait().await.unwrap();
232
233            let seen = seen.lock().unwrap().clone();
234            assert_that!(seen).contains_exactly(["one", "two", "final"]);
235        }
236    }
237
238    mod inspect_chunks {
239        use super::*;
240
241        #[tokio::test]
242        async fn accepts_stateful_callback() {
243            let (count_tx, count_rx) = oneshot::channel();
244            let mut chunk_count = 0;
245            let mut count_tx = Some(count_tx);
246            let inspector = spawn_consumer_sync(
247                "custom",
248                event_receiver(vec![
249                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ab"))),
250                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"cd"))),
251                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ef"))),
252                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"gh"))),
253                    StreamEvent::Eof,
254                ])
255                .await,
256                InspectChunks::builder()
257                    .f(move |_chunk| {
258                        chunk_count += 1;
259                        if chunk_count == 3 {
260                            count_tx.take().unwrap().send(chunk_count).unwrap();
261                            Next::Break
262                        } else {
263                            Next::Continue
264                        }
265                    })
266                    .build(),
267            );
268
269            inspector.wait().await.unwrap();
270            let chunk_count = count_rx.await.unwrap();
271            assert_that!(chunk_count).is_equal_to(3);
272        }
273    }
274
275    mod inspect_chunks_async {
276        use super::*;
277        use crate::output_stream::consumer::driver::spawn_consumer_async;
278
279        #[tokio::test]
280        async fn accepts_stateful_callback() {
281            let seen = Arc::new(Mutex::new(Vec::<Vec<u8>>::new()));
282            let seen_in_task = Arc::clone(&seen);
283            let mut chunk_count = 0;
284            let inspector = spawn_consumer_async(
285                "custom",
286                event_receiver(vec![
287                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ab"))),
288                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"cd"))),
289                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ef"))),
290                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"gh"))),
291                    StreamEvent::Eof,
292                ])
293                .await,
294                InspectChunksAsync::builder()
295                    .f(move |chunk| {
296                        chunk_count += 1;
297                        let seen = Arc::clone(&seen_in_task);
298                        let bytes = chunk.as_ref().to_vec();
299                        let should_break = chunk_count == 3;
300                        async move {
301                            seen.lock().unwrap().push(bytes);
302                            if should_break {
303                                Next::Break
304                            } else {
305                                Next::Continue
306                            }
307                        }
308                    })
309                    .build(),
310            );
311
312            inspector.wait().await.unwrap();
313
314            let seen = seen.lock().unwrap().clone();
315            assert_that!(seen).is_equal_to(vec![b"ab".to_vec(), b"cd".to_vec(), b"ef".to_vec()]);
316        }
317    }
318
319    mod inspect_lines_async {
320        use super::*;
321        use crate::NumBytesExt;
322        use crate::output_stream::consumer::driver::spawn_consumer_async;
323
324        #[test]
325        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
326        fn panics_when_max_line_length_is_zero() {
327            let _visitor = ParseLines::new(
328                LineParsingOptions {
329                    max_line_length: 0.bytes(),
330                    overflow_behavior: crate::LineOverflowBehavior::default(),
331                    buffer_compaction_threshold: None,
332                },
333                InspectLinesAsync::new(|_line| async { Next::Continue }),
334            );
335        }
336
337        #[tokio::test]
338        async fn preserves_unterminated_final_line() {
339            let seen = Arc::new(Mutex::new(Vec::<String>::new()));
340            let seen_in_task = Arc::clone(&seen);
341            let inspector = spawn_consumer_async(
342                "custom",
343                event_receiver(vec![
344                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"tail"))),
345                    StreamEvent::Eof,
346                ])
347                .await,
348                ParseLines::inspect_async(LineParsingOptions::default(), move |line| {
349                    let seen = Arc::clone(&seen_in_task);
350                    let line = line.into_owned();
351                    async move {
352                        seen.lock().unwrap().push(line);
353                        Next::Continue
354                    }
355                }),
356            );
357
358            inspector.wait().await.unwrap();
359
360            let seen = seen.lock().unwrap().clone();
361            assert_that!(seen).contains_exactly(["tail"]);
362        }
363    }
364}