Skip to main content

scivex_core/tensor/
einsum_path.rs

1//! Einsum contraction path optimizer.
2//!
3//! Determines the optimal order in which to contract pairs of tensors in a
4//! multi-operand einsum expression. This can dramatically reduce the number
5//! of floating-point operations and intermediate memory usage compared to
6//! contracting all tensors at once.
7//!
8//! Two strategies are provided:
9//! - **Greedy**: O(n^3) heuristic that picks the cheapest pairwise contraction
10//!   at each step. Works well in practice and is the default.
11//! - **Optimal**: Exhaustive search over all contraction orderings. Guaranteed
12//!   to find the minimum-FLOP path but is O(n!) and only practical for small
13//!   numbers of operands (≤ ~6).
14//!
15//! # Examples
16//!
17//! ```
18//! # use scivex_core::Tensor;
19//! # use scivex_core::tensor::einsum_path::{einsum_path, PathStrategy, einsum_optimized};
20//! let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
21//! let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
22//! let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
23//!
24//! // Get the contraction path
25//! let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
26//! assert_eq!(info.path.len(), 2); // Two pairwise contractions
27//!
28//! // Use the optimized einsum
29//! let result = einsum_optimized("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
30//! assert_eq!(result.shape(), &[2, 5]);
31//! ```
32
33use crate::Scalar;
34use crate::error::{CoreError, Result};
35use crate::tensor::Tensor;
36use std::collections::{BTreeMap, BTreeSet};
37
38use super::einsum::einsum;
39
40/// Strategy for finding the contraction path.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum PathStrategy {
43    /// Greedy heuristic: at each step, contract the pair with the smallest
44    /// intermediate tensor. O(n^3) in the number of operands.
45    Greedy,
46    /// Exhaustive search: tries all possible contraction orderings and picks
47    /// the one with the lowest total FLOP count. O(n!) — only practical for
48    /// ≤ ~6 operands.
49    Optimal,
50}
51
52/// A single contraction step: contract operands at positions `(i, j)` in the
53/// current operand list (where `i < j`).
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct ContractionPair {
56    /// First operand index (in the current list at this step).
57    pub first: usize,
58    /// Second operand index (in the current list at this step).
59    pub second: usize,
60}
61
62/// Information about the chosen contraction path.
63#[derive(Debug, Clone)]
64pub struct PathInfo {
65    /// Sequence of pairwise contractions. Each pair refers to indices in the
66    /// operand list *at that step* (operands are removed and the intermediate
67    /// result is appended after each step).
68    pub path: Vec<ContractionPair>,
69    /// Estimated total number of multiply-add operations (FLOPs).
70    pub flops: usize,
71    /// Size of the largest intermediate tensor (in number of elements).
72    pub largest_intermediate: usize,
73}
74
75/// Internal representation of an operand's index structure.
76#[derive(Debug, Clone)]
77struct OperandDesc {
78    /// Index labels for this operand.
79    indices: Vec<char>,
80}
81
82/// Parsed subscript components.
83type ParsedSubscripts = (Vec<Vec<char>>, Vec<char>, BTreeMap<char, usize>);
84
85/// Compute the contraction path for an einsum expression.
86///
87/// Returns a [`PathInfo`] describing the order in which to contract pairs of
88/// operands, along with estimated FLOPs and intermediate sizes.
89pub fn einsum_path<T: Scalar>(
90    subscripts: &str,
91    operands: &[&Tensor<T>],
92    strategy: PathStrategy,
93) -> Result<PathInfo> {
94    let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
95
96    let descs: Vec<OperandDesc> = input_subs
97        .iter()
98        .map(|indices| OperandDesc {
99            indices: indices.clone(),
100        })
101        .collect();
102
103    match strategy {
104        PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes),
105        PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes),
106    }
107}
108
109/// Execute an einsum using an optimized contraction path.
110///
111/// For expressions with 2 or fewer operands, this falls through to the
112/// standard `einsum`. For 3+ operands, it first computes a contraction path
113/// and then executes pairwise contractions in that order.
114pub fn einsum_optimized<T: Scalar>(
115    subscripts: &str,
116    operands: &[&Tensor<T>],
117    strategy: PathStrategy,
118) -> Result<Tensor<T>> {
119    if operands.len() <= 2 {
120        return einsum(subscripts, operands);
121    }
122
123    let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
124
125    let descs: Vec<OperandDesc> = input_subs
126        .iter()
127        .map(|indices| OperandDesc {
128            indices: indices.clone(),
129        })
130        .collect();
131
132    let path_info = match strategy {
133        PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes)?,
134        PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes)?,
135    };
136
137    execute_path(subscripts, operands, &path_info, &input_subs, &output_sub)
138}
139
140// ======================================================================
141// Parsing
142// ======================================================================
143
144/// Parse subscripts for path computation. Returns (input index lists, output
145/// indices, index-to-size mapping).
146fn parse_path_subscripts<T: Scalar>(
147    subscripts: &str,
148    operands: &[&Tensor<T>],
149) -> Result<ParsedSubscripts> {
150    let subscripts = subscripts.replace(' ', "");
151
152    let (inputs_str, output_sub) = if let Some((inp, out)) = subscripts.split_once("->") {
153        let output_indices: Vec<char> = out.chars().collect();
154        (inp.to_string(), output_indices)
155    } else {
156        // Implicit mode: output = sorted unique indices that appear exactly once
157        let mut counts: BTreeMap<char, usize> = BTreeMap::new();
158        for c in subscripts.chars() {
159            if c == ',' {
160                continue;
161            }
162            *counts.entry(c).or_insert(0) += 1;
163        }
164        let output_indices: Vec<char> = counts
165            .iter()
166            .filter(|(_, count)| **count == 1)
167            .map(|(&c, _)| c)
168            .collect();
169        (subscripts.clone(), output_indices)
170    };
171
172    let input_parts: Vec<&str> = inputs_str.split(',').collect();
173    if input_parts.len() != operands.len() {
174        return Err(CoreError::InvalidArgument {
175            reason: "number of subscript groups does not match number of operands",
176        });
177    }
178
179    let mut input_subs = Vec::with_capacity(input_parts.len());
180    let mut index_sizes: BTreeMap<char, usize> = BTreeMap::new();
181
182    for (i, part) in input_parts.iter().enumerate() {
183        let indices: Vec<char> = part.chars().collect();
184        if indices.len() != operands[i].ndim() {
185            return Err(CoreError::InvalidArgument {
186                reason: "operand rank does not match number of subscript indices",
187            });
188        }
189        let shape = operands[i].shape();
190        for (d, &c) in indices.iter().enumerate() {
191            if let Some(&existing) = index_sizes.get(&c) {
192                if existing != shape[d] {
193                    return Err(CoreError::DimensionMismatch {
194                        expected: vec![existing],
195                        got: vec![shape[d]],
196                    });
197                }
198            } else {
199                index_sizes.insert(c, shape[d]);
200            }
201        }
202        input_subs.push(indices);
203    }
204
205    Ok((input_subs, output_sub, index_sizes))
206}
207
208// ======================================================================
209// Cost estimation
210// ======================================================================
211
212/// Compute the cost (FLOPs) and output size of contracting two operands.
213fn contraction_cost(
214    a: &OperandDesc,
215    b: &OperandDesc,
216    output_indices: &[char],
217    index_sizes: &BTreeMap<char, usize>,
218) -> (usize, Vec<char>, Vec<usize>) {
219    // Indices in the result of contracting a and b:
220    // Keep indices that appear in a or b AND (appear in the final output OR
221    // appear in other operands that haven't been contracted yet).
222    // For simplicity in cost estimation, we keep all indices that are in a or b
223    // but contract those that appear in BOTH a and b and NOT in output_indices.
224    let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
225    let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
226    let output_set: BTreeSet<char> = output_indices.iter().copied().collect();
227
228    // Contracted indices: in both a and b, not in the final output
229    let contracted: BTreeSet<char> = a_set
230        .intersection(&b_set)
231        .filter(|c| !output_set.contains(c))
232        .copied()
233        .collect();
234
235    // Result indices: union of a and b minus contracted
236    let mut result_indices: Vec<char> = Vec::new();
237    // First add a's indices (in order), skipping contracted
238    for &c in &a.indices {
239        if !contracted.contains(&c) && !result_indices.contains(&c) {
240            result_indices.push(c);
241        }
242    }
243    // Then add b's indices not already present
244    for &c in &b.indices {
245        if !contracted.contains(&c) && !result_indices.contains(&c) {
246            result_indices.push(c);
247        }
248    }
249
250    let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
251
252    let result_size: usize = result_shape.iter().product::<usize>().max(1);
253    let contract_size: usize = contracted
254        .iter()
255        .map(|c| index_sizes[c])
256        .product::<usize>()
257        .max(1);
258
259    // FLOPs ≈ result_size * contract_size (one multiply-add per element per contraction)
260    let flops = result_size * contract_size;
261
262    (flops, result_indices, result_shape)
263}
264
265/// Compute result indices when contracting two operands, keeping indices that
266/// are needed by remaining operands or the final output.
267fn pairwise_result_indices(
268    a: &OperandDesc,
269    b: &OperandDesc,
270    remaining: &[OperandDesc],
271    final_output: &[char],
272    index_sizes: &BTreeMap<char, usize>,
273) -> (Vec<char>, Vec<usize>) {
274    let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
275    let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
276
277    // Indices needed by remaining operands or final output
278    let mut needed: BTreeSet<char> = final_output.iter().copied().collect();
279    for op in remaining {
280        for &c in &op.indices {
281            needed.insert(c);
282        }
283    }
284
285    // Contract indices that appear in both a and b but are NOT needed elsewhere
286    let contracted: BTreeSet<char> = a_set
287        .intersection(&b_set)
288        .filter(|c| !needed.contains(c))
289        .copied()
290        .collect();
291
292    let mut result_indices: Vec<char> = Vec::new();
293    for &c in &a.indices {
294        if !contracted.contains(&c) && !result_indices.contains(&c) {
295            result_indices.push(c);
296        }
297    }
298    for &c in &b.indices {
299        if !contracted.contains(&c) && !result_indices.contains(&c) {
300            result_indices.push(c);
301        }
302    }
303
304    let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
305    (result_indices, result_shape)
306}
307
308// ======================================================================
309// Greedy path
310// ======================================================================
311
312#[allow(clippy::unnecessary_wraps)]
313fn greedy_path(
314    descs: &[OperandDesc],
315    output_sub: &[char],
316    index_sizes: &BTreeMap<char, usize>,
317) -> Result<PathInfo> {
318    let n = descs.len();
319    if n <= 1 {
320        return Ok(PathInfo {
321            path: vec![],
322            flops: 0,
323            largest_intermediate: 0,
324        });
325    }
326
327    let mut current: Vec<OperandDesc> = descs.to_vec();
328    let mut path = Vec::with_capacity(n - 1);
329    let mut total_flops = 0usize;
330    let mut largest_intermediate = 0usize;
331
332    while current.len() > 1 {
333        let mut best_cost = usize::MAX;
334        let mut best_i = 0;
335        let mut best_j = 1;
336
337        // Find the cheapest pair to contract
338        for i in 0..current.len() {
339            for j in (i + 1)..current.len() {
340                // Build remaining list (excluding i and j)
341                let (cost, _, _) =
342                    contraction_cost(&current[i], &current[j], output_sub, index_sizes);
343
344                // Tie-break: prefer contractions that reduce more indices
345                if cost < best_cost {
346                    best_cost = cost;
347                    best_i = i;
348                    best_j = j;
349                }
350            }
351        }
352
353        // Build remaining for the chosen pair
354        let remaining: Vec<OperandDesc> = current
355            .iter()
356            .enumerate()
357            .filter(|&(k, _)| k != best_i && k != best_j)
358            .map(|(_, d)| d.clone())
359            .collect();
360
361        let (result_indices, result_shape) = pairwise_result_indices(
362            &current[best_i],
363            &current[best_j],
364            &remaining,
365            output_sub,
366            index_sizes,
367        );
368
369        let result_size: usize = result_shape.iter().product::<usize>().max(1);
370        largest_intermediate = largest_intermediate.max(result_size);
371        total_flops += best_cost;
372
373        path.push(ContractionPair {
374            first: best_i,
375            second: best_j,
376        });
377
378        // Remove j first (larger index), then i
379        current.remove(best_j);
380        current.remove(best_i);
381
382        // Append result
383        current.push(OperandDesc {
384            indices: result_indices,
385        });
386    }
387
388    Ok(PathInfo {
389        path,
390        flops: total_flops,
391        largest_intermediate,
392    })
393}
394
395// ======================================================================
396// Optimal path (exhaustive)
397// ======================================================================
398
399fn optimal_path(
400    descs: &[OperandDesc],
401    output_sub: &[char],
402    index_sizes: &BTreeMap<char, usize>,
403) -> Result<PathInfo> {
404    let n = descs.len();
405    if n <= 1 {
406        return Ok(PathInfo {
407            path: vec![],
408            flops: 0,
409            largest_intermediate: 0,
410        });
411    }
412    if n > 8 {
413        // Fall back to greedy for large inputs to avoid combinatorial explosion
414        return greedy_path(descs, output_sub, index_sizes);
415    }
416
417    let mut best_path: Vec<ContractionPair> = Vec::new();
418    let mut best_flops = usize::MAX;
419    let mut best_largest = 0usize;
420
421    find_optimal(
422        descs,
423        output_sub,
424        index_sizes,
425        &mut vec![],
426        0,
427        0,
428        &mut best_path,
429        &mut best_flops,
430        &mut best_largest,
431    );
432
433    Ok(PathInfo {
434        path: best_path,
435        flops: best_flops,
436        largest_intermediate: best_largest,
437    })
438}
439
440#[allow(clippy::too_many_arguments)]
441fn find_optimal(
442    current: &[OperandDesc],
443    output_sub: &[char],
444    index_sizes: &BTreeMap<char, usize>,
445    current_path: &mut Vec<ContractionPair>,
446    current_flops: usize,
447    current_largest: usize,
448    best_path: &mut Vec<ContractionPair>,
449    best_flops: &mut usize,
450    best_largest: &mut usize,
451) {
452    if current.len() <= 1 {
453        if current_flops < *best_flops {
454            *best_flops = current_flops;
455            *best_path = current_path.clone();
456            *best_largest = current_largest;
457        }
458        return;
459    }
460
461    // Prune: if we've already exceeded the best, stop
462    if current_flops >= *best_flops {
463        return;
464    }
465
466    for i in 0..current.len() {
467        for j in (i + 1)..current.len() {
468            let remaining: Vec<OperandDesc> = current
469                .iter()
470                .enumerate()
471                .filter(|&(k, _)| k != i && k != j)
472                .map(|(_, d)| d.clone())
473                .collect();
474
475            let (cost, _, _) = contraction_cost(&current[i], &current[j], output_sub, index_sizes);
476
477            let (result_indices, result_shape) = pairwise_result_indices(
478                &current[i],
479                &current[j],
480                &remaining,
481                output_sub,
482                index_sizes,
483            );
484
485            let result_size: usize = result_shape.iter().product::<usize>().max(1);
486
487            let mut next = remaining;
488            next.push(OperandDesc {
489                indices: result_indices,
490            });
491
492            current_path.push(ContractionPair {
493                first: i,
494                second: j,
495            });
496
497            find_optimal(
498                &next,
499                output_sub,
500                index_sizes,
501                current_path,
502                current_flops + cost,
503                current_largest.max(result_size),
504                best_path,
505                best_flops,
506                best_largest,
507            );
508
509            current_path.pop();
510        }
511    }
512}
513
514// ======================================================================
515// Path execution
516// ======================================================================
517
518/// Execute an einsum following a precomputed contraction path.
519fn execute_path<T: Scalar>(
520    _subscripts: &str,
521    operands: &[&Tensor<T>],
522    path_info: &PathInfo,
523    input_subs: &[Vec<char>],
524    final_output: &[char],
525) -> Result<Tensor<T>> {
526    // We maintain a list of (indices, tensor) pairs. At each step, we contract
527    // a pair and replace them with the result.
528    let mut tensors: Vec<(Vec<char>, Tensor<T>)> = input_subs
529        .iter()
530        .zip(operands.iter())
531        .map(|(indices, &t)| (indices.clone(), t.clone()))
532        .collect();
533
534    for step in &path_info.path {
535        let j = step.second;
536        let i = step.first;
537
538        // Remove j first (larger), then i
539        let (b_indices, b_tensor) = tensors.remove(j);
540        let (a_indices, a_tensor) = tensors.remove(i);
541
542        // Determine what indices should remain after this contraction
543        let remaining_descs: Vec<OperandDesc> = tensors
544            .iter()
545            .map(|(indices, _)| OperandDesc {
546                indices: indices.clone(),
547            })
548            .collect();
549
550        let a_desc = OperandDesc {
551            indices: a_indices.clone(),
552        };
553        let b_desc = OperandDesc {
554            indices: b_indices.clone(),
555        };
556
557        // Use a dummy index_sizes from actual tensor shapes
558        let mut local_sizes: BTreeMap<char, usize> = BTreeMap::new();
559        for (c, &s) in a_indices.iter().zip(a_tensor.shape().iter()) {
560            local_sizes.insert(*c, s);
561        }
562        for (c, &s) in b_indices.iter().zip(b_tensor.shape().iter()) {
563            local_sizes.insert(*c, s);
564        }
565
566        let (result_indices, _result_shape) = pairwise_result_indices(
567            &a_desc,
568            &b_desc,
569            &remaining_descs,
570            final_output,
571            &local_sizes,
572        );
573
574        // Build subscript string for this pairwise contraction
575        let a_sub: String = a_indices.iter().collect();
576        let b_sub: String = b_indices.iter().collect();
577        let out_sub: String = result_indices.iter().collect();
578        let pair_subscripts = format!("{a_sub},{b_sub}->{out_sub}");
579
580        let result = einsum(&pair_subscripts, &[&a_tensor, &b_tensor])?;
581        tensors.push((result_indices, result));
582    }
583
584    if tensors.len() == 1 {
585        let (current_indices, tensor) = tensors.pop().unwrap();
586        // If the indices don't match the final output order, we need a final
587        // transpose/rearrangement via einsum
588        if current_indices == final_output {
589            Ok(tensor)
590        } else {
591            let cur_sub: String = current_indices.iter().collect();
592            let out_sub: String = final_output.iter().collect();
593            let reorder = format!("{cur_sub}->{out_sub}");
594            einsum(&reorder, &[&tensor])
595        }
596    } else {
597        Err(CoreError::InvalidArgument {
598            reason: "einsum path execution did not reduce to a single tensor",
599        })
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_einsum_path_chain_matmul() {
609        // Chain of three matrix multiplies: A(2x3) @ B(3x4) @ C(4x5) -> (2x5)
610        let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
611        let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
612        let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
613
614        let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
615        assert_eq!(info.path.len(), 2);
616        assert!(info.flops > 0);
617    }
618
619    #[test]
620    fn test_einsum_path_optimal_vs_greedy() {
621        let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
622        let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
623        let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
624
625        let greedy = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
626        let optimal = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Optimal).unwrap();
627
628        // Optimal should be <= greedy FLOPs
629        assert!(optimal.flops <= greedy.flops);
630        assert_eq!(optimal.path.len(), 2);
631    }
632
633    #[test]
634    fn test_einsum_optimized_chain_matmul() {
635        // Verify correctness: chain matmul matches direct einsum
636        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
637        let b = Tensor::from_vec(
638            vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
639            vec![3, 4],
640        )
641        .unwrap();
642        let c = Tensor::from_vec(
643            vec![
644                1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
645                0.0, 0.0, 1.0, 0.0,
646            ],
647            vec![4, 5],
648        )
649        .unwrap();
650
651        let direct = einsum("ij,jk,kl->il", &[&a, &b, &c]).unwrap();
652        let optimized =
653            einsum_optimized("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
654
655        assert_eq!(direct.shape(), optimized.shape());
656        for (a, b) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
657            assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
658        }
659    }
660
661    #[test]
662    fn test_einsum_optimized_four_operands() {
663        // Four small matrices
664        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
665        let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
666        let c = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
667        let d = Tensor::from_vec(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
668
669        let direct = einsum("ij,jk,kl,lm->im", &[&a, &b, &c, &d]).unwrap();
670        let optimized =
671            einsum_optimized("ij,jk,kl,lm->im", &[&a, &b, &c, &d], PathStrategy::Optimal).unwrap();
672
673        assert_eq!(direct.shape(), optimized.shape());
674        for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
675            assert!((x - y).abs() < 1e-10, "mismatch: {x} vs {y}");
676        }
677    }
678
679    #[test]
680    fn test_einsum_path_two_operands() {
681        // With only two operands, path should be a single step
682        let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
683        let b = Tensor::from_vec(vec![1.0_f64; 6], vec![3, 2]).unwrap();
684
685        let result = einsum_optimized("ij,jk->ik", &[&a, &b], PathStrategy::Greedy).unwrap();
686        assert_eq!(result.shape(), &[2, 2]);
687    }
688
689    #[test]
690    fn test_einsum_path_single_operand() {
691        let a = Tensor::from_vec(vec![1.0_f64; 4], vec![2, 2]).unwrap();
692        let info = einsum_path("ij->ji", &[&a], PathStrategy::Greedy).unwrap();
693        assert!(info.path.is_empty());
694    }
695
696    #[test]
697    fn test_einsum_path_asymmetric_shapes() {
698        // Test where contraction order matters: A(100x2) @ B(2x3) @ C(3x100)
699        // Greedy should prefer contracting the smaller intermediates first
700        let a = Tensor::from_vec(vec![1.0_f64; 200], vec![100, 2]).unwrap();
701        let b = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
702        let c = Tensor::from_vec(vec![1.0_f64; 300], vec![3, 100]).unwrap();
703
704        let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
705        assert_eq!(info.path.len(), 2);
706        // Should contract A@B first (result 100x3 = 300 elements, cost 100*3*2=600)
707        // rather than B@C first (result 2x100 = 200 elements, cost 2*100*3=600)
708        // Both have same cost in this case, but the important thing is it works
709        assert!(info.flops > 0);
710    }
711
712    #[test]
713    fn test_einsum_optimized_with_trace() {
714        // Mix of contraction types: matmul + trace
715        // "ij,jk,kk->i" — A @ B, then trace of second dimension
716        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
717        let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
718        let c = Tensor::from_vec(vec![3.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
719
720        let direct = einsum("ij,jk,kk->i", &[&a, &b, &c]).unwrap();
721        let optimized =
722            einsum_optimized("ij,jk,kk->i", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
723
724        assert_eq!(direct.shape(), optimized.shape());
725        for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
726            assert!((x - y).abs() < 1e-10);
727        }
728    }
729}