1use crate::distributed::ProcessGroup;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::{Duration, Instant};
7use trustformers_core::tensor::Tensor;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TensorParallelismConfig {
16 pub tensor_parallel_size: usize,
18 pub partitioning_strategy: TensorPartitioningStrategy,
20 pub column_parallel: bool,
22 pub row_parallel: bool,
24 pub communication_pattern: TensorCommunicationPattern,
26 pub async_communication: bool,
28 pub fusion_threshold_bytes: usize,
30 pub gradient_accumulation: bool,
32 pub memory_optimization: TensorMemoryOptimization,
34 pub mixed_precision: bool,
36}
37
38impl Default for TensorParallelismConfig {
39 fn default() -> Self {
40 Self {
41 tensor_parallel_size: 1,
42 partitioning_strategy: TensorPartitioningStrategy::ColumnWise,
43 column_parallel: true,
44 row_parallel: true,
45 communication_pattern: TensorCommunicationPattern::AllReduce,
46 async_communication: true,
47 fusion_threshold_bytes: 1024 * 1024, gradient_accumulation: true,
49 memory_optimization: TensorMemoryOptimization::Medium,
50 mixed_precision: false,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum TensorPartitioningStrategy {
58 ColumnWise,
60 RowWise,
62 BatchWise,
64 SequenceWise,
66 Dynamic,
68 BlockWise,
70 Custom,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum TensorCommunicationPattern {
77 AllReduce,
79 AllGather,
81 ReduceScatter,
83 PointToPoint,
85 Hierarchical,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum TensorMemoryOptimization {
92 None,
93 Low,
94 Medium,
95 High,
96 Extreme,
97}
98
99#[derive(Debug, Clone)]
101pub struct TensorPartition {
102 pub partition_id: usize,
104 pub device_rank: usize,
106 pub tensor_name: String,
108 pub shape: Vec<usize>,
110 pub offset: Vec<usize>,
112 pub needs_communication: bool,
114 pub dependencies: Vec<usize>,
116}
117
118#[derive(Debug, Clone)]
120pub struct TensorOperation {
121 pub operation_id: usize,
123 pub operation_type: TensorOperationType,
125 pub input_partitions: Vec<usize>,
127 pub output_partitions: Vec<usize>,
129 pub communication_requirements: Vec<CommunicationRequirement>,
131 pub memory_requirements: usize,
133}
134
135#[derive(Debug, Clone, Hash, Eq, PartialEq)]
137pub enum TensorOperationType {
138 MatMul,
139 Add,
140 Attention,
141 Linear,
142 Embedding,
143 LayerNorm,
144 Activation,
145 Custom(String),
146}
147
148#[derive(Debug, Clone)]
150pub struct CommunicationRequirement {
151 pub source_partition: usize,
153 pub target_partition: usize,
155 pub communication_type: TensorCommunicationPattern,
157 pub data_size: usize,
159}
160
161#[allow(dead_code)]
163pub struct TensorParallelism {
164 config: TensorParallelismConfig,
165 global_rank: usize,
166 world_size: usize,
167
168 tensor_partitions: HashMap<String, Vec<TensorPartition>>,
170 local_partitions: HashMap<String, Vec<usize>>, tensor_group: Arc<dyn ProcessGroup>,
174 column_group: Option<Arc<dyn ProcessGroup>>,
175 row_group: Option<Arc<dyn ProcessGroup>>,
176
177 #[allow(dead_code)]
179 operation_scheduler: Arc<RwLock<OperationScheduler>>,
180
181 communication_optimizer: Arc<Mutex<CommunicationOptimizer>>,
183
184 statistics: Arc<Mutex<TensorParallelismStats>>,
186}
187
188#[derive(Debug, Default)]
190#[allow(dead_code)]
191struct OperationScheduler {
192 #[allow(dead_code)]
193 pending_operations: Vec<TensorOperation>,
194 running_operations: Vec<TensorOperation>,
195 completed_operations: Vec<TensorOperation>,
196 operation_graph: HashMap<usize, Vec<usize>>, }
198
199#[derive(Debug, Default)]
201#[allow(dead_code)]
202struct CommunicationOptimizer {
203 #[allow(dead_code)]
204 fusion_buffer: Vec<CommunicationRequirement>,
205 communication_schedule: Vec<Vec<CommunicationRequirement>>, async_handles: Vec<AsyncCommHandle>,
207 bandwidth_usage: f32,
208 latency_estimates: HashMap<TensorCommunicationPattern, Duration>,
209}
210
211#[derive(Debug)]
213#[allow(dead_code)]
214struct AsyncCommHandle {
215 #[allow(dead_code)]
216 id: usize,
217 completion_time: Instant,
218}
219
220#[derive(Debug, Default)]
222struct TensorParallelismStats {
223 total_communication_time: Duration,
224 computation_time: Duration,
225 memory_usage_per_device: HashMap<usize, u64>,
226 communication_volume: u64,
227 operation_count: HashMap<TensorOperationType, usize>,
228 efficiency_score: f32,
229}
230
231impl TensorParallelism {
232 pub fn new(
234 config: TensorParallelismConfig,
235 global_rank: usize,
236 world_size: usize,
237 tensor_group: Arc<dyn ProcessGroup>,
238 ) -> Result<Self> {
239 if config.tensor_parallel_size > world_size {
241 return Err(anyhow!(
242 "Tensor parallel size ({}) cannot exceed world size ({})",
243 config.tensor_parallel_size,
244 world_size
245 ));
246 }
247
248 if world_size % config.tensor_parallel_size != 0 {
249 return Err(anyhow!(
250 "World size ({}) must be divisible by tensor parallel size ({})",
251 world_size,
252 config.tensor_parallel_size
253 ));
254 }
255
256 let column_group = if config.column_parallel {
258 Some(tensor_group.clone())
260 } else {
261 None
262 };
263
264 let row_group = if config.row_parallel {
265 Some(tensor_group.clone())
267 } else {
268 None
269 };
270
271 Ok(Self {
272 config,
273 global_rank,
274 world_size,
275 tensor_partitions: HashMap::new(),
276 local_partitions: HashMap::new(),
277 tensor_group,
278 column_group,
279 row_group,
280 operation_scheduler: Arc::new(RwLock::new(OperationScheduler::default())),
281 communication_optimizer: Arc::new(Mutex::new(CommunicationOptimizer::default())),
282 statistics: Arc::new(Mutex::new(TensorParallelismStats::default())),
283 })
284 }
285
286 pub fn partition_tensor(
288 &mut self,
289 tensor_name: &str,
290 tensor_shape: &[usize],
291 strategy: Option<TensorPartitioningStrategy>,
292 ) -> Result<Vec<TensorPartition>> {
293 let partitioning_strategy = strategy.unwrap_or(self.config.partitioning_strategy.clone());
294
295 let partitions = match partitioning_strategy {
296 TensorPartitioningStrategy::ColumnWise => {
297 self.partition_column_wise(tensor_name, tensor_shape)?
298 },
299 TensorPartitioningStrategy::RowWise => {
300 self.partition_row_wise(tensor_name, tensor_shape)?
301 },
302 TensorPartitioningStrategy::BatchWise => {
303 self.partition_batch_wise(tensor_name, tensor_shape)?
304 },
305 TensorPartitioningStrategy::SequenceWise => {
306 self.partition_sequence_wise(tensor_name, tensor_shape)?
307 },
308 TensorPartitioningStrategy::Dynamic => {
309 self.partition_dynamic(tensor_name, tensor_shape)?
310 },
311 TensorPartitioningStrategy::BlockWise => {
312 self.partition_block_wise(tensor_name, tensor_shape)?
313 },
314 TensorPartitioningStrategy::Custom => {
315 self.partition_custom(tensor_name, tensor_shape)?
316 },
317 };
318
319 let local_partition_ids: Vec<usize> = partitions
321 .iter()
322 .enumerate()
323 .filter(|(_, partition)| partition.device_rank == self.global_rank)
324 .map(|(i, _)| i)
325 .collect();
326
327 self.tensor_partitions.insert(tensor_name.to_string(), partitions.clone());
328 self.local_partitions.insert(tensor_name.to_string(), local_partition_ids);
329
330 Ok(partitions)
331 }
332
333 fn partition_column_wise(
335 &self,
336 tensor_name: &str,
337 tensor_shape: &[usize],
338 ) -> Result<Vec<TensorPartition>> {
339 if tensor_shape.len() < 2 {
340 return Err(anyhow!(
341 "Column-wise partitioning requires at least 2D tensor"
342 ));
343 }
344
345 let num_partitions = self.config.tensor_parallel_size;
346 let columns = tensor_shape[tensor_shape.len() - 1];
347 let columns_per_partition = columns.div_ceil(num_partitions);
348
349 let mut partitions = Vec::new();
350
351 for partition_id in 0..num_partitions {
352 let start_col = partition_id * columns_per_partition;
353 let end_col = std::cmp::min(start_col + columns_per_partition, columns);
354
355 if start_col < columns {
356 let mut partition_shape = tensor_shape.to_vec();
357 partition_shape[tensor_shape.len() - 1] = end_col - start_col;
358
359 let mut offset = vec![0; tensor_shape.len()];
360 offset[tensor_shape.len() - 1] = start_col;
361
362 let partition = TensorPartition {
363 partition_id,
364 device_rank: partition_id % self.world_size,
365 tensor_name: tensor_name.to_string(),
366 shape: partition_shape,
367 offset,
368 needs_communication: true,
369 dependencies: Vec::new(),
370 };
371
372 partitions.push(partition);
373 }
374 }
375
376 Ok(partitions)
377 }
378
379 fn partition_row_wise(
381 &self,
382 tensor_name: &str,
383 tensor_shape: &[usize],
384 ) -> Result<Vec<TensorPartition>> {
385 if tensor_shape.len() < 2 {
386 return Err(anyhow!("Row-wise partitioning requires at least 2D tensor"));
387 }
388
389 let num_partitions = self.config.tensor_parallel_size;
390 let rows = tensor_shape[tensor_shape.len() - 2];
391 let rows_per_partition = rows.div_ceil(num_partitions);
392
393 let mut partitions = Vec::new();
394
395 for partition_id in 0..num_partitions {
396 let start_row = partition_id * rows_per_partition;
397 let end_row = std::cmp::min(start_row + rows_per_partition, rows);
398
399 if start_row < rows {
400 let mut partition_shape = tensor_shape.to_vec();
401 partition_shape[tensor_shape.len() - 2] = end_row - start_row;
402
403 let mut offset = vec![0; tensor_shape.len()];
404 offset[tensor_shape.len() - 2] = start_row;
405
406 let partition = TensorPartition {
407 partition_id,
408 device_rank: partition_id % self.world_size,
409 tensor_name: tensor_name.to_string(),
410 shape: partition_shape,
411 offset,
412 needs_communication: true,
413 dependencies: Vec::new(),
414 };
415
416 partitions.push(partition);
417 }
418 }
419
420 Ok(partitions)
421 }
422
423 fn partition_batch_wise(
425 &self,
426 tensor_name: &str,
427 tensor_shape: &[usize],
428 ) -> Result<Vec<TensorPartition>> {
429 if tensor_shape.is_empty() {
430 return Err(anyhow!(
431 "Batch-wise partitioning requires at least 1D tensor"
432 ));
433 }
434
435 let num_partitions = self.config.tensor_parallel_size;
436 let batch_size = tensor_shape[0];
437 let batch_per_partition = batch_size.div_ceil(num_partitions);
438
439 let mut partitions = Vec::new();
440
441 for partition_id in 0..num_partitions {
442 let start_batch = partition_id * batch_per_partition;
443 let end_batch = std::cmp::min(start_batch + batch_per_partition, batch_size);
444
445 if start_batch < batch_size {
446 let mut partition_shape = tensor_shape.to_vec();
447 partition_shape[0] = end_batch - start_batch;
448
449 let mut offset = vec![0; tensor_shape.len()];
450 offset[0] = start_batch;
451
452 let partition = TensorPartition {
453 partition_id,
454 device_rank: partition_id % self.world_size,
455 tensor_name: tensor_name.to_string(),
456 shape: partition_shape,
457 offset,
458 needs_communication: false, dependencies: Vec::new(),
460 };
461
462 partitions.push(partition);
463 }
464 }
465
466 Ok(partitions)
467 }
468
469 fn partition_sequence_wise(
471 &self,
472 tensor_name: &str,
473 tensor_shape: &[usize],
474 ) -> Result<Vec<TensorPartition>> {
475 if tensor_shape.len() < 2 {
476 return Err(anyhow!(
477 "Sequence-wise partitioning requires at least 2D tensor"
478 ));
479 }
480
481 let num_partitions = self.config.tensor_parallel_size;
483 let sequence_length = tensor_shape[1];
484 let seq_per_partition = sequence_length.div_ceil(num_partitions);
485
486 let mut partitions = Vec::new();
487
488 for partition_id in 0..num_partitions {
489 let start_seq = partition_id * seq_per_partition;
490 let end_seq = std::cmp::min(start_seq + seq_per_partition, sequence_length);
491
492 if start_seq < sequence_length {
493 let mut partition_shape = tensor_shape.to_vec();
494 partition_shape[1] = end_seq - start_seq;
495
496 let mut offset = vec![0; tensor_shape.len()];
497 offset[1] = start_seq;
498
499 let partition = TensorPartition {
500 partition_id,
501 device_rank: partition_id % self.world_size,
502 tensor_name: tensor_name.to_string(),
503 shape: partition_shape,
504 offset,
505 needs_communication: true,
506 dependencies: Vec::new(),
507 };
508
509 partitions.push(partition);
510 }
511 }
512
513 Ok(partitions)
514 }
515
516 fn partition_dynamic(
518 &self,
519 tensor_name: &str,
520 tensor_shape: &[usize],
521 ) -> Result<Vec<TensorPartition>> {
522 if tensor_shape.len() >= 2 {
524 let last_dim = tensor_shape[tensor_shape.len() - 1];
525 let second_last_dim = tensor_shape[tensor_shape.len() - 2];
526
527 if last_dim > second_last_dim {
528 self.partition_column_wise(tensor_name, tensor_shape)
530 } else {
531 self.partition_row_wise(tensor_name, tensor_shape)
533 }
534 } else {
535 self.partition_batch_wise(tensor_name, tensor_shape)
537 }
538 }
539
540 fn partition_block_wise(
542 &self,
543 tensor_name: &str,
544 tensor_shape: &[usize],
545 ) -> Result<Vec<TensorPartition>> {
546 if tensor_shape.len() != 2 {
547 return Err(anyhow!("Block-wise partitioning only supports 2D tensors"));
548 }
549
550 let num_partitions = self.config.tensor_parallel_size;
551 let grid_size = (num_partitions as f64).sqrt().ceil() as usize;
552
553 if grid_size * grid_size != num_partitions {
554 return self.partition_column_wise(tensor_name, tensor_shape);
556 }
557
558 let rows = tensor_shape[0];
559 let cols = tensor_shape[1];
560 let rows_per_block = rows.div_ceil(grid_size);
561 let cols_per_block = cols.div_ceil(grid_size);
562
563 let mut partitions = Vec::new();
564 let mut partition_id = 0;
565
566 for row_block in 0..grid_size {
567 for col_block in 0..grid_size {
568 let start_row = row_block * rows_per_block;
569 let end_row = std::cmp::min(start_row + rows_per_block, rows);
570 let start_col = col_block * cols_per_block;
571 let end_col = std::cmp::min(start_col + cols_per_block, cols);
572
573 if start_row < rows && start_col < cols {
574 let partition_shape = vec![end_row - start_row, end_col - start_col];
575 let offset = vec![start_row, start_col];
576
577 let partition = TensorPartition {
578 partition_id,
579 device_rank: partition_id % self.world_size,
580 tensor_name: tensor_name.to_string(),
581 shape: partition_shape,
582 offset,
583 needs_communication: true,
584 dependencies: Vec::new(),
585 };
586
587 partitions.push(partition);
588 partition_id += 1;
589 }
590 }
591 }
592
593 Ok(partitions)
594 }
595
596 fn partition_custom(
598 &self,
599 tensor_name: &str,
600 tensor_shape: &[usize],
601 ) -> Result<Vec<TensorPartition>> {
602 self.partition_column_wise(tensor_name, tensor_shape)
604 }
605
606 pub fn execute_operation(
608 &self,
609 operation: &TensorOperation,
610 inputs: &HashMap<String, Tensor>,
611 ) -> Result<HashMap<String, Tensor>> {
612 let start_time = Instant::now();
613
614 let outputs = match &operation.operation_type {
616 TensorOperationType::MatMul => self.execute_matmul(operation, inputs)?,
617 TensorOperationType::Add => self.execute_add(operation, inputs)?,
618 TensorOperationType::Attention => self.execute_attention(operation, inputs)?,
619 TensorOperationType::Linear => self.execute_linear(operation, inputs)?,
620 TensorOperationType::Embedding => self.execute_embedding(operation, inputs)?,
621 TensorOperationType::LayerNorm => self.execute_layernorm(operation, inputs)?,
622 TensorOperationType::Activation => self.execute_activation(operation, inputs)?,
623 TensorOperationType::Custom(name) => self.execute_custom(name, operation, inputs)?,
624 };
625
626 self.handle_communication_requirements(&operation.communication_requirements)?;
628
629 {
631 let mut stats = self.statistics.lock().expect("statistics lock should not be poisoned");
632 stats.computation_time += start_time.elapsed();
633 *stats.operation_count.entry(operation.operation_type.clone()).or_insert(0) += 1;
634 }
635
636 Ok(outputs)
637 }
638
639 fn execute_matmul(
641 &self,
642 _operation: &TensorOperation,
643 inputs: &HashMap<String, Tensor>,
644 ) -> Result<HashMap<String, Tensor>> {
645 let mut outputs = HashMap::new();
648
649 if let (Some(a), Some(b)) = (inputs.get("A"), inputs.get("B")) {
650 let result = a.matmul(b)?;
651 outputs.insert("output".to_string(), result);
652 }
653
654 Ok(outputs)
655 }
656
657 fn execute_add(
659 &self,
660 _operation: &TensorOperation,
661 inputs: &HashMap<String, Tensor>,
662 ) -> Result<HashMap<String, Tensor>> {
663 let mut outputs = HashMap::new();
664
665 if let (Some(a), Some(b)) = (inputs.get("A"), inputs.get("B")) {
666 let result = a.add(b)?;
667 outputs.insert("output".to_string(), result);
668 }
669
670 Ok(outputs)
671 }
672
673 fn execute_attention(
675 &self,
676 _operation: &TensorOperation,
677 inputs: &HashMap<String, Tensor>,
678 ) -> Result<HashMap<String, Tensor>> {
679 let mut outputs = HashMap::new();
681
682 if let Some(input) = inputs.get("input") {
683 outputs.insert("output".to_string(), input.clone());
685 }
686
687 Ok(outputs)
688 }
689
690 fn execute_linear(
692 &self,
693 _operation: &TensorOperation,
694 inputs: &HashMap<String, Tensor>,
695 ) -> Result<HashMap<String, Tensor>> {
696 let mut outputs = HashMap::new();
697
698 if let (Some(input), Some(weight)) = (inputs.get("input"), inputs.get("weight")) {
699 let result = input.matmul(weight)?;
700 outputs.insert("output".to_string(), result);
701 }
702
703 Ok(outputs)
704 }
705
706 fn execute_embedding(
708 &self,
709 _operation: &TensorOperation,
710 inputs: &HashMap<String, Tensor>,
711 ) -> Result<HashMap<String, Tensor>> {
712 let mut outputs = HashMap::new();
713
714 if let Some(input) = inputs.get("input") {
715 outputs.insert("output".to_string(), input.clone());
717 }
718
719 Ok(outputs)
720 }
721
722 fn execute_layernorm(
724 &self,
725 _operation: &TensorOperation,
726 inputs: &HashMap<String, Tensor>,
727 ) -> Result<HashMap<String, Tensor>> {
728 let mut outputs = HashMap::new();
729
730 if let Some(input) = inputs.get("input") {
731 outputs.insert("output".to_string(), input.clone());
733 }
734
735 Ok(outputs)
736 }
737
738 fn execute_activation(
740 &self,
741 _operation: &TensorOperation,
742 inputs: &HashMap<String, Tensor>,
743 ) -> Result<HashMap<String, Tensor>> {
744 let mut outputs = HashMap::new();
745
746 if let Some(input) = inputs.get("input") {
747 outputs.insert("output".to_string(), input.clone());
749 }
750
751 Ok(outputs)
752 }
753
754 fn execute_custom(
756 &self,
757 _operation_name: &str,
758 _operation: &TensorOperation,
759 inputs: &HashMap<String, Tensor>,
760 ) -> Result<HashMap<String, Tensor>> {
761 let mut outputs = HashMap::new();
762
763 if let Some(input) = inputs.get("input") {
764 outputs.insert("output".to_string(), input.clone());
765 }
766
767 Ok(outputs)
768 }
769
770 fn handle_communication_requirements(
772 &self,
773 requirements: &[CommunicationRequirement],
774 ) -> Result<()> {
775 let start_time = Instant::now();
776
777 for requirement in requirements {
778 match requirement.communication_type {
779 TensorCommunicationPattern::AllReduce => {
780 self.handle_all_reduce(requirement)?;
781 },
782 TensorCommunicationPattern::AllGather => {
783 self.handle_all_gather(requirement)?;
784 },
785 TensorCommunicationPattern::ReduceScatter => {
786 self.handle_reduce_scatter(requirement)?;
787 },
788 TensorCommunicationPattern::PointToPoint => {
789 self.handle_point_to_point(requirement)?;
790 },
791 TensorCommunicationPattern::Hierarchical => {
792 self.handle_hierarchical(requirement)?;
793 },
794 }
795 }
796
797 {
799 let mut stats = self.statistics.lock().expect("statistics lock should not be poisoned");
800 stats.total_communication_time += start_time.elapsed();
801 stats.communication_volume +=
802 requirements.iter().map(|r| r.data_size as u64).sum::<u64>();
803 }
804
805 Ok(())
806 }
807
808 fn handle_all_reduce(&self, requirement: &CommunicationRequirement) -> Result<()> {
810 let partition_id = requirement.source_partition;
812
813 let partition = self
815 .tensor_partitions
816 .values()
817 .flatten()
818 .find(|p| p.partition_id == partition_id)
819 .ok_or_else(|| anyhow!("Partition {} not found for all-reduce", partition_id))?;
820
821 let _group = if self.config.column_parallel && partition.needs_communication {
823 self.column_group.as_ref().unwrap_or(&self.tensor_group)
824 } else {
825 &self.tensor_group
826 };
827
828 println!(
833 "All-reduce: Processing partition {} on device {} (size: {} bytes)",
834 partition_id, partition.device_rank, requirement.data_size
835 );
836
837 std::thread::sleep(Duration::from_micros((requirement.data_size / 1000) as u64));
839
840 Ok(())
841 }
842
843 fn handle_all_gather(&self, requirement: &CommunicationRequirement) -> Result<()> {
845 let source_partition = requirement.source_partition;
847 let target_partition = requirement.target_partition;
848
849 let _source = self
851 .tensor_partitions
852 .values()
853 .flatten()
854 .find(|p| p.partition_id == source_partition)
855 .ok_or_else(|| {
856 anyhow!(
857 "Source partition {} not found for all-gather",
858 source_partition
859 )
860 })?;
861
862 let _group = if self.config.row_parallel {
864 self.row_group.as_ref().unwrap_or(&self.tensor_group)
865 } else {
866 &self.tensor_group
867 };
868
869 println!(
874 "All-gather: Collecting from partition {} to partition {} (size: {} bytes)",
875 source_partition, target_partition, requirement.data_size
876 );
877
878 if let Some(tensor_name) = self
880 .tensor_partitions
881 .iter()
882 .find(|(_, partitions)| partitions.iter().any(|p| p.partition_id == source_partition))
883 .map(|(name, _)| name.clone())
884 {
885 println!(
887 "All-gather: Updated tensor '{}' with gathered data",
888 tensor_name
889 );
890 }
891
892 std::thread::sleep(Duration::from_micros((requirement.data_size / 500) as u64));
894
895 Ok(())
896 }
897
898 fn handle_reduce_scatter(&self, requirement: &CommunicationRequirement) -> Result<()> {
900 let source_partition = requirement.source_partition;
902 let target_partition = requirement.target_partition;
903
904 let _source = self
906 .tensor_partitions
907 .values()
908 .flatten()
909 .find(|p| p.partition_id == source_partition)
910 .ok_or_else(|| {
911 anyhow!(
912 "Source partition {} not found for reduce-scatter",
913 source_partition
914 )
915 })?;
916
917 let _group = &self.tensor_group;
919
920 let chunk_size = requirement.data_size / self.world_size;
922
923 println!("Reduce-scatter: Reducing partition {} and scattering to partition {} (chunk size: {} bytes)",
929 source_partition, target_partition, chunk_size);
930
931 let my_chunk_index = self.global_rank;
933 println!(
934 "Reduce-scatter: Device {} will receive chunk {}",
935 self.global_rank, my_chunk_index
936 );
937
938 std::thread::sleep(Duration::from_micros((requirement.data_size / 750) as u64));
940
941 Ok(())
942 }
943
944 fn handle_point_to_point(&self, requirement: &CommunicationRequirement) -> Result<()> {
946 let source_partition = requirement.source_partition;
948 let target_partition = requirement.target_partition;
949
950 let source = self
952 .tensor_partitions
953 .values()
954 .flatten()
955 .find(|p| p.partition_id == source_partition)
956 .ok_or_else(|| anyhow!("Source partition {} not found for P2P", source_partition))?;
957
958 let target = self
959 .tensor_partitions
960 .values()
961 .flatten()
962 .find(|p| p.partition_id == target_partition)
963 .ok_or_else(|| anyhow!("Target partition {} not found for P2P", target_partition))?;
964
965 let is_sender = source.device_rank == self.global_rank;
967 let is_receiver = target.device_rank == self.global_rank;
968
969 if is_sender {
970 println!(
972 "P2P: Sending from partition {} to device {} (size: {} bytes)",
973 source_partition, target.device_rank, requirement.data_size
974 );
975
976 } else if is_receiver {
980 println!(
982 "P2P: Receiving from device {} to partition {} (size: {} bytes)",
983 source.device_rank, target_partition, requirement.data_size
984 );
985
986 } else {
990 println!(
992 "P2P: Device {} not involved in communication {} -> {}",
993 self.global_rank, source.device_rank, target.device_rank
994 );
995 }
996
997 if is_sender || is_receiver {
999 std::thread::sleep(Duration::from_micros(
1000 (requirement.data_size / 2000 + 100) as u64,
1001 ));
1002 }
1003
1004 Ok(())
1005 }
1006
1007 fn handle_hierarchical(&self, requirement: &CommunicationRequirement) -> Result<()> {
1009 let source_partition = requirement.source_partition;
1011 let target_partition = requirement.target_partition;
1012
1013 let _source = self
1015 .tensor_partitions
1016 .values()
1017 .flatten()
1018 .find(|p| p.partition_id == source_partition)
1019 .ok_or_else(|| {
1020 anyhow!(
1021 "Source partition {} not found for hierarchical comm",
1022 source_partition
1023 )
1024 })?;
1025
1026 let nodes_per_level = (self.world_size as f64).sqrt().ceil() as usize;
1028 let node_id = self.global_rank / nodes_per_level;
1029 let local_rank = self.global_rank % nodes_per_level;
1030
1031 println!(
1032 "Hierarchical: Device {} (node {}, local rank {}) processing partition {}",
1033 self.global_rank, node_id, local_rank, source_partition
1034 );
1035
1036 if local_rank == 0 {
1042 println!(
1044 "Hierarchical: Node leader {} participating in inter-node communication",
1045 self.global_rank
1046 );
1047
1048 std::thread::sleep(Duration::from_micros((requirement.data_size / 1000) as u64));
1050
1051 std::thread::sleep(Duration::from_micros((requirement.data_size / 500) as u64));
1053
1054 std::thread::sleep(Duration::from_micros((requirement.data_size / 2000) as u64));
1056 } else {
1057 println!(
1059 "Hierarchical: Device {} participating in intra-node communication with leader",
1060 self.global_rank
1061 );
1062
1063 std::thread::sleep(Duration::from_micros((requirement.data_size / 2000) as u64));
1065
1066 std::thread::sleep(Duration::from_micros((requirement.data_size / 4000) as u64));
1068 }
1069
1070 println!(
1071 "Hierarchical: Completed hierarchical communication for partition {} (target: {})",
1072 source_partition, target_partition
1073 );
1074
1075 Ok(())
1076 }
1077
1078 pub fn get_statistics(&self) -> TensorParallelismStatistics {
1080 let stats = self.statistics.lock().expect("lock should not be poisoned");
1081
1082 TensorParallelismStatistics {
1083 total_partitions: self.tensor_partitions.values().map(|v| v.len()).sum(),
1084 local_partitions: self.local_partitions.values().map(|v| v.len()).sum(),
1085 communication_time: stats.total_communication_time,
1086 computation_time: stats.computation_time,
1087 communication_volume: stats.communication_volume,
1088 efficiency_score: stats.efficiency_score,
1089 memory_usage_per_device: stats.memory_usage_per_device.clone(),
1090 }
1091 }
1092
1093 pub fn config(&self) -> &TensorParallelismConfig {
1095 &self.config
1096 }
1097
1098 pub fn get_local_partitions(&self, tensor_name: &str) -> Option<&Vec<usize>> {
1100 self.local_partitions.get(tensor_name)
1101 }
1102
1103 pub fn get_tensor_partitions(&self, tensor_name: &str) -> Option<&Vec<TensorPartition>> {
1105 self.tensor_partitions.get(tensor_name)
1106 }
1107}
1108
1109#[derive(Debug, Clone)]
1111pub struct TensorParallelismStatistics {
1112 pub total_partitions: usize,
1113 pub local_partitions: usize,
1114 pub communication_time: Duration,
1115 pub computation_time: Duration,
1116 pub communication_volume: u64,
1117 pub efficiency_score: f32,
1118 pub memory_usage_per_device: HashMap<usize, u64>,
1119}
1120
1121pub mod utils {
1123 use super::*;
1124
1125 pub fn calculate_optimal_tensor_config(
1127 model_size_params: u64,
1128 memory_per_device: u64,
1129 world_size: usize,
1130 ) -> Result<TensorParallelismConfig> {
1131 let memory_per_param = 4; let model_memory_size = model_size_params * memory_per_param;
1133
1134 let required_devices = model_memory_size.div_ceil(memory_per_device);
1135 let tensor_parallel_size = std::cmp::min(required_devices as usize, world_size);
1136
1137 Ok(TensorParallelismConfig {
1138 tensor_parallel_size,
1139 ..Default::default()
1140 })
1141 }
1142
1143 pub fn estimate_communication_overhead(
1145 config: &TensorParallelismConfig,
1146 tensor_size_bytes: usize,
1147 operations_per_step: usize,
1148 ) -> f32 {
1149 let communication_per_operation = match config.communication_pattern {
1150 TensorCommunicationPattern::AllReduce => tensor_size_bytes * 2, TensorCommunicationPattern::AllGather => {
1152 tensor_size_bytes * config.tensor_parallel_size
1153 },
1154 TensorCommunicationPattern::ReduceScatter => tensor_size_bytes,
1155 _ => tensor_size_bytes,
1156 };
1157
1158 (communication_per_operation * operations_per_step) as f32 / (1024.0 * 1024.0)
1159 }
1161
1162 pub fn calculate_memory_savings(model_params: u64, tensor_parallel_size: usize) -> f32 {
1164 if tensor_parallel_size <= 1 {
1165 return 0.0;
1166 }
1167
1168 let memory_per_device = model_params / tensor_parallel_size as u64;
1169 let total_memory_without_tp = model_params;
1170
1171 1.0 - (memory_per_device as f32 / total_memory_without_tp as f32)
1172 }
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177 use super::*;
1178 use crate::distributed::SimulatedProcessGroup;
1179 use std::sync::Arc;
1180
1181 #[test]
1182 fn test_tensor_parallelism_config() {
1183 let config = TensorParallelismConfig::default();
1184 assert_eq!(config.tensor_parallel_size, 1);
1185 assert!(config.column_parallel);
1186 assert!(config.row_parallel);
1187 }
1188
1189 #[test]
1190 fn test_tensor_parallelism_creation() {
1191 let config = TensorParallelismConfig {
1192 tensor_parallel_size: 4,
1193 ..Default::default()
1194 };
1195
1196 let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
1197 let tensor_parallelism = TensorParallelism::new(config, 0, 4, process_group);
1198
1199 assert!(tensor_parallelism.is_ok());
1200 }
1201
1202 #[test]
1203 fn test_column_wise_partitioning() {
1204 let config = TensorParallelismConfig {
1205 tensor_parallel_size: 2,
1206 ..Default::default()
1207 };
1208
1209 let process_group = Arc::new(SimulatedProcessGroup::new(0, 2));
1210 let mut tensor_parallelism =
1211 TensorParallelism::new(config, 0, 2, process_group).expect("tensor operation failed");
1212
1213 let partitions = tensor_parallelism
1214 .partition_tensor("test", &[100, 200], None)
1215 .expect("tensor operation failed");
1216 assert_eq!(partitions.len(), 2);
1217 assert_eq!(partitions[0].shape, vec![100, 100]);
1218 assert_eq!(partitions[1].shape, vec![100, 100]);
1219 }
1220
1221 #[test]
1222 fn test_row_wise_partitioning() {
1223 let config = TensorParallelismConfig {
1224 tensor_parallel_size: 2,
1225 partitioning_strategy: TensorPartitioningStrategy::RowWise,
1226 ..Default::default()
1227 };
1228
1229 let process_group = Arc::new(SimulatedProcessGroup::new(0, 2));
1230 let mut tensor_parallelism =
1231 TensorParallelism::new(config, 0, 2, process_group).expect("tensor operation failed");
1232
1233 let partitions = tensor_parallelism
1234 .partition_tensor("test", &[100, 200], None)
1235 .expect("tensor operation failed");
1236 assert_eq!(partitions.len(), 2);
1237 assert_eq!(partitions[0].shape, vec![50, 200]);
1238 assert_eq!(partitions[1].shape, vec![50, 200]);
1239 }
1240
1241 #[test]
1242 fn test_batch_wise_partitioning() {
1243 let config = TensorParallelismConfig {
1244 tensor_parallel_size: 2,
1245 partitioning_strategy: TensorPartitioningStrategy::BatchWise,
1246 ..Default::default()
1247 };
1248
1249 let process_group = Arc::new(SimulatedProcessGroup::new(0, 2));
1250 let mut tensor_parallelism =
1251 TensorParallelism::new(config, 0, 2, process_group).expect("tensor operation failed");
1252
1253 let partitions = tensor_parallelism
1254 .partition_tensor("test", &[64, 100, 200], None)
1255 .expect("tensor operation failed");
1256 assert_eq!(partitions.len(), 2);
1257 assert_eq!(partitions[0].shape, vec![32, 100, 200]);
1258 assert_eq!(partitions[1].shape, vec![32, 100, 200]);
1259 }
1260
1261 #[test]
1262 fn test_tensor_operation_execution() {
1263 let config = TensorParallelismConfig::default();
1264 let process_group = Arc::new(SimulatedProcessGroup::new(0, 1));
1265 let tensor_parallelism =
1266 TensorParallelism::new(config, 0, 1, process_group).expect("tensor operation failed");
1267
1268 let operation = TensorOperation {
1269 operation_id: 0,
1270 operation_type: TensorOperationType::Add,
1271 input_partitions: vec![0, 1],
1272 output_partitions: vec![0],
1273 communication_requirements: vec![],
1274 memory_requirements: 1024,
1275 };
1276
1277 let mut inputs = HashMap::new();
1278 inputs.insert(
1279 "A".to_string(),
1280 Tensor::ones(&[10, 10]).expect("tensor operation failed"),
1281 );
1282 inputs.insert(
1283 "B".to_string(),
1284 Tensor::ones(&[10, 10]).expect("tensor operation failed"),
1285 );
1286
1287 let result = tensor_parallelism.execute_operation(&operation, &inputs);
1288 assert!(result.is_ok());
1289 }
1290
1291 #[test]
1292 fn test_optimal_tensor_config_calculation() {
1293 let config = utils::calculate_optimal_tensor_config(
1296 10_000_000_000, 8 * 1024 * 1024 * 1024, 8, )
1300 .expect("operation failed in test");
1301
1302 assert!(
1303 config.tensor_parallel_size > 1,
1304 "Expected tensor_parallel_size > 1, got {}",
1305 config.tensor_parallel_size
1306 );
1307 }
1308
1309 #[test]
1310 fn test_communication_overhead_estimation() {
1311 let config = TensorParallelismConfig::default();
1312 let overhead = utils::estimate_communication_overhead(&config, 1024 * 1024, 100);
1313 assert!(overhead > 0.0);
1314 }
1315
1316 #[test]
1317 fn test_memory_savings_calculation() {
1318 let savings = utils::calculate_memory_savings(1_000_000_000, 4);
1319 assert!(savings > 0.0 && savings < 1.0);
1320 }
1321}