sklears_ensemble/
simd_ops.rs

1//! SIMD optimizations for ensemble operations
2//!
3//! This module provides SIMD-accelerated implementations of common ensemble operations
4//! such as array addition, scalar multiplication, and weighted averaging.
5
6use scirs2_core::ndarray::Array1;
7use sklears_core::types::Float;
8
9#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10use std::arch::x86_64::*;
11
12#[cfg(target_arch = "aarch64")]
13use std::arch::aarch64::*;
14
15/// SIMD-optimized array operations for ensemble methods
16pub struct SimdOps;
17
18impl SimdOps {
19    /// Add two arrays with SIMD acceleration when available
20    pub fn add_arrays(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
21        debug_assert_eq!(a.len(), b.len(), "Arrays must have the same length");
22
23        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
24        {
25            Self::add_arrays_avx2(a, b)
26        }
27
28        #[cfg(all(
29            target_arch = "x86_64",
30            target_feature = "avx",
31            not(target_feature = "avx2")
32        ))]
33        {
34            Self::add_arrays_avx(a, b)
35        }
36
37        #[cfg(all(
38            target_arch = "x86_64",
39            target_feature = "sse2",
40            not(target_feature = "avx")
41        ))]
42        {
43            Self::add_arrays_sse2(a, b)
44        }
45
46        #[cfg(target_arch = "aarch64")]
47        {
48            Self::add_arrays_neon(a, b)
49        }
50
51        #[cfg(not(any(
52            all(target_arch = "x86_64", target_feature = "sse2"),
53            target_arch = "aarch64"
54        )))]
55        {
56            Self::add_arrays_scalar(a, b)
57        }
58    }
59
60    /// Multiply array by scalar with SIMD acceleration when available
61    pub fn scalar_multiply(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
62        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
63        {
64            Self::scalar_multiply_avx2(array, scalar)
65        }
66
67        #[cfg(all(
68            target_arch = "x86_64",
69            target_feature = "avx",
70            not(target_feature = "avx2")
71        ))]
72        {
73            Self::scalar_multiply_avx(array, scalar)
74        }
75
76        #[cfg(all(
77            target_arch = "x86_64",
78            target_feature = "sse2",
79            not(target_feature = "avx")
80        ))]
81        {
82            Self::scalar_multiply_sse2(array, scalar)
83        }
84
85        #[cfg(target_arch = "aarch64")]
86        {
87            Self::scalar_multiply_neon(array, scalar)
88        }
89
90        #[cfg(not(any(
91            all(target_arch = "x86_64", target_feature = "sse2"),
92            target_arch = "aarch64"
93        )))]
94        {
95            Self::scalar_multiply_scalar(array, scalar)
96        }
97    }
98
99    /// Weighted sum of multiple arrays with SIMD acceleration
100    pub fn weighted_sum(arrays: &[&Array1<Float>], weights: &[Float]) -> Array1<Float> {
101        debug_assert_eq!(
102            arrays.len(),
103            weights.len(),
104            "Arrays and weights must have same length"
105        );
106        debug_assert!(!arrays.is_empty(), "Must have at least one array");
107
108        let len = arrays[0].len();
109        debug_assert!(
110            arrays.iter().all(|a| a.len() == len),
111            "All arrays must have same length"
112        );
113
114        let mut result = Array1::zeros(len);
115
116        for (array, &weight) in arrays.iter().zip(weights.iter()) {
117            let weighted_array = Self::scalar_multiply(array, weight);
118            result = Self::add_arrays(&result, &weighted_array);
119        }
120
121        result
122    }
123
124    /// Scalar implementation for fallback
125    fn add_arrays_scalar(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
126        a + b
127    }
128
129    /// Scalar implementation for fallback
130    fn scalar_multiply_scalar(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
131        array * scalar
132    }
133
134    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
135    fn add_arrays_avx2(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
136        unsafe {
137            let len = a.len();
138            let mut result = Array1::zeros(len);
139            let a_slice = a.as_slice().unwrap();
140            let b_slice = b.as_slice().unwrap();
141            let result_slice = result.as_slice_mut().unwrap();
142
143            let simd_len = len & !7; // Process 8 elements at a time
144
145            for i in (0..simd_len).step_by(8) {
146                let a_vec = _mm256_loadu_pd(&a_slice[i] as *const f64);
147                let b_vec = _mm256_loadu_pd(&b_slice[i] as *const f64);
148                let sum = _mm256_add_pd(a_vec, b_vec);
149                _mm256_storeu_pd(&mut result_slice[i] as *mut f64, sum);
150            }
151
152            // Handle remaining elements
153            for i in simd_len..len {
154                result_slice[i] = a_slice[i] + b_slice[i];
155            }
156
157            result
158        }
159    }
160
161    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
162    fn scalar_multiply_avx2(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
163        unsafe {
164            let len = array.len();
165            let mut result = Array1::zeros(len);
166            let array_slice = array.as_slice().unwrap();
167            let result_slice = result.as_slice_mut().unwrap();
168
169            let scalar_vec = _mm256_set1_pd(scalar);
170            let simd_len = len & !7; // Process 8 elements at a time
171
172            for i in (0..simd_len).step_by(8) {
173                let array_vec = _mm256_loadu_pd(&array_slice[i] as *const f64);
174                let product = _mm256_mul_pd(array_vec, scalar_vec);
175                _mm256_storeu_pd(&mut result_slice[i] as *mut f64, product);
176            }
177
178            // Handle remaining elements
179            for i in simd_len..len {
180                result_slice[i] = array_slice[i] * scalar;
181            }
182
183            result
184        }
185    }
186
187    #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
188    fn add_arrays_avx(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
189        unsafe {
190            let len = a.len();
191            let mut result = Array1::zeros(len);
192            let a_slice = a.as_slice().unwrap();
193            let b_slice = b.as_slice().unwrap();
194            let result_slice = result.as_slice_mut().unwrap();
195
196            let simd_len = len & !3; // Process 4 elements at a time
197
198            for i in (0..simd_len).step_by(4) {
199                let a_vec = _mm256_loadu_pd(&a_slice[i] as *const f64);
200                let b_vec = _mm256_loadu_pd(&b_slice[i] as *const f64);
201                let sum = _mm256_add_pd(a_vec, b_vec);
202                _mm256_storeu_pd(&mut result_slice[i] as *mut f64, sum);
203            }
204
205            // Handle remaining elements
206            for i in simd_len..len {
207                result_slice[i] = a_slice[i] + b_slice[i];
208            }
209
210            result
211        }
212    }
213
214    #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
215    fn scalar_multiply_avx(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
216        unsafe {
217            let len = array.len();
218            let mut result = Array1::zeros(len);
219            let array_slice = array.as_slice().unwrap();
220            let result_slice = result.as_slice_mut().unwrap();
221
222            let scalar_vec = _mm256_set1_pd(scalar);
223            let simd_len = len & !3; // Process 4 elements at a time
224
225            for i in (0..simd_len).step_by(4) {
226                let array_vec = _mm256_loadu_pd(&array_slice[i] as *const f64);
227                let product = _mm256_mul_pd(array_vec, scalar_vec);
228                _mm256_storeu_pd(&mut result_slice[i] as *mut f64, product);
229            }
230
231            // Handle remaining elements
232            for i in simd_len..len {
233                result_slice[i] = array_slice[i] * scalar;
234            }
235
236            result
237        }
238    }
239
240    #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
241    fn add_arrays_sse2(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
242        unsafe {
243            let len = a.len();
244            let mut result = Array1::zeros(len);
245            let a_slice = a.as_slice().unwrap();
246            let b_slice = b.as_slice().unwrap();
247            let result_slice = result.as_slice_mut().unwrap();
248
249            let simd_len = len & !1; // Process 2 elements at a time
250
251            for i in (0..simd_len).step_by(2) {
252                let a_vec = _mm_loadu_pd(&a_slice[i] as *const f64);
253                let b_vec = _mm_loadu_pd(&b_slice[i] as *const f64);
254                let sum = _mm_add_pd(a_vec, b_vec);
255                _mm_storeu_pd(&mut result_slice[i] as *mut f64, sum);
256            }
257
258            // Handle remaining elements
259            for i in simd_len..len {
260                result_slice[i] = a_slice[i] + b_slice[i];
261            }
262
263            result
264        }
265    }
266
267    #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
268    fn scalar_multiply_sse2(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
269        unsafe {
270            let len = array.len();
271            let mut result = Array1::zeros(len);
272            let array_slice = array.as_slice().unwrap();
273            let result_slice = result.as_slice_mut().unwrap();
274
275            let scalar_vec = _mm_set1_pd(scalar);
276            let simd_len = len & !1; // Process 2 elements at a time
277
278            for i in (0..simd_len).step_by(2) {
279                let array_vec = _mm_loadu_pd(&array_slice[i] as *const f64);
280                let product = _mm_mul_pd(array_vec, scalar_vec);
281                _mm_storeu_pd(&mut result_slice[i] as *mut f64, product);
282            }
283
284            // Handle remaining elements
285            for i in simd_len..len {
286                result_slice[i] = array_slice[i] * scalar;
287            }
288
289            result
290        }
291    }
292
293    #[cfg(target_arch = "aarch64")]
294    fn add_arrays_neon(a: &Array1<Float>, b: &Array1<Float>) -> Array1<Float> {
295        unsafe {
296            let len = a.len();
297            let mut result = Array1::zeros(len);
298            let a_slice = a.as_slice().unwrap();
299            let b_slice = b.as_slice().unwrap();
300            let result_slice = result.as_slice_mut().unwrap();
301
302            let simd_len = len & !1; // Process 2 elements at a time
303
304            for i in (0..simd_len).step_by(2) {
305                let a_vec = vld1q_f64(&a_slice[i] as *const f64);
306                let b_vec = vld1q_f64(&b_slice[i] as *const f64);
307                let sum = vaddq_f64(a_vec, b_vec);
308                vst1q_f64(&mut result_slice[i] as *mut f64, sum);
309            }
310
311            // Handle remaining elements
312            for i in simd_len..len {
313                result_slice[i] = a_slice[i] + b_slice[i];
314            }
315
316            result
317        }
318    }
319
320    #[cfg(target_arch = "aarch64")]
321    fn scalar_multiply_neon(array: &Array1<Float>, scalar: Float) -> Array1<Float> {
322        unsafe {
323            let len = array.len();
324            let mut result = Array1::zeros(len);
325            let array_slice = array.as_slice().unwrap();
326            let result_slice = result.as_slice_mut().unwrap();
327
328            let scalar_vec = vdupq_n_f64(scalar);
329            let simd_len = len & !1; // Process 2 elements at a time
330
331            for i in (0..simd_len).step_by(2) {
332                let array_vec = vld1q_f64(&array_slice[i] as *const f64);
333                let product = vmulq_f64(array_vec, scalar_vec);
334                vst1q_f64(&mut result_slice[i] as *mut f64, product);
335            }
336
337            // Handle remaining elements
338            for i in simd_len..len {
339                result_slice[i] = array_slice[i] * scalar;
340            }
341
342            result
343        }
344    }
345}
346
347#[allow(non_snake_case)]
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use scirs2_core::ndarray::array;
352
353    #[test]
354    fn test_add_arrays() {
355        let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
356        let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
357        let result = SimdOps::add_arrays(&a, &b);
358        let expected = array![3.0, 5.0, 7.0, 9.0, 11.0];
359
360        assert_eq!(result, expected);
361    }
362
363    #[test]
364    fn test_scalar_multiply() {
365        let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
366        let result = SimdOps::scalar_multiply(&a, 2.0);
367        let expected = array![2.0, 4.0, 6.0, 8.0, 10.0];
368
369        assert_eq!(result, expected);
370    }
371
372    #[test]
373    fn test_weighted_sum() {
374        let a = array![1.0, 2.0, 3.0];
375        let b = array![4.0, 5.0, 6.0];
376        let arrays = vec![&a, &b];
377        let weights = vec![0.5, 0.5];
378
379        let result = SimdOps::weighted_sum(&arrays, &weights);
380        let expected = array![2.5, 3.5, 4.5];
381
382        for (actual, expected) in result.iter().zip(expected.iter()) {
383            assert!((actual - expected).abs() < 1e-10);
384        }
385    }
386
387    #[test]
388    fn test_large_array_operations() {
389        let size = 1000;
390        let a = Array1::from_elem(size, 1.0);
391        let b = Array1::from_elem(size, 2.0);
392
393        let result = SimdOps::add_arrays(&a, &b);
394        assert_eq!(result.len(), size);
395        assert!(result.iter().all(|&x| (x - 3.0).abs() < 1e-10));
396
397        let scaled = SimdOps::scalar_multiply(&a, 5.0);
398        assert!(scaled.iter().all(|&x| (x - 5.0).abs() < 1e-10));
399    }
400}