1use crate::distributed::ProcessGroup;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::{Duration, Instant};
7use trustformers_core::tensor::Tensor;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ExpertParallelismConfig {
15 pub num_experts: usize,
17 pub experts_per_device: usize,
19 pub expert_parallel_size: usize,
21 pub top_k: usize,
23 pub load_balancing: LoadBalancingStrategy,
25 pub routing_strategy: ExpertRoutingStrategy,
27 pub capacity_factor: f32,
29 pub drop_tokens: bool,
31 pub use_auxiliary_loss: bool,
33 pub auxiliary_loss_weight: f32,
35 pub communication_pattern: ExpertCommunicationPattern,
37}
38
39impl Default for ExpertParallelismConfig {
40 fn default() -> Self {
41 Self {
42 num_experts: 8,
43 experts_per_device: 2,
44 expert_parallel_size: 4,
45 top_k: 2,
46 load_balancing: LoadBalancingStrategy::TokenChoiceBased,
47 routing_strategy: ExpertRoutingStrategy::LearnedGating,
48 capacity_factor: 1.25,
49 drop_tokens: false,
50 use_auxiliary_loss: true,
51 auxiliary_loss_weight: 0.01,
52 communication_pattern: ExpertCommunicationPattern::AllToAll,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum LoadBalancingStrategy {
60 TokenChoiceBased,
62 ExpertChoiceBased,
64 Dynamic,
66 RoundRobin,
68 LoadAware,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum ExpertRoutingStrategy {
75 LearnedGating,
77 HashBased,
79 Random,
81 LoadBased,
83 SimilarityBased,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum ExpertCommunicationPattern {
90 AllToAll,
92 PointToPoint,
94 Hierarchical,
96 Ring,
98}
99
100#[derive(Debug, Clone)]
102pub struct ExpertAssignment {
103 pub expert_id: usize,
105 pub device_rank: usize,
107 pub local_expert_id: usize,
109 pub load_weight: f32,
111}
112
113#[derive(Debug, Clone)]
115pub struct TokenRouting {
116 pub token_indices: Vec<usize>,
118 pub expert_assignments: Vec<Vec<(usize, f32)>>, pub destinations: HashMap<usize, Vec<usize>>, pub capacity_usage: HashMap<usize, usize>, }
125
126#[allow(dead_code)]
128pub struct ExpertParallelism {
129 config: ExpertParallelismConfig,
130 #[allow(dead_code)]
131 global_rank: usize,
132 world_size: usize,
133
134 expert_assignments: Vec<ExpertAssignment>,
136 local_experts: Vec<usize>, expert_group: Arc<dyn ProcessGroup>,
140
141 load_balancing_state: Arc<RwLock<LoadBalancingState>>,
143
144 communication_stats: Arc<Mutex<ExpertCommunicationStats>>,
146
147 routing_cache: Arc<Mutex<HashMap<String, TokenRouting>>>,
149}
150
151#[derive(Debug, Default)]
153#[allow(dead_code)]
154struct LoadBalancingState {
155 expert_loads: HashMap<usize, f32>,
156 #[allow(dead_code)]
157 expert_utilization: HashMap<usize, f32>,
158 token_distribution: HashMap<usize, usize>,
159 imbalance_score: f32,
160 last_rebalance_time: Option<Instant>,
161}
162
163#[derive(Debug, Default)]
165#[allow(dead_code)]
166struct ExpertCommunicationStats {
167 all_to_all_time: Duration,
168 #[allow(dead_code)]
169 point_to_point_time: Duration,
170 total_tokens_routed: u64,
171 expert_load_variance: f32,
172 communication_efficiency: f32,
173 routing_overhead: Duration,
174}
175
176impl ExpertParallelism {
177 pub fn new(
179 config: ExpertParallelismConfig,
180 global_rank: usize,
181 world_size: usize,
182 expert_group: Arc<dyn ProcessGroup>,
183 ) -> Result<Self> {
184 if config.num_experts % config.expert_parallel_size != 0 {
186 return Err(anyhow!(
187 "Number of experts ({}) must be divisible by expert parallel size ({})",
188 config.num_experts,
189 config.expert_parallel_size
190 ));
191 }
192
193 if config.experts_per_device * config.expert_parallel_size != config.num_experts {
194 return Err(anyhow!(
195 "Expert assignment mismatch: experts_per_device ({}) * expert_parallel_size ({}) != num_experts ({})",
196 config.experts_per_device, config.expert_parallel_size, config.num_experts
197 ));
198 }
199
200 let expert_assignments = Self::create_expert_assignments(&config, world_size)?;
202 let local_experts = Self::get_local_experts(&expert_assignments, global_rank);
203
204 Ok(Self {
205 config,
206 global_rank,
207 world_size,
208 expert_assignments,
209 local_experts,
210 expert_group,
211 load_balancing_state: Arc::new(RwLock::new(LoadBalancingState::default())),
212 communication_stats: Arc::new(Mutex::new(ExpertCommunicationStats::default())),
213 routing_cache: Arc::new(Mutex::new(HashMap::new())),
214 })
215 }
216
217 fn create_expert_assignments(
219 config: &ExpertParallelismConfig,
220 _world_size: usize,
221 ) -> Result<Vec<ExpertAssignment>> {
222 let mut assignments = Vec::new();
223
224 for expert_id in 0..config.num_experts {
225 let device_rank = expert_id / config.experts_per_device;
226 let local_expert_id = expert_id % config.experts_per_device;
227
228 assignments.push(ExpertAssignment {
229 expert_id,
230 device_rank,
231 local_expert_id,
232 load_weight: 1.0, });
234 }
235
236 Ok(assignments)
237 }
238
239 fn get_local_experts(assignments: &[ExpertAssignment], device_rank: usize) -> Vec<usize> {
241 assignments
242 .iter()
243 .filter(|assignment| assignment.device_rank == device_rank)
244 .map(|assignment| assignment.expert_id)
245 .collect()
246 }
247
248 pub fn route_tokens(&self, tokens: &Tensor, gating_scores: &Tensor) -> Result<TokenRouting> {
250 let start_time = Instant::now();
251
252 let routing = match self.config.routing_strategy {
254 ExpertRoutingStrategy::LearnedGating => {
255 self.learned_gating_routing(tokens, gating_scores)?
256 },
257 ExpertRoutingStrategy::HashBased => self.hash_based_routing(tokens)?,
258 ExpertRoutingStrategy::Random => self.random_routing(tokens)?,
259 ExpertRoutingStrategy::LoadBased => self.load_based_routing(tokens, gating_scores)?,
260 ExpertRoutingStrategy::SimilarityBased => {
261 self.similarity_based_routing(tokens, gating_scores)?
262 },
263 };
264
265 {
267 let mut stats = self
268 .communication_stats
269 .lock()
270 .expect("communication_stats lock should not be poisoned");
271 stats.routing_overhead += start_time.elapsed();
272 stats.total_tokens_routed += tokens.shape()[0] as u64;
273 }
274
275 Ok(routing)
276 }
277
278 fn learned_gating_routing(
280 &self,
281 tokens: &Tensor,
282 _gating_scores: &Tensor,
283 ) -> Result<TokenRouting> {
284 let batch_size = tokens.shape()[0];
285 let num_tokens = batch_size;
286
287 let mut token_routing = TokenRouting {
288 token_indices: (0..num_tokens).collect(),
289 expert_assignments: Vec::new(),
290 destinations: HashMap::new(),
291 capacity_usage: HashMap::new(),
292 };
293
294 for token_idx in 0..num_tokens {
296 let mut expert_scores = Vec::new();
297
298 for expert_id in 0..self.config.num_experts {
300 let score = 1.0 / (expert_id + 1) as f32; expert_scores.push((expert_id, score));
302 }
303
304 expert_scores
306 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
307 expert_scores.truncate(self.config.top_k);
308
309 let total_weight: f32 = expert_scores.iter().map(|(_, w)| w).sum();
311 let normalized_assignments: Vec<(usize, f32)> = expert_scores
312 .iter()
313 .map(|(expert_id, weight)| (*expert_id, weight / total_weight))
314 .collect();
315
316 token_routing.expert_assignments.push(normalized_assignments.clone());
317
318 for (expert_id, _) in normalized_assignments {
320 let device_rank = self.expert_assignments[expert_id].device_rank;
321 token_routing.destinations.entry(device_rank).or_default().push(token_idx);
322 }
323 }
324
325 Ok(token_routing)
326 }
327
328 fn hash_based_routing(&self, tokens: &Tensor) -> Result<TokenRouting> {
330 let batch_size = tokens.shape()[0];
331 let mut token_routing = TokenRouting {
332 token_indices: (0..batch_size).collect(),
333 expert_assignments: Vec::new(),
334 destinations: HashMap::new(),
335 capacity_usage: HashMap::new(),
336 };
337
338 for token_idx in 0..batch_size {
339 let expert_id = token_idx % self.config.num_experts;
341 let device_rank = self.expert_assignments[expert_id].device_rank;
342
343 token_routing.expert_assignments.push(vec![(expert_id, 1.0)]);
344 token_routing.destinations.entry(device_rank).or_default().push(token_idx);
345 }
346
347 Ok(token_routing)
348 }
349
350 fn random_routing(&self, tokens: &Tensor) -> Result<TokenRouting> {
352 let batch_size = tokens.shape()[0];
353 let mut token_routing = TokenRouting {
354 token_indices: (0..batch_size).collect(),
355 expert_assignments: Vec::new(),
356 destinations: HashMap::new(),
357 capacity_usage: HashMap::new(),
358 };
359
360 use std::collections::hash_map::DefaultHasher;
361 use std::hash::{Hash, Hasher};
362
363 for token_idx in 0..batch_size {
364 let mut hasher = DefaultHasher::new();
365 token_idx.hash(&mut hasher);
366 let expert_id = (hasher.finish() as usize) % self.config.num_experts;
367 let device_rank = self.expert_assignments[expert_id].device_rank;
368
369 token_routing.expert_assignments.push(vec![(expert_id, 1.0)]);
370 token_routing.destinations.entry(device_rank).or_default().push(token_idx);
371 }
372
373 Ok(token_routing)
374 }
375
376 fn load_based_routing(&self, tokens: &Tensor, _gating_scores: &Tensor) -> Result<TokenRouting> {
378 let batch_size = tokens.shape()[0];
379 let mut token_routing = TokenRouting {
380 token_indices: (0..batch_size).collect(),
381 expert_assignments: Vec::new(),
382 destinations: HashMap::new(),
383 capacity_usage: HashMap::new(),
384 };
385
386 let load_state = self
387 .load_balancing_state
388 .read()
389 .expect("load_balancing_state lock should not be poisoned");
390
391 for token_idx in 0..batch_size {
392 let mut min_load = f32::INFINITY;
394 let mut selected_expert = 0;
395
396 for expert_id in 0..self.config.num_experts {
397 let load = load_state.expert_loads.get(&expert_id).unwrap_or(&0.0);
398 if *load < min_load {
399 min_load = *load;
400 selected_expert = expert_id;
401 }
402 }
403
404 let device_rank = self.expert_assignments[selected_expert].device_rank;
405 token_routing.expert_assignments.push(vec![(selected_expert, 1.0)]);
406 token_routing.destinations.entry(device_rank).or_default().push(token_idx);
407 }
408
409 Ok(token_routing)
410 }
411
412 fn similarity_based_routing(
414 &self,
415 tokens: &Tensor,
416 gating_scores: &Tensor,
417 ) -> Result<TokenRouting> {
418 self.learned_gating_routing(tokens, gating_scores)
420 }
421
422 pub fn all_to_all_communication(
424 &self,
425 local_tokens: &Tensor,
426 routing: &TokenRouting,
427 ) -> Result<HashMap<usize, Tensor>> {
428 let start_time = Instant::now();
429
430 let mut expert_inputs = HashMap::new();
433
434 for expert_id in &self.local_experts {
435 let mut expert_tokens = Vec::new();
437
438 for (token_idx, assignments) in routing.expert_assignments.iter().enumerate() {
439 for (assigned_expert_id, weight) in assignments {
440 if *assigned_expert_id == *expert_id && *weight > 0.0 {
441 expert_tokens.push(token_idx);
443 }
444 }
445 }
446
447 if !expert_tokens.is_empty() {
449 let expert_tensor = Tensor::zeros(&[expert_tokens.len(), local_tokens.shape()[1]])?;
450 expert_inputs.insert(*expert_id, expert_tensor);
451 }
452 }
453
454 {
456 let mut stats = self
457 .communication_stats
458 .lock()
459 .expect("communication_stats lock should not be poisoned");
460 stats.all_to_all_time += start_time.elapsed();
461 }
462
463 Ok(expert_inputs)
464 }
465
466 pub fn update_load_balancing(&self, expert_outputs: &HashMap<usize, Tensor>) -> Result<()> {
468 let mut load_state = self
469 .load_balancing_state
470 .write()
471 .expect("load_balancing_state lock should not be poisoned");
472
473 for (expert_id, output) in expert_outputs {
475 let load = output.shape()[0] as f32; load_state.expert_loads.insert(*expert_id, load);
477 }
478
479 let total_load: f32 = load_state.expert_loads.values().sum();
481 let avg_load = total_load / self.config.num_experts as f32;
482
483 let mut variance = 0.0;
484 for load in load_state.expert_loads.values() {
485 variance += (load - avg_load).powi(2);
486 }
487 variance /= self.config.num_experts as f32;
488
489 load_state.imbalance_score = variance.sqrt() / avg_load.max(1e-6);
490 load_state.last_rebalance_time = Some(Instant::now());
491
492 Ok(())
493 }
494
495 pub fn get_load_balancing_stats(&self) -> LoadBalancingStats {
497 let load_state = self
498 .load_balancing_state
499 .read()
500 .expect("load_balancing_state lock should not be poisoned");
501 let comm_stats = self.communication_stats.lock().expect("lock should not be poisoned");
502
503 LoadBalancingStats {
504 expert_loads: load_state.expert_loads.clone(),
505 imbalance_score: load_state.imbalance_score,
506 total_tokens_routed: comm_stats.total_tokens_routed,
507 communication_efficiency: comm_stats.communication_efficiency,
508 routing_overhead: comm_stats.routing_overhead,
509 }
510 }
511
512 pub fn local_experts(&self) -> &[usize] {
514 &self.local_experts
515 }
516
517 pub fn get_expert_assignment(&self, expert_id: usize) -> Option<&ExpertAssignment> {
519 self.expert_assignments.get(expert_id)
520 }
521
522 pub fn config(&self) -> &ExpertParallelismConfig {
524 &self.config
525 }
526}
527
528#[derive(Debug, Clone)]
530pub struct LoadBalancingStats {
531 pub expert_loads: HashMap<usize, f32>,
532 pub imbalance_score: f32,
533 pub total_tokens_routed: u64,
534 pub communication_efficiency: f32,
535 pub routing_overhead: Duration,
536}
537
538pub mod utils {
540 use super::*;
541
542 pub fn calculate_optimal_expert_config(
544 num_experts: usize,
545 world_size: usize,
546 memory_per_expert_mb: usize,
547 available_memory_mb: usize,
548 ) -> Result<ExpertParallelismConfig> {
549 let experts_per_device = std::cmp::min(
550 available_memory_mb / memory_per_expert_mb,
551 num_experts / world_size,
552 );
553
554 if experts_per_device == 0 {
555 return Err(anyhow!("Insufficient memory for expert parallelism"));
556 }
557
558 let expert_parallel_size = num_experts.div_ceil(experts_per_device);
559
560 Ok(ExpertParallelismConfig {
561 num_experts,
562 experts_per_device,
563 expert_parallel_size,
564 ..Default::default()
565 })
566 }
567
568 pub fn estimate_communication_cost(
570 config: &ExpertParallelismConfig,
571 batch_size: usize,
572 sequence_length: usize,
573 hidden_size: usize,
574 ) -> f32 {
575 let total_tokens = batch_size * sequence_length;
576 let tokens_per_expert = total_tokens / config.num_experts;
577 let communication_volume = tokens_per_expert * hidden_size * 4; communication_volume as f32 / (1024.0 * 1024.0) }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::distributed::SimulatedProcessGroup;
588 use std::sync::Arc;
589
590 #[test]
591 fn test_expert_parallelism_config() {
592 let config = ExpertParallelismConfig::default();
593 assert_eq!(config.num_experts, 8);
594 assert_eq!(config.expert_parallel_size, 4);
595 assert_eq!(config.experts_per_device, 2);
596 }
597
598 #[test]
599 fn test_expert_assignment_creation() {
600 let config = ExpertParallelismConfig {
601 num_experts: 8,
602 experts_per_device: 2,
603 expert_parallel_size: 4,
604 ..Default::default()
605 };
606
607 let assignments = ExpertParallelism::create_expert_assignments(&config, 4)
608 .expect("operation failed in test");
609 assert_eq!(assignments.len(), 8);
610
611 for (i, assignment) in assignments.iter().enumerate() {
613 assert_eq!(assignment.expert_id, i);
614 assert_eq!(assignment.device_rank, i / 2);
615 assert_eq!(assignment.local_expert_id, i % 2);
616 }
617 }
618
619 #[test]
620 fn test_local_experts() {
621 let config = ExpertParallelismConfig {
622 num_experts: 8,
623 experts_per_device: 2,
624 expert_parallel_size: 4,
625 ..Default::default()
626 };
627
628 let assignments = ExpertParallelism::create_expert_assignments(&config, 4)
629 .expect("operation failed in test");
630 let local_experts = ExpertParallelism::get_local_experts(&assignments, 1);
631
632 assert_eq!(local_experts, vec![2, 3]);
633 }
634
635 #[test]
636 fn test_expert_parallelism_creation() {
637 let config = ExpertParallelismConfig {
638 num_experts: 8,
639 experts_per_device: 2,
640 expert_parallel_size: 4,
641 ..Default::default()
642 };
643
644 let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
645 let expert_parallelism = ExpertParallelism::new(config, 0, 4, process_group);
646
647 assert!(expert_parallelism.is_ok());
648 let ep = expert_parallelism.expect("operation failed in test");
649 assert_eq!(ep.local_experts(), &[0, 1]);
650 }
651
652 #[test]
653 fn test_hash_based_routing() {
654 let config = ExpertParallelismConfig {
655 num_experts: 4,
656 experts_per_device: 1,
657 expert_parallel_size: 4,
658 routing_strategy: ExpertRoutingStrategy::HashBased,
659 ..Default::default()
660 };
661
662 let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
663 let expert_parallelism =
664 ExpertParallelism::new(config, 0, 4, process_group).expect("operation failed in test");
665
666 let tokens = Tensor::zeros(&[8, 16]).expect("tensor operation failed");
667 let routing = expert_parallelism
668 .hash_based_routing(&tokens)
669 .expect("operation failed in test");
670
671 assert_eq!(routing.token_indices.len(), 8);
672 assert_eq!(routing.expert_assignments.len(), 8);
673 }
674
675 #[test]
676 fn test_load_balancing_update() {
677 let config = ExpertParallelismConfig::default();
678 let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
679 let expert_parallelism =
680 ExpertParallelism::new(config, 0, 4, process_group).expect("operation failed in test");
681
682 let mut expert_outputs = HashMap::new();
683 expert_outputs.insert(
684 0,
685 Tensor::zeros(&[10, 16]).expect("tensor operation failed"),
686 );
687 expert_outputs.insert(
688 1,
689 Tensor::zeros(&[15, 16]).expect("tensor operation failed"),
690 );
691
692 let result = expert_parallelism.update_load_balancing(&expert_outputs);
693 assert!(result.is_ok());
694
695 let stats = expert_parallelism.get_load_balancing_stats();
696 assert!(stats.expert_loads.contains_key(&0));
697 assert!(stats.expert_loads.contains_key(&1));
698 }
699
700 #[test]
701 fn test_optimal_expert_config_calculation() {
702 let config = utils::calculate_optimal_expert_config(16, 8, 1000, 4000)
703 .expect("operation failed in test");
704 assert!(config.experts_per_device <= 4);
705 assert!(config.expert_parallel_size > 0);
706 }
707
708 #[test]
709 fn test_communication_cost_estimation() {
710 let config = ExpertParallelismConfig::default();
711 let cost = utils::estimate_communication_cost(&config, 32, 512, 768);
712 assert!(cost > 0.0);
713 }
714}