1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::parallel_ops::*;
9use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
10use std::fmt::Debug;
11
12#[derive(Debug, Clone)]
14pub struct SimdConfig {
15 pub chunk_size: usize,
17 pub enable_prefetch: bool,
19 pub cache_friendly: bool,
21 pub block_size: usize,
23}
24
25impl Default for SimdConfig {
26 fn default() -> Self {
27 Self {
28 chunk_size: 1024,
29 enable_prefetch: true,
30 cache_friendly: true,
31 block_size: 256,
32 }
33 }
34}
35
36#[allow(dead_code)]
41pub fn pairwise_euclidean_blocked<F>(data: ArrayView2<F>, config: Option<SimdConfig>) -> Array1<F>
42where
43 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
44{
45 let config = config.unwrap_or_default();
46 let n_samples = data.shape()[0];
47 let _n_features = data.shape()[1];
48 let n_distances = n_samples * (n_samples - 1) / 2;
49 let mut distances = Array1::zeros(n_distances);
50
51 let caps = PlatformCapabilities::detect();
52
53 if caps.simd_available && config.cache_friendly {
54 pairwise_euclidean_blocked_simd(data, &mut distances, &config);
55 } else {
56 pairwise_euclidean_standard(data, &mut distances);
57 }
58
59 distances
60}
61
62#[allow(dead_code)]
64fn pairwise_euclidean_blocked_simd<F>(
65 data: ArrayView2<F>,
66 distances: &mut Array1<F>,
67 config: &SimdConfig,
68) where
69 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
70{
71 let n_samples = data.shape()[0];
72 let block_size = config.block_size;
73
74 let mut idx = 0;
75
76 for block_i in (0..n_samples).step_by(block_size) {
78 let end_i = (block_i + block_size).min(n_samples);
79
80 for block_j in (block_i..n_samples).step_by(block_size) {
81 let end_j = (block_j + block_size).min(n_samples);
82
83 for i in block_i..end_i {
85 let start_j = if block_i == block_j { i + 1 } else { block_j };
86
87 for j in start_j..end_j {
88 let row_i = data.row(i);
89 let row_j = data.row(j);
90
91 if config.enable_prefetch && j + 1 < end_j {
93 std::hint::spin_loop(); }
96
97 let diff = F::simd_sub(&row_i, &row_j);
98 let distance = F::simd_norm(&diff.view());
99
100 distances[idx] = distance;
101 idx += 1;
102 }
103 }
104 }
105 }
106}
107
108#[allow(dead_code)]
113pub fn pairwise_euclidean_streaming<'a, F>(
114 data_chunks: impl Iterator<Item = ArrayView2<'a, F>>,
115 chunk_size: usize,
116) -> Array1<F>
117where
118 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + 'a,
119{
120 let mut total_samples = 0;
122 let mut data_cache = Vec::new();
123
124 for chunk in data_chunks {
126 total_samples += chunk.nrows();
127 data_cache.push(chunk.to_owned());
128 }
129
130 let n_distances = total_samples * (total_samples - 1) / 2;
131 let mut distances = Array1::zeros(n_distances);
132 let mut idx = 0;
133
134 for (chunk_i, data_i) in data_cache.iter().enumerate() {
136 for (chunk_j, data_j) in data_cache.iter().enumerate().skip(chunk_i) {
137 if chunk_i == chunk_j {
138 idx += compute_intra_chunk_distances(data_i.view(), &mut distances, idx);
140 } else {
141 idx += compute_inter_chunk_distances(
143 data_i.view(),
144 data_j.view(),
145 &mut distances,
146 idx,
147 );
148 }
149 }
150 }
151
152 distances
153}
154
155#[allow(dead_code)]
157fn compute_intra_chunk_distances<F>(
158 chunk: ArrayView2<F>,
159 distances: &mut Array1<F>,
160 start_idx: usize,
161) -> usize
162where
163 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
164{
165 let n_samples = chunk.nrows();
166 let mut _idx = start_idx;
167
168 for i in 0..n_samples {
169 for j in (i + 1)..n_samples {
170 let row_i = chunk.row(i);
171 let row_j = chunk.row(j);
172
173 let diff = F::simd_sub(&row_i, &row_j);
174 let distance = F::simd_norm(&diff.view());
175
176 distances[_idx] = distance;
177 _idx += 1;
178 }
179 }
180
181 _idx - start_idx
182}
183
184#[allow(dead_code)]
186fn compute_inter_chunk_distances<F>(
187 chunk_i: ArrayView2<F>,
188 chunk_j: ArrayView2<F>,
189 distances: &mut Array1<F>,
190 start_idx: usize,
191) -> usize
192where
193 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
194{
195 let n_samples_i = chunk_i.nrows();
196 let n_samples_j = chunk_j.nrows();
197 let mut _idx = start_idx;
198
199 for _i in 0..n_samples_i {
200 for _j in 0..n_samples_j {
201 let row_i = chunk_i.row(_i);
202 let row_j = chunk_j.row(_j);
203
204 let diff = F::simd_sub(&row_i, &row_j);
205 let distance = F::simd_norm(&diff.view());
206
207 distances[_idx] = distance;
208 _idx += 1;
209 }
210 }
211
212 _idx - start_idx
213}
214
215#[allow(dead_code)]
225pub fn pairwise_euclidean_simd<F>(data: ArrayView2<F>) -> Array1<F>
226where
227 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
228{
229 let n_samples = data.shape()[0];
230 let n_features = data.shape()[1];
231 let n_distances = n_samples * (n_samples - 1) / 2;
232 let mut distances = Array1::zeros(n_distances);
233
234 let caps = PlatformCapabilities::detect();
235 let optimizer = AutoOptimizer::new();
236
237 if caps.simd_available && optimizer.should_use_simd(n_samples * n_features) {
238 pairwise_euclidean_simd_optimized(data, &mut distances);
239 } else {
240 pairwise_euclidean_standard(data, &mut distances);
241 }
242
243 distances
244}
245
246#[allow(dead_code)]
248fn pairwise_euclidean_standard<F>(data: ArrayView2<F>, distances: &mut Array1<F>)
249where
250 F: Float + FromPrimitive + Debug,
251{
252 let n_samples = data.shape()[0];
253 let n_features = data.shape()[1];
254
255 let mut idx = 0;
256 for i in 0..n_samples {
257 for j in (i + 1)..n_samples {
258 let mut sum_sq = F::zero();
259 for k in 0..n_features {
260 let diff = data[[i, k]] - data[[j, k]];
261 sum_sq = sum_sq + diff * diff;
262 }
263 distances[idx] = sum_sq.sqrt();
264 idx += 1;
265 }
266 }
267}
268
269#[allow(dead_code)]
271fn pairwise_euclidean_simd_optimized<F>(data: ArrayView2<F>, distances: &mut Array1<F>)
272where
273 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
274{
275 let n_samples = data.shape()[0];
276
277 let mut idx = 0;
278 for i in 0..n_samples {
279 for j in (i + 1)..n_samples {
280 let row_i = data.row(i);
281 let row_j = data.row(j);
282
283 let diff = F::simd_sub(&row_i, &row_j);
285 let distance = F::simd_norm(&diff.view());
286
287 distances[idx] = distance;
288 idx += 1;
289 }
290 }
291}
292
293#[allow(dead_code)]
308pub fn distance_to_centroids_simd<F>(
309 data: ArrayView2<F>,
310 centroids: ArrayView2<F>,
311) -> Result<Array2<F>, String>
312where
313 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
314{
315 let n_samples = data.shape()[0];
316 let n_clusters = centroids.shape()[0];
317 let n_features = data.shape()[1];
318
319 if centroids.shape()[1] != n_features {
320 return Err(format!(
321 "Data and centroids must have the same number of features: data has {}, centroids have {}",
322 n_features, centroids.shape()[1]
323 ));
324 }
325
326 let mut distances = Array2::zeros((n_samples, n_clusters));
327
328 let caps = PlatformCapabilities::detect();
329 let optimizer = AutoOptimizer::new();
330
331 if caps.simd_available && optimizer.should_use_simd(n_samples * n_features) {
332 distance_to_centroids_simd_optimized(data, centroids, &mut distances);
333 } else {
334 distance_to_centroids_standard(data, centroids, &mut distances);
335 }
336
337 Ok(distances)
338}
339
340#[allow(dead_code)]
342fn distance_to_centroids_standard<F>(
343 data: ArrayView2<F>,
344 centroids: ArrayView2<F>,
345 distances: &mut Array2<F>,
346) where
347 F: Float + FromPrimitive + Debug,
348{
349 let n_samples = data.shape()[0];
350 let n_clusters = centroids.shape()[0];
351 let n_features = data.shape()[1];
352
353 for i in 0..n_samples {
354 for j in 0..n_clusters {
355 let mut sum_sq = F::zero();
356 for k in 0..n_features {
357 let diff = data[[i, k]] - centroids[[j, k]];
358 sum_sq = sum_sq + diff * diff;
359 }
360 distances[[i, j]] = sum_sq.sqrt();
361 }
362 }
363}
364
365#[allow(dead_code)]
367fn distance_to_centroids_simd_optimized<F>(
368 data: ArrayView2<F>,
369 centroids: ArrayView2<F>,
370 distances: &mut Array2<F>,
371) where
372 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
373{
374 let n_samples = data.shape()[0];
375 let n_clusters = centroids.shape()[0];
376
377 for i in 0..n_samples {
378 for j in 0..n_clusters {
379 let data_row = data.row(i);
380 let centroid_row = centroids.row(j);
381
382 let diff = F::simd_sub(&data_row, ¢roid_row);
384 let distance = F::simd_norm(&diff.view());
385
386 distances[[i, j]] = distance;
387 }
388 }
389}
390
391#[allow(dead_code)]
401pub fn pairwise_euclidean_parallel<F>(data: ArrayView2<F>) -> Array1<F>
402where
403 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
404{
405 let n_samples = data.shape()[0];
406 let n_distances = n_samples * (n_samples - 1) / 2;
407
408 let mut pairs = Vec::with_capacity(n_distances);
410 for i in 0..n_samples {
411 for j in (i + 1)..n_samples {
412 pairs.push((i, j));
413 }
414 }
415
416 if is_parallel_enabled() && pairs.len() > 100 {
418 let distances: Vec<F> = pairs
420 .into_par_iter()
421 .map(|(i, j)| {
422 let row_i = data.row(i);
423 let row_j = data.row(j);
424
425 let diff = F::simd_sub(&row_i, &row_j);
427 F::simd_norm(&diff.view())
428 })
429 .collect();
430 Array1::from_vec(distances)
431 } else {
432 pairwise_euclidean_simd(data)
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use approx::assert_abs_diff_eq;
441 use scirs2_core::ndarray::Array2;
442
443 #[test]
444 fn test_pairwise_euclidean_simd() {
445 let data =
446 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
447
448 let distances = pairwise_euclidean_simd(data.view());
449
450 assert_eq!(distances.len(), 6);
452 assert_abs_diff_eq!(distances[0], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(distances[1], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(distances[2], 2.0_f64.sqrt(), epsilon = 1e-10); assert_abs_diff_eq!(distances[3], 2.0_f64.sqrt(), epsilon = 1e-10); assert_abs_diff_eq!(distances[4], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(distances[5], 1.0, epsilon = 1e-10); }
459
460 #[test]
461 fn test_distance_to_centroids_simd() {
462 let data =
463 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
464
465 let centroids = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.5, 1.0]).unwrap();
466
467 let distances = distance_to_centroids_simd(data.view(), centroids.view()).unwrap();
468
469 assert_eq!(distances.shape(), &[4, 2]);
470
471 assert_abs_diff_eq!(distances[[0, 0]], 0.5, epsilon = 1e-10); assert_abs_diff_eq!(distances[[3, 1]], 0.5, epsilon = 1e-10); }
475
476 #[test]
477 fn test_parallel_vs_standard() {
478 let data = Array2::from_shape_vec(
479 (6, 3),
480 vec![
481 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
482 9.0, 10.0,
483 ],
484 )
485 .unwrap();
486
487 let distances_simd = pairwise_euclidean_simd(data.view());
488 let distances_parallel = pairwise_euclidean_parallel(data.view());
489
490 assert_eq!(distances_simd.len(), distances_parallel.len());
491
492 for i in 0..distances_simd.len() {
493 assert_abs_diff_eq!(distances_simd[i], distances_parallel[i], epsilon = 1e-10);
494 }
495 }
496}