tensorlogic_compiler/passes/
graph_opt_integration.rs1use anyhow::Result;
9use tensorlogic_ir::{
10 analyze_memory, apply_tiling, fold_constants_aggressive, fuse_elementwise_operations,
11 optimize_layouts, EinsumGraph, GraphScheduler, OpType, SchedulingObjective,
12};
13
14#[derive(Debug, Clone)]
16pub struct GraphOptConfig {
17 pub enable_fusion: bool,
19 pub enable_layout_opt: bool,
21 pub enable_memory_opt: bool,
23 pub enable_constant_folding: bool,
25 pub enable_tiling: bool,
27 pub enable_scheduling: bool,
29 pub tile_size: Option<usize>,
31 pub memory_budget: Option<usize>,
33}
34
35impl Default for GraphOptConfig {
36 fn default() -> Self {
37 Self {
38 enable_fusion: true,
39 enable_layout_opt: true,
40 enable_memory_opt: true,
41 enable_constant_folding: true,
42 enable_tiling: false, enable_scheduling: true,
44 tile_size: Some(64),
45 memory_budget: None, }
47 }
48}
49
50impl GraphOptConfig {
51 pub fn aggressive() -> Self {
53 Self {
54 enable_fusion: true,
55 enable_layout_opt: true,
56 enable_memory_opt: true,
57 enable_constant_folding: true,
58 enable_tiling: true,
59 enable_scheduling: true,
60 tile_size: Some(128),
61 memory_budget: None,
62 }
63 }
64
65 pub fn conservative() -> Self {
67 Self {
68 enable_fusion: true,
69 enable_layout_opt: false,
70 enable_memory_opt: false,
71 enable_constant_folding: true,
72 enable_tiling: false,
73 enable_scheduling: false,
74 tile_size: None,
75 memory_budget: None,
76 }
77 }
78
79 pub fn for_inference() -> Self {
81 Self {
82 enable_fusion: true,
83 enable_layout_opt: true,
84 enable_memory_opt: false,
85 enable_constant_folding: true,
86 enable_tiling: false,
87 enable_scheduling: true,
88 tile_size: None,
89 memory_budget: None,
90 }
91 }
92
93 pub fn for_training() -> Self {
95 Self {
96 enable_fusion: true,
97 enable_layout_opt: true,
98 enable_memory_opt: true,
99 enable_constant_folding: true,
100 enable_tiling: true,
101 enable_scheduling: true,
102 tile_size: Some(64),
103 memory_budget: None,
104 }
105 }
106}
107
108#[derive(Debug, Default, Clone)]
110pub struct GraphOptStats {
111 pub ops_fused: usize,
113 pub layout_transforms: usize,
115 pub memory_opts: usize,
117 pub graph_constants_folded: usize,
119 pub tiles_created: usize,
121 pub memory_saved: usize,
123 pub estimated_speedup: f64,
125}
126
127impl GraphOptStats {
128 pub fn total_optimizations(&self) -> usize {
130 self.ops_fused
131 + self.layout_transforms
132 + self.memory_opts
133 + self.graph_constants_folded
134 + self.tiles_created
135 }
136
137 pub fn any_applied(&self) -> bool {
139 self.total_optimizations() > 0
140 }
141}
142
143pub fn apply_graph_optimizations(
167 graph: &EinsumGraph,
168 config: &GraphOptConfig,
169) -> Result<(EinsumGraph, GraphOptStats)> {
170 let mut optimized = graph.clone();
171 let mut stats = GraphOptStats {
172 estimated_speedup: 1.0,
173 ..Default::default()
174 };
175
176 if config.enable_constant_folding {
178 let fold_result = fold_constants_aggressive(&mut optimized)?;
179 stats.graph_constants_folded = fold_result.operations_folded;
180 stats.estimated_speedup *= fold_result.estimated_speedup;
181 }
182
183 if config.enable_fusion {
185 let fusion_result = fuse_elementwise_operations(&mut optimized)?;
186 stats.ops_fused = fusion_result.ops_fused;
187 stats.estimated_speedup *= fusion_result.estimated_speedup;
188 }
189
190 if config.enable_layout_opt {
192 let layout_result = optimize_layouts(&optimized)?;
193 stats.layout_transforms = layout_result.transformations_needed;
194 stats.estimated_speedup *= layout_result.estimated_speedup;
195 }
196
197 if config.enable_memory_opt {
199 let mem_result = analyze_memory(&optimized, 8)?; stats.memory_saved = mem_result.total_memory_bytes - mem_result.peak_memory_bytes;
201 stats.memory_opts = if mem_result.avg_utilization < 0.8 {
203 1
204 } else {
205 0
206 };
207 }
208
209 if config.enable_tiling {
211 if let Some(tile_size) = config.tile_size {
212 use tensorlogic_ir::{TileConfig as IrTileConfig, TilingStrategy};
213 let mut strategy = TilingStrategy::new();
214 strategy.add_tile(IrTileConfig::new(0, tile_size));
215 let tiling_result = apply_tiling(&mut optimized, &strategy)?;
216 stats.tiles_created = tiling_result.nodes_tiled + tiling_result.loops_unrolled;
217 }
218 }
219
220 if config.enable_scheduling {
222 let scheduler = GraphScheduler::new();
223 let _schedule = scheduler.schedule(&optimized, SchedulingObjective::MinimizeMemory)?;
224 stats.estimated_speedup *= 1.05;
226 }
227
228 Ok((optimized, stats))
229}
230
231pub fn apply_pattern_optimizations(graph: &EinsumGraph) -> Result<(EinsumGraph, usize)> {
247 Ok((graph.clone(), 0))
251}
252
253pub fn quick_optimize(graph: &EinsumGraph) -> Result<EinsumGraph> {
268 let config = GraphOptConfig {
269 enable_fusion: true,
270 enable_layout_opt: false,
271 enable_memory_opt: false,
272 enable_constant_folding: true,
273 enable_tiling: false,
274 enable_scheduling: false,
275 tile_size: None,
276 memory_budget: None,
277 };
278
279 let (optimized, _) = apply_graph_optimizations(graph, &config)?;
280 Ok(optimized)
281}
282
283pub fn recommend_optimizations(graph: &EinsumGraph) -> GraphOptConfig {
299 let node_count = graph.nodes.len();
300 let tensor_count = graph.tensors.len();
301
302 let mut elementwise_count = 0;
304 let mut einsum_count = 0;
305
306 for node in &graph.nodes {
307 match &node.op {
308 OpType::ElemUnary { .. } | OpType::ElemBinary { .. } => elementwise_count += 1,
309 OpType::Einsum { .. } => einsum_count += 1,
310 _ => {}
311 }
312 }
313
314 if node_count < 10 {
316 return GraphOptConfig {
317 enable_fusion: elementwise_count > 2,
318 enable_layout_opt: false,
319 enable_memory_opt: false,
320 enable_constant_folding: true,
321 enable_tiling: false,
322 enable_scheduling: false,
323 tile_size: None,
324 memory_budget: None,
325 };
326 }
327
328 if node_count < 50 {
330 return GraphOptConfig {
331 enable_fusion: elementwise_count > 3,
332 enable_layout_opt: einsum_count > 5,
333 enable_memory_opt: tensor_count > 20,
334 enable_constant_folding: true,
335 enable_tiling: false,
336 enable_scheduling: true,
337 tile_size: Some(64),
338 memory_budget: None,
339 };
340 }
341
342 GraphOptConfig {
344 enable_fusion: true,
345 enable_layout_opt: true,
346 enable_memory_opt: true,
347 enable_constant_folding: true,
348 enable_tiling: einsum_count > 10,
349 enable_scheduling: true,
350 tile_size: Some(128),
351 memory_budget: None,
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use tensorlogic_ir::EinsumNode;
359
360 #[test]
361 fn test_config_defaults() {
362 let config = GraphOptConfig::default();
363 assert!(config.enable_fusion);
364 assert!(config.enable_constant_folding);
365 }
366
367 #[test]
368 fn test_config_aggressive() {
369 let config = GraphOptConfig::aggressive();
370 assert!(config.enable_fusion);
371 assert!(config.enable_layout_opt);
372 assert!(config.enable_memory_opt);
373 assert!(config.enable_tiling);
374 assert!(config.enable_scheduling);
375 }
376
377 #[test]
378 fn test_config_conservative() {
379 let config = GraphOptConfig::conservative();
380 assert!(config.enable_fusion);
381 assert!(!config.enable_layout_opt);
382 assert!(!config.enable_tiling);
383 }
384
385 #[test]
386 fn test_stats_total_optimizations() {
387 let stats = GraphOptStats {
388 ops_fused: 5,
389 layout_transforms: 3,
390 memory_opts: 2,
391 graph_constants_folded: 1,
392 tiles_created: 0,
393 memory_saved: 1024,
394 estimated_speedup: 1.5,
395 };
396 assert_eq!(stats.total_optimizations(), 11);
397 assert!(stats.any_applied());
398 }
399
400 #[test]
401 fn test_quick_optimize_empty_graph() {
402 let graph = EinsumGraph::new();
403 let result = quick_optimize(&graph);
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn test_recommend_optimizations_small_graph() {
409 let mut graph = EinsumGraph::new();
410 let t0 = graph.add_tensor("x");
411 let t1 = graph.add_tensor("y");
412 let t2 = graph.add_tensor("z");
413 let _ = graph.add_node(EinsumNode::elem_binary("add", t0, t1, t2));
414
415 let config = recommend_optimizations(&graph);
416 assert!(!config.enable_tiling);
418 assert!(config.enable_constant_folding);
419 }
420
421 #[test]
422 fn test_recommend_optimizations_medium_graph() {
423 let mut graph = EinsumGraph::new();
424 for i in 0..30 {
426 let t_in = graph.add_tensor(format!("t{}", i));
427 let t_out = graph.add_tensor(format!("t{}_out", i));
428 let _ = graph.add_node(EinsumNode::elem_unary("relu", t_in, t_out));
429 }
430
431 let config = recommend_optimizations(&graph);
432 assert!(config.enable_fusion);
433 assert!(config.enable_scheduling);
434 }
435
436 #[test]
437 fn test_apply_optimizations_with_default_config() {
438 let mut graph = EinsumGraph::new();
439 let t0 = graph.add_tensor("x");
440 let t1 = graph.add_tensor("const");
441 let t2 = graph.add_tensor("result");
442 let _ = graph.add_node(EinsumNode::elem_binary("add", t0, t1, t2));
443
444 let config = GraphOptConfig::default();
445 let result = apply_graph_optimizations(&graph, &config);
446 assert!(result.is_ok());
447
448 let (_optimized, stats) = result.unwrap();
449 assert!(stats.estimated_speedup >= 1.0);
450 }
451
452 #[test]
453 fn test_apply_pattern_optimizations_empty() {
454 let graph = EinsumGraph::new();
455 let result = apply_pattern_optimizations(&graph);
456 assert!(result.is_ok());
457 }
458
459 #[test]
460 fn test_stats_any_applied_false() {
461 let stats = GraphOptStats::default();
462 assert!(!stats.any_applied());
463 }
464
465 #[test]
466 fn test_config_for_inference() {
467 let config = GraphOptConfig::for_inference();
468 assert!(config.enable_fusion);
469 assert!(config.enable_layout_opt);
470 assert!(!config.enable_memory_opt); assert!(!config.enable_tiling);
472 }
473
474 #[test]
475 fn test_config_for_training() {
476 let config = GraphOptConfig::for_training();
477 assert!(config.enable_memory_opt); assert!(config.enable_tiling);
479 assert_eq!(config.tile_size, Some(64));
480 }
481}