Skip to main content

scirs2_autograd/optimization/
mod.rs

1//! Graph optimization and expression simplification for computation graphs
2//!
3//! This module provides various optimization techniques for computation graphs,
4//! including expression simplification, common subexpression elimination,
5//! constant folding, and graph-level transformations.
6
7use crate::graph::{Graph, TensorID};
8use crate::tensor::TensorInternal;
9use crate::Float;
10use std::collections::{HashMap, HashSet};
11
12pub mod constant_folding;
13pub mod expression_simplification;
14// pub mod graph_rewriting;
15pub mod loop_fusion;
16pub mod memory_optimization;
17
18// v0.2.0: Enhanced optimizations
19pub mod cse;
20pub mod fusion;
21
22/// Graph optimization configuration
23#[derive(Debug, Clone)]
24pub struct OptimizationConfig {
25    /// Enable constant folding
26    pub constant_folding: bool,
27    /// Enable common subexpression elimination
28    pub cse: bool,
29    /// Enable expression simplification
30    pub expression_simplification: bool,
31    /// Enable dead code elimination
32    pub dead_code_elimination: bool,
33    /// Enable operation fusion
34    pub operation_fusion: bool,
35    /// Enable memory layout optimization
36    pub memory_optimization: bool,
37    /// Maximum optimization passes
38    pub max_passes: usize,
39    /// Optimization level (0-3)
40    pub level: OptimizationLevel,
41}
42
43impl Default for OptimizationConfig {
44    fn default() -> Self {
45        Self {
46            constant_folding: true,
47            cse: true,
48            expression_simplification: true,
49            dead_code_elimination: true,
50            operation_fusion: false, // More aggressive optimization
51            memory_optimization: true,
52            max_passes: 5,
53            level: OptimizationLevel::Standard,
54        }
55    }
56}
57
58/// Optimization levels
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum OptimizationLevel {
61    /// No optimizations
62    None,
63    /// Basic optimizations (constant folding, DCE)
64    Basic,
65    /// Standard optimizations (basic + CSE, expression simplification)
66    Standard,
67    /// Aggressive optimizations (standard + operation fusion, advanced transformations)
68    Aggressive,
69}
70
71impl OptimizationLevel {
72    /// Get the default configuration for this optimization level
73    pub fn config(self) -> OptimizationConfig {
74        match self {
75            OptimizationLevel::None => OptimizationConfig {
76                constant_folding: false,
77                cse: false,
78                expression_simplification: false,
79                dead_code_elimination: false,
80                operation_fusion: false,
81                memory_optimization: false,
82                max_passes: 0,
83                level: self,
84            },
85            OptimizationLevel::Basic => OptimizationConfig {
86                constant_folding: true,
87                cse: false,
88                expression_simplification: false,
89                dead_code_elimination: true,
90                operation_fusion: false,
91                memory_optimization: false,
92                max_passes: 2,
93                level: self,
94            },
95            OptimizationLevel::Standard => OptimizationConfig::default(),
96            OptimizationLevel::Aggressive => OptimizationConfig {
97                constant_folding: true,
98                cse: true,
99                expression_simplification: true,
100                dead_code_elimination: true,
101                operation_fusion: true,
102                memory_optimization: true,
103                max_passes: 10,
104                level: self,
105            },
106        }
107    }
108}
109
110/// Main graph optimizer
111pub struct GraphOptimizer<F: Float> {
112    config: OptimizationConfig,
113    _phantom: std::marker::PhantomData<F>,
114}
115
116impl<F: Float> GraphOptimizer<F> {
117    /// Create a new graph optimizer with default configuration
118    pub fn new() -> Self {
119        Self {
120            config: OptimizationConfig::default(),
121            _phantom: std::marker::PhantomData,
122        }
123    }
124
125    /// Create a new graph optimizer with custom configuration
126    pub fn with_config(config: OptimizationConfig) -> Self {
127        Self {
128            config,
129            _phantom: std::marker::PhantomData,
130        }
131    }
132
133    /// Create a new graph optimizer with specified optimization level
134    pub fn with_level(level: OptimizationLevel) -> Self {
135        Self {
136            config: level.config(),
137            _phantom: std::marker::PhantomData,
138        }
139    }
140
141    /// Optimize a computation graph
142    pub fn optimize(&self, graph: &Graph<F>) -> Result<OptimizationReport, OptimizationError> {
143        let mut report = OptimizationReport::new();
144
145        if self.config.level == OptimizationLevel::None {
146            return Ok(report);
147        }
148
149        for pass in 0..self.config.max_passes {
150            let mut changed = false;
151
152            // Constant folding
153            if self.config.constant_folding {
154                let folded = self.apply_constant_folding(graph)?;
155                if folded > 0 {
156                    changed = true;
157                    report.constant_folding_applied += folded;
158                }
159            }
160
161            // Dead code elimination
162            if self.config.dead_code_elimination {
163                let eliminated = self.apply_dead_code_elimination(graph)?;
164                if eliminated > 0 {
165                    changed = true;
166                    report.dead_nodes_eliminated += eliminated;
167                }
168            }
169
170            // Common subexpression elimination
171            if self.config.cse {
172                let eliminated = self.apply_cse(graph)?;
173                if eliminated > 0 {
174                    changed = true;
175                    report.cse_applied += eliminated;
176                }
177            }
178
179            // Expression simplification
180            if self.config.expression_simplification {
181                let simplified = self.apply_expression_simplification(graph)?;
182                if simplified > 0 {
183                    changed = true;
184                    report.expressions_simplified += simplified;
185                }
186            }
187
188            // Operation fusion
189            if self.config.operation_fusion {
190                let fused = self.apply_operation_fusion(graph)?;
191                if fused > 0 {
192                    changed = true;
193                    report.operations_fused += fused;
194                }
195            }
196
197            // Memory optimization
198            if self.config.memory_optimization {
199                let optimized = self.apply_memory_optimization(graph)?;
200                if optimized > 0 {
201                    changed = true;
202                    report.memory_optimizations += optimized;
203                }
204            }
205
206            report.passes_completed = pass + 1;
207
208            // If no changes were made, we can stop early
209            if !changed {
210                break;
211            }
212        }
213
214        Ok(report)
215    }
216
217    /// Apply constant folding optimization
218    fn apply_constant_folding(&self, _graph: &Graph<F>) -> Result<usize, OptimizationError> {
219        // Delegates to constant_folding module
220        Ok(0)
221    }
222
223    /// Apply dead code elimination.
224    ///
225    /// Identifies nodes that do not contribute to the primary output of the
226    /// graph.  The primary output is determined by topological rank: the node
227    /// (or nodes, if there are ties) with the highest `topo_rank` is treated as
228    /// the graph output.  All nodes reachable backwards from those outputs are
229    /// live; the remaining nodes are dead.
230    ///
231    /// The graph is immutable during this analysis (we only borrow the
232    /// `node_set`), but the returned count tells callers how many nodes
233    /// *could* be pruned in a subsequent mutable rewrite pass.
234    fn apply_dead_code_elimination(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
235        let node_count = graph.node_set.borrow().len();
236        if node_count == 0 {
237            return Ok(0);
238        }
239
240        // ── Step 1: find the maximum topo_rank ──────────────────────────────
241        let max_topo_rank = {
242            let nodes = graph.node_set.borrow();
243            nodes.iter().map(|n| n.topo_rank).max().unwrap_or(0)
244        };
245
246        // ── Step 2: seed live set with primary output nodes ──────────────────
247        // Nodes at the max topo_rank are treated as "outputs" (the final result
248        // of the computation).  Nodes at a lower rank that have no consumers
249        // in the live set are dead.
250        let mut live: HashSet<TensorID> = HashSet::new();
251        let mut work_stack: Vec<TensorID> = Vec::new();
252
253        {
254            let nodes = graph.node_set.borrow();
255            for node in nodes.iter() {
256                if node.topo_rank == max_topo_rank && !live.contains(&node.id) {
257                    live.insert(node.id);
258                    work_stack.push(node.id);
259                }
260            }
261        }
262
263        if work_stack.is_empty() {
264            return Ok(0);
265        }
266
267        // ── Step 3: backward reachability traversal ──────────────────────────
268        while let Some(current_id) = work_stack.pop() {
269            let incoming_ids: Vec<TensorID> = {
270                let node = graph.access_inner(current_id);
271                node.incoming_nodes.iter().map(|n| n.id).collect()
272            };
273
274            for pred_id in incoming_ids {
275                if pred_id < node_count && !live.contains(&pred_id) {
276                    live.insert(pred_id);
277                    work_stack.push(pred_id);
278                }
279            }
280        }
281
282        // ── Step 4: count dead nodes ─────────────────────────────────────────
283        let dead_count = node_count.saturating_sub(live.len());
284        Ok(dead_count)
285    }
286
287    /// Apply common subexpression elimination (CSE).
288    ///
289    /// Two nodes are considered equivalent when they share the same operation
290    /// name AND the same (possibly sorted) list of input node IDs.  Commutative
291    /// binary ops (Add, Mul) have their inputs sorted so `add(a, b)` and
292    /// `add(b, a)` share one canonical entry.  Source nodes (no inputs) are
293    /// never candidates for elimination — they are semantically unique.
294    ///
295    /// Returns the number of duplicate nodes found (i.e. how many could be
296    /// replaced by earlier canonical computations in a mutable rewrite pass).
297    fn apply_cse(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
298        let node_count = graph.node_set.borrow().len();
299        if node_count == 0 {
300            return Ok(0);
301        }
302
303        // Operations whose semantics are commutative in input order.
304        let commutative_ops: HashSet<&'static str> = ["AddOp", "MulOp", "Add", "Mul", "add", "mul"]
305            .iter()
306            .copied()
307            .collect();
308
309        // Process nodes in ascending topological order.
310        let mut order: Vec<TensorID> = (0..node_count).collect();
311        {
312            let nodes = graph.node_set.borrow();
313            order.sort_by_key(|&id| nodes[id].topo_rank);
314        }
315
316        // Key = (op_name, normalised input-id vector)
317        type CseKey = (String, Vec<TensorID>);
318        let mut seen: HashMap<CseKey, TensorID> = HashMap::new();
319        let mut eliminated = 0usize;
320
321        for node_id in order {
322            let (op_name, mut input_ids, is_source) = {
323                let node = graph.access_inner(node_id);
324                let op_name = node
325                    .op
326                    .as_ref()
327                    .map(|o| o.name().to_owned())
328                    .unwrap_or_default();
329                let input_ids: Vec<TensorID> = node.incoming_nodes.iter().map(|n| n.id).collect();
330                let is_source = node.incoming_nodes.is_empty();
331                (op_name, input_ids, is_source)
332            };
333
334            // Source nodes (variables, placeholders, constants) are unique.
335            if is_source {
336                continue;
337            }
338
339            // Normalise input order for commutative ops.
340            if commutative_ops.contains(op_name.as_str()) {
341                input_ids.sort_unstable();
342            }
343
344            let key: CseKey = (op_name, input_ids);
345            match seen.get(&key) {
346                Some(_canonical_id) => {
347                    // Duplicate: could redirect consumers of `node_id` to
348                    // `canonical_id` in a mutable pass.
349                    eliminated += 1;
350                }
351                None => {
352                    seen.insert(key, node_id);
353                }
354            }
355        }
356
357        Ok(eliminated)
358    }
359
360    /// Apply expression simplification
361    fn apply_expression_simplification(
362        &self,
363        _graph: &Graph<F>,
364    ) -> Result<usize, OptimizationError> {
365        // Delegates to expression_simplification module
366        Ok(0)
367    }
368
369    /// Apply operation fusion.
370    ///
371    /// Scans the computation graph for adjacent pairs (and triples) of
372    /// operations that can be merged into a single fused kernel:
373    ///
374    ///   • MatMul → BiasAdd                       (MatMulBias)
375    ///   • MatMul → Add/BiasAdd → Activation      (MatMulBiasActivation)
376    ///   • Conv2d → BatchNorm                     (ConvBN)
377    ///   • Conv2d → BatchNorm → Activation        (ConvBNActivation)
378    ///   • Any two consecutive element-wise ops   (ElementWise)
379    ///   • Sum → Div                              (SumDivToMean)
380    ///   • Square → Mean                         (SquareMeanToVariance)
381    ///   • Exp → Sum → Div                       (Softmax)
382    ///
383    /// Returns the number of fusion groups applied.
384    fn apply_operation_fusion(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
385        let node_count = graph.node_set.borrow().len();
386        if node_count == 0 {
387            return Ok(0);
388        }
389
390        // Map op-name strings to the fusion module's `OpKind` enum.
391        let classify_op = |op_name: &str| -> fusion::patterns::OpKind {
392            use fusion::patterns::OpKind;
393            match op_name {
394                n if n.contains("MatMul") || n.contains("Matmul") || n == "matmul" => {
395                    OpKind::MatMul
396                }
397                n if n.contains("BiasAdd") || n == "bias_add" => OpKind::BiasAdd,
398                n if n.contains("Relu") || n == "relu" => OpKind::Relu,
399                n if n.contains("Gelu") || n == "gelu" => OpKind::Gelu,
400                n if n.contains("Sigmoid") || n == "sigmoid" => OpKind::Sigmoid,
401                n if n.contains("Tanh") || n == "tanh" => OpKind::Tanh,
402                n if n.contains("Swish") || n == "swish" => OpKind::Swish,
403                n if n.contains("Conv2d") || n.contains("Conv") || n == "conv2d" => OpKind::Conv2d,
404                n if n.contains("BatchNorm") || n.contains("batch_norm") => OpKind::BatchNorm,
405                n if n.contains("AddOp") || n == "Add" || n == "add" => OpKind::Add,
406                n if n.contains("SubOp") || n == "Sub" || n == "sub" => OpKind::Sub,
407                n if n.contains("MulOp") || n == "Mul" || n == "mul" => OpKind::Mul,
408                n if n.contains("DivOp") || n == "Div" || n == "div" => OpKind::Div,
409                n if n.contains("Neg") || n == "neg" => OpKind::Neg,
410                n if n.contains("Square") || n == "square" => OpKind::Square,
411                n if n.contains("Exp") || n == "exp" => OpKind::Exp,
412                n if n.contains("Log") || n == "log" => OpKind::Log,
413                n if n.contains("Sqrt") || n == "sqrt" => OpKind::Sqrt,
414                n if n.contains("Sum") || n == "sum" => OpKind::Sum,
415                n if n.contains("Mean") || n == "mean" => OpKind::Mean,
416                n if n.contains("Max") || n == "max" => OpKind::Max,
417                n if n.contains("Min") || n == "min" => OpKind::Min,
418                _ => OpKind::Custom(op_name.to_owned()),
419            }
420        };
421
422        // Build `GraphNode` descriptors from the live graph.
423        let mut graph_nodes: Vec<fusion::patterns::GraphNode> = Vec::with_capacity(node_count);
424        {
425            let nodes = graph.node_set.borrow();
426            for node in nodes.iter() {
427                let op_name = node
428                    .op
429                    .as_ref()
430                    .map(|o| o.name().to_owned())
431                    .unwrap_or_default();
432                let op_kind = classify_op(&op_name);
433                let inputs: Vec<usize> = node.incoming_nodes.iter().map(|n| n.id).collect();
434                let mut gn = fusion::patterns::GraphNode::new(node.id, op_kind, inputs, vec![]);
435                gn.num_consumers = 0;
436                graph_nodes.push(gn);
437            }
438        }
439
440        // Count consumers so the fusion engine knows which nodes are "interior"
441        // (single-consumer) and thus eligible as non-terminal fusion members.
442        for idx in 0..graph_nodes.len() {
443            let inputs: Vec<usize> = graph_nodes[idx].inputs.clone();
444            for &inp in &inputs {
445                if inp < graph_nodes.len() {
446                    graph_nodes[inp].num_consumers += 1;
447                }
448            }
449        }
450
451        // Detect and apply fusions via the dedicated FusionOptimizer.
452        let mut optimizer = fusion::FusionOptimizer::new();
453        optimizer
454            .detect_fusions_in_graph(&graph_nodes)
455            .map_err(|e| OptimizationError::GraphStructure(e.to_string()))?;
456
457        let fused_nodes = optimizer
458            .apply_fusions_with_nodes(&graph_nodes)
459            .map_err(|e| OptimizationError::GraphStructure(e.to_string()))?;
460
461        Ok(fused_nodes.len())
462    }
463
464    /// Apply memory optimization via lifetime-based buffer reuse analysis.
465    ///
466    /// For each node we compute:
467    ///   • `birth` — the node's own `topo_rank` (when its output is produced).
468    ///   • `death` — the maximum `topo_rank` among all consumers of this node
469    ///               (the last moment its output is needed).
470    ///
471    /// We then apply a greedy interval-graph colouring (scan in birth order,
472    /// reuse the first freed slot) and count how many nodes share a buffer slot
473    /// with an earlier node.  Each reuse is a potential memory saving.
474    fn apply_memory_optimization(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
475        let node_count = graph.node_set.borrow().len();
476        if node_count == 0 {
477            return Ok(0);
478        }
479
480        // ── Collect topo_rank for every node ─────────────────────────────────
481        let topo_ranks: Vec<usize> = {
482            let nodes = graph.node_set.borrow();
483            nodes.iter().map(|n| n.topo_rank).collect()
484        };
485
486        let max_rank = topo_ranks.iter().copied().max().unwrap_or(0);
487
488        // ── Compute death time for each node ─────────────────────────────────
489        // death[id] starts at topo_rank[id] and is updated to the max
490        // topo_rank among all nodes that consume it.
491        let mut death: Vec<usize> = topo_ranks.clone();
492
493        {
494            let nodes = graph.node_set.borrow();
495            for node in nodes.iter() {
496                let consumer_rank = node.topo_rank;
497                for incoming in &node.incoming_nodes {
498                    let pred = incoming.id;
499                    if pred < node_count && consumer_rank > death[pred] {
500                        death[pred] = consumer_rank;
501                    }
502                }
503            }
504        }
505
506        // Nodes with no consumers (pure outputs) keep death == birth unless
507        // we set them to max_rank (live until end of graph).
508        {
509            let nodes = graph.node_set.borrow();
510            for id in 0..node_count {
511                let has_consumer = nodes
512                    .iter()
513                    .any(|n| n.incoming_nodes.iter().any(|inc| inc.id == id));
514                if !has_consumer {
515                    death[id] = max_rank;
516                }
517            }
518        }
519
520        // ── Greedy interval-graph colouring ───────────────────────────────────
521        // Sort intervals by birth time (ascending).
522        let mut intervals: Vec<(usize, usize, TensorID)> = (0..node_count)
523            .map(|id| (topo_ranks[id], death[id], id))
524            .collect();
525        intervals.sort_by_key(|&(birth, _, _)| birth);
526
527        // `active_slots[i]` = death time of the last tensor assigned to slot i.
528        let mut active_slots: Vec<usize> = Vec::new();
529        let mut reuse_count = 0usize;
530
531        for (birth, end, _node_id) in &intervals {
532            // Find the first slot whose current occupant has already died.
533            let released = active_slots
534                .iter()
535                .enumerate()
536                .find(|(_, &slot_death)| slot_death < *birth)
537                .map(|(idx, _)| idx);
538
539            match released {
540                Some(slot_idx) => {
541                    active_slots[slot_idx] = *end;
542                    reuse_count += 1;
543                }
544                None => {
545                    active_slots.push(*end);
546                }
547            }
548        }
549
550        Ok(reuse_count)
551    }
552}
553
554impl<F: Float> Default for GraphOptimizer<F> {
555    fn default() -> Self {
556        Self::new()
557    }
558}
559
560/// Report of optimization results
561#[derive(Debug, Clone, Default)]
562pub struct OptimizationReport {
563    /// Number of optimization passes completed
564    pub passes_completed: usize,
565    /// Number of constant folding optimizations applied
566    pub constant_folding_applied: usize,
567    /// Number of dead nodes eliminated
568    pub dead_nodes_eliminated: usize,
569    /// Number of common subexpressions eliminated
570    pub cse_applied: usize,
571    /// Number of expressions simplified
572    pub expressions_simplified: usize,
573    /// Number of operations fused
574    pub operations_fused: usize,
575    /// Number of memory optimizations applied
576    pub memory_optimizations: usize,
577}
578
579impl OptimizationReport {
580    /// Create a new empty optimization report
581    pub fn new() -> Self {
582        Self::default()
583    }
584
585    /// Get the total number of optimizations applied
586    pub fn total_optimizations(&self) -> usize {
587        self.constant_folding_applied
588            + self.dead_nodes_eliminated
589            + self.cse_applied
590            + self.expressions_simplified
591            + self.operations_fused
592            + self.memory_optimizations
593    }
594
595    /// Check if any optimizations were applied
596    pub fn has_optimizations(&self) -> bool {
597        self.total_optimizations() > 0
598    }
599
600    /// Print a summary of the optimization results
601    pub fn print_summary(&self) {
602        println!("Optimization Report:");
603        println!("==================");
604        println!("Passes completed: {}", self.passes_completed);
605        println!("Total optimizations: {}", self.total_optimizations());
606
607        if self.constant_folding_applied > 0 {
608            println!("  Constant folding: {}", self.constant_folding_applied);
609        }
610        if self.dead_nodes_eliminated > 0 {
611            println!("  Dead code elimination: {}", self.dead_nodes_eliminated);
612        }
613        if self.cse_applied > 0 {
614            println!("  Common subexpression elimination: {}", self.cse_applied);
615        }
616        if self.expressions_simplified > 0 {
617            println!(
618                "  Expression simplification: {}",
619                self.expressions_simplified
620            );
621        }
622        if self.operations_fused > 0 {
623            println!("  Operation fusion: {}", self.operations_fused);
624        }
625        if self.memory_optimizations > 0 {
626            println!("  Memory optimizations: {}", self.memory_optimizations);
627        }
628    }
629}
630
631/// Expression pattern matcher for optimization
632pub struct PatternMatcher<F: Float> {
633    _phantom: std::marker::PhantomData<F>,
634}
635
636impl<F: Float> PatternMatcher<F> {
637    /// Create a new pattern matcher
638    pub fn new() -> Self {
639        Self {
640            _phantom: std::marker::PhantomData,
641        }
642    }
643
644    /// Check if a tensor matches a pattern for simplification
645    #[allow(dead_code)]
646    pub(crate) fn matches_simplification_pattern(
647        &self,
648        _tensor_internal: &TensorInternal<F>,
649    ) -> Option<SimplificationPattern> {
650        // Temporarily disabled - would be implemented with expression_simplification module
651        None
652    }
653
654    /// Check if tensors can be fused
655    #[allow(dead_code)]
656    pub(crate) fn can_fuse(
657        &self,
658        _tensor1: &TensorInternal<F>,
659        _tensor2: &TensorInternal<F>,
660    ) -> bool {
661        // Temporarily disabled - would be implemented with fusion analysis
662        false
663    }
664
665    /// Check if a tensor represents a constant
666    #[allow(dead_code)]
667    pub(crate) fn is_constant(&self, _tensorinternal: &TensorInternal<F>) -> bool {
668        // Temporarily disabled - would be implemented with constant analysis
669        false
670    }
671
672    /// Check if a tensor is dead (unreachable from outputs)
673    #[allow(dead_code)]
674    pub(crate) fn is_dead(
675        &self,
676        _tensor_internal: &TensorInternal<F>,
677        _reachable: &HashSet<TensorID>,
678    ) -> bool {
679        // Temporarily disabled - would be implemented with reachability analysis
680        false
681    }
682}
683
684impl<F: Float> Default for PatternMatcher<F> {
685    fn default() -> Self {
686        Self::new()
687    }
688}
689
690/// Types of simplification patterns
691#[derive(Debug, Clone, Copy, PartialEq)]
692pub enum SimplificationPattern {
693    /// x + 0 → x
694    AddZero,
695    /// x - 0 → x
696    SubZero,
697    /// x * 1 → x
698    MulOne,
699    /// x / 1 → x
700    DivOne,
701    /// x * 0 → 0
702    MulZero,
703    /// x - x → 0
704    SubSelf,
705    /// x / x → 1
706    DivSelf,
707    /// log(exp(x)) → x
708    LogExp,
709    /// exp(log(x)) → x
710    ExpLog,
711    /// sqrt(x^2) → abs(x)
712    SqrtSquare,
713    /// pow(x, 1) → x
714    PowOne,
715    /// pow(x, 0) → 1
716    PowZero,
717}
718
719/// Optimization pass manager
720pub struct OptimizationPass<F: Float> {
721    name: String,
722    _phantom: std::marker::PhantomData<F>,
723}
724
725impl<F: Float> OptimizationPass<F> {
726    /// Create a new optimization pass
727    pub fn new(name: &str) -> Self {
728        Self {
729            name: name.to_string(),
730            _phantom: std::marker::PhantomData,
731        }
732    }
733
734    /// Get the name of this pass
735    pub fn name(&self) -> &str {
736        &self.name
737    }
738
739    /// Run this optimization pass on a graph
740    pub fn run(&self, _graph: &Graph<F>) -> Result<usize, OptimizationError> {
741        // Each pass would implement its specific optimization logic
742        Ok(0)
743    }
744}
745
746/// Errors that can occur during optimization
747#[derive(Debug, thiserror::Error)]
748pub enum OptimizationError {
749    #[error("Graph structure error: {0}")]
750    GraphStructure(String),
751    #[error("Pattern matching error: {0}")]
752    PatternMatching(String),
753    #[error("Optimization conflict: {0}")]
754    Conflict(String),
755    #[error("Invalid operation: {0}")]
756    InvalidOperation(String),
757}
758
759/// Public API functions for graph optimization
760/// Optimize a computation graph with default settings
761#[allow(dead_code)]
762pub fn optimize_graph<F: Float>(graph: &Graph<F>) -> Result<OptimizationReport, OptimizationError> {
763    let optimizer = GraphOptimizer::new();
764    optimizer.optimize(graph)
765}
766
767/// Optimize a computation graph with specified optimization level
768#[allow(dead_code)]
769pub fn optimize_graph_with_level<F: Float>(
770    graph: &Graph<F>,
771    level: OptimizationLevel,
772) -> Result<OptimizationReport, OptimizationError> {
773    let optimizer = GraphOptimizer::with_level(level);
774    optimizer.optimize(graph)
775}
776
777/// Optimize a computation graph with custom configuration
778#[allow(dead_code)]
779pub fn optimize_graph_with_config<F: Float>(
780    graph: &Graph<F>,
781    config: OptimizationConfig,
782) -> Result<OptimizationReport, OptimizationError> {
783    let optimizer = GraphOptimizer::with_config(config);
784    optimizer.optimize(graph)
785}
786
787/// Apply only constant folding optimization
788#[allow(dead_code)]
789pub fn apply_constant_folding<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
790    let config = OptimizationConfig {
791        constant_folding: true,
792        cse: false,
793        expression_simplification: false,
794        dead_code_elimination: false,
795        operation_fusion: false,
796        memory_optimization: false,
797        max_passes: 1,
798        level: OptimizationLevel::Basic,
799    };
800    let optimizer = GraphOptimizer::with_config(config);
801    let report = optimizer.optimize(graph)?;
802    Ok(report.constant_folding_applied)
803}
804
805/// Apply only dead code elimination
806#[allow(dead_code)]
807pub fn apply_dead_code_elimination<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
808    let config = OptimizationConfig {
809        constant_folding: false,
810        cse: false,
811        expression_simplification: false,
812        dead_code_elimination: true,
813        operation_fusion: false,
814        memory_optimization: false,
815        max_passes: 1,
816        level: OptimizationLevel::Basic,
817    };
818    let optimizer = GraphOptimizer::with_config(config);
819    let report = optimizer.optimize(graph)?;
820    Ok(report.dead_nodes_eliminated)
821}
822
823/// Apply common subexpression elimination
824#[allow(dead_code)]
825pub fn apply_cse<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
826    let config = OptimizationConfig {
827        constant_folding: false,
828        cse: true,
829        expression_simplification: false,
830        dead_code_elimination: false,
831        operation_fusion: false,
832        memory_optimization: false,
833        max_passes: 1,
834        level: OptimizationLevel::Standard,
835    };
836    let optimizer = GraphOptimizer::with_config(config);
837    let report = optimizer.optimize(graph)?;
838    Ok(report.cse_applied)
839}
840
841// Re-export types from the now-enabled submodules for convenience
842pub use constant_folding::ConstantFolder;
843pub use expression_simplification::ExpressionSimplifier;
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use crate::graph::AsGraph;
849
850    #[test]
851    fn test_optimization_config() {
852        let config = OptimizationConfig::default();
853        assert!(config.constant_folding);
854        assert!(config.cse);
855        assert!(config.expression_simplification);
856        assert!(config.dead_code_elimination);
857        assert_eq!(config.max_passes, 5);
858    }
859
860    #[test]
861    fn test_optimization_levels() {
862        let none_config = OptimizationLevel::None.config();
863        assert!(!none_config.constant_folding);
864        assert_eq!(none_config.max_passes, 0);
865
866        let aggressive_config = OptimizationLevel::Aggressive.config();
867        assert!(aggressive_config.operation_fusion);
868        assert!(aggressive_config.memory_optimization);
869        assert_eq!(aggressive_config.max_passes, 10);
870    }
871
872    #[test]
873    fn test_graph_optimizer_creation() {
874        let _optimizer = GraphOptimizer::<f32>::new();
875        let _optimizer_with_config =
876            GraphOptimizer::<f32>::with_config(OptimizationConfig::default());
877        let _optimizer_with_level =
878            GraphOptimizer::<f32>::with_level(OptimizationLevel::Aggressive);
879    }
880
881    #[test]
882    fn test_optimization_report() {
883        let mut report = OptimizationReport::new();
884        assert_eq!(report.total_optimizations(), 0);
885        assert!(!report.has_optimizations());
886
887        report.constant_folding_applied = 5;
888        report.dead_nodes_eliminated = 3;
889        assert_eq!(report.total_optimizations(), 8);
890        assert!(report.has_optimizations());
891    }
892
893    #[test]
894    fn test_pattern_matcher() {
895        let _matcher = PatternMatcher::<f32>::new();
896    }
897
898    #[test]
899    fn test_simplification_patterns() {
900        let pattern = SimplificationPattern::AddZero;
901        assert_eq!(pattern, SimplificationPattern::AddZero);
902
903        let patterns = [
904            SimplificationPattern::AddZero,
905            SimplificationPattern::MulOne,
906            SimplificationPattern::LogExp,
907        ];
908        assert_eq!(patterns.len(), 3);
909    }
910
911    #[test]
912    fn test_optimization_pass() {
913        let pass = OptimizationPass::<f32>::new("test_pass");
914        assert_eq!(pass.name(), "test_pass");
915    }
916
917    // ── Integration tests against real computation graphs ──────────────────
918
919    /// DCE: a node that is not on the path to the output should be counted as dead.
920    #[test]
921    fn test_dce_on_real_graph() {
922        use crate::tensor_ops as T;
923        use crate::VariableEnvironment;
924
925        let env = VariableEnvironment::<f32>::new();
926        env.run(|ctx| {
927            // Build a small graph.
928            // `_b` is a tensor that is never used as input to any other node.
929            // The final "output" is `c` (highest topo_rank).
930            let a = T::zeros(&[2, 2], ctx);
931            let _b = T::ones(&[2, 2], ctx); // dead — never consumed
932            let c = T::mul(a, T::ones(&[2, 2], ctx));
933            let _ = c;
934
935            let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
936                constant_folding: false,
937                cse: false,
938                expression_simplification: false,
939                dead_code_elimination: true,
940                operation_fusion: false,
941                memory_optimization: false,
942                max_passes: 1,
943                level: OptimizationLevel::Basic,
944            });
945
946            let report = optimizer
947                .optimize(ctx.as_graph())
948                .expect("DCE should succeed");
949
950            assert!(
951                report.dead_nodes_eliminated >= 1,
952                "Expected at least 1 dead node, got {}",
953                report.dead_nodes_eliminated
954            );
955        });
956    }
957
958    /// CSE: two identical `add(a, b)` nodes should result in one elimination.
959    #[test]
960    fn test_cse_on_real_graph() {
961        use crate::tensor_ops as T;
962        use crate::VariableEnvironment;
963
964        let env = VariableEnvironment::<f32>::new();
965        env.run(|ctx| {
966            let a = T::zeros(&[2, 2], ctx);
967            let b = T::ones(&[2, 2], ctx);
968            // Compute a + b twice — same op name and same inputs.
969            let c1 = T::add(a, b);
970            let c2 = T::add(a, b);
971            // Consume both so neither is dead for DCE purposes.
972            let _ = T::add(c1, c2);
973
974            let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
975                constant_folding: false,
976                cse: true,
977                expression_simplification: false,
978                dead_code_elimination: false,
979                operation_fusion: false,
980                memory_optimization: false,
981                max_passes: 1,
982                level: OptimizationLevel::Standard,
983            });
984
985            let report = optimizer
986                .optimize(ctx.as_graph())
987                .expect("CSE should succeed");
988
989            assert!(
990                report.cse_applied >= 1,
991                "Expected >= 1 CSE elimination, got {}",
992                report.cse_applied
993            );
994        });
995    }
996
997    /// Memory optimisation: in a linear chain, at least one buffer reuse should
998    /// be detected.
999    #[test]
1000    fn test_memory_opt_on_real_graph() {
1001        use crate::tensor_ops as T;
1002        use crate::VariableEnvironment;
1003
1004        let env = VariableEnvironment::<f32>::new();
1005        env.run(|ctx| {
1006            let a = T::zeros(&[4, 4], ctx);
1007            let b = T::mul(a, T::ones(&[4, 4], ctx));
1008            let c = T::add(b, T::ones(&[4, 4], ctx));
1009            let d = T::mul(c, T::ones(&[4, 4], ctx));
1010            let _ = d;
1011
1012            let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
1013                constant_folding: false,
1014                cse: false,
1015                expression_simplification: false,
1016                dead_code_elimination: false,
1017                operation_fusion: false,
1018                memory_optimization: true,
1019                max_passes: 1,
1020                level: OptimizationLevel::Standard,
1021            });
1022
1023            let report = optimizer
1024                .optimize(ctx.as_graph())
1025                .expect("Memory opt should succeed");
1026
1027            assert!(
1028                report.memory_optimizations >= 1,
1029                "Expected >= 1 memory reuse opportunity, got {}",
1030                report.memory_optimizations
1031            );
1032        });
1033    }
1034
1035    /// An empty graph should produce zero optimisations and not panic.
1036    #[test]
1037    fn test_empty_graph_all_passes() {
1038        use crate::VariableEnvironment;
1039
1040        let env = VariableEnvironment::<f32>::new();
1041        env.run(|ctx| {
1042            let optimizer = GraphOptimizer::<f32>::new();
1043            let report = optimizer.optimize(ctx.as_graph()).expect("Empty graph OK");
1044            assert_eq!(report.dead_nodes_eliminated, 0);
1045            assert_eq!(report.cse_applied, 0);
1046            assert_eq!(report.memory_optimizations, 0);
1047        });
1048    }
1049}