1use super::simd_ops::{self, SimdBinaryOp, SimdUnaryOp};
21use super::tiled_reductions;
22use super::types::{BinaryOp, CpuExecutor};
23use anyhow::Result;
24use scirs2_core::numeric::{Float, FromPrimitive, Num};
25use tenrso_core::{Axis, DenseND};
26
27#[allow(dead_code)]
32pub(crate) fn optimized_unary<T>(
33 executor: &CpuExecutor,
34 input: &DenseND<T>,
35 op: UnaryOpType,
36) -> Result<DenseND<T>>
37where
38 T: Clone + Num + Float + FromPrimitive + Send + Sync,
39{
40 if executor.enable_simd && simd_ops::should_use_simd(input.shape()) {
42 let simd_op = match op {
43 UnaryOpType::Neg => SimdUnaryOp::Neg,
44 UnaryOpType::Abs => SimdUnaryOp::Abs,
45 UnaryOpType::Exp => SimdUnaryOp::Exp,
46 UnaryOpType::Log => SimdUnaryOp::Log,
47 UnaryOpType::Sin => SimdUnaryOp::Sin,
48 UnaryOpType::Cos => SimdUnaryOp::Cos,
49 UnaryOpType::Sqrt => SimdUnaryOp::Sqrt,
50 UnaryOpType::Sqr => SimdUnaryOp::Sqr,
51 UnaryOpType::Recip => SimdUnaryOp::Recip,
52 UnaryOpType::Tanh => SimdUnaryOp::Tanh,
53 UnaryOpType::Sigmoid => SimdUnaryOp::Sigmoid,
54 UnaryOpType::ReLU => SimdUnaryOp::ReLU,
55 UnaryOpType::Gelu => SimdUnaryOp::Gelu,
56 UnaryOpType::Elu => SimdUnaryOp::Elu,
57 UnaryOpType::Selu => SimdUnaryOp::Selu,
58 UnaryOpType::Softplus => SimdUnaryOp::Softplus,
59 UnaryOpType::Sign => SimdUnaryOp::Sign,
60 };
61 return simd_ops::simd_unary(input, simd_op);
62 }
63
64 let result = match op {
66 UnaryOpType::Neg => input.view().mapv(|v| -v),
67 UnaryOpType::Abs => input.view().mapv(|v| v.abs()),
68 UnaryOpType::Exp => input.view().mapv(|v| v.exp()),
69 UnaryOpType::Log => input.view().mapv(|v| v.ln()),
70 UnaryOpType::Sin => input.view().mapv(|v| v.sin()),
71 UnaryOpType::Cos => input.view().mapv(|v| v.cos()),
72 UnaryOpType::Sqrt => input.view().mapv(|v| v.sqrt()),
73 UnaryOpType::Sqr => input.view().mapv(|v| v * v),
74 UnaryOpType::Recip => input.view().mapv(|v| v.recip()),
75 UnaryOpType::Tanh => input.view().mapv(|v| v.tanh()),
76 UnaryOpType::Sigmoid => input.view().mapv(|v| {
77 let one = T::one();
78 one / (one + (-v).exp())
79 }),
80 UnaryOpType::ReLU => input.view().mapv(|v| {
81 let zero = T::zero();
82 if v > zero {
83 v
84 } else {
85 zero
86 }
87 }),
88 UnaryOpType::Gelu => input.view().mapv(|v| {
89 let half = T::from_f64(0.5).unwrap_or_else(T::one);
90 let one = T::one();
91 let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
92 let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
93 let x_cubed = v * v * v;
94 let inner = coeff * (v + cubic_coeff * x_cubed);
95 half * v * (one + inner.tanh())
96 }),
97 UnaryOpType::Elu => input.view().mapv(|v| {
98 let zero = T::zero();
99 let one = T::one();
100 if v > zero {
101 v
102 } else {
103 v.exp() - one
104 }
105 }),
106 UnaryOpType::Selu => input.view().mapv(|v| {
107 let zero = T::zero();
108 let one = T::one();
109 let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
110 let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
111 if v > zero {
112 scale * v
113 } else {
114 scale * alpha * (v.exp() - one)
115 }
116 }),
117 UnaryOpType::Softplus => input.view().mapv(|v| {
118 let zero = T::zero();
119 let one = T::one();
120 let abs_v = v.abs();
121 let max_part = if v > zero { v } else { zero };
122 max_part + (one + (-abs_v).exp()).ln()
123 }),
124 UnaryOpType::Sign => input.view().mapv(|v| {
125 let zero = T::zero();
126 let one = T::one();
127 let neg_one = -one;
128 if v > zero {
129 one
130 } else if v < zero {
131 neg_one
132 } else {
133 zero
134 }
135 }),
136 };
137
138 Ok(DenseND::from_array(result))
139}
140
141#[allow(dead_code)]
146pub(crate) fn optimized_binary<T>(
147 executor: &CpuExecutor,
148 x: &DenseND<T>,
149 y: &DenseND<T>,
150 op: BinaryOp,
151) -> Result<DenseND<T>>
152where
153 T: Clone + Num + Float + Send + Sync + std::ops::AddAssign,
154{
155 if x.shape() == y.shape() {
157 if executor.enable_simd && simd_ops::should_use_simd(x.shape()) {
159 let simd_op = match op {
160 BinaryOp::Add => SimdBinaryOp::Add,
161 BinaryOp::Sub => SimdBinaryOp::Sub,
162 BinaryOp::Mul => SimdBinaryOp::Mul,
163 BinaryOp::Div => SimdBinaryOp::Div,
164 BinaryOp::Pow => SimdBinaryOp::Pow,
165 BinaryOp::Maximum => SimdBinaryOp::Maximum,
166 BinaryOp::Minimum => SimdBinaryOp::Minimum,
167 };
168 return simd_ops::simd_binary(x, y, simd_op);
169 }
170 }
171
172 use scirs2_core::ndarray_ext::Zip;
174 let result = match op {
175 BinaryOp::Add => &x.view() + &y.view(),
176 BinaryOp::Sub => &x.view() - &y.view(),
177 BinaryOp::Mul => &x.view() * &y.view(),
178 BinaryOp::Div => &x.view() / &y.view(),
179 BinaryOp::Pow => Zip::from(&x.view())
180 .and(&y.view())
181 .map_collect(|&x_val, &y_val| x_val.powf(y_val)),
182 BinaryOp::Maximum => Zip::from(&x.view())
183 .and(&y.view())
184 .map_collect(|&x_val, &y_val| if x_val > y_val { x_val } else { y_val }),
185 BinaryOp::Minimum => Zip::from(&x.view())
186 .and(&y.view())
187 .map_collect(|&x_val, &y_val| if x_val < y_val { x_val } else { y_val }),
188 };
189
190 Ok(DenseND::from_array(result))
191}
192
193#[allow(dead_code)]
197pub(crate) fn optimized_sum<T>(
198 executor: &CpuExecutor,
199 input: &DenseND<T>,
200 axes: &[Axis],
201) -> Result<DenseND<T>>
202where
203 T: Clone + Num + Send + Sync + std::ops::AddAssign + std::iter::Sum,
204{
205 if axes.is_empty() {
207 if executor.enable_tiled_reductions && tiled_reductions::should_use_tiling(input.shape()) {
208 let sum_val = tiled_reductions::tiled_sum_all(input)?;
209 let result = scirs2_core::ndarray_ext::Array::from_elem(
210 scirs2_core::ndarray_ext::IxDyn(&[]),
211 sum_val,
212 );
213 return Ok(DenseND::from_array(result));
214 } else {
215 let sum_val: T = input.view().iter().cloned().sum();
217 let result = scirs2_core::ndarray_ext::Array::from_elem(
218 scirs2_core::ndarray_ext::IxDyn(&[]),
219 sum_val,
220 );
221 return Ok(DenseND::from_array(result));
222 }
223 }
224
225 if axes.len() == 1 && executor.enable_tiled_reductions {
227 return tiled_reductions::tiled_sum_axis(input, axes[0]);
228 }
229
230 let mut result = input.view().to_owned();
232 let mut sorted_axes = axes.to_vec();
233 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
234
235 for &axis_idx in &sorted_axes {
236 let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
237 result = result.sum_axis(axis);
238 }
239
240 Ok(DenseND::from_array(result))
241}
242
243#[allow(dead_code)]
247pub(crate) fn optimized_mean<T>(
248 executor: &CpuExecutor,
249 input: &DenseND<T>,
250 axes: &[Axis],
251) -> Result<DenseND<T>>
252where
253 T: Clone + Num + Send + Sync + std::ops::AddAssign + Float + FromPrimitive + std::iter::Sum,
254{
255 if axes.is_empty() {
257 if executor.enable_tiled_reductions && tiled_reductions::should_use_tiling(input.shape()) {
258 let mean_val = tiled_reductions::tiled_mean_all(input)?;
259 let result = scirs2_core::ndarray_ext::Array::from_elem(
260 scirs2_core::ndarray_ext::IxDyn(&[]),
261 mean_val,
262 );
263 return Ok(DenseND::from_array(result));
264 } else {
265 let total_elements = input.view().len();
267 let sum: T = input.view().iter().cloned().sum();
268 let mean = sum / T::from_usize(total_elements).unwrap();
269 let result = scirs2_core::ndarray_ext::Array::from_elem(
270 scirs2_core::ndarray_ext::IxDyn(&[]),
271 mean,
272 );
273 return Ok(DenseND::from_array(result));
274 }
275 }
276
277 let mut result = input.view().to_owned();
279 let mut sorted_axes = axes.to_vec();
280 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
281
282 for &axis_idx in &sorted_axes {
283 let axis = scirs2_core::ndarray_ext::Axis(axis_idx);
284 result = result
285 .mean_axis(axis)
286 .ok_or_else(|| anyhow::anyhow!("Mean computation failed"))?;
287 }
288
289 Ok(DenseND::from_array(result))
290}
291
292#[derive(Clone, Copy, Debug)]
294#[allow(dead_code)]
295pub(crate) enum UnaryOpType {
296 Neg,
297 Abs,
298 Exp,
299 Log,
300 Sin,
301 Cos,
302 Sqrt,
303 Sqr,
304 Recip,
305 Tanh,
306 Sigmoid,
307 ReLU,
308 Gelu,
309 Elu,
310 Selu,
311 Softplus,
312 Sign,
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_optimized_unary_small_tensor() {
321 let executor = CpuExecutor::new();
322 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
323
324 let result = optimized_unary(&executor, &input, UnaryOpType::Neg).unwrap();
325 let result_view = result.view();
326
327 assert_eq!(result_view[[0]], -1.0);
328 assert_eq!(result_view[[1]], -2.0);
329 assert_eq!(result_view[[2]], -3.0);
330 assert_eq!(result_view[[3]], -4.0);
331 }
332
333 #[test]
334 fn test_optimized_binary_same_shape() {
335 let executor = CpuExecutor::new();
336 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
337 let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
338
339 let result = optimized_binary(&executor, &a, &b, BinaryOp::Add).unwrap();
340 let result_view = result.view();
341
342 assert_eq!(result_view[[0]], 6.0);
343 assert_eq!(result_view[[1]], 8.0);
344 assert_eq!(result_view[[2]], 10.0);
345 assert_eq!(result_view[[3]], 12.0);
346 }
347
348 #[test]
349 fn test_optimized_sum_all() {
350 let executor = CpuExecutor::new();
351 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
352
353 let result = optimized_sum(&executor, &input, &[]).unwrap();
354 let result_view = result.view();
355
356 assert_eq!(result_view[[]], 15.0);
357 }
358
359 #[test]
360 fn test_optimized_mean_all() {
361 let executor = CpuExecutor::new();
362 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
363
364 let result = optimized_mean(&executor, &input, &[]).unwrap();
365 let result_view = result.view();
366
367 assert_eq!(result_view[[]], 3.0);
368 }
369
370 #[test]
371 fn test_optimization_disabled() {
372 let executor = CpuExecutor::unoptimized();
373 let input = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
374
375 let result = optimized_unary(&executor, &input, UnaryOpType::Exp).unwrap();
377 let result_view = result.view();
378
379 assert!((result_view[[0]] - std::f64::consts::E).abs() < 1e-10);
380 }
381}