Skip to main content

tokio_process_tools/output_stream/visitors/
inspect.rs

1use crate::output_stream::Next;
2use crate::output_stream::event::Chunk;
3use crate::output_stream::line::adapter::{AsyncLineSink, LineSink};
4use crate::output_stream::visitor::{AsyncStreamVisitor, StreamVisitor};
5use std::borrow::Cow;
6use std::future::Future;
7use std::marker::PhantomData;
8use typed_builder::TypedBuilder;
9
10#[derive(TypedBuilder)]
11pub(crate) struct InspectChunks<F>
12where
13    F: FnMut(Chunk) -> Next + Send + 'static,
14{
15    pub f: F,
16}
17
18impl<F> StreamVisitor for InspectChunks<F>
19where
20    F: FnMut(Chunk) -> Next + Send + 'static,
21{
22    type Output = ();
23
24    fn on_chunk(&mut self, chunk: Chunk) -> Next {
25        (self.f)(chunk)
26    }
27
28    fn into_output(self) -> Self::Output {}
29}
30
31#[derive(TypedBuilder)]
32pub(crate) struct InspectChunksAsync<F, Fut>
33where
34    F: FnMut(Chunk) -> Fut + Send + 'static,
35    Fut: Future<Output = Next> + Send + 'static,
36{
37    pub f: F,
38    /// Phantom marker so the `Fut` bound lives on the struct (and on the derived builder)
39    /// rather than only on the impl block. The closure's return type carries `Fut`, so the
40    /// builder infers it from `f` — users never spell `Fut` out.
41    #[builder(default, setter(skip))]
42    pub _fut: PhantomData<fn() -> Fut>,
43}
44
45impl<F, Fut> AsyncStreamVisitor for InspectChunksAsync<F, Fut>
46where
47    F: FnMut(Chunk) -> Fut + Send + 'static,
48    Fut: Future<Output = Next> + Send + 'static,
49{
50    type Output = ();
51
52    fn on_chunk(&mut self, chunk: Chunk) -> impl Future<Output = Next> + Send + '_ {
53        (self.f)(chunk)
54    }
55
56    fn into_output(self) -> Self::Output {}
57}
58
59/// [`LineSink`] wrapping a per-line closure. Compose with
60/// [`LineAdapter`](crate::output_stream::line::adapter::LineAdapter) to drive `inspect_lines`, or to
61/// build your own custom inspect-lines consumer outside the built-in factory methods.
62pub struct InspectLineSink<F> {
63    f: F,
64}
65
66impl<F> InspectLineSink<F>
67where
68    F: FnMut(Cow<'_, str>) -> Next + Send + 'static,
69{
70    /// Creates a new sink that calls `f` once for each parsed line.
71    pub fn new(f: F) -> Self {
72        Self { f }
73    }
74}
75
76impl<F> LineSink for InspectLineSink<F>
77where
78    F: FnMut(Cow<'_, str>) -> Next + Send + 'static,
79{
80    type Output = ();
81
82    fn on_line(&mut self, line: Cow<'_, str>) -> Next {
83        (self.f)(line)
84    }
85
86    fn into_output(self) -> Self::Output {}
87}
88
89/// [`AsyncLineSink`] wrapping a per-line async closure. Compose with
90/// [`LineAdapter`](crate::output_stream::line::adapter::LineAdapter) (its [`AsyncStreamVisitor`] impl
91/// is selected automatically when the inner sink is an [`AsyncLineSink`]) to drive
92/// `inspect_lines_async`. The `PhantomData<fn() -> Fut>` carries the future's type onto the
93/// struct so callers never name `Fut` explicitly.
94pub struct InspectLineSinkAsync<F, Fut> {
95    f: F,
96    _fut: PhantomData<fn() -> Fut>,
97}
98
99impl<F, Fut> InspectLineSinkAsync<F, Fut>
100where
101    F: FnMut(Cow<'_, str>) -> Fut + Send + 'static,
102    Fut: Future<Output = Next> + Send + 'static,
103{
104    /// Creates a new sink that awaits `f` once for each parsed line.
105    pub fn new(f: F) -> Self {
106        Self {
107            f,
108            _fut: PhantomData,
109        }
110    }
111}
112
113impl<F, Fut> AsyncLineSink for InspectLineSinkAsync<F, Fut>
114where
115    F: FnMut(Cow<'_, str>) -> Fut + Send + 'static,
116    Fut: Future<Output = Next> + Send + 'static,
117{
118    type Output = ();
119
120    fn on_line<'a>(&'a mut self, line: Cow<'a, str>) -> impl Future<Output = Next> + Send + 'a {
121        (self.f)(line)
122    }
123
124    fn into_output(self) -> Self::Output {}
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::output_stream::consumer::Consumer;
131    use crate::output_stream::consumer::driver::spawn_consumer_sync;
132    use crate::output_stream::event::StreamEvent;
133    use crate::output_stream::event::tests::event_receiver;
134    use crate::output_stream::line::adapter::LineAdapter;
135    use crate::output_stream::line::options::LineParsingOptions;
136    use crate::{ConsumerCancelOutcome, ConsumerError, StreamReadError};
137    use assertr::prelude::*;
138    use bytes::Bytes;
139    use std::io;
140    use std::sync::{Arc, Mutex};
141    use std::time::Duration;
142    use tokio::sync::oneshot;
143
144    #[tokio::test]
145    async fn cancel_returns_cancelled_when_cooperative() {
146        let (task_termination_sender, task_termination_receiver) = oneshot::channel();
147        let inspector: Consumer<()> = Consumer {
148            stream_name: "custom",
149            task: Some(tokio::spawn(async move {
150                let _res = task_termination_receiver.await;
151                Ok(())
152            })),
153            task_termination_sender: Some(task_termination_sender),
154        };
155
156        let outcome = inspector.cancel(Duration::from_secs(1)).await.unwrap();
157
158        assert_that!(matches!(outcome, ConsumerCancelOutcome::Cancelled(()))).is_true();
159    }
160
161    mod inspect_lines {
162        use super::*;
163        use crate::NumBytesExt;
164
165        #[test]
166        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
167        fn panics_when_max_line_length_is_zero() {
168            let _visitor = LineAdapter::new(
169                LineParsingOptions {
170                    max_line_length: 0.bytes(),
171                    overflow_behavior: crate::LineOverflowBehavior::default(),
172                    buffer_compaction_threshold: None,
173                },
174                InspectLineSink::new(|_line| Next::Continue),
175            );
176        }
177
178        #[tokio::test]
179        async fn inspectors_return_stream_read_error() {
180            let error = StreamReadError::new("custom", io::Error::from(io::ErrorKind::BrokenPipe));
181            let inspector = spawn_consumer_sync(
182                "custom",
183                event_receiver(vec![
184                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"complete\npartial"))),
185                    StreamEvent::ReadError(error),
186                ])
187                .await,
188                LineAdapter::new(
189                    LineParsingOptions::default(),
190                    InspectLineSink::new(|_line| Next::Continue),
191                ),
192            );
193
194            match inspector.wait().await {
195                Err(ConsumerError::StreamRead { source }) => {
196                    assert_that!(source.stream_name()).is_equal_to("custom");
197                    assert_that!(source.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
198                }
199                other => {
200                    assert_that!(&other).fail(format_args!(
201                        "expected inspector stream read error, got {other:?}"
202                    ));
203                }
204            }
205        }
206
207        #[tokio::test]
208        async fn inspectors_skip_gaps_and_visit_final_unterminated_line() {
209            let seen = Arc::new(Mutex::new(Vec::<String>::new()));
210            let seen_in_task = Arc::clone(&seen);
211            let inspector = spawn_consumer_sync(
212                "custom",
213                event_receiver(vec![
214                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"one\npar"))),
215                    StreamEvent::Gap,
216                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"\ntwo\nfinal"))),
217                    StreamEvent::Eof,
218                ])
219                .await,
220                LineAdapter::new(
221                    LineParsingOptions::default(),
222                    InspectLineSink::new(move |line| {
223                        seen_in_task.lock().unwrap().push(line.into_owned());
224                        Next::Continue
225                    }),
226                ),
227            );
228
229            inspector.wait().await.unwrap();
230
231            let seen = seen.lock().unwrap().clone();
232            assert_that!(seen).contains_exactly(["one", "two", "final"]);
233        }
234    }
235
236    mod inspect_chunks {
237        use super::*;
238
239        #[tokio::test]
240        async fn accepts_stateful_callback() {
241            let (count_tx, count_rx) = oneshot::channel();
242            let mut chunk_count = 0;
243            let mut count_tx = Some(count_tx);
244            let inspector = spawn_consumer_sync(
245                "custom",
246                event_receiver(vec![
247                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ab"))),
248                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"cd"))),
249                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ef"))),
250                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"gh"))),
251                    StreamEvent::Eof,
252                ])
253                .await,
254                InspectChunks::builder()
255                    .f(move |_chunk| {
256                        chunk_count += 1;
257                        if chunk_count == 3 {
258                            count_tx.take().unwrap().send(chunk_count).unwrap();
259                            Next::Break
260                        } else {
261                            Next::Continue
262                        }
263                    })
264                    .build(),
265            );
266
267            inspector.wait().await.unwrap();
268            let chunk_count = count_rx.await.unwrap();
269            assert_that!(chunk_count).is_equal_to(3);
270        }
271    }
272
273    mod inspect_chunks_async {
274        use super::*;
275        use crate::output_stream::consumer::driver::spawn_consumer_async;
276
277        #[tokio::test]
278        async fn accepts_stateful_callback() {
279            let seen = Arc::new(Mutex::new(Vec::<Vec<u8>>::new()));
280            let seen_in_task = Arc::clone(&seen);
281            let mut chunk_count = 0;
282            let inspector = spawn_consumer_async(
283                "custom",
284                event_receiver(vec![
285                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ab"))),
286                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"cd"))),
287                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"ef"))),
288                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"gh"))),
289                    StreamEvent::Eof,
290                ])
291                .await,
292                InspectChunksAsync::builder()
293                    .f(move |chunk| {
294                        chunk_count += 1;
295                        let seen = Arc::clone(&seen_in_task);
296                        let bytes = chunk.as_ref().to_vec();
297                        let should_break = chunk_count == 3;
298                        async move {
299                            seen.lock().unwrap().push(bytes);
300                            if should_break {
301                                Next::Break
302                            } else {
303                                Next::Continue
304                            }
305                        }
306                    })
307                    .build(),
308            );
309
310            inspector.wait().await.unwrap();
311
312            let seen = seen.lock().unwrap().clone();
313            assert_that!(seen).is_equal_to(vec![b"ab".to_vec(), b"cd".to_vec(), b"ef".to_vec()]);
314        }
315    }
316
317    mod inspect_lines_async {
318        use super::*;
319        use crate::NumBytesExt;
320        use crate::output_stream::consumer::driver::spawn_consumer_async;
321
322        #[test]
323        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
324        fn panics_when_max_line_length_is_zero() {
325            let _visitor = LineAdapter::new(
326                LineParsingOptions {
327                    max_line_length: 0.bytes(),
328                    overflow_behavior: crate::LineOverflowBehavior::default(),
329                    buffer_compaction_threshold: None,
330                },
331                InspectLineSinkAsync::new(|_line| async { Next::Continue }),
332            );
333        }
334
335        #[tokio::test]
336        async fn preserves_unterminated_final_line() {
337            let seen = Arc::new(Mutex::new(Vec::<String>::new()));
338            let seen_in_task = Arc::clone(&seen);
339            let inspector = spawn_consumer_async(
340                "custom",
341                event_receiver(vec![
342                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"tail"))),
343                    StreamEvent::Eof,
344                ])
345                .await,
346                LineAdapter::new(
347                    LineParsingOptions::default(),
348                    InspectLineSinkAsync::new(move |line| {
349                        let seen = Arc::clone(&seen_in_task);
350                        let line = line.into_owned();
351                        async move {
352                            seen.lock().unwrap().push(line);
353                            Next::Continue
354                        }
355                    }),
356                ),
357            );
358
359            inspector.wait().await.unwrap();
360
361            let seen = seen.lock().unwrap().clone();
362            assert_that!(seen).contains_exactly(["tail"]);
363        }
364    }
365}