1use scirs2_core::ndarray::Array1;
7use sklears_core::types::Float;
8
9#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10use std::arch::x86_64::*;
11
12#[cfg(target_arch = "aarch64")]
13use std::arch::aarch64::*;
14
15pub struct SimdOps;
17
18impl SimdOps {
19 pub fn add_arrays(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
21 debug_assert_eq!(a.len(), b.len(), "Arrays must have the same length");
22
23 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
24 {
25 Self::add_arrays_avx2(a, b)
26 }
27
28 #[cfg(all(
29 target_arch = "x86_64",
30 target_feature = "avx",
31 not(target_feature = "avx2")
32 ))]
33 {
34 Self::add_arrays_avx(a, b)
35 }
36
37 #[cfg(all(
38 target_arch = "x86_64",
39 target_feature = "sse2",
40 not(target_feature = "avx")
41 ))]
42 {
43 Self::add_arrays_sse2(a, b)
44 }
45
46 #[cfg(target_arch = "aarch64")]
47 {
48 Self::add_arrays_neon(a, b)
49 }
50
51 #[cfg(not(any(
52 all(target_arch = "x86_64", target_feature = "sse2"),
53 target_arch = "aarch64"
54 )))]
55 {
56 Self::add_arrays_scalar(a, b)
57 }
58 }
59
60 pub fn scalar_multiply(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
62 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
63 {
64 Self::scalar_multiply_avx2(array, scalar)
65 }
66
67 #[cfg(all(
68 target_arch = "x86_64",
69 target_feature = "avx",
70 not(target_feature = "avx2")
71 ))]
72 {
73 Self::scalar_multiply_avx(array, scalar)
74 }
75
76 #[cfg(all(
77 target_arch = "x86_64",
78 target_feature = "sse2",
79 not(target_feature = "avx")
80 ))]
81 {
82 Self::scalar_multiply_sse2(array, scalar)
83 }
84
85 #[cfg(target_arch = "aarch64")]
86 {
87 Self::scalar_multiply_neon(array, scalar)
88 }
89
90 #[cfg(not(any(
91 all(target_arch = "x86_64", target_feature = "sse2"),
92 target_arch = "aarch64"
93 )))]
94 {
95 Self::scalar_multiply_scalar(array, scalar)
96 }
97 }
98
99 pub fn weighted_sum(arrays: &[&Array1<Float>], weights: &[Float]) -> Array1<Float> {
101 debug_assert_eq!(
102 arrays.len(),
103 weights.len(),
104 "Arrays and weights must have same length"
105 );
106 debug_assert!(!arrays.is_empty(), "Must have at least one array");
107
108 let len = arrays[0].len();
109 debug_assert!(
110 arrays.iter().all(|a| a.len() == len),
111 "All arrays must have same length"
112 );
113
114 let mut result = Array1::zeros(len);
115
116 for (array, &weight) in arrays.iter().zip(weights.iter()) {
117 let weighted_array = Self::scalar_multiply(array, weight);
118 result = Self::add_arrays(&result, &weighted_array);
119 }
120
121 result
122 }
123
124 fn add_arrays_scalar(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
126 a + b
127 }
128
129 fn scalar_multiply_scalar(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
131 array * scalar
132 }
133
134 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
135 fn add_arrays_avx2(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
136 unsafe {
137 let len = a.len();
138 let mut result = Array1::zeros(len);
139 let a_slice = a.as_slice().unwrap();
140 let b_slice = b.as_slice().unwrap();
141 let result_slice = result.as_slice_mut().unwrap();
142
143 let simd_len = len & !7; for i in (0..simd_len).step_by(8) {
146 let a_vec = _mm256_loadu_pd(&a_slice[i] as *const f64);
147 let b_vec = _mm256_loadu_pd(&b_slice[i] as *const f64);
148 let sum = _mm256_add_pd(a_vec, b_vec);
149 _mm256_storeu_pd(&mut result_slice[i] as *mut f64, sum);
150 }
151
152 for i in simd_len..len {
154 result_slice[i] = a_slice[i] + b_slice[i];
155 }
156
157 result
158 }
159 }
160
161 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
162 fn scalar_multiply_avx2(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
163 unsafe {
164 let len = array.len();
165 let mut result = Array1::zeros(len);
166 let array_slice = array.as_slice().unwrap();
167 let result_slice = result.as_slice_mut().unwrap();
168
169 let scalar_vec = _mm256_set1_pd(scalar);
170 let simd_len = len & !7; for i in (0..simd_len).step_by(8) {
173 let array_vec = _mm256_loadu_pd(&array_slice[i] as *const f64);
174 let product = _mm256_mul_pd(array_vec, scalar_vec);
175 _mm256_storeu_pd(&mut result_slice[i] as *mut f64, product);
176 }
177
178 for i in simd_len..len {
180 result_slice[i] = array_slice[i] * scalar;
181 }
182
183 result
184 }
185 }
186
187 #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
188 fn add_arrays_avx(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
189 unsafe {
190 let len = a.len();
191 let mut result = Array1::zeros(len);
192 let a_slice = a.as_slice().unwrap();
193 let b_slice = b.as_slice().unwrap();
194 let result_slice = result.as_slice_mut().unwrap();
195
196 let simd_len = len & !3; for i in (0..simd_len).step_by(4) {
199 let a_vec = _mm256_loadu_pd(&a_slice[i] as *const f64);
200 let b_vec = _mm256_loadu_pd(&b_slice[i] as *const f64);
201 let sum = _mm256_add_pd(a_vec, b_vec);
202 _mm256_storeu_pd(&mut result_slice[i] as *mut f64, sum);
203 }
204
205 for i in simd_len..len {
207 result_slice[i] = a_slice[i] + b_slice[i];
208 }
209
210 result
211 }
212 }
213
214 #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
215 fn scalar_multiply_avx(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
216 unsafe {
217 let len = array.len();
218 let mut result = Array1::zeros(len);
219 let array_slice = array.as_slice().unwrap();
220 let result_slice = result.as_slice_mut().unwrap();
221
222 let scalar_vec = _mm256_set1_pd(scalar);
223 let simd_len = len & !3; for i in (0..simd_len).step_by(4) {
226 let array_vec = _mm256_loadu_pd(&array_slice[i] as *const f64);
227 let product = _mm256_mul_pd(array_vec, scalar_vec);
228 _mm256_storeu_pd(&mut result_slice[i] as *mut f64, product);
229 }
230
231 for i in simd_len..len {
233 result_slice[i] = array_slice[i] * scalar;
234 }
235
236 result
237 }
238 }
239
240 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
241 fn add_arrays_sse2(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
242 unsafe {
243 let len = a.len();
244 let mut result = Array1::zeros(len);
245 let a_slice = a.as_slice().unwrap();
246 let b_slice = b.as_slice().unwrap();
247 let result_slice = result.as_slice_mut().unwrap();
248
249 let simd_len = len & !1; for i in (0..simd_len).step_by(2) {
252 let a_vec = _mm_loadu_pd(&a_slice[i] as *const f64);
253 let b_vec = _mm_loadu_pd(&b_slice[i] as *const f64);
254 let sum = _mm_add_pd(a_vec, b_vec);
255 _mm_storeu_pd(&mut result_slice[i] as *mut f64, sum);
256 }
257
258 for i in simd_len..len {
260 result_slice[i] = a_slice[i] + b_slice[i];
261 }
262
263 result
264 }
265 }
266
267 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
268 fn scalar_multiply_sse2(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
269 unsafe {
270 let len = array.len();
271 let mut result = Array1::zeros(len);
272 let array_slice = array.as_slice().unwrap();
273 let result_slice = result.as_slice_mut().unwrap();
274
275 let scalar_vec = _mm_set1_pd(scalar);
276 let simd_len = len & !1; for i in (0..simd_len).step_by(2) {
279 let array_vec = _mm_loadu_pd(&array_slice[i] as *const f64);
280 let product = _mm_mul_pd(array_vec, scalar_vec);
281 _mm_storeu_pd(&mut result_slice[i] as *mut f64, product);
282 }
283
284 for i in simd_len..len {
286 result_slice[i] = array_slice[i] * scalar;
287 }
288
289 result
290 }
291 }
292
293 #[cfg(target_arch = "aarch64")]
294 fn add_arrays_neon(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
295 unsafe {
296 let len = a.len();
297 let mut result = Array1::zeros(len);
298 let a_slice = a.as_slice().unwrap();
299 let b_slice = b.as_slice().unwrap();
300 let result_slice = result.as_slice_mut().unwrap();
301
302 let simd_len = len & !1; for i in (0..simd_len).step_by(2) {
305 let a_vec = vld1q_f64(&a_slice[i] as *const f64);
306 let b_vec = vld1q_f64(&b_slice[i] as *const f64);
307 let sum = vaddq_f64(a_vec, b_vec);
308 vst1q_f64(&mut result_slice[i] as *mut f64, sum);
309 }
310
311 for i in simd_len..len {
313 result_slice[i] = a_slice[i] + b_slice[i];
314 }
315
316 result
317 }
318 }
319
320 #[cfg(target_arch = "aarch64")]
321 fn scalar_multiply_neon(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
322 unsafe {
323 let len = array.len();
324 let mut result = Array1::zeros(len);
325 let array_slice = array.as_slice().unwrap();
326 let result_slice = result.as_slice_mut().unwrap();
327
328 let scalar_vec = vdupq_n_f64(scalar);
329 let simd_len = len & !1; for i in (0..simd_len).step_by(2) {
332 let array_vec = vld1q_f64(&array_slice[i] as *const f64);
333 let product = vmulq_f64(array_vec, scalar_vec);
334 vst1q_f64(&mut result_slice[i] as *mut f64, product);
335 }
336
337 for i in simd_len..len {
339 result_slice[i] = array_slice[i] * scalar;
340 }
341
342 result
343 }
344 }
345}
346
347#[allow(non_snake_case)]
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use scirs2_core::ndarray::array;
352
353 #[test]
354 fn test_add_arrays() {
355 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
356 let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
357 let result = SimdOps::add_arrays(&a, &b);
358 let expected = array![3.0, 5.0, 7.0, 9.0, 11.0];
359
360 assert_eq!(result, expected);
361 }
362
363 #[test]
364 fn test_scalar_multiply() {
365 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
366 let result = SimdOps::scalar_multiply(&a, 2.0);
367 let expected = array![2.0, 4.0, 6.0, 8.0, 10.0];
368
369 assert_eq!(result, expected);
370 }
371
372 #[test]
373 fn test_weighted_sum() {
374 let a = array![1.0, 2.0, 3.0];
375 let b = array![4.0, 5.0, 6.0];
376 let arrays = vec![&a, &b];
377 let weights = vec![0.5, 0.5];
378
379 let result = SimdOps::weighted_sum(&arrays, &weights);
380 let expected = array![2.5, 3.5, 4.5];
381
382 for (actual, expected) in result.iter().zip(expected.iter()) {
383 assert!((actual - expected).abs() < 1e-10);
384 }
385 }
386
387 #[test]
388 fn test_large_array_operations() {
389 let size = 1000;
390 let a = Array1::from_elem(size, 1.0);
391 let b = Array1::from_elem(size, 2.0);
392
393 let result = SimdOps::add_arrays(&a, &b);
394 assert_eq!(result.len(), size);
395 assert!(result.iter().all(|&x| (x - 3.0).abs() < 1e-10));
396
397 let scaled = SimdOps::scalar_multiply(&a, 5.0);
398 assert!(scaled.iter().all(|&x| (x - 5.0).abs() < 1e-10));
399 }
400}