1use torsh_core::{Result as TorshResult, TorshError};
13use torsh_tensor::Tensor;
14
15pub fn einsum_optimized(equation: &str, operands: &[&Tensor]) -> TorshResult<Tensor> {
58 if operands.is_empty() {
59 return Err(TorshError::invalid_argument_with_context(
60 "einsum requires at least one operand",
61 "einsum_optimized",
62 ));
63 }
64
65 let (inputs, output) = parse_einsum_equation(equation)?;
67
68 if inputs.len() != operands.len() {
70 return Err(TorshError::invalid_argument_with_context(
71 &format!(
72 "einsum equation expects {} operands, got {}",
73 inputs.len(),
74 operands.len()
75 ),
76 "einsum_optimized",
77 ));
78 }
79
80 let optimal_path = optimize_contraction_path(&inputs, &output)?;
82
83 execute_contraction_path(operands, &optimal_path, &output)
85}
86
87fn parse_einsum_equation(equation: &str) -> TorshResult<(Vec<String>, String)> {
89 let parts: Vec<&str> = equation.split("->").collect();
90
91 if parts.len() > 2 {
92 return Err(TorshError::invalid_argument_with_context(
93 "einsum equation can have at most one '->' separator",
94 "parse_einsum_equation",
95 ));
96 }
97
98 let input_str = parts[0];
99 let inputs: Vec<String> = input_str.split(',').map(|s| s.trim().to_string()).collect();
100
101 let output = if parts.len() == 2 {
102 parts[1].trim().to_string()
103 } else {
104 infer_output_indices(&inputs)
106 };
107
108 Ok((inputs, output))
109}
110
111fn infer_output_indices(inputs: &[String]) -> String {
113 use std::collections::HashMap;
114
115 let mut index_counts = HashMap::new();
116 for input in inputs {
117 for ch in input.chars() {
118 if ch.is_alphabetic() {
119 *index_counts.entry(ch).or_insert(0) += 1;
120 }
121 }
122 }
123
124 let mut output_chars: Vec<char> = index_counts
126 .iter()
127 .filter(|(_, &count)| count == 1)
128 .map(|(&ch, _)| ch)
129 .collect();
130
131 output_chars.sort_unstable();
132 output_chars.into_iter().collect()
133}
134
135fn optimize_contraction_path(
137 inputs: &[String],
138 _output: &str,
139) -> TorshResult<Vec<ContractionStep>> {
140 let mut steps = Vec::new();
143 let mut remaining = inputs.to_vec();
144
145 while remaining.len() > 1 {
146 let (idx1, idx2) = find_best_contraction_pair(&remaining)?;
148
149 let indices1 = &remaining[idx1];
150 let indices2 = &remaining[idx2];
151
152 let result_indices = compute_contraction_result(indices1, indices2);
154
155 steps.push(ContractionStep {
156 _operand1: idx1,
157 _operand2: idx2,
158 _result_indices: result_indices.clone(),
159 });
160
161 remaining.remove(idx2.max(idx1));
163 remaining.remove(idx1.min(idx2));
164 remaining.push(result_indices);
165 }
166
167 Ok(steps)
168}
169
170#[derive(Debug, Clone)]
171#[allow(dead_code)]
172struct ContractionStep {
173 _operand1: usize,
174 _operand2: usize,
175 _result_indices: String,
176}
177
178fn find_best_contraction_pair(remaining: &[String]) -> TorshResult<(usize, usize)> {
179 if remaining.len() < 2 {
180 return Err(TorshError::invalid_argument_with_context(
181 "need at least 2 tensors to find contraction pair",
182 "find_best_contraction_pair",
183 ));
184 }
185
186 Ok((0, 1))
188}
189
190fn compute_contraction_result(indices1: &str, indices2: &str) -> String {
191 use std::collections::HashSet;
192
193 let set1: HashSet<char> = indices1.chars().collect();
194 let set2: HashSet<char> = indices2.chars().collect();
195
196 let contracted: HashSet<char> = set1.intersection(&set2).copied().collect();
198
199 let mut result_chars: Vec<char> = indices1
200 .chars()
201 .chain(indices2.chars())
202 .filter(|&ch| !contracted.contains(&ch))
203 .collect();
204
205 let mut seen = HashSet::new();
207 result_chars.retain(|&ch| seen.insert(ch));
208
209 result_chars.into_iter().collect()
210}
211
212fn execute_contraction_path(
213 operands: &[&Tensor],
214 _path: &[ContractionStep],
215 _output: &str,
216) -> TorshResult<Tensor> {
217 if operands.len() == 2 {
219 let operand_vec: Vec<Tensor> = operands.iter().map(|&t| t.clone()).collect();
221 return crate::math::einsum("ij,jk->ik", &operand_vec);
222 }
223
224 Err(TorshError::InvalidOperation(
225 "general einsum contraction path execution not yet implemented (execute_contraction_path)"
226 .to_string(),
227 ))
228}
229
230pub fn tensor_contract(
266 a: &Tensor,
267 b: &Tensor,
268 axes_a: &[usize],
269 axes_b: &[usize],
270) -> TorshResult<Tensor> {
271 if axes_a.len() != axes_b.len() {
272 return Err(TorshError::invalid_argument_with_context(
273 "number of contraction axes must match",
274 "tensor_contract",
275 ));
276 }
277
278 let a_shape_obj = a.shape();
280 let shape_a = a_shape_obj.dims();
281 let b_shape_obj = b.shape();
282 let shape_b = b_shape_obj.dims();
283
284 for &axis in axes_a {
285 if axis >= shape_a.len() {
286 return Err(TorshError::invalid_argument_with_context(
287 &format!(
288 "axis {} out of range for tensor with {} dimensions",
289 axis,
290 shape_a.len()
291 ),
292 "tensor_contract",
293 ));
294 }
295 }
296
297 for &axis in axes_b {
298 if axis >= shape_b.len() {
299 return Err(TorshError::invalid_argument_with_context(
300 &format!(
301 "axis {} out of range for tensor with {} dimensions",
302 axis,
303 shape_b.len()
304 ),
305 "tensor_contract",
306 ));
307 }
308 }
309
310 for (&axis_a, &axis_b) in axes_a.iter().zip(axes_b.iter()) {
312 if shape_a[axis_a] != shape_b[axis_b] {
313 return Err(TorshError::invalid_argument_with_context(
314 &format!(
315 "contracted dimensions must match: {} != {}",
316 shape_a[axis_a], shape_b[axis_b]
317 ),
318 "tensor_contract",
319 ));
320 }
321 }
322
323 crate::manipulation::tensordot(
325 a,
326 b,
327 crate::manipulation::TensorDotAxes::Arrays(axes_a.to_vec(), axes_b.to_vec()),
328 )
329}
330
331pub fn tensor_map<F>(input: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
355where
356 F: Fn(f32) -> f32 + Send + Sync,
357{
358 let data = input.data()?;
359 let shape = input.shape().dims().to_vec();
360 let device = input.device();
361
362 let result_data: Vec<f32> = if data.len() > 10000 {
364 use scirs2_core::parallel_ops::*;
365 data.iter()
366 .copied()
367 .collect::<Vec<_>>()
368 .into_par_iter()
369 .map(f)
370 .collect()
371 } else {
372 data.iter().map(|&x| f(x)).collect()
373 };
374
375 Tensor::from_data(result_data, shape, device)
376}
377
378pub fn tensor_reduce<F>(
406 input: &Tensor<f32>,
407 axis: Option<usize>,
408 f: F,
409 init: f32,
410) -> TorshResult<Tensor<f32>>
411where
412 F: Fn(f32, f32) -> f32 + Send + Sync,
413{
414 let input_shape = input.shape();
415 let shape = input_shape.dims();
416
417 if let Some(ax) = axis {
418 if ax >= shape.len() {
419 return Err(TorshError::invalid_argument_with_context(
420 &format!(
421 "axis {} out of range for tensor with {} dimensions",
422 ax,
423 shape.len()
424 ),
425 "tensor_reduce",
426 ));
427 }
428
429 let data = input.data()?;
431 let mut output_shape = shape.to_vec();
432 output_shape.remove(ax);
433
434 if output_shape.is_empty() {
435 let result = data.iter().fold(init, |acc, &x| f(acc, x));
437 return Tensor::from_data(vec![result], vec![1], input.device());
438 }
439
440 let mut strides = vec![1; shape.len()];
442 for i in (0..shape.len() - 1).rev() {
443 strides[i] = strides[i + 1] * shape[i + 1];
444 }
445
446 let output_size: usize = output_shape.iter().product();
447 let axis_size = shape[ax];
448 let mut result_data = vec![init; output_size];
449
450 for (out_idx, result_val) in result_data.iter_mut().enumerate() {
452 for axis_idx in 0..axis_size {
453 let mut in_idx = 0;
455 let mut remaining = out_idx;
456 let mut out_dim_idx = 0;
457
458 for dim_idx in 0..shape.len() {
459 if dim_idx == ax {
460 in_idx += axis_idx * strides[dim_idx];
461 } else {
462 let size = output_shape[out_dim_idx];
463 let coord = remaining % size;
464 remaining /= size;
465 in_idx += coord * strides[dim_idx];
466 out_dim_idx += 1;
467 }
468 }
469
470 if in_idx < data.len() {
471 *result_val = f(*result_val, data[in_idx]);
472 }
473 }
474 }
475
476 Tensor::from_data(result_data, output_shape, input.device())
477 } else {
478 let data = input.data()?;
480 let result = data.iter().fold(init, |acc, &x| f(acc, x));
481 Tensor::from_data(vec![result], vec![1], input.device())
482 }
483}
484
485pub fn tensor_scan<F>(input: &Tensor<f32>, axis: usize, f: F, init: f32) -> TorshResult<Tensor<f32>>
513where
514 F: Fn(f32, f32) -> f32,
515{
516 let input_shape = input.shape();
517 let shape = input_shape.dims();
518
519 if axis >= shape.len() {
520 return Err(TorshError::invalid_argument_with_context(
521 &format!(
522 "axis {} out of range for tensor with {} dimensions",
523 axis,
524 shape.len()
525 ),
526 "tensor_scan",
527 ));
528 }
529
530 let data = input.data()?;
531 let mut result_data = data.to_vec();
532
533 let mut strides = vec![1; shape.len()];
535 for i in (0..shape.len() - 1).rev() {
536 strides[i] = strides[i + 1] * shape[i + 1];
537 }
538
539 let axis_size = shape[axis];
540 let axis_stride = strides[axis];
541
542 let other_size: usize = shape
544 .iter()
545 .enumerate()
546 .filter(|(i, _)| *i != axis)
547 .map(|(_, &s)| s)
548 .product();
549
550 for other_idx in 0..other_size {
551 let mut base_idx = 0;
553 let mut remaining = other_idx;
554
555 for (dim_idx, &size) in shape.iter().enumerate() {
556 if dim_idx != axis {
557 let coord = remaining % size;
558 remaining /= size;
559 base_idx += coord * strides[dim_idx];
560 }
561 }
562
563 let mut acc = init;
565 for axis_idx in 0..axis_size {
566 let idx = base_idx + axis_idx * axis_stride;
567 if idx < result_data.len() {
568 acc = f(acc, result_data[idx]);
569 result_data[idx] = acc;
570 }
571 }
572 }
573
574 Tensor::from_data(result_data, shape.to_vec(), input.device())
575}
576
577pub fn tensor_fold<F>(input: &Tensor<f32>, f: F, init: f32) -> TorshResult<f32>
603where
604 F: Fn(f32, f32) -> f32,
605{
606 let data = input.data()?;
607 Ok(data.iter().fold(init, |acc, &x| f(acc, x)))
608}
609
610pub fn tensor_outer(a: &Tensor<f32>, b: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
644 let a_shape_obj = a.shape();
645 let shape_a = a_shape_obj.dims();
646 let b_shape_obj = b.shape();
647 let shape_b = b_shape_obj.dims();
648
649 let mut new_shape_a = shape_a.to_vec();
651 new_shape_a.extend(vec![1; shape_b.len()]);
652
653 let mut new_shape_b = vec![1; shape_a.len()];
654 new_shape_b.extend(shape_b);
655
656 let a_reshaped = a.view(&new_shape_a.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
657 let b_reshaped = b.view(&new_shape_b.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
658
659 a_reshaped.mul(&b_reshaped)
661}
662
663pub fn tensor_zip<F>(a: &Tensor<f32>, b: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
689where
690 F: Fn(f32, f32) -> f32 + Send + Sync,
691{
692 if a.shape().dims() != b.shape().dims() {
694 return Err(TorshError::invalid_argument_with_context(
695 &format!(
696 "tensor shapes must match for zip: {:?} vs {:?}",
697 a.shape().dims(),
698 b.shape().dims()
699 ),
700 "tensor_zip",
701 ));
702 }
703
704 let data_a = a.data()?;
705 let data_b = b.data()?;
706 let shape = a.shape().dims().to_vec();
707 let device = a.device();
708
709 let result_data: Vec<f32> = if data_a.len() > 10000 {
711 use scirs2_core::parallel_ops::*;
712 let pairs: Vec<(f32, f32)> = data_a.iter().copied().zip(data_b.iter().copied()).collect();
713 pairs.into_par_iter().map(|(x, y)| f(x, y)).collect()
714 } else {
715 data_a
716 .iter()
717 .zip(data_b.iter())
718 .map(|(&x, &y)| f(x, y))
719 .collect()
720 };
721
722 Tensor::from_data(result_data, shape, device)
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use approx::assert_relative_eq;
729
730 #[test]
731 fn test_tensor_map() {
732 let input = Tensor::from_data(
733 vec![1.0, 2.0, 3.0, 4.0],
734 vec![2, 2],
735 torsh_core::device::DeviceType::Cpu,
736 )
737 .expect("failed to create tensor");
738
739 let output = tensor_map(&input, |x| x * 2.0).expect("map failed");
740 let output_data = output.data().expect("failed to get data");
741
742 assert_relative_eq!(output_data[0], 2.0, epsilon = 1e-6);
743 assert_relative_eq!(output_data[1], 4.0, epsilon = 1e-6);
744 assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
745 assert_relative_eq!(output_data[3], 8.0, epsilon = 1e-6);
746 }
747
748 #[test]
749 fn test_tensor_reduce() {
750 let input = Tensor::from_data(
751 vec![1.0, 2.0, 3.0, 4.0],
752 vec![4],
753 torsh_core::device::DeviceType::Cpu,
754 )
755 .expect("failed to create tensor");
756
757 let output = tensor_reduce(&input, None, |a, b| a + b, 0.0).expect("reduce failed");
758 let output_data = output.data().expect("failed to get data");
759
760 assert_relative_eq!(output_data[0], 10.0, epsilon = 1e-6);
761 }
762
763 #[test]
764 fn test_tensor_fold() {
765 let input = Tensor::from_data(
766 vec![1.0, 2.0, 3.0, 4.0],
767 vec![4],
768 torsh_core::device::DeviceType::Cpu,
769 )
770 .expect("failed to create tensor");
771
772 let result = tensor_fold(&input, |acc, x| acc + x, 0.0).expect("fold failed");
773 assert_relative_eq!(result, 10.0, epsilon = 1e-6);
774 }
775
776 #[test]
777 fn test_tensor_scan() {
778 let input = Tensor::from_data(
779 vec![1.0, 2.0, 3.0, 4.0],
780 vec![4],
781 torsh_core::device::DeviceType::Cpu,
782 )
783 .expect("failed to create tensor");
784
785 let output = tensor_scan(&input, 0, |a, b| a + b, 0.0).expect("scan failed");
786 let output_data = output.data().expect("failed to get data");
787
788 assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
789 assert_relative_eq!(output_data[1], 3.0, epsilon = 1e-6);
790 assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
791 assert_relative_eq!(output_data[3], 10.0, epsilon = 1e-6);
792 }
793
794 #[test]
795 fn test_tensor_outer() {
796 let a = Tensor::from_data(
797 vec![1.0, 2.0, 3.0],
798 vec![3],
799 torsh_core::device::DeviceType::Cpu,
800 )
801 .expect("failed to create tensor");
802
803 let b = Tensor::from_data(vec![4.0, 5.0], vec![2], torsh_core::device::DeviceType::Cpu)
804 .expect("failed to create tensor");
805
806 let c = tensor_outer(&a, &b).expect("outer product failed");
807 assert_eq!(c.shape().dims(), &[3, 2]);
808
809 let c_data = c.data().expect("failed to get data");
810 assert_relative_eq!(c_data[0], 4.0, epsilon = 1e-6); assert_relative_eq!(c_data[1], 5.0, epsilon = 1e-6); assert_relative_eq!(c_data[2], 8.0, epsilon = 1e-6); assert_relative_eq!(c_data[3], 10.0, epsilon = 1e-6); }
815
816 #[test]
817 fn test_tensor_zip() {
818 let a = Tensor::from_data(
819 vec![1.0, 2.0, 3.0, 4.0],
820 vec![4],
821 torsh_core::device::DeviceType::Cpu,
822 )
823 .expect("failed to create tensor");
824
825 let b = Tensor::from_data(
826 vec![5.0, 6.0, 7.0, 8.0],
827 vec![4],
828 torsh_core::device::DeviceType::Cpu,
829 )
830 .expect("failed to create tensor");
831
832 let c = tensor_zip(&a, &b, |x, y| x + y).expect("zip failed");
833 let c_data = c.data().expect("failed to get data");
834
835 assert_relative_eq!(c_data[0], 6.0, epsilon = 1e-6);
836 assert_relative_eq!(c_data[1], 8.0, epsilon = 1e-6);
837 assert_relative_eq!(c_data[2], 10.0, epsilon = 1e-6);
838 assert_relative_eq!(c_data[3], 12.0, epsilon = 1e-6);
839 }
840
841 #[test]
842 fn test_parse_einsum_equation() {
843 let (inputs, output) = parse_einsum_equation("ij,jk->ik").expect("parse failed");
844 assert_eq!(inputs, vec!["ij", "jk"]);
845 assert_eq!(output, "ik");
846
847 let (inputs, output) = parse_einsum_equation("ii->").expect("parse failed");
848 assert_eq!(inputs, vec!["ii"]);
849 assert_eq!(output, "");
850 }
851
852 #[test]
853 fn test_tensor_reduce_axis() {
854 let input = Tensor::from_data(
855 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
856 vec![2, 3],
857 torsh_core::device::DeviceType::Cpu,
858 )
859 .expect("failed to create tensor");
860
861 let output = tensor_reduce(&input, Some(0), |a, b| a + b, 0.0).expect("reduce failed");
863 assert_eq!(output.shape().dims(), &[3]);
864
865 let output_data = output.data().expect("failed to get data");
866 assert_relative_eq!(output_data[0], 5.0, epsilon = 1e-6); assert_relative_eq!(output_data[1], 7.0, epsilon = 1e-6); assert_relative_eq!(output_data[2], 9.0, epsilon = 1e-6); }
870}