1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive, Zero};
10use scirs2_core::parallel_ops::*;
11use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
12use std::fmt::Debug;
13
14use crate::error::{ClusteringError, Result};
15use statrs::statistics::Statistics;
16
17#[derive(Debug, Clone)]
19pub struct SimdOptimizationConfig {
20 pub simd_threshold: usize,
22 pub enable_parallel: bool,
24 pub parallel_chunk_size: usize,
26 pub cache_friendly: bool,
28 pub force_simd: bool,
30}
31
32impl Default for SimdOptimizationConfig {
33 fn default() -> Self {
34 Self {
35 simd_threshold: 64,
36 enable_parallel: true,
37 parallel_chunk_size: 1024,
38 cache_friendly: true,
39 force_simd: false,
40 }
41 }
42}
43
44#[allow(dead_code)]
64pub fn euclidean_distance_simd<F>(
65 x: ArrayView1<F>,
66 y: ArrayView1<F>,
67 config: Option<&SimdOptimizationConfig>,
68) -> Result<F>
69where
70 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
71{
72 if x.len() != y.len() {
73 return Err(ClusteringError::InvalidInput(format!(
74 "Vectors must have the same length: got {} and {}",
75 x.len(),
76 y.len()
77 )));
78 }
79
80 let default_config = SimdOptimizationConfig::default();
81 let config = config.unwrap_or(&default_config);
82 let caps = PlatformCapabilities::detect();
83 let optimizer = AutoOptimizer::new();
84
85 if (caps.simd_available && (optimizer.should_use_simd(x.len()) || config.force_simd))
86 || x.len() >= config.simd_threshold
87 {
88 let diff = F::simd_sub(&x, &y);
89 Ok(F::simd_norm(&diff.view()))
90 } else {
91 let mut sum = F::zero();
93 for i in 0..x.len() {
94 let diff = x[i] - y[i];
95 sum = sum + diff * diff;
96 }
97 Ok(sum.sqrt())
98 }
99}
100
101#[allow(dead_code)]
115pub fn whiten_simd<F>(obs: &Array2<F>, config: Option<&SimdOptimizationConfig>) -> Result<Array2<F>>
116where
117 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
118{
119 let default_config = SimdOptimizationConfig::default();
120 let config = config.unwrap_or(&default_config);
121 let n_samples = obs.shape()[0];
122 let n_features = obs.shape()[1];
123
124 if n_samples == 0 || n_features == 0 {
125 return Err(ClusteringError::InvalidInput(
126 "Input data cannot be empty".to_string(),
127 ));
128 }
129
130 let caps = PlatformCapabilities::detect();
131 let optimizer = AutoOptimizer::new();
132 let use_simd = caps.simd_available
133 && (optimizer.should_use_simd(n_samples * n_features) || config.force_simd);
134
135 if use_simd && config.enable_parallel && n_features > config.parallel_chunk_size {
136 whiten_simd_parallel(obs, config)
137 } else if use_simd {
138 whiten_simd_sequential(obs)
139 } else {
140 whiten_scalar_fallback(obs)
141 }
142}
143
144#[allow(dead_code)]
146fn whiten_simd_sequential<F>(obs: &Array2<F>) -> Result<Array2<F>>
147where
148 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
149{
150 let n_samples = obs.shape()[0];
151 let n_features = obs.shape()[1];
152 let n_samples_f = F::from(n_samples).unwrap();
153
154 let mut means = Array1::<F>::zeros(n_features);
156 for j in 0..n_features {
157 let column = obs.column(j);
158 means[j] = F::simd_sum(&column) / n_samples_f;
159 }
160
161 let mut stds = Array1::<F>::zeros(n_features);
163 for j in 0..n_features {
164 let column = obs.column(j);
165 let mean_array = Array1::from_elem(n_samples, means[j]);
166 let diff = F::simd_sub(&column, &mean_array.view());
167 let squared_diff = F::simd_mul(&diff.view(), &diff.view());
168 let variance = F::simd_sum(&squared_diff.view()) / F::from(n_samples - 1).unwrap();
169 stds[j] = variance.sqrt();
170
171 if stds[j] < F::from(1e-10).unwrap() {
173 stds[j] = F::one();
174 }
175 }
176
177 let mut whitened = Array2::<F>::zeros((n_samples, n_features));
179 for j in 0..n_features {
180 let column = obs.column(j);
181 let mean_array = Array1::from_elem(n_samples, means[j]);
182 let std_array = Array1::from_elem(n_samples, stds[j]);
183
184 let centered = F::simd_sub(&column, &mean_array.view());
185 let normalized = F::simd_div(¢ered.view(), &std_array.view());
186
187 for i in 0..n_samples {
188 whitened[[i, j]] = normalized[i];
189 }
190 }
191
192 Ok(whitened)
193}
194
195#[allow(dead_code)]
197fn whiten_simd_parallel<F>(obs: &Array2<F>, config: &SimdOptimizationConfig) -> Result<Array2<F>>
198where
199 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
200{
201 let n_samples = obs.shape()[0];
202 let n_features = obs.shape()[1];
203 let n_samples_f = F::from(n_samples).unwrap();
204
205 let means: Array1<F> = if is_parallel_enabled() {
207 (0..n_features)
208 .into_par_iter()
209 .map(|j| {
210 let column = obs.column(j);
211 F::simd_sum(&column) / n_samples_f
212 })
213 .collect::<Vec<_>>()
214 .into()
215 } else {
216 let mut means = Array1::<F>::zeros(n_features);
217 for j in 0..n_features {
218 let column = obs.column(j);
219 means[j] = F::simd_sum(&column) / n_samples_f;
220 }
221 means
222 };
223
224 let stds: Array1<F> = if is_parallel_enabled() {
226 (0..n_features)
227 .into_par_iter()
228 .map(|j| {
229 let column = obs.column(j);
230 let mean_array = Array1::from_elem(n_samples, means[j]);
231 let diff = F::simd_sub(&column, &mean_array.view());
232 let squared_diff = F::simd_mul(&diff.view(), &diff.view());
233 let variance = F::simd_sum(&squared_diff.view()) / F::from(n_samples - 1).unwrap();
234 let std = variance.sqrt();
235
236 if std < F::from(1e-10).unwrap() {
238 F::one()
239 } else {
240 std
241 }
242 })
243 .collect::<Vec<_>>()
244 .into()
245 } else {
246 whiten_simd_sequential(obs)?
247 .into_shape((n_samples, n_features))
248 .unwrap();
249 return whiten_simd_sequential(obs);
250 };
251
252 let mut whitened = Array2::<F>::zeros((n_samples, n_features));
254
255 if is_parallel_enabled() {
256 let chunk_size = config.parallel_chunk_size;
258 let normalized_columns: Vec<Array1<F>> = (0..n_features)
259 .into_par_iter()
260 .map(|j| {
261 let column = obs.column(j);
262 let mean_array = Array1::from_elem(n_samples, means[j]);
263 let std_array = Array1::from_elem(n_samples, stds[j]);
264
265 let centered = F::simd_sub(&column, &mean_array.view());
266 F::simd_div(¢ered.view(), &std_array.view())
267 })
268 .collect();
269
270 for (j, normalized_column) in normalized_columns.iter().enumerate() {
272 for i in 0..n_samples {
273 whitened[[i, j]] = normalized_column[i];
274 }
275 }
276 } else {
277 for j in 0..n_features {
278 let column = obs.column(j);
279 let mean_array = Array1::from_elem(n_samples, means[j]);
280 let std_array = Array1::from_elem(n_samples, stds[j]);
281
282 let centered = F::simd_sub(&column, &mean_array.view());
283 let normalized = F::simd_div(¢ered.view(), &std_array.view());
284
285 for i in 0..n_samples {
286 whitened[[i, j]] = normalized[i];
287 }
288 }
289 }
290
291 Ok(whitened)
292}
293
294#[allow(dead_code)]
296fn whiten_scalar_fallback<F>(obs: &Array2<F>) -> Result<Array2<F>>
297where
298 F: Float + FromPrimitive + Debug,
299{
300 let n_samples = obs.shape()[0];
301 let n_features = obs.shape()[1];
302
303 let mut means = Array1::<F>::zeros(n_features);
305 for j in 0..n_features {
306 let mut sum = F::zero();
307 for i in 0..n_samples {
308 sum = sum + obs[[i, j]];
309 }
310 means[j] = sum / F::from(n_samples).unwrap();
311 }
312
313 let mut stds = Array1::<F>::zeros(n_features);
315 for j in 0..n_features {
316 let mut sum = F::zero();
317 for i in 0..n_samples {
318 let diff = obs[[i, j]] - means[j];
319 sum = sum + diff * diff;
320 }
321 stds[j] = (sum / F::from(n_samples - 1).unwrap()).sqrt();
322
323 if stds[j] < F::from(1e-10).unwrap() {
325 stds[j] = F::one();
326 }
327 }
328
329 let mut whitened = Array2::<F>::zeros((n_samples, n_features));
331 for i in 0..n_samples {
332 for j in 0..n_features {
333 whitened[[i, j]] = (obs[[i, j]] - means[j]) / stds[j];
334 }
335 }
336
337 Ok(whitened)
338}
339
340#[allow(dead_code)]
356pub fn vq_simd<F>(
357 data: ArrayView2<F>,
358 centroids: ArrayView2<F>,
359 config: Option<&SimdOptimizationConfig>,
360) -> Result<(Array1<usize>, Array1<F>)>
361where
362 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
363{
364 if data.shape()[1] != centroids.shape()[1] {
365 return Err(ClusteringError::InvalidInput(format!(
366 "Data and centroids must have the same number of features: {} vs {}",
367 data.shape()[1],
368 centroids.shape()[1]
369 )));
370 }
371
372 let default_config = SimdOptimizationConfig::default();
373 let config = config.unwrap_or(&default_config);
374 let n_samples = data.shape()[0];
375 let n_centroids = centroids.shape()[0];
376
377 if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
378 vq_simd_parallel(data, centroids, config)
379 } else {
380 vq_simd_sequential(data, centroids, config)
381 }
382}
383
384#[allow(dead_code)]
386fn vq_simd_sequential<F>(
387 data: ArrayView2<F>,
388 centroids: ArrayView2<F>,
389 config: &SimdOptimizationConfig,
390) -> Result<(Array1<usize>, Array1<F>)>
391where
392 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
393{
394 let n_samples = data.shape()[0];
395 let n_centroids = centroids.shape()[0];
396
397 let mut labels = Array1::zeros(n_samples);
398 let mut distances = Array1::zeros(n_samples);
399
400 let caps = PlatformCapabilities::detect();
401 let use_simd = caps.simd_available || config.force_simd;
402
403 for i in 0..n_samples {
404 let point = data.slice(s![i, ..]);
405 let mut min_dist = F::infinity();
406 let mut closest_centroid = 0;
407
408 for j in 0..n_centroids {
409 let centroid = centroids.slice(s![j, ..]);
410
411 let dist = if use_simd {
412 let diff = F::simd_sub(&point, ¢roid);
413 F::simd_norm(&diff.view())
414 } else {
415 let mut sum = F::zero();
417 for k in 0..point.len() {
418 let diff = point[k] - centroid[k];
419 sum = sum + diff * diff;
420 }
421 sum.sqrt()
422 };
423
424 if dist < min_dist {
425 min_dist = dist;
426 closest_centroid = j;
427 }
428 }
429
430 labels[i] = closest_centroid;
431 distances[i] = min_dist;
432 }
433
434 Ok((labels, distances))
435}
436
437#[allow(dead_code)]
439fn vq_simd_parallel<F>(
440 data: ArrayView2<F>,
441 centroids: ArrayView2<F>,
442 config: &SimdOptimizationConfig,
443) -> Result<(Array1<usize>, Array1<F>)>
444where
445 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
446{
447 let n_samples = data.shape()[0];
448 let n_centroids = centroids.shape()[0];
449
450 let caps = PlatformCapabilities::detect();
451 let use_simd = caps.simd_available || config.force_simd;
452
453 let results: Vec<(usize, F)> = (0..n_samples)
455 .into_par_iter()
456 .map(|i| {
457 let point = data.slice(s![i, ..]);
458 let mut min_dist = F::infinity();
459 let mut closest_centroid = 0;
460
461 for j in 0..n_centroids {
462 let centroid = centroids.slice(s![j, ..]);
463
464 let dist = if use_simd {
465 let diff = F::simd_sub(&point, ¢roid);
466 F::simd_norm(&diff.view())
467 } else {
468 let mut sum = F::zero();
470 for k in 0..point.len() {
471 let diff = point[k] - centroid[k];
472 sum = sum + diff * diff;
473 }
474 sum.sqrt()
475 };
476
477 if dist < min_dist {
478 min_dist = dist;
479 closest_centroid = j;
480 }
481 }
482
483 (closest_centroid, min_dist)
484 })
485 .collect();
486
487 let mut labels = Array1::zeros(n_samples);
488 let mut distances = Array1::zeros(n_samples);
489
490 for (i, (label, distance)) in results.into_iter().enumerate() {
491 labels[i] = label;
492 distances[i] = distance;
493 }
494
495 Ok((labels, distances))
496}
497
498#[allow(dead_code)]
514pub fn compute_centroids_simd<F>(
515 data: ArrayView2<F>,
516 labels: &Array1<usize>,
517 k: usize,
518 config: Option<&SimdOptimizationConfig>,
519) -> Result<Array2<F>>
520where
521 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
522{
523 let default_config = SimdOptimizationConfig::default();
524 let config = config.unwrap_or(&default_config);
525 let n_samples = data.shape()[0];
526 let n_features = data.shape()[1];
527
528 if labels.len() != n_samples {
529 return Err(ClusteringError::InvalidInput(
530 "Labels array length must match number of data points".to_string(),
531 ));
532 }
533
534 let caps = PlatformCapabilities::detect();
535 let use_simd = caps.simd_available || config.force_simd;
536
537 if config.enable_parallel && is_parallel_enabled() && k > 4 {
538 compute_centroids_simd_parallel(data, labels, k, use_simd)
539 } else {
540 compute_centroids_simd_sequential(data, labels, k, use_simd)
541 }
542}
543
544#[allow(dead_code)]
546fn compute_centroids_simd_sequential<F>(
547 data: ArrayView2<F>,
548 labels: &Array1<usize>,
549 k: usize,
550 use_simd: bool,
551) -> Result<Array2<F>>
552where
553 F: Float + FromPrimitive + Debug + SimdUnifiedOps + std::iter::Sum,
554{
555 let n_samples = data.shape()[0];
556 let n_features = data.shape()[1];
557
558 let mut centroids = Array2::zeros((k, n_features));
559 let mut counts = Array1::<usize>::zeros(k);
560
561 for i in 0..n_samples {
563 let cluster = labels[i];
564 if cluster >= k {
565 return Err(ClusteringError::InvalidInput(format!(
566 "Label {} exceeds number of clusters {}",
567 cluster, k
568 )));
569 }
570
571 counts[cluster] += 1;
572
573 if use_simd {
574 let point = data.slice(s![i, ..]);
575 let centroid_row = centroids.slice_mut(s![cluster, ..]);
576 let updated_centroid = F::simd_add(¢roid_row.view(), &point);
577 for j in 0..n_features {
578 centroids[[cluster, j]] = updated_centroid[j];
579 }
580 } else {
581 for j in 0..n_features {
583 centroids[[cluster, j]] = centroids[[cluster, j]] + data[[i, j]];
584 }
585 }
586 }
587
588 for i in 0..k {
590 if counts[i] == 0 {
591 if n_samples > 0 {
593 let random_idx = i % n_samples; for j in 0..n_features {
595 centroids[[i, j]] = data[[random_idx, j]];
596 }
597 }
598 } else {
599 let count_f = F::from(counts[i]).unwrap();
600 if use_simd {
601 let centroid_row = centroids.slice(s![i, ..]);
602 let count_array = Array1::from_elem(n_features, count_f);
603 let normalized = F::simd_div(¢roid_row, &count_array.view());
604 for j in 0..n_features {
605 centroids[[i, j]] = normalized[j];
606 }
607 } else {
608 for j in 0..n_features {
610 centroids[[i, j]] = centroids[[i, j]] / count_f;
611 }
612 }
613 }
614 }
615
616 Ok(centroids)
617}
618
619#[allow(dead_code)]
621fn compute_centroids_simd_parallel<F>(
622 data: ArrayView2<F>,
623 labels: &Array1<usize>,
624 k: usize,
625 use_simd: bool,
626) -> Result<Array2<F>>
627where
628 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
629{
630 let n_features = data.shape()[1];
631
632 let centroids: Vec<Array1<F>> = (0..k)
634 .into_par_iter()
635 .map(|cluster_id| {
636 let mut sum = Array1::zeros(n_features);
637 let mut count = 0;
638
639 for i in 0..data.shape()[0] {
641 if labels[i] == cluster_id {
642 count += 1;
643 let point = data.slice(s![i, ..]);
644
645 if use_simd {
646 let updated_sum = F::simd_add(&sum.view(), &point);
647 for j in 0..n_features {
648 sum[j] = updated_sum[j];
649 }
650 } else {
651 for j in 0..n_features {
653 sum[j] = sum[j] + point[j];
654 }
655 }
656 }
657 }
658
659 if count == 0 {
661 if data.shape()[0] > 0 {
663 let random_idx = cluster_id % data.shape()[0];
664 data.slice(s![random_idx, ..]).to_owned()
665 } else {
666 sum
667 }
668 } else {
669 let count_f = F::from(count).unwrap();
670 if use_simd {
671 let count_array = Array1::from_elem(n_features, count_f);
672 let normalized = F::simd_div(&sum.view(), &count_array.view());
673 normalized
674 } else {
675 sum.mapv(|x| x / count_f)
677 }
678 }
679 })
680 .collect();
681
682 let mut result = Array2::zeros((k, n_features));
684 for (i, centroid) in centroids.into_iter().enumerate() {
685 for j in 0..n_features {
686 result[[i, j]] = centroid[j];
687 }
688 }
689
690 Ok(result)
691}
692
693#[allow(dead_code)]
708pub fn calculate_distortion_simd<F>(
709 data: ArrayView2<F>,
710 centroids: ArrayView2<F>,
711 labels: &Array1<usize>,
712 config: Option<&SimdOptimizationConfig>,
713) -> Result<F>
714where
715 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
716{
717 let default_config = SimdOptimizationConfig::default();
718 let config = config.unwrap_or(&default_config);
719 let n_samples = data.shape()[0];
720
721 if labels.len() != n_samples {
722 return Err(ClusteringError::InvalidInput(
723 "Labels array length must match number of data points".to_string(),
724 ));
725 }
726
727 let caps = PlatformCapabilities::detect();
728 let use_simd = caps.simd_available || config.force_simd;
729
730 if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
731 calculate_distortion_simd_parallel(data, centroids, labels, use_simd)
732 } else {
733 calculate_distortion_simd_sequential(data, centroids, labels, use_simd)
734 }
735}
736
737#[allow(dead_code)]
739fn calculate_distortion_simd_sequential<F>(
740 data: ArrayView2<F>,
741 centroids: ArrayView2<F>,
742 labels: &Array1<usize>,
743 use_simd: bool,
744) -> Result<F>
745where
746 F: Float + FromPrimitive + Debug + SimdUnifiedOps,
747{
748 let n_samples = data.shape()[0];
749 let mut total_distortion = F::zero();
750
751 for i in 0..n_samples {
752 let cluster = labels[i];
753 if cluster >= centroids.shape()[0] {
754 return Err(ClusteringError::InvalidInput(format!(
755 "Label {} exceeds number of centroids {}",
756 cluster,
757 centroids.shape()[0]
758 )));
759 }
760
761 let point = data.slice(s![i, ..]);
762 let centroid = centroids.slice(s![cluster, ..]);
763
764 let squared_distance = if use_simd {
765 let diff = F::simd_sub(&point, ¢roid);
766 let squared_diff = F::simd_mul(&diff.view(), &diff.view());
767 F::simd_sum(&squared_diff.view())
768 } else {
769 let mut sum = F::zero();
771 for j in 0..point.len() {
772 let diff = point[j] - centroid[j];
773 sum = sum + diff * diff;
774 }
775 sum
776 };
777
778 total_distortion = total_distortion + squared_distance;
779 }
780
781 Ok(total_distortion)
782}
783
784#[allow(dead_code)]
786fn calculate_distortion_simd_parallel<F>(
787 data: ArrayView2<F>,
788 centroids: ArrayView2<F>,
789 labels: &Array1<usize>,
790 use_simd: bool,
791) -> Result<F>
792where
793 F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
794{
795 let n_samples = data.shape()[0];
796
797 for &label in labels.iter() {
799 if label >= centroids.shape()[0] {
800 return Err(ClusteringError::InvalidInput(format!(
801 "Label {} exceeds number of centroids {}",
802 label,
803 centroids.shape()[0]
804 )));
805 }
806 }
807
808 let squared_distances: Vec<F> = (0..n_samples)
810 .into_par_iter()
811 .map(|i| {
812 let cluster = labels[i];
813 let point = data.slice(s![i, ..]);
814 let centroid = centroids.slice(s![cluster, ..]);
815
816 if use_simd {
817 let diff = F::simd_sub(&point, ¢roid);
818 let squared_diff = F::simd_mul(&diff.view(), &diff.view());
819 F::simd_sum(&squared_diff.view())
820 } else {
821 let mut sum = F::zero();
823 for j in 0..point.len() {
824 let diff = point[j] - centroid[j];
825 sum = sum + diff * diff;
826 }
827 sum
828 }
829 })
830 .collect();
831
832 Ok(squared_distances.into_iter().sum())
833}
834
835#[cfg(test)]
836mod tests {
837 use super::*;
838 use approx::assert_abs_diff_eq;
839 use scirs2_core::ndarray::Array2;
840
841 #[test]
842 fn test_euclidean_distance_simd() {
843 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
844 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
845
846 let distance = euclidean_distance_simd(x.view(), y.view(), None).unwrap();
847 let expected = ((4.0 - 1.0).powi(2) + (5.0 - 2.0).powi(2) + (6.0 - 3.0).powi(2)).sqrt();
848
849 assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
850 }
851
852 #[test]
853 fn test_whiten_simd() {
854 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 1.5, 2.5, 0.5, 1.5]).unwrap();
855
856 let config = SimdOptimizationConfig {
858 enable_parallel: false,
859 force_simd: false,
860 ..Default::default()
861 };
862
863 let whitened = whiten_simd(&data, Some(&config)).unwrap();
864
865 let col_means: Vec<f64> = (0..2).map(|j| whitened.column(j).mean()).collect();
867
868 for mean in col_means {
869 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-8);
870 }
871 }
872
873 #[test]
874 #[ignore = "timeout"]
875 fn test_vq_simd() {
876 let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
877
878 let centroids = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.75, 0.75]).unwrap();
879
880 let config = SimdOptimizationConfig {
882 enable_parallel: false,
883 force_simd: false,
884 ..Default::default()
885 };
886
887 let (labels, distances) = vq_simd(data.view(), centroids.view(), Some(&config)).unwrap();
888
889 assert_eq!(labels.len(), 3);
890 assert_eq!(distances.len(), 3);
891
892 for &distance in distances.iter() {
894 assert!(distance >= 0.0);
895 }
896 }
897
898 #[test]
899 #[ignore = "timeout"]
900 fn test_compute_centroids_simd() {
901 let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
902
903 let labels = Array1::from_vec(vec![0, 0, 1]);
904
905 let config = SimdOptimizationConfig {
907 enable_parallel: false,
908 force_simd: false,
909 ..Default::default()
910 };
911
912 let centroids = compute_centroids_simd(data.view(), &labels, 2, Some(&config)).unwrap();
913
914 assert_eq!(centroids.shape(), &[2, 2]);
915
916 assert_abs_diff_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-8);
918 assert_abs_diff_eq!(centroids[[0, 1]], 0.0, epsilon = 1e-8);
919
920 assert_abs_diff_eq!(centroids[[1, 0]], 0.0, epsilon = 1e-8);
922 assert_abs_diff_eq!(centroids[[1, 1]], 1.0, epsilon = 1e-8);
923 }
924
925 #[test]
926 #[ignore = "timeout"]
927 fn test_calculate_distortion_simd() {
928 let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
929
930 let centroids = Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 1.0]).unwrap();
931
932 let labels = Array1::from_vec(vec![0, 0, 1]);
933
934 let config = SimdOptimizationConfig {
936 enable_parallel: false,
937 force_simd: false,
938 ..Default::default()
939 };
940
941 let distortion =
942 calculate_distortion_simd(data.view(), centroids.view(), &labels, Some(&config))
943 .unwrap();
944
945 let expected = 0.5 * 0.5 + 0.5 * 0.5 + 0.0; assert_abs_diff_eq!(distortion, expected, epsilon = 1e-8);
949 }
950}