1use std::collections::{HashMap, HashSet, VecDeque};
8
9use super::EinsumGraph;
10use crate::error::IrError;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct ParallelGroup {
15 pub nodes: Vec<usize>,
17 pub estimated_cost: f64,
19 pub level: usize,
21}
22
23#[derive(Debug, Clone)]
25pub struct ParallelizationAnalysis {
26 pub parallel_groups: Vec<ParallelGroup>,
28 pub max_parallelism: usize,
30 pub avg_parallelism: f64,
32 pub critical_path_length: usize,
34 pub critical_path: Vec<usize>,
36 pub estimated_speedup: f64,
38}
39
40impl ParallelizationAnalysis {
41 pub fn new() -> Self {
43 Self {
44 parallel_groups: Vec::new(),
45 max_parallelism: 0,
46 avg_parallelism: 0.0,
47 critical_path_length: 0,
48 critical_path: Vec::new(),
49 estimated_speedup: 1.0,
50 }
51 }
52
53 pub fn has_parallelism(&self) -> bool {
55 self.max_parallelism > 1
56 }
57
58 pub fn total_nodes(&self) -> usize {
60 self.parallel_groups.iter().map(|g| g.nodes.len()).sum()
61 }
62}
63
64impl Default for ParallelizationAnalysis {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70pub fn analyze_parallelization(graph: &EinsumGraph) -> Result<ParallelizationAnalysis, IrError> {
98 if graph.nodes.is_empty() {
99 return Ok(ParallelizationAnalysis::new());
100 }
101
102 let (dependencies, dependents) = build_dependency_graph(graph);
104
105 let node_levels = compute_node_levels(graph, &dependencies);
107
108 let mut level_groups: HashMap<usize, Vec<usize>> = HashMap::new();
110 for (node_idx, &level) in node_levels.iter().enumerate() {
111 level_groups.entry(level).or_default().push(node_idx);
112 }
113
114 let mut parallel_groups = Vec::new();
116 let max_level = node_levels.iter().max().copied().unwrap_or(0);
117
118 for level in 0..=max_level {
119 if let Some(nodes) = level_groups.get(&level) {
120 let estimated_cost = estimate_group_cost(graph, nodes);
121 parallel_groups.push(ParallelGroup {
122 nodes: nodes.clone(),
123 estimated_cost,
124 level,
125 });
126 }
127 }
128
129 let max_parallelism = parallel_groups
131 .iter()
132 .map(|g| g.nodes.len())
133 .max()
134 .unwrap_or(0);
135
136 let total_nodes: usize = parallel_groups.iter().map(|g| g.nodes.len()).sum();
137 let avg_parallelism = if !parallel_groups.is_empty() {
138 total_nodes as f64 / parallel_groups.len() as f64
139 } else {
140 0.0
141 };
142
143 let (critical_path, critical_path_length) =
145 find_critical_path(graph, &node_levels, &dependents);
146
147 let sequential_cost: f64 = (0..graph.nodes.len())
149 .map(|i| estimate_node_cost(graph, i))
150 .sum();
151 let parallel_cost: f64 = parallel_groups.iter().map(|g| g.estimated_cost).sum();
152 let estimated_speedup = if parallel_cost > 0.0 {
153 sequential_cost / parallel_cost
154 } else {
155 1.0
156 };
157
158 Ok(ParallelizationAnalysis {
159 parallel_groups,
160 max_parallelism,
161 avg_parallelism,
162 critical_path_length,
163 critical_path,
164 estimated_speedup,
165 })
166}
167
168fn build_dependency_graph(
170 graph: &EinsumGraph,
171) -> (HashMap<usize, Vec<usize>>, HashMap<usize, Vec<usize>>) {
172 let mut dependencies: HashMap<usize, Vec<usize>> = HashMap::new();
173 let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
174
175 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
177 for (node_idx, node) in graph.nodes.iter().enumerate() {
178 for &output_idx in &node.outputs {
179 tensor_producer.insert(output_idx, node_idx);
180 }
181 }
182
183 for (node_idx, node) in graph.nodes.iter().enumerate() {
185 let mut node_deps = Vec::new();
186 for &input_idx in &node.inputs {
187 if let Some(&producer_idx) = tensor_producer.get(&input_idx) {
188 if producer_idx != node_idx {
189 node_deps.push(producer_idx);
190 dependents.entry(producer_idx).or_default().push(node_idx);
191 }
192 }
193 }
194 dependencies.insert(node_idx, node_deps);
195 }
196
197 (dependencies, dependents)
198}
199
200fn compute_node_levels(
202 graph: &EinsumGraph,
203 dependencies: &HashMap<usize, Vec<usize>>,
204) -> Vec<usize> {
205 let mut levels = vec![0; graph.nodes.len()];
206 let mut in_degree = vec![0; graph.nodes.len()];
207
208 for (node_idx, deps) in dependencies.iter() {
210 in_degree[*node_idx] = deps.len();
211 }
212
213 let mut queue: VecDeque<usize> = VecDeque::new();
215 for (node_idx, °ree) in in_degree.iter().enumerate() {
216 if degree == 0 && node_idx < graph.nodes.len() {
217 queue.push_back(node_idx);
218 levels[node_idx] = 0;
219 }
220 }
221
222 let mut dependents: HashMap<usize, Vec<usize>> = HashMap::new();
224 for (node_idx, deps) in dependencies.iter() {
225 for &dep in deps {
226 dependents.entry(dep).or_default().push(*node_idx);
227 }
228 }
229
230 let mut visited = HashSet::new();
232 while let Some(node_idx) = queue.pop_front() {
233 if visited.contains(&node_idx) {
234 continue;
235 }
236 visited.insert(node_idx);
237
238 let current_level = levels[node_idx];
239
240 if let Some(deps) = dependents.get(&node_idx) {
242 for &dep_idx in deps {
243 if dep_idx < graph.nodes.len() {
244 levels[dep_idx] = levels[dep_idx].max(current_level + 1);
245 queue.push_back(dep_idx);
246 }
247 }
248 }
249 }
250
251 levels
252}
253
254fn estimate_group_cost(graph: &EinsumGraph, nodes: &[usize]) -> f64 {
256 nodes
257 .iter()
258 .map(|&idx| estimate_node_cost(graph, idx))
259 .max_by(|a, b| a.partial_cmp(b).unwrap())
260 .unwrap_or(0.0)
261}
262
263fn estimate_node_cost(_graph: &EinsumGraph, _node_idx: usize) -> f64 {
265 1.0
268}
269
270fn find_critical_path(
272 graph: &EinsumGraph,
273 node_levels: &[usize],
274 _dependents: &HashMap<usize, Vec<usize>>,
275) -> (Vec<usize>, usize) {
276 let max_level = node_levels.iter().max().copied().unwrap_or(0);
277
278 let end_nodes: Vec<usize> = node_levels
280 .iter()
281 .enumerate()
282 .filter(|(_, &level)| level == max_level)
283 .map(|(idx, _)| idx)
284 .collect();
285
286 if end_nodes.is_empty() {
287 return (Vec::new(), 0);
288 }
289
290 let mut path = Vec::new();
292 let mut current = end_nodes[0];
293 path.push(current);
294
295 while node_levels[current] > 0 {
296 let predecessors = get_predecessors(graph, current);
298 if let Some(&pred) = predecessors
299 .iter()
300 .max_by_key(|&&idx| node_levels.get(idx).copied().unwrap_or(0))
301 {
302 path.push(pred);
303 current = pred;
304 } else {
305 break;
306 }
307 }
308
309 path.reverse();
310 let length = path.len();
311 (path, length)
312}
313
314fn get_predecessors(graph: &EinsumGraph, node_idx: usize) -> Vec<usize> {
316 let mut predecessors = Vec::new();
317
318 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
320 for (idx, node) in graph.nodes.iter().enumerate() {
321 for &output in &node.outputs {
322 tensor_producer.insert(output, idx);
323 }
324 }
325
326 if let Some(node) = graph.nodes.get(node_idx) {
328 for &input in &node.inputs {
329 if let Some(&producer) = tensor_producer.get(&input) {
330 predecessors.push(producer);
331 }
332 }
333 }
334
335 predecessors
336}
337
338pub fn partition_independent_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
344 if graph.nodes.is_empty() {
345 return Ok(Vec::new());
346 }
347
348 let (dependencies, dependents) = build_dependency_graph(graph);
349 let mut visited = HashSet::new();
350 let mut subgraphs = Vec::new();
351
352 for node_idx in 0..graph.nodes.len() {
353 if visited.contains(&node_idx) {
354 continue;
355 }
356
357 let mut subgraph = Vec::new();
358 let mut stack = vec![node_idx];
359
360 while let Some(current) = stack.pop() {
361 if visited.contains(¤t) {
362 continue;
363 }
364 visited.insert(current);
365 subgraph.push(current);
366
367 if let Some(deps) = dependencies.get(¤t) {
369 stack.extend(deps.iter().copied());
370 }
371 if let Some(deps) = dependents.get(¤t) {
372 stack.extend(deps.iter().copied());
373 }
374 }
375
376 subgraphs.push(subgraph);
377 }
378
379 Ok(subgraphs)
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::graph::EinsumNode;
386
387 #[test]
388 fn test_parallelization_analysis_default() {
389 let analysis = ParallelizationAnalysis::default();
390 assert_eq!(analysis.max_parallelism, 0);
391 assert!(!analysis.has_parallelism());
392 }
393
394 #[test]
395 fn test_analyze_empty_graph() {
396 let graph = EinsumGraph::new();
397 let analysis = analyze_parallelization(&graph).unwrap();
398 assert_eq!(analysis.max_parallelism, 0);
399 assert_eq!(analysis.total_nodes(), 0);
400 }
401
402 #[test]
403 fn test_analyze_single_node() {
404 let mut graph = EinsumGraph::new();
405 let a = graph.add_tensor("A");
406 let b = graph.add_tensor("B");
407 graph
408 .add_node(EinsumNode::elem_unary("relu", a, b))
409 .unwrap();
410
411 let analysis = analyze_parallelization(&graph).unwrap();
412 assert_eq!(analysis.max_parallelism, 1);
413 assert_eq!(analysis.total_nodes(), 1);
414 }
415
416 #[test]
417 fn test_analyze_parallel_nodes() {
418 let mut graph = EinsumGraph::new();
419 let a = graph.add_tensor("A");
420 let b = graph.add_tensor("B");
421 let c = graph.add_tensor("C");
422 let d = graph.add_tensor("D");
423
424 graph
426 .add_node(EinsumNode::elem_unary("relu", a, b))
427 .unwrap();
428 graph
429 .add_node(EinsumNode::elem_unary("tanh", c, d))
430 .unwrap();
431
432 let analysis = analyze_parallelization(&graph).unwrap();
433 assert_eq!(analysis.max_parallelism, 2);
434 assert!(analysis.has_parallelism());
435 }
436
437 #[test]
438 fn test_analyze_sequential_nodes() {
439 let mut graph = EinsumGraph::new();
440 let a = graph.add_tensor("A");
441 let b = graph.add_tensor("B");
442 let c = graph.add_tensor("C");
443
444 graph
446 .add_node(EinsumNode::elem_unary("relu", a, b))
447 .unwrap();
448 graph
449 .add_node(EinsumNode::elem_unary("tanh", b, c))
450 .unwrap();
451
452 let analysis = analyze_parallelization(&graph).unwrap();
453 assert_eq!(analysis.critical_path_length, 2);
454 }
455
456 #[test]
457 fn test_partition_empty_graph() {
458 let graph = EinsumGraph::new();
459 let subgraphs = partition_independent_subgraphs(&graph).unwrap();
460 assert!(subgraphs.is_empty());
461 }
462
463 #[test]
464 fn test_partition_single_node() {
465 let mut graph = EinsumGraph::new();
466 let a = graph.add_tensor("A");
467 let b = graph.add_tensor("B");
468 graph
469 .add_node(EinsumNode::elem_unary("relu", a, b))
470 .unwrap();
471
472 let subgraphs = partition_independent_subgraphs(&graph).unwrap();
473 assert_eq!(subgraphs.len(), 1);
474 assert_eq!(subgraphs[0].len(), 1);
475 }
476
477 #[test]
478 fn test_partition_independent_nodes() {
479 let mut graph = EinsumGraph::new();
480 let a = graph.add_tensor("A");
481 let b = graph.add_tensor("B");
482 let c = graph.add_tensor("C");
483 let d = graph.add_tensor("D");
484
485 graph
487 .add_node(EinsumNode::elem_unary("relu", a, b))
488 .unwrap();
489 graph
490 .add_node(EinsumNode::elem_unary("tanh", c, d))
491 .unwrap();
492
493 let subgraphs = partition_independent_subgraphs(&graph).unwrap();
494 assert_eq!(subgraphs.len(), 2);
496 }
497
498 #[test]
499 fn test_estimate_node_cost() {
500 let graph = EinsumGraph::new();
501 let cost = estimate_node_cost(&graph, 0);
502 assert_eq!(cost, 1.0);
503 }
504
505 #[test]
506 fn test_estimate_group_cost() {
507 let graph = EinsumGraph::new();
508 let cost = estimate_group_cost(&graph, &[0, 1, 2]);
509 assert_eq!(cost, 1.0); }
511
512 #[test]
513 fn test_parallel_group_creation() {
514 let group = ParallelGroup {
515 nodes: vec![0, 1, 2],
516 estimated_cost: 3.5,
517 level: 1,
518 };
519 assert_eq!(group.nodes.len(), 3);
520 assert_eq!(group.estimated_cost, 3.5);
521 assert_eq!(group.level, 1);
522 }
523}