scirs2_cluster/vq/
simd_optimizations.rs

1//! SIMD-optimized core clustering operations
2//!
3//! This module provides comprehensive SIMD optimizations for fundamental clustering
4//! operations including distance computations, data preprocessing, vector quantization,
5//! and centroid calculations. All functions provide automatic fallback to scalar
6//! implementations when SIMD is not available.
7
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive, Zero};
10use scirs2_core::parallel_ops::*;
11use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
12use std::fmt::Debug;
13
14use crate::error::{ClusteringError, Result};
15use statrs::statistics::Statistics;
16
17/// Configuration for SIMD optimizations
18#[derive(Debug, Clone)]
19pub struct SimdOptimizationConfig {
20    /// Minimum array size to trigger SIMD optimizations
21    pub simd_threshold: usize,
22    /// Enable parallel processing for large arrays
23    pub enable_parallel: bool,
24    /// Chunk size for parallel processing
25    pub parallel_chunk_size: usize,
26    /// Enable cache-friendly memory access patterns
27    pub cache_friendly: bool,
28    /// Force SIMD usage even for small arrays (for testing)
29    pub force_simd: bool,
30}
31
32impl Default for SimdOptimizationConfig {
33    fn default() -> Self {
34        Self {
35            simd_threshold: 64,
36            enable_parallel: true,
37            parallel_chunk_size: 1024,
38            cache_friendly: true,
39            force_simd: false,
40        }
41    }
42}
43
44/// SIMD-optimized Euclidean distance between two vectors
45///
46/// This function uses SIMD operations when available and beneficial,
47/// automatically falling back to scalar computation for small vectors
48/// or when SIMD is not available.
49///
50/// # Arguments
51///
52/// * `x` - First vector
53/// * `y` - Second vector
54/// * `config` - Optional SIMD configuration
55///
56/// # Returns
57///
58/// * Euclidean distance between the two vectors
59///
60/// # Errors
61///
62/// * Returns error if vectors have different lengths
63#[allow(dead_code)]
64pub fn euclidean_distance_simd<F>(
65    x: ArrayView1<F>,
66    y: ArrayView1<F>,
67    config: Option<&SimdOptimizationConfig>,
68) -> Result<F>
69where
70    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
71{
72    if x.len() != y.len() {
73        return Err(ClusteringError::InvalidInput(format!(
74            "Vectors must have the same length: got {} and {}",
75            x.len(),
76            y.len()
77        )));
78    }
79
80    let default_config = SimdOptimizationConfig::default();
81    let config = config.unwrap_or(&default_config);
82    let caps = PlatformCapabilities::detect();
83    let optimizer = AutoOptimizer::new();
84
85    if (caps.simd_available && (optimizer.should_use_simd(x.len()) || config.force_simd))
86        || x.len() >= config.simd_threshold
87    {
88        let diff = F::simd_sub(&x, &y);
89        Ok(F::simd_norm(&diff.view()))
90    } else {
91        // Scalar fallback
92        let mut sum = F::zero();
93        for i in 0..x.len() {
94            let diff = x[i] - y[i];
95            sum = sum + diff * diff;
96        }
97        Ok(sum.sqrt())
98    }
99}
100
101/// SIMD-optimized data whitening (normalization by standard deviation)
102///
103/// This function normalizes data features by subtracting the mean and dividing by
104/// the standard deviation using SIMD operations for improved performance.
105///
106/// # Arguments
107///
108/// * `obs` - Input data (n_samples × n_features)
109/// * `config` - Optional SIMD configuration
110///
111/// # Returns
112///
113/// * Whitened array with the same shape as input
114#[allow(dead_code)]
115pub fn whiten_simd<F>(obs: &Array2<F>, config: Option<&SimdOptimizationConfig>) -> Result<Array2<F>>
116where
117    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
118{
119    let default_config = SimdOptimizationConfig::default();
120    let config = config.unwrap_or(&default_config);
121    let n_samples = obs.shape()[0];
122    let n_features = obs.shape()[1];
123
124    if n_samples == 0 || n_features == 0 {
125        return Err(ClusteringError::InvalidInput(
126            "Input data cannot be empty".to_string(),
127        ));
128    }
129
130    let caps = PlatformCapabilities::detect();
131    let optimizer = AutoOptimizer::new();
132    let use_simd = caps.simd_available
133        && (optimizer.should_use_simd(n_samples * n_features) || config.force_simd);
134
135    if use_simd && config.enable_parallel && n_features > config.parallel_chunk_size {
136        whiten_simd_parallel(obs, config)
137    } else if use_simd {
138        whiten_simd_sequential(obs)
139    } else {
140        whiten_scalar_fallback(obs)
141    }
142}
143
144/// SIMD-optimized sequential whitening
145#[allow(dead_code)]
146fn whiten_simd_sequential<F>(obs: &Array2<F>) -> Result<Array2<F>>
147where
148    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
149{
150    let n_samples = obs.shape()[0];
151    let n_features = obs.shape()[1];
152    let n_samples_f = F::from(n_samples).unwrap();
153
154    // Calculate means using SIMD operations
155    let mut means = Array1::<F>::zeros(n_features);
156    for j in 0..n_features {
157        let column = obs.column(j);
158        means[j] = F::simd_sum(&column) / n_samples_f;
159    }
160
161    // Calculate standard deviations using SIMD operations
162    let mut stds = Array1::<F>::zeros(n_features);
163    for j in 0..n_features {
164        let column = obs.column(j);
165        let mean_array = Array1::from_elem(n_samples, means[j]);
166        let diff = F::simd_sub(&column, &mean_array.view());
167        let squared_diff = F::simd_mul(&diff.view(), &diff.view());
168        let variance = F::simd_sum(&squared_diff.view()) / F::from(n_samples - 1).unwrap();
169        stds[j] = variance.sqrt();
170
171        // Avoid division by zero
172        if stds[j] < F::from(1e-10).unwrap() {
173            stds[j] = F::one();
174        }
175    }
176
177    // Whiten the data using SIMD operations
178    let mut whitened = Array2::<F>::zeros((n_samples, n_features));
179    for j in 0..n_features {
180        let column = obs.column(j);
181        let mean_array = Array1::from_elem(n_samples, means[j]);
182        let std_array = Array1::from_elem(n_samples, stds[j]);
183
184        let centered = F::simd_sub(&column, &mean_array.view());
185        let normalized = F::simd_div(&centered.view(), &std_array.view());
186
187        for i in 0..n_samples {
188            whitened[[i, j]] = normalized[i];
189        }
190    }
191
192    Ok(whitened)
193}
194
195/// Parallel SIMD-optimized whitening for large datasets
196#[allow(dead_code)]
197fn whiten_simd_parallel<F>(obs: &Array2<F>, config: &SimdOptimizationConfig) -> Result<Array2<F>>
198where
199    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
200{
201    let n_samples = obs.shape()[0];
202    let n_features = obs.shape()[1];
203    let n_samples_f = F::from(n_samples).unwrap();
204
205    // Parallel mean calculation
206    let means: Array1<F> = if is_parallel_enabled() {
207        (0..n_features)
208            .into_par_iter()
209            .map(|j| {
210                let column = obs.column(j);
211                F::simd_sum(&column) / n_samples_f
212            })
213            .collect::<Vec<_>>()
214            .into()
215    } else {
216        let mut means = Array1::<F>::zeros(n_features);
217        for j in 0..n_features {
218            let column = obs.column(j);
219            means[j] = F::simd_sum(&column) / n_samples_f;
220        }
221        means
222    };
223
224    // Parallel standard deviation calculation
225    let stds: Array1<F> = if is_parallel_enabled() {
226        (0..n_features)
227            .into_par_iter()
228            .map(|j| {
229                let column = obs.column(j);
230                let mean_array = Array1::from_elem(n_samples, means[j]);
231                let diff = F::simd_sub(&column, &mean_array.view());
232                let squared_diff = F::simd_mul(&diff.view(), &diff.view());
233                let variance = F::simd_sum(&squared_diff.view()) / F::from(n_samples - 1).unwrap();
234                let std = variance.sqrt();
235
236                // Avoid division by zero
237                if std < F::from(1e-10).unwrap() {
238                    F::one()
239                } else {
240                    std
241                }
242            })
243            .collect::<Vec<_>>()
244            .into()
245    } else {
246        whiten_simd_sequential(obs)?
247            .into_shape((n_samples, n_features))
248            .unwrap();
249        return whiten_simd_sequential(obs);
250    };
251
252    // Parallel whitening
253    let mut whitened = Array2::<F>::zeros((n_samples, n_features));
254
255    if is_parallel_enabled() {
256        // Process features in parallel chunks
257        let chunk_size = config.parallel_chunk_size;
258        let normalized_columns: Vec<Array1<F>> = (0..n_features)
259            .into_par_iter()
260            .map(|j| {
261                let column = obs.column(j);
262                let mean_array = Array1::from_elem(n_samples, means[j]);
263                let std_array = Array1::from_elem(n_samples, stds[j]);
264
265                let centered = F::simd_sub(&column, &mean_array.view());
266                F::simd_div(&centered.view(), &std_array.view())
267            })
268            .collect();
269
270        // Assign the normalized columns to the whitened array
271        for (j, normalized_column) in normalized_columns.iter().enumerate() {
272            for i in 0..n_samples {
273                whitened[[i, j]] = normalized_column[i];
274            }
275        }
276    } else {
277        for j in 0..n_features {
278            let column = obs.column(j);
279            let mean_array = Array1::from_elem(n_samples, means[j]);
280            let std_array = Array1::from_elem(n_samples, stds[j]);
281
282            let centered = F::simd_sub(&column, &mean_array.view());
283            let normalized = F::simd_div(&centered.view(), &std_array.view());
284
285            for i in 0..n_samples {
286                whitened[[i, j]] = normalized[i];
287            }
288        }
289    }
290
291    Ok(whitened)
292}
293
294/// Scalar fallback for whitening when SIMD is not available
295#[allow(dead_code)]
296fn whiten_scalar_fallback<F>(obs: &Array2<F>) -> Result<Array2<F>>
297where
298    F: Float + FromPrimitive + Debug,
299{
300    let n_samples = obs.shape()[0];
301    let n_features = obs.shape()[1];
302
303    // Calculate mean for each feature
304    let mut means = Array1::<F>::zeros(n_features);
305    for j in 0..n_features {
306        let mut sum = F::zero();
307        for i in 0..n_samples {
308            sum = sum + obs[[i, j]];
309        }
310        means[j] = sum / F::from(n_samples).unwrap();
311    }
312
313    // Calculate standard deviation for each feature
314    let mut stds = Array1::<F>::zeros(n_features);
315    for j in 0..n_features {
316        let mut sum = F::zero();
317        for i in 0..n_samples {
318            let diff = obs[[i, j]] - means[j];
319            sum = sum + diff * diff;
320        }
321        stds[j] = (sum / F::from(n_samples - 1).unwrap()).sqrt();
322
323        // Avoid division by zero
324        if stds[j] < F::from(1e-10).unwrap() {
325            stds[j] = F::one();
326        }
327    }
328
329    // Whiten the data
330    let mut whitened = Array2::<F>::zeros((n_samples, n_features));
331    for i in 0..n_samples {
332        for j in 0..n_features {
333            whitened[[i, j]] = (obs[[i, j]] - means[j]) / stds[j];
334        }
335    }
336
337    Ok(whitened)
338}
339
340/// SIMD-optimized vector quantization (assignment to nearest centroids)
341///
342/// This function assigns each data point to its nearest centroid using SIMD
343/// operations for distance calculations.
344///
345/// # Arguments
346///
347/// * `data` - Input data (n_samples × n_features)
348/// * `centroids` - Centroids (n_centroids × n_features)
349/// * `config` - Optional SIMD configuration
350///
351/// # Returns
352///
353/// * Tuple of (labels, distances) where labels are cluster assignments
354///   and distances are distances to the nearest centroid
355#[allow(dead_code)]
356pub fn vq_simd<F>(
357    data: ArrayView2<F>,
358    centroids: ArrayView2<F>,
359    config: Option<&SimdOptimizationConfig>,
360) -> Result<(Array1<usize>, Array1<F>)>
361where
362    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
363{
364    if data.shape()[1] != centroids.shape()[1] {
365        return Err(ClusteringError::InvalidInput(format!(
366            "Data and centroids must have the same number of features: {} vs {}",
367            data.shape()[1],
368            centroids.shape()[1]
369        )));
370    }
371
372    let default_config = SimdOptimizationConfig::default();
373    let config = config.unwrap_or(&default_config);
374    let n_samples = data.shape()[0];
375    let n_centroids = centroids.shape()[0];
376
377    if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
378        vq_simd_parallel(data, centroids, config)
379    } else {
380        vq_simd_sequential(data, centroids, config)
381    }
382}
383
384/// Sequential SIMD-optimized vector quantization
385#[allow(dead_code)]
386fn vq_simd_sequential<F>(
387    data: ArrayView2<F>,
388    centroids: ArrayView2<F>,
389    config: &SimdOptimizationConfig,
390) -> Result<(Array1<usize>, Array1<F>)>
391where
392    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
393{
394    let n_samples = data.shape()[0];
395    let n_centroids = centroids.shape()[0];
396
397    let mut labels = Array1::zeros(n_samples);
398    let mut distances = Array1::zeros(n_samples);
399
400    let caps = PlatformCapabilities::detect();
401    let use_simd = caps.simd_available || config.force_simd;
402
403    for i in 0..n_samples {
404        let point = data.slice(s![i, ..]);
405        let mut min_dist = F::infinity();
406        let mut closest_centroid = 0;
407
408        for j in 0..n_centroids {
409            let centroid = centroids.slice(s![j, ..]);
410
411            let dist = if use_simd {
412                let diff = F::simd_sub(&point, &centroid);
413                F::simd_norm(&diff.view())
414            } else {
415                // Scalar fallback
416                let mut sum = F::zero();
417                for k in 0..point.len() {
418                    let diff = point[k] - centroid[k];
419                    sum = sum + diff * diff;
420                }
421                sum.sqrt()
422            };
423
424            if dist < min_dist {
425                min_dist = dist;
426                closest_centroid = j;
427            }
428        }
429
430        labels[i] = closest_centroid;
431        distances[i] = min_dist;
432    }
433
434    Ok((labels, distances))
435}
436
437/// Parallel SIMD-optimized vector quantization
438#[allow(dead_code)]
439fn vq_simd_parallel<F>(
440    data: ArrayView2<F>,
441    centroids: ArrayView2<F>,
442    config: &SimdOptimizationConfig,
443) -> Result<(Array1<usize>, Array1<F>)>
444where
445    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
446{
447    let n_samples = data.shape()[0];
448    let n_centroids = centroids.shape()[0];
449
450    let caps = PlatformCapabilities::detect();
451    let use_simd = caps.simd_available || config.force_simd;
452
453    // Process samples in parallel
454    let results: Vec<(usize, F)> = (0..n_samples)
455        .into_par_iter()
456        .map(|i| {
457            let point = data.slice(s![i, ..]);
458            let mut min_dist = F::infinity();
459            let mut closest_centroid = 0;
460
461            for j in 0..n_centroids {
462                let centroid = centroids.slice(s![j, ..]);
463
464                let dist = if use_simd {
465                    let diff = F::simd_sub(&point, &centroid);
466                    F::simd_norm(&diff.view())
467                } else {
468                    // Scalar fallback
469                    let mut sum = F::zero();
470                    for k in 0..point.len() {
471                        let diff = point[k] - centroid[k];
472                        sum = sum + diff * diff;
473                    }
474                    sum.sqrt()
475                };
476
477                if dist < min_dist {
478                    min_dist = dist;
479                    closest_centroid = j;
480                }
481            }
482
483            (closest_centroid, min_dist)
484        })
485        .collect();
486
487    let mut labels = Array1::zeros(n_samples);
488    let mut distances = Array1::zeros(n_samples);
489
490    for (i, (label, distance)) in results.into_iter().enumerate() {
491        labels[i] = label;
492        distances[i] = distance;
493    }
494
495    Ok((labels, distances))
496}
497
498/// SIMD-optimized centroid computation for K-means
499///
500/// This function computes new centroids from data points and their cluster assignments
501/// using SIMD operations for improved performance.
502///
503/// # Arguments
504///
505/// * `data` - Input data (n_samples × n_features)
506/// * `labels` - Cluster assignments for each data point
507/// * `k` - Number of clusters
508/// * `config` - Optional SIMD configuration
509///
510/// # Returns
511///
512/// * Array of new centroids (k × n_features)
513#[allow(dead_code)]
514pub fn compute_centroids_simd<F>(
515    data: ArrayView2<F>,
516    labels: &Array1<usize>,
517    k: usize,
518    config: Option<&SimdOptimizationConfig>,
519) -> Result<Array2<F>>
520where
521    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
522{
523    let default_config = SimdOptimizationConfig::default();
524    let config = config.unwrap_or(&default_config);
525    let n_samples = data.shape()[0];
526    let n_features = data.shape()[1];
527
528    if labels.len() != n_samples {
529        return Err(ClusteringError::InvalidInput(
530            "Labels array length must match number of data points".to_string(),
531        ));
532    }
533
534    let caps = PlatformCapabilities::detect();
535    let use_simd = caps.simd_available || config.force_simd;
536
537    if config.enable_parallel && is_parallel_enabled() && k > 4 {
538        compute_centroids_simd_parallel(data, labels, k, use_simd)
539    } else {
540        compute_centroids_simd_sequential(data, labels, k, use_simd)
541    }
542}
543
544/// Sequential SIMD-optimized centroid computation
545#[allow(dead_code)]
546fn compute_centroids_simd_sequential<F>(
547    data: ArrayView2<F>,
548    labels: &Array1<usize>,
549    k: usize,
550    use_simd: bool,
551) -> Result<Array2<F>>
552where
553    F: Float + FromPrimitive + Debug + SimdUnifiedOps + std::iter::Sum,
554{
555    let n_samples = data.shape()[0];
556    let n_features = data.shape()[1];
557
558    let mut centroids = Array2::zeros((k, n_features));
559    let mut counts = Array1::<usize>::zeros(k);
560
561    // Accumulate points for each cluster
562    for i in 0..n_samples {
563        let cluster = labels[i];
564        if cluster >= k {
565            return Err(ClusteringError::InvalidInput(format!(
566                "Label {} exceeds number of clusters {}",
567                cluster, k
568            )));
569        }
570
571        counts[cluster] += 1;
572
573        if use_simd {
574            let point = data.slice(s![i, ..]);
575            let centroid_row = centroids.slice_mut(s![cluster, ..]);
576            let updated_centroid = F::simd_add(&centroid_row.view(), &point);
577            for j in 0..n_features {
578                centroids[[cluster, j]] = updated_centroid[j];
579            }
580        } else {
581            // Scalar fallback
582            for j in 0..n_features {
583                centroids[[cluster, j]] = centroids[[cluster, j]] + data[[i, j]];
584            }
585        }
586    }
587
588    // Normalize by cluster sizes and handle empty clusters
589    for i in 0..k {
590        if counts[i] == 0 {
591            // Handle empty cluster by setting to a random data point
592            if n_samples > 0 {
593                let random_idx = i % n_samples; // Simple deterministic selection
594                for j in 0..n_features {
595                    centroids[[i, j]] = data[[random_idx, j]];
596                }
597            }
598        } else {
599            let count_f = F::from(counts[i]).unwrap();
600            if use_simd {
601                let centroid_row = centroids.slice(s![i, ..]);
602                let count_array = Array1::from_elem(n_features, count_f);
603                let normalized = F::simd_div(&centroid_row, &count_array.view());
604                for j in 0..n_features {
605                    centroids[[i, j]] = normalized[j];
606                }
607            } else {
608                // Scalar fallback
609                for j in 0..n_features {
610                    centroids[[i, j]] = centroids[[i, j]] / count_f;
611                }
612            }
613        }
614    }
615
616    Ok(centroids)
617}
618
619/// Parallel SIMD-optimized centroid computation
620#[allow(dead_code)]
621fn compute_centroids_simd_parallel<F>(
622    data: ArrayView2<F>,
623    labels: &Array1<usize>,
624    k: usize,
625    use_simd: bool,
626) -> Result<Array2<F>>
627where
628    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
629{
630    let n_features = data.shape()[1];
631
632    // Process clusters in parallel
633    let centroids: Vec<Array1<F>> = (0..k)
634        .into_par_iter()
635        .map(|cluster_id| {
636            let mut sum = Array1::zeros(n_features);
637            let mut count = 0;
638
639            // Accumulate points belonging to this cluster
640            for i in 0..data.shape()[0] {
641                if labels[i] == cluster_id {
642                    count += 1;
643                    let point = data.slice(s![i, ..]);
644
645                    if use_simd {
646                        let updated_sum = F::simd_add(&sum.view(), &point);
647                        for j in 0..n_features {
648                            sum[j] = updated_sum[j];
649                        }
650                    } else {
651                        // Scalar fallback
652                        for j in 0..n_features {
653                            sum[j] = sum[j] + point[j];
654                        }
655                    }
656                }
657            }
658
659            // Normalize or handle empty cluster
660            if count == 0 {
661                // Handle empty cluster
662                if data.shape()[0] > 0 {
663                    let random_idx = cluster_id % data.shape()[0];
664                    data.slice(s![random_idx, ..]).to_owned()
665                } else {
666                    sum
667                }
668            } else {
669                let count_f = F::from(count).unwrap();
670                if use_simd {
671                    let count_array = Array1::from_elem(n_features, count_f);
672                    let normalized = F::simd_div(&sum.view(), &count_array.view());
673                    normalized
674                } else {
675                    // Scalar fallback
676                    sum.mapv(|x| x / count_f)
677                }
678            }
679        })
680        .collect();
681
682    // Convert to 2D array
683    let mut result = Array2::zeros((k, n_features));
684    for (i, centroid) in centroids.into_iter().enumerate() {
685        for j in 0..n_features {
686            result[[i, j]] = centroid[j];
687        }
688    }
689
690    Ok(result)
691}
692
693/// SIMD-optimized distortion calculation
694///
695/// Computes the sum of squared distances from data points to their assigned centroids.
696///
697/// # Arguments
698///
699/// * `data` - Input data (n_samples × n_features)
700/// * `centroids` - Cluster centroids (k × n_features)
701/// * `labels` - Cluster assignments for each data point
702/// * `config` - Optional SIMD configuration
703///
704/// # Returns
705///
706/// * Total distortion (sum of squared distances)
707#[allow(dead_code)]
708pub fn calculate_distortion_simd<F>(
709    data: ArrayView2<F>,
710    centroids: ArrayView2<F>,
711    labels: &Array1<usize>,
712    config: Option<&SimdOptimizationConfig>,
713) -> Result<F>
714where
715    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
716{
717    let default_config = SimdOptimizationConfig::default();
718    let config = config.unwrap_or(&default_config);
719    let n_samples = data.shape()[0];
720
721    if labels.len() != n_samples {
722        return Err(ClusteringError::InvalidInput(
723            "Labels array length must match number of data points".to_string(),
724        ));
725    }
726
727    let caps = PlatformCapabilities::detect();
728    let use_simd = caps.simd_available || config.force_simd;
729
730    if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
731        calculate_distortion_simd_parallel(data, centroids, labels, use_simd)
732    } else {
733        calculate_distortion_simd_sequential(data, centroids, labels, use_simd)
734    }
735}
736
737/// Sequential SIMD-optimized distortion calculation
738#[allow(dead_code)]
739fn calculate_distortion_simd_sequential<F>(
740    data: ArrayView2<F>,
741    centroids: ArrayView2<F>,
742    labels: &Array1<usize>,
743    use_simd: bool,
744) -> Result<F>
745where
746    F: Float + FromPrimitive + Debug + SimdUnifiedOps,
747{
748    let n_samples = data.shape()[0];
749    let mut total_distortion = F::zero();
750
751    for i in 0..n_samples {
752        let cluster = labels[i];
753        if cluster >= centroids.shape()[0] {
754            return Err(ClusteringError::InvalidInput(format!(
755                "Label {} exceeds number of centroids {}",
756                cluster,
757                centroids.shape()[0]
758            )));
759        }
760
761        let point = data.slice(s![i, ..]);
762        let centroid = centroids.slice(s![cluster, ..]);
763
764        let squared_distance = if use_simd {
765            let diff = F::simd_sub(&point, &centroid);
766            let squared_diff = F::simd_mul(&diff.view(), &diff.view());
767            F::simd_sum(&squared_diff.view())
768        } else {
769            // Scalar fallback
770            let mut sum = F::zero();
771            for j in 0..point.len() {
772                let diff = point[j] - centroid[j];
773                sum = sum + diff * diff;
774            }
775            sum
776        };
777
778        total_distortion = total_distortion + squared_distance;
779    }
780
781    Ok(total_distortion)
782}
783
784/// Parallel SIMD-optimized distortion calculation
785#[allow(dead_code)]
786fn calculate_distortion_simd_parallel<F>(
787    data: ArrayView2<F>,
788    centroids: ArrayView2<F>,
789    labels: &Array1<usize>,
790    use_simd: bool,
791) -> Result<F>
792where
793    F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
794{
795    let n_samples = data.shape()[0];
796
797    // Validate all labels first
798    for &label in labels.iter() {
799        if label >= centroids.shape()[0] {
800            return Err(ClusteringError::InvalidInput(format!(
801                "Label {} exceeds number of centroids {}",
802                label,
803                centroids.shape()[0]
804            )));
805        }
806    }
807
808    // Compute squared distances in parallel
809    let squared_distances: Vec<F> = (0..n_samples)
810        .into_par_iter()
811        .map(|i| {
812            let cluster = labels[i];
813            let point = data.slice(s![i, ..]);
814            let centroid = centroids.slice(s![cluster, ..]);
815
816            if use_simd {
817                let diff = F::simd_sub(&point, &centroid);
818                let squared_diff = F::simd_mul(&diff.view(), &diff.view());
819                F::simd_sum(&squared_diff.view())
820            } else {
821                // Scalar fallback
822                let mut sum = F::zero();
823                for j in 0..point.len() {
824                    let diff = point[j] - centroid[j];
825                    sum = sum + diff * diff;
826                }
827                sum
828            }
829        })
830        .collect();
831
832    Ok(squared_distances.into_iter().sum())
833}
834
835#[cfg(test)]
836mod tests {
837    use super::*;
838    use approx::assert_abs_diff_eq;
839    use scirs2_core::ndarray::Array2;
840
841    #[test]
842    fn test_euclidean_distance_simd() {
843        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
844        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
845
846        let distance = euclidean_distance_simd(x.view(), y.view(), None).unwrap();
847        let expected = ((4.0 - 1.0).powi(2) + (5.0 - 2.0).powi(2) + (6.0 - 3.0).powi(2)).sqrt();
848
849        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
850    }
851
852    #[test]
853    fn test_whiten_simd() {
854        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 1.5, 2.5, 0.5, 1.5]).unwrap();
855
856        // Use simple config to speed up test
857        let config = SimdOptimizationConfig {
858            enable_parallel: false,
859            force_simd: false,
860            ..Default::default()
861        };
862
863        let whitened = whiten_simd(&data, Some(&config)).unwrap();
864
865        // Check that means are approximately zero
866        let col_means: Vec<f64> = (0..2).map(|j| whitened.column(j).mean()).collect();
867
868        for mean in col_means {
869            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-8);
870        }
871    }
872
873    #[test]
874    #[ignore = "timeout"]
875    fn test_vq_simd() {
876        let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
877
878        let centroids = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.75, 0.75]).unwrap();
879
880        // Use simple config to speed up test
881        let config = SimdOptimizationConfig {
882            enable_parallel: false,
883            force_simd: false,
884            ..Default::default()
885        };
886
887        let (labels, distances) = vq_simd(data.view(), centroids.view(), Some(&config)).unwrap();
888
889        assert_eq!(labels.len(), 3);
890        assert_eq!(distances.len(), 3);
891
892        // Check that all distances are non-negative
893        for &distance in distances.iter() {
894            assert!(distance >= 0.0);
895        }
896    }
897
898    #[test]
899    #[ignore = "timeout"]
900    fn test_compute_centroids_simd() {
901        let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
902
903        let labels = Array1::from_vec(vec![0, 0, 1]);
904
905        // Use simple config to speed up test
906        let config = SimdOptimizationConfig {
907            enable_parallel: false,
908            force_simd: false,
909            ..Default::default()
910        };
911
912        let centroids = compute_centroids_simd(data.view(), &labels, 2, Some(&config)).unwrap();
913
914        assert_eq!(centroids.shape(), &[2, 2]);
915
916        // Centroid 0 should be average of (0,0) and (1,0) = (0.5, 0)
917        assert_abs_diff_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-8);
918        assert_abs_diff_eq!(centroids[[0, 1]], 0.0, epsilon = 1e-8);
919
920        // Centroid 1 should be (0,1) since only one point
921        assert_abs_diff_eq!(centroids[[1, 0]], 0.0, epsilon = 1e-8);
922        assert_abs_diff_eq!(centroids[[1, 1]], 1.0, epsilon = 1e-8);
923    }
924
925    #[test]
926    #[ignore = "timeout"]
927    fn test_calculate_distortion_simd() {
928        let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
929
930        let centroids = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 1.0]).unwrap();
931
932        let labels = Array1::from_vec(vec![0, 0, 1]);
933
934        // Use simple config to speed up test
935        let config = SimdOptimizationConfig {
936            enable_parallel: false,
937            force_simd: false,
938            ..Default::default()
939        };
940
941        let distortion =
942            calculate_distortion_simd(data.view(), centroids.view(), &labels, Some(&config))
943                .unwrap();
944
945        // Calculate expected distortion manually
946        let expected = 0.5 * 0.5 + 0.5 * 0.5 + 0.0; // 0.5
947
948        assert_abs_diff_eq!(distortion, expected, epsilon = 1e-8);
949    }
950}