Skip to main content

vtcode_core/llm/providers/ollama/
pull.rs

1use hashbrown::HashMap;
2use ratatui::crossterm::{
3    cursor::MoveToColumn,
4    execute,
5    terminal::{Clear, ClearType},
6};
7use std::io;
8use std::io::Write;
9
10/// Ollama model pull functionality with progress reporting.
11/// Adapted from OpenAI Codex's codex-ollama/src/pull.rs
12/// Events emitted while pulling a model from Ollama.
13#[derive(Debug, Clone)]
14pub enum OllamaPullEvent {
15    /// A human-readable status message (e.g., "verifying", "writing").
16    Status(String),
17    /// Byte-level progress update for a specific layer digest.
18    ChunkProgress {
19        digest: String,
20        total: Option<u64>,
21        completed: Option<u64>,
22    },
23    /// The pull finished successfully.
24    Success,
25    /// Error event with a message.
26    Error(String),
27}
28
29/// A progress reporter for pull operations. Implementations decide how to render progress
30/// (CLI, TUI, logs, etc.).
31pub trait OllamaPullProgressReporter {
32    fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()>;
33}
34
35/// A minimal CLI reporter that writes inline progress to stderr.
36pub struct CliPullProgressReporter {
37    printed_header: bool,
38    last_line_len: usize,
39    last_completed_sum: u64,
40    last_instant: std::time::Instant,
41    totals_by_digest: HashMap<String, (u64, u64)>,
42}
43
44impl Default for CliPullProgressReporter {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl CliPullProgressReporter {
51    pub fn new() -> Self {
52        Self {
53            printed_header: false,
54            last_line_len: 0,
55            last_completed_sum: 0,
56            last_instant: std::time::Instant::now(),
57            totals_by_digest: HashMap::new(),
58        }
59    }
60}
61
62impl OllamaPullProgressReporter for CliPullProgressReporter {
63    fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()> {
64        let mut out = io::stderr();
65        match event {
66            OllamaPullEvent::Status(status) => {
67                // Avoid noisy manifest messages; otherwise show status inline.
68                if status.eq_ignore_ascii_case("pulling manifest") {
69                    return Ok(());
70                }
71                let pad = self.last_line_len.saturating_sub(status.len());
72                let line = format!("\r{status}{}", " ".repeat(pad));
73                self.last_line_len = status.len();
74                out.write_all(line.as_bytes())?;
75                out.flush()
76            }
77            OllamaPullEvent::ChunkProgress {
78                digest,
79                total,
80                completed,
81            } => {
82                if let Some(t) = total {
83                    self.totals_by_digest
84                        .entry(digest.clone())
85                        .or_insert((0, 0))
86                        .0 = *t;
87                }
88                if let Some(c) = completed {
89                    self.totals_by_digest
90                        .entry(digest.clone())
91                        .or_insert((0, 0))
92                        .1 = *c;
93                }
94                let (sum_total, sum_completed) = self
95                    .totals_by_digest
96                    .values()
97                    .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + t, acc.1 + c));
98
99                if sum_total > 0 {
100                    if !self.printed_header {
101                        let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
102                        let header = format!("Downloading model: total {gb:.2} GB\n");
103                        execute!(out, MoveToColumn(0), Clear(ClearType::CurrentLine))?;
104                        out.write_all(header.as_bytes())?;
105                        self.printed_header = true;
106                    }
107                    let now = std::time::Instant::now();
108                    let dt = now
109                        .duration_since(self.last_instant)
110                        .as_secs_f64()
111                        .max(0.001);
112                    let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
113                    let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
114                    self.last_completed_sum = sum_completed;
115                    self.last_instant = now;
116                    let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
117                    let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
118                    let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
119                    let text =
120                        format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
121                    let pad = self.last_line_len.saturating_sub(text.len());
122                    let line = format!("\r{text}{}", " ".repeat(pad));
123                    self.last_line_len = text.len();
124                    out.write_all(line.as_bytes())?;
125                    out.flush()
126                } else {
127                    Ok(())
128                }
129            }
130            OllamaPullEvent::Error(_) => {
131                // This will be handled by the caller, so we don't do anything
132                // here or the error will be printed twice.
133                Ok(())
134            }
135            OllamaPullEvent::Success => {
136                out.write_all(b"\n")?;
137                out.flush()
138            }
139        }
140    }
141}
142
143/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
144/// CLI behavior aligned until a dedicated TUI integration is implemented.
145#[derive(Default)]
146pub struct TuiPullProgressReporter(CliPullProgressReporter);
147
148impl OllamaPullProgressReporter for TuiPullProgressReporter {
149    fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()> {
150        self.0.on_event(event)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn cli_reporter_formats_status_messages() {
160        let mut reporter = CliPullProgressReporter::new();
161        let event = OllamaPullEvent::Status("verifying".to_string());
162        let result = reporter.on_event(&event);
163        result.unwrap();
164    }
165
166    #[test]
167    fn cli_reporter_tracks_download_progress() {
168        let mut reporter = CliPullProgressReporter::new();
169        let event = OllamaPullEvent::ChunkProgress {
170            digest: "sha256:abc".to_string(),
171            total: Some(1_000_000_000), // 1 GB
172            completed: Some(500_000_000),
173        };
174        let result = reporter.on_event(&event);
175        result.unwrap();
176    }
177}