Skip to main content

uboot_shell/
lib.rs

1//! Async U-Boot shell communication over runtime-neutral futures I/O.
2
3#[macro_use]
4extern crate log;
5
6use std::{
7    io::{Error, ErrorKind, Result, stdout},
8    path::{Path, PathBuf},
9    pin::Pin,
10    task::{Context, Poll},
11    time::Duration,
12};
13
14use futures::{
15    AsyncReadExt, AsyncWriteExt,
16    future::{Either, FutureExt, select},
17    io::{AllowStdIo, AsyncRead, AsyncWrite},
18    pin_mut,
19};
20use futures_timer::Delay;
21
22/// CRC16-CCITT checksum implementation.
23pub mod crc;
24
25/// YMODEM file transfer protocol implementation.
26pub mod ymodem;
27
28macro_rules! dbg {
29    ($($arg:tt)*) => {{
30        debug!("$ {}", &std::fmt::format(format_args!($($arg)*)));
31    }};
32}
33
34const CTRL_C: u8 = 0x03;
35const INT_STR: &str = "<INTERRUPT>";
36const INT: &[u8] = INT_STR.as_bytes();
37const LOADY_MAX_ATTEMPTS: usize = 3;
38const LOADY_RETRY_DELAY: Duration = Duration::from_millis(300);
39
40type Tx = Box<dyn AsyncWrite + Send + Unpin>;
41type Rx = Box<dyn AsyncRead + Send + Unpin>;
42
43pub struct UbootShell {
44    /// Transmit stream for sending bytes to U-Boot.
45    pub tx: Option<Tx>,
46    /// Receive stream for reading bytes from U-Boot.
47    pub rx: Option<Rx>,
48    /// Shell prompt prefix detected during initialization.
49    perfix: String,
50}
51
52impl UbootShell {
53    pub async fn new(
54        tx: impl AsyncWrite + Send + Unpin + 'static,
55        rx: impl AsyncRead + Send + Unpin + 'static,
56    ) -> Result<Self> {
57        let mut shell = Self {
58            tx: Some(Box::new(tx)),
59            rx: Some(Box::new(rx)),
60            perfix: String::new(),
61        };
62        shell.wait_for_shell().await?;
63        debug!("shell ready, perfix: `{}`", shell.perfix);
64        Ok(shell)
65    }
66
67    fn rx(&mut self) -> &mut Rx {
68        self.rx.as_mut().unwrap()
69    }
70
71    fn tx(&mut self) -> &mut Tx {
72        self.tx.as_mut().unwrap()
73    }
74
75    async fn wait_for_interrupt(&mut self) -> Result<Vec<u8>> {
76        let mut history = Vec::new();
77        let mut interrupt_line = Vec::new();
78        let interval = Duration::from_millis(20);
79        let mut last_interrupt = std::time::Instant::now() - interval;
80
81        debug!("wait for interrupt");
82        loop {
83            if last_interrupt.elapsed() >= interval {
84                self.tx().write_all(&[CTRL_C]).await?;
85                self.tx().flush().await?;
86                last_interrupt = std::time::Instant::now();
87            }
88
89            match self.read_byte_with_timeout(interval).await {
90                Ok(ch) => {
91                    history.push(ch);
92                    if history.last() == Some(&b'\n') {
93                        let line = history.trim_ascii_end();
94                        dbg!("{}", String::from_utf8_lossy(line));
95                        let interrupted = line.ends_with(INT);
96                        if interrupted {
97                            interrupt_line.extend_from_slice(line);
98                        }
99                        history.clear();
100                        if interrupted {
101                            break;
102                        }
103                    }
104                }
105                Err(err) if err.kind() == ErrorKind::TimedOut => {}
106                Err(err) => return Err(err),
107            }
108        }
109
110        Ok(interrupt_line)
111    }
112
113    async fn clear_shell(&mut self) -> Result<()> {
114        loop {
115            match self
116                .read_byte_with_timeout(Duration::from_millis(300))
117                .await
118            {
119                Ok(_) => {}
120                Err(err) if err.kind() == ErrorKind::TimedOut => return Ok(()),
121                Err(err) => return Err(err),
122            }
123        }
124    }
125
126    async fn wait_for_shell(&mut self) -> Result<()> {
127        let mut line = self.wait_for_interrupt().await?;
128        debug!("got {}", String::from_utf8_lossy(&line));
129        line.resize(line.len().saturating_sub(INT.len()), 0);
130        self.perfix = String::from_utf8_lossy(&line).to_string();
131        self.clear_shell().await?;
132        Ok(())
133    }
134
135    async fn read_byte(&mut self) -> Result<u8> {
136        self.read_byte_with_timeout(Duration::from_secs(5)).await
137    }
138
139    async fn read_byte_with_timeout(&mut self, timeout_limit: Duration) -> Result<u8> {
140        let mut buff = [0u8; 1];
141        let start = std::time::Instant::now();
142
143        loop {
144            let read = self.rx().read_exact(&mut buff).fuse();
145            let delay = Delay::new(Duration::from_millis(200)).fuse();
146            pin_mut!(read, delay);
147
148            match select(read, delay).await {
149                Either::Left((Ok(_), _)) => return Ok(buff[0]),
150                Either::Left((Err(err), _)) => return Err(err),
151                Either::Right((_, _)) => {
152                    if start.elapsed() > timeout_limit {
153                        return Err(Error::new(ErrorKind::TimedOut, "Timeout"));
154                    }
155                }
156            }
157        }
158    }
159
160    pub async fn wait_for_reply(&mut self, val: &str) -> Result<String> {
161        let mut reply = Vec::new();
162        let mut display = Vec::new();
163        debug!("wait for `{val}`");
164
165        loop {
166            let byte = self.read_byte().await?;
167            reply.push(byte);
168            display.push(byte);
169            if byte == b'\n' {
170                dbg!("{}", String::from_utf8_lossy(&display).trim_end());
171                display.clear();
172            }
173
174            if reply.ends_with(val.as_bytes()) {
175                dbg!("{}", String::from_utf8_lossy(&display).trim_end());
176                break;
177            }
178        }
179
180        Ok(String::from_utf8_lossy(&reply)
181            .trim()
182            .trim_end_matches(&self.perfix)
183            .to_string())
184    }
185
186    pub async fn cmd_without_reply(&mut self, cmd: &str) -> Result<()> {
187        self.tx().write_all(cmd.as_bytes()).await?;
188        self.tx().write_all(b"\n").await?;
189        self.tx().flush().await?;
190        Ok(())
191    }
192
193    async fn _cmd(&mut self, cmd: &str) -> Result<String> {
194        self.clear_shell().await?;
195        let ok_str = "cmd-ok";
196        let cmd_with_id = format!("{cmd}&& echo {ok_str}");
197        self.cmd_without_reply(&cmd_with_id).await?;
198        let perfix = self.perfix.clone();
199        let res = self
200            .wait_for_reply(&perfix)
201            .await?
202            .trim_end()
203            .trim_end_matches(self.perfix.as_str().trim())
204            .trim_end()
205            .to_string();
206
207        if res.ends_with(ok_str) {
208            Ok(res
209                .trim()
210                .trim_end_matches(ok_str)
211                .trim_end()
212                .trim_start_matches(&cmd_with_id)
213                .trim()
214                .to_string())
215        } else {
216            Err(Error::other(format!(
217                "command `{cmd}` failed, response: {res}",
218            )))
219        }
220    }
221
222    pub async fn cmd(&mut self, cmd: &str) -> Result<String> {
223        info!("cmd: {cmd}");
224        let mut retry = 3;
225        while retry > 0 {
226            match self._cmd(cmd).await {
227                Ok(res) => return Ok(res),
228                Err(err) => {
229                    warn!("cmd `{cmd}` failed: {err}, retrying...");
230                    retry -= 1;
231                    Delay::new(Duration::from_millis(100)).await;
232                }
233            }
234        }
235        Err(Error::other(format!(
236            "command `{cmd}` failed after retries",
237        )))
238    }
239
240    pub async fn set_env(
241        &mut self,
242        name: impl Into<String>,
243        value: impl Into<String>,
244    ) -> Result<()> {
245        self.cmd(&format!("setenv {} {}", name.into(), value.into()))
246            .await?;
247        Ok(())
248    }
249
250    pub async fn env(&mut self, name: impl Into<String>) -> Result<String> {
251        let name = name.into();
252        let s = self.cmd(&format!("echo ${name}")).await?;
253        let parts = s
254            .split('\n')
255            .filter(|line| !line.trim().is_empty())
256            .collect::<Vec<_>>();
257        let value = parts
258            .last()
259            .ok_or(Error::new(
260                ErrorKind::NotFound,
261                format!("env {name} not found"),
262            ))?
263            .to_string();
264        Ok(value)
265    }
266
267    pub async fn env_int(&mut self, name: impl Into<String>) -> Result<usize> {
268        let name = name.into();
269        let line = self.env(&name).await?;
270        debug!("env {name} = {line}");
271
272        parse_int(&line).ok_or(Error::new(
273            ErrorKind::InvalidData,
274            format!("env {name} is not a number"),
275        ))
276    }
277
278    pub async fn loady(
279        &mut self,
280        addr: usize,
281        file: impl Into<PathBuf>,
282        on_progress: impl Fn(usize, usize),
283    ) -> Result<String> {
284        let file = file.into();
285
286        for attempt in 1..=LOADY_MAX_ATTEMPTS {
287            match self.loady_once(addr, &file, &on_progress).await {
288                Ok(reply) => return Ok(reply),
289                Err(err) if attempt < LOADY_MAX_ATTEMPTS => {
290                    warn!(
291                        "loady attempt {attempt}/{LOADY_MAX_ATTEMPTS} failed: {err}; retrying..."
292                    );
293                    self.wait_for_shell().await.map_err(|recover_err| {
294                        Error::other(format!(
295                            "loady attempt {attempt} failed and shell recovery failed: {recover_err}",
296                        ))
297                    })?;
298                    Delay::new(LOADY_RETRY_DELAY).await;
299                }
300                Err(err) => {
301                    return Err(Error::other(format!(
302                        "loady failed after {LOADY_MAX_ATTEMPTS} attempts: {err}"
303                    )));
304                }
305            }
306        }
307
308        unreachable!("LOADY_MAX_ATTEMPTS must be greater than zero")
309    }
310
311    async fn loady_once(
312        &mut self,
313        addr: usize,
314        file: &Path,
315        on_progress: &impl Fn(usize, usize),
316    ) -> Result<String> {
317        self.clear_shell().await?;
318        self.cmd_without_reply(&format!("loady {addr:#x}")).await?;
319        let crc = self.wait_for_load_crc().await?;
320        let mut protocol = ymodem::Ymodem::new(crc);
321
322        let name = file
323            .file_name()
324            .and_then(|name| name.to_str())
325            .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "file name must be valid UTF-8"))?;
326        let size = std::fs::metadata(file)?.len() as usize;
327        let mut file = AllowStdIo::new(std::fs::File::open(file)?);
328
329        on_progress(0, size);
330        protocol
331            .send(self, &mut file, name, size, |sent| on_progress(sent, size))
332            .await?;
333        let perfix = self.perfix.clone();
334        self.wait_for_reply(&perfix).await
335    }
336
337    async fn wait_for_load_crc(&mut self) -> Result<bool> {
338        let mut reply = Vec::new();
339        loop {
340            let byte = self.read_byte().await?;
341            reply.push(byte);
342            print_raw(&[byte]).await?;
343
344            if reply.ends_with(b"C") {
345                return Ok(true);
346            }
347            let res = String::from_utf8_lossy(&reply);
348            if res.contains("try 'help'") {
349                return Err(Error::new(
350                    ErrorKind::InvalidData,
351                    format!("U-Boot loady failed: {res}"),
352                ));
353            }
354        }
355    }
356}
357
358impl AsyncRead for UbootShell {
359    fn poll_read(
360        self: Pin<&mut Self>,
361        cx: &mut Context<'_>,
362        buf: &mut [u8],
363    ) -> Poll<Result<usize>> {
364        let this = self.get_mut();
365        Pin::new(this.rx.as_mut().unwrap().as_mut()).poll_read(cx, buf)
366    }
367}
368
369impl AsyncWrite for UbootShell {
370    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
371        let this = self.get_mut();
372        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_write(cx, buf)
373    }
374
375    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
376        let this = self.get_mut();
377        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_flush(cx)
378    }
379
380    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
381        let this = self.get_mut();
382        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_close(cx)
383    }
384}
385
386fn parse_int(line: &str) -> Option<usize> {
387    let mut line = line.trim();
388    let mut radix = 10;
389    if line.starts_with("0x") {
390        line = &line[2..];
391        radix = 16;
392    }
393    u64::from_str_radix(line, radix)
394        .ok()
395        .map(|value| value as usize)
396}
397
398async fn print_raw(buff: &[u8]) -> Result<()> {
399    #[cfg(target_os = "windows")]
400    {
401        print_raw_win(buff);
402        Ok(())
403    }
404    #[cfg(not(target_os = "windows"))]
405    {
406        let mut out = AllowStdIo::new(stdout());
407        out.write_all(buff).await
408    }
409}
410
411#[cfg(target_os = "windows")]
412fn print_raw_win(buff: &[u8]) {
413    use std::sync::Mutex;
414    static PRINT_BUFF: Mutex<Vec<u8>> = Mutex::new(Vec::new());
415
416    let mut g = PRINT_BUFF.lock().unwrap();
417    g.extend_from_slice(buff);
418
419    if g.ends_with(b"\n") {
420        let s = String::from_utf8_lossy(&g[..]);
421        println!("{}", s.trim());
422        g.clear();
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use std::{
430        collections::VecDeque,
431        fs,
432        sync::{Arc, Mutex},
433    };
434
435    #[derive(Default)]
436    struct LoadyScript {
437        reads: VecDeque<u8>,
438        writes: Vec<u8>,
439        command: Vec<u8>,
440        loady_count: usize,
441        interrupted: bool,
442        accepting_commands: bool,
443    }
444
445    impl LoadyScript {
446        fn queue_read(&mut self, bytes: impl AsRef<[u8]>) {
447            self.reads.extend(bytes.as_ref());
448        }
449
450        fn handle_write(&mut self, bytes: &[u8]) {
451            self.writes.extend_from_slice(bytes);
452
453            if bytes == [CTRL_C] {
454                self.command.clear();
455                self.accepting_commands = true;
456                if !self.interrupted {
457                    self.interrupted = true;
458                    self.queue_read(b"=> <INTERRUPT>\n");
459                }
460                return;
461            }
462
463            if !self.accepting_commands {
464                return;
465            }
466
467            for &byte in bytes {
468                self.command.push(byte);
469                if byte == b'\n' {
470                    let command = std::mem::take(&mut self.command);
471                    if command.starts_with(b"loady ") {
472                        self.loady_count += 1;
473                        self.accepting_commands = false;
474                        self.queue_loady_response();
475                    }
476                } else if self.command.len() > 256 {
477                    self.command.clear();
478                }
479            }
480        }
481
482        fn queue_loady_response(&mut self) {
483            match self.loady_count {
484                1 => {
485                    self.queue_read(*b"C");
486                    self.queue_read([ymodem::CRC; ymodem::DEFAULT_BLOCK_RETRIES]);
487                }
488                2 => {
489                    self.queue_read(*b"C");
490                    self.queue_read([ymodem::ACK, ymodem::ACK, ymodem::ACK, ymodem::ACK, b'C']);
491                    self.queue_read(b"done\n=> ");
492                }
493                _ => {}
494            }
495        }
496    }
497
498    #[derive(Clone)]
499    struct ScriptedTx {
500        script: Arc<Mutex<LoadyScript>>,
501    }
502
503    #[derive(Clone)]
504    struct ScriptedRx {
505        script: Arc<Mutex<LoadyScript>>,
506    }
507
508    impl AsyncWrite for ScriptedTx {
509        fn poll_write(
510            self: Pin<&mut Self>,
511            _cx: &mut Context<'_>,
512            buf: &[u8],
513        ) -> Poll<Result<usize>> {
514            self.script.lock().unwrap().handle_write(buf);
515            Poll::Ready(Ok(buf.len()))
516        }
517
518        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
519            Poll::Ready(Ok(()))
520        }
521
522        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
523            Poll::Ready(Ok(()))
524        }
525    }
526
527    impl AsyncRead for ScriptedRx {
528        fn poll_read(
529            self: Pin<&mut Self>,
530            _cx: &mut Context<'_>,
531            buf: &mut [u8],
532        ) -> Poll<Result<usize>> {
533            let mut script = self.script.lock().unwrap();
534            if script.reads.is_empty() {
535                return Poll::Pending;
536            }
537
538            let n = buf.len().min(script.reads.len());
539            for slot in &mut buf[..n] {
540                *slot = script.reads.pop_front().unwrap();
541            }
542            Poll::Ready(Ok(n))
543        }
544    }
545
546    #[tokio::test]
547    async fn loady_restarts_transfer_after_receiver_rejects_first_attempt() -> Result<()> {
548        let script = Arc::new(Mutex::new(LoadyScript::default()));
549        script.lock().unwrap().accepting_commands = true;
550        let mut shell = UbootShell {
551            tx: Some(Box::new(ScriptedTx {
552                script: script.clone(),
553            })),
554            rx: Some(Box::new(ScriptedRx {
555                script: script.clone(),
556            })),
557            perfix: "=> ".to_string(),
558        };
559
560        let file =
561            std::env::temp_dir().join(format!("uboot-shell-loady-retry-{}", std::process::id()));
562        fs::write(&file, b"payload")?;
563
564        let progress = Arc::new(Mutex::new(Vec::new()));
565        let reply = shell
566            .loady(0x80200000, file.clone(), {
567                let progress = progress.clone();
568                move |sent, size| progress.lock().unwrap().push((sent, size))
569            })
570            .await;
571        let _ = fs::remove_file(&file);
572
573        assert!(reply?.contains("done"));
574        let script = script.lock().unwrap();
575        let writes = String::from_utf8_lossy(&script.writes);
576        assert_eq!(writes.matches("loady 0x80200000").count(), 2);
577        assert!(script.writes.contains(&CTRL_C));
578        assert_eq!(*progress.lock().unwrap(), vec![(0, 7), (0, 7), (7, 7)]);
579        Ok(())
580    }
581}