1use crate::GlobalOptions;
6use clap::Subcommand;
7use safetensors::SafeTensors;
8use std::path::{Path, PathBuf};
9use voirs_sdk::Result;
10
11#[derive(Debug, Clone, Subcommand)]
13pub enum CheckpointCommands {
14 Inspect {
16 #[arg(value_name = "FILE")]
18 checkpoint: PathBuf,
19
20 #[arg(long)]
22 verbose: bool,
23
24 #[arg(long, default_value = "text")]
26 format: String,
27 },
28
29 List {
31 #[arg(value_name = "DIR", default_value = "checkpoints")]
33 directory: PathBuf,
34
35 #[arg(long, default_value = "epoch")]
37 sort_by: String,
38
39 #[arg(long)]
41 top: Option<usize>,
42 },
43
44 Compare {
46 #[arg(value_name = "FILE1")]
48 checkpoint1: PathBuf,
49
50 #[arg(value_name = "FILE2")]
52 checkpoint2: PathBuf,
53
54 #[arg(long)]
56 diff_params: bool,
57 },
58
59 Convert {
61 #[arg(value_name = "INPUT")]
63 input: PathBuf,
64
65 #[arg(value_name = "OUTPUT")]
67 output: PathBuf,
68
69 #[arg(long, default_value = "auto")]
71 input_format: String,
72
73 #[arg(long, default_value = "safetensors")]
75 output_format: String,
76 },
77
78 Prune {
80 #[arg(value_name = "DIR")]
82 directory: PathBuf,
83
84 #[arg(long)]
86 keep_best: Option<usize>,
87
88 #[arg(long)]
90 keep_latest: Option<usize>,
91
92 #[arg(long)]
94 dry_run: bool,
95 },
96}
97
98pub async fn execute_checkpoint_command(
100 command: CheckpointCommands,
101 global: &GlobalOptions,
102) -> Result<()> {
103 match command {
104 CheckpointCommands::Inspect {
105 checkpoint,
106 verbose,
107 format,
108 } => inspect_checkpoint(&checkpoint, verbose, &format, global).await,
109 CheckpointCommands::List {
110 directory,
111 sort_by,
112 top,
113 } => list_checkpoints(&directory, &sort_by, top, global).await,
114 CheckpointCommands::Compare {
115 checkpoint1,
116 checkpoint2,
117 diff_params,
118 } => compare_checkpoints(&checkpoint1, &checkpoint2, diff_params, global).await,
119 CheckpointCommands::Convert {
120 input,
121 output,
122 input_format,
123 output_format,
124 } => convert_checkpoint(&input, &output, &input_format, &output_format, global).await,
125 CheckpointCommands::Prune {
126 directory,
127 keep_best,
128 keep_latest,
129 dry_run,
130 } => prune_checkpoints(&directory, keep_best, keep_latest, dry_run, global).await,
131 }
132}
133
134async fn inspect_checkpoint(
136 checkpoint_path: &PathBuf,
137 verbose: bool,
138 format: &str,
139 global: &GlobalOptions,
140) -> Result<()> {
141 if !checkpoint_path.exists() {
142 return Err(voirs_sdk::VoirsError::config_error(format!(
143 "Checkpoint file not found: {}",
144 checkpoint_path.display()
145 )));
146 }
147
148 let data = tokio::fs::read(checkpoint_path).await?;
150 let tensors = SafeTensors::deserialize(&data).map_err(|e| {
151 voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint: {}", e))
152 })?;
153
154 let json_path = checkpoint_path.with_extension("json");
156 let metadata = if json_path.exists() {
157 tokio::fs::read_to_string(&json_path)
158 .await
159 .ok()
160 .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
161 } else {
162 None
163 };
164
165 if format == "json" {
166 output_json_format(&tensors, checkpoint_path, verbose, metadata.as_ref())?;
167 } else {
168 output_text_format(
169 &tensors,
170 checkpoint_path,
171 verbose,
172 global,
173 metadata.as_ref(),
174 )?;
175 }
176
177 Ok(())
178}
179
180fn output_text_format(
182 tensors: &SafeTensors,
183 checkpoint_path: &Path,
184 verbose: bool,
185 global: &GlobalOptions,
186 metadata: Option<&serde_json::Value>,
187) -> Result<()> {
188 if !global.quiet {
189 println!("\n╔══════════════════════════════════════════════════════════╗");
190 println!("║ Checkpoint Inspection ║");
191 println!("╠══════════════════════════════════════════════════════════╣");
192 println!(
193 "║ File: {:<50} ║",
194 truncate_str(&checkpoint_path.display().to_string(), 50)
195 );
196
197 if let Some(meta_val) = metadata {
199 if let Some(obj) = meta_val.as_object() {
200 for (key, value) in obj {
201 if key != "tensors" {
202 let value_str = match value {
204 serde_json::Value::String(s) => s.clone(),
205 serde_json::Value::Number(n) => n.to_string(),
206 _ => value.to_string(),
207 };
208 println!(
209 "║ {}: {:<47} ║",
210 key,
211 truncate_str(&value_str, 47 - key.len())
212 );
213 }
214 }
215 }
216 }
217
218 println!("╠══════════════════════════════════════════════════════════╣");
219 println!("║ Tensors: {:<47} ║", tensors.names().len());
220
221 let mut total_params: usize = 0;
223 let mut total_size: usize = 0;
224
225 for name in tensors.names() {
226 if let Ok(tensor) = tensors.tensor(name) {
227 let shape = tensor.shape();
228 let params: usize = shape.iter().product();
229 total_params += params;
230 total_size += tensor.data().len();
231 }
232 }
233
234 println!("║ Total parameters: {:<38} ║", format_number(total_params));
235 println!("║ Total size: {:<44} ║", format_bytes(total_size));
236 println!("╚══════════════════════════════════════════════════════════╝\n");
237
238 if verbose {
239 println!("\n📊 Tensor Details:\n");
240 println!("{:<50} {:>15} {:>15}", "Name", "Shape", "Parameters");
241 println!("{}", "─".repeat(82));
242
243 for name in tensors.names() {
244 if let Ok(tensor) = tensors.tensor(name) {
245 let shape = tensor.shape();
246 let params: usize = shape.iter().product();
247 let shape_str = format!("{:?}", shape);
248
249 println!(
250 "{:<50} {:>15} {:>15}",
251 truncate_str(name, 50),
252 truncate_str(&shape_str, 15),
253 format_number(params)
254 );
255 }
256 }
257 println!();
258 }
259 }
260
261 Ok(())
262}
263
264fn output_json_format(
266 tensors: &SafeTensors,
267 checkpoint_path: &Path,
268 verbose: bool,
269 metadata: Option<&serde_json::Value>,
270) -> Result<()> {
271 use serde_json::json;
272
273 let mut tensor_info = Vec::new();
274 let mut total_params: usize = 0;
275
276 for name in tensors.names() {
277 if let Ok(tensor) = tensors.tensor(name) {
278 let shape: Vec<usize> = tensor.shape().to_vec();
279 let params: usize = shape.iter().product();
280 total_params += params;
281
282 if verbose {
283 tensor_info.push(json!({
284 "name": name,
285 "shape": shape,
286 "parameters": params,
287 "dtype": "F32",
288 }));
289 }
290 }
291 }
292
293 let output = json!({
294 "file": checkpoint_path.display().to_string(),
295 "num_tensors": tensors.names().len(),
296 "total_parameters": total_params,
297 "metadata": metadata,
298 "tensors": if verbose { Some(tensor_info) } else { None },
299 });
300
301 println!("{}", serde_json::to_string_pretty(&output)?);
302
303 Ok(())
304}
305
306async fn list_checkpoints(
308 directory: &PathBuf,
309 sort_by: &str,
310 top: Option<usize>,
311 global: &GlobalOptions,
312) -> Result<()> {
313 if !directory.exists() {
314 return Err(voirs_sdk::VoirsError::config_error(format!(
315 "Directory not found: {}",
316 directory.display()
317 )));
318 }
319
320 let mut checkpoints = Vec::new();
321
322 let mut entries = tokio::fs::read_dir(directory).await?;
324 while let Some(entry) = entries.next_entry().await? {
325 let path = entry.path();
326 if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
327 if let Ok(metadata) = entry.metadata().await {
328 let json_path = path.with_extension("json");
330 let mut epoch = 0;
331 let mut train_loss = 0.0;
332 let mut val_loss = 0.0;
333
334 if json_path.exists() {
335 if let Ok(meta_str) = tokio::fs::read_to_string(&json_path).await {
336 if let Ok(meta_json) = serde_json::from_str::<serde_json::Value>(&meta_str)
337 {
338 if let Some(obj) = meta_json.as_object() {
339 epoch = obj
340 .get("epoch")
341 .and_then(|v| {
342 v.as_u64()
343 .map(|n| n as usize)
344 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
345 })
346 .unwrap_or(0);
347 train_loss = obj
348 .get("train_loss")
349 .and_then(|v| {
350 v.as_f64()
351 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
352 })
353 .unwrap_or(0.0);
354 val_loss = obj
355 .get("val_loss")
356 .and_then(|v| {
357 v.as_f64()
358 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
359 })
360 .unwrap_or(0.0);
361 }
362 }
363 }
364 }
365
366 if let Ok(data) = tokio::fs::read(&path).await {
367 if SafeTensors::deserialize(&data).is_ok() {
368 checkpoints.push(CheckpointInfo {
369 path: path.clone(),
370 name: path.file_name().unwrap().to_string_lossy().to_string(),
371 epoch,
372 train_loss,
373 val_loss,
374 size: metadata.len(),
375 modified: metadata
376 .modified()
377 .ok()
378 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
379 .map(|d| d.as_secs())
380 .unwrap_or(0),
381 });
382 }
383 }
384 }
385 }
386 }
387
388 match sort_by {
390 "name" => checkpoints.sort_by(|a, b| a.name.cmp(&b.name)),
391 "epoch" => checkpoints.sort_by(|a, b| b.epoch.cmp(&a.epoch)),
392 "loss" => checkpoints.sort_by(|a, b| a.val_loss.partial_cmp(&b.val_loss).unwrap()),
393 "size" => checkpoints.sort_by(|a, b| b.size.cmp(&a.size)),
394 "date" => checkpoints.sort_by(|a, b| b.modified.cmp(&a.modified)),
395 _ => {}
396 }
397
398 if let Some(n) = top {
400 checkpoints.truncate(n);
401 }
402
403 if !global.quiet {
404 println!("\n📁 Checkpoints in {}:\n", directory.display());
405 println!(
406 "{:<35} {:>8} {:>12} {:>12} {:>10}",
407 "Name", "Epoch", "Train Loss", "Val Loss", "Size"
408 );
409 println!("{}", "─".repeat(82));
410
411 for ckpt in &checkpoints {
412 println!(
413 "{:<35} {:>8} {:>12.6} {:>12.6} {:>10}",
414 truncate_str(&ckpt.name, 35),
415 ckpt.epoch,
416 ckpt.train_loss,
417 ckpt.val_loss,
418 format_bytes(ckpt.size as usize)
419 );
420 }
421
422 println!("\nTotal: {} checkpoints\n", checkpoints.len());
423 }
424
425 Ok(())
426}
427
428#[derive(Debug, Clone)]
429struct CheckpointInfo {
430 path: PathBuf,
431 name: String,
432 epoch: usize,
433 train_loss: f64,
434 val_loss: f64,
435 size: u64,
436 modified: u64,
437}
438
439async fn compare_checkpoints(
441 checkpoint1: &PathBuf,
442 checkpoint2: &PathBuf,
443 diff_params: bool,
444 global: &GlobalOptions,
445) -> Result<()> {
446 let data1 = tokio::fs::read(checkpoint1).await?;
447 let data2 = tokio::fs::read(checkpoint2).await?;
448
449 let tensors1 = SafeTensors::deserialize(&data1).map_err(|e| {
450 voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint 1: {}", e))
451 })?;
452
453 let tensors2 = SafeTensors::deserialize(&data2).map_err(|e| {
454 voirs_sdk::VoirsError::config_error(format!("Failed to parse checkpoint 2: {}", e))
455 })?;
456
457 let json_path1 = checkpoint1.with_extension("json");
459 let json_path2 = checkpoint2.with_extension("json");
460
461 let meta1 = if json_path1.exists() {
462 tokio::fs::read_to_string(&json_path1)
463 .await
464 .ok()
465 .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
466 } else {
467 None
468 };
469
470 let meta2 = if json_path2.exists() {
471 tokio::fs::read_to_string(&json_path2)
472 .await
473 .ok()
474 .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
475 } else {
476 None
477 };
478
479 if !global.quiet {
480 println!("\n╔══════════════════════════════════════════════════════════╗");
481 println!("║ Checkpoint Comparison ║");
482 println!("╠══════════════════════════════════════════════════════════╣");
483
484 if let (Some(m1), Some(m2)) = (meta1.as_ref(), meta2.as_ref()) {
486 if let (Some(o1), Some(o2)) = (m1.as_object(), m2.as_object()) {
487 println!(
488 "║ {:<25} {:<12} {:<15} ║",
489 "Metric", "Checkpoint 1", "Checkpoint 2"
490 );
491 println!("╠══════════════════════════════════════════════════════════╣");
492
493 for key in o1.keys() {
494 if key != "tensors" {
495 if let (Some(v1), Some(v2)) = (o1.get(key), o2.get(key)) {
497 let s1 = match v1 {
498 serde_json::Value::String(s) => s.clone(),
499 serde_json::Value::Number(n) => n.to_string(),
500 _ => v1.to_string(),
501 };
502 let s2 = match v2 {
503 serde_json::Value::String(s) => s.clone(),
504 serde_json::Value::Number(n) => n.to_string(),
505 _ => v2.to_string(),
506 };
507
508 println!(
509 "║ {:<25} {:<12} {:<15} ║",
510 truncate_str(key, 25),
511 truncate_str(&s1, 12),
512 truncate_str(&s2, 15)
513 );
514 }
515 }
516 }
517 }
518 }
519
520 println!("╠══════════════════════════════════════════════════════════╣");
521 println!(
522 "║ Tensors in checkpoint 1: {:<31} ║",
523 tensors1.names().len()
524 );
525 println!(
526 "║ Tensors in checkpoint 2: {:<31} ║",
527 tensors2.names().len()
528 );
529 println!("╚══════════════════════════════════════════════════════════╝\n");
530
531 if diff_params {
532 let names1: std::collections::HashSet<String> =
534 tensors1.names().iter().map(|s| s.to_string()).collect();
535 let names2: std::collections::HashSet<String> =
536 tensors2.names().iter().map(|s| s.to_string()).collect();
537
538 let only_in_1: Vec<_> = names1.difference(&names2).collect();
539 let only_in_2: Vec<_> = names2.difference(&names1).collect();
540
541 if !only_in_1.is_empty() {
542 println!("⚠️ Tensors only in checkpoint 1:");
543 for name in only_in_1 {
544 println!(" - {}", name);
545 }
546 println!();
547 }
548
549 if !only_in_2.is_empty() {
550 println!("⚠️ Tensors only in checkpoint 2:");
551 for name in only_in_2 {
552 println!(" - {}", name);
553 }
554 println!();
555 }
556 }
557 }
558
559 Ok(())
560}
561
562async fn convert_checkpoint(
564 input: &PathBuf,
565 output: &PathBuf,
566 input_format: &str,
567 output_format: &str,
568 global: &GlobalOptions,
569) -> Result<()> {
570 if !input.exists() {
571 return Err(voirs_sdk::VoirsError::config_error(format!(
572 "Input checkpoint not found: {}",
573 input.display()
574 )));
575 }
576
577 let detected_input_format = if input_format == "auto" {
579 match input.extension().and_then(|s| s.to_str()) {
580 Some("safetensors") => "safetensors",
581 Some("pt") | Some("pth") => "pytorch",
582 Some("onnx") => "onnx",
583 _ => {
584 return Err(voirs_sdk::VoirsError::config_error(
585 "Could not auto-detect input format. Please specify --input-format",
586 ));
587 }
588 }
589 } else {
590 input_format
591 };
592
593 if !global.quiet {
594 println!("\n🔄 Checkpoint Conversion:");
595 println!(" Input: {} ({})", input.display(), detected_input_format);
596 println!(" Output: {} ({})", output.display(), output_format);
597 println!();
598 }
599
600 match (detected_input_format, output_format) {
602 ("safetensors", "safetensors") => {
603 convert_safetensors_to_safetensors(input, output, global).await
604 }
605 ("safetensors", "pytorch") => {
606 Err(voirs_sdk::VoirsError::config_error(
607 "SafeTensors to PyTorch conversion not yet implemented. Consider using Python: \
608 import safetensors.torch; safetensors.torch.save_file(tensors, 'output.pt')",
609 ))
610 }
611 ("safetensors", "onnx") => {
612 Err(voirs_sdk::VoirsError::config_error(
613 "SafeTensors to ONNX conversion not supported. ONNX requires model architecture definition.",
614 ))
615 }
616 ("pytorch", "safetensors") => {
617 Err(voirs_sdk::VoirsError::config_error(
618 "PyTorch to SafeTensors conversion not yet implemented. Consider using Python: \
619 import safetensors.torch; safetensors.torch.save_file(torch.load('input.pt'), 'output.safetensors')",
620 ))
621 }
622 ("pytorch", "pytorch") => {
623 tokio::fs::copy(input, output).await?;
625 if !global.quiet {
626 println!("✅ Checkpoint copied successfully");
627 }
628 Ok(())
629 }
630 ("onnx", _) => {
631 Err(voirs_sdk::VoirsError::config_error(
632 "ONNX checkpoint conversion not supported. ONNX models are runtime-optimized formats.",
633 ))
634 }
635 _ => {
636 Err(voirs_sdk::VoirsError::config_error(format!(
637 "Unsupported conversion: {} to {}",
638 detected_input_format, output_format
639 )))
640 }
641 }
642}
643
644async fn convert_safetensors_to_safetensors(
646 input: &PathBuf,
647 output: &PathBuf,
648 global: &GlobalOptions,
649) -> Result<()> {
650 let data = tokio::fs::read(input).await?;
652 let tensors = SafeTensors::deserialize(&data).map_err(|e| {
653 voirs_sdk::VoirsError::config_error(format!("Failed to parse input checkpoint: {}", e))
654 })?;
655
656 let json_path = input.with_extension("json");
658 let metadata = if json_path.exists() {
659 tokio::fs::read_to_string(&json_path)
660 .await
661 .ok()
662 .and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok())
663 } else {
664 None
665 };
666
667 tokio::fs::copy(input, output).await?;
669
670 if let Some(ref meta) = metadata {
671 let output_json = output.with_extension("json");
672 tokio::fs::write(&output_json, serde_json::to_string_pretty(meta)?).await?;
673 }
674
675 if !global.quiet {
676 println!("✅ SafeTensors checkpoint converted successfully");
677 println!(" Tensors: {}", tensors.names().len());
678
679 if metadata.is_some() {
680 println!(
681 " Metadata copied: {}",
682 output.with_extension("json").display()
683 );
684 }
685 }
686
687 Ok(())
688}
689
690async fn prune_checkpoints(
692 directory: &PathBuf,
693 keep_best: Option<usize>,
694 keep_latest: Option<usize>,
695 dry_run: bool,
696 global: &GlobalOptions,
697) -> Result<()> {
698 if !directory.exists() {
699 return Err(voirs_sdk::VoirsError::config_error(format!(
700 "Directory not found: {}",
701 directory.display()
702 )));
703 }
704
705 if keep_best.is_none() && keep_latest.is_none() {
706 return Err(voirs_sdk::VoirsError::config_error(
707 "Must specify at least one of --keep-best or --keep-latest",
708 ));
709 }
710
711 let mut checkpoints = Vec::new();
713 let mut entries = tokio::fs::read_dir(directory).await?;
714
715 while let Some(entry) = entries.next_entry().await? {
716 let path = entry.path();
717 if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
718 if let Ok(metadata) = entry.metadata().await {
719 let json_path = path.with_extension("json");
720 let mut epoch = 0;
721 let mut train_loss = 0.0;
722 let mut val_loss = f64::MAX;
723
724 if json_path.exists() {
725 if let Ok(meta_str) = tokio::fs::read_to_string(&json_path).await {
726 if let Ok(meta_json) = serde_json::from_str::<serde_json::Value>(&meta_str)
727 {
728 if let Some(obj) = meta_json.as_object() {
729 epoch = obj
730 .get("epoch")
731 .and_then(|v| {
732 v.as_u64()
733 .map(|n| n as usize)
734 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
735 })
736 .unwrap_or(0);
737 train_loss = obj
738 .get("train_loss")
739 .and_then(|v| {
740 v.as_f64()
741 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
742 })
743 .unwrap_or(0.0);
744 val_loss = obj
745 .get("val_loss")
746 .and_then(|v| {
747 v.as_f64()
748 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
749 })
750 .unwrap_or(f64::MAX);
751 }
752 }
753 }
754 }
755
756 if let Ok(data) = tokio::fs::read(&path).await {
757 if SafeTensors::deserialize(&data).is_ok() {
758 checkpoints.push(CheckpointInfo {
759 path: path.clone(),
760 name: path.file_name().unwrap().to_string_lossy().to_string(),
761 epoch,
762 train_loss,
763 val_loss,
764 size: metadata.len(),
765 modified: metadata
766 .modified()
767 .ok()
768 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
769 .map(|d| d.as_secs())
770 .unwrap_or(0),
771 });
772 }
773 }
774 }
775 }
776 }
777
778 if checkpoints.is_empty() {
779 if !global.quiet {
780 println!("No checkpoints found in {}", directory.display());
781 }
782 return Ok(());
783 }
784
785 let mut to_delete = Vec::new();
786
787 if let Some(n) = keep_best {
789 let mut sorted = checkpoints.clone();
791 sorted.sort_by(|a, b| {
792 a.val_loss
793 .partial_cmp(&b.val_loss)
794 .unwrap_or(std::cmp::Ordering::Equal)
795 });
796
797 let to_keep: std::collections::HashSet<_> =
799 sorted.iter().take(n).map(|c| c.path.clone()).collect();
800
801 for ckpt in &checkpoints {
802 if !to_keep.contains(&ckpt.path) {
803 to_delete.push(ckpt.clone());
804 }
805 }
806 }
807
808 if let Some(n) = keep_latest {
809 let mut sorted = checkpoints.clone();
811 sorted.sort_by(|a, b| b.modified.cmp(&a.modified));
812
813 let to_keep: std::collections::HashSet<_> =
815 sorted.iter().take(n).map(|c| c.path.clone()).collect();
816
817 for ckpt in &checkpoints {
819 if !to_keep.contains(&ckpt.path) && !to_delete.iter().any(|d| d.path == ckpt.path) {
820 to_delete.push(ckpt.clone());
821 }
822 }
823 }
824
825 if to_delete.is_empty() {
826 if !global.quiet {
827 println!("✅ No checkpoints need to be pruned");
828 }
829 return Ok(());
830 }
831
832 if !global.quiet {
833 println!("\n🗑️ Checkpoint Pruning:\n");
834 println!("Total checkpoints: {}", checkpoints.len());
835 println!("To delete: {}", to_delete.len());
836
837 if dry_run {
838 println!("\n⚠️ DRY RUN - No files will be deleted\n");
839 }
840
841 println!("\nCheckpoints to be deleted:");
842 println!(
843 "{:<35} {:>8} {:>12} {:>10}",
844 "Name", "Epoch", "Val Loss", "Size"
845 );
846 println!("{}", "─".repeat(70));
847
848 for ckpt in &to_delete {
849 println!(
850 "{:<35} {:>8} {:>12.6} {:>10}",
851 truncate_str(&ckpt.name, 35),
852 ckpt.epoch,
853 if ckpt.val_loss == f64::MAX {
854 0.0
855 } else {
856 ckpt.val_loss
857 },
858 format_bytes(ckpt.size as usize)
859 );
860 }
861 println!();
862 }
863
864 if !dry_run {
865 let mut deleted_count = 0;
866 for ckpt in &to_delete {
867 if let Err(e) = tokio::fs::remove_file(&ckpt.path).await {
869 if !global.quiet {
870 eprintln!("⚠️ Failed to delete {}: {}", ckpt.name, e);
871 }
872 } else {
873 deleted_count += 1;
874
875 let json_path = ckpt.path.with_extension("json");
877 if json_path.exists() {
878 let _ = tokio::fs::remove_file(&json_path).await;
879 }
880 }
881 }
882
883 if !global.quiet {
884 println!("✅ Deleted {} checkpoint(s)", deleted_count);
885 }
886 }
887
888 Ok(())
889}
890
891fn truncate_str(s: &str, max_len: usize) -> String {
894 if s.len() <= max_len {
895 s.to_string()
896 } else {
897 format!("{}...", &s[..max_len.saturating_sub(3)])
898 }
899}
900
901fn format_number(n: usize) -> String {
902 if n >= 1_000_000 {
903 format!("{:.2}M", n as f64 / 1_000_000.0)
904 } else if n >= 1_000 {
905 format!("{:.2}K", n as f64 / 1_000.0)
906 } else {
907 n.to_string()
908 }
909}
910
911fn format_bytes(bytes: usize) -> String {
912 if bytes >= 1_000_000_000 {
913 format!("{:.2} GB", bytes as f64 / 1_000_000_000.0)
914 } else if bytes >= 1_000_000 {
915 format!("{:.2} MB", bytes as f64 / 1_000_000.0)
916 } else if bytes >= 1_000 {
917 format!("{:.2} KB", bytes as f64 / 1_000.0)
918 } else {
919 format!("{} B", bytes)
920 }
921}