tenrso_exec/executor/
custom_ops.rs

1//! Custom operations with user-defined functions
2//!
3//! This module provides support for custom operations where users can define
4//! their own reduction functions, element-wise operations, and more.
5
6use anyhow::Result;
7use scirs2_core::ndarray_ext::{Array, IxDyn};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use tenrso_core::{Axis, DenseND, TensorHandle};
10
11/// Custom reduction operation with user-defined reduction function
12///
13/// # Arguments
14/// * `input` - Input tensor
15/// * `axes` - Axes along which to reduce (empty = reduce all)
16/// * `init_value` - Initial value for the reduction
17/// * `reduce_fn` - Binary reduction function: (accumulator, element) -> new_accumulator
18///
19/// # Example
20/// ```ignore
21/// // Compute product along axis 0
22/// custom_reduce(&tensor, &[0], 1.0, |acc, x| acc * x)?;
23/// ```
24pub fn custom_reduce<T, F>(
25    input: &DenseND<T>,
26    axes: &[Axis],
27    init_value: T,
28    reduce_fn: F,
29) -> Result<DenseND<T>>
30where
31    T: Clone + Num + std::ops::AddAssign,
32    F: Fn(T, T) -> T,
33{
34    if axes.is_empty() {
35        // Reduce all axes
36        let input_view = input.view();
37        let result = input_view.iter().cloned().fold(init_value, &reduce_fn);
38        let result_array = Array::from_elem(IxDyn(&[]), result);
39        return Ok(DenseND::from_array(result_array));
40    }
41
42    // Reduce along specific axes
43    let mut result = input.clone();
44    for &axis in axes {
45        if axis >= result.shape().len() {
46            return Err(anyhow::anyhow!(
47                "Axis {} out of bounds for tensor with {} dimensions",
48                axis,
49                result.shape().len()
50            ));
51        }
52
53        let result_view = result.view();
54
55        // Manual reduction along the specified axis
56        let new_shape: Vec<usize> = result
57            .shape()
58            .iter()
59            .enumerate()
60            .filter(|(i, _)| *i != axis)
61            .map(|(_, &s)| s)
62            .collect();
63
64        let axis_size = result.shape()[axis];
65        let output_size: usize = new_shape.iter().product();
66        let mut output_data = Vec::with_capacity(output_size);
67
68        // Iterate over all output positions
69        for out_idx in 0..output_size {
70            let mut acc = init_value.clone();
71
72            // Reduce along the specified axis
73            for axis_idx in 0..axis_size {
74                // Convert flat output index to multi-dimensional index
75                let mut in_idx = Vec::with_capacity(result.shape().len());
76                let mut remaining = out_idx;
77
78                for (dim_idx, &_dim_size) in result.shape().iter().enumerate() {
79                    if dim_idx == axis {
80                        in_idx.push(axis_idx);
81                    } else {
82                        let stride: usize = new_shape
83                            [if dim_idx < axis { dim_idx } else { dim_idx - 1 }..]
84                            .iter()
85                            .product();
86                        in_idx.push(remaining / stride);
87                        remaining %= stride;
88                    }
89                }
90
91                let value = result_view[in_idx.as_slice()].clone();
92                acc = reduce_fn(acc, value);
93            }
94
95            output_data.push(acc);
96        }
97
98        let result_array = Array::from_shape_vec(IxDyn(&new_shape), output_data)
99            .map_err(|e| anyhow::anyhow!("Failed to create result array: {}", e))?;
100        result = DenseND::from_array(result_array);
101    }
102
103    Ok(result)
104}
105
106/// Custom element-wise binary operation with user-defined function
107///
108/// Applies a custom binary operation element-wise to two tensors with broadcasting support.
109///
110/// # Arguments
111/// * `x` - First input tensor
112/// * `y` - Second input tensor
113/// * `op_fn` - Binary operation function: (x_elem, y_elem) -> result_elem
114///
115/// # Example
116/// ```ignore
117/// // Custom operation: (x + y) / 2
118/// custom_binary_op(&x, &y, |a, b| (a + b) / 2.0)?;
119/// ```
120pub fn custom_binary_op<T, F>(x: &DenseND<T>, y: &DenseND<T>, op_fn: F) -> Result<DenseND<T>>
121where
122    T: Clone + Num,
123    F: Fn(T, T) -> T,
124{
125    let x_view = x.view();
126    let y_view = y.view();
127
128    if x.shape() == y.shape() {
129        // Same shape - direct element-wise operation
130        let result_data: Vec<T> = x_view
131            .iter()
132            .zip(y_view.iter())
133            .map(|(a, b)| op_fn(a.clone(), b.clone()))
134            .collect();
135
136        let result_array = Array::from_shape_vec(IxDyn(x.shape()), result_data)
137            .map_err(|e| anyhow::anyhow!("Failed to create result array: {}", e))?;
138        return Ok(DenseND::from_array(result_array));
139    }
140
141    // Broadcasting case - simplified implementation
142    // TODO: Implement full broadcasting support
143    Err(anyhow::anyhow!(
144        "Custom binary operations with broadcasting not yet implemented. Shapes: {:?} vs {:?}",
145        x.shape(),
146        y.shape()
147    ))
148}
149
150/// Custom element-wise unary operation with user-defined function
151///
152/// # Arguments
153/// * `input` - Input tensor
154/// * `op_fn` - Unary operation function: (element) -> result_element
155///
156/// # Example
157/// ```ignore
158/// // Custom operation: sigmoid-like function
159/// custom_unary_op(&input, |x| x / (1.0 + x.abs()))?;
160/// ```
161pub fn custom_unary_op<T, F>(input: &DenseND<T>, op_fn: F) -> Result<DenseND<T>>
162where
163    T: Clone + Num,
164    F: Fn(T) -> T,
165{
166    let input_view = input.view();
167    let result = input_view.mapv(op_fn);
168    Ok(DenseND::from_array(result))
169}
170
171/// Apply a custom operation to a tensor handle
172pub fn apply_custom_unary<T, F>(input: &TensorHandle<T>, op_fn: F) -> Result<TensorHandle<T>>
173where
174    T: Clone + Num + Float + FromPrimitive,
175    F: Fn(T) -> T,
176{
177    if let Some(dense) = input.as_dense() {
178        let result = custom_unary_op(dense, op_fn)?;
179        Ok(TensorHandle::from_dense_auto(result))
180    } else {
181        Err(anyhow::anyhow!(
182            "Custom operations only supported for dense tensors"
183        ))
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_custom_reduce_product() {
193        let input = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
194        // Product of all elements: 2 * 3 * 4 * 5 = 120
195        let result = custom_reduce(&input, &[], 1.0, |acc, x| acc * x).unwrap();
196        let result_view = result.view();
197
198        assert!((result_view[[]] as f64 - 120.0).abs() < 1e-10);
199    }
200
201    #[test]
202    fn test_custom_reduce_max() {
203        let input = DenseND::from_vec(vec![2.0, 8.0, 4.0, 5.0], &[4]).unwrap();
204        // Max of all elements: 8.0
205        let result = custom_reduce(&input, &[], f64::NEG_INFINITY, |acc, x| {
206            if x > acc {
207                x
208            } else {
209                acc
210            }
211        })
212        .unwrap();
213        let result_view = result.view();
214
215        assert!((result_view[[]] as f64 - 8.0).abs() < 1e-10);
216    }
217
218    #[test]
219    fn test_custom_unary_op() {
220        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
221        // Square each element
222        let result = custom_unary_op(&input, |x| x * x).unwrap();
223        let result_view = result.view();
224
225        assert!((result_view[[0]] as f64 - 1.0).abs() < 1e-10);
226        assert!((result_view[[1]] as f64 - 4.0).abs() < 1e-10);
227        assert!((result_view[[2]] as f64 - 9.0).abs() < 1e-10);
228        assert!((result_view[[3]] as f64 - 16.0).abs() < 1e-10);
229    }
230
231    #[test]
232    fn test_custom_binary_op() {
233        let x = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
234        let y = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
235        // Custom operation: (x + y) / 2
236        let result = custom_binary_op(&x, &y, |a, b| (a + b) / 2.0).unwrap();
237        let result_view = result.view();
238
239        assert!((result_view[[0]] as f64 - 1.5).abs() < 1e-10);
240        assert!((result_view[[1]] as f64 - 2.5).abs() < 1e-10);
241        assert!((result_view[[2]] as f64 - 3.5).abs() < 1e-10);
242        assert!((result_view[[3]] as f64 - 4.5).abs() < 1e-10);
243    }
244
245    #[test]
246    fn test_apply_custom_unary() {
247        let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
248        let handle = TensorHandle::from_dense_auto(input);
249
250        // Apply custom sigmoid-like function
251        let result = apply_custom_unary(&handle, |x: f64| x / (1.0 + x.abs())).unwrap();
252        let result_dense = result.as_dense().unwrap();
253        let result_view = result_dense.view();
254
255        // Check that values are in expected range (0, 1)
256        for i in 0..4 {
257            let val = result_view[[i]] as f64;
258            assert!(val > 0.0 && val < 1.0);
259        }
260    }
261}