1#![allow(dead_code)]
19
20use anyhow::Result;
21use scirs2_core::ndarray_ext::{Array, ArrayView, IxDyn, Zip};
22use scirs2_core::numeric::{Float, FromPrimitive, Num};
23use tenrso_core::DenseND;
24
25const SIMD_THRESHOLD: usize = 1024;
28
29#[inline]
31pub(crate) fn should_use_simd(shape: &[usize]) -> bool {
32 let total_elements: usize = shape.iter().product();
33 total_elements >= SIMD_THRESHOLD
34}
35
36#[allow(dead_code)]
38pub(crate) enum SimdUnaryOp {
39 Neg,
40 Abs,
41 Exp,
42 Log,
43 Sin,
44 Cos,
45 Sqrt,
46 Sqr,
47 Recip,
48 Tanh,
49 Sigmoid,
50 ReLU,
51 Gelu,
52 Elu,
53 Selu,
54 Softplus,
55 Sign,
56}
57
58pub(crate) fn simd_unary<T>(input: &DenseND<T>, op: SimdUnaryOp) -> Result<DenseND<T>>
66where
67 T: Clone + Num + Float + FromPrimitive + Send + Sync,
68{
69 let input_view = input.view();
70
71 let result = match op {
74 SimdUnaryOp::Neg => input_view.mapv(|v| -v),
75 SimdUnaryOp::Abs => input_view.mapv(|v| v.abs()),
76 SimdUnaryOp::Exp => simd_exp(&input_view),
77 SimdUnaryOp::Log => simd_log(&input_view),
78 SimdUnaryOp::Sin => input_view.mapv(|v| v.sin()),
79 SimdUnaryOp::Cos => input_view.mapv(|v| v.cos()),
80 SimdUnaryOp::Sqrt => simd_sqrt(&input_view),
81 SimdUnaryOp::Sqr => simd_sqr(&input_view),
82 SimdUnaryOp::Recip => simd_recip(&input_view),
83 SimdUnaryOp::Tanh => input_view.mapv(|v| v.tanh()),
84 SimdUnaryOp::Sigmoid => simd_sigmoid(&input_view),
85 SimdUnaryOp::ReLU => simd_relu(&input_view),
86 SimdUnaryOp::Gelu => simd_gelu(&input_view),
87 SimdUnaryOp::Elu => simd_elu(&input_view),
88 SimdUnaryOp::Selu => simd_selu(&input_view),
89 SimdUnaryOp::Softplus => simd_softplus(&input_view),
90 SimdUnaryOp::Sign => simd_sign(&input_view),
91 };
92
93 Ok(DenseND::from_array(result))
94}
95
96#[inline]
98fn simd_exp<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
99where
100 T: Clone + Float,
101{
102 input.mapv(|v| v.exp())
104}
105
106#[inline]
108fn simd_log<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
109where
110 T: Clone + Float,
111{
112 input.mapv(|v| v.ln())
113}
114
115#[inline]
117fn simd_sqrt<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
118where
119 T: Clone + Float,
120{
121 input.mapv(|v| v.sqrt())
122}
123
124#[inline]
126fn simd_sqr<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
127where
128 T: Clone + Float,
129{
130 input.mapv(|v| v * v)
132}
133
134#[inline]
136fn simd_recip<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
137where
138 T: Clone + Float,
139{
140 input.mapv(|v| v.recip())
141}
142
143#[inline]
145fn simd_sigmoid<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
146where
147 T: Clone + Float + FromPrimitive,
148{
149 let one = T::one();
150 input.mapv(|v| one / (one + (-v).exp()))
151}
152
153#[inline]
155fn simd_relu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
156where
157 T: Clone + Float,
158{
159 let zero = T::zero();
160 input.mapv(|v| if v > zero { v } else { zero })
161}
162
163#[inline]
165fn simd_gelu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
166where
167 T: Clone + Float + FromPrimitive,
168{
169 let half = T::from_f64(0.5).unwrap_or_else(T::one);
170 let one = T::one();
171 let coeff = T::from_f64(0.7978845608028654).unwrap_or_else(T::one);
172 let cubic_coeff = T::from_f64(0.044715).unwrap_or_else(T::zero);
173
174 input.mapv(|v| {
175 let x_cubed = v * v * v;
176 let inner = coeff * (v + cubic_coeff * x_cubed);
177 half * v * (one + inner.tanh())
178 })
179}
180
181#[inline]
183fn simd_elu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
184where
185 T: Clone + Float + FromPrimitive,
186{
187 let zero = T::zero();
188 let one = T::one();
189
190 input.mapv(|v| if v > zero { v } else { v.exp() - one })
191}
192
193#[inline]
195fn simd_selu<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
196where
197 T: Clone + Float + FromPrimitive,
198{
199 let zero = T::zero();
200 let one = T::one();
201 let scale = T::from_f64(1.050_700_987_355_480_5).unwrap_or_else(T::one);
202 let alpha = T::from_f64(1.673_263_242_354_377_2).unwrap_or_else(T::one);
203
204 input.mapv(|v| {
205 if v > zero {
206 scale * v
207 } else {
208 scale * alpha * (v.exp() - one)
209 }
210 })
211}
212
213#[inline]
215fn simd_softplus<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
216where
217 T: Clone + Float + FromPrimitive,
218{
219 let zero = T::zero();
220 let one = T::one();
221
222 input.mapv(|v| {
223 let abs_v = v.abs();
224 let max_part = if v > zero { v } else { zero };
225 max_part + (one + (-abs_v).exp()).ln()
226 })
227}
228
229#[inline]
231fn simd_sign<T>(input: &ArrayView<T, IxDyn>) -> Array<T, IxDyn>
232where
233 T: Clone + Float + FromPrimitive,
234{
235 let zero = T::zero();
236 let one = T::one();
237 let neg_one = -one;
238
239 input.mapv(|v| {
240 if v > zero {
241 one
242 } else if v < zero {
243 neg_one
244 } else {
245 zero
246 }
247 })
248}
249
250#[allow(dead_code)]
252pub(crate) enum SimdBinaryOp {
253 Add,
254 Sub,
255 Mul,
256 Div,
257 Pow,
258 Maximum,
259 Minimum,
260}
261
262pub(crate) fn simd_binary<T>(x: &DenseND<T>, y: &DenseND<T>, op: SimdBinaryOp) -> Result<DenseND<T>>
270where
271 T: Clone + Num + Float + Send + Sync,
272{
273 let x_view = x.view();
274 let y_view = y.view();
275
276 if x.shape() == y.shape() {
278 let result = match op {
279 SimdBinaryOp::Add => &x_view + &y_view,
280 SimdBinaryOp::Sub => &x_view - &y_view,
281 SimdBinaryOp::Mul => &x_view * &y_view,
282 SimdBinaryOp::Div => &x_view / &y_view,
283 SimdBinaryOp::Pow => Zip::from(&x_view)
284 .and(&y_view)
285 .map_collect(|&a, &b| a.powf(b)),
286 SimdBinaryOp::Maximum => {
287 Zip::from(&x_view)
288 .and(&y_view)
289 .map_collect(|&a, &b| if a > b { a } else { b })
290 }
291 SimdBinaryOp::Minimum => {
292 Zip::from(&x_view)
293 .and(&y_view)
294 .map_collect(|&a, &b| if a < b { a } else { b })
295 }
296 };
297 return Ok(DenseND::from_array(result));
298 }
299
300 let result = match op {
303 SimdBinaryOp::Add => &x_view + &y_view,
304 SimdBinaryOp::Sub => &x_view - &y_view,
305 SimdBinaryOp::Mul => &x_view * &y_view,
306 SimdBinaryOp::Div => &x_view / &y_view,
307 SimdBinaryOp::Pow => Zip::from(&x_view)
308 .and(&y_view)
309 .map_collect(|&a, &b| a.powf(b)),
310 SimdBinaryOp::Maximum => {
311 Zip::from(&x_view)
312 .and(&y_view)
313 .map_collect(|&a, &b| if a > b { a } else { b })
314 }
315 SimdBinaryOp::Minimum => {
316 Zip::from(&x_view)
317 .and(&y_view)
318 .map_collect(|&a, &b| if a < b { a } else { b })
319 }
320 };
321
322 Ok(DenseND::from_array(result))
323}
324
325#[allow(dead_code)]
330pub(crate) fn simd_fma<T>(a: &DenseND<T>, b: &DenseND<T>, c: &DenseND<T>) -> Result<DenseND<T>>
331where
332 T: Clone + Num + Float + Send + Sync + std::ops::AddAssign,
333{
334 if a.shape() != b.shape() || a.shape() != c.shape() {
335 return Err(anyhow::anyhow!(
336 "FMA requires all tensors to have the same shape"
337 ));
338 }
339
340 let a_view = a.view();
341 let b_view = b.view();
342 let c_view = c.view();
343
344 let result = Zip::from(&a_view)
346 .and(&b_view)
347 .and(&c_view)
348 .map_collect(|&a_val, &b_val, &c_val| a_val * b_val + c_val);
349
350 Ok(DenseND::from_array(result))
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_should_use_simd() {
359 assert!(!should_use_simd(&[10, 10])); assert!(should_use_simd(&[32, 32])); assert!(should_use_simd(&[100, 100])); }
363
364 #[test]
365 fn test_simd_unary_exp() {
366 let input = DenseND::from_vec(vec![0.0, 1.0, 2.0, 3.0], &[4]).unwrap();
367 let result = simd_unary(&input, SimdUnaryOp::Exp).unwrap();
368 let result_view = result.view();
369
370 assert!((result_view[[0]] - 1.0).abs() < 1e-10);
371 assert!((result_view[[1]] - std::f64::consts::E).abs() < 1e-10);
372 }
373
374 #[test]
375 fn test_simd_unary_sqrt() {
376 let input = DenseND::from_vec(vec![1.0, 4.0, 9.0, 16.0], &[4]).unwrap();
377 let result = simd_unary(&input, SimdUnaryOp::Sqrt).unwrap();
378 let result_view = result.view();
379
380 assert!((result_view[[0]] - 1.0).abs() < 1e-10);
381 assert!((result_view[[1]] - 2.0).abs() < 1e-10);
382 assert!((result_view[[2]] - 3.0).abs() < 1e-10);
383 assert!((result_view[[3]] - 4.0).abs() < 1e-10);
384 }
385
386 #[test]
387 fn test_simd_unary_relu() {
388 let input = DenseND::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
389 let result = simd_unary(&input, SimdUnaryOp::ReLU).unwrap();
390 let result_view = result.view();
391
392 assert_eq!(result_view[[0]], 0.0);
393 assert_eq!(result_view[[1]], 0.0);
394 assert_eq!(result_view[[2]], 0.0);
395 assert_eq!(result_view[[3]], 1.0);
396 assert_eq!(result_view[[4]], 2.0);
397 }
398
399 #[test]
400 fn test_simd_binary_add() {
401 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
402 let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[4]).unwrap();
403 let result = simd_binary(&a, &b, SimdBinaryOp::Add).unwrap();
404 let result_view = result.view();
405
406 assert_eq!(result_view[[0]], 6.0);
407 assert_eq!(result_view[[1]], 8.0);
408 assert_eq!(result_view[[2]], 10.0);
409 assert_eq!(result_view[[3]], 12.0);
410 }
411
412 #[test]
413 fn test_simd_binary_mul() {
414 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
415 let b = DenseND::from_vec(vec![2.0, 3.0, 4.0, 5.0], &[4]).unwrap();
416 let result = simd_binary(&a, &b, SimdBinaryOp::Mul).unwrap();
417 let result_view = result.view();
418
419 assert_eq!(result_view[[0]], 2.0);
420 assert_eq!(result_view[[1]], 6.0);
421 assert_eq!(result_view[[2]], 12.0);
422 assert_eq!(result_view[[3]], 20.0);
423 }
424
425 #[test]
426 fn test_simd_fma() {
427 let a = DenseND::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
428 let b = DenseND::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
429 let c = DenseND::from_vec(vec![1.0, 1.0, 1.0], &[3]).unwrap();
430 let result = simd_fma(&a, &b, &c).unwrap();
431 let result_view = result.view();
432
433 assert_eq!(result_view[[0]], 3.0);
435 assert_eq!(result_view[[1]], 7.0);
436 assert_eq!(result_view[[2]], 13.0);
437 }
438}