Skip to main content

tokio_process_tools/output_stream/visitors/
wait.rs

1use crate::output_stream::Next;
2use crate::output_stream::line::adapter::LineVisitor;
3use std::borrow::Cow;
4
5/// [`LineVisitor`] that breaks the moment a predicate accepts a line and remembers whether it
6/// has matched yet. Compose with
7/// [`ParseLines`](crate::output_stream::line::adapter::ParseLines) to build a
8/// predicate-driven consumer; the backends' `wait_for_line` helpers wrap this same visitor with
9/// an additional [`tokio::time::timeout`].
10///
11/// There is no async counterpart. The predicate is a synchronous
12/// `Fn(Cow<'_, str>) -> bool` and has no `.await`, so an [`AsyncLineVisitor`] would not add
13/// capability. If you need to react asynchronously when a matching line is observed, await
14/// [`Consumer::wait`](crate::Consumer::wait) on this visitor's consumer and dispatch from there.
15///
16/// [`AsyncLineVisitor`]: crate::output_stream::line::adapter::AsyncLineVisitor
17pub struct WaitForLine<P> {
18    predicate: P,
19    matched: bool,
20}
21
22impl<P> WaitForLine<P>
23where
24    P: Fn(Cow<'_, str>) -> bool + Send + 'static,
25{
26    /// Creates a new sink that breaks the parser the first time `predicate` returns `true`.
27    pub fn new(predicate: P) -> Self {
28        Self {
29            predicate,
30            matched: false,
31        }
32    }
33}
34
35impl<P> LineVisitor for WaitForLine<P>
36where
37    P: Fn(Cow<'_, str>) -> bool + Send + 'static,
38{
39    type Output = bool;
40
41    fn on_line(&mut self, line: Cow<'_, str>) -> Next {
42        if (self.predicate)(line) {
43            self.matched = true;
44            Next::Break
45        } else {
46            Next::Continue
47        }
48    }
49
50    fn into_output(self) -> Self::Output {
51        self.matched
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use crate::output_stream::consumer::driver::consume_sync;
59    use crate::output_stream::event::{Chunk, StreamEvent};
60    use crate::output_stream::line::adapter::ParseLines;
61    use crate::output_stream::line::options::LineParsingOptions;
62    use crate::{LineOverflowBehavior, NumBytesExt, StreamReadError, WaitForLineResult};
63    use assertr::prelude::*;
64    use bytes::Bytes;
65    use std::io;
66    use std::time::Duration;
67    use tokio::sync::{mpsc, oneshot};
68
69    /// Drive a `WaitForLine` over the supplied events and translate the visitor's `bool`
70    /// output into [`WaitForLineResult`]. Mirrors what the deleted `wait_for_line` factory
71    /// used to do; lives in tests because production code now drives the visitor straight from
72    /// the backend method.
73    async fn drive_wait_for_line(
74        events: Vec<StreamEvent>,
75        predicate: impl Fn(Cow<'_, str>) -> bool + Send + 'static,
76        options: LineParsingOptions,
77    ) -> Result<WaitForLineResult, StreamReadError> {
78        let (tx, rx) = mpsc::channel(events.len().max(1));
79        for event in events {
80            tx.send(event).await.unwrap();
81        }
82        drop(tx);
83
84        let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
85        let visitor = ParseLines::new(options, WaitForLine::new(predicate));
86        let matched = consume_sync(rx, visitor, term_sig_rx).await?;
87        if matched {
88            Ok(WaitForLineResult::Matched)
89        } else {
90            Ok(WaitForLineResult::StreamClosed)
91        }
92    }
93
94    async fn wait_for_ready(
95        events: Vec<StreamEvent>,
96    ) -> Result<WaitForLineResult, StreamReadError> {
97        drive_wait_for_line(
98            events,
99            |line| line == "ready",
100            LineParsingOptions::default(),
101        )
102        .await
103    }
104
105    mod wait_for_line {
106        use super::*;
107
108        #[tokio::test]
109        async fn matches_intermediary_line() {
110            let result = wait_for_ready(vec![
111                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready\n"))),
112                StreamEvent::Eof,
113            ])
114            .await;
115            assert_that!(result)
116                .is_ok()
117                .is_equal_to(WaitForLineResult::Matched);
118        }
119
120        #[tokio::test]
121        async fn matches_final_line() {
122            let result = wait_for_ready(vec![
123                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready"))),
124                StreamEvent::Eof,
125            ])
126            .await;
127            assert_that!(result)
128                .is_ok()
129                .is_equal_to(WaitForLineResult::Matched);
130        }
131
132        #[tokio::test]
133        async fn returns_stream_closed_when_expected_is_not_matched_before_eof() {
134            let result = wait_for_ready(vec![
135                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nstill starting\n"))),
136                StreamEvent::Eof,
137            ])
138            .await;
139            assert_that!(result)
140                .is_ok()
141                .is_equal_to(WaitForLineResult::StreamClosed);
142        }
143
144        #[tokio::test]
145        async fn gap_does_not_join_lines() {
146            let result = wait_for_ready(vec![
147                StreamEvent::Chunk(Chunk(Bytes::from_static(b"rea"))),
148                StreamEvent::Gap,
149                StreamEvent::Chunk(Chunk(Bytes::from_static(b"dy\n"))),
150                StreamEvent::Eof,
151            ])
152            .await;
153            assert_that!(result)
154                .is_ok()
155                .is_equal_to(WaitForLineResult::StreamClosed);
156        }
157
158        #[tokio::test]
159        async fn reports_read_error() {
160            let result = wait_for_ready(vec![
161                StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\npartial"))),
162                StreamEvent::ReadError(StreamReadError::new(
163                    "custom",
164                    io::Error::from(io::ErrorKind::BrokenPipe),
165                )),
166            ])
167            .await;
168
169            let err = result.expect_err("read failure should be surfaced");
170            assert_that!(err.stream_name()).is_equal_to("custom");
171            assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
172        }
173
174        #[test]
175        #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
176        fn panics_when_max_line_length_is_zero() {
177            let _visitor = ParseLines::new(
178                LineParsingOptions {
179                    max_line_length: 0.bytes(),
180                    overflow_behavior: LineOverflowBehavior::default(),
181                    buffer_compaction_threshold: None,
182                },
183                WaitForLine::new(|_line| true),
184            );
185        }
186
187        #[tokio::test]
188        async fn honors_line_parsing_options() {
189            let result = drive_wait_for_line(
190                vec![
191                    StreamEvent::Chunk(Chunk(Bytes::from_static(b"readiness\n"))),
192                    StreamEvent::Eof,
193                ],
194                |line| line == "read",
195                LineParsingOptions {
196                    max_line_length: 4.bytes(),
197                    overflow_behavior: LineOverflowBehavior::DropAdditionalData,
198                    buffer_compaction_threshold: None,
199                },
200            )
201            .await;
202
203            assert_that!(result)
204                .is_ok()
205                .is_equal_to(WaitForLineResult::Matched);
206        }
207    }
208
209    mod wait_for_line_bounded {
210        use super::*;
211
212        #[tokio::test]
213        async fn times_out_with_timeout_error() {
214            let (_tx, rx) = mpsc::channel::<StreamEvent>(1);
215            let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
216            let visitor = ParseLines::new(
217                LineParsingOptions::default(),
218                WaitForLine::new(|line| line == "ready"),
219            );
220            let timeout = tokio::time::timeout(
221                Duration::from_millis(25),
222                consume_sync(rx, visitor, term_sig_rx),
223            )
224            .await
225            .map_or(Ok(WaitForLineResult::Timeout), |inner| {
226                inner.map(|matched| {
227                    if matched {
228                        WaitForLineResult::Matched
229                    } else {
230                        WaitForLineResult::StreamClosed
231                    }
232                })
233            });
234            assert_that!(timeout).is_equal_to(Ok(WaitForLineResult::Timeout));
235        }
236    }
237}