Skip to main content

walrus_cli/cmd/
download.rs

1//! `walrus download` — download a model from HuggingFace with progress.
2
3use crate::runner::gateway::GatewayRunner;
4use anyhow::Result;
5use clap::Args;
6use compact_str::CompactString;
7use futures_util::StreamExt;
8use protocol::{ClientMessage, ServerMessage};
9use std::io::Write;
10
11/// Download a model's files from HuggingFace.
12#[derive(Args, Debug)]
13pub struct Download {
14    /// HuggingFace model ID (e.g. "microsoft/Phi-3.5-mini-instruct").
15    pub model: String,
16}
17
18impl Download {
19    /// Run the download, streaming progress to the terminal.
20    pub async fn run(self, runner: &mut GatewayRunner) -> Result<()> {
21        let msg = ClientMessage::Download {
22            model: CompactString::from(&self.model),
23        };
24        let stream = runner.download_stream(msg);
25        futures_util::pin_mut!(stream);
26
27        let mut current_size: u64 = 0;
28        let mut downloaded: u64 = 0;
29        let mut current_file = String::new();
30
31        while let Some(result) = stream.next().await {
32            match result? {
33                ServerMessage::DownloadStart { model } => {
34                    println!("Downloading {model}...");
35                }
36                ServerMessage::DownloadFileStart { filename, size } => {
37                    current_file = filename;
38                    current_size = size;
39                    downloaded = 0;
40                }
41                ServerMessage::DownloadProgress { bytes } => {
42                    downloaded += bytes;
43                    let pct = if current_size > 0 {
44                        downloaded * 100 / current_size
45                    } else {
46                        0
47                    };
48                    eprint!(
49                        "\r  {} {}% ({} / {})",
50                        current_file,
51                        pct,
52                        format_bytes(downloaded),
53                        format_bytes(current_size),
54                    );
55                    std::io::stderr().flush().ok();
56                }
57                ServerMessage::DownloadFileEnd { filename } => {
58                    eprintln!("\r  {filename} done{:30}", "");
59                }
60                ServerMessage::DownloadEnd { model } => {
61                    println!("Download complete: {model}");
62                }
63                ServerMessage::Error { code, message } => {
64                    eprintln!("Error ({code}): {message}");
65                    break;
66                }
67                _ => {}
68            }
69        }
70        Ok(())
71    }
72}
73
74/// Format byte count as human-readable string.
75fn format_bytes(bytes: u64) -> String {
76    const KB: u64 = 1024;
77    const MB: u64 = 1024 * KB;
78    const GB: u64 = 1024 * MB;
79
80    if bytes >= GB {
81        format!("{:.1} GB", bytes as f64 / GB as f64)
82    } else if bytes >= MB {
83        format!("{:.1} MB", bytes as f64 / MB as f64)
84    } else if bytes >= KB {
85        format!("{:.1} KB", bytes as f64 / KB as f64)
86    } else {
87        format!("{bytes} B")
88    }
89}