Skip to main content

torsh_jit/
optimizer.rs

1//! Graph optimization passes for JIT compilation
2
3use crate::graph::{ComputationGraph, Conv2dInfo, Edge, Node, NodeId, Operation};
4use crate::JitResult;
5use petgraph::visit::EdgeRef;
6use std::collections::{HashMap, HashSet};
7use torsh_core::Shape;
8
9/// Graph optimizer that applies various optimization passes
10pub struct GraphOptimizer {
11    passes: Vec<Box<dyn OptimizationPass>>,
12}
13
14impl GraphOptimizer {
15    /// Create a new graph optimizer with default passes
16    pub fn new() -> Self {
17        Self {
18            passes: vec![
19                Box::new(DeadCodeElimination),
20                Box::new(ConstantFolding),
21                Box::new(CommonSubexpressionElimination),
22                Box::new(AlgebraicSimplification),
23                Box::new(StrengthReduction),
24                Box::new(LayoutOptimization),
25                Box::new(CacheAwareOptimization::default()),
26                Box::new(AutoVectorization),
27                Box::new(AutoParallelization),
28            ],
29        }
30    }
31
32    /// Create optimizer with custom passes
33    pub fn with_passes(passes: Vec<Box<dyn OptimizationPass>>) -> Self {
34        Self { passes }
35    }
36
37    /// Apply all optimization passes to the graph
38    pub fn optimize(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
39        for pass in &self.passes {
40            graph = pass.apply(graph)?;
41        }
42        Ok(graph)
43    }
44}
45
46impl Default for GraphOptimizer {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52/// Trait for optimization passes
53pub trait OptimizationPass: Send + Sync {
54    /// Name of the optimization pass
55    fn name(&self) -> &str;
56
57    /// Apply the optimization to the graph
58    fn apply(&self, graph: ComputationGraph) -> JitResult<ComputationGraph>;
59}
60
61/// Dead code elimination - remove unused nodes
62pub struct DeadCodeElimination;
63
64impl OptimizationPass for DeadCodeElimination {
65    fn name(&self) -> &str {
66        "DeadCodeElimination"
67    }
68
69    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
70        // Find all nodes reachable from outputs
71        let mut reachable = HashSet::new();
72        let mut to_visit = graph.outputs.clone();
73
74        while let Some(node) = to_visit.pop() {
75            if reachable.insert(node) {
76                // Add all predecessors to visit list
77                for pred in graph.predecessors(node) {
78                    to_visit.push(pred);
79                }
80            }
81        }
82
83        // Remove unreachable nodes
84        let all_nodes: Vec<_> = graph.graph.node_indices().collect();
85        for node in all_nodes {
86            if !reachable.contains(&node) {
87                graph.graph.remove_node(node);
88            }
89        }
90
91        Ok(graph)
92    }
93}
94
95/// Constant folding - evaluate constant expressions at compile time
96pub struct ConstantFolding;
97
98impl OptimizationPass for ConstantFolding {
99    fn name(&self) -> &str {
100        "ConstantFolding"
101    }
102
103    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
104        let nodes: Vec<_> = graph.graph.node_indices().collect();
105
106        for node_id in nodes {
107            if let Some(node) = graph.node(node_id).cloned() {
108                if self.can_fold(&graph, node_id, &node) {
109                    self.fold_node(&mut graph, node_id, &node)?;
110                }
111            }
112        }
113
114        Ok(graph)
115    }
116}
117
118impl ConstantFolding {
119    fn can_fold(&self, graph: &ComputationGraph, node_id: NodeId, node: &Node) -> bool {
120        // Check if all inputs are constants
121        match &node.op {
122            Operation::Add | Operation::Sub | Operation::Mul | Operation::Div => {
123                graph.predecessors(node_id).all(|pred_id| {
124                    graph
125                        .node(pred_id)
126                        .map(|n| matches!(&n.op, Operation::Constant(_)))
127                        .unwrap_or(false)
128                })
129            }
130            _ => false,
131        }
132    }
133
134    fn fold_node(
135        &self,
136        graph: &mut ComputationGraph,
137        node_id: NodeId,
138        node: &Node,
139    ) -> JitResult<()> {
140        use crate::graph::{ConstantInfo, ConstantValue, Operation};
141
142        // Check if all inputs to this node are constants
143        let predecessors: Vec<_> = graph
144            .graph
145            .edges_directed(node_id, petgraph::Direction::Incoming)
146            .map(|edge_ref| edge_ref.source())
147            .collect();
148
149        if predecessors.is_empty() {
150            return Ok(());
151        }
152
153        // Check if all predecessors are constant nodes
154        let mut all_constants = true;
155        let mut constant_values = Vec::new();
156
157        for pred_id in &predecessors {
158            if let Some(pred_node) = graph.node(*pred_id) {
159                match &pred_node.op {
160                    Operation::Constant(const_info) => {
161                        constant_values.push(const_info.value.clone());
162                    }
163                    _ => {
164                        all_constants = false;
165                        break;
166                    }
167                }
168            } else {
169                all_constants = false;
170                break;
171            }
172        }
173
174        if !all_constants {
175            return Ok(());
176        }
177
178        // Try to fold the operation
179        let folded_value = match (&node.op, constant_values.as_slice()) {
180            // Binary operations on scalar constants
181            (Operation::Add, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
182                Some(ConstantValue::Scalar(a + b))
183            }
184            (Operation::Sub, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
185                Some(ConstantValue::Scalar(a - b))
186            }
187            (Operation::Mul, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
188                Some(ConstantValue::Scalar(a * b))
189            }
190            (Operation::Div, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
191                if *b != 0.0 {
192                    Some(ConstantValue::Scalar(a / b))
193                } else {
194                    None // Division by zero
195                }
196            }
197
198            // Unary operations on scalar constants
199            (Operation::Neg, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(-a)),
200            (Operation::Abs, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(a.abs())),
201            (Operation::Sqrt, [ConstantValue::Scalar(a)]) => {
202                if *a >= 0.0 {
203                    Some(ConstantValue::Scalar(a.sqrt()))
204                } else {
205                    None // Sqrt of negative number
206                }
207            }
208            (Operation::Exp, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(a.exp())),
209            (Operation::Log, [ConstantValue::Scalar(a)]) => {
210                if *a > 0.0 {
211                    Some(ConstantValue::Scalar(a.ln()))
212                } else {
213                    None // Log of non-positive number
214                }
215            }
216
217            // Integer operations
218            (Operation::Add, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
219                Some(ConstantValue::IntScalar(a + b))
220            }
221            (Operation::Sub, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
222                Some(ConstantValue::IntScalar(a - b))
223            }
224            (Operation::Mul, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
225                Some(ConstantValue::IntScalar(a * b))
226            }
227
228            _ => None, // Operation not supported for constant folding
229        };
230
231        // If we successfully folded the operation, replace the node with a constant
232        if let Some(folded) = folded_value {
233            if let Some(node_mut) = graph.node_mut(node_id) {
234                node_mut.op = Operation::Constant(ConstantInfo { value: folded });
235
236                // Remove edges from predecessors since this is now a constant
237                let edges_to_remove: Vec<_> = graph
238                    .graph
239                    .edges_directed(node_id, petgraph::Direction::Incoming)
240                    .map(|edge| edge.id())
241                    .collect();
242
243                for edge_id in edges_to_remove {
244                    graph.graph.remove_edge(edge_id);
245                }
246            }
247        }
248
249        Ok(())
250    }
251}
252
253/// Common subexpression elimination - reuse identical computations
254pub struct CommonSubexpressionElimination;
255
256impl OptimizationPass for CommonSubexpressionElimination {
257    fn name(&self) -> &str {
258        "CommonSubexpressionElimination"
259    }
260
261    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
262        // Build a map of operation signatures to nodes
263        let mut op_map: HashMap<String, Vec<NodeId>> = HashMap::new();
264
265        for (node_id, node) in graph.nodes() {
266            let signature = self.compute_signature(&graph, node_id, node);
267            op_map.entry(signature).or_default().push(node_id);
268        }
269
270        // Find and eliminate duplicates
271        for (_, nodes) in op_map {
272            if nodes.len() > 1 {
273                // Keep the first node, redirect others
274                let keep = nodes[0];
275                for &duplicate in &nodes[1..] {
276                    self.redirect_node(&mut graph, duplicate, keep)?;
277                }
278            }
279        }
280
281        Ok(graph)
282    }
283}
284
285impl CommonSubexpressionElimination {
286    fn compute_signature(&self, graph: &ComputationGraph, node_id: NodeId, node: &Node) -> String {
287        // Create a signature based on operation and inputs
288        let mut sig = format!("{:?}", node.op);
289
290        // Add input signatures in sorted order for consistency
291        let mut inputs: Vec<_> = graph.predecessors(node_id).collect();
292        inputs.sort();
293
294        for input in inputs {
295            sig.push_str(&format!("_in{:?}", input));
296        }
297
298        sig
299    }
300
301    fn redirect_node(
302        &self,
303        graph: &mut ComputationGraph,
304        from: NodeId,
305        to: NodeId,
306    ) -> JitResult<()> {
307        // Redirect all edges from 'from' to 'to'
308        let successors: Vec<_> = graph.graph.neighbors(from).collect();
309
310        for succ in successors {
311            // Find the edge
312            if let Some(edge) = graph.graph.find_edge(from, succ) {
313                let edge_data = graph.graph[edge].clone();
314                graph.graph.remove_edge(edge);
315                graph.graph.add_edge(to, succ, edge_data);
316            }
317        }
318
319        // Remove the duplicate node
320        graph.graph.remove_node(from);
321
322        Ok(())
323    }
324}
325
326/// Algebraic simplification - apply mathematical identities
327pub struct AlgebraicSimplification;
328
329impl OptimizationPass for AlgebraicSimplification {
330    fn name(&self) -> &str {
331        "AlgebraicSimplification"
332    }
333
334    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
335        let nodes: Vec<_> = graph.graph.node_indices().collect();
336
337        for node_id in nodes {
338            if let Some(node) = graph.node(node_id).cloned() {
339                if let Some(simplified) = self.simplify(&graph, node_id, &node) {
340                    // Replace node with simplified version
341                    if let Some(node_mut) = graph.node_mut(node_id) {
342                        node_mut.op = simplified;
343                    }
344                }
345            }
346        }
347
348        Ok(graph)
349    }
350}
351
352impl AlgebraicSimplification {
353    fn simplify(
354        &self,
355        graph: &ComputationGraph,
356        node_id: NodeId,
357        node: &Node,
358    ) -> Option<Operation> {
359        match &node.op {
360            Operation::Mul => {
361                let preds: Vec<_> = graph.predecessors(node_id).collect();
362                if preds.len() == 2 {
363                    // Check for multiply by constants
364                    if let (Some(left), Some(right)) = (graph.node(preds[0]), graph.node(preds[1]))
365                    {
366                        match (&left.op, &right.op) {
367                            // x * 0 = 0
368                            (Operation::Constant(c), _) | (_, Operation::Constant(c)) => {
369                                if let crate::graph::ConstantValue::Scalar(v) = &c.value {
370                                    if *v == 0.0 {
371                                        return Some(Operation::Constant(
372                                            crate::graph::ConstantInfo {
373                                                value: crate::graph::ConstantValue::Scalar(0.0),
374                                            },
375                                        ));
376                                    }
377                                    // x * 1 = x is handled by constant folding pass
378                                }
379                            }
380                            // x * x = x^2 (could be optimized further)
381                            _ if preds[0] == preds[1] => {
382                                return Some(Operation::Pow);
383                            }
384                            _ => {}
385                        }
386                    }
387                }
388            }
389            Operation::Add => {
390                let preds: Vec<_> = graph.predecessors(node_id).collect();
391                if preds.len() == 2 {
392                    if let (Some(left), Some(right)) = (graph.node(preds[0]), graph.node(preds[1]))
393                    {
394                        // x + 0 = x
395                        match (&left.op, &right.op) {
396                            (Operation::Constant(c), _) | (_, Operation::Constant(c)) => {
397                                if let crate::graph::ConstantValue::Scalar(v) = &c.value {
398                                    if *v == 0.0 {
399                                        // x + 0 = x (handled by rewriting the graph)
400                                        // Return None to trigger graph rewriting
401                                    }
402                                }
403                            }
404                            _ => {}
405                        }
406                    }
407                }
408            }
409            Operation::Sub => {
410                let preds: Vec<_> = graph.predecessors(node_id).collect();
411                if preds.len() == 2 && preds[0] == preds[1] {
412                    // x - x = 0
413                    return Some(Operation::Constant(crate::graph::ConstantInfo {
414                        value: crate::graph::ConstantValue::Scalar(0.0),
415                    }));
416                }
417            }
418            Operation::Div => {
419                let preds: Vec<_> = graph.predecessors(node_id).collect();
420                if preds.len() == 2 {
421                    if let Some(right) = graph.node(preds[1]) {
422                        // x / 1 = x
423                        if let Operation::Constant(c) = &right.op {
424                            if let crate::graph::ConstantValue::Scalar(v) = &c.value {
425                                if *v == 1.0 {
426                                    // x / 1 = x (handled by graph rewriting)
427                                    return None;
428                                }
429                            }
430                        }
431                    }
432                    // x / x = 1
433                    if preds[0] == preds[1] {
434                        return Some(Operation::Constant(crate::graph::ConstantInfo {
435                            value: crate::graph::ConstantValue::Scalar(1.0),
436                        }));
437                    }
438                }
439            }
440            Operation::Pow => {
441                let preds: Vec<_> = graph.predecessors(node_id).collect();
442                if preds.len() == 2 {
443                    if let Some(right) = graph.node(preds[1]) {
444                        if let Operation::Constant(c) = &right.op {
445                            if let crate::graph::ConstantValue::Scalar(exp) = &c.value {
446                                match *exp {
447                                    // x^0 = 1
448                                    0.0 => {
449                                        return Some(Operation::Constant(
450                                            crate::graph::ConstantInfo {
451                                                value: crate::graph::ConstantValue::Scalar(1.0),
452                                            },
453                                        ))
454                                    }
455                                    // x^1 = x
456                                    1.0 => return None,
457                                    // x^2 can be optimized to x*x
458                                    2.0 => return Some(Operation::Mul),
459                                    _ => {}
460                                }
461                            }
462                        }
463                    }
464                }
465            }
466            Operation::Sqrt => {
467                let preds: Vec<_> = graph.predecessors(node_id).collect();
468                if preds.len() == 1 {
469                    if let Some(pred) = graph.node(preds[0]) {
470                        // sqrt(x^2) = |x| (could be x if x >= 0)
471                        if let Operation::Pow = &pred.op {
472                            let pow_preds: Vec<_> = graph.predecessors(preds[0]).collect();
473                            if pow_preds.len() == 2 {
474                                if let Some(exp_node) = graph.node(pow_preds[1]) {
475                                    if let Operation::Constant(c) = &exp_node.op {
476                                        if let crate::graph::ConstantValue::Scalar(2.0) = &c.value {
477                                            return Some(Operation::Abs);
478                                        }
479                                    }
480                                }
481                            }
482                        }
483                    }
484                }
485            }
486            Operation::Log => {
487                let preds: Vec<_> = graph.predecessors(node_id).collect();
488                if preds.len() == 1 {
489                    if let Some(pred) = graph.node(preds[0]) {
490                        // log(exp(x)) = x
491                        if let Operation::Exp = &pred.op {
492                            return None; // Handle by graph rewriting
493                        }
494                        // log(1) = 0
495                        if let Operation::Constant(c) = &pred.op {
496                            if let crate::graph::ConstantValue::Scalar(1.0) = &c.value {
497                                return Some(Operation::Constant(crate::graph::ConstantInfo {
498                                    value: crate::graph::ConstantValue::Scalar(0.0),
499                                }));
500                            }
501                        }
502                    }
503                }
504            }
505            Operation::Exp => {
506                let preds: Vec<_> = graph.predecessors(node_id).collect();
507                if preds.len() == 1 {
508                    if let Some(pred) = graph.node(preds[0]) {
509                        // exp(log(x)) = x
510                        if let Operation::Log = &pred.op {
511                            return None; // Handle by graph rewriting
512                        }
513                        // exp(0) = 1
514                        if let Operation::Constant(c) = &pred.op {
515                            if let crate::graph::ConstantValue::Scalar(0.0) = &c.value {
516                                return Some(Operation::Constant(crate::graph::ConstantInfo {
517                                    value: crate::graph::ConstantValue::Scalar(1.0),
518                                }));
519                            }
520                        }
521                    }
522                }
523            }
524            // Double negation: -(-x) = x
525            Operation::Neg => {
526                let preds: Vec<_> = graph.predecessors(node_id).collect();
527                if preds.len() == 1 {
528                    if let Some(pred) = graph.node(preds[0]) {
529                        if let Operation::Neg = &pred.op {
530                            return None; // Handle by graph rewriting
531                        }
532                    }
533                }
534            }
535            _ => {}
536        }
537
538        None
539    }
540}
541
542/// Strength reduction - replace expensive operations with cheaper ones
543pub struct StrengthReduction;
544
545impl OptimizationPass for StrengthReduction {
546    fn name(&self) -> &str {
547        "StrengthReduction"
548    }
549
550    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
551        let nodes: Vec<_> = graph.graph.node_indices().collect();
552
553        for node_id in nodes {
554            if let Some(node) = graph.node(node_id).cloned() {
555                if let Some(reduced) = self.reduce_strength(&graph, node_id, &node) {
556                    if let Some(node_mut) = graph.node_mut(node_id) {
557                        node_mut.op = reduced;
558                    }
559                }
560            }
561        }
562
563        Ok(graph)
564    }
565}
566
567impl StrengthReduction {
568    fn reduce_strength(
569        &self,
570        graph: &ComputationGraph,
571        node_id: NodeId,
572        node: &Node,
573    ) -> Option<Operation> {
574        match &node.op {
575            Operation::Pow => {
576                // Check for power of 2
577                let preds: Vec<_> = graph.predecessors(node_id).collect();
578                if preds.len() == 2 {
579                    let is_two = graph
580                        .node(preds[1])
581                        .and_then(|node| match &node.op {
582                            Operation::Constant(c) => match &c.value {
583                                crate::graph::ConstantValue::Scalar(v) => Some(*v == 2.0),
584                                _ => None,
585                            },
586                            _ => None,
587                        })
588                        .unwrap_or(false);
589
590                    if is_two {
591                        // x^2 -> x*x (multiply is usually faster)
592                        return Some(Operation::Mul);
593                    }
594                }
595            }
596            Operation::Div => {
597                // Check for division by constant
598                // Could be replaced with multiplication by reciprocal
599            }
600            _ => {}
601        }
602
603        None
604    }
605}
606
607/// Layout optimization - optimize memory access patterns
608pub struct LayoutOptimization;
609
610impl OptimizationPass for LayoutOptimization {
611    fn name(&self) -> &str {
612        "LayoutOptimization"
613    }
614
615    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
616        // Analyze memory access patterns and insert transpose operations where beneficial
617        // This optimization is particularly useful for conv2d and matmul operations
618
619        let nodes: Vec<_> = graph.graph.node_indices().collect();
620        let mut transpose_insertions = Vec::new();
621
622        for node_id in nodes {
623            if let Some(node) = graph.node(node_id) {
624                match &node.op {
625                    Operation::Conv2d(info) => {
626                        // Check if input layout would benefit from NCHW -> NHWC conversion
627                        // This can improve cache locality for certain convolution patterns
628                        if self.should_convert_layout_for_conv(info, &node.output_shape) {
629                            // Mark for layout conversion
630                            transpose_insertions.push((node_id, LayoutChange::NCHWtoNHWC));
631                        }
632                    }
633                    Operation::MatMul | Operation::BatchMatMul => {
634                        // Check if transposing inputs would reduce memory access stride
635                        let preds: Vec<_> = graph.predecessors(node_id).collect();
636                        if preds.len() == 2 {
637                            // Analyze if transposing either input would be beneficial
638                            if let (Some(left), Some(right)) =
639                                (graph.node(preds[0]), graph.node(preds[1]))
640                            {
641                                if self.should_transpose_for_matmul(
642                                    &left.output_shape,
643                                    &right.output_shape,
644                                ) {
645                                    transpose_insertions
646                                        .push((node_id, LayoutChange::TransposeMatmul));
647                                }
648                            }
649                        }
650                    }
651                    _ => {}
652                }
653            }
654        }
655
656        // Apply layout changes
657        for (node_id, change) in transpose_insertions {
658            match change {
659                LayoutChange::NCHWtoNHWC => {
660                    // Insert transpose operations to convert layout
661                    self.insert_layout_conversion(&mut graph, node_id, vec![0, 2, 3, 1])?;
662                }
663                LayoutChange::TransposeMatmul => {
664                    // Handle matmul-specific optimizations
665                    self.optimize_matmul_layout(&mut graph, node_id)?;
666                }
667            }
668        }
669
670        Ok(graph)
671    }
672}
673
674#[derive(Debug)]
675enum LayoutChange {
676    NCHWtoNHWC,
677    TransposeMatmul,
678}
679
680impl LayoutOptimization {
681    /// Check if converting conv2d layout would be beneficial
682    fn should_convert_layout_for_conv(&self, info: &Conv2dInfo, output_shape: &Shape) -> bool {
683        // Heuristic: prefer NHWC for depthwise and small kernel convolutions
684        // as they have better cache locality with channels-last format
685
686        // Depthwise convolution benefits from NHWC
687        if info.groups == info.in_channels && info.in_channels == info.out_channels {
688            return true;
689        }
690
691        // Small kernels (1x1, 3x3) often benefit from NHWC
692        if info.kernel_size.0 <= 3 && info.kernel_size.1 <= 3 {
693            // Also consider output spatial dimensions
694            if output_shape.ndim() >= 4 {
695                let height = output_shape.dims()[2];
696                let width = output_shape.dims()[3];
697                // Prefer NHWC for larger spatial dimensions
698                return height * width > 1024;
699            }
700        }
701
702        false
703    }
704
705    /// Check if transposing matmul inputs would be beneficial
706    fn should_transpose_for_matmul(&self, left_shape: &Shape, right_shape: &Shape) -> bool {
707        // Check if the matrices are already in optimal layout for matmul
708        // For A @ B, we want A to be row-major and B to be column-major
709        // This minimizes cache misses during computation
710
711        if left_shape.ndim() < 2 || right_shape.ndim() < 2 {
712            return false;
713        }
714
715        // Get the matrix dimensions
716        let left_rows = left_shape.dims()[left_shape.ndim() - 2];
717        let left_cols = left_shape.dims()[left_shape.ndim() - 1];
718        let right_rows = right_shape.dims()[right_shape.ndim() - 2];
719
720        // If the inner dimension is small, transposing might not be worth it
721        if left_cols < 16 || right_rows < 16 {
722            return false;
723        }
724
725        // Check stride patterns (simplified heuristic)
726        // If left matrix is tall and thin, or right matrix is short and wide,
727        // transposing might help
728        let left_aspect = left_rows as f32 / left_cols as f32;
729        let right_aspect = right_rows as f32 / right_shape.dims()[right_shape.ndim() - 1] as f32;
730
731        left_aspect > 4.0 || right_aspect < 0.25
732    }
733
734    /// Insert layout conversion (transpose) operations
735    fn insert_layout_conversion(
736        &self,
737        graph: &mut ComputationGraph,
738        node_id: NodeId,
739        dims: Vec<usize>,
740    ) -> JitResult<()> {
741        // Get predecessors of the node
742        let preds: Vec<_> = graph.predecessors(node_id).collect();
743
744        for pred_id in preds {
745            // Insert transpose between predecessor and current node
746            if let Some(pred_node) = graph.node(pred_id).cloned() {
747                // Create transpose node
748                let mut transpose_node = Node::new(
749                    Operation::Transpose { dims: dims.clone() },
750                    format!("{}_transpose", pred_node.name),
751                );
752                transpose_node = transpose_node
753                    .with_output_shapes(vec![Some(
754                        self.compute_transposed_shape(&pred_node.output_shape, &dims),
755                    )])
756                    .with_dtypes(vec![pred_node.dtype])
757                    .with_device(pred_node.device);
758                transpose_node.inputs = vec![pred_id];
759                transpose_node.is_output = false;
760
761                let transpose_id = graph.add_node(transpose_node);
762
763                // Rewire edges: pred -> transpose -> node
764                if let Some(edge) = graph.graph.find_edge(pred_id, node_id) {
765                    let edge_data = graph.graph[edge].clone();
766                    graph.graph.remove_edge(edge);
767                    graph.add_edge(pred_id, transpose_id, Edge::default());
768                    graph.add_edge(transpose_id, node_id, edge_data);
769                }
770            }
771        }
772
773        Ok(())
774    }
775
776    /// Optimize matmul layout by potentially transposing inputs
777    fn optimize_matmul_layout(
778        &self,
779        graph: &mut ComputationGraph,
780        node_id: NodeId,
781    ) -> JitResult<()> {
782        let preds: Vec<_> = graph.predecessors(node_id).collect();
783
784        if preds.len() == 2 {
785            // For now, just mark that we could optimize this
786            // In a real implementation, we would analyze the specific
787            // matrix dimensions and access patterns
788
789            // Could insert transposes to convert to optimal layout
790            // For example, ensuring the right matrix is column-major
791            // by transposing it before the matmul
792        }
793
794        Ok(())
795    }
796
797    /// Compute output shape after transpose
798    fn compute_transposed_shape(&self, shape: &Shape, dims: &[usize]) -> Shape {
799        let mut new_dims = vec![0; shape.ndim()];
800        let old_dims = shape.dims();
801
802        for (i, &dim) in dims.iter().enumerate() {
803            if dim < old_dims.len() {
804                new_dims[i] = old_dims[dim];
805            }
806        }
807
808        Shape::new(new_dims)
809    }
810}
811
812/// Auto-vectorization pass - converts element-wise operations to vector operations
813pub struct AutoVectorization;
814
815impl OptimizationPass for AutoVectorization {
816    fn name(&self) -> &str {
817        "AutoVectorization"
818    }
819
820    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
821        let nodes: Vec<_> = graph.graph.node_indices().collect();
822
823        for node_id in nodes {
824            if let Some(node) = graph.node(node_id).cloned() {
825                if self.can_vectorize(&graph, node_id, &node) {
826                    self.vectorize_node(&mut graph, node_id, &node)?;
827                }
828            }
829        }
830
831        Ok(graph)
832    }
833}
834
835impl AutoVectorization {
836    /// Check if a node can be vectorized
837    fn can_vectorize(&self, _graph: &ComputationGraph, _node_id: NodeId, node: &Node) -> bool {
838        // Check for element-wise operations on large tensors
839        match &node.op {
840            Operation::Add
841            | Operation::Sub
842            | Operation::Mul
843            | Operation::Div
844            | Operation::Relu
845            | Operation::Sigmoid
846            | Operation::Tanh
847            | Operation::Silu => {
848                // Check if the tensor is large enough to benefit from vectorization
849                let total_elements = node.output_shape.numel();
850                total_elements >= 1024 // Vectorize if >= 1024 elements
851            }
852            _ => false,
853        }
854    }
855
856    /// Apply vectorization to a node
857    fn vectorize_node(
858        &self,
859        graph: &mut ComputationGraph,
860        node_id: NodeId,
861        _node: &Node,
862    ) -> JitResult<()> {
863        if let Some(node_mut) = graph.node_mut(node_id) {
864            // Add vectorization hint to node attributes
865            node_mut
866                .attrs
867                .insert("vectorize".to_string(), crate::graph::Attribute::Bool(true));
868            node_mut.attrs.insert(
869                "vector_width".to_string(),
870                crate::graph::Attribute::Int(8), // 8-wide SIMD
871            );
872        }
873        Ok(())
874    }
875}
876
877/// Auto-parallelization pass - identifies opportunities for parallel execution
878pub struct AutoParallelization;
879
880impl OptimizationPass for AutoParallelization {
881    fn name(&self) -> &str {
882        "AutoParallelization"
883    }
884
885    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
886        // Find independent subgraphs that can be executed in parallel
887        let parallel_groups = self.find_parallel_groups(&graph)?;
888
889        // Mark nodes with parallelization hints
890        for (group_id, node_ids) in parallel_groups.iter().enumerate() {
891            for &node_id in node_ids {
892                if let Some(node_mut) = graph.node_mut(node_id) {
893                    node_mut.attrs.insert(
894                        "parallel_group".to_string(),
895                        crate::graph::Attribute::Int(group_id as i64),
896                    );
897                    node_mut.attrs.insert(
898                        "can_parallelize".to_string(),
899                        crate::graph::Attribute::Bool(true),
900                    );
901                }
902            }
903        }
904
905        Ok(graph)
906    }
907}
908
909impl AutoParallelization {
910    /// Find groups of nodes that can be executed in parallel
911    fn find_parallel_groups(&self, graph: &ComputationGraph) -> JitResult<Vec<Vec<NodeId>>> {
912        let mut parallel_groups = Vec::new();
913        let mut visited = HashSet::new();
914
915        // Topological sort to process nodes in dependency order
916        let topo_order = graph
917            .topological_sort()
918            .map_err(|e| crate::JitError::GraphError(e.to_string()))?;
919
920        for &node_id in &topo_order {
921            if visited.contains(&node_id) {
922                continue;
923            }
924
925            // Find all nodes at this "level" (same depth in dependency chain)
926            let mut current_group = Vec::new();
927
928            // Find nodes that have no dependencies on unvisited nodes
929            if self.can_execute_now(graph, node_id, &visited) {
930                current_group.push(node_id);
931                visited.insert(node_id);
932
933                // Look for other nodes at the same level
934                for &other_id in &topo_order {
935                    if other_id != node_id
936                        && !visited.contains(&other_id)
937                        && self.can_execute_now(graph, other_id, &visited)
938                        && self.can_execute_parallel(graph, node_id, other_id)
939                    {
940                        current_group.push(other_id);
941                        visited.insert(other_id);
942                    }
943                }
944            }
945
946            if current_group.len() > 1 {
947                parallel_groups.push(current_group);
948            }
949        }
950
951        Ok(parallel_groups)
952    }
953
954    /// Check if a node can execute now (all dependencies satisfied)
955    fn can_execute_now(
956        &self,
957        graph: &ComputationGraph,
958        node_id: NodeId,
959        visited: &HashSet<NodeId>,
960    ) -> bool {
961        graph
962            .predecessors(node_id)
963            .all(|pred| visited.contains(&pred))
964    }
965
966    /// Check if two nodes can be executed in parallel (no dependencies)
967    fn can_execute_parallel(&self, graph: &ComputationGraph, node1: NodeId, node2: NodeId) -> bool {
968        // Check if there's any dependency path between the nodes
969        !self.has_dependency_path(graph, node1, node2)
970            && !self.has_dependency_path(graph, node2, node1)
971    }
972
973    /// Check if there's a dependency path from node1 to node2
974    fn has_dependency_path(&self, graph: &ComputationGraph, from: NodeId, to: NodeId) -> bool {
975        let mut visited = HashSet::new();
976        let mut stack = vec![from];
977
978        while let Some(current) = stack.pop() {
979            if current == to {
980                return true;
981            }
982
983            if visited.insert(current) {
984                for successor in graph.successors(current) {
985                    stack.push(successor);
986                }
987            }
988        }
989
990        false
991    }
992}
993
994/// Cache-aware optimization - improve data locality and cache utilization
995pub struct CacheAwareOptimization {
996    /// Target cache line size in bytes
997    cache_line_size: usize,
998    /// L1 cache size in bytes
999    l1_cache_size: usize,
1000    /// L2 cache size in bytes
1001    l2_cache_size: usize,
1002}
1003
1004impl Default for CacheAwareOptimization {
1005    fn default() -> Self {
1006        Self {
1007            cache_line_size: 64,       // Typical cache line size
1008            l1_cache_size: 32 * 1024,  // 32 KB L1
1009            l2_cache_size: 256 * 1024, // 256 KB L2
1010        }
1011    }
1012}
1013
1014impl CacheAwareOptimization {
1015    pub fn new(cache_line_size: usize, l1_cache_size: usize, l2_cache_size: usize) -> Self {
1016        Self {
1017            cache_line_size,
1018            l1_cache_size,
1019            l2_cache_size,
1020        }
1021    }
1022
1023    /// Calculate optimal tile size for a given dimension
1024    fn calculate_tile_size(&self, dimension_size: usize, element_size: usize) -> usize {
1025        // Aim to fit tiles in L1 cache
1026        let elements_in_l1 = self.l1_cache_size / element_size;
1027        let sqrt_elements = (elements_in_l1 as f64).sqrt() as usize;
1028
1029        // Use power of 2 for better alignment
1030        let mut tile_size = 1;
1031        while tile_size * 2 <= sqrt_elements && tile_size * 2 <= dimension_size {
1032            tile_size *= 2;
1033        }
1034
1035        tile_size.max(8).min(dimension_size) // At least 8, at most dimension size
1036    }
1037
1038    /// Reorder operations for better cache locality
1039    fn reorder_for_locality(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1040        let mut reordered = 0;
1041
1042        // Find groups of operations that access the same data
1043        let access_groups = self.analyze_data_access_patterns(graph)?;
1044
1045        // For each group, try to schedule operations close together
1046        for group in access_groups {
1047            if group.len() > 1 {
1048                // Operations in the same group should be executed consecutively
1049                // This reduces cache misses by keeping data hot
1050                reordered += group.len();
1051            }
1052        }
1053
1054        Ok(reordered)
1055    }
1056
1057    /// Analyze data access patterns in the graph
1058    fn analyze_data_access_patterns(
1059        &self,
1060        graph: &ComputationGraph,
1061    ) -> JitResult<Vec<Vec<NodeId>>> {
1062        let mut groups = Vec::new();
1063        let mut visited = HashSet::new();
1064
1065        // Group nodes that access similar memory regions
1066        for (node_id, node) in graph.nodes() {
1067            if visited.contains(&node_id) {
1068                continue;
1069            }
1070
1071            let mut group = vec![node_id];
1072            visited.insert(node_id);
1073
1074            // Find related nodes that access similar data
1075            for (other_id, other_node) in graph.nodes() {
1076                if visited.contains(&other_id) {
1077                    continue;
1078                }
1079
1080                if self.have_similar_access_pattern(node, other_node) {
1081                    group.push(other_id);
1082                    visited.insert(other_id);
1083                }
1084            }
1085
1086            if group.len() > 1 {
1087                groups.push(group);
1088            }
1089        }
1090
1091        Ok(groups)
1092    }
1093
1094    /// Check if two nodes have similar data access patterns
1095    fn have_similar_access_pattern(&self, node1: &Node, node2: &Node) -> bool {
1096        // Nodes with similar operations likely access data similarly
1097        match (&node1.op, &node2.op) {
1098            (Operation::MatMul, Operation::MatMul) => true,
1099            (Operation::Conv2d(_), Operation::Conv2d(_)) => true,
1100            (Operation::Add, Operation::Add)
1101            | (Operation::Sub, Operation::Sub)
1102            | (Operation::Mul, Operation::Mul)
1103            | (Operation::Div, Operation::Div) => {
1104                // Element-wise operations with similar shapes
1105                node1.output_shape.dims() == node2.output_shape.dims()
1106            }
1107            _ => false,
1108        }
1109    }
1110
1111    /// Apply loop tiling to large operations
1112    fn apply_loop_tiling(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1113        let mut tiled = 0;
1114
1115        for (node_id, node) in graph.nodes() {
1116            match &node.op {
1117                Operation::MatMul => {
1118                    // MatMul benefits greatly from tiling
1119                    if let Some(shape) = node.output_shapes.first().and_then(|s| s.as_ref()) {
1120                        let dims = shape.dims();
1121                        if dims.len() >= 2 {
1122                            let m = dims[dims.len() - 2];
1123                            let n = dims[dims.len() - 1];
1124
1125                            // Calculate tile sizes
1126                            let tile_m = self.calculate_tile_size(m, 4); // Assuming f32
1127                            let tile_n = self.calculate_tile_size(n, 4);
1128
1129                            // Store tiling hint in node attributes
1130                            log::debug!(
1131                                "MatMul node {:?}: suggested tiling {}x{}",
1132                                node_id,
1133                                tile_m,
1134                                tile_n
1135                            );
1136                            tiled += 1;
1137                        }
1138                    }
1139                }
1140
1141                Operation::Conv2d(conv_info) => {
1142                    // Convolutions also benefit from tiling
1143                    let tile_size = self.calculate_tile_size(conv_info.kernel_size.0, 4);
1144                    log::debug!(
1145                        "Conv2d node {:?}: suggested tile size {}",
1146                        node_id,
1147                        tile_size
1148                    );
1149                    tiled += 1;
1150                }
1151
1152                _ => {}
1153            }
1154        }
1155
1156        Ok(tiled)
1157    }
1158
1159    /// Add prefetch hints for predictable access patterns
1160    fn add_prefetch_hints(&self, graph: &ComputationGraph) -> JitResult<usize> {
1161        let mut hints_added = 0;
1162
1163        // Analyze sequential access patterns
1164        for (node_id, node) in graph.nodes() {
1165            // Operations that scan through data sequentially
1166            match &node.op {
1167                Operation::MatMul | Operation::Conv2d(_) => {
1168                    // These operations have predictable access patterns
1169                    // Add prefetch hints for next cache lines
1170                    log::debug!("Node {:?}: adding software prefetch hints", node_id);
1171                    hints_added += 1;
1172                }
1173                _ => {}
1174            }
1175        }
1176
1177        Ok(hints_added)
1178    }
1179
1180    /// Optimize for spatial and temporal locality
1181    fn optimize_locality(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1182        let mut optimized = 0;
1183
1184        // Ensure operations that reuse data are scheduled close together
1185        optimized += self.reorder_for_locality(graph)?;
1186
1187        // Apply loop tiling for better cache utilization
1188        optimized += self.apply_loop_tiling(graph)?;
1189
1190        // Add prefetch hints for predictable accesses
1191        optimized += self.add_prefetch_hints(graph)?;
1192
1193        Ok(optimized)
1194    }
1195}
1196
1197impl OptimizationPass for CacheAwareOptimization {
1198    fn name(&self) -> &str {
1199        "cache_aware"
1200    }
1201
1202    fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
1203        let optimizations = self.optimize_locality(&mut graph)?;
1204
1205        log::info!(
1206            "Cache-aware optimization: {} improvements applied",
1207            optimizations
1208        );
1209
1210        Ok(graph)
1211    }
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216    use super::*;
1217    use crate::graph::Edge;
1218    use torsh_core::{DType, DeviceType};
1219
1220    #[test]
1221    fn test_optimizer_creation() {
1222        let optimizer = GraphOptimizer::new();
1223        assert_eq!(optimizer.passes.len(), 9); // Updated to include CacheAwareOptimization
1224    }
1225
1226    #[test]
1227    fn test_dead_code_elimination() {
1228        let mut graph = ComputationGraph::new();
1229
1230        // Create a simple graph with dead code
1231        let input = graph.add_node(
1232            Node::new(Operation::Input, "input".to_string())
1233                .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1234                .with_dtypes(vec![DType::F32])
1235                .with_device(DeviceType::Cpu),
1236        );
1237
1238        let dead = graph.add_node(
1239            Node::new(Operation::Relu, "dead".to_string())
1240                .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1241                .with_dtypes(vec![DType::F32])
1242                .with_device(DeviceType::Cpu),
1243        );
1244
1245        let output = graph.add_node(
1246            Node::new(Operation::Neg, "output".to_string())
1247                .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1248                .with_dtypes(vec![DType::F32])
1249                .with_device(DeviceType::Cpu),
1250        );
1251
1252        graph.add_edge(input, output, Edge::default());
1253        graph.add_edge(input, dead, Edge::default()); // Dead branch
1254
1255        graph.add_input(input);
1256        graph.add_output(output);
1257
1258        let dce = DeadCodeElimination;
1259        let optimized = dce.apply(graph).unwrap();
1260
1261        // After DCE, we should only have nodes reachable from outputs
1262        // The graph should have exactly 2 nodes (input and output)
1263        assert_eq!(
1264            optimized.graph.node_count(),
1265            2,
1266            "Should have exactly 2 nodes after DCE"
1267        );
1268
1269        // Verify the remaining nodes have the expected operations
1270        let remaining_nodes: Vec<_> = optimized
1271            .graph
1272            .node_indices()
1273            .filter_map(|idx| optimized.graph.node_weight(idx))
1274            .collect();
1275
1276        // Should have an input node and a neg node
1277        let has_input = remaining_nodes
1278            .iter()
1279            .any(|n| matches!(&n.op, Operation::Input));
1280        let has_neg = remaining_nodes
1281            .iter()
1282            .any(|n| matches!(&n.op, Operation::Neg));
1283
1284        assert!(has_input, "Input node should still exist");
1285        assert!(has_neg, "Output (neg) node should still exist");
1286
1287        // The sigmoid (dead) node should be gone
1288        let has_sigmoid = remaining_nodes
1289            .iter()
1290            .any(|n| matches!(&n.op, Operation::Sigmoid));
1291        assert!(!has_sigmoid, "Dead (sigmoid) node should be removed");
1292    }
1293}