1use crate::error::{SpatialError, SpatialResult};
57use ndarray::{s, Array1, Array2, ArrayView2};
58use rand::Rng;
59use statrs::statistics::Statistics;
60use std::collections::{HashMap, VecDeque};
61use std::sync::{Arc, Mutex};
62use std::time::{Duration, Instant};
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum PrecisionMode {
67 Full32,
69 Mixed16,
71 BrainFloat16,
73 Int8Dynamic,
75 Int4Advanced,
77 Adaptive,
79 AdvancedAdaptive,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq)]
85pub enum StabilityLevel {
86 Excellent,
88 Good,
90 Moderate,
92 Poor,
94 Critical,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100pub enum NumericalErrorType {
101 Overflow,
103 Underflow,
105 PrecisionLoss,
107 ConvergenceFailure,
109 IllConditioned,
111 InvalidValues,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq)]
117pub enum ScalingStrategy {
118 Conservative,
120 Balanced,
122 Aggressive,
124 Custom,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq)]
130pub enum TensorLayout {
131 RowMajor,
133 ColMajor,
135 Blocked,
137 ZOrder,
139 HardwareOptimized,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq)]
145pub enum GpuArchitecture {
146 Volta,
148 Ampere,
150 Hopper,
152 CDNA2,
154 CDNA3,
156 XeHPC,
158 XeGraphics,
160 Unknown,
162}
163
164#[derive(Debug, Clone)]
166pub struct TensorCoreCapabilities {
167 pub tensor_core_types: Vec<TensorCoreType>,
169 pub supported_precisions: Vec<PrecisionMode>,
171 pub max_tensor_size: (usize, usize, usize),
173 pub peak_throughput_tops: f64,
175 pub memory_bandwidth_gbps: f64,
177 pub l2_cache_mb: f64,
179 pub num_sms: usize,
181 pub architecture: GpuArchitecture,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq)]
187pub enum TensorCoreType {
188 NvidiaTensorCore,
190 AmdMatrixCore,
192 IntelXMX,
194 StandardCores,
196}
197
198#[derive(Debug, Clone)]
200pub struct StabilityMetrics {
201 pub condition_number: f64,
203 pub relative_error: f64,
205 pub forward_error: f64,
207 pub backward_error: f64,
209 pub digit_loss: f64,
211 pub stability_level: StabilityLevel,
213 pub error_types: Vec<NumericalErrorType>,
215 pub timestamp: Instant,
217}
218
219#[derive(Debug, Clone)]
221pub struct DynamicPrecisionConfig {
222 pub strategy: ScalingStrategy,
224 pub min_precision: PrecisionMode,
226 pub max_precision: PrecisionMode,
228 pub stability_threshold_up: f64,
230 pub stability_threshold_down: f64,
232 pub performance_weight: f64,
234 pub accuracy_weight: f64,
236 pub max_changes_per_operation: usize,
238 pub change_cooldown: Duration,
240}
241
242#[allow(dead_code)]
244#[derive(Debug)]
245pub struct NumericalStabilityMonitor {
246 current_metrics: StabilityMetrics,
248 stability_history: VecDeque<StabilityMetrics>,
250 precision_config: DynamicPrecisionConfig,
252 current_precision: PrecisionMode,
254 precision_history: VecDeque<(Instant, PrecisionMode, f64)>,
256 #[allow(dead_code)]
258 recovery_attempts: usize,
259 max_history_length: usize,
261 last_precision_change: Option<Instant>,
263}
264
265#[allow(dead_code)]
267#[derive(Debug)]
268pub struct ErrorRecoverySystem {
269 recovery_strategies: HashMap<NumericalErrorType, Vec<RecoveryAction>>,
271 recovery_history: VecDeque<RecoveryAttempt>,
273 max_recovery_attempts: usize,
275 success_rates: HashMap<RecoveryAction, f64>,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
281pub enum RecoveryAction {
282 IncreasePrecision,
284 ReduceTileSize,
286 FallbackAlgorithm,
288 NumericalStabilization,
290 RetryWithNewParams,
292 SwitchToCPU,
294}
295
296#[derive(Debug, Clone)]
298pub struct RecoveryAttempt {
299 pub error_type: NumericalErrorType,
301 pub action: RecoveryAction,
303 pub success: bool,
305 pub duration: Duration,
307 pub post_recovery_metrics: Option<StabilityMetrics>,
309 pub timestamp: Instant,
311}
312
313#[derive(Debug)]
315pub struct PerformanceAccuracyAnalyzer {
316 performance_data: HashMap<PrecisionMode, VecDeque<Duration>>,
318 accuracy_data: HashMap<PrecisionMode, VecDeque<f64>>,
320 optimization_params: TradeOffParams,
322 pareto_frontier: Vec<(f64, f64, PrecisionMode)>, }
325
326#[derive(Debug, Clone)]
328pub struct TradeOffParams {
329 pub performance_weight: f64,
331 pub accuracy_weight: f64,
333 pub energy_weight: f64,
335 pub min_accuracy: f64,
337 pub max_time: Duration,
339 pub objective: OptimizationObjective,
341}
342
343#[derive(Debug, Clone, Copy, PartialEq)]
345pub enum OptimizationObjective {
346 MaxPerformance,
348 MaxAccuracy,
350 Balanced,
352 MinEnergy,
354 Custom,
356}
357
358#[derive(Debug)]
360pub struct AdvancedTensorCoreDistanceMatrix {
361 base_computer: TensorCoreDistanceMatrix,
363 stability_monitor: Arc<Mutex<NumericalStabilityMonitor>>,
365 recovery_system: ErrorRecoverySystem,
367 performance_analyzer: PerformanceAccuracyAnalyzer,
369 dynamic_precision_enabled: bool,
371 auto_recovery_enabled: bool,
373}
374
375#[derive(Debug, Clone)]
377pub struct TensorCoreDistanceMatrix {
378 precision_mode: PrecisionMode,
380 layout_optimization: bool,
382 hierarchical_tiling: bool,
384 tile_size: (usize, usize),
386 capabilities: Option<TensorCoreCapabilities>,
388 tensor_layout: TensorLayout,
390 execution_streams: usize,
392}
393
394impl TensorCoreDistanceMatrix {
395 pub fn new() -> SpatialResult<Self> {
397 let capabilities = detect_tensor_core_capabilities()?;
398
399 Ok(Self {
400 precision_mode: PrecisionMode::Mixed16,
401 layout_optimization: true,
402 hierarchical_tiling: true,
403 tile_size: (256, 256),
404 capabilities: Some(capabilities),
405 tensor_layout: TensorLayout::HardwareOptimized,
406 execution_streams: 4,
407 })
408 }
409
410 pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
412 self.precision_mode = mode;
413 self
414 }
415
416 pub fn with_tensor_layout_optimization(mut self, enabled: bool) -> Self {
418 self.layout_optimization = enabled;
419 self
420 }
421
422 pub fn with_hierarchical_tiling(mut self, enabled: bool) -> Self {
424 self.hierarchical_tiling = enabled;
425 self
426 }
427
428 pub fn with_tile_size(mut self, rows: usize, cols: usize) -> Self {
430 self.tile_size = (rows, cols);
431 self
432 }
433
434 pub fn with_execution_streams(mut self, streams: usize) -> Self {
436 self.execution_streams = streams;
437 self
438 }
439
440 pub async fn compute_parallel(
442 &mut self,
443 points: &ArrayView2<'_, f64>,
444 ) -> SpatialResult<Array2<f64>> {
445 let (npoints, ndims) = points.dim();
446
447 if npoints == 0 || ndims == 0 {
448 return Err(SpatialError::InvalidInput("Empty input data".to_string()));
449 }
450
451 let optimizedpoints = if self.layout_optimization {
453 self.optimize_tensor_layout(points)?
454 } else {
455 points.to_owned()
456 };
457
458 if self.hierarchical_tiling && npoints > 1024 {
460 self.compute_hierarchical_tiled(&optimizedpoints.view())
461 .await
462 } else {
463 self.compute_direct_tensor_cores(&optimizedpoints.view())
464 .await
465 }
466 }
467
468 fn optimize_tensor_layout(
470 &mut self,
471 points: &ArrayView2<'_, f64>,
472 ) -> SpatialResult<Array2<f64>> {
473 let (npoints, ndims) = points.dim();
474
475 match self.tensor_layout {
476 TensorLayout::RowMajor => Ok(points.to_owned()),
477 TensorLayout::ColMajor => {
478 let mut transposed = Array2::zeros((ndims, npoints));
479 for (i, point) in points.outer_iter().enumerate() {
480 transposed.column_mut(i).assign(&point);
481 }
482 Ok(transposed.t().to_owned())
483 }
484 TensorLayout::Blocked => TensorCoreDistanceMatrix::create_blocked_layout(points),
485 TensorLayout::ZOrder => self.create_zorder_layout(points),
486 TensorLayout::HardwareOptimized => self.create_hardware_optimized_layout(points),
487 }
488 }
489
490 fn create_blocked_layout(points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
492 let (npoints, ndims) = points.dim();
493 let block_size = 64; let blocked_rows = npoints.div_ceil(block_size) * block_size;
496 let blocked_cols = ndims.div_ceil(block_size) * block_size;
497
498 let mut blocked_data = Array2::zeros((blocked_rows, blocked_cols));
499
500 for block_i in 0..(npoints / block_size + 1) {
501 for block_j in 0..(ndims / block_size + 1) {
502 let start_i = block_i * block_size;
503 let start_j = block_j * block_size;
504 let end_i = (start_i + block_size).min(npoints);
505 let end_j = (start_j + block_size).min(ndims);
506
507 for i in start_i..end_i {
508 for j in start_j..end_j {
509 blocked_data[[i, j]] = points[[i, j]];
510 }
511 }
512 }
513 }
514
515 Ok(blocked_data.slice(s![..npoints, ..ndims]).to_owned())
516 }
517
518 fn create_zorder_layout(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
520 let (npoints, ndims) = points.dim();
521
522 let mut z_indices: Vec<(usize, usize)> = (0..npoints)
524 .map(|i| {
525 (
526 i,
527 TensorCoreDistanceMatrix::calculate_z_order_index(i, ndims),
528 )
529 })
530 .collect();
531
532 z_indices.sort_by_key(|(_, z_idx)| *z_idx);
533
534 let mut reordered_data = Array2::zeros((npoints, ndims));
535 for (new_idx, (old_idx, z_idx)) in z_indices.iter().enumerate() {
536 reordered_data
537 .row_mut(new_idx)
538 .assign(&points.row(*old_idx));
539 }
540
541 Ok(reordered_data)
542 }
543
544 fn calculate_z_order_index(point_idx: usize, ndims: usize) -> usize {
546 let mut z_index = 0;
548 let temp_idx = point_idx;
549
550 for bit in 0..16 {
551 for dim in 0..ndims.min(3) {
553 if temp_idx & (1 << bit) != 0 {
555 z_index |= 1 << (bit * ndims + dim);
556 }
557 }
558 }
559
560 z_index
561 }
562
563 fn create_hardware_optimized_layout(
565 &self,
566 points: &ArrayView2<'_, f64>,
567 ) -> SpatialResult<Array2<f64>> {
568 if let Some(ref capabilities) = self.capabilities {
569 match capabilities.architecture {
570 GpuArchitecture::Ampere | GpuArchitecture::Hopper => {
571 self.create_nvidia_optimized_layout(points)
573 }
574 GpuArchitecture::CDNA2 | GpuArchitecture::CDNA3 => {
575 self.create_amd_optimized_layout(points)
577 }
578 GpuArchitecture::XeHPC | GpuArchitecture::XeGraphics => {
579 self.create_intel_optimized_layout(points)
581 }
582 _ => {
583 TensorCoreDistanceMatrix::create_blocked_layout(points)
585 }
586 }
587 } else {
588 TensorCoreDistanceMatrix::create_blocked_layout(points)
589 }
590 }
591
592 fn create_nvidia_optimized_layout(
594 &self,
595 points: &ArrayView2<'_, f64>,
596 ) -> SpatialResult<Array2<f64>> {
597 let (npoints, ndims) = points.dim();
598
599 let paddedpoints = npoints.div_ceil(8) * 8;
601 let padded_dims = ndims.div_ceil(8) * 8;
602
603 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
604
605 for i in 0..npoints {
607 for j in 0..ndims {
608 padded_data[[i, j]] = points[[i, j]];
609 }
610 }
611
612 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
614 }
615
616 fn create_amd_optimized_layout(
618 &self,
619 points: &ArrayView2<'_, f64>,
620 ) -> SpatialResult<Array2<f64>> {
621 let (npoints, ndims) = points.dim();
622
623 let paddedpoints = npoints.div_ceil(16) * 16;
625 let padded_dims = ndims.div_ceil(16) * 16;
626
627 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
628
629 for i in 0..npoints {
630 for j in 0..ndims {
631 padded_data[[i, j]] = points[[i, j]];
632 }
633 }
634
635 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
636 }
637
638 fn create_intel_optimized_layout(
640 &self,
641 points: &ArrayView2<'_, f64>,
642 ) -> SpatialResult<Array2<f64>> {
643 let (npoints, ndims) = points.dim();
644
645 let paddedpoints = npoints.div_ceil(32) * 32;
647 let padded_dims = ndims.div_ceil(32) * 32;
648
649 let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
650
651 for i in 0..npoints {
652 for j in 0..ndims {
653 padded_data[[i, j]] = points[[i, j]];
654 }
655 }
656
657 Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
658 }
659
660 async fn compute_hierarchical_tiled(
662 &mut self,
663 points: &ArrayView2<'_, f64>,
664 ) -> SpatialResult<Array2<f64>> {
665 let (npoints, ndims) = points.dim();
666 let mut distance_matrix = Array2::zeros((npoints, npoints));
667
668 let (tile_rows, tile_cols) = self.tile_size;
669 let precision_mode = self.precision_mode; let mut tile_futures = Vec::new();
673
674 for i in (0..npoints).step_by(tile_rows) {
675 for j in (0..npoints).step_by(tile_cols) {
676 let end_i = (i + tile_rows).min(npoints);
677 let end_j = (j + tile_cols).min(npoints);
678
679 let tilepoints_i = points.slice(s![i..end_i, ..]).to_owned();
680 let tilepoints_j = points.slice(s![j..end_j, ..]).to_owned();
681
682 let future = async move {
684 let (rows_i, _) = tilepoints_i.dim();
686 let (rows_j, _) = tilepoints_j.dim();
687 let mut tile_distances = Array2::zeros((rows_i, rows_j));
688
689 for r in 0..rows_i {
690 for c in 0..rows_j {
691 let p1 = tilepoints_i.row(r);
692 let p2 = tilepoints_j.row(c);
693 let diff = &p1 - &p2;
694 let dist = diff.iter().map(|x| x.powi(2)).sum::<f64>().sqrt();
695 tile_distances[[r, c]] = dist;
696 }
697 }
698 Ok::<Array2<f64>, SpatialError>(tile_distances)
699 };
700 tile_futures.push((i, j, end_i, end_j, future));
701 }
702 }
703
704 for (i, j, end_i, end_j, future) in tile_futures {
706 let tile_result = future.await?;
707
708 let tile_rows = end_i - i;
710 let tile_cols = end_j - j;
711
712 for row in 0..tile_rows {
713 for col in 0..tile_cols {
714 distance_matrix[[i + row, j + col]] = tile_result[[row, col]];
715 }
716 }
717 }
718
719 Ok(distance_matrix)
720 }
721
722 async fn compute_tile_tensor_cores(
724 &mut self,
725 points_i: Array2<f64>,
726 points_j: Array2<f64>,
727 precision_mode: PrecisionMode,
728 ) -> SpatialResult<Array2<f64>> {
729 let (_n_i, ndims) = points_i.dim();
730 let (_n_j, _) = points_j.dim();
731
732 match precision_mode {
733 PrecisionMode::Full32 => {
734 self.compute_distances_fp32(&points_i.view(), &points_j.view())
735 .await
736 }
737 PrecisionMode::Mixed16 => {
738 self.compute_distances_mixed16(&points_i.view(), &points_j.view())
739 .await
740 }
741 PrecisionMode::BrainFloat16 => {
742 self.compute_distances_bf16(&points_i.view(), &points_j.view())
743 .await
744 }
745 PrecisionMode::Int8Dynamic => {
746 self.compute_distances_int8(&points_i.view(), &points_j.view())
747 .await
748 }
749 PrecisionMode::Int4Advanced => {
750 self.compute_distances_int4(&points_i.view(), &points_j.view())
751 .await
752 }
753 PrecisionMode::Adaptive => {
754 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
755 .await
756 }
757 PrecisionMode::AdvancedAdaptive => {
758 self.compute_distances_adaptive(&points_i.view(), &points_j.view())
759 .await
760 }
761 }
762 }
763
764 async fn compute_direct_tensor_cores(
766 &mut self,
767 points: &ArrayView2<'_, f64>,
768 ) -> SpatialResult<Array2<f64>> {
769 self.compute_tile_tensor_cores(points.to_owned(), points.to_owned(), self.precision_mode)
770 .await
771 }
772
773 async fn compute_distances_fp32(
775 &self,
776 points_i: &ArrayView2<'_, f64>,
777 points_j: &ArrayView2<'_, f64>,
778 ) -> SpatialResult<Array2<f64>> {
779 let (n_i, ndims) = points_i.dim();
780 let (n_j, _) = points_j.dim();
781 let mut distances = Array2::zeros((n_i, n_j));
782
783 let norms_i: Array1<f64> = points_i
788 .outer_iter()
789 .map(|point| point.iter().map(|&x| x * x).sum())
790 .collect();
791
792 let norms_j: Array1<f64> = points_j
794 .outer_iter()
795 .map(|point| point.iter().map(|&x| x * x).sum())
796 .collect();
797
798 let cross_terms = self
800 .tensor_core_gemm_fp32(points_i, &points_j.t().to_owned().view())
801 .await?;
802
803 for _i in 0..n_i {
805 for _j in 0..n_j {
806 distances[[_i, _j]] = (norms_i[_i] + norms_j[_j] - 2.0 * cross_terms[[_i, _j]])
807 .max(0.0)
808 .sqrt();
809 }
810 }
811
812 Ok(distances)
813 }
814
815 async fn compute_distances_mixed16(
817 &self,
818 points_i: &ArrayView2<'_, f64>,
819 points_j: &ArrayView2<'_, f64>,
820 ) -> SpatialResult<Array2<f64>> {
821 let points_i_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_i)?;
823 let points_j_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_j)?;
824
825 let (n_i, _) = points_i.dim();
826 let (n_j, _) = points_j.dim();
827 let mut distances = Array2::zeros((n_i, n_j));
828
829 let norms_i_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_i_f16)?;
831 let norms_j_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_j_f16)?;
832
833 let cross_terms = self
835 .tensor_core_gemm_mixed16(&points_i_f16, &points_j_f16.t().to_owned())
836 .await?;
837
838 for _i in 0..n_i {
839 for _j in 0..n_j {
840 let distance_sq = norms_i_f16[_i] as f64 + norms_j_f16[_j] as f64
841 - 2.0 * cross_terms[[_i, _j]] as f64;
842 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
843 }
844 }
845
846 Ok(distances)
847 }
848
849 async fn compute_distances_bf16(
851 &mut self,
852 points_i: &ArrayView2<'_, f64>,
853 points_j: &ArrayView2<'_, f64>,
854 ) -> SpatialResult<Array2<f64>> {
855 let points_i_bf16 = self.convert_to_bf16(points_i)?;
858 let points_j_bf16 = self.convert_to_bf16(points_j)?;
859
860 let (n_i, _) = points_i.dim();
861 let (n_j, _) = points_j.dim();
862 let mut distances = Array2::zeros((n_i, n_j));
863
864 let norms_i_bf16 = self.compute_norms_bf16(&points_i_bf16)?;
865 let norms_j_bf16 = self.compute_norms_bf16(&points_j_bf16)?;
866
867 let cross_terms = self
868 .tensor_core_gemm_bf16(&points_i_bf16, &points_j_bf16.t().to_owned())
869 .await?;
870
871 for _i in 0..n_i {
872 for _j in 0..n_j {
873 let distance_sq = norms_i_bf16[_i] as f64 + norms_j_bf16[_j] as f64
874 - 2.0 * cross_terms[[_i, _j]] as f64;
875 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
876 }
877 }
878
879 Ok(distances)
880 }
881
882 async fn compute_distances_int8(
884 &self,
885 points_i: &ArrayView2<'_, f64>,
886 points_j: &ArrayView2<'_, f64>,
887 ) -> SpatialResult<Array2<f64>> {
888 let (scale_i, points_i_int8) = self.quantize_to_int8_dynamic(points_i)?;
890 let (scale_j, points_j_int8) = self.quantize_to_int8_dynamic(points_j)?;
891
892 let (n_i, _) = points_i.dim();
893 let (n_j, _) = points_j.dim();
894 let mut distances = Array2::zeros((n_i, n_j));
895
896 let combined_scale = scale_i * scale_j;
898
899 for _i in 0..n_i {
900 for _j in 0..n_j {
901 let cross_term_int32 = points_i_int8
903 .row(_i)
904 .iter()
905 .zip(points_j_int8.row(_j).iter())
906 .map(|(&a, &b)| (a as i32) * (b as i32))
907 .sum::<i32>();
908 let cross_term_f64 = cross_term_int32 as f64 * combined_scale;
909
910 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
912 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
913
914 let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
915 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
916 }
917 }
918
919 Ok(distances)
920 }
921
922 async fn compute_distances_int4(
924 &self,
925 points_i: &ArrayView2<'_, f64>,
926 points_j: &ArrayView2<'_, f64>,
927 ) -> SpatialResult<Array2<f64>> {
928 let (scale_i, points_i_int4) = self.quantize_to_int4_advanced(points_i)?;
930 let (scale_j, points_j_int4) = self.quantize_to_int4_advanced(points_j)?;
931
932 let points_i_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_i_int4);
934 let points_j_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_j_int4);
935
936 let (n_i, _) = points_i.dim();
937 let (n_j, _) = points_j.dim();
938 let mut distances = Array2::zeros((n_i, n_j));
939
940 for _i in 0..n_i {
946 for _j in 0..n_j {
947 let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
948 let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
949
950 let cross_term_f64 = 0.0; let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
953 distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
954 }
955 }
956
957 Ok(distances)
958 }
959
960 async fn compute_distances_adaptive(
962 &mut self,
963 points_i: &ArrayView2<'_, f64>,
964 points_j: &ArrayView2<'_, f64>,
965 ) -> SpatialResult<Array2<f64>> {
966 let data_range = self.analyze_data_range(points_i, points_j);
968 let condition_number = self.estimate_condition_number(points_i, points_j);
969
970 let optimal_precision = if condition_number > 1e6 {
971 PrecisionMode::Full32
972 } else if data_range > 1e3 {
973 PrecisionMode::BrainFloat16
974 } else if data_range > 100.0 {
975 PrecisionMode::Mixed16
976 } else {
977 PrecisionMode::Int8Dynamic
978 };
979
980 match optimal_precision {
981 PrecisionMode::Full32 => self.compute_distances_fp32(points_i, points_j).await,
982 PrecisionMode::Mixed16 => self.compute_distances_mixed16(points_i, points_j).await,
983 PrecisionMode::BrainFloat16 => self.compute_distances_bf16(points_i, points_j).await,
984 PrecisionMode::Int8Dynamic => self.compute_distances_int8(points_i, points_j).await,
985 PrecisionMode::Int4Advanced => self.compute_distances_int8(points_i, points_j).await, PrecisionMode::Adaptive => self.compute_distances_mixed16(points_i, points_j).await, PrecisionMode::AdvancedAdaptive => {
988 self.compute_distances_fp32(points_i, points_j).await
989 } }
991 }
992
993 async fn tensor_core_gemm_fp32(
995 &self,
996 a: &ArrayView2<'_, f64>,
997 b: &ArrayView2<'_, f64>,
998 ) -> SpatialResult<Array2<f64>> {
999 let (m, k) = a.dim();
1001 let (k2, n) = b.dim();
1002
1003 if k != k2 {
1004 return Err(SpatialError::InvalidInput(
1005 "Matrix dimensions don't match for multiplication".to_string(),
1006 ));
1007 }
1008
1009 let mut c = Array2::zeros((m, n));
1010
1011 let block_size = 16; for i in (0..m).step_by(block_size) {
1015 for j in (0..n).step_by(block_size) {
1016 for kk in (0..k).step_by(block_size) {
1017 let end_i = (i + block_size).min(m);
1018 let end_j = (j + block_size).min(n);
1019 let end_k = (kk + block_size).min(k);
1020
1021 for ii in i..end_i {
1023 for jj in j..end_j {
1024 for kkk in kk..end_k {
1025 c[[ii, jj]] += a[[ii, kkk]] * b[[kkk, jj]];
1026 }
1027 }
1028 }
1029 }
1030 }
1031 }
1032
1033 Ok(c)
1034 }
1035
1036 async fn tensor_core_gemm_mixed16(
1038 &self,
1039 a: &Array2<f32>,
1040 b: &Array2<f32>,
1041 ) -> SpatialResult<Array2<f32>> {
1042 let (m, k) = a.dim();
1044 let (k2, n) = b.dim();
1045
1046 if k != k2 {
1047 return Err(SpatialError::InvalidInput(
1048 "Matrix dimensions don't match".to_string(),
1049 ));
1050 }
1051
1052 let mut c = Array2::zeros((m, n));
1053 let block_size = 16;
1054
1055 for i in (0..m).step_by(block_size) {
1056 for j in (0..n).step_by(block_size) {
1057 for kk in (0..k).step_by(block_size) {
1058 let end_i = (i + block_size).min(m);
1059 let end_j = (j + block_size).min(n);
1060 let end_k = (kk + block_size).min(k);
1061
1062 for ii in i..end_i {
1063 for jj in j..end_j {
1064 for kkk in kk..end_k {
1065 c[[ii, jj]] += a[[ii, kkk]] * b[[kkk, jj]];
1067 }
1068 }
1069 }
1070 }
1071 }
1072 }
1073
1074 Ok(c)
1075 }
1076
1077 async fn tensor_core_gemm_bf16(
1079 &self,
1080 a: &Array2<f32>,
1081 b: &Array2<f32>,
1082 ) -> SpatialResult<Array2<f32>> {
1083 self.tensor_core_gemm_mixed16(a, b).await
1085 }
1086
1087 #[allow(dead_code)]
1089 async fn tensor_core_gemm_int8(
1090 &self,
1091 a: &Array2<i8>,
1092 b: &Array2<i8>,
1093 ) -> SpatialResult<Array2<i32>> {
1094 let (m, k) = a.dim();
1095 let (k2, n) = b.dim();
1096
1097 if k != k2 {
1098 return Err(SpatialError::InvalidInput(
1099 "Matrix dimensions don't match".to_string(),
1100 ));
1101 }
1102
1103 let mut c = Array2::zeros((m, n));
1104 let block_size = 16;
1105
1106 for i in (0..m).step_by(block_size) {
1107 for j in (0..n).step_by(block_size) {
1108 for kk in (0..k).step_by(block_size) {
1109 let end_i = (i + block_size).min(m);
1110 let end_j = (j + block_size).min(n);
1111 let end_k = (kk + block_size).min(k);
1112
1113 for ii in i..end_i {
1114 for jj in j..end_j {
1115 for kkk in kk..end_k {
1116 c[[ii, jj]] += a[[ii, kkk]] as i32 * b[[kkk, jj]] as i32;
1118 }
1119 }
1120 }
1121 }
1122 }
1123 }
1124
1125 Ok(c)
1126 }
1127
1128 fn convert_to_fp16(data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1130 let (rows, cols) = data.dim();
1131 let mut fp16_data = Array2::zeros((rows, cols));
1132
1133 for i in 0..rows {
1134 for j in 0..cols {
1135 fp16_data[[i, j]] = data[[i, j]] as f32;
1137 }
1138 }
1139
1140 Ok(fp16_data)
1141 }
1142
1143 fn convert_to_bf16(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1145 TensorCoreDistanceMatrix::convert_to_fp16(data)
1147 }
1148
1149 fn quantize_to_int8_dynamic(
1151 &self,
1152 data: &ArrayView2<'_, f64>,
1153 ) -> SpatialResult<(f64, Array2<i8>)> {
1154 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1155 let scale = max_val / 127.0; let (rows, cols) = data.dim();
1158 let mut quantized = Array2::zeros((rows, cols));
1159
1160 for i in 0..rows {
1161 for j in 0..cols {
1162 let quantized_val = (data[[i, j]] / scale).round() as i8;
1163 quantized[[i, j]] = quantized_val.clamp(-127, 127);
1164 }
1165 }
1166
1167 Ok((scale, quantized))
1168 }
1169
1170 fn quantize_to_int4_advanced(
1172 &self,
1173 data: &ArrayView2<'_, f64>,
1174 ) -> SpatialResult<(f64, Array2<i8>)> {
1175 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1176 let scale = max_val / 7.0; let (rows, cols) = data.dim();
1179 let mut quantized = Array2::zeros((rows, cols));
1180
1181 for i in 0..rows {
1182 for j in 0..cols {
1183 let quantized_val = (data[[i, j]] / scale).round() as i8;
1184 quantized[[i, j]] = quantized_val.clamp(-7, 7);
1185 }
1186 }
1187
1188 Ok((scale, quantized))
1189 }
1190
1191 fn int4_to_int8(data: &Array2<i8>) -> Array2<i8> {
1193 data.mapv(|x| x.clamp(-7, 7))
1195 }
1196
1197 fn compute_norms_fp16(data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1199 let norms = data
1200 .outer_iter()
1201 .map(|row| row.iter().map(|&x| x * x).sum())
1202 .collect();
1203 Ok(norms)
1204 }
1205
1206 fn compute_norms_bf16(&mut self, data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1208 TensorCoreDistanceMatrix::compute_norms_fp16(data)
1209 }
1210
1211 fn analyze_data_range(
1213 &self,
1214 points_i: &ArrayView2<'_, f64>,
1215 points_j: &ArrayView2<'_, f64>,
1216 ) -> f64 {
1217 let min_i = points_i.fold(f64::INFINITY, |acc, &x| acc.min(x));
1218 let max_i = points_i.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1219 let min_j = points_j.fold(f64::INFINITY, |acc, &x| acc.min(x));
1220 let max_j = points_j.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1221
1222 let overall_min = min_i.min(min_j);
1223 let overall_max = max_i.max(max_j);
1224
1225 overall_max - overall_min
1226 }
1227
1228 fn estimate_condition_number(
1230 &self,
1231 points_i: &ArrayView2<'_, f64>,
1232 points_j: &ArrayView2<'_, f64>,
1233 ) -> f64 {
1234 let data_range = self.analyze_data_range(points_i, points_j);
1236 let mean_i: f64 = points_i.sum() / (points_i.len() as f64);
1237 let mean_j: f64 = points_j.sum() / (points_j.len() as f64);
1238 let overall_mean = (mean_i + mean_j) / 2.0;
1239
1240 if overall_mean.abs() < 1e-10 {
1241 1e6 } else {
1243 data_range / overall_mean.abs()
1244 }
1245 }
1246}
1247
1248#[allow(dead_code)]
1250#[derive(Debug, Clone)]
1251pub struct TensorCoreClustering {
1252 _numclusters: usize,
1254 precision_mode: PrecisionMode,
1256 tensor_cores: bool,
1258 mixed_precision: bool,
1260 dynamic_precision: bool,
1262 capabilities: Option<TensorCoreCapabilities>,
1264}
1265
1266impl TensorCoreClustering {
1267 pub fn new(_numclusters: usize) -> SpatialResult<Self> {
1269 let capabilities = detect_tensor_core_capabilities().ok();
1270
1271 Ok(Self {
1272 _numclusters,
1273 precision_mode: PrecisionMode::Mixed16,
1274 tensor_cores: true,
1275 mixed_precision: true,
1276 dynamic_precision: false,
1277 capabilities,
1278 })
1279 }
1280
1281 pub fn with_tensor_cores(mut self, enabled: bool) -> Self {
1283 self.tensor_cores = enabled;
1284 self
1285 }
1286
1287 pub fn with_mixed_precision(mut self, enabled: bool) -> Self {
1289 self.mixed_precision = enabled;
1290 self
1291 }
1292
1293 pub fn with_dynamic_precision_scaling(mut self, enabled: bool) -> Self {
1295 self.dynamic_precision = enabled;
1296 self
1297 }
1298
1299 pub async fn fit(
1301 &mut self,
1302 points: &ArrayView2<'_, f64>,
1303 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
1304 let (npoints, ndims) = points.dim();
1305
1306 if npoints < self._numclusters {
1307 return Err(SpatialError::InvalidInput(
1308 "Number of points must be >= number of clusters".to_string(),
1309 ));
1310 }
1311
1312 let mut centroids = self.initialize_centroids(points)?;
1314 let mut assignments = Array1::zeros(npoints);
1315
1316 for _iteration in 0..100 {
1318 let distance_matrix = if self.tensor_cores {
1320 let tensor_computer =
1321 TensorCoreDistanceMatrix::new()?.with_precision_mode(self.precision_mode);
1322 tensor_computer
1323 .compute_distances_to_centroids(points, ¢roids.view())
1324 .await?
1325 } else {
1326 self.compute_distances_fallback(points, ¢roids.view())?
1327 };
1328
1329 let new_assignments = self.update_assignments(&distance_matrix)?;
1331
1332 let new_centroids = if self.tensor_cores {
1334 self.update_centroids_tensor_cores(points, &new_assignments)
1335 .await?
1336 } else {
1337 self.update_centroids_fallback(points, &new_assignments)?
1338 };
1339
1340 let centroid_change = self.compute_centroid_change(¢roids, &new_centroids);
1342 if centroid_change < 1e-6 {
1343 break;
1344 }
1345
1346 centroids = new_centroids;
1347 assignments = new_assignments;
1348 }
1349
1350 Ok((centroids, assignments))
1351 }
1352
1353 fn initialize_centroids(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
1355 let (npoints, ndims) = points.dim();
1356 let mut centroids = Array2::zeros((self._numclusters, ndims));
1357
1358 let mut rng = rand::rng();
1360
1361 let first_idx = rng.gen_range(0..npoints);
1363 centroids.row_mut(0).assign(&points.row(first_idx));
1364
1365 for k in 1..self._numclusters {
1367 let mut distances = Array1::zeros(npoints);
1368
1369 for i in 0..npoints {
1370 let point = points.row(i);
1371 let mut min_dist = f64::INFINITY;
1372
1373 for j in 0..k {
1374 let centroid = centroids.row(j);
1375 let dist: f64 = point
1376 .iter()
1377 .zip(centroid.iter())
1378 .map(|(&a, &b)| (a - b).powi(2))
1379 .sum::<f64>();
1380 min_dist = min_dist.min(dist);
1381 }
1382
1383 distances[i] = min_dist;
1384 }
1385
1386 let total_dist: f64 = distances.sum();
1388 let mut cumulative = 0.0;
1389 let random_val = rand::random::<f64>() * total_dist;
1390
1391 for i in 0..npoints {
1392 cumulative += distances[i];
1393 if cumulative >= random_val {
1394 centroids.row_mut(k).assign(&points.row(i));
1395 break;
1396 }
1397 }
1398 }
1399
1400 Ok(centroids)
1401 }
1402
1403 fn update_assignments(
1405 &mut self,
1406 distance_matrix: &Array2<f64>,
1407 ) -> SpatialResult<Array1<usize>> {
1408 let npoints = distance_matrix.nrows();
1409 let mut assignments = Array1::zeros(npoints);
1410
1411 for i in 0..npoints {
1412 let mut min_dist = f64::INFINITY;
1413 let mut best_cluster = 0;
1414
1415 for j in 0..self._numclusters {
1416 if distance_matrix[[i, j]] < min_dist {
1417 min_dist = distance_matrix[[i, j]];
1418 best_cluster = j;
1419 }
1420 }
1421
1422 assignments[i] = best_cluster;
1423 }
1424
1425 Ok(assignments)
1426 }
1427
1428 async fn update_centroids_tensor_cores(
1430 &self,
1431 points: &ArrayView2<'_, f64>,
1432 assignments: &Array1<usize>,
1433 ) -> SpatialResult<Array2<f64>> {
1434 let (_npoints, ndims) = points.dim();
1435 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1436 let mut cluster_counts = vec![0; self._numclusters];
1437
1438 for &cluster in assignments {
1440 cluster_counts[cluster] += 1;
1441 }
1442
1443 for cluster in 0..self._numclusters {
1445 if cluster_counts[cluster] == 0 {
1446 continue;
1447 }
1448
1449 let clusterpoints: Vec<usize> = assignments
1451 .iter()
1452 .enumerate()
1453 .filter(|(_, &c)| c == cluster)
1454 .map(|(i, _)| i)
1455 .collect();
1456
1457 let cluster_data = Array2::from_shape_fn((clusterpoints.len(), ndims), |(i, j)| {
1459 points[[clusterpoints[i], j]]
1460 });
1461
1462 let sum_vector = self.tensor_sum_reduction(&cluster_data.view()).await?;
1464 let count = clusterpoints.len() as f64;
1465
1466 for j in 0..ndims {
1467 new_centroids[[cluster, j]] = sum_vector[j] / count;
1468 }
1469 }
1470
1471 Ok(new_centroids)
1472 }
1473
1474 async fn tensor_sum_reduction(&self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
1476 let (_npoints, ndims) = data.dim();
1477 let mut sum_vector = Array1::zeros(ndims);
1478
1479 for j in 0..ndims {
1481 let column_sum: f64 = data.column(j).sum();
1482 sum_vector[j] = column_sum;
1483 }
1484
1485 Ok(sum_vector)
1486 }
1487
1488 fn compute_distances_fallback(
1490 &self,
1491 points: &ArrayView2<'_, f64>,
1492 centroids: &ArrayView2<'_, f64>,
1493 ) -> SpatialResult<Array2<f64>> {
1494 let (npoints, ndims) = points.dim();
1495 let (n_clusters_, _) = centroids.dim();
1496 let mut distances = Array2::zeros((npoints, n_clusters_));
1497
1498 for i in 0..npoints {
1499 for j in 0..n_clusters_ {
1500 let distance: f64 = points
1501 .row(i)
1502 .iter()
1503 .zip(centroids.row(j).iter())
1504 .map(|(&a, &b)| (a - b).powi(2))
1505 .sum::<f64>()
1506 .sqrt();
1507 distances[[i, j]] = distance;
1508 }
1509 }
1510
1511 Ok(distances)
1512 }
1513
1514 fn update_centroids_fallback(
1516 &self,
1517 points: &ArrayView2<'_, f64>,
1518 assignments: &Array1<usize>,
1519 ) -> SpatialResult<Array2<f64>> {
1520 let (npoints, ndims) = points.dim();
1521 let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1522 let mut cluster_counts = vec![0; self._numclusters];
1523
1524 for i in 0..npoints {
1526 let cluster = assignments[i];
1527 cluster_counts[cluster] += 1;
1528
1529 for j in 0..ndims {
1530 new_centroids[[cluster, j]] += points[[i, j]];
1531 }
1532 }
1533
1534 for cluster in 0..self._numclusters {
1536 if cluster_counts[cluster] > 0 {
1537 let count = cluster_counts[cluster] as f64;
1538 for j in 0..ndims {
1539 new_centroids[[cluster, j]] /= count;
1540 }
1541 }
1542 }
1543
1544 Ok(new_centroids)
1545 }
1546
1547 fn compute_centroid_change(
1549 &self,
1550 old_centroids: &Array2<f64>,
1551 new_centroids: &Array2<f64>,
1552 ) -> f64 {
1553 let mut total_change = 0.0;
1554
1555 for i in 0..self._numclusters {
1556 let change: f64 = old_centroids
1557 .row(i)
1558 .iter()
1559 .zip(new_centroids.row(i).iter())
1560 .map(|(&a, &b)| (a - b).powi(2))
1561 .sum::<f64>()
1562 .sqrt();
1563 total_change += change;
1564 }
1565
1566 total_change / (self._numclusters as f64)
1567 }
1568}
1569
1570impl Default for StabilityMetrics {
1571 fn default() -> Self {
1572 Self::new()
1573 }
1574}
1575
1576impl StabilityMetrics {
1577 pub fn new() -> Self {
1579 Self {
1580 condition_number: 1.0,
1581 relative_error: 0.0,
1582 forward_error: 0.0,
1583 backward_error: 0.0,
1584 digit_loss: 0.0,
1585 stability_level: StabilityLevel::Excellent,
1586 error_types: Vec::new(),
1587 timestamp: Instant::now(),
1588 }
1589 }
1590
1591 pub fn update_stability_level(&mut self) {
1593 self.stability_level = if self.condition_number > 1e12 || self.relative_error > 1e-3 {
1594 StabilityLevel::Critical
1595 } else if self.condition_number > 1e8 || self.relative_error > 1e-6 {
1596 StabilityLevel::Poor
1597 } else if self.condition_number > 1e4 || self.relative_error > 1e-9 {
1598 StabilityLevel::Moderate
1599 } else if self.condition_number > 1e2 || self.relative_error > 1e-12 {
1600 StabilityLevel::Good
1601 } else {
1602 StabilityLevel::Excellent
1603 };
1604 }
1605
1606 pub fn detect_errors(&mut self, data: &Array2<f64>) {
1608 self.error_types.clear();
1609
1610 for &value in data.iter() {
1612 if !value.is_finite() {
1613 self.error_types.push(NumericalErrorType::InvalidValues);
1614 break;
1615 }
1616 }
1617
1618 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1620 if max_val > 1e100 {
1621 self.error_types.push(NumericalErrorType::Overflow);
1622 } else if max_val < 1e-100 && max_val > 0.0 {
1623 self.error_types.push(NumericalErrorType::Underflow);
1624 }
1625
1626 if self.digit_loss > 6.0 {
1628 self.error_types.push(NumericalErrorType::PrecisionLoss);
1629 }
1630
1631 if self.condition_number > 1e12 {
1633 self.error_types.push(NumericalErrorType::IllConditioned);
1634 }
1635 }
1636}
1637
1638impl Default for DynamicPrecisionConfig {
1639 fn default() -> Self {
1640 Self {
1641 strategy: ScalingStrategy::Balanced,
1642 min_precision: PrecisionMode::Int8Dynamic,
1643 max_precision: PrecisionMode::Full32,
1644 stability_threshold_up: 1e-6,
1645 stability_threshold_down: 1e-9,
1646 performance_weight: 0.6,
1647 accuracy_weight: 0.4,
1648 max_changes_per_operation: 3,
1649 change_cooldown: Duration::from_millis(100),
1650 }
1651 }
1652}
1653
1654impl NumericalStabilityMonitor {
1655 pub fn new(config: DynamicPrecisionConfig) -> Self {
1657 Self {
1658 current_metrics: StabilityMetrics::new(),
1659 stability_history: VecDeque::new(),
1660 precision_config: config,
1661 current_precision: PrecisionMode::Mixed16,
1662 precision_history: VecDeque::new(),
1663 recovery_attempts: 0,
1664 max_history_length: 1000,
1665 last_precision_change: None,
1666 }
1667 }
1668
1669 pub fn monitor_stability(
1671 &mut self,
1672 data: &Array2<f64>,
1673 computation_result: &Array2<f64>,
1674 ) -> SpatialResult<()> {
1675 self.current_metrics.condition_number =
1677 NumericalStabilityMonitor::estimate_condition_number(data);
1678
1679 self.current_metrics.relative_error =
1681 self.estimate_relative_error(data, computation_result);
1682
1683 self.current_metrics.forward_error = self.estimate_forward_error(data, computation_result);
1685 self.current_metrics.backward_error =
1686 self.estimate_backward_error(data, computation_result);
1687
1688 self.current_metrics.digit_loss = self.estimate_digit_loss();
1690
1691 self.current_metrics.update_stability_level();
1693
1694 self.current_metrics.detect_errors(computation_result);
1696
1697 self.current_metrics.timestamp = Instant::now();
1699
1700 self.stability_history
1702 .push_back(self.current_metrics.clone());
1703 if self.stability_history.len() > self.max_history_length {
1704 self.stability_history.pop_front();
1705 }
1706
1707 Ok(())
1708 }
1709
1710 pub fn adjust_precision(&mut self) -> SpatialResult<PrecisionMode> {
1712 if let Some(last_change) = self.last_precision_change {
1714 if last_change.elapsed() < self.precision_config.change_cooldown {
1715 return Ok(self.current_precision);
1716 }
1717 }
1718
1719 let new_precision = match self.current_metrics.stability_level {
1720 StabilityLevel::Critical => {
1721 self.precision_config.max_precision
1723 }
1724 StabilityLevel::Poor => {
1725 NumericalStabilityMonitor::increase_precision(self.current_precision)
1727 }
1728 StabilityLevel::Moderate => {
1729 if self.current_metrics.relative_error
1731 > self.precision_config.stability_threshold_up
1732 {
1733 NumericalStabilityMonitor::increase_precision(self.current_precision)
1734 } else {
1735 self.current_precision
1736 }
1737 }
1738 StabilityLevel::Good => {
1739 if self.current_metrics.relative_error
1741 < self.precision_config.stability_threshold_down
1742 {
1743 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1744 } else {
1745 self.current_precision
1746 }
1747 }
1748 StabilityLevel::Excellent => {
1749 if self.precision_config.strategy == ScalingStrategy::Aggressive {
1751 self.precision_config.min_precision
1752 } else {
1753 NumericalStabilityMonitor::decrease_precision(self.current_precision)
1754 }
1755 }
1756 };
1757
1758 if new_precision != self.current_precision {
1760 self.precision_history.push_back((
1761 Instant::now(),
1762 new_precision,
1763 self.current_metrics.relative_error,
1764 ));
1765 self.current_precision = new_precision;
1766 self.last_precision_change = Some(Instant::now());
1767 }
1768
1769 Ok(new_precision)
1770 }
1771
1772 fn increase_precision(current: PrecisionMode) -> PrecisionMode {
1774 match current {
1775 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
1776 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
1777 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
1778 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
1779 PrecisionMode::Full32 => PrecisionMode::Full32, _ => PrecisionMode::Mixed16,
1781 }
1782 }
1783
1784 fn decrease_precision(current: PrecisionMode) -> PrecisionMode {
1786 match current {
1787 PrecisionMode::Full32 => PrecisionMode::BrainFloat16,
1788 PrecisionMode::BrainFloat16 => PrecisionMode::Mixed16,
1789 PrecisionMode::Mixed16 => PrecisionMode::Int8Dynamic,
1790 PrecisionMode::Int8Dynamic => PrecisionMode::Int4Advanced,
1791 PrecisionMode::Int4Advanced => PrecisionMode::Int4Advanced, _ => PrecisionMode::Mixed16,
1793 }
1794 }
1795
1796 fn estimate_condition_number(data: &Array2<f64>) -> f64 {
1798 let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1800 let min_val = data.fold(f64::INFINITY, |acc, &x| {
1801 if x.abs() > 1e-15 {
1802 acc.min(x.abs())
1803 } else {
1804 acc
1805 }
1806 });
1807
1808 if min_val.is_finite() && min_val > 0.0 {
1809 max_val / min_val
1810 } else {
1811 1e12 }
1813 }
1814
1815 fn estimate_relative_error(&mut self, input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1817 let mean_val = output.mean().unwrap_or(0.0);
1819 if mean_val.abs() > 1e-15 {
1820 let machine_eps = match self.current_precision {
1822 PrecisionMode::Full32 => 2.22e-16,
1823 PrecisionMode::Mixed16 | PrecisionMode::BrainFloat16 => 9.77e-4,
1824 PrecisionMode::Int8Dynamic => 1.0 / 256.0,
1825 PrecisionMode::Int4Advanced => 1.0 / 16.0,
1826 _ => 1e-6,
1827 };
1828 machine_eps * self.current_metrics.condition_number
1829 } else {
1830 0.0
1831 }
1832 }
1833
1834 fn estimate_forward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1836 self.current_metrics.relative_error * self.current_metrics.condition_number
1838 }
1839
1840 fn estimate_backward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1842 self.current_metrics.relative_error
1844 }
1845
1846 fn estimate_digit_loss(&self) -> f64 {
1848 if self.current_metrics.condition_number > 1.0 {
1849 self.current_metrics.condition_number.log10().max(0.0)
1850 } else {
1851 0.0
1852 }
1853 }
1854}
1855
1856impl Default for ErrorRecoverySystem {
1857 fn default() -> Self {
1858 Self::new()
1859 }
1860}
1861
1862impl ErrorRecoverySystem {
1863 pub fn new() -> Self {
1865 let mut recovery_strategies = HashMap::new();
1866
1867 recovery_strategies.insert(
1869 NumericalErrorType::Overflow,
1870 vec![
1871 RecoveryAction::IncreasePrecision,
1872 RecoveryAction::ReduceTileSize,
1873 RecoveryAction::NumericalStabilization,
1874 ],
1875 );
1876 recovery_strategies.insert(
1877 NumericalErrorType::Underflow,
1878 vec![
1879 RecoveryAction::IncreasePrecision,
1880 RecoveryAction::NumericalStabilization,
1881 ],
1882 );
1883 recovery_strategies.insert(
1884 NumericalErrorType::PrecisionLoss,
1885 vec![
1886 RecoveryAction::IncreasePrecision,
1887 RecoveryAction::RetryWithNewParams,
1888 ],
1889 );
1890 recovery_strategies.insert(
1891 NumericalErrorType::IllConditioned,
1892 vec![
1893 RecoveryAction::IncreasePrecision,
1894 RecoveryAction::NumericalStabilization,
1895 RecoveryAction::SwitchToCPU,
1896 ],
1897 );
1898 recovery_strategies.insert(
1899 NumericalErrorType::InvalidValues,
1900 vec![
1901 RecoveryAction::FallbackAlgorithm,
1902 RecoveryAction::SwitchToCPU,
1903 ],
1904 );
1905
1906 Self {
1907 recovery_strategies,
1908 recovery_history: VecDeque::new(),
1909 max_recovery_attempts: 3,
1910 success_rates: HashMap::new(),
1911 }
1912 }
1913
1914 pub async fn attempt_recovery(
1916 &mut self,
1917 error_type: NumericalErrorType,
1918 ) -> SpatialResult<RecoveryAction> {
1919 let start_time = Instant::now();
1920
1921 let strategies = self
1923 .recovery_strategies
1924 .get(&error_type)
1925 .ok_or_else(|| SpatialError::InvalidInput("Unknown error _type".to_string()))?
1926 .clone(); let best_action = self.choose_best_recovery_action(&strategies);
1930
1931 let attempt = RecoveryAttempt {
1933 error_type,
1934 action: best_action,
1935 success: false, duration: start_time.elapsed(),
1937 post_recovery_metrics: None,
1938 timestamp: start_time,
1939 };
1940
1941 self.recovery_history.push_back(attempt);
1942
1943 Ok(best_action)
1944 }
1945
1946 fn choose_best_recovery_action(&mut self, strategies: &[RecoveryAction]) -> RecoveryAction {
1948 strategies
1949 .iter()
1950 .max_by(|&a, &b| {
1951 let rate_a = self.success_rates.get(a).unwrap_or(&0.5);
1952 let rate_b = self.success_rates.get(b).unwrap_or(&0.5);
1953 rate_a
1954 .partial_cmp(rate_b)
1955 .unwrap_or(std::cmp::Ordering::Equal)
1956 })
1957 .copied()
1958 .unwrap_or(RecoveryAction::IncreasePrecision)
1959 }
1960
1961 pub fn update_success_rate(&mut self, action: RecoveryAction, success: bool) {
1963 let current_rate = self.success_rates.get(&action).unwrap_or(&0.5);
1964 let new_rate = if success {
1965 current_rate * 0.9 + 0.1 } else {
1967 current_rate * 0.9
1968 };
1969 self.success_rates.insert(action, new_rate);
1970 }
1971}
1972
1973impl PerformanceAccuracyAnalyzer {
1974 pub fn new(params: TradeOffParams) -> Self {
1976 Self {
1977 performance_data: HashMap::new(),
1978 accuracy_data: HashMap::new(),
1979 optimization_params: params,
1980 pareto_frontier: Vec::new(),
1981 }
1982 }
1983
1984 pub fn record_performance(&mut self, precision: PrecisionMode, duration: Duration) {
1986 self.performance_data
1987 .entry(precision)
1988 .or_default()
1989 .push_back(duration);
1990
1991 if let Some(history) = self.performance_data.get_mut(&precision) {
1993 if history.len() > 100 {
1994 history.pop_front();
1995 }
1996 }
1997 }
1998
1999 pub fn record_accuracy(&mut self, precision: PrecisionMode, accuracy: f64) {
2001 self.accuracy_data
2002 .entry(precision)
2003 .or_default()
2004 .push_back(accuracy);
2005
2006 if let Some(history) = self.accuracy_data.get_mut(&precision) {
2008 if history.len() > 100 {
2009 history.pop_front();
2010 }
2011 }
2012 }
2013
2014 pub fn optimize_precision(&mut self) -> PrecisionMode {
2016 self.update_pareto_frontier();
2017
2018 match self.optimization_params.objective {
2019 OptimizationObjective::MaxPerformance => self
2020 .pareto_frontier
2021 .iter()
2022 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
2023 .map(|(_a, b, mode)| *mode)
2024 .unwrap_or(PrecisionMode::Mixed16),
2025 OptimizationObjective::MaxAccuracy => self
2026 .pareto_frontier
2027 .iter()
2028 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
2029 .map(|(_a, b, mode)| *mode)
2030 .unwrap_or(PrecisionMode::Full32),
2031 OptimizationObjective::Balanced => {
2032 let mut best_score = f64::NEG_INFINITY;
2034 let mut best_mode = PrecisionMode::Mixed16;
2035
2036 let performance_weight = self.optimization_params.performance_weight;
2038 let accuracy_weight = self.optimization_params.accuracy_weight;
2039
2040 for &(perf, acc, mode) in &self.pareto_frontier {
2041 let perf_score = 1.0 / (perf + 1e-9);
2043 let score = performance_weight * perf_score + accuracy_weight * acc;
2044 if score > best_score {
2045 best_score = score;
2046 best_mode = mode;
2047 }
2048 }
2049
2050 best_mode
2051 }
2052 _ => PrecisionMode::Mixed16,
2053 }
2054 }
2055
2056 fn update_pareto_frontier(&mut self) {
2058 self.pareto_frontier.clear();
2059
2060 for precision in [
2061 PrecisionMode::Full32,
2062 PrecisionMode::BrainFloat16,
2063 PrecisionMode::Mixed16,
2064 PrecisionMode::Int8Dynamic,
2065 PrecisionMode::Int4Advanced,
2066 ] {
2067 if let (Some(perf_data), Some(acc_data)) = (
2068 self.performance_data.get(&precision),
2069 self.accuracy_data.get(&precision),
2070 ) {
2071 if !perf_data.is_empty() && !acc_data.is_empty() {
2072 let avg_perf = perf_data.iter().map(|d| d.as_secs_f64()).sum::<f64>()
2073 / perf_data.len() as f64;
2074 let avg_acc = acc_data.iter().sum::<f64>() / acc_data.len() as f64;
2075
2076 self.pareto_frontier.push((avg_perf, avg_acc, precision));
2077 }
2078 }
2079 }
2080 }
2081
2082 #[allow(dead_code)]
2084 fn compute_weighted_score(&mut self, performance: f64, accuracy: f64) -> f64 {
2085 let perf_score = 1.0 / (performance + 1e-9);
2087
2088 self.optimization_params.performance_weight * perf_score
2090 + self.optimization_params.accuracy_weight * accuracy
2091 }
2092}
2093
2094impl AdvancedTensorCoreDistanceMatrix {
2095 pub fn new() -> SpatialResult<Self> {
2097 let base_computer = TensorCoreDistanceMatrix::new()?;
2098 let precision_config = DynamicPrecisionConfig::default();
2099 let stability_monitor =
2100 Arc::new(Mutex::new(NumericalStabilityMonitor::new(precision_config)));
2101 let recovery_system = ErrorRecoverySystem::new();
2102 let trade_off_params = TradeOffParams {
2103 performance_weight: 0.6,
2104 accuracy_weight: 0.4,
2105 energy_weight: 0.0,
2106 min_accuracy: 0.95,
2107 max_time: Duration::from_secs(30),
2108 objective: OptimizationObjective::Balanced,
2109 };
2110 let performance_analyzer = PerformanceAccuracyAnalyzer::new(trade_off_params);
2111
2112 Ok(Self {
2113 base_computer,
2114 stability_monitor,
2115 recovery_system,
2116 performance_analyzer,
2117 dynamic_precision_enabled: true,
2118 auto_recovery_enabled: true,
2119 })
2120 }
2121
2122 pub fn with_dynamic_precision(mut self, enabled: bool) -> Self {
2124 self.dynamic_precision_enabled = enabled;
2125 self
2126 }
2127
2128 pub fn with_auto_recovery(mut self, enabled: bool) -> Self {
2130 self.auto_recovery_enabled = enabled;
2131 self
2132 }
2133
2134 pub async fn compute_with_stability_monitoring(
2136 &mut self,
2137 points: &ArrayView2<'_, f64>,
2138 ) -> SpatialResult<Array2<f64>> {
2139 let start_time = Instant::now();
2140
2141 {
2143 let mut monitor = self.stability_monitor.lock().unwrap();
2144 if self.dynamic_precision_enabled {
2147 let optimal_precision = monitor.adjust_precision()?;
2148 self.base_computer.precision_mode = optimal_precision;
2149 }
2150 }
2151
2152 let mut result = None;
2153 let mut recovery_attempts = 0;
2154 let max_attempts = 3;
2155
2156 while result.is_none() && recovery_attempts < max_attempts {
2157 match self.base_computer.compute_parallel(points).await {
2158 Ok(distances) => {
2159 {
2161 let mut monitor = self.stability_monitor.lock().unwrap();
2162 monitor.monitor_stability(&points.to_owned(), &distances)?;
2163 }
2164
2165 let stability_level = {
2167 let monitor = self.stability_monitor.lock().unwrap();
2168 monitor.current_metrics.stability_level
2169 };
2170
2171 if stability_level == StabilityLevel::Critical && self.auto_recovery_enabled {
2172 recovery_attempts += 1;
2174 let recovery_action = self
2175 .recovery_system
2176 .attempt_recovery(NumericalErrorType::IllConditioned)
2177 .await?;
2178
2179 self.apply_recovery_action(recovery_action).await?;
2181 continue;
2182 } else {
2183 result = Some(distances);
2184 }
2185 }
2186 Err(e) => {
2187 if self.auto_recovery_enabled && recovery_attempts < max_attempts {
2188 recovery_attempts += 1;
2189 let recovery_action = self
2190 .recovery_system
2191 .attempt_recovery(NumericalErrorType::InvalidValues)
2192 .await?;
2193 self.apply_recovery_action(recovery_action).await?;
2194 continue;
2195 } else {
2196 return Err(e);
2197 }
2198 }
2199 }
2200 }
2201
2202 let final_result = result.ok_or_else(|| {
2203 SpatialError::InvalidInput(
2204 "Failed to compute stable result after recovery attempts".to_string(),
2205 )
2206 })?;
2207
2208 let duration = start_time.elapsed();
2210 let precision = self.base_computer.precision_mode;
2211 self.performance_analyzer
2212 .record_performance(precision, duration);
2213
2214 let accuracy = self.estimate_result_accuracy(&final_result);
2216 self.performance_analyzer
2217 .record_accuracy(precision, accuracy);
2218
2219 Ok(final_result)
2220 }
2221
2222 async fn apply_recovery_action(&mut self, action: RecoveryAction) -> SpatialResult<()> {
2224 match action {
2225 RecoveryAction::IncreasePrecision => {
2226 self.base_computer.precision_mode = match self.base_computer.precision_mode {
2227 PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
2228 PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
2229 PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
2230 PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
2231 PrecisionMode::Full32 => PrecisionMode::Full32,
2232 _ => PrecisionMode::Mixed16,
2233 };
2234 }
2235 RecoveryAction::ReduceTileSize => {
2236 let (current_row, current_col) = self.base_computer.tile_size;
2237 self.base_computer.tile_size = (current_row / 2, current_col / 2);
2238 if self.base_computer.tile_size.0 < 16 {
2239 self.base_computer.tile_size = (16, 16);
2240 }
2241 }
2242 RecoveryAction::FallbackAlgorithm => {
2243 self.base_computer.precision_mode = PrecisionMode::Full32;
2245 self.base_computer.hierarchical_tiling = false;
2246 }
2247 RecoveryAction::NumericalStabilization => {
2248 self.base_computer.precision_mode = PrecisionMode::Full32;
2250 self.base_computer.tile_size = (64, 64);
2251 }
2252 _ => {
2253 self.base_computer.precision_mode = PrecisionMode::Full32;
2255 }
2256 }
2257
2258 Ok(())
2259 }
2260
2261 fn estimate_result_accuracy(&self, result: &Array2<f64>) -> f64 {
2263 let has_invalid = result.iter().any(|&x| !x.is_finite());
2265 if has_invalid {
2266 return 0.0;
2267 }
2268
2269 let max_val = result.fold(0.0f64, |acc, &x| acc.max(x.abs()));
2270 let min_val = result.fold(f64::INFINITY, |acc, &x| {
2271 if x.abs() > 1e-15 {
2272 acc.min(x.abs())
2273 } else {
2274 acc
2275 }
2276 });
2277
2278 if min_val.is_finite() && min_val > 0.0 {
2279 let dynamic_range = max_val / min_val;
2280 (1.0 / (1.0 + dynamic_range.log10() / 10.0)).clamp(0.0, 1.0)
2281 } else {
2282 0.95 }
2284 }
2285}
2286
2287#[allow(dead_code)]
2289pub fn detect_tensor_core_capabilities() -> SpatialResult<TensorCoreCapabilities> {
2290 Ok(TensorCoreCapabilities {
2294 tensor_core_types: vec![
2295 TensorCoreType::NvidiaTensorCore,
2296 TensorCoreType::StandardCores,
2297 ],
2298 supported_precisions: vec![
2299 PrecisionMode::Full32,
2300 PrecisionMode::Mixed16,
2301 PrecisionMode::BrainFloat16,
2302 PrecisionMode::Int8Dynamic,
2303 ],
2304 max_tensor_size: (4096, 4096, 4096),
2305 peak_throughput_tops: 312.0, memory_bandwidth_gbps: 1555.0, l2_cache_mb: 40.0,
2308 num_sms: 108,
2309 architecture: GpuArchitecture::Ampere,
2310 })
2311}
2312
2313impl TensorCoreDistanceMatrix {
2315 pub async fn compute_distances_to_centroids(
2317 &self,
2318 points: &ArrayView2<'_, f64>,
2319 centroids: &ArrayView2<'_, f64>,
2320 ) -> SpatialResult<Array2<f64>> {
2321 let (npoints, ndims) = points.dim();
2322 let (n_clusters_, n_dims_c) = centroids.dim();
2323 let mut distances = Array2::zeros((npoints, n_clusters_));
2324
2325 for i in 0..npoints {
2327 for j in 0..n_clusters_ {
2328 let distance: f64 = points
2329 .row(i)
2330 .iter()
2331 .zip(centroids.row(j).iter())
2332 .map(|(&a, &b)| (a - b).powi(2))
2333 .sum::<f64>()
2334 .sqrt();
2335 distances[[i, j]] = distance;
2336 }
2337 }
2338
2339 Ok(distances)
2340 }
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345 use super::*;
2346 use ndarray::array;
2347
2348 #[test]
2349 fn test_precision_mode() {
2350 assert_eq!(PrecisionMode::Mixed16, PrecisionMode::Mixed16);
2351 assert_ne!(PrecisionMode::Mixed16, PrecisionMode::Full32);
2352 }
2353
2354 #[test]
2355 fn test_tensor_core_capabilities() {
2356 let capabilities = detect_tensor_core_capabilities();
2357 assert!(capabilities.is_ok());
2358
2359 let caps = capabilities.unwrap();
2360 assert!(!caps.tensor_core_types.is_empty());
2361 assert!(!caps.supported_precisions.is_empty());
2362 }
2363
2364 #[test]
2365 fn test_tensor_core_distance_matrix_creation() {
2366 let result = TensorCoreDistanceMatrix::new();
2367 assert!(result.is_ok());
2368
2369 let matrix_computer = result.unwrap();
2370 assert_eq!(matrix_computer.precision_mode, PrecisionMode::Mixed16);
2371 }
2372
2373 #[test]
2374 fn test_tensor_core_clustering_creation() {
2375 let result = TensorCoreClustering::new(3);
2376 assert!(result.is_ok());
2377
2378 let clustering = result.unwrap();
2379 assert_eq!(clustering._numclusters, 3);
2380 }
2381
2382 #[tokio::test]
2383 async fn test_tensor_core_distance_computation() {
2384 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2385 let mut matrix_computer = TensorCoreDistanceMatrix::new().unwrap();
2386
2387 let result = matrix_computer.compute_parallel(&points.view()).await;
2388 assert!(result.is_ok());
2389
2390 let distances = result.unwrap();
2391 assert_eq!(distances.shape(), &[3, 3]);
2392
2393 for i in 0..3 {
2395 assert!((distances[[i, i]]).abs() < 1e-10);
2396 }
2397 }
2398
2399 #[tokio::test]
2400 async fn test_tensor_core_clustering() {
2401 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
2402 let mut clustering = TensorCoreClustering::new(2).unwrap();
2403
2404 let result = clustering.fit(&points.view()).await;
2405 assert!(result.is_ok());
2406
2407 let (centroids, assignments) = result.unwrap();
2408 assert_eq!(centroids.shape(), &[2, 2]);
2409 assert_eq!(assignments.len(), 4);
2410 }
2411
2412 #[test]
2413 fn test_stability_metrics_creation() {
2414 let metrics = StabilityMetrics::new();
2415 assert_eq!(metrics.condition_number, 1.0);
2416 assert_eq!(metrics.relative_error, 0.0);
2417 assert_eq!(metrics.stability_level, StabilityLevel::Excellent);
2418 assert!(metrics.error_types.is_empty());
2419 }
2420
2421 #[test]
2422 fn test_stability_level_update() {
2423 let mut metrics = StabilityMetrics::new();
2424
2425 metrics.condition_number = 1e15;
2427 metrics.update_stability_level();
2428 assert_eq!(metrics.stability_level, StabilityLevel::Critical);
2429
2430 metrics.condition_number = 1e9;
2432 metrics.relative_error = 1e-7;
2433 metrics.update_stability_level();
2434 assert_eq!(metrics.stability_level, StabilityLevel::Poor);
2435
2436 metrics.condition_number = 1e3;
2438 metrics.relative_error = 1e-10;
2439 metrics.update_stability_level();
2440 assert_eq!(metrics.stability_level, StabilityLevel::Good);
2441 }
2442
2443 #[test]
2444 #[ignore]
2445 fn test_error_detection() {
2446 let mut metrics = StabilityMetrics::new();
2447
2448 let data_with_nan = array![[1.0, 2.0], [f64::NAN, 4.0]];
2450 metrics.detect_errors(&data_with_nan);
2451 assert!(metrics
2452 .error_types
2453 .contains(&NumericalErrorType::InvalidValues));
2454
2455 let data_with_overflow = array![[1e150, 2.0], [3.0, 4.0]];
2457 metrics.detect_errors(&data_with_overflow);
2458 assert!(metrics.error_types.contains(&NumericalErrorType::Overflow));
2459
2460 let data_with_underflow = array![[1e-150, 2.0], [3.0, 4.0]];
2462 metrics.detect_errors(&data_with_underflow);
2463 assert!(metrics.error_types.contains(&NumericalErrorType::Underflow));
2464 }
2465
2466 #[test]
2467 fn test_dynamic_precision_config() {
2468 let config = DynamicPrecisionConfig::default();
2469 assert_eq!(config.strategy, ScalingStrategy::Balanced);
2470 assert_eq!(config.min_precision, PrecisionMode::Int8Dynamic);
2471 assert_eq!(config.max_precision, PrecisionMode::Full32);
2472 assert_eq!(config.performance_weight, 0.6);
2473 assert_eq!(config.accuracy_weight, 0.4);
2474 }
2475
2476 #[test]
2477 fn test_numerical_stability_monitor_creation() {
2478 let config = DynamicPrecisionConfig::default();
2479 let monitor = NumericalStabilityMonitor::new(config);
2480
2481 assert_eq!(monitor.current_precision, PrecisionMode::Mixed16);
2482 assert!(monitor.stability_history.is_empty());
2483 assert_eq!(monitor.recovery_attempts, 0);
2484 }
2485
2486 #[test]
2487 fn test_precision_increase_decrease() {
2488 let config = DynamicPrecisionConfig::default();
2489 let monitor = NumericalStabilityMonitor::new(config);
2490
2491 let increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Int8Dynamic);
2493 assert_eq!(increased, PrecisionMode::Mixed16);
2494
2495 let max_increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Full32);
2496 assert_eq!(max_increased, PrecisionMode::Full32); let decreased = NumericalStabilityMonitor::decrease_precision(PrecisionMode::Mixed16);
2500 assert_eq!(decreased, PrecisionMode::Int8Dynamic);
2501
2502 let min_decreased =
2503 NumericalStabilityMonitor::decrease_precision(PrecisionMode::Int4Advanced);
2504 assert_eq!(min_decreased, PrecisionMode::Int4Advanced); }
2506
2507 #[test]
2508 fn test_condition_number_estimation() {
2509 let config = DynamicPrecisionConfig::default();
2510 let monitor = NumericalStabilityMonitor::new(config);
2511
2512 let well_conditioned = array![[1.0, 2.0], [3.0, 4.0]];
2514 let condition_1 = NumericalStabilityMonitor::estimate_condition_number(&well_conditioned);
2515 assert!(condition_1 > 1.0 && condition_1 < 100.0);
2516
2517 let ill_conditioned = array![[1e-10, 2.0], [3.0, 1e10]];
2519 let condition_2 = NumericalStabilityMonitor::estimate_condition_number(&ill_conditioned);
2520 assert!(condition_2 > 1e15);
2521 }
2522
2523 #[test]
2524 fn test_error_recovery_system_creation() {
2525 let recovery_system = ErrorRecoverySystem::new();
2526
2527 assert!(!recovery_system.recovery_strategies.is_empty());
2529 assert!(recovery_system
2530 .recovery_strategies
2531 .contains_key(&NumericalErrorType::Overflow));
2532 assert!(recovery_system
2533 .recovery_strategies
2534 .contains_key(&NumericalErrorType::IllConditioned));
2535 assert_eq!(recovery_system.max_recovery_attempts, 3);
2536 }
2537
2538 #[tokio::test]
2539 async fn test_recovery_action_selection() {
2540 let mut recovery_system = ErrorRecoverySystem::new();
2541
2542 let action = recovery_system
2543 .attempt_recovery(NumericalErrorType::Overflow)
2544 .await;
2545 assert!(action.is_ok());
2546
2547 let recovery_action = action.unwrap();
2548 assert!(matches!(
2549 recovery_action,
2550 RecoveryAction::IncreasePrecision
2551 | RecoveryAction::ReduceTileSize
2552 | RecoveryAction::NumericalStabilization
2553 ));
2554 }
2555
2556 #[test]
2557 fn test_success_rate_update() {
2558 let mut recovery_system = ErrorRecoverySystem::new();
2559
2560 recovery_system.update_success_rate(RecoveryAction::IncreasePrecision, true);
2562 let rate = recovery_system
2563 .success_rates
2564 .get(&RecoveryAction::IncreasePrecision);
2565 assert!(rate.is_some());
2566 assert!(*rate.unwrap() > 0.5);
2567
2568 recovery_system.update_success_rate(RecoveryAction::ReduceTileSize, false);
2570 let rate = recovery_system
2571 .success_rates
2572 .get(&RecoveryAction::ReduceTileSize);
2573 assert!(rate.is_some());
2574 assert!(*rate.unwrap() < 0.5);
2575 }
2576
2577 #[test]
2578 fn test_performance_accuracy_analyzer() {
2579 let params = TradeOffParams {
2580 performance_weight: 0.7,
2581 accuracy_weight: 0.3,
2582 energy_weight: 0.0,
2583 min_accuracy: 0.9,
2584 max_time: Duration::from_secs(10),
2585 objective: OptimizationObjective::Balanced,
2586 };
2587
2588 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2589
2590 analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2592 analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2593
2594 analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2596 analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2597
2598 let optimal_precision = analyzer.optimize_precision();
2600 assert!(matches!(
2601 optimal_precision,
2602 PrecisionMode::Mixed16 | PrecisionMode::Full32
2603 ));
2604 }
2605
2606 #[test]
2607 fn test_pareto_frontier_update() {
2608 let params = TradeOffParams {
2609 performance_weight: 0.5,
2610 accuracy_weight: 0.5,
2611 energy_weight: 0.0,
2612 min_accuracy: 0.8,
2613 max_time: Duration::from_secs(5),
2614 objective: OptimizationObjective::Balanced,
2615 };
2616
2617 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2618
2619 analyzer.record_performance(PrecisionMode::Int8Dynamic, Duration::from_millis(50));
2621 analyzer.record_accuracy(PrecisionMode::Int8Dynamic, 0.85);
2622
2623 analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2624 analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2625
2626 analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2627 analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2628
2629 analyzer.update_pareto_frontier();
2630 assert!(!analyzer.pareto_frontier.is_empty());
2631 assert_eq!(analyzer.pareto_frontier.len(), 3);
2632 }
2633
2634 #[test]
2635 fn test_weighted_score_computation() {
2636 let params = TradeOffParams {
2637 performance_weight: 0.6,
2638 accuracy_weight: 0.4,
2639 energy_weight: 0.0,
2640 min_accuracy: 0.8,
2641 max_time: Duration::from_secs(5),
2642 objective: OptimizationObjective::Custom,
2643 };
2644
2645 let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2646
2647 let score1 = analyzer.compute_weighted_score(0.1, 0.9); let score2 = analyzer.compute_weighted_score(0.2, 0.95); assert!(score1 > 0.0);
2652 assert!(score2 > 0.0);
2653 }
2654
2655 #[test]
2656 fn test_advanced_tensor_core_distance_matrix_creation() {
2657 let result = AdvancedTensorCoreDistanceMatrix::new();
2658 assert!(result.is_ok());
2659
2660 let advanced_computer = result.unwrap();
2661 assert!(advanced_computer.dynamic_precision_enabled);
2662 assert!(advanced_computer.auto_recovery_enabled);
2663 }
2664
2665 #[tokio::test]
2666 #[ignore]
2667 async fn test_stability_monitoring_computation() {
2668 let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2669 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2670
2671 let result = advanced_computer
2672 .compute_with_stability_monitoring(&points.view())
2673 .await;
2674 assert!(result.is_ok());
2675
2676 let distances = result.unwrap();
2677 assert_eq!(distances.shape(), &[3, 3]);
2678
2679 let monitor = advanced_computer.stability_monitor.lock().unwrap();
2681 assert!(!monitor.stability_history.is_empty());
2682 }
2683
2684 #[tokio::test]
2685 async fn test_recovery_action_application() {
2686 let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2687 let original_precision = advanced_computer.base_computer.precision_mode;
2688
2689 let result = advanced_computer
2691 .apply_recovery_action(RecoveryAction::IncreasePrecision)
2692 .await;
2693 assert!(result.is_ok());
2694
2695 if original_precision != PrecisionMode::Full32 {
2697 assert_ne!(
2698 advanced_computer.base_computer.precision_mode,
2699 original_precision
2700 );
2701 }
2702
2703 let original_tile_size = advanced_computer.base_computer.tile_size;
2705 let result = advanced_computer
2706 .apply_recovery_action(RecoveryAction::ReduceTileSize)
2707 .await;
2708 assert!(result.is_ok());
2709
2710 let new_tile_size = advanced_computer.base_computer.tile_size;
2711 assert!(new_tile_size.0 <= original_tile_size.0);
2712 assert!(new_tile_size.1 <= original_tile_size.1);
2713 }
2714
2715 #[test]
2716 fn test_result_accuracy_estimation() {
2717 let advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2718
2719 let valid_result = array![[0.0, 1.0], [1.0, 0.0]];
2721 let accuracy = advanced_computer.estimate_result_accuracy(&valid_result);
2722 assert!(accuracy > 0.8 && accuracy <= 1.0);
2723
2724 let invalid_result = array![[0.0, f64::NAN], [1.0, 0.0]];
2726 let accuracy = advanced_computer.estimate_result_accuracy(&invalid_result);
2727 assert_eq!(accuracy, 0.0);
2728
2729 let high_range_result = array![[1e-10, 1e10], [1e5, 1e-5]];
2731 let accuracy = advanced_computer.estimate_result_accuracy(&high_range_result);
2732 assert!(accuracy > 0.0 && accuracy < 1.0);
2733 }
2734
2735 #[test]
2736 fn test_precision_mode_ordering() {
2737 assert!(matches!(
2739 PrecisionMode::AdvancedAdaptive,
2740 PrecisionMode::AdvancedAdaptive
2741 ));
2742 assert_ne!(PrecisionMode::AdvancedAdaptive, PrecisionMode::Adaptive);
2743 }
2744
2745 #[test]
2746 fn test_stability_levels() {
2747 assert!(matches!(StabilityLevel::Critical, StabilityLevel::Critical));
2748 assert_ne!(StabilityLevel::Critical, StabilityLevel::Excellent);
2749 }
2750
2751 #[test]
2752 fn test_error_types() {
2753 let error_types = [
2754 NumericalErrorType::Overflow,
2755 NumericalErrorType::Underflow,
2756 NumericalErrorType::PrecisionLoss,
2757 NumericalErrorType::ConvergenceFailure,
2758 NumericalErrorType::IllConditioned,
2759 NumericalErrorType::InvalidValues,
2760 ];
2761
2762 assert_eq!(error_types.len(), 6);
2763 assert!(error_types.contains(&NumericalErrorType::Overflow));
2764 }
2765
2766 #[test]
2767 fn test_scaling_strategies() {
2768 let strategies = [
2769 ScalingStrategy::Conservative,
2770 ScalingStrategy::Balanced,
2771 ScalingStrategy::Aggressive,
2772 ScalingStrategy::Custom,
2773 ];
2774
2775 assert_eq!(strategies.len(), 4);
2776 assert!(strategies.contains(&ScalingStrategy::Balanced));
2777 }
2778
2779 #[test]
2780 fn test_recovery_actions() {
2781 let actions = [
2782 RecoveryAction::IncreasePrecision,
2783 RecoveryAction::ReduceTileSize,
2784 RecoveryAction::FallbackAlgorithm,
2785 RecoveryAction::NumericalStabilization,
2786 RecoveryAction::RetryWithNewParams,
2787 RecoveryAction::SwitchToCPU,
2788 ];
2789
2790 assert_eq!(actions.len(), 6);
2791 assert!(actions.contains(&RecoveryAction::IncreasePrecision));
2792 }
2793
2794 #[test]
2795 fn test_optimization_objectives() {
2796 let objectives = [
2797 OptimizationObjective::MaxPerformance,
2798 OptimizationObjective::MaxAccuracy,
2799 OptimizationObjective::Balanced,
2800 OptimizationObjective::MinEnergy,
2801 OptimizationObjective::Custom,
2802 ];
2803
2804 assert_eq!(objectives.len(), 5);
2805 assert!(objectives.contains(&OptimizationObjective::Balanced));
2806 }
2807}