scirs2_core/simd/
arithmetic.rs

1//! Arithmetic operations with SIMD acceleration
2//!
3//! This module provides optimized implementations of arithmetic operations.
4//! Currently includes scalar multiplication, with more operations to be added.
5
6use ndarray::{Array1, ArrayView1};
7
8pub fn simd_scalar_mul_f32(a: &ArrayView1<f32>, scalar: f32) -> Array1<f32> {
9    let len = a.len();
10    let mut result = Array1::zeros(len);
11    let a_slice = a.as_slice().expect("Operation failed");
12    let result_slice: &mut [f32] = result.as_slice_mut().expect("Operation failed");
13
14    #[cfg(target_arch = "x86_64")]
15    {
16        use std::arch::x86_64::*;
17
18        if is_x86_feature_detected!("avx2") {
19            unsafe {
20                let scalar_vec = _mm256_set1_ps(scalar);
21                let mut i = 0;
22
23                // Process 8 f32s at a time with AVX2 - direct pointer writes
24                while i + 8 <= len {
25                    let a_vec = _mm256_loadu_ps(a_slice.as_ptr().add(i));
26                    let result_vec = _mm256_mul_ps(a_vec, scalar_vec);
27                    _mm256_storeu_ps(result_slice.as_mut_ptr().add(i), result_vec);
28                    i += 8;
29                }
30
31                // Handle remaining elements
32                for j in i..len {
33                    result_slice[j] = a_slice[j] * scalar;
34                }
35
36                return result;
37            }
38        } else if is_x86_feature_detected!("sse") {
39            unsafe {
40                let scalar_vec = _mm_set1_ps(scalar);
41                let mut i = 0;
42
43                // Process 4 f32s at a time with SSE
44                while i + 4 <= len {
45                    let a_vec = _mm_loadu_ps(a_slice.as_ptr().add(i));
46                    let result_vec = _mm_mul_ps(a_vec, scalar_vec);
47                    _mm_storeu_ps(result_slice.as_mut_ptr().add(i), result_vec);
48                    i += 4;
49                }
50
51                // Handle remaining elements
52                for j in i..len {
53                    result_slice[j] = a_slice[j] * scalar;
54                }
55
56                return result;
57            }
58        }
59    }
60
61    #[cfg(target_arch = "aarch64")]
62    {
63        use std::arch::aarch64::*;
64
65        if std::arch::is_aarch64_feature_detected!("neon") {
66            unsafe {
67                let scalar_vec = vdupq_n_f32(scalar);
68                let mut i = 0;
69
70                // Process 4 f32s at a time with NEON
71                while i + 4 <= len {
72                    let a_vec = vld1q_f32(a_slice.as_ptr().add(i));
73                    let result_vec = vmulq_f32(a_vec, scalar_vec);
74                    vst1q_f32(result_slice.as_mut_ptr().add(i), result_vec);
75                    i += 4;
76                }
77
78                // Handle remaining elements
79                for j in i..len {
80                    result_slice[j] = a_slice[j] * scalar;
81                }
82
83                return result;
84            }
85        }
86    }
87
88    // Fallback to scalar implementation
89    for i in 0..len {
90        result_slice[i] = a_slice[i] * scalar;
91    }
92
93    result
94}
95
96/// Apply scalar multiplication to an f64 array using unified SIMD operations
97#[allow(dead_code)]
98pub fn simd_scalar_mul_f64(a: &ArrayView1<f64>, scalar: f64) -> Array1<f64> {
99    let len = a.len();
100    let mut result = Array1::zeros(len);
101    let a_slice = a.as_slice().expect("Operation failed");
102    let result_slice: &mut [f64] = result.as_slice_mut().expect("Operation failed");
103
104    #[cfg(target_arch = "x86_64")]
105    {
106        use std::arch::x86_64::*;
107
108        if is_x86_feature_detected!("avx2") {
109            unsafe {
110                let scalar_vec = _mm256_set1_pd(scalar);
111                let mut i = 0;
112
113                // Process 4 f64s at a time with AVX2 - direct pointer writes
114                while i + 4 <= len {
115                    let a_vec = _mm256_loadu_pd(a_slice.as_ptr().add(i));
116                    let result_vec = _mm256_mul_pd(a_vec, scalar_vec);
117                    _mm256_storeu_pd(result_slice.as_mut_ptr().add(i), result_vec);
118                    i += 4;
119                }
120
121                // Handle remaining elements
122                for j in i..len {
123                    result_slice[j] = a_slice[j] * scalar;
124                }
125
126                return result;
127            }
128        } else if is_x86_feature_detected!("sse2") {
129            unsafe {
130                let scalar_vec = _mm_set1_pd(scalar);
131                let mut i = 0;
132
133                // Process 2 f64s at a time with SSE2
134                while i + 2 <= len {
135                    let a_vec = _mm_loadu_pd(a_slice.as_ptr().add(i));
136                    let result_vec = _mm_mul_pd(a_vec, scalar_vec);
137                    _mm_storeu_pd(result_slice.as_mut_ptr().add(i), result_vec);
138                    i += 2;
139                }
140
141                // Handle remaining elements
142                for j in i..len {
143                    result_slice[j] = a_slice[j] * scalar;
144                }
145
146                return result;
147            }
148        }
149    }
150
151    #[cfg(target_arch = "aarch64")]
152    {
153        use std::arch::aarch64::*;
154
155        if std::arch::is_aarch64_feature_detected!("neon") {
156            unsafe {
157                let scalar_vec = vdupq_n_f64(scalar);
158                let mut i = 0;
159
160                // Process 2 f64s at a time with NEON
161                while i + 2 <= len {
162                    let a_vec = vld1q_f64(a_slice.as_ptr().add(i));
163                    let result_vec = vmulq_f64(a_vec, scalar_vec);
164                    vst1q_f64(result_slice.as_mut_ptr().add(i), result_vec);
165                    i += 2;
166                }
167
168                // Handle remaining elements
169                for j in i..len {
170                    result_slice[j] = a_slice[j] * scalar;
171                }
172
173                return result;
174            }
175        }
176    }
177
178    // Fallback to scalar implementation
179    for i in 0..len {
180        result_slice[i] = a_slice[i] * scalar;
181    }
182
183    result
184}
185
186/// SIMD accelerated linspace function for f32 values
187///
188/// Creates a linearly spaced array between start and end (inclusive)
189/// using SIMD instructions for better performance.
190///
191/// # Arguments
192///
193/// * `start` - Start value
194/// * `end` - End value (inclusive)
195/// * `num` - Number of points
196///
197/// # Returns
198///
199/// * Array of linearly spaced values
200#[allow(dead_code)]
201pub fn linspace_f32(startval: f32, end: f32, num: usize) -> Array1<f32> {
202    if num < 2 {
203        return Array1::from_vec(vec![startval]);
204    }
205
206    let mut result = Array1::zeros(num);
207    let step = (end - startval) / (num as f32 - 1.0);
208
209    // Use scalar implementation for now - could be optimized with SIMD
210    for (i, elem) in result.iter_mut().enumerate() {
211        *elem = startval + step * i as f32;
212    }
213
214    // Make sure the last value is exactly end to avoid floating point precision issues
215    if let Some(last) = result.last_mut() {
216        *last = end;
217    }
218
219    result
220}
221
222/// SIMD accelerated linspace function for f64 values
223#[allow(dead_code)]
224pub fn linspace_f64(startval: f64, end: f64, num: usize) -> Array1<f64> {
225    if num < 2 {
226        return Array1::from_vec(vec![startval]);
227    }
228
229    let mut result = Array1::zeros(num);
230    let step = (end - startval) / (num as f64 - 1.0);
231
232    // Use scalar implementation for now - could be optimized with SIMD
233    for (i, elem) in result.iter_mut().enumerate() {
234        *elem = startval + step * i as f64;
235    }
236
237    // Make sure the last value is exactly end to avoid floating point precision issues
238    if let Some(last) = result.last_mut() {
239        *last = end;
240    }
241
242    result
243}