Skip to main content

torsh_fx/
graph_partitioning.rs

1//! Graph partitioning module for distributed execution
2//!
3//! This module provides functionality to partition FX graphs across multiple devices or nodes
4//! for distributed processing.
5
6use crate::{FxGraph, Node};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use std::collections::{HashMap, HashSet, VecDeque};
10use torsh_core::Result;
11
12/// Device information for graph partitioning
13#[derive(Debug, Clone, PartialEq)]
14pub struct DeviceInfo {
15    pub id: String,
16    pub device_type: DeviceType,
17    pub memory_capacity: usize,  // in bytes
18    pub compute_capability: f64, // relative compute power
19    pub bandwidth: f64,          // communication bandwidth
20}
21
22impl Eq for DeviceInfo {}
23
24impl std::hash::Hash for DeviceInfo {
25    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
26        self.id.hash(state);
27        self.device_type.hash(state);
28        self.memory_capacity.hash(state);
29        // Hash f64 values as bits to make them hashable
30        self.compute_capability.to_bits().hash(state);
31        self.bandwidth.to_bits().hash(state);
32    }
33}
34
35/// Device types for partitioning
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum DeviceType {
38    CPU,
39    CUDA(u8, u8), // compute capability
40    Metal,
41    OpenCL,
42    WebGPU,
43}
44
45/// Partitioning strategy
46#[derive(Debug, Clone, Copy, PartialEq)]
47pub enum PartitioningStrategy {
48    /// Minimize communication between partitions
49    MinCommunication,
50    /// Balance computational load across devices
51    LoadBalance,
52    /// Minimize memory usage per device
53    MemoryOptimal,
54    /// Custom strategy with weights
55    Weighted {
56        communication_weight: f64,
57        load_balance_weight: f64,
58        memory_weight: f64,
59    },
60}
61
62/// Graph partition representation
63#[derive(Debug, Clone)]
64pub struct GraphPartition {
65    pub device: DeviceInfo,
66    pub nodes: Vec<NodeIndex>,
67    pub local_edges: Vec<(NodeIndex, NodeIndex)>,
68    pub communication_edges: Vec<CommunicationEdge>,
69    pub estimated_memory: usize,
70    pub estimated_compute_time: f64,
71}
72
73/// Communication edge between partitions
74#[derive(Debug, Clone)]
75pub struct CommunicationEdge {
76    pub source_partition: usize,
77    pub target_partition: usize,
78    pub source_node: NodeIndex,
79    pub target_node: NodeIndex,
80    pub data_size: usize,
81    pub communication_cost: f64,
82}
83
84/// Partitioned graph result
85#[derive(Debug, Clone)]
86pub struct PartitionedGraph {
87    pub partitions: Vec<GraphPartition>,
88    pub communication_schedule: CommunicationSchedule,
89    pub total_communication_cost: f64,
90    pub load_balance_score: f64,
91    pub memory_efficiency: f64,
92}
93
94/// Communication schedule for coordinating data transfer
95#[derive(Debug, Clone)]
96pub struct CommunicationSchedule {
97    pub stages: Vec<CommunicationStage>,
98    pub total_stages: usize,
99}
100
101/// Communication stage with parallel transfers
102#[derive(Debug, Clone)]
103pub struct CommunicationStage {
104    pub stage_id: usize,
105    pub transfers: Vec<DataTransfer>,
106    pub dependencies: Vec<usize>, // prerequisite stages
107}
108
109/// Data transfer between devices
110#[derive(Debug, Clone)]
111pub struct DataTransfer {
112    pub source_device: String,
113    pub target_device: String,
114    pub data_id: String,
115    pub data_size: usize,
116    pub priority: TransferPriority,
117}
118
119/// Transfer priority for scheduling
120#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
121pub enum TransferPriority {
122    Low = 0,
123    Medium = 1,
124    High = 2,
125    Critical = 3,
126}
127
128/// Graph partitioner implementation
129pub struct GraphPartitioner {
130    devices: Vec<DeviceInfo>,
131    strategy: PartitioningStrategy,
132    max_partitions: Option<usize>,
133}
134
135impl GraphPartitioner {
136    /// Create a new graph partitioner
137    pub fn new(devices: Vec<DeviceInfo>, strategy: PartitioningStrategy) -> Self {
138        Self {
139            devices,
140            strategy,
141            max_partitions: None,
142        }
143    }
144
145    /// Set maximum number of partitions
146    pub fn with_max_partitions(mut self, max_partitions: usize) -> Self {
147        self.max_partitions = Some(max_partitions);
148        self
149    }
150
151    /// Partition the graph according to the strategy
152    pub fn partition(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
153        match self.strategy {
154            PartitioningStrategy::MinCommunication => self.partition_min_communication(graph),
155            PartitioningStrategy::LoadBalance => self.partition_load_balance(graph),
156            PartitioningStrategy::MemoryOptimal => self.partition_memory_optimal(graph),
157            PartitioningStrategy::Weighted { .. } => self.partition_weighted(graph),
158        }
159    }
160
161    fn partition_min_communication(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
162        let mut partitions = Vec::new();
163        let mut node_to_partition = HashMap::new();
164
165        // Start with a simple graph cut algorithm
166        let _node_weights = self.compute_node_weights(graph);
167        let _edge_weights = self.compute_edge_weights(graph);
168
169        // Use a greedy approach to minimize cut edges
170        let mut remaining_nodes: HashSet<NodeIndex> = graph.nodes().map(|(idx, _)| idx).collect();
171
172        for (device_idx, device) in self.devices.iter().enumerate() {
173            if remaining_nodes.is_empty() {
174                break;
175            }
176
177            let mut partition_nodes = Vec::new();
178            let target_size = remaining_nodes.len() / (self.devices.len() - device_idx);
179
180            // Start with a random node or input node
181            let start_node = if let Some(&node) = remaining_nodes.iter().next() {
182                node
183            } else {
184                break;
185            };
186
187            let mut to_visit = VecDeque::new();
188            to_visit.push_back(start_node);
189            remaining_nodes.remove(&start_node);
190
191            // BFS expansion while minimizing communication
192            while let Some(current_node) = to_visit.pop_front() {
193                partition_nodes.push(current_node);
194                node_to_partition.insert(current_node, device_idx);
195
196                if partition_nodes.len() >= target_size {
197                    break;
198                }
199
200                // Add neighbors that minimize communication cost
201                let neighbors = self.get_neighbors(graph, current_node);
202                for neighbor in neighbors {
203                    if remaining_nodes.contains(&neighbor) {
204                        to_visit.push_back(neighbor);
205                        remaining_nodes.remove(&neighbor);
206                    }
207                }
208            }
209
210            partitions.push(GraphPartition {
211                device: device.clone(),
212                nodes: partition_nodes,
213                local_edges: Vec::new(),
214                communication_edges: Vec::new(),
215                estimated_memory: 0,
216                estimated_compute_time: 0.0,
217            });
218        }
219
220        // Compute edges and communication costs
221        self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
222
223        let communication_schedule = self.create_communication_schedule(&partitions)?;
224        let metrics = self.compute_partition_metrics(&partitions);
225
226        Ok(PartitionedGraph {
227            partitions,
228            communication_schedule,
229            total_communication_cost: metrics.0,
230            load_balance_score: metrics.1,
231            memory_efficiency: metrics.2,
232        })
233    }
234
235    fn partition_load_balance(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
236        let node_weights = self.compute_node_weights(graph);
237        let total_weight: f64 = node_weights.values().sum();
238        let target_weight_per_device = total_weight / self.devices.len() as f64;
239
240        let mut partitions = Vec::new();
241        let mut node_to_partition = HashMap::new();
242        let mut remaining_nodes: Vec<_> = graph.nodes().map(|(idx, _)| idx).collect();
243
244        // Sort nodes by weight (descending) for better load balancing
245        remaining_nodes.sort_by(|&a, &b| {
246            node_weights
247                .get(&b)
248                .unwrap_or(&0.0)
249                .partial_cmp(node_weights.get(&a).unwrap_or(&0.0))
250                .expect("node weights should be comparable")
251        });
252
253        for (device_idx, device) in self.devices.iter().enumerate() {
254            let mut partition_nodes = Vec::new();
255            let mut current_weight = 0.0;
256            let adjusted_target = target_weight_per_device * device.compute_capability;
257
258            let mut i = 0;
259            while i < remaining_nodes.len() && current_weight < adjusted_target {
260                let node = remaining_nodes[i];
261                let node_weight = *node_weights.get(&node).unwrap_or(&0.0);
262
263                if current_weight + node_weight <= adjusted_target * 1.2
264                    || partition_nodes.is_empty()
265                {
266                    partition_nodes.push(node);
267                    node_to_partition.insert(node, device_idx);
268                    current_weight += node_weight;
269                    remaining_nodes.remove(i);
270                } else {
271                    i += 1;
272                }
273            }
274
275            partitions.push(GraphPartition {
276                device: device.clone(),
277                nodes: partition_nodes,
278                local_edges: Vec::new(),
279                communication_edges: Vec::new(),
280                estimated_memory: 0,
281                estimated_compute_time: current_weight,
282            });
283        }
284
285        // Handle remaining nodes
286        for node in remaining_nodes {
287            let min_partition = partitions
288                .iter()
289                .enumerate()
290                .min_by_key(|(_, p)| p.estimated_compute_time as u64)
291                .map(|(idx, _)| idx)
292                .unwrap_or(0);
293
294            partitions[min_partition].nodes.push(node);
295            node_to_partition.insert(node, min_partition);
296        }
297
298        self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
299
300        let communication_schedule = self.create_communication_schedule(&partitions)?;
301        let metrics = self.compute_partition_metrics(&partitions);
302
303        Ok(PartitionedGraph {
304            partitions,
305            communication_schedule,
306            total_communication_cost: metrics.0,
307            load_balance_score: metrics.1,
308            memory_efficiency: metrics.2,
309        })
310    }
311
312    fn partition_memory_optimal(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
313        let node_memory = self.compute_node_memory_usage(graph);
314
315        let mut partitions = Vec::new();
316        let mut node_to_partition = HashMap::new();
317        let mut remaining_nodes: Vec<_> = graph.nodes().map(|(idx, _)| idx).collect();
318
319        for (device_idx, device) in self.devices.iter().enumerate() {
320            let mut partition_nodes = Vec::new();
321            let mut current_memory = 0;
322            let memory_limit = device.memory_capacity;
323
324            let mut i = 0;
325            while i < remaining_nodes.len() {
326                let node = remaining_nodes[i];
327                let node_mem = *node_memory.get(&node).unwrap_or(&0);
328
329                if current_memory + node_mem <= memory_limit || partition_nodes.is_empty() {
330                    partition_nodes.push(node);
331                    node_to_partition.insert(node, device_idx);
332                    current_memory += node_mem;
333                    remaining_nodes.remove(i);
334                } else {
335                    i += 1;
336                }
337            }
338
339            partitions.push(GraphPartition {
340                device: device.clone(),
341                nodes: partition_nodes,
342                local_edges: Vec::new(),
343                communication_edges: Vec::new(),
344                estimated_memory: current_memory,
345                estimated_compute_time: 0.0,
346            });
347        }
348
349        self.compute_partition_edges(graph, &mut partitions, &node_to_partition)?;
350
351        let communication_schedule = self.create_communication_schedule(&partitions)?;
352        let metrics = self.compute_partition_metrics(&partitions);
353
354        Ok(PartitionedGraph {
355            partitions,
356            communication_schedule,
357            total_communication_cost: metrics.0,
358            load_balance_score: metrics.1,
359            memory_efficiency: metrics.2,
360        })
361    }
362
363    fn partition_weighted(&self, graph: &FxGraph) -> Result<PartitionedGraph> {
364        // Implement a weighted combination of strategies
365        // For now, use load balance as the primary strategy
366        self.partition_load_balance(graph)
367    }
368
369    fn compute_node_weights(&self, graph: &FxGraph) -> HashMap<NodeIndex, f64> {
370        let mut weights = HashMap::new();
371
372        for (idx, node) in graph.nodes() {
373            let weight = match node {
374                Node::Input(_) => 0.1,
375                Node::Output => 0.1,
376                Node::Call(op_name, _) => self.get_operation_weight(op_name),
377                Node::Conditional { .. } => 2.0,
378                Node::Loop { .. } => 5.0,
379                Node::Merge { .. } => 0.5,
380                Node::GetAttr { .. } => 0.1,
381            };
382            weights.insert(idx, weight);
383        }
384
385        weights
386    }
387
388    fn compute_edge_weights(&self, graph: &FxGraph) -> HashMap<(NodeIndex, NodeIndex), f64> {
389        let mut weights = HashMap::new();
390
391        for edge_ref in graph.graph.edge_references() {
392            let source = edge_ref.source();
393            let target = edge_ref.target();
394
395            // Estimate data size and communication cost
396            let weight = 1.0; // Default weight
397            weights.insert((source, target), weight);
398        }
399
400        weights
401    }
402
403    fn compute_node_memory_usage(&self, graph: &FxGraph) -> HashMap<NodeIndex, usize> {
404        let mut memory = HashMap::new();
405
406        for (idx, node) in graph.nodes() {
407            let mem_usage = match node {
408                Node::Input(_) => 1024 * 1024, // 1MB default
409                Node::Output => 0,
410                Node::Call(op_name, _) => self.get_operation_memory(op_name),
411                Node::Conditional { .. } => 512 * 1024,
412                Node::Loop { .. } => 2 * 1024 * 1024,
413                Node::Merge { .. } => 256 * 1024,
414                Node::GetAttr { .. } => 0,
415            };
416            memory.insert(idx, mem_usage);
417        }
418
419        memory
420    }
421
422    fn get_operation_weight(&self, op_name: &str) -> f64 {
423        match op_name {
424            "add" | "sub" | "mul" | "div" => 1.0,
425            "relu" | "sigmoid" | "tanh" => 1.5,
426            "conv2d" => 10.0,
427            "matmul" => 8.0,
428            "batch_norm" => 3.0,
429            "softmax" => 4.0,
430            _ => 2.0, // Default weight
431        }
432    }
433
434    fn get_operation_memory(&self, op_name: &str) -> usize {
435        match op_name {
436            "add" | "sub" | "mul" | "div" => 512 * 1024,
437            "relu" | "sigmoid" | "tanh" => 256 * 1024,
438            "conv2d" => 10 * 1024 * 1024,
439            "matmul" => 8 * 1024 * 1024,
440            "batch_norm" => 2 * 1024 * 1024,
441            "softmax" => 1 * 1024 * 1024,
442            _ => 1 * 1024 * 1024, // 1MB default
443        }
444    }
445
446    fn get_neighbors(&self, graph: &FxGraph, node: NodeIndex) -> Vec<NodeIndex> {
447        let mut neighbors = Vec::new();
448
449        // Get incoming edges
450        for edge_ref in graph
451            .graph
452            .edges_directed(node, petgraph::Direction::Incoming)
453        {
454            neighbors.push(edge_ref.source());
455        }
456
457        // Get outgoing edges
458        for edge_ref in graph
459            .graph
460            .edges_directed(node, petgraph::Direction::Outgoing)
461        {
462            neighbors.push(edge_ref.target());
463        }
464
465        neighbors
466    }
467
468    fn compute_partition_edges(
469        &self,
470        graph: &FxGraph,
471        partitions: &mut [GraphPartition],
472        node_to_partition: &HashMap<NodeIndex, usize>,
473    ) -> Result<()> {
474        // Clear existing edges
475        for partition in partitions.iter_mut() {
476            partition.local_edges.clear();
477            partition.communication_edges.clear();
478        }
479
480        for edge_ref in graph.graph.edge_references() {
481            let source = edge_ref.source();
482            let target = edge_ref.target();
483
484            let source_partition = match node_to_partition.get(&source) {
485                Some(partition) => *partition,
486                None => continue, // Skip edges with unmapped nodes
487            };
488            let target_partition = match node_to_partition.get(&target) {
489                Some(partition) => *partition,
490                None => continue, // Skip edges with unmapped nodes
491            };
492
493            if source_partition == target_partition {
494                // Local edge within partition
495                partitions[source_partition]
496                    .local_edges
497                    .push((source, target));
498            } else {
499                // Communication edge between partitions
500                let comm_edge = CommunicationEdge {
501                    source_partition,
502                    target_partition,
503                    source_node: source,
504                    target_node: target,
505                    data_size: 1024, // Estimate
506                    communication_cost: self.compute_communication_cost(
507                        &partitions[source_partition].device,
508                        &partitions[target_partition].device,
509                        1024,
510                    ),
511                };
512
513                partitions[source_partition]
514                    .communication_edges
515                    .push(comm_edge);
516            }
517        }
518
519        Ok(())
520    }
521
522    fn compute_communication_cost(
523        &self,
524        source: &DeviceInfo,
525        target: &DeviceInfo,
526        data_size: usize,
527    ) -> f64 {
528        let bandwidth = source.bandwidth.min(target.bandwidth);
529        let latency = if source.device_type == target.device_type {
530            0.001
531        } else {
532            0.01
533        };
534
535        (data_size as f64) / bandwidth + latency
536    }
537
538    fn create_communication_schedule(
539        &self,
540        partitions: &[GraphPartition],
541    ) -> Result<CommunicationSchedule> {
542        let mut stages = Vec::new();
543        let mut processed_transfers = HashSet::new();
544        let mut stage_id = 0;
545
546        // Group communication edges by dependencies
547        let mut remaining_edges: Vec<_> = partitions
548            .iter()
549            .enumerate()
550            .flat_map(|(partition_idx, partition)| {
551                partition
552                    .communication_edges
553                    .iter()
554                    .map(move |edge| (partition_idx, edge))
555            })
556            .collect();
557
558        while !remaining_edges.is_empty() {
559            let mut current_stage = CommunicationStage {
560                stage_id,
561                transfers: Vec::new(),
562                dependencies: Vec::new(),
563            };
564
565            let mut i = 0;
566            while i < remaining_edges.len() {
567                let (_, edge) = &remaining_edges[i];
568                let transfer_key = (
569                    edge.source_partition,
570                    edge.target_partition,
571                    edge.source_node,
572                    edge.target_node,
573                );
574
575                if !processed_transfers.contains(&transfer_key) {
576                    let transfer = DataTransfer {
577                        source_device: partitions[edge.source_partition].device.id.clone(),
578                        target_device: partitions[edge.target_partition].device.id.clone(),
579                        data_id: format!(
580                            "data_{}_{}",
581                            edge.source_node.index(),
582                            edge.target_node.index()
583                        ),
584                        data_size: edge.data_size,
585                        priority: TransferPriority::Medium,
586                    };
587
588                    current_stage.transfers.push(transfer);
589                    processed_transfers.insert(transfer_key);
590                    remaining_edges.remove(i);
591                } else {
592                    i += 1;
593                }
594            }
595
596            if !current_stage.transfers.is_empty() {
597                stages.push(current_stage);
598                stage_id += 1;
599            }
600        }
601
602        Ok(CommunicationSchedule {
603            total_stages: stages.len(),
604            stages,
605        })
606    }
607
608    fn compute_partition_metrics(&self, partitions: &[GraphPartition]) -> (f64, f64, f64) {
609        let total_communication_cost = partitions
610            .iter()
611            .flat_map(|p| &p.communication_edges)
612            .map(|edge| edge.communication_cost)
613            .sum();
614
615        let compute_times: Vec<f64> = partitions
616            .iter()
617            .map(|p| p.estimated_compute_time)
618            .collect();
619        let max_compute_time = compute_times.iter().cloned().fold(0.0, f64::max);
620        let avg_compute_time = compute_times.iter().sum::<f64>() / compute_times.len() as f64;
621        let load_balance_score = if max_compute_time > 0.0 {
622            avg_compute_time / max_compute_time
623        } else {
624            1.0
625        };
626
627        let memory_usage: Vec<usize> = partitions.iter().map(|p| p.estimated_memory).collect();
628        let total_memory = memory_usage.iter().sum::<usize>();
629        let total_capacity: usize = partitions.iter().map(|p| p.device.memory_capacity).sum();
630        let memory_efficiency = if total_capacity > 0 {
631            total_memory as f64 / total_capacity as f64
632        } else {
633            0.0
634        };
635
636        (
637            total_communication_cost,
638            load_balance_score,
639            memory_efficiency,
640        )
641    }
642}
643
644/// Utility functions for graph partitioning
645impl GraphPartitioner {
646    /// Create a default CPU cluster configuration
647    pub fn create_cpu_cluster(num_devices: usize) -> Vec<DeviceInfo> {
648        (0..num_devices)
649            .map(|i| DeviceInfo {
650                id: format!("cpu_{i}"),
651                device_type: DeviceType::CPU,
652                memory_capacity: 8 * 1024 * 1024 * 1024, // 8GB
653                compute_capability: 1.0,
654                bandwidth: 10_000_000_000.0, // 10 GB/s
655            })
656            .collect()
657    }
658
659    /// Create a heterogeneous cluster with CPU and GPU devices
660    pub fn create_heterogeneous_cluster() -> Vec<DeviceInfo> {
661        vec![
662            DeviceInfo {
663                id: "cpu_0".to_string(),
664                device_type: DeviceType::CPU,
665                memory_capacity: 16 * 1024 * 1024 * 1024, // 16GB
666                compute_capability: 1.0,
667                bandwidth: 50_000_000_000.0, // 50 GB/s
668            },
669            DeviceInfo {
670                id: "cuda_0".to_string(),
671                device_type: DeviceType::CUDA(8, 0), // RTX 3080 class
672                memory_capacity: 10 * 1024 * 1024 * 1024, // 10GB
673                compute_capability: 5.0,
674                bandwidth: 760_000_000_000.0, // 760 GB/s
675            },
676            DeviceInfo {
677                id: "cuda_1".to_string(),
678                device_type: DeviceType::CUDA(8, 6), // RTX 4090 class
679                memory_capacity: 24 * 1024 * 1024 * 1024, // 24GB
680                compute_capability: 8.0,
681                bandwidth: 1_000_000_000_000.0, // 1 TB/s
682            },
683        ]
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use crate::{Edge, FxGraph, Node};
691
692    #[test]
693    fn test_graph_partitioning_min_communication() {
694        let mut graph = FxGraph::new();
695        let input1 = graph.graph.add_node(Node::Input("x".to_string()));
696        let input2 = graph.graph.add_node(Node::Input("y".to_string()));
697        let add = graph.graph.add_node(Node::Call(
698            "add".to_string(),
699            vec!["x".to_string(), "y".to_string()],
700        ));
701        let relu = graph
702            .graph
703            .add_node(Node::Call("relu".to_string(), vec!["add_out".to_string()]));
704        let output = graph.graph.add_node(Node::Output);
705
706        graph.graph.add_edge(
707            input1,
708            add,
709            Edge {
710                name: "x".to_string(),
711            },
712        );
713        graph.graph.add_edge(
714            input2,
715            add,
716            Edge {
717                name: "y".to_string(),
718            },
719        );
720        graph.graph.add_edge(
721            add,
722            relu,
723            Edge {
724                name: "add_out".to_string(),
725            },
726        );
727        graph.graph.add_edge(
728            relu,
729            output,
730            Edge {
731                name: "relu_out".to_string(),
732            },
733        );
734
735        graph.inputs = vec![input1, input2];
736        graph.outputs = vec![output];
737
738        let devices = GraphPartitioner::create_cpu_cluster(2);
739        let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::MinCommunication);
740
741        let result = partitioner.partition(&graph).unwrap();
742
743        assert_eq!(result.partitions.len(), 2);
744        assert!(result.total_communication_cost >= 0.0);
745        assert!(result.load_balance_score > 0.0);
746    }
747
748    #[test]
749    fn test_graph_partitioning_load_balance() {
750        let mut graph = FxGraph::new();
751        let input = graph.graph.add_node(Node::Input("x".to_string()));
752
753        // Create a linear chain of expensive operations
754        let mut prev = input;
755        for i in 0..6 {
756            let op = graph
757                .graph
758                .add_node(Node::Call("matmul".to_string(), vec![format!("input_{i}")]));
759            graph.graph.add_edge(
760                prev,
761                op,
762                Edge {
763                    name: format!("edge_{i}"),
764                },
765            );
766            prev = op;
767        }
768
769        let output = graph.graph.add_node(Node::Output);
770        graph.graph.add_edge(
771            prev,
772            output,
773            Edge {
774                name: "final".to_string(),
775            },
776        );
777
778        graph.inputs = vec![input];
779        graph.outputs = vec![output];
780
781        let devices = GraphPartitioner::create_heterogeneous_cluster();
782        let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::LoadBalance);
783
784        let result = partitioner.partition(&graph).unwrap();
785
786        assert_eq!(result.partitions.len(), 3);
787        assert!(result.load_balance_score > 0.0);
788
789        // Check that high-compute devices get more work
790        let gpu_partitions: Vec<_> = result
791            .partitions
792            .iter()
793            .filter(|p| matches!(p.device.device_type, DeviceType::CUDA(_, _)))
794            .collect();
795
796        assert!(!gpu_partitions.is_empty());
797    }
798
799    #[test]
800    fn test_communication_schedule() {
801        let devices = vec![
802            DeviceInfo {
803                id: "device_0".to_string(),
804                device_type: DeviceType::CPU,
805                memory_capacity: 1024 * 1024 * 1024,
806                compute_capability: 1.0,
807                bandwidth: 1_000_000_000.0,
808            },
809            DeviceInfo {
810                id: "device_1".to_string(),
811                device_type: DeviceType::CPU,
812                memory_capacity: 1024 * 1024 * 1024,
813                compute_capability: 1.0,
814                bandwidth: 1_000_000_000.0,
815            },
816        ];
817
818        let partitions = vec![
819            GraphPartition {
820                device: devices[0].clone(),
821                nodes: vec![],
822                local_edges: vec![],
823                communication_edges: vec![CommunicationEdge {
824                    source_partition: 0,
825                    target_partition: 1,
826                    source_node: NodeIndex::new(0),
827                    target_node: NodeIndex::new(1),
828                    data_size: 1024,
829                    communication_cost: 0.001,
830                }],
831                estimated_memory: 0,
832                estimated_compute_time: 0.0,
833            },
834            GraphPartition {
835                device: devices[1].clone(),
836                nodes: vec![],
837                local_edges: vec![],
838                communication_edges: vec![],
839                estimated_memory: 0,
840                estimated_compute_time: 0.0,
841            },
842        ];
843
844        let partitioner = GraphPartitioner::new(devices, PartitioningStrategy::MinCommunication);
845        let schedule = partitioner
846            .create_communication_schedule(&partitions)
847            .unwrap();
848
849        assert!(schedule.total_stages > 0);
850        assert!(!schedule.stages.is_empty());
851        assert!(!schedule.stages[0].transfers.is_empty());
852    }
853}