tokio_process_tools/output_stream/visitors/
wait.rs1use crate::output_stream::Next;
2use crate::output_stream::line::adapter::LineSink;
3use std::borrow::Cow;
4
5pub struct WaitForLineSink<P> {
10 predicate: P,
11 matched: bool,
12}
13
14impl<P> WaitForLineSink<P>
15where
16 P: Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
17{
18 pub fn new(predicate: P) -> Self {
20 Self {
21 predicate,
22 matched: false,
23 }
24 }
25}
26
27impl<P> LineSink for WaitForLineSink<P>
28where
29 P: Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
30{
31 type Output = bool;
32
33 fn on_line(&mut self, line: Cow<'_, str>) -> Next {
34 if (self.predicate)(line) {
35 self.matched = true;
36 Next::Break
37 } else {
38 Next::Continue
39 }
40 }
41
42 fn into_output(self) -> Self::Output {
43 self.matched
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use crate::output_stream::consumer::driver::consume_sync;
51 use crate::output_stream::event::{Chunk, StreamEvent};
52 use crate::output_stream::line::adapter::LineAdapter;
53 use crate::output_stream::line::options::LineParsingOptions;
54 use crate::{LineOverflowBehavior, NumBytesExt, StreamReadError, WaitForLineResult};
55 use assertr::prelude::*;
56 use bytes::Bytes;
57 use std::io;
58 use std::time::Duration;
59 use tokio::sync::{mpsc, oneshot};
60
61 async fn drive_wait_for_line(
66 events: Vec<StreamEvent>,
67 predicate: impl Fn(Cow<'_, str>) -> bool + Send + Sync + 'static,
68 options: LineParsingOptions,
69 ) -> Result<WaitForLineResult, StreamReadError> {
70 let (tx, rx) = mpsc::channel(events.len().max(1));
71 for event in events {
72 tx.send(event).await.unwrap();
73 }
74 drop(tx);
75
76 let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
77 let visitor = LineAdapter::new(options, WaitForLineSink::new(predicate));
78 let matched = consume_sync(rx, visitor, term_sig_rx).await?;
79 if matched {
80 Ok(WaitForLineResult::Matched)
81 } else {
82 Ok(WaitForLineResult::StreamClosed)
83 }
84 }
85
86 async fn wait_for_ready(
87 events: Vec<StreamEvent>,
88 ) -> Result<WaitForLineResult, StreamReadError> {
89 drive_wait_for_line(
90 events,
91 |line| line == "ready",
92 LineParsingOptions::default(),
93 )
94 .await
95 }
96
97 mod wait_for_line {
98 use super::*;
99
100 #[tokio::test]
101 async fn matches_intermediary_line() {
102 let result = wait_for_ready(vec![
103 StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready\n"))),
104 StreamEvent::Eof,
105 ])
106 .await;
107 assert_that!(result)
108 .is_ok()
109 .is_equal_to(WaitForLineResult::Matched);
110 }
111
112 #[tokio::test]
113 async fn matches_final_line() {
114 let result = wait_for_ready(vec![
115 StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nready"))),
116 StreamEvent::Eof,
117 ])
118 .await;
119 assert_that!(result)
120 .is_ok()
121 .is_equal_to(WaitForLineResult::Matched);
122 }
123
124 #[tokio::test]
125 async fn returns_stream_closed_when_expected_is_not_matched_before_eof() {
126 let result = wait_for_ready(vec![
127 StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\nstill starting\n"))),
128 StreamEvent::Eof,
129 ])
130 .await;
131 assert_that!(result)
132 .is_ok()
133 .is_equal_to(WaitForLineResult::StreamClosed);
134 }
135
136 #[tokio::test]
137 async fn gap_does_not_join_lines() {
138 let result = wait_for_ready(vec![
139 StreamEvent::Chunk(Chunk(Bytes::from_static(b"rea"))),
140 StreamEvent::Gap,
141 StreamEvent::Chunk(Chunk(Bytes::from_static(b"dy\n"))),
142 StreamEvent::Eof,
143 ])
144 .await;
145 assert_that!(result)
146 .is_ok()
147 .is_equal_to(WaitForLineResult::StreamClosed);
148 }
149
150 #[tokio::test]
151 async fn reports_read_error() {
152 let result = wait_for_ready(vec![
153 StreamEvent::Chunk(Chunk(Bytes::from_static(b"booting\npartial"))),
154 StreamEvent::ReadError(StreamReadError::new(
155 "custom",
156 io::Error::from(io::ErrorKind::BrokenPipe),
157 )),
158 ])
159 .await;
160
161 let err = result.expect_err("read failure should be surfaced");
162 assert_that!(err.stream_name()).is_equal_to("custom");
163 assert_that!(err.kind()).is_equal_to(io::ErrorKind::BrokenPipe);
164 }
165
166 #[test]
167 #[should_panic(expected = "LineParsingOptions::max_line_length must be greater than zero")]
168 fn panics_when_max_line_length_is_zero() {
169 let _visitor = LineAdapter::new(
170 LineParsingOptions {
171 max_line_length: 0.bytes(),
172 overflow_behavior: LineOverflowBehavior::default(),
173 buffer_compaction_threshold: None,
174 },
175 WaitForLineSink::new(|_line| true),
176 );
177 }
178
179 #[tokio::test]
180 async fn honors_line_parsing_options() {
181 let result = drive_wait_for_line(
182 vec![
183 StreamEvent::Chunk(Chunk(Bytes::from_static(b"readiness\n"))),
184 StreamEvent::Eof,
185 ],
186 |line| line == "read",
187 LineParsingOptions {
188 max_line_length: 4.bytes(),
189 overflow_behavior: LineOverflowBehavior::DropAdditionalData,
190 buffer_compaction_threshold: None,
191 },
192 )
193 .await;
194
195 assert_that!(result)
196 .is_ok()
197 .is_equal_to(WaitForLineResult::Matched);
198 }
199 }
200
201 mod wait_for_line_bounded {
202 use super::*;
203
204 #[tokio::test]
205 async fn times_out_with_timeout_error() {
206 let (_tx, rx) = mpsc::channel::<StreamEvent>(1);
207 let (_term_sig_tx, term_sig_rx) = oneshot::channel::<()>();
208 let visitor = LineAdapter::new(
209 LineParsingOptions::default(),
210 WaitForLineSink::new(|line| line == "ready"),
211 );
212 let timeout = tokio::time::timeout(
213 Duration::from_millis(25),
214 consume_sync(rx, visitor, term_sig_rx),
215 )
216 .await
217 .map_or(Ok(WaitForLineResult::Timeout), |inner| {
218 inner.map(|matched| {
219 if matched {
220 WaitForLineResult::Matched
221 } else {
222 WaitForLineResult::StreamClosed
223 }
224 })
225 });
226 assert_that!(timeout).is_equal_to(Ok(WaitForLineResult::Timeout));
227 }
228 }
229}