1use crate::{CancelToken, Error, Result, SifliToolTrait};
2use serialport::{ClearBuffer, SerialPort};
3use std::collections::VecDeque;
4use std::io::{self, ErrorKind, Read, Write};
5use std::time::{Duration, Instant};
6
7#[cfg(test)]
8use serialport::{DataBits, FlowControl, Parity, StopBits};
9#[cfg(test)]
10use std::sync::{Arc, Mutex};
11
12const SLEEP_CHUNK: Duration = Duration::from_millis(25);
13const IDLE_BACKOFF: Duration = Duration::from_millis(5);
14const MAX_CAPTURE_BUFFER: usize = 1024;
15
16pub struct PatternMatch {
17 pub index: usize,
18 pub buffer: Vec<u8>,
19}
20
21pub fn sleep_with_cancel(cancel_token: &CancelToken, duration: Duration) -> Result<()> {
22 let mut remaining = duration;
23 while remaining > Duration::ZERO {
24 cancel_token.check_cancelled()?;
25 let sleep_for = remaining.min(SLEEP_CHUNK);
26 std::thread::sleep(sleep_for);
27 remaining = remaining.saturating_sub(sleep_for);
28 }
29 cancel_token.check_cancelled()
30}
31
32pub fn io_cancelled_error() -> io::Error {
33 io::Error::new(ErrorKind::Interrupted, Error::Cancelled)
34}
35
36pub fn is_cancelled_io_error(error: &io::Error) -> bool {
37 if error.kind() != ErrorKind::Interrupted {
38 return false;
39 }
40
41 error
42 .get_ref()
43 .and_then(|inner| inner.downcast_ref::<Error>())
44 .is_some_and(|inner| matches!(inner, Error::Cancelled))
45}
46
47pub struct CancelableReader {
48 port: Box<dyn SerialPort>,
49 cancel_token: CancelToken,
50}
51
52impl Read for CancelableReader {
53 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
54 self.cancel_token
55 .check_cancelled()
56 .map_err(|_| io_cancelled_error())?;
57 self.port.read(buf)
58 }
59}
60
61pub struct CancelableWriter {
62 port: Box<dyn SerialPort>,
63 cancel_token: CancelToken,
64}
65
66impl Write for CancelableWriter {
67 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68 self.cancel_token
69 .check_cancelled()
70 .map_err(|_| io_cancelled_error())?;
71 self.port.write(buf)
72 }
73
74 fn flush(&mut self) -> io::Result<()> {
75 self.cancel_token
76 .check_cancelled()
77 .map_err(|_| io_cancelled_error())?;
78 self.port.flush()
79 }
80}
81
82pub struct SerialIo<'a> {
83 port: &'a mut dyn SerialPort,
84 cancel_token: CancelToken,
85}
86
87impl<'a> SerialIo<'a> {
88 pub fn new(port: &'a mut dyn SerialPort, cancel_token: CancelToken) -> Self {
89 Self { port, cancel_token }
90 }
91
92 pub fn cancel_token(&self) -> &CancelToken {
93 &self.cancel_token
94 }
95
96 pub fn check_cancelled(&self) -> Result<()> {
97 self.cancel_token.check_cancelled()
98 }
99
100 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
101 self.check_cancelled()?;
102 self.port.read(buf).map_err(Into::into)
103 }
104
105 pub fn write_all(&mut self, buf: &[u8]) -> Result<()> {
106 self.check_cancelled()?;
107 self.port.write_all(buf)?;
108 self.check_cancelled()
109 }
110
111 pub fn flush(&mut self) -> Result<()> {
112 self.check_cancelled()?;
113 self.port.flush()?;
114 self.check_cancelled()
115 }
116
117 pub fn clear(&mut self, buffer: ClearBuffer) -> Result<()> {
118 self.check_cancelled()?;
119 self.port.clear(buffer)?;
120 self.check_cancelled()
121 }
122
123 pub fn set_baud_rate(&mut self, baud_rate: u32) -> Result<()> {
124 self.check_cancelled()?;
125 self.port.set_baud_rate(baud_rate)?;
126 self.check_cancelled()
127 }
128
129 pub fn write_request_to_send(&mut self, level: bool) -> Result<()> {
130 self.check_cancelled()?;
131 self.port.write_request_to_send(level)?;
132 self.check_cancelled()
133 }
134
135 pub fn sleep(&self, duration: Duration) -> Result<()> {
136 sleep_with_cancel(&self.cancel_token, duration)
137 }
138
139 pub fn try_clone_reader(&mut self) -> Result<CancelableReader> {
140 self.check_cancelled()?;
141 Ok(CancelableReader {
142 port: self.port.try_clone()?,
143 cancel_token: self.cancel_token.clone(),
144 })
145 }
146
147 pub fn try_clone_writer(&mut self) -> Result<CancelableWriter> {
148 self.check_cancelled()?;
149 Ok(CancelableWriter {
150 port: self.port.try_clone()?,
151 cancel_token: self.cancel_token.clone(),
152 })
153 }
154
155 pub fn read_exact_with_timeout(
156 &mut self,
157 buf: &mut [u8],
158 timeout: Duration,
159 context: &str,
160 ) -> Result<()> {
161 if buf.is_empty() {
162 return Ok(());
163 }
164
165 let mut last_activity = Instant::now();
166 let mut offset = 0usize;
167
168 while offset < buf.len() {
169 self.check_cancelled()?;
170 match self.port.read(&mut buf[offset..]) {
171 Ok(0) => {
172 if last_activity.elapsed() > timeout {
173 return Err(Error::timeout(format!("waiting for {}", context)));
174 }
175 self.sleep(IDLE_BACKOFF)?;
176 }
177 Ok(n) => {
178 offset += n;
179 last_activity = Instant::now();
180 }
181 Err(error)
182 if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
183 {
184 if last_activity.elapsed() > timeout {
185 return Err(Error::timeout(format!("waiting for {}", context)));
186 }
187 self.sleep(IDLE_BACKOFF)?;
188 }
189 Err(error) if error.kind() == ErrorKind::Interrupted => continue,
190 Err(error) => return Err(error.into()),
191 }
192 }
193
194 Ok(())
195 }
196
197 pub fn read_line_with_timeout(&mut self, timeout: Duration, context: &str) -> Result<String> {
198 let mut buffer = Vec::new();
199 let mut last_activity = Instant::now();
200
201 loop {
202 self.check_cancelled()?;
203 let mut byte = [0u8; 1];
204 match self.port.read(&mut byte) {
205 Ok(0) => {
206 if last_activity.elapsed() > timeout {
207 return Err(Error::timeout(format!("waiting for {}", context)));
208 }
209 }
210 Ok(_) => {
211 last_activity = Instant::now();
212 match byte[0] {
213 b'\n' => break,
214 b'\r' => continue,
215 ch => buffer.push(ch),
216 }
217 }
218 Err(error)
219 if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
220 {
221 if last_activity.elapsed() > timeout {
222 return Err(Error::timeout(format!("waiting for {}", context)));
223 }
224 }
225 Err(error) if error.kind() == ErrorKind::Interrupted => continue,
226 Err(error) => return Err(error.into()),
227 }
228 }
229
230 Ok(String::from_utf8_lossy(&buffer).into_owned())
231 }
232
233 pub fn read_non_empty_line_with_timeout(
234 &mut self,
235 timeout: Duration,
236 context: &str,
237 ) -> Result<String> {
238 loop {
239 let line = self.read_line_with_timeout(timeout, context)?;
240 let trimmed = line.trim().to_string();
241 if !trimmed.is_empty() {
242 return Ok(trimmed);
243 }
244 }
245 }
246
247 pub fn wait_for_pattern(
248 &mut self,
249 pattern: &[u8],
250 timeout: Duration,
251 context: &str,
252 ) -> Result<Vec<u8>> {
253 let matched = self.wait_for_patterns(&[pattern], timeout, context)?;
254 Ok(matched.buffer)
255 }
256
257 pub fn wait_for_patterns(
258 &mut self,
259 patterns: &[&[u8]],
260 timeout: Duration,
261 context: &str,
262 ) -> Result<PatternMatch> {
263 let start = Instant::now();
264 let max_len = patterns
265 .iter()
266 .map(|pattern| pattern.len())
267 .max()
268 .unwrap_or(0);
269 let mut buffer = Vec::new();
270 let mut window = VecDeque::with_capacity(max_len.max(1));
271
272 loop {
273 self.check_cancelled()?;
274 if start.elapsed() > timeout {
275 return Err(Error::timeout(format!("waiting for {}", context)));
276 }
277
278 let mut byte = [0u8; 1];
279 match self.port.read(&mut byte) {
280 Ok(0) => continue,
281 Ok(_) => {
282 buffer.push(byte[0]);
283 if buffer.len() > MAX_CAPTURE_BUFFER {
284 let drain_len = buffer.len() - MAX_CAPTURE_BUFFER;
285 buffer.drain(..drain_len);
286 }
287 window.push_back(byte[0]);
288 if window.len() > max_len {
289 window.pop_front();
290 }
291
292 for (index, pattern) in patterns.iter().enumerate() {
293 if window.len() >= pattern.len()
294 && window
295 .iter()
296 .rev()
297 .take(pattern.len())
298 .rev()
299 .copied()
300 .eq(pattern.iter().copied())
301 {
302 return Ok(PatternMatch { index, buffer });
303 }
304 }
305 }
306 Err(error)
307 if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
308 {
309 continue;
310 }
311 Err(error) if error.kind() == ErrorKind::Interrupted => continue,
312 Err(error) => return Err(error.into()),
313 }
314 }
315 }
316
317 pub fn wait_for_prompt(
318 &mut self,
319 prompt: &[u8],
320 retry_interval: Duration,
321 max_retries: u32,
322 ) -> Result<()> {
323 let mut retry_count = 0u32;
324 let mut window = VecDeque::with_capacity(prompt.len().max(1));
325 let mut last_retry = Instant::now();
326
327 self.write_all(b"\r\n")?;
328 self.flush()?;
329
330 loop {
331 self.check_cancelled()?;
332
333 if last_retry.elapsed() > retry_interval {
334 self.clear(ClearBuffer::All)?;
335 self.sleep(Duration::from_millis(100))?;
336 retry_count = retry_count.saturating_add(1);
337 if retry_count > max_retries {
338 return Err(Error::timeout("waiting for shell prompt"));
339 }
340 last_retry = Instant::now();
341 window.clear();
342 self.write_all(b"\r\n")?;
343 self.flush()?;
344 }
345
346 let mut byte = [0u8; 1];
347 match self.port.read(&mut byte) {
348 Ok(0) => self.sleep(IDLE_BACKOFF)?,
349 Ok(_) => {
350 window.push_back(byte[0]);
351 if window.len() > prompt.len() {
352 window.pop_front();
353 }
354
355 if window.len() == prompt.len()
356 && window.iter().copied().eq(prompt.iter().copied())
357 {
358 return Ok(());
359 }
360 }
361 Err(error)
362 if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
363 {
364 self.sleep(IDLE_BACKOFF)?;
365 }
366 Err(error) if error.kind() == ErrorKind::Interrupted => continue,
367 Err(error) => return Err(error.into()),
368 }
369 }
370 }
371}
372
373pub fn for_tool<T: SifliToolTrait + ?Sized>(tool: &mut T) -> SerialIo<'_> {
374 let cancel_token = tool.base().cancel_token.clone();
375 SerialIo::new(tool.port().as_mut(), cancel_token)
376}
377
378#[cfg(test)]
379pub(crate) mod test_support {
380 use super::*;
381
382 #[derive(Default)]
383 pub struct TestSerialPortState {
384 pub read_data: VecDeque<u8>,
385 pub writes: Vec<u8>,
386 pub baud_rate: u32,
387 pub timeout: Duration,
388 pub clear_calls: usize,
389 pub rts_history: Vec<bool>,
390 pub write_calls: usize,
391 pub cancel_on_write_call: Option<(usize, CancelToken)>,
392 }
393
394 pub struct TestSerialPort {
395 state: Arc<Mutex<TestSerialPortState>>,
396 }
397
398 impl TestSerialPort {
399 pub fn from_bytes(bytes: &[u8]) -> (Self, Arc<Mutex<TestSerialPortState>>) {
400 let state = Arc::new(Mutex::new(TestSerialPortState {
401 read_data: bytes.iter().copied().collect(),
402 baud_rate: 1_000_000,
403 timeout: Duration::from_millis(5),
404 ..Default::default()
405 }));
406 (
407 Self {
408 state: state.clone(),
409 },
410 state,
411 )
412 }
413 }
414
415 impl Read for TestSerialPort {
416 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
417 let mut state = self.state.lock().unwrap();
418 if state.read_data.is_empty() {
419 return Err(io::Error::new(ErrorKind::TimedOut, "no data"));
420 }
421
422 let bytes_read = buf.len().min(state.read_data.len());
423 for slot in buf.iter_mut().take(bytes_read) {
424 *slot = state.read_data.pop_front().unwrap();
425 }
426 Ok(bytes_read)
427 }
428 }
429
430 impl Write for TestSerialPort {
431 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
432 let mut state = self.state.lock().unwrap();
433 state.write_calls = state.write_calls.saturating_add(1);
434 state.writes.extend_from_slice(buf);
435 if let Some((target_call, token)) = &state.cancel_on_write_call
436 && state.write_calls >= *target_call
437 {
438 token.cancel();
439 }
440 Ok(buf.len())
441 }
442
443 fn flush(&mut self) -> io::Result<()> {
444 Ok(())
445 }
446 }
447
448 impl SerialPort for TestSerialPort {
449 fn name(&self) -> Option<String> {
450 Some("test-port".to_string())
451 }
452
453 fn baud_rate(&self) -> serialport::Result<u32> {
454 Ok(self.state.lock().unwrap().baud_rate)
455 }
456
457 fn data_bits(&self) -> serialport::Result<DataBits> {
458 Ok(DataBits::Eight)
459 }
460
461 fn flow_control(&self) -> serialport::Result<FlowControl> {
462 Ok(FlowControl::None)
463 }
464
465 fn parity(&self) -> serialport::Result<Parity> {
466 Ok(Parity::None)
467 }
468
469 fn stop_bits(&self) -> serialport::Result<StopBits> {
470 Ok(StopBits::One)
471 }
472
473 fn timeout(&self) -> Duration {
474 self.state.lock().unwrap().timeout
475 }
476
477 fn set_baud_rate(&mut self, baud_rate: u32) -> serialport::Result<()> {
478 self.state.lock().unwrap().baud_rate = baud_rate;
479 Ok(())
480 }
481
482 fn set_data_bits(&mut self, _: DataBits) -> serialport::Result<()> {
483 Ok(())
484 }
485
486 fn set_flow_control(&mut self, _: FlowControl) -> serialport::Result<()> {
487 Ok(())
488 }
489
490 fn set_parity(&mut self, _: Parity) -> serialport::Result<()> {
491 Ok(())
492 }
493
494 fn set_stop_bits(&mut self, _: StopBits) -> serialport::Result<()> {
495 Ok(())
496 }
497
498 fn set_timeout(&mut self, timeout: Duration) -> serialport::Result<()> {
499 self.state.lock().unwrap().timeout = timeout;
500 Ok(())
501 }
502
503 fn write_request_to_send(&mut self, level: bool) -> serialport::Result<()> {
504 self.state.lock().unwrap().rts_history.push(level);
505 Ok(())
506 }
507
508 fn write_data_terminal_ready(&mut self, _: bool) -> serialport::Result<()> {
509 Ok(())
510 }
511
512 fn read_clear_to_send(&mut self) -> serialport::Result<bool> {
513 Ok(false)
514 }
515
516 fn read_data_set_ready(&mut self) -> serialport::Result<bool> {
517 Ok(false)
518 }
519
520 fn read_ring_indicator(&mut self) -> serialport::Result<bool> {
521 Ok(false)
522 }
523
524 fn read_carrier_detect(&mut self) -> serialport::Result<bool> {
525 Ok(false)
526 }
527
528 fn bytes_to_read(&self) -> serialport::Result<u32> {
529 Ok(self.state.lock().unwrap().read_data.len() as u32)
530 }
531
532 fn bytes_to_write(&self) -> serialport::Result<u32> {
533 Ok(0)
534 }
535
536 fn clear(&self, _: ClearBuffer) -> serialport::Result<()> {
537 self.state.lock().unwrap().clear_calls += 1;
538 Ok(())
539 }
540
541 fn try_clone(&self) -> serialport::Result<Box<dyn SerialPort>> {
542 Ok(Box::new(Self {
543 state: self.state.clone(),
544 }))
545 }
546
547 fn set_break(&self) -> serialport::Result<()> {
548 Ok(())
549 }
550
551 fn clear_break(&self) -> serialport::Result<()> {
552 Ok(())
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::{Duration, *};
560 use crate::CancelToken;
561
562 #[test]
563 fn wait_for_pattern_stops_when_cancelled() {
564 let (mut port, _) = test_support::TestSerialPort::from_bytes(&[]);
565 let token = CancelToken::new();
566 token.cancel();
567 let mut io = SerialIo::new(&mut port, token);
568
569 let result = io.wait_for_pattern(b"OK", Duration::from_millis(50), "OK response");
570
571 assert!(matches!(result, Err(Error::Cancelled)));
572 }
573
574 #[test]
575 fn wait_for_prompt_retries_and_can_be_cancelled() {
576 let (mut port, _) = test_support::TestSerialPort::from_bytes(&[]);
577 let token = CancelToken::new();
578 token.cancel();
579 let mut io = SerialIo::new(&mut port, token);
580
581 let result = io.wait_for_prompt(b"msh >", Duration::from_millis(50), 1);
582
583 assert!(matches!(result, Err(Error::Cancelled)));
584 }
585
586 #[test]
587 fn cloned_reader_reports_cancelled_io_error() {
588 let (mut port, state) = test_support::TestSerialPort::from_bytes(b"abc");
589 let token = CancelToken::new();
590 state.lock().unwrap().cancel_on_write_call = Some((1, token.clone()));
591 let mut io = SerialIo::new(&mut port, token);
592
593 let mut reader = io.try_clone_reader().unwrap();
594 let mut writer = io.try_clone_writer().unwrap();
595 writer.write_all(b"x").unwrap();
596
597 let mut buffer = [0u8; 1];
598 let error = reader.read(&mut buffer).unwrap_err();
599 assert!(is_cancelled_io_error(&error));
600 }
601
602 #[test]
603 fn wait_for_patterns_bounds_captured_buffer() {
604 let mut bytes = vec![b'a'; MAX_CAPTURE_BUFFER + 32];
605 bytes.extend_from_slice(b"OK");
606 let (mut port, _) = test_support::TestSerialPort::from_bytes(&bytes);
607 let token = CancelToken::new();
608 let mut io = SerialIo::new(&mut port, token);
609
610 let matched = io
611 .wait_for_patterns(&[b"OK"], Duration::from_millis(100), "OK response")
612 .unwrap();
613
614 assert!(matched.buffer.len() <= MAX_CAPTURE_BUFFER);
615 assert!(matched.buffer.ends_with(b"OK"));
616 }
617}