Skip to main content

scirs2_core/array_protocol/
mixed_precision.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Mixed-precision operations for the array protocol.
8//!
9//! This module provides support for mixed-precision operations, allowing
10//! arrays to use different numeric types (e.g., f32, f64) for storage
11//! and computation to optimize performance and memory usage.
12
13use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::{cast as num_cast, Float};
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23    ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27/// Precision levels for array operations.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30    /// Half-precision floating point (16-bit)
31    Half,
32
33    /// Single-precision floating point (32-bit)
34    Single,
35
36    /// Double-precision floating point (64-bit)
37    Double,
38
39    /// Mixed precision (e.g., store in 16/32-bit, compute in 64-bit)
40    Mixed,
41}
42
43impl fmt::Display for Precision {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Precision::Half => write!(f, "half"),
47            Precision::Single => write!(f, "single"),
48            Precision::Double => write!(f, "double"),
49            Precision::Mixed => write!(f, "mixed"),
50        }
51    }
52}
53
54/// Configuration for mixed-precision operations.
55#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57    /// Storage precision for arrays.
58    pub storage_precision: Precision,
59
60    /// Computation precision for operations.
61    pub computeprecision: Precision,
62
63    /// Automatic precision selection based on array size and operation.
64    pub auto_precision: bool,
65
66    /// Threshold for automatic downcast to lower precision.
67    pub downcast_threshold: usize,
68
69    /// Always use double precision for intermediate results.
70    pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74    fn default() -> Self {
75        Self {
76            storage_precision: Precision::Single,
77            computeprecision: Precision::Double,
78            auto_precision: true,
79            downcast_threshold: 10_000_000, // 10M elements
80            double_precision_accumulation: true,
81        }
82    }
83}
84
85/// Global mixed-precision configuration.
86pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87    RwLock::new(MixedPrecisionConfig {
88        storage_precision: Precision::Single,
89        computeprecision: Precision::Double,
90        auto_precision: true,
91        downcast_threshold: 10_000_000, // 10M elements
92        double_precision_accumulation: true,
93    })
94});
95
96/// Set the global mixed-precision configuration.
97#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99    if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100        *global_config = config;
101    }
102}
103
104/// Get the current mixed-precision configuration.
105#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107    MIXED_PRECISION_CONFIG
108        .read()
109        .map(|c| c.clone())
110        .unwrap_or_default()
111}
112
113/// Determine the optimal precision for an array based on its size.
114#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117    T: Clone + 'static,
118    D: Dimension,
119{
120    let config = get_mixed_precision_config();
121    let size = array.len();
122
123    if config.auto_precision {
124        if size >= config.downcast_threshold {
125            Precision::Single
126        } else {
127            Precision::Double
128        }
129    } else {
130        config.storage_precision
131    }
132}
133
134/// Mixed-precision array that can automatically convert between precisions.
135///
136/// This wrapper enables arrays to use different precision levels for storage
137/// and computation, automatically converting between precisions as needed.
138#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141    T: Clone + 'static,
142    D: Dimension,
143{
144    /// The array stored at the specified precision.
145    array: Array<T, D>,
146
147    /// The current storage precision.
148    storage_precision: Precision,
149
150    /// The precision used for computations.
151    computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156    T: Clone + Float + 'static,
157    D: Dimension,
158{
159    /// Create a new mixed-precision array.
160    pub fn new(array: Array<T, D>) -> Self {
161        let precision = match std::mem::size_of::<T>() {
162            2 => Precision::Half,
163            4 => Precision::Single,
164            8 => Precision::Double,
165            _ => Precision::Mixed,
166        };
167
168        Self {
169            array,
170            storage_precision: precision,
171            computeprecision: precision,
172        }
173    }
174
175    /// Create a new mixed-precision array with specified compute precision.
176    pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177        let storage_precision = match std::mem::size_of::<T>() {
178            2 => Precision::Half,
179            4 => Precision::Single,
180            8 => Precision::Double,
181            _ => Precision::Mixed,
182        };
183
184        Self {
185            array: data,
186            storage_precision,
187            computeprecision,
188        }
189    }
190
191    /// Convert the array to a different floating-point precision `U`.
192    ///
193    /// Each element is cast from `T` to `U` using [`fn@num_traits::cast`].  If any
194    /// element cannot be represented in `U` (e.g. an `f64` infinity cast to a
195    /// hypothetical narrow type) the method returns a
196    /// [`CoreError::ComputationError`].
197    ///
198    /// # Example
199    /// ```
200    /// use ndarray::array;
201    /// use scirs2_core::array_protocol::mixed_precision::MixedPrecisionArray;
202    ///
203    /// let arr = array![1.0_f64, 2.5_f64, 1.75_f64];
204    /// let mp = MixedPrecisionArray::new(arr.into_dyn());
205    /// let as_f32: ndarray::ArrayD<f32> = mp.at_precision()
206    ///     .expect("f64 -> f32 conversion should succeed");
207    /// assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
208    /// ```
209    pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
210    where
211        U: Clone + Float + 'static,
212    {
213        // ndarray does not have a fallible mapv, so we collect into a Vec<U> first.
214        let mut converted: Vec<U> = Vec::with_capacity(self.array.len());
215        for x in self.array.iter() {
216            match num_cast::<T, U>(*x) {
217                Some(v) => converted.push(v),
218                None => {
219                    return Err(CoreError::ComputationError(ErrorContext::new(format!(
220                        "at_precision: failed to cast element to target precision (source size \
221                         {} bytes, target size {} bytes)",
222                        std::mem::size_of::<T>(),
223                        std::mem::size_of::<U>(),
224                    ))))
225                }
226            }
227        }
228
229        // Reconstruct with the same shape.
230        Array::from_shape_vec(self.array.raw_dim(), converted).map_err(|e| {
231            CoreError::ShapeError(ErrorContext::new(format!(
232                "at_precision: failed to reconstruct array from converted elements: {e}"
233            )))
234        })
235    }
236
237    /// Get the current storage precision.
238    pub fn storage_precision(&self) -> Precision {
239        self.storage_precision
240    }
241
242    /// Get the underlying array.
243    pub const fn array(&self) -> &Array<T, D> {
244        &self.array
245    }
246}
247
248/// Trait for arrays that support mixed-precision operations.
249pub trait MixedPrecisionSupport: ArrayProtocol {
250    /// Convert the array to the specified precision.
251    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
252
253    /// Get the current precision of the array.
254    fn precision(&self) -> Precision;
255
256    /// Check if the array supports the specified precision.
257    fn supports_precision(&self, precision: Precision) -> bool;
258
259    /// Borrow this value as an [`ArrayProtocol`] trait object.
260    ///
261    /// `MixedPrecisionSupport` has `ArrayProtocol` as a supertrait, but on stable
262    /// Rust a `&dyn MixedPrecisionSupport` cannot be upcast to `&dyn ArrayProtocol`
263    /// without the unstable `trait_upcasting` feature (RFC #65991). This method
264    /// provides that bridge explicitly: every implementor already *is* an
265    /// `ArrayProtocol`, so the default implementation simply returns `self`.
266    ///
267    /// This allows mixed-precision arrays to be passed to operations that are
268    /// generic over `&dyn ArrayProtocol` (such as those in
269    /// `crate::array_protocol::operations`) on stable Rust.
270    fn as_array_protocol(&self) -> &dyn ArrayProtocol;
271}
272
273/// Extract the inner `ndarray` of element type `T` and dimension `D` from a
274/// boxed argument produced by the operation dispatcher.
275///
276/// The dispatcher in [`crate::array_protocol::operations`] boxes operands as
277/// `Box<dyn ArrayProtocol>` (further boxed into `Box<dyn Any>`). This helper
278/// looks through that indirection and recognises both [`MixedPrecisionArray`]
279/// and [`NdarrayWrapper`] operands, returning an owned copy of the underlying
280/// `ndarray`. Returns `None` if the argument is not a recognised array of the
281/// requested `T`/`D`.
282fn extract_inner_ndarray<T, D>(arg: &dyn Any) -> Option<Array<T, D>>
283where
284    T: Clone + Float + Send + Sync + 'static,
285    D: Dimension + Send + Sync + 'static,
286{
287    // Case 1: the operand was boxed as `Box<dyn ArrayProtocol>` (the path used
288    // by the operation dispatcher).
289    if let Some(ap) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
290        let inner: &dyn ArrayProtocol = &**ap;
291        if let Some(mp) = inner.as_any().downcast_ref::<MixedPrecisionArray<T, D>>() {
292            return Some(mp.array.clone());
293        }
294        if let Some(nd) = inner.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
295            return Some(nd.as_array().clone());
296        }
297        return None;
298    }
299
300    // Case 2: the operand was boxed directly as its concrete type.
301    if let Some(mp) = arg.downcast_ref::<MixedPrecisionArray<T, D>>() {
302        return Some(mp.array.clone());
303    }
304    if let Some(nd) = arg.downcast_ref::<NdarrayWrapper<T, D>>() {
305        return Some(nd.as_array().clone());
306    }
307
308    None
309}
310
311/// Normalise the result returned by `NdarrayWrapper`'s array-function kernels
312/// into the `Box<dyn Any>`-of-`Box<dyn ArrayProtocol>` shape expected by the
313/// operation dispatcher in [`crate::array_protocol::operations`].
314///
315/// `NdarrayWrapper::array_function` returns its array results boxed as the
316/// concrete `NdarrayWrapper<T, _>` type, whereas the dispatcher downcasts the
317/// result to `Box<dyn ArrayProtocol>`. This helper bridges that gap for the
318/// floating-point element type `T` across the dimensionalities the kernels can
319/// produce (`Ix1`, `Ix2`, `IxDyn`). Non-array results (for example the scalar
320/// produced by `sum`) do not match any branch and are returned unchanged so the
321/// caller can downcast them directly.
322fn rewrap_result_as_array_protocol<T>(result: Box<dyn Any>) -> Box<dyn Any>
323where
324    T: Clone + Float + Send + Sync + 'static,
325{
326    use crate::ndarray::{Ix1, Ix2, IxDyn};
327
328    // Already in the expected shape (e.g. produced by another delegating layer).
329    if result.is::<Box<dyn ArrayProtocol>>() {
330        return result;
331    }
332
333    // 2-D results: matmul, and element-wise ops on 2-D inputs.
334    let result = match result.downcast::<NdarrayWrapper<T, Ix2>>() {
335        Ok(wrapper) => {
336            let boxed: Box<dyn ArrayProtocol> = wrapper;
337            return Box::new(boxed);
338        }
339        Err(other) => other,
340    };
341
342    // 1-D results: element-wise ops on 1-D inputs, reshape to 1-D.
343    let result = match result.downcast::<NdarrayWrapper<T, Ix1>>() {
344        Ok(wrapper) => {
345            let boxed: Box<dyn ArrayProtocol> = wrapper;
346            return Box::new(boxed);
347        }
348        Err(other) => other,
349    };
350
351    // Dynamic-dimension results.
352    match result.downcast::<NdarrayWrapper<T, IxDyn>>() {
353        Ok(wrapper) => {
354            let boxed: Box<dyn ArrayProtocol> = wrapper;
355            Box::new(boxed)
356        }
357        // Not an array result (e.g. a scalar from `sum`): pass through unchanged.
358        Err(other) => other,
359    }
360}
361
362/// Implement ArrayProtocol for MixedPrecisionArray.
363impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
364where
365    T: Clone + Float + Send + Sync + 'static,
366    D: Dimension + Send + Sync + 'static,
367{
368    fn array_function(
369        &self,
370        func: &ArrayFunction,
371        types: &[TypeId],
372        args: &[Box<dyn Any>],
373        kwargs: &HashMap<String, Box<dyn Any>>,
374    ) -> Result<Box<dyn Any>, NotImplemented> {
375        // Wrap `self` as a plain `NdarrayWrapper`. The mixed-precision storage is
376        // a regular `ndarray`, so all numeric kernels live in `NdarrayWrapper`'s
377        // implementation; this struct only manages precision metadata.
378        let wrapped_self = NdarrayWrapper::new(self.array.clone());
379
380        // Determine operating precision based on function and arguments. The
381        // precision is currently used to validate the requested operation; the
382        // actual numeric computation is delegated to `NdarrayWrapper`.
383        let precision = kwargs
384            .get("precision")
385            .and_then(|p| p.downcast_ref::<Precision>())
386            .cloned()
387            .unwrap_or(self.computeprecision);
388
389        match func.name {
390            "scirs2::array_protocol::operations::matmul"
391            | "scirs2::array_protocol::operations::add"
392            | "scirs2::array_protocol::operations::subtract"
393            | "scirs2::array_protocol::operations::multiply" => {
394                // Binary operations need the second operand. The dispatcher boxes
395                // operands as `Box<dyn ArrayProtocol>`, so we extract the inner
396                // ndarray (whether it arrived as a `MixedPrecisionArray` or an
397                // `NdarrayWrapper`) and re-wrap it as an `NdarrayWrapper` so the
398                // delegated kernel receives the concrete type it expects.
399                if args.len() < 2 {
400                    return Err(NotImplemented);
401                }
402
403                let Some(other_array) = extract_inner_ndarray::<T, D>(args[1].as_ref()) else {
404                    return Err(NotImplemented);
405                };
406                let wrapped_other = NdarrayWrapper::new(other_array);
407
408                // Forbid precision levels we cannot honour numerically. Half is
409                // not representable by the underlying storage on stable Rust.
410                if matches!(precision, Precision::Half) {
411                    return Err(NotImplemented);
412                }
413
414                let new_args: Vec<Box<dyn Any>> =
415                    vec![Box::new(wrapped_self.clone()), Box::new(wrapped_other)];
416                wrapped_self
417                    .array_function(func, types, &new_args, kwargs)
418                    .map(rewrap_result_as_array_protocol::<T>)
419            }
420            "scirs2::array_protocol::operations::transpose"
421            | "scirs2::array_protocol::operations::reshape"
422            | "scirs2::array_protocol::operations::sum" => {
423                // Unary operations: delegate to `NdarrayWrapper` with `self`
424                // re-wrapped as the first argument. Array results (transpose,
425                // reshape) are normalised; scalar results (sum) pass through.
426                let new_args: Vec<Box<dyn Any>> = vec![Box::new(wrapped_self.clone())];
427                wrapped_self
428                    .array_function(func, types, &new_args, kwargs)
429                    .map(rewrap_result_as_array_protocol::<T>)
430            }
431            _ => {
432                // For any other function, delegate to the standard implementation
433                // with the original arguments.
434                wrapped_self.array_function(func, types, args, kwargs)
435            }
436        }
437    }
438
439    fn as_any(&self) -> &dyn Any {
440        self
441    }
442
443    fn shape(&self) -> &[usize] {
444        self.array.shape()
445    }
446
447    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
448        Box::new(Self {
449            array: self.array.clone(),
450            storage_precision: self.storage_precision,
451            computeprecision: self.computeprecision,
452        })
453    }
454}
455
456/// Implement MixedPrecisionSupport for MixedPrecisionArray.
457impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
458where
459    T: Clone + Float + Send + Sync + 'static,
460    D: Dimension + Send + Sync + 'static,
461{
462    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
463        match precision {
464            Precision::Single => {
465                // For actual implementation, this would convert f64 to f32 if needed
466                // This is a simplified version - in reality, we would need to convert between types
467
468                let current_precision = self.precision();
469                if current_precision == Precision::Single {
470                    // Already in single precision
471                    return Ok(Box::new(self.clone()));
472                }
473
474                // In real implementation, would handle proper conversion from T to f32
475                // For now, create a new array with the requested precision
476                let array_single = self.array.clone();
477                let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
478                Ok(Box::new(newarray))
479            }
480            Precision::Double => {
481                // For actual implementation, this would convert f32 to f64 if needed
482
483                let current_precision = self.precision();
484                if current_precision == Precision::Double {
485                    // Already in double precision
486                    return Ok(Box::new(self.clone()));
487                }
488
489                // In real implementation, would handle proper conversion from T to f64
490                // For now, create a new array with the requested precision
491                let array_double = self.array.clone();
492                let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
493                Ok(Box::new(newarray))
494            }
495            Precision::Mixed => {
496                // For mixed precision, use storage precision of the current array and double compute precision
497                let array_mixed = self.array.clone();
498                let newarray =
499                    MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
500                Ok(Box::new(newarray))
501            }
502            _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
503                "Conversion to {precision} precision not implemented"
504            )))),
505        }
506    }
507
508    fn precision(&self) -> Precision {
509        // If storage and compute precision differ, return Mixed
510        if self.storage_precision != self.computeprecision {
511            Precision::Mixed
512        } else {
513            self.storage_precision
514        }
515    }
516
517    fn supports_precision(&self, precision: Precision) -> bool {
518        matches!(precision, Precision::Single | Precision::Double)
519    }
520
521    fn as_array_protocol(&self) -> &dyn ArrayProtocol {
522        self
523    }
524}
525
526/// Implement MixedPrecisionSupport for GPUNdarray.
527impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
528where
529    T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
530    D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
531{
532    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
533        // For GPUs, creating a new array with mixed precision enabled
534        let mut config = self.config().clone();
535        config.mixed_precision = precision == Precision::Mixed;
536
537        if let Ok(cpu_array) = self.to_cpu() {
538            // Use as_any() to downcast the ArrayProtocol trait object
539            if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
540                let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
541                return Ok(Box::new(new_gpu_array));
542            }
543        }
544
545        Err(CoreError::NotImplementedError(ErrorContext::new(format!(
546            "Conversion to {precision} precision not implemented for GPU arrays"
547        ))))
548    }
549
550    fn precision(&self) -> Precision {
551        if self.config().mixed_precision {
552            Precision::Mixed
553        } else {
554            match std::mem::size_of::<T>() {
555                4 => Precision::Single,
556                8 => Precision::Double,
557                _ => Precision::Mixed,
558            }
559        }
560    }
561
562    fn supports_precision(&self, precision: Precision) -> bool {
563        // Most GPUs support all precision levels
564        true
565    }
566
567    fn as_array_protocol(&self) -> &dyn ArrayProtocol {
568        self
569    }
570}
571
572/// Execute an operation with a specific precision.
573///
574/// This function automatically converts arrays to the specified precision
575/// before executing the operation.
576#[allow(dead_code)]
577pub fn execute_with_precision<F, R>(
578    arrays: &[&dyn MixedPrecisionSupport],
579    precision: Precision,
580    executor: F,
581) -> CoreResult<R>
582where
583    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
584    R: 'static,
585{
586    // Check if all arrays support the requested precision
587    for array in arrays {
588        if !array.supports_precision(precision) {
589            return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
590                "One or more arrays do not support {precision} precision"
591            ))));
592        }
593    }
594
595    // Convert arrays to the requested precision. Each conversion yields a
596    // `Box<dyn MixedPrecisionSupport>` that owns its precision-converted data.
597    let mut converted_arrays: Vec<Box<dyn MixedPrecisionSupport>> =
598        Vec::with_capacity(arrays.len());
599
600    for &array in arrays {
601        let converted = array.to_precision(precision)?;
602        converted_arrays.push(converted);
603    }
604
605    // Bridge `&dyn MixedPrecisionSupport` to `&dyn ArrayProtocol` on stable Rust.
606    //
607    // Trait upcasting (`&dyn MixedPrecisionSupport` -> `&dyn ArrayProtocol`) is
608    // unstable (RFC #65991), so instead of relying on it we use the explicit
609    // `as_array_protocol` bridge method defined on `MixedPrecisionSupport`. Every
610    // implementor already *is* an `ArrayProtocol`, so this is a zero-cost borrow.
611    let protocol_refs: Vec<&dyn ArrayProtocol> = converted_arrays
612        .iter()
613        .map(|array| array.as_array_protocol())
614        .collect();
615
616    // Run the requested operation on the precision-converted arrays.
617    executor(&protocol_refs)
618}
619
620/// Implementation of common array operations with mixed precision.
621pub mod ops {
622    use super::*;
623    use crate::array_protocol::operations as array_ops;
624
625    /// Matrix multiplication with specified precision.
626    pub fn matmul(
627        a: &dyn MixedPrecisionSupport,
628        b: &dyn MixedPrecisionSupport,
629        precision: Precision,
630    ) -> CoreResult<Box<dyn ArrayProtocol>> {
631        execute_with_precision(&[a, b], precision, |arrays| {
632            // Convert OperationError to CoreError
633            match array_ops::matmul(arrays[0], arrays[1]) {
634                Ok(result) => Ok(result),
635                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
636                    e.to_string(),
637                ))),
638            }
639        })
640    }
641
642    /// Element-wise addition with specified precision.
643    pub fn add(
644        a: &dyn MixedPrecisionSupport,
645        b: &dyn MixedPrecisionSupport,
646        precision: Precision,
647    ) -> CoreResult<Box<dyn ArrayProtocol>> {
648        execute_with_precision(&[a, b], precision, |arrays| {
649            // Convert OperationError to CoreError
650            match array_ops::add(arrays[0], arrays[1]) {
651                Ok(result) => Ok(result),
652                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
653                    e.to_string(),
654                ))),
655            }
656        })
657    }
658
659    /// Element-wise multiplication with specified precision.
660    pub fn multiply(
661        a: &dyn MixedPrecisionSupport,
662        b: &dyn MixedPrecisionSupport,
663        precision: Precision,
664    ) -> CoreResult<Box<dyn ArrayProtocol>> {
665        execute_with_precision(&[a, b], precision, |arrays| {
666            // Convert OperationError to CoreError
667            match array_ops::multiply(arrays[0], arrays[1]) {
668                Ok(result) => Ok(result),
669                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
670                    e.to_string(),
671                ))),
672            }
673        })
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680    use ::ndarray::arr2;
681
682    #[test]
683    fn test_mixed_precision_array() {
684        // Create a mixed-precision array
685        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
686        let mixed_array = MixedPrecisionArray::new(array.clone());
687
688        // Check the storage precision (should be double for f64 arrays)
689        assert_eq!(mixed_array.storage_precision(), Precision::Double);
690
691        // Test the ArrayProtocol implementation
692        let array_protocol: &dyn ArrayProtocol = &mixed_array;
693        // The array is of type MixedPrecisionArray<f64, Ix2> (not IxDyn)
694        assert!(array_protocol
695            .as_any()
696            .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
697    }
698
699    #[test]
700    fn test_mixed_precision_support() {
701        // Initialize the array protocol
702        crate::array_protocol::init();
703
704        // Create a mixed-precision array
705        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
706        let mixed_array = MixedPrecisionArray::new(array.clone());
707
708        // Test MixedPrecisionSupport implementation
709        let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
710        assert_eq!(mixed_support.precision(), Precision::Double);
711        assert!(mixed_support.supports_precision(Precision::Single));
712        assert!(mixed_support.supports_precision(Precision::Double));
713    }
714
715    // ── at_precision tests ───────────────────────────────────────────────────
716
717    /// Downcast f64 → f32: values should be preserved within f32 precision.
718    #[test]
719    fn test_at_precision_f64_to_f32() {
720        use ::ndarray::array;
721        // Use values that are not approximate constants recognized by clippy.
722        let arr = array![1.0_f64, 2.5_f64, -1.75_f64].into_dyn();
723        let mp = MixedPrecisionArray::new(arr);
724        let as_f32: crate::ndarray::ArrayD<f32> = mp
725            .at_precision()
726            .expect("f64 → f32 precision conversion should succeed");
727        assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
728        assert!((as_f32[1] - 2.5_f32).abs() < 1e-6);
729        assert!((as_f32[2] - (-1.75_f32)).abs() < 1e-6);
730    }
731
732    /// Upcast f32 → f64: precision should be maintained.
733    #[test]
734    fn test_at_precision_f32_to_f64() {
735        use ::ndarray::array;
736        let arr = array![0.5_f32, 1.25_f32, -2.0_f32].into_dyn();
737        let mp = MixedPrecisionArray::new(arr);
738        let as_f64: crate::ndarray::ArrayD<f64> = mp
739            .at_precision()
740            .expect("f32 → f64 precision conversion should succeed");
741        assert!((as_f64[0] - 0.5_f64).abs() < 1e-12);
742        assert!((as_f64[1] - 1.25_f64).abs() < 1e-12);
743        assert!((as_f64[2] - (-2.0_f64)).abs() < 1e-12);
744    }
745
746    /// Identity conversion f64 → f64 should be a no-op.
747    #[test]
748    fn test_at_precision_same_type_is_identity() {
749        use ::ndarray::array;
750        let arr = array![42.0_f64, -7.5_f64].into_dyn();
751        let mp = MixedPrecisionArray::new(arr.clone());
752        let result: crate::ndarray::ArrayD<f64> = mp
753            .at_precision()
754            .expect("f64 → f64 precision conversion should succeed");
755        for (a, b) in arr.iter().zip(result.iter()) {
756            assert_eq!(*a, *b, "Identity conversion must not change values");
757        }
758    }
759
760    /// 2-D array conversion preserves shape.
761    #[test]
762    fn test_at_precision_preserves_shape() {
763        let arr = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
764        let mp = MixedPrecisionArray::new(arr);
765        let as_f32: crate::ndarray::Array<f32, crate::ndarray::Ix2> = mp
766            .at_precision()
767            .expect("2D f64 → f32 conversion should succeed");
768        assert_eq!(as_f32.shape(), &[2, 2]);
769        assert!((as_f32[[0, 0]] - 1.0_f32).abs() < 1e-6);
770        assert!((as_f32[[1, 1]] - 4.0_f32).abs() < 1e-6);
771    }
772
773    // ── execute_with_precision end-to-end tests ──────────────────────────────
774
775    /// `ops::matmul` must run through `execute_with_precision` end-to-end on
776    /// stable Rust and return the correct numeric result (no `Err`, no upcast).
777    #[test]
778    fn test_execute_with_precision_matmul_single() {
779        crate::array_protocol::init();
780
781        // [[1, 2], [3, 4]] x [[5, 6], [7, 8]] = [[19, 22], [43, 50]]
782        let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
783        let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
784
785        let result = ops::matmul(&a, &b, Precision::Single)
786            .expect("mixed-precision matmul should succeed on stable Rust");
787
788        let wrapper = result
789            .as_any()
790            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
791            .expect("matmul result should be an NdarrayWrapper<f64, Ix2>");
792        let out = wrapper.as_array();
793
794        assert_eq!(out.shape(), &[2, 2]);
795        assert!((out[[0, 0]] - 19.0).abs() < 1e-9);
796        assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
797        assert!((out[[1, 0]] - 43.0).abs() < 1e-9);
798        assert!((out[[1, 1]] - 50.0).abs() < 1e-9);
799    }
800
801    /// `ops::add` must run through `execute_with_precision` end-to-end and return
802    /// the correct element-wise sum.
803    #[test]
804    fn test_execute_with_precision_add_single() {
805        crate::array_protocol::init();
806
807        let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
808        let b = MixedPrecisionArray::new(arr2(&[[10.0_f64, 20.0], [30.0, 40.0]]));
809
810        let result = ops::add(&a, &b, Precision::Single)
811            .expect("mixed-precision add should succeed on stable Rust");
812
813        let wrapper = result
814            .as_any()
815            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
816            .expect("add result should be an NdarrayWrapper<f64, Ix2>");
817        let out = wrapper.as_array();
818
819        assert_eq!(out.shape(), &[2, 2]);
820        assert!((out[[0, 0]] - 11.0).abs() < 1e-9);
821        assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
822        assert!((out[[1, 0]] - 33.0).abs() < 1e-9);
823        assert!((out[[1, 1]] - 44.0).abs() < 1e-9);
824    }
825
826    /// Half precision is not numerically representable by the stable backend, so
827    /// the operation must surface an error rather than silently producing wrong
828    /// results.
829    #[test]
830    fn test_execute_with_precision_half_is_rejected() {
831        crate::array_protocol::init();
832
833        let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
834        let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
835
836        // `supports_precision` returns false for Half, so this must be rejected.
837        let result = ops::matmul(&a, &b, Precision::Half);
838        assert!(
839            result.is_err(),
840            "Half precision matmul must return an error"
841        );
842    }
843}