scirs2_core/array_protocol/
mixed_precision.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Mixed-precision operations for the array protocol.
14//!
15//! This module provides support for mixed-precision operations, allowing
16//! arrays to use different numeric types (e.g., f32, f64) for storage
17//! and computation to optimize performance and memory usage.
18
19use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::fmt;
22use std::sync::{LazyLock, RwLock};
23
24use ndarray::{Array, Dimension};
25use num_traits::Float;
26
27use crate::array_protocol::gpu_impl::GPUNdarray;
28use crate::array_protocol::{
29    ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
30};
31use crate::error::{CoreError, CoreResult, ErrorContext};
32
33/// Precision levels for array operations.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum Precision {
36    /// Half-precision floating point (16-bit)
37    Half,
38
39    /// Single-precision floating point (32-bit)
40    Single,
41
42    /// Double-precision floating point (64-bit)
43    Double,
44
45    /// Mixed precision (e.g., store in 16/32-bit, compute in 64-bit)
46    Mixed,
47}
48
49impl fmt::Display for Precision {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            Precision::Half => write!(f, "half"),
53            Precision::Single => write!(f, "single"),
54            Precision::Double => write!(f, "double"),
55            Precision::Mixed => write!(f, "mixed"),
56        }
57    }
58}
59
60/// Configuration for mixed-precision operations.
61#[derive(Debug, Clone)]
62pub struct MixedPrecisionConfig {
63    /// Storage precision for arrays.
64    pub storage_precision: Precision,
65
66    /// Computation precision for operations.
67    pub computeprecision: Precision,
68
69    /// Automatic precision selection based on array size and operation.
70    pub auto_precision: bool,
71
72    /// Threshold for automatic downcast to lower precision.
73    pub downcast_threshold: usize,
74
75    /// Always use double precision for intermediate results.
76    pub double_precision_accumulation: bool,
77}
78
79impl Default for MixedPrecisionConfig {
80    fn default() -> Self {
81        Self {
82            storage_precision: Precision::Single,
83            computeprecision: Precision::Double,
84            auto_precision: true,
85            downcast_threshold: 10_000_000, // 10M elements
86            double_precision_accumulation: true,
87        }
88    }
89}
90
91/// Global mixed-precision configuration.
92pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
93    RwLock::new(MixedPrecisionConfig {
94        storage_precision: Precision::Single,
95        computeprecision: Precision::Double,
96        auto_precision: true,
97        downcast_threshold: 10_000_000, // 10M elements
98        double_precision_accumulation: true,
99    })
100});
101
102/// Set the global mixed-precision configuration.
103#[allow(dead_code)]
104pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
105    if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
106        *global_config = config;
107    }
108}
109
110/// Get the current mixed-precision configuration.
111#[allow(dead_code)]
112pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
113    MIXED_PRECISION_CONFIG
114        .read()
115        .map(|c| c.clone())
116        .unwrap_or_default()
117}
118
119/// Determine the optimal precision for an array based on its size.
120#[allow(dead_code)]
121pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
122where
123    T: Clone + 'static,
124    D: Dimension,
125{
126    let config = get_mixed_precision_config();
127    let size = array.len();
128
129    if config.auto_precision {
130        if size >= config.downcast_threshold {
131            Precision::Single
132        } else {
133            Precision::Double
134        }
135    } else {
136        config.storage_precision
137    }
138}
139
140/// Mixed-precision array that can automatically convert between precisions.
141///
142/// This wrapper enables arrays to use different precision levels for storage
143/// and computation, automatically converting between precisions as needed.
144#[derive(Debug, Clone)]
145pub struct MixedPrecisionArray<T, D>
146where
147    T: Clone + 'static,
148    D: Dimension,
149{
150    /// The array stored at the specified precision.
151    array: Array<T, D>,
152
153    /// The current storage precision.
154    storage_precision: Precision,
155
156    /// The precision used for computations.
157    computeprecision: Precision,
158}
159
160impl<T, D> MixedPrecisionArray<T, D>
161where
162    T: Clone + Float + 'static,
163    D: Dimension,
164{
165    /// Create a new mixed-precision array.
166    pub fn new(array: Array<T, D>) -> Self {
167        let precision = match std::mem::size_of::<T>() {
168            2 => Precision::Half,
169            4 => Precision::Single,
170            8 => Precision::Double,
171            _ => Precision::Mixed,
172        };
173
174        Self {
175            array,
176            storage_precision: precision,
177            computeprecision: precision,
178        }
179    }
180
181    /// Create a new mixed-precision array with specified compute precision.
182    pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
183        let storage_precision = match std::mem::size_of::<T>() {
184            2 => Precision::Half,
185            4 => Precision::Single,
186            8 => Precision::Double,
187            _ => Precision::Mixed,
188        };
189
190        Self {
191            array: data,
192            storage_precision,
193            computeprecision,
194        }
195    }
196
197    /// Get the array at the specified precision.
198    ///
199    /// This is a placeholder implementation. In a real implementation,
200    /// this would convert the array to the requested precision.
201    pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
202    where
203        U: Clone + Float + 'static,
204    {
205        // This is a simplified implementation for demonstration purposes.
206        // In a real implementation, this would handle proper type conversion.
207        Err(CoreError::NotImplementedError(ErrorContext::new(
208            "Precision conversion not fully implemented yet",
209        )))
210    }
211
212    /// Get the current storage precision.
213    pub fn storage_precision(&self) -> Precision {
214        self.storage_precision
215    }
216
217    /// Get the underlying array.
218    pub const fn array(&self) -> &Array<T, D> {
219        &self.array
220    }
221}
222
223/// Trait for arrays that support mixed-precision operations.
224pub trait MixedPrecisionSupport: ArrayProtocol {
225    /// Convert the array to the specified precision.
226    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
227
228    /// Get the current precision of the array.
229    fn precision(&self) -> Precision;
230
231    /// Check if the array supports the specified precision.
232    fn supports_precision(&self, precision: Precision) -> bool;
233}
234
235/// Implement ArrayProtocol for MixedPrecisionArray.
236impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
237where
238    T: Clone + Float + Send + Sync + 'static,
239    D: Dimension + Send + Sync + 'static,
240{
241    fn array_function(
242        &self,
243        func: &ArrayFunction,
244        types: &[TypeId],
245        args: &[Box<dyn Any>],
246        kwargs: &HashMap<String, Box<dyn Any>>,
247    ) -> Result<Box<dyn Any>, NotImplemented> {
248        // If the function supports mixed precision, delegate to the appropriate implementation
249        let precision = kwargs
250            .get("precision")
251            .and_then(|p| p.downcast_ref::<Precision>())
252            .cloned()
253            .unwrap_or(self.computeprecision);
254
255        // Determine operating precision based on function and arguments
256        match func.name {
257            "scirs2::array_protocol::operations::matmul" => {
258                // If we have a second argument, check its precision
259                if args.len() >= 2 {
260                    // Adjust to highest precision of the two arrays
261                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
262                        let other_precision = other.computeprecision;
263                        let _precision_to_use = match (precision, other_precision) {
264                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
265                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
266                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
267                            (Precision::Half, Precision::Half) => Precision::Half,
268                        };
269
270                        // We can't modify kwargs, so we'll just forward directly
271                        // Get NdarrayWrapper for self array
272                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
273
274                        // Delegate to the NdarrayWrapper implementation
275                        return wrapped_self.array_function(func, types, args, kwargs);
276                    }
277                }
278
279                // Convert to the requested precision and use standard implementation
280                match precision {
281                    Precision::Single | Precision::Double => {
282                        // Wrap in NdarrayWrapper for computation
283                        let wrapped = NdarrayWrapper::new(self.array.clone());
284
285                        // Adjust args to use wrapped version
286                        let mut new_args = Vec::with_capacity(args.len());
287                        new_args.push(Box::new(wrapped.clone()));
288
289                        // We don't need to include other args since we already have a new wrapped object
290                        // For simplicity, just delegate to the original args
291                        // Delegate to NdarrayWrapper
292                        wrapped.array_function(func, types, args, kwargs)
293                    }
294                    Precision::Mixed => {
295                        // Use Double precision for Mixed calculations
296                        let wrapped = NdarrayWrapper::new(self.array.clone());
297
298                        // Create new args and kwargs with Double precision
299                        let mut new_args = Vec::with_capacity(args.len());
300                        new_args.push(Box::new(wrapped.clone()));
301
302                        // We can't modify kwargs, so just forward along
303                        // Delegate to NdarrayWrapper directly with original args and kwargs
304                        wrapped.array_function(func, types, args, kwargs)
305                    }
306                    _ => Err(NotImplemented),
307                }
308            }
309            "scirs2::array_protocol::operations::add"
310            | "scirs2::array_protocol::operations::subtract"
311            | "scirs2::array_protocol::operations::multiply" => {
312                // Similar pattern for element-wise operations
313                // If we have a second argument, check its precision
314                if args.len() >= 2 {
315                    if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
316                        // Use the highest precision for the operation
317                        let other_precision = other.computeprecision;
318                        let _precision_to_use = match (precision, other_precision) {
319                            (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
320                            (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
321                            (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
322                            (Precision::Half, Precision::Half) => Precision::Half,
323                        };
324
325                        // We can't modify kwargs, so we'll just forward directly
326                        // Get NdarrayWrapper for self array
327                        let wrapped_self = NdarrayWrapper::new(self.array.clone());
328
329                        // Delegate to the NdarrayWrapper implementation
330                        return wrapped_self.array_function(func, types, args, kwargs);
331                    }
332                }
333
334                // Convert to the requested precision and use standard implementation
335                let wrapped = NdarrayWrapper::new(self.array.clone());
336
337                // Delegate to NdarrayWrapper with original args
338                wrapped.array_function(func, types, args, kwargs)
339            }
340            "scirs2::array_protocol::operations::transpose"
341            | "scirs2::array_protocol::operations::reshape"
342            | "scirs2::array_protocol::operations::sum" => {
343                // For unary operations, simply use the current precision
344                // Convert to standard wrapper and delegate
345                let wrapped = NdarrayWrapper::new(self.array.clone());
346
347                // Delegate to NdarrayWrapper with original args
348                wrapped.array_function(func, types, args, kwargs)
349            }
350            _ => {
351                // For any other function, delegate to standard implementation
352                let wrapped = NdarrayWrapper::new(self.array.clone());
353                wrapped.array_function(func, types, args, kwargs)
354            }
355        }
356    }
357
358    fn as_any(&self) -> &dyn Any {
359        self
360    }
361
362    fn shape(&self) -> &[usize] {
363        self.array.shape()
364    }
365
366    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
367        Box::new(Self {
368            array: self.array.clone(),
369            storage_precision: self.storage_precision,
370            computeprecision: self.computeprecision,
371        })
372    }
373}
374
375/// Implement MixedPrecisionSupport for MixedPrecisionArray.
376impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
377where
378    T: Clone + Float + Send + Sync + 'static,
379    D: Dimension + Send + Sync + 'static,
380{
381    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
382        match precision {
383            Precision::Single => {
384                // For actual implementation, this would convert f64 to f32 if needed
385                // This is a simplified version - in reality, we would need to convert between types
386
387                let current_precision = self.precision();
388                if current_precision == Precision::Single {
389                    // Already in single precision
390                    return Ok(Box::new(self.clone()));
391                }
392
393                // In real implementation, would handle proper conversion from T to f32
394                // For now, create a new array with the requested precision
395                let array_single = self.array.clone();
396                let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
397                Ok(Box::new(newarray))
398            }
399            Precision::Double => {
400                // For actual implementation, this would convert f32 to f64 if needed
401
402                let current_precision = self.precision();
403                if current_precision == Precision::Double {
404                    // Already in double precision
405                    return Ok(Box::new(self.clone()));
406                }
407
408                // In real implementation, would handle proper conversion from T to f64
409                // For now, create a new array with the requested precision
410                let array_double = self.array.clone();
411                let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
412                Ok(Box::new(newarray))
413            }
414            Precision::Mixed => {
415                // For mixed precision, use storage precision of the current array and double compute precision
416                let array_mixed = self.array.clone();
417                let newarray =
418                    MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
419                Ok(Box::new(newarray))
420            }
421            _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
422                "Conversion to {precision} precision not implemented"
423            )))),
424        }
425    }
426
427    fn precision(&self) -> Precision {
428        // If storage and compute precision differ, return Mixed
429        if self.storage_precision != self.computeprecision {
430            Precision::Mixed
431        } else {
432            self.storage_precision
433        }
434    }
435
436    fn supports_precision(&self, precision: Precision) -> bool {
437        matches!(precision, Precision::Single | Precision::Double)
438    }
439}
440
441/// Implement MixedPrecisionSupport for GPUNdarray.
442impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
443where
444    T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
445    D: Dimension + Send + Sync + 'static + ndarray::RemoveAxis,
446{
447    fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
448        // For GPUs, creating a new array with mixed precision enabled
449        let mut config = self.config().clone();
450        config.mixed_precision = precision == Precision::Mixed;
451
452        if let Ok(cpu_array) = self.to_cpu() {
453            // Use as_any() to downcast the ArrayProtocol trait object
454            if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
455                let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
456                return Ok(Box::new(new_gpu_array));
457            }
458        }
459
460        Err(CoreError::NotImplementedError(ErrorContext::new(format!(
461            "Conversion to {precision} precision not implemented for GPU arrays"
462        ))))
463    }
464
465    fn precision(&self) -> Precision {
466        if self.config().mixed_precision {
467            Precision::Mixed
468        } else {
469            match std::mem::size_of::<T>() {
470                4 => Precision::Single,
471                8 => Precision::Double,
472                _ => Precision::Mixed,
473            }
474        }
475    }
476
477    fn supports_precision(&self, precision: Precision) -> bool {
478        // Most GPUs support all precision levels
479        true
480    }
481}
482
483/// Execute an operation with a specific precision.
484///
485/// This function automatically converts arrays to the specified precision
486/// before executing the operation.
487#[allow(dead_code)]
488pub fn execute_with_precision<F, R>(
489    arrays: &[&dyn MixedPrecisionSupport],
490    precision: Precision,
491    executor: F,
492) -> CoreResult<R>
493where
494    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
495    R: 'static,
496{
497    // Check if all arrays support the requested precision
498    for array in arrays {
499        if !array.supports_precision(precision) {
500            return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
501                "One or more arrays do not support {precision} precision"
502            ))));
503        }
504    }
505
506    // Convert arrays to the requested precision
507    let mut converted_arrays = Vec::with_capacity(arrays.len());
508
509    for &array in arrays {
510        let converted = array.to_precision(precision)?;
511        converted_arrays.push(converted);
512    }
513
514    // Create array references
515    let array_refs: Vec<&dyn ArrayProtocol> = converted_arrays
516        .iter()
517        .map(|arr| arr.as_ref() as &dyn ArrayProtocol)
518        .collect();
519
520    // Execute the operation
521    executor(&array_refs)
522}
523
524/// Implementation of common array operations with mixed precision.
525pub mod ops {
526    use super::*;
527    use crate::array_protocol::operations as array_ops;
528
529    /// Matrix multiplication with specified precision.
530    pub fn matmul(
531        a: &dyn MixedPrecisionSupport,
532        b: &dyn MixedPrecisionSupport,
533        precision: Precision,
534    ) -> CoreResult<Box<dyn ArrayProtocol>> {
535        execute_with_precision(&[a, b], precision, |arrays| {
536            // Convert OperationError to CoreError
537            match array_ops::matmul(arrays[0], arrays[1]) {
538                Ok(result) => Ok(result),
539                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
540                    e.to_string(),
541                ))),
542            }
543        })
544    }
545
546    /// Element-wise addition with specified precision.
547    pub fn add(
548        a: &dyn MixedPrecisionSupport,
549        b: &dyn MixedPrecisionSupport,
550        precision: Precision,
551    ) -> CoreResult<Box<dyn ArrayProtocol>> {
552        execute_with_precision(&[a, b], precision, |arrays| {
553            // Convert OperationError to CoreError
554            match array_ops::add(arrays[0], arrays[1]) {
555                Ok(result) => Ok(result),
556                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
557                    e.to_string(),
558                ))),
559            }
560        })
561    }
562
563    /// Element-wise multiplication with specified precision.
564    pub fn multiply(
565        a: &dyn MixedPrecisionSupport,
566        b: &dyn MixedPrecisionSupport,
567        precision: Precision,
568    ) -> CoreResult<Box<dyn ArrayProtocol>> {
569        execute_with_precision(&[a, b], precision, |arrays| {
570            // Convert OperationError to CoreError
571            match array_ops::multiply(arrays[0], arrays[1]) {
572                Ok(result) => Ok(result),
573                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
574                    e.to_string(),
575                ))),
576            }
577        })
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use ndarray::arr2;
585
586    #[test]
587    fn test_mixed_precision_array() {
588        // Create a mixed-precision array
589        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
590        let mixed_array = MixedPrecisionArray::new(array.clone());
591
592        // Check the storage precision (should be double for f64 arrays)
593        assert_eq!(mixed_array.storage_precision(), Precision::Double);
594
595        // Test the ArrayProtocol implementation
596        let array_protocol: &dyn ArrayProtocol = &mixed_array;
597        // The array is of type MixedPrecisionArray<f64, Ix2> (not IxDyn)
598        assert!(array_protocol
599            .as_any()
600            .is::<MixedPrecisionArray<f64, ndarray::Ix2>>());
601    }
602
603    #[test]
604    fn test_mixed_precision_support() {
605        // Initialize the array protocol
606        crate::array_protocol::init();
607
608        // Create a mixed-precision array
609        let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
610        let mixed_array = MixedPrecisionArray::new(array.clone());
611
612        // Test MixedPrecisionSupport implementation
613        let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
614        assert_eq!(mixed_support.precision(), Precision::Double);
615        assert!(mixed_support.supports_precision(Precision::Single));
616        assert!(mixed_support.supports_precision(Precision::Double));
617    }
618}