sklears_kernel_approximation/
unsafe_optimizations.rs1use scirs2_core::ndarray::{Array1, Array2};
12
13#[inline]
27pub unsafe fn dot_product_unrolled(a: &[f64], b: &[f64]) -> f64 {
28 debug_assert_eq!(a.len(), b.len(), "Vectors must have the same length");
29
30 let len = a.len();
31 let chunks = len / 4;
32 let remainder = len % 4;
33
34 let mut sum0 = 0.0;
35 let mut sum1 = 0.0;
36 let mut sum2 = 0.0;
37 let mut sum3 = 0.0;
38
39 let a_ptr = a.as_ptr();
40 let b_ptr = b.as_ptr();
41
42 for i in 0..chunks {
44 let idx = i * 4;
45 sum0 += *a_ptr.add(idx) * *b_ptr.add(idx);
46 sum1 += *a_ptr.add(idx + 1) * *b_ptr.add(idx + 1);
47 sum2 += *a_ptr.add(idx + 2) * *b_ptr.add(idx + 2);
48 sum3 += *a_ptr.add(idx + 3) * *b_ptr.add(idx + 3);
49 }
50
51 let mut sum_remainder = 0.0;
53 for i in 0..remainder {
54 let idx = chunks * 4 + i;
55 sum_remainder += *a_ptr.add(idx) * *b_ptr.add(idx);
56 }
57
58 sum0 + sum1 + sum2 + sum3 + sum_remainder
59}
60
61#[inline]
71pub unsafe fn matvec_multiply_fast(matrix: &Array2<f64>, vector: &[f64], result: &mut [f64]) {
72 let (n_rows, n_cols) = matrix.dim();
73 debug_assert_eq!(n_cols, vector.len(), "Dimension mismatch");
74 debug_assert_eq!(n_rows, result.len(), "Result size mismatch");
75
76 let matrix_ptr = matrix.as_ptr();
77 let vector_ptr = vector.as_ptr();
78 let result_ptr = result.as_mut_ptr();
79
80 for i in 0..n_rows {
81 let row_offset = i * n_cols;
82 let mut sum = 0.0;
83
84 let chunks = n_cols / 4;
86 let remainder = n_cols % 4;
87
88 for j in 0..chunks {
89 let idx = j * 4;
90 sum += *matrix_ptr.add(row_offset + idx) * *vector_ptr.add(idx);
91 sum += *matrix_ptr.add(row_offset + idx + 1) * *vector_ptr.add(idx + 1);
92 sum += *matrix_ptr.add(row_offset + idx + 2) * *vector_ptr.add(idx + 2);
93 sum += *matrix_ptr.add(row_offset + idx + 3) * *vector_ptr.add(idx + 3);
94 }
95
96 for j in 0..remainder {
97 let idx = chunks * 4 + j;
98 sum += *matrix_ptr.add(row_offset + idx) * *vector_ptr.add(idx);
99 }
100
101 *result_ptr.add(i) = sum;
102 }
103}
104
105#[inline]
112pub unsafe fn elementwise_op_fast<F>(a: &[f64], b: &[f64], out: &mut [f64], mut op: F)
113where
114 F: FnMut(f64, f64) -> f64,
115{
116 debug_assert_eq!(a.len(), b.len());
117 debug_assert_eq!(a.len(), out.len());
118
119 let len = a.len();
120 let chunks = len / 4;
121 let remainder = len % 4;
122
123 let a_ptr = a.as_ptr();
124 let b_ptr = b.as_ptr();
125 let out_ptr = out.as_mut_ptr();
126
127 for i in 0..chunks {
129 let idx = i * 4;
130 *out_ptr.add(idx) = op(*a_ptr.add(idx), *b_ptr.add(idx));
131 *out_ptr.add(idx + 1) = op(*a_ptr.add(idx + 1), *b_ptr.add(idx + 1));
132 *out_ptr.add(idx + 2) = op(*a_ptr.add(idx + 2), *b_ptr.add(idx + 2));
133 *out_ptr.add(idx + 3) = op(*a_ptr.add(idx + 3), *b_ptr.add(idx + 3));
134 }
135
136 for i in 0..remainder {
138 let idx = chunks * 4 + i;
139 *out_ptr.add(idx) = op(*a_ptr.add(idx), *b_ptr.add(idx));
140 }
141}
142
143#[inline]
152pub unsafe fn rbf_kernel_fast(x: &[f64], y: &[f64], gamma: f64) -> f64 {
153 debug_assert_eq!(x.len(), y.len());
154
155 let len = x.len();
156 let chunks = len / 4;
157 let remainder = len % 4;
158
159 let x_ptr = x.as_ptr();
160 let y_ptr = y.as_ptr();
161
162 let mut sum0 = 0.0;
163 let mut sum1 = 0.0;
164 let mut sum2 = 0.0;
165 let mut sum3 = 0.0;
166
167 for i in 0..chunks {
169 let idx = i * 4;
170 let diff0 = *x_ptr.add(idx) - *y_ptr.add(idx);
171 let diff1 = *x_ptr.add(idx + 1) - *y_ptr.add(idx + 1);
172 let diff2 = *x_ptr.add(idx + 2) - *y_ptr.add(idx + 2);
173 let diff3 = *x_ptr.add(idx + 3) - *y_ptr.add(idx + 3);
174
175 sum0 += diff0 * diff0;
176 sum1 += diff1 * diff1;
177 sum2 += diff2 * diff2;
178 sum3 += diff3 * diff3;
179 }
180
181 let mut sum_remainder = 0.0;
182 for i in 0..remainder {
183 let idx = chunks * 4 + i;
184 let diff = *x_ptr.add(idx) - *y_ptr.add(idx);
185 sum_remainder += diff * diff;
186 }
187
188 let squared_dist = sum0 + sum1 + sum2 + sum3 + sum_remainder;
189 (-gamma * squared_dist).exp()
190}
191
192#[inline]
194pub fn safe_dot_product(a: &[f64], b: &[f64]) -> Option<f64> {
195 if a.len() != b.len() {
196 return None;
197 }
198
199 if a.iter().any(|x| x.is_nan()) || b.iter().any(|x| x.is_nan()) {
201 return None;
202 }
203
204 Some(unsafe { dot_product_unrolled(a, b) })
205}
206
207#[inline]
209pub fn safe_matvec_multiply(matrix: &Array2<f64>, vector: &Array1<f64>) -> Option<Array1<f64>> {
210 let (n_rows, n_cols) = matrix.dim();
211 if n_cols != vector.len() {
212 return None;
213 }
214
215 let mut result = Array1::zeros(n_rows);
216 unsafe {
217 matvec_multiply_fast(
218 matrix,
219 vector.as_slice().unwrap(),
220 result.as_slice_mut().unwrap(),
221 );
222 }
223 Some(result)
224}
225
226pub unsafe fn batch_rbf_kernel_fast(
233 x_matrix: &Array2<f64>,
234 y_matrix: &Array2<f64>,
235 gamma: f64,
236 output: &mut Array2<f64>,
237) {
238 let (n_x, d_x) = x_matrix.dim();
239 let (n_y, d_y) = y_matrix.dim();
240 let (out_rows, out_cols) = output.dim();
241
242 debug_assert_eq!(d_x, d_y, "Feature dimensions must match");
243 debug_assert_eq!(out_rows, n_x, "Output rows mismatch");
244 debug_assert_eq!(out_cols, n_y, "Output cols mismatch");
245
246 let x_ptr = x_matrix.as_ptr();
247 let y_ptr = y_matrix.as_ptr();
248 let out_ptr = output.as_mut_ptr();
249
250 for i in 0..n_x {
251 for j in 0..n_y {
252 let mut squared_dist = 0.0;
253
254 let x_offset = i * d_x;
255 let y_offset = j * d_y;
256
257 for k in 0..d_x {
259 let diff = *x_ptr.add(x_offset + k) - *y_ptr.add(y_offset + k);
260 squared_dist += diff * diff;
261 }
262
263 *out_ptr.add(i * n_y + j) = (-gamma * squared_dist).exp();
264 }
265 }
266}
267
268#[inline]
275pub unsafe fn fast_cosine_features(
276 projection: &[f64],
277 offset: &[f64],
278 scale: f64,
279 output: &mut [f64],
280) {
281 debug_assert_eq!(projection.len(), offset.len());
282 debug_assert_eq!(projection.len(), output.len());
283
284 let len = projection.len();
285 let proj_ptr = projection.as_ptr();
286 let offset_ptr = offset.as_ptr();
287 let out_ptr = output.as_mut_ptr();
288
289 for i in 0..len {
290 let val = *proj_ptr.add(i) + *offset_ptr.add(i);
291 *out_ptr.add(i) = scale * val.cos();
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use scirs2_core::ndarray::array;
299
300 #[test]
301 fn test_safe_dot_product() {
302 let a = vec![1.0, 2.0, 3.0];
303 let b = vec![4.0, 5.0, 6.0];
304
305 let result = safe_dot_product(&a, &b).unwrap();
306 assert_eq!(result, 32.0); }
308
309 #[test]
310 fn test_safe_dot_product_length_mismatch() {
311 let a = vec![1.0, 2.0];
312 let b = vec![3.0, 4.0, 5.0];
313
314 assert!(safe_dot_product(&a, &b).is_none());
315 }
316
317 #[test]
318 fn test_safe_dot_product_nan() {
319 let a = vec![1.0, f64::NAN, 3.0];
320 let b = vec![4.0, 5.0, 6.0];
321
322 assert!(safe_dot_product(&a, &b).is_none());
323 }
324
325 #[test]
326 fn test_safe_matvec_multiply() {
327 let matrix = array![[1.0, 2.0], [3.0, 4.0]];
328 let vector = array![5.0, 6.0];
329
330 let result = safe_matvec_multiply(&matrix, &vector).unwrap();
331 assert_eq!(result[0], 17.0); assert_eq!(result[1], 39.0); }
334
335 #[test]
336 fn test_unsafe_rbf_kernel() {
337 let x = vec![1.0, 2.0, 3.0];
338 let y = vec![1.0, 2.0, 3.0];
339 let gamma = 0.5;
340
341 let result = unsafe { rbf_kernel_fast(&x, &y, gamma) };
342 assert!((result - 1.0).abs() < 1e-10); }
344
345 #[test]
346 fn test_unsafe_rbf_kernel_different() {
347 let x = vec![0.0, 0.0];
348 let y = vec![1.0, 0.0];
349 let gamma = 0.5;
350
351 let result = unsafe { rbf_kernel_fast(&x, &y, gamma) };
352 let expected = (-gamma * 1.0).exp(); assert!((result - expected).abs() < 1e-10);
354 }
355
356 #[test]
357 fn test_fast_cosine_features() {
358 let projection = vec![0.0, std::f64::consts::PI / 2.0];
359 let offset = vec![0.0, 0.0];
360 let scale = 1.0;
361 let mut output = vec![0.0; 2];
362
363 unsafe {
364 fast_cosine_features(&projection, &offset, scale, &mut output);
365 }
366
367 assert!((output[0] - 1.0).abs() < 1e-10);
368 assert!(output[1].abs() < 1e-10);
369 }
370
371 #[test]
372 fn test_elementwise_op() {
373 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
374 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
375 let mut out = vec![0.0; 5];
376
377 unsafe {
378 elementwise_op_fast(&a, &b, &mut out, |x, y| x + y);
379 }
380
381 assert_eq!(out, vec![3.0, 5.0, 7.0, 9.0, 11.0]);
382 }
383}