1use std::collections::{HashMap, HashSet};
8
9use super::{EinsumGraph, EinsumNode, OpType};
10use crate::error::IrError;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct FusionStats {
15 pub ops_fused: usize,
17 pub fusion_groups: usize,
19 pub estimated_speedup: f64,
21}
22
23impl FusionStats {
24 pub fn new() -> Self {
26 Self {
27 ops_fused: 0,
28 fusion_groups: 0,
29 estimated_speedup: 1.0,
30 }
31 }
32}
33
34impl Default for FusionStats {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40pub fn fuse_elementwise_operations(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
62 let mut stats = FusionStats::new();
63
64 let mut tensor_users: HashMap<usize, Vec<usize>> = HashMap::new();
66 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
67
68 for (node_idx, node) in graph.nodes.iter().enumerate() {
69 for &output_idx in &node.outputs {
70 tensor_producer.insert(output_idx, node_idx);
71 }
72 for &input_idx in &node.inputs {
73 tensor_users.entry(input_idx).or_default().push(node_idx);
74 }
75 }
76
77 let mut fusible_chains = find_fusible_chains(graph, &tensor_users, &tensor_producer);
79
80 for chain in fusible_chains.drain(..) {
82 if chain.len() > 1 {
83 stats.ops_fused += chain.len();
84 stats.fusion_groups += 1;
85 stats.estimated_speedup *= 1.0 + (chain.len() as f64 * 0.1);
87 }
88 }
89
90 Ok(stats)
91}
92
93fn find_fusible_chains(
95 graph: &EinsumGraph,
96 tensor_users: &HashMap<usize, Vec<usize>>,
97 tensor_producer: &HashMap<usize, usize>,
98) -> Vec<Vec<usize>> {
99 let mut chains = Vec::new();
100 let mut visited = HashSet::new();
101
102 for (node_idx, node) in graph.nodes.iter().enumerate() {
103 if visited.contains(&node_idx) {
104 continue;
105 }
106
107 if is_fusible_operation(&node.op) {
108 let mut chain = vec![node_idx];
109 visited.insert(node_idx);
110
111 extend_chain_forward(
113 graph,
114 node_idx,
115 &mut chain,
116 &mut visited,
117 tensor_users,
118 tensor_producer,
119 );
120
121 if chain.len() > 1 {
122 chains.push(chain);
123 }
124 }
125 }
126
127 chains
128}
129
130fn is_fusible_operation(op_type: &OpType) -> bool {
132 matches!(
133 op_type,
134 OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
135 )
136}
137
138fn extend_chain_forward(
140 graph: &EinsumGraph,
141 current_node: usize,
142 chain: &mut Vec<usize>,
143 visited: &mut HashSet<usize>,
144 tensor_users: &HashMap<usize, Vec<usize>>,
145 _tensor_producer: &HashMap<usize, usize>,
146) {
147 let node = &graph.nodes[current_node];
148
149 for &output_idx in &node.outputs {
151 if let Some(users) = tensor_users.get(&output_idx) {
153 if users.len() == 1 {
154 let next_node_idx = users[0];
155 if visited.contains(&next_node_idx) {
156 continue;
157 }
158
159 let next_node = &graph.nodes[next_node_idx];
160 if is_fusible_operation(&next_node.op) && can_fuse_nodes(node, next_node) {
161 visited.insert(next_node_idx);
162 chain.push(next_node_idx);
163 extend_chain_forward(
165 graph,
166 next_node_idx,
167 chain,
168 visited,
169 tensor_users,
170 _tensor_producer,
171 );
172 }
173 }
174 }
175 }
176}
177
178fn can_fuse_nodes(node1: &EinsumNode, node2: &EinsumNode) -> bool {
180 if !is_fusible_operation(&node1.op) || !is_fusible_operation(&node2.op) {
182 return false;
183 }
184
185 matches!(
188 (&node1.op, &node2.op),
189 (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
190 | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
191 | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
192 )
193}
194
195pub fn fuse_map_reduce(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
211 let mut stats = FusionStats::new();
212
213 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
215 for (node_idx, node) in graph.nodes.iter().enumerate() {
216 for &output_idx in &node.outputs {
217 tensor_producer.insert(output_idx, node_idx);
218 }
219 }
220
221 let mut fuse_pairs = Vec::new();
223
224 for (reduce_idx, reduce_node) in graph.nodes.iter().enumerate() {
225 if matches!(reduce_node.op, OpType::Reduce { .. }) {
226 if let Some(&input_idx) = reduce_node.inputs.first() {
228 if let Some(&map_idx) = tensor_producer.get(&input_idx) {
229 let map_node = &graph.nodes[map_idx];
230 if is_fusible_operation(&map_node.op) {
231 fuse_pairs.push((map_idx, reduce_idx));
232 }
233 }
234 }
235 }
236 }
237
238 stats.ops_fused = fuse_pairs.len() * 2; stats.fusion_groups = fuse_pairs.len();
240 stats.estimated_speedup = 1.0 + (fuse_pairs.len() as f64 * 0.2);
241
242 Ok(stats)
243}
244
245pub fn fuse_einsum_operations(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
261 let mut stats = FusionStats::new();
262
263 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
265 let mut tensor_users: HashMap<usize, Vec<usize>> = HashMap::new();
266
267 for (node_idx, node) in graph.nodes.iter().enumerate() {
268 for &output_idx in &node.outputs {
269 tensor_producer.insert(output_idx, node_idx);
270 }
271 for &input_idx in &node.inputs {
272 tensor_users.entry(input_idx).or_default().push(node_idx);
273 }
274 }
275
276 let mut fuse_pairs = Vec::new();
278
279 for (node2_idx, node2) in graph.nodes.iter().enumerate() {
280 if let OpType::Einsum { spec: spec2 } = &node2.op {
281 for &input_idx in &node2.inputs {
283 if let Some(&node1_idx) = tensor_producer.get(&input_idx) {
284 let node1 = &graph.nodes[node1_idx];
285 if let OpType::Einsum { spec: spec1 } = &node1.op {
286 if can_fuse_einsums(spec1, spec2, &tensor_users, input_idx) {
288 fuse_pairs.push((node1_idx, node2_idx));
289 }
290 }
291 }
292 }
293 }
294 }
295
296 stats.ops_fused = fuse_pairs.len() * 2;
297 stats.fusion_groups = fuse_pairs.len();
298 stats.estimated_speedup = 1.0 + (fuse_pairs.len() as f64 * 0.3);
299
300 Ok(stats)
301}
302
303fn can_fuse_einsums(
305 _spec1: &str,
306 _spec2: &str,
307 tensor_users: &HashMap<usize, Vec<usize>>,
308 intermediate_tensor: usize,
309) -> bool {
310 if let Some(users) = tensor_users.get(&intermediate_tensor) {
312 if users.len() != 1 {
313 return false;
314 }
315 }
316
317 true
320}
321
322pub fn fuse_all(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
327 let mut total_stats = FusionStats::new();
328
329 let elem_stats = fuse_elementwise_operations(graph)?;
331 total_stats.ops_fused += elem_stats.ops_fused;
332 total_stats.fusion_groups += elem_stats.fusion_groups;
333 total_stats.estimated_speedup *= elem_stats.estimated_speedup;
334
335 let map_reduce_stats = fuse_map_reduce(graph)?;
337 total_stats.ops_fused += map_reduce_stats.ops_fused;
338 total_stats.fusion_groups += map_reduce_stats.fusion_groups;
339 total_stats.estimated_speedup *= map_reduce_stats.estimated_speedup;
340
341 let einsum_stats = fuse_einsum_operations(graph)?;
343 total_stats.ops_fused += einsum_stats.ops_fused;
344 total_stats.fusion_groups += einsum_stats.fusion_groups;
345 total_stats.estimated_speedup *= einsum_stats.estimated_speedup;
346
347 Ok(total_stats)
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_fusion_stats_default() {
356 let stats = FusionStats::default();
357 assert_eq!(stats.ops_fused, 0);
358 assert_eq!(stats.fusion_groups, 0);
359 assert_eq!(stats.estimated_speedup, 1.0);
360 }
361
362 #[test]
363 fn test_is_fusible_operation() {
364 assert!(is_fusible_operation(&OpType::ElemUnary {
365 op: "relu".to_string()
366 }));
367 assert!(is_fusible_operation(&OpType::ElemBinary {
368 op: "add".to_string()
369 }));
370 assert!(!is_fusible_operation(&OpType::Einsum {
371 spec: "ij,jk->ik".to_string()
372 }));
373 }
374
375 #[test]
376 fn test_can_fuse_unary_nodes() {
377 let node1 = EinsumNode::elem_unary("relu", 0, 1);
378 let node2 = EinsumNode::elem_unary("tanh", 1, 2);
379 assert!(can_fuse_nodes(&node1, &node2));
380 }
381
382 #[test]
383 fn test_can_fuse_unary_binary_nodes() {
384 let node1 = EinsumNode::elem_unary("relu", 0, 1);
385 let node2 = EinsumNode::elem_binary("add", 1, 2, 3);
386 assert!(can_fuse_nodes(&node1, &node2));
387 }
388
389 #[test]
390 fn test_cannot_fuse_einsum_nodes() {
391 let node1 = EinsumNode::einsum("ij,jk->ik", vec![0, 1], vec![2]);
392 let node2 = EinsumNode::einsum("ik,kl->il", vec![2, 3], vec![4]);
393 assert!(!can_fuse_nodes(&node1, &node2));
395 }
396
397 #[test]
398 fn test_fuse_elementwise_empty_graph() {
399 let mut graph = EinsumGraph::new();
400 let stats = fuse_elementwise_operations(&mut graph).unwrap();
401 assert_eq!(stats.ops_fused, 0);
402 assert_eq!(stats.fusion_groups, 0);
403 }
404
405 #[test]
406 fn test_fuse_elementwise_single_op() {
407 let mut graph = EinsumGraph::new();
408 let a = graph.add_tensor("A");
409 let b = graph.add_tensor("B");
410 graph
411 .add_node(EinsumNode::elem_unary("relu", a, b))
412 .unwrap();
413
414 let stats = fuse_elementwise_operations(&mut graph).unwrap();
415 assert_eq!(stats.ops_fused, 0);
417 }
418
419 #[test]
420 fn test_fuse_map_reduce_empty_graph() {
421 let mut graph = EinsumGraph::new();
422 let stats = fuse_map_reduce(&mut graph).unwrap();
423 assert_eq!(stats.ops_fused, 0);
424 }
425
426 #[test]
427 fn test_fuse_einsum_empty_graph() {
428 let mut graph = EinsumGraph::new();
429 let stats = fuse_einsum_operations(&mut graph).unwrap();
430 assert_eq!(stats.ops_fused, 0);
431 }
432
433 #[test]
434 fn test_fuse_all_empty_graph() {
435 let mut graph = EinsumGraph::new();
436 let stats = fuse_all(&mut graph).unwrap();
437 assert_eq!(stats.ops_fused, 0);
438 assert_eq!(stats.fusion_groups, 0);
439 }
440
441 #[test]
442 fn test_find_fusible_chains_empty() {
443 let graph = EinsumGraph::new();
444 let tensor_users = HashMap::new();
445 let tensor_producer = HashMap::new();
446 let chains = find_fusible_chains(&graph, &tensor_users, &tensor_producer);
447 assert!(chains.is_empty());
448 }
449
450 #[test]
451 fn test_can_fuse_einsums_single_user() {
452 let tensor_users = HashMap::from([(1, vec![2])]);
453 assert!(can_fuse_einsums("ij,jk->ik", "ik,kl->il", &tensor_users, 1));
454 }
455
456 #[test]
457 fn test_cannot_fuse_einsums_multiple_users() {
458 let tensor_users = HashMap::from([(1, vec![2, 3])]);
459 assert!(!can_fuse_einsums(
460 "ij,jk->ik",
461 "ik,kl->il",
462 &tensor_users,
463 1
464 ));
465 }
466}