1use crate::error::{SpatialError, SpatialResult};
57use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
58use scirs2_core::random::Rng;
59use statrs::statistics::Statistics;
60use std::collections::{HashMap, VecDeque};
61use std::sync::{Arc, Mutex};
62use std::time::{Duration, Instant};
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum PrecisionMode {
67 Full32,
69 Mixed16,
71 BrainFloat16,
73 Int8Dynamic,
75 Int4Advanced,
77 Adaptive,
79 AdvancedAdaptive,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq)]
85pub enum StabilityLevel {
86 Excellent,
88 Good,
90 Moderate,
92 Poor,
94 Critical,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100pub enum NumericalErrorType {
101 Overflow,
103 Underflow,
105 PrecisionLoss,
107 ConvergenceFailure,
109 IllConditioned,
111 InvalidValues,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq)]
117pub enum ScalingStrategy {
118 Conservative,
120 Balanced,
122 Aggressive,
124 Custom,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq)]
130pub enum TensorLayout {
131 RowMajor,
133 ColMajor,
135 Blocked,
137 ZOrder,
139 HardwareOptimized,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq)]
145pub enum GpuArchitecture {
146 Volta,
148 Ampere,
150 Hopper,
152 CDNA2,
154 CDNA3,
156 XeHPC,
158 XeGraphics,
160 Unknown,
162}
163
164#[derive(Debug, Clone)]
166pub struct TensorCoreCapabilities {
167 pub tensor_core_types: Vec<TensorCoreType>,
169 pub supported_precisions: Vec<PrecisionMode>,
171 pub max_tensor_size: (usize, usize, usize),
173 pub peak_throughput_tops: f64,
175 pub memory_bandwidth_gbps: f64,
177 pub l2_cache_mb: f64,
179 pub num_sms: usize,
181 pub architecture: GpuArchitecture,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq)]
187pub enum TensorCoreType {
188 NvidiaTensorCore,
190 AmdMatrixCore,
192 IntelXMX,
194 StandardCores,
196}
197
198#[derive(Debug, Clone)]
200pub struct StabilityMetrics {
201 pub condition_number: f64,
203 pub relative_error: f64,
205 pub forward_error: f64,
207 pub backward_error: f64,
209 pub digit_loss: f64,
211 pub stability_level: StabilityLevel,
213 pub error_types: Vec<NumericalErrorType>,
215 pub timestamp: Instant,
217}
218
219#[derive(Debug, Clone)]
221pub struct DynamicPrecisionConfig {
222 pub strategy: ScalingStrategy,
224 pub min_precision: PrecisionMode,
226 pub max_precision: PrecisionMode,
228 pub stability_threshold_up: f64,
230 pub stability_threshold_down: f64,
232 pub performance_weight: f64,
234 pub accuracy_weight: f64,
236 pub max_changes_per_operation: usize,
238 pub change_cooldown: Duration,
240}
241
242#[allow(dead_code)]
244#[derive(Debug)]
245pub struct NumericalStabilityMonitor {
246 current_metrics: StabilityMetrics,
248 stability_history: VecDeque<StabilityMetrics>,
250 precision_config: DynamicPrecisionConfig,
252 current_precision: PrecisionMode,
254 precision_history: VecDeque<(Instant, PrecisionMode, f64)>,
256 #[allow(dead_code)]
258 recovery_attempts: usize,
259 max_history_length: usize,
261 last_precision_change: Option<Instant>,
263}
264
265#[allow(dead_code)]
267#[derive(Debug)]
268pub struct ErrorRecoverySystem {
269 recovery_strategies: HashMap<NumericalErrorType, Vec<RecoveryAction>>,
271 recovery_history: VecDeque<RecoveryAttempt>,
273 max_recovery_attempts: usize,
275 success_rates: HashMap<RecoveryAction, f64>,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
281pub enum RecoveryAction {
282 IncreasePrecision,
284 ReduceTileSize,
286 FallbackAlgorithm,
288 NumericalStabilization,
290 RetryWithNewParams,
292 SwitchToCPU,
294}
295
296#[derive(Debug, Clone)]
298pub struct RecoveryAttempt {
299 pub error_type: NumericalErrorType,
301 pub action: RecoveryAction,
303 pub success: bool,
305 pub duration: Duration,
307 pub post_recovery_metrics: Option<StabilityMetrics>,
309 pub timestamp: Instant,
311}
312
313#[derive(Debug)]
315pub struct PerformanceAccuracyAnalyzer {
316 performance_data: HashMap<PrecisionMode, VecDeque<Duration>>,
318 accuracy_data: HashMap<PrecisionMode, VecDeque<f64>>,
320 optimization_params: TradeOffParams,
322 pareto_frontier: Vec<(f64, f64, PrecisionMode)>, }
325
326#[derive(Debug, Clone)]
328pub struct TradeOffParams {
329 pub performance_weight: f64,
331 pub accuracy_weight: f64,
333 pub energy_weight: f64,
335 pub min_accuracy: f64,
337 pub max_time: Duration,
339 pub objective: OptimizationObjective,
341}
342
343#[derive(Debug, Clone, Copy, PartialEq)]
345pub enum OptimizationObjective {
346 MaxPerformance,
348 MaxAccuracy,
350 Balanced,
352 MinEnergy,
354 Custom,
356}
357
358#[derive(Debug)]
360pub struct AdvancedTensorCoreDistanceMatrix {
361 base_computer: TensorCoreDistanceMatrix,
363 stability_monitor: Arc<Mutex<NumericalStabilityMonitor>>,
365 recovery_system: ErrorRecoverySystem,
367 performance_analyzer: PerformanceAccuracyAnalyzer,
369 dynamic_precision_enabled: bool,
371 auto_recovery_enabled: bool,
373}
374
375#[derive(Debug, Clone)]
377pub struct TensorCoreDistanceMatrix {
378 precision_mode: PrecisionMode,
380 layout_optimization: bool,
382 hierarchical_tiling: bool,
384 tile_size: (usize, usize),
386 capabilities: Option<TensorCoreCapabilities>,
388 tensor_layout: TensorLayout,
390 execution_streams: usize,
392}
393
394impl TensorCoreDistanceMatrix {
395 pub fn new() -> SpatialResult<Self> {
397 let capabilities = detect_tensor_core_capabilities()?;
398
399 Ok(Self {
400 precision_mode: PrecisionMode::Mixed16,
401 layout_optimization: true,
402 hierarchical_tiling: true,
403 tile_size: (256, 256),
404 capabilities: Some(capabilities),
405 tensor_layout: TensorLayout::HardwareOptimized,
406 execution_streams: 4,
407 })
408 }
409
410 pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
412 self.precision_mode = mode;
413 self
414 }
415
416 pub fn with_tensor_layout_optimization(mut self, enabled: bool) -> Self {
418 self.layout_optimization = enabled;
419 self
420 }
421
422 pub fn with_hierarchical_tiling(mut self, enabled: bool) -> Self {
424 self.hierarchical_tiling = enabled;
425 self
426 }
427
428 pub fn with_tile_size(mut self, rows: usize, cols: usize) -> Self {
430 self.tile_size = (rows, cols);
431 self
432 }
433
434 pub fn with_execution_streams(mut self, streams: usize) -> Self {
436 self.execution_streams = streams;
437 self
438 }
439
440 pub async fn compute_parallel(
442 &mut self,
443 points: &ArrayView2<'_, f64>,
444 ) -> SpatialResult<Array2<f64>> {
445 let (npoints, ndims) = points.dim();
446
447 if npoints == 0 || ndims == 0 {
448 return Err(SpatialError::InvalidInput("Empty input data".to_string()));
449 }
450
451 let optimizedpoints = if self.layout_optimization {
453 self.optimize_tensor_layout(points)?
454 } else {
455 points.to_owned()
456 };
457
458 if self.hierarchical_tiling && npoints > 1024 {
460 self.compute_hierarchical_tiled(&optimizedpoints.view())
461 .await
462 } else {
463 self.compute_direct_tensor_cores(&optimizedpoints.view())
464 .await
465 }
466 }
467
468 fn optimize_tensor_layout(
470 &mut self,
471 points: &ArrayView2<'_, f64>,
472 ) -> SpatialResult<Array2<f64>> {
473 let (npoints, ndims) = points.dim();
474
475 match self.tensor_layout {
476 TensorLayout::RowMajor => Ok(points.to_owned()),
477 TensorLayout::ColMajor => {
478 let mut transposed = Array2::zeros((ndims, npoints));
479 for (i, point) in points.outer_iter().enumerate() {
480 transposed.column_mut(i).assign(&point);
481 }
482 Ok(transposed.t().to_owned())
483 }
484 TensorLayout::Blocked => TensorCoreDistanceMatrix::create_blocked_layout(points),
485 TensorLayout::ZOrder => self.create_zorder_layout(points),
486 TensorLayout::HardwareOptimized => self.create_hardware_optimized_layout(points),
487 }
488 }
489
490 fn create_blocked_layout(points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
492 let (npoints, ndims) = points.dim();
493 let block_size = 64; let blocked_rows = npoints.div_ceil(block_size) * block_size;
496 let blocked_cols = ndims.div_ceil(block_size) * block_size;
497
498 let mut blocked_data = Array2::zeros((blocked_rows, blocked_cols));
499
500 for block_i in 0..(npoints / block_size + 1) {
501 for block_j in 0..(ndims / block_size + 1) {
502 let start_i = block_i * block_size;
503 let start_j = block_j * block_size;
504 let end_i = (start_i + block_size).min(npoints);
505 let end_j = (start_j + block_size).min(ndims);
506
507 for i in start_i..end_i {
508 for j in start_j..end_j {
509 blocked_data[[i, j]] = points[[i, j]];
510 }
511 }
512 }
513 }
514
515 Ok(blocked_data.slice(s![..npoints, ..ndims]).to_owned())
516 }
517
518 fn create_zorder_layout(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
520 let (npoints, ndims) = points.dim();
521
522 let mut z_indices: Vec<(usize, usize)> = (0..npoints)
524 .map(|i| {
525 (
526 i,
527 TensorCoreDistanceMatrix::calculate_z_order_index(i, ndims),
528 )
529 })
530 .collect();
531
532 z_indices.sort_by_key(|(_, z_idx)| *z_idx);
533
534 let mut reordered_data = Array2::zeros((npoints, ndims));
535 for (new_idx, (old_idx, z_idx)) in z_indices.iter().enumerate() {
536 reordered_data
537 .row_mut(new_idx)
538 .assign(&points.row(*old_idx));
539 }
540
541 Ok(reordered_data)
542 }
543
544 fn calculate_z_order_index(point_idx: usize, ndims: usize) -> usize {
546 let mut z_index = 0;
548 let temp_idx = point_idx;
549
550 for bit in 0..16 {
551 for dim in 0..ndims.min(3) {
553 if temp_idx & (1 << bit) != 0 {
555 z_index |= 1 << (bit * ndims + dim);
556 }
557 }
558 }
559
560 z_index
561 }
562
563 fn create_hardware_optimized_layout(
565 &self,
566 points: &ArrayView2<'_, f64>,
567 ) -> SpatialResult<Array2<f64>> {
568 if let Some(ref capabilities) = self.capabilities {
569 match capabilities.architecture {
570 GpuArchitecture::Ampere | GpuArchitecture::Hopper => {
571 self.create_nvidia_optimized_layout(points)
573 }
574 GpuArchitecture::CDNA2 | GpuArchitecture::CDNA3 => {
575 self.create_amd_optimized_layout(points)
577 }
578 GpuArchitecture::XeHPC | GpuArchitecture::XeGraphics => {
579 self.create_intel_optimized_layout(points)
581 }
582 _ => {
583 TensorCoreDistanceMatrix::create_blocked_layout(points)
585 }
586 }
587 } else {
588 TensorCoreDistanceMatrix::create_blocked_layout(points)
589 }
590 }
591
592 fn create_nvidia_optimized_layout(
594 &self,
595 points: &ArrayView2<'_, f64>,
596 ) -> SpatialResult<Array2<f64>> {
597 let (npoints, ndims) = points.dim();
598
599 let paddedpoints = npoints.div_ceil(8) * 8;
601 let padded_dims = ndims.div_ceil(8) * 8;
602
603 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
604
605 for i in 0..npoints {
607 for j in 0..ndims {
608 padded_data[[i, j]] = points[[i, j]];
609 }
610 }
611
612 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
614 }
615
616 fn create_amd_optimized_layout(
618 &self,
619 points: &ArrayView2<'_, f64>,
620 ) -> SpatialResult<Array2<f64>> {
621 let (npoints, ndims) = points.dim();
622
623 let paddedpoints = npoints.div_ceil(16) * 16;
625 let padded_dims = ndims.div_ceil(16) * 16;
626
627 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
628
629 for i in 0..npoints {
630 for j in 0..ndims {
631 padded_data[[i, j]] = points[[i, j]];
632 }
633 }
634
635 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
636 }
637
638 fn create_intel_optimized_layout(
640 &self,
641 points: &ArrayView2<'_, f64>,
642 ) -> SpatialResult<Array2<f64>> {
643 let (npoints, ndims) = points.dim();
644
645 let paddedpoints = npoints.div_ceil(32) * 32;
647 let padded_dims = ndims.div_ceil(32) * 32;
648
649 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
650
651 for i in 0..npoints {
652 for j in 0..ndims {
653 padded_data[[i, j]] = points[[i, j]];
654 }
655 }
656
657 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
658 }
659
660 async fn compute_hierarchical_tiled(
662 &mut self,
663 points: &ArrayView2<'_, f64>,
664 ) -> SpatialResult<Array2<f64>> {
665 let (npoints, ndims) = points.dim();
666 let mut distance_matrix = Array2::zeros((npoints, npoints));
667
668 let (tile_rows, tile_cols) = self.tile_size;
669 let precision_mode = self.precision_mode; let mut tile_futures = Vec::new();
673
674 for i in (0..npoints).step_by(tile_rows) {
675 for j in (0..npoints).step_by(tile_cols) {
676 let end_i = (i + tile_rows).min(npoints);
677 let end_j = (j + tile_cols).min(npoints);
678
679 let tilepoints_i = points.slice(s![i..end_i, ..]).to_owned();
680 let tilepoints_j = points.slice(s![j..end_j, ..]).to_owned();
681
682 let future = async move {
684 let (rows_i, _) = tilepoints_i.dim();
686 let (rows_j, _) = tilepoints_j.dim();
687 let mut tile_distances = Array2::zeros((rows_i, rows_j));
688
689 for r in 0..rows_i {
690 for c in 0..rows_j {
691 let p1 = tilepoints_i.row(r);
692 let p2 = tilepoints_j.row(c);
693
694 let dist = if ndims <= 16 {
696 use scirs2_core::simd_ops::SimdUnifiedOps;
698 let diff = f64::simd_sub(&p1, &p2);
699 let squared = f64::simd_mul(&diff.view(), &diff.view());
700 f64::simd_sum(&squared.view()).sqrt()
701 } else {
702 let diff = &p1 - &p2;
704 diff.iter().map(|x| x.powi(2)).sum::<f64>().sqrt()
705 };
706 tile_distances[[r, c]] = dist;
707 }
708 }
709 Ok::<Array2<f64>, SpatialError>(tile_distances)
710 };
711 tile_futures.push((i, j, end_i, end_j, future));
712 }
713 }
714
715 for (i, j, end_i, end_j, future) in tile_futures {
717 let tile_result = future.await?;
718
719 let tile_rows = end_i - i;
721 let tile_cols = end_j - j;
722
723 for row in 0..tile_rows {
724 for col in 0..tile_cols {
725 distance_matrix[[i + row, j + col]] = tile_result[[row, col]];
726 }
727 }
728 }
729
730 Ok(distance_matrix)
731 }
732
733 async fn compute_tile_tensor_cores(
735 &mut self,
736 points_i: Array2<f64>,
737 points_j: Array2<f64>,
738 precision_mode: PrecisionMode,
739 ) -> SpatialResult<Array2<f64>> {
740 let (_n_i, ndims) = points_i.dim();
741 let (_n_j, _) = points_j.dim();
742
743 match precision_mode {
744 PrecisionMode::Full32 => {
745 self.compute_distances_fp32(&points_i.view(), &points_j.view())
746 .await
747 }
748 PrecisionMode::Mixed16 => {
749 self.compute_distances_mixed16(&points_i.view(), &points_j.view())
750 .await
751 }
752 PrecisionMode::BrainFloat16 => {
753 self.compute_distances_bf16(&points_i.view(), &points_j.view())
754 .await
755 }
756 PrecisionMode::Int8Dynamic => {
757 self.compute_distances_int8(&points_i.view(), &points_j.view())
758 .await
759 }
760 PrecisionMode::Int4Advanced => {
761 self.compute_distances_int4(&points_i.view(), &points_j.view())
762 .await
763 }
764 PrecisionMode::Adaptive => {
765 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
766 .await
767 }
768 PrecisionMode::AdvancedAdaptive => {
769 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
770 .await
771 }
772 }
773 }
774
775 async fn compute_direct_tensor_cores(
777 &mut self,
778 points: &ArrayView2<'_, f64>,
779 ) -> SpatialResult<Array2<f64>> {
780 self.compute_tile_tensor_cores(points.to_owned(), points.to_owned(), self.precision_mode)
781 .await
782 }
783
784 async fn compute_distances_fp32(
786 &self,
787 points_i: &ArrayView2<'_, f64>,
788 points_j: &ArrayView2<'_, f64>,
789 ) -> SpatialResult<Array2<f64>> {
790 let (n_i, ndims) = points_i.dim();
791 let (n_j, _) = points_j.dim();
792 let mut distances = Array2::zeros((n_i, n_j));
793
794 let norms_i: Array1<f64> = points_i
799 .outer_iter()
800 .map(|point| point.iter().map(|&x| x * x).sum())
801 .collect();
802
803 let norms_j: Array1<f64> = points_j
805 .outer_iter()
806 .map(|point| point.iter().map(|&x| x * x).sum())
807 .collect();
808
809 let cross_terms = self
811 .tensor_core_gemm_fp32(points_i, &points_j.t().to_owned().view())
812 .await?;
813
814 for _i in 0..n_i {
816 for _j in 0..n_j {
817 distances[[_i, _j]] = (norms_i[_i] + norms_j[_j] - 2.0 * cross_terms[[_i, _j]])
818 .max(0.0)
819 .sqrt();
820 }
821 }
822
823 Ok(distances)
824 }
825
826 async fn compute_distances_mixed16(
828 &self,
829 points_i: &ArrayView2<'_, f64>,
830 points_j: &ArrayView2<'_, f64>,
831 ) -> SpatialResult<Array2<f64>> {
832 let points_i_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_i)?;
834 let points_j_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_j)?;
835
836 let (n_i, _) = points_i.dim();
837 let (n_j, _) = points_j.dim();
838 let mut distances = Array2::zeros((n_i, n_j));
839
840 let norms_i_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_i_f16)?;
842 let norms_j_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_j_f16)?;
843
844 let cross_terms = self
846 .tensor_core_gemm_mixed16(&points_i_f16, &points_j_f16.t().to_owned())
847 .await?;
848
849 for _i in 0..n_i {
850 for _j in 0..n_j {
851 let distance_sq = norms_i_f16[_i] as f64 + norms_j_f16[_j] as f64
852 - 2.0 * cross_terms[[_i, _j]] as f64;
853 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
854 }
855 }
856
857 Ok(distances)
858 }
859
860 async fn compute_distances_bf16(
862 &mut self,
863 points_i: &ArrayView2<'_, f64>,
864 points_j: &ArrayView2<'_, f64>,
865 ) -> SpatialResult<Array2<f64>> {
866 let points_i_bf16 = self.convert_to_bf16(points_i)?;
869 let points_j_bf16 = self.convert_to_bf16(points_j)?;
870
871 let (n_i, _) = points_i.dim();
872 let (n_j, _) = points_j.dim();
873 let mut distances = Array2::zeros((n_i, n_j));
874
875 let norms_i_bf16 = self.compute_norms_bf16(&points_i_bf16)?;
876 let norms_j_bf16 = self.compute_norms_bf16(&points_j_bf16)?;
877
878 let cross_terms = self
879 .tensor_core_gemm_bf16(&points_i_bf16, &points_j_bf16.t().to_owned())
880 .await?;
881
882 for _i in 0..n_i {
883 for _j in 0..n_j {
884 let distance_sq = norms_i_bf16[_i] as f64 + norms_j_bf16[_j] as f64
885 - 2.0 * cross_terms[[_i, _j]] as f64;
886 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
887 }
888 }
889
890 Ok(distances)
891 }
892
893 async fn compute_distances_int8(
895 &self,
896 points_i: &ArrayView2<'_, f64>,
897 points_j: &ArrayView2<'_, f64>,
898 ) -> SpatialResult<Array2<f64>> {
899 let (scale_i, points_i_int8) = self.quantize_to_int8_dynamic(points_i)?;
901 let (scale_j, points_j_int8) = self.quantize_to_int8_dynamic(points_j)?;
902
903 let (n_i, _) = points_i.dim();
904 let (n_j, _) = points_j.dim();
905 let mut distances = Array2::zeros((n_i, n_j));
906
907 let combined_scale = scale_i * scale_j;
909
910 for _i in 0..n_i {
911 for _j in 0..n_j {
912 let cross_term_int32 = points_i_int8
914 .row(_i)
915 .iter()
916 .zip(points_j_int8.row(_j).iter())
917 .map(|(&a, &b)| (a as i32) * (b as i32))
918 .sum::<i32>();
919 let cross_term_f64 = cross_term_int32 as f64 * combined_scale;
920
921 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
923 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
924
925 let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
926 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
927 }
928 }
929
930 Ok(distances)
931 }
932
933 async fn compute_distances_int4(
935 &self,
936 points_i: &ArrayView2<'_, f64>,
937 points_j: &ArrayView2<'_, f64>,
938 ) -> SpatialResult<Array2<f64>> {
939 let (scale_i, points_i_int4) = self.quantize_to_int4_advanced(points_i)?;
941 let (scale_j, points_j_int4) = self.quantize_to_int4_advanced(points_j)?;
942
943 let points_i_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_i_int4);
945 let points_j_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_j_int4);
946
947 let (n_i, _) = points_i.dim();
948 let (n_j, _) = points_j.dim();
949 let mut distances = Array2::zeros((n_i, n_j));
950
951 let n_i_chunks = n_i / 4;
957 let n_j_chunks = n_j / 4;
958
959 for i_chunk in 0..n_i_chunks {
961 for j_chunk in 0..n_j_chunks {
962 let i_base = i_chunk * 4;
963 let j_base = j_chunk * 4;
964
965 for i_offset in 0..4 {
967 let _i = i_base + i_offset;
968 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
969
970 let _j0 = j_base;
972 let _j1 = j_base + 1;
973 let _j2 = j_base + 2;
974 let _j3 = j_base + 3;
975
976 let norm_j0_sq: f64 = points_j.row(_j0).iter().map(|&x| x * x).sum();
977 let norm_j1_sq: f64 = points_j.row(_j1).iter().map(|&x| x * x).sum();
978 let norm_j2_sq: f64 = points_j.row(_j2).iter().map(|&x| x * x).sum();
979 let norm_j3_sq: f64 = points_j.row(_j3).iter().map(|&x| x * x).sum();
980
981 let cross_term_f64 = 0.0; let distance_sq0 = norm_i_sq + norm_j0_sq - 2.0 * cross_term_f64;
985 let distance_sq1 = norm_i_sq + norm_j1_sq - 2.0 * cross_term_f64;
986 let distance_sq2 = norm_i_sq + norm_j2_sq - 2.0 * cross_term_f64;
987 let distance_sq3 = norm_i_sq + norm_j3_sq - 2.0 * cross_term_f64;
988
989 distances[[_i, _j0]] = distance_sq0.max(0.0).sqrt();
990 distances[[_i, _j1]] = distance_sq1.max(0.0).sqrt();
991 distances[[_i, _j2]] = distance_sq2.max(0.0).sqrt();
992 distances[[_i, _j3]] = distance_sq3.max(0.0).sqrt();
993 }
994 }
995 }
996
997 for _i in (n_i_chunks * 4)..n_i {
999 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
1000 for _j in 0..n_j {
1001 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
1002 let cross_term_f64 = 0.0; let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
1004 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
1005 }
1006 }
1007
1008 for _i in 0..(n_i_chunks * 4) {
1010 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
1011 for _j in (n_j_chunks * 4)..n_j {
1012 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
1013 let cross_term_f64 = 0.0; let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
1015 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
1016 }
1017 }
1018
1019 Ok(distances)
1020 }
1021
1022 async fn compute_distances_adaptive(
1024 &mut self,
1025 points_i: &ArrayView2<'_, f64>,
1026 points_j: &ArrayView2<'_, f64>,
1027 ) -> SpatialResult<Array2<f64>> {
1028 let data_range = self.analyze_data_range(points_i, points_j);
1030 let condition_number = self.estimate_condition_number(points_i, points_j);
1031
1032 let optimal_precision = if condition_number > 1e6 {
1033 PrecisionMode::Full32
1034 } else if data_range > 1e3 {
1035 PrecisionMode::BrainFloat16
1036 } else if data_range > 100.0 {
1037 PrecisionMode::Mixed16
1038 } else {
1039 PrecisionMode::Int8Dynamic
1040 };
1041
1042 match optimal_precision {
1043 PrecisionMode::Full32 => self.compute_distances_fp32(points_i, points_j).await,
1044 PrecisionMode::Mixed16 => self.compute_distances_mixed16(points_i, points_j).await,
1045 PrecisionMode::BrainFloat16 => self.compute_distances_bf16(points_i, points_j).await,
1046 PrecisionMode::Int8Dynamic => self.compute_distances_int8(points_i, points_j).await,
1047 PrecisionMode::Int4Advanced => self.compute_distances_int8(points_i, points_j).await, PrecisionMode::Adaptive => self.compute_distances_mixed16(points_i, points_j).await, PrecisionMode::AdvancedAdaptive => {
1050 self.compute_distances_fp32(points_i, points_j).await
1051 } }
1053 }
1054
1055 async fn tensor_core_gemm_fp32(
1057 &self,
1058 a: &ArrayView2<'_, f64>,
1059 b: &ArrayView2<'_, f64>,
1060 ) -> SpatialResult<Array2<f64>> {
1061 let (m, k) = a.dim();
1063 let (k2, n) = b.dim();
1064
1065 if k != k2 {
1066 return Err(SpatialError::InvalidInput(
1067 "Matrix dimensions don't match for multiplication".to_string(),
1068 ));
1069 }
1070
1071 let mut c = Array2::zeros((m, n));
1072
1073 let block_size = 16; for i in (0..m).step_by(block_size) {
1077 for j in (0..n).step_by(block_size) {
1078 for kk in (0..k).step_by(block_size) {
1079 let end_i = (i + block_size).min(m);
1080 let end_j = (j + block_size).min(n);
1081 let end_k = (kk + block_size).min(k);
1082
1083 let block_rows = end_i - i;
1085 let block_cols = end_j - j;
1086 let block_k = end_k - kk;
1087
1088 let k_chunks = block_k / 4;
1090
1091 for ii in i..end_i {
1092 for jj in j..end_j {
1093 let mut accumulator = c[[ii, jj]];
1094
1095 for k_chunk in 0..k_chunks {
1097 let k_base = kk + k_chunk * 4;
1098
1099 let a_val0 = a[[ii, k_base]];
1101 let a_val1 = a[[ii, k_base + 1]];
1102 let a_val2 = a[[ii, k_base + 2]];
1103 let a_val3 = a[[ii, k_base + 3]];
1104
1105 let b_val0 = b[[k_base, jj]];
1106 let b_val1 = b[[k_base + 1, jj]];
1107 let b_val2 = b[[k_base + 2, jj]];
1108 let b_val3 = b[[k_base + 3, jj]];
1109
1110 accumulator += a_val0 * b_val0
1111 + a_val1 * b_val1
1112 + a_val2 * b_val2
1113 + a_val3 * b_val3;
1114 }
1115
1116 for kkk in (kk + k_chunks * 4)..end_k {
1118 accumulator += a[[ii, kkk]] * b[[kkk, jj]];
1119 }
1120
1121 c[[ii, jj]] = accumulator;
1122 }
1123 }
1124 }
1125 }
1126 }
1127
1128 Ok(c)
1129 }
1130
1131 async fn tensor_core_gemm_mixed16(
1133 &self,
1134 a: &Array2<f32>,
1135 b: &Array2<f32>,
1136 ) -> SpatialResult<Array2<f32>> {
1137 let (m, k) = a.dim();
1139 let (k2, n) = b.dim();
1140
1141 if k != k2 {
1142 return Err(SpatialError::InvalidInput(
1143 "Matrix dimensions don't match".to_string(),
1144 ));
1145 }
1146
1147 let mut c = Array2::zeros((m, n));
1148 let block_size = 16;
1149
1150 for i in (0..m).step_by(block_size) {
1151 for j in (0..n).step_by(block_size) {
1152 for kk in (0..k).step_by(block_size) {
1153 let end_i = (i + block_size).min(m);
1154 let end_j = (j + block_size).min(n);
1155 let end_k = (kk + block_size).min(k);
1156
1157 let block_k = end_k - kk;
1159 let k_chunks = block_k / 4;
1160
1161 for ii in i..end_i {
1162 for jj in j..end_j {
1163 let mut accumulator = c[[ii, jj]];
1164
1165 for k_chunk in 0..k_chunks {
1167 let k_base = kk + k_chunk * 4;
1168
1169 let a_val0 = a[[ii, k_base]];
1171 let a_val1 = a[[ii, k_base + 1]];
1172 let a_val2 = a[[ii, k_base + 2]];
1173 let a_val3 = a[[ii, k_base + 3]];
1174
1175 let b_val0 = b[[k_base, jj]];
1176 let b_val1 = b[[k_base + 1, jj]];
1177 let b_val2 = b[[k_base + 2, jj]];
1178 let b_val3 = b[[k_base + 3, jj]];
1179
1180 accumulator += a_val0 * b_val0
1181 + a_val1 * b_val1
1182 + a_val2 * b_val2
1183 + a_val3 * b_val3;
1184 }
1185
1186 for kkk in (kk + k_chunks * 4)..end_k {
1188 accumulator += a[[ii, kkk]] * b[[kkk, jj]];
1189 }
1190
1191 c[[ii, jj]] = accumulator;
1192 }
1193 }
1194 }
1195 }
1196 }
1197
1198 Ok(c)
1199 }
1200
1201 async fn tensor_core_gemm_bf16(
1203 &self,
1204 a: &Array2<f32>,
1205 b: &Array2<f32>,
1206 ) -> SpatialResult<Array2<f32>> {
1207 self.tensor_core_gemm_mixed16(a, b).await
1209 }
1210
1211 #[allow(dead_code)]
1213 async fn tensor_core_gemm_int8(
1214 &self,
1215 a: &Array2<i8>,
1216 b: &Array2<i8>,
1217 ) -> SpatialResult<Array2<i32>> {
1218 let (m, k) = a.dim();
1219 let (k2, n) = b.dim();
1220
1221 if k != k2 {
1222 return Err(SpatialError::InvalidInput(
1223 "Matrix dimensions don't match".to_string(),
1224 ));
1225 }
1226
1227 let mut c = Array2::zeros((m, n));
1228 let block_size = 16;
1229
1230 for i in (0..m).step_by(block_size) {
1231 for j in (0..n).step_by(block_size) {
1232 for kk in (0..k).step_by(block_size) {
1233 let end_i = (i + block_size).min(m);
1234 let end_j = (j + block_size).min(n);
1235 let end_k = (kk + block_size).min(k);
1236
1237 for ii in i..end_i {
1238 for jj in j..end_j {
1239 for kkk in kk..end_k {
1240 c[[ii, jj]] += a[[ii, kkk]] as i32 * b[[kkk, jj]] as i32;
1242 }
1243 }
1244 }
1245 }
1246 }
1247 }
1248
1249 Ok(c)
1250 }
1251
1252 fn convert_to_fp16(data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1254 let (rows, cols) = data.dim();
1255 let mut fp16_data = Array2::zeros((rows, cols));
1256
1257 for i in 0..rows {
1258 for j in 0..cols {
1259 fp16_data[[i, j]] = data[[i, j]] as f32;
1261 }
1262 }
1263
1264 Ok(fp16_data)
1265 }
1266
1267 fn convert_to_bf16(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1269 TensorCoreDistanceMatrix::convert_to_fp16(data)
1271 }
1272
1273 fn quantize_to_int8_dynamic(
1275 &self,
1276 data: &ArrayView2<'_, f64>,
1277 ) -> SpatialResult<(f64, Array2<i8>)> {
1278 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1279 let scale = max_val / 127.0; let (rows, cols) = data.dim();
1282 let mut quantized = Array2::zeros((rows, cols));
1283
1284 for i in 0..rows {
1285 for j in 0..cols {
1286 let quantized_val = (data[[i, j]] / scale).round() as i8;
1287 quantized[[i, j]] = quantized_val.clamp(-127, 127);
1288 }
1289 }
1290
1291 Ok((scale, quantized))
1292 }
1293
1294 fn quantize_to_int4_advanced(
1296 &self,
1297 data: &ArrayView2<'_, f64>,
1298 ) -> SpatialResult<(f64, Array2<i8>)> {
1299 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1300 let scale = max_val / 7.0; let (rows, cols) = data.dim();
1303 let mut quantized = Array2::zeros((rows, cols));
1304
1305 for i in 0..rows {
1306 for j in 0..cols {
1307 let quantized_val = (data[[i, j]] / scale).round() as i8;
1308 quantized[[i, j]] = quantized_val.clamp(-7, 7);
1309 }
1310 }
1311
1312 Ok((scale, quantized))
1313 }
1314
1315 fn int4_to_int8(data: &Array2<i8>) -> Array2<i8> {
1317 data.mapv(|x| x.clamp(-7, 7))
1319 }
1320
1321 fn compute_norms_fp16(data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1323 let norms = data
1324 .outer_iter()
1325 .map(|row| row.iter().map(|&x| x * x).sum())
1326 .collect();
1327 Ok(norms)
1328 }
1329
1330 fn compute_norms_bf16(&mut self, data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1332 TensorCoreDistanceMatrix::compute_norms_fp16(data)
1333 }
1334
1335 fn analyze_data_range(
1337 &self,
1338 points_i: &ArrayView2<'_, f64>,
1339 points_j: &ArrayView2<'_, f64>,
1340 ) -> f64 {
1341 let min_i = points_i.fold(f64::INFINITY, |acc, &x| acc.min(x));
1342 let max_i = points_i.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1343 let min_j = points_j.fold(f64::INFINITY, |acc, &x| acc.min(x));
1344 let max_j = points_j.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1345
1346 let overall_min = min_i.min(min_j);
1347 let overall_max = max_i.max(max_j);
1348
1349 overall_max - overall_min
1350 }
1351
1352 fn estimate_condition_number(
1354 &self,
1355 points_i: &ArrayView2<'_, f64>,
1356 points_j: &ArrayView2<'_, f64>,
1357 ) -> f64 {
1358 let data_range = self.analyze_data_range(points_i, points_j);
1360 let mean_i: f64 = points_i.sum() / (points_i.len() as f64);
1361 let mean_j: f64 = points_j.sum() / (points_j.len() as f64);
1362 let overall_mean = (mean_i + mean_j) / 2.0;
1363
1364 if overall_mean.abs() < 1e-10 {
1365 1e6 } else {
1367 data_range / overall_mean.abs()
1368 }
1369 }
1370}
1371
1372#[allow(dead_code)]
1374#[derive(Debug, Clone)]
1375pub struct TensorCoreClustering {
1376 _numclusters: usize,
1378 precision_mode: PrecisionMode,
1380 tensor_cores: bool,
1382 mixed_precision: bool,
1384 dynamic_precision: bool,
1386 capabilities: Option<TensorCoreCapabilities>,
1388}
1389
1390impl TensorCoreClustering {
1391 pub fn new(_numclusters: usize) -> SpatialResult<Self> {
1393 let capabilities = detect_tensor_core_capabilities().ok();
1394
1395 Ok(Self {
1396 _numclusters,
1397 precision_mode: PrecisionMode::Mixed16,
1398 tensor_cores: true,
1399 mixed_precision: true,
1400 dynamic_precision: false,
1401 capabilities,
1402 })
1403 }
1404
1405 pub fn with_tensor_cores(mut self, enabled: bool) -> Self {
1407 self.tensor_cores = enabled;
1408 self
1409 }
1410
1411 pub fn with_mixed_precision(mut self, enabled: bool) -> Self {
1413 self.mixed_precision = enabled;
1414 self
1415 }
1416
1417 pub fn with_dynamic_precision_scaling(mut self, enabled: bool) -> Self {
1419 self.dynamic_precision = enabled;
1420 self
1421 }
1422
1423 pub async fn fit(
1425 &mut self,
1426 points: &ArrayView2<'_, f64>,
1427 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
1428 let (npoints, ndims) = points.dim();
1429
1430 if npoints < self._numclusters {
1431 return Err(SpatialError::InvalidInput(
1432 "Number of points must be >= number of clusters".to_string(),
1433 ));
1434 }
1435
1436 let mut centroids = self.initialize_centroids(points)?;
1438 let mut assignments = Array1::zeros(npoints);
1439
1440 for _iteration in 0..100 {
1442 let distance_matrix = if self.tensor_cores {
1444 let tensor_computer =
1445 TensorCoreDistanceMatrix::new()?.with_precision_mode(self.precision_mode);
1446 tensor_computer
1447 .compute_distances_to_centroids(points, ¢roids.view())
1448 .await?
1449 } else {
1450 self.compute_distances_fallback(points, ¢roids.view())?
1451 };
1452
1453 let new_assignments = self.update_assignments(&distance_matrix)?;
1455
1456 let new_centroids = if self.tensor_cores {
1458 self.update_centroids_tensor_cores(points, &new_assignments)
1459 .await?
1460 } else {
1461 self.update_centroids_fallback(points, &new_assignments)?
1462 };
1463
1464 let centroid_change = self.compute_centroid_change(¢roids, &new_centroids);
1466 if centroid_change < 1e-6 {
1467 break;
1468 }
1469
1470 centroids = new_centroids;
1471 assignments = new_assignments;
1472 }
1473
1474 Ok((centroids, assignments))
1475 }
1476
1477 fn initialize_centroids(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
1479 let (npoints, ndims) = points.dim();
1480 let mut centroids = Array2::zeros((self._numclusters, ndims));
1481
1482 let mut rng = scirs2_core::random::rng();
1484
1485 let first_idx = rng.gen_range(0..npoints);
1487 centroids.row_mut(0).assign(&points.row(first_idx));
1488
1489 for k in 1..self._numclusters {
1491 let mut distances = Array1::zeros(npoints);
1492
1493 for i in 0..npoints {
1494 let point = points.row(i);
1495 let mut min_dist = f64::INFINITY;
1496
1497 for j in 0..k {
1498 let centroid = centroids.row(j);
1499 let dist: f64 = point
1500 .iter()
1501 .zip(centroid.iter())
1502 .map(|(&a, &b)| (a - b).powi(2))
1503 .sum::<f64>();
1504 min_dist = min_dist.min(dist);
1505 }
1506
1507 distances[i] = min_dist;
1508 }
1509
1510 let total_dist: f64 = distances.sum();
1512 let mut cumulative = 0.0;
1513 let random_val = scirs2_core::random::random::<f64>() * total_dist;
1514
1515 for i in 0..npoints {
1516 cumulative += distances[i];
1517 if cumulative >= random_val {
1518 centroids.row_mut(k).assign(&points.row(i));
1519 break;
1520 }
1521 }
1522 }
1523
1524 Ok(centroids)
1525 }
1526
1527 fn update_assignments(
1529 &mut self,
1530 distance_matrix: &Array2<f64>,
1531 ) -> SpatialResult<Array1<usize>> {
1532 let npoints = distance_matrix.nrows();
1533 let mut assignments = Array1::zeros(npoints);
1534
1535 for i in 0..npoints {
1536 let mut min_dist = f64::INFINITY;
1537 let mut best_cluster = 0;
1538
1539 for j in 0..self._numclusters {
1540 if distance_matrix[[i, j]] < min_dist {
1541 min_dist = distance_matrix[[i, j]];
1542 best_cluster = j;
1543 }
1544 }
1545
1546 assignments[i] = best_cluster;
1547 }
1548
1549 Ok(assignments)
1550 }
1551
1552 async fn update_centroids_tensor_cores(
1554 &self,
1555 points: &ArrayView2<'_, f64>,
1556 assignments: &Array1<usize>,
1557 ) -> SpatialResult<Array2<f64>> {
1558 let (_npoints, ndims) = points.dim();
1559 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1560 let mut cluster_counts = vec![0; self._numclusters];
1561
1562 for &cluster in assignments {
1564 cluster_counts[cluster] += 1;
1565 }
1566
1567 for cluster in 0..self._numclusters {
1569 if cluster_counts[cluster] == 0 {
1570 continue;
1571 }
1572
1573 let clusterpoints: Vec<usize> = assignments
1575 .iter()
1576 .enumerate()
1577 .filter(|(_, &c)| c == cluster)
1578 .map(|(i, _)| i)
1579 .collect();
1580
1581 let cluster_data = Array2::from_shape_fn((clusterpoints.len(), ndims), |(i, j)| {
1583 points[[clusterpoints[i], j]]
1584 });
1585
1586 let sum_vector = self.tensor_sum_reduction(&cluster_data.view()).await?;
1588 let count = clusterpoints.len() as f64;
1589
1590 for j in 0..ndims {
1591 new_centroids[[cluster, j]] = sum_vector[j] / count;
1592 }
1593 }
1594
1595 Ok(new_centroids)
1596 }
1597
1598 async fn tensor_sum_reduction(&self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
1600 let (_npoints, ndims) = data.dim();
1601 let mut sum_vector = Array1::zeros(ndims);
1602
1603 for j in 0..ndims {
1605 let column_sum: f64 = data.column(j).sum();
1606 sum_vector[j] = column_sum;
1607 }
1608
1609 Ok(sum_vector)
1610 }
1611
1612 fn compute_distances_fallback(
1614 &self,
1615 points: &ArrayView2<'_, f64>,
1616 centroids: &ArrayView2<'_, f64>,
1617 ) -> SpatialResult<Array2<f64>> {
1618 let (npoints, ndims) = points.dim();
1619 let (n_clusters_, _) = centroids.dim();
1620 let mut distances = Array2::zeros((npoints, n_clusters_));
1621
1622 let cluster_chunks = n_clusters_ / 4;
1624
1625 for i in 0..npoints {
1626 let point_row = points.row(i);
1627
1628 for j_chunk in 0..cluster_chunks {
1630 let j_base = j_chunk * 4;
1631
1632 let j0 = j_base;
1634 let j1 = j_base + 1;
1635 let j2 = j_base + 2;
1636 let j3 = j_base + 3;
1637
1638 let centroid_row0 = centroids.row(j0);
1639 let centroid_row1 = centroids.row(j1);
1640 let centroid_row2 = centroids.row(j2);
1641 let centroid_row3 = centroids.row(j3);
1642
1643 let distance0: f64 = point_row
1644 .iter()
1645 .zip(centroid_row0.iter())
1646 .map(|(&a, &b)| (a - b).powi(2))
1647 .sum::<f64>()
1648 .sqrt();
1649
1650 let distance1: f64 = point_row
1651 .iter()
1652 .zip(centroid_row1.iter())
1653 .map(|(&a, &b)| (a - b).powi(2))
1654 .sum::<f64>()
1655 .sqrt();
1656
1657 let distance2: f64 = point_row
1658 .iter()
1659 .zip(centroid_row2.iter())
1660 .map(|(&a, &b)| (a - b).powi(2))
1661 .sum::<f64>()
1662 .sqrt();
1663
1664 let distance3: f64 = point_row
1665 .iter()
1666 .zip(centroid_row3.iter())
1667 .map(|(&a, &b)| (a - b).powi(2))
1668 .sum::<f64>()
1669 .sqrt();
1670
1671 distances[[i, j0]] = distance0;
1672 distances[[i, j1]] = distance1;
1673 distances[[i, j2]] = distance2;
1674 distances[[i, j3]] = distance3;
1675 }
1676
1677 for j in (cluster_chunks * 4)..n_clusters_ {
1679 let distance: f64 = point_row
1680 .iter()
1681 .zip(centroids.row(j).iter())
1682 .map(|(&a, &b)| (a - b).powi(2))
1683 .sum::<f64>()
1684 .sqrt();
1685 distances[[i, j]] = distance;
1686 }
1687 }
1688
1689 Ok(distances)
1690 }
1691
1692 fn update_centroids_fallback(
1694 &self,
1695 points: &ArrayView2<'_, f64>,
1696 assignments: &Array1<usize>,
1697 ) -> SpatialResult<Array2<f64>> {
1698 let (npoints, ndims) = points.dim();
1699 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1700 let mut cluster_counts = vec![0; self._numclusters];
1701
1702 for i in 0..npoints {
1704 let cluster = assignments[i];
1705 cluster_counts[cluster] += 1;
1706
1707 for j in 0..ndims {
1708 new_centroids[[cluster, j]] += points[[i, j]];
1709 }
1710 }
1711
1712 for cluster in 0..self._numclusters {
1714 if cluster_counts[cluster] > 0 {
1715 let count = cluster_counts[cluster] as f64;
1716 for j in 0..ndims {
1717 new_centroids[[cluster, j]] /= count;
1718 }
1719 }
1720 }
1721
1722 Ok(new_centroids)
1723 }
1724
1725 fn compute_centroid_change(
1727 &self,
1728 old_centroids: &Array2<f64>,
1729 new_centroids: &Array2<f64>,
1730 ) -> f64 {
1731 let mut total_change = 0.0;
1732
1733 for i in 0..self._numclusters {
1734 let change: f64 = old_centroids
1735 .row(i)
1736 .iter()
1737 .zip(new_centroids.row(i).iter())
1738 .map(|(&a, &b)| (a - b).powi(2))
1739 .sum::<f64>()
1740 .sqrt();
1741 total_change += change;
1742 }
1743
1744 total_change / (self._numclusters as f64)
1745 }
1746}
1747
1748impl Default for StabilityMetrics {
1749 fn default() -> Self {
1750 Self::new()
1751 }
1752}
1753
1754impl StabilityMetrics {
1755 pub fn new() -> Self {
1757 Self {
1758 condition_number: 1.0,
1759 relative_error: 0.0,
1760 forward_error: 0.0,
1761 backward_error: 0.0,
1762 digit_loss: 0.0,
1763 stability_level: StabilityLevel::Excellent,
1764 error_types: Vec::new(),
1765 timestamp: Instant::now(),
1766 }
1767 }
1768
1769 pub fn update_stability_level(&mut self) {
1771 self.stability_level = if self.condition_number > 1e12 || self.relative_error > 1e-3 {
1772 StabilityLevel::Critical
1773 } else if self.condition_number > 1e8 || self.relative_error > 1e-6 {
1774 StabilityLevel::Poor
1775 } else if self.condition_number > 1e4 || self.relative_error > 1e-9 {
1776 StabilityLevel::Moderate
1777 } else if self.condition_number > 1e2 || self.relative_error > 1e-12 {
1778 StabilityLevel::Good
1779 } else {
1780 StabilityLevel::Excellent
1781 };
1782 }
1783
1784 pub fn detect_errors(&mut self, data: &Array2<f64>) {
1786 self.error_types.clear();
1787
1788 for &value in data.iter() {
1790 if !value.is_finite() {
1791 self.error_types.push(NumericalErrorType::InvalidValues);
1792 break;
1793 }
1794 }
1795
1796 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1798 if max_val > 1e100 {
1799 self.error_types.push(NumericalErrorType::Overflow);
1800 } else if max_val < 1e-100 && max_val > 0.0 {
1801 self.error_types.push(NumericalErrorType::Underflow);
1802 }
1803
1804 if self.digit_loss > 6.0 {
1806 self.error_types.push(NumericalErrorType::PrecisionLoss);
1807 }
1808
1809 if self.condition_number > 1e12 {
1811 self.error_types.push(NumericalErrorType::IllConditioned);
1812 }
1813 }
1814}
1815
1816impl Default for DynamicPrecisionConfig {
1817 fn default() -> Self {
1818 Self {
1819 strategy: ScalingStrategy::Balanced,
1820 min_precision: PrecisionMode::Int8Dynamic,
1821 max_precision: PrecisionMode::Full32,
1822 stability_threshold_up: 1e-6,
1823 stability_threshold_down: 1e-9,
1824 performance_weight: 0.6,
1825 accuracy_weight: 0.4,
1826 max_changes_per_operation: 3,
1827 change_cooldown: Duration::from_millis(100),
1828 }
1829 }
1830}
1831
1832impl NumericalStabilityMonitor {
1833 pub fn new(config: DynamicPrecisionConfig) -> Self {
1835 Self {
1836 current_metrics: StabilityMetrics::new(),
1837 stability_history: VecDeque::new(),
1838 precision_config: config,
1839 current_precision: PrecisionMode::Mixed16,
1840 precision_history: VecDeque::new(),
1841 recovery_attempts: 0,
1842 max_history_length: 1000,
1843 last_precision_change: None,
1844 }
1845 }
1846
1847 pub fn monitor_stability(
1849 &mut self,
1850 data: &Array2<f64>,
1851 computation_result: &Array2<f64>,
1852 ) -> SpatialResult<()> {
1853 self.current_metrics.condition_number =
1855 NumericalStabilityMonitor::estimate_condition_number(data);
1856
1857 self.current_metrics.relative_error =
1859 self.estimate_relative_error(data, computation_result);
1860
1861 self.current_metrics.forward_error = self.estimate_forward_error(data, computation_result);
1863 self.current_metrics.backward_error =
1864 self.estimate_backward_error(data, computation_result);
1865
1866 self.current_metrics.digit_loss = self.estimate_digit_loss();
1868
1869 self.current_metrics.update_stability_level();
1871
1872 self.current_metrics.detect_errors(computation_result);
1874
1875 self.current_metrics.timestamp = Instant::now();
1877
1878 self.stability_history
1880 .push_back(self.current_metrics.clone());
1881 if self.stability_history.len() > self.max_history_length {
1882 self.stability_history.pop_front();
1883 }
1884
1885 Ok(())
1886 }
1887
1888 pub fn adjust_precision(&mut self) -> SpatialResult<PrecisionMode> {
1890 if let Some(last_change) = self.last_precision_change {
1892 if last_change.elapsed() < self.precision_config.change_cooldown {
1893 return Ok(self.current_precision);
1894 }
1895 }
1896
1897 let new_precision = match self.current_metrics.stability_level {
1898 StabilityLevel::Critical => {
1899 self.precision_config.max_precision
1901 }
1902 StabilityLevel::Poor => {
1903 NumericalStabilityMonitor::increase_precision(self.current_precision)
1905 }
1906 StabilityLevel::Moderate => {
1907 if self.current_metrics.relative_error
1909 > self.precision_config.stability_threshold_up
1910 {
1911 NumericalStabilityMonitor::increase_precision(self.current_precision)
1912 } else {
1913 self.current_precision
1914 }
1915 }
1916 StabilityLevel::Good => {
1917 if self.current_metrics.relative_error
1919 < self.precision_config.stability_threshold_down
1920 {
1921 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1922 } else {
1923 self.current_precision
1924 }
1925 }
1926 StabilityLevel::Excellent => {
1927 if self.precision_config.strategy == ScalingStrategy::Aggressive {
1929 self.precision_config.min_precision
1930 } else {
1931 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1932 }
1933 }
1934 };
1935
1936 if new_precision != self.current_precision {
1938 self.precision_history.push_back((
1939 Instant::now(),
1940 new_precision,
1941 self.current_metrics.relative_error,
1942 ));
1943 self.current_precision = new_precision;
1944 self.last_precision_change = Some(Instant::now());
1945 }
1946
1947 Ok(new_precision)
1948 }
1949
1950 fn increase_precision(current: PrecisionMode) -> PrecisionMode {
1952 match current {
1953 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
1954 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
1955 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
1956 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
1957 PrecisionMode::Full32 => PrecisionMode::Full32, _ => PrecisionMode::Mixed16,
1959 }
1960 }
1961
1962 fn decrease_precision(current: PrecisionMode) -> PrecisionMode {
1964 match current {
1965 PrecisionMode::Full32 => PrecisionMode::BrainFloat16,
1966 PrecisionMode::BrainFloat16 => PrecisionMode::Mixed16,
1967 PrecisionMode::Mixed16 => PrecisionMode::Int8Dynamic,
1968 PrecisionMode::Int8Dynamic => PrecisionMode::Int4Advanced,
1969 PrecisionMode::Int4Advanced => PrecisionMode::Int4Advanced, _ => PrecisionMode::Mixed16,
1971 }
1972 }
1973
1974 fn estimate_condition_number(data: &Array2<f64>) -> f64 {
1976 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1978 let min_val = data.fold(f64::INFINITY, |acc, &x| {
1979 if x.abs() > 1e-15 {
1980 acc.min(x.abs())
1981 } else {
1982 acc
1983 }
1984 });
1985
1986 if min_val.is_finite() && min_val > 0.0 {
1987 max_val / min_val
1988 } else {
1989 1e12 }
1991 }
1992
1993 fn estimate_relative_error(&mut self, input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1995 let mean_val = output.mean().unwrap_or(0.0);
1997 if mean_val.abs() > 1e-15 {
1998 let machine_eps = match self.current_precision {
2000 PrecisionMode::Full32 => 2.22e-16,
2001 PrecisionMode::Mixed16 | PrecisionMode::BrainFloat16 => 9.77e-4,
2002 PrecisionMode::Int8Dynamic => 1.0 / 256.0,
2003 PrecisionMode::Int4Advanced => 1.0 / 16.0,
2004 _ => 1e-6,
2005 };
2006 machine_eps * self.current_metrics.condition_number
2007 } else {
2008 0.0
2009 }
2010 }
2011
2012 fn estimate_forward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
2014 self.current_metrics.relative_error * self.current_metrics.condition_number
2016 }
2017
2018 fn estimate_backward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
2020 self.current_metrics.relative_error
2022 }
2023
2024 fn estimate_digit_loss(&self) -> f64 {
2026 if self.current_metrics.condition_number > 1.0 {
2027 self.current_metrics.condition_number.log10().max(0.0)
2028 } else {
2029 0.0
2030 }
2031 }
2032}
2033
2034impl Default for ErrorRecoverySystem {
2035 fn default() -> Self {
2036 Self::new()
2037 }
2038}
2039
2040impl ErrorRecoverySystem {
2041 pub fn new() -> Self {
2043 let mut recovery_strategies = HashMap::new();
2044
2045 recovery_strategies.insert(
2047 NumericalErrorType::Overflow,
2048 vec![
2049 RecoveryAction::IncreasePrecision,
2050 RecoveryAction::ReduceTileSize,
2051 RecoveryAction::NumericalStabilization,
2052 ],
2053 );
2054 recovery_strategies.insert(
2055 NumericalErrorType::Underflow,
2056 vec![
2057 RecoveryAction::IncreasePrecision,
2058 RecoveryAction::NumericalStabilization,
2059 ],
2060 );
2061 recovery_strategies.insert(
2062 NumericalErrorType::PrecisionLoss,
2063 vec![
2064 RecoveryAction::IncreasePrecision,
2065 RecoveryAction::RetryWithNewParams,
2066 ],
2067 );
2068 recovery_strategies.insert(
2069 NumericalErrorType::IllConditioned,
2070 vec![
2071 RecoveryAction::IncreasePrecision,
2072 RecoveryAction::NumericalStabilization,
2073 RecoveryAction::SwitchToCPU,
2074 ],
2075 );
2076 recovery_strategies.insert(
2077 NumericalErrorType::InvalidValues,
2078 vec![
2079 RecoveryAction::FallbackAlgorithm,
2080 RecoveryAction::SwitchToCPU,
2081 ],
2082 );
2083
2084 Self {
2085 recovery_strategies,
2086 recovery_history: VecDeque::new(),
2087 max_recovery_attempts: 3,
2088 success_rates: HashMap::new(),
2089 }
2090 }
2091
2092 pub async fn attempt_recovery(
2094 &mut self,
2095 error_type: NumericalErrorType,
2096 ) -> SpatialResult<RecoveryAction> {
2097 let start_time = Instant::now();
2098
2099 let strategies = self
2101 .recovery_strategies
2102 .get(&error_type)
2103 .ok_or_else(|| SpatialError::InvalidInput("Unknown error _type".to_string()))?
2104 .clone(); let best_action = self.choose_best_recovery_action(&strategies);
2108
2109 let attempt = RecoveryAttempt {
2111 error_type,
2112 action: best_action,
2113 success: false, duration: start_time.elapsed(),
2115 post_recovery_metrics: None,
2116 timestamp: start_time,
2117 };
2118
2119 self.recovery_history.push_back(attempt);
2120
2121 Ok(best_action)
2122 }
2123
2124 fn choose_best_recovery_action(&mut self, strategies: &[RecoveryAction]) -> RecoveryAction {
2126 strategies
2127 .iter()
2128 .max_by(|&a, &b| {
2129 let rate_a = self.success_rates.get(a).unwrap_or(&0.5);
2130 let rate_b = self.success_rates.get(b).unwrap_or(&0.5);
2131 rate_a
2132 .partial_cmp(rate_b)
2133 .unwrap_or(std::cmp::Ordering::Equal)
2134 })
2135 .copied()
2136 .unwrap_or(RecoveryAction::IncreasePrecision)
2137 }
2138
2139 pub fn update_success_rate(&mut self, action: RecoveryAction, success: bool) {
2141 let current_rate = self.success_rates.get(&action).unwrap_or(&0.5);
2142 let new_rate = if success {
2143 current_rate * 0.9 + 0.1 } else {
2145 current_rate * 0.9
2146 };
2147 self.success_rates.insert(action, new_rate);
2148 }
2149}
2150
2151impl PerformanceAccuracyAnalyzer {
2152 pub fn new(params: TradeOffParams) -> Self {
2154 Self {
2155 performance_data: HashMap::new(),
2156 accuracy_data: HashMap::new(),
2157 optimization_params: params,
2158 pareto_frontier: Vec::new(),
2159 }
2160 }
2161
2162 pub fn record_performance(&mut self, precision: PrecisionMode, duration: Duration) {
2164 self.performance_data
2165 .entry(precision)
2166 .or_default()
2167 .push_back(duration);
2168
2169 if let Some(history) = self.performance_data.get_mut(&precision) {
2171 if history.len() > 100 {
2172 history.pop_front();
2173 }
2174 }
2175 }
2176
2177 pub fn record_accuracy(&mut self, precision: PrecisionMode, accuracy: f64) {
2179 self.accuracy_data
2180 .entry(precision)
2181 .or_default()
2182 .push_back(accuracy);
2183
2184 if let Some(history) = self.accuracy_data.get_mut(&precision) {
2186 if history.len() > 100 {
2187 history.pop_front();
2188 }
2189 }
2190 }
2191
2192 pub fn optimize_precision(&mut self) -> PrecisionMode {
2194 self.update_pareto_frontier();
2195
2196 match self.optimization_params.objective {
2197 OptimizationObjective::MaxPerformance => self
2198 .pareto_frontier
2199 .iter()
2200 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
2201 .map(|(_a, b, mode)| *mode)
2202 .unwrap_or(PrecisionMode::Mixed16),
2203 OptimizationObjective::MaxAccuracy => self
2204 .pareto_frontier
2205 .iter()
2206 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
2207 .map(|(_a, b, mode)| *mode)
2208 .unwrap_or(PrecisionMode::Full32),
2209 OptimizationObjective::Balanced => {
2210 let mut best_score = f64::NEG_INFINITY;
2212 let mut best_mode = PrecisionMode::Mixed16;
2213
2214 let performance_weight = self.optimization_params.performance_weight;
2216 let accuracy_weight = self.optimization_params.accuracy_weight;
2217
2218 for &(perf, acc, mode) in &self.pareto_frontier {
2219 let perf_score = 1.0 / (perf + 1e-9);
2221 let score = performance_weight * perf_score + accuracy_weight * acc;
2222 if score > best_score {
2223 best_score = score;
2224 best_mode = mode;
2225 }
2226 }
2227
2228 best_mode
2229 }
2230 _ => PrecisionMode::Mixed16,
2231 }
2232 }
2233
2234 fn update_pareto_frontier(&mut self) {
2236 self.pareto_frontier.clear();
2237
2238 for precision in [
2239 PrecisionMode::Full32,
2240 PrecisionMode::BrainFloat16,
2241 PrecisionMode::Mixed16,
2242 PrecisionMode::Int8Dynamic,
2243 PrecisionMode::Int4Advanced,
2244 ] {
2245 if let (Some(perf_data), Some(acc_data)) = (
2246 self.performance_data.get(&precision),
2247 self.accuracy_data.get(&precision),
2248 ) {
2249 if !perf_data.is_empty() && !acc_data.is_empty() {
2250 let avg_perf = perf_data.iter().map(|d| d.as_secs_f64()).sum::<f64>()
2251 / perf_data.len() as f64;
2252 let avg_acc = acc_data.iter().sum::<f64>() / acc_data.len() as f64;
2253
2254 self.pareto_frontier.push((avg_perf, avg_acc, precision));
2255 }
2256 }
2257 }
2258 }
2259
2260 #[allow(dead_code)]
2262 fn compute_weighted_score(&mut self, performance: f64, accuracy: f64) -> f64 {
2263 let perf_score = 1.0 / (performance + 1e-9);
2265
2266 self.optimization_params.performance_weight * perf_score
2268 + self.optimization_params.accuracy_weight * accuracy
2269 }
2270}
2271
2272impl AdvancedTensorCoreDistanceMatrix {
2273 pub fn new() -> SpatialResult<Self> {
2275 let base_computer = TensorCoreDistanceMatrix::new()?;
2276 let precision_config = DynamicPrecisionConfig::default();
2277 let stability_monitor =
2278 Arc::new(Mutex::new(NumericalStabilityMonitor::new(precision_config)));
2279 let recovery_system = ErrorRecoverySystem::new();
2280 let trade_off_params = TradeOffParams {
2281 performance_weight: 0.6,
2282 accuracy_weight: 0.4,
2283 energy_weight: 0.0,
2284 min_accuracy: 0.95,
2285 max_time: Duration::from_secs(30),
2286 objective: OptimizationObjective::Balanced,
2287 };
2288 let performance_analyzer = PerformanceAccuracyAnalyzer::new(trade_off_params);
2289
2290 Ok(Self {
2291 base_computer,
2292 stability_monitor,
2293 recovery_system,
2294 performance_analyzer,
2295 dynamic_precision_enabled: true,
2296 auto_recovery_enabled: true,
2297 })
2298 }
2299
2300 pub fn with_dynamic_precision(mut self, enabled: bool) -> Self {
2302 self.dynamic_precision_enabled = enabled;
2303 self
2304 }
2305
2306 pub fn with_auto_recovery(mut self, enabled: bool) -> Self {
2308 self.auto_recovery_enabled = enabled;
2309 self
2310 }
2311
2312 pub async fn compute_with_stability_monitoring(
2314 &mut self,
2315 points: &ArrayView2<'_, f64>,
2316 ) -> SpatialResult<Array2<f64>> {
2317 let start_time = Instant::now();
2318
2319 {
2321 let mut monitor = self.stability_monitor.lock().unwrap();
2322 if self.dynamic_precision_enabled {
2325 let optimal_precision = monitor.adjust_precision()?;
2326 self.base_computer.precision_mode = optimal_precision;
2327 }
2328 }
2329
2330 let mut result = None;
2331 let mut recovery_attempts = 0;
2332 let max_attempts = 3;
2333
2334 while result.is_none() && recovery_attempts < max_attempts {
2335 match self.base_computer.compute_parallel(points).await {
2336 Ok(distances) => {
2337 {
2339 let mut monitor = self.stability_monitor.lock().unwrap();
2340 monitor.monitor_stability(&points.to_owned(), &distances)?;
2341 }
2342
2343 let stability_level = {
2345 let monitor = self.stability_monitor.lock().unwrap();
2346 monitor.current_metrics.stability_level
2347 };
2348
2349 if stability_level == StabilityLevel::Critical && self.auto_recovery_enabled {
2350 recovery_attempts += 1;
2352 let recovery_action = self
2353 .recovery_system
2354 .attempt_recovery(NumericalErrorType::IllConditioned)
2355 .await?;
2356
2357 self.apply_recovery_action(recovery_action).await?;
2359 continue;
2360 } else {
2361 result = Some(distances);
2362 }
2363 }
2364 Err(e) => {
2365 if self.auto_recovery_enabled && recovery_attempts < max_attempts {
2366 recovery_attempts += 1;
2367 let recovery_action = self
2368 .recovery_system
2369 .attempt_recovery(NumericalErrorType::InvalidValues)
2370 .await?;
2371 self.apply_recovery_action(recovery_action).await?;
2372 continue;
2373 } else {
2374 return Err(e);
2375 }
2376 }
2377 }
2378 }
2379
2380 let final_result = result.ok_or_else(|| {
2381 SpatialError::InvalidInput(
2382 "Failed to compute stable result after recovery attempts".to_string(),
2383 )
2384 })?;
2385
2386 let duration = start_time.elapsed();
2388 let precision = self.base_computer.precision_mode;
2389 self.performance_analyzer
2390 .record_performance(precision, duration);
2391
2392 let accuracy = self.estimate_result_accuracy(&final_result);
2394 self.performance_analyzer
2395 .record_accuracy(precision, accuracy);
2396
2397 Ok(final_result)
2398 }
2399
2400 async fn apply_recovery_action(&mut self, action: RecoveryAction) -> SpatialResult<()> {
2402 match action {
2403 RecoveryAction::IncreasePrecision => {
2404 self.base_computer.precision_mode = match self.base_computer.precision_mode {
2405 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
2406 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
2407 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
2408 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
2409 PrecisionMode::Full32 => PrecisionMode::Full32,
2410 _ => PrecisionMode::Mixed16,
2411 };
2412 }
2413 RecoveryAction::ReduceTileSize => {
2414 let (current_row, current_col) = self.base_computer.tile_size;
2415 self.base_computer.tile_size = (current_row / 2, current_col / 2);
2416 if self.base_computer.tile_size.0 < 16 {
2417 self.base_computer.tile_size = (16, 16);
2418 }
2419 }
2420 RecoveryAction::FallbackAlgorithm => {
2421 self.base_computer.precision_mode = PrecisionMode::Full32;
2423 self.base_computer.hierarchical_tiling = false;
2424 }
2425 RecoveryAction::NumericalStabilization => {
2426 self.base_computer.precision_mode = PrecisionMode::Full32;
2428 self.base_computer.tile_size = (64, 64);
2429 }
2430 _ => {
2431 self.base_computer.precision_mode = PrecisionMode::Full32;
2433 }
2434 }
2435
2436 Ok(())
2437 }
2438
2439 fn estimate_result_accuracy(&self, result: &Array2<f64>) -> f64 {
2441 let has_invalid = result.iter().any(|&x| !x.is_finite());
2443 if has_invalid {
2444 return 0.0;
2445 }
2446
2447 let max_val = result.fold(0.0f64, |acc, &x| acc.max(x.abs()));
2448 let min_val = result.fold(f64::INFINITY, |acc, &x| {
2449 if x.abs() > 1e-15 {
2450 acc.min(x.abs())
2451 } else {
2452 acc
2453 }
2454 });
2455
2456 if min_val.is_finite() && min_val > 0.0 {
2457 let dynamic_range = max_val / min_val;
2458 (1.0 / (1.0 + dynamic_range.log10() / 10.0)).clamp(0.0, 1.0)
2459 } else {
2460 0.95 }
2462 }
2463}
2464
2465#[allow(dead_code)]
2467pub fn detect_tensor_core_capabilities() -> SpatialResult<TensorCoreCapabilities> {
2468 Ok(TensorCoreCapabilities {
2472 tensor_core_types: vec![
2473 TensorCoreType::NvidiaTensorCore,
2474 TensorCoreType::StandardCores,
2475 ],
2476 supported_precisions: vec![
2477 PrecisionMode::Full32,
2478 PrecisionMode::Mixed16,
2479 PrecisionMode::BrainFloat16,
2480 PrecisionMode::Int8Dynamic,
2481 ],
2482 max_tensor_size: (4096, 4096, 4096),
2483 peak_throughput_tops: 312.0, memory_bandwidth_gbps: 1555.0, l2_cache_mb: 40.0,
2486 num_sms: 108,
2487 architecture: GpuArchitecture::Ampere,
2488 })
2489}
2490
2491impl TensorCoreDistanceMatrix {
2493 pub async fn compute_distances_to_centroids(
2495 &self,
2496 points: &ArrayView2<'_, f64>,
2497 centroids: &ArrayView2<'_, f64>,
2498 ) -> SpatialResult<Array2<f64>> {
2499 let (npoints, ndims) = points.dim();
2500 let (n_clusters_, n_dims_c) = centroids.dim();
2501 let mut distances = Array2::zeros((npoints, n_clusters_));
2502
2503 let cluster_chunks = n_clusters_ / 4;
2505
2506 for i in 0..npoints {
2507 let point_row = points.row(i);
2508
2509 for j_chunk in 0..cluster_chunks {
2511 let j_base = j_chunk * 4;
2512
2513 let j0 = j_base;
2515 let j1 = j_base + 1;
2516 let j2 = j_base + 2;
2517 let j3 = j_base + 3;
2518
2519 let centroid_row0 = centroids.row(j0);
2520 let centroid_row1 = centroids.row(j1);
2521 let centroid_row2 = centroids.row(j2);
2522 let centroid_row3 = centroids.row(j3);
2523
2524 let distance0: f64 = point_row
2525 .iter()
2526 .zip(centroid_row0.iter())
2527 .map(|(&a, &b)| (a - b).powi(2))
2528 .sum::<f64>()
2529 .sqrt();
2530
2531 let distance1: f64 = point_row
2532 .iter()
2533 .zip(centroid_row1.iter())
2534 .map(|(&a, &b)| (a - b).powi(2))
2535 .sum::<f64>()
2536 .sqrt();
2537
2538 let distance2: f64 = point_row
2539 .iter()
2540 .zip(centroid_row2.iter())
2541 .map(|(&a, &b)| (a - b).powi(2))
2542 .sum::<f64>()
2543 .sqrt();
2544
2545 let distance3: f64 = point_row
2546 .iter()
2547 .zip(centroid_row3.iter())
2548 .map(|(&a, &b)| (a - b).powi(2))
2549 .sum::<f64>()
2550 .sqrt();
2551
2552 distances[[i, j0]] = distance0;
2553 distances[[i, j1]] = distance1;
2554 distances[[i, j2]] = distance2;
2555 distances[[i, j3]] = distance3;
2556 }
2557
2558 for j in (cluster_chunks * 4)..n_clusters_ {
2560 let distance: f64 = point_row
2561 .iter()
2562 .zip(centroids.row(j).iter())
2563 .map(|(&a, &b)| (a - b).powi(2))
2564 .sum::<f64>()
2565 .sqrt();
2566 distances[[i, j]] = distance;
2567 }
2568 }
2569
2570 Ok(distances)
2571 }
2572}
2573
2574#[cfg(test)]
2575mod tests {
2576 use super::*;
2577 use scirs2_core::ndarray::array;
2578
2579 #[test]
2580 fn test_precision_mode() {
2581 assert_eq!(PrecisionMode::Mixed16, PrecisionMode::Mixed16);
2582 assert_ne!(PrecisionMode::Mixed16, PrecisionMode::Full32);
2583 }
2584
2585 #[test]
2586 fn test_tensor_core_capabilities() {
2587 let capabilities = detect_tensor_core_capabilities();
2588 assert!(capabilities.is_ok());
2589
2590 let caps = capabilities.unwrap();
2591 assert!(!caps.tensor_core_types.is_empty());
2592 assert!(!caps.supported_precisions.is_empty());
2593 }
2594
2595 #[test]
2596 fn test_tensor_core_distance_matrix_creation() {
2597 let result = TensorCoreDistanceMatrix::new();
2598 assert!(result.is_ok());
2599
2600 let matrix_computer = result.unwrap();
2601 assert_eq!(matrix_computer.precision_mode, PrecisionMode::Mixed16);
2602 }
2603
2604 #[test]
2605 fn test_tensor_core_clustering_creation() {
2606 let result = TensorCoreClustering::new(3);
2607 assert!(result.is_ok());
2608
2609 let clustering = result.unwrap();
2610 assert_eq!(clustering._numclusters, 3);
2611 }
2612
2613 #[tokio::test]
2614 async fn test_tensor_core_distance_computation() {
2615 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2616 let mut matrix_computer = TensorCoreDistanceMatrix::new().unwrap();
2617
2618 let result = matrix_computer.compute_parallel(&points.view()).await;
2619 assert!(result.is_ok());
2620
2621 let distances = result.unwrap();
2622 assert_eq!(distances.shape(), &[3, 3]);
2623
2624 for i in 0..3 {
2626 assert!((distances[[i, i]]).abs() < 1e-10);
2627 }
2628 }
2629
2630 #[tokio::test]
2631 async fn test_tensor_core_clustering() {
2632 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
2633 let mut clustering = TensorCoreClustering::new(2).unwrap();
2634
2635 let result = clustering.fit(&points.view()).await;
2636 assert!(result.is_ok());
2637
2638 let (centroids, assignments) = result.unwrap();
2639 assert_eq!(centroids.shape(), &[2, 2]);
2640 assert_eq!(assignments.len(), 4);
2641 }
2642
2643 #[test]
2644 fn test_stability_metrics_creation() {
2645 let metrics = StabilityMetrics::new();
2646 assert_eq!(metrics.condition_number, 1.0);
2647 assert_eq!(metrics.relative_error, 0.0);
2648 assert_eq!(metrics.stability_level, StabilityLevel::Excellent);
2649 assert!(metrics.error_types.is_empty());
2650 }
2651
2652 #[test]
2653 fn test_stability_level_update() {
2654 let mut metrics = StabilityMetrics::new();
2655
2656 metrics.condition_number = 1e15;
2658 metrics.update_stability_level();
2659 assert_eq!(metrics.stability_level, StabilityLevel::Critical);
2660
2661 metrics.condition_number = 1e9;
2663 metrics.relative_error = 1e-7;
2664 metrics.update_stability_level();
2665 assert_eq!(metrics.stability_level, StabilityLevel::Poor);
2666
2667 metrics.condition_number = 1e3;
2669 metrics.relative_error = 1e-10;
2670 metrics.update_stability_level();
2671 assert_eq!(metrics.stability_level, StabilityLevel::Good);
2672 }
2673
2674 #[test]
2675 fn test_error_detection() {
2676 let mut metrics = StabilityMetrics::new();
2677
2678 let data_with_nan = array![[1.0, 2.0], [f64::NAN, 4.0]];
2680 metrics.detect_errors(&data_with_nan);
2681 assert!(metrics
2682 .error_types
2683 .contains(&NumericalErrorType::InvalidValues));
2684
2685 let data_with_overflow = array![[1e150, 2.0], [3.0, 4.0]];
2687 metrics.detect_errors(&data_with_overflow);
2688 assert!(metrics.error_types.contains(&NumericalErrorType::Overflow));
2689
2690 let data_with_underflow = array![[1e-150, 1e-120], [1e-130, 1e-140]];
2692 metrics.detect_errors(&data_with_underflow);
2693 assert!(metrics.error_types.contains(&NumericalErrorType::Underflow));
2694 }
2695
2696 #[test]
2697 fn test_dynamic_precision_config() {
2698 let config = DynamicPrecisionConfig::default();
2699 assert_eq!(config.strategy, ScalingStrategy::Balanced);
2700 assert_eq!(config.min_precision, PrecisionMode::Int8Dynamic);
2701 assert_eq!(config.max_precision, PrecisionMode::Full32);
2702 assert_eq!(config.performance_weight, 0.6);
2703 assert_eq!(config.accuracy_weight, 0.4);
2704 }
2705
2706 #[test]
2707 fn test_numerical_stability_monitor_creation() {
2708 let config = DynamicPrecisionConfig::default();
2709 let monitor = NumericalStabilityMonitor::new(config);
2710
2711 assert_eq!(monitor.current_precision, PrecisionMode::Mixed16);
2712 assert!(monitor.stability_history.is_empty());
2713 assert_eq!(monitor.recovery_attempts, 0);
2714 }
2715
2716 #[test]
2717 fn test_precision_increase_decrease() {
2718 let config = DynamicPrecisionConfig::default();
2719 let monitor = NumericalStabilityMonitor::new(config);
2720
2721 let increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Int8Dynamic);
2723 assert_eq!(increased, PrecisionMode::Mixed16);
2724
2725 let max_increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Full32);
2726 assert_eq!(max_increased, PrecisionMode::Full32); let decreased = NumericalStabilityMonitor::decrease_precision(PrecisionMode::Mixed16);
2730 assert_eq!(decreased, PrecisionMode::Int8Dynamic);
2731
2732 let min_decreased =
2733 NumericalStabilityMonitor::decrease_precision(PrecisionMode::Int4Advanced);
2734 assert_eq!(min_decreased, PrecisionMode::Int4Advanced); }
2736
2737 #[test]
2738 fn test_condition_number_estimation() {
2739 let config = DynamicPrecisionConfig::default();
2740 let monitor = NumericalStabilityMonitor::new(config);
2741
2742 let well_conditioned = array![[1.0, 2.0], [3.0, 4.0]];
2744 let condition_1 = NumericalStabilityMonitor::estimate_condition_number(&well_conditioned);
2745 assert!(condition_1 > 1.0 && condition_1 < 100.0);
2746
2747 let ill_conditioned = array![[1e-10, 2.0], [3.0, 1e10]];
2749 let condition_2 = NumericalStabilityMonitor::estimate_condition_number(&ill_conditioned);
2750 assert!(condition_2 > 1e15);
2751 }
2752
2753 #[test]
2754 fn test_error_recovery_system_creation() {
2755 let recovery_system = ErrorRecoverySystem::new();
2756
2757 assert!(!recovery_system.recovery_strategies.is_empty());
2759 assert!(recovery_system
2760 .recovery_strategies
2761 .contains_key(&NumericalErrorType::Overflow));
2762 assert!(recovery_system
2763 .recovery_strategies
2764 .contains_key(&NumericalErrorType::IllConditioned));
2765 assert_eq!(recovery_system.max_recovery_attempts, 3);
2766 }
2767
2768 #[tokio::test]
2769 async fn test_recovery_action_selection() {
2770 let mut recovery_system = ErrorRecoverySystem::new();
2771
2772 let action = recovery_system
2773 .attempt_recovery(NumericalErrorType::Overflow)
2774 .await;
2775 assert!(action.is_ok());
2776
2777 let recovery_action = action.unwrap();
2778 assert!(matches!(
2779 recovery_action,
2780 RecoveryAction::IncreasePrecision
2781 | RecoveryAction::ReduceTileSize
2782 | RecoveryAction::NumericalStabilization
2783 ));
2784 }
2785
2786 #[test]
2787 fn test_success_rate_update() {
2788 let mut recovery_system = ErrorRecoverySystem::new();
2789
2790 recovery_system.update_success_rate(RecoveryAction::IncreasePrecision, true);
2792 let rate = recovery_system
2793 .success_rates
2794 .get(&RecoveryAction::IncreasePrecision);
2795 assert!(rate.is_some());
2796 assert!(*rate.unwrap() > 0.5);
2797
2798 recovery_system.update_success_rate(RecoveryAction::ReduceTileSize, false);
2800 let rate = recovery_system
2801 .success_rates
2802 .get(&RecoveryAction::ReduceTileSize);
2803 assert!(rate.is_some());
2804 assert!(*rate.unwrap() < 0.5);
2805 }
2806
2807 #[test]
2808 fn test_performance_accuracy_analyzer() {
2809 let params = TradeOffParams {
2810 performance_weight: 0.7,
2811 accuracy_weight: 0.3,
2812 energy_weight: 0.0,
2813 min_accuracy: 0.9,
2814 max_time: Duration::from_secs(10),
2815 objective: OptimizationObjective::Balanced,
2816 };
2817
2818 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2819
2820 analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2822 analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2823
2824 analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2826 analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2827
2828 let optimal_precision = analyzer.optimize_precision();
2830 assert!(matches!(
2831 optimal_precision,
2832 PrecisionMode::Mixed16 | PrecisionMode::Full32
2833 ));
2834 }
2835
2836 #[test]
2837 fn test_pareto_frontier_update() {
2838 let params = TradeOffParams {
2839 performance_weight: 0.5,
2840 accuracy_weight: 0.5,
2841 energy_weight: 0.0,
2842 min_accuracy: 0.8,
2843 max_time: Duration::from_secs(5),
2844 objective: OptimizationObjective::Balanced,
2845 };
2846
2847 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2848
2849 analyzer.record_performance(PrecisionMode::Int8Dynamic, Duration::from_millis(50));
2851 analyzer.record_accuracy(PrecisionMode::Int8Dynamic, 0.85);
2852
2853 analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2854 analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2855
2856 analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2857 analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2858
2859 analyzer.update_pareto_frontier();
2860 assert!(!analyzer.pareto_frontier.is_empty());
2861 assert_eq!(analyzer.pareto_frontier.len(), 3);
2862 }
2863
2864 #[test]
2865 fn test_weighted_score_computation() {
2866 let params = TradeOffParams {
2867 performance_weight: 0.6,
2868 accuracy_weight: 0.4,
2869 energy_weight: 0.0,
2870 min_accuracy: 0.8,
2871 max_time: Duration::from_secs(5),
2872 objective: OptimizationObjective::Custom,
2873 };
2874
2875 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2876
2877 let score1 = analyzer.compute_weighted_score(0.1, 0.9); let score2 = analyzer.compute_weighted_score(0.2, 0.95); assert!(score1 > 0.0);
2882 assert!(score2 > 0.0);
2883 }
2884
2885 #[test]
2886 fn test_advanced_tensor_core_distance_matrix_creation() {
2887 let result = AdvancedTensorCoreDistanceMatrix::new();
2888 assert!(result.is_ok());
2889
2890 let advanced_computer = result.unwrap();
2891 assert!(advanced_computer.dynamic_precision_enabled);
2892 assert!(advanced_computer.auto_recovery_enabled);
2893 }
2894
2895 #[tokio::test]
2896 #[ignore]
2897 async fn test_stability_monitoring_computation() {
2898 let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2899 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2900
2901 let result = advanced_computer
2902 .compute_with_stability_monitoring(&points.view())
2903 .await;
2904 assert!(result.is_ok());
2905
2906 let distances = result.unwrap();
2907 assert_eq!(distances.shape(), &[3, 3]);
2908
2909 let monitor = advanced_computer.stability_monitor.lock().unwrap();
2911 assert!(!monitor.stability_history.is_empty());
2912 }
2913
2914 #[tokio::test]
2915 async fn test_recovery_action_application() {
2916 let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2917 let original_precision = advanced_computer.base_computer.precision_mode;
2918
2919 let result = advanced_computer
2921 .apply_recovery_action(RecoveryAction::IncreasePrecision)
2922 .await;
2923 assert!(result.is_ok());
2924
2925 if original_precision != PrecisionMode::Full32 {
2927 assert_ne!(
2928 advanced_computer.base_computer.precision_mode,
2929 original_precision
2930 );
2931 }
2932
2933 let original_tile_size = advanced_computer.base_computer.tile_size;
2935 let result = advanced_computer
2936 .apply_recovery_action(RecoveryAction::ReduceTileSize)
2937 .await;
2938 assert!(result.is_ok());
2939
2940 let new_tile_size = advanced_computer.base_computer.tile_size;
2941 assert!(new_tile_size.0 <= original_tile_size.0);
2942 assert!(new_tile_size.1 <= original_tile_size.1);
2943 }
2944
2945 #[test]
2946 fn test_result_accuracy_estimation() {
2947 let advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2948
2949 let valid_result = array![[0.0, 1.0], [1.0, 0.0]];
2951 let accuracy = advanced_computer.estimate_result_accuracy(&valid_result);
2952 assert!(accuracy > 0.8 && accuracy <= 1.0);
2953
2954 let invalid_result = array![[0.0, f64::NAN], [1.0, 0.0]];
2956 let accuracy = advanced_computer.estimate_result_accuracy(&invalid_result);
2957 assert_eq!(accuracy, 0.0);
2958
2959 let high_range_result = array![[1e-10, 1e10], [1e5, 1e-5]];
2961 let accuracy = advanced_computer.estimate_result_accuracy(&high_range_result);
2962 assert!(accuracy > 0.0 && accuracy < 1.0);
2963 }
2964
2965 #[test]
2966 fn test_precision_mode_ordering() {
2967 assert!(matches!(
2969 PrecisionMode::AdvancedAdaptive,
2970 PrecisionMode::AdvancedAdaptive
2971 ));
2972 assert_ne!(PrecisionMode::AdvancedAdaptive, PrecisionMode::Adaptive);
2973 }
2974
2975 #[test]
2976 fn test_stability_levels() {
2977 assert!(matches!(StabilityLevel::Critical, StabilityLevel::Critical));
2978 assert_ne!(StabilityLevel::Critical, StabilityLevel::Excellent);
2979 }
2980
2981 #[test]
2982 fn test_error_types() {
2983 let error_types = [
2984 NumericalErrorType::Overflow,
2985 NumericalErrorType::Underflow,
2986 NumericalErrorType::PrecisionLoss,
2987 NumericalErrorType::ConvergenceFailure,
2988 NumericalErrorType::IllConditioned,
2989 NumericalErrorType::InvalidValues,
2990 ];
2991
2992 assert_eq!(error_types.len(), 6);
2993 assert!(error_types.contains(&NumericalErrorType::Overflow));
2994 }
2995
2996 #[test]
2997 fn test_scaling_strategies() {
2998 let strategies = [
2999 ScalingStrategy::Conservative,
3000 ScalingStrategy::Balanced,
3001 ScalingStrategy::Aggressive,
3002 ScalingStrategy::Custom,
3003 ];
3004
3005 assert_eq!(strategies.len(), 4);
3006 assert!(strategies.contains(&ScalingStrategy::Balanced));
3007 }
3008
3009 #[test]
3010 fn test_recovery_actions() {
3011 let actions = [
3012 RecoveryAction::IncreasePrecision,
3013 RecoveryAction::ReduceTileSize,
3014 RecoveryAction::FallbackAlgorithm,
3015 RecoveryAction::NumericalStabilization,
3016 RecoveryAction::RetryWithNewParams,
3017 RecoveryAction::SwitchToCPU,
3018 ];
3019
3020 assert_eq!(actions.len(), 6);
3021 assert!(actions.contains(&RecoveryAction::IncreasePrecision));
3022 }
3023
3024 #[test]
3025 fn test_optimization_objectives() {
3026 let objectives = [
3027 OptimizationObjective::MaxPerformance,
3028 OptimizationObjective::MaxAccuracy,
3029 OptimizationObjective::Balanced,
3030 OptimizationObjective::MinEnergy,
3031 OptimizationObjective::Custom,
3032 ];
3033
3034 assert_eq!(objectives.len(), 5);
3035 assert!(objectives.contains(&OptimizationObjective::Balanced));
3036 }
3037}