Skip to main content

tensorlogic_ir/graph/
parallel.rs

1//! Parallelization analysis for identifying independent subgraphs.
2//!
3//! This module provides utilities for analyzing computational graphs to identify
4//! opportunities for parallel execution. It finds groups of operations that have
5//! no dependencies on each other and can safely execute concurrently.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12/// A group of nodes that can execute in parallel
13#[derive(Debug, Clone, PartialEq)]
14pub struct ParallelGroup {
15    /// Indices of nodes in this parallel group
16    pub nodes: Vec<usize>,
17    /// Estimated computation cost of this group
18    pub estimated_cost: f64,
19    /// Level in the execution schedule (for visualization)
20    pub level: usize,
21}
22
23/// Analysis result containing parallel execution opportunities
24#[derive(Debug, Clone)]
25pub struct ParallelizationAnalysis {
26    /// Groups of nodes that can execute in parallel at each level
27    pub parallel_groups: Vec<ParallelGroup>,
28    /// Maximum parallelism (largest group size)
29    pub max_parallelism: usize,
30    /// Average parallelism across all levels
31    pub avg_parallelism: f64,
32    /// Critical path length (longest dependency chain)
33    pub critical_path_length: usize,
34    /// Nodes on the critical path
35    pub critical_path: Vec<usize>,
36    /// Estimated parallel speedup (compared to sequential execution)
37    pub estimated_speedup: f64,
38}
39
40impl ParallelizationAnalysis {
41    /// Create a new empty analysis
42    pub fn new() -> Self {
43        Self {
44            parallel_groups: Vec::new(),
45            max_parallelism: 0,
46            avg_parallelism: 0.0,
47            critical_path_length: 0,
48            critical_path: Vec::new(),
49            estimated_speedup: 1.0,
50        }
51    }
52
53    /// Check if the graph has any parallelism opportunities
54    pub fn has_parallelism(&self) -> bool {
55        self.max_parallelism > 1
56    }
57
58    /// Get total number of nodes across all parallel groups
59    pub fn total_nodes(&self) -> usize {
60        self.parallel_groups.iter().map(|g| g.nodes.len()).sum()
61    }
62}
63
64impl Default for ParallelizationAnalysis {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// Analyze graph for parallel execution opportunities
71///
72/// This function performs a topological analysis of the computation graph
73/// to identify sets of operations that can execute in parallel. It uses
74/// level-based scheduling to group independent operations.
75///
76/// # Returns
77///
78/// Returns a `ParallelizationAnalysis` containing:
79/// - Groups of parallelizable operations at each level
80/// - Critical path analysis
81/// - Estimated speedup from parallelization
82///
83/// # Example
84///
85/// ```rust
86/// use tensorlogic_ir::{EinsumGraph, analyze_parallelization};
87///
88/// let mut graph = EinsumGraph::new();
89/// // Build your graph...
90///
91/// let analysis = analyze_parallelization(&graph).unwrap();
92/// if analysis.has_parallelism() {
93///     println!("Max parallelism: {}", analysis.max_parallelism);
94///     println!("Estimated speedup: {:.2}x", analysis.estimated_speedup);
95/// }
96/// ```
97pub fn analyze_parallelization(graph: &EinsumGraph) -> Result<ParallelizationAnalysis, IrError> {
98    if graph.nodes.is_empty() {
99        return Ok(ParallelizationAnalysis::new());
100    }
101
102    // Build dependency information
103    let (dependencies, dependents) = build_dependency_graph(graph);
104
105    // Compute node levels using topological sort
106    let node_levels = compute_node_levels(graph, &dependencies);
107
108    // Group nodes by level
109    let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
110    for (node_idx, &level) in node_levels.iter().enumerate() {
111        level_groups.entry(level).or_default().push(node_idx);
112    }
113
114    // Create parallel groups
115    let mut parallel_groups = Vec::new();
116    let max_level = node_levels.iter().max().copied().unwrap_or(0);
117
118    for level in 0..=max_level {
119        if let Some(nodes) = level_groups.get(&level) {
120            let estimated_cost = estimate_group_cost(graph, nodes);
121            parallel_groups.push(ParallelGroup {
122                nodes: nodes.clone(),
123                estimated_cost,
124                level,
125            });
126        }
127    }
128
129    // Compute statistics
130    let max_parallelism = parallel_groups
131        .iter()
132        .map(|g| g.nodes.len())
133        .max()
134        .unwrap_or(0);
135
136    let total_nodes: usize = parallel_groups.iter().map(|g| g.nodes.len()).sum();
137    let avg_parallelism = if !parallel_groups.is_empty() {
138        total_nodes as f64 / parallel_groups.len() as f64
139    } else {
140        0.0
141    };
142
143    // Find critical path
144    let (critical_path, critical_path_length) =
145        find_critical_path(graph, &node_levels, &dependents);
146
147    // Estimate speedup (simplified model)
148    let sequential_cost: f64 = (0..graph.nodes.len())
149        .map(|i| estimate_node_cost(graph, i))
150        .sum();
151    let parallel_cost: f64 = parallel_groups.iter().map(|g| g.estimated_cost).sum();
152    let estimated_speedup = if parallel_cost > 0.0 {
153        sequential_cost / parallel_cost
154    } else {
155        1.0
156    };
157
158    Ok(ParallelizationAnalysis {
159        parallel_groups,
160        max_parallelism,
161        avg_parallelism,
162        critical_path_length,
163        critical_path,
164        estimated_speedup,
165    })
166}
167
168/// Build dependency graph (forward and backward)
169fn build_dependency_graph(
170    graph: &EinsumGraph,
171) -> (HashMap<usize, Vec<usize>>, HashMap<usize, Vec<usize>>) {
172    let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
173    let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
174
175    // Build tensor-to-producer mapping
176    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
177    for (node_idx, node) in graph.nodes.iter().enumerate() {
178        for &output_idx in &node.outputs {
179            tensor_producer.insert(output_idx, node_idx);
180        }
181    }
182
183    // Build dependency relationships
184    for (node_idx, node) in graph.nodes.iter().enumerate() {
185        let mut node_deps = Vec::new();
186        for &input_idx in &node.inputs {
187            if let Some(&producer_idx) = tensor_producer.get(&input_idx) {
188                if producer_idx != node_idx {
189                    node_deps.push(producer_idx);
190                    dependents.entry(producer_idx).or_default().push(node_idx);
191                }
192            }
193        }
194        dependencies.insert(node_idx, node_deps);
195    }
196
197    (dependencies, dependents)
198}
199
200/// Compute the execution level for each node using topological sort
201fn compute_node_levels(
202    graph: &EinsumGraph,
203    dependencies: &HashMap<usize, Vec<usize>>,
204) -> Vec<usize> {
205    let mut levels = vec![0; graph.nodes.len()];
206    let mut in_degree = vec![0; graph.nodes.len()];
207
208    // Calculate in-degrees (count how many nodes each node depends on)
209    for (node_idx, deps) in dependencies.iter() {
210        in_degree[*node_idx] = deps.len();
211    }
212
213    // Find all nodes with no dependencies (level 0)
214    let mut queue: VecDeque<usize> = VecDeque::new();
215    for (node_idx, &degree) in in_degree.iter().enumerate() {
216        if degree == 0 && node_idx < graph.nodes.len() {
217            queue.push_back(node_idx);
218            levels[node_idx] = 0;
219        }
220    }
221
222    // Build reverse dependency map (who depends on me?)
223    let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
224    for (node_idx, deps) in dependencies.iter() {
225        for &dep in deps {
226            dependents.entry(dep).or_default().push(*node_idx);
227        }
228    }
229
230    // BFS to assign levels
231    let mut visited = HashSet::new();
232    while let Some(node_idx) = queue.pop_front() {
233        if visited.contains(&node_idx) {
234            continue;
235        }
236        visited.insert(node_idx);
237
238        let current_level = levels[node_idx];
239
240        // Update nodes that depend on this one
241        if let Some(deps) = dependents.get(&node_idx) {
242            for &dep_idx in deps {
243                if dep_idx < graph.nodes.len() {
244                    levels[dep_idx] = levels[dep_idx].max(current_level + 1);
245                    queue.push_back(dep_idx);
246                }
247            }
248        }
249    }
250
251    levels
252}
253
254/// Estimate computational cost for a group of nodes
255fn estimate_group_cost(graph: &EinsumGraph, nodes: &[usize]) -> f64 {
256    nodes
257        .iter()
258        .map(|&idx| estimate_node_cost(graph, idx))
259        .max_by(|a, b| a.partial_cmp(b).unwrap())
260        .unwrap_or(0.0)
261}
262
263/// Estimate computational cost for a single node (simplified)
264fn estimate_node_cost(_graph: &EinsumGraph, _node_idx: usize) -> f64 {
265    // Simplified cost model - in practice, this would analyze the operation type
266    // and tensor sizes to estimate FLOPs and memory traffic
267    1.0
268}
269
270/// Find the critical path in the computation graph
271fn find_critical_path(
272    graph: &EinsumGraph,
273    node_levels: &[usize],
274    _dependents: &HashMap<usize, Vec<usize>>,
275) -> (Vec<usize>, usize) {
276    let max_level = node_levels.iter().max().copied().unwrap_or(0);
277
278    // Find nodes at maximum level
279    let end_nodes: Vec<usize> = node_levels
280        .iter()
281        .enumerate()
282        .filter(|(_, &level)| level == max_level)
283        .map(|(idx, _)| idx)
284        .collect();
285
286    if end_nodes.is_empty() {
287        return (Vec::new(), 0);
288    }
289
290    // Backtrack from end node to find critical path
291    let mut path = Vec::new();
292    let mut current = end_nodes[0];
293    path.push(current);
294
295    while node_levels[current] > 0 {
296        // Find predecessor with highest level
297        let predecessors = get_predecessors(graph, current);
298        if let Some(&pred) = predecessors
299            .iter()
300            .max_by_key(|&&idx| node_levels.get(idx).copied().unwrap_or(0))
301        {
302            path.push(pred);
303            current = pred;
304        } else {
305            break;
306        }
307    }
308
309    path.reverse();
310    let length = path.len();
311    (path, length)
312}
313
314/// Get predecessor nodes for a given node
315fn get_predecessors(graph: &EinsumGraph, node_idx: usize) -> Vec<usize> {
316    let mut predecessors = Vec::new();
317
318    // Build tensor-to-producer mapping
319    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
320    for (idx, node) in graph.nodes.iter().enumerate() {
321        for &output in &node.outputs {
322            tensor_producer.insert(output, idx);
323        }
324    }
325
326    // Find nodes that produce inputs for this node
327    if let Some(node) = graph.nodes.get(node_idx) {
328        for &input in &node.inputs {
329            if let Some(&producer) = tensor_producer.get(&input) {
330                predecessors.push(producer);
331            }
332        }
333    }
334
335    predecessors
336}
337
338/// Partition graph into independent subgraphs for parallel execution
339///
340/// This function divides the computation graph into the largest possible
341/// independent subgraphs that can execute in parallel without any
342/// data dependencies between them.
343pub fn partition_independent_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
344    if graph.nodes.is_empty() {
345        return Ok(Vec::new());
346    }
347
348    let (dependencies, dependents) = build_dependency_graph(graph);
349    let mut visited = HashSet::new();
350    let mut subgraphs = Vec::new();
351
352    for node_idx in 0..graph.nodes.len() {
353        if visited.contains(&node_idx) {
354            continue;
355        }
356
357        let mut subgraph = Vec::new();
358        let mut stack = vec![node_idx];
359
360        while let Some(current) = stack.pop() {
361            if visited.contains(&current) {
362                continue;
363            }
364            visited.insert(current);
365            subgraph.push(current);
366
367            // Add dependencies and dependents
368            if let Some(deps) = dependencies.get(&current) {
369                stack.extend(deps.iter().copied());
370            }
371            if let Some(deps) = dependents.get(&current) {
372                stack.extend(deps.iter().copied());
373            }
374        }
375
376        subgraphs.push(subgraph);
377    }
378
379    Ok(subgraphs)
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::graph::EinsumNode;
386
387    #[test]
388    fn test_parallelization_analysis_default() {
389        let analysis = ParallelizationAnalysis::default();
390        assert_eq!(analysis.max_parallelism, 0);
391        assert!(!analysis.has_parallelism());
392    }
393
394    #[test]
395    fn test_analyze_empty_graph() {
396        let graph = EinsumGraph::new();
397        let analysis = analyze_parallelization(&graph).unwrap();
398        assert_eq!(analysis.max_parallelism, 0);
399        assert_eq!(analysis.total_nodes(), 0);
400    }
401
402    #[test]
403    fn test_analyze_single_node() {
404        let mut graph = EinsumGraph::new();
405        let a = graph.add_tensor("A");
406        let b = graph.add_tensor("B");
407        graph
408            .add_node(EinsumNode::elem_unary("relu", a, b))
409            .unwrap();
410
411        let analysis = analyze_parallelization(&graph).unwrap();
412        assert_eq!(analysis.max_parallelism, 1);
413        assert_eq!(analysis.total_nodes(), 1);
414    }
415
416    #[test]
417    fn test_analyze_parallel_nodes() {
418        let mut graph = EinsumGraph::new();
419        let a = graph.add_tensor("A");
420        let b = graph.add_tensor("B");
421        let c = graph.add_tensor("C");
422        let d = graph.add_tensor("D");
423
424        // Two independent operations
425        graph
426            .add_node(EinsumNode::elem_unary("relu", a, b))
427            .unwrap();
428        graph
429            .add_node(EinsumNode::elem_unary("tanh", c, d))
430            .unwrap();
431
432        let analysis = analyze_parallelization(&graph).unwrap();
433        assert_eq!(analysis.max_parallelism, 2);
434        assert!(analysis.has_parallelism());
435    }
436
437    #[test]
438    fn test_analyze_sequential_nodes() {
439        let mut graph = EinsumGraph::new();
440        let a = graph.add_tensor("A");
441        let b = graph.add_tensor("B");
442        let c = graph.add_tensor("C");
443
444        // Sequential operations
445        graph
446            .add_node(EinsumNode::elem_unary("relu", a, b))
447            .unwrap();
448        graph
449            .add_node(EinsumNode::elem_unary("tanh", b, c))
450            .unwrap();
451
452        let analysis = analyze_parallelization(&graph).unwrap();
453        assert_eq!(analysis.critical_path_length, 2);
454    }
455
456    #[test]
457    fn test_partition_empty_graph() {
458        let graph = EinsumGraph::new();
459        let subgraphs = partition_independent_subgraphs(&graph).unwrap();
460        assert!(subgraphs.is_empty());
461    }
462
463    #[test]
464    fn test_partition_single_node() {
465        let mut graph = EinsumGraph::new();
466        let a = graph.add_tensor("A");
467        let b = graph.add_tensor("B");
468        graph
469            .add_node(EinsumNode::elem_unary("relu", a, b))
470            .unwrap();
471
472        let subgraphs = partition_independent_subgraphs(&graph).unwrap();
473        assert_eq!(subgraphs.len(), 1);
474        assert_eq!(subgraphs[0].len(), 1);
475    }
476
477    #[test]
478    fn test_partition_independent_nodes() {
479        let mut graph = EinsumGraph::new();
480        let a = graph.add_tensor("A");
481        let b = graph.add_tensor("B");
482        let c = graph.add_tensor("C");
483        let d = graph.add_tensor("D");
484
485        // Two truly independent operations (no shared tensors)
486        graph
487            .add_node(EinsumNode::elem_unary("relu", a, b))
488            .unwrap();
489        graph
490            .add_node(EinsumNode::elem_unary("tanh", c, d))
491            .unwrap();
492
493        let subgraphs = partition_independent_subgraphs(&graph).unwrap();
494        // Should have 2 independent subgraphs
495        assert_eq!(subgraphs.len(), 2);
496    }
497
498    #[test]
499    fn test_estimate_node_cost() {
500        let graph = EinsumGraph::new();
501        let cost = estimate_node_cost(&graph, 0);
502        assert_eq!(cost, 1.0);
503    }
504
505    #[test]
506    fn test_estimate_group_cost() {
507        let graph = EinsumGraph::new();
508        let cost = estimate_group_cost(&graph, &[0, 1, 2]);
509        assert_eq!(cost, 1.0); // Max of individual costs
510    }
511
512    #[test]
513    fn test_parallel_group_creation() {
514        let group = ParallelGroup {
515            nodes: vec![0, 1, 2],
516            estimated_cost: 3.5,
517            level: 1,
518        };
519        assert_eq!(group.nodes.len(), 3);
520        assert_eq!(group.estimated_cost, 3.5);
521        assert_eq!(group.level, 1);
522    }
523}