Skip to main content

tensorlogic_compiler/passes/
advanced_analysis.rs

1//! Advanced graph analysis integration.
2//!
3//! This module integrates tensorlogic-ir's analysis capabilities into the compiler pipeline.
4//! It provides:
5//! - Memory usage estimation
6//! - Parallelization analysis
7//! - Optimization recommendations
8
9use std::cmp::Reverse;
10
11use anyhow::Result;
12use tensorlogic_ir::{analyze_memory, analyze_parallelization, EinsumGraph, OpType};
13
14/// Result of advanced graph analysis
15#[derive(Debug, Clone)]
16pub struct AnalysisReport {
17    /// Peak memory usage (bytes)
18    pub peak_memory_bytes: usize,
19    /// Memory-intensive operations (node indices)
20    pub memory_intensive_ops: Vec<usize>,
21    /// Potential memory savings (bytes)
22    pub potential_memory_savings: usize,
23    /// Parallelization opportunities
24    pub parallel_opportunities: Vec<ParallelOpportunity>,
25    /// Recommended optimizations
26    pub recommendations: Vec<OptimizationRecommendation>,
27}
28
29/// Description of a parallelization opportunity
30#[derive(Debug, Clone)]
31pub struct ParallelOpportunity {
32    /// Nodes that can be executed in parallel
33    pub parallel_nodes: Vec<usize>,
34    /// Estimated speedup from parallelization
35    pub estimated_speedup: f64,
36    /// Description of the opportunity
37    pub description: String,
38}
39
40/// Optimization recommendation based on analysis
41#[derive(Debug, Clone)]
42pub struct OptimizationRecommendation {
43    /// Priority level (higher = more important)
44    pub priority: u8,
45    /// Category of recommendation
46    pub category: RecommendationCategory,
47    /// Human-readable description
48    pub description: String,
49    /// Estimated improvement (as a ratio, e.g., 1.5 = 50% improvement)
50    pub estimated_improvement: f64,
51}
52
53/// Categories of optimization recommendations
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum RecommendationCategory {
56    /// Recommendations related to operation fusion
57    Fusion,
58    /// Recommendations related to memory optimization
59    Memory,
60    /// Recommendations related to parallelization
61    Parallelization,
62    /// Recommendations related to layout optimization
63    Layout,
64    /// Recommendations related to numerical stability
65    Numerical,
66    /// General optimization recommendations
67    General,
68}
69
70impl AnalysisReport {
71    /// Create a new empty analysis report
72    pub fn new() -> Self {
73        Self {
74            peak_memory_bytes: 0,
75            memory_intensive_ops: Vec::new(),
76            potential_memory_savings: 0,
77            parallel_opportunities: Vec::new(),
78            recommendations: Vec::new(),
79        }
80    }
81
82    /// Check if there are high-priority recommendations
83    pub fn has_critical_recommendations(&self) -> bool {
84        self.recommendations.iter().any(|r| r.priority >= 8)
85    }
86
87    /// Get high-priority recommendations (priority >= 7)
88    pub fn high_priority_recommendations(&self) -> Vec<&OptimizationRecommendation> {
89        self.recommendations
90            .iter()
91            .filter(|r| r.priority >= 7)
92            .collect()
93    }
94
95    /// Estimate total potential speedup from all recommendations
96    pub fn estimated_total_speedup(&self) -> f64 {
97        self.recommendations
98            .iter()
99            .map(|r| r.estimated_improvement)
100            .fold(1.0, |acc, x| acc * x)
101    }
102}
103
104impl Default for AnalysisReport {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110/// Perform comprehensive analysis of a compiled graph
111///
112/// This analyzes the graph structure, identifies optimization opportunities,
113/// and provides actionable recommendations for improving performance.
114///
115/// # Example
116///
117/// ```
118/// use tensorlogic_compiler::passes::advanced_analysis::analyze_graph;
119/// use tensorlogic_ir::EinsumGraph;
120///
121/// let graph = EinsumGraph::new();
122/// let report = analyze_graph(&graph).expect("unwrap");
123///
124/// println!("Peak memory: {} bytes", report.peak_memory_bytes);
125///
126/// for rec in report.high_priority_recommendations() {
127///     println!("Priority {}: {}", rec.priority, rec.description);
128/// }
129/// ```
130pub fn analyze_graph(graph: &EinsumGraph) -> Result<AnalysisReport> {
131    let mut report = AnalysisReport::new();
132
133    // Memory analysis
134    let memory_analysis = analyze_memory(graph, 8)?; // 8 bytes for f64
135    report.peak_memory_bytes = memory_analysis.peak_memory_bytes;
136    report.potential_memory_savings = memory_analysis
137        .total_memory_bytes
138        .saturating_sub(memory_analysis.peak_memory_bytes);
139
140    // Parallelization analysis
141    let parallel_analysis = analyze_parallelization(graph)?;
142    for group in parallel_analysis.parallel_groups {
143        if group.nodes.len() > 1 {
144            report.parallel_opportunities.push(ParallelOpportunity {
145                parallel_nodes: group.nodes.clone(),
146                estimated_speedup: (group.nodes.len() as f64).sqrt(),
147                description: format!("{} operations can execute in parallel", group.nodes.len()),
148            });
149        }
150    }
151
152    // Generate recommendations
153    generate_recommendations(graph, &mut report);
154
155    Ok(report)
156}
157
158/// Generate optimization recommendations based on analysis
159fn generate_recommendations(graph: &EinsumGraph, report: &mut AnalysisReport) {
160    // Recommendation 1: Fusion opportunities
161    if has_fusible_operations(graph) {
162        report.recommendations.push(OptimizationRecommendation {
163            priority: 9,
164            category: RecommendationCategory::Fusion,
165            description: "Enable operation fusion to reduce kernel launches".to_string(),
166            estimated_improvement: 1.3,
167        });
168    }
169
170    // Recommendation 2: Memory optimization
171    if report.peak_memory_bytes > 100_000_000 {
172        // > 100MB
173        report.recommendations.push(OptimizationRecommendation {
174            priority: 8,
175            category: RecommendationCategory::Memory,
176            description: "Enable memory optimization to reduce peak usage".to_string(),
177            estimated_improvement: 1.2,
178        });
179    }
180
181    // Recommendation 3: Parallelization
182    if !report.parallel_opportunities.is_empty() {
183        let max_speedup = report
184            .parallel_opportunities
185            .iter()
186            .map(|p| p.estimated_speedup)
187            .fold(0.0, f64::max);
188        report.recommendations.push(OptimizationRecommendation {
189            priority: 7,
190            category: RecommendationCategory::Parallelization,
191            description: format!(
192                "Parallelize {} independent operation groups",
193                report.parallel_opportunities.len()
194            ),
195            estimated_improvement: max_speedup,
196        });
197    }
198
199    // Recommendation 4: Layout optimization for large graphs
200    if graph.nodes.len() > 50 {
201        report.recommendations.push(OptimizationRecommendation {
202            priority: 6,
203            category: RecommendationCategory::Layout,
204            description: "Apply layout optimization for better cache locality".to_string(),
205            estimated_improvement: 1.15,
206        });
207    }
208
209    // Recommendation 5: Tiling for memory-intensive graphs
210    if report.peak_memory_bytes > 500_000_000 {
211        // > 500MB
212        report.recommendations.push(OptimizationRecommendation {
213            priority: 5,
214            category: RecommendationCategory::General,
215            description: "Consider tiling to reduce memory pressure".to_string(),
216            estimated_improvement: 1.1,
217        });
218    }
219
220    // Sort recommendations by priority (descending)
221    report.recommendations.sort_by_key(|r| Reverse(r.priority));
222}
223
224/// Check if the graph has fusible operations
225fn has_fusible_operations(graph: &EinsumGraph) -> bool {
226    let mut consecutive_elementwise = 0;
227
228    for node in &graph.nodes {
229        match &node.op {
230            OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
231                consecutive_elementwise += 1;
232                if consecutive_elementwise >= 2 {
233                    return true;
234                }
235            }
236            _ => {
237                consecutive_elementwise = 0;
238            }
239        }
240    }
241
242    false
243}
244
245/// Quick analysis for fast feedback during compilation
246///
247/// Provides essential metrics without deep analysis.
248///
249/// # Example
250///
251/// ```
252/// use tensorlogic_compiler::passes::advanced_analysis::quick_analyze;
253/// use tensorlogic_ir::EinsumGraph;
254///
255/// let graph = EinsumGraph::new();
256/// let (peak_memory, parallel_groups) = quick_analyze(&graph).expect("unwrap");
257/// println!("Memory: {} bytes, Parallelism: {}", peak_memory, parallel_groups);
258/// ```
259pub fn quick_analyze(graph: &EinsumGraph) -> Result<(usize, usize)> {
260    let memory = analyze_memory(graph, 8)?;
261    let parallel = analyze_parallelization(graph)?;
262    let parallel_groups = parallel
263        .parallel_groups
264        .iter()
265        .filter(|g| g.nodes.len() > 1)
266        .count();
267    Ok((memory.peak_memory_bytes, parallel_groups))
268}
269
270/// Print a human-readable analysis report
271///
272/// # Example
273///
274/// ```
275/// use tensorlogic_compiler::passes::advanced_analysis::{analyze_graph, print_report};
276/// use tensorlogic_ir::EinsumGraph;
277///
278/// let graph = EinsumGraph::new();
279/// let report = analyze_graph(&graph).expect("unwrap");
280/// print_report(&report);
281/// ```
282pub fn print_report(report: &AnalysisReport) {
283    println!("=== Graph Analysis Report ===");
284    println!(
285        "Peak Memory Usage: {:.2} MB",
286        report.peak_memory_bytes as f64 / 1_048_576.0
287    );
288    println!(
289        "Potential Memory Savings: {:.2} MB",
290        report.potential_memory_savings as f64 / 1_048_576.0
291    );
292
293    if !report.parallel_opportunities.is_empty() {
294        println!(
295            "\nParallelization Opportunities: {}",
296            report.parallel_opportunities.len()
297        );
298        for (i, opp) in report.parallel_opportunities.iter().enumerate() {
299            println!(
300                "  {}. {} (speedup: {:.2}x)",
301                i + 1,
302                opp.description,
303                opp.estimated_speedup
304            );
305        }
306    }
307
308    if !report.recommendations.is_empty() {
309        println!("\nOptimization Recommendations:");
310        for (i, rec) in report.recommendations.iter().enumerate() {
311            println!(
312                "  {}. [Priority {}] {} (improvement: {:.2}x)",
313                i + 1,
314                rec.priority,
315                rec.description,
316                rec.estimated_improvement
317            );
318        }
319
320        println!(
321            "\nEstimated Total Speedup: {:.2}x",
322            report.estimated_total_speedup()
323        );
324    }
325
326    println!("============================");
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use tensorlogic_ir::EinsumNode;
333
334    #[test]
335    fn test_analysis_report_new() {
336        let report = AnalysisReport::new();
337        assert_eq!(report.peak_memory_bytes, 0);
338        assert!(report.parallel_opportunities.is_empty());
339    }
340
341    #[test]
342    fn test_has_critical_recommendations() {
343        let mut report = AnalysisReport::new();
344        assert!(!report.has_critical_recommendations());
345
346        report.recommendations.push(OptimizationRecommendation {
347            priority: 9,
348            category: RecommendationCategory::Fusion,
349            description: "Test".to_string(),
350            estimated_improvement: 1.5,
351        });
352        assert!(report.has_critical_recommendations());
353    }
354
355    #[test]
356    fn test_high_priority_recommendations() {
357        let mut report = AnalysisReport::new();
358
359        report.recommendations.push(OptimizationRecommendation {
360            priority: 9,
361            category: RecommendationCategory::Fusion,
362            description: "High priority".to_string(),
363            estimated_improvement: 1.5,
364        });
365
366        report.recommendations.push(OptimizationRecommendation {
367            priority: 5,
368            category: RecommendationCategory::General,
369            description: "Low priority".to_string(),
370            estimated_improvement: 1.1,
371        });
372
373        let high_priority = report.high_priority_recommendations();
374        assert_eq!(high_priority.len(), 1);
375        assert_eq!(high_priority[0].priority, 9);
376    }
377
378    #[test]
379    fn test_estimated_total_speedup() {
380        let mut report = AnalysisReport::new();
381
382        report.recommendations.push(OptimizationRecommendation {
383            priority: 9,
384            category: RecommendationCategory::Fusion,
385            description: "Test 1".to_string(),
386            estimated_improvement: 1.5,
387        });
388
389        report.recommendations.push(OptimizationRecommendation {
390            priority: 8,
391            category: RecommendationCategory::Memory,
392            description: "Test 2".to_string(),
393            estimated_improvement: 1.2,
394        });
395
396        let speedup = report.estimated_total_speedup();
397        assert!((speedup - 1.8).abs() < 0.01); // 1.5 * 1.2 = 1.8
398    }
399
400    #[test]
401    fn test_analyze_empty_graph() {
402        let graph = EinsumGraph::new();
403        let result = analyze_graph(&graph);
404        assert!(result.is_ok());
405
406        let report = result.expect("unwrap");
407        assert_eq!(report.peak_memory_bytes, 0);
408    }
409
410    #[test]
411    fn test_quick_analyze_empty() {
412        let graph = EinsumGraph::new();
413        let result = quick_analyze(&graph);
414        assert!(result.is_ok());
415
416        let (memory, parallel_groups) = result.expect("unwrap");
417        assert_eq!(memory, 0);
418        assert_eq!(parallel_groups, 0);
419    }
420
421    #[test]
422    fn test_has_fusible_operations_no_fusion() {
423        let mut graph = EinsumGraph::new();
424        let t0 = graph.add_tensor("x");
425        let t1 = graph.add_tensor("x_relu");
426        let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t1));
427
428        assert!(!has_fusible_operations(&graph));
429    }
430
431    #[test]
432    fn test_has_fusible_operations_with_fusion() {
433        let mut graph = EinsumGraph::new();
434        let t0 = graph.add_tensor("x");
435        let t1 = graph.add_tensor("y");
436        let t2 = graph.add_tensor("x_relu");
437        let t3 = graph.add_tensor("y_tanh");
438        let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t2));
439        let _ = graph.add_node(EinsumNode::elem_unary("tanh", t1, t3));
440
441        // Two consecutive element-wise operations
442        assert!(has_fusible_operations(&graph));
443    }
444
445    #[test]
446    fn test_recommendation_category_equality() {
447        assert_eq!(
448            RecommendationCategory::Fusion,
449            RecommendationCategory::Fusion
450        );
451        assert_ne!(
452            RecommendationCategory::Fusion,
453            RecommendationCategory::Memory
454        );
455    }
456
457    #[test]
458    fn test_parallel_opportunity_creation() {
459        let opp = ParallelOpportunity {
460            parallel_nodes: vec![1, 2, 3],
461            estimated_speedup: 1.7,
462            description: "Test opportunity".to_string(),
463        };
464
465        assert_eq!(opp.parallel_nodes.len(), 3);
466        assert!((opp.estimated_speedup - 1.7).abs() < 0.01);
467    }
468}