Skip to main content

sftool_lib/common/
serial_io.rs

1use crate::{CancelToken, Error, Result, SifliToolTrait};
2use serialport::{ClearBuffer, SerialPort};
3use std::collections::VecDeque;
4use std::io::{self, ErrorKind, Read, Write};
5use std::time::{Duration, Instant};
6
7#[cfg(test)]
8use serialport::{DataBits, FlowControl, Parity, StopBits};
9#[cfg(test)]
10use std::sync::{Arc, Mutex};
11
12const SLEEP_CHUNK: Duration = Duration::from_millis(25);
13const IDLE_BACKOFF: Duration = Duration::from_millis(5);
14const MAX_CAPTURE_BUFFER: usize = 1024;
15
16pub struct PatternMatch {
17    pub index: usize,
18    pub buffer: Vec<u8>,
19}
20
21pub fn sleep_with_cancel(cancel_token: &CancelToken, duration: Duration) -> Result<()> {
22    let mut remaining = duration;
23    while remaining > Duration::ZERO {
24        cancel_token.check_cancelled()?;
25        let sleep_for = remaining.min(SLEEP_CHUNK);
26        std::thread::sleep(sleep_for);
27        remaining = remaining.saturating_sub(sleep_for);
28    }
29    cancel_token.check_cancelled()
30}
31
32pub fn io_cancelled_error() -> io::Error {
33    io::Error::new(ErrorKind::Interrupted, Error::Cancelled)
34}
35
36pub fn is_cancelled_io_error(error: &io::Error) -> bool {
37    if error.kind() != ErrorKind::Interrupted {
38        return false;
39    }
40
41    error
42        .get_ref()
43        .and_then(|inner| inner.downcast_ref::<Error>())
44        .is_some_and(|inner| matches!(inner, Error::Cancelled))
45}
46
47pub struct CancelableReader {
48    port: Box<dyn SerialPort>,
49    cancel_token: CancelToken,
50}
51
52impl Read for CancelableReader {
53    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
54        self.cancel_token
55            .check_cancelled()
56            .map_err(|_| io_cancelled_error())?;
57        self.port.read(buf)
58    }
59}
60
61pub struct CancelableWriter {
62    port: Box<dyn SerialPort>,
63    cancel_token: CancelToken,
64}
65
66impl Write for CancelableWriter {
67    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68        self.cancel_token
69            .check_cancelled()
70            .map_err(|_| io_cancelled_error())?;
71        self.port.write(buf)
72    }
73
74    fn flush(&mut self) -> io::Result<()> {
75        self.cancel_token
76            .check_cancelled()
77            .map_err(|_| io_cancelled_error())?;
78        self.port.flush()
79    }
80}
81
82pub struct SerialIo<'a> {
83    port: &'a mut dyn SerialPort,
84    cancel_token: CancelToken,
85}
86
87impl<'a> SerialIo<'a> {
88    pub fn new(port: &'a mut dyn SerialPort, cancel_token: CancelToken) -> Self {
89        Self { port, cancel_token }
90    }
91
92    pub fn cancel_token(&self) -> &CancelToken {
93        &self.cancel_token
94    }
95
96    pub fn check_cancelled(&self) -> Result<()> {
97        self.cancel_token.check_cancelled()
98    }
99
100    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
101        self.check_cancelled()?;
102        self.port.read(buf).map_err(Into::into)
103    }
104
105    pub fn write_all(&mut self, buf: &[u8]) -> Result<()> {
106        self.check_cancelled()?;
107        self.port.write_all(buf)?;
108        self.check_cancelled()
109    }
110
111    pub fn flush(&mut self) -> Result<()> {
112        self.check_cancelled()?;
113        self.port.flush()?;
114        self.check_cancelled()
115    }
116
117    pub fn clear(&mut self, buffer: ClearBuffer) -> Result<()> {
118        self.check_cancelled()?;
119        self.port.clear(buffer)?;
120        self.check_cancelled()
121    }
122
123    pub fn set_baud_rate(&mut self, baud_rate: u32) -> Result<()> {
124        self.check_cancelled()?;
125        self.port.set_baud_rate(baud_rate)?;
126        self.check_cancelled()
127    }
128
129    pub fn write_request_to_send(&mut self, level: bool) -> Result<()> {
130        self.check_cancelled()?;
131        self.port.write_request_to_send(level)?;
132        self.check_cancelled()
133    }
134
135    pub fn sleep(&self, duration: Duration) -> Result<()> {
136        sleep_with_cancel(&self.cancel_token, duration)
137    }
138
139    pub fn try_clone_reader(&mut self) -> Result<CancelableReader> {
140        self.check_cancelled()?;
141        Ok(CancelableReader {
142            port: self.port.try_clone()?,
143            cancel_token: self.cancel_token.clone(),
144        })
145    }
146
147    pub fn try_clone_writer(&mut self) -> Result<CancelableWriter> {
148        self.check_cancelled()?;
149        Ok(CancelableWriter {
150            port: self.port.try_clone()?,
151            cancel_token: self.cancel_token.clone(),
152        })
153    }
154
155    pub fn read_exact_with_timeout(
156        &mut self,
157        buf: &mut [u8],
158        timeout: Duration,
159        context: &str,
160    ) -> Result<()> {
161        if buf.is_empty() {
162            return Ok(());
163        }
164
165        let mut last_activity = Instant::now();
166        let mut offset = 0usize;
167
168        while offset < buf.len() {
169            self.check_cancelled()?;
170            match self.port.read(&mut buf[offset..]) {
171                Ok(0) => {
172                    if last_activity.elapsed() > timeout {
173                        return Err(Error::timeout(format!("waiting for {}", context)));
174                    }
175                    self.sleep(IDLE_BACKOFF)?;
176                }
177                Ok(n) => {
178                    offset += n;
179                    last_activity = Instant::now();
180                }
181                Err(error)
182                    if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
183                {
184                    if last_activity.elapsed() > timeout {
185                        return Err(Error::timeout(format!("waiting for {}", context)));
186                    }
187                    self.sleep(IDLE_BACKOFF)?;
188                }
189                Err(error) if error.kind() == ErrorKind::Interrupted => continue,
190                Err(error) => return Err(error.into()),
191            }
192        }
193
194        Ok(())
195    }
196
197    pub fn read_line_with_timeout(&mut self, timeout: Duration, context: &str) -> Result<String> {
198        let mut buffer = Vec::new();
199        let mut last_activity = Instant::now();
200
201        loop {
202            self.check_cancelled()?;
203            let mut byte = [0u8; 1];
204            match self.port.read(&mut byte) {
205                Ok(0) => {
206                    if last_activity.elapsed() > timeout {
207                        return Err(Error::timeout(format!("waiting for {}", context)));
208                    }
209                }
210                Ok(_) => {
211                    last_activity = Instant::now();
212                    match byte[0] {
213                        b'\n' => break,
214                        b'\r' => continue,
215                        ch => buffer.push(ch),
216                    }
217                }
218                Err(error)
219                    if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
220                {
221                    if last_activity.elapsed() > timeout {
222                        return Err(Error::timeout(format!("waiting for {}", context)));
223                    }
224                }
225                Err(error) if error.kind() == ErrorKind::Interrupted => continue,
226                Err(error) => return Err(error.into()),
227            }
228        }
229
230        Ok(String::from_utf8_lossy(&buffer).into_owned())
231    }
232
233    pub fn read_non_empty_line_with_timeout(
234        &mut self,
235        timeout: Duration,
236        context: &str,
237    ) -> Result<String> {
238        loop {
239            let line = self.read_line_with_timeout(timeout, context)?;
240            let trimmed = line.trim().to_string();
241            if !trimmed.is_empty() {
242                return Ok(trimmed);
243            }
244        }
245    }
246
247    pub fn wait_for_pattern(
248        &mut self,
249        pattern: &[u8],
250        timeout: Duration,
251        context: &str,
252    ) -> Result<Vec<u8>> {
253        let matched = self.wait_for_patterns(&[pattern], timeout, context)?;
254        Ok(matched.buffer)
255    }
256
257    pub fn wait_for_patterns(
258        &mut self,
259        patterns: &[&[u8]],
260        timeout: Duration,
261        context: &str,
262    ) -> Result<PatternMatch> {
263        let start = Instant::now();
264        let max_len = patterns
265            .iter()
266            .map(|pattern| pattern.len())
267            .max()
268            .unwrap_or(0);
269        let mut buffer = Vec::new();
270        let mut window = VecDeque::with_capacity(max_len.max(1));
271
272        loop {
273            self.check_cancelled()?;
274            if start.elapsed() > timeout {
275                return Err(Error::timeout(format!("waiting for {}", context)));
276            }
277
278            let mut byte = [0u8; 1];
279            match self.port.read(&mut byte) {
280                Ok(0) => continue,
281                Ok(_) => {
282                    buffer.push(byte[0]);
283                    if buffer.len() > MAX_CAPTURE_BUFFER {
284                        let drain_len = buffer.len() - MAX_CAPTURE_BUFFER;
285                        buffer.drain(..drain_len);
286                    }
287                    window.push_back(byte[0]);
288                    if window.len() > max_len {
289                        window.pop_front();
290                    }
291
292                    for (index, pattern) in patterns.iter().enumerate() {
293                        if window.len() >= pattern.len()
294                            && window
295                                .iter()
296                                .rev()
297                                .take(pattern.len())
298                                .rev()
299                                .copied()
300                                .eq(pattern.iter().copied())
301                        {
302                            return Ok(PatternMatch { index, buffer });
303                        }
304                    }
305                }
306                Err(error)
307                    if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
308                {
309                    continue;
310                }
311                Err(error) if error.kind() == ErrorKind::Interrupted => continue,
312                Err(error) => return Err(error.into()),
313            }
314        }
315    }
316
317    pub fn wait_for_prompt(
318        &mut self,
319        prompt: &[u8],
320        retry_interval: Duration,
321        max_retries: u32,
322    ) -> Result<()> {
323        let mut retry_count = 0u32;
324        let mut window = VecDeque::with_capacity(prompt.len().max(1));
325        let mut last_retry = Instant::now();
326
327        self.write_all(b"\r\n")?;
328        self.flush()?;
329
330        loop {
331            self.check_cancelled()?;
332
333            if last_retry.elapsed() > retry_interval {
334                self.clear(ClearBuffer::All)?;
335                self.sleep(Duration::from_millis(100))?;
336                retry_count = retry_count.saturating_add(1);
337                if retry_count > max_retries {
338                    return Err(Error::timeout("waiting for shell prompt"));
339                }
340                last_retry = Instant::now();
341                window.clear();
342                self.write_all(b"\r\n")?;
343                self.flush()?;
344            }
345
346            let mut byte = [0u8; 1];
347            match self.port.read(&mut byte) {
348                Ok(0) => self.sleep(IDLE_BACKOFF)?,
349                Ok(_) => {
350                    window.push_back(byte[0]);
351                    if window.len() > prompt.len() {
352                        window.pop_front();
353                    }
354
355                    if window.len() == prompt.len()
356                        && window.iter().copied().eq(prompt.iter().copied())
357                    {
358                        return Ok(());
359                    }
360                }
361                Err(error)
362                    if matches!(error.kind(), ErrorKind::TimedOut | ErrorKind::WouldBlock) =>
363                {
364                    self.sleep(IDLE_BACKOFF)?;
365                }
366                Err(error) if error.kind() == ErrorKind::Interrupted => continue,
367                Err(error) => return Err(error.into()),
368            }
369        }
370    }
371}
372
373pub fn for_tool<T: SifliToolTrait + ?Sized>(tool: &mut T) -> SerialIo<'_> {
374    let cancel_token = tool.base().cancel_token.clone();
375    SerialIo::new(tool.port().as_mut(), cancel_token)
376}
377
378#[cfg(test)]
379pub(crate) mod test_support {
380    use super::*;
381
382    #[derive(Default)]
383    pub struct TestSerialPortState {
384        pub read_data: VecDeque<u8>,
385        pub writes: Vec<u8>,
386        pub baud_rate: u32,
387        pub timeout: Duration,
388        pub clear_calls: usize,
389        pub rts_history: Vec<bool>,
390        pub write_calls: usize,
391        pub cancel_on_write_call: Option<(usize, CancelToken)>,
392    }
393
394    pub struct TestSerialPort {
395        state: Arc<Mutex<TestSerialPortState>>,
396    }
397
398    impl TestSerialPort {
399        pub fn from_bytes(bytes: &[u8]) -> (Self, Arc<Mutex<TestSerialPortState>>) {
400            let state = Arc::new(Mutex::new(TestSerialPortState {
401                read_data: bytes.iter().copied().collect(),
402                baud_rate: 1_000_000,
403                timeout: Duration::from_millis(5),
404                ..Default::default()
405            }));
406            (
407                Self {
408                    state: state.clone(),
409                },
410                state,
411            )
412        }
413    }
414
415    impl Read for TestSerialPort {
416        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
417            let mut state = self.state.lock().unwrap();
418            if state.read_data.is_empty() {
419                return Err(io::Error::new(ErrorKind::TimedOut, "no data"));
420            }
421
422            let bytes_read = buf.len().min(state.read_data.len());
423            for slot in buf.iter_mut().take(bytes_read) {
424                *slot = state.read_data.pop_front().unwrap();
425            }
426            Ok(bytes_read)
427        }
428    }
429
430    impl Write for TestSerialPort {
431        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
432            let mut state = self.state.lock().unwrap();
433            state.write_calls = state.write_calls.saturating_add(1);
434            state.writes.extend_from_slice(buf);
435            if let Some((target_call, token)) = &state.cancel_on_write_call
436                && state.write_calls >= *target_call
437            {
438                token.cancel();
439            }
440            Ok(buf.len())
441        }
442
443        fn flush(&mut self) -> io::Result<()> {
444            Ok(())
445        }
446    }
447
448    impl SerialPort for TestSerialPort {
449        fn name(&self) -> Option<String> {
450            Some("test-port".to_string())
451        }
452
453        fn baud_rate(&self) -> serialport::Result<u32> {
454            Ok(self.state.lock().unwrap().baud_rate)
455        }
456
457        fn data_bits(&self) -> serialport::Result<DataBits> {
458            Ok(DataBits::Eight)
459        }
460
461        fn flow_control(&self) -> serialport::Result<FlowControl> {
462            Ok(FlowControl::None)
463        }
464
465        fn parity(&self) -> serialport::Result<Parity> {
466            Ok(Parity::None)
467        }
468
469        fn stop_bits(&self) -> serialport::Result<StopBits> {
470            Ok(StopBits::One)
471        }
472
473        fn timeout(&self) -> Duration {
474            self.state.lock().unwrap().timeout
475        }
476
477        fn set_baud_rate(&mut self, baud_rate: u32) -> serialport::Result<()> {
478            self.state.lock().unwrap().baud_rate = baud_rate;
479            Ok(())
480        }
481
482        fn set_data_bits(&mut self, _: DataBits) -> serialport::Result<()> {
483            Ok(())
484        }
485
486        fn set_flow_control(&mut self, _: FlowControl) -> serialport::Result<()> {
487            Ok(())
488        }
489
490        fn set_parity(&mut self, _: Parity) -> serialport::Result<()> {
491            Ok(())
492        }
493
494        fn set_stop_bits(&mut self, _: StopBits) -> serialport::Result<()> {
495            Ok(())
496        }
497
498        fn set_timeout(&mut self, timeout: Duration) -> serialport::Result<()> {
499            self.state.lock().unwrap().timeout = timeout;
500            Ok(())
501        }
502
503        fn write_request_to_send(&mut self, level: bool) -> serialport::Result<()> {
504            self.state.lock().unwrap().rts_history.push(level);
505            Ok(())
506        }
507
508        fn write_data_terminal_ready(&mut self, _: bool) -> serialport::Result<()> {
509            Ok(())
510        }
511
512        fn read_clear_to_send(&mut self) -> serialport::Result<bool> {
513            Ok(false)
514        }
515
516        fn read_data_set_ready(&mut self) -> serialport::Result<bool> {
517            Ok(false)
518        }
519
520        fn read_ring_indicator(&mut self) -> serialport::Result<bool> {
521            Ok(false)
522        }
523
524        fn read_carrier_detect(&mut self) -> serialport::Result<bool> {
525            Ok(false)
526        }
527
528        fn bytes_to_read(&self) -> serialport::Result<u32> {
529            Ok(self.state.lock().unwrap().read_data.len() as u32)
530        }
531
532        fn bytes_to_write(&self) -> serialport::Result<u32> {
533            Ok(0)
534        }
535
536        fn clear(&self, _: ClearBuffer) -> serialport::Result<()> {
537            self.state.lock().unwrap().clear_calls += 1;
538            Ok(())
539        }
540
541        fn try_clone(&self) -> serialport::Result<Box<dyn SerialPort>> {
542            Ok(Box::new(Self {
543                state: self.state.clone(),
544            }))
545        }
546
547        fn set_break(&self) -> serialport::Result<()> {
548            Ok(())
549        }
550
551        fn clear_break(&self) -> serialport::Result<()> {
552            Ok(())
553        }
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::{Duration, *};
560    use crate::CancelToken;
561
562    #[test]
563    fn wait_for_pattern_stops_when_cancelled() {
564        let (mut port, _) = test_support::TestSerialPort::from_bytes(&[]);
565        let token = CancelToken::new();
566        token.cancel();
567        let mut io = SerialIo::new(&mut port, token);
568
569        let result = io.wait_for_pattern(b"OK", Duration::from_millis(50), "OK response");
570
571        assert!(matches!(result, Err(Error::Cancelled)));
572    }
573
574    #[test]
575    fn wait_for_prompt_retries_and_can_be_cancelled() {
576        let (mut port, _) = test_support::TestSerialPort::from_bytes(&[]);
577        let token = CancelToken::new();
578        token.cancel();
579        let mut io = SerialIo::new(&mut port, token);
580
581        let result = io.wait_for_prompt(b"msh >", Duration::from_millis(50), 1);
582
583        assert!(matches!(result, Err(Error::Cancelled)));
584    }
585
586    #[test]
587    fn cloned_reader_reports_cancelled_io_error() {
588        let (mut port, state) = test_support::TestSerialPort::from_bytes(b"abc");
589        let token = CancelToken::new();
590        state.lock().unwrap().cancel_on_write_call = Some((1, token.clone()));
591        let mut io = SerialIo::new(&mut port, token);
592
593        let mut reader = io.try_clone_reader().unwrap();
594        let mut writer = io.try_clone_writer().unwrap();
595        writer.write_all(b"x").unwrap();
596
597        let mut buffer = [0u8; 1];
598        let error = reader.read(&mut buffer).unwrap_err();
599        assert!(is_cancelled_io_error(&error));
600    }
601
602    #[test]
603    fn wait_for_patterns_bounds_captured_buffer() {
604        let mut bytes = vec![b'a'; MAX_CAPTURE_BUFFER + 32];
605        bytes.extend_from_slice(b"OK");
606        let (mut port, _) = test_support::TestSerialPort::from_bytes(&bytes);
607        let token = CancelToken::new();
608        let mut io = SerialIo::new(&mut port, token);
609
610        let matched = io
611            .wait_for_patterns(&[b"OK"], Duration::from_millis(100), "OK response")
612            .unwrap();
613
614        assert!(matched.buffer.len() <= MAX_CAPTURE_BUFFER);
615        assert!(matched.buffer.ends_with(b"OK"));
616    }
617}