1#[macro_use]
4extern crate log;
5
6use std::{
7 io::{Error, ErrorKind, Result, stdout},
8 path::{Path, PathBuf},
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12};
13
14use futures::{
15 AsyncReadExt, AsyncWriteExt,
16 future::{Either, FutureExt, select},
17 io::{AllowStdIo, AsyncRead, AsyncWrite},
18 pin_mut,
19};
20use futures_timer::Delay;
21
22pub mod crc;
24
25pub mod ymodem;
27
28macro_rules! dbg {
29 ($($arg:tt)*) => {{
30 debug!("$ {}", &std::fmt::format(format_args!($($arg)*)));
31 }};
32}
33
34const CTRL_C: u8 = 0x03;
35const INT_STR: &str = "<INTERRUPT>";
36const INT: &[u8] = INT_STR.as_bytes();
37const LOADY_MAX_ATTEMPTS: usize = 3;
38const LOADY_RETRY_DELAY: Duration = Duration::from_millis(300);
39
40type Tx = Box<dyn AsyncWrite + Send + Unpin>;
41type Rx = Box<dyn AsyncRead + Send + Unpin>;
42
43pub struct UbootShell {
44 pub tx: Option<Tx>,
46 pub rx: Option<Rx>,
48 perfix: String,
50}
51
52impl UbootShell {
53 pub async fn new(
54 tx: impl AsyncWrite + Send + Unpin + 'static,
55 rx: impl AsyncRead + Send + Unpin + 'static,
56 ) -> Result<Self> {
57 let mut shell = Self {
58 tx: Some(Box::new(tx)),
59 rx: Some(Box::new(rx)),
60 perfix: String::new(),
61 };
62 shell.wait_for_shell().await?;
63 debug!("shell ready, perfix: `{}`", shell.perfix);
64 Ok(shell)
65 }
66
67 fn rx(&mut self) -> &mut Rx {
68 self.rx.as_mut().unwrap()
69 }
70
71 fn tx(&mut self) -> &mut Tx {
72 self.tx.as_mut().unwrap()
73 }
74
75 async fn wait_for_interrupt(&mut self) -> Result<Vec<u8>> {
76 let mut history = Vec::new();
77 let mut interrupt_line = Vec::new();
78 let interval = Duration::from_millis(20);
79 let mut last_interrupt = std::time::Instant::now() - interval;
80
81 debug!("wait for interrupt");
82 loop {
83 if last_interrupt.elapsed() >= interval {
84 self.tx().write_all(&[CTRL_C]).await?;
85 self.tx().flush().await?;
86 last_interrupt = std::time::Instant::now();
87 }
88
89 match self.read_byte_with_timeout(interval).await {
90 Ok(ch) => {
91 history.push(ch);
92 if history.last() == Some(&b'\n') {
93 let line = history.trim_ascii_end();
94 dbg!("{}", String::from_utf8_lossy(line));
95 let interrupted = line.ends_with(INT);
96 if interrupted {
97 interrupt_line.extend_from_slice(line);
98 }
99 history.clear();
100 if interrupted {
101 break;
102 }
103 }
104 }
105 Err(err) if err.kind() == ErrorKind::TimedOut => {}
106 Err(err) => return Err(err),
107 }
108 }
109
110 Ok(interrupt_line)
111 }
112
113 async fn clear_shell(&mut self) -> Result<()> {
114 loop {
115 match self
116 .read_byte_with_timeout(Duration::from_millis(300))
117 .await
118 {
119 Ok(_) => {}
120 Err(err) if err.kind() == ErrorKind::TimedOut => return Ok(()),
121 Err(err) => return Err(err),
122 }
123 }
124 }
125
126 async fn wait_for_shell(&mut self) -> Result<()> {
127 let mut line = self.wait_for_interrupt().await?;
128 debug!("got {}", String::from_utf8_lossy(&line));
129 line.resize(line.len().saturating_sub(INT.len()), 0);
130 self.perfix = String::from_utf8_lossy(&line).to_string();
131 self.clear_shell().await?;
132 Ok(())
133 }
134
135 async fn read_byte(&mut self) -> Result<u8> {
136 self.read_byte_with_timeout(Duration::from_secs(5)).await
137 }
138
139 async fn read_byte_with_timeout(&mut self, timeout_limit: Duration) -> Result<u8> {
140 let mut buff = [0u8; 1];
141 let start = std::time::Instant::now();
142
143 loop {
144 let read = self.rx().read_exact(&mut buff).fuse();
145 let delay = Delay::new(Duration::from_millis(200)).fuse();
146 pin_mut!(read, delay);
147
148 match select(read, delay).await {
149 Either::Left((Ok(_), _)) => return Ok(buff[0]),
150 Either::Left((Err(err), _)) => return Err(err),
151 Either::Right((_, _)) => {
152 if start.elapsed() > timeout_limit {
153 return Err(Error::new(ErrorKind::TimedOut, "Timeout"));
154 }
155 }
156 }
157 }
158 }
159
160 pub async fn wait_for_reply(&mut self, val: &str) -> Result<String> {
161 let mut reply = Vec::new();
162 let mut display = Vec::new();
163 debug!("wait for `{val}`");
164
165 loop {
166 let byte = self.read_byte().await?;
167 reply.push(byte);
168 display.push(byte);
169 if byte == b'\n' {
170 dbg!("{}", String::from_utf8_lossy(&display).trim_end());
171 display.clear();
172 }
173
174 if reply.ends_with(val.as_bytes()) {
175 dbg!("{}", String::from_utf8_lossy(&display).trim_end());
176 break;
177 }
178 }
179
180 Ok(String::from_utf8_lossy(&reply)
181 .trim()
182 .trim_end_matches(&self.perfix)
183 .to_string())
184 }
185
186 pub async fn cmd_without_reply(&mut self, cmd: &str) -> Result<()> {
187 self.tx().write_all(cmd.as_bytes()).await?;
188 self.tx().write_all(b"\n").await?;
189 self.tx().flush().await?;
190 Ok(())
191 }
192
193 async fn _cmd(&mut self, cmd: &str) -> Result<String> {
194 self.clear_shell().await?;
195 let ok_str = "cmd-ok";
196 let cmd_with_id = format!("{cmd}&& echo {ok_str}");
197 self.cmd_without_reply(&cmd_with_id).await?;
198 let perfix = self.perfix.clone();
199 let res = self
200 .wait_for_reply(&perfix)
201 .await?
202 .trim_end()
203 .trim_end_matches(self.perfix.as_str().trim())
204 .trim_end()
205 .to_string();
206
207 if res.ends_with(ok_str) {
208 Ok(res
209 .trim()
210 .trim_end_matches(ok_str)
211 .trim_end()
212 .trim_start_matches(&cmd_with_id)
213 .trim()
214 .to_string())
215 } else {
216 Err(Error::other(format!(
217 "command `{cmd}` failed, response: {res}",
218 )))
219 }
220 }
221
222 pub async fn cmd(&mut self, cmd: &str) -> Result<String> {
223 info!("cmd: {cmd}");
224 let mut retry = 3;
225 while retry > 0 {
226 match self._cmd(cmd).await {
227 Ok(res) => return Ok(res),
228 Err(err) => {
229 warn!("cmd `{cmd}` failed: {err}, retrying...");
230 retry -= 1;
231 Delay::new(Duration::from_millis(100)).await;
232 }
233 }
234 }
235 Err(Error::other(format!(
236 "command `{cmd}` failed after retries",
237 )))
238 }
239
240 pub async fn set_env(
241 &mut self,
242 name: impl Into<String>,
243 value: impl Into<String>,
244 ) -> Result<()> {
245 self.cmd(&format!("setenv {} {}", name.into(), value.into()))
246 .await?;
247 Ok(())
248 }
249
250 pub async fn env(&mut self, name: impl Into<String>) -> Result<String> {
251 let name = name.into();
252 let s = self.cmd(&format!("echo ${name}")).await?;
253 let parts = s
254 .split('\n')
255 .filter(|line| !line.trim().is_empty())
256 .collect::<Vec<_>>();
257 let value = parts
258 .last()
259 .ok_or(Error::new(
260 ErrorKind::NotFound,
261 format!("env {name} not found"),
262 ))?
263 .to_string();
264 Ok(value)
265 }
266
267 pub async fn env_int(&mut self, name: impl Into<String>) -> Result<usize> {
268 let name = name.into();
269 let line = self.env(&name).await?;
270 debug!("env {name} = {line}");
271
272 parse_int(&line).ok_or(Error::new(
273 ErrorKind::InvalidData,
274 format!("env {name} is not a number"),
275 ))
276 }
277
278 pub async fn loady(
279 &mut self,
280 addr: usize,
281 file: impl Into<PathBuf>,
282 on_progress: impl Fn(usize, usize),
283 ) -> Result<String> {
284 let file = file.into();
285
286 for attempt in 1..=LOADY_MAX_ATTEMPTS {
287 match self.loady_once(addr, &file, &on_progress).await {
288 Ok(reply) => return Ok(reply),
289 Err(err) if attempt < LOADY_MAX_ATTEMPTS => {
290 warn!(
291 "loady attempt {attempt}/{LOADY_MAX_ATTEMPTS} failed: {err}; retrying..."
292 );
293 self.wait_for_shell().await.map_err(|recover_err| {
294 Error::other(format!(
295 "loady attempt {attempt} failed and shell recovery failed: {recover_err}",
296 ))
297 })?;
298 Delay::new(LOADY_RETRY_DELAY).await;
299 }
300 Err(err) => {
301 return Err(Error::other(format!(
302 "loady failed after {LOADY_MAX_ATTEMPTS} attempts: {err}"
303 )));
304 }
305 }
306 }
307
308 unreachable!("LOADY_MAX_ATTEMPTS must be greater than zero")
309 }
310
311 async fn loady_once(
312 &mut self,
313 addr: usize,
314 file: &Path,
315 on_progress: &impl Fn(usize, usize),
316 ) -> Result<String> {
317 self.clear_shell().await?;
318 self.cmd_without_reply(&format!("loady {addr:#x}")).await?;
319 let crc = self.wait_for_load_crc().await?;
320 let mut protocol = ymodem::Ymodem::new(crc);
321
322 let name = file
323 .file_name()
324 .and_then(|name| name.to_str())
325 .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "file name must be valid UTF-8"))?;
326 let size = std::fs::metadata(file)?.len() as usize;
327 let mut file = AllowStdIo::new(std::fs::File::open(file)?);
328
329 on_progress(0, size);
330 protocol
331 .send(self, &mut file, name, size, |sent| on_progress(sent, size))
332 .await?;
333 let perfix = self.perfix.clone();
334 self.wait_for_reply(&perfix).await
335 }
336
337 async fn wait_for_load_crc(&mut self) -> Result<bool> {
338 let mut reply = Vec::new();
339 loop {
340 let byte = self.read_byte().await?;
341 reply.push(byte);
342 print_raw(&[byte]).await?;
343
344 if reply.ends_with(b"C") {
345 return Ok(true);
346 }
347 let res = String::from_utf8_lossy(&reply);
348 if res.contains("try 'help'") {
349 return Err(Error::new(
350 ErrorKind::InvalidData,
351 format!("U-Boot loady failed: {res}"),
352 ));
353 }
354 }
355 }
356}
357
358impl AsyncRead for UbootShell {
359 fn poll_read(
360 self: Pin<&mut Self>,
361 cx: &mut Context<'_>,
362 buf: &mut [u8],
363 ) -> Poll<Result<usize>> {
364 let this = self.get_mut();
365 Pin::new(this.rx.as_mut().unwrap().as_mut()).poll_read(cx, buf)
366 }
367}
368
369impl AsyncWrite for UbootShell {
370 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
371 let this = self.get_mut();
372 Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_write(cx, buf)
373 }
374
375 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
376 let this = self.get_mut();
377 Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_flush(cx)
378 }
379
380 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
381 let this = self.get_mut();
382 Pin::new(this.tx.as_mut().unwrap().as_mut()).poll_close(cx)
383 }
384}
385
386fn parse_int(line: &str) -> Option<usize> {
387 let mut line = line.trim();
388 let mut radix = 10;
389 if line.starts_with("0x") {
390 line = &line[2..];
391 radix = 16;
392 }
393 u64::from_str_radix(line, radix)
394 .ok()
395 .map(|value| value as usize)
396}
397
398async fn print_raw(buff: &[u8]) -> Result<()> {
399 #[cfg(target_os = "windows")]
400 {
401 print_raw_win(buff);
402 Ok(())
403 }
404 #[cfg(not(target_os = "windows"))]
405 {
406 let mut out = AllowStdIo::new(stdout());
407 out.write_all(buff).await
408 }
409}
410
411#[cfg(target_os = "windows")]
412fn print_raw_win(buff: &[u8]) {
413 use std::sync::Mutex;
414 static PRINT_BUFF: Mutex<Vec<u8>> = Mutex::new(Vec::new());
415
416 let mut g = PRINT_BUFF.lock().unwrap();
417 g.extend_from_slice(buff);
418
419 if g.ends_with(b"\n") {
420 let s = String::from_utf8_lossy(&g[..]);
421 println!("{}", s.trim());
422 g.clear();
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use std::{
430 collections::VecDeque,
431 fs,
432 sync::{Arc, Mutex},
433 };
434
435 #[derive(Default)]
436 struct LoadyScript {
437 reads: VecDeque<u8>,
438 writes: Vec<u8>,
439 command: Vec<u8>,
440 loady_count: usize,
441 interrupted: bool,
442 accepting_commands: bool,
443 }
444
445 impl LoadyScript {
446 fn queue_read(&mut self, bytes: impl AsRef<[u8]>) {
447 self.reads.extend(bytes.as_ref());
448 }
449
450 fn handle_write(&mut self, bytes: &[u8]) {
451 self.writes.extend_from_slice(bytes);
452
453 if bytes == [CTRL_C] {
454 self.command.clear();
455 self.accepting_commands = true;
456 if !self.interrupted {
457 self.interrupted = true;
458 self.queue_read(b"=> <INTERRUPT>\n");
459 }
460 return;
461 }
462
463 if !self.accepting_commands {
464 return;
465 }
466
467 for &byte in bytes {
468 self.command.push(byte);
469 if byte == b'\n' {
470 let command = std::mem::take(&mut self.command);
471 if command.starts_with(b"loady ") {
472 self.loady_count += 1;
473 self.accepting_commands = false;
474 self.queue_loady_response();
475 }
476 } else if self.command.len() > 256 {
477 self.command.clear();
478 }
479 }
480 }
481
482 fn queue_loady_response(&mut self) {
483 match self.loady_count {
484 1 => {
485 self.queue_read(*b"C");
486 self.queue_read([ymodem::CRC; ymodem::DEFAULT_BLOCK_RETRIES]);
487 }
488 2 => {
489 self.queue_read(*b"C");
490 self.queue_read([ymodem::ACK, ymodem::ACK, ymodem::ACK, ymodem::ACK, b'C']);
491 self.queue_read(b"done\n=> ");
492 }
493 _ => {}
494 }
495 }
496 }
497
498 #[derive(Clone)]
499 struct ScriptedTx {
500 script: Arc<Mutex<LoadyScript>>,
501 }
502
503 #[derive(Clone)]
504 struct ScriptedRx {
505 script: Arc<Mutex<LoadyScript>>,
506 }
507
508 impl AsyncWrite for ScriptedTx {
509 fn poll_write(
510 self: Pin<&mut Self>,
511 _cx: &mut Context<'_>,
512 buf: &[u8],
513 ) -> Poll<Result<usize>> {
514 self.script.lock().unwrap().handle_write(buf);
515 Poll::Ready(Ok(buf.len()))
516 }
517
518 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
519 Poll::Ready(Ok(()))
520 }
521
522 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
523 Poll::Ready(Ok(()))
524 }
525 }
526
527 impl AsyncRead for ScriptedRx {
528 fn poll_read(
529 self: Pin<&mut Self>,
530 _cx: &mut Context<'_>,
531 buf: &mut [u8],
532 ) -> Poll<Result<usize>> {
533 let mut script = self.script.lock().unwrap();
534 if script.reads.is_empty() {
535 return Poll::Pending;
536 }
537
538 let n = buf.len().min(script.reads.len());
539 for slot in &mut buf[..n] {
540 *slot = script.reads.pop_front().unwrap();
541 }
542 Poll::Ready(Ok(n))
543 }
544 }
545
546 #[tokio::test]
547 async fn loady_restarts_transfer_after_receiver_rejects_first_attempt() -> Result<()> {
548 let script = Arc::new(Mutex::new(LoadyScript::default()));
549 script.lock().unwrap().accepting_commands = true;
550 let mut shell = UbootShell {
551 tx: Some(Box::new(ScriptedTx {
552 script: script.clone(),
553 })),
554 rx: Some(Box::new(ScriptedRx {
555 script: script.clone(),
556 })),
557 perfix: "=> ".to_string(),
558 };
559
560 let file =
561 std::env::temp_dir().join(format!("uboot-shell-loady-retry-{}", std::process::id()));
562 fs::write(&file, b"payload")?;
563
564 let progress = Arc::new(Mutex::new(Vec::new()));
565 let reply = shell
566 .loady(0x80200000, file.clone(), {
567 let progress = progress.clone();
568 move |sent, size| progress.lock().unwrap().push((sent, size))
569 })
570 .await;
571 let _ = fs::remove_file(&file);
572
573 assert!(reply?.contains("done"));
574 let script = script.lock().unwrap();
575 let writes = String::from_utf8_lossy(&script.writes);
576 assert_eq!(writes.matches("loady 0x80200000").count(), 2);
577 assert!(script.writes.contains(&CTRL_C));
578 assert_eq!(*progress.lock().unwrap(), vec![(0, 7), (0, 7), (7, 7)]);
579 Ok(())
580 }
581}