Skip to main content

uboot_shell/
lib.rs

1//! Async U-Boot shell communication over runtime-neutral futures I/O.
2
3#[macro_use]
4extern crate log;
5
6use std::{
7    io::{Error, ErrorKind, Result, stdout},
8    path::PathBuf,
9    pin::Pin,
10    task::{Context, Poll},
11    time::Duration,
12};
13
14use futures::{
15    AsyncReadExt, AsyncWriteExt,
16    future::{Either, FutureExt, select},
17    io::{AllowStdIo, AsyncRead, AsyncWrite},
18    pin_mut,
19};
20use futures_timer::Delay;
21
22/// CRC16-CCITT checksum implementation.
23pub mod crc;
24
25/// YMODEM file transfer protocol implementation.
26pub mod ymodem;
27
28macro_rules! dbg {
29    ($($arg:tt)*) => {{
30        debug!("$ {}", &std::fmt::format(format_args!($($arg)*)));
31    }};
32}
33
34const CTRL_C: u8 = 0x03;
35const INT_STR: &str = "<INTERRUPT>";
36const INT: &[u8] = INT_STR.as_bytes();
37
38type Tx = Box<dyn AsyncWrite + Send + Unpin>;
39type Rx = Box<dyn AsyncRead + Send + Unpin>;
40
41pub struct UbootShell {
42    /// Transmit stream for sending bytes to U-Boot.
43    pub tx: Option<Tx>,
44    /// Receive stream for reading bytes from U-Boot.
45    pub rx: Option<Rx>,
46    /// Shell prompt prefix detected during initialization.
47    perfix: String,
48}
49
50impl UbootShell {
51    pub async fn new(
52        tx: impl AsyncWrite + Send + Unpin + 'static,
53        rx: impl AsyncRead + Send + Unpin + 'static,
54    ) -> Result<Self> {
55        let mut shell = Self {
56            tx: Some(Box::new(tx)),
57            rx: Some(Box::new(rx)),
58            perfix: String::new(),
59        };
60        shell.wait_for_shell().await?;
61        debug!("shell ready, perfix: `{}`", shell.perfix);
62        Ok(shell)
63    }
64
65    fn rx(&mut self) -> &mut Rx {
66        self.rx.as_mut().unwrap()
67    }
68
69    fn tx(&mut self) -> &mut Tx {
70        self.tx.as_mut().unwrap()
71    }
72
73    async fn wait_for_interrupt(&mut self) -> Result<Vec<u8>> {
74        let mut history = Vec::new();
75        let mut interrupt_line = Vec::new();
76        let interval = Duration::from_millis(20);
77        let mut last_interrupt = std::time::Instant::now() - interval;
78
79        debug!("wait for interrupt");
80        loop {
81            if last_interrupt.elapsed() >= interval {
82                self.tx().write_all(&[CTRL_C]).await?;
83                self.tx().flush().await?;
84                last_interrupt = std::time::Instant::now();
85            }
86
87            match self.read_byte_with_timeout(interval).await {
88                Ok(ch) => {
89                    history.push(ch);
90                    if history.last() == Some(&b'\n') {
91                        let line = history.trim_ascii_end();
92                        dbg!("{}", String::from_utf8_lossy(line));
93                        let interrupted = line.ends_with(INT);
94                        if interrupted {
95                            interrupt_line.extend_from_slice(line);
96                        }
97                        history.clear();
98                        if interrupted {
99                            break;
100                        }
101                    }
102                }
103                Err(err) if err.kind() == ErrorKind::TimedOut => {}
104                Err(err) => return Err(err),
105            }
106        }
107
108        Ok(interrupt_line)
109    }
110
111    async fn clear_shell(&mut self) -> Result<()> {
112        loop {
113            match self
114                .read_byte_with_timeout(Duration::from_millis(300))
115                .await
116            {
117                Ok(_) => {}
118                Err(err) if err.kind() == ErrorKind::TimedOut => return Ok(()),
119                Err(err) => return Err(err),
120            }
121        }
122    }
123
124    async fn wait_for_shell(&mut self) -> Result<()> {
125        let mut line = self.wait_for_interrupt().await?;
126        debug!("got {}", String::from_utf8_lossy(&line));
127        line.resize(line.len().saturating_sub(INT.len()), 0);
128        self.perfix = String::from_utf8_lossy(&line).to_string();
129        self.clear_shell().await?;
130        Ok(())
131    }
132
133    async fn read_byte(&mut self) -> Result<u8> {
134        self.read_byte_with_timeout(Duration::from_secs(5)).await
135    }
136
137    async fn read_byte_with_timeout(&mut self, timeout_limit: Duration) -> Result<u8> {
138        let mut buff = [0u8; 1];
139        let start = std::time::Instant::now();
140
141        loop {
142            let read = self.rx().read_exact(&mut buff).fuse();
143            let delay = Delay::new(Duration::from_millis(200)).fuse();
144            pin_mut!(read, delay);
145
146            match select(read, delay).await {
147                Either::Left((Ok(_), _)) => return Ok(buff[0]),
148                Either::Left((Err(err), _)) => return Err(err),
149                Either::Right((_, _)) => {
150                    if start.elapsed() > timeout_limit {
151                        return Err(Error::new(ErrorKind::TimedOut, "Timeout"));
152                    }
153                }
154            }
155        }
156    }
157
158    pub async fn wait_for_reply(&mut self, val: &str) -> Result<String> {
159        let mut reply = Vec::new();
160        let mut display = Vec::new();
161        debug!("wait for `{}`", val);
162
163        loop {
164            let byte = self.read_byte().await?;
165            reply.push(byte);
166            display.push(byte);
167            if byte == b'\n' {
168                dbg!("{}", String::from_utf8_lossy(&display).trim_end());
169                display.clear();
170            }
171
172            if reply.ends_with(val.as_bytes()) {
173                dbg!("{}", String::from_utf8_lossy(&display).trim_end());
174                break;
175            }
176        }
177
178        Ok(String::from_utf8_lossy(&reply)
179            .trim()
180            .trim_end_matches(&self.perfix)
181            .to_string())
182    }
183
184    pub async fn cmd_without_reply(&mut self, cmd: &str) -> Result<()> {
185        self.tx().write_all(cmd.as_bytes()).await?;
186        self.tx().write_all(b"\n").await?;
187        self.tx().flush().await?;
188        Ok(())
189    }
190
191    async fn _cmd(&mut self, cmd: &str) -> Result<String> {
192        self.clear_shell().await?;
193        let ok_str = "cmd-ok";
194        let cmd_with_id = format!("{cmd}&& echo {ok_str}");
195        self.cmd_without_reply(&cmd_with_id).await?;
196        let perfix = self.perfix.clone();
197        let res = self
198            .wait_for_reply(&perfix)
199            .await?
200            .trim_end()
201            .trim_end_matches(self.perfix.as_str().trim())
202            .trim_end()
203            .to_string();
204
205        if res.ends_with(ok_str) {
206            Ok(res
207                .trim()
208                .trim_end_matches(ok_str)
209                .trim_end()
210                .trim_start_matches(&cmd_with_id)
211                .trim()
212                .to_string())
213        } else {
214            Err(Error::other(format!(
215                "command `{cmd}` failed, response: {res}",
216            )))
217        }
218    }
219
220    pub async fn cmd(&mut self, cmd: &str) -> Result<String> {
221        info!("cmd: {cmd}");
222        let mut retry = 3;
223        while retry > 0 {
224            match self._cmd(cmd).await {
225                Ok(res) => return Ok(res),
226                Err(err) => {
227                    warn!("cmd `{}` failed: {}, retrying...", cmd, err);
228                    retry -= 1;
229                    Delay::new(Duration::from_millis(100)).await;
230                }
231            }
232        }
233        Err(Error::other(format!(
234            "command `{cmd}` failed after retries",
235        )))
236    }
237
238    pub async fn set_env(
239        &mut self,
240        name: impl Into<String>,
241        value: impl Into<String>,
242    ) -> Result<()> {
243        self.cmd(&format!("setenv {} {}", name.into(), value.into()))
244            .await?;
245        Ok(())
246    }
247
248    pub async fn env(&mut self, name: impl Into<String>) -> Result<String> {
249        let name = name.into();
250        let s = self.cmd(&format!("echo ${name}")).await?;
251        let parts = s
252            .split('\n')
253            .filter(|line| !line.trim().is_empty())
254            .collect::<Vec<_>>();
255        let value = parts
256            .last()
257            .ok_or(Error::new(
258                ErrorKind::NotFound,
259                format!("env {name} not found"),
260            ))?
261            .to_string();
262        Ok(value)
263    }
264
265    pub async fn env_int(&mut self, name: impl Into<String>) -> Result<usize> {
266        let name = name.into();
267        let line = self.env(&name).await?;
268        debug!("env {name} = {line}");
269
270        parse_int(&line).ok_or(Error::new(
271            ErrorKind::InvalidData,
272            format!("env {name} is not a number"),
273        ))
274    }
275
276    pub async fn loady(
277        &mut self,
278        addr: usize,
279        file: impl Into<PathBuf>,
280        on_progress: impl Fn(usize, usize),
281    ) -> Result<String> {
282        self.cmd_without_reply(&format!("loady {addr:#x}")).await?;
283        let crc = self.wait_for_load_crc().await?;
284        let mut protocol = ymodem::Ymodem::new(crc);
285
286        let file = file.into();
287        let name = file
288            .file_name()
289            .and_then(|name| name.to_str())
290            .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "file name must be valid UTF-8"))?;
291        let size = std::fs::metadata(&file)?.len() as usize;
292        let mut file = AllowStdIo::new(std::fs::File::open(&file)?);
293
294        protocol
295            .send(self, &mut file, name, size, |sent| on_progress(sent, size))
296            .await?;
297        let perfix = self.perfix.clone();
298        self.wait_for_reply(&perfix).await
299    }
300
301    async fn wait_for_load_crc(&mut self) -> Result<bool> {
302        let mut reply = Vec::new();
303        loop {
304            let byte = self.read_byte().await?;
305            reply.push(byte);
306            print_raw(&[byte]).await?;
307
308            if reply.ends_with(b"C") {
309                return Ok(true);
310            }
311            let res = String::from_utf8_lossy(&reply);
312            if res.contains("try 'help'") {
313                return Err(Error::new(
314                    ErrorKind::InvalidData,
315                    format!("U-Boot loady failed: {res}"),
316                ));
317            }
318        }
319    }
320}
321
322impl AsyncRead for UbootShell {
323    fn poll_read(
324        self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326        buf: &mut [u8],
327    ) -> Poll<Result<usize>> {
328        let this = self.get_mut();
329        Pin::new(this.rx.as_mut().unwrap().as_mut()).poll_read(cx, buf)
330    }
331}
332
333impl AsyncWrite for UbootShell {
334    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
335        let this = self.get_mut();
336        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_write(cx, buf)
337    }
338
339    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
340        let this = self.get_mut();
341        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_flush(cx)
342    }
343
344    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
345        let this = self.get_mut();
346        Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_close(cx)
347    }
348}
349
350fn parse_int(line: &str) -> Option<usize> {
351    let mut line = line.trim();
352    let mut radix = 10;
353    if line.starts_with("0x") {
354        line = &line[2..];
355        radix = 16;
356    }
357    u64::from_str_radix(line, radix)
358        .ok()
359        .map(|value| value as usize)
360}
361
362async fn print_raw(buff: &[u8]) -> Result<()> {
363    #[cfg(target_os = "windows")]
364    {
365        print_raw_win(buff);
366        Ok(())
367    }
368    #[cfg(not(target_os = "windows"))]
369    {
370        let mut out = AllowStdIo::new(stdout());
371        out.write_all(buff).await
372    }
373}
374
375#[cfg(target_os = "windows")]
376fn print_raw_win(buff: &[u8]) {
377    use std::sync::Mutex;
378    static PRINT_BUFF: Mutex<Vec<u8>> = Mutex::new(Vec::new());
379
380    let mut g = PRINT_BUFF.lock().unwrap();
381    g.extend_from_slice(buff);
382
383    if g.ends_with(b"\n") {
384        let s = String::from_utf8_lossy(&g[..]);
385        println!("{}", s.trim());
386        g.clear();
387    }
388}