Skip to main content

uboot_shell/
ymodem.rs

1//! Async YMODEM file transfer protocol implementation.
2
3use std::io::{Error, ErrorKind, Result};
4
5use futures::{
6    AsyncReadExt, AsyncWriteExt,
7    io::{AllowStdIo, AsyncRead, AsyncWrite},
8};
9
10use crate::crc::crc16_ccitt;
11
12const SOH: u8 = 0x01;
13const STX: u8 = 0x02;
14const EOT: u8 = 0x04;
15pub(crate) const ACK: u8 = 0x06;
16pub(crate) const NAK: u8 = 0x15;
17const EOF: u8 = 0x1A;
18pub(crate) const CRC: u8 = 0x43;
19pub(crate) const DEFAULT_BLOCK_RETRIES: usize = 10;
20
21pub struct Ymodem {
22    crc_mode: bool,
23    blk: u8,
24    max_block_retries: usize,
25}
26
27impl Ymodem {
28    pub fn new(crc_mode: bool) -> Self {
29        Self {
30            crc_mode,
31            blk: 0,
32            max_block_retries: DEFAULT_BLOCK_RETRIES,
33        }
34    }
35
36    fn nak(&self) -> u8 {
37        if self.crc_mode { CRC } else { NAK }
38    }
39
40    async fn getc<D: AsyncRead + Unpin>(&mut self, dev: &mut D) -> Result<u8> {
41        let mut buff = [0u8; 1];
42        dev.read_exact(&mut buff).await?;
43        Ok(buff[0])
44    }
45
46    async fn wait_for_start<D: AsyncRead + Unpin>(&mut self, dev: &mut D) -> Result<()> {
47        loop {
48            match self.getc(dev).await? {
49                NAK => {
50                    self.crc_mode = false;
51                    return Ok(());
52                }
53                CRC => {
54                    self.crc_mode = true;
55                    return Ok(());
56                }
57                _ => {}
58            }
59        }
60    }
61
62    pub async fn send<D, F>(
63        &mut self,
64        dev: &mut D,
65        file: &mut F,
66        name: &str,
67        size: usize,
68        on_progress: impl Fn(usize),
69    ) -> Result<()>
70    where
71        D: AsyncWrite + AsyncRead + Unpin,
72        F: AsyncRead + Unpin,
73    {
74        info!("Sending file: {name}");
75
76        self.send_header(dev, name, size).await?;
77
78        let mut buff = [0u8; 1024];
79        let mut send_size = 0;
80
81        loop {
82            let n = file.read(&mut buff).await?;
83            if n == 0 {
84                break;
85            }
86            self.send_blk(dev, &buff[..n], EOF, false).await?;
87            send_size += n;
88            on_progress(send_size);
89        }
90
91        dev.write_all(&[EOT]).await?;
92        dev.flush().await?;
93        self.wait_ack(dev).await?;
94
95        self.send_blk(dev, &[0], 0, true).await?;
96        self.wait_for_start(dev).await?;
97        Ok(())
98    }
99
100    async fn wait_ack<D: AsyncRead + Unpin>(&mut self, dev: &mut D) -> Result<()> {
101        let nak = self.nak();
102        loop {
103            let c = self.getc(dev).await?;
104            match c {
105                ACK => return Ok(()),
106                _ => {
107                    if c == nak {
108                        return Err(Error::new(ErrorKind::BrokenPipe, "NAK"));
109                    }
110                    let mut out = AllowStdIo::new(std::io::stdout());
111                    out.write_all(&[c]).await?;
112                }
113            }
114        }
115    }
116
117    async fn send_header<D: AsyncWrite + AsyncRead + Unpin>(
118        &mut self,
119        dev: &mut D,
120        name: &str,
121        size: usize,
122    ) -> Result<()> {
123        let mut buff = Vec::new();
124        buff.append(&mut name.as_bytes().to_vec());
125        buff.push(0);
126        buff.append(&mut format!("{size}").as_bytes().to_vec());
127        buff.push(0);
128        self.send_blk(dev, &buff, 0, false).await
129    }
130
131    async fn send_blk<D: AsyncWrite + AsyncRead + Unpin>(
132        &mut self,
133        dev: &mut D,
134        data: &[u8],
135        pad: u8,
136        last: bool,
137    ) -> Result<()> {
138        let (len, p) = if data.len() > 128 {
139            (1024, STX)
140        } else {
141            (128, SOH)
142        };
143        let blk = if last { 0 } else { self.blk };
144        let mut err = None;
145        let mut retries = self.max_block_retries;
146
147        loop {
148            if retries == 0 {
149                return Err(err.unwrap_or(Error::new(ErrorKind::BrokenPipe, "retry too much")));
150            }
151
152            dev.write_all(&[p, blk, !blk]).await?;
153
154            let mut buf = vec![pad; len];
155            buf[..data.len()].copy_from_slice(data);
156            dev.write_all(&buf).await?;
157
158            if self.crc_mode {
159                let chsum = crc16_ccitt(0, &buf);
160                let crc1 = (chsum >> 8) as u8;
161                let crc2 = (chsum & 0xff) as u8;
162                dev.write_all(&[crc1, crc2]).await?;
163            }
164            dev.flush().await?;
165
166            match self.wait_ack(dev).await {
167                Ok(_) => break,
168                Err(e) => {
169                    err = Some(e);
170                    retries -= 1;
171                }
172            }
173        }
174
175        if self.blk == u8::MAX {
176            self.blk = 0;
177        } else {
178            self.blk += 1;
179        }
180
181        Ok(())
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use std::{
189        collections::VecDeque,
190        pin::Pin,
191        sync::Mutex,
192        task::{Context, Poll},
193    };
194
195    use futures::io::Cursor;
196
197    struct ScriptedDevice {
198        reads: VecDeque<u8>,
199        writes: Vec<u8>,
200    }
201
202    impl AsyncRead for ScriptedDevice {
203        fn poll_read(
204            mut self: Pin<&mut Self>,
205            _cx: &mut Context<'_>,
206            buf: &mut [u8],
207        ) -> Poll<Result<usize>> {
208            if self.reads.is_empty() {
209                return Poll::Ready(Ok(0));
210            }
211
212            let n = buf.len().min(self.reads.len());
213            for slot in &mut buf[..n] {
214                *slot = self.reads.pop_front().unwrap();
215            }
216            Poll::Ready(Ok(n))
217        }
218    }
219
220    impl AsyncWrite for ScriptedDevice {
221        fn poll_write(
222            mut self: Pin<&mut Self>,
223            _cx: &mut Context<'_>,
224            buf: &[u8],
225        ) -> Poll<Result<usize>> {
226            self.writes.extend_from_slice(buf);
227            Poll::Ready(Ok(buf.len()))
228        }
229
230        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
231            Poll::Ready(Ok(()))
232        }
233
234        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
235            Poll::Ready(Ok(()))
236        }
237    }
238
239    #[tokio::test]
240    async fn acked_block_resets_retry_budget() -> Result<()> {
241        let mut reads = VecDeque::from([CRC, ACK]);
242        reads.extend(std::iter::repeat_n(CRC, DEFAULT_BLOCK_RETRIES - 1));
243        reads.extend([ACK, ACK, ACK, CRC]);
244
245        let mut dev = ScriptedDevice {
246            reads,
247            writes: Vec::new(),
248        };
249        let mut file = Cursor::new(b"payload".to_vec());
250        let progress = Mutex::new(Vec::new());
251
252        Ymodem::new(true)
253            .send(&mut dev, &mut file, "kernel", 7, |sent| {
254                progress.lock().unwrap().push(sent);
255            })
256            .await?;
257
258        assert_eq!(*progress.lock().unwrap(), vec![7]);
259        assert!(dev.writes.contains(&EOT));
260        Ok(())
261    }
262}