1use crate::{FxGraph, Node};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use std::collections::{HashMap, HashSet, VecDeque};
10use torsh_core::Result;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct DeviceInfo {
15 pub id: String,
16 pub device_type: DeviceType,
17 pub memory_capacity: usize, pub compute_capability: f64, pub bandwidth: f64, }
21
22impl Eq for DeviceInfo {}
23
24impl std::hash::Hash for DeviceInfo {
25 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
26 self.id.hash(state);
27 self.device_type.hash(state);
28 self.memory_capacity.hash(state);
29 self.compute_capability.to_bits().hash(state);
31 self.bandwidth.to_bits().hash(state);
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum DeviceType {
38 CPU,
39 CUDA(u8, u8), Metal,
41 OpenCL,
42 WebGPU,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq)]
47pub enum PartitioningStrategy {
48 MinCommunication,
50 LoadBalance,
52 MemoryOptimal,
54 Weighted {
56 communication_weight: f64,
57 load_balance_weight: f64,
58 memory_weight: f64,
59 },
60}
61
62#[derive(Debug, Clone)]
64pub struct GraphPartition {
65 pub device: DeviceInfo,
66 pub nodes: Vec<NodeIndex>,
67 pub local_edges: Vec<(NodeIndex, NodeIndex)>,
68 pub communication_edges: Vec<CommunicationEdge>,
69 pub estimated_memory: usize,
70 pub estimated_compute_time: f64,
71}
72
73#[derive(Debug, Clone)]
75pub struct CommunicationEdge {
76 pub source_partition: usize,
77 pub target_partition: usize,
78 pub source_node: NodeIndex,
79 pub target_node: NodeIndex,
80 pub data_size: usize,
81 pub communication_cost: f64,
82}
83
84#[derive(Debug, Clone)]
86pub struct PartitionedGraph {
87 pub partitions: Vec<GraphPartition>,
88 pub communication_schedule: CommunicationSchedule,
89 pub total_communication_cost: f64,
90 pub load_balance_score: f64,
91 pub memory_efficiency: f64,
92}
93
94#[derive(Debug, Clone)]
96pub struct CommunicationSchedule {
97 pub stages: Vec<CommunicationStage>,
98 pub total_stages: usize,
99}
100
101#[derive(Debug, Clone)]
103pub struct CommunicationStage {
104 pub stage_id: usize,
105 pub transfers: Vec<DataTransfer>,
106 pub dependencies: Vec<usize>, }
108
109#[derive(Debug, Clone)]
111pub struct DataTransfer {
112 pub source_device: String,
113 pub target_device: String,
114 pub data_id: String,
115 pub data_size: usize,
116 pub priority: TransferPriority,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
121pub enum TransferPriority {
122 Low = 0,
123 Medium = 1,
124 High = 2,
125 Critical = 3,
126}
127
128pub struct GraphPartitioner {
130 devices: Vec<DeviceInfo>,
131 strategy: PartitioningStrategy,
132 max_partitions: Option<usize>,
133}
134
135impl GraphPartitioner {
136 pub fn new(devices: Vec<DeviceInfo>, strategy: PartitioningStrategy) -> Self {
138 Self {
139 devices,
140 strategy,
141 max_partitions: None,
142 }
143 }
144
145 pub fn with_max_partitions(mut self, max_partitions: usize) -> Self {
147 self.max_partitions = Some(max_partitions);
148 self
149 }
150
151 pub fn partition(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
153 match self.strategy {
154 PartitioningStrategy::MinCommunication => self.partition_min_communication(graph),
155 PartitioningStrategy::LoadBalance => self.partition_load_balance(graph),
156 PartitioningStrategy::MemoryOptimal => self.partition_memory_optimal(graph),
157 PartitioningStrategy::Weighted { .. } => self.partition_weighted(graph),
158 }
159 }
160
161 fn partition_min_communication(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
162 let mut partitions = Vec::new();
163 let mut node_to_partition = HashMap::new();
164
165 let _node_weights = self.compute_node_weights(graph);
167 let _edge_weights = self.compute_edge_weights(graph);
168
169 let mut remaining_nodes: HashSet<NodeIndex> = graph.nodes().map(|(idx, _)| idx).collect();
171
172 for (device_idx, device) in self.devices.iter().enumerate() {
173 if remaining_nodes.is_empty() {
174 break;
175 }
176
177 let mut partition_nodes = Vec::new();
178 let target_size = remaining_nodes.len() / (self.devices.len() - device_idx);
179
180 let start_node = if let Some(&node) = remaining_nodes.iter().next() {
182 node
183 } else {
184 break;
185 };
186
187 let mut to_visit = VecDeque::new();
188 to_visit.push_back(start_node);
189 remaining_nodes.remove(&start_node);
190
191 while let Some(current_node) = to_visit.pop_front() {
193 partition_nodes.push(current_node);
194 node_to_partition.insert(current_node, device_idx);
195
196 if partition_nodes.len() >= target_size {
197 break;
198 }
199
200 let neighbors = self.get_neighbors(graph, current_node);
202 for neighbor in neighbors {
203 if remaining_nodes.contains(&neighbor) {
204 to_visit.push_back(neighbor);
205 remaining_nodes.remove(&neighbor);
206 }
207 }
208 }
209
210 partitions.push(GraphPartition {
211 device: device.clone(),
212 nodes: partition_nodes,
213 local_edges: Vec::new(),
214 communication_edges: Vec::new(),
215 estimated_memory: 0,
216 estimated_compute_time: 0.0,
217 });
218 }
219
220 self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
222
223 let communication_schedule = self.create_communication_schedule(&partitions)?;
224 let metrics = self.compute_partition_metrics(&partitions);
225
226 Ok(PartitionedGraph {
227 partitions,
228 communication_schedule,
229 total_communication_cost: metrics.0,
230 load_balance_score: metrics.1,
231 memory_efficiency: metrics.2,
232 })
233 }
234
235 fn partition_load_balance(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
236 let node_weights = self.compute_node_weights(graph);
237 let total_weight: f64 = node_weights.values().sum();
238 let target_weight_per_device = total_weight / self.devices.len() as f64;
239
240 let mut partitions = Vec::new();
241 let mut node_to_partition = HashMap::new();
242 let mut remaining_nodes: Vec<_> = graph.nodes().map(|(idx, _)| idx).collect();
243
244 remaining_nodes.sort_by(|&a, &b| {
246 node_weights
247 .get(&b)
248 .unwrap_or(&0.0)
249 .partial_cmp(node_weights.get(&a).unwrap_or(&0.0))
250 .expect("node weights should be comparable")
251 });
252
253 for (device_idx, device) in self.devices.iter().enumerate() {
254 let mut partition_nodes = Vec::new();
255 let mut current_weight = 0.0;
256 let adjusted_target = target_weight_per_device * device.compute_capability;
257
258 let mut i = 0;
259 while i < remaining_nodes.len() && current_weight < adjusted_target {
260 let node = remaining_nodes[i];
261 let node_weight = *node_weights.get(&node).unwrap_or(&0.0);
262
263 if current_weight + node_weight <= adjusted_target * 1.2
264 || partition_nodes.is_empty()
265 {
266 partition_nodes.push(node);
267 node_to_partition.insert(node, device_idx);
268 current_weight += node_weight;
269 remaining_nodes.remove(i);
270 } else {
271 i += 1;
272 }
273 }
274
275 partitions.push(GraphPartition {
276 device: device.clone(),
277 nodes: partition_nodes,
278 local_edges: Vec::new(),
279 communication_edges: Vec::new(),
280 estimated_memory: 0,
281 estimated_compute_time: current_weight,
282 });
283 }
284
285 for node in remaining_nodes {
287 let min_partition = partitions
288 .iter()
289 .enumerate()
290 .min_by_key(|(_, p)| p.estimated_compute_time as u64)
291 .map(|(idx, _)| idx)
292 .unwrap_or(0);
293
294 partitions[min_partition].nodes.push(node);
295 node_to_partition.insert(node, min_partition);
296 }
297
298 self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
299
300 let communication_schedule = self.create_communication_schedule(&partitions)?;
301 let metrics = self.compute_partition_metrics(&partitions);
302
303 Ok(PartitionedGraph {
304 partitions,
305 communication_schedule,
306 total_communication_cost: metrics.0,
307 load_balance_score: metrics.1,
308 memory_efficiency: metrics.2,
309 })
310 }
311
312 fn partition_memory_optimal(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
313 let node_memory = self.compute_node_memory_usage(graph);
314
315 let mut partitions = Vec::new();
316 let mut node_to_partition = HashMap::new();
317 let mut remaining_nodes: Vec<_> = graph.nodes().map(|(idx, _)| idx).collect();
318
319 for (device_idx, device) in self.devices.iter().enumerate() {
320 let mut partition_nodes = Vec::new();
321 let mut current_memory = 0;
322 let memory_limit = device.memory_capacity;
323
324 let mut i = 0;
325 while i < remaining_nodes.len() {
326 let node = remaining_nodes[i];
327 let node_mem = *node_memory.get(&node).unwrap_or(&0);
328
329 if current_memory + node_mem <= memory_limit || partition_nodes.is_empty() {
330 partition_nodes.push(node);
331 node_to_partition.insert(node, device_idx);
332 current_memory += node_mem;
333 remaining_nodes.remove(i);
334 } else {
335 i += 1;
336 }
337 }
338
339 partitions.push(GraphPartition {
340 device: device.clone(),
341 nodes: partition_nodes,
342 local_edges: Vec::new(),
343 communication_edges: Vec::new(),
344 estimated_memory: current_memory,
345 estimated_compute_time: 0.0,
346 });
347 }
348
349 self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
350
351 let communication_schedule = self.create_communication_schedule(&partitions)?;
352 let metrics = self.compute_partition_metrics(&partitions);
353
354 Ok(PartitionedGraph {
355 partitions,
356 communication_schedule,
357 total_communication_cost: metrics.0,
358 load_balance_score: metrics.1,
359 memory_efficiency: metrics.2,
360 })
361 }
362
363 fn partition_weighted(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
364 self.partition_load_balance(graph)
367 }
368
369 fn compute_node_weights(&self, graph: &FxGraph) -> HashMap<NodeIndex, f64> {
370 let mut weights = HashMap::new();
371
372 for (idx, node) in graph.nodes() {
373 let weight = match node {
374 Node::Input(_) => 0.1,
375 Node::Output => 0.1,
376 Node::Call(op_name, _) => self.get_operation_weight(op_name),
377 Node::Conditional { .. } => 2.0,
378 Node::Loop { .. } => 5.0,
379 Node::Merge { .. } => 0.5,
380 Node::GetAttr { .. } => 0.1,
381 };
382 weights.insert(idx, weight);
383 }
384
385 weights
386 }
387
388 fn compute_edge_weights(&self, graph: &FxGraph) -> HashMap<(NodeIndex, NodeIndex), f64> {
389 let mut weights = HashMap::new();
390
391 for edge_ref in graph.graph.edge_references() {
392 let source = edge_ref.source();
393 let target = edge_ref.target();
394
395 let weight = 1.0; weights.insert((source, target), weight);
398 }
399
400 weights
401 }
402
403 fn compute_node_memory_usage(&self, graph: &FxGraph) -> HashMap<NodeIndex, usize> {
404 let mut memory = HashMap::new();
405
406 for (idx, node) in graph.nodes() {
407 let mem_usage = match node {
408 Node::Input(_) => 1024 * 1024, Node::Output => 0,
410 Node::Call(op_name, _) => self.get_operation_memory(op_name),
411 Node::Conditional { .. } => 512 * 1024,
412 Node::Loop { .. } => 2 * 1024 * 1024,
413 Node::Merge { .. } => 256 * 1024,
414 Node::GetAttr { .. } => 0,
415 };
416 memory.insert(idx, mem_usage);
417 }
418
419 memory
420 }
421
422 fn get_operation_weight(&self, op_name: &str) -> f64 {
423 match op_name {
424 "add" | "sub" | "mul" | "div" => 1.0,
425 "relu" | "sigmoid" | "tanh" => 1.5,
426 "conv2d" => 10.0,
427 "matmul" => 8.0,
428 "batch_norm" => 3.0,
429 "softmax" => 4.0,
430 _ => 2.0, }
432 }
433
434 fn get_operation_memory(&self, op_name: &str) -> usize {
435 match op_name {
436 "add" | "sub" | "mul" | "div" => 512 * 1024,
437 "relu" | "sigmoid" | "tanh" => 256 * 1024,
438 "conv2d" => 10 * 1024 * 1024,
439 "matmul" => 8 * 1024 * 1024,
440 "batch_norm" => 2 * 1024 * 1024,
441 "softmax" => 1 * 1024 * 1024,
442 _ => 1 * 1024 * 1024, }
444 }
445
446 fn get_neighbors(&self, graph: &FxGraph, node: NodeIndex) -> Vec<NodeIndex> {
447 let mut neighbors = Vec::new();
448
449 for edge_ref in graph
451 .graph
452 .edges_directed(node, petgraph::Direction::Incoming)
453 {
454 neighbors.push(edge_ref.source());
455 }
456
457 for edge_ref in graph
459 .graph
460 .edges_directed(node, petgraph::Direction::Outgoing)
461 {
462 neighbors.push(edge_ref.target());
463 }
464
465 neighbors
466 }
467
468 fn compute_partition_edges(
469 &self,
470 graph: &FxGraph,
471 partitions: &mut [GraphPartition],
472 node_to_partition: &HashMap<NodeIndex, usize>,
473 ) -> Result<()> {
474 for partition in partitions.iter_mut() {
476 partition.local_edges.clear();
477 partition.communication_edges.clear();
478 }
479
480 for edge_ref in graph.graph.edge_references() {
481 let source = edge_ref.source();
482 let target = edge_ref.target();
483
484 let source_partition = match node_to_partition.get(&source) {
485 Some(partition) => *partition,
486 None => continue, };
488 let target_partition = match node_to_partition.get(&target) {
489 Some(partition) => *partition,
490 None => continue, };
492
493 if source_partition == target_partition {
494 partitions[source_partition]
496 .local_edges
497 .push((source, target));
498 } else {
499 let comm_edge = CommunicationEdge {
501 source_partition,
502 target_partition,
503 source_node: source,
504 target_node: target,
505 data_size: 1024, communication_cost: self.compute_communication_cost(
507 &partitions[source_partition].device,
508 &partitions[target_partition].device,
509 1024,
510 ),
511 };
512
513 partitions[source_partition]
514 .communication_edges
515 .push(comm_edge);
516 }
517 }
518
519 Ok(())
520 }
521
522 fn compute_communication_cost(
523 &self,
524 source: &DeviceInfo,
525 target: &DeviceInfo,
526 data_size: usize,
527 ) -> f64 {
528 let bandwidth = source.bandwidth.min(target.bandwidth);
529 let latency = if source.device_type == target.device_type {
530 0.001
531 } else {
532 0.01
533 };
534
535 (data_size as f64) / bandwidth + latency
536 }
537
538 fn create_communication_schedule(
539 &self,
540 partitions: &[GraphPartition],
541 ) -> Result<CommunicationSchedule> {
542 let mut stages = Vec::new();
543 let mut processed_transfers = HashSet::new();
544 let mut stage_id = 0;
545
546 let mut remaining_edges: Vec<_> = partitions
548 .iter()
549 .enumerate()
550 .flat_map(|(partition_idx, partition)| {
551 partition
552 .communication_edges
553 .iter()
554 .map(move |edge| (partition_idx, edge))
555 })
556 .collect();
557
558 while !remaining_edges.is_empty() {
559 let mut current_stage = CommunicationStage {
560 stage_id,
561 transfers: Vec::new(),
562 dependencies: Vec::new(),
563 };
564
565 let mut i = 0;
566 while i < remaining_edges.len() {
567 let (_, edge) = &remaining_edges[i];
568 let transfer_key = (
569 edge.source_partition,
570 edge.target_partition,
571 edge.source_node,
572 edge.target_node,
573 );
574
575 if !processed_transfers.contains(&transfer_key) {
576 let transfer = DataTransfer {
577 source_device: partitions[edge.source_partition].device.id.clone(),
578 target_device: partitions[edge.target_partition].device.id.clone(),
579 data_id: format!(
580 "data_{}_{}",
581 edge.source_node.index(),
582 edge.target_node.index()
583 ),
584 data_size: edge.data_size,
585 priority: TransferPriority::Medium,
586 };
587
588 current_stage.transfers.push(transfer);
589 processed_transfers.insert(transfer_key);
590 remaining_edges.remove(i);
591 } else {
592 i += 1;
593 }
594 }
595
596 if !current_stage.transfers.is_empty() {
597 stages.push(current_stage);
598 stage_id += 1;
599 }
600 }
601
602 Ok(CommunicationSchedule {
603 total_stages: stages.len(),
604 stages,
605 })
606 }
607
608 fn compute_partition_metrics(&self, partitions: &[GraphPartition]) -> (f64, f64, f64) {
609 let total_communication_cost = partitions
610 .iter()
611 .flat_map(|p| &p.communication_edges)
612 .map(|edge| edge.communication_cost)
613 .sum();
614
615 let compute_times: Vec<f64> = partitions
616 .iter()
617 .map(|p| p.estimated_compute_time)
618 .collect();
619 let max_compute_time = compute_times.iter().cloned().fold(0.0, f64::max);
620 let avg_compute_time = compute_times.iter().sum::<f64>() / compute_times.len() as f64;
621 let load_balance_score = if max_compute_time > 0.0 {
622 avg_compute_time / max_compute_time
623 } else {
624 1.0
625 };
626
627 let memory_usage: Vec<usize> = partitions.iter().map(|p| p.estimated_memory).collect();
628 let total_memory = memory_usage.iter().sum::<usize>();
629 let total_capacity: usize = partitions.iter().map(|p| p.device.memory_capacity).sum();
630 let memory_efficiency = if total_capacity > 0 {
631 total_memory as f64 / total_capacity as f64
632 } else {
633 0.0
634 };
635
636 (
637 total_communication_cost,
638 load_balance_score,
639 memory_efficiency,
640 )
641 }
642}
643
644impl GraphPartitioner {
646 pub fn create_cpu_cluster(num_devices: usize) -> Vec<DeviceInfo> {
648 (0..num_devices)
649 .map(|i| DeviceInfo {
650 id: format!("cpu_{i}"),
651 device_type: DeviceType::CPU,
652 memory_capacity: 8 * 1024 * 1024 * 1024, compute_capability: 1.0,
654 bandwidth: 10_000_000_000.0, })
656 .collect()
657 }
658
659 pub fn create_heterogeneous_cluster() -> Vec<DeviceInfo> {
661 vec![
662 DeviceInfo {
663 id: "cpu_0".to_string(),
664 device_type: DeviceType::CPU,
665 memory_capacity: 16 * 1024 * 1024 * 1024, compute_capability: 1.0,
667 bandwidth: 50_000_000_000.0, },
669 DeviceInfo {
670 id: "cuda_0".to_string(),
671 device_type: DeviceType::CUDA(8, 0), memory_capacity: 10 * 1024 * 1024 * 1024, compute_capability: 5.0,
674 bandwidth: 760_000_000_000.0, },
676 DeviceInfo {
677 id: "cuda_1".to_string(),
678 device_type: DeviceType::CUDA(8, 6), memory_capacity: 24 * 1024 * 1024 * 1024, compute_capability: 8.0,
681 bandwidth: 1_000_000_000_000.0, },
683 ]
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::{Edge, FxGraph, Node};
691
692 #[test]
693 fn test_graph_partitioning_min_communication() {
694 let mut graph = FxGraph::new();
695 let input1 = graph.graph.add_node(Node::Input("x".to_string()));
696 let input2 = graph.graph.add_node(Node::Input("y".to_string()));
697 let add = graph.graph.add_node(Node::Call(
698 "add".to_string(),
699 vec!["x".to_string(), "y".to_string()],
700 ));
701 let relu = graph
702 .graph
703 .add_node(Node::Call("relu".to_string(), vec!["add_out".to_string()]));
704 let output = graph.graph.add_node(Node::Output);
705
706 graph.graph.add_edge(
707 input1,
708 add,
709 Edge {
710 name: "x".to_string(),
711 },
712 );
713 graph.graph.add_edge(
714 input2,
715 add,
716 Edge {
717 name: "y".to_string(),
718 },
719 );
720 graph.graph.add_edge(
721 add,
722 relu,
723 Edge {
724 name: "add_out".to_string(),
725 },
726 );
727 graph.graph.add_edge(
728 relu,
729 output,
730 Edge {
731 name: "relu_out".to_string(),
732 },
733 );
734
735 graph.inputs = vec![input1, input2];
736 graph.outputs = vec![output];
737
738 let devices = GraphPartitioner::create_cpu_cluster(2);
739 let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::MinCommunication);
740
741 let result = partitioner.partition(&graph).unwrap();
742
743 assert_eq!(result.partitions.len(), 2);
744 assert!(result.total_communication_cost >= 0.0);
745 assert!(result.load_balance_score > 0.0);
746 }
747
748 #[test]
749 fn test_graph_partitioning_load_balance() {
750 let mut graph = FxGraph::new();
751 let input = graph.graph.add_node(Node::Input("x".to_string()));
752
753 let mut prev = input;
755 for i in 0..6 {
756 let op = graph
757 .graph
758 .add_node(Node::Call("matmul".to_string(), vec![format!("input_{i}")]));
759 graph.graph.add_edge(
760 prev,
761 op,
762 Edge {
763 name: format!("edge_{i}"),
764 },
765 );
766 prev = op;
767 }
768
769 let output = graph.graph.add_node(Node::Output);
770 graph.graph.add_edge(
771 prev,
772 output,
773 Edge {
774 name: "final".to_string(),
775 },
776 );
777
778 graph.inputs = vec![input];
779 graph.outputs = vec![output];
780
781 let devices = GraphPartitioner::create_heterogeneous_cluster();
782 let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::LoadBalance);
783
784 let result = partitioner.partition(&graph).unwrap();
785
786 assert_eq!(result.partitions.len(), 3);
787 assert!(result.load_balance_score > 0.0);
788
789 let gpu_partitions: Vec<_> = result
791 .partitions
792 .iter()
793 .filter(|p| matches!(p.device.device_type, DeviceType::CUDA(_, _)))
794 .collect();
795
796 assert!(!gpu_partitions.is_empty());
797 }
798
799 #[test]
800 fn test_communication_schedule() {
801 let devices = vec![
802 DeviceInfo {
803 id: "device_0".to_string(),
804 device_type: DeviceType::CPU,
805 memory_capacity: 1024 * 1024 * 1024,
806 compute_capability: 1.0,
807 bandwidth: 1_000_000_000.0,
808 },
809 DeviceInfo {
810 id: "device_1".to_string(),
811 device_type: DeviceType::CPU,
812 memory_capacity: 1024 * 1024 * 1024,
813 compute_capability: 1.0,
814 bandwidth: 1_000_000_000.0,
815 },
816 ];
817
818 let partitions = vec![
819 GraphPartition {
820 device: devices[0].clone(),
821 nodes: vec![],
822 local_edges: vec![],
823 communication_edges: vec![CommunicationEdge {
824 source_partition: 0,
825 target_partition: 1,
826 source_node: NodeIndex::new(0),
827 target_node: NodeIndex::new(1),
828 data_size: 1024,
829 communication_cost: 0.001,
830 }],
831 estimated_memory: 0,
832 estimated_compute_time: 0.0,
833 },
834 GraphPartition {
835 device: devices[1].clone(),
836 nodes: vec![],
837 local_edges: vec![],
838 communication_edges: vec![],
839 estimated_memory: 0,
840 estimated_compute_time: 0.0,
841 },
842 ];
843
844 let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::MinCommunication);
845 let schedule = partitioner
846 .create_communication_schedule(&partitions)
847 .unwrap();
848
849 assert!(schedule.total_stages > 0);
850 assert!(!schedule.stages.is_empty());
851 assert!(!schedule.stages[0].transfers.is_empty());
852 }
853}