tensorlogic_cli/
optimize.rs

1//! Graph optimization pipeline for CLI
2//!
3//! Provides IR-level optimizations:
4//! - Identity operation elimination
5//! - Einsum operation merging
6//! - Contraction order optimization
7//! - Multi-pass optimization
8
9use anyhow::Result;
10use std::str::FromStr;
11use tensorlogic_compiler::passes::{
12    optimize_einsum_graph as compiler_optimize_graph, EinsumOptResult,
13};
14use tensorlogic_ir::EinsumGraph;
15
16use crate::output::{print_info, print_success};
17
18/// Optimization statistics
19#[derive(Debug, Clone, Default)]
20pub struct OptimizationStats {
21    /// Number of identity operations simplified
22    pub identity_simplifications: usize,
23    /// Number of einsum operations merged
24    pub merged_einsums: usize,
25    /// Number of operations reordered
26    pub reordered_ops: usize,
27    /// Estimated speedup factor
28    pub estimated_speedup: f64,
29}
30
31impl From<EinsumOptResult> for OptimizationStats {
32    fn from(result: EinsumOptResult) -> Self {
33        Self {
34            identity_simplifications: result.identity_eliminated,
35            merged_einsums: result.merged_count,
36            reordered_ops: result.reordered_count,
37            estimated_speedup: result.estimated_speedup,
38        }
39    }
40}
41
42/// Optimize graph using tensorlogic-compiler's optimization passes
43fn optimize_graph_internal(graph: &mut EinsumGraph) -> OptimizationStats {
44    match compiler_optimize_graph(graph) {
45        Ok(result) => result.into(),
46        Err(_) => OptimizationStats::default(),
47    }
48}
49
50/// Optimization level
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum OptimizationLevel {
53    /// No optimizations
54    None,
55    /// Basic optimizations (1 pass)
56    Basic,
57    /// Standard optimizations (2 passes)
58    #[default]
59    Standard,
60    /// Aggressive optimizations (until convergence)
61    Aggressive,
62}
63
64impl OptimizationLevel {
65    /// Get number of optimization passes
66    pub fn num_passes(&self) -> usize {
67        match self {
68            OptimizationLevel::None => 0,
69            OptimizationLevel::Basic => 1,
70            OptimizationLevel::Standard => 2,
71            OptimizationLevel::Aggressive => 10, // Until convergence
72        }
73    }
74
75    /// Get description
76    pub fn description(&self) -> &'static str {
77        match self {
78            OptimizationLevel::None => "No optimizations",
79            OptimizationLevel::Basic => "Basic (1 pass: DCE + CSE)",
80            OptimizationLevel::Standard => "Standard (2 passes: DCE + CSE + Identity)",
81            OptimizationLevel::Aggressive => "Aggressive (until convergence)",
82        }
83    }
84}
85
86// Implement FromStr trait
87impl FromStr for OptimizationLevel {
88    type Err = anyhow::Error;
89
90    fn from_str(s: &str) -> Result<Self> {
91        match s.to_lowercase().as_str() {
92            "none" | "0" => Ok(OptimizationLevel::None),
93            "basic" | "1" => Ok(OptimizationLevel::Basic),
94            "standard" | "2" => Ok(OptimizationLevel::Standard),
95            "aggressive" | "3" => Ok(OptimizationLevel::Aggressive),
96            _ => anyhow::bail!("Unknown optimization level: {}", s),
97        }
98    }
99}
100
101// Default implementation is now derived with #[default] attribute
102
103/// Optimization configuration
104#[derive(Debug, Clone)]
105pub struct OptimizationConfig {
106    /// Optimization level
107    pub level: OptimizationLevel,
108    /// Enable Dead Code Elimination (reserved for future use)
109    #[allow(dead_code)]
110    pub enable_dce: bool,
111    /// Enable Common Subexpression Elimination (reserved for future use)
112    #[allow(dead_code)]
113    pub enable_cse: bool,
114    /// Enable identity simplification (reserved for future use)
115    #[allow(dead_code)]
116    pub enable_identity: bool,
117    /// Show optimization statistics
118    pub show_stats: bool,
119    /// Verbose output
120    pub verbose: bool,
121}
122
123impl Default for OptimizationConfig {
124    fn default() -> Self {
125        Self {
126            level: OptimizationLevel::default(),
127            enable_dce: true,
128            enable_cse: true,
129            enable_identity: true,
130            show_stats: false,
131            verbose: false,
132        }
133    }
134}
135
136/// Optimize einsum graph with configuration
137pub fn optimize_einsum_graph(
138    mut graph: EinsumGraph,
139    config: &OptimizationConfig,
140) -> Result<(EinsumGraph, OptimizationStats)> {
141    if config.level == OptimizationLevel::None {
142        if config.verbose {
143            print_info("Skipping optimizations (level: None)");
144        }
145        return Ok((graph, OptimizationStats::default()));
146    }
147
148    let num_passes = config.level.num_passes();
149    let mut total_stats = OptimizationStats::default();
150
151    if config.verbose {
152        print_info(&format!(
153            "Applying {} ({})",
154            config.level.description(),
155            num_passes
156        ));
157        println!(
158            "  Initial: {} nodes, {} tensors",
159            graph.nodes.len(),
160            graph.tensors.len()
161        );
162    }
163
164    for pass in 0..num_passes {
165        let before_nodes = graph.nodes.len();
166        let before_tensors = graph.tensors.len();
167
168        // Apply optimization
169        let stats = optimize_graph_internal(&mut graph);
170
171        // Check for convergence
172        if stats.identity_simplifications == 0
173            && stats.merged_einsums == 0
174            && stats.reordered_ops == 0
175        {
176            if config.verbose {
177                println!("  Converged after {} passes", pass + 1);
178            }
179            break;
180        }
181
182        // Accumulate stats
183        total_stats.identity_simplifications += stats.identity_simplifications;
184        total_stats.merged_einsums += stats.merged_einsums;
185        total_stats.reordered_ops += stats.reordered_ops;
186        if stats.estimated_speedup > 1.0 {
187            total_stats.estimated_speedup *= stats.estimated_speedup;
188        }
189
190        if config.verbose {
191            println!(
192                "  Pass {}: {} → {} nodes, {} → {} tensors",
193                pass + 1,
194                before_nodes,
195                graph.nodes.len(),
196                before_tensors,
197                graph.tensors.len()
198            );
199        }
200    }
201
202    if config.show_stats || config.verbose {
203        print_optimization_stats(&total_stats);
204    }
205
206    let total_improvements = total_stats.identity_simplifications
207        + total_stats.merged_einsums
208        + total_stats.reordered_ops;
209
210    if total_improvements > 0 {
211        print_success(&format!(
212            "Optimization complete: {} identities removed, {} einsums merged, {} reordered",
213            total_stats.identity_simplifications,
214            total_stats.merged_einsums,
215            total_stats.reordered_ops
216        ));
217    } else if config.verbose {
218        print_info("No optimizations applied (graph already optimal)");
219    }
220
221    Ok((graph, total_stats))
222}
223
224/// Print optimization statistics
225fn print_optimization_stats(stats: &OptimizationStats) {
226    println!("\nOptimization Statistics:");
227    println!(
228        "  Identity operations eliminated: {}",
229        stats.identity_simplifications
230    );
231    println!("  Einsum operations merged: {}", stats.merged_einsums);
232    println!("  Operations reordered: {}", stats.reordered_ops);
233
234    let total = stats.identity_simplifications + stats.merged_einsums + stats.reordered_ops;
235    if total > 0 {
236        println!("  Total improvements: {}", total);
237        if stats.estimated_speedup > 1.0 {
238            println!("  Estimated speedup: {:.2}x", stats.estimated_speedup);
239        }
240    }
241}
242
243/// List available optimization levels (reserved for future use)
244#[allow(dead_code)]
245pub fn list_optimization_levels() {
246    println!("Optimization Levels:");
247    println!();
248
249    for level in &[
250        OptimizationLevel::None,
251        OptimizationLevel::Basic,
252        OptimizationLevel::Standard,
253        OptimizationLevel::Aggressive,
254    ] {
255        println!("  {:?}: {}", level, level.description());
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_optimization_level_from_str() {
265        assert_eq!(
266            OptimizationLevel::from_str("none").unwrap(),
267            OptimizationLevel::None
268        );
269        assert_eq!(
270            OptimizationLevel::from_str("basic").unwrap(),
271            OptimizationLevel::Basic
272        );
273        assert_eq!(
274            OptimizationLevel::from_str("2").unwrap(),
275            OptimizationLevel::Standard
276        );
277        assert!(OptimizationLevel::from_str("invalid").is_err());
278    }
279
280    #[test]
281    fn test_optimization_level_num_passes() {
282        assert_eq!(OptimizationLevel::None.num_passes(), 0);
283        assert_eq!(OptimizationLevel::Basic.num_passes(), 1);
284        assert_eq!(OptimizationLevel::Standard.num_passes(), 2);
285        assert_eq!(OptimizationLevel::Aggressive.num_passes(), 10);
286    }
287
288    #[test]
289    fn test_optimization_config_default() {
290        let config = OptimizationConfig::default();
291        assert_eq!(config.level, OptimizationLevel::Standard);
292        assert!(config.enable_dce);
293        assert!(config.enable_cse);
294        assert!(config.enable_identity);
295    }
296}