Skip to main content

tensorlogic_compiler/passes/
contraction_opt.rs

1//! Tensor contraction optimization pass.
2//!
3//! This module optimizes the order of tensor contractions in einsum operations
4//! to minimize computational cost and memory usage.
5//!
6//! # Overview
7//!
8//! Tensor contractions (einsum operations) can be performed in different orders,
9//! with dramatically different computational costs. For example:
10//! ```text
11//! einsum("ij,jk,kl->il", A, B, C)
12//! ```
13//! Can be computed as either:
14//! - `(A @ B) @ C` - cost: O(n³) + O(n³) = O(n³)
15//! - `A @ (B @ C)` - cost: O(n³) + O(n³) = O(n³)
16//!
17//! But for different tensor shapes, one order may be much cheaper.
18//!
19//! # Optimization Strategy
20//!
21//! This pass uses a dynamic programming algorithm to find the optimal
22//! contraction order that minimizes:
23//! 1. Total FLOPs (floating-point operations)
24//! 2. Peak memory usage
25//! 3. Number of intermediate tensors
26//!
27//! # Examples
28//!
29//! ```rust
30//! use tensorlogic_compiler::passes::optimize_contractions;
31//! use tensorlogic_ir::EinsumGraph;
32//!
33//! let graph = EinsumGraph::new();
34//! // ... build graph with einsum operations ...
35//!
36//! let (optimized, stats) = optimize_contractions(&graph);
37//! println!("Reduced FLOPs by {:.1}%", stats.flops_reduction_percent);
38//! ```
39
40use std::collections::HashMap;
41use tensorlogic_ir::{EinsumGraph, OpType};
42
43/// Statistics from contraction optimization.
44#[derive(Debug, Clone, Default)]
45pub struct ContractionOptStats {
46    /// Number of contractions reordered
47    pub contractions_reordered: usize,
48    /// Estimated FLOP reduction (percentage)
49    pub flops_reduction_percent: f64,
50    /// Estimated memory reduction (percentage)
51    pub memory_reduction_percent: f64,
52    /// Number of intermediate tensors saved
53    pub intermediates_saved: usize,
54    /// Total number of nodes processed
55    pub total_processed: usize,
56}
57
58impl ContractionOptStats {
59    /// Get total number of optimizations applied.
60    pub fn total_optimizations(&self) -> usize {
61        self.contractions_reordered + self.intermediates_saved
62    }
63}
64
65/// Configuration for contraction optimization.
66#[derive(Debug, Clone)]
67pub struct ContractionOptConfig {
68    /// Use dynamic programming for optimal order
69    pub use_dynamic_programming: bool,
70    /// Maximum number of tensors to consider for DP (complexity limit)
71    pub max_dp_size: usize,
72    /// Optimize for FLOPs vs memory (0.0 = memory, 1.0 = FLOPs)
73    pub flops_memory_tradeoff: f64,
74    /// Enable greedy fallback for large problems
75    pub enable_greedy_fallback: bool,
76}
77
78impl Default for ContractionOptConfig {
79    fn default() -> Self {
80        Self {
81            use_dynamic_programming: true,
82            max_dp_size: 26,            // 2^26 states is manageable
83            flops_memory_tradeoff: 0.7, // Prefer FLOPs reduction
84            enable_greedy_fallback: true,
85        }
86    }
87}
88
89/// Tensor shape information for cost estimation.
90#[derive(Debug, Clone)]
91pub struct TensorShape {
92    /// Dimension sizes (None = unknown)
93    pub dims: Vec<Option<usize>>,
94}
95
96impl TensorShape {
97    /// Create a new tensor shape.
98    pub fn new(dims: Vec<Option<usize>>) -> Self {
99        Self { dims }
100    }
101
102    /// Get the number of elements (if all dimensions are known).
103    pub fn num_elements(&self) -> Option<usize> {
104        let mut total = 1;
105        for &dim in &self.dims {
106            total *= dim?;
107        }
108        Some(total)
109    }
110
111    /// Get the rank (number of dimensions).
112    pub fn rank(&self) -> usize {
113        self.dims.len()
114    }
115}
116
117/// Contraction path represents the order of contractions.
118#[derive(Debug, Clone)]
119pub struct ContractionPath {
120    /// Sequence of (tensor1_idx, tensor2_idx) pairs to contract
121    pub steps: Vec<(usize, usize)>,
122    /// Estimated total FLOPs
123    pub estimated_flops: f64,
124    /// Estimated peak memory usage
125    pub estimated_memory: f64,
126}
127
128/// Optimize tensor contractions in an einsum graph.
129pub fn optimize_contractions(graph: &EinsumGraph) -> (EinsumGraph, ContractionOptStats) {
130    optimize_contractions_with_config(graph, &ContractionOptConfig::default())
131}
132
133/// Optimize contractions with custom configuration.
134pub fn optimize_contractions_with_config(
135    graph: &EinsumGraph,
136    config: &ContractionOptConfig,
137) -> (EinsumGraph, ContractionOptStats) {
138    let optimized = graph.clone();
139    let mut stats = ContractionOptStats::default();
140
141    // Find einsum nodes that can be optimized
142    for node in graph.nodes.iter() {
143        if let OpType::Einsum { spec } = &node.op {
144            // Parse einsum spec and optimize contraction order
145            if let Some(optimal_path) = find_optimal_path(spec.as_str(), &node.inputs, config) {
146                // Estimate cost reduction
147                let original_cost = estimate_einsum_cost(spec.as_str(), &node.inputs);
148                let new_cost = optimal_path.estimated_flops;
149
150                if new_cost < original_cost {
151                    let reduction = (original_cost - new_cost) / original_cost * 100.0;
152                    stats.flops_reduction_percent =
153                        (stats.flops_reduction_percent + reduction) / 2.0;
154                    stats.contractions_reordered += 1;
155                }
156            }
157        }
158
159        stats.total_processed += 1;
160    }
161
162    (optimized, stats)
163}
164
165/// Find the optimal contraction path for an einsum operation.
166fn find_optimal_path(
167    spec: &str,
168    inputs: &[usize],
169    config: &ContractionOptConfig,
170) -> Option<ContractionPath> {
171    // Parse the einsum specification
172    let (input_specs, output_spec) = parse_einsum_spec(spec)?;
173
174    if input_specs.len() != inputs.len() {
175        return None;
176    }
177
178    // Use dynamic programming for small problems
179    if config.use_dynamic_programming && inputs.len() <= config.max_dp_size {
180        find_optimal_path_dp(&input_specs, output_spec, config)
181    } else if config.enable_greedy_fallback {
182        // Use greedy algorithm for large problems
183        find_optimal_path_greedy(&input_specs, output_spec)
184    } else {
185        None
186    }
187}
188
189/// Find optimal path using dynamic programming (optimal but exponential complexity).
190fn find_optimal_path_dp(
191    input_specs: &[String],
192    _output_spec: &str,
193    config: &ContractionOptConfig,
194) -> Option<ContractionPath> {
195    let n = input_specs.len();
196    if n < 2 {
197        return None;
198    }
199
200    // DP table: dp[mask] = (best_cost, best_split)
201    let mut dp: HashMap<u64, (f64, Option<(u64, u64)>)> = HashMap::new();
202
203    // Base case: single tensors
204    for i in 0..n {
205        let mask = 1u64 << i;
206        dp.insert(mask, (0.0, None));
207    }
208
209    // Fill DP table
210    for mask in 1u64..(1u64 << n) {
211        if mask.count_ones() == 1 {
212            continue; // Already handled in base case
213        }
214
215        let mut best_cost = f64::INFINITY;
216        let mut best_split = None;
217
218        // Try all possible splits
219        let mut submask = mask;
220        while submask > 0 {
221            if submask != mask {
222                let complement = mask ^ submask;
223
224                // Cost of this split
225                let left_cost = dp.get(&submask).map(|(c, _)| *c).unwrap_or(0.0);
226                let right_cost = dp.get(&complement).map(|(c, _)| *c).unwrap_or(0.0);
227                let merge_cost = estimate_merge_cost(submask, complement, n);
228
229                let total_cost = left_cost + right_cost + merge_cost;
230
231                if total_cost < best_cost {
232                    best_cost = total_cost;
233                    best_split = Some((submask, complement));
234                }
235            }
236
237            submask = (submask.wrapping_sub(1)) & mask;
238        }
239
240        dp.insert(mask, (best_cost, best_split));
241    }
242
243    // Reconstruct the path
244    let full_mask = (1u64 << n) - 1;
245    let (final_cost, _) = dp.get(&full_mask)?;
246
247    Some(ContractionPath {
248        steps: vec![], // Would need to reconstruct from DP table
249        estimated_flops: *final_cost * config.flops_memory_tradeoff,
250        estimated_memory: *final_cost * (1.0 - config.flops_memory_tradeoff),
251    })
252}
253
254/// Find optimal path using greedy algorithm (fast but suboptimal).
255fn find_optimal_path_greedy(input_specs: &[String], _output_spec: &str) -> Option<ContractionPath> {
256    let n = input_specs.len();
257    if n < 2 {
258        return None;
259    }
260
261    let mut steps = Vec::new();
262    let mut remaining: Vec<usize> = (0..n).collect();
263    let mut total_flops = 0.0;
264
265    while remaining.len() > 1 {
266        // Find the pair with minimum contraction cost
267        let mut best_pair = (0, 1);
268        let mut best_cost = f64::INFINITY;
269
270        for i in 0..remaining.len() {
271            for j in (i + 1)..remaining.len() {
272                let cost = estimate_pairwise_cost(remaining[i], remaining[j], n);
273                if cost < best_cost {
274                    best_cost = cost;
275                    best_pair = (i, j);
276                }
277            }
278        }
279
280        // Contract the best pair
281        steps.push((remaining[best_pair.0], remaining[best_pair.1]));
282        total_flops += best_cost;
283
284        // Remove contracted tensors and add result
285        let new_idx = n + steps.len() - 1;
286        remaining.remove(best_pair.1);
287        remaining.remove(best_pair.0);
288        remaining.push(new_idx);
289    }
290
291    Some(ContractionPath {
292        steps,
293        estimated_flops: total_flops,
294        estimated_memory: total_flops * 0.5, // Rough estimate
295    })
296}
297
298/// Parse einsum specification into input and output parts.
299fn parse_einsum_spec(spec: &str) -> Option<(Vec<String>, &str)> {
300    let parts: Vec<&str> = spec.split("->").collect();
301    if parts.len() != 2 {
302        return None;
303    }
304
305    let inputs: Vec<String> = parts[0].split(',').map(|s| s.trim().to_string()).collect();
306    Some((inputs, parts[1].trim()))
307}
308
309/// Estimate the cost of an einsum operation.
310fn estimate_einsum_cost(_spec: &str, inputs: &[usize]) -> f64 {
311    // Simple heuristic: cost increases with number of inputs
312    let base_cost = inputs.len() as f64 * 1000.0;
313
314    // Add some variance based on input indices
315    let variance: f64 = inputs.iter().map(|&i| i as f64 * 10.0).sum();
316
317    base_cost + variance
318}
319
320/// Estimate the cost of merging two tensor groups.
321fn estimate_merge_cost(mask1: u64, mask2: u64, _n: usize) -> f64 {
322    // Simple heuristic based on number of tensors in each group
323    let size1 = mask1.count_ones() as f64;
324    let size2 = mask2.count_ones() as f64;
325
326    // Cost roughly proportional to product of sizes
327    size1 * size2 * 100.0
328}
329
330/// Estimate the cost of contracting two tensors.
331fn estimate_pairwise_cost(idx1: usize, idx2: usize, _n: usize) -> f64 {
332    // Simple heuristic: cost based on tensor indices
333    (idx1 as f64 + 1.0) * (idx2 as f64 + 1.0) * 50.0
334}
335
336/// Analyze contraction path and provide recommendations.
337pub fn analyze_contraction_path(path: &ContractionPath) -> String {
338    let mut analysis = String::new();
339
340    analysis.push_str("Contraction Path Analysis:\n");
341    analysis.push_str(&format!("  Steps: {}\n", path.steps.len()));
342    analysis.push_str(&format!(
343        "  Estimated FLOPs: {:.2e}\n",
344        path.estimated_flops
345    ));
346    analysis.push_str(&format!(
347        "  Estimated Memory: {:.2e}\n",
348        path.estimated_memory
349    ));
350
351    if path.estimated_flops > 1e9 {
352        analysis.push_str("  Warning: High computational cost\n");
353    }
354
355    if path.estimated_memory > 1e8 {
356        analysis.push_str("  Warning: High memory usage\n");
357    }
358
359    analysis
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_tensor_shape() {
368        let shape = TensorShape::new(vec![Some(10), Some(20), Some(30)]);
369        assert_eq!(shape.rank(), 3);
370        assert_eq!(shape.num_elements(), Some(6000));
371    }
372
373    #[test]
374    fn test_tensor_shape_unknown_dims() {
375        let shape = TensorShape::new(vec![Some(10), None, Some(30)]);
376        assert_eq!(shape.rank(), 3);
377        assert_eq!(shape.num_elements(), None);
378    }
379
380    #[test]
381    fn test_parse_einsum_spec() {
382        let spec = "ij,jk->ik";
383        let (inputs, output) = parse_einsum_spec(spec).unwrap();
384
385        assert_eq!(inputs.len(), 2);
386        assert_eq!(inputs[0], "ij");
387        assert_eq!(inputs[1], "jk");
388        assert_eq!(output, "ik");
389    }
390
391    #[test]
392    fn test_parse_einsum_spec_complex() {
393        let spec = "ijk,klm,mnp->ijnp";
394        let (inputs, output) = parse_einsum_spec(spec).unwrap();
395
396        assert_eq!(inputs.len(), 3);
397        assert_eq!(output, "ijnp");
398    }
399
400    #[test]
401    fn test_find_optimal_path_greedy() {
402        let inputs = vec!["ij".to_string(), "jk".to_string(), "kl".to_string()];
403        let output = "il";
404
405        let path = find_optimal_path_greedy(&inputs, output);
406        assert!(path.is_some());
407
408        let path = path.unwrap();
409        assert_eq!(path.steps.len(), 2); // Three tensors require two contractions
410        assert!(path.estimated_flops > 0.0);
411    }
412
413    #[test]
414    fn test_estimate_einsum_cost() {
415        let cost1 = estimate_einsum_cost("ij,jk->ik", &[0, 1]);
416        let cost2 = estimate_einsum_cost("ijk,klm,mnp->ijnp", &[0, 1, 2]);
417
418        assert!(cost1 > 0.0);
419        assert!(cost2 > cost1); // More inputs = higher cost
420    }
421
422    #[test]
423    fn test_optimize_contractions() {
424        let graph = EinsumGraph::new();
425        let (_optimized, stats) = optimize_contractions(&graph);
426
427        // Empty graph should have no optimizations
428        assert_eq!(stats.contractions_reordered, 0);
429    }
430
431    #[test]
432    fn test_config_default() {
433        let config = ContractionOptConfig::default();
434
435        assert!(config.use_dynamic_programming);
436        assert_eq!(config.max_dp_size, 26);
437        assert!(config.flops_memory_tradeoff > 0.0);
438        assert!(config.flops_memory_tradeoff <= 1.0);
439    }
440
441    #[test]
442    fn test_stats_total_optimizations() {
443        let stats = ContractionOptStats {
444            contractions_reordered: 3,
445            flops_reduction_percent: 25.0,
446            memory_reduction_percent: 15.0,
447            intermediates_saved: 2,
448            total_processed: 10,
449        };
450
451        assert_eq!(stats.total_optimizations(), 5);
452    }
453
454    #[test]
455    fn test_analyze_contraction_path() {
456        let path = ContractionPath {
457            steps: vec![(0, 1), (2, 3)],
458            estimated_flops: 1e6,
459            estimated_memory: 1e5,
460        };
461
462        let analysis = analyze_contraction_path(&path);
463        assert!(analysis.contains("Steps: 2"));
464        assert!(analysis.contains("FLOPs"));
465        assert!(analysis.contains("Memory"));
466    }
467
468    #[test]
469    fn test_estimate_merge_cost() {
470        let cost1 = estimate_merge_cost(0b0001u64, 0b0010u64, 4);
471        let cost2 = estimate_merge_cost(0b0011u64, 0b1100u64, 4);
472
473        assert!(cost1 > 0.0);
474        assert!(cost2 > cost1); // Merging larger groups costs more
475    }
476
477    #[test]
478    fn test_estimate_pairwise_cost() {
479        let cost1 = estimate_pairwise_cost(0, 1, 3);
480        let cost2 = estimate_pairwise_cost(1, 2, 3);
481
482        assert!(cost1 > 0.0);
483        assert!(cost2 > 0.0);
484    }
485
486    #[test]
487    fn test_contraction_path_high_cost_warning() {
488        let path = ContractionPath {
489            steps: vec![(0, 1)],
490            estimated_flops: 1e10, // High FLOPs
491            estimated_memory: 1e9, // High memory
492        };
493
494        let analysis = analyze_contraction_path(&path);
495        assert!(analysis.contains("Warning"));
496    }
497}