1use crate::enhanced_distributed_training::{DistributedConfig, PerformanceMetrics};
51use serde::{Deserialize, Serialize};
52use std::collections::{HashMap, VecDeque};
53use std::path::PathBuf;
54use std::sync::{Arc, Mutex};
55use std::time::{Duration, Instant, SystemTime};
56use trustformers_core::errors::Result;
57use trustformers_core::tensor::Tensor;
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AutoScalerConfig {
62 pub min_nodes: usize,
64 pub max_nodes: usize,
66 pub strategy: ScalingStrategy,
68 pub scale_up_threshold: f32,
70 pub scale_down_threshold: f32,
72 pub scaling_cooldown: Duration,
74 pub predictive_scaling: bool,
76 pub cost_priority: f32,
78}
79
80impl Default for AutoScalerConfig {
81 fn default() -> Self {
82 Self {
83 min_nodes: 1,
84 max_nodes: 16,
85 strategy: ScalingStrategy::Performance,
86 scale_up_threshold: 0.85,
87 scale_down_threshold: 0.6,
88 scaling_cooldown: Duration::from_secs(300), predictive_scaling: true,
90 cost_priority: 0.3, }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub enum ScalingStrategy {
98 Performance,
100 QueueBased,
102 Predictive,
104 CostOptimized,
106 Custom(String),
108}
109
110pub struct AutoScaler {
112 config: AutoScalerConfig,
113 current_nodes: usize,
114 last_scaling_action: Instant,
115 performance_history: VecDeque<PerformanceMetrics>,
116 scaling_history: Vec<ScalingEvent>,
117 workload_predictor: WorkloadPredictor,
118 cost_optimizer: CostOptimizer,
119}
120
121impl AutoScaler {
122 pub fn new(config: AutoScalerConfig) -> Self {
123 Self {
124 current_nodes: config.min_nodes,
125 config,
126 last_scaling_action: Instant::now(),
127 performance_history: VecDeque::with_capacity(1000),
128 scaling_history: Vec::new(),
129 workload_predictor: WorkloadPredictor::new(),
130 cost_optimizer: CostOptimizer::new(),
131 }
132 }
133
134 pub fn with_min_nodes(mut self, min_nodes: usize) -> Self {
136 self.config.min_nodes = min_nodes;
137 if self.current_nodes < min_nodes {
139 self.current_nodes = min_nodes;
140 }
141 self
142 }
143
144 pub fn with_max_nodes(mut self, max_nodes: usize) -> Self {
145 self.config.max_nodes = max_nodes;
146 self
147 }
148
149 pub fn with_scaling_strategy(mut self, strategy: ScalingStrategy) -> Self {
150 self.config.strategy = strategy;
151 self
152 }
153
154 pub fn with_scale_up_threshold(mut self, threshold: f32) -> Self {
155 self.config.scale_up_threshold = threshold;
156 self
157 }
158
159 pub fn with_scale_down_threshold(mut self, threshold: f32) -> Self {
160 self.config.scale_down_threshold = threshold;
161 self
162 }
163
164 pub fn update_and_scale(&mut self, metrics: &PerformanceMetrics) -> Result<ScalingDecision> {
166 self.performance_history.push_back(metrics.clone());
168 if self.performance_history.len() > 1000 {
169 self.performance_history.pop_front();
170 }
171
172 if self.last_scaling_action.elapsed() < self.config.scaling_cooldown {
174 return Ok(ScalingDecision::NoAction);
175 }
176
177 let avg_utilization =
179 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32;
180 let _avg_memory =
181 metrics.memory_usage.iter().sum::<f32>() / metrics.memory_usage.len() as f32;
182
183 let decision = match &self.config.strategy {
185 ScalingStrategy::Performance => self.performance_based_scaling(avg_utilization)?,
186 ScalingStrategy::QueueBased => self.queue_based_scaling(metrics)?,
187 ScalingStrategy::Predictive => self.predictive_scaling(metrics)?,
188 ScalingStrategy::CostOptimized => {
189 self.cost_optimized_scaling(avg_utilization, metrics)?
190 },
191 ScalingStrategy::Custom(_) => self.custom_scaling(metrics)?,
192 };
193
194 match &decision {
196 ScalingDecision::ScaleUp(nodes) => {
197 self.execute_scale_up(*nodes)?;
198 },
199 ScalingDecision::ScaleDown(nodes) => {
200 self.execute_scale_down(*nodes)?;
201 },
202 ScalingDecision::NoAction => {},
203 }
204
205 Ok(decision)
206 }
207
208 fn performance_based_scaling(&self, avg_utilization: f32) -> Result<ScalingDecision> {
209 if avg_utilization > self.config.scale_up_threshold
210 && self.current_nodes < self.config.max_nodes
211 {
212 let target_utilization = 0.75; let utilization_ratio = avg_utilization / target_utilization;
215 let nodes_to_add =
216 ((utilization_ratio - 1.0) * self.current_nodes as f32).ceil() as usize;
217 let nodes_to_add = nodes_to_add.min(self.config.max_nodes - self.current_nodes);
218
219 Ok(ScalingDecision::ScaleUp(nodes_to_add))
220 } else if avg_utilization < self.config.scale_down_threshold
221 && self.current_nodes > self.config.min_nodes
222 {
223 let target_utilization = 0.8; let required_nodes =
226 (avg_utilization * self.current_nodes as f32 / target_utilization).ceil() as usize;
227 let nodes_to_remove = self.current_nodes.saturating_sub(required_nodes);
228 let nodes_to_remove = nodes_to_remove.min(self.current_nodes - self.config.min_nodes);
229
230 if nodes_to_remove > 0 {
231 Ok(ScalingDecision::ScaleDown(nodes_to_remove))
232 } else {
233 Ok(ScalingDecision::NoAction)
234 }
235 } else {
236 Ok(ScalingDecision::NoAction)
237 }
238 }
239
240 fn queue_based_scaling(&self, metrics: &PerformanceMetrics) -> Result<ScalingDecision> {
241 let throughput_ratio = metrics.throughput / 1000.0; if throughput_ratio < 0.5 && self.current_nodes < self.config.max_nodes {
245 Ok(ScalingDecision::ScaleUp(1))
246 } else if throughput_ratio > 2.0 && self.current_nodes > self.config.min_nodes {
247 Ok(ScalingDecision::ScaleDown(1))
248 } else {
249 Ok(ScalingDecision::NoAction)
250 }
251 }
252
253 fn predictive_scaling(&mut self, metrics: &PerformanceMetrics) -> Result<ScalingDecision> {
254 if !self.config.predictive_scaling {
255 return self.performance_based_scaling(
256 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32,
257 );
258 }
259
260 self.workload_predictor.update_metrics(metrics);
262
263 let predicted_load = self.workload_predictor.predict_workload(Duration::from_secs(600))?;
265
266 if predicted_load > self.config.scale_up_threshold * 1.1 && self.current_nodes < self.config.max_nodes
269 {
270 let nodes_to_add =
271 ((predicted_load - 0.75) * self.current_nodes as f32).ceil() as usize;
272 Ok(ScalingDecision::ScaleUp(
273 nodes_to_add.min(self.config.max_nodes - self.current_nodes),
274 ))
275 } else if predicted_load < self.config.scale_down_threshold * 0.9 && self.current_nodes > self.config.min_nodes
277 {
278 let target_nodes = (predicted_load / 0.8 * self.current_nodes as f32).ceil() as usize;
279 let nodes_to_remove = self.current_nodes.saturating_sub(target_nodes);
280 if nodes_to_remove > 0 {
281 Ok(ScalingDecision::ScaleDown(
282 nodes_to_remove.min(self.current_nodes - self.config.min_nodes),
283 ))
284 } else {
285 Ok(ScalingDecision::NoAction)
286 }
287 } else {
288 Ok(ScalingDecision::NoAction)
289 }
290 }
291
292 fn cost_optimized_scaling(
293 &mut self,
294 avg_utilization: f32,
295 metrics: &PerformanceMetrics,
296 ) -> Result<ScalingDecision> {
297 let current_cost = self.cost_optimizer.calculate_current_cost(self.current_nodes, metrics);
299
300 if avg_utilization > self.config.scale_up_threshold
302 && self.current_nodes < self.config.max_nodes
303 {
304 let scale_up_cost =
305 self.cost_optimizer.calculate_scale_up_cost(self.current_nodes + 1, metrics);
306 let cost_benefit_ratio = current_cost / scale_up_cost;
307
308 if cost_benefit_ratio > (1.0 - self.config.cost_priority) {
309 Ok(ScalingDecision::ScaleUp(1))
310 } else {
311 Ok(ScalingDecision::NoAction)
312 }
313 } else if avg_utilization < self.config.scale_down_threshold
314 && self.current_nodes > self.config.min_nodes
315 {
316 let scale_down_cost =
317 self.cost_optimizer.calculate_scale_down_cost(self.current_nodes - 1, metrics);
318 let cost_savings = current_cost - scale_down_cost;
319
320 if cost_savings > current_cost * 0.1 {
321 Ok(ScalingDecision::ScaleDown(1))
323 } else {
324 Ok(ScalingDecision::NoAction)
325 }
326 } else {
327 Ok(ScalingDecision::NoAction)
328 }
329 }
330
331 fn custom_scaling(&self, _metrics: &PerformanceMetrics) -> Result<ScalingDecision> {
332 Ok(ScalingDecision::NoAction)
334 }
335
336 fn execute_scale_up(&mut self, nodes: usize) -> Result<()> {
337 println!(
338 "🔼 Scaling up: Adding {} nodes (current: {})",
339 nodes, self.current_nodes
340 );
341
342 self.current_nodes += nodes;
343 self.last_scaling_action = Instant::now();
344
345 self.scaling_history.push(ScalingEvent {
346 timestamp: SystemTime::now(),
347 action: ScalingAction::ScaleUp,
348 nodes_changed: nodes,
349 reason: "Performance threshold exceeded".to_string(),
350 });
351
352 Ok(())
359 }
360
361 fn execute_scale_down(&mut self, nodes: usize) -> Result<()> {
362 println!(
363 "🔽 Scaling down: Removing {} nodes (current: {})",
364 nodes, self.current_nodes
365 );
366
367 self.current_nodes -= nodes;
368 self.last_scaling_action = Instant::now();
369
370 self.scaling_history.push(ScalingEvent {
371 timestamp: SystemTime::now(),
372 action: ScalingAction::ScaleDown,
373 nodes_changed: nodes,
374 reason: "Low utilization detected".to_string(),
375 });
376
377 Ok(())
384 }
385
386 pub fn get_current_nodes(&self) -> usize {
387 self.current_nodes
388 }
389
390 pub fn get_scaling_history(&self) -> &[ScalingEvent] {
391 &self.scaling_history
392 }
393}
394
395#[derive(Debug, Clone)]
397pub enum ScalingDecision {
398 ScaleUp(usize),
399 ScaleDown(usize),
400 NoAction,
401}
402
403#[derive(Debug, Clone)]
405pub struct ScalingEvent {
406 pub timestamp: SystemTime,
407 pub action: ScalingAction,
408 pub nodes_changed: usize,
409 pub reason: String,
410}
411
412#[derive(Debug, Clone)]
413pub enum ScalingAction {
414 ScaleUp,
415 ScaleDown,
416}
417
418pub struct WorkloadPredictor {
420 historical_data: VecDeque<(Instant, f32)>, trend_analyzer: TrendAnalyzer,
422 seasonal_analyzer: SeasonalAnalyzer,
423}
424
425impl Default for WorkloadPredictor {
426 fn default() -> Self {
427 Self::new()
428 }
429}
430
431impl WorkloadPredictor {
432 pub fn new() -> Self {
433 Self {
434 historical_data: VecDeque::with_capacity(10000),
435 trend_analyzer: TrendAnalyzer::new(),
436 seasonal_analyzer: SeasonalAnalyzer::new(),
437 }
438 }
439
440 pub fn update_metrics(&mut self, metrics: &PerformanceMetrics) {
441 let avg_utilization =
442 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32;
443 let now = Instant::now();
444
445 self.historical_data.push_back((now, avg_utilization));
446 if self.historical_data.len() > 10000 {
447 self.historical_data.pop_front();
448 }
449
450 self.trend_analyzer.update(avg_utilization);
451 self.seasonal_analyzer.update(now, avg_utilization);
452 }
453
454 pub fn predict_workload(&self, horizon: Duration) -> Result<f32> {
455 if self.historical_data.len() < 10 {
456 return Ok(0.75); }
459
460 let trend_prediction = self.trend_analyzer.predict(horizon)?;
462 let seasonal_prediction = self.seasonal_analyzer.predict(horizon)?;
463
464 let prediction = trend_prediction * 0.7 + seasonal_prediction * 0.3;
466
467 Ok(prediction.clamp(0.0, 1.0))
469 }
470}
471
472pub struct TrendAnalyzer {
474 values: VecDeque<f32>,
475 window_size: usize,
476}
477
478impl Default for TrendAnalyzer {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484impl TrendAnalyzer {
485 pub fn new() -> Self {
486 Self {
487 values: VecDeque::with_capacity(100),
488 window_size: 50,
489 }
490 }
491
492 pub fn update(&mut self, value: f32) {
493 self.values.push_back(value);
494 if self.values.len() > self.window_size {
495 self.values.pop_front();
496 }
497 }
498
499 pub fn predict(&self, _horizon: Duration) -> Result<f32> {
500 if self.values.len() < 10 {
501 return Ok(0.75); }
503
504 let values: Vec<f32> = self.values.iter().cloned().collect();
506 let n = values.len() as f32;
507
508 let x_sum = (0..values.len()).sum::<usize>() as f32;
509 let y_sum = values.iter().sum::<f32>();
510 let xy_sum = values.iter().enumerate().map(|(i, &y)| i as f32 * y).sum::<f32>();
511 let x2_sum = (0..values.len()).map(|i| (i * i) as f32).sum::<f32>();
512
513 let slope = (n * xy_sum - x_sum * y_sum) / (n * x2_sum - x_sum * x_sum);
515 let intercept = (y_sum - slope * x_sum) / n;
516
517 let next_x = values.len() as f32;
519 let prediction = slope * next_x + intercept;
520
521 Ok(prediction)
522 }
523}
524
525pub struct SeasonalAnalyzer {
527 hourly_patterns: HashMap<u32, Vec<f32>>, last_update: Option<Instant>,
529}
530
531impl Default for SeasonalAnalyzer {
532 fn default() -> Self {
533 Self::new()
534 }
535}
536
537impl SeasonalAnalyzer {
538 pub fn new() -> Self {
539 Self {
540 hourly_patterns: HashMap::new(),
541 last_update: None,
542 }
543 }
544
545 pub fn update(&mut self, timestamp: Instant, value: f32) {
546 let pseudo_hour = (timestamp.elapsed().as_secs() / 3600) % 24;
548
549 self.hourly_patterns.entry(pseudo_hour as u32).or_default().push(value);
550
551 for values in self.hourly_patterns.values_mut() {
553 if values.len() > 100 {
554 values.drain(0..50); }
556 }
557
558 self.last_update = Some(timestamp);
559 }
560
561 pub fn predict(&self, _horizon: Duration) -> Result<f32> {
562 if self.hourly_patterns.is_empty() {
563 return Ok(0.75); }
565
566 let all_values: Vec<f32> =
568 self.hourly_patterns.values().flat_map(|v| v.iter()).cloned().collect();
569
570 if all_values.is_empty() {
571 Ok(0.75)
572 } else {
573 Ok(all_values.iter().sum::<f32>() / all_values.len() as f32)
574 }
575 }
576}
577
578pub struct CostOptimizer {
580 cost_model: CostModel,
581 #[allow(dead_code)]
582 performance_model: PerformanceModel,
583}
584
585impl Default for CostOptimizer {
586 fn default() -> Self {
587 Self::new()
588 }
589}
590
591impl CostOptimizer {
592 pub fn new() -> Self {
593 Self {
594 cost_model: CostModel::new(),
595 performance_model: PerformanceModel::new(),
596 }
597 }
598
599 pub fn calculate_current_cost(&self, nodes: usize, metrics: &PerformanceMetrics) -> f32 {
600 self.cost_model.calculate_cost(nodes, metrics)
601 }
602
603 pub fn calculate_scale_up_cost(&self, new_nodes: usize, metrics: &PerformanceMetrics) -> f32 {
604 self.cost_model.calculate_cost(new_nodes, metrics)
605 }
606
607 pub fn calculate_scale_down_cost(&self, new_nodes: usize, metrics: &PerformanceMetrics) -> f32 {
608 self.cost_model.calculate_cost(new_nodes, metrics)
609 }
610}
611
612pub struct CostModel {
614 cost_per_node_hour: f32,
615 bandwidth_cost_factor: f32,
616}
617
618impl Default for CostModel {
619 fn default() -> Self {
620 Self::new()
621 }
622}
623
624impl CostModel {
625 pub fn new() -> Self {
626 Self {
627 cost_per_node_hour: 3.0, bandwidth_cost_factor: 0.1, }
630 }
631
632 pub fn calculate_cost(&self, nodes: usize, metrics: &PerformanceMetrics) -> f32 {
633 let compute_cost = nodes as f32 * self.cost_per_node_hour;
634 let bandwidth_cost = metrics.bandwidth_utilization * self.bandwidth_cost_factor;
635 compute_cost + bandwidth_cost
636 }
637}
638
639pub struct PerformanceModel {
641 scaling_efficiency: f32,
642}
643
644impl Default for PerformanceModel {
645 fn default() -> Self {
646 Self::new()
647 }
648}
649
650impl PerformanceModel {
651 pub fn new() -> Self {
652 Self {
653 scaling_efficiency: 0.85, }
655 }
656
657 pub fn predict_performance(&self, nodes: usize, base_throughput: f32) -> f32 {
658 base_throughput * nodes as f32 * self.scaling_efficiency
659 }
660}
661
662pub struct SmartCheckpointManager {
664 config: CheckpointConfig,
665 checkpoint_history: Vec<CheckpointInfo>,
666 compression_enabled: bool,
667 validation_enabled: bool,
668 differential_enabled: bool,
669 checkpoint_dir: PathBuf,
670}
671
672#[derive(Debug, Clone)]
673pub struct CheckpointConfig {
674 pub base_frequency: usize,
676 pub adaptive_frequency: bool,
678 pub max_file_size_mb: usize,
680 pub retention_count: usize,
682 pub compression: bool,
684 pub validation: bool,
686 pub differential: bool,
688}
689
690impl Default for CheckpointConfig {
691 fn default() -> Self {
692 Self {
693 base_frequency: 1000,
694 adaptive_frequency: true,
695 max_file_size_mb: 1024, retention_count: 5,
697 compression: true,
698 validation: true,
699 differential: true,
700 }
701 }
702}
703
704#[derive(Debug, Clone)]
705pub struct CheckpointInfo {
706 pub step: usize,
707 pub timestamp: SystemTime,
708 pub file_path: PathBuf,
709 pub file_size: usize,
710 pub validation_passed: bool,
711 pub is_differential: bool,
712 pub base_checkpoint: Option<usize>, }
714
715impl SmartCheckpointManager {
716 pub fn new(config: CheckpointConfig, checkpoint_dir: PathBuf) -> Result<Self> {
717 std::fs::create_dir_all(&checkpoint_dir)?;
718
719 let compression_enabled = config.compression;
720 let validation_enabled = config.validation;
721 let differential_enabled = config.differential;
722
723 Ok(Self {
724 config,
725 checkpoint_history: Vec::new(),
726 compression_enabled,
727 validation_enabled,
728 differential_enabled,
729 checkpoint_dir,
730 })
731 }
732
733 pub fn should_checkpoint(&self, step: usize, performance_metrics: &PerformanceMetrics) -> bool {
734 if step % self.config.base_frequency == 0 {
735 return true;
736 }
737
738 if self.config.adaptive_frequency {
739 self.adaptive_checkpoint_decision(step, performance_metrics)
741 } else {
742 false
743 }
744 }
745
746 fn adaptive_checkpoint_decision(&self, _step: usize, metrics: &PerformanceMetrics) -> bool {
747 let avg_gpu_util =
749 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32;
750 let performance_variance = self.calculate_performance_variance(metrics);
751
752 performance_variance > 0.1 || avg_gpu_util < 0.5
754 }
755
756 fn calculate_performance_variance(&self, metrics: &PerformanceMetrics) -> f32 {
757 if metrics.gpu_utilization.is_empty() {
758 return 0.0;
759 }
760
761 let mean =
762 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32;
763 let variance = metrics.gpu_utilization.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
764 / metrics.gpu_utilization.len() as f32;
765
766 variance.sqrt()
767 }
768
769 pub fn create_checkpoint(
770 &mut self,
771 step: usize,
772 model_state: &HashMap<String, Tensor>,
773 ) -> Result<CheckpointInfo> {
774 let timestamp = SystemTime::now();
775
776 let is_differential = self.differential_enabled && !self.checkpoint_history.is_empty();
778 let base_checkpoint = if is_differential {
779 self.checkpoint_history.last().map(|c| c.step)
780 } else {
781 None
782 };
783
784 let filename = if is_differential {
786 format!(
787 "checkpoint_step_{}_diff_{}.ckpt",
788 step,
789 base_checkpoint.unwrap()
790 )
791 } else {
792 format!("checkpoint_step_{}_full.ckpt", step)
793 };
794 let file_path = self.checkpoint_dir.join(filename);
795
796 let checkpoint_data = if is_differential {
798 self.create_differential_checkpoint(model_state)?
799 } else {
800 self.create_full_checkpoint(model_state)?
801 };
802
803 let final_data = if self.compression_enabled {
805 self.compress_checkpoint(&checkpoint_data)?
806 } else {
807 checkpoint_data
808 };
809
810 std::fs::write(&file_path, &final_data)?;
812 let file_size = final_data.len();
813
814 let validation_passed = if self.validation_enabled {
816 self.validate_checkpoint(&file_path)?
817 } else {
818 true
819 };
820
821 let checkpoint_info = CheckpointInfo {
822 step,
823 timestamp,
824 file_path,
825 file_size,
826 validation_passed,
827 is_differential,
828 base_checkpoint,
829 };
830
831 self.checkpoint_history.push(checkpoint_info.clone());
832
833 self.cleanup_old_checkpoints()?;
835
836 println!(
837 "📁 Checkpoint created: Step {}, Size: {:.2}MB, Type: {}",
838 step,
839 file_size as f32 / (1024.0 * 1024.0),
840 if is_differential { "Differential" } else { "Full" }
841 );
842
843 Ok(checkpoint_info)
844 }
845
846 fn create_full_checkpoint(&self, model_state: &HashMap<String, Tensor>) -> Result<Vec<u8>> {
847 let mut data = Vec::new();
850
851 data.extend_from_slice(b"TFRS_CKPT_FULL");
853
854 data.extend_from_slice(&(model_state.len() as u32).to_le_bytes());
856
857 for (name, tensor) in model_state {
859 data.extend_from_slice(&(name.len() as u32).to_le_bytes());
861 data.extend_from_slice(name.as_bytes());
862
863 let shape = tensor.shape();
865 data.extend_from_slice(&(shape.len() as u32).to_le_bytes());
866 for dim in shape {
867 data.extend_from_slice(&(dim as u32).to_le_bytes());
868 }
869
870 let tensor_data = tensor.to_vec_u8()?;
872 data.extend_from_slice(&(tensor_data.len() as u32).to_le_bytes());
873 for &value in &tensor_data {
874 data.extend_from_slice(&value.to_le_bytes());
875 }
876 }
877
878 Ok(data)
879 }
880
881 fn create_differential_checkpoint(
882 &self,
883 model_state: &HashMap<String, Tensor>,
884 ) -> Result<Vec<u8>> {
885 let mut data = Vec::new();
888
889 data.extend_from_slice(b"TFRS_CKPT_DIFF");
891
892 if let Some(base_step) = self.checkpoint_history.last().map(|c| c.step) {
894 data.extend_from_slice(&(base_step as u32).to_le_bytes());
895 }
896
897 let full_data = self.create_full_checkpoint(model_state)?;
900 data.extend_from_slice(&full_data);
901
902 Ok(data)
903 }
904
905 fn compress_checkpoint(&self, data: &[u8]) -> Result<Vec<u8>> {
906 let mut compressed = Vec::new();
909 compressed.extend_from_slice(b"COMPRESSED");
910 compressed.extend_from_slice(&(data.len() as u32).to_le_bytes());
911 compressed.extend_from_slice(data);
912 Ok(compressed)
913 }
914
915 fn validate_checkpoint(&self, file_path: &PathBuf) -> Result<bool> {
916 let metadata = std::fs::metadata(file_path)?;
918 Ok(metadata.len() > 100) }
920
921 fn cleanup_old_checkpoints(&mut self) -> Result<()> {
922 if self.checkpoint_history.len() <= self.config.retention_count {
923 return Ok(());
924 }
925
926 let to_remove = self.checkpoint_history.len() - self.config.retention_count;
928 for _ in 0..to_remove {
929 if let Some(old_checkpoint) = self.checkpoint_history.first() {
930 if let Err(e) = std::fs::remove_file(&old_checkpoint.file_path) {
931 eprintln!("Warning: Failed to remove old checkpoint: {}", e);
932 }
933 }
934 self.checkpoint_history.remove(0);
935 }
936
937 Ok(())
938 }
939
940 pub fn get_latest_checkpoint(&self) -> Option<&CheckpointInfo> {
941 self.checkpoint_history.last()
942 }
943
944 pub fn get_checkpoint_history(&self) -> &[CheckpointInfo] {
945 &self.checkpoint_history
946 }
947}
948
949pub struct PerformanceMLOptimizer {
951 config: MLOptimizerConfig,
952 performance_model: Arc<Mutex<MLPerformanceModel>>,
953 optimization_history: Vec<OptimizationResult>,
954 last_optimization: Instant,
955}
956
957#[derive(Debug, Clone)]
958pub struct MLOptimizerConfig {
959 pub prediction_horizon: usize,
961 pub optimization_frequency: usize,
963 pub auto_tuning: bool,
965 pub model_learning_rate: f32,
967 pub feature_engineering: bool,
969}
970
971impl Default for MLOptimizerConfig {
972 fn default() -> Self {
973 Self {
974 prediction_horizon: 100,
975 optimization_frequency: 50,
976 auto_tuning: true,
977 model_learning_rate: 0.001,
978 feature_engineering: true,
979 }
980 }
981}
982
983#[derive(Debug, Clone)]
984pub struct OptimizationResult {
985 pub timestamp: SystemTime,
986 pub optimization_type: OptimizationType,
987 pub performance_improvement: f32,
988 pub parameters_changed: HashMap<String, f32>,
989}
990
991#[derive(Debug, Clone)]
992pub enum OptimizationType {
993 BatchSizeOptimization,
994 LearningRateScheduling,
995 CommunicationPatternOptimization,
996 MemoryOptimization,
997 CompressionOptimization,
998}
999
1000impl PerformanceMLOptimizer {
1001 pub fn new(config: MLOptimizerConfig) -> Self {
1002 Self {
1003 config,
1004 performance_model: Arc::new(Mutex::new(MLPerformanceModel::new())),
1005 optimization_history: Vec::new(),
1006 last_optimization: Instant::now() - Duration::from_secs(120),
1008 }
1009 }
1010
1011 pub fn with_prediction_horizon(mut self, horizon: usize) -> Self {
1012 self.config.prediction_horizon = horizon;
1013 self
1014 }
1015
1016 pub fn with_optimization_frequency(mut self, frequency: usize) -> Self {
1017 self.config.optimization_frequency = frequency;
1018 self
1019 }
1020
1021 pub fn should_optimize(&self, step: usize) -> bool {
1022 step % self.config.optimization_frequency == 0
1023 && self.last_optimization.elapsed() > Duration::from_secs(60) }
1025
1026 pub fn optimize_performance(
1027 &mut self,
1028 current_metrics: &PerformanceMetrics,
1029 training_config: &mut DistributedConfig,
1030 ) -> Result<Vec<OptimizationResult>> {
1031 let mut optimizations = Vec::new();
1032
1033 {
1035 let mut model = self.performance_model.lock().unwrap();
1036 model.update_training_data(current_metrics)?;
1037 }
1038
1039 if self.config.auto_tuning {
1041 if let Some(result) = self.optimize_batch_sizes(current_metrics, training_config)? {
1043 optimizations.push(result);
1044 }
1045
1046 if let Some(result) = self.optimize_compression(current_metrics, training_config)? {
1048 optimizations.push(result);
1049 }
1050
1051 if let Some(result) = self.optimize_communication(current_metrics, training_config)? {
1053 optimizations.push(result);
1054 }
1055 }
1056
1057 self.optimization_history.extend(optimizations.clone());
1058 self.last_optimization = Instant::now();
1059
1060 Ok(optimizations)
1061 }
1062
1063 fn optimize_batch_sizes(
1064 &self,
1065 metrics: &PerformanceMetrics,
1066 config: &mut DistributedConfig,
1067 ) -> Result<Option<OptimizationResult>> {
1068 let avg_utilization =
1069 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32;
1070 let avg_memory =
1071 metrics.memory_usage.iter().sum::<f32>() / metrics.memory_usage.len() as f32;
1072
1073 let model = self.performance_model.lock().unwrap();
1075 let predicted_optimal_batch =
1076 model.predict_optimal_batch_size(avg_utilization, avg_memory)?;
1077
1078 let current_batch = config.dynamic_batching.initial_batch_size as f32;
1079 let improvement = (predicted_optimal_batch - current_batch) / current_batch;
1080
1081 if improvement.abs() > 0.1 {
1082 config.dynamic_batching.initial_batch_size = predicted_optimal_batch as usize;
1084
1085 let mut params_changed = HashMap::new();
1086 params_changed.insert("batch_size".to_string(), predicted_optimal_batch);
1087
1088 Ok(Some(OptimizationResult {
1089 timestamp: SystemTime::now(),
1090 optimization_type: OptimizationType::BatchSizeOptimization,
1091 performance_improvement: improvement,
1092 parameters_changed: params_changed,
1093 }))
1094 } else {
1095 Ok(None)
1096 }
1097 }
1098
1099 fn optimize_compression(
1100 &self,
1101 metrics: &PerformanceMetrics,
1102 config: &mut DistributedConfig,
1103 ) -> Result<Option<OptimizationResult>> {
1104 if metrics.communication_overhead > 0.3 {
1105 config.compression.target_ratio = (config.compression.target_ratio * 0.8).max(0.05);
1108
1109 let mut params_changed = HashMap::new();
1110 params_changed.insert(
1111 "compression_ratio".to_string(),
1112 config.compression.target_ratio,
1113 );
1114
1115 Ok(Some(OptimizationResult {
1116 timestamp: SystemTime::now(),
1117 optimization_type: OptimizationType::CompressionOptimization,
1118 performance_improvement: 0.15, parameters_changed: params_changed,
1120 }))
1121 } else {
1122 Ok(None)
1123 }
1124 }
1125
1126 fn optimize_communication(
1127 &self,
1128 metrics: &PerformanceMetrics,
1129 _config: &mut DistributedConfig,
1130 ) -> Result<Option<OptimizationResult>> {
1131 if metrics.bandwidth_utilization < 0.5 {
1133 let mut params_changed = HashMap::new();
1135 params_changed.insert("communication_frequency".to_string(), 1.2);
1136
1137 Ok(Some(OptimizationResult {
1138 timestamp: SystemTime::now(),
1139 optimization_type: OptimizationType::CommunicationPatternOptimization,
1140 performance_improvement: 0.08, parameters_changed: params_changed,
1142 }))
1143 } else {
1144 Ok(None)
1145 }
1146 }
1147
1148 pub fn get_optimization_history(&self) -> &[OptimizationResult] {
1149 &self.optimization_history
1150 }
1151}
1152
1153pub struct MLPerformanceModel {
1155 training_data: Vec<(Vec<f32>, f32)>, model_weights: Vec<f32>,
1157 learning_rate: f32,
1158}
1159
1160impl Default for MLPerformanceModel {
1161 fn default() -> Self {
1162 Self::new()
1163 }
1164}
1165
1166impl MLPerformanceModel {
1167 pub fn new() -> Self {
1168 Self {
1169 training_data: Vec::new(),
1170 model_weights: vec![0.5, 0.3, 0.2, 0.1], learning_rate: 0.001,
1172 }
1173 }
1174
1175 pub fn update_training_data(&mut self, metrics: &PerformanceMetrics) -> Result<()> {
1176 let features = vec![
1178 metrics.gpu_utilization.iter().sum::<f32>() / metrics.gpu_utilization.len() as f32,
1179 metrics.memory_usage.iter().sum::<f32>() / metrics.memory_usage.len() as f32,
1180 metrics.communication_overhead,
1181 metrics.bandwidth_utilization,
1182 ];
1183
1184 let target = metrics.throughput;
1185
1186 self.training_data.push((features, target));
1187
1188 if self.training_data.len() > 1000 {
1190 self.training_data.drain(0..500);
1191 }
1192
1193 if self.training_data.len() > 10 {
1195 self.update_model_weights()?;
1196 }
1197
1198 Ok(())
1199 }
1200
1201 fn update_model_weights(&mut self) -> Result<()> {
1202 if self.training_data.is_empty() {
1203 return Ok(());
1204 }
1205
1206 for (features, target) in &self.training_data {
1208 let prediction = self.predict_with_features(features)?;
1209 let error = target - prediction;
1210
1211 for i in 0..self.model_weights.len().min(features.len()) {
1213 self.model_weights[i] += self.learning_rate * error * features[i];
1214 }
1215 }
1216
1217 Ok(())
1218 }
1219
1220 pub fn predict_optimal_batch_size(
1221 &self,
1222 gpu_utilization: f32,
1223 memory_usage: f32,
1224 ) -> Result<f32> {
1225 let utilization_factor = if gpu_utilization < 0.7 {
1227 1.2
1228 } else if gpu_utilization > 0.9 {
1229 0.8
1230 } else {
1231 1.0
1232 };
1233 let memory_factor = if memory_usage > 0.9 {
1234 0.7
1235 } else if memory_usage < 0.5 {
1236 1.3
1237 } else {
1238 1.0
1239 };
1240
1241 let base_batch_size = 32.0_f32;
1242 let optimal_batch: f32 = base_batch_size * utilization_factor * memory_factor;
1243
1244 Ok(optimal_batch.clamp(8.0_f32, 256.0_f32)) }
1246
1247 fn predict_with_features(&self, features: &[f32]) -> Result<f32> {
1248 let prediction = features
1249 .iter()
1250 .zip(self.model_weights.iter())
1251 .map(|(&f, &w)| f * w)
1252 .sum::<f32>();
1253
1254 Ok(prediction.max(0.0)) }
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260 use super::*;
1261
1262 #[test]
1263 fn test_auto_scaler_config() {
1264 let mut config = AutoScalerConfig::default();
1265 config.min_nodes = 2;
1266 config.max_nodes = 32;
1267
1268 assert_eq!(config.min_nodes, 2);
1270 assert_eq!(config.max_nodes, 32);
1271 }
1272
1273 #[test]
1274 fn test_auto_scaler_creation() {
1275 let config = AutoScalerConfig::default();
1276 let auto_scaler = AutoScaler::new(config)
1277 .with_min_nodes(2)
1278 .with_max_nodes(16)
1279 .with_scaling_strategy(ScalingStrategy::Performance);
1280
1281 assert_eq!(auto_scaler.get_current_nodes(), 2);
1282 assert!(matches!(
1283 auto_scaler.config.strategy,
1284 ScalingStrategy::Performance
1285 ));
1286 }
1287
1288 #[test]
1289 fn test_workload_predictor() {
1290 let mut predictor = WorkloadPredictor::new();
1291
1292 let metrics = PerformanceMetrics {
1294 throughput: 1000.0,
1295 gpu_utilization: vec![0.8, 0.7, 0.9],
1296 memory_usage: vec![0.6, 0.7, 0.5],
1297 communication_overhead: 0.2,
1298 compression_ratio: 0.1,
1299 bandwidth_utilization: 0.8,
1300 step_time: Duration::from_millis(100),
1301 };
1302
1303 predictor.update_metrics(&metrics);
1304
1305 let prediction = predictor.predict_workload(Duration::from_secs(600)).unwrap();
1306 assert!(prediction >= 0.0 && prediction <= 1.0);
1307 }
1308
1309 #[test]
1310 fn test_checkpoint_manager() {
1311 let config = CheckpointConfig::default();
1312 let temp_dir = std::env::temp_dir().join("test_checkpoints");
1313
1314 if temp_dir.exists() {
1315 std::fs::remove_dir_all(&temp_dir).ok();
1316 }
1317
1318 let manager = SmartCheckpointManager::new(config, temp_dir).unwrap();
1319
1320 let metrics = PerformanceMetrics {
1321 throughput: 1000.0,
1322 gpu_utilization: vec![0.8],
1323 memory_usage: vec![0.6],
1324 communication_overhead: 0.2,
1325 compression_ratio: 0.1,
1326 bandwidth_utilization: 0.8,
1327 step_time: Duration::from_millis(100),
1328 };
1329
1330 assert!(manager.should_checkpoint(1000, &metrics));
1331 assert!(!manager.should_checkpoint(999, &metrics));
1332 }
1333
1334 #[test]
1335 fn test_ml_optimizer() {
1336 let config = MLOptimizerConfig::default();
1337 let optimizer = PerformanceMLOptimizer::new(config)
1338 .with_prediction_horizon(50)
1339 .with_optimization_frequency(25);
1340
1341 assert_eq!(optimizer.config.prediction_horizon, 50);
1342 assert_eq!(optimizer.config.optimization_frequency, 25);
1343
1344 assert!(optimizer.should_optimize(25));
1345 assert!(!optimizer.should_optimize(24));
1346 }
1347
1348 #[test]
1349 fn test_trend_analyzer() {
1350 let mut analyzer = TrendAnalyzer::new();
1351
1352 for i in 0..20 {
1354 analyzer.update(i as f32 * 0.1);
1355 }
1356
1357 let prediction = analyzer.predict(Duration::from_secs(60)).unwrap();
1358 assert!(prediction > 1.0); }
1360}