tensorlogic_compiler/passes/
advanced_analysis.rs1use std::cmp::Reverse;
10
11use anyhow::Result;
12use tensorlogic_ir::{analyze_memory, analyze_parallelization, EinsumGraph, OpType};
13
14#[derive(Debug, Clone)]
16pub struct AnalysisReport {
17 pub peak_memory_bytes: usize,
19 pub memory_intensive_ops: Vec<usize>,
21 pub potential_memory_savings: usize,
23 pub parallel_opportunities: Vec<ParallelOpportunity>,
25 pub recommendations: Vec<OptimizationRecommendation>,
27}
28
29#[derive(Debug, Clone)]
31pub struct ParallelOpportunity {
32 pub parallel_nodes: Vec<usize>,
34 pub estimated_speedup: f64,
36 pub description: String,
38}
39
40#[derive(Debug, Clone)]
42pub struct OptimizationRecommendation {
43 pub priority: u8,
45 pub category: RecommendationCategory,
47 pub description: String,
49 pub estimated_improvement: f64,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum RecommendationCategory {
56 Fusion,
58 Memory,
60 Parallelization,
62 Layout,
64 Numerical,
66 General,
68}
69
70impl AnalysisReport {
71 pub fn new() -> Self {
73 Self {
74 peak_memory_bytes: 0,
75 memory_intensive_ops: Vec::new(),
76 potential_memory_savings: 0,
77 parallel_opportunities: Vec::new(),
78 recommendations: Vec::new(),
79 }
80 }
81
82 pub fn has_critical_recommendations(&self) -> bool {
84 self.recommendations.iter().any(|r| r.priority >= 8)
85 }
86
87 pub fn high_priority_recommendations(&self) -> Vec<&OptimizationRecommendation> {
89 self.recommendations
90 .iter()
91 .filter(|r| r.priority >= 7)
92 .collect()
93 }
94
95 pub fn estimated_total_speedup(&self) -> f64 {
97 self.recommendations
98 .iter()
99 .map(|r| r.estimated_improvement)
100 .fold(1.0, |acc, x| acc * x)
101 }
102}
103
104impl Default for AnalysisReport {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110pub fn analyze_graph(graph: &EinsumGraph) -> Result<AnalysisReport> {
131 let mut report = AnalysisReport::new();
132
133 let memory_analysis = analyze_memory(graph, 8)?; report.peak_memory_bytes = memory_analysis.peak_memory_bytes;
136 report.potential_memory_savings = memory_analysis
137 .total_memory_bytes
138 .saturating_sub(memory_analysis.peak_memory_bytes);
139
140 let parallel_analysis = analyze_parallelization(graph)?;
142 for group in parallel_analysis.parallel_groups {
143 if group.nodes.len() > 1 {
144 report.parallel_opportunities.push(ParallelOpportunity {
145 parallel_nodes: group.nodes.clone(),
146 estimated_speedup: (group.nodes.len() as f64).sqrt(),
147 description: format!("{} operations can execute in parallel", group.nodes.len()),
148 });
149 }
150 }
151
152 generate_recommendations(graph, &mut report);
154
155 Ok(report)
156}
157
158fn generate_recommendations(graph: &EinsumGraph, report: &mut AnalysisReport) {
160 if has_fusible_operations(graph) {
162 report.recommendations.push(OptimizationRecommendation {
163 priority: 9,
164 category: RecommendationCategory::Fusion,
165 description: "Enable operation fusion to reduce kernel launches".to_string(),
166 estimated_improvement: 1.3,
167 });
168 }
169
170 if report.peak_memory_bytes > 100_000_000 {
172 report.recommendations.push(OptimizationRecommendation {
174 priority: 8,
175 category: RecommendationCategory::Memory,
176 description: "Enable memory optimization to reduce peak usage".to_string(),
177 estimated_improvement: 1.2,
178 });
179 }
180
181 if !report.parallel_opportunities.is_empty() {
183 let max_speedup = report
184 .parallel_opportunities
185 .iter()
186 .map(|p| p.estimated_speedup)
187 .fold(0.0, f64::max);
188 report.recommendations.push(OptimizationRecommendation {
189 priority: 7,
190 category: RecommendationCategory::Parallelization,
191 description: format!(
192 "Parallelize {} independent operation groups",
193 report.parallel_opportunities.len()
194 ),
195 estimated_improvement: max_speedup,
196 });
197 }
198
199 if graph.nodes.len() > 50 {
201 report.recommendations.push(OptimizationRecommendation {
202 priority: 6,
203 category: RecommendationCategory::Layout,
204 description: "Apply layout optimization for better cache locality".to_string(),
205 estimated_improvement: 1.15,
206 });
207 }
208
209 if report.peak_memory_bytes > 500_000_000 {
211 report.recommendations.push(OptimizationRecommendation {
213 priority: 5,
214 category: RecommendationCategory::General,
215 description: "Consider tiling to reduce memory pressure".to_string(),
216 estimated_improvement: 1.1,
217 });
218 }
219
220 report.recommendations.sort_by_key(|r| Reverse(r.priority));
222}
223
224fn has_fusible_operations(graph: &EinsumGraph) -> bool {
226 let mut consecutive_elementwise = 0;
227
228 for node in &graph.nodes {
229 match &node.op {
230 OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
231 consecutive_elementwise += 1;
232 if consecutive_elementwise >= 2 {
233 return true;
234 }
235 }
236 _ => {
237 consecutive_elementwise = 0;
238 }
239 }
240 }
241
242 false
243}
244
245pub fn quick_analyze(graph: &EinsumGraph) -> Result<(usize, usize)> {
260 let memory = analyze_memory(graph, 8)?;
261 let parallel = analyze_parallelization(graph)?;
262 let parallel_groups = parallel
263 .parallel_groups
264 .iter()
265 .filter(|g| g.nodes.len() > 1)
266 .count();
267 Ok((memory.peak_memory_bytes, parallel_groups))
268}
269
270pub fn print_report(report: &AnalysisReport) {
283 println!("=== Graph Analysis Report ===");
284 println!(
285 "Peak Memory Usage: {:.2} MB",
286 report.peak_memory_bytes as f64 / 1_048_576.0
287 );
288 println!(
289 "Potential Memory Savings: {:.2} MB",
290 report.potential_memory_savings as f64 / 1_048_576.0
291 );
292
293 if !report.parallel_opportunities.is_empty() {
294 println!(
295 "\nParallelization Opportunities: {}",
296 report.parallel_opportunities.len()
297 );
298 for (i, opp) in report.parallel_opportunities.iter().enumerate() {
299 println!(
300 " {}. {} (speedup: {:.2}x)",
301 i + 1,
302 opp.description,
303 opp.estimated_speedup
304 );
305 }
306 }
307
308 if !report.recommendations.is_empty() {
309 println!("\nOptimization Recommendations:");
310 for (i, rec) in report.recommendations.iter().enumerate() {
311 println!(
312 " {}. [Priority {}] {} (improvement: {:.2}x)",
313 i + 1,
314 rec.priority,
315 rec.description,
316 rec.estimated_improvement
317 );
318 }
319
320 println!(
321 "\nEstimated Total Speedup: {:.2}x",
322 report.estimated_total_speedup()
323 );
324 }
325
326 println!("============================");
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tensorlogic_ir::EinsumNode;
333
334 #[test]
335 fn test_analysis_report_new() {
336 let report = AnalysisReport::new();
337 assert_eq!(report.peak_memory_bytes, 0);
338 assert!(report.parallel_opportunities.is_empty());
339 }
340
341 #[test]
342 fn test_has_critical_recommendations() {
343 let mut report = AnalysisReport::new();
344 assert!(!report.has_critical_recommendations());
345
346 report.recommendations.push(OptimizationRecommendation {
347 priority: 9,
348 category: RecommendationCategory::Fusion,
349 description: "Test".to_string(),
350 estimated_improvement: 1.5,
351 });
352 assert!(report.has_critical_recommendations());
353 }
354
355 #[test]
356 fn test_high_priority_recommendations() {
357 let mut report = AnalysisReport::new();
358
359 report.recommendations.push(OptimizationRecommendation {
360 priority: 9,
361 category: RecommendationCategory::Fusion,
362 description: "High priority".to_string(),
363 estimated_improvement: 1.5,
364 });
365
366 report.recommendations.push(OptimizationRecommendation {
367 priority: 5,
368 category: RecommendationCategory::General,
369 description: "Low priority".to_string(),
370 estimated_improvement: 1.1,
371 });
372
373 let high_priority = report.high_priority_recommendations();
374 assert_eq!(high_priority.len(), 1);
375 assert_eq!(high_priority[0].priority, 9);
376 }
377
378 #[test]
379 fn test_estimated_total_speedup() {
380 let mut report = AnalysisReport::new();
381
382 report.recommendations.push(OptimizationRecommendation {
383 priority: 9,
384 category: RecommendationCategory::Fusion,
385 description: "Test 1".to_string(),
386 estimated_improvement: 1.5,
387 });
388
389 report.recommendations.push(OptimizationRecommendation {
390 priority: 8,
391 category: RecommendationCategory::Memory,
392 description: "Test 2".to_string(),
393 estimated_improvement: 1.2,
394 });
395
396 let speedup = report.estimated_total_speedup();
397 assert!((speedup - 1.8).abs() < 0.01); }
399
400 #[test]
401 fn test_analyze_empty_graph() {
402 let graph = EinsumGraph::new();
403 let result = analyze_graph(&graph);
404 assert!(result.is_ok());
405
406 let report = result.expect("unwrap");
407 assert_eq!(report.peak_memory_bytes, 0);
408 }
409
410 #[test]
411 fn test_quick_analyze_empty() {
412 let graph = EinsumGraph::new();
413 let result = quick_analyze(&graph);
414 assert!(result.is_ok());
415
416 let (memory, parallel_groups) = result.expect("unwrap");
417 assert_eq!(memory, 0);
418 assert_eq!(parallel_groups, 0);
419 }
420
421 #[test]
422 fn test_has_fusible_operations_no_fusion() {
423 let mut graph = EinsumGraph::new();
424 let t0 = graph.add_tensor("x");
425 let t1 = graph.add_tensor("x_relu");
426 let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t1));
427
428 assert!(!has_fusible_operations(&graph));
429 }
430
431 #[test]
432 fn test_has_fusible_operations_with_fusion() {
433 let mut graph = EinsumGraph::new();
434 let t0 = graph.add_tensor("x");
435 let t1 = graph.add_tensor("y");
436 let t2 = graph.add_tensor("x_relu");
437 let t3 = graph.add_tensor("y_tanh");
438 let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t2));
439 let _ = graph.add_node(EinsumNode::elem_unary("tanh", t1, t3));
440
441 assert!(has_fusible_operations(&graph));
443 }
444
445 #[test]
446 fn test_recommendation_category_equality() {
447 assert_eq!(
448 RecommendationCategory::Fusion,
449 RecommendationCategory::Fusion
450 );
451 assert_ne!(
452 RecommendationCategory::Fusion,
453 RecommendationCategory::Memory
454 );
455 }
456
457 #[test]
458 fn test_parallel_opportunity_creation() {
459 let opp = ParallelOpportunity {
460 parallel_nodes: vec![1, 2, 3],
461 estimated_speedup: 1.7,
462 description: "Test opportunity".to_string(),
463 };
464
465 assert_eq!(opp.parallel_nodes.len(), 3);
466 assert!((opp.estimated_speedup - 1.7).abs() < 0.01);
467 }
468}