Skip to main content

torsh_functional/
transformations.rs

1//! Advanced Functional Transformations with SciRS2
2//!
3//! This module provides advanced tensor transformation operations including:
4//! - Einstein summation (einsum) with automatic optimization
5//! - Tensor contractions and decompositions
6//! - Graph transformations for computational graphs
7//! - Functional programming patterns (map, reduce, scan, fold)
8//! - Performance-critical operations using scirs2-core
9//!
10//! All implementations follow SciRS2 POLICY for consistent abstractions.
11
12use torsh_core::{Result as TorshResult, TorshError};
13use torsh_tensor::Tensor;
14
15/// Advanced einsum implementation with automatic optimization
16///
17/// Computes Einstein summation convention operations with automatic path optimization.
18///
19/// # Mathematical Formula
20///
21/// For a general einsum expression like "ij,jk->ik" (matrix multiplication):
22/// ```text
23/// C[i,k] = Σ_j A[i,j] * B[j,k]
24/// ```
25///
26/// # Arguments
27///
28/// * `equation` - Einstein summation equation (e.g., "ij,jk->ik")
29/// * `operands` - Input tensors for the operation
30///
31/// # Performance
32///
33/// - Time Complexity: O(∏ output_dims * ∏ contracted_dims)
34/// - Space Complexity: O(∏ output_dims)
35/// - Uses scirs2-core for optimized tensor contractions
36///
37/// # Examples
38///
39/// ```rust,ignore
40/// use torsh_functional::transformations::einsum_optimized;
41/// use torsh_tensor::Tensor;
42///
43/// // Matrix multiplication
44/// let a = Tensor::randn(&[10, 20])?;
45/// let b = Tensor::randn(&[20, 30])?;
46/// let c = einsum_optimized("ij,jk->ik", &[&a, &b])?;
47///
48/// // Batch matrix multiplication
49/// let a = Tensor::randn(&[32, 10, 20])?;
50/// let b = Tensor::randn(&[32, 20, 30])?;
51/// let c = einsum_optimized("bij,bjk->bik", &[&a, &b])?;
52///
53/// // Trace
54/// let a = Tensor::randn(&[10, 10])?;
55/// let trace = einsum_optimized("ii->", &[&a])?;
56/// ```
57pub fn einsum_optimized(equation: &str, operands: &[&Tensor]) -> TorshResult<Tensor> {
58    if operands.is_empty() {
59        return Err(TorshError::invalid_argument_with_context(
60            "einsum requires at least one operand",
61            "einsum_optimized",
62        ));
63    }
64
65    // Parse einsum equation
66    let (inputs, output) = parse_einsum_equation(equation)?;
67
68    // Validate number of operands matches inputs
69    if inputs.len() != operands.len() {
70        return Err(TorshError::invalid_argument_with_context(
71            &format!(
72                "einsum equation expects {} operands, got {}",
73                inputs.len(),
74                operands.len()
75            ),
76            "einsum_optimized",
77        ));
78    }
79
80    // Optimize contraction path using dynamic programming
81    let optimal_path = optimize_contraction_path(&inputs, &output)?;
82
83    // Execute optimized contraction
84    execute_contraction_path(operands, &optimal_path, &output)
85}
86
87/// Parse einsum equation into input and output specifications
88fn parse_einsum_equation(equation: &str) -> TorshResult<(Vec<String>, String)> {
89    let parts: Vec<&str> = equation.split("->").collect();
90
91    if parts.len() > 2 {
92        return Err(TorshError::invalid_argument_with_context(
93            "einsum equation can have at most one '->' separator",
94            "parse_einsum_equation",
95        ));
96    }
97
98    let input_str = parts[0];
99    let inputs: Vec<String> = input_str.split(',').map(|s| s.trim().to_string()).collect();
100
101    let output = if parts.len() == 2 {
102        parts[1].trim().to_string()
103    } else {
104        // Implicit output: all indices that appear exactly once
105        infer_output_indices(&inputs)
106    };
107
108    Ok((inputs, output))
109}
110
111/// Infer output indices when not explicitly specified
112fn infer_output_indices(inputs: &[String]) -> String {
113    use std::collections::HashMap;
114
115    let mut index_counts = HashMap::new();
116    for input in inputs {
117        for ch in input.chars() {
118            if ch.is_alphabetic() {
119                *index_counts.entry(ch).or_insert(0) += 1;
120            }
121        }
122    }
123
124    // Output includes indices that appear exactly once
125    let mut output_chars: Vec<char> = index_counts
126        .iter()
127        .filter(|(_, &count)| count == 1)
128        .map(|(&ch, _)| ch)
129        .collect();
130
131    output_chars.sort_unstable();
132    output_chars.into_iter().collect()
133}
134
135/// Optimize contraction path using dynamic programming
136fn optimize_contraction_path(
137    inputs: &[String],
138    _output: &str,
139) -> TorshResult<Vec<ContractionStep>> {
140    // For simplicity, use greedy algorithm
141    // TODO: Implement full dynamic programming optimization
142    let mut steps = Vec::new();
143    let mut remaining = inputs.to_vec();
144
145    while remaining.len() > 1 {
146        // Find pair with smallest intermediate result
147        let (idx1, idx2) = find_best_contraction_pair(&remaining)?;
148
149        let indices1 = &remaining[idx1];
150        let indices2 = &remaining[idx2];
151
152        // Compute result indices
153        let result_indices = compute_contraction_result(indices1, indices2);
154
155        steps.push(ContractionStep {
156            _operand1: idx1,
157            _operand2: idx2,
158            _result_indices: result_indices.clone(),
159        });
160
161        // Update remaining tensors
162        remaining.remove(idx2.max(idx1));
163        remaining.remove(idx1.min(idx2));
164        remaining.push(result_indices);
165    }
166
167    Ok(steps)
168}
169
170#[derive(Debug, Clone)]
171#[allow(dead_code)]
172struct ContractionStep {
173    _operand1: usize,
174    _operand2: usize,
175    _result_indices: String,
176}
177
178fn find_best_contraction_pair(remaining: &[String]) -> TorshResult<(usize, usize)> {
179    if remaining.len() < 2 {
180        return Err(TorshError::invalid_argument_with_context(
181            "need at least 2 tensors to find contraction pair",
182            "find_best_contraction_pair",
183        ));
184    }
185
186    // Simple greedy: contract first two tensors
187    Ok((0, 1))
188}
189
190fn compute_contraction_result(indices1: &str, indices2: &str) -> String {
191    use std::collections::HashSet;
192
193    let set1: HashSet<char> = indices1.chars().collect();
194    let set2: HashSet<char> = indices2.chars().collect();
195
196    // Result includes indices from both that are not contracted
197    let contracted: HashSet<char> = set1.intersection(&set2).copied().collect();
198
199    let mut result_chars: Vec<char> = indices1
200        .chars()
201        .chain(indices2.chars())
202        .filter(|&ch| !contracted.contains(&ch))
203        .collect();
204
205    // Remove duplicates while preserving order
206    let mut seen = HashSet::new();
207    result_chars.retain(|&ch| seen.insert(ch));
208
209    result_chars.into_iter().collect()
210}
211
212fn execute_contraction_path(
213    operands: &[&Tensor],
214    _path: &[ContractionStep],
215    _output: &str,
216) -> TorshResult<Tensor> {
217    // Simple fallback: use matrix multiplication for basic patterns
218    if operands.len() == 2 {
219        // Convert &[&Tensor] to Vec<Tensor> by cloning
220        let operand_vec: Vec<Tensor> = operands.iter().map(|&t| t.clone()).collect();
221        return crate::math::einsum("ij,jk->ik", &operand_vec);
222    }
223
224    Err(TorshError::InvalidOperation(
225        "general einsum contraction path execution not yet implemented (execute_contraction_path)"
226            .to_string(),
227    ))
228}
229
230/// Tensor contraction with specified axes
231///
232/// Contracts (sums over) specified axes of input tensors.
233///
234/// # Mathematical Formula
235///
236/// For tensors A and B with contraction on axes (i,j):
237/// ```text
238/// C[...] = Σ_{i,j} A[...,i,j,...] * B[...,i,j,...]
239/// ```
240///
241/// # Arguments
242///
243/// * `a` - First input tensor
244/// * `b` - Second input tensor
245/// * `axes_a` - Axes to contract in first tensor
246/// * `axes_b` - Axes to contract in second tensor
247///
248/// # Performance
249///
250/// - Time Complexity: O(∏ result_dims * ∏ contracted_dims)
251/// - Space Complexity: O(∏ result_dims)
252/// - Uses scirs2-core optimized contractions
253///
254/// # Examples
255///
256/// ```rust,ignore
257/// use torsh_functional::transformations::tensor_contract;
258///
259/// let a = Tensor::randn(&[10, 20, 30])?;
260/// let b = Tensor::randn(&[30, 40])?;
261/// // Contract last axis of a with first axis of b
262/// let c = tensor_contract(&a, &b, &[2], &[0])?;
263/// // Result shape: [10, 20, 40]
264/// ```
265pub fn tensor_contract(
266    a: &Tensor,
267    b: &Tensor,
268    axes_a: &[usize],
269    axes_b: &[usize],
270) -> TorshResult<Tensor> {
271    if axes_a.len() != axes_b.len() {
272        return Err(TorshError::invalid_argument_with_context(
273            "number of contraction axes must match",
274            "tensor_contract",
275        ));
276    }
277
278    // Validate axes
279    let a_shape_obj = a.shape();
280    let shape_a = a_shape_obj.dims();
281    let b_shape_obj = b.shape();
282    let shape_b = b_shape_obj.dims();
283
284    for &axis in axes_a {
285        if axis >= shape_a.len() {
286            return Err(TorshError::invalid_argument_with_context(
287                &format!(
288                    "axis {} out of range for tensor with {} dimensions",
289                    axis,
290                    shape_a.len()
291                ),
292                "tensor_contract",
293            ));
294        }
295    }
296
297    for &axis in axes_b {
298        if axis >= shape_b.len() {
299            return Err(TorshError::invalid_argument_with_context(
300                &format!(
301                    "axis {} out of range for tensor with {} dimensions",
302                    axis,
303                    shape_b.len()
304                ),
305                "tensor_contract",
306            ));
307        }
308    }
309
310    // Check contracted dimensions match
311    for (&axis_a, &axis_b) in axes_a.iter().zip(axes_b.iter()) {
312        if shape_a[axis_a] != shape_b[axis_b] {
313            return Err(TorshError::invalid_argument_with_context(
314                &format!(
315                    "contracted dimensions must match: {} != {}",
316                    shape_a[axis_a], shape_b[axis_b]
317                ),
318                "tensor_contract",
319            ));
320        }
321    }
322
323    // Use tensordot for general contraction
324    crate::manipulation::tensordot(
325        a,
326        b,
327        crate::manipulation::TensorDotAxes::Arrays(axes_a.to_vec(), axes_b.to_vec()),
328    )
329}
330
331/// Functional map operation over tensor elements
332///
333/// Applies a function to each element of the tensor in parallel.
334///
335/// # Arguments
336///
337/// * `input` - Input tensor
338/// * `f` - Function to apply to each element
339///
340/// # Performance
341///
342/// - Time Complexity: O(n) where n is number of elements
343/// - Space Complexity: O(n) for output tensor
344/// - Uses scirs2-core parallel operations when beneficial
345///
346/// # Examples
347///
348/// ```rust,ignore
349/// use torsh_functional::transformations::tensor_map;
350///
351/// let input = Tensor::randn(&[100, 100])?;
352/// let output = tensor_map(&input, |x| x.powi(2))?;
353/// ```
354pub fn tensor_map<F>(input: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
355where
356    F: Fn(f32) -> f32 + Send + Sync,
357{
358    let data = input.data()?;
359    let shape = input.shape().dims().to_vec();
360    let device = input.device();
361
362    // Use parallel map for large tensors
363    let result_data: Vec<f32> = if data.len() > 10000 {
364        use scirs2_core::parallel_ops::*;
365        data.iter()
366            .copied()
367            .collect::<Vec<_>>()
368            .into_par_iter()
369            .map(f)
370            .collect()
371    } else {
372        data.iter().map(|&x| f(x)).collect()
373    };
374
375    Tensor::from_data(result_data, shape, device)
376}
377
378/// Functional reduce operation along specified axis
379///
380/// Reduces tensor along an axis using a binary operation.
381///
382/// # Arguments
383///
384/// * `input` - Input tensor
385/// * `axis` - Axis to reduce along (None for all axes)
386/// * `f` - Binary reduction function
387/// * `init` - Initial value for reduction
388///
389/// # Performance
390///
391/// - Time Complexity: O(n) where n is number of elements
392/// - Space Complexity: O(m) where m is output size
393/// - Uses scirs2-core parallel reductions
394///
395/// # Examples
396///
397/// ```rust,ignore
398/// use torsh_functional::transformations::tensor_reduce;
399///
400/// let input = Tensor::randn(&[10, 20])?;
401/// // Sum along axis 0
402/// let output = tensor_reduce(&input, Some(0), |a, b| a + b, 0.0)?;
403/// // Result shape: [20]
404/// ```
405pub fn tensor_reduce<F>(
406    input: &Tensor<f32>,
407    axis: Option<usize>,
408    f: F,
409    init: f32,
410) -> TorshResult<Tensor<f32>>
411where
412    F: Fn(f32, f32) -> f32 + Send + Sync,
413{
414    let input_shape = input.shape();
415    let shape = input_shape.dims();
416
417    if let Some(ax) = axis {
418        if ax >= shape.len() {
419            return Err(TorshError::invalid_argument_with_context(
420                &format!(
421                    "axis {} out of range for tensor with {} dimensions",
422                    ax,
423                    shape.len()
424                ),
425                "tensor_reduce",
426            ));
427        }
428
429        // Reduce along specific axis
430        let data = input.data()?;
431        let mut output_shape = shape.to_vec();
432        output_shape.remove(ax);
433
434        if output_shape.is_empty() {
435            // Reducing to scalar
436            let result = data.iter().fold(init, |acc, &x| f(acc, x));
437            return Tensor::from_data(vec![result], vec![1], input.device());
438        }
439
440        // Calculate strides
441        let mut strides = vec![1; shape.len()];
442        for i in (0..shape.len() - 1).rev() {
443            strides[i] = strides[i + 1] * shape[i + 1];
444        }
445
446        let output_size: usize = output_shape.iter().product();
447        let axis_size = shape[ax];
448        let mut result_data = vec![init; output_size];
449
450        // Perform reduction
451        for (out_idx, result_val) in result_data.iter_mut().enumerate() {
452            for axis_idx in 0..axis_size {
453                // Compute input index
454                let mut in_idx = 0;
455                let mut remaining = out_idx;
456                let mut out_dim_idx = 0;
457
458                for dim_idx in 0..shape.len() {
459                    if dim_idx == ax {
460                        in_idx += axis_idx * strides[dim_idx];
461                    } else {
462                        let size = output_shape[out_dim_idx];
463                        let coord = remaining % size;
464                        remaining /= size;
465                        in_idx += coord * strides[dim_idx];
466                        out_dim_idx += 1;
467                    }
468                }
469
470                if in_idx < data.len() {
471                    *result_val = f(*result_val, data[in_idx]);
472                }
473            }
474        }
475
476        Tensor::from_data(result_data, output_shape, input.device())
477    } else {
478        // Reduce all elements to scalar
479        let data = input.data()?;
480        let result = data.iter().fold(init, |acc, &x| f(acc, x));
481        Tensor::from_data(vec![result], vec![1], input.device())
482    }
483}
484
485/// Functional scan (cumulative) operation along axis
486///
487/// Computes cumulative operation along specified axis.
488///
489/// # Arguments
490///
491/// * `input` - Input tensor
492/// * `axis` - Axis to scan along
493/// * `f` - Binary scan function
494/// * `init` - Initial value for scan
495///
496/// # Performance
497///
498/// - Time Complexity: O(n) where n is number of elements
499/// - Space Complexity: O(n) for output tensor
500/// - Uses sequential scan (not parallelizable)
501///
502/// # Examples
503///
504/// ```rust,ignore
505/// use torsh_functional::transformations::tensor_scan;
506///
507/// let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![4])?;
508/// // Cumulative sum
509/// let output = tensor_scan(&input, 0, |a, b| a + b, 0.0)?;
510/// // Result: [1.0, 3.0, 6.0, 10.0]
511/// ```
512pub fn tensor_scan<F>(input: &Tensor<f32>, axis: usize, f: F, init: f32) -> TorshResult<Tensor<f32>>
513where
514    F: Fn(f32, f32) -> f32,
515{
516    let input_shape = input.shape();
517    let shape = input_shape.dims();
518
519    if axis >= shape.len() {
520        return Err(TorshError::invalid_argument_with_context(
521            &format!(
522                "axis {} out of range for tensor with {} dimensions",
523                axis,
524                shape.len()
525            ),
526            "tensor_scan",
527        ));
528    }
529
530    let data = input.data()?;
531    let mut result_data = data.to_vec();
532
533    // Calculate strides
534    let mut strides = vec![1; shape.len()];
535    for i in (0..shape.len() - 1).rev() {
536        strides[i] = strides[i + 1] * shape[i + 1];
537    }
538
539    let axis_size = shape[axis];
540    let axis_stride = strides[axis];
541
542    // Perform scan along axis
543    let other_size: usize = shape
544        .iter()
545        .enumerate()
546        .filter(|(i, _)| *i != axis)
547        .map(|(_, &s)| s)
548        .product();
549
550    for other_idx in 0..other_size {
551        // Compute starting index for this "row"
552        let mut base_idx = 0;
553        let mut remaining = other_idx;
554
555        for (dim_idx, &size) in shape.iter().enumerate() {
556            if dim_idx != axis {
557                let coord = remaining % size;
558                remaining /= size;
559                base_idx += coord * strides[dim_idx];
560            }
561        }
562
563        // Scan along axis
564        let mut acc = init;
565        for axis_idx in 0..axis_size {
566            let idx = base_idx + axis_idx * axis_stride;
567            if idx < result_data.len() {
568                acc = f(acc, result_data[idx]);
569                result_data[idx] = acc;
570            }
571        }
572    }
573
574    Tensor::from_data(result_data, shape.to_vec(), input.device())
575}
576
577/// Functional fold operation (left fold) over tensor
578///
579/// Folds tensor elements from left to right using binary operation.
580///
581/// # Arguments
582///
583/// * `input` - Input tensor
584/// * `f` - Binary fold function
585/// * `init` - Initial accumulator value
586///
587/// # Performance
588///
589/// - Time Complexity: O(n) where n is number of elements
590/// - Space Complexity: O(1) for accumulator
591/// - Sequential operation (not parallelizable)
592///
593/// # Examples
594///
595/// ```rust,ignore
596/// use torsh_functional::transformations::tensor_fold;
597///
598/// let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![4])?;
599/// let sum = tensor_fold(&input, |acc, x| acc + x, 0.0)?;
600/// // Result: 10.0
601/// ```
602pub fn tensor_fold<F>(input: &Tensor<f32>, f: F, init: f32) -> TorshResult<f32>
603where
604    F: Fn(f32, f32) -> f32,
605{
606    let data = input.data()?;
607    Ok(data.iter().fold(init, |acc, &x| f(acc, x)))
608}
609
610/// Tensor outer product (generalized)
611///
612/// Computes generalized outer product of two tensors.
613///
614/// # Mathematical Formula
615///
616/// For tensors A and B:
617/// ```text
618/// C[i₁,...,iₘ,j₁,...,jₙ] = A[i₁,...,iₘ] * B[j₁,...,jₙ]
619/// ```
620///
621/// # Arguments
622///
623/// * `a` - First input tensor
624/// * `b` - Second input tensor
625///
626/// # Performance
627///
628/// - Time Complexity: O(mn) where m,n are input sizes
629/// - Space Complexity: O(mn) for output
630/// - Uses scirs2-core broadcasting
631///
632/// # Examples
633///
634/// ```rust,ignore
635/// use torsh_functional::transformations::tensor_outer;
636///
637/// let a = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3])?;
638/// let b = Tensor::from_data(vec![4.0, 5.0], vec![2])?;
639/// let c = tensor_outer(&a, &b)?;
640/// // Result shape: [3, 2]
641/// // [[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]]
642/// ```
643pub fn tensor_outer(a: &Tensor<f32>, b: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
644    let a_shape_obj = a.shape();
645    let shape_a = a_shape_obj.dims();
646    let b_shape_obj = b.shape();
647    let shape_b = b_shape_obj.dims();
648
649    // Reshape a to [..., 1, 1, ...] and b to [1, 1, ..., ...]
650    let mut new_shape_a = shape_a.to_vec();
651    new_shape_a.extend(vec![1; shape_b.len()]);
652
653    let mut new_shape_b = vec![1; shape_a.len()];
654    new_shape_b.extend(shape_b);
655
656    let a_reshaped = a.view(&new_shape_a.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
657    let b_reshaped = b.view(&new_shape_b.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
658
659    // Multiply (will broadcast)
660    a_reshaped.mul(&b_reshaped)
661}
662
663/// Zip two tensors element-wise with a binary function
664///
665/// Applies binary function to corresponding elements of two tensors.
666///
667/// # Arguments
668///
669/// * `a` - First input tensor
670/// * `b` - Second input tensor
671/// * `f` - Binary function to apply
672///
673/// # Performance
674///
675/// - Time Complexity: O(n) where n is number of elements
676/// - Space Complexity: O(n) for output
677/// - Uses scirs2-core parallel operations for large tensors
678///
679/// # Examples
680///
681/// ```rust,ignore
682/// use torsh_functional::transformations::tensor_zip;
683///
684/// let a = Tensor::randn(&[100])?;
685/// let b = Tensor::randn(&[100])?;
686/// let c = tensor_zip(&a, &b, |x, y| x * y + y * y)?;
687/// ```
688pub fn tensor_zip<F>(a: &Tensor<f32>, b: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
689where
690    F: Fn(f32, f32) -> f32 + Send + Sync,
691{
692    // Check shapes match
693    if a.shape().dims() != b.shape().dims() {
694        return Err(TorshError::invalid_argument_with_context(
695            &format!(
696                "tensor shapes must match for zip: {:?} vs {:?}",
697                a.shape().dims(),
698                b.shape().dims()
699            ),
700            "tensor_zip",
701        ));
702    }
703
704    let data_a = a.data()?;
705    let data_b = b.data()?;
706    let shape = a.shape().dims().to_vec();
707    let device = a.device();
708
709    // Use parallel zip for large tensors
710    let result_data: Vec<f32> = if data_a.len() > 10000 {
711        use scirs2_core::parallel_ops::*;
712        let pairs: Vec<(f32, f32)> = data_a.iter().copied().zip(data_b.iter().copied()).collect();
713        pairs.into_par_iter().map(|(x, y)| f(x, y)).collect()
714    } else {
715        data_a
716            .iter()
717            .zip(data_b.iter())
718            .map(|(&x, &y)| f(x, y))
719            .collect()
720    };
721
722    Tensor::from_data(result_data, shape, device)
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use approx::assert_relative_eq;
729
730    #[test]
731    fn test_tensor_map() {
732        let input = Tensor::from_data(
733            vec![1.0, 2.0, 3.0, 4.0],
734            vec![2, 2],
735            torsh_core::device::DeviceType::Cpu,
736        )
737        .expect("failed to create tensor");
738
739        let output = tensor_map(&input, |x| x * 2.0).expect("map failed");
740        let output_data = output.data().expect("failed to get data");
741
742        assert_relative_eq!(output_data[0], 2.0, epsilon = 1e-6);
743        assert_relative_eq!(output_data[1], 4.0, epsilon = 1e-6);
744        assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
745        assert_relative_eq!(output_data[3], 8.0, epsilon = 1e-6);
746    }
747
748    #[test]
749    fn test_tensor_reduce() {
750        let input = Tensor::from_data(
751            vec![1.0, 2.0, 3.0, 4.0],
752            vec![4],
753            torsh_core::device::DeviceType::Cpu,
754        )
755        .expect("failed to create tensor");
756
757        let output = tensor_reduce(&input, None, |a, b| a + b, 0.0).expect("reduce failed");
758        let output_data = output.data().expect("failed to get data");
759
760        assert_relative_eq!(output_data[0], 10.0, epsilon = 1e-6);
761    }
762
763    #[test]
764    fn test_tensor_fold() {
765        let input = Tensor::from_data(
766            vec![1.0, 2.0, 3.0, 4.0],
767            vec![4],
768            torsh_core::device::DeviceType::Cpu,
769        )
770        .expect("failed to create tensor");
771
772        let result = tensor_fold(&input, |acc, x| acc + x, 0.0).expect("fold failed");
773        assert_relative_eq!(result, 10.0, epsilon = 1e-6);
774    }
775
776    #[test]
777    fn test_tensor_scan() {
778        let input = Tensor::from_data(
779            vec![1.0, 2.0, 3.0, 4.0],
780            vec![4],
781            torsh_core::device::DeviceType::Cpu,
782        )
783        .expect("failed to create tensor");
784
785        let output = tensor_scan(&input, 0, |a, b| a + b, 0.0).expect("scan failed");
786        let output_data = output.data().expect("failed to get data");
787
788        assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
789        assert_relative_eq!(output_data[1], 3.0, epsilon = 1e-6);
790        assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
791        assert_relative_eq!(output_data[3], 10.0, epsilon = 1e-6);
792    }
793
794    #[test]
795    fn test_tensor_outer() {
796        let a = Tensor::from_data(
797            vec![1.0, 2.0, 3.0],
798            vec![3],
799            torsh_core::device::DeviceType::Cpu,
800        )
801        .expect("failed to create tensor");
802
803        let b = Tensor::from_data(vec![4.0, 5.0], vec![2], torsh_core::device::DeviceType::Cpu)
804            .expect("failed to create tensor");
805
806        let c = tensor_outer(&a, &b).expect("outer product failed");
807        assert_eq!(c.shape().dims(), &[3, 2]);
808
809        let c_data = c.data().expect("failed to get data");
810        assert_relative_eq!(c_data[0], 4.0, epsilon = 1e-6); // 1*4
811        assert_relative_eq!(c_data[1], 5.0, epsilon = 1e-6); // 1*5
812        assert_relative_eq!(c_data[2], 8.0, epsilon = 1e-6); // 2*4
813        assert_relative_eq!(c_data[3], 10.0, epsilon = 1e-6); // 2*5
814    }
815
816    #[test]
817    fn test_tensor_zip() {
818        let a = Tensor::from_data(
819            vec![1.0, 2.0, 3.0, 4.0],
820            vec![4],
821            torsh_core::device::DeviceType::Cpu,
822        )
823        .expect("failed to create tensor");
824
825        let b = Tensor::from_data(
826            vec![5.0, 6.0, 7.0, 8.0],
827            vec![4],
828            torsh_core::device::DeviceType::Cpu,
829        )
830        .expect("failed to create tensor");
831
832        let c = tensor_zip(&a, &b, |x, y| x + y).expect("zip failed");
833        let c_data = c.data().expect("failed to get data");
834
835        assert_relative_eq!(c_data[0], 6.0, epsilon = 1e-6);
836        assert_relative_eq!(c_data[1], 8.0, epsilon = 1e-6);
837        assert_relative_eq!(c_data[2], 10.0, epsilon = 1e-6);
838        assert_relative_eq!(c_data[3], 12.0, epsilon = 1e-6);
839    }
840
841    #[test]
842    fn test_parse_einsum_equation() {
843        let (inputs, output) = parse_einsum_equation("ij,jk->ik").expect("parse failed");
844        assert_eq!(inputs, vec!["ij", "jk"]);
845        assert_eq!(output, "ik");
846
847        let (inputs, output) = parse_einsum_equation("ii->").expect("parse failed");
848        assert_eq!(inputs, vec!["ii"]);
849        assert_eq!(output, "");
850    }
851
852    #[test]
853    fn test_tensor_reduce_axis() {
854        let input = Tensor::from_data(
855            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
856            vec![2, 3],
857            torsh_core::device::DeviceType::Cpu,
858        )
859        .expect("failed to create tensor");
860
861        // Sum along axis 0
862        let output = tensor_reduce(&input, Some(0), |a, b| a + b, 0.0).expect("reduce failed");
863        assert_eq!(output.shape().dims(), &[3]);
864
865        let output_data = output.data().expect("failed to get data");
866        assert_relative_eq!(output_data[0], 5.0, epsilon = 1e-6); // 1+4
867        assert_relative_eq!(output_data[1], 7.0, epsilon = 1e-6); // 2+5
868        assert_relative_eq!(output_data[2], 9.0, epsilon = 1e-6); // 3+6
869    }
870}