tenrso_exec/executor/
custom_ops.rs1use anyhow::Result;
7use scirs2_core::ndarray_ext::{Array, IxDyn};
8use scirs2_core::numeric::{Float, FromPrimitive, Num};
9use tenrso_core::{Axis, DenseND, TensorHandle};
10
11pub fn custom_reduce<T, F>(
25 input: &DenseND<T>,
26 axes: &[Axis],
27 init_value: T,
28 reduce_fn: F,
29) -> Result<DenseND<T>>
30where
31 T: Clone + Num + std::ops::AddAssign,
32 F: Fn(T, T) -> T,
33{
34 if axes.is_empty() {
35 let input_view = input.view();
37 let result = input_view.iter().cloned().fold(init_value, &reduce_fn);
38 let result_array = Array::from_elem(IxDyn(&[]), result);
39 return Ok(DenseND::from_array(result_array));
40 }
41
42 let mut result = input.clone();
44 for &axis in axes {
45 if axis >= result.shape().len() {
46 return Err(anyhow::anyhow!(
47 "Axis {} out of bounds for tensor with {} dimensions",
48 axis,
49 result.shape().len()
50 ));
51 }
52
53 let result_view = result.view();
54
55 let new_shape: Vec<usize> = result
57 .shape()
58 .iter()
59 .enumerate()
60 .filter(|(i, _)| *i != axis)
61 .map(|(_, &s)| s)
62 .collect();
63
64 let axis_size = result.shape()[axis];
65 let output_size: usize = new_shape.iter().product();
66 let mut output_data = Vec::with_capacity(output_size);
67
68 for out_idx in 0..output_size {
70 let mut acc = init_value.clone();
71
72 for axis_idx in 0..axis_size {
74 let mut in_idx = Vec::with_capacity(result.shape().len());
76 let mut remaining = out_idx;
77
78 for (dim_idx, &_dim_size) in result.shape().iter().enumerate() {
79 if dim_idx == axis {
80 in_idx.push(axis_idx);
81 } else {
82 let stride: usize = new_shape
83 [if dim_idx < axis { dim_idx } else { dim_idx - 1 }..]
84 .iter()
85 .product();
86 in_idx.push(remaining / stride);
87 remaining %= stride;
88 }
89 }
90
91 let value = result_view[in_idx.as_slice()].clone();
92 acc = reduce_fn(acc, value);
93 }
94
95 output_data.push(acc);
96 }
97
98 let result_array = Array::from_shape_vec(IxDyn(&new_shape), output_data)
99 .map_err(|e| anyhow::anyhow!("Failed to create result array: {}", e))?;
100 result = DenseND::from_array(result_array);
101 }
102
103 Ok(result)
104}
105
106pub fn custom_binary_op<T, F>(x: &DenseND<T>, y: &DenseND<T>, op_fn: F) -> Result<DenseND<T>>
121where
122 T: Clone + Num,
123 F: Fn(T, T) -> T,
124{
125 let x_view = x.view();
126 let y_view = y.view();
127
128 if x.shape() == y.shape() {
129 let result_data: Vec<T> = x_view
131 .iter()
132 .zip(y_view.iter())
133 .map(|(a, b)| op_fn(a.clone(), b.clone()))
134 .collect();
135
136 let result_array = Array::from_shape_vec(IxDyn(x.shape()), result_data)
137 .map_err(|e| anyhow::anyhow!("Failed to create result array: {}", e))?;
138 return Ok(DenseND::from_array(result_array));
139 }
140
141 Err(anyhow::anyhow!(
144 "Custom binary operations with broadcasting not yet implemented. Shapes: {:?} vs {:?}",
145 x.shape(),
146 y.shape()
147 ))
148}
149
150pub fn custom_unary_op<T, F>(input: &DenseND<T>, op_fn: F) -> Result<DenseND<T>>
162where
163 T: Clone + Num,
164 F: Fn(T) -> T,
165{
166 let input_view = input.view();
167 let result = input_view.mapv(op_fn);
168 Ok(DenseND::from_array(result))
169}
170
171pub fn apply_custom_unary<T, F>(input: &TensorHandle<T>, op_fn: F) -> Result<TensorHandle<T>>
173where
174 T: Clone + Num + Float + FromPrimitive,
175 F: Fn(T) -> T,
176{
177 if let Some(dense) = input.as_dense() {
178 let result = custom_unary_op(dense, op_fn)?;
179 Ok(TensorHandle::from_dense_auto(result))
180 } else {
181 Err(anyhow::anyhow!(
182 "Custom operations only supported for dense tensors"
183 ))
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_custom_reduce_product() {
193 let input = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
194 let result = custom_reduce(&input, &[], 1.0, |acc, x| acc * x).unwrap();
196 let result_view = result.view();
197
198 assert!((result_view[[]] as f64 - 120.0).abs() < 1e-10);
199 }
200
201 #[test]
202 fn test_custom_reduce_max() {
203 let input = DenseND::from_vec(vec![2.0, 8.0, 4.0, 5.0], &[4]).unwrap();
204 let result = custom_reduce(&input, &[], f64::NEG_INFINITY, |acc, x| {
206 if x > acc {
207 x
208 } else {
209 acc
210 }
211 })
212 .unwrap();
213 let result_view = result.view();
214
215 assert!((result_view[[]] as f64 - 8.0).abs() < 1e-10);
216 }
217
218 #[test]
219 fn test_custom_unary_op() {
220 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
221 let result = custom_unary_op(&input, |x| x * x).unwrap();
223 let result_view = result.view();
224
225 assert!((result_view[[0]] as f64 - 1.0).abs() < 1e-10);
226 assert!((result_view[[1]] as f64 - 4.0).abs() < 1e-10);
227 assert!((result_view[[2]] as f64 - 9.0).abs() < 1e-10);
228 assert!((result_view[[3]] as f64 - 16.0).abs() < 1e-10);
229 }
230
231 #[test]
232 fn test_custom_binary_op() {
233 let x = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
234 let y = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
235 let result = custom_binary_op(&x, &y, |a, b| (a + b) / 2.0).unwrap();
237 let result_view = result.view();
238
239 assert!((result_view[[0]] as f64 - 1.5).abs() < 1e-10);
240 assert!((result_view[[1]] as f64 - 2.5).abs() < 1e-10);
241 assert!((result_view[[2]] as f64 - 3.5).abs() < 1e-10);
242 assert!((result_view[[3]] as f64 - 4.5).abs() < 1e-10);
243 }
244
245 #[test]
246 fn test_apply_custom_unary() {
247 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
248 let handle = TensorHandle::from_dense_auto(input);
249
250 let result = apply_custom_unary(&handle, |x: f64| x / (1.0 + x.abs())).unwrap();
252 let result_dense = result.as_dense().unwrap();
253 let result_view = result_dense.view();
254
255 for i in 0..4 {
257 let val = result_view[[i]] as f64;
258 assert!(val > 0.0 && val < 1.0);
259 }
260 }
261}