tensorlogic_compiler/passes/
contraction_opt.rs1use std::collections::HashMap;
41use tensorlogic_ir::{EinsumGraph, OpType};
42
43#[derive(Debug, Clone, Default)]
45pub struct ContractionOptStats {
46 pub contractions_reordered: usize,
48 pub flops_reduction_percent: f64,
50 pub memory_reduction_percent: f64,
52 pub intermediates_saved: usize,
54 pub total_processed: usize,
56}
57
58impl ContractionOptStats {
59 pub fn total_optimizations(&self) -> usize {
61 self.contractions_reordered + self.intermediates_saved
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ContractionOptConfig {
68 pub use_dynamic_programming: bool,
70 pub max_dp_size: usize,
72 pub flops_memory_tradeoff: f64,
74 pub enable_greedy_fallback: bool,
76}
77
78impl Default for ContractionOptConfig {
79 fn default() -> Self {
80 Self {
81 use_dynamic_programming: true,
82 max_dp_size: 26, flops_memory_tradeoff: 0.7, enable_greedy_fallback: true,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct TensorShape {
92 pub dims: Vec<Option<usize>>,
94}
95
96impl TensorShape {
97 pub fn new(dims: Vec<Option<usize>>) -> Self {
99 Self { dims }
100 }
101
102 pub fn num_elements(&self) -> Option<usize> {
104 let mut total = 1;
105 for &dim in &self.dims {
106 total *= dim?;
107 }
108 Some(total)
109 }
110
111 pub fn rank(&self) -> usize {
113 self.dims.len()
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct ContractionPath {
120 pub steps: Vec<(usize, usize)>,
122 pub estimated_flops: f64,
124 pub estimated_memory: f64,
126}
127
128pub fn optimize_contractions(graph: &EinsumGraph) -> (EinsumGraph, ContractionOptStats) {
130 optimize_contractions_with_config(graph, &ContractionOptConfig::default())
131}
132
133pub fn optimize_contractions_with_config(
135 graph: &EinsumGraph,
136 config: &ContractionOptConfig,
137) -> (EinsumGraph, ContractionOptStats) {
138 let optimized = graph.clone();
139 let mut stats = ContractionOptStats::default();
140
141 for node in graph.nodes.iter() {
143 if let OpType::Einsum { spec } = &node.op {
144 if let Some(optimal_path) = find_optimal_path(spec.as_str(), &node.inputs, config) {
146 let original_cost = estimate_einsum_cost(spec.as_str(), &node.inputs);
148 let new_cost = optimal_path.estimated_flops;
149
150 if new_cost < original_cost {
151 let reduction = (original_cost - new_cost) / original_cost * 100.0;
152 stats.flops_reduction_percent =
153 (stats.flops_reduction_percent + reduction) / 2.0;
154 stats.contractions_reordered += 1;
155 }
156 }
157 }
158
159 stats.total_processed += 1;
160 }
161
162 (optimized, stats)
163}
164
165fn find_optimal_path(
167 spec: &str,
168 inputs: &[usize],
169 config: &ContractionOptConfig,
170) -> Option<ContractionPath> {
171 let (input_specs, output_spec) = parse_einsum_spec(spec)?;
173
174 if input_specs.len() != inputs.len() {
175 return None;
176 }
177
178 if config.use_dynamic_programming && inputs.len() <= config.max_dp_size {
180 find_optimal_path_dp(&input_specs, output_spec, config)
181 } else if config.enable_greedy_fallback {
182 find_optimal_path_greedy(&input_specs, output_spec)
184 } else {
185 None
186 }
187}
188
189fn find_optimal_path_dp(
191 input_specs: &[String],
192 _output_spec: &str,
193 config: &ContractionOptConfig,
194) -> Option<ContractionPath> {
195 let n = input_specs.len();
196 if n < 2 {
197 return None;
198 }
199
200 let mut dp: HashMap<u64, (f64, Option<(u64, u64)>)> = HashMap::new();
202
203 for i in 0..n {
205 let mask = 1u64 << i;
206 dp.insert(mask, (0.0, None));
207 }
208
209 for mask in 1u64..(1u64 << n) {
211 if mask.count_ones() == 1 {
212 continue; }
214
215 let mut best_cost = f64::INFINITY;
216 let mut best_split = None;
217
218 let mut submask = mask;
220 while submask > 0 {
221 if submask != mask {
222 let complement = mask ^ submask;
223
224 let left_cost = dp.get(&submask).map(|(c, _)| *c).unwrap_or(0.0);
226 let right_cost = dp.get(&complement).map(|(c, _)| *c).unwrap_or(0.0);
227 let merge_cost = estimate_merge_cost(submask, complement, n);
228
229 let total_cost = left_cost + right_cost + merge_cost;
230
231 if total_cost < best_cost {
232 best_cost = total_cost;
233 best_split = Some((submask, complement));
234 }
235 }
236
237 submask = (submask.wrapping_sub(1)) & mask;
238 }
239
240 dp.insert(mask, (best_cost, best_split));
241 }
242
243 let full_mask = (1u64 << n) - 1;
245 let (final_cost, _) = dp.get(&full_mask)?;
246
247 Some(ContractionPath {
248 steps: vec![], estimated_flops: *final_cost * config.flops_memory_tradeoff,
250 estimated_memory: *final_cost * (1.0 - config.flops_memory_tradeoff),
251 })
252}
253
254fn find_optimal_path_greedy(input_specs: &[String], _output_spec: &str) -> Option<ContractionPath> {
256 let n = input_specs.len();
257 if n < 2 {
258 return None;
259 }
260
261 let mut steps = Vec::new();
262 let mut remaining: Vec<usize> = (0..n).collect();
263 let mut total_flops = 0.0;
264
265 while remaining.len() > 1 {
266 let mut best_pair = (0, 1);
268 let mut best_cost = f64::INFINITY;
269
270 for i in 0..remaining.len() {
271 for j in (i + 1)..remaining.len() {
272 let cost = estimate_pairwise_cost(remaining[i], remaining[j], n);
273 if cost < best_cost {
274 best_cost = cost;
275 best_pair = (i, j);
276 }
277 }
278 }
279
280 steps.push((remaining[best_pair.0], remaining[best_pair.1]));
282 total_flops += best_cost;
283
284 let new_idx = n + steps.len() - 1;
286 remaining.remove(best_pair.1);
287 remaining.remove(best_pair.0);
288 remaining.push(new_idx);
289 }
290
291 Some(ContractionPath {
292 steps,
293 estimated_flops: total_flops,
294 estimated_memory: total_flops * 0.5, })
296}
297
298fn parse_einsum_spec(spec: &str) -> Option<(Vec<String>, &str)> {
300 let parts: Vec<&str> = spec.split("->").collect();
301 if parts.len() != 2 {
302 return None;
303 }
304
305 let inputs: Vec<String> = parts[0].split(',').map(|s| s.trim().to_string()).collect();
306 Some((inputs, parts[1].trim()))
307}
308
309fn estimate_einsum_cost(_spec: &str, inputs: &[usize]) -> f64 {
311 let base_cost = inputs.len() as f64 * 1000.0;
313
314 let variance: f64 = inputs.iter().map(|&i| i as f64 * 10.0).sum();
316
317 base_cost + variance
318}
319
320fn estimate_merge_cost(mask1: u64, mask2: u64, _n: usize) -> f64 {
322 let size1 = mask1.count_ones() as f64;
324 let size2 = mask2.count_ones() as f64;
325
326 size1 * size2 * 100.0
328}
329
330fn estimate_pairwise_cost(idx1: usize, idx2: usize, _n: usize) -> f64 {
332 (idx1 as f64 + 1.0) * (idx2 as f64 + 1.0) * 50.0
334}
335
336pub fn analyze_contraction_path(path: &ContractionPath) -> String {
338 let mut analysis = String::new();
339
340 analysis.push_str("Contraction Path Analysis:\n");
341 analysis.push_str(&format!(" Steps: {}\n", path.steps.len()));
342 analysis.push_str(&format!(
343 " Estimated FLOPs: {:.2e}\n",
344 path.estimated_flops
345 ));
346 analysis.push_str(&format!(
347 " Estimated Memory: {:.2e}\n",
348 path.estimated_memory
349 ));
350
351 if path.estimated_flops > 1e9 {
352 analysis.push_str(" Warning: High computational cost\n");
353 }
354
355 if path.estimated_memory > 1e8 {
356 analysis.push_str(" Warning: High memory usage\n");
357 }
358
359 analysis
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_tensor_shape() {
368 let shape = TensorShape::new(vec![Some(10), Some(20), Some(30)]);
369 assert_eq!(shape.rank(), 3);
370 assert_eq!(shape.num_elements(), Some(6000));
371 }
372
373 #[test]
374 fn test_tensor_shape_unknown_dims() {
375 let shape = TensorShape::new(vec![Some(10), None, Some(30)]);
376 assert_eq!(shape.rank(), 3);
377 assert_eq!(shape.num_elements(), None);
378 }
379
380 #[test]
381 fn test_parse_einsum_spec() {
382 let spec = "ij,jk->ik";
383 let (inputs, output) = parse_einsum_spec(spec).unwrap();
384
385 assert_eq!(inputs.len(), 2);
386 assert_eq!(inputs[0], "ij");
387 assert_eq!(inputs[1], "jk");
388 assert_eq!(output, "ik");
389 }
390
391 #[test]
392 fn test_parse_einsum_spec_complex() {
393 let spec = "ijk,klm,mnp->ijnp";
394 let (inputs, output) = parse_einsum_spec(spec).unwrap();
395
396 assert_eq!(inputs.len(), 3);
397 assert_eq!(output, "ijnp");
398 }
399
400 #[test]
401 fn test_find_optimal_path_greedy() {
402 let inputs = vec!["ij".to_string(), "jk".to_string(), "kl".to_string()];
403 let output = "il";
404
405 let path = find_optimal_path_greedy(&inputs, output);
406 assert!(path.is_some());
407
408 let path = path.unwrap();
409 assert_eq!(path.steps.len(), 2); assert!(path.estimated_flops > 0.0);
411 }
412
413 #[test]
414 fn test_estimate_einsum_cost() {
415 let cost1 = estimate_einsum_cost("ij,jk->ik", &[0, 1]);
416 let cost2 = estimate_einsum_cost("ijk,klm,mnp->ijnp", &[0, 1, 2]);
417
418 assert!(cost1 > 0.0);
419 assert!(cost2 > cost1); }
421
422 #[test]
423 fn test_optimize_contractions() {
424 let graph = EinsumGraph::new();
425 let (_optimized, stats) = optimize_contractions(&graph);
426
427 assert_eq!(stats.contractions_reordered, 0);
429 }
430
431 #[test]
432 fn test_config_default() {
433 let config = ContractionOptConfig::default();
434
435 assert!(config.use_dynamic_programming);
436 assert_eq!(config.max_dp_size, 26);
437 assert!(config.flops_memory_tradeoff > 0.0);
438 assert!(config.flops_memory_tradeoff <= 1.0);
439 }
440
441 #[test]
442 fn test_stats_total_optimizations() {
443 let stats = ContractionOptStats {
444 contractions_reordered: 3,
445 flops_reduction_percent: 25.0,
446 memory_reduction_percent: 15.0,
447 intermediates_saved: 2,
448 total_processed: 10,
449 };
450
451 assert_eq!(stats.total_optimizations(), 5);
452 }
453
454 #[test]
455 fn test_analyze_contraction_path() {
456 let path = ContractionPath {
457 steps: vec![(0, 1), (2, 3)],
458 estimated_flops: 1e6,
459 estimated_memory: 1e5,
460 };
461
462 let analysis = analyze_contraction_path(&path);
463 assert!(analysis.contains("Steps: 2"));
464 assert!(analysis.contains("FLOPs"));
465 assert!(analysis.contains("Memory"));
466 }
467
468 #[test]
469 fn test_estimate_merge_cost() {
470 let cost1 = estimate_merge_cost(0b0001u64, 0b0010u64, 4);
471 let cost2 = estimate_merge_cost(0b0011u64, 0b1100u64, 4);
472
473 assert!(cost1 > 0.0);
474 assert!(cost2 > cost1); }
476
477 #[test]
478 fn test_estimate_pairwise_cost() {
479 let cost1 = estimate_pairwise_cost(0, 1, 3);
480 let cost2 = estimate_pairwise_cost(1, 2, 3);
481
482 assert!(cost1 > 0.0);
483 assert!(cost2 > 0.0);
484 }
485
486 #[test]
487 fn test_contraction_path_high_cost_warning() {
488 let path = ContractionPath {
489 steps: vec![(0, 1)],
490 estimated_flops: 1e10, estimated_memory: 1e9, };
493
494 let analysis = analyze_contraction_path(&path);
495 assert!(analysis.contains("Warning"));
496 }
497}