Skip to main content

sklears_simd/
advanced_optimizations.rs

1//! Advanced SIMD optimization techniques
2//!
3//! This module provides cutting-edge optimization techniques for SIMD operations,
4//! including cache-aware algorithms, vectorization strategies, and memory-efficient
5//! implementations for high-performance machine learning computations.
6
7use crate::traits::SimdError;
8#[cfg(target_arch = "x86")]
9use core::arch::x86::*;
10#[cfg(target_arch = "x86_64")]
11use core::arch::x86_64::*;
12
13#[cfg(feature = "no-std")]
14use alloc::{vec, vec::Vec};
15#[cfg(not(feature = "no-std"))]
16use std::{vec, vec::Vec};
17
18/// Parameters for the convolution operation.
19pub struct ConvolutionParams {
20    /// Shape of the input tensor as `(channels, height, width)`.
21    pub input_shape: (usize, usize, usize),
22    /// Shape of the kernel tensor as `(filters, height, width)`.
23    pub kernel_shape: (usize, usize, usize),
24    /// Convolution stride.
25    pub stride: usize,
26    /// Zero-padding applied to input borders.
27    pub padding: usize,
28}
29
30/// Row/column ranges for a single matrix-multiplication tile.
31struct BlockRange {
32    i_start: usize,
33    j_start: usize,
34    k_start: usize,
35    i_end: usize,
36    j_end: usize,
37    k_end: usize,
38    n: usize,
39    k: usize,
40}
41
42/// Advanced SIMD optimization strategies
43pub struct AdvancedSimdOptimizer {
44    #[allow(dead_code)]
45    // Used for cache-aware algorithm tuning; read by calculate_optimal_block_size indirectly
46    cache_line_size: usize,
47    #[allow(dead_code)] // Reserved for prefetch-hint emission in future AVX512 prefetch path
48    prefetch_distance: usize,
49    #[allow(dead_code)] // Stored for reporting and future adaptive-width selection
50    vectorization_width: usize,
51}
52
53impl AdvancedSimdOptimizer {
54    /// Create a new advanced SIMD optimizer with platform-specific tuning
55    pub fn new() -> Self {
56        Self {
57            cache_line_size: 64,    // Common cache line size
58            prefetch_distance: 512, // Prefetch distance in bytes
59            vectorization_width: 8, // AVX-256 width for f32
60        }
61    }
62
63    /// Cache-aware matrix multiplication with blocking
64    pub fn cache_aware_matrix_multiply(
65        &self,
66        a: &[f32],
67        b: &[f32],
68        c: &mut [f32],
69        m: usize,
70        n: usize,
71        k: usize,
72    ) -> Result<(), SimdError> {
73        if a.len() != m * k || b.len() != k * n || c.len() != m * n {
74            return Err(SimdError::DimensionMismatch {
75                expected: m * n,
76                actual: c.len(),
77            });
78        }
79
80        // Optimal block sizes for cache efficiency
81        let block_size = self.calculate_optimal_block_size(m, n, k);
82
83        for i in (0..m).step_by(block_size) {
84            for j in (0..n).step_by(block_size) {
85                for kk in (0..k).step_by(block_size) {
86                    let i_max = (i + block_size).min(m);
87                    let j_max = (j + block_size).min(n);
88                    let k_max = (kk + block_size).min(k);
89
90                    self.matrix_multiply_block(
91                        a,
92                        b,
93                        c,
94                        &BlockRange {
95                            i_start: i,
96                            j_start: j,
97                            k_start: kk,
98                            i_end: i_max,
99                            j_end: j_max,
100                            k_end: k_max,
101                            n,
102                            k,
103                        },
104                    )?;
105                }
106            }
107        }
108
109        Ok(())
110    }
111
112    /// Vectorized dot product with manual loop unrolling
113    pub fn vectorized_dot_product(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
114        if a.len() != b.len() {
115            return Err(SimdError::DimensionMismatch {
116                expected: a.len(),
117                actual: b.len(),
118            });
119        }
120
121        let len = a.len();
122        if len == 0 {
123            return Ok(0.0);
124        }
125
126        let mut result = 0.0f32;
127
128        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129        {
130            if crate::simd_feature_detected!("avx2") {
131                return unsafe { self.dot_product_avx2(a, b) };
132            } else if crate::simd_feature_detected!("sse2") {
133                return unsafe { self.dot_product_sse2(a, b) };
134            }
135        }
136
137        // Fallback scalar implementation with loop unrolling
138        let chunks = len / 4;
139        let remainder = len % 4;
140
141        for i in 0..chunks {
142            let base = i * 4;
143            result += a[base] * b[base]
144                + a[base + 1] * b[base + 1]
145                + a[base + 2] * b[base + 2]
146                + a[base + 3] * b[base + 3];
147        }
148
149        for i in (chunks * 4)..(chunks * 4 + remainder) {
150            result += a[i] * b[i];
151        }
152
153        Ok(result)
154    }
155
156    /// Memory-efficient convolution with spatial locality optimization
157    pub fn optimized_convolution(
158        &self,
159        input: &[f32],
160        kernel: &[f32],
161        output: &mut [f32],
162        params: &ConvolutionParams,
163    ) -> Result<(), SimdError> {
164        let (in_channels, in_height, in_width) = params.input_shape;
165        let (out_channels, k_height, k_width) = params.kernel_shape;
166        let stride = params.stride;
167        let padding = params.padding;
168
169        let out_height = (in_height + 2 * padding - k_height) / stride + 1;
170        let out_width = (in_width + 2 * padding - k_width) / stride + 1;
171
172        if output.len() != out_channels * out_height * out_width {
173            return Err(SimdError::DimensionMismatch {
174                expected: out_channels * out_height * out_width,
175                actual: output.len(),
176            });
177        }
178
179        // Use im2col transformation for better memory access patterns
180        let im2col_data = self.im2col_transform(
181            input,
182            params.input_shape,
183            params.kernel_shape,
184            stride,
185            padding,
186        )?;
187
188        // Perform optimized matrix multiplication
189        self.cache_aware_matrix_multiply(
190            kernel,
191            &im2col_data,
192            output,
193            out_channels,
194            out_height * out_width,
195            in_channels * k_height * k_width,
196        )?;
197
198        Ok(())
199    }
200
201    /// Advanced vectorized reduction with tree reduction pattern
202    pub fn vectorized_reduction(&self, data: &[f32], op: ReductionOp) -> Result<f32, SimdError> {
203        if data.is_empty() {
204            return Err(SimdError::EmptyInput);
205        }
206
207        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208        {
209            if crate::simd_feature_detected!("avx2") {
210                return unsafe { self.reduction_avx2(data, op) };
211            }
212        }
213
214        // Fallback scalar implementation
215        match op {
216            ReductionOp::Sum => Ok(data.iter().sum()),
217            ReductionOp::Max => Ok(data.iter().copied().fold(f32::NEG_INFINITY, f32::max)),
218            ReductionOp::Min => Ok(data.iter().copied().fold(f32::INFINITY, f32::min)),
219            ReductionOp::Mean => Ok(data.iter().sum::<f32>() / data.len() as f32),
220        }
221    }
222
223    // Private helper methods
224
225    fn calculate_optimal_block_size(&self, _m: usize, _n: usize, _k: usize) -> usize {
226        // Estimate optimal block size based on cache size and data dimensions
227        let cache_size = 32768; // L1 cache size estimate
228        let element_size = 4; // f32 size
229        let block_elements = cache_size / (3 * element_size); // Account for A, B, C matrices
230
231        let block_size = (block_elements as f32).sqrt() as usize;
232        block_size.clamp(8, 64) // Clamp to reasonable range
233    }
234
235    fn matrix_multiply_block(
236        &self,
237        a: &[f32],
238        b: &[f32],
239        c: &mut [f32],
240        block: &BlockRange,
241    ) -> Result<(), SimdError> {
242        for i in block.i_start..block.i_end {
243            for j in block.j_start..block.j_end {
244                let mut sum = 0.0f32;
245                for kk in block.k_start..block.k_end {
246                    sum += a[i * block.k + kk] * b[kk * block.n + j];
247                }
248                c[i * block.n + j] += sum;
249            }
250        }
251        Ok(())
252    }
253
254    fn im2col_transform(
255        &self,
256        input: &[f32],
257        input_shape: (usize, usize, usize),
258        kernel_shape: (usize, usize, usize),
259        stride: usize,
260        padding: usize,
261    ) -> Result<Vec<f32>, SimdError> {
262        let (in_channels, in_height, in_width) = input_shape;
263        let (_, k_height, k_width) = kernel_shape;
264
265        let out_height = (in_height + 2 * padding - k_height) / stride + 1;
266        let out_width = (in_width + 2 * padding - k_width) / stride + 1;
267
268        let mut result = vec![0.0f32; in_channels * k_height * k_width * out_height * out_width];
269
270        for c in 0..in_channels {
271            for kh in 0..k_height {
272                for kw in 0..k_width {
273                    for oh in 0..out_height {
274                        for ow in 0..out_width {
275                            let ih = oh * stride + kh;
276                            let iw = ow * stride + kw;
277
278                            let value = if ih >= padding
279                                && ih < in_height + padding
280                                && iw >= padding
281                                && iw < in_width + padding
282                            {
283                                let adjusted_ih = ih - padding;
284                                let adjusted_iw = iw - padding;
285                                input[c * in_height * in_width
286                                    + adjusted_ih * in_width
287                                    + adjusted_iw]
288                            } else {
289                                0.0f32
290                            };
291
292                            let col_idx = (c * k_height * k_width + kh * k_width + kw)
293                                * out_height
294                                * out_width
295                                + oh * out_width
296                                + ow;
297                            result[col_idx] = value;
298                        }
299                    }
300                }
301            }
302        }
303
304        Ok(result)
305    }
306
307    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308    #[target_feature(enable = "avx2")]
309    unsafe fn dot_product_avx2(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
310        let len = a.len();
311        let mut sum = _mm256_setzero_ps();
312
313        let chunks = len / 8;
314        for i in 0..chunks {
315            let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
316            let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
317            let product = _mm256_mul_ps(a_vec, b_vec);
318            sum = _mm256_add_ps(sum, product);
319        }
320
321        // Horizontal sum
322        let sum_high = _mm256_extractf128_ps(sum, 1);
323        let sum_low = _mm256_castps256_ps128(sum);
324        let sum128 = _mm_add_ps(sum_high, sum_low);
325
326        let mut result = [0.0f32; 4];
327        _mm_storeu_ps(result.as_mut_ptr(), sum128);
328        let mut final_sum = result[0] + result[1] + result[2] + result[3];
329
330        // Handle remaining elements
331        for i in (chunks * 8)..len {
332            final_sum += a[i] * b[i];
333        }
334
335        Ok(final_sum)
336    }
337
338    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
339    #[target_feature(enable = "sse2")]
340    unsafe fn dot_product_sse2(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
341        let len = a.len();
342        let mut sum = _mm_setzero_ps();
343
344        let chunks = len / 4;
345        for i in 0..chunks {
346            let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
347            let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4));
348            let product = _mm_mul_ps(a_vec, b_vec);
349            sum = _mm_add_ps(sum, product);
350        }
351
352        let mut result = [0.0f32; 4];
353        _mm_storeu_ps(result.as_mut_ptr(), sum);
354        let mut final_sum = result[0] + result[1] + result[2] + result[3];
355
356        // Handle remaining elements
357        for i in (chunks * 4)..len {
358            final_sum += a[i] * b[i];
359        }
360
361        Ok(final_sum)
362    }
363
364    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
365    #[target_feature(enable = "avx2")]
366    unsafe fn reduction_avx2(&self, data: &[f32], op: ReductionOp) -> Result<f32, SimdError> {
367        let len = data.len();
368        let chunks = len / 8;
369
370        let mut accumulator = match op {
371            ReductionOp::Sum | ReductionOp::Mean => _mm256_setzero_ps(),
372            ReductionOp::Max => _mm256_set1_ps(f32::NEG_INFINITY),
373            ReductionOp::Min => _mm256_set1_ps(f32::INFINITY),
374        };
375
376        for i in 0..chunks {
377            let data_vec = _mm256_loadu_ps(data.as_ptr().add(i * 8));
378            accumulator = match op {
379                ReductionOp::Sum | ReductionOp::Mean => _mm256_add_ps(accumulator, data_vec),
380                ReductionOp::Max => _mm256_max_ps(accumulator, data_vec),
381                ReductionOp::Min => _mm256_min_ps(accumulator, data_vec),
382            };
383        }
384
385        // Horizontal reduction
386        let mut result = [0.0f32; 8];
387        _mm256_storeu_ps(result.as_mut_ptr(), accumulator);
388
389        let mut final_result = match op {
390            ReductionOp::Sum | ReductionOp::Mean => result.iter().sum::<f32>(),
391            ReductionOp::Max => result.iter().copied().fold(f32::NEG_INFINITY, f32::max),
392            ReductionOp::Min => result.iter().copied().fold(f32::INFINITY, f32::min),
393        };
394
395        // Handle remaining elements
396        for val in data.iter().take(len).skip(chunks * 8) {
397            final_result = match op {
398                ReductionOp::Sum | ReductionOp::Mean => final_result + *val,
399                ReductionOp::Max => final_result.max(*val),
400                ReductionOp::Min => final_result.min(*val),
401            };
402        }
403
404        if matches!(op, ReductionOp::Mean) {
405            final_result /= len as f32;
406        }
407
408        Ok(final_result)
409    }
410}
411
412impl Default for AdvancedSimdOptimizer {
413    fn default() -> Self {
414        Self::new()
415    }
416}
417
418/// Reduction operation types
419#[derive(Debug, Clone, Copy)]
420pub enum ReductionOp {
421    Sum,
422    Max,
423    Min,
424    Mean,
425}
426
427/// Cache-aware sorting for SIMD operations
428pub struct CacheAwareSort;
429
430impl CacheAwareSort {
431    /// Vectorized merge sort with cache-friendly access patterns
432    pub fn vectorized_merge_sort(data: &mut [f32]) {
433        if data.len() <= 1 {
434            return;
435        }
436
437        let mid = data.len() / 2;
438        Self::vectorized_merge_sort(&mut data[..mid]);
439        Self::vectorized_merge_sort(&mut data[mid..]);
440
441        // Cache-friendly merge
442        let mut temp = vec![0.0f32; data.len()];
443        Self::cache_friendly_merge(data, &mut temp, mid);
444        data.copy_from_slice(&temp);
445    }
446
447    fn cache_friendly_merge(data: &[f32], temp: &mut [f32], mid: usize) {
448        let (left, right) = data.split_at(mid);
449        let mut i = 0;
450        let mut j = 0;
451        let mut k = 0;
452
453        while i < left.len() && j < right.len() {
454            if left[i] <= right[j] {
455                temp[k] = left[i];
456                i += 1;
457            } else {
458                temp[k] = right[j];
459                j += 1;
460            }
461            k += 1;
462        }
463
464        while i < left.len() {
465            temp[k] = left[i];
466            i += 1;
467            k += 1;
468        }
469
470        while j < right.len() {
471            temp[k] = right[j];
472            j += 1;
473            k += 1;
474        }
475    }
476}
477
478#[allow(non_snake_case)]
479#[cfg(all(test, not(feature = "no-std")))]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_vectorized_dot_product() {
485        let optimizer = AdvancedSimdOptimizer::new();
486        let a = vec![1.0, 2.0, 3.0, 4.0];
487        let b = vec![5.0, 6.0, 7.0, 8.0];
488
489        let result = optimizer
490            .vectorized_dot_product(&a, &b)
491            .expect("operation should succeed");
492        let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0;
493
494        assert!((result - expected).abs() < 1e-6);
495    }
496
497    #[test]
498    fn test_vectorized_reduction() {
499        let optimizer = AdvancedSimdOptimizer::new();
500        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
501
502        let sum = optimizer
503            .vectorized_reduction(&data, ReductionOp::Sum)
504            .expect("operation should succeed");
505        assert_eq!(sum, 15.0);
506
507        let max = optimizer
508            .vectorized_reduction(&data, ReductionOp::Max)
509            .expect("operation should succeed");
510        assert_eq!(max, 5.0);
511
512        let min = optimizer
513            .vectorized_reduction(&data, ReductionOp::Min)
514            .expect("operation should succeed");
515        assert_eq!(min, 1.0);
516
517        let mean = optimizer
518            .vectorized_reduction(&data, ReductionOp::Mean)
519            .expect("operation should succeed");
520        assert_eq!(mean, 3.0);
521    }
522
523    #[test]
524    fn test_cache_aware_sort() {
525        let mut data = vec![5.0, 2.0, 8.0, 1.0, 9.0, 3.0];
526        CacheAwareSort::vectorized_merge_sort(&mut data);
527
528        let expected = vec![1.0, 2.0, 3.0, 5.0, 8.0, 9.0];
529        assert_eq!(data, expected);
530    }
531}