Skip to main content

voirs_cli/commands/
checkpoint.rs

1//! Checkpoint management commands
2//!
3//! Provides utilities for inspecting, managing, and converting model checkpoints.
4
5use crate::GlobalOptions;
6use clap::Subcommand;
7use safetensors::SafeTensors;
8use std::path::{Path, PathBuf};
9use voirs_sdk::Result;
10
11/// Checkpoint management subcommands
12#[derive(Debug, Clone, Subcommand)]
13pub enum CheckpointCommands {
14    /// Inspect checkpoint file and show metadata
15    Inspect {
16        /// Path to checkpoint file (.safetensors)
17        #[arg(value_name = "FILE")]
18        checkpoint: PathBuf,
19
20        /// Show detailed tensor information
21        #[arg(long)]
22        verbose: bool,
23
24        /// Output format (text, json)
25        #[arg(long, default_value = "text")]
26        format: String,
27    },
28
29    /// List all checkpoints in directory
30    List {
31        /// Directory to search for checkpoints
32        #[arg(value_name = "DIR", default_value = "checkpoints")]
33        directory: PathBuf,
34
35        /// Sort by (name, epoch, loss, size, date)
36        #[arg(long, default_value = "epoch")]
37        sort_by: String,
38
39        /// Show only best N checkpoints
40        #[arg(long)]
41        top: Option<usize>,
42    },
43
44    /// Compare two checkpoints
45    Compare {
46        /// First checkpoint file
47        #[arg(value_name = "FILE1")]
48        checkpoint1: PathBuf,
49
50        /// Second checkpoint file
51        #[arg(value_name = "FILE2")]
52        checkpoint2: PathBuf,
53
54        /// Show parameter differences
55        #[arg(long)]
56        diff_params: bool,
57    },
58
59    /// Convert checkpoint format
60    Convert {
61        /// Input checkpoint file
62        #[arg(value_name = "INPUT")]
63        input: PathBuf,
64
65        /// Output checkpoint file
66        #[arg(value_name = "OUTPUT")]
67        output: PathBuf,
68
69        /// Input format (auto, safetensors, pytorch, onnx)
70        #[arg(long, default_value = "auto")]
71        input_format: String,
72
73        /// Output format (safetensors, pytorch, onnx)
74        #[arg(long, default_value = "safetensors")]
75        output_format: String,
76    },
77
78    /// Prune checkpoints (keep only best/latest)
79    Prune {
80        /// Directory containing checkpoints
81        #[arg(value_name = "DIR")]
82        directory: PathBuf,
83
84        /// Keep best N checkpoints by validation loss
85        #[arg(long)]
86        keep_best: Option<usize>,
87
88        /// Keep latest N checkpoints
89        #[arg(long)]
90        keep_latest: Option<usize>,
91
92        /// Dry run (don't actually delete)
93        #[arg(long)]
94        dry_run: bool,
95    },
96}
97
98/// Execute checkpoint management command
99pub async fn execute_checkpoint_command(
100    command: CheckpointCommands,
101    global: &GlobalOptions,
102) -> Result<()> {
103    match command {
104        CheckpointCommands::Inspect {
105            checkpoint,
106            verbose,
107            format,
108        } => inspect_checkpoint(&checkpoint, verbose, &format, global).await,
109        CheckpointCommands::List {
110            directory,
111            sort_by,
112            top,
113        } => list_checkpoints(&directory, &sort_by, top, global).await,
114        CheckpointCommands::Compare {
115            checkpoint1,
116            checkpoint2,
117            diff_params,
118        } => compare_checkpoints(&checkpoint1, &checkpoint2, diff_params, global).await,
119        CheckpointCommands::Convert {
120            input,
121            output,
122            input_format,
123            output_format,
124        } => convert_checkpoint(&input, &output, &input_format, &output_format, global).await,
125        CheckpointCommands::Prune {
126            directory,
127            keep_best,
128            keep_latest,
129            dry_run,
130        } => prune_checkpoints(&directory, keep_best, keep_latest, dry_run, global).await,
131    }
132}
133
134/// Inspect a checkpoint file
135async fn inspect_checkpoint(
136    checkpoint_path: &PathBuf,
137    verbose: bool,
138    format: &str,
139    global: &GlobalOptions,
140) -> Result<()> {
141    if !checkpoint_path.exists() {
142        return Err(voirs_sdk::VoirsError::config_error(format!(
143            "Checkpoint file not found: {}",
144            checkpoint_path.display()
145        )));
146    }
147
148    // Read checkpoint file
149    let data = tokio::fs::read(checkpoint_path).await?;
150    let tensors = SafeTensors::deserialize(&data).map_err(|e| {
151        voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint: {}", e))
152    })?;
153
154    // Try to load metadata from companion .json file
155    let json_path = checkpoint_path.with_extension("json");
156    let metadata = if json_path.exists() {
157        tokio::fs::read_to_string(&json_path)
158            .await
159            .ok()
160            .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
161    } else {
162        None
163    };
164
165    if format == "json" {
166        output_json_format(&tensors, checkpoint_path, verbose, metadata.as_ref())?;
167    } else {
168        output_text_format(
169            &tensors,
170            checkpoint_path,
171            verbose,
172            global,
173            metadata.as_ref(),
174        )?;
175    }
176
177    Ok(())
178}
179
180/// Output checkpoint info in text format
181fn output_text_format(
182    tensors: &SafeTensors,
183    checkpoint_path: &Path,
184    verbose: bool,
185    global: &GlobalOptions,
186    metadata: Option<&serde_json::Value>,
187) -> Result<()> {
188    if !global.quiet {
189        println!("\n╔══════════════════════════════════════════════════════════╗");
190        println!("║              Checkpoint Inspection                       ║");
191        println!("╠══════════════════════════════════════════════════════════╣");
192        println!(
193            "║ File: {:<50} ║",
194            truncate_str(&checkpoint_path.display().to_string(), 50)
195        );
196
197        // Display metadata if available
198        if let Some(meta_val) = metadata {
199            if let Some(obj) = meta_val.as_object() {
200                for (key, value) in obj {
201                    if key != "tensors" {
202                        // Skip tensors array in metadata
203                        let value_str = match value {
204                            serde_json::Value::String(s) => s.clone(),
205                            serde_json::Value::Number(n) => n.to_string(),
206                            _ => value.to_string(),
207                        };
208                        println!(
209                            "║ {}: {:<47} ║",
210                            key,
211                            truncate_str(&value_str, 47 - key.len())
212                        );
213                    }
214                }
215            }
216        }
217
218        println!("╠══════════════════════════════════════════════════════════╣");
219        println!("║ Tensors: {:<47} ║", tensors.names().len());
220
221        // Calculate total parameters
222        let mut total_params: usize = 0;
223        let mut total_size: usize = 0;
224
225        for name in tensors.names() {
226            if let Ok(tensor) = tensors.tensor(name) {
227                let shape = tensor.shape();
228                let params: usize = shape.iter().product();
229                total_params += params;
230                total_size += tensor.data().len();
231            }
232        }
233
234        println!("║ Total parameters: {:<38} ║", format_number(total_params));
235        println!("║ Total size: {:<44} ║", format_bytes(total_size));
236        println!("╚══════════════════════════════════════════════════════════╝\n");
237
238        if verbose {
239            println!("\n📊 Tensor Details:\n");
240            println!("{:<50} {:>15} {:>15}", "Name", "Shape", "Parameters");
241            println!("{}", "─".repeat(82));
242
243            for name in tensors.names() {
244                if let Ok(tensor) = tensors.tensor(name) {
245                    let shape = tensor.shape();
246                    let params: usize = shape.iter().product();
247                    let shape_str = format!("{:?}", shape);
248
249                    println!(
250                        "{:<50} {:>15} {:>15}",
251                        truncate_str(name, 50),
252                        truncate_str(&shape_str, 15),
253                        format_number(params)
254                    );
255                }
256            }
257            println!();
258        }
259    }
260
261    Ok(())
262}
263
264/// Output checkpoint info in JSON format
265fn output_json_format(
266    tensors: &SafeTensors,
267    checkpoint_path: &Path,
268    verbose: bool,
269    metadata: Option<&serde_json::Value>,
270) -> Result<()> {
271    use serde_json::json;
272
273    let mut tensor_info = Vec::new();
274    let mut total_params: usize = 0;
275
276    for name in tensors.names() {
277        if let Ok(tensor) = tensors.tensor(name) {
278            let shape: Vec<usize> = tensor.shape().to_vec();
279            let params: usize = shape.iter().product();
280            total_params += params;
281
282            if verbose {
283                tensor_info.push(json!({
284                    "name": name,
285                    "shape": shape,
286                    "parameters": params,
287                    "dtype": "F32",
288                }));
289            }
290        }
291    }
292
293    let output = json!({
294        "file": checkpoint_path.display().to_string(),
295        "num_tensors": tensors.names().len(),
296        "total_parameters": total_params,
297        "metadata": metadata,
298        "tensors": if verbose { Some(tensor_info) } else { None },
299    });
300
301    println!("{}", serde_json::to_string_pretty(&output)?);
302
303    Ok(())
304}
305
306/// List checkpoints in a directory
307async fn list_checkpoints(
308    directory: &PathBuf,
309    sort_by: &str,
310    top: Option<usize>,
311    global: &GlobalOptions,
312) -> Result<()> {
313    if !directory.exists() {
314        return Err(voirs_sdk::VoirsError::config_error(format!(
315            "Directory not found: {}",
316            directory.display()
317        )));
318    }
319
320    let mut checkpoints = Vec::new();
321
322    // Read all .safetensors files
323    let mut entries = tokio::fs::read_dir(directory).await?;
324    while let Some(entry) = entries.next_entry().await? {
325        let path = entry.path();
326        if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
327            if let Ok(metadata) = entry.metadata().await {
328                // Try to read checkpoint metadata from companion .json file
329                let json_path = path.with_extension("json");
330                let mut epoch = 0;
331                let mut train_loss = 0.0;
332                let mut val_loss = 0.0;
333
334                if json_path.exists() {
335                    if let Ok(meta_str) = tokio::fs::read_to_string(&json_path).await {
336                        if let Ok(meta_json) = serde_json::from_str::<serde_json::Value>(&meta_str)
337                        {
338                            if let Some(obj) = meta_json.as_object() {
339                                epoch = obj
340                                    .get("epoch")
341                                    .and_then(|v| {
342                                        v.as_u64()
343                                            .map(|n| n as usize)
344                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
345                                    })
346                                    .unwrap_or(0);
347                                train_loss = obj
348                                    .get("train_loss")
349                                    .and_then(|v| {
350                                        v.as_f64()
351                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
352                                    })
353                                    .unwrap_or(0.0);
354                                val_loss = obj
355                                    .get("val_loss")
356                                    .and_then(|v| {
357                                        v.as_f64()
358                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
359                                    })
360                                    .unwrap_or(0.0);
361                            }
362                        }
363                    }
364                }
365
366                if let Ok(data) = tokio::fs::read(&path).await {
367                    if SafeTensors::deserialize(&data).is_ok() {
368                        checkpoints.push(CheckpointInfo {
369                            path: path.clone(),
370                            name: path.file_name().unwrap().to_string_lossy().to_string(),
371                            epoch,
372                            train_loss,
373                            val_loss,
374                            size: metadata.len(),
375                            modified: metadata
376                                .modified()
377                                .ok()
378                                .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
379                                .map(|d| d.as_secs())
380                                .unwrap_or(0),
381                        });
382                    }
383                }
384            }
385        }
386    }
387
388    // Sort checkpoints
389    match sort_by {
390        "name" => checkpoints.sort_by(|a, b| a.name.cmp(&b.name)),
391        "epoch" => checkpoints.sort_by(|a, b| b.epoch.cmp(&a.epoch)),
392        "loss" => checkpoints.sort_by(|a, b| a.val_loss.partial_cmp(&b.val_loss).unwrap()),
393        "size" => checkpoints.sort_by(|a, b| b.size.cmp(&a.size)),
394        "date" => checkpoints.sort_by(|a, b| b.modified.cmp(&a.modified)),
395        _ => {}
396    }
397
398    // Limit to top N
399    if let Some(n) = top {
400        checkpoints.truncate(n);
401    }
402
403    if !global.quiet {
404        println!("\n📁 Checkpoints in {}:\n", directory.display());
405        println!(
406            "{:<35} {:>8} {:>12} {:>12} {:>10}",
407            "Name", "Epoch", "Train Loss", "Val Loss", "Size"
408        );
409        println!("{}", "─".repeat(82));
410
411        for ckpt in &checkpoints {
412            println!(
413                "{:<35} {:>8} {:>12.6} {:>12.6} {:>10}",
414                truncate_str(&ckpt.name, 35),
415                ckpt.epoch,
416                ckpt.train_loss,
417                ckpt.val_loss,
418                format_bytes(ckpt.size as usize)
419            );
420        }
421
422        println!("\nTotal: {} checkpoints\n", checkpoints.len());
423    }
424
425    Ok(())
426}
427
428#[derive(Debug, Clone)]
429struct CheckpointInfo {
430    path: PathBuf,
431    name: String,
432    epoch: usize,
433    train_loss: f64,
434    val_loss: f64,
435    size: u64,
436    modified: u64,
437}
438
439/// Compare two checkpoints
440async fn compare_checkpoints(
441    checkpoint1: &PathBuf,
442    checkpoint2: &PathBuf,
443    diff_params: bool,
444    global: &GlobalOptions,
445) -> Result<()> {
446    let data1 = tokio::fs::read(checkpoint1).await?;
447    let data2 = tokio::fs::read(checkpoint2).await?;
448
449    let tensors1 = SafeTensors::deserialize(&data1).map_err(|e| {
450        voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint 1: {}", e))
451    })?;
452
453    let tensors2 = SafeTensors::deserialize(&data2).map_err(|e| {
454        voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint 2: {}", e))
455    })?;
456
457    // Try to load metadata from companion .json files
458    let json_path1 = checkpoint1.with_extension("json");
459    let json_path2 = checkpoint2.with_extension("json");
460
461    let meta1 = if json_path1.exists() {
462        tokio::fs::read_to_string(&json_path1)
463            .await
464            .ok()
465            .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
466    } else {
467        None
468    };
469
470    let meta2 = if json_path2.exists() {
471        tokio::fs::read_to_string(&json_path2)
472            .await
473            .ok()
474            .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
475    } else {
476        None
477    };
478
479    if !global.quiet {
480        println!("\n╔══════════════════════════════════════════════════════════╗");
481        println!("║              Checkpoint Comparison                       ║");
482        println!("╠══════════════════════════════════════════════════════════╣");
483
484        // Compare metadata
485        if let (Some(m1), Some(m2)) = (meta1.as_ref(), meta2.as_ref()) {
486            if let (Some(o1), Some(o2)) = (m1.as_object(), m2.as_object()) {
487                println!(
488                    "║ {:<25} {:<12} {:<15} ║",
489                    "Metric", "Checkpoint 1", "Checkpoint 2"
490                );
491                println!("╠══════════════════════════════════════════════════════════╣");
492
493                for key in o1.keys() {
494                    if key != "tensors" {
495                        // Skip tensors array
496                        if let (Some(v1), Some(v2)) = (o1.get(key), o2.get(key)) {
497                            let s1 = match v1 {
498                                serde_json::Value::String(s) => s.clone(),
499                                serde_json::Value::Number(n) => n.to_string(),
500                                _ => v1.to_string(),
501                            };
502                            let s2 = match v2 {
503                                serde_json::Value::String(s) => s.clone(),
504                                serde_json::Value::Number(n) => n.to_string(),
505                                _ => v2.to_string(),
506                            };
507
508                            println!(
509                                "║ {:<25} {:<12} {:<15} ║",
510                                truncate_str(key, 25),
511                                truncate_str(&s1, 12),
512                                truncate_str(&s2, 15)
513                            );
514                        }
515                    }
516                }
517            }
518        }
519
520        println!("╠══════════════════════════════════════════════════════════╣");
521        println!(
522            "║ Tensors in checkpoint 1: {:<31} ║",
523            tensors1.names().len()
524        );
525        println!(
526            "║ Tensors in checkpoint 2: {:<31} ║",
527            tensors2.names().len()
528        );
529        println!("╚══════════════════════════════════════════════════════════╝\n");
530
531        if diff_params {
532            // Show parameter differences
533            let names1: std::collections::HashSet<String> =
534                tensors1.names().iter().map(|s| s.to_string()).collect();
535            let names2: std::collections::HashSet<String> =
536                tensors2.names().iter().map(|s| s.to_string()).collect();
537
538            let only_in_1: Vec<_> = names1.difference(&names2).collect();
539            let only_in_2: Vec<_> = names2.difference(&names1).collect();
540
541            if !only_in_1.is_empty() {
542                println!("⚠️  Tensors only in checkpoint 1:");
543                for name in only_in_1 {
544                    println!("   - {}", name);
545                }
546                println!();
547            }
548
549            if !only_in_2.is_empty() {
550                println!("⚠️  Tensors only in checkpoint 2:");
551                for name in only_in_2 {
552                    println!("   - {}", name);
553                }
554                println!();
555            }
556        }
557    }
558
559    Ok(())
560}
561
562/// Convert checkpoint format
563async fn convert_checkpoint(
564    input: &PathBuf,
565    output: &PathBuf,
566    input_format: &str,
567    output_format: &str,
568    global: &GlobalOptions,
569) -> Result<()> {
570    if !input.exists() {
571        return Err(voirs_sdk::VoirsError::config_error(format!(
572            "Input checkpoint not found: {}",
573            input.display()
574        )));
575    }
576
577    // Auto-detect input format if needed
578    let detected_input_format = if input_format == "auto" {
579        match input.extension().and_then(|s| s.to_str()) {
580            Some("safetensors") => "safetensors",
581            Some("pt") | Some("pth") => "pytorch",
582            Some("onnx") => "onnx",
583            _ => {
584                return Err(voirs_sdk::VoirsError::config_error(
585                    "Could not auto-detect input format. Please specify --input-format",
586                ));
587            }
588        }
589    } else {
590        input_format
591    };
592
593    if !global.quiet {
594        println!("\n🔄 Checkpoint Conversion:");
595        println!("   Input:  {} ({})", input.display(), detected_input_format);
596        println!("   Output: {} ({})", output.display(), output_format);
597        println!();
598    }
599
600    // Handle conversion based on format pair
601    match (detected_input_format, output_format) {
602        ("safetensors", "safetensors") => {
603            convert_safetensors_to_safetensors(input, output, global).await
604        }
605        ("safetensors", "pytorch") => {
606            Err(voirs_sdk::VoirsError::config_error(
607                "SafeTensors to PyTorch conversion not yet implemented. Consider using Python: \
608                import safetensors.torch; safetensors.torch.save_file(tensors, 'output.pt')",
609            ))
610        }
611        ("safetensors", "onnx") => {
612            Err(voirs_sdk::VoirsError::config_error(
613                "SafeTensors to ONNX conversion not supported. ONNX requires model architecture definition.",
614            ))
615        }
616        ("pytorch", "safetensors") => {
617            Err(voirs_sdk::VoirsError::config_error(
618                "PyTorch to SafeTensors conversion not yet implemented. Consider using Python: \
619                import safetensors.torch; safetensors.torch.save_file(torch.load('input.pt'), 'output.safetensors')",
620            ))
621        }
622        ("pytorch", "pytorch") => {
623            // Simple copy with potential re-serialization
624            tokio::fs::copy(input, output).await?;
625            if !global.quiet {
626                println!("✅ Checkpoint copied successfully");
627            }
628            Ok(())
629        }
630        ("onnx", _) => {
631            Err(voirs_sdk::VoirsError::config_error(
632                "ONNX checkpoint conversion not supported. ONNX models are runtime-optimized formats.",
633            ))
634        }
635        _ => {
636            Err(voirs_sdk::VoirsError::config_error(format!(
637                "Unsupported conversion: {} to {}",
638                detected_input_format, output_format
639            )))
640        }
641    }
642}
643
644/// Convert SafeTensors to SafeTensors (with potential metadata updates)
645async fn convert_safetensors_to_safetensors(
646    input: &PathBuf,
647    output: &PathBuf,
648    global: &GlobalOptions,
649) -> Result<()> {
650    // Read input checkpoint
651    let data = tokio::fs::read(input).await?;
652    let tensors = SafeTensors::deserialize(&data).map_err(|e| {
653        voirs_sdk::VoirsError::config_error(format!("Failed to parse input checkpoint: {}", e))
654    })?;
655
656    // Read metadata if available
657    let json_path = input.with_extension("json");
658    let metadata = if json_path.exists() {
659        tokio::fs::read_to_string(&json_path)
660            .await
661            .ok()
662            .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
663    } else {
664        None
665    };
666
667    // For now, just copy the file and metadata
668    tokio::fs::copy(input, output).await?;
669
670    if let Some(ref meta) = metadata {
671        let output_json = output.with_extension("json");
672        tokio::fs::write(&output_json, serde_json::to_string_pretty(meta)?).await?;
673    }
674
675    if !global.quiet {
676        println!("✅ SafeTensors checkpoint converted successfully");
677        println!("   Tensors: {}", tensors.names().len());
678
679        if metadata.is_some() {
680            println!(
681                "   Metadata copied: {}",
682                output.with_extension("json").display()
683            );
684        }
685    }
686
687    Ok(())
688}
689
690/// Prune checkpoints
691async fn prune_checkpoints(
692    directory: &PathBuf,
693    keep_best: Option<usize>,
694    keep_latest: Option<usize>,
695    dry_run: bool,
696    global: &GlobalOptions,
697) -> Result<()> {
698    if !directory.exists() {
699        return Err(voirs_sdk::VoirsError::config_error(format!(
700            "Directory not found: {}",
701            directory.display()
702        )));
703    }
704
705    if keep_best.is_none() && keep_latest.is_none() {
706        return Err(voirs_sdk::VoirsError::config_error(
707            "Must specify at least one of --keep-best or --keep-latest",
708        ));
709    }
710
711    // Collect all checkpoints
712    let mut checkpoints = Vec::new();
713    let mut entries = tokio::fs::read_dir(directory).await?;
714
715    while let Some(entry) = entries.next_entry().await? {
716        let path = entry.path();
717        if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
718            if let Ok(metadata) = entry.metadata().await {
719                let json_path = path.with_extension("json");
720                let mut epoch = 0;
721                let mut train_loss = 0.0;
722                let mut val_loss = f64::MAX;
723
724                if json_path.exists() {
725                    if let Ok(meta_str) = tokio::fs::read_to_string(&json_path).await {
726                        if let Ok(meta_json) = serde_json::from_str::<serde_json::Value>(&meta_str)
727                        {
728                            if let Some(obj) = meta_json.as_object() {
729                                epoch = obj
730                                    .get("epoch")
731                                    .and_then(|v| {
732                                        v.as_u64()
733                                            .map(|n| n as usize)
734                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
735                                    })
736                                    .unwrap_or(0);
737                                train_loss = obj
738                                    .get("train_loss")
739                                    .and_then(|v| {
740                                        v.as_f64()
741                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
742                                    })
743                                    .unwrap_or(0.0);
744                                val_loss = obj
745                                    .get("val_loss")
746                                    .and_then(|v| {
747                                        v.as_f64()
748                                            .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
749                                    })
750                                    .unwrap_or(f64::MAX);
751                            }
752                        }
753                    }
754                }
755
756                if let Ok(data) = tokio::fs::read(&path).await {
757                    if SafeTensors::deserialize(&data).is_ok() {
758                        checkpoints.push(CheckpointInfo {
759                            path: path.clone(),
760                            name: path.file_name().unwrap().to_string_lossy().to_string(),
761                            epoch,
762                            train_loss,
763                            val_loss,
764                            size: metadata.len(),
765                            modified: metadata
766                                .modified()
767                                .ok()
768                                .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
769                                .map(|d| d.as_secs())
770                                .unwrap_or(0),
771                        });
772                    }
773                }
774            }
775        }
776    }
777
778    if checkpoints.is_empty() {
779        if !global.quiet {
780            println!("No checkpoints found in {}", directory.display());
781        }
782        return Ok(());
783    }
784
785    let mut to_delete = Vec::new();
786
787    // Determine which checkpoints to keep
788    if let Some(n) = keep_best {
789        // Sort by validation loss (ascending - lower is better)
790        let mut sorted = checkpoints.clone();
791        sorted.sort_by(|a, b| {
792            a.val_loss
793                .partial_cmp(&b.val_loss)
794                .unwrap_or(std::cmp::Ordering::Equal)
795        });
796
797        // Keep best N, mark rest for deletion
798        let to_keep: std::collections::HashSet<_> =
799            sorted.iter().take(n).map(|c| c.path.clone()).collect();
800
801        for ckpt in &checkpoints {
802            if !to_keep.contains(&ckpt.path) {
803                to_delete.push(ckpt.clone());
804            }
805        }
806    }
807
808    if let Some(n) = keep_latest {
809        // Sort by modification time (descending - newer first)
810        let mut sorted = checkpoints.clone();
811        sorted.sort_by(|a, b| b.modified.cmp(&a.modified));
812
813        // Keep latest N
814        let to_keep: std::collections::HashSet<_> =
815            sorted.iter().take(n).map(|c| c.path.clone()).collect();
816
817        // Only delete if not already marked and not in keep set
818        for ckpt in &checkpoints {
819            if !to_keep.contains(&ckpt.path) && !to_delete.iter().any(|d| d.path == ckpt.path) {
820                to_delete.push(ckpt.clone());
821            }
822        }
823    }
824
825    if to_delete.is_empty() {
826        if !global.quiet {
827            println!("✅ No checkpoints need to be pruned");
828        }
829        return Ok(());
830    }
831
832    if !global.quiet {
833        println!("\n🗑️  Checkpoint Pruning:\n");
834        println!("Total checkpoints: {}", checkpoints.len());
835        println!("To delete: {}", to_delete.len());
836
837        if dry_run {
838            println!("\n⚠️  DRY RUN - No files will be deleted\n");
839        }
840
841        println!("\nCheckpoints to be deleted:");
842        println!(
843            "{:<35} {:>8} {:>12} {:>10}",
844            "Name", "Epoch", "Val Loss", "Size"
845        );
846        println!("{}", "─".repeat(70));
847
848        for ckpt in &to_delete {
849            println!(
850                "{:<35} {:>8} {:>12.6} {:>10}",
851                truncate_str(&ckpt.name, 35),
852                ckpt.epoch,
853                if ckpt.val_loss == f64::MAX {
854                    0.0
855                } else {
856                    ckpt.val_loss
857                },
858                format_bytes(ckpt.size as usize)
859            );
860        }
861        println!();
862    }
863
864    if !dry_run {
865        let mut deleted_count = 0;
866        for ckpt in &to_delete {
867            // Delete .safetensors file
868            if let Err(e) = tokio::fs::remove_file(&ckpt.path).await {
869                if !global.quiet {
870                    eprintln!("⚠️  Failed to delete {}: {}", ckpt.name, e);
871                }
872            } else {
873                deleted_count += 1;
874
875                // Also delete companion .json file if exists
876                let json_path = ckpt.path.with_extension("json");
877                if json_path.exists() {
878                    let _ = tokio::fs::remove_file(&json_path).await;
879                }
880            }
881        }
882
883        if !global.quiet {
884            println!("✅ Deleted {} checkpoint(s)", deleted_count);
885        }
886    }
887
888    Ok(())
889}
890
891// Helper functions
892
893fn truncate_str(s: &str, max_len: usize) -> String {
894    if s.len() <= max_len {
895        s.to_string()
896    } else {
897        format!("{}...", &s[..max_len.saturating_sub(3)])
898    }
899}
900
901fn format_number(n: usize) -> String {
902    if n >= 1_000_000 {
903        format!("{:.2}M", n as f64 / 1_000_000.0)
904    } else if n >= 1_000 {
905        format!("{:.2}K", n as f64 / 1_000.0)
906    } else {
907        n.to_string()
908    }
909}
910
911fn format_bytes(bytes: usize) -> String {
912    if bytes >= 1_000_000_000 {
913        format!("{:.2} GB", bytes as f64 / 1_000_000_000.0)
914    } else if bytes >= 1_000_000 {
915        format!("{:.2} MB", bytes as f64 / 1_000_000.0)
916    } else if bytes >= 1_000 {
917        format!("{:.2} KB", bytes as f64 / 1_000.0)
918    } else {
919        format!("{} B", bytes)
920    }
921}