Skip to main content

scirs2_core/array_protocol/
mixed_precision.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Mixed-precision operations for the array protocol.
8//!
9//! This module provides support for mixed-precision operations, allowing
10//! arrays to use different numeric types (e.g., f32, f64) for storage
11//! and computation to optimize performance and memory usage.
12
13use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::Float;
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23    ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27/// Precision levels for array operations.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30    /// Half-precision floating point (16-bit)
31    Half,
32
33    /// Single-precision floating point (32-bit)
34    Single,
35
36    /// Double-precision floating point (64-bit)
37    Double,
38
39    /// Mixed precision (e.g., store in 16/32-bit, compute in 64-bit)
40    Mixed,
41}
42
43impl fmt::Display for Precision {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Precision::Half => write!(f, "half"),
47            Precision::Single => write!(f, "single"),
48            Precision::Double => write!(f, "double"),
49            Precision::Mixed => write!(f, "mixed"),
50        }
51    }
52}
53
54/// Configuration for mixed-precision operations.
55#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57    /// Storage precision for arrays.
58    pub storage_precision: Precision,
59
60    /// Computation precision for operations.
61    pub computeprecision: Precision,
62
63    /// Automatic precision selection based on array size and operation.
64    pub auto_precision: bool,
65
66    /// Threshold for automatic downcast to lower precision.
67    pub downcast_threshold: usize,
68
69    /// Always use double precision for intermediate results.
70    pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74    fn default() -> Self {
75        Self {
76            storage_precision: Precision::Single,
77            computeprecision: Precision::Double,
78            auto_precision: true,
79            downcast_threshold: 10_000_000, // 10M elements
80            double_precision_accumulation: true,
81        }
82    }
83}
84
85/// Global mixed-precision configuration.
86pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87    RwLock::new(MixedPrecisionConfig {
88        storage_precision: Precision::Single,
89        computeprecision: Precision::Double,
90        auto_precision: true,
91        downcast_threshold: 10_000_000, // 10M elements
92        double_precision_accumulation: true,
93    })
94});
95
96/// Set the global mixed-precision configuration.
97#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99    if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100        *global_config = config;
101    }
102}
103
104/// Get the current mixed-precision configuration.
105#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107    MIXED_PRECISION_CONFIG
108        .read()
109        .map(|c| c.clone())
110        .unwrap_or_default()
111}
112
113/// Determine the optimal precision for an array based on its size.
114#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117    T: Clone + 'static,
118    D: Dimension,
119{
120    let config = get_mixed_precision_config();
121    let size = array.len();
122
123    if config.auto_precision {
124        if size >= config.downcast_threshold {
125            Precision::Single
126        } else {
127            Precision::Double
128        }
129    } else {
130        config.storage_precision
131    }
132}
133
134/// Mixed-precision array that can automatically convert between precisions.
135///
136/// This wrapper enables arrays to use different precision levels for storage
137/// and computation, automatically converting between precisions as needed.
138#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141    T: Clone + 'static,
142    D: Dimension,
143{
144    /// The array stored at the specified precision.
145    array: Array<T, D>,
146
147    /// The current storage precision.
148    storage_precision: Precision,
149
150    /// The precision used for computations.
151    computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156    T: Clone + Float + 'static,
157    D: Dimension,
158{
159    /// Create a new mixed-precision array.
160    pub fn new(array: Array<T, D>) -> Self {
161        let precision = match std::mem::size_of::<T>() {
162            2 => Precision::Half,
163            4 => Precision::Single,
164            8 => Precision::Double,
165            _ => Precision::Mixed,
166        };
167
168        Self {
169            array,
170            storage_precision: precision,
171            computeprecision: precision,
172        }
173    }
174
175    /// Create a new mixed-precision array with specified compute precision.
176    pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177        let storage_precision = match std::mem::size_of::<T>() {
178            2 => Precision::Half,
179            4 => Precision::Single,
180            8 => Precision::Double,
181            _ => Precision::Mixed,
182        };
183
184        Self {
185            array: data,
186            storage_precision,
187            computeprecision,
188        }
189    }
190
191    /// Get the array at the specified precision.
192    ///
193    /// This is a placeholder implementation. In a real implementation,
194    /// this would convert the array to the requested precision.
195    pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
196    where
197        U: Clone + Float + 'static,
198    {
199        // This is a simplified implementation for demonstration purposes.
200        // In a real implementation, this would handle proper type conversion.
201        Err(CoreError::NotImplementedError(ErrorContext::new(
202            "Precision conversion not fully implemented yet",
203        )))
204    }
205
206    /// Get the current storage precision.
207    pub fn storage_precision(&self) -> Precision {
208        self.storage_precision
209    }
210
211    /// Get the underlying array.
212    pub const fn array(&self) -> &Array<T, D> {
213        &self.array
214    }
215}
216
217/// Trait for arrays that support mixed-precision operations.
218pub trait MixedPrecisionSupport: ArrayProtocol {
219    /// Convert the array to the specified precision.
220    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
221
222    /// Get the current precision of the array.
223    fn precision(&self) -> Precision;
224
225    /// Check if the array supports the specified precision.
226    fn supports_precision(&self, precision: Precision) -> bool;
227}
228
229/// Implement ArrayProtocol for MixedPrecisionArray.
230impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
231where
232    T: Clone + Float + Send + Sync + 'static,
233    D: Dimension + Send + Sync + 'static,
234{
235    fn array_function(
236        &self,
237        func: &ArrayFunction,
238        types: &[TypeId],
239        args: &[Box<dyn Any>],
240        kwargs: &HashMap<String, Box<dyn Any>>,
241    ) -> Result<Box<dyn Any>, NotImplemented> {
242        // If the function supports mixed precision, delegate to the appropriate implementation
243        let precision = kwargs
244            .get("precision")
245            .and_then(|p| p.downcast_ref::<Precision>())
246            .cloned()
247            .unwrap_or(self.computeprecision);
248
249        // Determine operating precision based on function and arguments
250        match func.name {
251            "scirs2::array_protocol::operations::matmul" => {
252                // If we have a second argument, check its precision
253                if args.len() >= 2 {
254                    // Adjust to highest precision of the two arrays
255                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
256                        let other_precision = other.computeprecision;
257                        let _precision_to_use = match (precision, other_precision) {
258                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
259                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
260                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
261                            (Precision::Half, Precision::Half) => Precision::Half,
262                        };
263
264                        // We can't modify kwargs, so we'll just forward directly
265                        // Get NdarrayWrapper for self array
266                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
267
268                        // Delegate to the NdarrayWrapper implementation
269                        return wrapped_self.array_function(func, types, args, kwargs);
270                    }
271                }
272
273                // Convert to the requested precision and use standard implementation
274                match precision {
275                    Precision::Single | Precision::Double => {
276                        // Wrap in NdarrayWrapper for computation
277                        let wrapped = NdarrayWrapper::new(self.array.clone());
278
279                        // Adjust args to use wrapped version
280                        let mut new_args = Vec::with_capacity(args.len());
281                        new_args.push(Box::new(wrapped.clone()));
282
283                        // We don't need to include other args since we already have a new wrapped object
284                        // For simplicity, just delegate to the original args
285                        // Delegate to NdarrayWrapper
286                        wrapped.array_function(func, types, args, kwargs)
287                    }
288                    Precision::Mixed => {
289                        // Use Double precision for Mixed calculations
290                        let wrapped = NdarrayWrapper::new(self.array.clone());
291
292                        // Create new args and kwargs with Double precision
293                        let mut new_args = Vec::with_capacity(args.len());
294                        new_args.push(Box::new(wrapped.clone()));
295
296                        // We can't modify kwargs, so just forward along
297                        // Delegate to NdarrayWrapper directly with original args and kwargs
298                        wrapped.array_function(func, types, args, kwargs)
299                    }
300                    _ => Err(NotImplemented),
301                }
302            }
303            "scirs2::array_protocol::operations::add"
304            | "scirs2::array_protocol::operations::subtract"
305            | "scirs2::array_protocol::operations::multiply" => {
306                // Similar pattern for element-wise operations
307                // If we have a second argument, check its precision
308                if args.len() >= 2 {
309                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
310                        // Use the highest precision for the operation
311                        let other_precision = other.computeprecision;
312                        let _precision_to_use = match (precision, other_precision) {
313                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
314                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
315                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
316                            (Precision::Half, Precision::Half) => Precision::Half,
317                        };
318
319                        // We can't modify kwargs, so we'll just forward directly
320                        // Get NdarrayWrapper for self array
321                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
322
323                        // Delegate to the NdarrayWrapper implementation
324                        return wrapped_self.array_function(func, types, args, kwargs);
325                    }
326                }
327
328                // Convert to the requested precision and use standard implementation
329                let wrapped = NdarrayWrapper::new(self.array.clone());
330
331                // Delegate to NdarrayWrapper with original args
332                wrapped.array_function(func, types, args, kwargs)
333            }
334            "scirs2::array_protocol::operations::transpose"
335            | "scirs2::array_protocol::operations::reshape"
336            | "scirs2::array_protocol::operations::sum" => {
337                // For unary operations, simply use the current precision
338                // Convert to standard wrapper and delegate
339                let wrapped = NdarrayWrapper::new(self.array.clone());
340
341                // Delegate to NdarrayWrapper with original args
342                wrapped.array_function(func, types, args, kwargs)
343            }
344            _ => {
345                // For any other function, delegate to standard implementation
346                let wrapped = NdarrayWrapper::new(self.array.clone());
347                wrapped.array_function(func, types, args, kwargs)
348            }
349        }
350    }
351
352    fn as_any(&self) -> &dyn Any {
353        self
354    }
355
356    fn shape(&self) -> &[usize] {
357        self.array.shape()
358    }
359
360    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
361        Box::new(Self {
362            array: self.array.clone(),
363            storage_precision: self.storage_precision,
364            computeprecision: self.computeprecision,
365        })
366    }
367}
368
369/// Implement MixedPrecisionSupport for MixedPrecisionArray.
370impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
371where
372    T: Clone + Float + Send + Sync + 'static,
373    D: Dimension + Send + Sync + 'static,
374{
375    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
376        match precision {
377            Precision::Single => {
378                // For actual implementation, this would convert f64 to f32 if needed
379                // This is a simplified version - in reality, we would need to convert between types
380
381                let current_precision = self.precision();
382                if current_precision == Precision::Single {
383                    // Already in single precision
384                    return Ok(Box::new(self.clone()));
385                }
386
387                // In real implementation, would handle proper conversion from T to f32
388                // For now, create a new array with the requested precision
389                let array_single = self.array.clone();
390                let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
391                Ok(Box::new(newarray))
392            }
393            Precision::Double => {
394                // For actual implementation, this would convert f32 to f64 if needed
395
396                let current_precision = self.precision();
397                if current_precision == Precision::Double {
398                    // Already in double precision
399                    return Ok(Box::new(self.clone()));
400                }
401
402                // In real implementation, would handle proper conversion from T to f64
403                // For now, create a new array with the requested precision
404                let array_double = self.array.clone();
405                let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
406                Ok(Box::new(newarray))
407            }
408            Precision::Mixed => {
409                // For mixed precision, use storage precision of the current array and double compute precision
410                let array_mixed = self.array.clone();
411                let newarray =
412                    MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
413                Ok(Box::new(newarray))
414            }
415            _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
416                "Conversion to {precision} precision not implemented"
417            )))),
418        }
419    }
420
421    fn precision(&self) -> Precision {
422        // If storage and compute precision differ, return Mixed
423        if self.storage_precision != self.computeprecision {
424            Precision::Mixed
425        } else {
426            self.storage_precision
427        }
428    }
429
430    fn supports_precision(&self, precision: Precision) -> bool {
431        matches!(precision, Precision::Single | Precision::Double)
432    }
433}
434
435/// Implement MixedPrecisionSupport for GPUNdarray.
436impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
437where
438    T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
439    D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
440{
441    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
442        // For GPUs, creating a new array with mixed precision enabled
443        let mut config = self.config().clone();
444        config.mixed_precision = precision == Precision::Mixed;
445
446        if let Ok(cpu_array) = self.to_cpu() {
447            // Use as_any() to downcast the ArrayProtocol trait object
448            if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
449                let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
450                return Ok(Box::new(new_gpu_array));
451            }
452        }
453
454        Err(CoreError::NotImplementedError(ErrorContext::new(format!(
455            "Conversion to {precision} precision not implemented for GPU arrays"
456        ))))
457    }
458
459    fn precision(&self) -> Precision {
460        if self.config().mixed_precision {
461            Precision::Mixed
462        } else {
463            match std::mem::size_of::<T>() {
464                4 => Precision::Single,
465                8 => Precision::Double,
466                _ => Precision::Mixed,
467            }
468        }
469    }
470
471    fn supports_precision(&self, precision: Precision) -> bool {
472        // Most GPUs support all precision levels
473        true
474    }
475}
476
477/// Execute an operation with a specific precision.
478///
479/// This function automatically converts arrays to the specified precision
480/// before executing the operation.
481#[allow(dead_code)]
482pub fn execute_with_precision<F, R>(
483    arrays: &[&dyn MixedPrecisionSupport],
484    precision: Precision,
485    executor: F,
486) -> CoreResult<R>
487where
488    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
489    R: 'static,
490{
491    // Check if all arrays support the requested precision
492    for array in arrays {
493        if !array.supports_precision(precision) {
494            return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
495                "One or more arrays do not support {precision} precision"
496            ))));
497        }
498    }
499
500    // Convert arrays to the requested precision
501    let mut converted_arrays = Vec::with_capacity(arrays.len());
502
503    for &array in arrays {
504        let converted = array.to_precision(precision)?;
505        converted_arrays.push(converted);
506    }
507
508    // NOTE: Trait upcasting is unstable, so we skip this for now
509    // This functionality is not critical for TenRSo
510    // TODO: Re-enable once trait_upcasting is stabilized (RFC #65991)
511
512    // Workaround: just return error for now
513    Err("Mixed precision batch execution not supported on stable Rust - requires trait_upcasting feature".to_string().into())
514}
515
516/// Implementation of common array operations with mixed precision.
517pub mod ops {
518    use super::*;
519    use crate::array_protocol::operations as array_ops;
520
521    /// Matrix multiplication with specified precision.
522    pub fn matmul(
523        a: &dyn MixedPrecisionSupport,
524        b: &dyn MixedPrecisionSupport,
525        precision: Precision,
526    ) -> CoreResult<Box<dyn ArrayProtocol>> {
527        execute_with_precision(&[a, b], precision, |arrays| {
528            // Convert OperationError to CoreError
529            match array_ops::matmul(arrays[0], arrays[1]) {
530                Ok(result) => Ok(result),
531                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
532                    e.to_string(),
533                ))),
534            }
535        })
536    }
537
538    /// Element-wise addition with specified precision.
539    pub fn add(
540        a: &dyn MixedPrecisionSupport,
541        b: &dyn MixedPrecisionSupport,
542        precision: Precision,
543    ) -> CoreResult<Box<dyn ArrayProtocol>> {
544        execute_with_precision(&[a, b], precision, |arrays| {
545            // Convert OperationError to CoreError
546            match array_ops::add(arrays[0], arrays[1]) {
547                Ok(result) => Ok(result),
548                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
549                    e.to_string(),
550                ))),
551            }
552        })
553    }
554
555    /// Element-wise multiplication with specified precision.
556    pub fn multiply(
557        a: &dyn MixedPrecisionSupport,
558        b: &dyn MixedPrecisionSupport,
559        precision: Precision,
560    ) -> CoreResult<Box<dyn ArrayProtocol>> {
561        execute_with_precision(&[a, b], precision, |arrays| {
562            // Convert OperationError to CoreError
563            match array_ops::multiply(arrays[0], arrays[1]) {
564                Ok(result) => Ok(result),
565                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
566                    e.to_string(),
567                ))),
568            }
569        })
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use ::ndarray::arr2;
577
578    #[test]
579    fn test_mixed_precision_array() {
580        // Create a mixed-precision array
581        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
582        let mixed_array = MixedPrecisionArray::new(array.clone());
583
584        // Check the storage precision (should be double for f64 arrays)
585        assert_eq!(mixed_array.storage_precision(), Precision::Double);
586
587        // Test the ArrayProtocol implementation
588        let array_protocol: &dyn ArrayProtocol = &mixed_array;
589        // The array is of type MixedPrecisionArray<f64, Ix2> (not IxDyn)
590        assert!(array_protocol
591            .as_any()
592            .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
593    }
594
595    #[test]
596    fn test_mixed_precision_support() {
597        // Initialize the array protocol
598        crate::array_protocol::init();
599
600        // Create a mixed-precision array
601        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
602        let mixed_array = MixedPrecisionArray::new(array.clone());
603
604        // Test MixedPrecisionSupport implementation
605        let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
606        assert_eq!(mixed_support.precision(), Precision::Double);
607        assert!(mixed_support.supports_precision(Precision::Single));
608        assert!(mixed_support.supports_precision(Precision::Double));
609    }
610}