sklears_kernel_approximation/
unsafe_optimizations.rs

1//! Performance-critical unsafe optimizations
2//!
3//! This module contains unsafe code for performance-critical paths in kernel
4//! approximation. All unsafe code is carefully reviewed and documented.
5//!
6//! # Safety
7//!
8//! All functions in this module that use unsafe code have detailed safety
9//! documentation explaining the invariants that must be upheld.
10
11use scirs2_core::ndarray::{Array1, Array2};
12
13/// Unsafe dot product with manual loop unrolling
14///
15/// # Safety
16///
17/// - `a` and `b` must have the same length
18/// - All elements must be valid f64 values (not NaN that could cause UB in comparisons)
19///
20/// # Performance
21///
22/// This function uses manual loop unrolling to improve performance by:
23/// - Reducing loop overhead
24/// - Enabling better instruction-level parallelism
25/// - Improving CPU pipeline utilization
26#[inline]
27pub unsafe fn dot_product_unrolled(a: &[f64], b: &[f64]) -> f64 {
28    debug_assert_eq!(a.len(), b.len(), "Vectors must have the same length");
29
30    let len = a.len();
31    let chunks = len / 4;
32    let remainder = len % 4;
33
34    let mut sum0 = 0.0;
35    let mut sum1 = 0.0;
36    let mut sum2 = 0.0;
37    let mut sum3 = 0.0;
38
39    let a_ptr = a.as_ptr();
40    let b_ptr = b.as_ptr();
41
42    // Process 4 elements at a time
43    for i in 0..chunks {
44        let idx = i * 4;
45        sum0 += *a_ptr.add(idx) * *b_ptr.add(idx);
46        sum1 += *a_ptr.add(idx + 1) * *b_ptr.add(idx + 1);
47        sum2 += *a_ptr.add(idx + 2) * *b_ptr.add(idx + 2);
48        sum3 += *a_ptr.add(idx + 3) * *b_ptr.add(idx + 3);
49    }
50
51    // Handle remainder
52    let mut sum_remainder = 0.0;
53    for i in 0..remainder {
54        let idx = chunks * 4 + i;
55        sum_remainder += *a_ptr.add(idx) * *b_ptr.add(idx);
56    }
57
58    sum0 + sum1 + sum2 + sum3 + sum_remainder
59}
60
61/// Fast matrix-vector multiplication using unsafe optimizations
62///
63/// Computes `result = matrix * vector`
64///
65/// # Safety
66///
67/// - `matrix` must have the same number of columns as `vector` has elements
68/// - `result` must have the same length as `matrix` has rows
69/// - All slices must be properly aligned and valid
70#[inline]
71pub unsafe fn matvec_multiply_fast(matrix: &Array2<f64>, vector: &[f64], result: &mut [f64]) {
72    let (n_rows, n_cols) = matrix.dim();
73    debug_assert_eq!(n_cols, vector.len(), "Dimension mismatch");
74    debug_assert_eq!(n_rows, result.len(), "Result size mismatch");
75
76    let matrix_ptr = matrix.as_ptr();
77    let vector_ptr = vector.as_ptr();
78    let result_ptr = result.as_mut_ptr();
79
80    for i in 0..n_rows {
81        let row_offset = i * n_cols;
82        let mut sum = 0.0;
83
84        // Manual loop unrolling for better performance
85        let chunks = n_cols / 4;
86        let remainder = n_cols % 4;
87
88        for j in 0..chunks {
89            let idx = j * 4;
90            sum += *matrix_ptr.add(row_offset + idx) * *vector_ptr.add(idx);
91            sum += *matrix_ptr.add(row_offset + idx + 1) * *vector_ptr.add(idx + 1);
92            sum += *matrix_ptr.add(row_offset + idx + 2) * *vector_ptr.add(idx + 2);
93            sum += *matrix_ptr.add(row_offset + idx + 3) * *vector_ptr.add(idx + 3);
94        }
95
96        for j in 0..remainder {
97            let idx = chunks * 4 + j;
98            sum += *matrix_ptr.add(row_offset + idx) * *vector_ptr.add(idx);
99        }
100
101        *result_ptr.add(i) = sum;
102    }
103}
104
105/// Fast element-wise operations with unrolled loops
106///
107/// # Safety
108///
109/// - All slices must have the same length
110/// - Output slice must be valid for writing
111#[inline]
112pub unsafe fn elementwise_op_fast<F>(a: &[f64], b: &[f64], out: &mut [f64], mut op: F)
113where
114    F: FnMut(f64, f64) -> f64,
115{
116    debug_assert_eq!(a.len(), b.len());
117    debug_assert_eq!(a.len(), out.len());
118
119    let len = a.len();
120    let chunks = len / 4;
121    let remainder = len % 4;
122
123    let a_ptr = a.as_ptr();
124    let b_ptr = b.as_ptr();
125    let out_ptr = out.as_mut_ptr();
126
127    // Process 4 elements at a time
128    for i in 0..chunks {
129        let idx = i * 4;
130        *out_ptr.add(idx) = op(*a_ptr.add(idx), *b_ptr.add(idx));
131        *out_ptr.add(idx + 1) = op(*a_ptr.add(idx + 1), *b_ptr.add(idx + 1));
132        *out_ptr.add(idx + 2) = op(*a_ptr.add(idx + 2), *b_ptr.add(idx + 2));
133        *out_ptr.add(idx + 3) = op(*a_ptr.add(idx + 3), *b_ptr.add(idx + 3));
134    }
135
136    // Handle remainder
137    for i in 0..remainder {
138        let idx = chunks * 4 + i;
139        *out_ptr.add(idx) = op(*a_ptr.add(idx), *b_ptr.add(idx));
140    }
141}
142
143/// Fast exponential computation for RBF kernels
144///
145/// Computes exp(-gamma * ||x - y||^2) for kernel matrices
146///
147/// # Safety
148///
149/// - Input and output slices must be properly sized
150/// - gamma must be a valid f64 value
151#[inline]
152pub unsafe fn rbf_kernel_fast(x: &[f64], y: &[f64], gamma: f64) -> f64 {
153    debug_assert_eq!(x.len(), y.len());
154
155    let len = x.len();
156    let chunks = len / 4;
157    let remainder = len % 4;
158
159    let x_ptr = x.as_ptr();
160    let y_ptr = y.as_ptr();
161
162    let mut sum0 = 0.0;
163    let mut sum1 = 0.0;
164    let mut sum2 = 0.0;
165    let mut sum3 = 0.0;
166
167    // Compute squared distance with loop unrolling
168    for i in 0..chunks {
169        let idx = i * 4;
170        let diff0 = *x_ptr.add(idx) - *y_ptr.add(idx);
171        let diff1 = *x_ptr.add(idx + 1) - *y_ptr.add(idx + 1);
172        let diff2 = *x_ptr.add(idx + 2) - *y_ptr.add(idx + 2);
173        let diff3 = *x_ptr.add(idx + 3) - *y_ptr.add(idx + 3);
174
175        sum0 += diff0 * diff0;
176        sum1 += diff1 * diff1;
177        sum2 += diff2 * diff2;
178        sum3 += diff3 * diff3;
179    }
180
181    let mut sum_remainder = 0.0;
182    for i in 0..remainder {
183        let idx = chunks * 4 + i;
184        let diff = *x_ptr.add(idx) - *y_ptr.add(idx);
185        sum_remainder += diff * diff;
186    }
187
188    let squared_dist = sum0 + sum1 + sum2 + sum3 + sum_remainder;
189    (-gamma * squared_dist).exp()
190}
191
192/// Safe wrapper for dot product with bounds checking
193#[inline]
194pub fn safe_dot_product(a: &[f64], b: &[f64]) -> Option<f64> {
195    if a.len() != b.len() {
196        return None;
197    }
198
199    // Check for NaN values
200    if a.iter().any(|x| x.is_nan()) || b.iter().any(|x| x.is_nan()) {
201        return None;
202    }
203
204    Some(unsafe { dot_product_unrolled(a, b) })
205}
206
207/// Safe wrapper for matrix-vector multiplication
208#[inline]
209pub fn safe_matvec_multiply(matrix: &Array2<f64>, vector: &Array1<f64>) -> Option<Array1<f64>> {
210    let (n_rows, n_cols) = matrix.dim();
211    if n_cols != vector.len() {
212        return None;
213    }
214
215    let mut result = Array1::zeros(n_rows);
216    unsafe {
217        matvec_multiply_fast(
218            matrix,
219            vector.as_slice().unwrap(),
220            result.as_slice_mut().unwrap(),
221        );
222    }
223    Some(result)
224}
225
226/// Batch RBF kernel computation with unsafe optimizations
227///
228/// # Safety
229///
230/// - All matrices must have compatible dimensions
231/// - gamma must be a valid positive f64
232pub unsafe fn batch_rbf_kernel_fast(
233    x_matrix: &Array2<f64>,
234    y_matrix: &Array2<f64>,
235    gamma: f64,
236    output: &mut Array2<f64>,
237) {
238    let (n_x, d_x) = x_matrix.dim();
239    let (n_y, d_y) = y_matrix.dim();
240    let (out_rows, out_cols) = output.dim();
241
242    debug_assert_eq!(d_x, d_y, "Feature dimensions must match");
243    debug_assert_eq!(out_rows, n_x, "Output rows mismatch");
244    debug_assert_eq!(out_cols, n_y, "Output cols mismatch");
245
246    let x_ptr = x_matrix.as_ptr();
247    let y_ptr = y_matrix.as_ptr();
248    let out_ptr = output.as_mut_ptr();
249
250    for i in 0..n_x {
251        for j in 0..n_y {
252            let mut squared_dist = 0.0;
253
254            let x_offset = i * d_x;
255            let y_offset = j * d_y;
256
257            // Compute squared Euclidean distance
258            for k in 0..d_x {
259                let diff = *x_ptr.add(x_offset + k) - *y_ptr.add(y_offset + k);
260                squared_dist += diff * diff;
261            }
262
263            *out_ptr.add(i * n_y + j) = (-gamma * squared_dist).exp();
264        }
265    }
266}
267
268/// Fast cosine features computation for Random Fourier Features
269///
270/// # Safety
271///
272/// - All arrays must be properly sized
273/// - No aliasing between input and output
274#[inline]
275pub unsafe fn fast_cosine_features(
276    projection: &[f64],
277    offset: &[f64],
278    scale: f64,
279    output: &mut [f64],
280) {
281    debug_assert_eq!(projection.len(), offset.len());
282    debug_assert_eq!(projection.len(), output.len());
283
284    let len = projection.len();
285    let proj_ptr = projection.as_ptr();
286    let offset_ptr = offset.as_ptr();
287    let out_ptr = output.as_mut_ptr();
288
289    for i in 0..len {
290        let val = *proj_ptr.add(i) + *offset_ptr.add(i);
291        *out_ptr.add(i) = scale * val.cos();
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use scirs2_core::ndarray::array;
299
300    #[test]
301    fn test_safe_dot_product() {
302        let a = vec![1.0, 2.0, 3.0];
303        let b = vec![4.0, 5.0, 6.0];
304
305        let result = safe_dot_product(&a, &b).unwrap();
306        assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6
307    }
308
309    #[test]
310    fn test_safe_dot_product_length_mismatch() {
311        let a = vec![1.0, 2.0];
312        let b = vec![3.0, 4.0, 5.0];
313
314        assert!(safe_dot_product(&a, &b).is_none());
315    }
316
317    #[test]
318    fn test_safe_dot_product_nan() {
319        let a = vec![1.0, f64::NAN, 3.0];
320        let b = vec![4.0, 5.0, 6.0];
321
322        assert!(safe_dot_product(&a, &b).is_none());
323    }
324
325    #[test]
326    fn test_safe_matvec_multiply() {
327        let matrix = array![[1.0, 2.0], [3.0, 4.0]];
328        let vector = array![5.0, 6.0];
329
330        let result = safe_matvec_multiply(&matrix, &vector).unwrap();
331        assert_eq!(result[0], 17.0); // 1*5 + 2*6
332        assert_eq!(result[1], 39.0); // 3*5 + 4*6
333    }
334
335    #[test]
336    fn test_unsafe_rbf_kernel() {
337        let x = vec![1.0, 2.0, 3.0];
338        let y = vec![1.0, 2.0, 3.0];
339        let gamma = 0.5;
340
341        let result = unsafe { rbf_kernel_fast(&x, &y, gamma) };
342        assert!((result - 1.0).abs() < 1e-10); // Same vectors should give 1.0
343    }
344
345    #[test]
346    fn test_unsafe_rbf_kernel_different() {
347        let x = vec![0.0, 0.0];
348        let y = vec![1.0, 0.0];
349        let gamma = 0.5;
350
351        let result = unsafe { rbf_kernel_fast(&x, &y, gamma) };
352        let expected = (-gamma * 1.0).exp(); // squared distance is 1.0
353        assert!((result - expected).abs() < 1e-10);
354    }
355
356    #[test]
357    fn test_fast_cosine_features() {
358        let projection = vec![0.0, std::f64::consts::PI / 2.0];
359        let offset = vec![0.0, 0.0];
360        let scale = 1.0;
361        let mut output = vec![0.0; 2];
362
363        unsafe {
364            fast_cosine_features(&projection, &offset, scale, &mut output);
365        }
366
367        assert!((output[0] - 1.0).abs() < 1e-10);
368        assert!(output[1].abs() < 1e-10);
369    }
370
371    #[test]
372    fn test_elementwise_op() {
373        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
374        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
375        let mut out = vec![0.0; 5];
376
377        unsafe {
378            elementwise_op_fast(&a, &b, &mut out, |x, y| x + y);
379        }
380
381        assert_eq!(out, vec![3.0, 5.0, 7.0, 9.0, 11.0]);
382    }
383}