vtcode_core/llm/providers/ollama/
pull.rs1use hashbrown::HashMap;
2use ratatui::crossterm::{
3 cursor::MoveToColumn,
4 execute,
5 terminal::{Clear, ClearType},
6};
7use std::io;
8use std::io::Write;
9
10#[derive(Debug, Clone)]
14pub enum OllamaPullEvent {
15 Status(String),
17 ChunkProgress {
19 digest: String,
20 total: Option<u64>,
21 completed: Option<u64>,
22 },
23 Success,
25 Error(String),
27}
28
29pub trait OllamaPullProgressReporter {
32 fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()>;
33}
34
35pub struct CliPullProgressReporter {
37 printed_header: bool,
38 last_line_len: usize,
39 last_completed_sum: u64,
40 last_instant: std::time::Instant,
41 totals_by_digest: HashMap<String, (u64, u64)>,
42}
43
44impl Default for CliPullProgressReporter {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl CliPullProgressReporter {
51 pub fn new() -> Self {
52 Self {
53 printed_header: false,
54 last_line_len: 0,
55 last_completed_sum: 0,
56 last_instant: std::time::Instant::now(),
57 totals_by_digest: HashMap::new(),
58 }
59 }
60}
61
62impl OllamaPullProgressReporter for CliPullProgressReporter {
63 fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()> {
64 let mut out = io::stderr();
65 match event {
66 OllamaPullEvent::Status(status) => {
67 if status.eq_ignore_ascii_case("pulling manifest") {
69 return Ok(());
70 }
71 let pad = self.last_line_len.saturating_sub(status.len());
72 let line = format!("\r{status}{}", " ".repeat(pad));
73 self.last_line_len = status.len();
74 out.write_all(line.as_bytes())?;
75 out.flush()
76 }
77 OllamaPullEvent::ChunkProgress {
78 digest,
79 total,
80 completed,
81 } => {
82 if let Some(t) = total {
83 self.totals_by_digest
84 .entry(digest.clone())
85 .or_insert((0, 0))
86 .0 = *t;
87 }
88 if let Some(c) = completed {
89 self.totals_by_digest
90 .entry(digest.clone())
91 .or_insert((0, 0))
92 .1 = *c;
93 }
94 let (sum_total, sum_completed) = self
95 .totals_by_digest
96 .values()
97 .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + t, acc.1 + c));
98
99 if sum_total > 0 {
100 if !self.printed_header {
101 let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
102 let header = format!("Downloading model: total {gb:.2} GB\n");
103 execute!(out, MoveToColumn(0), Clear(ClearType::CurrentLine))?;
104 out.write_all(header.as_bytes())?;
105 self.printed_header = true;
106 }
107 let now = std::time::Instant::now();
108 let dt = now
109 .duration_since(self.last_instant)
110 .as_secs_f64()
111 .max(0.001);
112 let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
113 let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
114 self.last_completed_sum = sum_completed;
115 self.last_instant = now;
116 let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
117 let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
118 let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
119 let text =
120 format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
121 let pad = self.last_line_len.saturating_sub(text.len());
122 let line = format!("\r{text}{}", " ".repeat(pad));
123 self.last_line_len = text.len();
124 out.write_all(line.as_bytes())?;
125 out.flush()
126 } else {
127 Ok(())
128 }
129 }
130 OllamaPullEvent::Error(_) => {
131 Ok(())
134 }
135 OllamaPullEvent::Success => {
136 out.write_all(b"\n")?;
137 out.flush()
138 }
139 }
140 }
141}
142
143#[derive(Default)]
146pub struct TuiPullProgressReporter(CliPullProgressReporter);
147
148impl OllamaPullProgressReporter for TuiPullProgressReporter {
149 fn on_event(&mut self, event: &OllamaPullEvent) -> io::Result<()> {
150 self.0.on_event(event)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn cli_reporter_formats_status_messages() {
160 let mut reporter = CliPullProgressReporter::new();
161 let event = OllamaPullEvent::Status("verifying".to_string());
162 let result = reporter.on_event(&event);
163 result.unwrap();
164 }
165
166 #[test]
167 fn cli_reporter_tracks_download_progress() {
168 let mut reporter = CliPullProgressReporter::new();
169 let event = OllamaPullEvent::ChunkProgress {
170 digest: "sha256:abc".to_string(),
171 total: Some(1_000_000_000), completed: Some(500_000_000),
173 };
174 let result = reporter.on_event(&event);
175 result.unwrap();
176 }
177}