1use anyhow::Result;
10use std::str::FromStr;
11use tensorlogic_compiler::passes::{
12 optimize_einsum_graph as compiler_optimize_graph, EinsumOptResult,
13};
14use tensorlogic_ir::EinsumGraph;
15
16use crate::output::{print_info, print_success};
17
18#[derive(Debug, Clone, Default)]
20pub struct OptimizationStats {
21 pub identity_simplifications: usize,
23 pub merged_einsums: usize,
25 pub reordered_ops: usize,
27 pub estimated_speedup: f64,
29}
30
31impl From<EinsumOptResult> for OptimizationStats {
32 fn from(result: EinsumOptResult) -> Self {
33 Self {
34 identity_simplifications: result.identity_eliminated,
35 merged_einsums: result.merged_count,
36 reordered_ops: result.reordered_count,
37 estimated_speedup: result.estimated_speedup,
38 }
39 }
40}
41
42fn optimize_graph_internal(graph: &mut EinsumGraph) -> OptimizationStats {
44 match compiler_optimize_graph(graph) {
45 Ok(result) => result.into(),
46 Err(_) => OptimizationStats::default(),
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum OptimizationLevel {
53 None,
55 Basic,
57 #[default]
59 Standard,
60 Aggressive,
62}
63
64impl OptimizationLevel {
65 pub fn num_passes(&self) -> usize {
67 match self {
68 OptimizationLevel::None => 0,
69 OptimizationLevel::Basic => 1,
70 OptimizationLevel::Standard => 2,
71 OptimizationLevel::Aggressive => 10, }
73 }
74
75 pub fn description(&self) -> &'static str {
77 match self {
78 OptimizationLevel::None => "No optimizations",
79 OptimizationLevel::Basic => "Basic (1 pass: DCE + CSE)",
80 OptimizationLevel::Standard => "Standard (2 passes: DCE + CSE + Identity)",
81 OptimizationLevel::Aggressive => "Aggressive (until convergence)",
82 }
83 }
84}
85
86impl FromStr for OptimizationLevel {
88 type Err = anyhow::Error;
89
90 fn from_str(s: &str) -> Result<Self> {
91 match s.to_lowercase().as_str() {
92 "none" | "0" => Ok(OptimizationLevel::None),
93 "basic" | "1" => Ok(OptimizationLevel::Basic),
94 "standard" | "2" => Ok(OptimizationLevel::Standard),
95 "aggressive" | "3" => Ok(OptimizationLevel::Aggressive),
96 _ => anyhow::bail!("Unknown optimization level: {}", s),
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
105pub struct OptimizationConfig {
106 pub level: OptimizationLevel,
108 #[allow(dead_code)]
110 pub enable_dce: bool,
111 #[allow(dead_code)]
113 pub enable_cse: bool,
114 #[allow(dead_code)]
116 pub enable_identity: bool,
117 pub show_stats: bool,
119 pub verbose: bool,
121}
122
123impl Default for OptimizationConfig {
124 fn default() -> Self {
125 Self {
126 level: OptimizationLevel::default(),
127 enable_dce: true,
128 enable_cse: true,
129 enable_identity: true,
130 show_stats: false,
131 verbose: false,
132 }
133 }
134}
135
136pub fn optimize_einsum_graph(
138 mut graph: EinsumGraph,
139 config: &OptimizationConfig,
140) -> Result<(EinsumGraph, OptimizationStats)> {
141 if config.level == OptimizationLevel::None {
142 if config.verbose {
143 print_info("Skipping optimizations (level: None)");
144 }
145 return Ok((graph, OptimizationStats::default()));
146 }
147
148 let num_passes = config.level.num_passes();
149 let mut total_stats = OptimizationStats::default();
150
151 if config.verbose {
152 print_info(&format!(
153 "Applying {} ({})",
154 config.level.description(),
155 num_passes
156 ));
157 println!(
158 " Initial: {} nodes, {} tensors",
159 graph.nodes.len(),
160 graph.tensors.len()
161 );
162 }
163
164 for pass in 0..num_passes {
165 let before_nodes = graph.nodes.len();
166 let before_tensors = graph.tensors.len();
167
168 let stats = optimize_graph_internal(&mut graph);
170
171 if stats.identity_simplifications == 0
173 && stats.merged_einsums == 0
174 && stats.reordered_ops == 0
175 {
176 if config.verbose {
177 println!(" Converged after {} passes", pass + 1);
178 }
179 break;
180 }
181
182 total_stats.identity_simplifications += stats.identity_simplifications;
184 total_stats.merged_einsums += stats.merged_einsums;
185 total_stats.reordered_ops += stats.reordered_ops;
186 if stats.estimated_speedup > 1.0 {
187 total_stats.estimated_speedup *= stats.estimated_speedup;
188 }
189
190 if config.verbose {
191 println!(
192 " Pass {}: {} → {} nodes, {} → {} tensors",
193 pass + 1,
194 before_nodes,
195 graph.nodes.len(),
196 before_tensors,
197 graph.tensors.len()
198 );
199 }
200 }
201
202 if config.show_stats || config.verbose {
203 print_optimization_stats(&total_stats);
204 }
205
206 let total_improvements = total_stats.identity_simplifications
207 + total_stats.merged_einsums
208 + total_stats.reordered_ops;
209
210 if total_improvements > 0 {
211 print_success(&format!(
212 "Optimization complete: {} identities removed, {} einsums merged, {} reordered",
213 total_stats.identity_simplifications,
214 total_stats.merged_einsums,
215 total_stats.reordered_ops
216 ));
217 } else if config.verbose {
218 print_info("No optimizations applied (graph already optimal)");
219 }
220
221 Ok((graph, total_stats))
222}
223
224fn print_optimization_stats(stats: &OptimizationStats) {
226 println!("\nOptimization Statistics:");
227 println!(
228 " Identity operations eliminated: {}",
229 stats.identity_simplifications
230 );
231 println!(" Einsum operations merged: {}", stats.merged_einsums);
232 println!(" Operations reordered: {}", stats.reordered_ops);
233
234 let total = stats.identity_simplifications + stats.merged_einsums + stats.reordered_ops;
235 if total > 0 {
236 println!(" Total improvements: {}", total);
237 if stats.estimated_speedup > 1.0 {
238 println!(" Estimated speedup: {:.2}x", stats.estimated_speedup);
239 }
240 }
241}
242
243#[allow(dead_code)]
245pub fn list_optimization_levels() {
246 println!("Optimization Levels:");
247 println!();
248
249 for level in &[
250 OptimizationLevel::None,
251 OptimizationLevel::Basic,
252 OptimizationLevel::Standard,
253 OptimizationLevel::Aggressive,
254 ] {
255 println!(" {:?}: {}", level, level.description());
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_optimization_level_from_str() {
265 assert_eq!(
266 OptimizationLevel::from_str("none").unwrap(),
267 OptimizationLevel::None
268 );
269 assert_eq!(
270 OptimizationLevel::from_str("basic").unwrap(),
271 OptimizationLevel::Basic
272 );
273 assert_eq!(
274 OptimizationLevel::from_str("2").unwrap(),
275 OptimizationLevel::Standard
276 );
277 assert!(OptimizationLevel::from_str("invalid").is_err());
278 }
279
280 #[test]
281 fn test_optimization_level_num_passes() {
282 assert_eq!(OptimizationLevel::None.num_passes(), 0);
283 assert_eq!(OptimizationLevel::Basic.num_passes(), 1);
284 assert_eq!(OptimizationLevel::Standard.num_passes(), 2);
285 assert_eq!(OptimizationLevel::Aggressive.num_passes(), 10);
286 }
287
288 #[test]
289 fn test_optimization_config_default() {
290 let config = OptimizationConfig::default();
291 assert_eq!(config.level, OptimizationLevel::Standard);
292 assert!(config.enable_dce);
293 assert!(config.enable_cse);
294 assert!(config.enable_identity);
295 }
296}