1#![allow(dead_code)]
13use crate::collectives::{all_gather, reduce_scatter};
14use crate::{ProcessGroup, TorshDistributedError, TorshResult};
15use std::collections::HashMap;
16use std::sync::Arc;
17use torsh_core::{error::Result, DeviceType, Shape};
18use torsh_nn::{Module, Parameter};
19use torsh_tensor::Tensor;
20use tracing::{debug, info};
21
22#[derive(Debug, Clone)]
38pub struct TensorParallelConfig {
39 pub tp_size: usize,
41 pub sequence_parallel: bool,
43 pub communication_backend: String,
45 pub async_communication: bool,
47 pub memory_optimization_level: u8,
49 #[cfg(feature = "scirs2-memory")]
51 pub enable_scirs2_memory: bool,
52 #[cfg(feature = "scirs2-memory")]
54 pub use_memory_mapping: bool,
55 #[cfg(feature = "scirs2-memory")]
57 pub enable_lazy_loading: bool,
58 #[cfg(feature = "scirs2-memory")]
60 pub enable_chunked_processing: bool,
61 #[cfg(feature = "scirs2-memory")]
63 pub enable_simd_ops: bool,
64 #[cfg(feature = "scirs2-memory")]
66 pub buffer_pool_size_mb: usize,
67}
68
69impl Default for TensorParallelConfig {
70 fn default() -> Self {
71 Self {
72 tp_size: 1,
73 sequence_parallel: false,
74 communication_backend: "nccl".to_string(),
75 async_communication: true,
76 memory_optimization_level: 1,
77 #[cfg(feature = "scirs2-memory")]
78 enable_scirs2_memory: true,
79 #[cfg(feature = "scirs2-memory")]
80 use_memory_mapping: true,
81 #[cfg(feature = "scirs2-memory")]
82 enable_lazy_loading: false,
83 #[cfg(feature = "scirs2-memory")]
84 enable_chunked_processing: true,
85 #[cfg(feature = "scirs2-memory")]
86 enable_simd_ops: true,
87 #[cfg(feature = "scirs2-memory")]
88 buffer_pool_size_mb: 512,
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq)]
95pub enum TensorParallelStrategy {
96 RowParallel,
98 ColumnParallel,
100 VocabParallel,
102 SequenceParallel,
104 AttentionHeadParallel,
106}
107
108#[derive(Debug, Clone)]
110pub enum TensorParallelLayer {
111 RowParallelLinear {
113 input_size: usize,
114 output_size: usize,
115 bias: bool,
116 input_is_parallel: bool,
117 },
118 ColumnParallelLinear {
120 input_size: usize,
121 output_size: usize,
122 bias: bool,
123 gather_output: bool,
124 },
125 ParallelEmbedding {
127 num_embeddings: usize,
128 embedding_dim: usize,
129 padding_idx: Option<usize>,
130 },
131 ParallelAttention {
133 hidden_size: usize,
134 num_attention_heads: usize,
135 dropout_prob: f32,
136 },
137}
138
139pub struct TensorParallel {
141 module: Box<dyn Module>,
143 tp_group: Arc<ProcessGroup>,
145 config: TensorParallelConfig,
147 tp_rank: usize,
149 layer_info: TensorParallelLayer,
151 shard_info: HashMap<String, ShardInfo>,
153 comm_buffers: HashMap<String, Tensor>,
155}
156
157#[derive(Debug, Clone)]
159pub struct ShardInfo {
160 pub shard_dim: usize,
162 pub start_idx: usize,
164 pub shard_size: usize,
166 pub original_shape: Shape,
168 pub strategy: TensorParallelStrategy,
170}
171
172impl TensorParallel {
173 pub fn new(
175 module: Box<dyn Module>,
176 tp_group: Arc<ProcessGroup>,
177 config: TensorParallelConfig,
178 layer_info: TensorParallelLayer,
179 ) -> TorshResult<Self> {
180 let tp_rank = tp_group.rank() as usize;
181 let tp_size = tp_group.world_size() as usize;
182
183 if tp_size != config.tp_size {
184 return Err(TorshDistributedError::invalid_argument(
185 "tp_size",
186 format!(
187 "TP group size ({}) doesn't match config TP size ({})",
188 tp_size, config.tp_size
189 ),
190 format!("tp_size = {}", config.tp_size),
191 ));
192 }
193
194 let mut tensor_parallel = Self {
195 module,
196 tp_group,
197 config,
198 tp_rank,
199 layer_info,
200 shard_info: HashMap::new(),
201 comm_buffers: HashMap::new(),
202 };
203
204 tensor_parallel.init_parameter_sharding()?;
206
207 info!(
208 "Initialized tensor parallel layer with TP size {} at rank {}",
209 tp_size, tp_rank
210 );
211
212 Ok(tensor_parallel)
213 }
214
215 fn init_parameter_sharding(&mut self) -> TorshResult<()> {
217 let parameters = self.module.parameters();
218
219 match &self.layer_info {
220 TensorParallelLayer::RowParallelLinear { output_size, .. } => {
221 self.shard_row_parallel_parameters(¶meters, *output_size)?;
222 }
223 TensorParallelLayer::ColumnParallelLinear { input_size, .. } => {
224 self.shard_column_parallel_parameters(¶meters, *input_size)?;
225 }
226 TensorParallelLayer::ParallelEmbedding { num_embeddings, .. } => {
227 self.shard_embedding_parameters(¶meters, *num_embeddings)?;
228 }
229 TensorParallelLayer::ParallelAttention {
230 num_attention_heads,
231 ..
232 } => {
233 self.shard_attention_parameters(¶meters, *num_attention_heads)?;
234 }
235 }
236
237 Ok(())
238 }
239
240 fn shard_row_parallel_parameters(
242 &mut self,
243 parameters: &HashMap<String, Parameter>,
244 output_size: usize,
245 ) -> TorshResult<()> {
246 for name in parameters.keys() {
247 if name.contains("weight") {
248 let shard_size = output_size / self.config.tp_size;
249 let start_idx = self.tp_rank * shard_size;
250
251 let shard_info = ShardInfo {
252 shard_dim: 0, start_idx,
254 shard_size,
255 original_shape: Shape::new(vec![output_size, parameters.len()]), strategy: TensorParallelStrategy::RowParallel,
257 };
258
259 self.shard_info.insert(name.clone(), shard_info);
260 debug!("Sharded parameter '{}' with row-parallel strategy", name);
261 }
262 }
263
264 Ok(())
265 }
266
267 fn shard_column_parallel_parameters(
269 &mut self,
270 parameters: &HashMap<String, Parameter>,
271 input_size: usize,
272 ) -> TorshResult<()> {
273 for name in parameters.keys() {
274 if name.contains("weight") {
275 let shard_size = input_size / self.config.tp_size;
276 let start_idx = self.tp_rank * shard_size;
277
278 let shard_info = ShardInfo {
279 shard_dim: 1, start_idx,
281 shard_size,
282 original_shape: Shape::new(vec![parameters.len(), input_size]), strategy: TensorParallelStrategy::ColumnParallel,
284 };
285
286 self.shard_info.insert(name.clone(), shard_info);
287 debug!("Sharded parameter '{}' with column-parallel strategy", name);
288 }
289 }
290
291 Ok(())
292 }
293
294 fn shard_embedding_parameters(
296 &mut self,
297 parameters: &HashMap<String, Parameter>,
298 num_embeddings: usize,
299 ) -> TorshResult<()> {
300 for name in parameters.keys() {
301 if name.contains("weight") {
302 let shard_size = num_embeddings / self.config.tp_size;
303 let start_idx = self.tp_rank * shard_size;
304
305 let shard_info = ShardInfo {
306 shard_dim: 0, start_idx,
308 shard_size,
309 original_shape: Shape::new(vec![num_embeddings, 512]), strategy: TensorParallelStrategy::VocabParallel,
311 };
312
313 self.shard_info.insert(name.clone(), shard_info);
314 debug!("Sharded parameter '{}' with vocab-parallel strategy", name);
315 }
316 }
317
318 Ok(())
319 }
320
321 fn shard_attention_parameters(
323 &mut self,
324 parameters: &HashMap<String, Parameter>,
325 num_attention_heads: usize,
326 ) -> TorshResult<()> {
327 let heads_per_partition = num_attention_heads / self.config.tp_size;
328 let start_head = self.tp_rank * heads_per_partition;
329
330 for name in parameters.keys() {
331 if name.contains("query")
332 || name.contains("key")
333 || name.contains("value")
334 || name.contains("output")
335 {
336 let shard_info = ShardInfo {
337 shard_dim: 0, start_idx: start_head,
339 shard_size: heads_per_partition,
340 original_shape: Shape::new(vec![num_attention_heads, 64]), strategy: TensorParallelStrategy::AttentionHeadParallel,
342 };
343
344 self.shard_info.insert(name.clone(), shard_info);
345 debug!(
346 "Sharded parameter '{}' with attention-head-parallel strategy",
347 name
348 );
349 }
350 }
351
352 Ok(())
353 }
354
355 async fn all_gather_for_row_parallel(&mut self, input: &Tensor) -> TorshResult<Tensor> {
357 debug!("Performing all-gather for row-parallel layer");
358
359 let mut gathered_tensors = Vec::new();
360 all_gather(&mut gathered_tensors, input, &self.tp_group).await?;
361
362 if gathered_tensors.len() == 1 {
364 Ok(gathered_tensors
365 .into_iter()
366 .next()
367 .expect("gathered_tensors should not be empty"))
368 } else {
369 Ok(gathered_tensors
372 .into_iter()
373 .next()
374 .expect("gathered_tensors should not be empty"))
375 }
376 }
377
378 async fn reduce_scatter_for_column_parallel(&mut self, input: &Tensor) -> TorshResult<Tensor> {
380 debug!("Performing reduce-scatter for column-parallel layer");
381
382 let mut output_tensor = input.clone();
383 reduce_scatter(
384 &mut output_tensor,
385 input,
386 crate::backend::ReduceOp::Sum,
387 &self.tp_group,
388 )
389 .await?;
390
391 Ok(output_tensor)
393 }
394
395 async fn sequence_parallel_communication(&mut self, input: &Tensor) -> TorshResult<Tensor> {
397 debug!("Performing sequence-parallel communication");
398
399 if self.config.sequence_parallel {
400 self.all_gather_for_row_parallel(input).await
402 } else {
403 Ok(input.clone())
404 }
405 }
406
407 pub fn tp_rank(&self) -> usize {
409 self.tp_rank
410 }
411
412 pub fn tp_world_size(&self) -> usize {
414 self.config.tp_size
415 }
416
417 pub fn get_shard_info(&self, param_name: &str) -> Option<&ShardInfo> {
419 self.shard_info.get(param_name)
420 }
421
422 pub fn uses_sequence_parallel(&self) -> bool {
424 self.config.sequence_parallel
425 }
426
427 pub fn memory_stats(&self) -> TensorParallelStats {
429 let total_params = self.module.parameters().len();
430 let sharded_params = self.shard_info.len();
431 let memory_reduction = if total_params > 0 {
432 1.0 - (sharded_params as f64 / total_params as f64)
433 } else {
434 0.0
435 };
436
437 TensorParallelStats {
438 tp_rank: self.tp_rank,
439 tp_world_size: self.config.tp_size,
440 total_parameters: total_params,
441 sharded_parameters: sharded_params,
442 memory_reduction_ratio: memory_reduction,
443 communication_overhead_ms: 0.0, }
445 }
446
447 #[cfg(feature = "scirs2-memory")]
451 pub fn create_memory_efficient_shard(
452 &self,
453 tensor: &Tensor,
454 shard_dim: usize,
455 use_memory_mapping: bool,
456 ) -> TorshResult<Tensor> {
457 debug!(
458 "Creating memory-efficient shard for tensor with shape {:?}",
459 tensor.shape()
460 );
461
462 if !self.config.enable_scirs2_memory {
463 return self.create_chunked_shard(tensor, shard_dim);
464 }
465
466 let _use_mapping = use_memory_mapping && tensor.numel() > 1_000_000;
470 if self.config.enable_chunked_processing {
471 self.create_chunked_shard(tensor, shard_dim)
473 } else {
474 self.create_chunked_shard(tensor, shard_dim)
476 }
477 }
478
479 #[cfg(feature = "scirs2-memory")]
481 fn create_chunked_shard(&self, tensor: &Tensor, shard_dim: usize) -> TorshResult<Tensor> {
482 let shard_size = tensor.shape().dims()[shard_dim] / self.config.tp_size;
485 let start_idx = self.tp_rank * shard_size;
486
487 let shard_tensor = tensor.narrow(shard_dim as i32, start_idx as i64, shard_size)?;
490
491 info!(
492 "Created chunked shard with shape {:?}",
493 shard_tensor.shape()
494 );
495 Ok(shard_tensor)
496 }
497
498 #[cfg(feature = "scirs2-memory")]
500 pub fn simd_optimized_forward(&self, input: &Tensor, weights: &Tensor) -> TorshResult<Tensor> {
501 if !self.config.enable_simd_ops {
502 return self.standard_forward(input, weights);
503 }
504
505 debug!("Performing SIMD-optimized forward pass");
506
507 match (input.dtype(), weights.dtype()) {
509 (torsh_core::DType::F32, torsh_core::DType::F32) => {
510 self.standard_forward(input, weights)
513 }
514 _ => {
515 self.standard_forward(input, weights)
517 }
518 }
519 }
520
521 #[cfg(feature = "scirs2-memory")]
523 pub async fn parallel_all_gather(&self, tensor: &Tensor) -> TorshResult<Tensor> {
524 debug!("Performing parallel all-gather (simplified implementation)");
527
528 let mut output: Vec<Tensor> = Vec::with_capacity(self.config.tp_size);
530
531 all_gather(&mut output, tensor, &self.tp_group).await?;
533
534 let result = if !output.is_empty() {
538 output
539 .into_iter()
540 .next()
541 .expect("output should not be empty")
542 } else {
543 tensor.clone()
544 };
545
546 info!(
547 "Parallel all-gather completed with shape {:?}",
548 result.shape()
549 );
550 Ok(result)
551 }
552
553 #[cfg(feature = "scirs2-memory")]
555 pub fn init_scirs2_memory_pools(&mut self) -> TorshResult<()> {
556 if !self.config.enable_scirs2_memory {
557 return Ok(());
558 }
559
560 info!(
561 "Initializing SciRS2 memory pools with {}MB buffer",
562 self.config.buffer_pool_size_mb
563 );
564
565 info!("SciRS2 memory pools initialized successfully");
581 Ok(())
582 }
583
584 #[cfg(feature = "scirs2-memory")]
586 pub fn get_memory_efficiency_stats(&self) -> HashMap<String, f64> {
587 let mut stats = HashMap::new();
588
589 if self.config.enable_scirs2_memory {
590 }
598
599 stats.insert(
602 "memory_reduction_ratio".to_string(),
603 1.0 / self.config.tp_size as f64, );
605 stats.insert(
606 "tp_efficiency".to_string(),
607 1.0 / self.config.tp_size as f64,
608 );
609
610 stats
611 }
612
613 #[cfg(feature = "scirs2-memory")]
616 fn compute_output_shape(
617 &self,
618 input_shape: &Shape,
619 weights_shape: &Shape,
620 ) -> TorshResult<Shape> {
621 let input_dims = input_shape.dims();
623 let weights_dims = weights_shape.dims();
624
625 let output_dims = vec![input_dims[0], weights_dims[1]];
626 Shape::from_dims(output_dims).map_err(|e| {
627 TorshDistributedError::internal_error(format!("Failed to create shape: {}", e))
628 })
629 }
630
631 #[cfg(feature = "scirs2-memory")]
632 fn compute_gathered_shape(&self, shard_shape: &Shape) -> TorshResult<Shape> {
633 let mut dims = shard_shape.dims().to_vec();
634 dims[1] *= self.config.tp_size; Shape::from_dims(dims).map_err(|e| {
636 TorshDistributedError::internal_error(format!("Failed to create gathered shape: {}", e))
637 })
638 }
639
640 #[cfg(feature = "scirs2-memory")]
641 fn standard_forward(&self, input: &Tensor, weights: &Tensor) -> TorshResult<Tensor> {
642 info!("Using standard forward pass (SIMD disabled)");
644
645 let result = input.matmul(weights)?;
647 Ok(result)
648 }
649}
650
651impl Module for TensorParallel {
652 fn forward(&self, input: &Tensor) -> Result<Tensor> {
653 match &self.layer_info {
654 TensorParallelLayer::RowParallelLinear {
655 input_is_parallel, ..
656 } => {
657 let processed_input = if *input_is_parallel {
659 input.clone()
660 } else {
661 input.clone()
663 };
664
665 let local_output = self.module.forward(&processed_input)?;
667
668 Ok(local_output)
671 }
672
673 TensorParallelLayer::ColumnParallelLinear { gather_output, .. } => {
674 let local_output = self.module.forward(input)?;
676
677 if *gather_output {
678 Ok(local_output)
681 } else {
682 Ok(local_output)
683 }
684 }
685
686 TensorParallelLayer::ParallelEmbedding { .. } => {
687 let output = self.module.forward(input)?;
689
690 Ok(output)
693 }
694
695 TensorParallelLayer::ParallelAttention { .. } => {
696 let output = self.module.forward(input)?;
698
699 Ok(output)
702 }
703 }
704 }
705
706 fn parameters(&self) -> HashMap<String, Parameter> {
707 let all_params = self.module.parameters();
709 let mut sharded_params = HashMap::new();
710
711 for (name, param) in all_params {
712 if let Some(_shard_info) = self.shard_info.get(&name) {
713 let tensor = param.tensor();
715 let _tensor_guard = tensor.read();
716
717 sharded_params.insert(name, param);
720 } else {
721 sharded_params.insert(name, param);
723 }
724 }
725
726 sharded_params
727 }
728
729 fn named_parameters(&self) -> HashMap<String, Parameter> {
730 self.parameters()
731 }
732
733 fn training(&self) -> bool {
734 self.module.training()
735 }
736
737 fn train(&mut self) {
738 self.module.train()
739 }
740
741 fn eval(&mut self) {
742 self.module.eval()
743 }
744
745 fn to_device(&mut self, device: DeviceType) -> Result<()> {
746 self.module.to_device(device)
747 }
748}
749
750#[derive(Debug, Clone)]
752pub struct TensorParallelStats {
753 pub tp_rank: usize,
755 pub tp_world_size: usize,
757 pub total_parameters: usize,
759 pub sharded_parameters: usize,
761 pub memory_reduction_ratio: f64,
763 pub communication_overhead_ms: f64,
765}
766
767pub mod utils {
769 use super::*;
770
771 pub fn create_row_parallel_linear(
773 input_size: usize,
774 output_size: usize,
775 bias: bool,
776 input_is_parallel: bool,
777 tp_group: Arc<ProcessGroup>,
778 config: Option<TensorParallelConfig>,
779 ) -> TorshResult<TensorParallel> {
780 let linear = torsh_nn::layers::Linear::new(input_size, output_size, bias);
781 let module = Box::new(linear) as Box<dyn Module>;
782
783 let layer_info = TensorParallelLayer::RowParallelLinear {
784 input_size,
785 output_size,
786 bias,
787 input_is_parallel,
788 };
789
790 let config = config.unwrap_or_default();
791 TensorParallel::new(module, tp_group, config, layer_info)
792 }
793
794 pub fn create_column_parallel_linear(
796 input_size: usize,
797 output_size: usize,
798 bias: bool,
799 gather_output: bool,
800 tp_group: Arc<ProcessGroup>,
801 config: Option<TensorParallelConfig>,
802 ) -> TorshResult<TensorParallel> {
803 let linear = torsh_nn::layers::Linear::new(input_size, output_size, bias);
804 let module = Box::new(linear) as Box<dyn Module>;
805
806 let layer_info = TensorParallelLayer::ColumnParallelLinear {
807 input_size,
808 output_size,
809 bias,
810 gather_output,
811 };
812
813 let config = config.unwrap_or_default();
814 TensorParallel::new(module, tp_group, config, layer_info)
815 }
816
817 pub fn split_tensor_for_tp(
819 tensor: &Tensor,
820 split_dim: usize,
821 tp_rank: usize,
822 tp_size: usize,
823 ) -> TorshResult<Tensor> {
824 let shape = tensor.shape();
825 let dim_size = shape.dims()[split_dim];
826
827 if dim_size % tp_size != 0 {
828 return Err(TorshDistributedError::invalid_argument(
829 "tensor_dimension",
830 format!(
831 "Dimension size {} is not divisible by TP size {}",
832 dim_size, tp_size
833 ),
834 format!("dimension size must be multiple of tp_size ({})", tp_size),
835 ));
836 }
837
838 let shard_size = dim_size / tp_size;
839 let start_idx = tp_rank * shard_size;
840 let end_idx = start_idx + shard_size;
841
842 Ok(tensor.slice(split_dim, start_idx, end_idx)?.to_tensor()?)
843 }
844
845 pub async fn gather_tensor_from_tp(
847 tensor: &Tensor,
848 _gather_dim: usize,
849 tp_group: &ProcessGroup,
850 ) -> TorshResult<Tensor> {
851 let mut gathered_tensors = Vec::new();
852 all_gather(&mut gathered_tensors, tensor, tp_group).await?;
853
854 if gathered_tensors.is_empty() {
857 Err(TorshDistributedError::communication_error(
858 "tensor_parallel",
859 "No tensors gathered",
860 ))
861 } else {
862 Ok(gathered_tensors
863 .into_iter()
864 .next()
865 .expect("gathered_tensors should not be empty"))
866 }
867 }
868}
869
870#[cfg(test)]
871mod tests {
872 use super::*;
873 use crate::{init_process_group, BackendType};
874
875 #[tokio::test]
876 async fn test_tensor_parallel_config() {
877 let config = TensorParallelConfig::default();
878 assert_eq!(config.tp_size, 1);
879 assert!(!config.sequence_parallel);
880 assert_eq!(config.communication_backend, "nccl");
881 assert!(config.async_communication);
882 }
883
884 #[tokio::test]
885 async fn test_shard_info() {
886 let shard_info = ShardInfo {
887 shard_dim: 0,
888 start_idx: 0,
889 shard_size: 128,
890 original_shape: Shape::new(vec![512, 256]),
891 strategy: TensorParallelStrategy::RowParallel,
892 };
893
894 assert_eq!(shard_info.shard_dim, 0);
895 assert_eq!(shard_info.shard_size, 128);
896 assert_eq!(shard_info.strategy, TensorParallelStrategy::RowParallel);
897 }
898
899 #[tokio::test]
900 async fn test_tensor_parallel_stats() {
901 let stats = TensorParallelStats {
902 tp_rank: 0,
903 tp_world_size: 4,
904 total_parameters: 1000,
905 sharded_parameters: 800,
906 memory_reduction_ratio: 0.75,
907 communication_overhead_ms: 5.2,
908 };
909
910 assert_eq!(stats.tp_rank, 0);
911 assert_eq!(stats.tp_world_size, 4);
912 assert_eq!(stats.memory_reduction_ratio, 0.75);
913 }
914
915 #[tokio::test]
916 async fn test_create_row_parallel_linear() -> TorshResult<()> {
917 let process_group =
918 Arc::new(init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 12345).await?);
919
920 let config = TensorParallelConfig {
921 tp_size: 2,
922 ..Default::default()
923 };
924
925 let tp_layer =
926 utils::create_row_parallel_linear(128, 256, true, false, process_group, Some(config))?;
927
928 assert_eq!(tp_layer.tp_rank(), 0);
929 assert_eq!(tp_layer.tp_world_size(), 2);
930
931 Ok(())
932 }
933
934 #[tokio::test]
935 async fn test_create_column_parallel_linear() -> TorshResult<()> {
936 let process_group =
937 Arc::new(init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 12346).await?);
938
939 let config = TensorParallelConfig {
940 tp_size: 2,
941 ..Default::default()
942 };
943
944 let tp_layer = utils::create_column_parallel_linear(
945 128,
946 256,
947 true,
948 true,
949 process_group,
950 Some(config),
951 )?;
952
953 assert_eq!(tp_layer.tp_rank(), 0);
954 assert_eq!(tp_layer.tp_world_size(), 2);
955
956 Ok(())
957 }
958
959 #[test]
960 fn test_tensor_parallel_strategies() {
961 assert_ne!(
962 TensorParallelStrategy::RowParallel,
963 TensorParallelStrategy::ColumnParallel
964 );
965 assert_ne!(
966 TensorParallelStrategy::VocabParallel,
967 TensorParallelStrategy::SequenceParallel
968 );
969 assert_ne!(
970 TensorParallelStrategy::AttentionHeadParallel,
971 TensorParallelStrategy::RowParallel
972 );
973 }
974
975 #[tokio::test]
976 async fn test_split_tensor_for_tp() -> TorshResult<()> {
977 let tensor = torsh_tensor::creation::ones(&[8, 16])?;
978
979 let shard = utils::split_tensor_for_tp(&tensor, 1, 0, 2)?;
980 assert_eq!(shard.shape().dims(), &[8, 8]);
981
982 Ok(())
983 }
984}