1use crate::Scalar;
34use crate::error::{CoreError, Result};
35use crate::tensor::Tensor;
36use std::collections::{BTreeMap, BTreeSet};
37
38use super::einsum::einsum;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum PathStrategy {
43 Greedy,
46 Optimal,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct ContractionPair {
56 pub first: usize,
58 pub second: usize,
60}
61
62#[derive(Debug, Clone)]
64pub struct PathInfo {
65 pub path: Vec<ContractionPair>,
69 pub flops: usize,
71 pub largest_intermediate: usize,
73}
74
75#[derive(Debug, Clone)]
77struct OperandDesc {
78 indices: Vec<char>,
80}
81
82type ParsedSubscripts = (Vec<Vec<char>>, Vec<char>, BTreeMap<char, usize>);
84
85pub fn einsum_path<T: Scalar>(
90 subscripts: &str,
91 operands: &[&Tensor<T>],
92 strategy: PathStrategy,
93) -> Result<PathInfo> {
94 let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
95
96 let descs: Vec<OperandDesc> = input_subs
97 .iter()
98 .map(|indices| OperandDesc {
99 indices: indices.clone(),
100 })
101 .collect();
102
103 match strategy {
104 PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes),
105 PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes),
106 }
107}
108
109pub fn einsum_optimized<T: Scalar>(
115 subscripts: &str,
116 operands: &[&Tensor<T>],
117 strategy: PathStrategy,
118) -> Result<Tensor<T>> {
119 if operands.len() <= 2 {
120 return einsum(subscripts, operands);
121 }
122
123 let (input_subs, output_sub, index_sizes) = parse_path_subscripts(subscripts, operands)?;
124
125 let descs: Vec<OperandDesc> = input_subs
126 .iter()
127 .map(|indices| OperandDesc {
128 indices: indices.clone(),
129 })
130 .collect();
131
132 let path_info = match strategy {
133 PathStrategy::Greedy => greedy_path(&descs, &output_sub, &index_sizes)?,
134 PathStrategy::Optimal => optimal_path(&descs, &output_sub, &index_sizes)?,
135 };
136
137 execute_path(subscripts, operands, &path_info, &input_subs, &output_sub)
138}
139
140fn parse_path_subscripts<T: Scalar>(
147 subscripts: &str,
148 operands: &[&Tensor<T>],
149) -> Result<ParsedSubscripts> {
150 let subscripts = subscripts.replace(' ', "");
151
152 let (inputs_str, output_sub) = if let Some((inp, out)) = subscripts.split_once("->") {
153 let output_indices: Vec<char> = out.chars().collect();
154 (inp.to_string(), output_indices)
155 } else {
156 let mut counts: BTreeMap<char, usize> = BTreeMap::new();
158 for c in subscripts.chars() {
159 if c == ',' {
160 continue;
161 }
162 *counts.entry(c).or_insert(0) += 1;
163 }
164 let output_indices: Vec<char> = counts
165 .iter()
166 .filter(|(_, count)| **count == 1)
167 .map(|(&c, _)| c)
168 .collect();
169 (subscripts.clone(), output_indices)
170 };
171
172 let input_parts: Vec<&str> = inputs_str.split(',').collect();
173 if input_parts.len() != operands.len() {
174 return Err(CoreError::InvalidArgument {
175 reason: "number of subscript groups does not match number of operands",
176 });
177 }
178
179 let mut input_subs = Vec::with_capacity(input_parts.len());
180 let mut index_sizes: BTreeMap<char, usize> = BTreeMap::new();
181
182 for (i, part) in input_parts.iter().enumerate() {
183 let indices: Vec<char> = part.chars().collect();
184 if indices.len() != operands[i].ndim() {
185 return Err(CoreError::InvalidArgument {
186 reason: "operand rank does not match number of subscript indices",
187 });
188 }
189 let shape = operands[i].shape();
190 for (d, &c) in indices.iter().enumerate() {
191 if let Some(&existing) = index_sizes.get(&c) {
192 if existing != shape[d] {
193 return Err(CoreError::DimensionMismatch {
194 expected: vec![existing],
195 got: vec![shape[d]],
196 });
197 }
198 } else {
199 index_sizes.insert(c, shape[d]);
200 }
201 }
202 input_subs.push(indices);
203 }
204
205 Ok((input_subs, output_sub, index_sizes))
206}
207
208fn contraction_cost(
214 a: &OperandDesc,
215 b: &OperandDesc,
216 output_indices: &[char],
217 index_sizes: &BTreeMap<char, usize>,
218) -> (usize, Vec<char>, Vec<usize>) {
219 let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
225 let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
226 let output_set: BTreeSet<char> = output_indices.iter().copied().collect();
227
228 let contracted: BTreeSet<char> = a_set
230 .intersection(&b_set)
231 .filter(|c| !output_set.contains(c))
232 .copied()
233 .collect();
234
235 let mut result_indices: Vec<char> = Vec::new();
237 for &c in &a.indices {
239 if !contracted.contains(&c) && !result_indices.contains(&c) {
240 result_indices.push(c);
241 }
242 }
243 for &c in &b.indices {
245 if !contracted.contains(&c) && !result_indices.contains(&c) {
246 result_indices.push(c);
247 }
248 }
249
250 let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
251
252 let result_size: usize = result_shape.iter().product::<usize>().max(1);
253 let contract_size: usize = contracted
254 .iter()
255 .map(|c| index_sizes[c])
256 .product::<usize>()
257 .max(1);
258
259 let flops = result_size * contract_size;
261
262 (flops, result_indices, result_shape)
263}
264
265fn pairwise_result_indices(
268 a: &OperandDesc,
269 b: &OperandDesc,
270 remaining: &[OperandDesc],
271 final_output: &[char],
272 index_sizes: &BTreeMap<char, usize>,
273) -> (Vec<char>, Vec<usize>) {
274 let a_set: BTreeSet<char> = a.indices.iter().copied().collect();
275 let b_set: BTreeSet<char> = b.indices.iter().copied().collect();
276
277 let mut needed: BTreeSet<char> = final_output.iter().copied().collect();
279 for op in remaining {
280 for &c in &op.indices {
281 needed.insert(c);
282 }
283 }
284
285 let contracted: BTreeSet<char> = a_set
287 .intersection(&b_set)
288 .filter(|c| !needed.contains(c))
289 .copied()
290 .collect();
291
292 let mut result_indices: Vec<char> = Vec::new();
293 for &c in &a.indices {
294 if !contracted.contains(&c) && !result_indices.contains(&c) {
295 result_indices.push(c);
296 }
297 }
298 for &c in &b.indices {
299 if !contracted.contains(&c) && !result_indices.contains(&c) {
300 result_indices.push(c);
301 }
302 }
303
304 let result_shape: Vec<usize> = result_indices.iter().map(|c| index_sizes[c]).collect();
305 (result_indices, result_shape)
306}
307
308#[allow(clippy::unnecessary_wraps)]
313fn greedy_path(
314 descs: &[OperandDesc],
315 output_sub: &[char],
316 index_sizes: &BTreeMap<char, usize>,
317) -> Result<PathInfo> {
318 let n = descs.len();
319 if n <= 1 {
320 return Ok(PathInfo {
321 path: vec![],
322 flops: 0,
323 largest_intermediate: 0,
324 });
325 }
326
327 let mut current: Vec<OperandDesc> = descs.to_vec();
328 let mut path = Vec::with_capacity(n - 1);
329 let mut total_flops = 0usize;
330 let mut largest_intermediate = 0usize;
331
332 while current.len() > 1 {
333 let mut best_cost = usize::MAX;
334 let mut best_i = 0;
335 let mut best_j = 1;
336
337 for i in 0..current.len() {
339 for j in (i + 1)..current.len() {
340 let (cost, _, _) =
342 contraction_cost(¤t[i], ¤t[j], output_sub, index_sizes);
343
344 if cost < best_cost {
346 best_cost = cost;
347 best_i = i;
348 best_j = j;
349 }
350 }
351 }
352
353 let remaining: Vec<OperandDesc> = current
355 .iter()
356 .enumerate()
357 .filter(|&(k, _)| k != best_i && k != best_j)
358 .map(|(_, d)| d.clone())
359 .collect();
360
361 let (result_indices, result_shape) = pairwise_result_indices(
362 ¤t[best_i],
363 ¤t[best_j],
364 &remaining,
365 output_sub,
366 index_sizes,
367 );
368
369 let result_size: usize = result_shape.iter().product::<usize>().max(1);
370 largest_intermediate = largest_intermediate.max(result_size);
371 total_flops += best_cost;
372
373 path.push(ContractionPair {
374 first: best_i,
375 second: best_j,
376 });
377
378 current.remove(best_j);
380 current.remove(best_i);
381
382 current.push(OperandDesc {
384 indices: result_indices,
385 });
386 }
387
388 Ok(PathInfo {
389 path,
390 flops: total_flops,
391 largest_intermediate,
392 })
393}
394
395fn optimal_path(
400 descs: &[OperandDesc],
401 output_sub: &[char],
402 index_sizes: &BTreeMap<char, usize>,
403) -> Result<PathInfo> {
404 let n = descs.len();
405 if n <= 1 {
406 return Ok(PathInfo {
407 path: vec![],
408 flops: 0,
409 largest_intermediate: 0,
410 });
411 }
412 if n > 8 {
413 return greedy_path(descs, output_sub, index_sizes);
415 }
416
417 let mut best_path: Vec<ContractionPair> = Vec::new();
418 let mut best_flops = usize::MAX;
419 let mut best_largest = 0usize;
420
421 find_optimal(
422 descs,
423 output_sub,
424 index_sizes,
425 &mut vec![],
426 0,
427 0,
428 &mut best_path,
429 &mut best_flops,
430 &mut best_largest,
431 );
432
433 Ok(PathInfo {
434 path: best_path,
435 flops: best_flops,
436 largest_intermediate: best_largest,
437 })
438}
439
440#[allow(clippy::too_many_arguments)]
441fn find_optimal(
442 current: &[OperandDesc],
443 output_sub: &[char],
444 index_sizes: &BTreeMap<char, usize>,
445 current_path: &mut Vec<ContractionPair>,
446 current_flops: usize,
447 current_largest: usize,
448 best_path: &mut Vec<ContractionPair>,
449 best_flops: &mut usize,
450 best_largest: &mut usize,
451) {
452 if current.len() <= 1 {
453 if current_flops < *best_flops {
454 *best_flops = current_flops;
455 *best_path = current_path.clone();
456 *best_largest = current_largest;
457 }
458 return;
459 }
460
461 if current_flops >= *best_flops {
463 return;
464 }
465
466 for i in 0..current.len() {
467 for j in (i + 1)..current.len() {
468 let remaining: Vec<OperandDesc> = current
469 .iter()
470 .enumerate()
471 .filter(|&(k, _)| k != i && k != j)
472 .map(|(_, d)| d.clone())
473 .collect();
474
475 let (cost, _, _) = contraction_cost(¤t[i], ¤t[j], output_sub, index_sizes);
476
477 let (result_indices, result_shape) = pairwise_result_indices(
478 ¤t[i],
479 ¤t[j],
480 &remaining,
481 output_sub,
482 index_sizes,
483 );
484
485 let result_size: usize = result_shape.iter().product::<usize>().max(1);
486
487 let mut next = remaining;
488 next.push(OperandDesc {
489 indices: result_indices,
490 });
491
492 current_path.push(ContractionPair {
493 first: i,
494 second: j,
495 });
496
497 find_optimal(
498 &next,
499 output_sub,
500 index_sizes,
501 current_path,
502 current_flops + cost,
503 current_largest.max(result_size),
504 best_path,
505 best_flops,
506 best_largest,
507 );
508
509 current_path.pop();
510 }
511 }
512}
513
514fn execute_path<T: Scalar>(
520 _subscripts: &str,
521 operands: &[&Tensor<T>],
522 path_info: &PathInfo,
523 input_subs: &[Vec<char>],
524 final_output: &[char],
525) -> Result<Tensor<T>> {
526 let mut tensors: Vec<(Vec<char>, Tensor<T>)> = input_subs
529 .iter()
530 .zip(operands.iter())
531 .map(|(indices, &t)| (indices.clone(), t.clone()))
532 .collect();
533
534 for step in &path_info.path {
535 let j = step.second;
536 let i = step.first;
537
538 let (b_indices, b_tensor) = tensors.remove(j);
540 let (a_indices, a_tensor) = tensors.remove(i);
541
542 let remaining_descs: Vec<OperandDesc> = tensors
544 .iter()
545 .map(|(indices, _)| OperandDesc {
546 indices: indices.clone(),
547 })
548 .collect();
549
550 let a_desc = OperandDesc {
551 indices: a_indices.clone(),
552 };
553 let b_desc = OperandDesc {
554 indices: b_indices.clone(),
555 };
556
557 let mut local_sizes: BTreeMap<char, usize> = BTreeMap::new();
559 for (c, &s) in a_indices.iter().zip(a_tensor.shape().iter()) {
560 local_sizes.insert(*c, s);
561 }
562 for (c, &s) in b_indices.iter().zip(b_tensor.shape().iter()) {
563 local_sizes.insert(*c, s);
564 }
565
566 let (result_indices, _result_shape) = pairwise_result_indices(
567 &a_desc,
568 &b_desc,
569 &remaining_descs,
570 final_output,
571 &local_sizes,
572 );
573
574 let a_sub: String = a_indices.iter().collect();
576 let b_sub: String = b_indices.iter().collect();
577 let out_sub: String = result_indices.iter().collect();
578 let pair_subscripts = format!("{a_sub},{b_sub}->{out_sub}");
579
580 let result = einsum(&pair_subscripts, &[&a_tensor, &b_tensor])?;
581 tensors.push((result_indices, result));
582 }
583
584 if tensors.len() == 1 {
585 let (current_indices, tensor) = tensors.pop().unwrap();
586 if current_indices == final_output {
589 Ok(tensor)
590 } else {
591 let cur_sub: String = current_indices.iter().collect();
592 let out_sub: String = final_output.iter().collect();
593 let reorder = format!("{cur_sub}->{out_sub}");
594 einsum(&reorder, &[&tensor])
595 }
596 } else {
597 Err(CoreError::InvalidArgument {
598 reason: "einsum path execution did not reduce to a single tensor",
599 })
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_einsum_path_chain_matmul() {
609 let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
611 let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
612 let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
613
614 let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
615 assert_eq!(info.path.len(), 2);
616 assert!(info.flops > 0);
617 }
618
619 #[test]
620 fn test_einsum_path_optimal_vs_greedy() {
621 let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
622 let b = Tensor::from_vec(vec![1.0_f64; 12], vec![3, 4]).unwrap();
623 let c = Tensor::from_vec(vec![1.0_f64; 20], vec![4, 5]).unwrap();
624
625 let greedy = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
626 let optimal = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Optimal).unwrap();
627
628 assert!(optimal.flops <= greedy.flops);
630 assert_eq!(optimal.path.len(), 2);
631 }
632
633 #[test]
634 fn test_einsum_optimized_chain_matmul() {
635 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
637 let b = Tensor::from_vec(
638 vec![1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
639 vec![3, 4],
640 )
641 .unwrap();
642 let c = Tensor::from_vec(
643 vec![
644 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
645 0.0, 0.0, 1.0, 0.0,
646 ],
647 vec![4, 5],
648 )
649 .unwrap();
650
651 let direct = einsum("ij,jk,kl->il", &[&a, &b, &c]).unwrap();
652 let optimized =
653 einsum_optimized("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
654
655 assert_eq!(direct.shape(), optimized.shape());
656 for (a, b) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
657 assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
658 }
659 }
660
661 #[test]
662 fn test_einsum_optimized_four_operands() {
663 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
665 let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
666 let c = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
667 let d = Tensor::from_vec(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
668
669 let direct = einsum("ij,jk,kl,lm->im", &[&a, &b, &c, &d]).unwrap();
670 let optimized =
671 einsum_optimized("ij,jk,kl,lm->im", &[&a, &b, &c, &d], PathStrategy::Optimal).unwrap();
672
673 assert_eq!(direct.shape(), optimized.shape());
674 for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
675 assert!((x - y).abs() < 1e-10, "mismatch: {x} vs {y}");
676 }
677 }
678
679 #[test]
680 fn test_einsum_path_two_operands() {
681 let a = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
683 let b = Tensor::from_vec(vec![1.0_f64; 6], vec![3, 2]).unwrap();
684
685 let result = einsum_optimized("ij,jk->ik", &[&a, &b], PathStrategy::Greedy).unwrap();
686 assert_eq!(result.shape(), &[2, 2]);
687 }
688
689 #[test]
690 fn test_einsum_path_single_operand() {
691 let a = Tensor::from_vec(vec![1.0_f64; 4], vec![2, 2]).unwrap();
692 let info = einsum_path("ij->ji", &[&a], PathStrategy::Greedy).unwrap();
693 assert!(info.path.is_empty());
694 }
695
696 #[test]
697 fn test_einsum_path_asymmetric_shapes() {
698 let a = Tensor::from_vec(vec![1.0_f64; 200], vec![100, 2]).unwrap();
701 let b = Tensor::from_vec(vec![1.0_f64; 6], vec![2, 3]).unwrap();
702 let c = Tensor::from_vec(vec![1.0_f64; 300], vec![3, 100]).unwrap();
703
704 let info = einsum_path("ij,jk,kl->il", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
705 assert_eq!(info.path.len(), 2);
706 assert!(info.flops > 0);
710 }
711
712 #[test]
713 fn test_einsum_optimized_with_trace() {
714 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
717 let b = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
718 let c = Tensor::from_vec(vec![3.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
719
720 let direct = einsum("ij,jk,kk->i", &[&a, &b, &c]).unwrap();
721 let optimized =
722 einsum_optimized("ij,jk,kk->i", &[&a, &b, &c], PathStrategy::Greedy).unwrap();
723
724 assert_eq!(direct.shape(), optimized.shape());
725 for (x, y) in direct.as_slice().iter().zip(optimized.as_slice().iter()) {
726 assert!((x - y).abs() < 1e-10);
727 }
728 }
729}