tensorlogic_compiler/passes/
advanced_analysis.rs1use anyhow::Result;
10use tensorlogic_ir::{analyze_memory, analyze_parallelization, EinsumGraph, OpType};
11
12#[derive(Debug, Clone)]
14pub struct AnalysisReport {
15 pub peak_memory_bytes: usize,
17 pub memory_intensive_ops: Vec<usize>,
19 pub potential_memory_savings: usize,
21 pub parallel_opportunities: Vec<ParallelOpportunity>,
23 pub recommendations: Vec<OptimizationRecommendation>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ParallelOpportunity {
30 pub parallel_nodes: Vec<usize>,
32 pub estimated_speedup: f64,
34 pub description: String,
36}
37
38#[derive(Debug, Clone)]
40pub struct OptimizationRecommendation {
41 pub priority: u8,
43 pub category: RecommendationCategory,
45 pub description: String,
47 pub estimated_improvement: f64,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum RecommendationCategory {
54 Fusion,
56 Memory,
58 Parallelization,
60 Layout,
62 Numerical,
64 General,
66}
67
68impl AnalysisReport {
69 pub fn new() -> Self {
71 Self {
72 peak_memory_bytes: 0,
73 memory_intensive_ops: Vec::new(),
74 potential_memory_savings: 0,
75 parallel_opportunities: Vec::new(),
76 recommendations: Vec::new(),
77 }
78 }
79
80 pub fn has_critical_recommendations(&self) -> bool {
82 self.recommendations.iter().any(|r| r.priority >= 8)
83 }
84
85 pub fn high_priority_recommendations(&self) -> Vec<&OptimizationRecommendation> {
87 self.recommendations
88 .iter()
89 .filter(|r| r.priority >= 7)
90 .collect()
91 }
92
93 pub fn estimated_total_speedup(&self) -> f64 {
95 self.recommendations
96 .iter()
97 .map(|r| r.estimated_improvement)
98 .fold(1.0, |acc, x| acc * x)
99 }
100}
101
102impl Default for AnalysisReport {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108pub fn analyze_graph(graph: &EinsumGraph) -> Result<AnalysisReport> {
129 let mut report = AnalysisReport::new();
130
131 let memory_analysis = analyze_memory(graph, 8)?; report.peak_memory_bytes = memory_analysis.peak_memory_bytes;
134 report.potential_memory_savings = memory_analysis
135 .total_memory_bytes
136 .saturating_sub(memory_analysis.peak_memory_bytes);
137
138 let parallel_analysis = analyze_parallelization(graph)?;
140 for group in parallel_analysis.parallel_groups {
141 if group.nodes.len() > 1 {
142 report.parallel_opportunities.push(ParallelOpportunity {
143 parallel_nodes: group.nodes.clone(),
144 estimated_speedup: (group.nodes.len() as f64).sqrt(),
145 description: format!("{} operations can execute in parallel", group.nodes.len()),
146 });
147 }
148 }
149
150 generate_recommendations(graph, &mut report);
152
153 Ok(report)
154}
155
156fn generate_recommendations(graph: &EinsumGraph, report: &mut AnalysisReport) {
158 if has_fusible_operations(graph) {
160 report.recommendations.push(OptimizationRecommendation {
161 priority: 9,
162 category: RecommendationCategory::Fusion,
163 description: "Enable operation fusion to reduce kernel launches".to_string(),
164 estimated_improvement: 1.3,
165 });
166 }
167
168 if report.peak_memory_bytes > 100_000_000 {
170 report.recommendations.push(OptimizationRecommendation {
172 priority: 8,
173 category: RecommendationCategory::Memory,
174 description: "Enable memory optimization to reduce peak usage".to_string(),
175 estimated_improvement: 1.2,
176 });
177 }
178
179 if !report.parallel_opportunities.is_empty() {
181 let max_speedup = report
182 .parallel_opportunities
183 .iter()
184 .map(|p| p.estimated_speedup)
185 .fold(0.0, f64::max);
186 report.recommendations.push(OptimizationRecommendation {
187 priority: 7,
188 category: RecommendationCategory::Parallelization,
189 description: format!(
190 "Parallelize {} independent operation groups",
191 report.parallel_opportunities.len()
192 ),
193 estimated_improvement: max_speedup,
194 });
195 }
196
197 if graph.nodes.len() > 50 {
199 report.recommendations.push(OptimizationRecommendation {
200 priority: 6,
201 category: RecommendationCategory::Layout,
202 description: "Apply layout optimization for better cache locality".to_string(),
203 estimated_improvement: 1.15,
204 });
205 }
206
207 if report.peak_memory_bytes > 500_000_000 {
209 report.recommendations.push(OptimizationRecommendation {
211 priority: 5,
212 category: RecommendationCategory::General,
213 description: "Consider tiling to reduce memory pressure".to_string(),
214 estimated_improvement: 1.1,
215 });
216 }
217
218 report
220 .recommendations
221 .sort_by(|a, b| b.priority.cmp(&a.priority));
222}
223
224fn has_fusible_operations(graph: &EinsumGraph) -> bool {
226 let mut consecutive_elementwise = 0;
227
228 for node in &graph.nodes {
229 match &node.op {
230 OpType::ElemUnary { op: _ } | OpType::ElemBinary { op: _ } => {
231 consecutive_elementwise += 1;
232 if consecutive_elementwise >= 2 {
233 return true;
234 }
235 }
236 _ => {
237 consecutive_elementwise = 0;
238 }
239 }
240 }
241
242 false
243}
244
245pub fn quick_analyze(graph: &EinsumGraph) -> Result<(usize, usize)> {
260 let memory = analyze_memory(graph, 8)?;
261 let parallel = analyze_parallelization(graph)?;
262 let parallel_groups = parallel
263 .parallel_groups
264 .iter()
265 .filter(|g| g.nodes.len() > 1)
266 .count();
267 Ok((memory.peak_memory_bytes, parallel_groups))
268}
269
270pub fn print_report(report: &AnalysisReport) {
283 println!("=== Graph Analysis Report ===");
284 println!(
285 "Peak Memory Usage: {:.2} MB",
286 report.peak_memory_bytes as f64 / 1_048_576.0
287 );
288 println!(
289 "Potential Memory Savings: {:.2} MB",
290 report.potential_memory_savings as f64 / 1_048_576.0
291 );
292
293 if !report.parallel_opportunities.is_empty() {
294 println!(
295 "\nParallelization Opportunities: {}",
296 report.parallel_opportunities.len()
297 );
298 for (i, opp) in report.parallel_opportunities.iter().enumerate() {
299 println!(
300 " {}. {} (speedup: {:.2}x)",
301 i + 1,
302 opp.description,
303 opp.estimated_speedup
304 );
305 }
306 }
307
308 if !report.recommendations.is_empty() {
309 println!("\nOptimization Recommendations:");
310 for (i, rec) in report.recommendations.iter().enumerate() {
311 println!(
312 " {}. [Priority {}] {} (improvement: {:.2}x)",
313 i + 1,
314 rec.priority,
315 rec.description,
316 rec.estimated_improvement
317 );
318 }
319
320 println!(
321 "\nEstimated Total Speedup: {:.2}x",
322 report.estimated_total_speedup()
323 );
324 }
325
326 println!("============================");
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tensorlogic_ir::EinsumNode;
333
334 #[test]
335 fn test_analysis_report_new() {
336 let report = AnalysisReport::new();
337 assert_eq!(report.peak_memory_bytes, 0);
338 assert!(report.parallel_opportunities.is_empty());
339 }
340
341 #[test]
342 fn test_has_critical_recommendations() {
343 let mut report = AnalysisReport::new();
344 assert!(!report.has_critical_recommendations());
345
346 report.recommendations.push(OptimizationRecommendation {
347 priority: 9,
348 category: RecommendationCategory::Fusion,
349 description: "Test".to_string(),
350 estimated_improvement: 1.5,
351 });
352 assert!(report.has_critical_recommendations());
353 }
354
355 #[test]
356 fn test_high_priority_recommendations() {
357 let mut report = AnalysisReport::new();
358
359 report.recommendations.push(OptimizationRecommendation {
360 priority: 9,
361 category: RecommendationCategory::Fusion,
362 description: "High priority".to_string(),
363 estimated_improvement: 1.5,
364 });
365
366 report.recommendations.push(OptimizationRecommendation {
367 priority: 5,
368 category: RecommendationCategory::General,
369 description: "Low priority".to_string(),
370 estimated_improvement: 1.1,
371 });
372
373 let high_priority = report.high_priority_recommendations();
374 assert_eq!(high_priority.len(), 1);
375 assert_eq!(high_priority[0].priority, 9);
376 }
377
378 #[test]
379 fn test_estimated_total_speedup() {
380 let mut report = AnalysisReport::new();
381
382 report.recommendations.push(OptimizationRecommendation {
383 priority: 9,
384 category: RecommendationCategory::Fusion,
385 description: "Test 1".to_string(),
386 estimated_improvement: 1.5,
387 });
388
389 report.recommendations.push(OptimizationRecommendation {
390 priority: 8,
391 category: RecommendationCategory::Memory,
392 description: "Test 2".to_string(),
393 estimated_improvement: 1.2,
394 });
395
396 let speedup = report.estimated_total_speedup();
397 assert!((speedup - 1.8).abs() < 0.01); }
399
400 #[test]
401 fn test_analyze_empty_graph() {
402 let graph = EinsumGraph::new();
403 let result = analyze_graph(&graph);
404 assert!(result.is_ok());
405
406 let report = result.unwrap();
407 assert_eq!(report.peak_memory_bytes, 0);
408 }
409
410 #[test]
411 fn test_quick_analyze_empty() {
412 let graph = EinsumGraph::new();
413 let result = quick_analyze(&graph);
414 assert!(result.is_ok());
415
416 let (memory, parallel_groups) = result.unwrap();
417 assert_eq!(memory, 0);
418 assert_eq!(parallel_groups, 0);
419 }
420
421 #[test]
422 fn test_has_fusible_operations_no_fusion() {
423 let mut graph = EinsumGraph::new();
424 let t0 = graph.add_tensor("x");
425 let t1 = graph.add_tensor("x_relu");
426 let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t1));
427
428 assert!(!has_fusible_operations(&graph));
429 }
430
431 #[test]
432 fn test_has_fusible_operations_with_fusion() {
433 let mut graph = EinsumGraph::new();
434 let t0 = graph.add_tensor("x");
435 let t1 = graph.add_tensor("y");
436 let t2 = graph.add_tensor("x_relu");
437 let t3 = graph.add_tensor("y_tanh");
438 let _ = graph.add_node(EinsumNode::elem_unary("relu", t0, t2));
439 let _ = graph.add_node(EinsumNode::elem_unary("tanh", t1, t3));
440
441 assert!(has_fusible_operations(&graph));
443 }
444
445 #[test]
446 fn test_recommendation_category_equality() {
447 assert_eq!(
448 RecommendationCategory::Fusion,
449 RecommendationCategory::Fusion
450 );
451 assert_ne!(
452 RecommendationCategory::Fusion,
453 RecommendationCategory::Memory
454 );
455 }
456
457 #[test]
458 fn test_parallel_opportunity_creation() {
459 let opp = ParallelOpportunity {
460 parallel_nodes: vec![1, 2, 3],
461 estimated_speedup: 1.7,
462 description: "Test opportunity".to_string(),
463 };
464
465 assert_eq!(opp.parallel_nodes.len(), 3);
466 assert!((opp.estimated_speedup - 1.7).abs() < 0.01);
467 }
468}