Skip to main content

scirs2_core/array_protocol/
mixed_precision.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Mixed-precision operations for the array protocol.
8//!
9//! This module provides support for mixed-precision operations, allowing
10//! arrays to use different numeric types (e.g., f32, f64) for storage
11//! and computation to optimize performance and memory usage.
12
13use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::{cast as num_cast, Float};
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23    ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27/// Precision levels for array operations.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30    /// Half-precision floating point (16-bit)
31    Half,
32
33    /// Single-precision floating point (32-bit)
34    Single,
35
36    /// Double-precision floating point (64-bit)
37    Double,
38
39    /// Mixed precision (e.g., store in 16/32-bit, compute in 64-bit)
40    Mixed,
41}
42
43impl fmt::Display for Precision {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Precision::Half => write!(f, "half"),
47            Precision::Single => write!(f, "single"),
48            Precision::Double => write!(f, "double"),
49            Precision::Mixed => write!(f, "mixed"),
50        }
51    }
52}
53
54/// Configuration for mixed-precision operations.
55#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57    /// Storage precision for arrays.
58    pub storage_precision: Precision,
59
60    /// Computation precision for operations.
61    pub computeprecision: Precision,
62
63    /// Automatic precision selection based on array size and operation.
64    pub auto_precision: bool,
65
66    /// Threshold for automatic downcast to lower precision.
67    pub downcast_threshold: usize,
68
69    /// Always use double precision for intermediate results.
70    pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74    fn default() -> Self {
75        Self {
76            storage_precision: Precision::Single,
77            computeprecision: Precision::Double,
78            auto_precision: true,
79            downcast_threshold: 10_000_000, // 10M elements
80            double_precision_accumulation: true,
81        }
82    }
83}
84
85/// Global mixed-precision configuration.
86pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87    RwLock::new(MixedPrecisionConfig {
88        storage_precision: Precision::Single,
89        computeprecision: Precision::Double,
90        auto_precision: true,
91        downcast_threshold: 10_000_000, // 10M elements
92        double_precision_accumulation: true,
93    })
94});
95
96/// Set the global mixed-precision configuration.
97#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99    if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100        *global_config = config;
101    }
102}
103
104/// Get the current mixed-precision configuration.
105#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107    MIXED_PRECISION_CONFIG
108        .read()
109        .map(|c| c.clone())
110        .unwrap_or_default()
111}
112
113/// Determine the optimal precision for an array based on its size.
114#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117    T: Clone + 'static,
118    D: Dimension,
119{
120    let config = get_mixed_precision_config();
121    let size = array.len();
122
123    if config.auto_precision {
124        if size >= config.downcast_threshold {
125            Precision::Single
126        } else {
127            Precision::Double
128        }
129    } else {
130        config.storage_precision
131    }
132}
133
134/// Mixed-precision array that can automatically convert between precisions.
135///
136/// This wrapper enables arrays to use different precision levels for storage
137/// and computation, automatically converting between precisions as needed.
138#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141    T: Clone + 'static,
142    D: Dimension,
143{
144    /// The array stored at the specified precision.
145    array: Array<T, D>,
146
147    /// The current storage precision.
148    storage_precision: Precision,
149
150    /// The precision used for computations.
151    computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156    T: Clone + Float + 'static,
157    D: Dimension,
158{
159    /// Create a new mixed-precision array.
160    pub fn new(array: Array<T, D>) -> Self {
161        let precision = match std::mem::size_of::<T>() {
162            2 => Precision::Half,
163            4 => Precision::Single,
164            8 => Precision::Double,
165            _ => Precision::Mixed,
166        };
167
168        Self {
169            array,
170            storage_precision: precision,
171            computeprecision: precision,
172        }
173    }
174
175    /// Create a new mixed-precision array with specified compute precision.
176    pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177        let storage_precision = match std::mem::size_of::<T>() {
178            2 => Precision::Half,
179            4 => Precision::Single,
180            8 => Precision::Double,
181            _ => Precision::Mixed,
182        };
183
184        Self {
185            array: data,
186            storage_precision,
187            computeprecision,
188        }
189    }
190
191    /// Convert the array to a different floating-point precision `U`.
192    ///
193    /// Each element is cast from `T` to `U` using [`num_traits::cast`].  If any
194    /// element cannot be represented in `U` (e.g. an `f64` infinity cast to a
195    /// hypothetical narrow type) the method returns a
196    /// [`CoreError::ComputationError`].
197    ///
198    /// # Example
199    /// ```
200    /// use ndarray::array;
201    /// use scirs2_core::array_protocol::mixed_precision::MixedPrecisionArray;
202    ///
203    /// let arr = array![1.0_f64, 2.5_f64, 1.75_f64];
204    /// let mp = MixedPrecisionArray::new(arr.into_dyn());
205    /// let as_f32: ndarray::ArrayD<f32> = mp.at_precision()
206    ///     .expect("f64 -> f32 conversion should succeed");
207    /// assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
208    /// ```
209    pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
210    where
211        U: Clone + Float + 'static,
212    {
213        // ndarray does not have a fallible mapv, so we collect into a Vec<U> first.
214        let mut converted: Vec<U> = Vec::with_capacity(self.array.len());
215        for x in self.array.iter() {
216            match num_cast::<T, U>(*x) {
217                Some(v) => converted.push(v),
218                None => {
219                    return Err(CoreError::ComputationError(ErrorContext::new(format!(
220                        "at_precision: failed to cast element to target precision (source size \
221                         {} bytes, target size {} bytes)",
222                        std::mem::size_of::<T>(),
223                        std::mem::size_of::<U>(),
224                    ))))
225                }
226            }
227        }
228
229        // Reconstruct with the same shape.
230        Array::from_shape_vec(self.array.raw_dim(), converted).map_err(|e| {
231            CoreError::ShapeError(ErrorContext::new(format!(
232                "at_precision: failed to reconstruct array from converted elements: {e}"
233            )))
234        })
235    }
236
237    /// Get the current storage precision.
238    pub fn storage_precision(&self) -> Precision {
239        self.storage_precision
240    }
241
242    /// Get the underlying array.
243    pub const fn array(&self) -> &Array<T, D> {
244        &self.array
245    }
246}
247
248/// Trait for arrays that support mixed-precision operations.
249pub trait MixedPrecisionSupport: ArrayProtocol {
250    /// Convert the array to the specified precision.
251    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
252
253    /// Get the current precision of the array.
254    fn precision(&self) -> Precision;
255
256    /// Check if the array supports the specified precision.
257    fn supports_precision(&self, precision: Precision) -> bool;
258}
259
260/// Implement ArrayProtocol for MixedPrecisionArray.
261impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
262where
263    T: Clone + Float + Send + Sync + 'static,
264    D: Dimension + Send + Sync + 'static,
265{
266    fn array_function(
267        &self,
268        func: &ArrayFunction,
269        types: &[TypeId],
270        args: &[Box<dyn Any>],
271        kwargs: &HashMap<String, Box<dyn Any>>,
272    ) -> Result<Box<dyn Any>, NotImplemented> {
273        // If the function supports mixed precision, delegate to the appropriate implementation
274        let precision = kwargs
275            .get("precision")
276            .and_then(|p| p.downcast_ref::<Precision>())
277            .cloned()
278            .unwrap_or(self.computeprecision);
279
280        // Determine operating precision based on function and arguments
281        match func.name {
282            "scirs2::array_protocol::operations::matmul" => {
283                // If we have a second argument, check its precision
284                if args.len() >= 2 {
285                    // Adjust to highest precision of the two arrays
286                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
287                        let other_precision = other.computeprecision;
288                        let _precision_to_use = match (precision, other_precision) {
289                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
290                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
291                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
292                            (Precision::Half, Precision::Half) => Precision::Half,
293                        };
294
295                        // We can't modify kwargs, so we'll just forward directly
296                        // Get NdarrayWrapper for self array
297                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
298
299                        // Delegate to the NdarrayWrapper implementation
300                        return wrapped_self.array_function(func, types, args, kwargs);
301                    }
302                }
303
304                // Convert to the requested precision and use standard implementation
305                match precision {
306                    Precision::Single | Precision::Double => {
307                        // Wrap in NdarrayWrapper for computation
308                        let wrapped = NdarrayWrapper::new(self.array.clone());
309
310                        // Adjust args to use wrapped version
311                        let mut new_args = Vec::with_capacity(args.len());
312                        new_args.push(Box::new(wrapped.clone()));
313
314                        // We don't need to include other args since we already have a new wrapped object
315                        // For simplicity, just delegate to the original args
316                        // Delegate to NdarrayWrapper
317                        wrapped.array_function(func, types, args, kwargs)
318                    }
319                    Precision::Mixed => {
320                        // Use Double precision for Mixed calculations
321                        let wrapped = NdarrayWrapper::new(self.array.clone());
322
323                        // Create new args and kwargs with Double precision
324                        let mut new_args = Vec::with_capacity(args.len());
325                        new_args.push(Box::new(wrapped.clone()));
326
327                        // We can't modify kwargs, so just forward along
328                        // Delegate to NdarrayWrapper directly with original args and kwargs
329                        wrapped.array_function(func, types, args, kwargs)
330                    }
331                    _ => Err(NotImplemented),
332                }
333            }
334            "scirs2::array_protocol::operations::add"
335            | "scirs2::array_protocol::operations::subtract"
336            | "scirs2::array_protocol::operations::multiply" => {
337                // Similar pattern for element-wise operations
338                // If we have a second argument, check its precision
339                if args.len() >= 2 {
340                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
341                        // Use the highest precision for the operation
342                        let other_precision = other.computeprecision;
343                        let _precision_to_use = match (precision, other_precision) {
344                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
345                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
346                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
347                            (Precision::Half, Precision::Half) => Precision::Half,
348                        };
349
350                        // We can't modify kwargs, so we'll just forward directly
351                        // Get NdarrayWrapper for self array
352                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
353
354                        // Delegate to the NdarrayWrapper implementation
355                        return wrapped_self.array_function(func, types, args, kwargs);
356                    }
357                }
358
359                // Convert to the requested precision and use standard implementation
360                let wrapped = NdarrayWrapper::new(self.array.clone());
361
362                // Delegate to NdarrayWrapper with original args
363                wrapped.array_function(func, types, args, kwargs)
364            }
365            "scirs2::array_protocol::operations::transpose"
366            | "scirs2::array_protocol::operations::reshape"
367            | "scirs2::array_protocol::operations::sum" => {
368                // For unary operations, simply use the current precision
369                // Convert to standard wrapper and delegate
370                let wrapped = NdarrayWrapper::new(self.array.clone());
371
372                // Delegate to NdarrayWrapper with original args
373                wrapped.array_function(func, types, args, kwargs)
374            }
375            _ => {
376                // For any other function, delegate to standard implementation
377                let wrapped = NdarrayWrapper::new(self.array.clone());
378                wrapped.array_function(func, types, args, kwargs)
379            }
380        }
381    }
382
383    fn as_any(&self) -> &dyn Any {
384        self
385    }
386
387    fn shape(&self) -> &[usize] {
388        self.array.shape()
389    }
390
391    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
392        Box::new(Self {
393            array: self.array.clone(),
394            storage_precision: self.storage_precision,
395            computeprecision: self.computeprecision,
396        })
397    }
398}
399
400/// Implement MixedPrecisionSupport for MixedPrecisionArray.
401impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
402where
403    T: Clone + Float + Send + Sync + 'static,
404    D: Dimension + Send + Sync + 'static,
405{
406    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
407        match precision {
408            Precision::Single => {
409                // For actual implementation, this would convert f64 to f32 if needed
410                // This is a simplified version - in reality, we would need to convert between types
411
412                let current_precision = self.precision();
413                if current_precision == Precision::Single {
414                    // Already in single precision
415                    return Ok(Box::new(self.clone()));
416                }
417
418                // In real implementation, would handle proper conversion from T to f32
419                // For now, create a new array with the requested precision
420                let array_single = self.array.clone();
421                let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
422                Ok(Box::new(newarray))
423            }
424            Precision::Double => {
425                // For actual implementation, this would convert f32 to f64 if needed
426
427                let current_precision = self.precision();
428                if current_precision == Precision::Double {
429                    // Already in double precision
430                    return Ok(Box::new(self.clone()));
431                }
432
433                // In real implementation, would handle proper conversion from T to f64
434                // For now, create a new array with the requested precision
435                let array_double = self.array.clone();
436                let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
437                Ok(Box::new(newarray))
438            }
439            Precision::Mixed => {
440                // For mixed precision, use storage precision of the current array and double compute precision
441                let array_mixed = self.array.clone();
442                let newarray =
443                    MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
444                Ok(Box::new(newarray))
445            }
446            _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
447                "Conversion to {precision} precision not implemented"
448            )))),
449        }
450    }
451
452    fn precision(&self) -> Precision {
453        // If storage and compute precision differ, return Mixed
454        if self.storage_precision != self.computeprecision {
455            Precision::Mixed
456        } else {
457            self.storage_precision
458        }
459    }
460
461    fn supports_precision(&self, precision: Precision) -> bool {
462        matches!(precision, Precision::Single | Precision::Double)
463    }
464}
465
466/// Implement MixedPrecisionSupport for GPUNdarray.
467impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
468where
469    T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
470    D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
471{
472    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
473        // For GPUs, creating a new array with mixed precision enabled
474        let mut config = self.config().clone();
475        config.mixed_precision = precision == Precision::Mixed;
476
477        if let Ok(cpu_array) = self.to_cpu() {
478            // Use as_any() to downcast the ArrayProtocol trait object
479            if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
480                let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
481                return Ok(Box::new(new_gpu_array));
482            }
483        }
484
485        Err(CoreError::NotImplementedError(ErrorContext::new(format!(
486            "Conversion to {precision} precision not implemented for GPU arrays"
487        ))))
488    }
489
490    fn precision(&self) -> Precision {
491        if self.config().mixed_precision {
492            Precision::Mixed
493        } else {
494            match std::mem::size_of::<T>() {
495                4 => Precision::Single,
496                8 => Precision::Double,
497                _ => Precision::Mixed,
498            }
499        }
500    }
501
502    fn supports_precision(&self, precision: Precision) -> bool {
503        // Most GPUs support all precision levels
504        true
505    }
506}
507
508/// Execute an operation with a specific precision.
509///
510/// This function automatically converts arrays to the specified precision
511/// before executing the operation.
512#[allow(dead_code)]
513pub fn execute_with_precision<F, R>(
514    arrays: &[&dyn MixedPrecisionSupport],
515    precision: Precision,
516    executor: F,
517) -> CoreResult<R>
518where
519    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
520    R: 'static,
521{
522    // Check if all arrays support the requested precision
523    for array in arrays {
524        if !array.supports_precision(precision) {
525            return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
526                "One or more arrays do not support {precision} precision"
527            ))));
528        }
529    }
530
531    // Convert arrays to the requested precision
532    let mut converted_arrays = Vec::with_capacity(arrays.len());
533
534    for &array in arrays {
535        let converted = array.to_precision(precision)?;
536        converted_arrays.push(converted);
537    }
538
539    // NOTE: Trait upcasting is unstable, so we skip this for now
540    // This functionality is not critical for TenRSo
541    // TODO: Re-enable once trait_upcasting is stabilized (RFC #65991)
542
543    // Workaround: just return error for now
544    Err("Mixed precision batch execution not supported on stable Rust - requires trait_upcasting feature".to_string().into())
545}
546
547/// Implementation of common array operations with mixed precision.
548pub mod ops {
549    use super::*;
550    use crate::array_protocol::operations as array_ops;
551
552    /// Matrix multiplication with specified precision.
553    pub fn matmul(
554        a: &dyn MixedPrecisionSupport,
555        b: &dyn MixedPrecisionSupport,
556        precision: Precision,
557    ) -> CoreResult<Box<dyn ArrayProtocol>> {
558        execute_with_precision(&[a, b], precision, |arrays| {
559            // Convert OperationError to CoreError
560            match array_ops::matmul(arrays[0], arrays[1]) {
561                Ok(result) => Ok(result),
562                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
563                    e.to_string(),
564                ))),
565            }
566        })
567    }
568
569    /// Element-wise addition with specified precision.
570    pub fn add(
571        a: &dyn MixedPrecisionSupport,
572        b: &dyn MixedPrecisionSupport,
573        precision: Precision,
574    ) -> CoreResult<Box<dyn ArrayProtocol>> {
575        execute_with_precision(&[a, b], precision, |arrays| {
576            // Convert OperationError to CoreError
577            match array_ops::add(arrays[0], arrays[1]) {
578                Ok(result) => Ok(result),
579                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
580                    e.to_string(),
581                ))),
582            }
583        })
584    }
585
586    /// Element-wise multiplication with specified precision.
587    pub fn multiply(
588        a: &dyn MixedPrecisionSupport,
589        b: &dyn MixedPrecisionSupport,
590        precision: Precision,
591    ) -> CoreResult<Box<dyn ArrayProtocol>> {
592        execute_with_precision(&[a, b], precision, |arrays| {
593            // Convert OperationError to CoreError
594            match array_ops::multiply(arrays[0], arrays[1]) {
595                Ok(result) => Ok(result),
596                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
597                    e.to_string(),
598                ))),
599            }
600        })
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use ::ndarray::arr2;
608
609    #[test]
610    fn test_mixed_precision_array() {
611        // Create a mixed-precision array
612        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
613        let mixed_array = MixedPrecisionArray::new(array.clone());
614
615        // Check the storage precision (should be double for f64 arrays)
616        assert_eq!(mixed_array.storage_precision(), Precision::Double);
617
618        // Test the ArrayProtocol implementation
619        let array_protocol: &dyn ArrayProtocol = &mixed_array;
620        // The array is of type MixedPrecisionArray<f64, Ix2> (not IxDyn)
621        assert!(array_protocol
622            .as_any()
623            .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
624    }
625
626    #[test]
627    fn test_mixed_precision_support() {
628        // Initialize the array protocol
629        crate::array_protocol::init();
630
631        // Create a mixed-precision array
632        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
633        let mixed_array = MixedPrecisionArray::new(array.clone());
634
635        // Test MixedPrecisionSupport implementation
636        let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
637        assert_eq!(mixed_support.precision(), Precision::Double);
638        assert!(mixed_support.supports_precision(Precision::Single));
639        assert!(mixed_support.supports_precision(Precision::Double));
640    }
641
642    // ── at_precision tests ───────────────────────────────────────────────────
643
644    /// Downcast f64 → f32: values should be preserved within f32 precision.
645    #[test]
646    fn test_at_precision_f64_to_f32() {
647        use ::ndarray::array;
648        // Use values that are not approximate constants recognized by clippy.
649        let arr = array![1.0_f64, 2.5_f64, -1.75_f64].into_dyn();
650        let mp = MixedPrecisionArray::new(arr);
651        let as_f32: crate::ndarray::ArrayD<f32> = mp
652            .at_precision()
653            .expect("f64 → f32 precision conversion should succeed");
654        assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
655        assert!((as_f32[1] - 2.5_f32).abs() < 1e-6);
656        assert!((as_f32[2] - (-1.75_f32)).abs() < 1e-6);
657    }
658
659    /// Upcast f32 → f64: precision should be maintained.
660    #[test]
661    fn test_at_precision_f32_to_f64() {
662        use ::ndarray::array;
663        let arr = array![0.5_f32, 1.25_f32, -2.0_f32].into_dyn();
664        let mp = MixedPrecisionArray::new(arr);
665        let as_f64: crate::ndarray::ArrayD<f64> = mp
666            .at_precision()
667            .expect("f32 → f64 precision conversion should succeed");
668        assert!((as_f64[0] - 0.5_f64).abs() < 1e-12);
669        assert!((as_f64[1] - 1.25_f64).abs() < 1e-12);
670        assert!((as_f64[2] - (-2.0_f64)).abs() < 1e-12);
671    }
672
673    /// Identity conversion f64 → f64 should be a no-op.
674    #[test]
675    fn test_at_precision_same_type_is_identity() {
676        use ::ndarray::array;
677        let arr = array![42.0_f64, -7.5_f64].into_dyn();
678        let mp = MixedPrecisionArray::new(arr.clone());
679        let result: crate::ndarray::ArrayD<f64> = mp
680            .at_precision()
681            .expect("f64 → f64 precision conversion should succeed");
682        for (a, b) in arr.iter().zip(result.iter()) {
683            assert_eq!(*a, *b, "Identity conversion must not change values");
684        }
685    }
686
687    /// 2-D array conversion preserves shape.
688    #[test]
689    fn test_at_precision_preserves_shape() {
690        let arr = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
691        let mp = MixedPrecisionArray::new(arr);
692        let as_f32: crate::ndarray::Array<f32, crate::ndarray::Ix2> = mp
693            .at_precision()
694            .expect("2D f64 → f32 conversion should succeed");
695        assert_eq!(as_f32.shape(), &[2, 2]);
696        assert!((as_f32[[0, 0]] - 1.0_f32).abs() < 1e-6);
697        assert!((as_f32[[1, 1]] - 4.0_f32).abs() < 1e-6);
698    }
699}