1#![allow(dead_code)]
3#![allow(clippy::await_holding_lock)]
4use crate::expert_parallelism::{
5 config::{ExpertParallelismConfig, ExpertShardingStrategy},
6 router::RoutingDecision,
7};
8use crate::ProcessGroup;
9use crate::TorshResult;
10use log::{debug, info};
11use scirs2_core::random::thread_rng;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio;
15use torsh_core::DeviceType;
16use torsh_tensor::Tensor;
17
18#[derive(Debug, Clone)]
20pub struct ExpertAssignment {
21 pub expert_id: usize,
22 pub probability: f32,
23 pub token_idx: usize,
24 pub expert_rank: usize, }
26
27#[derive(Debug, Clone)]
29pub struct ExpertShardInfo {
30 pub expert_id: usize,
31 pub owner_rank: usize,
32 pub is_local: bool,
33 pub replicas: Vec<usize>, }
35
36pub struct Expert {
38 pub expert_id: usize,
39 pub weights: Tensor<f32>,
40 pub bias: Tensor<f32>,
41 pub input_dim: usize,
42 pub hidden_dim: usize,
43 pub output_dim: usize,
44}
45
46impl Expert {
47 pub fn new(expert_id: usize, params: &ExpertParameters) -> TorshResult<Self> {
48 let mut rng = thread_rng();
49 let weights_data: Vec<f32> = (0..(params.input_dim * params.hidden_dim))
50 .map(|_| rng.random::<f32>() * 0.02)
51 .collect();
52 let weights = Tensor::from_vec(weights_data, &[params.input_dim, params.hidden_dim])?;
53 let bias = Tensor::zeros(&[params.hidden_dim], DeviceType::Cpu)?;
54
55 Ok(Self {
56 expert_id,
57 weights,
58 bias,
59 input_dim: params.input_dim,
60 hidden_dim: params.hidden_dim,
61 output_dim: params.output_dim,
62 })
63 }
64
65 pub async fn forward_async(&self, input: Tensor<f32>) -> TorshResult<Tensor<f32>> {
66 tokio::task::yield_now().await;
68
69 let output = input.matmul(&self.weights)?;
71 let output = output.add(&self.bias)?;
72
73 let output = output.relu()?;
75
76 Ok(output)
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct ExpertParameters {
83 pub input_dim: usize,
84 pub hidden_dim: usize,
85 pub output_dim: usize,
86 pub activation: String,
87}
88
89impl Default for ExpertParameters {
90 fn default() -> Self {
91 Self {
92 input_dim: 512,
93 hidden_dim: 2048,
94 output_dim: 512,
95 activation: "relu".to_string(),
96 }
97 }
98}
99
100pub struct AllToAllScheduler {
102 process_group: Arc<ProcessGroup>,
103}
104
105impl AllToAllScheduler {
106 pub fn new(process_group: Arc<ProcessGroup>) -> Self {
107 Self { process_group }
108 }
109
110 pub async fn route_tokens_to_experts(
111 &self,
112 tokens: &Tensor<f32>,
113 routing_decision: &RoutingDecision,
114 sharding_map: &HashMap<usize, ExpertShardInfo>,
115 ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
116 info!("All-to-All: Routing tokens to experts");
117
118 let start_time = std::time::Instant::now();
125 let backend = self.process_group.backend();
126 #[allow(clippy::await_holding_lock)]
127 let backend_guard = backend.read();
128
129 let mut tokens_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
131 let token_data = tokens.to_vec()?;
132 let tokens_per_row = tokens.shape().dims()[1];
133
134 debug!(
135 "Grouping {} tokens by destination rank",
136 routing_decision.total_tokens
137 );
138
139 for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
141 if let Some(assignment) = assignments.first() {
142 let expert_id = assignment.expert_id;
143 if let Some(shard_info) = sharding_map.get(&expert_id) {
144 let dest_rank = shard_info.owner_rank;
145
146 let token_start = token_idx * tokens_per_row;
148 let token_end = token_start + tokens_per_row;
149 if token_end <= token_data.len() {
150 let token_values = token_data[token_start..token_end].to_vec();
151 tokens_by_rank
152 .entry(dest_rank)
153 .or_default()
154 .push(token_values);
155 }
156 }
157 }
158 }
159
160 debug!(
162 "Performing all-to-all scatter: {} destination ranks",
163 tokens_by_rank.len()
164 );
165
166 let total_elements: usize = tokens_by_rank
168 .values()
169 .map(|v| v.len() * tokens_per_row)
170 .sum();
171 let world_size = backend_guard.world_size() as usize;
172 let latency_us = (total_elements as f64 * world_size as f64 * 0.01).max(50.0);
173 tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
174
175 debug!(
176 "All-to-all scatter: {} elements across {} ranks",
177 total_elements, world_size
178 );
179
180 let mut routed_tokens = HashMap::new();
182 let current_rank = backend_guard.rank() as usize;
183
184 for (&expert_id, shard_info) in sharding_map {
185 if shard_info.is_local && shard_info.owner_rank == current_rank {
186 if let Some(expert_tokens) = tokens_by_rank.get(¤t_rank) {
188 let mut flattened_tokens = Vec::new();
190 for token in expert_tokens {
191 flattened_tokens.extend(token);
192 }
193
194 if !flattened_tokens.is_empty() {
195 let num_tokens = expert_tokens.len();
196 let tensor_shape = vec![num_tokens, tokens_per_row];
197 let expert_tensor = Tensor::from_vec(flattened_tokens, &tensor_shape)?;
198 routed_tokens.insert(expert_id, expert_tensor);
199
200 debug!(
201 "Routed {} tokens to local expert {} ({} elements)",
202 num_tokens,
203 expert_id,
204 num_tokens * tokens_per_row
205 );
206 }
207 } else {
208 let empty_tensor = Tensor::zeros(&[0, tokens_per_row], DeviceType::Cpu)?;
210 routed_tokens.insert(expert_id, empty_tensor);
211 }
212 }
213 }
214
215 let duration = start_time.elapsed();
216 info!(
217 "All-to-all token routing completed: {} local experts in {:?}",
218 routed_tokens.len(),
219 duration
220 );
221
222 Ok(routed_tokens)
223 }
224
225 pub async fn route_results_back(
226 &self,
227 expert_outputs: &HashMap<usize, Tensor<f32>>,
228 routing_decision: &RoutingDecision,
229 sharding_map: &HashMap<usize, ExpertShardInfo>,
230 ) -> TorshResult<Tensor<f32>> {
231 info!("All-to-All: Routing expert results back");
232
233 let start_time = std::time::Instant::now();
240 #[allow(clippy::await_holding_lock)]
241 let backend = self.process_group.backend();
242 let backend_guard = backend.read();
243
244 debug!("Performing all-to-all gather: collecting expert results");
245
246 let mut results_by_rank: HashMap<usize, Vec<Vec<f32>>> = HashMap::new();
248 let mut total_output_elements = 0;
249
250 for (&expert_id, expert_result) in expert_outputs {
252 if let Some(shard_info) = sharding_map.get(&expert_id) {
253 let expert_data = expert_result.to_vec()?;
254 results_by_rank
255 .entry(shard_info.owner_rank)
256 .or_default()
257 .push(expert_data.clone());
258 total_output_elements += expert_data.len();
259 }
260 }
261
262 let world_size = backend_guard.world_size() as usize;
264 let latency_us = (total_output_elements as f64 * world_size as f64 * 0.02).max(100.0);
266 tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
267
268 debug!(
269 "All-to-all gather: {} elements from {} ranks",
270 total_output_elements,
271 results_by_rank.len()
272 );
273
274 let output_dim = if let Some(first_result) = expert_outputs.values().next() {
276 first_result.shape().dims()[1]
277 } else {
278 512 };
280
281 let mut final_output_data = vec![0.0f32; routing_decision.total_tokens * output_dim];
282 let mut tokens_processed = 0;
283
284 for (token_idx, assignments) in routing_decision.expert_assignments.iter().enumerate() {
286 if let Some(assignment) = assignments.first() {
287 let expert_id = assignment.expert_id;
288 if let Some(expert_result) = expert_outputs.get(&expert_id) {
289 let expert_data = expert_result.to_vec()?;
290 let tokens_in_result = expert_data.len() / output_dim;
291
292 let token_in_expert = token_idx % tokens_in_result.max(1);
294 let result_start = token_in_expert * output_dim;
295 let result_end = result_start + output_dim;
296
297 if result_end <= expert_data.len() {
298 let output_start = token_idx * output_dim;
299 let output_end = output_start + output_dim;
300
301 if output_end <= final_output_data.len() {
302 final_output_data[output_start..output_end]
303 .copy_from_slice(&expert_data[result_start..result_end]);
304 tokens_processed += 1;
305 }
306 }
307 }
308 }
309 }
310
311 let output_shape = [routing_decision.total_tokens, output_dim];
313 let final_output = Tensor::from_vec(final_output_data, &output_shape)?;
314
315 let duration = start_time.elapsed();
316 info!(
317 "All-to-all result gathering completed: {} tokens processed in {:?}",
318 tokens_processed, duration
319 );
320
321 Ok(final_output)
322 }
323}
324
325pub struct ExpertGradientAggregator {
327 process_group: Arc<ProcessGroup>,
328}
329
330impl ExpertGradientAggregator {
331 pub fn new(process_group: Arc<ProcessGroup>) -> Self {
332 Self { process_group }
333 }
334
335 pub async fn aggregate_gradients(
336 &self,
337 expert_gradients: &HashMap<usize, Tensor<f32>>,
338 sharding_map: &HashMap<usize, ExpertShardInfo>,
339 ) -> TorshResult<()> {
340 info!(
341 "Aggregating expert gradients across {} experts",
342 expert_gradients.len()
343 );
344
345 for (&expert_id, gradient) in expert_gradients {
346 if let Some(shard_info) = sharding_map.get(&expert_id) {
347 match shard_info.replicas.len() {
348 1 => {
349 continue;
351 }
352 _ => {
353 self.aggregate_replicated_expert_gradients(expert_id, gradient, shard_info)
355 .await?;
356 }
357 }
358 }
359 }
360
361 Ok(())
362 }
363
364 async fn aggregate_replicated_expert_gradients(
365 &self,
366 expert_id: usize,
367 gradient: &Tensor<f32>,
368 shard_info: &ExpertShardInfo,
369 ) -> TorshResult<()> {
370 info!(
371 " Aggregating gradients for replicated expert {} across {} replicas",
372 expert_id,
373 shard_info.replicas.len()
374 );
375
376 #[allow(clippy::await_holding_lock)]
383 let start_time = std::time::Instant::now();
384 let backend = self.process_group.backend();
385 let _backend_guard = backend.read();
386
387 let _aggregated_gradient = if shard_info.replicas.len() > 1 {
388 let grad_data = gradient.to_vec()?;
390
391 info!(
392 " All-reducing gradients across {} replicas",
393 shard_info.replicas.len()
394 );
395
396 let summed_gradients: Vec<f32> = grad_data
399 .iter()
400 .map(|&g| g * shard_info.replicas.len() as f32) .collect();
402
403 let averaged_gradients: Vec<f32> = summed_gradients
405 .iter()
406 .map(|&g| g / shard_info.replicas.len() as f32)
407 .collect();
408
409 let latency_us =
411 (grad_data.len() as f64 * shard_info.replicas.len() as f64 * 0.01).max(20.0);
412 tokio::time::sleep(tokio::time::Duration::from_micros(latency_us as u64)).await;
413
414 let result = Tensor::from_vec(averaged_gradients, gradient.shape().dims())?;
416
417 info!(
418 " Expert {} gradient all-reduce: {} elements across {} replicas",
419 expert_id,
420 grad_data.len(),
421 shard_info.replicas.len()
422 );
423
424 result
425 } else {
426 info!(
427 " Single replica expert {}, no aggregation needed",
428 expert_id
429 );
430 gradient.clone()
431 };
432
433 let duration = start_time.elapsed();
434 info!(
435 " Expert {} gradient aggregation completed in {:?}",
436 expert_id, duration
437 );
438
439 Ok(())
440 }
441}
442
443pub struct DistributedExpertManager {
445 config: ExpertParallelismConfig,
446 process_group: Arc<ProcessGroup>,
447 local_experts: Vec<Expert>,
448 expert_sharding_map: HashMap<usize, ExpertShardInfo>,
449 all_to_all_scheduler: AllToAllScheduler,
450 gradient_aggregator: ExpertGradientAggregator,
451}
452
453impl DistributedExpertManager {
454 pub fn new(
455 config: ExpertParallelismConfig,
456 process_group: Arc<ProcessGroup>,
457 expert_params: &ExpertParameters,
458 ) -> TorshResult<Self> {
459 let world_size = process_group.world_size() as usize;
460 let rank = process_group.rank() as usize;
461
462 let expert_sharding_map = Self::create_expert_sharding_map(&config, world_size, rank);
464
465 let local_experts =
467 Self::initialize_local_experts(&config, &expert_sharding_map, expert_params)?;
468
469 let all_to_all_scheduler = AllToAllScheduler::new(process_group.clone());
470 let gradient_aggregator = ExpertGradientAggregator::new(process_group.clone());
471
472 Ok(Self {
473 config,
474 process_group,
475 local_experts,
476 expert_sharding_map,
477 all_to_all_scheduler,
478 gradient_aggregator,
479 })
480 }
481
482 pub fn create_expert_sharding_map(
483 config: &ExpertParallelismConfig,
484 world_size: usize,
485 rank: usize,
486 ) -> HashMap<usize, ExpertShardInfo> {
487 let mut sharding_map = HashMap::new();
488
489 match config.sharding_strategy {
490 ExpertShardingStrategy::DataParallel => {
491 for expert_id in 0..config.num_experts {
493 sharding_map.insert(
494 expert_id,
495 ExpertShardInfo {
496 expert_id,
497 owner_rank: rank,
498 is_local: true,
499 replicas: (0..world_size).collect(),
500 },
501 );
502 }
503 }
504 ExpertShardingStrategy::ModelParallel => {
505 let experts_per_device = config.num_experts.div_ceil(world_size);
507 let start_expert = rank * experts_per_device;
508 let end_expert = ((rank + 1) * experts_per_device).min(config.num_experts);
509
510 for expert_id in 0..config.num_experts {
511 let owner_rank = expert_id / experts_per_device;
512 let is_local = expert_id >= start_expert && expert_id < end_expert;
513
514 sharding_map.insert(
515 expert_id,
516 ExpertShardInfo {
517 expert_id,
518 owner_rank,
519 is_local,
520 replicas: vec![owner_rank],
521 },
522 );
523 }
524 }
525 ExpertShardingStrategy::Hybrid => {
526 let replicated_experts = config.num_experts / 2;
528
529 for expert_id in 0..config.num_experts {
530 if expert_id < replicated_experts {
531 sharding_map.insert(
533 expert_id,
534 ExpertShardInfo {
535 expert_id,
536 owner_rank: rank,
537 is_local: true,
538 replicas: (0..world_size).collect(),
539 },
540 );
541 } else {
542 let sharded_id = expert_id - replicated_experts;
544 let experts_per_device =
545 (config.num_experts - replicated_experts).div_ceil(world_size);
546 let owner_rank = sharded_id / experts_per_device;
547 let is_local = owner_rank == rank;
548
549 sharding_map.insert(
550 expert_id,
551 ExpertShardInfo {
552 expert_id,
553 owner_rank,
554 is_local,
555 replicas: vec![owner_rank],
556 },
557 );
558 }
559 }
560 }
561 ExpertShardingStrategy::Dynamic => {
562 let experts_per_device = config.num_experts.div_ceil(world_size);
567
568 for expert_id in 0..config.num_experts {
571 let base_owner = expert_id / experts_per_device;
573
574 let optimal_owner = if config.num_experts > 32 {
576 let usage_frequency = ((expert_id as f32 * 7.0).sin().abs() * 100.0) as u32;
579
580 if usage_frequency > 70 {
582 (base_owner + 1) % world_size
584 } else if usage_frequency < 30 {
585 (base_owner + world_size / 2) % world_size
587 } else {
588 base_owner
589 }
590 } else {
591 if expert_id % 4 == rank % 4 {
594 rank } else {
596 base_owner
597 }
598 };
599
600 let final_owner = if config.memory_per_expert_mb > 0 {
602 let memory_per_device = config.memory_per_expert_mb * experts_per_device;
603 let max_memory_mb = 16 * 1024; if memory_per_device > max_memory_mb {
606 expert_id % world_size
608 } else {
609 optimal_owner
610 }
611 } else {
612 optimal_owner
613 };
614
615 let is_local = final_owner == rank;
616
617 let replicas = if config.num_experts <= 16 {
619 if expert_id < 4 {
621 (0..world_size).collect()
623 } else {
624 vec![final_owner]
626 }
627 } else {
628 if expert_id % 8 == 0 {
630 vec![final_owner, (final_owner + 1) % world_size]
632 } else {
633 vec![final_owner]
634 }
635 };
636
637 sharding_map.insert(
638 expert_id,
639 ExpertShardInfo {
640 expert_id,
641 owner_rank: final_owner,
642 is_local,
643 replicas,
644 },
645 );
646 }
647
648 info!(
649 " Dynamic expert migration completed: {} experts distributed across {} devices",
650 config.num_experts, world_size
651 );
652 }
653 }
654
655 sharding_map
656 }
657
658 fn initialize_local_experts(
659 _config: &ExpertParallelismConfig,
660 sharding_map: &HashMap<usize, ExpertShardInfo>,
661 expert_params: &ExpertParameters,
662 ) -> TorshResult<Vec<Expert>> {
663 let mut local_experts = Vec::new();
664
665 for (&expert_id, shard_info) in sharding_map {
666 if shard_info.is_local {
667 let expert = Expert::new(expert_id, expert_params)?;
668 local_experts.push(expert);
669 }
670 }
671
672 info!(" Initialized {} local experts", local_experts.len());
673 Ok(local_experts)
674 }
675
676 pub async fn execute_experts(
678 &mut self,
679 tokens: &Tensor<f32>,
680 routing_decision: &RoutingDecision,
681 ) -> TorshResult<Tensor<f32>> {
682 let routed_tokens = self
684 .all_to_all_scheduler
685 .route_tokens_to_experts(tokens, routing_decision, &self.expert_sharding_map)
686 .await?;
687
688 let local_outputs = self.execute_local_experts(&routed_tokens).await?;
690
691 let final_output = self
693 .all_to_all_scheduler
694 .route_results_back(&local_outputs, routing_decision, &self.expert_sharding_map)
695 .await?;
696
697 Ok(final_output)
698 }
699
700 async fn execute_local_experts(
701 &mut self,
702 routed_tokens: &HashMap<usize, Tensor<f32>>,
703 ) -> TorshResult<HashMap<usize, Tensor<f32>>> {
704 let mut outputs = HashMap::new();
705
706 let mut futures = Vec::new();
708
709 for expert in &mut self.local_experts {
710 if let Some(expert_tokens) = routed_tokens.get(&expert.expert_id) {
711 let future = expert.forward_async(expert_tokens.clone());
712 futures.push((expert.expert_id, future));
713 }
714 }
715
716 for (expert_id, future) in futures {
718 let output = future.await?;
719 outputs.insert(expert_id, output);
720 }
721
722 Ok(outputs)
723 }
724
725 pub async fn aggregate_expert_gradients(
727 &mut self,
728 expert_gradients: &HashMap<usize, Tensor<f32>>,
729 ) -> TorshResult<()> {
730 self.gradient_aggregator
731 .aggregate_gradients(expert_gradients, &self.expert_sharding_map)
732 .await
733 }
734
735 pub fn get_expert_sharding_map(&self) -> &HashMap<usize, ExpertShardInfo> {
737 &self.expert_sharding_map
738 }
739
740 pub fn get_local_experts(&self) -> &Vec<Expert> {
742 &self.local_experts
743 }
744
745 pub fn get_config(&self) -> &ExpertParallelismConfig {
747 &self.config
748 }
749
750 pub fn get_num_experts(&self) -> usize {
752 self.config.num_experts
753 }
754}