1use crate::error::{SpatialError, SpatialResult};
57use scirs2_core::ndarray::ArrayStatCompat;
58use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
59use scirs2_core::random::Rng;
60use statrs::statistics::Statistics;
61use std::collections::{HashMap, VecDeque};
62use std::sync::{Arc, Mutex};
63use std::time::{Duration, Instant};
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub enum PrecisionMode {
68 Full32,
70 Mixed16,
72 BrainFloat16,
74 Int8Dynamic,
76 Int4Advanced,
78 Adaptive,
80 AdvancedAdaptive,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq)]
86pub enum StabilityLevel {
87 Excellent,
89 Good,
91 Moderate,
93 Poor,
95 Critical,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
101pub enum NumericalErrorType {
102 Overflow,
104 Underflow,
106 PrecisionLoss,
108 ConvergenceFailure,
110 IllConditioned,
112 InvalidValues,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq)]
118pub enum ScalingStrategy {
119 Conservative,
121 Balanced,
123 Aggressive,
125 Custom,
127}
128
129#[derive(Debug, Clone, Copy, PartialEq)]
131pub enum TensorLayout {
132 RowMajor,
134 ColMajor,
136 Blocked,
138 ZOrder,
140 HardwareOptimized,
142}
143
144#[derive(Debug, Clone, Copy, PartialEq)]
146pub enum GpuArchitecture {
147 Volta,
149 Ampere,
151 Hopper,
153 CDNA2,
155 CDNA3,
157 XeHPC,
159 XeGraphics,
161 Unknown,
163}
164
165#[derive(Debug, Clone)]
167pub struct TensorCoreCapabilities {
168 pub tensor_core_types: Vec<TensorCoreType>,
170 pub supported_precisions: Vec<PrecisionMode>,
172 pub max_tensor_size: (usize, usize, usize),
174 pub peak_throughput_tops: f64,
176 pub memory_bandwidth_gbps: f64,
178 pub l2_cache_mb: f64,
180 pub num_sms: usize,
182 pub architecture: GpuArchitecture,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq)]
188pub enum TensorCoreType {
189 NvidiaTensorCore,
191 AmdMatrixCore,
193 IntelXMX,
195 StandardCores,
197}
198
199#[derive(Debug, Clone)]
201pub struct StabilityMetrics {
202 pub condition_number: f64,
204 pub relative_error: f64,
206 pub forward_error: f64,
208 pub backward_error: f64,
210 pub digit_loss: f64,
212 pub stability_level: StabilityLevel,
214 pub error_types: Vec<NumericalErrorType>,
216 pub timestamp: Instant,
218}
219
220#[derive(Debug, Clone)]
222pub struct DynamicPrecisionConfig {
223 pub strategy: ScalingStrategy,
225 pub min_precision: PrecisionMode,
227 pub max_precision: PrecisionMode,
229 pub stability_threshold_up: f64,
231 pub stability_threshold_down: f64,
233 pub performance_weight: f64,
235 pub accuracy_weight: f64,
237 pub max_changes_per_operation: usize,
239 pub change_cooldown: Duration,
241}
242
243#[allow(dead_code)]
245#[derive(Debug)]
246pub struct NumericalStabilityMonitor {
247 current_metrics: StabilityMetrics,
249 stability_history: VecDeque<StabilityMetrics>,
251 precision_config: DynamicPrecisionConfig,
253 current_precision: PrecisionMode,
255 precision_history: VecDeque<(Instant, PrecisionMode, f64)>,
257 #[allow(dead_code)]
259 recovery_attempts: usize,
260 max_history_length: usize,
262 last_precision_change: Option<Instant>,
264}
265
266#[allow(dead_code)]
268#[derive(Debug)]
269pub struct ErrorRecoverySystem {
270 recovery_strategies: HashMap<NumericalErrorType, Vec<RecoveryAction>>,
272 recovery_history: VecDeque<RecoveryAttempt>,
274 max_recovery_attempts: usize,
276 success_rates: HashMap<RecoveryAction, f64>,
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
282pub enum RecoveryAction {
283 IncreasePrecision,
285 ReduceTileSize,
287 FallbackAlgorithm,
289 NumericalStabilization,
291 RetryWithNewParams,
293 SwitchToCPU,
295}
296
297#[derive(Debug, Clone)]
299pub struct RecoveryAttempt {
300 pub error_type: NumericalErrorType,
302 pub action: RecoveryAction,
304 pub success: bool,
306 pub duration: Duration,
308 pub post_recovery_metrics: Option<StabilityMetrics>,
310 pub timestamp: Instant,
312}
313
314#[derive(Debug)]
316pub struct PerformanceAccuracyAnalyzer {
317 performance_data: HashMap<PrecisionMode, VecDeque<Duration>>,
319 accuracy_data: HashMap<PrecisionMode, VecDeque<f64>>,
321 optimization_params: TradeOffParams,
323 pareto_frontier: Vec<(f64, f64, PrecisionMode)>, }
326
327#[derive(Debug, Clone)]
329pub struct TradeOffParams {
330 pub performance_weight: f64,
332 pub accuracy_weight: f64,
334 pub energy_weight: f64,
336 pub min_accuracy: f64,
338 pub max_time: Duration,
340 pub objective: OptimizationObjective,
342}
343
344#[derive(Debug, Clone, Copy, PartialEq)]
346pub enum OptimizationObjective {
347 MaxPerformance,
349 MaxAccuracy,
351 Balanced,
353 MinEnergy,
355 Custom,
357}
358
359#[derive(Debug)]
361pub struct AdvancedTensorCoreDistanceMatrix {
362 base_computer: TensorCoreDistanceMatrix,
364 stability_monitor: Arc<Mutex<NumericalStabilityMonitor>>,
366 recovery_system: ErrorRecoverySystem,
368 performance_analyzer: PerformanceAccuracyAnalyzer,
370 dynamic_precision_enabled: bool,
372 auto_recovery_enabled: bool,
374}
375
376#[derive(Debug, Clone)]
378pub struct TensorCoreDistanceMatrix {
379 precision_mode: PrecisionMode,
381 layout_optimization: bool,
383 hierarchical_tiling: bool,
385 tile_size: (usize, usize),
387 capabilities: Option<TensorCoreCapabilities>,
389 tensor_layout: TensorLayout,
391 execution_streams: usize,
393}
394
395impl TensorCoreDistanceMatrix {
396 pub fn new() -> SpatialResult<Self> {
398 let capabilities = detect_tensor_core_capabilities()?;
399
400 Ok(Self {
401 precision_mode: PrecisionMode::Mixed16,
402 layout_optimization: true,
403 hierarchical_tiling: true,
404 tile_size: (256, 256),
405 capabilities: Some(capabilities),
406 tensor_layout: TensorLayout::HardwareOptimized,
407 execution_streams: 4,
408 })
409 }
410
411 pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
413 self.precision_mode = mode;
414 self
415 }
416
417 pub fn with_tensor_layout_optimization(mut self, enabled: bool) -> Self {
419 self.layout_optimization = enabled;
420 self
421 }
422
423 pub fn with_hierarchical_tiling(mut self, enabled: bool) -> Self {
425 self.hierarchical_tiling = enabled;
426 self
427 }
428
429 pub fn with_tile_size(mut self, rows: usize, cols: usize) -> Self {
431 self.tile_size = (rows, cols);
432 self
433 }
434
435 pub fn with_execution_streams(mut self, streams: usize) -> Self {
437 self.execution_streams = streams;
438 self
439 }
440
441 pub async fn compute_parallel(
443 &mut self,
444 points: &ArrayView2<'_, f64>,
445 ) -> SpatialResult<Array2<f64>> {
446 let (npoints, ndims) = points.dim();
447
448 if npoints == 0 || ndims == 0 {
449 return Err(SpatialError::InvalidInput("Empty input data".to_string()));
450 }
451
452 let optimizedpoints = if self.layout_optimization {
454 self.optimize_tensor_layout(points)?
455 } else {
456 points.to_owned()
457 };
458
459 if self.hierarchical_tiling && npoints > 1024 {
461 self.compute_hierarchical_tiled(&optimizedpoints.view())
462 .await
463 } else {
464 self.compute_direct_tensor_cores(&optimizedpoints.view())
465 .await
466 }
467 }
468
469 fn optimize_tensor_layout(
471 &mut self,
472 points: &ArrayView2<'_, f64>,
473 ) -> SpatialResult<Array2<f64>> {
474 let (npoints, ndims) = points.dim();
475
476 match self.tensor_layout {
477 TensorLayout::RowMajor => Ok(points.to_owned()),
478 TensorLayout::ColMajor => {
479 let mut transposed = Array2::zeros((ndims, npoints));
480 for (i, point) in points.outer_iter().enumerate() {
481 transposed.column_mut(i).assign(&point);
482 }
483 Ok(transposed.t().to_owned())
484 }
485 TensorLayout::Blocked => TensorCoreDistanceMatrix::create_blocked_layout(points),
486 TensorLayout::ZOrder => self.create_zorder_layout(points),
487 TensorLayout::HardwareOptimized => self.create_hardware_optimized_layout(points),
488 }
489 }
490
491 fn create_blocked_layout(points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
493 let (npoints, ndims) = points.dim();
494 let block_size = 64; let blocked_rows = npoints.div_ceil(block_size) * block_size;
497 let blocked_cols = ndims.div_ceil(block_size) * block_size;
498
499 let mut blocked_data = Array2::zeros((blocked_rows, blocked_cols));
500
501 for block_i in 0..(npoints / block_size + 1) {
502 for block_j in 0..(ndims / block_size + 1) {
503 let start_i = block_i * block_size;
504 let start_j = block_j * block_size;
505 let end_i = (start_i + block_size).min(npoints);
506 let end_j = (start_j + block_size).min(ndims);
507
508 for i in start_i..end_i {
509 for j in start_j..end_j {
510 blocked_data[[i, j]] = points[[i, j]];
511 }
512 }
513 }
514 }
515
516 Ok(blocked_data.slice(s![..npoints, ..ndims]).to_owned())
517 }
518
519 fn create_zorder_layout(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
521 let (npoints, ndims) = points.dim();
522
523 let mut z_indices: Vec<(usize, usize)> = (0..npoints)
525 .map(|i| {
526 (
527 i,
528 TensorCoreDistanceMatrix::calculate_z_order_index(i, ndims),
529 )
530 })
531 .collect();
532
533 z_indices.sort_by_key(|(_, z_idx)| *z_idx);
534
535 let mut reordered_data = Array2::zeros((npoints, ndims));
536 for (new_idx, (old_idx, z_idx)) in z_indices.iter().enumerate() {
537 reordered_data
538 .row_mut(new_idx)
539 .assign(&points.row(*old_idx));
540 }
541
542 Ok(reordered_data)
543 }
544
545 fn calculate_z_order_index(point_idx: usize, ndims: usize) -> usize {
547 let mut z_index = 0;
549 let temp_idx = point_idx;
550
551 for bit in 0..16 {
552 for dim in 0..ndims.min(3) {
554 if temp_idx & (1 << bit) != 0 {
556 z_index |= 1 << (bit * ndims + dim);
557 }
558 }
559 }
560
561 z_index
562 }
563
564 fn create_hardware_optimized_layout(
566 &self,
567 points: &ArrayView2<'_, f64>,
568 ) -> SpatialResult<Array2<f64>> {
569 if let Some(ref capabilities) = self.capabilities {
570 match capabilities.architecture {
571 GpuArchitecture::Ampere | GpuArchitecture::Hopper => {
572 self.create_nvidia_optimized_layout(points)
574 }
575 GpuArchitecture::CDNA2 | GpuArchitecture::CDNA3 => {
576 self.create_amd_optimized_layout(points)
578 }
579 GpuArchitecture::XeHPC | GpuArchitecture::XeGraphics => {
580 self.create_intel_optimized_layout(points)
582 }
583 _ => {
584 TensorCoreDistanceMatrix::create_blocked_layout(points)
586 }
587 }
588 } else {
589 TensorCoreDistanceMatrix::create_blocked_layout(points)
590 }
591 }
592
593 fn create_nvidia_optimized_layout(
595 &self,
596 points: &ArrayView2<'_, f64>,
597 ) -> SpatialResult<Array2<f64>> {
598 let (npoints, ndims) = points.dim();
599
600 let paddedpoints = npoints.div_ceil(8) * 8;
602 let padded_dims = ndims.div_ceil(8) * 8;
603
604 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
605
606 for i in 0..npoints {
608 for j in 0..ndims {
609 padded_data[[i, j]] = points[[i, j]];
610 }
611 }
612
613 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
615 }
616
617 fn create_amd_optimized_layout(
619 &self,
620 points: &ArrayView2<'_, f64>,
621 ) -> SpatialResult<Array2<f64>> {
622 let (npoints, ndims) = points.dim();
623
624 let paddedpoints = npoints.div_ceil(16) * 16;
626 let padded_dims = ndims.div_ceil(16) * 16;
627
628 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
629
630 for i in 0..npoints {
631 for j in 0..ndims {
632 padded_data[[i, j]] = points[[i, j]];
633 }
634 }
635
636 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
637 }
638
639 fn create_intel_optimized_layout(
641 &self,
642 points: &ArrayView2<'_, f64>,
643 ) -> SpatialResult<Array2<f64>> {
644 let (npoints, ndims) = points.dim();
645
646 let paddedpoints = npoints.div_ceil(32) * 32;
648 let padded_dims = ndims.div_ceil(32) * 32;
649
650 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
651
652 for i in 0..npoints {
653 for j in 0..ndims {
654 padded_data[[i, j]] = points[[i, j]];
655 }
656 }
657
658 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
659 }
660
661 async fn compute_hierarchical_tiled(
663 &mut self,
664 points: &ArrayView2<'_, f64>,
665 ) -> SpatialResult<Array2<f64>> {
666 let (npoints, ndims) = points.dim();
667 let mut distance_matrix = Array2::zeros((npoints, npoints));
668
669 let (tile_rows, tile_cols) = self.tile_size;
670 let precision_mode = self.precision_mode; let mut tile_futures = Vec::new();
674
675 for i in (0..npoints).step_by(tile_rows) {
676 for j in (0..npoints).step_by(tile_cols) {
677 let end_i = (i + tile_rows).min(npoints);
678 let end_j = (j + tile_cols).min(npoints);
679
680 let tilepoints_i = points.slice(s![i..end_i, ..]).to_owned();
681 let tilepoints_j = points.slice(s![j..end_j, ..]).to_owned();
682
683 let future = async move {
685 let (rows_i, _) = tilepoints_i.dim();
687 let (rows_j, _) = tilepoints_j.dim();
688 let mut tile_distances = Array2::zeros((rows_i, rows_j));
689
690 for r in 0..rows_i {
691 for c in 0..rows_j {
692 let p1 = tilepoints_i.row(r);
693 let p2 = tilepoints_j.row(c);
694
695 let dist = if ndims <= 16 {
697 use scirs2_core::simd_ops::SimdUnifiedOps;
699 let diff = f64::simd_sub(&p1, &p2);
700 let squared = f64::simd_mul(&diff.view(), &diff.view());
701 f64::simd_sum(&squared.view()).sqrt()
702 } else {
703 let diff = &p1 - &p2;
705 diff.iter().map(|x| x.powi(2)).sum::<f64>().sqrt()
706 };
707 tile_distances[[r, c]] = dist;
708 }
709 }
710 Ok::<Array2<f64>, SpatialError>(tile_distances)
711 };
712 tile_futures.push((i, j, end_i, end_j, future));
713 }
714 }
715
716 for (i, j, end_i, end_j, future) in tile_futures {
718 let tile_result = future.await?;
719
720 let tile_rows = end_i - i;
722 let tile_cols = end_j - j;
723
724 for row in 0..tile_rows {
725 for col in 0..tile_cols {
726 distance_matrix[[i + row, j + col]] = tile_result[[row, col]];
727 }
728 }
729 }
730
731 Ok(distance_matrix)
732 }
733
734 async fn compute_tile_tensor_cores(
736 &mut self,
737 points_i: Array2<f64>,
738 points_j: Array2<f64>,
739 precision_mode: PrecisionMode,
740 ) -> SpatialResult<Array2<f64>> {
741 let (_n_i, ndims) = points_i.dim();
742 let (_n_j, _) = points_j.dim();
743
744 match precision_mode {
745 PrecisionMode::Full32 => {
746 self.compute_distances_fp32(&points_i.view(), &points_j.view())
747 .await
748 }
749 PrecisionMode::Mixed16 => {
750 self.compute_distances_mixed16(&points_i.view(), &points_j.view())
751 .await
752 }
753 PrecisionMode::BrainFloat16 => {
754 self.compute_distances_bf16(&points_i.view(), &points_j.view())
755 .await
756 }
757 PrecisionMode::Int8Dynamic => {
758 self.compute_distances_int8(&points_i.view(), &points_j.view())
759 .await
760 }
761 PrecisionMode::Int4Advanced => {
762 self.compute_distances_int4(&points_i.view(), &points_j.view())
763 .await
764 }
765 PrecisionMode::Adaptive => {
766 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
767 .await
768 }
769 PrecisionMode::AdvancedAdaptive => {
770 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
771 .await
772 }
773 }
774 }
775
776 async fn compute_direct_tensor_cores(
778 &mut self,
779 points: &ArrayView2<'_, f64>,
780 ) -> SpatialResult<Array2<f64>> {
781 self.compute_tile_tensor_cores(points.to_owned(), points.to_owned(), self.precision_mode)
782 .await
783 }
784
785 async fn compute_distances_fp32(
787 &self,
788 points_i: &ArrayView2<'_, f64>,
789 points_j: &ArrayView2<'_, f64>,
790 ) -> SpatialResult<Array2<f64>> {
791 let (n_i, ndims) = points_i.dim();
792 let (n_j, _) = points_j.dim();
793 let mut distances = Array2::zeros((n_i, n_j));
794
795 let norms_i: Array1<f64> = points_i
800 .outer_iter()
801 .map(|point| point.iter().map(|&x| x * x).sum())
802 .collect();
803
804 let norms_j: Array1<f64> = points_j
806 .outer_iter()
807 .map(|point| point.iter().map(|&x| x * x).sum())
808 .collect();
809
810 let cross_terms = self
812 .tensor_core_gemm_fp32(points_i, &points_j.t().to_owned().view())
813 .await?;
814
815 for _i in 0..n_i {
817 for _j in 0..n_j {
818 distances[[_i, _j]] = (norms_i[_i] + norms_j[_j] - 2.0 * cross_terms[[_i, _j]])
819 .max(0.0)
820 .sqrt();
821 }
822 }
823
824 Ok(distances)
825 }
826
827 async fn compute_distances_mixed16(
829 &self,
830 points_i: &ArrayView2<'_, f64>,
831 points_j: &ArrayView2<'_, f64>,
832 ) -> SpatialResult<Array2<f64>> {
833 let points_i_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_i)?;
835 let points_j_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_j)?;
836
837 let (n_i, _) = points_i.dim();
838 let (n_j, _) = points_j.dim();
839 let mut distances = Array2::zeros((n_i, n_j));
840
841 let norms_i_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_i_f16)?;
843 let norms_j_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_j_f16)?;
844
845 let cross_terms = self
847 .tensor_core_gemm_mixed16(&points_i_f16, &points_j_f16.t().to_owned())
848 .await?;
849
850 for _i in 0..n_i {
851 for _j in 0..n_j {
852 let distance_sq = norms_i_f16[_i] as f64 + norms_j_f16[_j] as f64
853 - 2.0 * cross_terms[[_i, _j]] as f64;
854 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
855 }
856 }
857
858 Ok(distances)
859 }
860
861 async fn compute_distances_bf16(
863 &mut self,
864 points_i: &ArrayView2<'_, f64>,
865 points_j: &ArrayView2<'_, f64>,
866 ) -> SpatialResult<Array2<f64>> {
867 let points_i_bf16 = self.convert_to_bf16(points_i)?;
870 let points_j_bf16 = self.convert_to_bf16(points_j)?;
871
872 let (n_i, _) = points_i.dim();
873 let (n_j, _) = points_j.dim();
874 let mut distances = Array2::zeros((n_i, n_j));
875
876 let norms_i_bf16 = self.compute_norms_bf16(&points_i_bf16)?;
877 let norms_j_bf16 = self.compute_norms_bf16(&points_j_bf16)?;
878
879 let cross_terms = self
880 .tensor_core_gemm_bf16(&points_i_bf16, &points_j_bf16.t().to_owned())
881 .await?;
882
883 for _i in 0..n_i {
884 for _j in 0..n_j {
885 let distance_sq = norms_i_bf16[_i] as f64 + norms_j_bf16[_j] as f64
886 - 2.0 * cross_terms[[_i, _j]] as f64;
887 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
888 }
889 }
890
891 Ok(distances)
892 }
893
894 async fn compute_distances_int8(
896 &self,
897 points_i: &ArrayView2<'_, f64>,
898 points_j: &ArrayView2<'_, f64>,
899 ) -> SpatialResult<Array2<f64>> {
900 let (scale_i, points_i_int8) = self.quantize_to_int8_dynamic(points_i)?;
902 let (scale_j, points_j_int8) = self.quantize_to_int8_dynamic(points_j)?;
903
904 let (n_i, _) = points_i.dim();
905 let (n_j, _) = points_j.dim();
906 let mut distances = Array2::zeros((n_i, n_j));
907
908 let combined_scale = scale_i * scale_j;
910
911 for _i in 0..n_i {
912 for _j in 0..n_j {
913 let cross_term_int32 = points_i_int8
915 .row(_i)
916 .iter()
917 .zip(points_j_int8.row(_j).iter())
918 .map(|(&a, &b)| (a as i32) * (b as i32))
919 .sum::<i32>();
920 let cross_term_f64 = cross_term_int32 as f64 * combined_scale;
921
922 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
924 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
925
926 let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
927 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
928 }
929 }
930
931 Ok(distances)
932 }
933
934 async fn compute_distances_int4(
936 &self,
937 points_i: &ArrayView2<'_, f64>,
938 points_j: &ArrayView2<'_, f64>,
939 ) -> SpatialResult<Array2<f64>> {
940 let (scale_i, points_i_int4) = self.quantize_to_int4_advanced(points_i)?;
942 let (scale_j, points_j_int4) = self.quantize_to_int4_advanced(points_j)?;
943
944 let points_i_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_i_int4);
946 let points_j_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_j_int4);
947
948 let (n_i, _) = points_i.dim();
949 let (n_j, _) = points_j.dim();
950 let mut distances = Array2::zeros((n_i, n_j));
951
952 let n_i_chunks = n_i / 4;
958 let n_j_chunks = n_j / 4;
959
960 for i_chunk in 0..n_i_chunks {
962 for j_chunk in 0..n_j_chunks {
963 let i_base = i_chunk * 4;
964 let j_base = j_chunk * 4;
965
966 for i_offset in 0..4 {
968 let _i = i_base + i_offset;
969 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
970
971 let _j0 = j_base;
973 let _j1 = j_base + 1;
974 let _j2 = j_base + 2;
975 let _j3 = j_base + 3;
976
977 let norm_j0_sq: f64 = points_j.row(_j0).iter().map(|&x| x * x).sum();
978 let norm_j1_sq: f64 = points_j.row(_j1).iter().map(|&x| x * x).sum();
979 let norm_j2_sq: f64 = points_j.row(_j2).iter().map(|&x| x * x).sum();
980 let norm_j3_sq: f64 = points_j.row(_j3).iter().map(|&x| x * x).sum();
981
982 let cross_term_f64 = 0.0; let distance_sq0 = norm_i_sq + norm_j0_sq - 2.0 * cross_term_f64;
986 let distance_sq1 = norm_i_sq + norm_j1_sq - 2.0 * cross_term_f64;
987 let distance_sq2 = norm_i_sq + norm_j2_sq - 2.0 * cross_term_f64;
988 let distance_sq3 = norm_i_sq + norm_j3_sq - 2.0 * cross_term_f64;
989
990 distances[[_i, _j0]] = distance_sq0.max(0.0).sqrt();
991 distances[[_i, _j1]] = distance_sq1.max(0.0).sqrt();
992 distances[[_i, _j2]] = distance_sq2.max(0.0).sqrt();
993 distances[[_i, _j3]] = distance_sq3.max(0.0).sqrt();
994 }
995 }
996 }
997
998 for _i in (n_i_chunks * 4)..n_i {
1000 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
1001 for _j in 0..n_j {
1002 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
1003 let cross_term_f64 = 0.0; let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
1005 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
1006 }
1007 }
1008
1009 for _i in 0..(n_i_chunks * 4) {
1011 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
1012 for _j in (n_j_chunks * 4)..n_j {
1013 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
1014 let cross_term_f64 = 0.0; let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
1016 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
1017 }
1018 }
1019
1020 Ok(distances)
1021 }
1022
1023 async fn compute_distances_adaptive(
1025 &mut self,
1026 points_i: &ArrayView2<'_, f64>,
1027 points_j: &ArrayView2<'_, f64>,
1028 ) -> SpatialResult<Array2<f64>> {
1029 let data_range = self.analyze_data_range(points_i, points_j);
1031 let condition_number = self.estimate_condition_number(points_i, points_j);
1032
1033 let optimal_precision = if condition_number > 1e6 {
1034 PrecisionMode::Full32
1035 } else if data_range > 1e3 {
1036 PrecisionMode::BrainFloat16
1037 } else if data_range > 100.0 {
1038 PrecisionMode::Mixed16
1039 } else {
1040 PrecisionMode::Int8Dynamic
1041 };
1042
1043 match optimal_precision {
1044 PrecisionMode::Full32 => self.compute_distances_fp32(points_i, points_j).await,
1045 PrecisionMode::Mixed16 => self.compute_distances_mixed16(points_i, points_j).await,
1046 PrecisionMode::BrainFloat16 => self.compute_distances_bf16(points_i, points_j).await,
1047 PrecisionMode::Int8Dynamic => self.compute_distances_int8(points_i, points_j).await,
1048 PrecisionMode::Int4Advanced => self.compute_distances_int8(points_i, points_j).await, PrecisionMode::Adaptive => self.compute_distances_mixed16(points_i, points_j).await, PrecisionMode::AdvancedAdaptive => {
1051 self.compute_distances_fp32(points_i, points_j).await
1052 } }
1054 }
1055
1056 async fn tensor_core_gemm_fp32(
1058 &self,
1059 a: &ArrayView2<'_, f64>,
1060 b: &ArrayView2<'_, f64>,
1061 ) -> SpatialResult<Array2<f64>> {
1062 let (m, k) = a.dim();
1064 let (k2, n) = b.dim();
1065
1066 if k != k2 {
1067 return Err(SpatialError::InvalidInput(
1068 "Matrix dimensions don't match for multiplication".to_string(),
1069 ));
1070 }
1071
1072 let mut c = Array2::zeros((m, n));
1073
1074 let block_size = 16; for i in (0..m).step_by(block_size) {
1078 for j in (0..n).step_by(block_size) {
1079 for kk in (0..k).step_by(block_size) {
1080 let end_i = (i + block_size).min(m);
1081 let end_j = (j + block_size).min(n);
1082 let end_k = (kk + block_size).min(k);
1083
1084 let block_rows = end_i - i;
1086 let block_cols = end_j - j;
1087 let block_k = end_k - kk;
1088
1089 let k_chunks = block_k / 4;
1091
1092 for ii in i..end_i {
1093 for jj in j..end_j {
1094 let mut accumulator = c[[ii, jj]];
1095
1096 for k_chunk in 0..k_chunks {
1098 let k_base = kk + k_chunk * 4;
1099
1100 let a_val0 = a[[ii, k_base]];
1102 let a_val1 = a[[ii, k_base + 1]];
1103 let a_val2 = a[[ii, k_base + 2]];
1104 let a_val3 = a[[ii, k_base + 3]];
1105
1106 let b_val0 = b[[k_base, jj]];
1107 let b_val1 = b[[k_base + 1, jj]];
1108 let b_val2 = b[[k_base + 2, jj]];
1109 let b_val3 = b[[k_base + 3, jj]];
1110
1111 accumulator += a_val0 * b_val0
1112 + a_val1 * b_val1
1113 + a_val2 * b_val2
1114 + a_val3 * b_val3;
1115 }
1116
1117 for kkk in (kk + k_chunks * 4)..end_k {
1119 accumulator += a[[ii, kkk]] * b[[kkk, jj]];
1120 }
1121
1122 c[[ii, jj]] = accumulator;
1123 }
1124 }
1125 }
1126 }
1127 }
1128
1129 Ok(c)
1130 }
1131
1132 async fn tensor_core_gemm_mixed16(
1134 &self,
1135 a: &Array2<f32>,
1136 b: &Array2<f32>,
1137 ) -> SpatialResult<Array2<f32>> {
1138 let (m, k) = a.dim();
1140 let (k2, n) = b.dim();
1141
1142 if k != k2 {
1143 return Err(SpatialError::InvalidInput(
1144 "Matrix dimensions don't match".to_string(),
1145 ));
1146 }
1147
1148 let mut c = Array2::zeros((m, n));
1149 let block_size = 16;
1150
1151 for i in (0..m).step_by(block_size) {
1152 for j in (0..n).step_by(block_size) {
1153 for kk in (0..k).step_by(block_size) {
1154 let end_i = (i + block_size).min(m);
1155 let end_j = (j + block_size).min(n);
1156 let end_k = (kk + block_size).min(k);
1157
1158 let block_k = end_k - kk;
1160 let k_chunks = block_k / 4;
1161
1162 for ii in i..end_i {
1163 for jj in j..end_j {
1164 let mut accumulator = c[[ii, jj]];
1165
1166 for k_chunk in 0..k_chunks {
1168 let k_base = kk + k_chunk * 4;
1169
1170 let a_val0 = a[[ii, k_base]];
1172 let a_val1 = a[[ii, k_base + 1]];
1173 let a_val2 = a[[ii, k_base + 2]];
1174 let a_val3 = a[[ii, k_base + 3]];
1175
1176 let b_val0 = b[[k_base, jj]];
1177 let b_val1 = b[[k_base + 1, jj]];
1178 let b_val2 = b[[k_base + 2, jj]];
1179 let b_val3 = b[[k_base + 3, jj]];
1180
1181 accumulator += a_val0 * b_val0
1182 + a_val1 * b_val1
1183 + a_val2 * b_val2
1184 + a_val3 * b_val3;
1185 }
1186
1187 for kkk in (kk + k_chunks * 4)..end_k {
1189 accumulator += a[[ii, kkk]] * b[[kkk, jj]];
1190 }
1191
1192 c[[ii, jj]] = accumulator;
1193 }
1194 }
1195 }
1196 }
1197 }
1198
1199 Ok(c)
1200 }
1201
1202 async fn tensor_core_gemm_bf16(
1204 &self,
1205 a: &Array2<f32>,
1206 b: &Array2<f32>,
1207 ) -> SpatialResult<Array2<f32>> {
1208 self.tensor_core_gemm_mixed16(a, b).await
1210 }
1211
1212 #[allow(dead_code)]
1214 async fn tensor_core_gemm_int8(
1215 &self,
1216 a: &Array2<i8>,
1217 b: &Array2<i8>,
1218 ) -> SpatialResult<Array2<i32>> {
1219 let (m, k) = a.dim();
1220 let (k2, n) = b.dim();
1221
1222 if k != k2 {
1223 return Err(SpatialError::InvalidInput(
1224 "Matrix dimensions don't match".to_string(),
1225 ));
1226 }
1227
1228 let mut c = Array2::zeros((m, n));
1229 let block_size = 16;
1230
1231 for i in (0..m).step_by(block_size) {
1232 for j in (0..n).step_by(block_size) {
1233 for kk in (0..k).step_by(block_size) {
1234 let end_i = (i + block_size).min(m);
1235 let end_j = (j + block_size).min(n);
1236 let end_k = (kk + block_size).min(k);
1237
1238 for ii in i..end_i {
1239 for jj in j..end_j {
1240 for kkk in kk..end_k {
1241 c[[ii, jj]] += a[[ii, kkk]] as i32 * b[[kkk, jj]] as i32;
1243 }
1244 }
1245 }
1246 }
1247 }
1248 }
1249
1250 Ok(c)
1251 }
1252
1253 fn convert_to_fp16(data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1255 let (rows, cols) = data.dim();
1256 let mut fp16_data = Array2::zeros((rows, cols));
1257
1258 for i in 0..rows {
1259 for j in 0..cols {
1260 fp16_data[[i, j]] = data[[i, j]] as f32;
1262 }
1263 }
1264
1265 Ok(fp16_data)
1266 }
1267
1268 fn convert_to_bf16(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1270 TensorCoreDistanceMatrix::convert_to_fp16(data)
1272 }
1273
1274 fn quantize_to_int8_dynamic(
1276 &self,
1277 data: &ArrayView2<'_, f64>,
1278 ) -> SpatialResult<(f64, Array2<i8>)> {
1279 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1280 let scale = max_val / 127.0; let (rows, cols) = data.dim();
1283 let mut quantized = Array2::zeros((rows, cols));
1284
1285 for i in 0..rows {
1286 for j in 0..cols {
1287 let quantized_val = (data[[i, j]] / scale).round() as i8;
1288 quantized[[i, j]] = quantized_val.clamp(-127, 127);
1289 }
1290 }
1291
1292 Ok((scale, quantized))
1293 }
1294
1295 fn quantize_to_int4_advanced(
1297 &self,
1298 data: &ArrayView2<'_, f64>,
1299 ) -> SpatialResult<(f64, Array2<i8>)> {
1300 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1301 let scale = max_val / 7.0; let (rows, cols) = data.dim();
1304 let mut quantized = Array2::zeros((rows, cols));
1305
1306 for i in 0..rows {
1307 for j in 0..cols {
1308 let quantized_val = (data[[i, j]] / scale).round() as i8;
1309 quantized[[i, j]] = quantized_val.clamp(-7, 7);
1310 }
1311 }
1312
1313 Ok((scale, quantized))
1314 }
1315
1316 fn int4_to_int8(data: &Array2<i8>) -> Array2<i8> {
1318 data.mapv(|x| x.clamp(-7, 7))
1320 }
1321
1322 fn compute_norms_fp16(data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1324 let norms = data
1325 .outer_iter()
1326 .map(|row| row.iter().map(|&x| x * x).sum())
1327 .collect();
1328 Ok(norms)
1329 }
1330
1331 fn compute_norms_bf16(&mut self, data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1333 TensorCoreDistanceMatrix::compute_norms_fp16(data)
1334 }
1335
1336 fn analyze_data_range(
1338 &self,
1339 points_i: &ArrayView2<'_, f64>,
1340 points_j: &ArrayView2<'_, f64>,
1341 ) -> f64 {
1342 let min_i = points_i.fold(f64::INFINITY, |acc, &x| acc.min(x));
1343 let max_i = points_i.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1344 let min_j = points_j.fold(f64::INFINITY, |acc, &x| acc.min(x));
1345 let max_j = points_j.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1346
1347 let overall_min = min_i.min(min_j);
1348 let overall_max = max_i.max(max_j);
1349
1350 overall_max - overall_min
1351 }
1352
1353 fn estimate_condition_number(
1355 &self,
1356 points_i: &ArrayView2<'_, f64>,
1357 points_j: &ArrayView2<'_, f64>,
1358 ) -> f64 {
1359 let data_range = self.analyze_data_range(points_i, points_j);
1361 let mean_i: f64 = points_i.sum() / (points_i.len() as f64);
1362 let mean_j: f64 = points_j.sum() / (points_j.len() as f64);
1363 let overall_mean = (mean_i + mean_j) / 2.0;
1364
1365 if overall_mean.abs() < 1e-10 {
1366 1e6 } else {
1368 data_range / overall_mean.abs()
1369 }
1370 }
1371}
1372
1373#[allow(dead_code)]
1375#[derive(Debug, Clone)]
1376pub struct TensorCoreClustering {
1377 _numclusters: usize,
1379 precision_mode: PrecisionMode,
1381 tensor_cores: bool,
1383 mixed_precision: bool,
1385 dynamic_precision: bool,
1387 capabilities: Option<TensorCoreCapabilities>,
1389}
1390
1391impl TensorCoreClustering {
1392 pub fn new(_numclusters: usize) -> SpatialResult<Self> {
1394 let capabilities = detect_tensor_core_capabilities().ok();
1395
1396 Ok(Self {
1397 _numclusters,
1398 precision_mode: PrecisionMode::Mixed16,
1399 tensor_cores: true,
1400 mixed_precision: true,
1401 dynamic_precision: false,
1402 capabilities,
1403 })
1404 }
1405
1406 pub fn with_tensor_cores(mut self, enabled: bool) -> Self {
1408 self.tensor_cores = enabled;
1409 self
1410 }
1411
1412 pub fn with_mixed_precision(mut self, enabled: bool) -> Self {
1414 self.mixed_precision = enabled;
1415 self
1416 }
1417
1418 pub fn with_dynamic_precision_scaling(mut self, enabled: bool) -> Self {
1420 self.dynamic_precision = enabled;
1421 self
1422 }
1423
1424 pub async fn fit(
1426 &mut self,
1427 points: &ArrayView2<'_, f64>,
1428 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
1429 let (npoints, ndims) = points.dim();
1430
1431 if npoints < self._numclusters {
1432 return Err(SpatialError::InvalidInput(
1433 "Number of points must be >= number of clusters".to_string(),
1434 ));
1435 }
1436
1437 let mut centroids = self.initialize_centroids(points)?;
1439 let mut assignments = Array1::zeros(npoints);
1440
1441 for _iteration in 0..100 {
1443 let distance_matrix = if self.tensor_cores {
1445 let tensor_computer =
1446 TensorCoreDistanceMatrix::new()?.with_precision_mode(self.precision_mode);
1447 tensor_computer
1448 .compute_distances_to_centroids(points, ¢roids.view())
1449 .await?
1450 } else {
1451 self.compute_distances_fallback(points, ¢roids.view())?
1452 };
1453
1454 let new_assignments = self.update_assignments(&distance_matrix)?;
1456
1457 let new_centroids = if self.tensor_cores {
1459 self.update_centroids_tensor_cores(points, &new_assignments)
1460 .await?
1461 } else {
1462 self.update_centroids_fallback(points, &new_assignments)?
1463 };
1464
1465 let centroid_change = self.compute_centroid_change(¢roids, &new_centroids);
1467 if centroid_change < 1e-6 {
1468 break;
1469 }
1470
1471 centroids = new_centroids;
1472 assignments = new_assignments;
1473 }
1474
1475 Ok((centroids, assignments))
1476 }
1477
1478 fn initialize_centroids(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
1480 let (npoints, ndims) = points.dim();
1481 let mut centroids = Array2::zeros((self._numclusters, ndims));
1482
1483 let mut rng = scirs2_core::random::rng();
1485
1486 let first_idx = rng.gen_range(0..npoints);
1488 centroids.row_mut(0).assign(&points.row(first_idx));
1489
1490 for k in 1..self._numclusters {
1492 let mut distances = Array1::zeros(npoints);
1493
1494 for i in 0..npoints {
1495 let point = points.row(i);
1496 let mut min_dist = f64::INFINITY;
1497
1498 for j in 0..k {
1499 let centroid = centroids.row(j);
1500 let dist: f64 = point
1501 .iter()
1502 .zip(centroid.iter())
1503 .map(|(&a, &b)| (a - b).powi(2))
1504 .sum::<f64>();
1505 min_dist = min_dist.min(dist);
1506 }
1507
1508 distances[i] = min_dist;
1509 }
1510
1511 let total_dist: f64 = distances.sum();
1513 let mut cumulative = 0.0;
1514 let random_val = scirs2_core::random::random::<f64>() * total_dist;
1515
1516 for i in 0..npoints {
1517 cumulative += distances[i];
1518 if cumulative >= random_val {
1519 centroids.row_mut(k).assign(&points.row(i));
1520 break;
1521 }
1522 }
1523 }
1524
1525 Ok(centroids)
1526 }
1527
1528 fn update_assignments(
1530 &mut self,
1531 distance_matrix: &Array2<f64>,
1532 ) -> SpatialResult<Array1<usize>> {
1533 let npoints = distance_matrix.nrows();
1534 let mut assignments = Array1::zeros(npoints);
1535
1536 for i in 0..npoints {
1537 let mut min_dist = f64::INFINITY;
1538 let mut best_cluster = 0;
1539
1540 for j in 0..self._numclusters {
1541 if distance_matrix[[i, j]] < min_dist {
1542 min_dist = distance_matrix[[i, j]];
1543 best_cluster = j;
1544 }
1545 }
1546
1547 assignments[i] = best_cluster;
1548 }
1549
1550 Ok(assignments)
1551 }
1552
1553 async fn update_centroids_tensor_cores(
1555 &self,
1556 points: &ArrayView2<'_, f64>,
1557 assignments: &Array1<usize>,
1558 ) -> SpatialResult<Array2<f64>> {
1559 let (_npoints, ndims) = points.dim();
1560 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1561 let mut cluster_counts = vec![0; self._numclusters];
1562
1563 for &cluster in assignments {
1565 cluster_counts[cluster] += 1;
1566 }
1567
1568 for cluster in 0..self._numclusters {
1570 if cluster_counts[cluster] == 0 {
1571 continue;
1572 }
1573
1574 let clusterpoints: Vec<usize> = assignments
1576 .iter()
1577 .enumerate()
1578 .filter(|(_, &c)| c == cluster)
1579 .map(|(i, _)| i)
1580 .collect();
1581
1582 let cluster_data = Array2::from_shape_fn((clusterpoints.len(), ndims), |(i, j)| {
1584 points[[clusterpoints[i], j]]
1585 });
1586
1587 let sum_vector = self.tensor_sum_reduction(&cluster_data.view()).await?;
1589 let count = clusterpoints.len() as f64;
1590
1591 for j in 0..ndims {
1592 new_centroids[[cluster, j]] = sum_vector[j] / count;
1593 }
1594 }
1595
1596 Ok(new_centroids)
1597 }
1598
1599 async fn tensor_sum_reduction(&self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
1601 let (_npoints, ndims) = data.dim();
1602 let mut sum_vector = Array1::zeros(ndims);
1603
1604 for j in 0..ndims {
1606 let column_sum: f64 = data.column(j).sum();
1607 sum_vector[j] = column_sum;
1608 }
1609
1610 Ok(sum_vector)
1611 }
1612
1613 fn compute_distances_fallback(
1615 &self,
1616 points: &ArrayView2<'_, f64>,
1617 centroids: &ArrayView2<'_, f64>,
1618 ) -> SpatialResult<Array2<f64>> {
1619 let (npoints, ndims) = points.dim();
1620 let (n_clusters_, _) = centroids.dim();
1621 let mut distances = Array2::zeros((npoints, n_clusters_));
1622
1623 let cluster_chunks = n_clusters_ / 4;
1625
1626 for i in 0..npoints {
1627 let point_row = points.row(i);
1628
1629 for j_chunk in 0..cluster_chunks {
1631 let j_base = j_chunk * 4;
1632
1633 let j0 = j_base;
1635 let j1 = j_base + 1;
1636 let j2 = j_base + 2;
1637 let j3 = j_base + 3;
1638
1639 let centroid_row0 = centroids.row(j0);
1640 let centroid_row1 = centroids.row(j1);
1641 let centroid_row2 = centroids.row(j2);
1642 let centroid_row3 = centroids.row(j3);
1643
1644 let distance0: f64 = point_row
1645 .iter()
1646 .zip(centroid_row0.iter())
1647 .map(|(&a, &b)| (a - b).powi(2))
1648 .sum::<f64>()
1649 .sqrt();
1650
1651 let distance1: f64 = point_row
1652 .iter()
1653 .zip(centroid_row1.iter())
1654 .map(|(&a, &b)| (a - b).powi(2))
1655 .sum::<f64>()
1656 .sqrt();
1657
1658 let distance2: f64 = point_row
1659 .iter()
1660 .zip(centroid_row2.iter())
1661 .map(|(&a, &b)| (a - b).powi(2))
1662 .sum::<f64>()
1663 .sqrt();
1664
1665 let distance3: f64 = point_row
1666 .iter()
1667 .zip(centroid_row3.iter())
1668 .map(|(&a, &b)| (a - b).powi(2))
1669 .sum::<f64>()
1670 .sqrt();
1671
1672 distances[[i, j0]] = distance0;
1673 distances[[i, j1]] = distance1;
1674 distances[[i, j2]] = distance2;
1675 distances[[i, j3]] = distance3;
1676 }
1677
1678 for j in (cluster_chunks * 4)..n_clusters_ {
1680 let distance: f64 = point_row
1681 .iter()
1682 .zip(centroids.row(j).iter())
1683 .map(|(&a, &b)| (a - b).powi(2))
1684 .sum::<f64>()
1685 .sqrt();
1686 distances[[i, j]] = distance;
1687 }
1688 }
1689
1690 Ok(distances)
1691 }
1692
1693 fn update_centroids_fallback(
1695 &self,
1696 points: &ArrayView2<'_, f64>,
1697 assignments: &Array1<usize>,
1698 ) -> SpatialResult<Array2<f64>> {
1699 let (npoints, ndims) = points.dim();
1700 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1701 let mut cluster_counts = vec![0; self._numclusters];
1702
1703 for i in 0..npoints {
1705 let cluster = assignments[i];
1706 cluster_counts[cluster] += 1;
1707
1708 for j in 0..ndims {
1709 new_centroids[[cluster, j]] += points[[i, j]];
1710 }
1711 }
1712
1713 for cluster in 0..self._numclusters {
1715 if cluster_counts[cluster] > 0 {
1716 let count = cluster_counts[cluster] as f64;
1717 for j in 0..ndims {
1718 new_centroids[[cluster, j]] /= count;
1719 }
1720 }
1721 }
1722
1723 Ok(new_centroids)
1724 }
1725
1726 fn compute_centroid_change(
1728 &self,
1729 old_centroids: &Array2<f64>,
1730 new_centroids: &Array2<f64>,
1731 ) -> f64 {
1732 let mut total_change = 0.0;
1733
1734 for i in 0..self._numclusters {
1735 let change: f64 = old_centroids
1736 .row(i)
1737 .iter()
1738 .zip(new_centroids.row(i).iter())
1739 .map(|(&a, &b)| (a - b).powi(2))
1740 .sum::<f64>()
1741 .sqrt();
1742 total_change += change;
1743 }
1744
1745 total_change / (self._numclusters as f64)
1746 }
1747}
1748
1749impl Default for StabilityMetrics {
1750 fn default() -> Self {
1751 Self::new()
1752 }
1753}
1754
1755impl StabilityMetrics {
1756 pub fn new() -> Self {
1758 Self {
1759 condition_number: 1.0,
1760 relative_error: 0.0,
1761 forward_error: 0.0,
1762 backward_error: 0.0,
1763 digit_loss: 0.0,
1764 stability_level: StabilityLevel::Excellent,
1765 error_types: Vec::new(),
1766 timestamp: Instant::now(),
1767 }
1768 }
1769
1770 pub fn update_stability_level(&mut self) {
1772 self.stability_level = if self.condition_number > 1e12 || self.relative_error > 1e-3 {
1773 StabilityLevel::Critical
1774 } else if self.condition_number > 1e8 || self.relative_error > 1e-6 {
1775 StabilityLevel::Poor
1776 } else if self.condition_number > 1e4 || self.relative_error > 1e-9 {
1777 StabilityLevel::Moderate
1778 } else if self.condition_number > 1e2 || self.relative_error > 1e-12 {
1779 StabilityLevel::Good
1780 } else {
1781 StabilityLevel::Excellent
1782 };
1783 }
1784
1785 pub fn detect_errors(&mut self, data: &Array2<f64>) {
1787 self.error_types.clear();
1788
1789 for &value in data.iter() {
1791 if !value.is_finite() {
1792 self.error_types.push(NumericalErrorType::InvalidValues);
1793 break;
1794 }
1795 }
1796
1797 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1799 if max_val > 1e100 {
1800 self.error_types.push(NumericalErrorType::Overflow);
1801 } else if max_val < 1e-100 && max_val > 0.0 {
1802 self.error_types.push(NumericalErrorType::Underflow);
1803 }
1804
1805 if self.digit_loss > 6.0 {
1807 self.error_types.push(NumericalErrorType::PrecisionLoss);
1808 }
1809
1810 if self.condition_number > 1e12 {
1812 self.error_types.push(NumericalErrorType::IllConditioned);
1813 }
1814 }
1815}
1816
1817impl Default for DynamicPrecisionConfig {
1818 fn default() -> Self {
1819 Self {
1820 strategy: ScalingStrategy::Balanced,
1821 min_precision: PrecisionMode::Int8Dynamic,
1822 max_precision: PrecisionMode::Full32,
1823 stability_threshold_up: 1e-6,
1824 stability_threshold_down: 1e-9,
1825 performance_weight: 0.6,
1826 accuracy_weight: 0.4,
1827 max_changes_per_operation: 3,
1828 change_cooldown: Duration::from_millis(100),
1829 }
1830 }
1831}
1832
1833impl NumericalStabilityMonitor {
1834 pub fn new(config: DynamicPrecisionConfig) -> Self {
1836 Self {
1837 current_metrics: StabilityMetrics::new(),
1838 stability_history: VecDeque::new(),
1839 precision_config: config,
1840 current_precision: PrecisionMode::Mixed16,
1841 precision_history: VecDeque::new(),
1842 recovery_attempts: 0,
1843 max_history_length: 1000,
1844 last_precision_change: None,
1845 }
1846 }
1847
1848 pub fn monitor_stability(
1850 &mut self,
1851 data: &Array2<f64>,
1852 computation_result: &Array2<f64>,
1853 ) -> SpatialResult<()> {
1854 self.current_metrics.condition_number =
1856 NumericalStabilityMonitor::estimate_condition_number(data);
1857
1858 self.current_metrics.relative_error =
1860 self.estimate_relative_error(data, computation_result);
1861
1862 self.current_metrics.forward_error = self.estimate_forward_error(data, computation_result);
1864 self.current_metrics.backward_error =
1865 self.estimate_backward_error(data, computation_result);
1866
1867 self.current_metrics.digit_loss = self.estimate_digit_loss();
1869
1870 self.current_metrics.update_stability_level();
1872
1873 self.current_metrics.detect_errors(computation_result);
1875
1876 self.current_metrics.timestamp = Instant::now();
1878
1879 self.stability_history
1881 .push_back(self.current_metrics.clone());
1882 if self.stability_history.len() > self.max_history_length {
1883 self.stability_history.pop_front();
1884 }
1885
1886 Ok(())
1887 }
1888
1889 pub fn adjust_precision(&mut self) -> SpatialResult<PrecisionMode> {
1891 if let Some(last_change) = self.last_precision_change {
1893 if last_change.elapsed() < self.precision_config.change_cooldown {
1894 return Ok(self.current_precision);
1895 }
1896 }
1897
1898 let new_precision = match self.current_metrics.stability_level {
1899 StabilityLevel::Critical => {
1900 self.precision_config.max_precision
1902 }
1903 StabilityLevel::Poor => {
1904 NumericalStabilityMonitor::increase_precision(self.current_precision)
1906 }
1907 StabilityLevel::Moderate => {
1908 if self.current_metrics.relative_error
1910 > self.precision_config.stability_threshold_up
1911 {
1912 NumericalStabilityMonitor::increase_precision(self.current_precision)
1913 } else {
1914 self.current_precision
1915 }
1916 }
1917 StabilityLevel::Good => {
1918 if self.current_metrics.relative_error
1920 < self.precision_config.stability_threshold_down
1921 {
1922 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1923 } else {
1924 self.current_precision
1925 }
1926 }
1927 StabilityLevel::Excellent => {
1928 if self.precision_config.strategy == ScalingStrategy::Aggressive {
1930 self.precision_config.min_precision
1931 } else {
1932 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1933 }
1934 }
1935 };
1936
1937 if new_precision != self.current_precision {
1939 self.precision_history.push_back((
1940 Instant::now(),
1941 new_precision,
1942 self.current_metrics.relative_error,
1943 ));
1944 self.current_precision = new_precision;
1945 self.last_precision_change = Some(Instant::now());
1946 }
1947
1948 Ok(new_precision)
1949 }
1950
1951 fn increase_precision(current: PrecisionMode) -> PrecisionMode {
1953 match current {
1954 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
1955 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
1956 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
1957 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
1958 PrecisionMode::Full32 => PrecisionMode::Full32, _ => PrecisionMode::Mixed16,
1960 }
1961 }
1962
1963 fn decrease_precision(current: PrecisionMode) -> PrecisionMode {
1965 match current {
1966 PrecisionMode::Full32 => PrecisionMode::BrainFloat16,
1967 PrecisionMode::BrainFloat16 => PrecisionMode::Mixed16,
1968 PrecisionMode::Mixed16 => PrecisionMode::Int8Dynamic,
1969 PrecisionMode::Int8Dynamic => PrecisionMode::Int4Advanced,
1970 PrecisionMode::Int4Advanced => PrecisionMode::Int4Advanced, _ => PrecisionMode::Mixed16,
1972 }
1973 }
1974
1975 fn estimate_condition_number(data: &Array2<f64>) -> f64 {
1977 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1979 let min_val = data.fold(f64::INFINITY, |acc, &x| {
1980 if x.abs() > 1e-15 {
1981 acc.min(x.abs())
1982 } else {
1983 acc
1984 }
1985 });
1986
1987 if min_val.is_finite() && min_val > 0.0 {
1988 max_val / min_val
1989 } else {
1990 1e12 }
1992 }
1993
1994 fn estimate_relative_error(&mut self, input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1996 let mean_val = output.mean_or(0.0);
1998 if mean_val.abs() > 1e-15 {
1999 let machine_eps = match self.current_precision {
2001 PrecisionMode::Full32 => 2.22e-16,
2002 PrecisionMode::Mixed16 | PrecisionMode::BrainFloat16 => 9.77e-4,
2003 PrecisionMode::Int8Dynamic => 1.0 / 256.0,
2004 PrecisionMode::Int4Advanced => 1.0 / 16.0,
2005 _ => 1e-6,
2006 };
2007 machine_eps * self.current_metrics.condition_number
2008 } else {
2009 0.0
2010 }
2011 }
2012
2013 fn estimate_forward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
2015 self.current_metrics.relative_error * self.current_metrics.condition_number
2017 }
2018
2019 fn estimate_backward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
2021 self.current_metrics.relative_error
2023 }
2024
2025 fn estimate_digit_loss(&self) -> f64 {
2027 if self.current_metrics.condition_number > 1.0 {
2028 self.current_metrics.condition_number.log10().max(0.0)
2029 } else {
2030 0.0
2031 }
2032 }
2033}
2034
2035impl Default for ErrorRecoverySystem {
2036 fn default() -> Self {
2037 Self::new()
2038 }
2039}
2040
2041impl ErrorRecoverySystem {
2042 pub fn new() -> Self {
2044 let mut recovery_strategies = HashMap::new();
2045
2046 recovery_strategies.insert(
2048 NumericalErrorType::Overflow,
2049 vec![
2050 RecoveryAction::IncreasePrecision,
2051 RecoveryAction::ReduceTileSize,
2052 RecoveryAction::NumericalStabilization,
2053 ],
2054 );
2055 recovery_strategies.insert(
2056 NumericalErrorType::Underflow,
2057 vec![
2058 RecoveryAction::IncreasePrecision,
2059 RecoveryAction::NumericalStabilization,
2060 ],
2061 );
2062 recovery_strategies.insert(
2063 NumericalErrorType::PrecisionLoss,
2064 vec![
2065 RecoveryAction::IncreasePrecision,
2066 RecoveryAction::RetryWithNewParams,
2067 ],
2068 );
2069 recovery_strategies.insert(
2070 NumericalErrorType::IllConditioned,
2071 vec![
2072 RecoveryAction::IncreasePrecision,
2073 RecoveryAction::NumericalStabilization,
2074 RecoveryAction::SwitchToCPU,
2075 ],
2076 );
2077 recovery_strategies.insert(
2078 NumericalErrorType::InvalidValues,
2079 vec![
2080 RecoveryAction::FallbackAlgorithm,
2081 RecoveryAction::SwitchToCPU,
2082 ],
2083 );
2084
2085 Self {
2086 recovery_strategies,
2087 recovery_history: VecDeque::new(),
2088 max_recovery_attempts: 3,
2089 success_rates: HashMap::new(),
2090 }
2091 }
2092
2093 pub async fn attempt_recovery(
2095 &mut self,
2096 error_type: NumericalErrorType,
2097 ) -> SpatialResult<RecoveryAction> {
2098 let start_time = Instant::now();
2099
2100 let strategies = self
2102 .recovery_strategies
2103 .get(&error_type)
2104 .ok_or_else(|| SpatialError::InvalidInput("Unknown error _type".to_string()))?
2105 .clone(); let best_action = self.choose_best_recovery_action(&strategies);
2109
2110 let attempt = RecoveryAttempt {
2112 error_type,
2113 action: best_action,
2114 success: false, duration: start_time.elapsed(),
2116 post_recovery_metrics: None,
2117 timestamp: start_time,
2118 };
2119
2120 self.recovery_history.push_back(attempt);
2121
2122 Ok(best_action)
2123 }
2124
2125 fn choose_best_recovery_action(&mut self, strategies: &[RecoveryAction]) -> RecoveryAction {
2127 strategies
2128 .iter()
2129 .max_by(|&a, &b| {
2130 let rate_a = self.success_rates.get(a).unwrap_or(&0.5);
2131 let rate_b = self.success_rates.get(b).unwrap_or(&0.5);
2132 rate_a
2133 .partial_cmp(rate_b)
2134 .unwrap_or(std::cmp::Ordering::Equal)
2135 })
2136 .copied()
2137 .unwrap_or(RecoveryAction::IncreasePrecision)
2138 }
2139
2140 pub fn update_success_rate(&mut self, action: RecoveryAction, success: bool) {
2142 let current_rate = self.success_rates.get(&action).unwrap_or(&0.5);
2143 let new_rate = if success {
2144 current_rate * 0.9 + 0.1 } else {
2146 current_rate * 0.9
2147 };
2148 self.success_rates.insert(action, new_rate);
2149 }
2150}
2151
2152impl PerformanceAccuracyAnalyzer {
2153 pub fn new(params: TradeOffParams) -> Self {
2155 Self {
2156 performance_data: HashMap::new(),
2157 accuracy_data: HashMap::new(),
2158 optimization_params: params,
2159 pareto_frontier: Vec::new(),
2160 }
2161 }
2162
2163 pub fn record_performance(&mut self, precision: PrecisionMode, duration: Duration) {
2165 self.performance_data
2166 .entry(precision)
2167 .or_default()
2168 .push_back(duration);
2169
2170 if let Some(history) = self.performance_data.get_mut(&precision) {
2172 if history.len() > 100 {
2173 history.pop_front();
2174 }
2175 }
2176 }
2177
2178 pub fn record_accuracy(&mut self, precision: PrecisionMode, accuracy: f64) {
2180 self.accuracy_data
2181 .entry(precision)
2182 .or_default()
2183 .push_back(accuracy);
2184
2185 if let Some(history) = self.accuracy_data.get_mut(&precision) {
2187 if history.len() > 100 {
2188 history.pop_front();
2189 }
2190 }
2191 }
2192
2193 pub fn optimize_precision(&mut self) -> PrecisionMode {
2195 self.update_pareto_frontier();
2196
2197 match self.optimization_params.objective {
2198 OptimizationObjective::MaxPerformance => self
2199 .pareto_frontier
2200 .iter()
2201 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
2202 .map(|(_a, b, mode)| *mode)
2203 .unwrap_or(PrecisionMode::Mixed16),
2204 OptimizationObjective::MaxAccuracy => self
2205 .pareto_frontier
2206 .iter()
2207 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
2208 .map(|(_a, b, mode)| *mode)
2209 .unwrap_or(PrecisionMode::Full32),
2210 OptimizationObjective::Balanced => {
2211 let mut best_score = f64::NEG_INFINITY;
2213 let mut best_mode = PrecisionMode::Mixed16;
2214
2215 let performance_weight = self.optimization_params.performance_weight;
2217 let accuracy_weight = self.optimization_params.accuracy_weight;
2218
2219 for &(perf, acc, mode) in &self.pareto_frontier {
2220 let perf_score = 1.0 / (perf + 1e-9);
2222 let score = performance_weight * perf_score + accuracy_weight * acc;
2223 if score > best_score {
2224 best_score = score;
2225 best_mode = mode;
2226 }
2227 }
2228
2229 best_mode
2230 }
2231 _ => PrecisionMode::Mixed16,
2232 }
2233 }
2234
2235 fn update_pareto_frontier(&mut self) {
2237 self.pareto_frontier.clear();
2238
2239 for precision in [
2240 PrecisionMode::Full32,
2241 PrecisionMode::BrainFloat16,
2242 PrecisionMode::Mixed16,
2243 PrecisionMode::Int8Dynamic,
2244 PrecisionMode::Int4Advanced,
2245 ] {
2246 if let (Some(perf_data), Some(acc_data)) = (
2247 self.performance_data.get(&precision),
2248 self.accuracy_data.get(&precision),
2249 ) {
2250 if !perf_data.is_empty() && !acc_data.is_empty() {
2251 let avg_perf = perf_data.iter().map(|d| d.as_secs_f64()).sum::<f64>()
2252 / perf_data.len() as f64;
2253 let avg_acc = acc_data.iter().sum::<f64>() / acc_data.len() as f64;
2254
2255 self.pareto_frontier.push((avg_perf, avg_acc, precision));
2256 }
2257 }
2258 }
2259 }
2260
2261 #[allow(dead_code)]
2263 fn compute_weighted_score(&mut self, performance: f64, accuracy: f64) -> f64 {
2264 let perf_score = 1.0 / (performance + 1e-9);
2266
2267 self.optimization_params.performance_weight * perf_score
2269 + self.optimization_params.accuracy_weight * accuracy
2270 }
2271}
2272
2273impl AdvancedTensorCoreDistanceMatrix {
2274 pub fn new() -> SpatialResult<Self> {
2276 let base_computer = TensorCoreDistanceMatrix::new()?;
2277 let precision_config = DynamicPrecisionConfig::default();
2278 let stability_monitor =
2279 Arc::new(Mutex::new(NumericalStabilityMonitor::new(precision_config)));
2280 let recovery_system = ErrorRecoverySystem::new();
2281 let trade_off_params = TradeOffParams {
2282 performance_weight: 0.6,
2283 accuracy_weight: 0.4,
2284 energy_weight: 0.0,
2285 min_accuracy: 0.95,
2286 max_time: Duration::from_secs(30),
2287 objective: OptimizationObjective::Balanced,
2288 };
2289 let performance_analyzer = PerformanceAccuracyAnalyzer::new(trade_off_params);
2290
2291 Ok(Self {
2292 base_computer,
2293 stability_monitor,
2294 recovery_system,
2295 performance_analyzer,
2296 dynamic_precision_enabled: true,
2297 auto_recovery_enabled: true,
2298 })
2299 }
2300
2301 pub fn with_dynamic_precision(mut self, enabled: bool) -> Self {
2303 self.dynamic_precision_enabled = enabled;
2304 self
2305 }
2306
2307 pub fn with_auto_recovery(mut self, enabled: bool) -> Self {
2309 self.auto_recovery_enabled = enabled;
2310 self
2311 }
2312
2313 pub async fn compute_with_stability_monitoring(
2315 &mut self,
2316 points: &ArrayView2<'_, f64>,
2317 ) -> SpatialResult<Array2<f64>> {
2318 let start_time = Instant::now();
2319
2320 {
2322 let mut monitor = self.stability_monitor.lock().unwrap();
2323 if self.dynamic_precision_enabled {
2326 let optimal_precision = monitor.adjust_precision()?;
2327 self.base_computer.precision_mode = optimal_precision;
2328 }
2329 }
2330
2331 let mut result = None;
2332 let mut recovery_attempts = 0;
2333 let max_attempts = 3;
2334
2335 while result.is_none() && recovery_attempts < max_attempts {
2336 match self.base_computer.compute_parallel(points).await {
2337 Ok(distances) => {
2338 {
2340 let mut monitor = self.stability_monitor.lock().unwrap();
2341 monitor.monitor_stability(&points.to_owned(), &distances)?;
2342 }
2343
2344 let stability_level = {
2346 let monitor = self.stability_monitor.lock().unwrap();
2347 monitor.current_metrics.stability_level
2348 };
2349
2350 if stability_level == StabilityLevel::Critical && self.auto_recovery_enabled {
2351 recovery_attempts += 1;
2353 let recovery_action = self
2354 .recovery_system
2355 .attempt_recovery(NumericalErrorType::IllConditioned)
2356 .await?;
2357
2358 self.apply_recovery_action(recovery_action).await?;
2360 continue;
2361 } else {
2362 result = Some(distances);
2363 }
2364 }
2365 Err(e) => {
2366 if self.auto_recovery_enabled && recovery_attempts < max_attempts {
2367 recovery_attempts += 1;
2368 let recovery_action = self
2369 .recovery_system
2370 .attempt_recovery(NumericalErrorType::InvalidValues)
2371 .await?;
2372 self.apply_recovery_action(recovery_action).await?;
2373 continue;
2374 } else {
2375 return Err(e);
2376 }
2377 }
2378 }
2379 }
2380
2381 let final_result = result.ok_or_else(|| {
2382 SpatialError::InvalidInput(
2383 "Failed to compute stable result after recovery attempts".to_string(),
2384 )
2385 })?;
2386
2387 let duration = start_time.elapsed();
2389 let precision = self.base_computer.precision_mode;
2390 self.performance_analyzer
2391 .record_performance(precision, duration);
2392
2393 let accuracy = self.estimate_result_accuracy(&final_result);
2395 self.performance_analyzer
2396 .record_accuracy(precision, accuracy);
2397
2398 Ok(final_result)
2399 }
2400
2401 async fn apply_recovery_action(&mut self, action: RecoveryAction) -> SpatialResult<()> {
2403 match action {
2404 RecoveryAction::IncreasePrecision => {
2405 self.base_computer.precision_mode = match self.base_computer.precision_mode {
2406 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
2407 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
2408 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
2409 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
2410 PrecisionMode::Full32 => PrecisionMode::Full32,
2411 _ => PrecisionMode::Mixed16,
2412 };
2413 }
2414 RecoveryAction::ReduceTileSize => {
2415 let (current_row, current_col) = self.base_computer.tile_size;
2416 self.base_computer.tile_size = (current_row / 2, current_col / 2);
2417 if self.base_computer.tile_size.0 < 16 {
2418 self.base_computer.tile_size = (16, 16);
2419 }
2420 }
2421 RecoveryAction::FallbackAlgorithm => {
2422 self.base_computer.precision_mode = PrecisionMode::Full32;
2424 self.base_computer.hierarchical_tiling = false;
2425 }
2426 RecoveryAction::NumericalStabilization => {
2427 self.base_computer.precision_mode = PrecisionMode::Full32;
2429 self.base_computer.tile_size = (64, 64);
2430 }
2431 _ => {
2432 self.base_computer.precision_mode = PrecisionMode::Full32;
2434 }
2435 }
2436
2437 Ok(())
2438 }
2439
2440 fn estimate_result_accuracy(&self, result: &Array2<f64>) -> f64 {
2442 let has_invalid = result.iter().any(|&x| !x.is_finite());
2444 if has_invalid {
2445 return 0.0;
2446 }
2447
2448 let max_val = result.fold(0.0f64, |acc, &x| acc.max(x.abs()));
2449 let min_val = result.fold(f64::INFINITY, |acc, &x| {
2450 if x.abs() > 1e-15 {
2451 acc.min(x.abs())
2452 } else {
2453 acc
2454 }
2455 });
2456
2457 if min_val.is_finite() && min_val > 0.0 {
2458 let dynamic_range = max_val / min_val;
2459 (1.0 / (1.0 + dynamic_range.log10() / 10.0)).clamp(0.0, 1.0)
2460 } else {
2461 0.95 }
2463 }
2464}
2465
2466#[allow(dead_code)]
2468pub fn detect_tensor_core_capabilities() -> SpatialResult<TensorCoreCapabilities> {
2469 Ok(TensorCoreCapabilities {
2473 tensor_core_types: vec![
2474 TensorCoreType::NvidiaTensorCore,
2475 TensorCoreType::StandardCores,
2476 ],
2477 supported_precisions: vec![
2478 PrecisionMode::Full32,
2479 PrecisionMode::Mixed16,
2480 PrecisionMode::BrainFloat16,
2481 PrecisionMode::Int8Dynamic,
2482 ],
2483 max_tensor_size: (4096, 4096, 4096),
2484 peak_throughput_tops: 312.0, memory_bandwidth_gbps: 1555.0, l2_cache_mb: 40.0,
2487 num_sms: 108,
2488 architecture: GpuArchitecture::Ampere,
2489 })
2490}
2491
2492impl TensorCoreDistanceMatrix {
2494 pub async fn compute_distances_to_centroids(
2496 &self,
2497 points: &ArrayView2<'_, f64>,
2498 centroids: &ArrayView2<'_, f64>,
2499 ) -> SpatialResult<Array2<f64>> {
2500 let (npoints, ndims) = points.dim();
2501 let (n_clusters_, n_dims_c) = centroids.dim();
2502 let mut distances = Array2::zeros((npoints, n_clusters_));
2503
2504 let cluster_chunks = n_clusters_ / 4;
2506
2507 for i in 0..npoints {
2508 let point_row = points.row(i);
2509
2510 for j_chunk in 0..cluster_chunks {
2512 let j_base = j_chunk * 4;
2513
2514 let j0 = j_base;
2516 let j1 = j_base + 1;
2517 let j2 = j_base + 2;
2518 let j3 = j_base + 3;
2519
2520 let centroid_row0 = centroids.row(j0);
2521 let centroid_row1 = centroids.row(j1);
2522 let centroid_row2 = centroids.row(j2);
2523 let centroid_row3 = centroids.row(j3);
2524
2525 let distance0: f64 = point_row
2526 .iter()
2527 .zip(centroid_row0.iter())
2528 .map(|(&a, &b)| (a - b).powi(2))
2529 .sum::<f64>()
2530 .sqrt();
2531
2532 let distance1: f64 = point_row
2533 .iter()
2534 .zip(centroid_row1.iter())
2535 .map(|(&a, &b)| (a - b).powi(2))
2536 .sum::<f64>()
2537 .sqrt();
2538
2539 let distance2: f64 = point_row
2540 .iter()
2541 .zip(centroid_row2.iter())
2542 .map(|(&a, &b)| (a - b).powi(2))
2543 .sum::<f64>()
2544 .sqrt();
2545
2546 let distance3: f64 = point_row
2547 .iter()
2548 .zip(centroid_row3.iter())
2549 .map(|(&a, &b)| (a - b).powi(2))
2550 .sum::<f64>()
2551 .sqrt();
2552
2553 distances[[i, j0]] = distance0;
2554 distances[[i, j1]] = distance1;
2555 distances[[i, j2]] = distance2;
2556 distances[[i, j3]] = distance3;
2557 }
2558
2559 for j in (cluster_chunks * 4)..n_clusters_ {
2561 let distance: f64 = point_row
2562 .iter()
2563 .zip(centroids.row(j).iter())
2564 .map(|(&a, &b)| (a - b).powi(2))
2565 .sum::<f64>()
2566 .sqrt();
2567 distances[[i, j]] = distance;
2568 }
2569 }
2570
2571 Ok(distances)
2572 }
2573}
2574
2575#[cfg(test)]
2576#[path = "tensor_cores_tests.rs"]
2577mod tests;