Skip to main content

tokio_process_tools/output_stream/visitors/
wait.rs

1use crate::output_stream::Next;
2use crate::output_stream::line::adapter::LineSink;
3use std::borrow::Cow;
4
5/// [`LineSink`] that breaks the moment a predicate accepts a line and remembers whether it
6/// has matched yet. Compose with
7/// [`LineAdapter`](crate::output_stream::line::adapter::LineAdapter) to drive `wait_for_line`, or to
8/// build your own custom predicate-driven consumer outside the built-in factory methods.
9pub struct WaitForLineSink<P> {
10    predicate: P,
11    matched: bool,
12}
13
14impl<P> WaitForLineSink<P>
15where
16    P: Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
17{
18    /// Creates a new sink that breaks the parser the first time `predicate` returns `true`.
19    pub fn new(predicate: P) -> Self {
20        Self {
21            predicate,
22            matched: false,
23        }
24    }
25}
26
27impl<P> LineSink for WaitForLineSink<P>
28where
29    P: Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
30{
31    type Output = bool;
32
33    fn on_line(&mut self, line: Cow<'_, str>) -> Next {
34        if (self.predicate)(line) {
35            self.matched = true;
36            Next::Break
37        } else {
38            Next::Continue
39        }
40    }
41
42    fn into_output(self) -> Self::Output {
43        self.matched
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use crate::output_stream::consumer::driver::consume_sync;
51    use crate::output_stream::event::{Chunk, StreamEvent};
52    use crate::output_stream::line::adapter::LineAdapter;
53    use crate::output_stream::line::options::LineParsingOptions;
54    use crate::{LineOverflowBehavior, NumBytesExt, StreamReadError, WaitForLineResult};
55    use assertr::prelude::*;
56    use bytes::Bytes;
57    use std::io;
58    use std::time::Duration;
59    use tokio::sync::{mpsc, oneshot};
60
61    /// Drive a `WaitForLineSink` over the supplied events and translate the visitor's `bool`
62    /// output into [`WaitForLineResult`]. Mirrors what the deleted `wait_for_line` factory
63    /// used to do; lives in tests because production code now drives the visitor straight from
64    /// the backend method.
65    async fn drive_wait_for_line(
66        events: Vec<StreamEvent>,
67        predicate: impl Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
68        options: LineParsingOptions,
69    ) -> Result<WaitForLineResult, StreamReadError> {
70        let (tx, rx) = mpsc::channel(events.len().max(1));
71        for event in events {
72            tx.send(event).await.unwrap();
73        }
74        drop(tx);
75
76        let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
77        let visitor = LineAdapter::new(options, WaitForLineSink::new(predicate));
78        let matched = consume_sync(rx, visitor, term_sig_rx).await?;
79        if matched {
80            Ok(WaitForLineResult::Matched)
81        } else {
82            Ok(WaitForLineResult::StreamClosed)
83        }
84    }
85
86    async fn wait_for_ready(
87        events: Vec<StreamEvent>,
88    ) -> Result<WaitForLineResult, StreamReadError> {
89        drive_wait_for_line(
90            events,
91            |line| line == "ready",
92            LineParsingOptions::default(),
93        )
94        .await
95    }
96
97    mod wait_for_line {
98        use super::*;
99
100        #[tokio::test]
101        async fn matches_intermediary_line() {
102            let result = wait_for_ready(vec![
103                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready\n"))),
104                StreamEvent::Eof,
105            ])
106            .await;
107            assert_that!(result)
108                .is_ok()
109                .is_equal_to(WaitForLineResult::Matched);
110        }
111
112        #[tokio::test]
113        async fn matches_final_line() {
114            let result = wait_for_ready(vec![
115                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready"))),
116                StreamEvent::Eof,
117            ])
118            .await;
119            assert_that!(result)
120                .is_ok()
121                .is_equal_to(WaitForLineResult::Matched);
122        }
123
124        #[tokio::test]
125        async fn returns_stream_closed_when_expected_is_not_matched_before_eof() {
126            let result = wait_for_ready(vec![
127                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nstill starting\n"))),
128                StreamEvent::Eof,
129            ])
130            .await;
131            assert_that!(result)
132                .is_ok()
133                .is_equal_to(WaitForLineResult::StreamClosed);
134        }
135
136        #[tokio::test]
137        async fn gap_does_not_join_lines() {
138            let result = wait_for_ready(vec![
139                StreamEvent::Chunk(Chunk(Bytes::from_static(b"rea"))),
140                StreamEvent::Gap,
141                StreamEvent::Chunk(Chunk(Bytes::from_static(b"dy\n"))),
142                StreamEvent::Eof,
143            ])
144            .await;
145            assert_that!(result)
146                .is_ok()
147                .is_equal_to(WaitForLineResult::StreamClosed);
148        }
149
150        #[tokio::test]
151        async fn reports_read_error() {
152            let result = wait_for_ready(vec![
153                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\npartial"))),
154                StreamEvent::ReadError(StreamReadError::new(
155                    "custom",
156                    io::Error::from(io::ErrorKind::BrokenPipe),
157                )),
158            ])
159            .await;
160
161            let err = result.expect_err("read failure should be surfaced");
162            assert_that!(err.stream_name()).is_equal_to("custom");
163            assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
164        }
165
166        #[test]
167        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
168        fn panics_when_max_line_length_is_zero() {
169            let _visitor = LineAdapter::new(
170                LineParsingOptions {
171                    max_line_length: 0.bytes(),
172                    overflow_behavior: LineOverflowBehavior::default(),
173                    buffer_compaction_threshold: None,
174                },
175                WaitForLineSink::new(|_line| true),
176            );
177        }
178
179        #[tokio::test]
180        async fn honors_line_parsing_options() {
181            let result = drive_wait_for_line(
182                vec![
183                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"readiness\n"))),
184                    StreamEvent::Eof,
185                ],
186                |line| line == "read",
187                LineParsingOptions {
188                    max_line_length: 4.bytes(),
189                    overflow_behavior: LineOverflowBehavior::DropAdditionalData,
190                    buffer_compaction_threshold: None,
191                },
192            )
193            .await;
194
195            assert_that!(result)
196                .is_ok()
197                .is_equal_to(WaitForLineResult::Matched);
198        }
199    }
200
201    mod wait_for_line_bounded {
202        use super::*;
203
204        #[tokio::test]
205        async fn times_out_with_timeout_error() {
206            let (_tx, rx) = mpsc::channel::<StreamEvent>(1);
207            let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
208            let visitor = LineAdapter::new(
209                LineParsingOptions::default(),
210                WaitForLineSink::new(|line| line == "ready"),
211            );
212            let timeout = tokio::time::timeout(
213                Duration::from_millis(25),
214                consume_sync(rx, visitor, term_sig_rx),
215            )
216            .await
217            .map_or(Ok(WaitForLineResult::Timeout), |inner| {
218                inner.map(|matched| {
219                    if matched {
220                        WaitForLineResult::Matched
221                    } else {
222                        WaitForLineResult::StreamClosed
223                    }
224                })
225            });
226            assert_that!(timeout).is_equal_to(Ok(WaitForLineResult::Timeout));
227        }
228    }
229}