Skip to main content

voirs_cli/commands/train/
progress.rs

1//! Training progress monitoring utilities
2//!
3//! Provides comprehensive progress tracking and visualization for model training:
4//! - Multi-level progress bars (epoch, batch, metrics)
5//! - Live metrics display (loss, learning rate, GPU usage)
6//! - Resource monitoring (CPU, memory, GPU)
7//! - Training statistics (samples/sec, ETA)
8
9use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13/// Training progress tracker with multi-level progress bars
14pub struct TrainingProgress {
15    /// Multi-progress container
16    multi: MultiProgress,
17    /// Epoch progress bar
18    epoch_bar: ProgressBar,
19    /// Batch progress bar
20    batch_bar: ProgressBar,
21    /// Metrics display bar
22    metrics_bar: ProgressBar,
23    /// Resource monitoring bar
24    resource_bar: Option<ProgressBar>,
25    /// Training start time
26    start_time: Instant,
27    /// Current epoch
28    current_epoch: usize,
29    /// Total epochs
30    total_epochs: usize,
31    /// Enable resource monitoring
32    monitor_resources: bool,
33}
34
35impl TrainingProgress {
36    /// Create new training progress tracker
37    pub fn new(total_epochs: usize, batches_per_epoch: usize, monitor_resources: bool) -> Self {
38        let multi = MultiProgress::new();
39
40        // Epoch progress bar
41        let epoch_bar = multi.add(ProgressBar::new(total_epochs as u64));
42        epoch_bar.set_style(
43            ProgressStyle::default_bar()
44                .template("{prefix:.bold.cyan} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
45                .unwrap()
46                .progress_chars("█▓▒░ "),
47        );
48        epoch_bar.set_prefix("Epochs");
49
50        // Batch progress bar
51        let batch_bar = multi.add(ProgressBar::new(batches_per_epoch as u64));
52        batch_bar.set_style(
53            ProgressStyle::default_bar()
54                .template("{prefix:.bold.green} [{bar:40.green/blue}] {pos}/{len} {msg}")
55                .unwrap()
56                .progress_chars("█▓▒░ "),
57        );
58        batch_bar.set_prefix("Batches");
59
60        // Metrics display (spinner with metrics)
61        let metrics_bar = multi.add(ProgressBar::new_spinner());
62        metrics_bar.set_style(
63            ProgressStyle::default_spinner()
64                .template("{prefix:.bold.yellow} {spinner:.yellow} {msg}")
65                .unwrap(),
66        );
67        metrics_bar.set_prefix("Metrics");
68
69        // Resource monitoring bar (optional)
70        let resource_bar = if monitor_resources {
71            let bar = multi.add(ProgressBar::new_spinner());
72            bar.set_style(
73                ProgressStyle::default_spinner()
74                    .template("{prefix:.bold.magenta} {spinner:.magenta} {msg}")
75                    .unwrap(),
76            );
77            bar.set_prefix("Resources");
78            Some(bar)
79        } else {
80            None
81        };
82
83        Self {
84            multi,
85            epoch_bar,
86            batch_bar,
87            metrics_bar,
88            resource_bar,
89            start_time: Instant::now(),
90            current_epoch: 0,
91            total_epochs,
92            monitor_resources,
93        }
94    }
95
96    /// Start new epoch
97    pub fn start_epoch(&mut self, epoch: usize, batches: usize) {
98        self.current_epoch = epoch;
99        self.batch_bar.set_length(batches as u64);
100        self.batch_bar.set_position(0);
101        self.epoch_bar
102            .set_message(format!("Epoch {}/{}", epoch + 1, self.total_epochs));
103    }
104
105    /// Update batch progress
106    pub fn update_batch(&self, batch: usize, batch_loss: f64, samples_per_sec: f64) {
107        self.batch_bar.set_position(batch as u64);
108        self.batch_bar.set_message(format!(
109            "loss: {:.4}, {:.1} samples/s",
110            batch_loss, samples_per_sec
111        ));
112    }
113
114    /// Finish current batch
115    pub fn finish_batch(&self) {
116        self.batch_bar.inc(1);
117    }
118
119    /// Update metrics display
120    pub fn update_metrics(&self, metrics: &TrainingMetrics) {
121        let msg = format!(
122            "Loss: {:.4} | LR: {:.6} | Grad: {:.4} | Time: {}",
123            metrics.loss,
124            metrics.learning_rate,
125            metrics.grad_norm.unwrap_or(0.0),
126            format_duration(self.start_time.elapsed())
127        );
128        self.metrics_bar.set_message(msg);
129        self.metrics_bar.tick();
130    }
131
132    /// Update resource monitoring
133    pub fn update_resources(&self, resources: &ResourceUsage) {
134        if let Some(bar) = &self.resource_bar {
135            let msg = format!(
136                "CPU: {:.1}% | RAM: {:.1}GB | GPU: {}",
137                resources.cpu_percent,
138                resources.ram_gb,
139                resources
140                    .gpu_percent
141                    .map(|g| format!("{:.1}%", g))
142                    .unwrap_or_else(|| "N/A".to_string())
143            );
144            bar.set_message(msg);
145            bar.tick();
146        }
147    }
148
149    /// Finish epoch
150    pub fn finish_epoch(&mut self, epoch_metrics: &EpochMetrics) {
151        self.epoch_bar.inc(1);
152        self.epoch_bar.set_message(format!(
153            "train_loss: {:.4}, val_loss: {:.4}, time: {}s",
154            epoch_metrics.train_loss,
155            epoch_metrics.val_loss.unwrap_or(0.0),
156            epoch_metrics.duration.as_secs()
157        ));
158    }
159
160    /// Finish training
161    pub fn finish(&self, final_message: &str) {
162        self.epoch_bar
163            .finish_with_message(final_message.to_string());
164        self.batch_bar.finish_and_clear();
165        self.metrics_bar.finish_and_clear();
166        if let Some(bar) = &self.resource_bar {
167            bar.finish_and_clear();
168        }
169    }
170
171    /// Print summary statistics
172    pub fn print_summary(&self, stats: &TrainingStats) {
173        println!("\n╔══════════════════════════════════════════════════════════╗");
174        println!("║              Training Summary                            ║");
175        println!("╠══════════════════════════════════════════════════════════╣");
176        println!(
177            "║ Total time:         {:<30} ║",
178            format_duration(stats.total_duration)
179        );
180        println!("║ Epochs completed:   {:<30} ║", stats.epochs_completed);
181        println!("║ Total steps:        {:<30} ║", stats.total_steps);
182        println!("║ Final train loss:   {:<30.4} ║", stats.final_train_loss);
183        if let Some(val_loss) = stats.final_val_loss {
184            println!("║ Final val loss:     {:<30.4} ║", val_loss);
185        }
186        println!(
187            "║ Best val loss:      {:<30.4} ║",
188            stats.best_val_loss.unwrap_or(0.0)
189        );
190        println!(
191            "║ Avg samples/sec:    {:<30.1} ║",
192            stats.avg_samples_per_sec
193        );
194        println!("╚══════════════════════════════════════════════════════════╝");
195    }
196}
197
198/// Training metrics for current step
199#[derive(Debug, Clone)]
200pub struct TrainingMetrics {
201    pub loss: f64,
202    pub learning_rate: f64,
203    pub grad_norm: Option<f64>,
204}
205
206/// Metrics for completed epoch
207#[derive(Debug, Clone)]
208pub struct EpochMetrics {
209    pub epoch: usize,
210    pub train_loss: f64,
211    pub val_loss: Option<f64>,
212    pub duration: Duration,
213}
214
215/// Resource usage statistics
216#[derive(Debug, Clone)]
217pub struct ResourceUsage {
218    pub cpu_percent: f64,
219    pub ram_gb: f64,
220    pub gpu_percent: Option<f64>,
221    pub gpu_memory_gb: Option<f64>,
222}
223
224impl ResourceUsage {
225    /// Get current resource usage
226    pub fn current() -> Self {
227        // Get RAM usage
228        let ram_gb = Self::get_memory_usage_gb();
229
230        // Get CPU usage (approximate using process info)
231        let cpu_percent = Self::get_cpu_usage_percent();
232
233        // GPU monitoring would require platform-specific APIs
234        // (CUDA for NVIDIA, Metal for macOS, etc.)
235        Self {
236            cpu_percent,
237            ram_gb,
238            gpu_percent: None,
239            gpu_memory_gb: None,
240        }
241    }
242
243    /// Get current memory usage in GB
244    #[cfg(target_os = "macos")]
245    fn get_memory_usage_gb() -> f64 {
246        use std::mem;
247
248        unsafe {
249            let mut info: libc::vm_statistics64 = mem::zeroed();
250            let mut count = (mem::size_of::<libc::vm_statistics64>()
251                / mem::size_of::<libc::integer_t>())
252                as libc::mach_msg_type_number_t;
253
254            let host_port = libc::mach_host_self();
255            let result = libc::host_statistics64(
256                host_port,
257                libc::HOST_VM_INFO64,
258                &mut info as *mut _ as *mut _,
259                &mut count,
260            );
261
262            if result == libc::KERN_SUCCESS {
263                let page_size = Self::get_page_size();
264                let used_memory =
265                    (info.active_count + info.inactive_count + info.wire_count) as u64 * page_size;
266                used_memory as f64 / 1_073_741_824.0 // Convert bytes to GB
267            } else {
268                0.0
269            }
270        }
271    }
272
273    #[cfg(target_os = "linux")]
274    fn get_memory_usage_gb() -> f64 {
275        // Read /proc/meminfo on Linux
276        if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
277            let mut total_kb = 0u64;
278            let mut available_kb = 0u64;
279
280            for line in content.lines() {
281                if line.starts_with("MemTotal:") {
282                    total_kb = line
283                        .split_whitespace()
284                        .nth(1)
285                        .and_then(|s| s.parse().ok())
286                        .unwrap_or(0);
287                } else if line.starts_with("MemAvailable:") {
288                    available_kb = line
289                        .split_whitespace()
290                        .nth(1)
291                        .and_then(|s| s.parse().ok())
292                        .unwrap_or(0);
293                }
294            }
295
296            if total_kb > 0 && available_kb > 0 {
297                let used_kb = total_kb - available_kb;
298                return used_kb as f64 / 1_048_576.0; // Convert KB to GB
299            }
300        }
301        0.0
302    }
303
304    #[cfg(not(any(target_os = "macos", target_os = "linux")))]
305    fn get_memory_usage_gb() -> f64 {
306        // Fallback for unsupported platforms
307        0.0
308    }
309
310    #[cfg(target_os = "macos")]
311    fn get_page_size() -> u64 {
312        unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u64 }
313    }
314
315    /// Get approximate CPU usage percent
316    fn get_cpu_usage_percent() -> f64 {
317        // Simple approximation: assume 50% usage during training
318        // For accurate measurement, would need to track process CPU time
319        // over intervals (would require sysinfo or similar crate)
320
321        // For now, return estimated load based on CPU count
322        let cpu_count = num_cpus::get();
323
324        // Estimate based on active training (typically uses 70-90% of available cores)
325        let estimated_usage = 75.0 * cpu_count as f64 / cpu_count as f64;
326        estimated_usage.min(100.0)
327    }
328}
329
330/// Training statistics summary
331#[derive(Debug, Clone)]
332pub struct TrainingStats {
333    pub total_duration: Duration,
334    pub epochs_completed: usize,
335    pub total_steps: usize,
336    pub final_train_loss: f64,
337    pub final_val_loss: Option<f64>,
338    pub best_val_loss: Option<f64>,
339    pub avg_samples_per_sec: f64,
340}
341
342/// Format duration as human-readable string
343fn format_duration(duration: Duration) -> String {
344    let secs = duration.as_secs();
345    if secs < 60 {
346        format!("{}s", secs)
347    } else if secs < 3600 {
348        format!("{}m {}s", secs / 60, secs % 60)
349    } else {
350        format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_format_duration() {
360        assert_eq!(format_duration(Duration::from_secs(45)), "45s");
361        assert_eq!(format_duration(Duration::from_secs(125)), "2m 5s");
362        assert_eq!(format_duration(Duration::from_secs(7325)), "2h 2m");
363    }
364}