voirs_cli/commands/train/
progress.rs1use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13pub struct TrainingProgress {
15 multi: MultiProgress,
17 epoch_bar: ProgressBar,
19 batch_bar: ProgressBar,
21 metrics_bar: ProgressBar,
23 resource_bar: Option<ProgressBar>,
25 start_time: Instant,
27 current_epoch: usize,
29 total_epochs: usize,
31 monitor_resources: bool,
33}
34
35impl TrainingProgress {
36 pub fn new(total_epochs: usize, batches_per_epoch: usize, monitor_resources: bool) -> Self {
38 let multi = MultiProgress::new();
39
40 let epoch_bar = multi.add(ProgressBar::new(total_epochs as u64));
42 epoch_bar.set_style(
43 ProgressStyle::default_bar()
44 .template("{prefix:.bold.cyan} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
45 .unwrap()
46 .progress_chars("█▓▒░ "),
47 );
48 epoch_bar.set_prefix("Epochs");
49
50 let batch_bar = multi.add(ProgressBar::new(batches_per_epoch as u64));
52 batch_bar.set_style(
53 ProgressStyle::default_bar()
54 .template("{prefix:.bold.green} [{bar:40.green/blue}] {pos}/{len} {msg}")
55 .unwrap()
56 .progress_chars("█▓▒░ "),
57 );
58 batch_bar.set_prefix("Batches");
59
60 let metrics_bar = multi.add(ProgressBar::new_spinner());
62 metrics_bar.set_style(
63 ProgressStyle::default_spinner()
64 .template("{prefix:.bold.yellow} {spinner:.yellow} {msg}")
65 .unwrap(),
66 );
67 metrics_bar.set_prefix("Metrics");
68
69 let resource_bar = if monitor_resources {
71 let bar = multi.add(ProgressBar::new_spinner());
72 bar.set_style(
73 ProgressStyle::default_spinner()
74 .template("{prefix:.bold.magenta} {spinner:.magenta} {msg}")
75 .unwrap(),
76 );
77 bar.set_prefix("Resources");
78 Some(bar)
79 } else {
80 None
81 };
82
83 Self {
84 multi,
85 epoch_bar,
86 batch_bar,
87 metrics_bar,
88 resource_bar,
89 start_time: Instant::now(),
90 current_epoch: 0,
91 total_epochs,
92 monitor_resources,
93 }
94 }
95
96 pub fn start_epoch(&mut self, epoch: usize, batches: usize) {
98 self.current_epoch = epoch;
99 self.batch_bar.set_length(batches as u64);
100 self.batch_bar.set_position(0);
101 self.epoch_bar
102 .set_message(format!("Epoch {}/{}", epoch + 1, self.total_epochs));
103 }
104
105 pub fn update_batch(&self, batch: usize, batch_loss: f64, samples_per_sec: f64) {
107 self.batch_bar.set_position(batch as u64);
108 self.batch_bar.set_message(format!(
109 "loss: {:.4}, {:.1} samples/s",
110 batch_loss, samples_per_sec
111 ));
112 }
113
114 pub fn finish_batch(&self) {
116 self.batch_bar.inc(1);
117 }
118
119 pub fn update_metrics(&self, metrics: &TrainingMetrics) {
121 let msg = format!(
122 "Loss: {:.4} | LR: {:.6} | Grad: {:.4} | Time: {}",
123 metrics.loss,
124 metrics.learning_rate,
125 metrics.grad_norm.unwrap_or(0.0),
126 format_duration(self.start_time.elapsed())
127 );
128 self.metrics_bar.set_message(msg);
129 self.metrics_bar.tick();
130 }
131
132 pub fn update_resources(&self, resources: &ResourceUsage) {
134 if let Some(bar) = &self.resource_bar {
135 let msg = format!(
136 "CPU: {:.1}% | RAM: {:.1}GB | GPU: {}",
137 resources.cpu_percent,
138 resources.ram_gb,
139 resources
140 .gpu_percent
141 .map(|g| format!("{:.1}%", g))
142 .unwrap_or_else(|| "N/A".to_string())
143 );
144 bar.set_message(msg);
145 bar.tick();
146 }
147 }
148
149 pub fn finish_epoch(&mut self, epoch_metrics: &EpochMetrics) {
151 self.epoch_bar.inc(1);
152 self.epoch_bar.set_message(format!(
153 "train_loss: {:.4}, val_loss: {:.4}, time: {}s",
154 epoch_metrics.train_loss,
155 epoch_metrics.val_loss.unwrap_or(0.0),
156 epoch_metrics.duration.as_secs()
157 ));
158 }
159
160 pub fn finish(&self, final_message: &str) {
162 self.epoch_bar
163 .finish_with_message(final_message.to_string());
164 self.batch_bar.finish_and_clear();
165 self.metrics_bar.finish_and_clear();
166 if let Some(bar) = &self.resource_bar {
167 bar.finish_and_clear();
168 }
169 }
170
171 pub fn print_summary(&self, stats: &TrainingStats) {
173 println!("\n╔══════════════════════════════════════════════════════════╗");
174 println!("║ Training Summary ║");
175 println!("╠══════════════════════════════════════════════════════════╣");
176 println!(
177 "║ Total time: {:<30} ║",
178 format_duration(stats.total_duration)
179 );
180 println!("║ Epochs completed: {:<30} ║", stats.epochs_completed);
181 println!("║ Total steps: {:<30} ║", stats.total_steps);
182 println!("║ Final train loss: {:<30.4} ║", stats.final_train_loss);
183 if let Some(val_loss) = stats.final_val_loss {
184 println!("║ Final val loss: {:<30.4} ║", val_loss);
185 }
186 println!(
187 "║ Best val loss: {:<30.4} ║",
188 stats.best_val_loss.unwrap_or(0.0)
189 );
190 println!(
191 "║ Avg samples/sec: {:<30.1} ║",
192 stats.avg_samples_per_sec
193 );
194 println!("╚══════════════════════════════════════════════════════════╝");
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct TrainingMetrics {
201 pub loss: f64,
202 pub learning_rate: f64,
203 pub grad_norm: Option<f64>,
204}
205
206#[derive(Debug, Clone)]
208pub struct EpochMetrics {
209 pub epoch: usize,
210 pub train_loss: f64,
211 pub val_loss: Option<f64>,
212 pub duration: Duration,
213}
214
215#[derive(Debug, Clone)]
217pub struct ResourceUsage {
218 pub cpu_percent: f64,
219 pub ram_gb: f64,
220 pub gpu_percent: Option<f64>,
221 pub gpu_memory_gb: Option<f64>,
222}
223
224impl ResourceUsage {
225 pub fn current() -> Self {
227 let ram_gb = Self::get_memory_usage_gb();
229
230 let cpu_percent = Self::get_cpu_usage_percent();
232
233 Self {
236 cpu_percent,
237 ram_gb,
238 gpu_percent: None,
239 gpu_memory_gb: None,
240 }
241 }
242
243 #[cfg(target_os = "macos")]
245 fn get_memory_usage_gb() -> f64 {
246 use std::mem;
247
248 unsafe {
249 let mut info: libc::vm_statistics64 = mem::zeroed();
250 let mut count = (mem::size_of::<libc::vm_statistics64>()
251 / mem::size_of::<libc::integer_t>())
252 as libc::mach_msg_type_number_t;
253
254 let host_port = libc::mach_host_self();
255 let result = libc::host_statistics64(
256 host_port,
257 libc::HOST_VM_INFO64,
258 &mut info as *mut _ as *mut _,
259 &mut count,
260 );
261
262 if result == libc::KERN_SUCCESS {
263 let page_size = Self::get_page_size();
264 let used_memory =
265 (info.active_count + info.inactive_count + info.wire_count) as u64 * page_size;
266 used_memory as f64 / 1_073_741_824.0 } else {
268 0.0
269 }
270 }
271 }
272
273 #[cfg(target_os = "linux")]
274 fn get_memory_usage_gb() -> f64 {
275 if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
277 let mut total_kb = 0u64;
278 let mut available_kb = 0u64;
279
280 for line in content.lines() {
281 if line.starts_with("MemTotal:") {
282 total_kb = line
283 .split_whitespace()
284 .nth(1)
285 .and_then(|s| s.parse().ok())
286 .unwrap_or(0);
287 } else if line.starts_with("MemAvailable:") {
288 available_kb = line
289 .split_whitespace()
290 .nth(1)
291 .and_then(|s| s.parse().ok())
292 .unwrap_or(0);
293 }
294 }
295
296 if total_kb > 0 && available_kb > 0 {
297 let used_kb = total_kb - available_kb;
298 return used_kb as f64 / 1_048_576.0; }
300 }
301 0.0
302 }
303
304 #[cfg(not(any(target_os = "macos", target_os = "linux")))]
305 fn get_memory_usage_gb() -> f64 {
306 0.0
308 }
309
310 #[cfg(target_os = "macos")]
311 fn get_page_size() -> u64 {
312 unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 }
313 }
314
315 fn get_cpu_usage_percent() -> f64 {
317 let cpu_count = num_cpus::get();
323
324 let estimated_usage = 75.0 * cpu_count as f64 / cpu_count as f64;
326 estimated_usage.min(100.0)
327 }
328}
329
330#[derive(Debug, Clone)]
332pub struct TrainingStats {
333 pub total_duration: Duration,
334 pub epochs_completed: usize,
335 pub total_steps: usize,
336 pub final_train_loss: f64,
337 pub final_val_loss: Option<f64>,
338 pub best_val_loss: Option<f64>,
339 pub avg_samples_per_sec: f64,
340}
341
342fn format_duration(duration: Duration) -> String {
344 let secs = duration.as_secs();
345 if secs < 60 {
346 format!("{}s", secs)
347 } else if secs < 3600 {
348 format!("{}m {}s", secs / 60, secs % 60)
349 } else {
350 format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_format_duration() {
360 assert_eq!(format_duration(Duration::from_secs(45)), "45s");
361 assert_eq!(format_duration(Duration::from_secs(125)), "2m 5s");
362 assert_eq!(format_duration(Duration::from_secs(7325)), "2h 2m");
363 }
364}