Skip to main content

tensorlogic_compiler/passes/
graph_opt_integration.rs

1//! Integration pass for IR-level graph optimizations.
2//!
3//! This module provides compiler-level integration with the sophisticated
4//! graph optimization passes available in tensorlogic-ir. It applies
5//! fusion, layout optimization, memory optimization, and pattern-based
6//! transformations to compiled EinsumGraphs.
7
8use anyhow::Result;
9use tensorlogic_ir::{
10    analyze_memory, apply_tiling, fold_constants_aggressive, fuse_elementwise_operations,
11    optimize_layouts, EinsumGraph, GraphScheduler, OpType, SchedulingObjective,
12};
13
14/// Configuration for graph optimization integration
15#[derive(Debug, Clone)]
16pub struct GraphOptConfig {
17    /// Enable operation fusion
18    pub enable_fusion: bool,
19    /// Enable layout optimization
20    pub enable_layout_opt: bool,
21    /// Enable memory optimization
22    pub enable_memory_opt: bool,
23    /// Enable constant folding at graph level
24    pub enable_constant_folding: bool,
25    /// Enable tiling for large operations
26    pub enable_tiling: bool,
27    /// Enable scheduling optimization
28    pub enable_scheduling: bool,
29    /// Target tile size for tiling optimization
30    pub tile_size: Option<usize>,
31    /// Memory budget for memory optimization (in bytes)
32    pub memory_budget: Option<usize>,
33}
34
35impl Default for GraphOptConfig {
36    fn default() -> Self {
37        Self {
38            enable_fusion: true,
39            enable_layout_opt: true,
40            enable_memory_opt: true,
41            enable_constant_folding: true,
42            enable_tiling: false, // Conservative default
43            enable_scheduling: true,
44            tile_size: Some(64),
45            memory_budget: None, // Auto-detect
46        }
47    }
48}
49
50impl GraphOptConfig {
51    /// Create a new configuration with all optimizations enabled
52    pub fn aggressive() -> Self {
53        Self {
54            enable_fusion: true,
55            enable_layout_opt: true,
56            enable_memory_opt: true,
57            enable_constant_folding: true,
58            enable_tiling: true,
59            enable_scheduling: true,
60            tile_size: Some(128),
61            memory_budget: None,
62        }
63    }
64
65    /// Create a configuration with conservative optimizations (safe defaults)
66    pub fn conservative() -> Self {
67        Self {
68            enable_fusion: true,
69            enable_layout_opt: false,
70            enable_memory_opt: false,
71            enable_constant_folding: true,
72            enable_tiling: false,
73            enable_scheduling: false,
74            tile_size: None,
75            memory_budget: None,
76        }
77    }
78
79    /// Create a configuration optimized for inference (minimize latency)
80    pub fn for_inference() -> Self {
81        Self {
82            enable_fusion: true,
83            enable_layout_opt: true,
84            enable_memory_opt: false,
85            enable_constant_folding: true,
86            enable_tiling: false,
87            enable_scheduling: true,
88            tile_size: None,
89            memory_budget: None,
90        }
91    }
92
93    /// Create a configuration optimized for training (minimize memory)
94    pub fn for_training() -> Self {
95        Self {
96            enable_fusion: true,
97            enable_layout_opt: true,
98            enable_memory_opt: true,
99            enable_constant_folding: true,
100            enable_tiling: true,
101            enable_scheduling: true,
102            tile_size: Some(64),
103            memory_budget: None,
104        }
105    }
106}
107
108/// Statistics from graph optimization integration
109#[derive(Debug, Default, Clone)]
110pub struct GraphOptStats {
111    /// Number of operations fused
112    pub ops_fused: usize,
113    /// Number of layout transformations applied
114    pub layout_transforms: usize,
115    /// Number of memory optimizations applied
116    pub memory_opts: usize,
117    /// Number of constants folded at graph level
118    pub graph_constants_folded: usize,
119    /// Number of tiles created
120    pub tiles_created: usize,
121    /// Estimated memory reduction (bytes)
122    pub memory_saved: usize,
123    /// Estimated speedup from optimizations
124    pub estimated_speedup: f64,
125}
126
127impl GraphOptStats {
128    /// Total number of optimizations applied
129    pub fn total_optimizations(&self) -> usize {
130        self.ops_fused
131            + self.layout_transforms
132            + self.memory_opts
133            + self.graph_constants_folded
134            + self.tiles_created
135    }
136
137    /// Check if any optimizations were applied
138    pub fn any_applied(&self) -> bool {
139        self.total_optimizations() > 0
140    }
141}
142
143/// Apply integrated graph optimizations to a compiled EinsumGraph
144///
145/// This is the main entry point for applying IR-level optimizations to
146/// graphs produced by the compiler. It orchestrates multiple optimization
147/// passes in an intelligent order.
148///
149/// # Example
150///
151/// ```
152/// use tensorlogic_compiler::passes::graph_opt_integration::{
153///     apply_graph_optimizations, GraphOptConfig
154/// };
155/// use tensorlogic_ir::EinsumGraph;
156///
157/// let mut graph = EinsumGraph::new();
158/// // ... compile logic expressions into graph ...
159///
160/// let config = GraphOptConfig::default();
161/// let (optimized, stats) = apply_graph_optimizations(&graph, &config).unwrap();
162///
163/// println!("Applied {} optimizations", stats.total_optimizations());
164/// println!("Estimated speedup: {:.2}x", stats.estimated_speedup);
165/// ```
166pub fn apply_graph_optimizations(
167    graph: &EinsumGraph,
168    config: &GraphOptConfig,
169) -> Result<(EinsumGraph, GraphOptStats)> {
170    let mut optimized = graph.clone();
171    let mut stats = GraphOptStats {
172        estimated_speedup: 1.0,
173        ..Default::default()
174    };
175
176    // Phase 1: Constant folding (do this early to simplify later passes)
177    if config.enable_constant_folding {
178        let fold_result = fold_constants_aggressive(&mut optimized)?;
179        stats.graph_constants_folded = fold_result.operations_folded;
180        stats.estimated_speedup *= fold_result.estimated_speedup;
181    }
182
183    // Phase 2: Operation fusion (reduce kernel launches)
184    if config.enable_fusion {
185        let fusion_result = fuse_elementwise_operations(&mut optimized)?;
186        stats.ops_fused = fusion_result.ops_fused;
187        stats.estimated_speedup *= fusion_result.estimated_speedup;
188    }
189
190    // Phase 3: Layout optimization (improve cache locality)
191    if config.enable_layout_opt {
192        let layout_result = optimize_layouts(&optimized)?;
193        stats.layout_transforms = layout_result.transformations_needed;
194        stats.estimated_speedup *= layout_result.estimated_speedup;
195    }
196
197    // Phase 4: Memory analysis (understand memory usage patterns)
198    if config.enable_memory_opt {
199        let mem_result = analyze_memory(&optimized, 8)?; // 8 bytes for f64 elements
200        stats.memory_saved = mem_result.total_memory_bytes - mem_result.peak_memory_bytes;
201        // Note: memory analysis provides insights but doesn't modify the graph
202        stats.memory_opts = if mem_result.avg_utilization < 0.8 {
203            1
204        } else {
205            0
206        };
207    }
208
209    // Phase 5: Tiling (for large operations)
210    if config.enable_tiling {
211        if let Some(tile_size) = config.tile_size {
212            use tensorlogic_ir::{TileConfig as IrTileConfig, TilingStrategy};
213            let mut strategy = TilingStrategy::new();
214            strategy.add_tile(IrTileConfig::new(0, tile_size));
215            let tiling_result = apply_tiling(&mut optimized, &strategy)?;
216            stats.tiles_created = tiling_result.nodes_tiled + tiling_result.loops_unrolled;
217        }
218    }
219
220    // Phase 6: Scheduling (optimize execution order)
221    if config.enable_scheduling {
222        let scheduler = GraphScheduler::new();
223        let _schedule = scheduler.schedule(&optimized, SchedulingObjective::MinimizeMemory)?;
224        // Scheduling doesn't modify graph, just provides execution order
225        stats.estimated_speedup *= 1.05;
226    }
227
228    Ok((optimized, stats))
229}
230
231/// Apply pattern-based graph transformations
232///
233/// Uses the IR's pattern matching system to apply domain-specific
234/// optimizations and transformations to the graph.
235///
236/// # Example
237///
238/// ```
239/// use tensorlogic_compiler::passes::graph_opt_integration::apply_pattern_optimizations;
240/// use tensorlogic_ir::EinsumGraph;
241///
242/// let graph = EinsumGraph::new();
243/// let (optimized, count) = apply_pattern_optimizations(&graph).unwrap();
244/// println!("Applied {} pattern-based optimizations", count);
245/// ```
246pub fn apply_pattern_optimizations(graph: &EinsumGraph) -> Result<(EinsumGraph, usize)> {
247    // Pattern rewriting is available but needs to be explicitly configured
248    // For now, return the original graph with 0 rewrites
249    // In the future, this will use tensorlogic_ir::PatternMatcher
250    Ok((graph.clone(), 0))
251}
252
253/// Quick optimization pass with sensible defaults
254///
255/// Applies a curated set of fast optimizations that provide good
256/// bang-for-buck with minimal compilation overhead.
257///
258/// # Example
259///
260/// ```
261/// use tensorlogic_compiler::passes::graph_opt_integration::quick_optimize;
262/// use tensorlogic_ir::EinsumGraph;
263///
264/// let graph = EinsumGraph::new();
265/// let optimized = quick_optimize(&graph).unwrap();
266/// ```
267pub fn quick_optimize(graph: &EinsumGraph) -> Result<EinsumGraph> {
268    let config = GraphOptConfig {
269        enable_fusion: true,
270        enable_layout_opt: false,
271        enable_memory_opt: false,
272        enable_constant_folding: true,
273        enable_tiling: false,
274        enable_scheduling: false,
275        tile_size: None,
276        memory_budget: None,
277    };
278
279    let (optimized, _) = apply_graph_optimizations(graph, &config)?;
280    Ok(optimized)
281}
282
283/// Analyze graph and recommend optimization configuration
284///
285/// Examines the structure and characteristics of a graph to suggest
286/// which optimizations are likely to be beneficial.
287///
288/// # Example
289///
290/// ```
291/// use tensorlogic_compiler::passes::graph_opt_integration::recommend_optimizations;
292/// use tensorlogic_ir::EinsumGraph;
293///
294/// let graph = EinsumGraph::new();
295/// let config = recommend_optimizations(&graph);
296/// println!("Recommended fusion: {}", config.enable_fusion);
297/// ```
298pub fn recommend_optimizations(graph: &EinsumGraph) -> GraphOptConfig {
299    let node_count = graph.nodes.len();
300    let tensor_count = graph.tensors.len();
301
302    // Count operation types
303    let mut elementwise_count = 0;
304    let mut einsum_count = 0;
305
306    for node in &graph.nodes {
307        match &node.op {
308            OpType::ElemUnary { .. } | OpType::ElemBinary { .. } => elementwise_count += 1,
309            OpType::Einsum { .. } => einsum_count += 1,
310            _ => {}
311        }
312    }
313
314    // Small graphs: conservative optimizations
315    if node_count < 10 {
316        return GraphOptConfig {
317            enable_fusion: elementwise_count > 2,
318            enable_layout_opt: false,
319            enable_memory_opt: false,
320            enable_constant_folding: true,
321            enable_tiling: false,
322            enable_scheduling: false,
323            tile_size: None,
324            memory_budget: None,
325        };
326    }
327
328    // Medium graphs: selective optimizations
329    if node_count < 50 {
330        return GraphOptConfig {
331            enable_fusion: elementwise_count > 3,
332            enable_layout_opt: einsum_count > 5,
333            enable_memory_opt: tensor_count > 20,
334            enable_constant_folding: true,
335            enable_tiling: false,
336            enable_scheduling: true,
337            tile_size: Some(64),
338            memory_budget: None,
339        };
340    }
341
342    // Large graphs: aggressive optimizations
343    GraphOptConfig {
344        enable_fusion: true,
345        enable_layout_opt: true,
346        enable_memory_opt: true,
347        enable_constant_folding: true,
348        enable_tiling: einsum_count > 10,
349        enable_scheduling: true,
350        tile_size: Some(128),
351        memory_budget: None,
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use tensorlogic_ir::EinsumNode;
359
360    #[test]
361    fn test_config_defaults() {
362        let config = GraphOptConfig::default();
363        assert!(config.enable_fusion);
364        assert!(config.enable_constant_folding);
365    }
366
367    #[test]
368    fn test_config_aggressive() {
369        let config = GraphOptConfig::aggressive();
370        assert!(config.enable_fusion);
371        assert!(config.enable_layout_opt);
372        assert!(config.enable_memory_opt);
373        assert!(config.enable_tiling);
374        assert!(config.enable_scheduling);
375    }
376
377    #[test]
378    fn test_config_conservative() {
379        let config = GraphOptConfig::conservative();
380        assert!(config.enable_fusion);
381        assert!(!config.enable_layout_opt);
382        assert!(!config.enable_tiling);
383    }
384
385    #[test]
386    fn test_stats_total_optimizations() {
387        let stats = GraphOptStats {
388            ops_fused: 5,
389            layout_transforms: 3,
390            memory_opts: 2,
391            graph_constants_folded: 1,
392            tiles_created: 0,
393            memory_saved: 1024,
394            estimated_speedup: 1.5,
395        };
396        assert_eq!(stats.total_optimizations(), 11);
397        assert!(stats.any_applied());
398    }
399
400    #[test]
401    fn test_quick_optimize_empty_graph() {
402        let graph = EinsumGraph::new();
403        let result = quick_optimize(&graph);
404        assert!(result.is_ok());
405    }
406
407    #[test]
408    fn test_recommend_optimizations_small_graph() {
409        let mut graph = EinsumGraph::new();
410        let t0 = graph.add_tensor("x");
411        let t1 = graph.add_tensor("y");
412        let t2 = graph.add_tensor("z");
413        let _ = graph.add_node(EinsumNode::elem_binary("add", t0, t1, t2));
414
415        let config = recommend_optimizations(&graph);
416        // Small graph should have conservative settings
417        assert!(!config.enable_tiling);
418        assert!(config.enable_constant_folding);
419    }
420
421    #[test]
422    fn test_recommend_optimizations_medium_graph() {
423        let mut graph = EinsumGraph::new();
424        // Create a graph with ~30 nodes
425        for i in 0..30 {
426            let t_in = graph.add_tensor(format!("t{}", i));
427            let t_out = graph.add_tensor(format!("t{}_out", i));
428            let _ = graph.add_node(EinsumNode::elem_unary("relu", t_in, t_out));
429        }
430
431        let config = recommend_optimizations(&graph);
432        assert!(config.enable_fusion);
433        assert!(config.enable_scheduling);
434    }
435
436    #[test]
437    fn test_apply_optimizations_with_default_config() {
438        let mut graph = EinsumGraph::new();
439        let t0 = graph.add_tensor("x");
440        let t1 = graph.add_tensor("const");
441        let t2 = graph.add_tensor("result");
442        let _ = graph.add_node(EinsumNode::elem_binary("add", t0, t1, t2));
443
444        let config = GraphOptConfig::default();
445        let result = apply_graph_optimizations(&graph, &config);
446        assert!(result.is_ok());
447
448        let (_optimized, stats) = result.unwrap();
449        assert!(stats.estimated_speedup >= 1.0);
450    }
451
452    #[test]
453    fn test_apply_pattern_optimizations_empty() {
454        let graph = EinsumGraph::new();
455        let result = apply_pattern_optimizations(&graph);
456        assert!(result.is_ok());
457    }
458
459    #[test]
460    fn test_stats_any_applied_false() {
461        let stats = GraphOptStats::default();
462        assert!(!stats.any_applied());
463    }
464
465    #[test]
466    fn test_config_for_inference() {
467        let config = GraphOptConfig::for_inference();
468        assert!(config.enable_fusion);
469        assert!(config.enable_layout_opt);
470        assert!(!config.enable_memory_opt); // Prioritize speed over memory
471        assert!(!config.enable_tiling);
472    }
473
474    #[test]
475    fn test_config_for_training() {
476        let config = GraphOptConfig::for_training();
477        assert!(config.enable_memory_opt); // Prioritize memory efficiency
478        assert!(config.enable_tiling);
479        assert_eq!(config.tile_size, Some(64));
480    }
481}