1use crate::{GraphData, GraphLayer};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use torsh_tensor::Tensor;
10
11#[derive(Debug, Clone)]
13pub struct DistributedConfig {
14 pub num_workers: usize,
16 pub rank: usize,
18 pub backend: CommunicationBackend,
20 pub partitioning: GraphPartitioning,
22 pub aggregation: AggregationMethod,
24 pub sync_frequency: usize,
26}
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum CommunicationBackend {
31 MPI,
33 NCCL,
35 Gloo,
37 TCP,
39 InMemory,
41}
42
43pub enum GraphPartitioning {
45 Random,
47 METIS,
49 Hash,
51 Community,
53 Custom(Box<dyn Fn(&GraphData, usize) -> Vec<PartitionInfo> + Send + Sync>),
55}
56
57impl std::fmt::Debug for GraphPartitioning {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 GraphPartitioning::Random => write!(f, "GraphPartitioning::Random"),
62 GraphPartitioning::METIS => write!(f, "GraphPartitioning::METIS"),
63 GraphPartitioning::Hash => write!(f, "GraphPartitioning::Hash"),
64 GraphPartitioning::Community => write!(f, "GraphPartitioning::Community"),
65 GraphPartitioning::Custom(_) => write!(f, "GraphPartitioning::Custom(<function>)"),
66 }
67 }
68}
69
70impl Clone for GraphPartitioning {
72 fn clone(&self) -> Self {
73 match self {
74 GraphPartitioning::Random => GraphPartitioning::Random,
75 GraphPartitioning::METIS => GraphPartitioning::METIS,
76 GraphPartitioning::Hash => GraphPartitioning::Hash,
77 GraphPartitioning::Community => GraphPartitioning::Community,
78 GraphPartitioning::Custom(_) => {
79 GraphPartitioning::Random
81 }
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub enum AggregationMethod {
89 Average,
91 Sum,
93 WeightedAverage,
95 ParameterServer,
97 AllReduce,
99}
100
101#[derive(Debug, Clone)]
103pub struct PartitionInfo {
104 pub worker_rank: usize,
106 pub nodes: Vec<usize>,
108 pub internal_edges: Vec<(usize, usize)>,
110 pub boundary_edges: Vec<(usize, usize, usize)>, pub metrics: PartitionMetrics,
114}
115
116#[derive(Debug, Clone)]
118pub struct PartitionMetrics {
119 pub num_nodes: usize,
121 pub num_internal_edges: usize,
123 pub num_boundary_edges: usize,
125 pub load_balance_score: f32,
127 pub communication_cost: f32,
129}
130
131#[derive(Debug)]
133pub struct DistributedGNN {
134 pub config: DistributedConfig,
136 pub local_partition: GraphData,
138 pub partition_info: PartitionInfo,
140 pub comm_manager: CommunicationManager,
142 pub sync_state: Arc<Mutex<SyncState>>,
144 pub metrics: DistributedMetrics,
146}
147
148impl DistributedGNN {
149 pub fn new(
151 config: DistributedConfig,
152 full_graph: &GraphData,
153 ) -> Result<Self, DistributedError> {
154 let partitions = Self::partition_graph(full_graph, &config)?;
156 let local_partition = partitions[config.rank].clone();
157
158 let comm_manager = CommunicationManager::new(&config)?;
160
161 let partition_info = Self::create_partition_info(&local_partition, config.rank);
163
164 let sync_state = Arc::new(Mutex::new(SyncState::new()));
165 let metrics = DistributedMetrics::new();
166
167 Ok(Self {
168 config,
169 local_partition,
170 partition_info,
171 comm_manager,
172 sync_state,
173 metrics,
174 })
175 }
176
177 pub fn distributed_forward(
179 &mut self,
180 layer: &dyn GraphLayer,
181 ) -> Result<GraphData, DistributedError> {
182 let boundary_features = self.gather_boundary_features()?;
184
185 let augmented_graph = self.augment_local_graph(&boundary_features)?;
187
188 let local_output = layer.forward(&augmented_graph);
190
191 self.communicate_boundary_updates(&local_output)?;
193
194 Ok(local_output)
195 }
196
197 pub fn synchronize_parameters(
199 &mut self,
200 parameters: &[Tensor],
201 ) -> Result<Vec<Tensor>, DistributedError> {
202 match self.config.aggregation {
203 AggregationMethod::AllReduce => self.all_reduce_parameters(parameters),
204 AggregationMethod::Average => self.average_parameters(parameters),
205 AggregationMethod::Sum => self.sum_parameters(parameters),
206 AggregationMethod::WeightedAverage => self.weighted_average_parameters(parameters),
207 AggregationMethod::ParameterServer => self.parameter_server_sync(parameters),
208 }
209 }
210
211 fn all_reduce_parameters(
213 &mut self,
214 parameters: &[Tensor],
215 ) -> Result<Vec<Tensor>, DistributedError> {
216 let mut reduced_params = Vec::new();
217
218 for param in parameters {
219 let param_data = param.to_vec().map_err(|e| {
221 DistributedError::CommunicationError(format!(
222 "Failed to serialize parameter: {:?}",
223 e
224 ))
225 })?;
226
227 let reduced_data = self.comm_manager.all_reduce(¶m_data)?;
229
230 let reduced_param = self.vec_to_tensor(&reduced_data, param.shape().dims())?;
232 reduced_params.push(reduced_param);
233 }
234
235 Ok(reduced_params)
236 }
237
238 fn average_parameters(
240 &mut self,
241 parameters: &[Tensor],
242 ) -> Result<Vec<Tensor>, DistributedError> {
243 let summed_params = self.sum_parameters(parameters)?;
244 let num_workers = self.config.num_workers as f32;
245
246 Ok(summed_params
247 .into_iter()
248 .map(|param| {
249 param
250 .div_scalar(num_workers)
251 .expect("parameter division should succeed")
252 })
253 .collect())
254 }
255
256 fn sum_parameters(&mut self, parameters: &[Tensor]) -> Result<Vec<Tensor>, DistributedError> {
258 let mut summed_params = Vec::new();
259
260 for param in parameters {
261 let param_data = param.to_vec().map_err(|e| {
262 DistributedError::CommunicationError(format!(
263 "Failed to serialize parameter: {:?}",
264 e
265 ))
266 })?;
267
268 let summed_data = self.comm_manager.all_reduce_sum(¶m_data)?;
269 let summed_param = self.vec_to_tensor(&summed_data, param.shape().dims())?;
270 summed_params.push(summed_param);
271 }
272
273 Ok(summed_params)
274 }
275
276 fn weighted_average_parameters(
278 &mut self,
279 parameters: &[Tensor],
280 ) -> Result<Vec<Tensor>, DistributedError> {
281 let local_weight = self.partition_info.metrics.num_nodes as f32;
282 let total_weight = self.comm_manager.all_reduce_sum(&[local_weight])?[0];
283
284 let weighted_params = parameters
285 .iter()
286 .map(|param| {
287 param
288 .mul_scalar(local_weight)
289 .expect("parameter weighting should succeed")
290 })
291 .collect::<Vec<_>>();
292
293 let summed_params = self.sum_parameters(&weighted_params)?;
294
295 Ok(summed_params
296 .into_iter()
297 .map(|param| {
298 param
299 .div_scalar(total_weight)
300 .expect("weighted parameter division should succeed")
301 })
302 .collect())
303 }
304
305 fn parameter_server_sync(
307 &mut self,
308 parameters: &[Tensor],
309 ) -> Result<Vec<Tensor>, DistributedError> {
310 if self.config.rank == 0 {
311 self.parameter_server_master(parameters)
313 } else {
314 self.parameter_server_worker(parameters)
316 }
317 }
318
319 fn parameter_server_master(
320 &mut self,
321 parameters: &[Tensor],
322 ) -> Result<Vec<Tensor>, DistributedError> {
323 let mut accumulated_updates = parameters.to_vec();
325
326 for worker_rank in 1..self.config.num_workers {
327 let worker_updates = self.comm_manager.receive_from(worker_rank)?;
328 for (i, update) in worker_updates.iter().enumerate() {
330 if i < accumulated_updates.len() {
331 accumulated_updates[i] = accumulated_updates[i]
332 .add(update)
333 .expect("operation should succeed");
334 }
335 }
336 }
337
338 let num_workers = self.config.num_workers as f32;
340 let averaged_params: Vec<Tensor> = accumulated_updates
341 .into_iter()
342 .map(|param| {
343 param
344 .div_scalar(num_workers)
345 .expect("parameter server division should succeed")
346 })
347 .collect();
348
349 for worker_rank in 1..self.config.num_workers {
351 self.comm_manager.send_to(worker_rank, &averaged_params)?;
352 }
353
354 Ok(averaged_params)
355 }
356
357 fn parameter_server_worker(
358 &mut self,
359 parameters: &[Tensor],
360 ) -> Result<Vec<Tensor>, DistributedError> {
361 self.comm_manager.send_to(0, parameters)?;
363
364 self.comm_manager.receive_from(0)
366 }
367
368 fn gather_boundary_features(&mut self) -> Result<HashMap<usize, Tensor>, DistributedError> {
370 let mut boundary_features = HashMap::new();
371
372 for &(_, _, target_worker) in &self.partition_info.boundary_edges {
374 if target_worker != self.config.rank {
375 let features = self.comm_manager.request_boundary_features(target_worker)?;
377 boundary_features.insert(target_worker, features);
378 }
379 }
380
381 Ok(boundary_features)
382 }
383
384 fn augment_local_graph(
386 &self,
387 _boundary_features: &HashMap<usize, Tensor>,
388 ) -> Result<GraphData, DistributedError> {
389 Ok(self.local_partition.clone())
392 }
393
394 fn communicate_boundary_updates(
396 &mut self,
397 _local_output: &GraphData,
398 ) -> Result<(), DistributedError> {
399 Ok(())
402 }
403
404 fn partition_graph(
406 graph: &GraphData,
407 config: &DistributedConfig,
408 ) -> Result<Vec<GraphData>, DistributedError> {
409 match &config.partitioning {
410 GraphPartitioning::Random => Self::random_partition(graph, config.num_workers),
411 GraphPartitioning::Hash => Self::hash_partition(graph, config.num_workers),
412 GraphPartitioning::METIS => Self::metis_partition(graph, config.num_workers),
413 GraphPartitioning::Community => Self::community_partition(graph, config.num_workers),
414 GraphPartitioning::Custom(partition_fn) => {
415 let partition_infos = partition_fn(graph, config.num_workers);
416 Self::create_partitions_from_info(graph, &partition_infos)
417 }
418 }
419 }
420
421 fn random_partition(
422 graph: &GraphData,
423 num_partitions: usize,
424 ) -> Result<Vec<GraphData>, DistributedError> {
425 let mut partitions = Vec::new();
426 let nodes_per_partition = graph.num_nodes / num_partitions;
427
428 for i in 0..num_partitions {
429 let start_node = i * nodes_per_partition;
430 let end_node = if i == num_partitions - 1 {
431 graph.num_nodes
432 } else {
433 (i + 1) * nodes_per_partition
434 };
435
436 let partition_nodes = (start_node..end_node).collect::<Vec<_>>();
438 let partition_graph = Self::extract_subgraph(graph, &partition_nodes)?;
439 partitions.push(partition_graph);
440 }
441
442 Ok(partitions)
443 }
444
445 fn hash_partition(
446 graph: &GraphData,
447 num_partitions: usize,
448 ) -> Result<Vec<GraphData>, DistributedError> {
449 let mut partition_nodes: Vec<Vec<usize>> = vec![Vec::new(); num_partitions];
450
451 for node in 0..graph.num_nodes {
453 let partition_id = node % num_partitions;
454 partition_nodes[partition_id].push(node);
455 }
456
457 let mut partitions = Vec::new();
458 for nodes in partition_nodes {
459 let partition_graph = Self::extract_subgraph(graph, &nodes)?;
460 partitions.push(partition_graph);
461 }
462
463 Ok(partitions)
464 }
465
466 fn metis_partition(
467 _graph: &GraphData,
468 _num_partitions: usize,
469 ) -> Result<Vec<GraphData>, DistributedError> {
470 Err(DistributedError::PartitioningError(
472 "METIS partitioning not implemented".to_string(),
473 ))
474 }
475
476 fn community_partition(
477 _graph: &GraphData,
478 _num_partitions: usize,
479 ) -> Result<Vec<GraphData>, DistributedError> {
480 Err(DistributedError::PartitioningError(
482 "Community partitioning not implemented".to_string(),
483 ))
484 }
485
486 fn create_partitions_from_info(
487 graph: &GraphData,
488 partition_infos: &[PartitionInfo],
489 ) -> Result<Vec<GraphData>, DistributedError> {
490 let mut partitions = Vec::new();
491
492 for info in partition_infos {
493 let partition_graph = Self::extract_subgraph(graph, &info.nodes)?;
494 partitions.push(partition_graph);
495 }
496
497 Ok(partitions)
498 }
499
500 fn extract_subgraph(graph: &GraphData, nodes: &[usize]) -> Result<GraphData, DistributedError> {
501 if nodes.is_empty() {
505 return Ok(GraphData::new(
506 torsh_tensor::creation::zeros(&[0, graph.x.shape().dims()[1]])
507 .expect("empty features tensor creation should succeed"),
508 torsh_tensor::creation::zeros(&[2, 0])
509 .expect("empty edge index tensor creation should succeed"),
510 ));
511 }
512
513 let feature_dim = graph.x.shape().dims()[1];
515 let mut subgraph_features = Vec::new();
516
517 for &node in nodes {
518 if node < graph.num_nodes {
519 for _f in 0..feature_dim {
521 subgraph_features.push(1.0); }
523 }
524 }
525
526 let x = torsh_tensor::creation::from_vec(
527 subgraph_features,
528 &[nodes.len(), feature_dim],
529 graph.x.device(),
530 )
531 .map_err(|e| {
532 DistributedError::TensorError(format!("Failed to create features tensor: {:?}", e))
533 })?;
534
535 let edge_index = torsh_tensor::creation::zeros(&[2, 0])
537 .expect("minimal edge index creation should succeed");
538
539 Ok(GraphData::new(x, edge_index))
540 }
541
542 fn create_partition_info(graph: &GraphData, rank: usize) -> PartitionInfo {
543 PartitionInfo {
544 worker_rank: rank,
545 nodes: (0..graph.num_nodes).collect(),
546 internal_edges: Vec::new(),
547 boundary_edges: Vec::new(),
548 metrics: PartitionMetrics {
549 num_nodes: graph.num_nodes,
550 num_internal_edges: 0,
551 num_boundary_edges: 0,
552 load_balance_score: 0.0,
553 communication_cost: 0.0,
554 },
555 }
556 }
557
558 fn vec_to_tensor(&self, data: &[f32], shape: &[usize]) -> Result<Tensor, DistributedError> {
559 torsh_tensor::creation::from_vec(data.to_vec(), shape, torsh_core::device::DeviceType::Cpu)
560 .map_err(|e| DistributedError::TensorError(format!("Failed to create tensor: {:?}", e)))
561 }
562}
563
564#[derive(Debug)]
566pub struct CommunicationManager {
567 backend: CommunicationBackend,
568 rank: usize,
569 num_workers: usize,
570 }
572
573impl CommunicationManager {
574 pub fn new(config: &DistributedConfig) -> Result<Self, DistributedError> {
575 Ok(Self {
576 backend: config.backend.clone(),
577 rank: config.rank,
578 num_workers: config.num_workers,
579 })
580 }
581
582 pub fn rank(&self) -> usize {
584 self.rank
585 }
586
587 pub fn num_workers(&self) -> usize {
589 self.num_workers
590 }
591
592 pub fn all_reduce(&mut self, data: &[f32]) -> Result<Vec<f32>, DistributedError> {
593 match self.backend {
594 CommunicationBackend::InMemory => {
595 Ok(data.to_vec())
597 }
598 _ => Err(DistributedError::CommunicationError(
599 "Backend not implemented".to_string(),
600 )),
601 }
602 }
603
604 pub fn all_reduce_sum(&mut self, data: &[f32]) -> Result<Vec<f32>, DistributedError> {
605 Ok(data.to_vec())
607 }
608
609 pub fn send_to(
610 &mut self,
611 _target_rank: usize,
612 _data: &[Tensor],
613 ) -> Result<(), DistributedError> {
614 Ok(())
616 }
617
618 pub fn receive_from(&mut self, _source_rank: usize) -> Result<Vec<Tensor>, DistributedError> {
619 Ok(Vec::new())
621 }
622
623 pub fn request_boundary_features(
624 &mut self,
625 _target_worker: usize,
626 ) -> Result<Tensor, DistributedError> {
627 torsh_tensor::creation::zeros(&[1, 1])
629 .map_err(|e| DistributedError::TensorError(format!("Failed to create tensor: {:?}", e)))
630 }
631}
632
633#[derive(Debug)]
635pub struct SyncState {
636 pub current_step: usize,
637 pub last_sync_step: usize,
638 pub pending_updates: HashMap<usize, Vec<Tensor>>,
639}
640
641impl SyncState {
642 pub fn new() -> Self {
643 Self {
644 current_step: 0,
645 last_sync_step: 0,
646 pending_updates: HashMap::new(),
647 }
648 }
649
650 pub fn should_sync(&self, sync_frequency: usize) -> bool {
651 self.current_step - self.last_sync_step >= sync_frequency
652 }
653
654 pub fn mark_synced(&mut self) {
655 self.last_sync_step = self.current_step;
656 self.pending_updates.clear();
657 }
658}
659
660#[derive(Debug, Clone)]
662pub struct DistributedMetrics {
663 pub communication_time_ms: f64,
664 pub computation_time_ms: f64,
665 pub synchronization_time_ms: f64,
666 pub total_bytes_communicated: usize,
667 pub num_synchronizations: usize,
668 pub efficiency_score: f32,
669}
670
671impl DistributedMetrics {
672 pub fn new() -> Self {
673 Self {
674 communication_time_ms: 0.0,
675 computation_time_ms: 0.0,
676 synchronization_time_ms: 0.0,
677 total_bytes_communicated: 0,
678 num_synchronizations: 0,
679 efficiency_score: 1.0,
680 }
681 }
682
683 pub fn compute_efficiency(&mut self) {
684 let total_time = self.communication_time_ms + self.computation_time_ms;
685 if total_time > 0.0 {
686 self.efficiency_score = (self.computation_time_ms / total_time) as f32;
687 }
688 }
689}
690
691#[derive(Debug, Clone)]
693pub enum DistributedError {
694 CommunicationError(String),
696 PartitioningError(String),
698 TensorError(String),
700 ConfigError(String),
702 SynchronizationError(String),
704}
705
706impl std::fmt::Display for DistributedError {
707 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
708 match self {
709 DistributedError::CommunicationError(msg) => write!(f, "Communication error: {}", msg),
710 DistributedError::PartitioningError(msg) => write!(f, "Partitioning error: {}", msg),
711 DistributedError::TensorError(msg) => write!(f, "Tensor error: {}", msg),
712 DistributedError::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
713 DistributedError::SynchronizationError(msg) => {
714 write!(f, "Synchronization error: {}", msg)
715 }
716 }
717 }
718}
719
720impl std::error::Error for DistributedError {}
721
722#[derive(Debug)]
724pub struct DistributedGraphLayer {
725 pub base_layer: Box<dyn GraphLayer>,
727 pub coordinator: DistributedGNN,
729}
730
731impl DistributedGraphLayer {
732 pub fn new(
733 base_layer: Box<dyn GraphLayer>,
734 config: DistributedConfig,
735 full_graph: &GraphData,
736 ) -> Result<Self, DistributedError> {
737 let coordinator = DistributedGNN::new(config, full_graph)?;
738
739 Ok(Self {
740 base_layer,
741 coordinator,
742 })
743 }
744}
745
746impl GraphLayer for DistributedGraphLayer {
747 fn forward(&self, graph: &GraphData) -> GraphData {
748 self.base_layer.forward(graph)
751 }
752
753 fn parameters(&self) -> Vec<Tensor> {
754 self.base_layer.parameters()
755 }
756}
757
758pub mod utils {
760 use super::*;
761
762 pub fn calculate_load_balance(partition_sizes: &[usize]) -> f32 {
764 if partition_sizes.is_empty() {
765 return 0.0;
766 }
767
768 let mean_size = partition_sizes.iter().sum::<usize>() as f32 / partition_sizes.len() as f32;
769 let variance: f32 = partition_sizes
770 .iter()
771 .map(|&size| (size as f32 - mean_size).powi(2))
772 .sum::<f32>()
773 / partition_sizes.len() as f32;
774
775 variance / mean_size.max(1.0)
776 }
777
778 pub fn estimate_communication_cost(partition_infos: &[PartitionInfo]) -> f32 {
780 partition_infos
781 .iter()
782 .map(|info| info.metrics.num_boundary_edges as f32)
783 .sum()
784 }
785
786 pub fn create_optimal_config(num_gpus: usize, graph_size: usize) -> DistributedConfig {
788 let num_workers = num_gpus.max(1);
789 let backend = if num_gpus > 1 {
790 CommunicationBackend::NCCL
791 } else {
792 CommunicationBackend::InMemory
793 };
794
795 let partitioning = if graph_size > 1_000_000 {
796 GraphPartitioning::METIS
797 } else if graph_size > 10_000 {
798 GraphPartitioning::Community
799 } else {
800 GraphPartitioning::Hash
801 };
802
803 DistributedConfig {
804 num_workers,
805 rank: 0, backend,
807 partitioning,
808 aggregation: AggregationMethod::AllReduce,
809 sync_frequency: 10,
810 }
811 }
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 use torsh_tensor::creation::randn;
819
820 #[test]
821 fn test_distributed_config_creation() {
822 let config = DistributedConfig {
823 num_workers: 4,
824 rank: 0,
825 backend: CommunicationBackend::InMemory,
826 partitioning: GraphPartitioning::Random,
827 aggregation: AggregationMethod::Average,
828 sync_frequency: 10,
829 };
830
831 assert_eq!(config.num_workers, 4);
832 assert_eq!(config.rank, 0);
833 }
834
835 #[test]
836 fn test_load_balance_calculation() {
837 let partition_sizes = vec![100, 100, 100, 100];
838 let balance_score = utils::calculate_load_balance(&partition_sizes);
839 assert_eq!(balance_score, 0.0); let unbalanced_sizes = vec![200, 50, 50, 50];
842 let unbalanced_score = utils::calculate_load_balance(&unbalanced_sizes);
843 assert!(unbalanced_score > 0.0); }
845
846 #[test]
847 fn test_communication_cost_estimation() {
848 let partition_info = PartitionInfo {
849 worker_rank: 0,
850 nodes: vec![0, 1, 2],
851 internal_edges: vec![(0, 1)],
852 boundary_edges: vec![(2, 3, 1)],
853 metrics: PartitionMetrics {
854 num_nodes: 3,
855 num_internal_edges: 1,
856 num_boundary_edges: 1,
857 load_balance_score: 0.0,
858 communication_cost: 1.0,
859 },
860 };
861
862 let cost = utils::estimate_communication_cost(&[partition_info]);
863 assert_eq!(cost, 1.0);
864 }
865
866 #[test]
867 fn test_optimal_config_creation() {
868 let config = utils::create_optimal_config(4, 1_000_000);
869 assert_eq!(config.num_workers, 4);
870 assert_eq!(config.backend, CommunicationBackend::NCCL);
871
872 let small_config = utils::create_optimal_config(1, 1000);
873 assert_eq!(small_config.num_workers, 1);
874 assert_eq!(small_config.backend, CommunicationBackend::InMemory);
875 }
876
877 #[test]
878 fn test_sync_state() {
879 let mut sync_state = SyncState::new();
880 assert_eq!(sync_state.current_step, 0);
881 assert!(!sync_state.should_sync(10));
882
883 sync_state.current_step = 10;
884 assert!(sync_state.should_sync(10));
885
886 sync_state.mark_synced();
887 assert_eq!(sync_state.last_sync_step, 10);
888 }
889
890 #[test]
891 fn test_distributed_metrics() {
892 let mut metrics = DistributedMetrics::new();
893 metrics.computation_time_ms = 800.0;
894 metrics.communication_time_ms = 200.0;
895
896 metrics.compute_efficiency();
897 assert_eq!(metrics.efficiency_score, 0.8);
898 }
899
900 #[test]
901 fn test_partition_info_creation() {
902 let x = randn(&[5, 3]).unwrap();
903 let edge_index = torsh_tensor::creation::zeros(&[2, 0]).unwrap();
904 let graph = GraphData::new(x, edge_index);
905
906 let partition_info = DistributedGNN::create_partition_info(&graph, 0);
907 assert_eq!(partition_info.worker_rank, 0);
908 assert_eq!(partition_info.nodes.len(), 5);
909 assert_eq!(partition_info.metrics.num_nodes, 5);
910 }
911}