Skip to main content

trustformers_training/
expert_parallelism.rs

1use crate::distributed::ProcessGroup;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::{Duration, Instant};
7use trustformers_core::tensor::Tensor;
8
9/// Expert Parallelism Configuration for Mixture of Experts (MoE) models
10///
11/// Expert parallelism distributes experts across different devices/processes,
12/// enabling scaling of MoE models with efficient expert routing and load balancing.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ExpertParallelismConfig {
15    /// Number of experts in the MoE layer
16    pub num_experts: usize,
17    /// Number of experts per device/process
18    pub experts_per_device: usize,
19    /// Number of devices/processes for expert parallelism
20    pub expert_parallel_size: usize,
21    /// Top-k routing for expert selection
22    pub top_k: usize,
23    /// Load balancing strategy
24    pub load_balancing: LoadBalancingStrategy,
25    /// Expert routing strategy
26    pub routing_strategy: ExpertRoutingStrategy,
27    /// Whether to use expert capacity limiting
28    pub capacity_factor: f32,
29    /// Drop tokens when capacity is exceeded
30    pub drop_tokens: bool,
31    /// Use auxiliary load balancing loss
32    pub use_auxiliary_loss: bool,
33    /// Auxiliary loss weight
34    pub auxiliary_loss_weight: f32,
35    /// Expert communication pattern
36    pub communication_pattern: ExpertCommunicationPattern,
37}
38
39impl Default for ExpertParallelismConfig {
40    fn default() -> Self {
41        Self {
42            num_experts: 8,
43            experts_per_device: 2,
44            expert_parallel_size: 4,
45            top_k: 2,
46            load_balancing: LoadBalancingStrategy::TokenChoiceBased,
47            routing_strategy: ExpertRoutingStrategy::LearnedGating,
48            capacity_factor: 1.25,
49            drop_tokens: false,
50            use_auxiliary_loss: true,
51            auxiliary_loss_weight: 0.01,
52            communication_pattern: ExpertCommunicationPattern::AllToAll,
53        }
54    }
55}
56
57/// Load balancing strategies for expert utilization
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum LoadBalancingStrategy {
60    /// Balance load based on token choice
61    TokenChoiceBased,
62    /// Balance load based on expert choice
63    ExpertChoiceBased,
64    /// Dynamic load balancing
65    Dynamic,
66    /// Round-robin assignment
67    RoundRobin,
68    /// Load-aware routing
69    LoadAware,
70}
71
72/// Expert routing strategies
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum ExpertRoutingStrategy {
75    /// Learned gating network
76    LearnedGating,
77    /// Hash-based routing
78    HashBased,
79    /// Random routing
80    Random,
81    /// Load-based routing
82    LoadBased,
83    /// Similarity-based routing
84    SimilarityBased,
85}
86
87/// Communication patterns for expert parallelism
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub enum ExpertCommunicationPattern {
90    /// All-to-all communication
91    AllToAll,
92    /// Point-to-point communication
93    PointToPoint,
94    /// Hierarchical communication
95    Hierarchical,
96    /// Ring-based communication
97    Ring,
98}
99
100/// Expert assignment and routing information
101#[derive(Debug, Clone)]
102pub struct ExpertAssignment {
103    /// Expert ID
104    pub expert_id: usize,
105    /// Device/process rank where expert is located
106    pub device_rank: usize,
107    /// Local expert index on the device
108    pub local_expert_id: usize,
109    /// Load weight for this expert
110    pub load_weight: f32,
111}
112
113/// Token routing information
114#[derive(Debug, Clone)]
115pub struct TokenRouting {
116    /// Token indices
117    pub token_indices: Vec<usize>,
118    /// Expert assignments for each token
119    pub expert_assignments: Vec<Vec<(usize, f32)>>, // (expert_id, weight)
120    /// Communication destinations
121    pub destinations: HashMap<usize, Vec<usize>>, // device_rank -> token_indices
122    /// Capacity constraints
123    pub capacity_usage: HashMap<usize, usize>, // expert_id -> current_tokens
124}
125
126/// Expert parallelism coordinator
127#[allow(dead_code)]
128pub struct ExpertParallelism {
129    config: ExpertParallelismConfig,
130    #[allow(dead_code)]
131    global_rank: usize,
132    world_size: usize,
133
134    // Expert assignment mapping
135    expert_assignments: Vec<ExpertAssignment>,
136    local_experts: Vec<usize>, // Expert IDs local to this device
137
138    // Process groups
139    expert_group: Arc<dyn ProcessGroup>,
140
141    // Load balancing state
142    load_balancing_state: Arc<RwLock<LoadBalancingState>>,
143
144    // Communication statistics
145    communication_stats: Arc<Mutex<ExpertCommunicationStats>>,
146
147    // Routing cache for efficiency
148    routing_cache: Arc<Mutex<HashMap<String, TokenRouting>>>,
149}
150
151/// Load balancing state tracking
152#[derive(Debug, Default)]
153#[allow(dead_code)]
154struct LoadBalancingState {
155    expert_loads: HashMap<usize, f32>,
156    #[allow(dead_code)]
157    expert_utilization: HashMap<usize, f32>,
158    token_distribution: HashMap<usize, usize>,
159    imbalance_score: f32,
160    last_rebalance_time: Option<Instant>,
161}
162
163/// Communication statistics for expert parallelism
164#[derive(Debug, Default)]
165#[allow(dead_code)]
166struct ExpertCommunicationStats {
167    all_to_all_time: Duration,
168    #[allow(dead_code)]
169    point_to_point_time: Duration,
170    total_tokens_routed: u64,
171    expert_load_variance: f32,
172    communication_efficiency: f32,
173    routing_overhead: Duration,
174}
175
176impl ExpertParallelism {
177    /// Create a new expert parallelism coordinator
178    pub fn new(
179        config: ExpertParallelismConfig,
180        global_rank: usize,
181        world_size: usize,
182        expert_group: Arc<dyn ProcessGroup>,
183    ) -> Result<Self> {
184        // Validate configuration
185        if config.num_experts % config.expert_parallel_size != 0 {
186            return Err(anyhow!(
187                "Number of experts ({}) must be divisible by expert parallel size ({})",
188                config.num_experts,
189                config.expert_parallel_size
190            ));
191        }
192
193        if config.experts_per_device * config.expert_parallel_size != config.num_experts {
194            return Err(anyhow!(
195                "Expert assignment mismatch: experts_per_device ({}) * expert_parallel_size ({}) != num_experts ({})",
196                config.experts_per_device, config.expert_parallel_size, config.num_experts
197            ));
198        }
199
200        // Create expert assignments
201        let expert_assignments = Self::create_expert_assignments(&config, world_size)?;
202        let local_experts = Self::get_local_experts(&expert_assignments, global_rank);
203
204        Ok(Self {
205            config,
206            global_rank,
207            world_size,
208            expert_assignments,
209            local_experts,
210            expert_group,
211            load_balancing_state: Arc::new(RwLock::new(LoadBalancingState::default())),
212            communication_stats: Arc::new(Mutex::new(ExpertCommunicationStats::default())),
213            routing_cache: Arc::new(Mutex::new(HashMap::new())),
214        })
215    }
216
217    /// Create expert assignments across devices
218    fn create_expert_assignments(
219        config: &ExpertParallelismConfig,
220        _world_size: usize,
221    ) -> Result<Vec<ExpertAssignment>> {
222        let mut assignments = Vec::new();
223
224        for expert_id in 0..config.num_experts {
225            let device_rank = expert_id / config.experts_per_device;
226            let local_expert_id = expert_id % config.experts_per_device;
227
228            assignments.push(ExpertAssignment {
229                expert_id,
230                device_rank,
231                local_expert_id,
232                load_weight: 1.0, // Initialize with equal weights
233            });
234        }
235
236        Ok(assignments)
237    }
238
239    /// Get local expert IDs for a given device rank
240    fn get_local_experts(assignments: &[ExpertAssignment], device_rank: usize) -> Vec<usize> {
241        assignments
242            .iter()
243            .filter(|assignment| assignment.device_rank == device_rank)
244            .map(|assignment| assignment.expert_id)
245            .collect()
246    }
247
248    /// Route tokens to experts based on gating scores
249    pub fn route_tokens(&self, tokens: &Tensor, gating_scores: &Tensor) -> Result<TokenRouting> {
250        let start_time = Instant::now();
251
252        // Implement token routing logic based on strategy
253        let routing = match self.config.routing_strategy {
254            ExpertRoutingStrategy::LearnedGating => {
255                self.learned_gating_routing(tokens, gating_scores)?
256            },
257            ExpertRoutingStrategy::HashBased => self.hash_based_routing(tokens)?,
258            ExpertRoutingStrategy::Random => self.random_routing(tokens)?,
259            ExpertRoutingStrategy::LoadBased => self.load_based_routing(tokens, gating_scores)?,
260            ExpertRoutingStrategy::SimilarityBased => {
261                self.similarity_based_routing(tokens, gating_scores)?
262            },
263        };
264
265        // Update statistics
266        {
267            let mut stats = self
268                .communication_stats
269                .lock()
270                .expect("communication_stats lock should not be poisoned");
271            stats.routing_overhead += start_time.elapsed();
272            stats.total_tokens_routed += tokens.shape()[0] as u64;
273        }
274
275        Ok(routing)
276    }
277
278    /// Learned gating-based token routing
279    fn learned_gating_routing(
280        &self,
281        tokens: &Tensor,
282        _gating_scores: &Tensor,
283    ) -> Result<TokenRouting> {
284        let batch_size = tokens.shape()[0];
285        let num_tokens = batch_size;
286
287        let mut token_routing = TokenRouting {
288            token_indices: (0..num_tokens).collect(),
289            expert_assignments: Vec::new(),
290            destinations: HashMap::new(),
291            capacity_usage: HashMap::new(),
292        };
293
294        // Get top-k experts for each token
295        for token_idx in 0..num_tokens {
296            let mut expert_scores = Vec::new();
297
298            // Extract scores for this token (simplified - in practice would use tensor operations)
299            for expert_id in 0..self.config.num_experts {
300                let score = 1.0 / (expert_id + 1) as f32; // Placeholder score calculation
301                expert_scores.push((expert_id, score));
302            }
303
304            // Sort by score and take top-k
305            expert_scores
306                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
307            expert_scores.truncate(self.config.top_k);
308
309            // Normalize weights
310            let total_weight: f32 = expert_scores.iter().map(|(_, w)| w).sum();
311            let normalized_assignments: Vec<(usize, f32)> = expert_scores
312                .iter()
313                .map(|(expert_id, weight)| (*expert_id, weight / total_weight))
314                .collect();
315
316            token_routing.expert_assignments.push(normalized_assignments.clone());
317
318            // Update destinations mapping
319            for (expert_id, _) in normalized_assignments {
320                let device_rank = self.expert_assignments[expert_id].device_rank;
321                token_routing.destinations.entry(device_rank).or_default().push(token_idx);
322            }
323        }
324
325        Ok(token_routing)
326    }
327
328    /// Hash-based token routing for deterministic assignment
329    fn hash_based_routing(&self, tokens: &Tensor) -> Result<TokenRouting> {
330        let batch_size = tokens.shape()[0];
331        let mut token_routing = TokenRouting {
332            token_indices: (0..batch_size).collect(),
333            expert_assignments: Vec::new(),
334            destinations: HashMap::new(),
335            capacity_usage: HashMap::new(),
336        };
337
338        for token_idx in 0..batch_size {
339            // Simple hash-based assignment (in practice, would use token content)
340            let expert_id = token_idx % self.config.num_experts;
341            let device_rank = self.expert_assignments[expert_id].device_rank;
342
343            token_routing.expert_assignments.push(vec![(expert_id, 1.0)]);
344            token_routing.destinations.entry(device_rank).or_default().push(token_idx);
345        }
346
347        Ok(token_routing)
348    }
349
350    /// Random token routing
351    fn random_routing(&self, tokens: &Tensor) -> Result<TokenRouting> {
352        let batch_size = tokens.shape()[0];
353        let mut token_routing = TokenRouting {
354            token_indices: (0..batch_size).collect(),
355            expert_assignments: Vec::new(),
356            destinations: HashMap::new(),
357            capacity_usage: HashMap::new(),
358        };
359
360        use std::collections::hash_map::DefaultHasher;
361        use std::hash::{Hash, Hasher};
362
363        for token_idx in 0..batch_size {
364            let mut hasher = DefaultHasher::new();
365            token_idx.hash(&mut hasher);
366            let expert_id = (hasher.finish() as usize) % self.config.num_experts;
367            let device_rank = self.expert_assignments[expert_id].device_rank;
368
369            token_routing.expert_assignments.push(vec![(expert_id, 1.0)]);
370            token_routing.destinations.entry(device_rank).or_default().push(token_idx);
371        }
372
373        Ok(token_routing)
374    }
375
376    /// Load-based token routing
377    fn load_based_routing(&self, tokens: &Tensor, _gating_scores: &Tensor) -> Result<TokenRouting> {
378        let batch_size = tokens.shape()[0];
379        let mut token_routing = TokenRouting {
380            token_indices: (0..batch_size).collect(),
381            expert_assignments: Vec::new(),
382            destinations: HashMap::new(),
383            capacity_usage: HashMap::new(),
384        };
385
386        let load_state = self
387            .load_balancing_state
388            .read()
389            .expect("load_balancing_state lock should not be poisoned");
390
391        for token_idx in 0..batch_size {
392            // Find least loaded expert
393            let mut min_load = f32::INFINITY;
394            let mut selected_expert = 0;
395
396            for expert_id in 0..self.config.num_experts {
397                let load = load_state.expert_loads.get(&expert_id).unwrap_or(&0.0);
398                if *load < min_load {
399                    min_load = *load;
400                    selected_expert = expert_id;
401                }
402            }
403
404            let device_rank = self.expert_assignments[selected_expert].device_rank;
405            token_routing.expert_assignments.push(vec![(selected_expert, 1.0)]);
406            token_routing.destinations.entry(device_rank).or_default().push(token_idx);
407        }
408
409        Ok(token_routing)
410    }
411
412    /// Similarity-based token routing
413    fn similarity_based_routing(
414        &self,
415        tokens: &Tensor,
416        gating_scores: &Tensor,
417    ) -> Result<TokenRouting> {
418        // For now, fall back to learned gating (similarity requires embedding analysis)
419        self.learned_gating_routing(tokens, gating_scores)
420    }
421
422    /// Perform all-to-all communication for expert parallelism
423    pub fn all_to_all_communication(
424        &self,
425        local_tokens: &Tensor,
426        routing: &TokenRouting,
427    ) -> Result<HashMap<usize, Tensor>> {
428        let start_time = Instant::now();
429
430        // Simulate all-to-all communication
431        // In practice, this would involve actual tensor communication
432        let mut expert_inputs = HashMap::new();
433
434        for expert_id in &self.local_experts {
435            // Collect tokens assigned to this expert
436            let mut expert_tokens = Vec::new();
437
438            for (token_idx, assignments) in routing.expert_assignments.iter().enumerate() {
439                for (assigned_expert_id, weight) in assignments {
440                    if *assigned_expert_id == *expert_id && *weight > 0.0 {
441                        // In practice, would extract actual token data
442                        expert_tokens.push(token_idx);
443                    }
444                }
445            }
446
447            // Create tensor for this expert (simplified)
448            if !expert_tokens.is_empty() {
449                let expert_tensor = Tensor::zeros(&[expert_tokens.len(), local_tokens.shape()[1]])?;
450                expert_inputs.insert(*expert_id, expert_tensor);
451            }
452        }
453
454        // Update communication statistics
455        {
456            let mut stats = self
457                .communication_stats
458                .lock()
459                .expect("communication_stats lock should not be poisoned");
460            stats.all_to_all_time += start_time.elapsed();
461        }
462
463        Ok(expert_inputs)
464    }
465
466    /// Update load balancing state
467    pub fn update_load_balancing(&self, expert_outputs: &HashMap<usize, Tensor>) -> Result<()> {
468        let mut load_state = self
469            .load_balancing_state
470            .write()
471            .expect("load_balancing_state lock should not be poisoned");
472
473        // Update expert loads based on output sizes
474        for (expert_id, output) in expert_outputs {
475            let load = output.shape()[0] as f32; // Number of tokens processed
476            load_state.expert_loads.insert(*expert_id, load);
477        }
478
479        // Calculate utilization and imbalance
480        let total_load: f32 = load_state.expert_loads.values().sum();
481        let avg_load = total_load / self.config.num_experts as f32;
482
483        let mut variance = 0.0;
484        for load in load_state.expert_loads.values() {
485            variance += (load - avg_load).powi(2);
486        }
487        variance /= self.config.num_experts as f32;
488
489        load_state.imbalance_score = variance.sqrt() / avg_load.max(1e-6);
490        load_state.last_rebalance_time = Some(Instant::now());
491
492        Ok(())
493    }
494
495    /// Get load balancing statistics
496    pub fn get_load_balancing_stats(&self) -> LoadBalancingStats {
497        let load_state = self
498            .load_balancing_state
499            .read()
500            .expect("load_balancing_state lock should not be poisoned");
501        let comm_stats = self.communication_stats.lock().expect("lock should not be poisoned");
502
503        LoadBalancingStats {
504            expert_loads: load_state.expert_loads.clone(),
505            imbalance_score: load_state.imbalance_score,
506            total_tokens_routed: comm_stats.total_tokens_routed,
507            communication_efficiency: comm_stats.communication_efficiency,
508            routing_overhead: comm_stats.routing_overhead,
509        }
510    }
511
512    /// Get local expert IDs
513    pub fn local_experts(&self) -> &[usize] {
514        &self.local_experts
515    }
516
517    /// Get expert assignment for a given expert ID
518    pub fn get_expert_assignment(&self, expert_id: usize) -> Option<&ExpertAssignment> {
519        self.expert_assignments.get(expert_id)
520    }
521
522    /// Get configuration
523    pub fn config(&self) -> &ExpertParallelismConfig {
524        &self.config
525    }
526}
527
528/// Load balancing statistics
529#[derive(Debug, Clone)]
530pub struct LoadBalancingStats {
531    pub expert_loads: HashMap<usize, f32>,
532    pub imbalance_score: f32,
533    pub total_tokens_routed: u64,
534    pub communication_efficiency: f32,
535    pub routing_overhead: Duration,
536}
537
538/// Expert parallelism utilities
539pub mod utils {
540    use super::*;
541
542    /// Calculate optimal expert parallelism configuration
543    pub fn calculate_optimal_expert_config(
544        num_experts: usize,
545        world_size: usize,
546        memory_per_expert_mb: usize,
547        available_memory_mb: usize,
548    ) -> Result<ExpertParallelismConfig> {
549        let experts_per_device = std::cmp::min(
550            available_memory_mb / memory_per_expert_mb,
551            num_experts / world_size,
552        );
553
554        if experts_per_device == 0 {
555            return Err(anyhow!("Insufficient memory for expert parallelism"));
556        }
557
558        let expert_parallel_size = num_experts.div_ceil(experts_per_device);
559
560        Ok(ExpertParallelismConfig {
561            num_experts,
562            experts_per_device,
563            expert_parallel_size,
564            ..Default::default()
565        })
566    }
567
568    /// Estimate communication cost for expert parallelism
569    pub fn estimate_communication_cost(
570        config: &ExpertParallelismConfig,
571        batch_size: usize,
572        sequence_length: usize,
573        hidden_size: usize,
574    ) -> f32 {
575        let total_tokens = batch_size * sequence_length;
576        let tokens_per_expert = total_tokens / config.num_experts;
577        let communication_volume = tokens_per_expert * hidden_size * 4; // 4 bytes per float
578
579        // Simplified cost model
580        communication_volume as f32 / (1024.0 * 1024.0) // Convert to MB
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use crate::distributed::SimulatedProcessGroup;
588    use std::sync::Arc;
589
590    #[test]
591    fn test_expert_parallelism_config() {
592        let config = ExpertParallelismConfig::default();
593        assert_eq!(config.num_experts, 8);
594        assert_eq!(config.expert_parallel_size, 4);
595        assert_eq!(config.experts_per_device, 2);
596    }
597
598    #[test]
599    fn test_expert_assignment_creation() {
600        let config = ExpertParallelismConfig {
601            num_experts: 8,
602            experts_per_device: 2,
603            expert_parallel_size: 4,
604            ..Default::default()
605        };
606
607        let assignments = ExpertParallelism::create_expert_assignments(&config, 4)
608            .expect("operation failed in test");
609        assert_eq!(assignments.len(), 8);
610
611        // Check that experts are distributed correctly
612        for (i, assignment) in assignments.iter().enumerate() {
613            assert_eq!(assignment.expert_id, i);
614            assert_eq!(assignment.device_rank, i / 2);
615            assert_eq!(assignment.local_expert_id, i % 2);
616        }
617    }
618
619    #[test]
620    fn test_local_experts() {
621        let config = ExpertParallelismConfig {
622            num_experts: 8,
623            experts_per_device: 2,
624            expert_parallel_size: 4,
625            ..Default::default()
626        };
627
628        let assignments = ExpertParallelism::create_expert_assignments(&config, 4)
629            .expect("operation failed in test");
630        let local_experts = ExpertParallelism::get_local_experts(&assignments, 1);
631
632        assert_eq!(local_experts, vec![2, 3]);
633    }
634
635    #[test]
636    fn test_expert_parallelism_creation() {
637        let config = ExpertParallelismConfig {
638            num_experts: 8,
639            experts_per_device: 2,
640            expert_parallel_size: 4,
641            ..Default::default()
642        };
643
644        let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
645        let expert_parallelism = ExpertParallelism::new(config, 0, 4, process_group);
646
647        assert!(expert_parallelism.is_ok());
648        let ep = expert_parallelism.expect("operation failed in test");
649        assert_eq!(ep.local_experts(), &[0, 1]);
650    }
651
652    #[test]
653    fn test_hash_based_routing() {
654        let config = ExpertParallelismConfig {
655            num_experts: 4,
656            experts_per_device: 1,
657            expert_parallel_size: 4,
658            routing_strategy: ExpertRoutingStrategy::HashBased,
659            ..Default::default()
660        };
661
662        let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
663        let expert_parallelism =
664            ExpertParallelism::new(config, 0, 4, process_group).expect("operation failed in test");
665
666        let tokens = Tensor::zeros(&[8, 16]).expect("tensor operation failed");
667        let routing = expert_parallelism
668            .hash_based_routing(&tokens)
669            .expect("operation failed in test");
670
671        assert_eq!(routing.token_indices.len(), 8);
672        assert_eq!(routing.expert_assignments.len(), 8);
673    }
674
675    #[test]
676    fn test_load_balancing_update() {
677        let config = ExpertParallelismConfig::default();
678        let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
679        let expert_parallelism =
680            ExpertParallelism::new(config, 0, 4, process_group).expect("operation failed in test");
681
682        let mut expert_outputs = HashMap::new();
683        expert_outputs.insert(
684            0,
685            Tensor::zeros(&[10, 16]).expect("tensor operation failed"),
686        );
687        expert_outputs.insert(
688            1,
689            Tensor::zeros(&[15, 16]).expect("tensor operation failed"),
690        );
691
692        let result = expert_parallelism.update_load_balancing(&expert_outputs);
693        assert!(result.is_ok());
694
695        let stats = expert_parallelism.get_load_balancing_stats();
696        assert!(stats.expert_loads.contains_key(&0));
697        assert!(stats.expert_loads.contains_key(&1));
698    }
699
700    #[test]
701    fn test_optimal_expert_config_calculation() {
702        let config = utils::calculate_optimal_expert_config(16, 8, 1000, 4000)
703            .expect("operation failed in test");
704        assert!(config.experts_per_device <= 4);
705        assert!(config.expert_parallel_size > 0);
706    }
707
708    #[test]
709    fn test_communication_cost_estimation() {
710        let config = ExpertParallelismConfig::default();
711        let cost = utils::estimate_communication_cost(&config, 32, 512, 768);
712        assert!(cost > 0.0);
713    }
714}