scirs2_cluster/vq/
distance_simd.rs

1//! SIMD-accelerated distance computations for clustering algorithms
2//!
3//! This module provides highly optimized distance calculations using the unified
4//! SIMD operations from scirs2-core, with fallbacks to standard implementations.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::parallel_ops::*;
9use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
10use std::fmt::Debug;
11
12/// Memory-efficient configuration for SIMD operations
13#[derive(Debug, Clone)]
14pub struct SimdConfig {
15    /// Chunk size for memory-efficient processing
16    pub chunk_size: usize,
17    /// Enable memory prefetching
18    pub enable_prefetch: bool,
19    /// Use cache-friendly algorithms
20    pub cache_friendly: bool,
21    /// Block size for blocked algorithms
22    pub block_size: usize,
23}
24
25impl Default for SimdConfig {
26    fn default() -> Self {
27        Self {
28            chunk_size: 1024,
29            enable_prefetch: true,
30            cache_friendly: true,
31            block_size: 256,
32        }
33    }
34}
35
36/// Memory-efficient blocked distance computation for large datasets
37///
38/// This function uses cache-friendly blocking to compute distances efficiently
39/// for datasets that don't fit in cache.
40#[allow(dead_code)]
41pub fn pairwise_euclidean_blocked<F>(data: ArrayView2<F>, config: Option<SimdConfig>) -> Array1<F>
42where
43    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
44{
45    let config = config.unwrap_or_default();
46    let n_samples = data.shape()[0];
47    let _n_features = data.shape()[1];
48    let n_distances = n_samples * (n_samples - 1) / 2;
49    let mut distances = Array1::zeros(n_distances);
50
51    let caps = PlatformCapabilities::detect();
52
53    if caps.simd_available && config.cache_friendly {
54        pairwise_euclidean_blocked_simd(data, &mut distances, &config);
55    } else {
56        pairwise_euclidean_standard(data, &mut distances);
57    }
58
59    distances
60}
61
62/// Cache-friendly blocked SIMD implementation
63#[allow(dead_code)]
64fn pairwise_euclidean_blocked_simd<F>(
65    data: ArrayView2<F>,
66    distances: &mut Array1<F>,
67    config: &SimdConfig,
68) where
69    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
70{
71    let n_samples = data.shape()[0];
72    let block_size = config.block_size;
73
74    let mut idx = 0;
75
76    // Process data in blocks to improve cache efficiency
77    for block_i in (0..n_samples).step_by(block_size) {
78        let end_i = (block_i + block_size).min(n_samples);
79
80        for block_j in (block_i..n_samples).step_by(block_size) {
81            let end_j = (block_j + block_size).min(n_samples);
82
83            // Process block [block_i..end_i) × [block_j..end_j)
84            for i in block_i..end_i {
85                let start_j = if block_i == block_j { i + 1 } else { block_j };
86
87                for j in start_j..end_j {
88                    let row_i = data.row(i);
89                    let row_j = data.row(j);
90
91                    // Use SIMD operations with prefetching if enabled
92                    if config.enable_prefetch && j + 1 < end_j {
93                        // Prefetch next row for better memory access patterns
94                        std::hint::spin_loop(); // Simplified prefetch simulation
95                    }
96
97                    let diff = F::simd_sub(&row_i, &row_j);
98                    let distance = F::simd_norm(&diff.view());
99
100                    distances[idx] = distance;
101                    idx += 1;
102                }
103            }
104        }
105    }
106}
107
108/// Streaming distance computation for out-of-core datasets
109///
110/// This function computes distances in streaming fashion, suitable for
111/// datasets that don't fit in memory.
112#[allow(dead_code)]
113pub fn pairwise_euclidean_streaming<'a, F>(
114    data_chunks: impl Iterator<Item = ArrayView2<'a, F>>,
115    chunk_size: usize,
116) -> Array1<F>
117where
118    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + 'a,
119{
120    // Remove unused variable
121    let mut total_samples = 0;
122    let mut data_cache = Vec::new();
123
124    // First pass: collect data and count samples
125    for chunk in data_chunks {
126        total_samples += chunk.nrows();
127        data_cache.push(chunk.to_owned());
128    }
129
130    let n_distances = total_samples * (total_samples - 1) / 2;
131    let mut distances = Array1::zeros(n_distances);
132    let mut idx = 0;
133
134    // Second pass: compute distances between _chunks
135    for (chunk_i, data_i) in data_cache.iter().enumerate() {
136        for (chunk_j, data_j) in data_cache.iter().enumerate().skip(chunk_i) {
137            if chunk_i == chunk_j {
138                // Intra-chunk distances
139                idx += compute_intra_chunk_distances(data_i.view(), &mut distances, idx);
140            } else {
141                // Inter-chunk distances
142                idx += compute_inter_chunk_distances(
143                    data_i.view(),
144                    data_j.view(),
145                    &mut distances,
146                    idx,
147                );
148            }
149        }
150    }
151
152    distances
153}
154
155/// Compute distances within a single chunk
156#[allow(dead_code)]
157fn compute_intra_chunk_distances<F>(
158    chunk: ArrayView2<F>,
159    distances: &mut Array1<F>,
160    start_idx: usize,
161) -> usize
162where
163    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
164{
165    let n_samples = chunk.nrows();
166    let mut _idx = start_idx;
167
168    for i in 0..n_samples {
169        for j in (i + 1)..n_samples {
170            let row_i = chunk.row(i);
171            let row_j = chunk.row(j);
172
173            let diff = F::simd_sub(&row_i, &row_j);
174            let distance = F::simd_norm(&diff.view());
175
176            distances[_idx] = distance;
177            _idx += 1;
178        }
179    }
180
181    _idx - start_idx
182}
183
184/// Compute distances between two chunks
185#[allow(dead_code)]
186fn compute_inter_chunk_distances<F>(
187    chunk_i: ArrayView2<F>,
188    chunk_j: ArrayView2<F>,
189    distances: &mut Array1<F>,
190    start_idx: usize,
191) -> usize
192where
193    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
194{
195    let n_samples_i = chunk_i.nrows();
196    let n_samples_j = chunk_j.nrows();
197    let mut _idx = start_idx;
198
199    for _i in 0..n_samples_i {
200        for _j in 0..n_samples_j {
201            let row_i = chunk_i.row(_i);
202            let row_j = chunk_j.row(_j);
203
204            let diff = F::simd_sub(&row_i, &row_j);
205            let distance = F::simd_norm(&diff.view());
206
207            distances[_idx] = distance;
208            _idx += 1;
209        }
210    }
211
212    _idx - start_idx
213}
214
215/// Compute Euclidean distances between all pairs of points using SIMD when available
216///
217/// # Arguments
218///
219/// * `data` - Input data (n_samples × n_features)
220///
221/// # Returns
222///
223/// * Condensed distance matrix as a 1D array
224#[allow(dead_code)]
225pub fn pairwise_euclidean_simd<F>(data: ArrayView2<F>) -> Array1<F>
226where
227    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
228{
229    let n_samples = data.shape()[0];
230    let n_features = data.shape()[1];
231    let n_distances = n_samples * (n_samples - 1) / 2;
232    let mut distances = Array1::zeros(n_distances);
233
234    let caps = PlatformCapabilities::detect();
235    let optimizer = AutoOptimizer::new();
236
237    if caps.simd_available && optimizer.should_use_simd(n_samples * n_features) {
238        pairwise_euclidean_simd_optimized(data, &mut distances);
239    } else {
240        pairwise_euclidean_standard(data, &mut distances);
241    }
242
243    distances
244}
245
246/// Standard pairwise Euclidean distance computation
247#[allow(dead_code)]
248fn pairwise_euclidean_standard<F>(data: ArrayView2<F>, distances: &mut Array1<F>)
249where
250    F: Float + FromPrimitive + Debug,
251{
252    let n_samples = data.shape()[0];
253    let n_features = data.shape()[1];
254
255    let mut idx = 0;
256    for i in 0..n_samples {
257        for j in (i + 1)..n_samples {
258            let mut sum_sq = F::zero();
259            for k in 0..n_features {
260                let diff = data[[i, k]] - data[[j, k]];
261                sum_sq = sum_sq + diff * diff;
262            }
263            distances[idx] = sum_sq.sqrt();
264            idx += 1;
265        }
266    }
267}
268
269/// SIMD-optimized pairwise Euclidean distance computation using unified operations
270#[allow(dead_code)]
271fn pairwise_euclidean_simd_optimized<F>(data: ArrayView2<F>, distances: &mut Array1<F>)
272where
273    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
274{
275    let n_samples = data.shape()[0];
276
277    let mut idx = 0;
278    for i in 0..n_samples {
279        for j in (i + 1)..n_samples {
280            let row_i = data.row(i);
281            let row_j = data.row(j);
282
283            // Use SIMD operations for vector subtraction and norm calculation
284            let diff = F::simd_sub(&row_i, &row_j);
285            let distance = F::simd_norm(&diff.view());
286
287            distances[idx] = distance;
288            idx += 1;
289        }
290    }
291}
292
293/// Compute distances from each point to a set of centroids using SIMD
294///
295/// # Arguments
296///
297/// * `data` - Input data (n_samples × n_features)
298/// * `centroids` - Cluster centroids (n_clusters × n_features)
299///
300/// # Returns
301///
302/// * Distance matrix (n_samples × n_clusters)
303///
304/// # Errors
305///
306/// * Returns error if data and centroids have different numbers of features
307#[allow(dead_code)]
308pub fn distance_to_centroids_simd<F>(
309    data: ArrayView2<F>,
310    centroids: ArrayView2<F>,
311) -> Result<Array2<F>, String>
312where
313    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
314{
315    let n_samples = data.shape()[0];
316    let n_clusters = centroids.shape()[0];
317    let n_features = data.shape()[1];
318
319    if centroids.shape()[1] != n_features {
320        return Err(format!(
321            "Data and centroids must have the same number of features: data has {}, centroids have {}",
322            n_features, centroids.shape()[1]
323        ));
324    }
325
326    let mut distances = Array2::zeros((n_samples, n_clusters));
327
328    let caps = PlatformCapabilities::detect();
329    let optimizer = AutoOptimizer::new();
330
331    if caps.simd_available && optimizer.should_use_simd(n_samples * n_features) {
332        distance_to_centroids_simd_optimized(data, centroids, &mut distances);
333    } else {
334        distance_to_centroids_standard(data, centroids, &mut distances);
335    }
336
337    Ok(distances)
338}
339
340/// Standard distance to centroids computation
341#[allow(dead_code)]
342fn distance_to_centroids_standard<F>(
343    data: ArrayView2<F>,
344    centroids: ArrayView2<F>,
345    distances: &mut Array2<F>,
346) where
347    F: Float + FromPrimitive + Debug,
348{
349    let n_samples = data.shape()[0];
350    let n_clusters = centroids.shape()[0];
351    let n_features = data.shape()[1];
352
353    for i in 0..n_samples {
354        for j in 0..n_clusters {
355            let mut sum_sq = F::zero();
356            for k in 0..n_features {
357                let diff = data[[i, k]] - centroids[[j, k]];
358                sum_sq = sum_sq + diff * diff;
359            }
360            distances[[i, j]] = sum_sq.sqrt();
361        }
362    }
363}
364
365/// SIMD-optimized distance to centroids computation using unified operations
366#[allow(dead_code)]
367fn distance_to_centroids_simd_optimized<F>(
368    data: ArrayView2<F>,
369    centroids: ArrayView2<F>,
370    distances: &mut Array2<F>,
371) where
372    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
373{
374    let n_samples = data.shape()[0];
375    let n_clusters = centroids.shape()[0];
376
377    for i in 0..n_samples {
378        for j in 0..n_clusters {
379            let data_row = data.row(i);
380            let centroid_row = centroids.row(j);
381
382            // Use SIMD operations for vector subtraction and norm calculation
383            let diff = F::simd_sub(&data_row, &centroid_row);
384            let distance = F::simd_norm(&diff.view());
385
386            distances[[i, j]] = distance;
387        }
388    }
389}
390
391/// Parallel distance matrix computation using core parallel operations
392///
393/// # Arguments
394///
395/// * `data` - Input data (n_samples × n_features)
396///
397/// # Returns
398///
399/// * Condensed distance matrix
400#[allow(dead_code)]
401pub fn pairwise_euclidean_parallel<F>(data: ArrayView2<F>) -> Array1<F>
402where
403    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
404{
405    let n_samples = data.shape()[0];
406    let n_distances = n_samples * (n_samples - 1) / 2;
407
408    // Create index pairs
409    let mut pairs = Vec::with_capacity(n_distances);
410    for i in 0..n_samples {
411        for j in (i + 1)..n_samples {
412            pairs.push((i, j));
413        }
414    }
415
416    // Use parallel operations from core
417    if is_parallel_enabled() && pairs.len() > 100 {
418        // Compute distances in parallel using core abstractions
419        let distances: Vec<F> = pairs
420            .into_par_iter()
421            .map(|(i, j)| {
422                let row_i = data.row(i);
423                let row_j = data.row(j);
424
425                // Use SIMD operations for distance calculation
426                let diff = F::simd_sub(&row_i, &row_j);
427                F::simd_norm(&diff.view())
428            })
429            .collect();
430        Array1::from_vec(distances)
431    } else {
432        // Fallback to SIMD version for small problems or when parallel is disabled
433        pairwise_euclidean_simd(data)
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use approx::assert_abs_diff_eq;
441    use scirs2_core::ndarray::Array2;
442
443    #[test]
444    fn test_pairwise_euclidean_simd() {
445        let data =
446            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
447
448        let distances = pairwise_euclidean_simd(data.view());
449
450        // Expected distances: (0,1)=1.0, (0,2)=1.0, (0,3)=√2, (1,2)=√2, (1,3)=1.0, (2,3)=1.0
451        assert_eq!(distances.len(), 6);
452        assert_abs_diff_eq!(distances[0], 1.0, epsilon = 1e-10); // (0,1)
453        assert_abs_diff_eq!(distances[1], 1.0, epsilon = 1e-10); // (0,2)
454        assert_abs_diff_eq!(distances[2], 2.0_f64.sqrt(), epsilon = 1e-10); // (0,3)
455        assert_abs_diff_eq!(distances[3], 2.0_f64.sqrt(), epsilon = 1e-10); // (1,2)
456        assert_abs_diff_eq!(distances[4], 1.0, epsilon = 1e-10); // (1,3)
457        assert_abs_diff_eq!(distances[5], 1.0, epsilon = 1e-10); // (2,3)
458    }
459
460    #[test]
461    fn test_distance_to_centroids_simd() {
462        let data =
463            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
464
465        let centroids = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.5, 1.0]).unwrap();
466
467        let distances = distance_to_centroids_simd(data.view(), centroids.view()).unwrap();
468
469        assert_eq!(distances.shape(), &[4, 2]);
470
471        // Check some expected distances
472        assert_abs_diff_eq!(distances[[0, 0]], 0.5, epsilon = 1e-10); // (0,0) to centroid 0
473        assert_abs_diff_eq!(distances[[3, 1]], 0.5, epsilon = 1e-10); // (1,1) to centroid 1
474    }
475
476    #[test]
477    fn test_parallel_vs_standard() {
478        let data = Array2::from_shape_vec(
479            (6, 3),
480            vec![
481                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
482                9.0, 10.0,
483            ],
484        )
485        .unwrap();
486
487        let distances_simd = pairwise_euclidean_simd(data.view());
488        let distances_parallel = pairwise_euclidean_parallel(data.view());
489
490        assert_eq!(distances_simd.len(), distances_parallel.len());
491
492        for i in 0..distances_simd.len() {
493            assert_abs_diff_eq!(distances_simd[i], distances_parallel[i], epsilon = 1e-10);
494        }
495    }
496}