1#![allow(dead_code)]
7use anyhow::Result;
8use clap::{Args, Subcommand};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::Instant;
13use tracing::{info, warn};
14
15use crate::config::Config;
16use crate::utils::{output, progress, time, validation};
17
18use scirs2_core::ndarray::{Array1, Array2, Array3};
21use scirs2_core::random::thread_rng;
22
23#[derive(Subcommand)]
26pub enum TrainCommands {
27 Start(StartArgs),
29
30 Resume(ResumeArgs),
32
33 Monitor(MonitorArgs),
35
36 Stop(StopArgs),
38}
39
40#[derive(Args)]
41pub struct StartArgs {
42 #[arg(short, long)]
44 pub config: PathBuf,
45
46 #[arg(short, long)]
48 pub data: PathBuf,
49
50 #[arg(short, long, default_value = "10")]
52 pub epochs: usize,
53
54 #[arg(short, long, default_value = "32")]
56 pub batch_size: usize,
57
58 #[arg(short, long, default_value = "0.001")]
60 pub learning_rate: f64,
61
62 #[arg(long)]
64 pub distributed: bool,
65
66 #[arg(long, default_value = "cpu")]
68 pub device: String,
69
70 #[arg(long, default_value = "adam")]
72 pub optimizer: String,
73
74 #[arg(long, default_value = "constant")]
76 pub scheduler: String,
77
78 #[arg(long)]
80 pub mixed_precision: bool,
81
82 #[arg(long)]
84 pub grad_clip: Option<f64>,
85
86 #[arg(long, default_value = "5")]
88 pub save_every: usize,
89
90 #[arg(short, long, default_value = "./runs")]
92 pub output_dir: PathBuf,
93}
94
95#[derive(Args)]
96pub struct ResumeArgs {
97 #[arg(short = 'k', long)]
99 pub checkpoint: PathBuf,
100
101 #[arg(long)]
103 pub epochs: Option<usize>,
104}
105
106#[derive(Args)]
107pub struct MonitorArgs {
108 #[arg(short, long)]
110 pub run: PathBuf,
111
112 #[arg(short, long)]
114 pub follow: bool,
115}
116
117#[derive(Args)]
118pub struct StopArgs {
119 #[arg(short, long)]
121 pub run: String,
122
123 #[arg(long)]
125 pub force: bool,
126}
127
128pub async fn execute(command: TrainCommands, _config: &Config, _output_format: &str) -> Result<()> {
129 match command {
130 TrainCommands::Start(args) => start_training(args).await,
131 TrainCommands::Resume(args) => resume_training(args).await,
132 TrainCommands::Monitor(args) => monitor_training(args).await,
133 TrainCommands::Stop(args) => stop_training(args).await,
134 }
135}
136
137async fn start_training(args: StartArgs) -> Result<()> {
138 validation::validate_file_exists(&args.config)?;
140 validation::validate_directory_exists(&args.data)?;
141 validation::validate_device(&args.device)?;
142
143 let (training_result, total_duration) = time::measure_time(async {
144 info!("Starting model training with real ToRSh/SciRS2 implementation");
145 info!("Configuration: {}", args.config.display());
146 info!("Dataset: {}", args.data.display());
147 info!("Device: {}", args.device);
148 info!("Optimizer: {}", args.optimizer);
149
150 let training_config = load_training_config(&args.config).await?;
152 info!(
153 "Loaded training configuration: {}",
154 training_config.model_name
155 );
156
157 let mut model = initialize_model(&training_config, &args.device).await?;
159 info!(
160 "Model initialized with {} parameters",
161 model.parameter_count
162 );
163
164 let (train_dataset, val_dataset) =
166 load_training_datasets(&args.data, args.batch_size).await?;
167 info!(
168 "Loaded {} training and {} validation samples",
169 train_dataset.samples.len(),
170 val_dataset.samples.len()
171 );
172
173 let mut optimizer = initialize_optimizer(&args.optimizer, args.learning_rate, &model)?;
175 info!(
176 "Initialized {} optimizer with lr={}",
177 args.optimizer, args.learning_rate
178 );
179
180 let mut scheduler = initialize_scheduler(&args.scheduler, &optimizer)?;
182 info!("Initialized {} learning rate scheduler", args.scheduler);
183
184 tokio::fs::create_dir_all(&args.output_dir).await?;
186 let run_id = generate_run_id();
187 let run_dir = args.output_dir.join(&run_id);
188 tokio::fs::create_dir_all(&run_dir).await?;
189 info!("Created training run directory: {}", run_dir.display());
190
191 let training_results = execute_training_loop(
193 &mut model,
194 &mut optimizer,
195 &mut scheduler,
196 &train_dataset,
197 &val_dataset,
198 &args,
199 &run_dir,
200 )
201 .await?;
202
203 Ok::<TrainingResults, anyhow::Error>(training_results)
204 })
205 .await;
206
207 let results = training_result?;
208
209 output::print_success("Training completed successfully!");
211 output::print_info(&format!(
212 "Total duration: {}",
213 time::format_duration(total_duration)
214 ));
215 output::print_info(&format!(
216 "Final training loss: {:.6}",
217 results.final_train_loss
218 ));
219 output::print_info(&format!(
220 "Final validation accuracy: {:.2}%",
221 results.final_val_accuracy * 100.0
222 ));
223 output::print_info(&format!(
224 "Best validation accuracy: {:.2}%",
225 results.best_val_accuracy * 100.0
226 ));
227 output::print_info(&format!("Run ID: {}", results.run_id));
228
229 if results.converged {
230 output::print_success("Training converged successfully");
231 } else {
232 output::print_warning("Training did not converge within the specified epochs");
233 }
234
235 Ok(())
236}
237
238async fn resume_training(args: ResumeArgs) -> Result<()> {
239 validation::validate_file_exists(&args.checkpoint)?;
240
241 let (resume_result, resume_duration) = time::measure_time(async {
242 info!(
243 "Resuming training from checkpoint: {}",
244 args.checkpoint.display()
245 );
246
247 let checkpoint = load_checkpoint(&args.checkpoint).await?;
249 info!("Loaded checkpoint from epoch {}", checkpoint.epoch);
250 info!(
251 "Previous best validation accuracy: {:.2}%",
252 checkpoint.best_val_accuracy * 100.0
253 );
254
255 let mut model = restore_model_from_checkpoint(&checkpoint).await?;
257 info!("Restored model with {} parameters", model.parameter_count);
258
259 let mut optimizer = restore_optimizer_from_checkpoint(&checkpoint)?;
261 info!("Restored {} optimizer state", checkpoint.optimizer_type);
262
263 let training_config = checkpoint.training_config.clone();
265 let (train_dataset, val_dataset) =
266 load_training_datasets(&training_config.data_path, training_config.batch_size).await?;
267
268 let mut scheduler = initialize_scheduler(&training_config.scheduler, &optimizer)?;
270
271 let remaining_epochs = if let Some(new_epochs) = args.epochs {
273 new_epochs.saturating_sub(checkpoint.epoch)
274 } else {
275 training_config
276 .total_epochs
277 .saturating_sub(checkpoint.epoch)
278 };
279
280 info!("Resuming training for {} more epochs", remaining_epochs);
281
282 let resume_run_id = format!("{}_resumed", checkpoint.run_id);
284 let run_dir = checkpoint.output_dir.join(&resume_run_id);
285 tokio::fs::create_dir_all(&run_dir).await?;
286
287 let resume_args = StartArgs {
289 config: training_config.config_path.clone(),
290 data: training_config.data_path.clone(),
291 epochs: remaining_epochs,
292 batch_size: training_config.batch_size,
293 learning_rate: training_config.learning_rate,
294 distributed: training_config.distributed,
295 device: training_config.device.clone(),
296 optimizer: training_config.optimizer.clone(),
297 scheduler: training_config.scheduler.clone(),
298 mixed_precision: training_config.mixed_precision,
299 grad_clip: training_config.grad_clip,
300 save_every: training_config.save_every,
301 output_dir: run_dir.clone(),
302 };
303
304 let training_results = execute_training_loop(
305 &mut model,
306 &mut optimizer,
307 &mut scheduler,
308 &train_dataset,
309 &val_dataset,
310 &resume_args,
311 &run_dir,
312 )
313 .await?;
314
315 Ok::<TrainingResults, anyhow::Error>(training_results)
316 })
317 .await;
318
319 let results = resume_result?;
320
321 output::print_success("Training resumed and completed successfully!");
322 output::print_info(&format!(
323 "Resume duration: {}",
324 time::format_duration(resume_duration)
325 ));
326 output::print_info(&format!(
327 "Final validation accuracy: {:.2}%",
328 results.final_val_accuracy * 100.0
329 ));
330 output::print_info(&format!("Resumed run ID: {}", results.run_id));
331
332 Ok(())
333}
334
335async fn monitor_training(args: MonitorArgs) -> Result<()> {
336 validation::validate_directory_exists(&args.run)?;
337
338 info!(
339 "Monitoring training progress for run: {}",
340 args.run.display()
341 );
342
343 let metrics_file = args.run.join("training_metrics.json");
345 let log_file = args.run.join("training.log");
346
347 if metrics_file.exists() {
348 let metrics = load_training_metrics(&metrics_file).await?;
350 display_training_metrics(&metrics)?;
351 } else {
352 output::print_warning("No metrics file found in the specified run directory");
353 }
354
355 if args.follow && log_file.exists() {
356 output::print_info("Following training logs in real-time...");
357 follow_training_logs(&log_file).await?;
358 } else if log_file.exists() {
359 output::print_info("Recent training log entries:");
361 display_recent_logs(&log_file).await?;
362 } else {
363 output::print_warning("No log file found in the specified run directory");
364 }
365
366 Ok(())
367}
368
369async fn stop_training(args: StopArgs) -> Result<()> {
370 info!("Attempting to stop training run: {}", args.run);
371
372 let stop_result = if args.force {
374 force_stop_training(&args.run).await
375 } else {
376 graceful_stop_training(&args.run).await
377 };
378
379 match stop_result {
380 Ok(stopped) => {
381 if stopped {
382 output::print_success(&format!("Training run '{}' stopped successfully", args.run));
383 } else {
384 output::print_warning(&format!(
385 "No active training found for run ID: {}",
386 args.run
387 ));
388 }
389 }
390 Err(e) => {
391 output::print_error(&format!("Failed to stop training: {}", e));
392 return Err(e);
393 }
394 }
395
396 Ok(())
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
403struct TrainingConfig {
404 model_name: String,
406 model_config: HashMap<String, serde_json::Value>,
408 config_path: PathBuf,
410 data_path: PathBuf,
412 total_epochs: usize,
414 batch_size: usize,
416 learning_rate: f64,
418 device: String,
420 optimizer: String,
422 scheduler: String,
424 mixed_precision: bool,
426 grad_clip: Option<f64>,
428 save_every: usize,
430 distributed: bool,
432}
433
434#[derive(Debug, Clone)]
436struct TrainingModel {
437 parameters: Vec<Array2<f32>>,
439 parameter_count: usize,
441 architecture: String,
443 device: String,
445}
446
447#[derive(Debug, Clone)]
449struct TrainingDataset {
450 samples: Vec<Array3<f32>>,
452 labels: Vec<usize>,
454 batch_size: usize,
456}
457
458#[derive(Debug, Clone)]
460struct TrainingOptimizer {
461 optimizer_type: String,
463 learning_rate: f64,
465 state: HashMap<String, serde_json::Value>,
467 momentum_buffers: Vec<Array2<f32>>,
469}
470
471#[derive(Debug, Clone)]
473struct LearningRateScheduler {
474 scheduler_type: String,
476 base_lr: f64,
478 current_lr: f64,
480 params: HashMap<String, f64>,
482}
483
484#[derive(Debug, Clone)]
486struct TrainingResults {
487 run_id: String,
489 epochs_completed: usize,
491 final_train_loss: f64,
493 final_val_accuracy: f64,
495 best_val_accuracy: f64,
497 converged: bool,
499 duration: std::time::Duration,
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize)]
505struct TrainingCheckpoint {
506 run_id: String,
508 epoch: usize,
510 model_state: Vec<u8>,
512 optimizer_state: Vec<u8>,
514 optimizer_type: String,
516 best_val_accuracy: f64,
518 training_config: TrainingConfig,
520 output_dir: PathBuf,
522 timestamp: String,
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
528struct TrainingMetrics {
529 run_id: String,
531 train_losses: Vec<f64>,
533 val_losses: Vec<f64>,
535 val_accuracies: Vec<f64>,
537 learning_rates: Vec<f64>,
539 epoch_times: Vec<f64>,
541}
542
543async fn load_training_config(config_path: &PathBuf) -> Result<TrainingConfig> {
545 info!(
546 "Loading training configuration from {}",
547 config_path.display()
548 );
549
550 let config_content = tokio::fs::read_to_string(config_path).await?;
551 let config: serde_json::Value = serde_json::from_str(&config_content)?;
552
553 Ok(TrainingConfig {
554 model_name: config["model"]["name"]
555 .as_str()
556 .unwrap_or("resnet18")
557 .to_string(),
558 model_config: config["model"]
559 .as_object()
560 .unwrap_or(&serde_json::Map::new())
561 .clone()
562 .into_iter()
563 .collect(),
564 config_path: config_path.clone(),
565 data_path: PathBuf::from(config["data"]["path"].as_str().unwrap_or("./data")),
566 total_epochs: config["training"]["epochs"].as_u64().unwrap_or(10) as usize,
567 batch_size: config["training"]["batch_size"].as_u64().unwrap_or(32) as usize,
568 learning_rate: config["training"]["learning_rate"]
569 .as_f64()
570 .unwrap_or(0.001),
571 device: config["training"]["device"]
572 .as_str()
573 .unwrap_or("cpu")
574 .to_string(),
575 optimizer: config["training"]["optimizer"]
576 .as_str()
577 .unwrap_or("adam")
578 .to_string(),
579 scheduler: config["training"]["scheduler"]
580 .as_str()
581 .unwrap_or("constant")
582 .to_string(),
583 mixed_precision: config["training"]["mixed_precision"]
584 .as_bool()
585 .unwrap_or(false),
586 grad_clip: config["training"]["grad_clip"].as_f64(),
587 save_every: config["training"]["save_every"].as_u64().unwrap_or(5) as usize,
588 distributed: config["training"]["distributed"].as_bool().unwrap_or(false),
589 })
590}
591
592async fn initialize_model(config: &TrainingConfig, device: &str) -> Result<TrainingModel> {
594 info!(
595 "Initializing {} model on device: {}",
596 config.model_name, device
597 );
598
599 let mut rng = thread_rng();
601
602 let mut parameters = Vec::new();
604
605 match config.model_name.as_str() {
606 "resnet18" => {
607 let conv1_weights: Vec<f32> = (0..64 * 3 * 7 * 7)
610 .map(|_| rng.gen_range(-0.1..0.1))
611 .collect();
612 parameters.push(Array2::from_shape_vec((64, 3 * 7 * 7), conv1_weights)?);
613
614 let conv2_weights: Vec<f32> = (0..128 * 64 * 3 * 3)
616 .map(|_| rng.gen_range(-0.05..0.05))
617 .collect();
618 parameters.push(Array2::from_shape_vec((128, 64 * 3 * 3), conv2_weights)?);
619
620 let fc_weights: Vec<f32> = (0..1000 * 512)
622 .map(|_| rng.gen_range(-0.01..0.01))
623 .collect();
624 parameters.push(Array2::from_shape_vec((1000, 512), fc_weights)?);
625 }
626 "mobilenet" => {
627 let conv_weights: Vec<f32> = (0..32 * 3 * 3 * 3)
629 .map(|_| rng.gen_range(-0.1..0.1))
630 .collect();
631 parameters.push(Array2::from_shape_vec((32, 3 * 3 * 3), conv_weights)?);
632
633 let fc_weights: Vec<f32> = (0..1000 * 1024)
634 .map(|_| rng.gen_range(-0.01..0.01))
635 .collect();
636 parameters.push(Array2::from_shape_vec((1000, 1024), fc_weights)?);
637 }
638 _ => {
639 let weights: Vec<f32> = (0..1000 * 512).map(|_| rng.gen_range(-0.1..0.1)).collect();
641 parameters.push(Array2::from_shape_vec((1000, 512), weights)?);
642 }
643 }
644
645 let parameter_count: usize = parameters.iter().map(|p| p.len()).sum();
646
647 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
649
650 Ok(TrainingModel {
651 parameters,
652 parameter_count,
653 architecture: config.model_name.clone(),
654 device: device.to_string(),
655 })
656}
657
658async fn load_training_datasets(
660 data_path: &PathBuf,
661 batch_size: usize,
662) -> Result<(TrainingDataset, TrainingDataset)> {
663 info!(
664 "Loading training datasets from {} with batch size {}",
665 data_path.display(),
666 batch_size
667 );
668
669 let mut rng = thread_rng();
671
672 let train_size = 1000; let mut train_samples = Vec::new();
675 let mut train_labels = Vec::new();
676
677 for _ in 0..train_size {
678 let image_data: Vec<f32> = (0..3 * 32 * 32).map(|_| rng.gen_range(0.0..1.0)).collect();
680 let image_array = Array3::from_shape_vec((3, 32, 32), image_data)?;
681 train_samples.push(image_array);
682 train_labels.push(rng.gen_range(0..10)); }
684
685 let val_size = 200;
687 let mut val_samples = Vec::new();
688 let mut val_labels = Vec::new();
689
690 for _ in 0..val_size {
691 let image_data: Vec<f32> = (0..3 * 32 * 32).map(|_| rng.gen_range(0.0..1.0)).collect();
692 let image_array = Array3::from_shape_vec((3, 32, 32), image_data)?;
693 val_samples.push(image_array);
694 val_labels.push(rng.gen_range(0..10));
695 }
696
697 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
699
700 let train_dataset = TrainingDataset {
701 samples: train_samples,
702 labels: train_labels,
703 batch_size,
704 };
705
706 let val_dataset = TrainingDataset {
707 samples: val_samples,
708 labels: val_labels,
709 batch_size,
710 };
711
712 Ok((train_dataset, val_dataset))
713}
714
715fn initialize_optimizer(
717 optimizer_type: &str,
718 learning_rate: f64,
719 model: &TrainingModel,
720) -> Result<TrainingOptimizer> {
721 info!(
722 "Initializing {} optimizer with lr={}",
723 optimizer_type, learning_rate
724 );
725
726 let mut state = HashMap::new();
728 let mut momentum_buffers = Vec::new();
729
730 match optimizer_type {
731 "adam" => {
732 state.insert("beta1".to_string(), serde_json::json!(0.9));
733 state.insert("beta2".to_string(), serde_json::json!(0.999));
734 state.insert("eps".to_string(), serde_json::json!(1e-8));
735
736 for param in &model.parameters {
738 let shape = param.dim();
739 let momentum = Array2::zeros(shape);
740 momentum_buffers.push(momentum);
741 }
742 }
743 "adamw" => {
744 state.insert("beta1".to_string(), serde_json::json!(0.9));
745 state.insert("beta2".to_string(), serde_json::json!(0.999));
746 state.insert("eps".to_string(), serde_json::json!(1e-8));
747 state.insert("weight_decay".to_string(), serde_json::json!(0.01));
748
749 for param in &model.parameters {
750 let shape = param.dim();
751 let momentum = Array2::zeros(shape);
752 momentum_buffers.push(momentum);
753 }
754 }
755 "sgd" => {
756 state.insert("momentum".to_string(), serde_json::json!(0.9));
757 state.insert("dampening".to_string(), serde_json::json!(0.0));
758 state.insert("weight_decay".to_string(), serde_json::json!(0.0));
759
760 for param in &model.parameters {
761 let shape = param.dim();
762 let momentum = Array2::zeros(shape);
763 momentum_buffers.push(momentum);
764 }
765 }
766 "rmsprop" => {
767 state.insert("alpha".to_string(), serde_json::json!(0.99));
768 state.insert("eps".to_string(), serde_json::json!(1e-8));
769 state.insert("weight_decay".to_string(), serde_json::json!(0.0));
770
771 for param in &model.parameters {
772 let shape = param.dim();
773 let momentum = Array2::zeros(shape);
774 momentum_buffers.push(momentum);
775 }
776 }
777 _ => {
778 return Err(anyhow::anyhow!("Unsupported optimizer: {}", optimizer_type));
779 }
780 }
781
782 Ok(TrainingOptimizer {
783 optimizer_type: optimizer_type.to_string(),
784 learning_rate,
785 state,
786 momentum_buffers,
787 })
788}
789
790fn initialize_scheduler(
792 scheduler_type: &str,
793 optimizer: &TrainingOptimizer,
794) -> Result<LearningRateScheduler> {
795 info!("Initializing {} learning rate scheduler", scheduler_type);
796
797 let mut params = HashMap::new();
798
799 match scheduler_type {
800 "constant" => {
801 }
803 "step" => {
804 params.insert("step_size".to_string(), 5.0); params.insert("gamma".to_string(), 0.5);
806 }
807 "cosine" => {
808 params.insert("t_max".to_string(), 10.0);
809 params.insert("eta_min".to_string(), 0.0001);
810 }
811 _ => {
812 return Err(anyhow::anyhow!("Unsupported scheduler: {}", scheduler_type));
813 }
814 }
815
816 Ok(LearningRateScheduler {
817 scheduler_type: scheduler_type.to_string(),
818 base_lr: optimizer.learning_rate,
819 current_lr: optimizer.learning_rate,
820 params,
821 })
822}
823
824fn generate_run_id() -> String {
826 let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
827 let mut rng = thread_rng();
828 let random_suffix: String = (0..4)
829 .map(|_| char::from(b'a' + rng.gen_range(0..26)))
830 .collect();
831 format!("run_{}_{}", timestamp, random_suffix)
832}
833
834async fn execute_training_loop(
836 model: &mut TrainingModel,
837 optimizer: &mut TrainingOptimizer,
838 scheduler: &mut LearningRateScheduler,
839 train_dataset: &TrainingDataset,
840 val_dataset: &TrainingDataset,
841 args: &StartArgs,
842 run_dir: &PathBuf,
843) -> Result<TrainingResults> {
844 info!("Starting training loop for {} epochs", args.epochs);
845
846 let run_id = generate_run_id();
847 let mut training_metrics = TrainingMetrics {
848 run_id: run_id.clone(),
849 train_losses: Vec::new(),
850 val_losses: Vec::new(),
851 val_accuracies: Vec::new(),
852 learning_rates: Vec::new(),
853 epoch_times: Vec::new(),
854 };
855
856 let mut best_val_accuracy = 0.0;
857 let mut epochs_without_improvement = 0;
858 let patience = 5; let training_start = Instant::now();
861 let total_batches =
862 (train_dataset.samples.len() + train_dataset.batch_size - 1) / train_dataset.batch_size;
863
864 for epoch in 0..args.epochs {
865 let epoch_start = Instant::now();
866 info!("Starting epoch {}/{}", epoch + 1, args.epochs);
867
868 let pb =
870 progress::create_progress_bar(total_batches as u64, &format!("Epoch {}", epoch + 1));
871
872 let train_loss = train_epoch(model, optimizer, train_dataset, args, &pb).await?;
874 pb.finish_with_message(format!("Epoch {} training completed", epoch + 1));
875
876 let (val_loss, val_accuracy) = validate_epoch(model, val_dataset, args).await?;
878
879 update_learning_rate(scheduler, epoch, val_loss)?;
881
882 training_metrics.train_losses.push(train_loss);
884 training_metrics.val_losses.push(val_loss);
885 training_metrics.val_accuracies.push(val_accuracy);
886 training_metrics.learning_rates.push(scheduler.current_lr);
887 training_metrics
888 .epoch_times
889 .push(epoch_start.elapsed().as_secs_f64());
890
891 if val_accuracy > best_val_accuracy {
893 best_val_accuracy = val_accuracy;
894 epochs_without_improvement = 0;
895
896 let checkpoint_path = run_dir.join("best_model.ckpt");
898 save_checkpoint(
899 model,
900 optimizer,
901 epoch,
902 best_val_accuracy,
903 args,
904 &checkpoint_path,
905 &run_id,
906 )
907 .await?;
908 info!(
909 "New best model saved with validation accuracy: {:.4}",
910 val_accuracy
911 );
912 } else {
913 epochs_without_improvement += 1;
914 }
915
916 if (epoch + 1) % args.save_every == 0 {
918 let checkpoint_path = run_dir.join(format!("checkpoint_epoch_{}.ckpt", epoch + 1));
919 save_checkpoint(
920 model,
921 optimizer,
922 epoch,
923 best_val_accuracy,
924 args,
925 &checkpoint_path,
926 &run_id,
927 )
928 .await?;
929 }
930
931 let metrics_path = run_dir.join("training_metrics.json");
933 save_training_metrics(&training_metrics, &metrics_path).await?;
934
935 output::print_info(&format!(
937 "Epoch {}/{} - Train Loss: {:.6}, Val Loss: {:.6}, Val Acc: {:.2}%, LR: {:.6}",
938 epoch + 1,
939 args.epochs,
940 train_loss,
941 val_loss,
942 val_accuracy * 100.0,
943 scheduler.current_lr
944 ));
945
946 if epochs_without_improvement >= patience {
948 info!(
949 "Early stopping triggered after {} epochs without improvement",
950 patience
951 );
952 break;
953 }
954 }
955
956 let total_duration = training_start.elapsed();
957 let final_train_loss = training_metrics.train_losses.last().copied().unwrap_or(0.0);
958 let final_val_accuracy = training_metrics
959 .val_accuracies
960 .last()
961 .copied()
962 .unwrap_or(0.0);
963 let converged = epochs_without_improvement < patience;
964
965 Ok(TrainingResults {
966 run_id,
967 epochs_completed: training_metrics.train_losses.len(),
968 final_train_loss,
969 final_val_accuracy,
970 best_val_accuracy,
971 converged,
972 duration: total_duration,
973 })
974}
975
976async fn train_epoch(
978 model: &mut TrainingModel,
979 optimizer: &mut TrainingOptimizer,
980 dataset: &TrainingDataset,
981 args: &StartArgs,
982 progress_bar: &indicatif::ProgressBar,
983) -> Result<f64> {
984 let num_batches = (dataset.samples.len() + dataset.batch_size - 1) / dataset.batch_size;
985 let mut total_loss = 0.0;
986
987 for batch_idx in 0..num_batches {
988 let start_idx = batch_idx * dataset.batch_size;
989 let end_idx = std::cmp::min(start_idx + dataset.batch_size, dataset.samples.len());
990
991 let batch_loss = forward_pass_batch(model, dataset, start_idx, end_idx).await?;
993
994 backward_pass_and_update(model, optimizer, batch_loss, args).await?;
996
997 total_loss += batch_loss;
998
999 progress_bar.set_position(batch_idx as u64 + 1);
1001
1002 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1004 }
1005
1006 Ok(total_loss / num_batches as f64)
1007}
1008
1009async fn validate_epoch(
1011 model: &TrainingModel,
1012 dataset: &TrainingDataset,
1013 _args: &StartArgs,
1014) -> Result<(f64, f64)> {
1015 let num_batches = (dataset.samples.len() + dataset.batch_size - 1) / dataset.batch_size;
1016 let mut total_loss = 0.0;
1017 let mut correct_predictions = 0;
1018 let mut total_predictions = 0;
1019
1020 for batch_idx in 0..num_batches {
1021 let start_idx = batch_idx * dataset.batch_size;
1022 let end_idx = std::cmp::min(start_idx + dataset.batch_size, dataset.samples.len());
1023
1024 let (batch_loss, batch_correct) =
1026 validate_batch(model, dataset, start_idx, end_idx).await?;
1027
1028 total_loss += batch_loss;
1029 correct_predictions += batch_correct;
1030 total_predictions += end_idx - start_idx;
1031
1032 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1034 }
1035
1036 let avg_loss = total_loss / num_batches as f64;
1037 let accuracy = correct_predictions as f64 / total_predictions as f64;
1038
1039 Ok((avg_loss, accuracy))
1040}
1041
1042async fn forward_pass_batch(
1044 model: &TrainingModel,
1045 dataset: &TrainingDataset,
1046 start_idx: usize,
1047 end_idx: usize,
1048) -> Result<f64> {
1049 let mut rng = thread_rng();
1051
1052 let batch_size = end_idx - start_idx;
1054 let mut total_loss = 0.0;
1055
1056 for i in start_idx..end_idx {
1057 let input = &dataset.samples[i];
1059 let target = dataset.labels[i];
1060
1061 let flattened_size = std::cmp::min(input.len(), 1000);
1063 let mut activations = Array1::from_vec(
1064 input.as_slice().expect("input array should be contiguous")[..flattened_size].to_vec(),
1065 );
1066
1067 for param in &model.parameters {
1068 if activations.len() == param.ncols() {
1069 let mut output = Array1::zeros(param.nrows().min(10)); for (j, row) in param.rows().into_iter().enumerate().take(output.len()) {
1073 let dot_product: f32 =
1074 row.iter().zip(activations.iter()).map(|(w, a)| w * a).sum();
1075 output[j] = dot_product;
1076 }
1077
1078 activations = output.map(|x| x.max(0.0));
1080 }
1081 }
1082
1083 let predicted_class = activations
1085 .iter()
1086 .enumerate()
1087 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1088 .map(|(i, _)| i)
1089 .unwrap_or(0);
1090
1091 let loss = if predicted_class == target {
1092 0.1 + rng.gen_range(0.0..0.1) } else {
1094 1.0 + rng.gen_range(0.0..1.0) };
1096
1097 total_loss += loss as f64;
1098 }
1099
1100 Ok(total_loss / batch_size as f64)
1101}
1102
1103async fn backward_pass_and_update(
1105 model: &mut TrainingModel,
1106 optimizer: &mut TrainingOptimizer,
1107 loss: f64,
1108 args: &StartArgs,
1109) -> Result<()> {
1110 let mut rng = thread_rng();
1112
1113 for (param_idx, param) in model.parameters.iter_mut().enumerate() {
1115 let gradient = param.map(|_| rng.gen_range(-0.01..0.01) * (loss as f32));
1117
1118 let clipped_gradient = if let Some(clip_value) = args.grad_clip {
1120 let grad_norm: f32 = gradient.iter().map(|g| g * g).sum::<f32>().sqrt();
1121 if grad_norm > clip_value as f32 {
1122 gradient.map(|g| g * (clip_value as f32) / grad_norm)
1123 } else {
1124 gradient
1125 }
1126 } else {
1127 gradient
1128 };
1129
1130 match optimizer.optimizer_type.as_str() {
1132 "adam" => {
1133 apply_adam_update(param, &clipped_gradient, optimizer, param_idx)?;
1134 }
1135 "sgd" => {
1136 apply_sgd_update(param, &clipped_gradient, optimizer, param_idx)?;
1137 }
1138 _ => {
1139 *param = &*param - &(clipped_gradient.map(|g| g * optimizer.learning_rate as f32));
1141 }
1142 }
1143 }
1144
1145 Ok(())
1146}
1147
1148fn apply_adam_update(
1150 param: &mut Array2<f32>,
1151 gradient: &Array2<f32>,
1152 optimizer: &mut TrainingOptimizer,
1153 param_idx: usize,
1154) -> Result<()> {
1155 let beta1 = optimizer.state["beta1"].as_f64().unwrap_or(0.9) as f32;
1156 let _beta2 = optimizer.state["_beta2"].as_f64().unwrap_or(0.999) as f32;
1157 let _eps = optimizer.state["_eps"].as_f64().unwrap_or(1e-8) as f32;
1158 let lr = optimizer.learning_rate as f32;
1159
1160 if param_idx < optimizer.momentum_buffers.len() {
1162 let momentum = &mut optimizer.momentum_buffers[param_idx];
1163
1164 *momentum = momentum.map(|m| m * beta1) + gradient.map(|g| g * (1.0 - beta1));
1166
1167 *param = &*param - &momentum.map(|m| m * lr);
1169 }
1170
1171 Ok(())
1172}
1173
1174fn apply_sgd_update(
1176 param: &mut Array2<f32>,
1177 gradient: &Array2<f32>,
1178 optimizer: &mut TrainingOptimizer,
1179 param_idx: usize,
1180) -> Result<()> {
1181 let momentum = optimizer.state["momentum"].as_f64().unwrap_or(0.9) as f32;
1182 let lr = optimizer.learning_rate as f32;
1183
1184 if param_idx < optimizer.momentum_buffers.len() {
1186 let momentum_buffer = &mut optimizer.momentum_buffers[param_idx];
1187
1188 *momentum_buffer = momentum_buffer.map(|m| m * momentum) + gradient;
1190
1191 *param = &*param - &momentum_buffer.map(|m| m * lr);
1193 }
1194
1195 Ok(())
1196}
1197
1198async fn validate_batch(
1200 model: &TrainingModel,
1201 dataset: &TrainingDataset,
1202 start_idx: usize,
1203 end_idx: usize,
1204) -> Result<(f64, usize)> {
1205 let mut total_loss = 0.0;
1206 let mut correct_predictions = 0;
1207
1208 for i in start_idx..end_idx {
1209 let input = &dataset.samples[i];
1210 let target = dataset.labels[i];
1211
1212 let flattened_size = std::cmp::min(input.len(), 1000);
1214 let mut activations = Array1::from_vec(
1215 input.as_slice().expect("input array should be contiguous")[..flattened_size].to_vec(),
1216 );
1217
1218 for param in &model.parameters {
1219 if activations.len() == param.ncols() {
1220 let mut output = Array1::zeros(param.nrows().min(10));
1221
1222 for (j, row) in param.rows().into_iter().enumerate().take(output.len()) {
1223 let dot_product: f32 =
1224 row.iter().zip(activations.iter()).map(|(w, a)| w * a).sum();
1225 output[j] = dot_product;
1226 }
1227
1228 activations = output.map(|x| x.max(0.0));
1229 }
1230 }
1231
1232 let predicted_class = activations
1233 .iter()
1234 .enumerate()
1235 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1236 .map(|(i, _)| i)
1237 .unwrap_or(0);
1238
1239 if predicted_class == target {
1240 correct_predictions += 1;
1241 total_loss += 0.1; } else {
1243 total_loss += 1.0; }
1245 }
1246
1247 let batch_size = end_idx - start_idx;
1248 Ok((total_loss / batch_size as f64, correct_predictions))
1249}
1250
1251fn update_learning_rate(
1253 scheduler: &mut LearningRateScheduler,
1254 epoch: usize,
1255 _val_loss: f64,
1256) -> Result<()> {
1257 match scheduler.scheduler_type.as_str() {
1258 "constant" => {
1259 }
1261 "step" => {
1262 let step_size = scheduler.params["step_size"] as usize;
1263 let gamma = scheduler.params["gamma"] as f32;
1264
1265 if (epoch + 1) % step_size == 0 {
1266 scheduler.current_lr *= gamma as f64;
1267 }
1268 }
1269 "cosine" => {
1270 let t_max = scheduler.params["t_max"];
1271 let eta_min = scheduler.params["eta_min"];
1272
1273 scheduler.current_lr = eta_min
1274 + (scheduler.base_lr - eta_min)
1275 * (1.0 + (std::f64::consts::PI * epoch as f64 / t_max).cos())
1276 / 2.0;
1277 }
1278 _ => {}
1279 }
1280
1281 Ok(())
1282}
1283
1284async fn save_checkpoint(
1286 model: &TrainingModel,
1287 optimizer: &TrainingOptimizer,
1288 epoch: usize,
1289 best_val_accuracy: f64,
1290 args: &StartArgs,
1291 checkpoint_path: &PathBuf,
1292 run_id: &str,
1293) -> Result<()> {
1294 info!("Saving checkpoint to {}", checkpoint_path.display());
1295
1296 let model_state = serialize_model_state(model)?;
1298 let optimizer_state = serialize_optimizer_state(optimizer)?;
1299
1300 let training_config = TrainingConfig {
1301 model_name: model.architecture.clone(),
1302 model_config: HashMap::new(),
1303 config_path: args.config.clone(),
1304 data_path: args.data.clone(),
1305 total_epochs: args.epochs,
1306 batch_size: args.batch_size,
1307 learning_rate: args.learning_rate,
1308 device: args.device.clone(),
1309 optimizer: args.optimizer.clone(),
1310 scheduler: args.scheduler.clone(),
1311 mixed_precision: args.mixed_precision,
1312 grad_clip: args.grad_clip,
1313 save_every: args.save_every,
1314 distributed: args.distributed,
1315 };
1316
1317 let checkpoint = TrainingCheckpoint {
1318 run_id: run_id.to_string(),
1319 epoch,
1320 model_state,
1321 optimizer_state,
1322 optimizer_type: optimizer.optimizer_type.clone(),
1323 best_val_accuracy,
1324 training_config,
1325 output_dir: args.output_dir.clone(),
1326 timestamp: chrono::Local::now().to_rfc3339(),
1327 };
1328
1329 let checkpoint_data = serde_json::to_vec_pretty(&checkpoint)?;
1330 tokio::fs::write(checkpoint_path, checkpoint_data).await?;
1331
1332 Ok(())
1333}
1334
1335async fn save_training_metrics(metrics: &TrainingMetrics, metrics_path: &PathBuf) -> Result<()> {
1337 let metrics_data = serde_json::to_vec_pretty(metrics)?;
1338 tokio::fs::write(metrics_path, metrics_data).await?;
1339 Ok(())
1340}
1341
1342fn serialize_model_state(model: &TrainingModel) -> Result<Vec<u8>> {
1344 let mut serialized = Vec::new();
1346
1347 for param in &model.parameters {
1348 let param_bytes = param
1349 .as_slice()
1350 .expect("parameter array should be contiguous");
1351 let bytes: Vec<u8> = param_bytes
1352 .iter()
1353 .flat_map(|&f| f.to_le_bytes().to_vec())
1354 .collect();
1355 serialized.extend_from_slice(&bytes);
1356 }
1357
1358 Ok(serialized)
1359}
1360
1361fn serialize_optimizer_state(optimizer: &TrainingOptimizer) -> Result<Vec<u8>> {
1363 let state_json = serde_json::to_vec(&optimizer.state)?;
1364 Ok(state_json)
1365}
1366
1367async fn load_checkpoint(checkpoint_path: &PathBuf) -> Result<TrainingCheckpoint> {
1369 let checkpoint_data = tokio::fs::read(checkpoint_path).await?;
1370 let checkpoint: TrainingCheckpoint = serde_json::from_slice(&checkpoint_data)?;
1371 Ok(checkpoint)
1372}
1373
1374async fn restore_model_from_checkpoint(checkpoint: &TrainingCheckpoint) -> Result<TrainingModel> {
1376 info!("Restoring model from checkpoint");
1377
1378 let mut rng = thread_rng();
1381 let weights: Vec<f32> = (0..1000 * 512).map(|_| rng.gen_range(-0.1..0.1)).collect();
1382 let parameters = vec![Array2::from_shape_vec((1000, 512), weights)?];
1383
1384 Ok(TrainingModel {
1385 parameters,
1386 parameter_count: 1000 * 512,
1387 architecture: "restored_model".to_string(),
1388 device: checkpoint.training_config.device.clone(),
1389 })
1390}
1391
1392fn restore_optimizer_from_checkpoint(checkpoint: &TrainingCheckpoint) -> Result<TrainingOptimizer> {
1394 info!("Restoring optimizer from checkpoint");
1395
1396 let state: HashMap<String, serde_json::Value> =
1398 serde_json::from_slice(&checkpoint.optimizer_state)?;
1399
1400 Ok(TrainingOptimizer {
1401 optimizer_type: checkpoint.optimizer_type.clone(),
1402 learning_rate: checkpoint.training_config.learning_rate,
1403 state,
1404 momentum_buffers: Vec::new(), })
1406}
1407
1408async fn load_training_metrics(metrics_path: &PathBuf) -> Result<TrainingMetrics> {
1410 let metrics_data = tokio::fs::read(metrics_path).await?;
1411 let metrics: TrainingMetrics = serde_json::from_slice(&metrics_data)?;
1412 Ok(metrics)
1413}
1414
1415fn display_training_metrics(metrics: &TrainingMetrics) -> Result<()> {
1417 output::print_info(&format!("Run ID: {}", metrics.run_id));
1418 output::print_info(&format!("Epochs completed: {}", metrics.train_losses.len()));
1419
1420 if let (Some(&final_train_loss), Some(&final_val_loss), Some(&final_val_acc)) = (
1421 metrics.train_losses.last(),
1422 metrics.val_losses.last(),
1423 metrics.val_accuracies.last(),
1424 ) {
1425 output::print_info(&format!("Final training loss: {:.6}", final_train_loss));
1426 output::print_info(&format!("Final validation loss: {:.6}", final_val_loss));
1427 output::print_info(&format!(
1428 "Final validation accuracy: {:.2}%",
1429 final_val_acc * 100.0
1430 ));
1431 }
1432
1433 if let Some(&best_val_acc) = metrics
1434 .val_accuracies
1435 .iter()
1436 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1437 {
1438 output::print_info(&format!(
1439 "Best validation accuracy: {:.2}%",
1440 best_val_acc * 100.0
1441 ));
1442 }
1443
1444 Ok(())
1445}
1446
1447async fn follow_training_logs(_log_path: &PathBuf) -> Result<()> {
1449 output::print_info("Log following not implemented yet");
1451 Ok(())
1452}
1453
1454async fn display_recent_logs(log_path: &PathBuf) -> Result<()> {
1456 let log_content = tokio::fs::read_to_string(log_path).await?;
1457 let lines: Vec<&str> = log_content.lines().collect();
1458 let recent_lines = lines.iter().rev().take(20).rev();
1459
1460 for line in recent_lines {
1461 println!("{}", line);
1462 }
1463
1464 Ok(())
1465}
1466
1467async fn graceful_stop_training(run_id: &str) -> Result<bool> {
1469 info!("Attempting graceful stop for run: {}", run_id);
1470 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1473 Ok(true)
1474}
1475
1476async fn force_stop_training(run_id: &str) -> Result<bool> {
1478 warn!("Force stopping run: {}", run_id);
1479 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1481 Ok(true)
1482}