scirs2_spatial/
tensor_cores.rs

1//! Advanced GPU Tensor Core utilization for spatial algorithms
2//!
3//! This module provides cutting-edge implementations that leverage modern GPU tensor cores
4//! (NVIDIA's Tensor Cores, AMD's Matrix Cores, Intel's XMX units) for maximum performance
5//! in spatial computing. It includes mixed-precision operations, automatic layout optimization,
6//! and hardware-specific kernel selection for optimal throughput.
7//!
8//! # Features
9//!
10//! - **Tensor Core acceleration** for matrix operations in spatial algorithms
11//! - **Mixed-precision computing** (FP16, BF16, INT8, INT4) for maximum throughput
12//! - **Automatic tensor layout optimization** for memory coalescing
13//! - **Hierarchical tiling strategies** for large datasets
14//! - **Multi-GPU tensor parallelism** for distributed spatial computation
15//! - **Dynamic precision selection** based on numerical stability requirements
16//! - **Fused kernel operations** to minimize memory bandwidth
17//! - **Async execution pipelines** for maximum GPU utilization
18//!
19//! # Supported Hardware
20//!
21//! - **NVIDIA**: V100, A100, H100, RTX 30/40 series (Tensor Cores)
22//! - **AMD**: MI250X, MI300 series (Matrix Cores)
23//! - **Intel**: Ponte Vecchio, Arc GPUs (XMX units)
24//! - **Automatic fallback** to standard compute units when tensor cores unavailable
25//!
26//! # Examples
27//!
28//! ```
29//! use scirs2_spatial::tensor_cores::{TensorCoreDistanceMatrix, TensorCoreClustering, PrecisionMode};
30//! use ndarray::array;
31//!
32//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
33//! // Tensor core distance matrix computation
34//! let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
35//!
36//! let mut tensor_matrix = TensorCoreDistanceMatrix::new()?
37//!     .with_precision_mode(PrecisionMode::Mixed16)
38//!     .with_tensor_layout_optimization(true)
39//!     .with_hierarchical_tiling(true);
40//!
41//! let distances = tensor_matrix.compute_parallel(&points.view()).await?;
42//! println!("Tensor core distance matrix: {:?}", distances);
43//!
44//! // Tensor core k-means clustering
45//! let mut tensor_kmeans = TensorCoreClustering::new(2)?
46//!     .with_tensor_cores(true)
47//!     .with_mixed_precision(true)
48//!     .with_dynamic_precision_scaling(true);
49//!
50//! let (centroids, assignments) = tensor_kmeans.fit(&points.view()).await?;
51//! println!("Tensor core centroids: {:?}", centroids);
52//! # Ok(())
53//! # }
54//! ```
55
56use crate::error::{SpatialError, SpatialResult};
57use ndarray::{s, Array1, Array2, ArrayView2};
58use rand::Rng;
59use statrs::statistics::Statistics;
60use std::collections::{HashMap, VecDeque};
61use std::sync::{Arc, Mutex};
62use std::time::{Duration, Instant};
63
64/// Precision modes for tensor core operations
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum PrecisionMode {
67    /// Full precision (FP32)
68    Full32,
69    /// Mixed precision (FP16 compute, FP32 accumulate)
70    Mixed16,
71    /// Brain floating point (BF16)
72    BrainFloat16,
73    /// 8-bit integer with dynamic scaling
74    Int8Dynamic,
75    /// 4-bit integer with advanced quantization
76    Int4Advanced,
77    /// Automatic precision selection
78    Adaptive,
79    /// Advanced-adaptive with stability monitoring
80    AdvancedAdaptive,
81}
82
83/// Numerical stability level
84#[derive(Debug, Clone, Copy, PartialEq)]
85pub enum StabilityLevel {
86    /// Excellent numerical stability
87    Excellent,
88    /// Good numerical stability
89    Good,
90    /// Moderate numerical stability
91    Moderate,
92    /// Poor numerical stability - increase precision
93    Poor,
94    /// Critical numerical instability - recovery needed
95    Critical,
96}
97
98/// Numerical error types
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100pub enum NumericalErrorType {
101    /// Overflow in computation
102    Overflow,
103    /// Underflow in computation
104    Underflow,
105    /// Loss of precision
106    PrecisionLoss,
107    /// Convergence failure
108    ConvergenceFailure,
109    /// Ill-conditioned matrix
110    IllConditioned,
111    /// NaN or Inf values
112    InvalidValues,
113}
114
115/// Dynamic precision scaling strategy
116#[derive(Debug, Clone, Copy, PartialEq)]
117pub enum ScalingStrategy {
118    /// Conservative - always use higher precision when uncertain
119    Conservative,
120    /// Balanced - balance performance and accuracy
121    Balanced,
122    /// Aggressive - favor performance over precision
123    Aggressive,
124    /// Custom - user-defined thresholds
125    Custom,
126}
127
128/// Tensor layout optimization strategies
129#[derive(Debug, Clone, Copy, PartialEq)]
130pub enum TensorLayout {
131    /// Row-major layout (C-style)
132    RowMajor,
133    /// Column-major layout (Fortran-style)
134    ColMajor,
135    /// Blocked layout for cache efficiency
136    Blocked,
137    /// Hierarchical Z-order layout
138    ZOrder,
139    /// Hardware-optimized layout
140    HardwareOptimized,
141}
142
143/// GPU architecture types
144#[derive(Debug, Clone, Copy, PartialEq)]
145pub enum GpuArchitecture {
146    /// NVIDIA Volta (V100)
147    Volta,
148    /// NVIDIA Ampere (A100, RTX 30 series)
149    Ampere,
150    /// NVIDIA Hopper (H100)
151    Hopper,
152    /// AMD CDNA2 (MI250X)
153    CDNA2,
154    /// AMD CDNA3 (MI300)
155    CDNA3,
156    /// Intel Xe HPC (Ponte Vecchio)
157    XeHPC,
158    /// Intel Xe Graphics (Arc)
159    XeGraphics,
160    /// Unknown or fallback
161    Unknown,
162}
163
164/// Tensor core capabilities
165#[derive(Debug, Clone)]
166pub struct TensorCoreCapabilities {
167    /// Available tensor core types
168    pub tensor_core_types: Vec<TensorCoreType>,
169    /// Supported precision modes
170    pub supported_precisions: Vec<PrecisionMode>,
171    /// Maximum tensor dimensions
172    pub max_tensor_size: (usize, usize, usize),
173    /// Peak throughput (TOPS)
174    pub peak_throughput_tops: f64,
175    /// Memory bandwidth (GB/s)
176    pub memory_bandwidth_gbps: f64,
177    /// L2 cache size (MB)
178    pub l2_cache_mb: f64,
179    /// Number of streaming multiprocessors
180    pub num_sms: usize,
181    /// Architecture
182    pub architecture: GpuArchitecture,
183}
184
185/// Tensor core types
186#[derive(Debug, Clone, Copy, PartialEq)]
187pub enum TensorCoreType {
188    /// NVIDIA Tensor Cores (WMMA)
189    NvidiaTensorCore,
190    /// AMD Matrix Cores
191    AmdMatrixCore,
192    /// Intel XMX units
193    IntelXMX,
194    /// Standard CUDA/OpenCL cores (fallback)
195    StandardCores,
196}
197
198/// Numerical stability metrics
199#[derive(Debug, Clone)]
200pub struct StabilityMetrics {
201    /// Condition number of the computation
202    pub condition_number: f64,
203    /// Relative error estimate
204    pub relative_error: f64,
205    /// Forward error bound
206    pub forward_error: f64,
207    /// Backward error bound
208    pub backward_error: f64,
209    /// Loss of significant digits
210    pub digit_loss: f64,
211    /// Current stability level
212    pub stability_level: StabilityLevel,
213    /// Detected error types
214    pub error_types: Vec<NumericalErrorType>,
215    /// Timestamp of measurement
216    pub timestamp: Instant,
217}
218
219/// Dynamic precision scaling configuration
220#[derive(Debug, Clone)]
221pub struct DynamicPrecisionConfig {
222    /// Scaling strategy
223    pub strategy: ScalingStrategy,
224    /// Minimum precision level
225    pub min_precision: PrecisionMode,
226    /// Maximum precision level
227    pub max_precision: PrecisionMode,
228    /// Stability threshold for precision increase
229    pub stability_threshold_up: f64,
230    /// Stability threshold for precision decrease
231    pub stability_threshold_down: f64,
232    /// Performance weight in decision making
233    pub performance_weight: f64,
234    /// Accuracy weight in decision making
235    pub accuracy_weight: f64,
236    /// Maximum precision changes per operation
237    pub max_changes_per_operation: usize,
238    /// Cooldown period between precision changes
239    pub change_cooldown: Duration,
240}
241
242/// Real-time numerical stability monitor
243#[allow(dead_code)]
244#[derive(Debug)]
245pub struct NumericalStabilityMonitor {
246    /// Current stability metrics
247    current_metrics: StabilityMetrics,
248    /// Historical stability data
249    stability_history: VecDeque<StabilityMetrics>,
250    /// Dynamic precision configuration
251    precision_config: DynamicPrecisionConfig,
252    /// Current precision mode
253    current_precision: PrecisionMode,
254    /// Precision change history
255    precision_history: VecDeque<(Instant, PrecisionMode, f64)>,
256    /// Error recovery attempts
257    #[allow(dead_code)]
258    recovery_attempts: usize,
259    /// Maximum history length
260    max_history_length: usize,
261    /// Last precision change time
262    last_precision_change: Option<Instant>,
263}
264
265/// Advanced error recovery system
266#[allow(dead_code)]
267#[derive(Debug)]
268pub struct ErrorRecoverySystem {
269    /// Recovery strategies by error type
270    recovery_strategies: HashMap<NumericalErrorType, Vec<RecoveryAction>>,
271    /// Recovery attempt history
272    recovery_history: VecDeque<RecoveryAttempt>,
273    /// Maximum recovery attempts per operation
274    max_recovery_attempts: usize,
275    /// Recovery success rate tracking
276    success_rates: HashMap<RecoveryAction, f64>,
277}
278
279/// Recovery action types
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
281pub enum RecoveryAction {
282    /// Increase precision mode
283    IncreasePrecision,
284    /// Reduce tile size
285    ReduceTileSize,
286    /// Switch to fallback algorithm
287    FallbackAlgorithm,
288    /// Apply numerical stabilization
289    NumericalStabilization,
290    /// Retry with different parameters
291    RetryWithNewParams,
292    /// Switch to CPU computation
293    SwitchToCPU,
294}
295
296/// Recovery attempt record
297#[derive(Debug, Clone)]
298pub struct RecoveryAttempt {
299    /// Error type that triggered recovery
300    pub error_type: NumericalErrorType,
301    /// Recovery action taken
302    pub action: RecoveryAction,
303    /// Success/failure of recovery
304    pub success: bool,
305    /// Time taken for recovery
306    pub duration: Duration,
307    /// Stability metrics after recovery
308    pub post_recovery_metrics: Option<StabilityMetrics>,
309    /// Timestamp
310    pub timestamp: Instant,
311}
312
313/// Performance-accuracy trade-off analyzer
314#[derive(Debug)]
315pub struct PerformanceAccuracyAnalyzer {
316    /// Performance measurements by precision mode
317    performance_data: HashMap<PrecisionMode, VecDeque<Duration>>,
318    /// Accuracy measurements by precision mode
319    accuracy_data: HashMap<PrecisionMode, VecDeque<f64>>,
320    /// Trade-off optimization parameters
321    optimization_params: TradeOffParams,
322    /// Current Pareto frontier
323    pareto_frontier: Vec<(f64, f64, PrecisionMode)>, // (performance, accuracy, mode)
324}
325
326/// Trade-off optimization parameters
327#[derive(Debug, Clone)]
328pub struct TradeOffParams {
329    /// Weight for performance (speed)
330    pub performance_weight: f64,
331    /// Weight for accuracy
332    pub accuracy_weight: f64,
333    /// Weight for energy efficiency
334    pub energy_weight: f64,
335    /// Minimum acceptable accuracy
336    pub min_accuracy: f64,
337    /// Maximum acceptable time
338    pub max_time: Duration,
339    /// Optimization objective
340    pub objective: OptimizationObjective,
341}
342
343/// Optimization objectives
344#[derive(Debug, Clone, Copy, PartialEq)]
345pub enum OptimizationObjective {
346    /// Maximize performance (minimize time)
347    MaxPerformance,
348    /// Maximize accuracy
349    MaxAccuracy,
350    /// Balance performance and accuracy
351    Balanced,
352    /// Minimize energy consumption
353    MinEnergy,
354    /// Custom weighted objective
355    Custom,
356}
357
358/// Tensor core distance matrix computer with advanced stability monitoring
359#[derive(Debug)]
360pub struct AdvancedTensorCoreDistanceMatrix {
361    /// Base tensor core computer
362    base_computer: TensorCoreDistanceMatrix,
363    /// Numerical stability monitor
364    stability_monitor: Arc<Mutex<NumericalStabilityMonitor>>,
365    /// Error recovery system
366    recovery_system: ErrorRecoverySystem,
367    /// Performance-accuracy analyzer
368    performance_analyzer: PerformanceAccuracyAnalyzer,
369    /// Enable dynamic precision scaling
370    dynamic_precision_enabled: bool,
371    /// Enable automatic error recovery
372    auto_recovery_enabled: bool,
373}
374
375/// Tensor core distance matrix computer
376#[derive(Debug, Clone)]
377pub struct TensorCoreDistanceMatrix {
378    /// Precision mode
379    precision_mode: PrecisionMode,
380    /// Enable tensor layout optimization
381    layout_optimization: bool,
382    /// Enable hierarchical tiling
383    hierarchical_tiling: bool,
384    /// Tile size for blocking
385    tile_size: (usize, usize),
386    /// GPU capabilities
387    capabilities: Option<TensorCoreCapabilities>,
388    /// Current tensor layout
389    tensor_layout: TensorLayout,
390    /// Async execution streams
391    execution_streams: usize,
392}
393
394impl TensorCoreDistanceMatrix {
395    /// Create new tensor core distance matrix computer
396    pub fn new() -> SpatialResult<Self> {
397        let capabilities = detect_tensor_core_capabilities()?;
398
399        Ok(Self {
400            precision_mode: PrecisionMode::Mixed16,
401            layout_optimization: true,
402            hierarchical_tiling: true,
403            tile_size: (256, 256),
404            capabilities: Some(capabilities),
405            tensor_layout: TensorLayout::HardwareOptimized,
406            execution_streams: 4,
407        })
408    }
409
410    /// Configure precision mode
411    pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
412        self.precision_mode = mode;
413        self
414    }
415
416    /// Enable tensor layout optimization
417    pub fn with_tensor_layout_optimization(mut self, enabled: bool) -> Self {
418        self.layout_optimization = enabled;
419        self
420    }
421
422    /// Enable hierarchical tiling
423    pub fn with_hierarchical_tiling(mut self, enabled: bool) -> Self {
424        self.hierarchical_tiling = enabled;
425        self
426    }
427
428    /// Configure tile size
429    pub fn with_tile_size(mut self, rows: usize, cols: usize) -> Self {
430        self.tile_size = (rows, cols);
431        self
432    }
433
434    /// Configure execution streams
435    pub fn with_execution_streams(mut self, streams: usize) -> Self {
436        self.execution_streams = streams;
437        self
438    }
439
440    /// Compute distance matrix using tensor cores
441    pub async fn compute_parallel(
442        &mut self,
443        points: &ArrayView2<'_, f64>,
444    ) -> SpatialResult<Array2<f64>> {
445        let (npoints, ndims) = points.dim();
446
447        if npoints == 0 || ndims == 0 {
448            return Err(SpatialError::InvalidInput("Empty input data".to_string()));
449        }
450
451        // Optimize tensor layout
452        let optimizedpoints = if self.layout_optimization {
453            self.optimize_tensor_layout(points)?
454        } else {
455            points.to_owned()
456        };
457
458        // Choose computation strategy based on data size
459        if self.hierarchical_tiling && npoints > 1024 {
460            self.compute_hierarchical_tiled(&optimizedpoints.view())
461                .await
462        } else {
463            self.compute_direct_tensor_cores(&optimizedpoints.view())
464                .await
465        }
466    }
467
468    /// Optimize tensor layout for hardware
469    fn optimize_tensor_layout(
470        &mut self,
471        points: &ArrayView2<'_, f64>,
472    ) -> SpatialResult<Array2<f64>> {
473        let (npoints, ndims) = points.dim();
474
475        match self.tensor_layout {
476            TensorLayout::RowMajor => Ok(points.to_owned()),
477            TensorLayout::ColMajor => {
478                let mut transposed = Array2::zeros((ndims, npoints));
479                for (i, point) in points.outer_iter().enumerate() {
480                    transposed.column_mut(i).assign(&point);
481                }
482                Ok(transposed.t().to_owned())
483            }
484            TensorLayout::Blocked => TensorCoreDistanceMatrix::create_blocked_layout(points),
485            TensorLayout::ZOrder => self.create_zorder_layout(points),
486            TensorLayout::HardwareOptimized => self.create_hardware_optimized_layout(points),
487        }
488    }
489
490    /// Create blocked tensor layout
491    fn create_blocked_layout(points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
492        let (npoints, ndims) = points.dim();
493        let block_size = 64; // Optimize for cache lines
494
495        let blocked_rows = npoints.div_ceil(block_size) * block_size;
496        let blocked_cols = ndims.div_ceil(block_size) * block_size;
497
498        let mut blocked_data = Array2::zeros((blocked_rows, blocked_cols));
499
500        for block_i in 0..(npoints / block_size + 1) {
501            for block_j in 0..(ndims / block_size + 1) {
502                let start_i = block_i * block_size;
503                let start_j = block_j * block_size;
504                let end_i = (start_i + block_size).min(npoints);
505                let end_j = (start_j + block_size).min(ndims);
506
507                for i in start_i..end_i {
508                    for j in start_j..end_j {
509                        blocked_data[[i, j]] = points[[i, j]];
510                    }
511                }
512            }
513        }
514
515        Ok(blocked_data.slice(s![..npoints, ..ndims]).to_owned())
516    }
517
518    /// Create Z-order (Morton order) layout
519    fn create_zorder_layout(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
520        let (npoints, ndims) = points.dim();
521
522        // Create Z-order mapping
523        let mut z_indices: Vec<(usize, usize)> = (0..npoints)
524            .map(|i| {
525                (
526                    i,
527                    TensorCoreDistanceMatrix::calculate_z_order_index(i, ndims),
528                )
529            })
530            .collect();
531
532        z_indices.sort_by_key(|(_, z_idx)| *z_idx);
533
534        let mut reordered_data = Array2::zeros((npoints, ndims));
535        for (new_idx, (old_idx, z_idx)) in z_indices.iter().enumerate() {
536            reordered_data
537                .row_mut(new_idx)
538                .assign(&points.row(*old_idx));
539        }
540
541        Ok(reordered_data)
542    }
543
544    /// Calculate Z-order (Morton) index
545    fn calculate_z_order_index(point_idx: usize, ndims: usize) -> usize {
546        // Simplified Z-order calculation
547        let mut z_index = 0;
548        let temp_idx = point_idx;
549
550        for bit in 0..16 {
551            // Limit to 16 bits for practical purposes
552            for dim in 0..ndims.min(3) {
553                // Limit to 3 dimensions
554                if temp_idx & (1 << bit) != 0 {
555                    z_index |= 1 << (bit * ndims + dim);
556                }
557            }
558        }
559
560        z_index
561    }
562
563    /// Create hardware-optimized layout
564    fn create_hardware_optimized_layout(
565        &self,
566        points: &ArrayView2<'_, f64>,
567    ) -> SpatialResult<Array2<f64>> {
568        if let Some(ref capabilities) = self.capabilities {
569            match capabilities.architecture {
570                GpuArchitecture::Ampere | GpuArchitecture::Hopper => {
571                    // Use NVIDIA-optimized layout (NHWC-like for spatial data)
572                    self.create_nvidia_optimized_layout(points)
573                }
574                GpuArchitecture::CDNA2 | GpuArchitecture::CDNA3 => {
575                    // Use AMD-optimized layout
576                    self.create_amd_optimized_layout(points)
577                }
578                GpuArchitecture::XeHPC | GpuArchitecture::XeGraphics => {
579                    // Use Intel-optimized layout
580                    self.create_intel_optimized_layout(points)
581                }
582                _ => {
583                    // Fallback to blocked layout
584                    TensorCoreDistanceMatrix::create_blocked_layout(points)
585                }
586            }
587        } else {
588            TensorCoreDistanceMatrix::create_blocked_layout(points)
589        }
590    }
591
592    /// Create NVIDIA-optimized tensor layout
593    fn create_nvidia_optimized_layout(
594        &self,
595        points: &ArrayView2<'_, f64>,
596    ) -> SpatialResult<Array2<f64>> {
597        let (npoints, ndims) = points.dim();
598
599        // Pad dimensions to multiples of 8 for tensor core efficiency
600        let paddedpoints = npoints.div_ceil(8) * 8;
601        let padded_dims = ndims.div_ceil(8) * 8;
602
603        let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
604
605        // Copy original data
606        for i in 0..npoints {
607            for j in 0..ndims {
608                padded_data[[i, j]] = points[[i, j]];
609            }
610        }
611
612        // Return view of original size
613        Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
614    }
615
616    /// Create AMD-optimized tensor layout
617    fn create_amd_optimized_layout(
618        &self,
619        points: &ArrayView2<'_, f64>,
620    ) -> SpatialResult<Array2<f64>> {
621        let (npoints, ndims) = points.dim();
622
623        // AMD matrix cores prefer multiples of 16
624        let paddedpoints = npoints.div_ceil(16) * 16;
625        let padded_dims = ndims.div_ceil(16) * 16;
626
627        let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
628
629        for i in 0..npoints {
630            for j in 0..ndims {
631                padded_data[[i, j]] = points[[i, j]];
632            }
633        }
634
635        Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
636    }
637
638    /// Create Intel-optimized tensor layout  
639    fn create_intel_optimized_layout(
640        &self,
641        points: &ArrayView2<'_, f64>,
642    ) -> SpatialResult<Array2<f64>> {
643        let (npoints, ndims) = points.dim();
644
645        // Intel XMX units prefer multiples of 32
646        let paddedpoints = npoints.div_ceil(32) * 32;
647        let padded_dims = ndims.div_ceil(32) * 32;
648
649        let mut padded_data = Array2::zeros((paddedpoints, padded_dims));
650
651        for i in 0..npoints {
652            for j in 0..ndims {
653                padded_data[[i, j]] = points[[i, j]];
654            }
655        }
656
657        Ok(padded_data.slice(s![..npoints, ..ndims]).to_owned())
658    }
659
660    /// Compute using hierarchical tiling strategy
661    async fn compute_hierarchical_tiled(
662        &mut self,
663        points: &ArrayView2<'_, f64>,
664    ) -> SpatialResult<Array2<f64>> {
665        let (npoints, ndims) = points.dim();
666        let mut distance_matrix = Array2::zeros((npoints, npoints));
667
668        let (tile_rows, tile_cols) = self.tile_size;
669        let precision_mode = self.precision_mode; // Extract before loop
670
671        // Create async tasks for tile computation
672        let mut tile_futures = Vec::new();
673
674        for i in (0..npoints).step_by(tile_rows) {
675            for j in (0..npoints).step_by(tile_cols) {
676                let end_i = (i + tile_rows).min(npoints);
677                let end_j = (j + tile_cols).min(npoints);
678
679                let tilepoints_i = points.slice(s![i..end_i, ..]).to_owned();
680                let tilepoints_j = points.slice(s![j..end_j, ..]).to_owned();
681
682                // Use extracted precision_mode instead of accessing self
683                let future = async move {
684                    // Basic distance computation for tile
685                    let (rows_i, _) = tilepoints_i.dim();
686                    let (rows_j, _) = tilepoints_j.dim();
687                    let mut tile_distances = Array2::zeros((rows_i, rows_j));
688
689                    for r in 0..rows_i {
690                        for c in 0..rows_j {
691                            let p1 = tilepoints_i.row(r);
692                            let p2 = tilepoints_j.row(c);
693                            let diff = &p1 - &p2;
694                            let dist = diff.iter().map(|x| x.powi(2)).sum::<f64>().sqrt();
695                            tile_distances[[r, c]] = dist;
696                        }
697                    }
698                    Ok::<Array2<f64>, SpatialError>(tile_distances)
699                };
700                tile_futures.push((i, j, end_i, end_j, future));
701            }
702        }
703
704        // Execute tiles and collect results
705        for (i, j, end_i, end_j, future) in tile_futures {
706            let tile_result = future.await?;
707
708            // Copy tile result to main matrix
709            let tile_rows = end_i - i;
710            let tile_cols = end_j - j;
711
712            for row in 0..tile_rows {
713                for col in 0..tile_cols {
714                    distance_matrix[[i + row, j + col]] = tile_result[[row, col]];
715                }
716            }
717        }
718
719        Ok(distance_matrix)
720    }
721
722    /// Compute tile using tensor cores
723    async fn compute_tile_tensor_cores(
724        &mut self,
725        points_i: Array2<f64>,
726        points_j: Array2<f64>,
727        precision_mode: PrecisionMode,
728    ) -> SpatialResult<Array2<f64>> {
729        let (_n_i, ndims) = points_i.dim();
730        let (_n_j, _) = points_j.dim();
731
732        match precision_mode {
733            PrecisionMode::Full32 => {
734                self.compute_distances_fp32(&points_i.view(), &points_j.view())
735                    .await
736            }
737            PrecisionMode::Mixed16 => {
738                self.compute_distances_mixed16(&points_i.view(), &points_j.view())
739                    .await
740            }
741            PrecisionMode::BrainFloat16 => {
742                self.compute_distances_bf16(&points_i.view(), &points_j.view())
743                    .await
744            }
745            PrecisionMode::Int8Dynamic => {
746                self.compute_distances_int8(&points_i.view(), &points_j.view())
747                    .await
748            }
749            PrecisionMode::Int4Advanced => {
750                self.compute_distances_int4(&points_i.view(), &points_j.view())
751                    .await
752            }
753            PrecisionMode::Adaptive => {
754                self.compute_distances_adaptive(&points_i.view(), &points_j.view())
755                    .await
756            }
757            PrecisionMode::AdvancedAdaptive => {
758                self.compute_distances_adaptive(&points_i.view(), &points_j.view())
759                    .await
760            }
761        }
762    }
763
764    /// Direct tensor core computation (no tiling)
765    async fn compute_direct_tensor_cores(
766        &mut self,
767        points: &ArrayView2<'_, f64>,
768    ) -> SpatialResult<Array2<f64>> {
769        self.compute_tile_tensor_cores(points.to_owned(), points.to_owned(), self.precision_mode)
770            .await
771    }
772
773    /// Compute distances using FP32 precision
774    async fn compute_distances_fp32(
775        &self,
776        points_i: &ArrayView2<'_, f64>,
777        points_j: &ArrayView2<'_, f64>,
778    ) -> SpatialResult<Array2<f64>> {
779        let (n_i, ndims) = points_i.dim();
780        let (n_j, _) = points_j.dim();
781        let mut distances = Array2::zeros((n_i, n_j));
782
783        // Simulate tensor core operation using GEMM
784        // D[_i_j] = ||points_i[_i] - points_j[_j]||²
785
786        // Compute ||points_i||² for each point
787        let norms_i: Array1<f64> = points_i
788            .outer_iter()
789            .map(|point| point.iter().map(|&x| x * x).sum())
790            .collect();
791
792        // Compute ||points_j||² for each point
793        let norms_j: Array1<f64> = points_j
794            .outer_iter()
795            .map(|point| point.iter().map(|&x| x * x).sum())
796            .collect();
797
798        // Compute cross terms using matrix multiplication (tensor core operation)
799        let cross_terms = self
800            .tensor_core_gemm_fp32(points_i, &points_j.t().to_owned().view())
801            .await?;
802
803        // Combine terms: ||a-b||² = ||a||² + ||b||² - 2⟨a,b⟩
804        for _i in 0..n_i {
805            for _j in 0..n_j {
806                distances[[_i, _j]] = (norms_i[_i] + norms_j[_j] - 2.0 * cross_terms[[_i, _j]])
807                    .max(0.0)
808                    .sqrt();
809            }
810        }
811
812        Ok(distances)
813    }
814
815    /// Compute distances using mixed FP16 precision
816    async fn compute_distances_mixed16(
817        &self,
818        points_i: &ArrayView2<'_, f64>,
819        points_j: &ArrayView2<'_, f64>,
820    ) -> SpatialResult<Array2<f64>> {
821        // Convert to FP16 for computation, accumulate in FP32
822        let points_i_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_i)?;
823        let points_j_f16 = TensorCoreDistanceMatrix::convert_to_fp16(points_j)?;
824
825        let (n_i, _) = points_i.dim();
826        let (n_j, _) = points_j.dim();
827        let mut distances = Array2::zeros((n_i, n_j));
828
829        // Simulate mixed precision computation
830        let norms_i_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_i_f16)?;
831        let norms_j_f16 = TensorCoreDistanceMatrix::compute_norms_fp16(&points_j_f16)?;
832
833        // Tensor core GEMM in FP16 with FP32 accumulation
834        let cross_terms = self
835            .tensor_core_gemm_mixed16(&points_i_f16, &points_j_f16.t().to_owned())
836            .await?;
837
838        for _i in 0..n_i {
839            for _j in 0..n_j {
840                let distance_sq = norms_i_f16[_i] as f64 + norms_j_f16[_j] as f64
841                    - 2.0 * cross_terms[[_i, _j]] as f64;
842                distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
843            }
844        }
845
846        Ok(distances)
847    }
848
849    /// Compute distances using BFloat16 precision
850    async fn compute_distances_bf16(
851        &mut self,
852        points_i: &ArrayView2<'_, f64>,
853        points_j: &ArrayView2<'_, f64>,
854    ) -> SpatialResult<Array2<f64>> {
855        // Similar to FP16 but with BFloat16 format
856        // BF16 has better dynamic range than FP16
857        let points_i_bf16 = self.convert_to_bf16(points_i)?;
858        let points_j_bf16 = self.convert_to_bf16(points_j)?;
859
860        let (n_i, _) = points_i.dim();
861        let (n_j, _) = points_j.dim();
862        let mut distances = Array2::zeros((n_i, n_j));
863
864        let norms_i_bf16 = self.compute_norms_bf16(&points_i_bf16)?;
865        let norms_j_bf16 = self.compute_norms_bf16(&points_j_bf16)?;
866
867        let cross_terms = self
868            .tensor_core_gemm_bf16(&points_i_bf16, &points_j_bf16.t().to_owned())
869            .await?;
870
871        for _i in 0..n_i {
872            for _j in 0..n_j {
873                let distance_sq = norms_i_bf16[_i] as f64 + norms_j_bf16[_j] as f64
874                    - 2.0 * cross_terms[[_i, _j]] as f64;
875                distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
876            }
877        }
878
879        Ok(distances)
880    }
881
882    /// Compute distances using INT8 with dynamic scaling
883    async fn compute_distances_int8(
884        &self,
885        points_i: &ArrayView2<'_, f64>,
886        points_j: &ArrayView2<'_, f64>,
887    ) -> SpatialResult<Array2<f64>> {
888        // Dynamic quantization to INT8
889        let (scale_i, points_i_int8) = self.quantize_to_int8_dynamic(points_i)?;
890        let (scale_j, points_j_int8) = self.quantize_to_int8_dynamic(points_j)?;
891
892        let (n_i, _) = points_i.dim();
893        let (n_j, _) = points_j.dim();
894        let mut distances = Array2::zeros((n_i, n_j));
895
896        // Compute using INT8 tensor cores
897        let combined_scale = scale_i * scale_j;
898
899        for _i in 0..n_i {
900            for _j in 0..n_j {
901                // Compute cross term using INT8
902                let cross_term_int32 = points_i_int8
903                    .row(_i)
904                    .iter()
905                    .zip(points_j_int8.row(_j).iter())
906                    .map(|(&a, &b)| (a as i32) * (b as i32))
907                    .sum::<i32>();
908                let cross_term_f64 = cross_term_int32 as f64 * combined_scale;
909
910                // Compute norms in original space
911                let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
912                let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
913
914                let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
915                distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
916            }
917        }
918
919        Ok(distances)
920    }
921
922    /// Compute distances using INT4 with advanced quantization
923    async fn compute_distances_int4(
924        &self,
925        points_i: &ArrayView2<'_, f64>,
926        points_j: &ArrayView2<'_, f64>,
927    ) -> SpatialResult<Array2<f64>> {
928        // Advanced INT4 quantization with optimal scaling
929        let (scale_i, points_i_int4) = self.quantize_to_int4_advanced(points_i)?;
930        let (scale_j, points_j_int4) = self.quantize_to_int4_advanced(points_j)?;
931
932        // For simplicity, convert INT4 to INT8 for computation
933        let points_i_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_i_int4);
934        let points_j_int8 = TensorCoreDistanceMatrix::int4_to_int8(&points_j_int4);
935
936        let (n_i, _) = points_i.dim();
937        let (n_j, _) = points_j.dim();
938        let mut distances = Array2::zeros((n_i, n_j));
939
940        // TODO: Implement cross terms calculation with scale
941        // let cross_terms_int32 = self
942        //     .tensor_core_gemm_int8(&points_i_int8, &points_j_int8.t()) as f64 * combined_scale;
943
944        // Calculate distances
945        for _i in 0..n_i {
946            for _j in 0..n_j {
947                let norm_i_sq: f64 = points_i.row(_i).iter().map(|&x| x * x).sum();
948                let norm_j_sq: f64 = points_j.row(_j).iter().map(|&x| x * x).sum();
949
950                // TODO: Use cross_term from tensor core computation
951                let cross_term_f64 = 0.0; // Placeholder
952                let distance_sq = norm_i_sq + norm_j_sq - 2.0 * cross_term_f64;
953                distances[[_i, _j]] = distance_sq.max(0.0).sqrt();
954            }
955        }
956
957        Ok(distances)
958    }
959
960    /// Adaptive precision computation based on numerical requirements
961    async fn compute_distances_adaptive(
962        &mut self,
963        points_i: &ArrayView2<'_, f64>,
964        points_j: &ArrayView2<'_, f64>,
965    ) -> SpatialResult<Array2<f64>> {
966        // Analyze data characteristics to choose optimal precision
967        let data_range = self.analyze_data_range(points_i, points_j);
968        let condition_number = self.estimate_condition_number(points_i, points_j);
969
970        let optimal_precision = if condition_number > 1e6 {
971            PrecisionMode::Full32
972        } else if data_range > 1e3 {
973            PrecisionMode::BrainFloat16
974        } else if data_range > 100.0 {
975            PrecisionMode::Mixed16
976        } else {
977            PrecisionMode::Int8Dynamic
978        };
979
980        match optimal_precision {
981            PrecisionMode::Full32 => self.compute_distances_fp32(points_i, points_j).await,
982            PrecisionMode::Mixed16 => self.compute_distances_mixed16(points_i, points_j).await,
983            PrecisionMode::BrainFloat16 => self.compute_distances_bf16(points_i, points_j).await,
984            PrecisionMode::Int8Dynamic => self.compute_distances_int8(points_i, points_j).await,
985            PrecisionMode::Int4Advanced => self.compute_distances_int8(points_i, points_j).await, // Fallback to int8
986            PrecisionMode::Adaptive => self.compute_distances_mixed16(points_i, points_j).await, // Fallback to mixed16
987            PrecisionMode::AdvancedAdaptive => {
988                self.compute_distances_fp32(points_i, points_j).await
989            } // Fallback to fp32
990        }
991    }
992
993    /// Tensor core GEMM operation in FP32
994    async fn tensor_core_gemm_fp32(
995        &self,
996        a: &ArrayView2<'_, f64>,
997        b: &ArrayView2<'_, f64>,
998    ) -> SpatialResult<Array2<f64>> {
999        // Simulate tensor core GEMM C = A * B
1000        let (m, k) = a.dim();
1001        let (k2, n) = b.dim();
1002
1003        if k != k2 {
1004            return Err(SpatialError::InvalidInput(
1005                "Matrix dimensions don't match for multiplication".to_string(),
1006            ));
1007        }
1008
1009        let mut c = Array2::zeros((m, n));
1010
1011        // Simulate blocked matrix multiplication with tensor cores
1012        let block_size = 16; // Typical tensor core block size
1013
1014        for i in (0..m).step_by(block_size) {
1015            for j in (0..n).step_by(block_size) {
1016                for kk in (0..k).step_by(block_size) {
1017                    let end_i = (i + block_size).min(m);
1018                    let end_j = (j + block_size).min(n);
1019                    let end_k = (kk + block_size).min(k);
1020
1021                    // Simulate tensor core computation for this block
1022                    for ii in i..end_i {
1023                        for jj in j..end_j {
1024                            for kkk in kk..end_k {
1025                                c[[ii, jj]] += a[[ii, kkk]] * b[[kkk, jj]];
1026                            }
1027                        }
1028                    }
1029                }
1030            }
1031        }
1032
1033        Ok(c)
1034    }
1035
1036    /// Tensor core GEMM operation in mixed FP16
1037    async fn tensor_core_gemm_mixed16(
1038        &self,
1039        a: &Array2<f32>,
1040        b: &Array2<f32>,
1041    ) -> SpatialResult<Array2<f32>> {
1042        // Similar to FP32 but with FP16 inputs and FP32 accumulation
1043        let (m, k) = a.dim();
1044        let (k2, n) = b.dim();
1045
1046        if k != k2 {
1047            return Err(SpatialError::InvalidInput(
1048                "Matrix dimensions don't match".to_string(),
1049            ));
1050        }
1051
1052        let mut c = Array2::zeros((m, n));
1053        let block_size = 16;
1054
1055        for i in (0..m).step_by(block_size) {
1056            for j in (0..n).step_by(block_size) {
1057                for kk in (0..k).step_by(block_size) {
1058                    let end_i = (i + block_size).min(m);
1059                    let end_j = (j + block_size).min(n);
1060                    let end_k = (kk + block_size).min(k);
1061
1062                    for ii in i..end_i {
1063                        for jj in j..end_j {
1064                            for kkk in kk..end_k {
1065                                // Simulate FP16 multiply with FP32 accumulate
1066                                c[[ii, jj]] += a[[ii, kkk]] * b[[kkk, jj]];
1067                            }
1068                        }
1069                    }
1070                }
1071            }
1072        }
1073
1074        Ok(c)
1075    }
1076
1077    /// Tensor core GEMM operation in BFloat16
1078    async fn tensor_core_gemm_bf16(
1079        &self,
1080        a: &Array2<f32>,
1081        b: &Array2<f32>,
1082    ) -> SpatialResult<Array2<f32>> {
1083        // Similar to mixed16 but simulating BF16 characteristics
1084        self.tensor_core_gemm_mixed16(a, b).await
1085    }
1086
1087    /// Tensor core GEMM operation in INT8
1088    #[allow(dead_code)]
1089    async fn tensor_core_gemm_int8(
1090        &self,
1091        a: &Array2<i8>,
1092        b: &Array2<i8>,
1093    ) -> SpatialResult<Array2<i32>> {
1094        let (m, k) = a.dim();
1095        let (k2, n) = b.dim();
1096
1097        if k != k2 {
1098            return Err(SpatialError::InvalidInput(
1099                "Matrix dimensions don't match".to_string(),
1100            ));
1101        }
1102
1103        let mut c = Array2::zeros((m, n));
1104        let block_size = 16;
1105
1106        for i in (0..m).step_by(block_size) {
1107            for j in (0..n).step_by(block_size) {
1108                for kk in (0..k).step_by(block_size) {
1109                    let end_i = (i + block_size).min(m);
1110                    let end_j = (j + block_size).min(n);
1111                    let end_k = (kk + block_size).min(k);
1112
1113                    for ii in i..end_i {
1114                        for jj in j..end_j {
1115                            for kkk in kk..end_k {
1116                                // INT8 multiply with INT32 accumulate
1117                                c[[ii, jj]] += a[[ii, kkk]] as i32 * b[[kkk, jj]] as i32;
1118                            }
1119                        }
1120                    }
1121                }
1122            }
1123        }
1124
1125        Ok(c)
1126    }
1127
1128    /// Convert FP64 to FP16 format
1129    fn convert_to_fp16(data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1130        let (rows, cols) = data.dim();
1131        let mut fp16_data = Array2::zeros((rows, cols));
1132
1133        for i in 0..rows {
1134            for j in 0..cols {
1135                // Simple conversion to FP32 (FP16 would need special library)
1136                fp16_data[[i, j]] = data[[i, j]] as f32;
1137            }
1138        }
1139
1140        Ok(fp16_data)
1141    }
1142
1143    /// Convert FP64 to BFloat16 format
1144    fn convert_to_bf16(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f32>> {
1145        // Similar to FP16 but with BF16 characteristics
1146        TensorCoreDistanceMatrix::convert_to_fp16(data)
1147    }
1148
1149    /// Quantize to INT8 with dynamic scaling
1150    fn quantize_to_int8_dynamic(
1151        &self,
1152        data: &ArrayView2<'_, f64>,
1153    ) -> SpatialResult<(f64, Array2<i8>)> {
1154        let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1155        let scale = max_val / 127.0; // Map to [-127, 127]
1156
1157        let (rows, cols) = data.dim();
1158        let mut quantized = Array2::zeros((rows, cols));
1159
1160        for i in 0..rows {
1161            for j in 0..cols {
1162                let quantized_val = (data[[i, j]] / scale).round() as i8;
1163                quantized[[i, j]] = quantized_val.clamp(-127, 127);
1164            }
1165        }
1166
1167        Ok((scale, quantized))
1168    }
1169
1170    /// Quantize to INT4 with advanced quantization
1171    fn quantize_to_int4_advanced(
1172        &self,
1173        data: &ArrayView2<'_, f64>,
1174    ) -> SpatialResult<(f64, Array2<i8>)> {
1175        let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1176        let scale = max_val / 7.0; // Map to [-7, 7] for 4-bit
1177
1178        let (rows, cols) = data.dim();
1179        let mut quantized = Array2::zeros((rows, cols));
1180
1181        for i in 0..rows {
1182            for j in 0..cols {
1183                let quantized_val = (data[[i, j]] / scale).round() as i8;
1184                quantized[[i, j]] = quantized_val.clamp(-7, 7);
1185            }
1186        }
1187
1188        Ok((scale, quantized))
1189    }
1190
1191    /// Convert INT4 to INT8 for computation
1192    fn int4_to_int8(data: &Array2<i8>) -> Array2<i8> {
1193        // INT4 values are already in INT8 format, just clamp to ensure 4-bit range
1194        data.mapv(|x| x.clamp(-7, 7))
1195    }
1196
1197    /// Compute norms for FP16 data
1198    fn compute_norms_fp16(data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1199        let norms = data
1200            .outer_iter()
1201            .map(|row| row.iter().map(|&x| x * x).sum())
1202            .collect();
1203        Ok(norms)
1204    }
1205
1206    /// Compute norms for BF16 data
1207    fn compute_norms_bf16(&mut self, data: &Array2<f32>) -> SpatialResult<Array1<f32>> {
1208        TensorCoreDistanceMatrix::compute_norms_fp16(data)
1209    }
1210
1211    /// Analyze data range for adaptive precision
1212    fn analyze_data_range(
1213        &self,
1214        points_i: &ArrayView2<'_, f64>,
1215        points_j: &ArrayView2<'_, f64>,
1216    ) -> f64 {
1217        let min_i = points_i.fold(f64::INFINITY, |acc, &x| acc.min(x));
1218        let max_i = points_i.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1219        let min_j = points_j.fold(f64::INFINITY, |acc, &x| acc.min(x));
1220        let max_j = points_j.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1221
1222        let overall_min = min_i.min(min_j);
1223        let overall_max = max_i.max(max_j);
1224
1225        overall_max - overall_min
1226    }
1227
1228    /// Estimate condition number for numerical stability
1229    fn estimate_condition_number(
1230        &self,
1231        points_i: &ArrayView2<'_, f64>,
1232        points_j: &ArrayView2<'_, f64>,
1233    ) -> f64 {
1234        // Simplified condition number estimation
1235        let data_range = self.analyze_data_range(points_i, points_j);
1236        let mean_i: f64 = points_i.sum() / (points_i.len() as f64);
1237        let mean_j: f64 = points_j.sum() / (points_j.len() as f64);
1238        let overall_mean = (mean_i + mean_j) / 2.0;
1239
1240        if overall_mean.abs() < 1e-10 {
1241            1e6 // High condition number for near-zero data
1242        } else {
1243            data_range / overall_mean.abs()
1244        }
1245    }
1246}
1247
1248/// Tensor core clustering algorithm
1249#[allow(dead_code)]
1250#[derive(Debug, Clone)]
1251pub struct TensorCoreClustering {
1252    /// Number of clusters
1253    _numclusters: usize,
1254    /// Precision mode
1255    precision_mode: PrecisionMode,
1256    /// Enable tensor cores
1257    tensor_cores: bool,
1258    /// Enable mixed precision
1259    mixed_precision: bool,
1260    /// Dynamic precision scaling
1261    dynamic_precision: bool,
1262    /// GPU capabilities
1263    capabilities: Option<TensorCoreCapabilities>,
1264}
1265
1266impl TensorCoreClustering {
1267    /// Create new tensor core clustering
1268    pub fn new(_numclusters: usize) -> SpatialResult<Self> {
1269        let capabilities = detect_tensor_core_capabilities().ok();
1270
1271        Ok(Self {
1272            _numclusters,
1273            precision_mode: PrecisionMode::Mixed16,
1274            tensor_cores: true,
1275            mixed_precision: true,
1276            dynamic_precision: false,
1277            capabilities,
1278        })
1279    }
1280
1281    /// Enable tensor cores
1282    pub fn with_tensor_cores(mut self, enabled: bool) -> Self {
1283        self.tensor_cores = enabled;
1284        self
1285    }
1286
1287    /// Enable mixed precision
1288    pub fn with_mixed_precision(mut self, enabled: bool) -> Self {
1289        self.mixed_precision = enabled;
1290        self
1291    }
1292
1293    /// Enable dynamic precision scaling
1294    pub fn with_dynamic_precision_scaling(mut self, enabled: bool) -> Self {
1295        self.dynamic_precision = enabled;
1296        self
1297    }
1298
1299    /// Fit clustering using tensor cores
1300    pub async fn fit(
1301        &mut self,
1302        points: &ArrayView2<'_, f64>,
1303    ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
1304        let (npoints, ndims) = points.dim();
1305
1306        if npoints < self._numclusters {
1307            return Err(SpatialError::InvalidInput(
1308                "Number of points must be >= number of clusters".to_string(),
1309            ));
1310        }
1311
1312        // Initialize centroids
1313        let mut centroids = self.initialize_centroids(points)?;
1314        let mut assignments = Array1::zeros(npoints);
1315
1316        // Tensor core k-means iterations
1317        for _iteration in 0..100 {
1318            // Compute distances using tensor cores
1319            let distance_matrix = if self.tensor_cores {
1320                let tensor_computer =
1321                    TensorCoreDistanceMatrix::new()?.with_precision_mode(self.precision_mode);
1322                tensor_computer
1323                    .compute_distances_to_centroids(points, &centroids.view())
1324                    .await?
1325            } else {
1326                self.compute_distances_fallback(points, &centroids.view())?
1327            };
1328
1329            // Update assignments
1330            let new_assignments = self.update_assignments(&distance_matrix)?;
1331
1332            // Update centroids using tensor core operations
1333            let new_centroids = if self.tensor_cores {
1334                self.update_centroids_tensor_cores(points, &new_assignments)
1335                    .await?
1336            } else {
1337                self.update_centroids_fallback(points, &new_assignments)?
1338            };
1339
1340            // Check convergence
1341            let centroid_change = self.compute_centroid_change(&centroids, &new_centroids);
1342            if centroid_change < 1e-6 {
1343                break;
1344            }
1345
1346            centroids = new_centroids;
1347            assignments = new_assignments;
1348        }
1349
1350        Ok((centroids, assignments))
1351    }
1352
1353    /// Initialize centroids using k-means++
1354    fn initialize_centroids(&mut self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
1355        let (npoints, ndims) = points.dim();
1356        let mut centroids = Array2::zeros((self._numclusters, ndims));
1357
1358        // k-means++ initialization
1359        let mut rng = rand::rng();
1360
1361        // Choose first centroid randomly
1362        let first_idx = rng.gen_range(0..npoints);
1363        centroids.row_mut(0).assign(&points.row(first_idx));
1364
1365        // Choose remaining centroids with probability proportional to distance
1366        for k in 1..self._numclusters {
1367            let mut distances = Array1::zeros(npoints);
1368
1369            for i in 0..npoints {
1370                let point = points.row(i);
1371                let mut min_dist = f64::INFINITY;
1372
1373                for j in 0..k {
1374                    let centroid = centroids.row(j);
1375                    let dist: f64 = point
1376                        .iter()
1377                        .zip(centroid.iter())
1378                        .map(|(&a, &b)| (a - b).powi(2))
1379                        .sum::<f64>();
1380                    min_dist = min_dist.min(dist);
1381                }
1382
1383                distances[i] = min_dist;
1384            }
1385
1386            // Choose next centroid with probability proportional to squared distance
1387            let total_dist: f64 = distances.sum();
1388            let mut cumulative = 0.0;
1389            let random_val = rand::random::<f64>() * total_dist;
1390
1391            for i in 0..npoints {
1392                cumulative += distances[i];
1393                if cumulative >= random_val {
1394                    centroids.row_mut(k).assign(&points.row(i));
1395                    break;
1396                }
1397            }
1398        }
1399
1400        Ok(centroids)
1401    }
1402
1403    /// Update assignments based on distance matrix
1404    fn update_assignments(
1405        &mut self,
1406        distance_matrix: &Array2<f64>,
1407    ) -> SpatialResult<Array1<usize>> {
1408        let npoints = distance_matrix.nrows();
1409        let mut assignments = Array1::zeros(npoints);
1410
1411        for i in 0..npoints {
1412            let mut min_dist = f64::INFINITY;
1413            let mut best_cluster = 0;
1414
1415            for j in 0..self._numclusters {
1416                if distance_matrix[[i, j]] < min_dist {
1417                    min_dist = distance_matrix[[i, j]];
1418                    best_cluster = j;
1419                }
1420            }
1421
1422            assignments[i] = best_cluster;
1423        }
1424
1425        Ok(assignments)
1426    }
1427
1428    /// Update centroids using tensor core operations
1429    async fn update_centroids_tensor_cores(
1430        &self,
1431        points: &ArrayView2<'_, f64>,
1432        assignments: &Array1<usize>,
1433    ) -> SpatialResult<Array2<f64>> {
1434        let (_npoints, ndims) = points.dim();
1435        let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1436        let mut cluster_counts = vec![0; self._numclusters];
1437
1438        // Count points in each cluster
1439        for &cluster in assignments {
1440            cluster_counts[cluster] += 1;
1441        }
1442
1443        // Compute new centroids using tensor operations
1444        for cluster in 0..self._numclusters {
1445            if cluster_counts[cluster] == 0 {
1446                continue;
1447            }
1448
1449            // Create mask for points in this cluster
1450            let clusterpoints: Vec<usize> = assignments
1451                .iter()
1452                .enumerate()
1453                .filter(|(_, &c)| c == cluster)
1454                .map(|(i, _)| i)
1455                .collect();
1456
1457            // Extract cluster points
1458            let cluster_data = Array2::from_shape_fn((clusterpoints.len(), ndims), |(i, j)| {
1459                points[[clusterpoints[i], j]]
1460            });
1461
1462            // Compute mean using tensor operations (sum + scale)
1463            let sum_vector = self.tensor_sum_reduction(&cluster_data.view()).await?;
1464            let count = clusterpoints.len() as f64;
1465
1466            for j in 0..ndims {
1467                new_centroids[[cluster, j]] = sum_vector[j] / count;
1468            }
1469        }
1470
1471        Ok(new_centroids)
1472    }
1473
1474    /// Tensor sum reduction operation
1475    async fn tensor_sum_reduction(&self, data: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
1476        let (_npoints, ndims) = data.dim();
1477        let mut sum_vector = Array1::zeros(ndims);
1478
1479        // Simulate tensor reduction operation
1480        for j in 0..ndims {
1481            let column_sum: f64 = data.column(j).sum();
1482            sum_vector[j] = column_sum;
1483        }
1484
1485        Ok(sum_vector)
1486    }
1487
1488    /// Fallback distance computation without tensor cores
1489    fn compute_distances_fallback(
1490        &self,
1491        points: &ArrayView2<'_, f64>,
1492        centroids: &ArrayView2<'_, f64>,
1493    ) -> SpatialResult<Array2<f64>> {
1494        let (npoints, ndims) = points.dim();
1495        let (n_clusters_, _) = centroids.dim();
1496        let mut distances = Array2::zeros((npoints, n_clusters_));
1497
1498        for i in 0..npoints {
1499            for j in 0..n_clusters_ {
1500                let distance: f64 = points
1501                    .row(i)
1502                    .iter()
1503                    .zip(centroids.row(j).iter())
1504                    .map(|(&a, &b)| (a - b).powi(2))
1505                    .sum::<f64>()
1506                    .sqrt();
1507                distances[[i, j]] = distance;
1508            }
1509        }
1510
1511        Ok(distances)
1512    }
1513
1514    /// Fallback centroid update without tensor cores
1515    fn update_centroids_fallback(
1516        &self,
1517        points: &ArrayView2<'_, f64>,
1518        assignments: &Array1<usize>,
1519    ) -> SpatialResult<Array2<f64>> {
1520        let (npoints, ndims) = points.dim();
1521        let mut new_centroids = Array2::zeros((self._numclusters, ndims));
1522        let mut cluster_counts = vec![0; self._numclusters];
1523
1524        // Sum points for each cluster
1525        for i in 0..npoints {
1526            let cluster = assignments[i];
1527            cluster_counts[cluster] += 1;
1528
1529            for j in 0..ndims {
1530                new_centroids[[cluster, j]] += points[[i, j]];
1531            }
1532        }
1533
1534        // Compute means
1535        for cluster in 0..self._numclusters {
1536            if cluster_counts[cluster] > 0 {
1537                let count = cluster_counts[cluster] as f64;
1538                for j in 0..ndims {
1539                    new_centroids[[cluster, j]] /= count;
1540                }
1541            }
1542        }
1543
1544        Ok(new_centroids)
1545    }
1546
1547    /// Compute change in centroids for convergence checking
1548    fn compute_centroid_change(
1549        &self,
1550        old_centroids: &Array2<f64>,
1551        new_centroids: &Array2<f64>,
1552    ) -> f64 {
1553        let mut total_change = 0.0;
1554
1555        for i in 0..self._numclusters {
1556            let change: f64 = old_centroids
1557                .row(i)
1558                .iter()
1559                .zip(new_centroids.row(i).iter())
1560                .map(|(&a, &b)| (a - b).powi(2))
1561                .sum::<f64>()
1562                .sqrt();
1563            total_change += change;
1564        }
1565
1566        total_change / (self._numclusters as f64)
1567    }
1568}
1569
1570impl Default for StabilityMetrics {
1571    fn default() -> Self {
1572        Self::new()
1573    }
1574}
1575
1576impl StabilityMetrics {
1577    /// Create new stability metrics
1578    pub fn new() -> Self {
1579        Self {
1580            condition_number: 1.0,
1581            relative_error: 0.0,
1582            forward_error: 0.0,
1583            backward_error: 0.0,
1584            digit_loss: 0.0,
1585            stability_level: StabilityLevel::Excellent,
1586            error_types: Vec::new(),
1587            timestamp: Instant::now(),
1588        }
1589    }
1590
1591    /// Update stability level based on metrics
1592    pub fn update_stability_level(&mut self) {
1593        self.stability_level = if self.condition_number > 1e12 || self.relative_error > 1e-3 {
1594            StabilityLevel::Critical
1595        } else if self.condition_number > 1e8 || self.relative_error > 1e-6 {
1596            StabilityLevel::Poor
1597        } else if self.condition_number > 1e4 || self.relative_error > 1e-9 {
1598            StabilityLevel::Moderate
1599        } else if self.condition_number > 1e2 || self.relative_error > 1e-12 {
1600            StabilityLevel::Good
1601        } else {
1602            StabilityLevel::Excellent
1603        };
1604    }
1605
1606    /// Check for numerical errors
1607    pub fn detect_errors(&mut self, data: &Array2<f64>) {
1608        self.error_types.clear();
1609
1610        // Check for NaN or Inf values
1611        for &value in data.iter() {
1612            if !value.is_finite() {
1613                self.error_types.push(NumericalErrorType::InvalidValues);
1614                break;
1615            }
1616        }
1617
1618        // Check for overflow/underflow
1619        let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1620        if max_val > 1e100 {
1621            self.error_types.push(NumericalErrorType::Overflow);
1622        } else if max_val < 1e-100 && max_val > 0.0 {
1623            self.error_types.push(NumericalErrorType::Underflow);
1624        }
1625
1626        // Check for precision loss
1627        if self.digit_loss > 6.0 {
1628            self.error_types.push(NumericalErrorType::PrecisionLoss);
1629        }
1630
1631        // Check for ill-conditioning
1632        if self.condition_number > 1e12 {
1633            self.error_types.push(NumericalErrorType::IllConditioned);
1634        }
1635    }
1636}
1637
1638impl Default for DynamicPrecisionConfig {
1639    fn default() -> Self {
1640        Self {
1641            strategy: ScalingStrategy::Balanced,
1642            min_precision: PrecisionMode::Int8Dynamic,
1643            max_precision: PrecisionMode::Full32,
1644            stability_threshold_up: 1e-6,
1645            stability_threshold_down: 1e-9,
1646            performance_weight: 0.6,
1647            accuracy_weight: 0.4,
1648            max_changes_per_operation: 3,
1649            change_cooldown: Duration::from_millis(100),
1650        }
1651    }
1652}
1653
1654impl NumericalStabilityMonitor {
1655    /// Create new stability monitor
1656    pub fn new(config: DynamicPrecisionConfig) -> Self {
1657        Self {
1658            current_metrics: StabilityMetrics::new(),
1659            stability_history: VecDeque::new(),
1660            precision_config: config,
1661            current_precision: PrecisionMode::Mixed16,
1662            precision_history: VecDeque::new(),
1663            recovery_attempts: 0,
1664            max_history_length: 1000,
1665            last_precision_change: None,
1666        }
1667    }
1668
1669    /// Monitor stability during computation
1670    pub fn monitor_stability(
1671        &mut self,
1672        data: &Array2<f64>,
1673        computation_result: &Array2<f64>,
1674    ) -> SpatialResult<()> {
1675        // Compute condition number estimate
1676        self.current_metrics.condition_number =
1677            NumericalStabilityMonitor::estimate_condition_number(data);
1678
1679        // Estimate relative error
1680        self.current_metrics.relative_error =
1681            self.estimate_relative_error(data, computation_result);
1682
1683        // Compute forward and backward error bounds
1684        self.current_metrics.forward_error = self.estimate_forward_error(data, computation_result);
1685        self.current_metrics.backward_error =
1686            self.estimate_backward_error(data, computation_result);
1687
1688        // Estimate digit loss
1689        self.current_metrics.digit_loss = self.estimate_digit_loss();
1690
1691        // Update stability level
1692        self.current_metrics.update_stability_level();
1693
1694        // Detect errors
1695        self.current_metrics.detect_errors(computation_result);
1696
1697        // Update timestamp
1698        self.current_metrics.timestamp = Instant::now();
1699
1700        // Add to history
1701        self.stability_history
1702            .push_back(self.current_metrics.clone());
1703        if self.stability_history.len() > self.max_history_length {
1704            self.stability_history.pop_front();
1705        }
1706
1707        Ok(())
1708    }
1709
1710    /// Dynamically adjust precision based on stability
1711    pub fn adjust_precision(&mut self) -> SpatialResult<PrecisionMode> {
1712        // Check cooldown period
1713        if let Some(last_change) = self.last_precision_change {
1714            if last_change.elapsed() < self.precision_config.change_cooldown {
1715                return Ok(self.current_precision);
1716            }
1717        }
1718
1719        let new_precision = match self.current_metrics.stability_level {
1720            StabilityLevel::Critical => {
1721                // Use highest precision for critical stability
1722                self.precision_config.max_precision
1723            }
1724            StabilityLevel::Poor => {
1725                // Increase precision
1726                NumericalStabilityMonitor::increase_precision(self.current_precision)
1727            }
1728            StabilityLevel::Moderate => {
1729                // Maintain current precision or slightly adjust
1730                if self.current_metrics.relative_error
1731                    > self.precision_config.stability_threshold_up
1732                {
1733                    NumericalStabilityMonitor::increase_precision(self.current_precision)
1734                } else {
1735                    self.current_precision
1736                }
1737            }
1738            StabilityLevel::Good => {
1739                // Can potentially decrease precision for performance
1740                if self.current_metrics.relative_error
1741                    < self.precision_config.stability_threshold_down
1742                {
1743                    NumericalStabilityMonitor::decrease_precision(self.current_precision)
1744                } else {
1745                    self.current_precision
1746                }
1747            }
1748            StabilityLevel::Excellent => {
1749                // Use lowest precision for maximum performance
1750                if self.precision_config.strategy == ScalingStrategy::Aggressive {
1751                    self.precision_config.min_precision
1752                } else {
1753                    NumericalStabilityMonitor::decrease_precision(self.current_precision)
1754                }
1755            }
1756        };
1757
1758        // Update precision if changed
1759        if new_precision != self.current_precision {
1760            self.precision_history.push_back((
1761                Instant::now(),
1762                new_precision,
1763                self.current_metrics.relative_error,
1764            ));
1765            self.current_precision = new_precision;
1766            self.last_precision_change = Some(Instant::now());
1767        }
1768
1769        Ok(new_precision)
1770    }
1771
1772    /// Increase precision mode
1773    fn increase_precision(current: PrecisionMode) -> PrecisionMode {
1774        match current {
1775            PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
1776            PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
1777            PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
1778            PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
1779            PrecisionMode::Full32 => PrecisionMode::Full32, // Already at max
1780            _ => PrecisionMode::Mixed16,
1781        }
1782    }
1783
1784    /// Decrease precision mode
1785    fn decrease_precision(current: PrecisionMode) -> PrecisionMode {
1786        match current {
1787            PrecisionMode::Full32 => PrecisionMode::BrainFloat16,
1788            PrecisionMode::BrainFloat16 => PrecisionMode::Mixed16,
1789            PrecisionMode::Mixed16 => PrecisionMode::Int8Dynamic,
1790            PrecisionMode::Int8Dynamic => PrecisionMode::Int4Advanced,
1791            PrecisionMode::Int4Advanced => PrecisionMode::Int4Advanced, // Already at min
1792            _ => PrecisionMode::Mixed16,
1793        }
1794    }
1795
1796    /// Estimate condition number
1797    fn estimate_condition_number(data: &Array2<f64>) -> f64 {
1798        // Simplified condition number estimation
1799        let max_val = data.fold(0.0f64, |acc, &x| acc.max(x.abs()));
1800        let min_val = data.fold(f64::INFINITY, |acc, &x| {
1801            if x.abs() > 1e-15 {
1802                acc.min(x.abs())
1803            } else {
1804                acc
1805            }
1806        });
1807
1808        if min_val.is_finite() && min_val > 0.0 {
1809            max_val / min_val
1810        } else {
1811            1e12 // High condition number for near-singular cases
1812        }
1813    }
1814
1815    /// Estimate relative error
1816    fn estimate_relative_error(&mut self, input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1817        // Simplified relative error estimation
1818        let mean_val = output.mean().unwrap_or(0.0);
1819        if mean_val.abs() > 1e-15 {
1820            // Use machine epsilon scaled by condition number
1821            let machine_eps = match self.current_precision {
1822                PrecisionMode::Full32 => 2.22e-16,
1823                PrecisionMode::Mixed16 | PrecisionMode::BrainFloat16 => 9.77e-4,
1824                PrecisionMode::Int8Dynamic => 1.0 / 256.0,
1825                PrecisionMode::Int4Advanced => 1.0 / 16.0,
1826                _ => 1e-6,
1827            };
1828            machine_eps * self.current_metrics.condition_number
1829        } else {
1830            0.0
1831        }
1832    }
1833
1834    /// Estimate forward error
1835    fn estimate_forward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1836        // Forward error bound estimate
1837        self.current_metrics.relative_error * self.current_metrics.condition_number
1838    }
1839
1840    /// Estimate backward error
1841    fn estimate_backward_error(&mut self, _input: &Array2<f64>, output: &Array2<f64>) -> f64 {
1842        // Backward error bound estimate
1843        self.current_metrics.relative_error
1844    }
1845
1846    /// Estimate digit loss
1847    fn estimate_digit_loss(&self) -> f64 {
1848        if self.current_metrics.condition_number > 1.0 {
1849            self.current_metrics.condition_number.log10().max(0.0)
1850        } else {
1851            0.0
1852        }
1853    }
1854}
1855
1856impl Default for ErrorRecoverySystem {
1857    fn default() -> Self {
1858        Self::new()
1859    }
1860}
1861
1862impl ErrorRecoverySystem {
1863    /// Create new error recovery system
1864    pub fn new() -> Self {
1865        let mut recovery_strategies = HashMap::new();
1866
1867        // Define recovery strategies for each error type
1868        recovery_strategies.insert(
1869            NumericalErrorType::Overflow,
1870            vec![
1871                RecoveryAction::IncreasePrecision,
1872                RecoveryAction::ReduceTileSize,
1873                RecoveryAction::NumericalStabilization,
1874            ],
1875        );
1876        recovery_strategies.insert(
1877            NumericalErrorType::Underflow,
1878            vec![
1879                RecoveryAction::IncreasePrecision,
1880                RecoveryAction::NumericalStabilization,
1881            ],
1882        );
1883        recovery_strategies.insert(
1884            NumericalErrorType::PrecisionLoss,
1885            vec![
1886                RecoveryAction::IncreasePrecision,
1887                RecoveryAction::RetryWithNewParams,
1888            ],
1889        );
1890        recovery_strategies.insert(
1891            NumericalErrorType::IllConditioned,
1892            vec![
1893                RecoveryAction::IncreasePrecision,
1894                RecoveryAction::NumericalStabilization,
1895                RecoveryAction::SwitchToCPU,
1896            ],
1897        );
1898        recovery_strategies.insert(
1899            NumericalErrorType::InvalidValues,
1900            vec![
1901                RecoveryAction::FallbackAlgorithm,
1902                RecoveryAction::SwitchToCPU,
1903            ],
1904        );
1905
1906        Self {
1907            recovery_strategies,
1908            recovery_history: VecDeque::new(),
1909            max_recovery_attempts: 3,
1910            success_rates: HashMap::new(),
1911        }
1912    }
1913
1914    /// Attempt recovery from numerical error
1915    pub async fn attempt_recovery(
1916        &mut self,
1917        error_type: NumericalErrorType,
1918    ) -> SpatialResult<RecoveryAction> {
1919        let start_time = Instant::now();
1920
1921        // Get recovery strategies for this error _type
1922        let strategies = self
1923            .recovery_strategies
1924            .get(&error_type)
1925            .ok_or_else(|| SpatialError::InvalidInput("Unknown error _type".to_string()))?
1926            .clone(); // Clone to avoid borrowing conflict
1927
1928        // Choose best strategy based on success rates
1929        let best_action = self.choose_best_recovery_action(&strategies);
1930
1931        // Record recovery attempt
1932        let attempt = RecoveryAttempt {
1933            error_type,
1934            action: best_action,
1935            success: false, // Will be updated after actual recovery
1936            duration: start_time.elapsed(),
1937            post_recovery_metrics: None,
1938            timestamp: start_time,
1939        };
1940
1941        self.recovery_history.push_back(attempt);
1942
1943        Ok(best_action)
1944    }
1945
1946    /// Choose best recovery action based on success rates
1947    fn choose_best_recovery_action(&mut self, strategies: &[RecoveryAction]) -> RecoveryAction {
1948        strategies
1949            .iter()
1950            .max_by(|&a, &b| {
1951                let rate_a = self.success_rates.get(a).unwrap_or(&0.5);
1952                let rate_b = self.success_rates.get(b).unwrap_or(&0.5);
1953                rate_a
1954                    .partial_cmp(rate_b)
1955                    .unwrap_or(std::cmp::Ordering::Equal)
1956            })
1957            .copied()
1958            .unwrap_or(RecoveryAction::IncreasePrecision)
1959    }
1960
1961    /// Update success rate for recovery action
1962    pub fn update_success_rate(&mut self, action: RecoveryAction, success: bool) {
1963        let current_rate = self.success_rates.get(&action).unwrap_or(&0.5);
1964        let new_rate = if success {
1965            current_rate * 0.9 + 0.1 // Exponential moving average
1966        } else {
1967            current_rate * 0.9
1968        };
1969        self.success_rates.insert(action, new_rate);
1970    }
1971}
1972
1973impl PerformanceAccuracyAnalyzer {
1974    /// Create new performance-accuracy analyzer
1975    pub fn new(params: TradeOffParams) -> Self {
1976        Self {
1977            performance_data: HashMap::new(),
1978            accuracy_data: HashMap::new(),
1979            optimization_params: params,
1980            pareto_frontier: Vec::new(),
1981        }
1982    }
1983
1984    /// Record performance measurement
1985    pub fn record_performance(&mut self, precision: PrecisionMode, duration: Duration) {
1986        self.performance_data
1987            .entry(precision)
1988            .or_default()
1989            .push_back(duration);
1990
1991        // Maintain reasonable history size
1992        if let Some(history) = self.performance_data.get_mut(&precision) {
1993            if history.len() > 100 {
1994                history.pop_front();
1995            }
1996        }
1997    }
1998
1999    /// Record accuracy measurement
2000    pub fn record_accuracy(&mut self, precision: PrecisionMode, accuracy: f64) {
2001        self.accuracy_data
2002            .entry(precision)
2003            .or_default()
2004            .push_back(accuracy);
2005
2006        // Maintain reasonable history size
2007        if let Some(history) = self.accuracy_data.get_mut(&precision) {
2008            if history.len() > 100 {
2009                history.pop_front();
2010            }
2011        }
2012    }
2013
2014    /// Optimize precision mode based on trade-offs
2015    pub fn optimize_precision(&mut self) -> PrecisionMode {
2016        self.update_pareto_frontier();
2017
2018        match self.optimization_params.objective {
2019            OptimizationObjective::MaxPerformance => self
2020                .pareto_frontier
2021                .iter()
2022                .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
2023                .map(|(_a, b, mode)| *mode)
2024                .unwrap_or(PrecisionMode::Mixed16),
2025            OptimizationObjective::MaxAccuracy => self
2026                .pareto_frontier
2027                .iter()
2028                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
2029                .map(|(_a, b, mode)| *mode)
2030                .unwrap_or(PrecisionMode::Full32),
2031            OptimizationObjective::Balanced => {
2032                // Weighted combination - compute scores first to avoid borrowing conflict
2033                let mut best_score = f64::NEG_INFINITY;
2034                let mut best_mode = PrecisionMode::Mixed16;
2035
2036                // Extract weights to avoid borrowing conflict
2037                let performance_weight = self.optimization_params.performance_weight;
2038                let accuracy_weight = self.optimization_params.accuracy_weight;
2039
2040                for &(perf, acc, mode) in &self.pareto_frontier {
2041                    // Inline compute_weighted_score logic to avoid borrowing conflict
2042                    let perf_score = 1.0 / (perf + 1e-9);
2043                    let score = performance_weight * perf_score + accuracy_weight * acc;
2044                    if score > best_score {
2045                        best_score = score;
2046                        best_mode = mode;
2047                    }
2048                }
2049
2050                best_mode
2051            }
2052            _ => PrecisionMode::Mixed16,
2053        }
2054    }
2055
2056    /// Update Pareto frontier
2057    fn update_pareto_frontier(&mut self) {
2058        self.pareto_frontier.clear();
2059
2060        for precision in [
2061            PrecisionMode::Full32,
2062            PrecisionMode::BrainFloat16,
2063            PrecisionMode::Mixed16,
2064            PrecisionMode::Int8Dynamic,
2065            PrecisionMode::Int4Advanced,
2066        ] {
2067            if let (Some(perf_data), Some(acc_data)) = (
2068                self.performance_data.get(&precision),
2069                self.accuracy_data.get(&precision),
2070            ) {
2071                if !perf_data.is_empty() && !acc_data.is_empty() {
2072                    let avg_perf = perf_data.iter().map(|d| d.as_secs_f64()).sum::<f64>()
2073                        / perf_data.len() as f64;
2074                    let avg_acc = acc_data.iter().sum::<f64>() / acc_data.len() as f64;
2075
2076                    self.pareto_frontier.push((avg_perf, avg_acc, precision));
2077                }
2078            }
2079        }
2080    }
2081
2082    /// Compute weighted score for balanced optimization
2083    #[allow(dead_code)]
2084    fn compute_weighted_score(&mut self, performance: f64, accuracy: f64) -> f64 {
2085        // Performance score (inverse of time - higher is better)
2086        let perf_score = 1.0 / (performance + 1e-9);
2087
2088        // Weighted combination
2089        self.optimization_params.performance_weight * perf_score
2090            + self.optimization_params.accuracy_weight * accuracy
2091    }
2092}
2093
2094impl AdvancedTensorCoreDistanceMatrix {
2095    /// Create new advanced tensor core distance matrix computer
2096    pub fn new() -> SpatialResult<Self> {
2097        let base_computer = TensorCoreDistanceMatrix::new()?;
2098        let precision_config = DynamicPrecisionConfig::default();
2099        let stability_monitor =
2100            Arc::new(Mutex::new(NumericalStabilityMonitor::new(precision_config)));
2101        let recovery_system = ErrorRecoverySystem::new();
2102        let trade_off_params = TradeOffParams {
2103            performance_weight: 0.6,
2104            accuracy_weight: 0.4,
2105            energy_weight: 0.0,
2106            min_accuracy: 0.95,
2107            max_time: Duration::from_secs(30),
2108            objective: OptimizationObjective::Balanced,
2109        };
2110        let performance_analyzer = PerformanceAccuracyAnalyzer::new(trade_off_params);
2111
2112        Ok(Self {
2113            base_computer,
2114            stability_monitor,
2115            recovery_system,
2116            performance_analyzer,
2117            dynamic_precision_enabled: true,
2118            auto_recovery_enabled: true,
2119        })
2120    }
2121
2122    /// Configure dynamic precision scaling
2123    pub fn with_dynamic_precision(mut self, enabled: bool) -> Self {
2124        self.dynamic_precision_enabled = enabled;
2125        self
2126    }
2127
2128    /// Configure automatic error recovery
2129    pub fn with_auto_recovery(mut self, enabled: bool) -> Self {
2130        self.auto_recovery_enabled = enabled;
2131        self
2132    }
2133
2134    /// Compute distance matrix with advanced stability monitoring
2135    pub async fn compute_with_stability_monitoring(
2136        &mut self,
2137        points: &ArrayView2<'_, f64>,
2138    ) -> SpatialResult<Array2<f64>> {
2139        let start_time = Instant::now();
2140
2141        // Initial stability assessment
2142        {
2143            let mut monitor = self.stability_monitor.lock().unwrap();
2144            // Skip initial stability check as we don't have a result yet
2145
2146            if self.dynamic_precision_enabled {
2147                let optimal_precision = monitor.adjust_precision()?;
2148                self.base_computer.precision_mode = optimal_precision;
2149            }
2150        }
2151
2152        let mut result = None;
2153        let mut recovery_attempts = 0;
2154        let max_attempts = 3;
2155
2156        while result.is_none() && recovery_attempts < max_attempts {
2157            match self.base_computer.compute_parallel(points).await {
2158                Ok(distances) => {
2159                    // Monitor stability of result
2160                    {
2161                        let mut monitor = self.stability_monitor.lock().unwrap();
2162                        monitor.monitor_stability(&points.to_owned(), &distances)?;
2163                    }
2164
2165                    // Check for numerical errors
2166                    let stability_level = {
2167                        let monitor = self.stability_monitor.lock().unwrap();
2168                        monitor.current_metrics.stability_level
2169                    };
2170
2171                    if stability_level == StabilityLevel::Critical && self.auto_recovery_enabled {
2172                        // Attempt recovery
2173                        recovery_attempts += 1;
2174                        let recovery_action = self
2175                            .recovery_system
2176                            .attempt_recovery(NumericalErrorType::IllConditioned)
2177                            .await?;
2178
2179                        // Apply recovery action
2180                        self.apply_recovery_action(recovery_action).await?;
2181                        continue;
2182                    } else {
2183                        result = Some(distances);
2184                    }
2185                }
2186                Err(e) => {
2187                    if self.auto_recovery_enabled && recovery_attempts < max_attempts {
2188                        recovery_attempts += 1;
2189                        let recovery_action = self
2190                            .recovery_system
2191                            .attempt_recovery(NumericalErrorType::InvalidValues)
2192                            .await?;
2193                        self.apply_recovery_action(recovery_action).await?;
2194                        continue;
2195                    } else {
2196                        return Err(e);
2197                    }
2198                }
2199            }
2200        }
2201
2202        let final_result = result.ok_or_else(|| {
2203            SpatialError::InvalidInput(
2204                "Failed to compute stable result after recovery attempts".to_string(),
2205            )
2206        })?;
2207
2208        // Record performance data
2209        let duration = start_time.elapsed();
2210        let precision = self.base_computer.precision_mode;
2211        self.performance_analyzer
2212            .record_performance(precision, duration);
2213
2214        // Estimate accuracy (simplified)
2215        let accuracy = self.estimate_result_accuracy(&final_result);
2216        self.performance_analyzer
2217            .record_accuracy(precision, accuracy);
2218
2219        Ok(final_result)
2220    }
2221
2222    /// Apply recovery action
2223    async fn apply_recovery_action(&mut self, action: RecoveryAction) -> SpatialResult<()> {
2224        match action {
2225            RecoveryAction::IncreasePrecision => {
2226                self.base_computer.precision_mode = match self.base_computer.precision_mode {
2227                    PrecisionMode::Int4Advanced => PrecisionMode::Int8Dynamic,
2228                    PrecisionMode::Int8Dynamic => PrecisionMode::Mixed16,
2229                    PrecisionMode::Mixed16 => PrecisionMode::BrainFloat16,
2230                    PrecisionMode::BrainFloat16 => PrecisionMode::Full32,
2231                    PrecisionMode::Full32 => PrecisionMode::Full32,
2232                    _ => PrecisionMode::Mixed16,
2233                };
2234            }
2235            RecoveryAction::ReduceTileSize => {
2236                let (current_row, current_col) = self.base_computer.tile_size;
2237                self.base_computer.tile_size = (current_row / 2, current_col / 2);
2238                if self.base_computer.tile_size.0 < 16 {
2239                    self.base_computer.tile_size = (16, 16);
2240                }
2241            }
2242            RecoveryAction::FallbackAlgorithm => {
2243                // Switch to more conservative settings
2244                self.base_computer.precision_mode = PrecisionMode::Full32;
2245                self.base_computer.hierarchical_tiling = false;
2246            }
2247            RecoveryAction::NumericalStabilization => {
2248                // Apply numerical stabilization techniques
2249                self.base_computer.precision_mode = PrecisionMode::Full32;
2250                self.base_computer.tile_size = (64, 64);
2251            }
2252            _ => {
2253                // Default recovery
2254                self.base_computer.precision_mode = PrecisionMode::Full32;
2255            }
2256        }
2257
2258        Ok(())
2259    }
2260
2261    /// Estimate result accuracy (simplified)
2262    fn estimate_result_accuracy(&self, result: &Array2<f64>) -> f64 {
2263        // Simplified accuracy estimation based on numerical properties
2264        let has_invalid = result.iter().any(|&x| !x.is_finite());
2265        if has_invalid {
2266            return 0.0;
2267        }
2268
2269        let max_val = result.fold(0.0f64, |acc, &x| acc.max(x.abs()));
2270        let min_val = result.fold(f64::INFINITY, |acc, &x| {
2271            if x.abs() > 1e-15 {
2272                acc.min(x.abs())
2273            } else {
2274                acc
2275            }
2276        });
2277
2278        if min_val.is_finite() && min_val > 0.0 {
2279            let dynamic_range = max_val / min_val;
2280            (1.0 / (1.0 + dynamic_range.log10() / 10.0)).clamp(0.0, 1.0)
2281        } else {
2282            0.95 // Default good accuracy
2283        }
2284    }
2285}
2286
2287/// Detect tensor core capabilities of available GPU hardware
2288#[allow(dead_code)]
2289pub fn detect_tensor_core_capabilities() -> SpatialResult<TensorCoreCapabilities> {
2290    // Simulate hardware detection
2291    // In a real implementation, this would use CUDA/ROCm/OpenCL APIs
2292
2293    Ok(TensorCoreCapabilities {
2294        tensor_core_types: vec![
2295            TensorCoreType::NvidiaTensorCore,
2296            TensorCoreType::StandardCores,
2297        ],
2298        supported_precisions: vec![
2299            PrecisionMode::Full32,
2300            PrecisionMode::Mixed16,
2301            PrecisionMode::BrainFloat16,
2302            PrecisionMode::Int8Dynamic,
2303        ],
2304        max_tensor_size: (4096, 4096, 4096),
2305        peak_throughput_tops: 312.0,   // A100 FP16 performance
2306        memory_bandwidth_gbps: 1555.0, // A100 HBM2 bandwidth
2307        l2_cache_mb: 40.0,
2308        num_sms: 108,
2309        architecture: GpuArchitecture::Ampere,
2310    })
2311}
2312
2313/// Extension trait for TensorCoreDistanceMatrix
2314impl TensorCoreDistanceMatrix {
2315    /// Compute distances from points to centroids
2316    pub async fn compute_distances_to_centroids(
2317        &self,
2318        points: &ArrayView2<'_, f64>,
2319        centroids: &ArrayView2<'_, f64>,
2320    ) -> SpatialResult<Array2<f64>> {
2321        let (npoints, ndims) = points.dim();
2322        let (n_clusters_, n_dims_c) = centroids.dim();
2323        let mut distances = Array2::zeros((npoints, n_clusters_));
2324
2325        // Compute distances using optimized tensor operations
2326        for i in 0..npoints {
2327            for j in 0..n_clusters_ {
2328                let distance: f64 = points
2329                    .row(i)
2330                    .iter()
2331                    .zip(centroids.row(j).iter())
2332                    .map(|(&a, &b)| (a - b).powi(2))
2333                    .sum::<f64>()
2334                    .sqrt();
2335                distances[[i, j]] = distance;
2336            }
2337        }
2338
2339        Ok(distances)
2340    }
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345    use super::*;
2346    use ndarray::array;
2347
2348    #[test]
2349    fn test_precision_mode() {
2350        assert_eq!(PrecisionMode::Mixed16, PrecisionMode::Mixed16);
2351        assert_ne!(PrecisionMode::Mixed16, PrecisionMode::Full32);
2352    }
2353
2354    #[test]
2355    fn test_tensor_core_capabilities() {
2356        let capabilities = detect_tensor_core_capabilities();
2357        assert!(capabilities.is_ok());
2358
2359        let caps = capabilities.unwrap();
2360        assert!(!caps.tensor_core_types.is_empty());
2361        assert!(!caps.supported_precisions.is_empty());
2362    }
2363
2364    #[test]
2365    fn test_tensor_core_distance_matrix_creation() {
2366        let result = TensorCoreDistanceMatrix::new();
2367        assert!(result.is_ok());
2368
2369        let matrix_computer = result.unwrap();
2370        assert_eq!(matrix_computer.precision_mode, PrecisionMode::Mixed16);
2371    }
2372
2373    #[test]
2374    fn test_tensor_core_clustering_creation() {
2375        let result = TensorCoreClustering::new(3);
2376        assert!(result.is_ok());
2377
2378        let clustering = result.unwrap();
2379        assert_eq!(clustering._numclusters, 3);
2380    }
2381
2382    #[tokio::test]
2383    async fn test_tensor_core_distance_computation() {
2384        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2385        let mut matrix_computer = TensorCoreDistanceMatrix::new().unwrap();
2386
2387        let result = matrix_computer.compute_parallel(&points.view()).await;
2388        assert!(result.is_ok());
2389
2390        let distances = result.unwrap();
2391        assert_eq!(distances.shape(), &[3, 3]);
2392
2393        // Check diagonal is zero
2394        for i in 0..3 {
2395            assert!((distances[[i, i]]).abs() < 1e-10);
2396        }
2397    }
2398
2399    #[tokio::test]
2400    async fn test_tensor_core_clustering() {
2401        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
2402        let mut clustering = TensorCoreClustering::new(2).unwrap();
2403
2404        let result = clustering.fit(&points.view()).await;
2405        assert!(result.is_ok());
2406
2407        let (centroids, assignments) = result.unwrap();
2408        assert_eq!(centroids.shape(), &[2, 2]);
2409        assert_eq!(assignments.len(), 4);
2410    }
2411
2412    #[test]
2413    fn test_stability_metrics_creation() {
2414        let metrics = StabilityMetrics::new();
2415        assert_eq!(metrics.condition_number, 1.0);
2416        assert_eq!(metrics.relative_error, 0.0);
2417        assert_eq!(metrics.stability_level, StabilityLevel::Excellent);
2418        assert!(metrics.error_types.is_empty());
2419    }
2420
2421    #[test]
2422    fn test_stability_level_update() {
2423        let mut metrics = StabilityMetrics::new();
2424
2425        // Test critical stability
2426        metrics.condition_number = 1e15;
2427        metrics.update_stability_level();
2428        assert_eq!(metrics.stability_level, StabilityLevel::Critical);
2429
2430        // Test poor stability
2431        metrics.condition_number = 1e9;
2432        metrics.relative_error = 1e-7;
2433        metrics.update_stability_level();
2434        assert_eq!(metrics.stability_level, StabilityLevel::Poor);
2435
2436        // Test good stability
2437        metrics.condition_number = 1e3;
2438        metrics.relative_error = 1e-10;
2439        metrics.update_stability_level();
2440        assert_eq!(metrics.stability_level, StabilityLevel::Good);
2441    }
2442
2443    #[test]
2444    #[ignore]
2445    fn test_error_detection() {
2446        let mut metrics = StabilityMetrics::new();
2447
2448        // Test NaN detection
2449        let data_with_nan = array![[1.0, 2.0], [f64::NAN, 4.0]];
2450        metrics.detect_errors(&data_with_nan);
2451        assert!(metrics
2452            .error_types
2453            .contains(&NumericalErrorType::InvalidValues));
2454
2455        // Test overflow detection
2456        let data_with_overflow = array![[1e150, 2.0], [3.0, 4.0]];
2457        metrics.detect_errors(&data_with_overflow);
2458        assert!(metrics.error_types.contains(&NumericalErrorType::Overflow));
2459
2460        // Test underflow detection
2461        let data_with_underflow = array![[1e-150, 2.0], [3.0, 4.0]];
2462        metrics.detect_errors(&data_with_underflow);
2463        assert!(metrics.error_types.contains(&NumericalErrorType::Underflow));
2464    }
2465
2466    #[test]
2467    fn test_dynamic_precision_config() {
2468        let config = DynamicPrecisionConfig::default();
2469        assert_eq!(config.strategy, ScalingStrategy::Balanced);
2470        assert_eq!(config.min_precision, PrecisionMode::Int8Dynamic);
2471        assert_eq!(config.max_precision, PrecisionMode::Full32);
2472        assert_eq!(config.performance_weight, 0.6);
2473        assert_eq!(config.accuracy_weight, 0.4);
2474    }
2475
2476    #[test]
2477    fn test_numerical_stability_monitor_creation() {
2478        let config = DynamicPrecisionConfig::default();
2479        let monitor = NumericalStabilityMonitor::new(config);
2480
2481        assert_eq!(monitor.current_precision, PrecisionMode::Mixed16);
2482        assert!(monitor.stability_history.is_empty());
2483        assert_eq!(monitor.recovery_attempts, 0);
2484    }
2485
2486    #[test]
2487    fn test_precision_increase_decrease() {
2488        let config = DynamicPrecisionConfig::default();
2489        let monitor = NumericalStabilityMonitor::new(config);
2490
2491        // Test precision increase
2492        let increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Int8Dynamic);
2493        assert_eq!(increased, PrecisionMode::Mixed16);
2494
2495        let max_increased = NumericalStabilityMonitor::increase_precision(PrecisionMode::Full32);
2496        assert_eq!(max_increased, PrecisionMode::Full32); // Should stay at max
2497
2498        // Test precision decrease
2499        let decreased = NumericalStabilityMonitor::decrease_precision(PrecisionMode::Mixed16);
2500        assert_eq!(decreased, PrecisionMode::Int8Dynamic);
2501
2502        let min_decreased =
2503            NumericalStabilityMonitor::decrease_precision(PrecisionMode::Int4Advanced);
2504        assert_eq!(min_decreased, PrecisionMode::Int4Advanced); // Should stay at min
2505    }
2506
2507    #[test]
2508    fn test_condition_number_estimation() {
2509        let config = DynamicPrecisionConfig::default();
2510        let monitor = NumericalStabilityMonitor::new(config);
2511
2512        // Well-conditioned data
2513        let well_conditioned = array![[1.0, 2.0], [3.0, 4.0]];
2514        let condition_1 = NumericalStabilityMonitor::estimate_condition_number(&well_conditioned);
2515        assert!(condition_1 > 1.0 && condition_1 < 100.0);
2516
2517        // Ill-conditioned data (large range)
2518        let ill_conditioned = array![[1e-10, 2.0], [3.0, 1e10]];
2519        let condition_2 = NumericalStabilityMonitor::estimate_condition_number(&ill_conditioned);
2520        assert!(condition_2 > 1e15);
2521    }
2522
2523    #[test]
2524    fn test_error_recovery_system_creation() {
2525        let recovery_system = ErrorRecoverySystem::new();
2526
2527        // Check that recovery strategies are defined
2528        assert!(!recovery_system.recovery_strategies.is_empty());
2529        assert!(recovery_system
2530            .recovery_strategies
2531            .contains_key(&NumericalErrorType::Overflow));
2532        assert!(recovery_system
2533            .recovery_strategies
2534            .contains_key(&NumericalErrorType::IllConditioned));
2535        assert_eq!(recovery_system.max_recovery_attempts, 3);
2536    }
2537
2538    #[tokio::test]
2539    async fn test_recovery_action_selection() {
2540        let mut recovery_system = ErrorRecoverySystem::new();
2541
2542        let action = recovery_system
2543            .attempt_recovery(NumericalErrorType::Overflow)
2544            .await;
2545        assert!(action.is_ok());
2546
2547        let recovery_action = action.unwrap();
2548        assert!(matches!(
2549            recovery_action,
2550            RecoveryAction::IncreasePrecision
2551                | RecoveryAction::ReduceTileSize
2552                | RecoveryAction::NumericalStabilization
2553        ));
2554    }
2555
2556    #[test]
2557    fn test_success_rate_update() {
2558        let mut recovery_system = ErrorRecoverySystem::new();
2559
2560        // Test successful recovery
2561        recovery_system.update_success_rate(RecoveryAction::IncreasePrecision, true);
2562        let rate = recovery_system
2563            .success_rates
2564            .get(&RecoveryAction::IncreasePrecision);
2565        assert!(rate.is_some());
2566        assert!(*rate.unwrap() > 0.5);
2567
2568        // Test failed recovery
2569        recovery_system.update_success_rate(RecoveryAction::ReduceTileSize, false);
2570        let rate = recovery_system
2571            .success_rates
2572            .get(&RecoveryAction::ReduceTileSize);
2573        assert!(rate.is_some());
2574        assert!(*rate.unwrap() < 0.5);
2575    }
2576
2577    #[test]
2578    fn test_performance_accuracy_analyzer() {
2579        let params = TradeOffParams {
2580            performance_weight: 0.7,
2581            accuracy_weight: 0.3,
2582            energy_weight: 0.0,
2583            min_accuracy: 0.9,
2584            max_time: Duration::from_secs(10),
2585            objective: OptimizationObjective::Balanced,
2586        };
2587
2588        let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2589
2590        // Record some performance data
2591        analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2592        analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2593
2594        // Record some accuracy data
2595        analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2596        analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2597
2598        // Test optimization
2599        let optimal_precision = analyzer.optimize_precision();
2600        assert!(matches!(
2601            optimal_precision,
2602            PrecisionMode::Mixed16 | PrecisionMode::Full32
2603        ));
2604    }
2605
2606    #[test]
2607    fn test_pareto_frontier_update() {
2608        let params = TradeOffParams {
2609            performance_weight: 0.5,
2610            accuracy_weight: 0.5,
2611            energy_weight: 0.0,
2612            min_accuracy: 0.8,
2613            max_time: Duration::from_secs(5),
2614            objective: OptimizationObjective::Balanced,
2615        };
2616
2617        let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2618
2619        // Add data for multiple precision modes
2620        analyzer.record_performance(PrecisionMode::Int8Dynamic, Duration::from_millis(50));
2621        analyzer.record_accuracy(PrecisionMode::Int8Dynamic, 0.85);
2622
2623        analyzer.record_performance(PrecisionMode::Mixed16, Duration::from_millis(100));
2624        analyzer.record_accuracy(PrecisionMode::Mixed16, 0.95);
2625
2626        analyzer.record_performance(PrecisionMode::Full32, Duration::from_millis(200));
2627        analyzer.record_accuracy(PrecisionMode::Full32, 0.99);
2628
2629        analyzer.update_pareto_frontier();
2630        assert!(!analyzer.pareto_frontier.is_empty());
2631        assert_eq!(analyzer.pareto_frontier.len(), 3);
2632    }
2633
2634    #[test]
2635    fn test_weighted_score_computation() {
2636        let params = TradeOffParams {
2637            performance_weight: 0.6,
2638            accuracy_weight: 0.4,
2639            energy_weight: 0.0,
2640            min_accuracy: 0.8,
2641            max_time: Duration::from_secs(5),
2642            objective: OptimizationObjective::Custom,
2643        };
2644
2645        let mut analyzer = PerformanceAccuracyAnalyzer::new(params);
2646
2647        // Test different performance-accuracy combinations
2648        let score1 = analyzer.compute_weighted_score(0.1, 0.9); // Fast, accurate
2649        let score2 = analyzer.compute_weighted_score(0.2, 0.95); // Slower, more accurate
2650
2651        assert!(score1 > 0.0);
2652        assert!(score2 > 0.0);
2653    }
2654
2655    #[test]
2656    fn test_advanced_tensor_core_distance_matrix_creation() {
2657        let result = AdvancedTensorCoreDistanceMatrix::new();
2658        assert!(result.is_ok());
2659
2660        let advanced_computer = result.unwrap();
2661        assert!(advanced_computer.dynamic_precision_enabled);
2662        assert!(advanced_computer.auto_recovery_enabled);
2663    }
2664
2665    #[tokio::test]
2666    #[ignore]
2667    async fn test_stability_monitoring_computation() {
2668        let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2669        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
2670
2671        let result = advanced_computer
2672            .compute_with_stability_monitoring(&points.view())
2673            .await;
2674        assert!(result.is_ok());
2675
2676        let distances = result.unwrap();
2677        assert_eq!(distances.shape(), &[3, 3]);
2678
2679        // Check that stability monitoring was performed
2680        let monitor = advanced_computer.stability_monitor.lock().unwrap();
2681        assert!(!monitor.stability_history.is_empty());
2682    }
2683
2684    #[tokio::test]
2685    async fn test_recovery_action_application() {
2686        let mut advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2687        let original_precision = advanced_computer.base_computer.precision_mode;
2688
2689        // Test precision increase recovery
2690        let result = advanced_computer
2691            .apply_recovery_action(RecoveryAction::IncreasePrecision)
2692            .await;
2693        assert!(result.is_ok());
2694
2695        // Precision should have increased (unless already at max)
2696        if original_precision != PrecisionMode::Full32 {
2697            assert_ne!(
2698                advanced_computer.base_computer.precision_mode,
2699                original_precision
2700            );
2701        }
2702
2703        // Test tile size reduction recovery
2704        let original_tile_size = advanced_computer.base_computer.tile_size;
2705        let result = advanced_computer
2706            .apply_recovery_action(RecoveryAction::ReduceTileSize)
2707            .await;
2708        assert!(result.is_ok());
2709
2710        let new_tile_size = advanced_computer.base_computer.tile_size;
2711        assert!(new_tile_size.0 <= original_tile_size.0);
2712        assert!(new_tile_size.1 <= original_tile_size.1);
2713    }
2714
2715    #[test]
2716    fn test_result_accuracy_estimation() {
2717        let advanced_computer = AdvancedTensorCoreDistanceMatrix::new().unwrap();
2718
2719        // Test with valid data
2720        let valid_result = array![[0.0, 1.0], [1.0, 0.0]];
2721        let accuracy = advanced_computer.estimate_result_accuracy(&valid_result);
2722        assert!(accuracy > 0.8 && accuracy <= 1.0);
2723
2724        // Test with invalid data (NaN)
2725        let invalid_result = array![[0.0, f64::NAN], [1.0, 0.0]];
2726        let accuracy = advanced_computer.estimate_result_accuracy(&invalid_result);
2727        assert_eq!(accuracy, 0.0);
2728
2729        // Test with high dynamic range data
2730        let high_range_result = array![[1e-10, 1e10], [1e5, 1e-5]];
2731        let accuracy = advanced_computer.estimate_result_accuracy(&high_range_result);
2732        assert!(accuracy > 0.0 && accuracy < 1.0);
2733    }
2734
2735    #[test]
2736    fn test_precision_mode_ordering() {
2737        // Test AdvancedAdaptive mode
2738        assert!(matches!(
2739            PrecisionMode::AdvancedAdaptive,
2740            PrecisionMode::AdvancedAdaptive
2741        ));
2742        assert_ne!(PrecisionMode::AdvancedAdaptive, PrecisionMode::Adaptive);
2743    }
2744
2745    #[test]
2746    fn test_stability_levels() {
2747        assert!(matches!(StabilityLevel::Critical, StabilityLevel::Critical));
2748        assert_ne!(StabilityLevel::Critical, StabilityLevel::Excellent);
2749    }
2750
2751    #[test]
2752    fn test_error_types() {
2753        let error_types = [
2754            NumericalErrorType::Overflow,
2755            NumericalErrorType::Underflow,
2756            NumericalErrorType::PrecisionLoss,
2757            NumericalErrorType::ConvergenceFailure,
2758            NumericalErrorType::IllConditioned,
2759            NumericalErrorType::InvalidValues,
2760        ];
2761
2762        assert_eq!(error_types.len(), 6);
2763        assert!(error_types.contains(&NumericalErrorType::Overflow));
2764    }
2765
2766    #[test]
2767    fn test_scaling_strategies() {
2768        let strategies = [
2769            ScalingStrategy::Conservative,
2770            ScalingStrategy::Balanced,
2771            ScalingStrategy::Aggressive,
2772            ScalingStrategy::Custom,
2773        ];
2774
2775        assert_eq!(strategies.len(), 4);
2776        assert!(strategies.contains(&ScalingStrategy::Balanced));
2777    }
2778
2779    #[test]
2780    fn test_recovery_actions() {
2781        let actions = [
2782            RecoveryAction::IncreasePrecision,
2783            RecoveryAction::ReduceTileSize,
2784            RecoveryAction::FallbackAlgorithm,
2785            RecoveryAction::NumericalStabilization,
2786            RecoveryAction::RetryWithNewParams,
2787            RecoveryAction::SwitchToCPU,
2788        ];
2789
2790        assert_eq!(actions.len(), 6);
2791        assert!(actions.contains(&RecoveryAction::IncreasePrecision));
2792    }
2793
2794    #[test]
2795    fn test_optimization_objectives() {
2796        let objectives = [
2797            OptimizationObjective::MaxPerformance,
2798            OptimizationObjective::MaxAccuracy,
2799            OptimizationObjective::Balanced,
2800            OptimizationObjective::MinEnergy,
2801            OptimizationObjective::Custom,
2802        ];
2803
2804        assert_eq!(objectives.len(), 5);
2805        assert!(objectives.contains(&OptimizationObjective::Balanced));
2806    }
2807}