1use crate::distributed::DistributedConfig;
2use crate::expert_parallelism::ExpertParallelismConfig;
3use crate::parallelism_3d::ParallelismConfig;
4use crate::sequence_parallelism::SequenceParallelismConfig;
5use crate::tensor_parallelism::TensorParallelismConfig;
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10use trustformers_core::Model;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AutoParallelismConfig {
22 pub enabled: bool,
24 pub selection_algorithm: SelectionAlgorithm,
26 pub optimization_objective: OptimizationObjective,
28 pub hardware_constraints: HardwareConstraints,
30 pub model_constraints: ModelConstraints,
32 pub performance_requirements: PerformanceRequirements,
34 pub evaluation_method: EvaluationMethod,
36 pub dynamic_adaptation: bool,
38 pub adaptation_frequency: usize,
40}
41
42impl Default for AutoParallelismConfig {
43 fn default() -> Self {
44 Self {
45 enabled: true,
46 selection_algorithm: SelectionAlgorithm::CostBasedOptimization,
47 optimization_objective: OptimizationObjective::MinimizeTime,
48 hardware_constraints: HardwareConstraints::default(),
49 model_constraints: ModelConstraints::default(),
50 performance_requirements: PerformanceRequirements::default(),
51 evaluation_method: EvaluationMethod::ModelBased,
52 dynamic_adaptation: false,
53 adaptation_frequency: 1000,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub enum SelectionAlgorithm {
61 RuleBased,
63 CostBasedOptimization,
65 MLBased,
67 GeneticAlgorithm,
69 SimulatedAnnealing,
71 MultiObjective,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub enum OptimizationObjective {
78 MinimizeTime,
80 MinimizeMemory,
82 MinimizeCommunication,
84 MaximizeThroughput,
86 MaximizeEfficiency,
88 MultiObjective(Vec<OptimizationObjective>),
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct HardwareConstraints {
95 pub num_devices: usize,
97 pub memory_per_device: u64,
99 pub compute_per_device: f64,
101 pub inter_device_bandwidth: u64,
103 pub intra_node_bandwidth: u64,
105 pub network_latency: f64,
107 pub device_types: Vec<DeviceType>,
109 pub topology: NetworkTopology,
111}
112
113impl Default for HardwareConstraints {
114 fn default() -> Self {
115 Self {
116 num_devices: 8,
117 memory_per_device: 80 * 1024 * 1024 * 1024, compute_per_device: 312e12, inter_device_bandwidth: 600 * 1024 * 1024 * 1024, intra_node_bandwidth: 900 * 1024 * 1024 * 1024, network_latency: 5.0, device_types: vec![DeviceType::GPU; 8],
123 topology: NetworkTopology::FullyConnected,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum DeviceType {
131 GPU,
132 TPU,
133 CPU,
134 Custom(String),
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub enum NetworkTopology {
140 FullyConnected,
141 Ring,
142 Tree,
143 Mesh2D,
144 Mesh3D,
145 Torus,
146 Custom(String),
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ModelConstraints {
152 pub num_parameters: u64,
154 pub num_layers: usize,
156 pub hidden_size: usize,
158 pub num_attention_heads: usize,
160 pub max_sequence_length: usize,
162 pub vocab_size: usize,
164 pub architecture_type: ArchitectureType,
166 pub has_mixture_of_experts: bool,
168 pub num_experts: Option<usize>,
170}
171
172impl Default for ModelConstraints {
173 fn default() -> Self {
174 Self {
175 num_parameters: 7_000_000_000, num_layers: 32,
177 hidden_size: 4096,
178 num_attention_heads: 32,
179 max_sequence_length: 2048,
180 vocab_size: 50257,
181 architecture_type: ArchitectureType::Transformer,
182 has_mixture_of_experts: false,
183 num_experts: None,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub enum ArchitectureType {
191 Transformer,
192 GPT,
193 BERT,
194 T5,
195 MoE,
196 ConvNet,
197 RNN,
198 Custom(String),
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct PerformanceRequirements {
204 pub max_training_time: Option<Duration>,
206 pub min_throughput: Option<f64>,
208 pub max_memory_per_device: Option<u64>,
210 pub max_communication_overhead: Option<f32>,
212 pub min_efficiency: Option<f32>,
214}
215
216impl Default for PerformanceRequirements {
217 fn default() -> Self {
218 Self {
219 max_training_time: None,
220 min_throughput: None,
221 max_memory_per_device: None,
222 max_communication_overhead: Some(0.3), min_efficiency: Some(0.7), }
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub enum EvaluationMethod {
231 ModelBased,
233 SimulationBased,
235 ProfilingBased,
237 Hybrid,
239}
240
241#[derive(Debug, Clone)]
243pub struct ParallelismStrategy {
244 pub strategy_id: String,
246 pub data_parallel: Option<DistributedConfig>,
248 pub parallelism_3d: Option<ParallelismConfig>,
250 pub expert_parallel: Option<ExpertParallelismConfig>,
252 pub sequence_parallel: Option<SequenceParallelismConfig>,
254 pub tensor_parallel: Option<TensorParallelismConfig>,
256 pub expected_performance: PerformanceMetrics,
258 pub confidence: f32,
260 pub rationale: String,
262}
263
264#[derive(Debug, Clone)]
266pub struct PerformanceMetrics {
267 pub time_per_step: Duration,
269 pub memory_per_device: u64,
271 pub communication_overhead: f32,
273 pub throughput: f64,
275 pub efficiency: f32,
277 pub scalability: f32,
279}
280
281#[derive(Debug, Clone)]
283pub struct MLFeatures {
284 pub log_num_parameters: f64,
286 pub num_layers: f64,
287 pub log_hidden_size: f64,
288 pub num_attention_heads: f64,
289 pub log_sequence_length: f64,
290 pub log_vocab_size: f64,
291 pub has_moe: f64, pub log_num_devices: f64,
295 pub log_memory_per_device: f64,
296 pub log_compute_per_device: f64,
297 pub log_bandwidth: f64,
298 pub network_latency: f64,
299
300 pub memory_to_compute_ratio: f64,
302 pub parameters_per_device: f64,
303 pub communication_intensity: f64,
304}
305
306#[derive(Debug, Clone)]
308pub struct GeneticIndividual {
309 pub strategy: ParallelismStrategy,
311 pub fitness: f32,
313 pub dp_size: usize,
315 pub mp_size: usize,
317 pub pp_size: usize,
319}
320
321pub struct AutoParallelismSelector {
323 config: AutoParallelismConfig,
324 #[allow(dead_code)]
325 strategy_cache: HashMap<String, ParallelismStrategy>,
326 performance_history: Vec<(ParallelismStrategy, PerformanceMetrics)>,
327 current_strategy: Option<ParallelismStrategy>,
328}
329
330impl AutoParallelismSelector {
331 pub fn new(config: AutoParallelismConfig) -> Self {
333 Self {
334 config,
335 strategy_cache: HashMap::new(),
336 performance_history: Vec::new(),
337 current_strategy: None,
338 }
339 }
340
341 pub fn select_strategy(&mut self) -> Result<ParallelismStrategy> {
343 let strategies = self.generate_candidate_strategies()?;
344 let evaluated_strategies = self.evaluate_strategies(strategies)?;
345 let optimal_strategy = self.select_optimal_strategy(evaluated_strategies)?;
346
347 self.current_strategy = Some(optimal_strategy.clone());
348 Ok(optimal_strategy)
349 }
350
351 fn generate_candidate_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
353 let mut strategies = Vec::new();
354
355 match self.config.selection_algorithm {
357 SelectionAlgorithm::RuleBased => {
358 strategies.extend(self.generate_rule_based_strategies()?);
359 },
360 SelectionAlgorithm::CostBasedOptimization => {
361 strategies.extend(self.generate_cost_based_strategies()?);
362 },
363 SelectionAlgorithm::MLBased => {
364 strategies.extend(self.generate_ml_based_strategies()?);
365 },
366 SelectionAlgorithm::GeneticAlgorithm => {
367 strategies.extend(self.generate_genetic_strategies()?);
368 },
369 SelectionAlgorithm::SimulatedAnnealing => {
370 strategies.extend(self.generate_annealing_strategies()?);
371 },
372 SelectionAlgorithm::MultiObjective => {
373 strategies.extend(self.generate_multi_objective_strategies()?);
374 },
375 }
376
377 Ok(strategies)
378 }
379
380 fn generate_rule_based_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
382 let mut strategies = Vec::new();
383 let hardware = &self.config.hardware_constraints;
384 let model = &self.config.model_constraints;
385
386 if model.num_parameters < 1_000_000_000 {
388 strategies.push(self.create_data_parallel_strategy()?);
390 }
391
392 if model.num_parameters > 10_000_000_000 {
394 strategies.push(self.create_3d_parallel_strategy()?);
396 }
397
398 if model.has_mixture_of_experts {
400 strategies.push(self.create_expert_parallel_strategy()?);
401 }
402
403 if model.max_sequence_length > 8192 {
405 strategies.push(self.create_sequence_parallel_strategy()?);
406 }
407
408 if model.hidden_size > 8192 {
410 strategies.push(self.create_tensor_parallel_strategy()?);
411 }
412
413 if hardware.num_devices > 16 {
415 strategies.push(self.create_hybrid_strategy()?);
416 }
417
418 if strategies.is_empty() {
420 strategies.push(self.create_data_parallel_strategy()?);
422 }
423
424 Ok(strategies)
425 }
426
427 fn generate_cost_based_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
429 let mut strategies = Vec::new();
430
431 let dp_sizes = vec![1, 2, 4, 8];
433 let mp_sizes = vec![1, 2, 4];
434 let pp_sizes = vec![1, 2, 4];
435
436 for dp in &dp_sizes {
437 for mp in &mp_sizes {
438 for pp in &pp_sizes {
439 if dp * mp * pp <= self.config.hardware_constraints.num_devices {
440 let strategy = self.create_3d_strategy_with_config(*dp, *mp, *pp)?;
441 strategies.push(strategy);
442 }
443 }
444 }
445 }
446
447 Ok(strategies)
448 }
449
450 fn generate_ml_based_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
452 let features = self.extract_ml_features()?;
454
455 let predicted_strategies = self.predict_strategies_with_ml(&features)?;
457
458 if !self.performance_history.is_empty() {
460 return self.refine_strategies_with_history(predicted_strategies);
461 }
462
463 Ok(predicted_strategies)
464 }
465
466 fn extract_ml_features(&self) -> Result<MLFeatures> {
468 let hardware = &self.config.hardware_constraints;
469 let model = &self.config.model_constraints;
470
471 Ok(MLFeatures {
472 log_num_parameters: (model.num_parameters as f64).log10(),
474 num_layers: model.num_layers as f64,
475 log_hidden_size: (model.hidden_size as f64).log10(),
476 num_attention_heads: model.num_attention_heads as f64,
477 log_sequence_length: (model.max_sequence_length as f64).log10(),
478 log_vocab_size: (model.vocab_size as f64).log10(),
479 has_moe: if model.has_mixture_of_experts { 1.0 } else { 0.0 },
480
481 log_num_devices: (hardware.num_devices as f64).log10(),
483 log_memory_per_device: (hardware.memory_per_device as f64).log10(),
484 log_compute_per_device: hardware.compute_per_device.log10(),
485 log_bandwidth: (hardware.inter_device_bandwidth as f64).log10(),
486 network_latency: hardware.network_latency,
487
488 memory_to_compute_ratio: (hardware.memory_per_device as f64)
490 / hardware.compute_per_device,
491 parameters_per_device: (model.num_parameters as f64) / (hardware.num_devices as f64),
492 communication_intensity: (model.hidden_size * model.num_attention_heads) as f64
493 / (hardware.inter_device_bandwidth as f64 / 1e9), })
495 }
496
497 fn predict_strategies_with_ml(
499 &self,
500 features: &MLFeatures,
501 ) -> Result<Vec<ParallelismStrategy>> {
502 let mut strategies = Vec::new();
503
504 if features.log_num_parameters < 9.0 {
507 if features.log_num_devices < 1.0 {
510 strategies.push(self.create_data_parallel_strategy()?);
512 } else {
513 strategies.push(self.create_data_parallel_strategy()?);
514 if features.log_hidden_size > 3.5 {
515 strategies.push(self.create_tensor_parallel_strategy()?);
517 }
518 }
519 } else if features.log_num_parameters < 10.3 {
520 if features.log_num_devices < 0.9 {
523 strategies.push(self.create_data_parallel_strategy()?);
525 if features.log_hidden_size > 3.6 {
526 strategies.push(self.create_tensor_parallel_strategy()?);
527 }
528 } else {
529 strategies.push(self.create_3d_parallel_strategy()?);
530 if features.has_moe > 0.5 {
531 strategies.push(self.create_expert_parallel_strategy()?);
532 }
533 }
534 } else {
535 strategies.push(self.create_3d_parallel_strategy()?);
538 if features.log_num_devices > 1.2 {
539 strategies.push(self.create_hybrid_strategy()?);
541 }
542 if features.has_moe > 0.5 {
543 strategies.push(self.create_expert_parallel_strategy()?);
544 }
545 if features.log_sequence_length > 3.9 {
546 strategies.push(self.create_sequence_parallel_strategy()?);
548 }
549 }
550
551 if features.communication_intensity > 0.1 {
553 if !strategies.iter().any(|s| s.strategy_id.contains("tensor_parallel")) {
555 strategies.push(self.create_tensor_parallel_strategy()?);
556 }
557 }
558
559 if features.parameters_per_device > 10e9 {
561 if !strategies.iter().any(|s| s.strategy_id.contains("3d_parallel")) {
563 strategies.push(self.create_3d_parallel_strategy()?);
564 }
565 }
566
567 Ok(strategies)
568 }
569
570 fn refine_strategies_with_history(
572 &self,
573 mut strategies: Vec<ParallelismStrategy>,
574 ) -> Result<Vec<ParallelismStrategy>> {
575 let mut strategy_performance_map: HashMap<String, Vec<f32>> = HashMap::new();
577
578 for (historical_strategy, historical_performance) in &self.performance_history {
579 let performance_score = self.calculate_performance_score(historical_performance);
580 strategy_performance_map
581 .entry(historical_strategy.strategy_id.clone())
582 .or_default()
583 .push(performance_score);
584 }
585
586 for strategy in &mut strategies {
588 if let Some(historical_scores) = strategy_performance_map.get(&strategy.strategy_id) {
589 let avg_score =
590 historical_scores.iter().sum::<f32>() / historical_scores.len() as f32;
591
592 if avg_score > 0.8 {
594 strategy.confidence = (strategy.confidence + 0.2).min(1.0);
595 } else if avg_score < 0.5 {
596 strategy.confidence = (strategy.confidence - 0.2).max(0.1);
597 }
598 }
599 }
600
601 strategies.sort_by(|a, b| {
603 b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
604 });
605 Ok(strategies)
606 }
607
608 fn calculate_performance_score(&self, metrics: &PerformanceMetrics) -> f32 {
610 let time_score = 1.0 / (metrics.time_per_step.as_secs_f32() + 1e-6);
611 let memory_score = 1.0 / (metrics.memory_per_device as f32 / 1e9 + 1e-6);
612 let comm_score = 1.0 - metrics.communication_overhead.clamp(0.0, 1.0);
613 let throughput_score = (metrics.throughput as f32).min(10.0) / 10.0;
614 let efficiency_score = metrics.efficiency;
615
616 (time_score * 0.25
618 + memory_score * 0.15
619 + comm_score * 0.2
620 + throughput_score * 0.2
621 + efficiency_score * 0.2)
622 .clamp(0.0, 1.0)
623 }
624
625 fn generate_genetic_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
627 let population_size = 20;
628 let generations = 10;
629 let mutation_rate = 0.2;
630 let elite_size = 4;
631
632 let mut population = self.initialize_genetic_population(population_size)?;
634
635 for _generation in 0..generations {
637 self.evaluate_genetic_fitness(&mut population)?;
639
640 population.sort_by(|a, b| {
642 b.fitness.partial_cmp(&a.fitness).unwrap_or(std::cmp::Ordering::Equal)
643 });
644
645 let mut new_population = Vec::new();
647
648 for i in 0..elite_size.min(population.len()) {
650 new_population.push(population[i].clone());
651 }
652
653 while new_population.len() < population_size {
655 let parent1 = self.tournament_selection(&population, 3)?;
656 let parent2 = self.tournament_selection(&population, 3)?;
657
658 let mut offspring = self.crossover_genetic_individual(parent1, parent2)?;
659
660 if fastrand::f32() < mutation_rate {
661 self.mutate_genetic_individual(&mut offspring)?;
662 }
663
664 new_population.push(offspring);
665 }
666
667 population = new_population;
668 }
669
670 self.evaluate_genetic_fitness(&mut population)?;
672 population
673 .sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap_or(std::cmp::Ordering::Equal));
674
675 Ok(population.into_iter().take(5).map(|gi| gi.strategy).collect())
676 }
677
678 fn initialize_genetic_population(&self, size: usize) -> Result<Vec<GeneticIndividual>> {
680 let mut population = Vec::new();
681 let max_devices = self.config.hardware_constraints.num_devices;
682
683 for _ in 0..size {
684 let dp_size = 1 << fastrand::usize(0..4); let mp_size = 1 << fastrand::usize(0..3); let pp_size = max_devices / (dp_size * mp_size).max(1);
688
689 let strategy = if pp_size > 1 {
690 self.create_3d_strategy_with_config(dp_size, mp_size, pp_size)?
691 } else if mp_size > 1 {
692 self.create_tensor_parallel_strategy()?
693 } else {
694 self.create_data_parallel_strategy()?
695 };
696
697 population.push(GeneticIndividual {
698 strategy,
699 fitness: 0.0,
700 dp_size,
701 mp_size,
702 pp_size,
703 });
704 }
705
706 Ok(population)
707 }
708
709 fn evaluate_genetic_fitness(&self, population: &mut [GeneticIndividual]) -> Result<()> {
711 for individual in population {
712 individual.fitness = self.calculate_strategy_fitness(&individual.strategy);
713 }
714 Ok(())
715 }
716
717 fn calculate_strategy_fitness(&self, strategy: &ParallelismStrategy) -> f32 {
719 let metrics = &strategy.expected_performance;
720
721 let time_fitness = 1.0 / (metrics.time_per_step.as_secs_f32() + 1e-6);
723 let memory_fitness = 1.0 / (metrics.memory_per_device as f32 / 1e9 + 1e-6);
724 let comm_fitness = 1.0 - metrics.communication_overhead.clamp(0.0, 1.0);
725 let throughput_fitness = (metrics.throughput as f32).min(10.0);
726 let efficiency_fitness = metrics.efficiency;
727
728 match &self.config.optimization_objective {
730 OptimizationObjective::MinimizeTime => time_fitness,
731 OptimizationObjective::MinimizeMemory => memory_fitness,
732 OptimizationObjective::MinimizeCommunication => comm_fitness,
733 OptimizationObjective::MaximizeThroughput => throughput_fitness,
734 OptimizationObjective::MaximizeEfficiency => efficiency_fitness,
735 OptimizationObjective::MultiObjective(_) => {
736 (time_fitness
737 + memory_fitness
738 + comm_fitness
739 + throughput_fitness
740 + efficiency_fitness)
741 / 5.0
742 },
743 }
744 }
745
746 fn tournament_selection<'a>(
748 &self,
749 population: &'a [GeneticIndividual],
750 tournament_size: usize,
751 ) -> Result<&'a GeneticIndividual> {
752 let mut best_individual = &population[fastrand::usize(0..population.len())];
753
754 for _ in 1..tournament_size {
755 let candidate = &population[fastrand::usize(0..population.len())];
756 if candidate.fitness > best_individual.fitness {
757 best_individual = candidate;
758 }
759 }
760
761 Ok(best_individual)
762 }
763
764 fn crossover_genetic_individual(
766 &self,
767 parent1: &GeneticIndividual,
768 parent2: &GeneticIndividual,
769 ) -> Result<GeneticIndividual> {
770 let dp_size = if fastrand::bool() { parent1.dp_size } else { parent2.dp_size };
772 let mp_size = if fastrand::bool() { parent1.mp_size } else { parent2.mp_size };
773 let pp_size = if fastrand::bool() { parent1.pp_size } else { parent2.pp_size };
774
775 let total_devices = dp_size * mp_size * pp_size;
777 let max_devices = self.config.hardware_constraints.num_devices;
778
779 if total_devices <= max_devices {
780 let strategy = self.create_3d_strategy_with_config(dp_size, mp_size, pp_size)?;
781 Ok(GeneticIndividual {
782 strategy,
783 fitness: 0.0,
784 dp_size,
785 mp_size,
786 pp_size,
787 })
788 } else {
789 Ok(if parent1.fitness > parent2.fitness {
791 parent1.clone()
792 } else {
793 parent2.clone()
794 })
795 }
796 }
797
798 fn mutate_genetic_individual(&self, individual: &mut GeneticIndividual) -> Result<()> {
800 let max_devices = self.config.hardware_constraints.num_devices;
801
802 match fastrand::usize(0..3) {
804 0 => {
805 let new_dp = (individual.dp_size * 2).min(max_devices);
807 if new_dp * individual.mp_size * individual.pp_size <= max_devices {
808 individual.dp_size = new_dp;
809 }
810 },
811 1 => {
812 let new_mp = (individual.mp_size * 2).min(8);
814 if individual.dp_size * new_mp * individual.pp_size <= max_devices {
815 individual.mp_size = new_mp;
816 }
817 },
818 2 => {
819 let new_pp = (individual.pp_size * 2).min(max_devices);
821 if individual.dp_size * individual.mp_size * new_pp <= max_devices {
822 individual.pp_size = new_pp;
823 }
824 },
825 _ => {},
826 }
827
828 individual.strategy = self.create_3d_strategy_with_config(
830 individual.dp_size,
831 individual.mp_size,
832 individual.pp_size,
833 )?;
834 individual.fitness = 0.0; Ok(())
837 }
838
839 fn generate_annealing_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
841 self.generate_cost_based_strategies()
843 }
844
845 fn generate_multi_objective_strategies(&self) -> Result<Vec<ParallelismStrategy>> {
847 self.generate_cost_based_strategies()
849 }
850
851 fn create_data_parallel_strategy(&self) -> Result<ParallelismStrategy> {
853 let data_parallel = Some(DistributedConfig {
854 world_size: self.config.hardware_constraints.num_devices,
855 rank: 0,
856 backend: crate::distributed::DistributedBackend::NCCL,
857 master_addr: "localhost".to_string(),
858 master_port: 29500,
859 gradient_compression: false,
860 bucket_size_mb: 25,
861 });
862
863 let expected_performance = self.estimate_performance_data_parallel()?;
864
865 Ok(ParallelismStrategy {
866 strategy_id: "data_parallel".to_string(),
867 data_parallel,
868 parallelism_3d: None,
869 expert_parallel: None,
870 sequence_parallel: None,
871 tensor_parallel: None,
872 expected_performance,
873 confidence: 0.9,
874 rationale: "Model size suitable for data parallelism".to_string(),
875 })
876 }
877
878 fn create_3d_parallel_strategy(&self) -> Result<ParallelismStrategy> {
880 let num_devices = self.config.hardware_constraints.num_devices;
881
882 let dp_size = std::cmp::min(4, num_devices);
884 let mp_size = std::cmp::min(2, num_devices / dp_size);
885 let pp_size = num_devices / (dp_size * mp_size);
886
887 self.create_3d_strategy_with_config(dp_size, mp_size, pp_size)
888 }
889
890 fn create_3d_strategy_with_config(
892 &self,
893 dp_size: usize,
894 mp_size: usize,
895 pp_size: usize,
896 ) -> Result<ParallelismStrategy> {
897 let parallelism_3d = Some(ParallelismConfig {
898 dp_size,
899 mp_size,
900 pp_size,
901 num_micro_batches: 4,
902 gradient_accumulation: true,
903 accumulation_steps: 1,
904 activation_checkpointing: true,
905 comm_backend: crate::parallelism_3d::CommBackend::NCCL,
906 pipeline_schedule: crate::parallelism_3d::PipelineSchedule::GPipe,
907 memory_optimization: crate::parallelism_3d::MemoryOptimization::Medium,
908 });
909
910 let expected_performance =
911 self.estimate_performance_3d_parallel(dp_size, mp_size, pp_size)?;
912
913 Ok(ParallelismStrategy {
914 strategy_id: format!("3d_parallel_{}_{}_", dp_size, mp_size),
915 data_parallel: None,
916 parallelism_3d,
917 expert_parallel: None,
918 sequence_parallel: None,
919 tensor_parallel: None,
920 expected_performance,
921 confidence: 0.8,
922 rationale: format!(
923 "Large model requiring 3D parallelism: DP={}, MP={}, PP={}",
924 dp_size, mp_size, pp_size
925 ),
926 })
927 }
928
929 fn create_expert_parallel_strategy(&self) -> Result<ParallelismStrategy> {
931 let num_experts = self.config.model_constraints.num_experts.unwrap_or(8);
932 let expert_parallel_size =
933 std::cmp::min(num_experts, self.config.hardware_constraints.num_devices);
934
935 let expert_parallel = Some(ExpertParallelismConfig {
936 num_experts,
937 experts_per_device: num_experts / expert_parallel_size,
938 expert_parallel_size,
939 top_k: 2,
940 load_balancing: crate::expert_parallelism::LoadBalancingStrategy::TokenChoiceBased,
941 routing_strategy: crate::expert_parallelism::ExpertRoutingStrategy::LearnedGating,
942 capacity_factor: 1.25,
943 drop_tokens: false,
944 use_auxiliary_loss: true,
945 auxiliary_loss_weight: 0.01,
946 communication_pattern: crate::expert_parallelism::ExpertCommunicationPattern::AllToAll,
947 });
948
949 let expected_performance = self.estimate_performance_expert_parallel()?;
950
951 Ok(ParallelismStrategy {
952 strategy_id: "expert_parallel".to_string(),
953 data_parallel: None,
954 parallelism_3d: None,
955 expert_parallel,
956 sequence_parallel: None,
957 tensor_parallel: None,
958 expected_performance,
959 confidence: 0.85,
960 rationale: "MoE model requiring expert parallelism".to_string(),
961 })
962 }
963
964 fn create_sequence_parallel_strategy(&self) -> Result<ParallelismStrategy> {
966 let sequence_parallel_size = std::cmp::min(4, self.config.hardware_constraints.num_devices);
967 let max_seq_per_device =
968 self.config.model_constraints.max_sequence_length / sequence_parallel_size;
969
970 let sequence_parallel = Some(SequenceParallelismConfig {
971 sequence_parallel_size,
972 max_sequence_length_per_device: max_seq_per_device,
973 overlap_size: std::cmp::min(128, max_seq_per_device / 10),
974 attention_communication_opt: true,
975 communication_pattern:
976 crate::sequence_parallelism::SequenceCommunicationPattern::RingAllReduce,
977 splitting_strategy: crate::sequence_parallelism::SequenceSplittingStrategy::EqualChunks,
978 sync_gradients: true,
979 memory_optimization: crate::sequence_parallelism::SequenceMemoryOptimization::Medium,
980 use_checkpointing: true,
981 });
982
983 let expected_performance = self.estimate_performance_sequence_parallel()?;
984
985 Ok(ParallelismStrategy {
986 strategy_id: "sequence_parallel".to_string(),
987 data_parallel: None,
988 parallelism_3d: None,
989 expert_parallel: None,
990 sequence_parallel,
991 tensor_parallel: None,
992 expected_performance,
993 confidence: 0.8,
994 rationale: "Long sequences requiring sequence parallelism".to_string(),
995 })
996 }
997
998 fn create_tensor_parallel_strategy(&self) -> Result<ParallelismStrategy> {
1000 let tensor_parallel_size = std::cmp::min(4, self.config.hardware_constraints.num_devices);
1001
1002 let tensor_parallel = Some(TensorParallelismConfig {
1003 tensor_parallel_size,
1004 partitioning_strategy:
1005 crate::tensor_parallelism::TensorPartitioningStrategy::ColumnWise,
1006 column_parallel: true,
1007 row_parallel: true,
1008 communication_pattern: crate::tensor_parallelism::TensorCommunicationPattern::AllReduce,
1009 async_communication: true,
1010 fusion_threshold_bytes: 1024 * 1024,
1011 gradient_accumulation: true,
1012 memory_optimization: crate::tensor_parallelism::TensorMemoryOptimization::Medium,
1013 mixed_precision: false,
1014 });
1015
1016 let expected_performance = self.estimate_performance_tensor_parallel()?;
1017
1018 Ok(ParallelismStrategy {
1019 strategy_id: "tensor_parallel".to_string(),
1020 data_parallel: None,
1021 parallelism_3d: None,
1022 expert_parallel: None,
1023 sequence_parallel: None,
1024 tensor_parallel,
1025 expected_performance,
1026 confidence: 0.85,
1027 rationale: "Wide model requiring tensor parallelism".to_string(),
1028 })
1029 }
1030
1031 fn create_hybrid_strategy(&self) -> Result<ParallelismStrategy> {
1033 let num_devices = self.config.hardware_constraints.num_devices;
1034
1035 let dp_size = 2;
1037 let mp_size = 2;
1038 let pp_size = num_devices / (dp_size * mp_size);
1039
1040 let parallelism_3d = Some(ParallelismConfig {
1041 dp_size,
1042 mp_size,
1043 pp_size,
1044 num_micro_batches: 4,
1045 gradient_accumulation: true,
1046 accumulation_steps: 1,
1047 activation_checkpointing: true,
1048 comm_backend: crate::parallelism_3d::CommBackend::NCCL,
1049 pipeline_schedule: crate::parallelism_3d::PipelineSchedule::GPipe,
1050 memory_optimization: crate::parallelism_3d::MemoryOptimization::High,
1051 });
1052
1053 let tensor_parallel = if self.config.model_constraints.hidden_size > 4096 {
1054 Some(TensorParallelismConfig {
1055 tensor_parallel_size: mp_size,
1056 ..Default::default()
1057 })
1058 } else {
1059 None
1060 };
1061
1062 let expected_performance = self.estimate_performance_hybrid()?;
1063
1064 Ok(ParallelismStrategy {
1065 strategy_id: "hybrid".to_string(),
1066 data_parallel: None,
1067 parallelism_3d,
1068 expert_parallel: None,
1069 sequence_parallel: None,
1070 tensor_parallel,
1071 expected_performance,
1072 confidence: 0.75,
1073 rationale: "Complex model and many devices requiring hybrid parallelism".to_string(),
1074 })
1075 }
1076
1077 fn evaluate_strategies(
1079 &self,
1080 strategies: Vec<ParallelismStrategy>,
1081 ) -> Result<Vec<ParallelismStrategy>> {
1082 match self.config.evaluation_method {
1083 EvaluationMethod::ModelBased => self.evaluate_model_based(strategies),
1084 EvaluationMethod::SimulationBased => self.evaluate_simulation_based(strategies),
1085 EvaluationMethod::ProfilingBased => self.evaluate_profiling_based(strategies),
1086 EvaluationMethod::Hybrid => self.evaluate_hybrid(strategies),
1087 }
1088 }
1089
1090 fn evaluate_model_based(
1092 &self,
1093 mut strategies: Vec<ParallelismStrategy>,
1094 ) -> Result<Vec<ParallelismStrategy>> {
1095 for strategy in &mut strategies {
1097 strategy.expected_performance = self.refine_performance_estimate(strategy)?;
1098 strategy.confidence = self.calculate_confidence(strategy);
1099 }
1100 Ok(strategies)
1101 }
1102
1103 fn evaluate_simulation_based(
1105 &self,
1106 strategies: Vec<ParallelismStrategy>,
1107 ) -> Result<Vec<ParallelismStrategy>> {
1108 self.evaluate_model_based(strategies)
1110 }
1111
1112 fn evaluate_profiling_based(
1114 &self,
1115 strategies: Vec<ParallelismStrategy>,
1116 ) -> Result<Vec<ParallelismStrategy>> {
1117 self.evaluate_model_based(strategies)
1119 }
1120
1121 fn evaluate_hybrid(
1123 &self,
1124 strategies: Vec<ParallelismStrategy>,
1125 ) -> Result<Vec<ParallelismStrategy>> {
1126 self.evaluate_model_based(strategies)
1128 }
1129
1130 fn select_optimal_strategy(
1132 &self,
1133 mut strategies: Vec<ParallelismStrategy>,
1134 ) -> Result<ParallelismStrategy> {
1135 if strategies.is_empty() {
1136 return Err(anyhow!("No strategies available for selection"));
1137 }
1138
1139 strategies
1141 .sort_by(|a, b| self.compare_strategies(a, b).unwrap_or(std::cmp::Ordering::Equal));
1142
1143 Ok(strategies.into_iter().next().expect("strategies is not empty"))
1144 }
1145
1146 fn compare_strategies(
1148 &self,
1149 a: &ParallelismStrategy,
1150 b: &ParallelismStrategy,
1151 ) -> Result<std::cmp::Ordering> {
1152 match &self.config.optimization_objective {
1153 OptimizationObjective::MinimizeTime => {
1154 Ok(a.expected_performance.time_per_step.cmp(&b.expected_performance.time_per_step))
1155 },
1156 OptimizationObjective::MinimizeMemory => Ok(a
1157 .expected_performance
1158 .memory_per_device
1159 .cmp(&b.expected_performance.memory_per_device)),
1160 OptimizationObjective::MinimizeCommunication => Ok(a
1161 .expected_performance
1162 .communication_overhead
1163 .partial_cmp(&b.expected_performance.communication_overhead)
1164 .unwrap_or(std::cmp::Ordering::Equal)),
1165 OptimizationObjective::MaximizeThroughput => Ok(b
1166 .expected_performance
1167 .throughput
1168 .partial_cmp(&a.expected_performance.throughput)
1169 .unwrap_or(std::cmp::Ordering::Equal)),
1170 OptimizationObjective::MaximizeEfficiency => Ok(b
1171 .expected_performance
1172 .efficiency
1173 .partial_cmp(&a.expected_performance.efficiency)
1174 .unwrap_or(std::cmp::Ordering::Equal)),
1175 OptimizationObjective::MultiObjective(_objectives) => {
1176 let score_a = self.calculate_multi_objective_score(a);
1178 let score_b = self.calculate_multi_objective_score(b);
1179 Ok(score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal))
1180 },
1181 }
1182 }
1183
1184 fn calculate_multi_objective_score(&self, strategy: &ParallelismStrategy) -> f32 {
1186 let time_score = 1.0 / (strategy.expected_performance.time_per_step.as_secs_f32() + 1e-6);
1188 let memory_score =
1189 1.0 / (strategy.expected_performance.memory_per_device as f32 / 1e9 + 1e-6);
1190 let comm_score = 1.0 / (strategy.expected_performance.communication_overhead + 1e-6);
1191 let throughput_score = strategy.expected_performance.throughput as f32;
1192 let efficiency_score = strategy.expected_performance.efficiency;
1193
1194 (time_score + memory_score + comm_score + throughput_score + efficiency_score) / 5.0
1195 }
1196
1197 fn estimate_performance_data_parallel(&self) -> Result<PerformanceMetrics> {
1199 let model = &self.config.model_constraints;
1200 let hardware = &self.config.hardware_constraints;
1201
1202 let params_per_device = model.num_parameters * 4; let memory_per_device = params_per_device + 2 * params_per_device; let compute_time = (model.num_parameters as f64 * 2.0) / hardware.compute_per_device; let communication_time =
1208 (params_per_device as f64) / hardware.inter_device_bandwidth as f64;
1209 let total_time = compute_time + communication_time;
1210
1211 Ok(PerformanceMetrics {
1212 time_per_step: Duration::from_secs_f64(total_time),
1213 memory_per_device,
1214 communication_overhead: communication_time as f32 / total_time as f32,
1215 throughput: 1.0 / total_time,
1216 efficiency: 0.8,
1217 scalability: 0.9,
1218 })
1219 }
1220
1221 fn estimate_performance_3d_parallel(
1223 &self,
1224 dp_size: usize,
1225 mp_size: usize,
1226 _pp_size: usize,
1227 ) -> Result<PerformanceMetrics> {
1228 let model = &self.config.model_constraints;
1229 let hardware = &self.config.hardware_constraints;
1230
1231 let params_per_device = model.num_parameters / (mp_size as u64);
1233 let memory_per_device = params_per_device * 4 + 2 * params_per_device;
1234
1235 let compute_time = (params_per_device as f64 * 2.0) / hardware.compute_per_device;
1236 let pipeline_bubble = 0.1; let communication_time = compute_time * 0.2; let total_time = compute_time * (1.0 + pipeline_bubble) + communication_time;
1239
1240 Ok(PerformanceMetrics {
1241 time_per_step: Duration::from_secs_f64(total_time),
1242 memory_per_device,
1243 communication_overhead: communication_time as f32 / total_time as f32,
1244 throughput: dp_size as f64 / total_time,
1245 efficiency: 0.85,
1246 scalability: 0.95,
1247 })
1248 }
1249
1250 fn estimate_performance_expert_parallel(&self) -> Result<PerformanceMetrics> {
1252 let model = &self.config.model_constraints;
1253 let hardware = &self.config.hardware_constraints;
1254
1255 let experts_per_device = model.num_experts.unwrap_or(8) / hardware.num_devices;
1256 let params_per_expert = model.num_parameters / model.num_experts.unwrap_or(8) as u64;
1257 let memory_per_device = params_per_expert * experts_per_device as u64 * 4;
1258
1259 let compute_time = (params_per_expert as f64 * 2.0) / hardware.compute_per_device;
1260 let routing_overhead = 0.1; let communication_time = compute_time * 0.15; let total_time = compute_time * (1.0 + routing_overhead) + communication_time;
1263
1264 Ok(PerformanceMetrics {
1265 time_per_step: Duration::from_secs_f64(total_time),
1266 memory_per_device,
1267 communication_overhead: communication_time as f32 / total_time as f32,
1268 throughput: 1.0 / total_time,
1269 efficiency: 0.9,
1270 scalability: 0.95,
1271 })
1272 }
1273
1274 fn estimate_performance_sequence_parallel(&self) -> Result<PerformanceMetrics> {
1276 let model = &self.config.model_constraints;
1277 let hardware = &self.config.hardware_constraints;
1278
1279 let seq_per_device = model.max_sequence_length / hardware.num_devices;
1280 let memory_per_device = (seq_per_device * model.hidden_size * 4) as u64;
1281
1282 let compute_time = (model.num_parameters as f64 * 2.0) / hardware.compute_per_device;
1283 let attention_comm_overhead = 0.2; let total_time = compute_time * (1.0 + attention_comm_overhead);
1285
1286 Ok(PerformanceMetrics {
1287 time_per_step: Duration::from_secs_f64(total_time),
1288 memory_per_device,
1289 communication_overhead: attention_comm_overhead as f32,
1290 throughput: 1.0 / total_time,
1291 efficiency: 0.8,
1292 scalability: 0.85,
1293 })
1294 }
1295
1296 fn estimate_performance_tensor_parallel(&self) -> Result<PerformanceMetrics> {
1298 let model = &self.config.model_constraints;
1299 let hardware = &self.config.hardware_constraints;
1300
1301 let params_per_device = model.num_parameters / hardware.num_devices as u64;
1302 let memory_per_device = params_per_device * 4;
1303
1304 let compute_time = (params_per_device as f64 * 2.0) / hardware.compute_per_device;
1305 let tensor_comm_overhead = 0.25; let total_time = compute_time * (1.0 + tensor_comm_overhead);
1307
1308 Ok(PerformanceMetrics {
1309 time_per_step: Duration::from_secs_f64(total_time),
1310 memory_per_device,
1311 communication_overhead: tensor_comm_overhead as f32,
1312 throughput: 1.0 / total_time,
1313 efficiency: 0.75,
1314 scalability: 0.8,
1315 })
1316 }
1317
1318 fn estimate_performance_hybrid(&self) -> Result<PerformanceMetrics> {
1320 let base_metrics = self.estimate_performance_3d_parallel(2, 2, 2)?;
1322
1323 Ok(PerformanceMetrics {
1324 time_per_step: base_metrics.time_per_step,
1325 memory_per_device: base_metrics.memory_per_device / 2, communication_overhead: base_metrics.communication_overhead * 1.1, throughput: base_metrics.throughput * 0.95, efficiency: 0.9,
1329 scalability: 0.95,
1330 })
1331 }
1332
1333 fn refine_performance_estimate(
1335 &self,
1336 strategy: &ParallelismStrategy,
1337 ) -> Result<PerformanceMetrics> {
1338 Ok(strategy.expected_performance.clone())
1341 }
1342
1343 fn calculate_confidence(&self, strategy: &ParallelismStrategy) -> f32 {
1345 let mut confidence: f32 = 0.5;
1347
1348 if strategy.strategy_id.contains("data_parallel") {
1350 confidence += 0.3;
1351 }
1352 if strategy.strategy_id.contains("3d_parallel") {
1353 confidence += 0.2;
1354 }
1355
1356 if strategy.strategy_id.contains("hybrid") {
1358 confidence -= 0.1;
1359 }
1360
1361 confidence.clamp(0.0, 1.0)
1362 }
1363
1364 pub fn current_strategy(&self) -> Option<&ParallelismStrategy> {
1366 self.current_strategy.as_ref()
1367 }
1368
1369 pub fn update_performance_history(&mut self, actual_performance: PerformanceMetrics) {
1371 if let Some(current_strategy) = &self.current_strategy {
1372 self.performance_history.push((current_strategy.clone(), actual_performance));
1373
1374 if self.performance_history.len() > 100 {
1376 self.performance_history.remove(0);
1377 }
1378 }
1379 }
1380
1381 pub fn config(&self) -> &AutoParallelismConfig {
1383 &self.config
1384 }
1385}
1386
1387pub mod utils {
1389 use super::*;
1390
1391 pub fn estimate_model_memory(constraints: &ModelConstraints) -> u64 {
1393 let param_memory = constraints.num_parameters * 4; let gradient_memory = param_memory; let optimizer_memory = param_memory * 2; param_memory + gradient_memory + optimizer_memory
1398 }
1399
1400 pub fn meets_requirements(
1402 strategy: &ParallelismStrategy,
1403 requirements: &PerformanceRequirements,
1404 ) -> bool {
1405 if let Some(max_time) = requirements.max_training_time {
1406 if strategy.expected_performance.time_per_step > max_time {
1407 return false;
1408 }
1409 }
1410
1411 if let Some(min_throughput) = requirements.min_throughput {
1412 if strategy.expected_performance.throughput < min_throughput {
1413 return false;
1414 }
1415 }
1416
1417 if let Some(max_memory) = requirements.max_memory_per_device {
1418 if strategy.expected_performance.memory_per_device > max_memory {
1419 return false;
1420 }
1421 }
1422
1423 if let Some(max_comm_overhead) = requirements.max_communication_overhead {
1424 if strategy.expected_performance.communication_overhead > max_comm_overhead {
1425 return false;
1426 }
1427 }
1428
1429 if let Some(min_efficiency) = requirements.min_efficiency {
1430 if strategy.expected_performance.efficiency < min_efficiency {
1431 return false;
1432 }
1433 }
1434
1435 true
1436 }
1437
1438 pub fn detect_hardware_constraints() -> Result<HardwareConstraints> {
1440 Ok(HardwareConstraints::default())
1442 }
1443
1444 pub fn analyze_model_constraints<M: Model>(_model: &M) -> Result<ModelConstraints> {
1446 Ok(ModelConstraints::default())
1448 }
1449}
1450
1451#[cfg(test)]
1452mod tests {
1453 use super::*;
1454
1455 #[test]
1456 fn test_auto_parallelism_config() {
1457 let config = AutoParallelismConfig::default();
1458 assert!(config.enabled);
1459 assert_eq!(config.hardware_constraints.num_devices, 8);
1460 }
1461
1462 #[test]
1463 fn test_auto_parallelism_selector_creation() {
1464 let config = AutoParallelismConfig::default();
1465 let selector = AutoParallelismSelector::new(config);
1466 assert!(selector.current_strategy.is_none());
1467 }
1468
1469 #[test]
1470 fn test_strategy_selection() {
1471 let config = AutoParallelismConfig::default();
1472 let mut selector = AutoParallelismSelector::new(config);
1473
1474 let strategy = selector.select_strategy();
1475 assert!(strategy.is_ok());
1476 assert!(selector.current_strategy.is_some());
1477 }
1478
1479 #[test]
1480 fn test_rule_based_strategy_generation() {
1481 let config = AutoParallelismConfig {
1482 selection_algorithm: SelectionAlgorithm::RuleBased,
1483 ..Default::default()
1484 };
1485 let selector = AutoParallelismSelector::new(config);
1486
1487 let strategies = selector.generate_rule_based_strategies();
1488 assert!(strategies.is_ok());
1489 assert!(!strategies.expect("operation failed in test").is_empty());
1490 }
1491
1492 #[test]
1493 fn test_performance_estimation() {
1494 let config = AutoParallelismConfig::default();
1495 let selector = AutoParallelismSelector::new(config);
1496
1497 let metrics = selector.estimate_performance_data_parallel();
1498 assert!(metrics.is_ok());
1499
1500 let metrics = metrics.expect("operation failed in test");
1501 assert!(metrics.time_per_step.as_secs_f64() > 0.0);
1502 assert!(metrics.memory_per_device > 0);
1503 }
1504
1505 #[test]
1506 fn test_strategy_comparison() {
1507 let config = AutoParallelismConfig {
1508 optimization_objective: OptimizationObjective::MinimizeTime,
1509 ..Default::default()
1510 };
1511 let selector = AutoParallelismSelector::new(config);
1512
1513 let strategy1 = ParallelismStrategy {
1514 strategy_id: "test1".to_string(),
1515 data_parallel: None,
1516 parallelism_3d: None,
1517 expert_parallel: None,
1518 sequence_parallel: None,
1519 tensor_parallel: None,
1520 expected_performance: PerformanceMetrics {
1521 time_per_step: Duration::from_secs(1),
1522 memory_per_device: 1000,
1523 communication_overhead: 0.1,
1524 throughput: 1.0,
1525 efficiency: 0.8,
1526 scalability: 0.9,
1527 },
1528 confidence: 0.8,
1529 rationale: "Test strategy 1".to_string(),
1530 };
1531
1532 let strategy2 = ParallelismStrategy {
1533 strategy_id: "test2".to_string(),
1534 data_parallel: None,
1535 parallelism_3d: None,
1536 expert_parallel: None,
1537 sequence_parallel: None,
1538 tensor_parallel: None,
1539 expected_performance: PerformanceMetrics {
1540 time_per_step: Duration::from_secs(2),
1541 memory_per_device: 800,
1542 communication_overhead: 0.05,
1543 throughput: 0.5,
1544 efficiency: 0.9,
1545 scalability: 0.85,
1546 },
1547 confidence: 0.9,
1548 rationale: "Test strategy 2".to_string(),
1549 };
1550
1551 let comparison = selector.compare_strategies(&strategy1, &strategy2);
1552 assert!(comparison.is_ok());
1553 assert_eq!(
1554 comparison.expect("operation failed in test"),
1555 std::cmp::Ordering::Less
1556 ); }
1558
1559 #[test]
1560 fn test_memory_estimation() {
1561 let constraints = ModelConstraints {
1562 num_parameters: 1_000_000,
1563 ..Default::default()
1564 };
1565
1566 let memory = utils::estimate_model_memory(&constraints);
1567 assert_eq!(memory, 16_000_000); }
1569
1570 #[test]
1571 fn test_requirements_checking() {
1572 let strategy = ParallelismStrategy {
1573 strategy_id: "test".to_string(),
1574 data_parallel: None,
1575 parallelism_3d: None,
1576 expert_parallel: None,
1577 sequence_parallel: None,
1578 tensor_parallel: None,
1579 expected_performance: PerformanceMetrics {
1580 time_per_step: Duration::from_secs(1),
1581 memory_per_device: 1000,
1582 communication_overhead: 0.2,
1583 throughput: 2.0,
1584 efficiency: 0.8,
1585 scalability: 0.9,
1586 },
1587 confidence: 0.8,
1588 rationale: "Test strategy".to_string(),
1589 };
1590
1591 let requirements = PerformanceRequirements {
1592 max_training_time: Some(Duration::from_secs(2)),
1593 min_throughput: Some(1.0),
1594 max_memory_per_device: Some(2000),
1595 max_communication_overhead: Some(0.3),
1596 min_efficiency: Some(0.7),
1597 };
1598
1599 assert!(utils::meets_requirements(&strategy, &requirements));
1600
1601 let strict_requirements = PerformanceRequirements {
1602 max_training_time: Some(Duration::from_millis(500)),
1603 ..requirements
1604 };
1605
1606 assert!(!utils::meets_requirements(&strategy, &strict_requirements));
1607 }
1608}