Skip to main content

torsh_distributed/expert_parallelism/
manager.rs

1// Framework infrastructure - components designed for future use
2#![allow(dead_code)]
3#![allow(clippy::await_holding_lock)]
4use crate::expert_parallelism::{
5    config::{ExpertParallelismConfig, ExpertShardingStrategy},
6    router::RoutingDecision,
7};
8use crate::ProcessGroup;
9use crate::TorshResult;
10use log::{debug, info};
11use scirs2_core::random::thread_rng;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio;
15use torsh_core::DeviceType;
16use torsh_tensor::Tensor;
17
18/// Expert assignment for a single token
19#[derive(Debug, Clone)]
20pub struct ExpertAssignment {
21    pub expert_id: usize,
22    pub probability: f32,
23    pub token_idx: usize,
24    pub expert_rank: usize, // Rank among selected experts (0 = highest probability)
25}
26
27/// Expert shard information
28#[derive(Debug, Clone)]
29pub struct ExpertShardInfo {
30    pub expert_id: usize,
31    pub owner_rank: usize,
32    pub is_local: bool,
33    pub replicas: Vec<usize>, // Ranks that have copies of this expert
34}
35
36/// Individual expert model
37pub struct Expert {
38    pub expert_id: usize,
39    pub weights: Tensor<f32>,
40    pub bias: Tensor<f32>,
41    pub input_dim: usize,
42    pub hidden_dim: usize,
43    pub output_dim: usize,
44}
45
46impl Expert {
47    pub fn new(expert_id: usize, params: &ExpertParameters) -> TorshResult<Self> {
48        let mut rng = thread_rng();
49        let weights_data: Vec<f32> = (0..(params.input_dim * params.hidden_dim))
50            .map(|_| rng.random::<f32>() * 0.02)
51            .collect();
52        let weights = Tensor::from_vec(weights_data, &[params.input_dim, params.hidden_dim])?;
53        let bias = Tensor::zeros(&[params.hidden_dim], DeviceType::Cpu)?;
54
55        Ok(Self {
56            expert_id,
57            weights,
58            bias,
59            input_dim: params.input_dim,
60            hidden_dim: params.hidden_dim,
61            output_dim: params.output_dim,
62        })
63    }
64
65    pub async fn forward_async(&self, input: Tensor<f32>) -> TorshResult<Tensor<f32>> {
66        // Simulate async computation
67        tokio::task::yield_now().await;
68
69        // Expert computation: input @ weights + bias
70        let output = input.matmul(&self.weights)?;
71        let output = output.add(&self.bias)?;
72
73        // Apply activation (ReLU for simplicity)
74        let output = output.relu()?;
75
76        Ok(output)
77    }
78}
79
80/// Expert parameter configuration
81#[derive(Debug, Clone)]
82pub struct ExpertParameters {
83    pub input_dim: usize,
84    pub hidden_dim: usize,
85    pub output_dim: usize,
86    pub activation: String,
87}
88
89impl Default for ExpertParameters {
90    fn default() -> Self {
91        Self {
92            input_dim: 512,
93            hidden_dim: 2048,
94            output_dim: 512,
95            activation: "relu".to_string(),
96        }
97    }
98}
99
100/// All-to-All communication scheduler for token routing
101pub struct AllToAllScheduler {
102    process_group: Arc<ProcessGroup>,
103}
104
105impl AllToAllScheduler {
106    pub fn new(process_group: Arc<ProcessGroup>) -> Self {
107        Self { process_group }
108    }
109
110    pub async fn route_tokens_to_experts(
111        &self,
112        tokens: &Tensor<f32>,
113        routing_decision: &RoutingDecision,
114        sharding_map: &HashMap<usize, ExpertShardInfo>,
115    ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
116        info!("All-to-All: Routing tokens to experts");
117
118        // Enhanced all-to-all communication implementation for token routing
119        // This involves:
120        // 1. Grouping tokens by destination rank based on expert assignment
121        // 2. Performing all-to-all scatter to send tokens to expert owners
122        // 3. Receiving tokens assigned to local experts
123
124        let start_time = std::time::Instant::now();
125        let backend = self.process_group.backend();
126        #[allow(clippy::await_holding_lock)]
127        let backend_guard = backend.read();
128
129        // Step 1: Group tokens by destination rank
130        let mut tokens_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
131        let token_data = tokens.to_vec()?;
132        let tokens_per_row = tokens.shape().dims()[1];
133
134        debug!(
135            "Grouping {} tokens by destination rank",
136            routing_decision.total_tokens
137        );
138
139        // Process each token and determine its destination rank
140        for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
141            if let Some(assignment) = assignments.first() {
142                let expert_id = assignment.expert_id;
143                if let Some(shard_info) = sharding_map.get(&expert_id) {
144                    let dest_rank = shard_info.owner_rank;
145
146                    // Extract token data for this token
147                    let token_start = token_idx * tokens_per_row;
148                    let token_end = token_start + tokens_per_row;
149                    if token_end <= token_data.len() {
150                        let token_values = token_data[token_start..token_end].to_vec();
151                        tokens_by_rank
152                            .entry(dest_rank)
153                            .or_default()
154                            .push(token_values);
155                    }
156                }
157            }
158        }
159
160        // Step 2: Perform all-to-all scatter simulation
161        debug!(
162            "Performing all-to-all scatter: {} destination ranks",
163            tokens_by_rank.len()
164        );
165
166        // Simulate all-to-all communication latency
167        let total_elements: usize = tokens_by_rank
168            .values()
169            .map(|v| v.len() * tokens_per_row)
170            .sum();
171        let world_size = backend_guard.world_size() as usize;
172        let latency_us = (total_elements as f64 * world_size as f64 * 0.01).max(50.0);
173        tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
174
175        debug!(
176            "All-to-all scatter: {} elements across {} ranks",
177            total_elements, world_size
178        );
179
180        // Step 3: Receive and organize tokens for local experts
181        let mut routed_tokens = HashMap::new();
182        let current_rank = backend_guard.rank() as usize;
183
184        for (&expert_id, shard_info) in sharding_map {
185            if shard_info.is_local && shard_info.owner_rank == current_rank {
186                // Get tokens assigned to this local expert
187                if let Some(expert_tokens) = tokens_by_rank.get(&current_rank) {
188                    // Flatten the token vectors into a single tensor
189                    let mut flattened_tokens = Vec::new();
190                    for token in expert_tokens {
191                        flattened_tokens.extend(token);
192                    }
193
194                    if !flattened_tokens.is_empty() {
195                        let num_tokens = expert_tokens.len();
196                        let tensor_shape = vec![num_tokens, tokens_per_row];
197                        let expert_tensor = Tensor::from_vec(flattened_tokens, &tensor_shape)?;
198                        routed_tokens.insert(expert_id, expert_tensor);
199
200                        debug!(
201                            "Routed {} tokens to local expert {} ({} elements)",
202                            num_tokens,
203                            expert_id,
204                            num_tokens * tokens_per_row
205                        );
206                    }
207                } else {
208                    // Create empty tensor for expert with no assigned tokens
209                    let empty_tensor = Tensor::zeros(&[0, tokens_per_row], DeviceType::Cpu)?;
210                    routed_tokens.insert(expert_id, empty_tensor);
211                }
212            }
213        }
214
215        let duration = start_time.elapsed();
216        info!(
217            "All-to-all token routing completed: {} local experts in {:?}",
218            routed_tokens.len(),
219            duration
220        );
221
222        Ok(routed_tokens)
223    }
224
225    pub async fn route_results_back(
226        &self,
227        expert_outputs: &HashMap<usize, Tensor<f32>>,
228        routing_decision: &RoutingDecision,
229        sharding_map: &HashMap<usize, ExpertShardInfo>,
230    ) -> TorshResult<Tensor<f32>> {
231        info!("All-to-All: Routing expert results back");
232
233        // Enhanced all-to-all gather implementation for expert result collection
234        // This involves:
235        // 1. Performing all-to-all gather to collect results from all experts
236        // 2. Reassembling tokens in their original order
237        // 3. Combining results from multiple experts per token
238
239        let start_time = std::time::Instant::now();
240        #[allow(clippy::await_holding_lock)]
241        let backend = self.process_group.backend();
242        let backend_guard = backend.read();
243
244        debug!("Performing all-to-all gather: collecting expert results");
245
246        // Step 1: Prepare expert results for all-to-all gather
247        let mut results_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
248        let mut total_output_elements = 0;
249
250        // Process expert results
251        for (&expert_id, expert_result) in expert_outputs {
252            if let Some(shard_info) = sharding_map.get(&expert_id) {
253                let expert_data = expert_result.to_vec()?;
254                results_by_rank
255                    .entry(shard_info.owner_rank)
256                    .or_default()
257                    .push(expert_data.clone());
258                total_output_elements += expert_data.len();
259            }
260        }
261
262        // Step 2: Perform all-to-all gather simulation
263        let world_size = backend_guard.world_size() as usize;
264        // Simulate all-to-all gather latency (typically more expensive than scatter)
265        let latency_us = (total_output_elements as f64 * world_size as f64 * 0.02).max(100.0);
266        tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
267
268        debug!(
269            "All-to-all gather: {} elements from {} ranks",
270            total_output_elements,
271            results_by_rank.len()
272        );
273
274        // Step 3: Reassemble tokens in their original order
275        let output_dim = if let Some(first_result) = expert_outputs.values().next() {
276            first_result.shape().dims()[1]
277        } else {
278            512 // Default output dimension
279        };
280
281        let mut final_output_data = vec![0.0f32; routing_decision.total_tokens * output_dim];
282        let mut tokens_processed = 0;
283
284        // Process each token according to routing decision
285        for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
286            if let Some(assignment) = assignments.first() {
287                let expert_id = assignment.expert_id;
288                if let Some(expert_result) = expert_outputs.get(&expert_id) {
289                    let expert_data = expert_result.to_vec()?;
290                    let tokens_in_result = expert_data.len() / output_dim;
291
292                    // Find the appropriate token result within this expert's output
293                    let token_in_expert = token_idx % tokens_in_result.max(1);
294                    let result_start = token_in_expert * output_dim;
295                    let result_end = result_start + output_dim;
296
297                    if result_end <= expert_data.len() {
298                        let output_start = token_idx * output_dim;
299                        let output_end = output_start + output_dim;
300
301                        if output_end <= final_output_data.len() {
302                            final_output_data[output_start..output_end]
303                                .copy_from_slice(&expert_data[result_start..result_end]);
304                            tokens_processed += 1;
305                        }
306                    }
307                }
308            }
309        }
310
311        // Step 4: Create final output tensor
312        let output_shape = [routing_decision.total_tokens, output_dim];
313        let final_output = Tensor::from_vec(final_output_data, &output_shape)?;
314
315        let duration = start_time.elapsed();
316        info!(
317            "All-to-all result gathering completed: {} tokens processed in {:?}",
318            tokens_processed, duration
319        );
320
321        Ok(final_output)
322    }
323}
324
325/// Expert gradient aggregation
326pub struct ExpertGradientAggregator {
327    process_group: Arc<ProcessGroup>,
328}
329
330impl ExpertGradientAggregator {
331    pub fn new(process_group: Arc<ProcessGroup>) -> Self {
332        Self { process_group }
333    }
334
335    pub async fn aggregate_gradients(
336        &self,
337        expert_gradients: &HashMap<usize, Tensor<f32>>,
338        sharding_map: &HashMap<usize, ExpertShardInfo>,
339    ) -> TorshResult<()> {
340        info!(
341            "Aggregating expert gradients across {} experts",
342            expert_gradients.len()
343        );
344
345        for (&expert_id, gradient) in expert_gradients {
346            if let Some(shard_info) = sharding_map.get(&expert_id) {
347                match shard_info.replicas.len() {
348                    1 => {
349                        // Expert is sharded, no aggregation needed
350                        continue;
351                    }
352                    _ => {
353                        // Expert is replicated, need to aggregate gradients
354                        self.aggregate_replicated_expert_gradients(expert_id, gradient, shard_info)
355                            .await?;
356                    }
357                }
358            }
359        }
360
361        Ok(())
362    }
363
364    async fn aggregate_replicated_expert_gradients(
365        &self,
366        expert_id: usize,
367        gradient: &Tensor<f32>,
368        shard_info: &ExpertShardInfo,
369    ) -> TorshResult<()> {
370        info!(
371            "    Aggregating gradients for replicated expert {} across {} replicas",
372            expert_id,
373            shard_info.replicas.len()
374        );
375
376        // Enhanced gradient aggregation using all-reduce for replicated experts
377        // For replicated experts, we need to:
378        // 1. All-reduce gradients across all replicas
379        // 2. Average the gradients
380        // 3. Update expert parameters consistently
381
382        #[allow(clippy::await_holding_lock)]
383        let start_time = std::time::Instant::now();
384        let backend = self.process_group.backend();
385        let _backend_guard = backend.read();
386
387        let _aggregated_gradient = if shard_info.replicas.len() > 1 {
388            // Extract gradient data for all-reduce
389            let grad_data = gradient.to_vec()?;
390
391            info!(
392                "      All-reducing gradients across {} replicas",
393                shard_info.replicas.len()
394            );
395
396            // Simulate all-reduce operation across expert replicas
397            // In production, this would use a subgroup communicator for the replica set
398            let summed_gradients: Vec<f32> = grad_data
399                .iter()
400                .map(|&g| g * shard_info.replicas.len() as f32) // Simulate sum across replicas
401                .collect();
402
403            // Average the gradients
404            let averaged_gradients: Vec<f32> = summed_gradients
405                .iter()
406                .map(|&g| g / shard_info.replicas.len() as f32)
407                .collect();
408
409            // Simulate network latency for replica all-reduce
410            let latency_us =
411                (grad_data.len() as f64 * shard_info.replicas.len() as f64 * 0.01).max(20.0);
412            tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
413
414            // Create aggregated gradient tensor
415            let result = Tensor::from_vec(averaged_gradients, gradient.shape().dims())?;
416
417            info!(
418                "      Expert {} gradient all-reduce: {} elements across {} replicas",
419                expert_id,
420                grad_data.len(),
421                shard_info.replicas.len()
422            );
423
424            result
425        } else {
426            info!(
427                "       Single replica expert {}, no aggregation needed",
428                expert_id
429            );
430            gradient.clone()
431        };
432
433        let duration = start_time.elapsed();
434        info!(
435            "      Expert {} gradient aggregation completed in {:?}",
436            expert_id, duration
437        );
438
439        Ok(())
440    }
441}
442
443/// Distributed expert execution manager
444pub struct DistributedExpertManager {
445    config: ExpertParallelismConfig,
446    process_group: Arc<ProcessGroup>,
447    local_experts: Vec<Expert>,
448    expert_sharding_map: HashMap<usize, ExpertShardInfo>,
449    all_to_all_scheduler: AllToAllScheduler,
450    gradient_aggregator: ExpertGradientAggregator,
451}
452
453impl DistributedExpertManager {
454    pub fn new(
455        config: ExpertParallelismConfig,
456        process_group: Arc<ProcessGroup>,
457        expert_params: &ExpertParameters,
458    ) -> TorshResult<Self> {
459        let world_size = process_group.world_size() as usize;
460        let rank = process_group.rank() as usize;
461
462        // Create expert sharding map
463        let expert_sharding_map = Self::create_expert_sharding_map(&config, world_size, rank);
464
465        // Initialize local experts
466        let local_experts =
467            Self::initialize_local_experts(&config, &expert_sharding_map, expert_params)?;
468
469        let all_to_all_scheduler = AllToAllScheduler::new(process_group.clone());
470        let gradient_aggregator = ExpertGradientAggregator::new(process_group.clone());
471
472        Ok(Self {
473            config,
474            process_group,
475            local_experts,
476            expert_sharding_map,
477            all_to_all_scheduler,
478            gradient_aggregator,
479        })
480    }
481
482    pub fn create_expert_sharding_map(
483        config: &ExpertParallelismConfig,
484        world_size: usize,
485        rank: usize,
486    ) -> HashMap<usize, ExpertShardInfo> {
487        let mut sharding_map = HashMap::new();
488
489        match config.sharding_strategy {
490            ExpertShardingStrategy::DataParallel => {
491                // All experts on all devices
492                for expert_id in 0..config.num_experts {
493                    sharding_map.insert(
494                        expert_id,
495                        ExpertShardInfo {
496                            expert_id,
497                            owner_rank: rank,
498                            is_local: true,
499                            replicas: (0..world_size).collect(),
500                        },
501                    );
502                }
503            }
504            ExpertShardingStrategy::ModelParallel => {
505                // Distribute experts across devices
506                let experts_per_device = config.num_experts.div_ceil(world_size);
507                let start_expert = rank * experts_per_device;
508                let end_expert = ((rank + 1) * experts_per_device).min(config.num_experts);
509
510                for expert_id in 0..config.num_experts {
511                    let owner_rank = expert_id / experts_per_device;
512                    let is_local = expert_id >= start_expert && expert_id < end_expert;
513
514                    sharding_map.insert(
515                        expert_id,
516                        ExpertShardInfo {
517                            expert_id,
518                            owner_rank,
519                            is_local,
520                            replicas: vec![owner_rank],
521                        },
522                    );
523                }
524            }
525            ExpertShardingStrategy::Hybrid => {
526                // Mix of replicated and sharded experts
527                let replicated_experts = config.num_experts / 2;
528
529                for expert_id in 0..config.num_experts {
530                    if expert_id < replicated_experts {
531                        // Replicated experts
532                        sharding_map.insert(
533                            expert_id,
534                            ExpertShardInfo {
535                                expert_id,
536                                owner_rank: rank,
537                                is_local: true,
538                                replicas: (0..world_size).collect(),
539                            },
540                        );
541                    } else {
542                        // Sharded experts
543                        let sharded_id = expert_id - replicated_experts;
544                        let experts_per_device =
545                            (config.num_experts - replicated_experts).div_ceil(world_size);
546                        let owner_rank = sharded_id / experts_per_device;
547                        let is_local = owner_rank == rank;
548
549                        sharding_map.insert(
550                            expert_id,
551                            ExpertShardInfo {
552                                expert_id,
553                                owner_rank,
554                                is_local,
555                                replicas: vec![owner_rank],
556                            },
557                        );
558                    }
559                }
560            }
561            ExpertShardingStrategy::Dynamic => {
562                // Dynamic expert migration based on load balancing and communication patterns
563                // This implements intelligent expert placement and migration
564
565                // Initialize with model parallel as baseline
566                let experts_per_device = config.num_experts.div_ceil(world_size);
567
568                // Simulate load-based expert migration decisions
569                // In a real implementation, this would use historical routing statistics
570                for expert_id in 0..config.num_experts {
571                    // Calculate optimal placement based on simulated usage patterns
572                    let base_owner = expert_id / experts_per_device;
573
574                    // Dynamic migration logic: redistribute based on load patterns
575                    let optimal_owner = if config.num_experts > 32 {
576                        // For large numbers of experts, use load-balancing migration
577                        // Simulate expert usage frequency (in practice, would use real statistics)
578                        let usage_frequency = ((expert_id as f32 * 7.0).sin().abs() * 100.0) as u32;
579
580                        // High-usage experts get moved to less loaded devices
581                        if usage_frequency > 70 {
582                            // Move high-usage experts to spread load
583                            (base_owner + 1) % world_size
584                        } else if usage_frequency < 30 {
585                            // Consolidate low-usage experts
586                            (base_owner + world_size / 2) % world_size
587                        } else {
588                            base_owner
589                        }
590                    } else {
591                        // For smaller numbers, use communication-aware placement
592                        // Group related experts on the same device for better cache locality
593                        if expert_id % 4 == rank % 4 {
594                            rank // Keep related experts local
595                        } else {
596                            base_owner
597                        }
598                    };
599
600                    // Implement memory-aware migration: don't overload any single device
601                    let final_owner = if config.memory_per_expert_mb > 0 {
602                        let memory_per_device = config.memory_per_expert_mb * experts_per_device;
603                        let max_memory_mb = 16 * 1024; // 16GB limit per device
604
605                        if memory_per_device > max_memory_mb {
606                            // Redistribute to prevent memory overflow
607                            expert_id % world_size
608                        } else {
609                            optimal_owner
610                        }
611                    } else {
612                        optimal_owner
613                    };
614
615                    let is_local = final_owner == rank;
616
617                    // Determine replication strategy for dynamic experts
618                    let replicas = if config.num_experts <= 16 {
619                        // Small number of experts: replicate critical ones
620                        if expert_id < 4 {
621                            // Replicate first few experts across all devices
622                            (0..world_size).collect()
623                        } else {
624                            // Single owner for others
625                            vec![final_owner]
626                        }
627                    } else {
628                        // Large number of experts: selective replication
629                        if expert_id % 8 == 0 {
630                            // Replicate every 8th expert for load distribution
631                            vec![final_owner, (final_owner + 1) % world_size]
632                        } else {
633                            vec![final_owner]
634                        }
635                    };
636
637                    sharding_map.insert(
638                        expert_id,
639                        ExpertShardInfo {
640                            expert_id,
641                            owner_rank: final_owner,
642                            is_local,
643                            replicas,
644                        },
645                    );
646                }
647
648                info!(
649                    " Dynamic expert migration completed: {} experts distributed across {} devices",
650                    config.num_experts, world_size
651                );
652            }
653        }
654
655        sharding_map
656    }
657
658    fn initialize_local_experts(
659        _config: &ExpertParallelismConfig,
660        sharding_map: &HashMap<usize, ExpertShardInfo>,
661        expert_params: &ExpertParameters,
662    ) -> TorshResult<Vec<Expert>> {
663        let mut local_experts = Vec::new();
664
665        for (&expert_id, shard_info) in sharding_map {
666            if shard_info.is_local {
667                let expert = Expert::new(expert_id, expert_params)?;
668                local_experts.push(expert);
669            }
670        }
671
672        info!(" Initialized {} local experts", local_experts.len());
673        Ok(local_experts)
674    }
675
676    /// Execute distributed expert computation
677    pub async fn execute_experts(
678        &mut self,
679        tokens: &Tensor<f32>,
680        routing_decision: &RoutingDecision,
681    ) -> TorshResult<Tensor<f32>> {
682        // Step 1: All-to-All communication to route tokens to expert owners
683        let routed_tokens = self
684            .all_to_all_scheduler
685            .route_tokens_to_experts(tokens, routing_decision, &self.expert_sharding_map)
686            .await?;
687
688        // Step 2: Execute local experts in parallel
689        let local_outputs = self.execute_local_experts(&routed_tokens).await?;
690
691        // Step 3: All-to-All communication to route results back to original positions
692        let final_output = self
693            .all_to_all_scheduler
694            .route_results_back(&local_outputs, routing_decision, &self.expert_sharding_map)
695            .await?;
696
697        Ok(final_output)
698    }
699
700    async fn execute_local_experts(
701        &mut self,
702        routed_tokens: &HashMap<usize, Tensor<f32>>,
703    ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
704        let mut outputs = HashMap::new();
705
706        // Execute all local experts in parallel
707        let mut futures = Vec::new();
708
709        for expert in &mut self.local_experts {
710            if let Some(expert_tokens) = routed_tokens.get(&expert.expert_id) {
711                let future = expert.forward_async(expert_tokens.clone());
712                futures.push((expert.expert_id, future));
713            }
714        }
715
716        // Await all expert computations
717        for (expert_id, future) in futures {
718            let output = future.await?;
719            outputs.insert(expert_id, output);
720        }
721
722        Ok(outputs)
723    }
724
725    /// Aggregate gradients across distributed experts
726    pub async fn aggregate_expert_gradients(
727        &mut self,
728        expert_gradients: &HashMap<usize, Tensor<f32>>,
729    ) -> TorshResult<()> {
730        self.gradient_aggregator
731            .aggregate_gradients(expert_gradients, &self.expert_sharding_map)
732            .await
733    }
734
735    /// Get expert sharding information
736    pub fn get_expert_sharding_map(&self) -> &HashMap<usize, ExpertShardInfo> {
737        &self.expert_sharding_map
738    }
739
740    /// Get local experts
741    pub fn get_local_experts(&self) -> &Vec<Expert> {
742        &self.local_experts
743    }
744
745    /// Get configuration
746    pub fn get_config(&self) -> &ExpertParallelismConfig {
747        &self.config
748    }
749
750    /// Get the total number of experts across all ranks
751    pub fn get_num_experts(&self) -> usize {
752        self.config.num_experts
753    }
754}