1use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::{cast as num_cast, Float};
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23 ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30 Half,
32
33 Single,
35
36 Double,
38
39 Mixed,
41}
42
43impl fmt::Display for Precision {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Precision::Half => write!(f, "half"),
47 Precision::Single => write!(f, "single"),
48 Precision::Double => write!(f, "double"),
49 Precision::Mixed => write!(f, "mixed"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57 pub storage_precision: Precision,
59
60 pub computeprecision: Precision,
62
63 pub auto_precision: bool,
65
66 pub downcast_threshold: usize,
68
69 pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74 fn default() -> Self {
75 Self {
76 storage_precision: Precision::Single,
77 computeprecision: Precision::Double,
78 auto_precision: true,
79 downcast_threshold: 10_000_000, double_precision_accumulation: true,
81 }
82 }
83}
84
85pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87 RwLock::new(MixedPrecisionConfig {
88 storage_precision: Precision::Single,
89 computeprecision: Precision::Double,
90 auto_precision: true,
91 downcast_threshold: 10_000_000, double_precision_accumulation: true,
93 })
94});
95
96#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99 if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100 *global_config = config;
101 }
102}
103
104#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107 MIXED_PRECISION_CONFIG
108 .read()
109 .map(|c| c.clone())
110 .unwrap_or_default()
111}
112
113#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117 T: Clone + 'static,
118 D: Dimension,
119{
120 let config = get_mixed_precision_config();
121 let size = array.len();
122
123 if config.auto_precision {
124 if size >= config.downcast_threshold {
125 Precision::Single
126 } else {
127 Precision::Double
128 }
129 } else {
130 config.storage_precision
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141 T: Clone + 'static,
142 D: Dimension,
143{
144 array: Array<T, D>,
146
147 storage_precision: Precision,
149
150 computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156 T: Clone + Float + 'static,
157 D: Dimension,
158{
159 pub fn new(array: Array<T, D>) -> Self {
161 let precision = match std::mem::size_of::<T>() {
162 2 => Precision::Half,
163 4 => Precision::Single,
164 8 => Precision::Double,
165 _ => Precision::Mixed,
166 };
167
168 Self {
169 array,
170 storage_precision: precision,
171 computeprecision: precision,
172 }
173 }
174
175 pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177 let storage_precision = match std::mem::size_of::<T>() {
178 2 => Precision::Half,
179 4 => Precision::Single,
180 8 => Precision::Double,
181 _ => Precision::Mixed,
182 };
183
184 Self {
185 array: data,
186 storage_precision,
187 computeprecision,
188 }
189 }
190
191 pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
210 where
211 U: Clone + Float + 'static,
212 {
213 let mut converted: Vec<U> = Vec::with_capacity(self.array.len());
215 for x in self.array.iter() {
216 match num_cast::<T, U>(*x) {
217 Some(v) => converted.push(v),
218 None => {
219 return Err(CoreError::ComputationError(ErrorContext::new(format!(
220 "at_precision: failed to cast element to target precision (source size \
221 {} bytes, target size {} bytes)",
222 std::mem::size_of::<T>(),
223 std::mem::size_of::<U>(),
224 ))))
225 }
226 }
227 }
228
229 Array::from_shape_vec(self.array.raw_dim(), converted).map_err(|e| {
231 CoreError::ShapeError(ErrorContext::new(format!(
232 "at_precision: failed to reconstruct array from converted elements: {e}"
233 )))
234 })
235 }
236
237 pub fn storage_precision(&self) -> Precision {
239 self.storage_precision
240 }
241
242 pub const fn array(&self) -> &Array<T, D> {
244 &self.array
245 }
246}
247
248pub trait MixedPrecisionSupport: ArrayProtocol {
250 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
252
253 fn precision(&self) -> Precision;
255
256 fn supports_precision(&self, precision: Precision) -> bool;
258}
259
260impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
262where
263 T: Clone + Float + Send + Sync + 'static,
264 D: Dimension + Send + Sync + 'static,
265{
266 fn array_function(
267 &self,
268 func: &ArrayFunction,
269 types: &[TypeId],
270 args: &[Box<dyn Any>],
271 kwargs: &HashMap<String, Box<dyn Any>>,
272 ) -> Result<Box<dyn Any>, NotImplemented> {
273 let precision = kwargs
275 .get("precision")
276 .and_then(|p| p.downcast_ref::<Precision>())
277 .cloned()
278 .unwrap_or(self.computeprecision);
279
280 match func.name {
282 "scirs2::array_protocol::operations::matmul" => {
283 if args.len() >= 2 {
285 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
287 let other_precision = other.computeprecision;
288 let _precision_to_use = match (precision, other_precision) {
289 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
290 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
291 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
292 (Precision::Half, Precision::Half) => Precision::Half,
293 };
294
295 let wrapped_self = NdarrayWrapper::new(self.array.clone());
298
299 return wrapped_self.array_function(func, types, args, kwargs);
301 }
302 }
303
304 match precision {
306 Precision::Single | Precision::Double => {
307 let wrapped = NdarrayWrapper::new(self.array.clone());
309
310 let mut new_args = Vec::with_capacity(args.len());
312 new_args.push(Box::new(wrapped.clone()));
313
314 wrapped.array_function(func, types, args, kwargs)
318 }
319 Precision::Mixed => {
320 let wrapped = NdarrayWrapper::new(self.array.clone());
322
323 let mut new_args = Vec::with_capacity(args.len());
325 new_args.push(Box::new(wrapped.clone()));
326
327 wrapped.array_function(func, types, args, kwargs)
330 }
331 _ => Err(NotImplemented),
332 }
333 }
334 "scirs2::array_protocol::operations::add"
335 | "scirs2::array_protocol::operations::subtract"
336 | "scirs2::array_protocol::operations::multiply" => {
337 if args.len() >= 2 {
340 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
341 let other_precision = other.computeprecision;
343 let _precision_to_use = match (precision, other_precision) {
344 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
345 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
346 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
347 (Precision::Half, Precision::Half) => Precision::Half,
348 };
349
350 let wrapped_self = NdarrayWrapper::new(self.array.clone());
353
354 return wrapped_self.array_function(func, types, args, kwargs);
356 }
357 }
358
359 let wrapped = NdarrayWrapper::new(self.array.clone());
361
362 wrapped.array_function(func, types, args, kwargs)
364 }
365 "scirs2::array_protocol::operations::transpose"
366 | "scirs2::array_protocol::operations::reshape"
367 | "scirs2::array_protocol::operations::sum" => {
368 let wrapped = NdarrayWrapper::new(self.array.clone());
371
372 wrapped.array_function(func, types, args, kwargs)
374 }
375 _ => {
376 let wrapped = NdarrayWrapper::new(self.array.clone());
378 wrapped.array_function(func, types, args, kwargs)
379 }
380 }
381 }
382
383 fn as_any(&self) -> &dyn Any {
384 self
385 }
386
387 fn shape(&self) -> &[usize] {
388 self.array.shape()
389 }
390
391 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
392 Box::new(Self {
393 array: self.array.clone(),
394 storage_precision: self.storage_precision,
395 computeprecision: self.computeprecision,
396 })
397 }
398}
399
400impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
402where
403 T: Clone + Float + Send + Sync + 'static,
404 D: Dimension + Send + Sync + 'static,
405{
406 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
407 match precision {
408 Precision::Single => {
409 let current_precision = self.precision();
413 if current_precision == Precision::Single {
414 return Ok(Box::new(self.clone()));
416 }
417
418 let array_single = self.array.clone();
421 let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
422 Ok(Box::new(newarray))
423 }
424 Precision::Double => {
425 let current_precision = self.precision();
428 if current_precision == Precision::Double {
429 return Ok(Box::new(self.clone()));
431 }
432
433 let array_double = self.array.clone();
436 let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
437 Ok(Box::new(newarray))
438 }
439 Precision::Mixed => {
440 let array_mixed = self.array.clone();
442 let newarray =
443 MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
444 Ok(Box::new(newarray))
445 }
446 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
447 "Conversion to {precision} precision not implemented"
448 )))),
449 }
450 }
451
452 fn precision(&self) -> Precision {
453 if self.storage_precision != self.computeprecision {
455 Precision::Mixed
456 } else {
457 self.storage_precision
458 }
459 }
460
461 fn supports_precision(&self, precision: Precision) -> bool {
462 matches!(precision, Precision::Single | Precision::Double)
463 }
464}
465
466impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
468where
469 T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
470 D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
471{
472 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
473 let mut config = self.config().clone();
475 config.mixed_precision = precision == Precision::Mixed;
476
477 if let Ok(cpu_array) = self.to_cpu() {
478 if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
480 let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
481 return Ok(Box::new(new_gpu_array));
482 }
483 }
484
485 Err(CoreError::NotImplementedError(ErrorContext::new(format!(
486 "Conversion to {precision} precision not implemented for GPU arrays"
487 ))))
488 }
489
490 fn precision(&self) -> Precision {
491 if self.config().mixed_precision {
492 Precision::Mixed
493 } else {
494 match std::mem::size_of::<T>() {
495 4 => Precision::Single,
496 8 => Precision::Double,
497 _ => Precision::Mixed,
498 }
499 }
500 }
501
502 fn supports_precision(&self, precision: Precision) -> bool {
503 true
505 }
506}
507
508#[allow(dead_code)]
513pub fn execute_with_precision<F, R>(
514 arrays: &[&dyn MixedPrecisionSupport],
515 precision: Precision,
516 executor: F,
517) -> CoreResult<R>
518where
519 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
520 R: 'static,
521{
522 for array in arrays {
524 if !array.supports_precision(precision) {
525 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
526 "One or more arrays do not support {precision} precision"
527 ))));
528 }
529 }
530
531 let mut converted_arrays = Vec::with_capacity(arrays.len());
533
534 for &array in arrays {
535 let converted = array.to_precision(precision)?;
536 converted_arrays.push(converted);
537 }
538
539 Err("Mixed precision batch execution not supported on stable Rust - requires trait_upcasting feature".to_string().into())
545}
546
547pub mod ops {
549 use super::*;
550 use crate::array_protocol::operations as array_ops;
551
552 pub fn matmul(
554 a: &dyn MixedPrecisionSupport,
555 b: &dyn MixedPrecisionSupport,
556 precision: Precision,
557 ) -> CoreResult<Box<dyn ArrayProtocol>> {
558 execute_with_precision(&[a, b], precision, |arrays| {
559 match array_ops::matmul(arrays[0], arrays[1]) {
561 Ok(result) => Ok(result),
562 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
563 e.to_string(),
564 ))),
565 }
566 })
567 }
568
569 pub fn add(
571 a: &dyn MixedPrecisionSupport,
572 b: &dyn MixedPrecisionSupport,
573 precision: Precision,
574 ) -> CoreResult<Box<dyn ArrayProtocol>> {
575 execute_with_precision(&[a, b], precision, |arrays| {
576 match array_ops::add(arrays[0], arrays[1]) {
578 Ok(result) => Ok(result),
579 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
580 e.to_string(),
581 ))),
582 }
583 })
584 }
585
586 pub fn multiply(
588 a: &dyn MixedPrecisionSupport,
589 b: &dyn MixedPrecisionSupport,
590 precision: Precision,
591 ) -> CoreResult<Box<dyn ArrayProtocol>> {
592 execute_with_precision(&[a, b], precision, |arrays| {
593 match array_ops::multiply(arrays[0], arrays[1]) {
595 Ok(result) => Ok(result),
596 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
597 e.to_string(),
598 ))),
599 }
600 })
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use ::ndarray::arr2;
608
609 #[test]
610 fn test_mixed_precision_array() {
611 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
613 let mixed_array = MixedPrecisionArray::new(array.clone());
614
615 assert_eq!(mixed_array.storage_precision(), Precision::Double);
617
618 let array_protocol: &dyn ArrayProtocol = &mixed_array;
620 assert!(array_protocol
622 .as_any()
623 .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
624 }
625
626 #[test]
627 fn test_mixed_precision_support() {
628 crate::array_protocol::init();
630
631 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
633 let mixed_array = MixedPrecisionArray::new(array.clone());
634
635 let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
637 assert_eq!(mixed_support.precision(), Precision::Double);
638 assert!(mixed_support.supports_precision(Precision::Single));
639 assert!(mixed_support.supports_precision(Precision::Double));
640 }
641
642 #[test]
646 fn test_at_precision_f64_to_f32() {
647 use ::ndarray::array;
648 let arr = array![1.0_f64, 2.5_f64, -1.75_f64].into_dyn();
650 let mp = MixedPrecisionArray::new(arr);
651 let as_f32: crate::ndarray::ArrayD<f32> = mp
652 .at_precision()
653 .expect("f64 → f32 precision conversion should succeed");
654 assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
655 assert!((as_f32[1] - 2.5_f32).abs() < 1e-6);
656 assert!((as_f32[2] - (-1.75_f32)).abs() < 1e-6);
657 }
658
659 #[test]
661 fn test_at_precision_f32_to_f64() {
662 use ::ndarray::array;
663 let arr = array![0.5_f32, 1.25_f32, -2.0_f32].into_dyn();
664 let mp = MixedPrecisionArray::new(arr);
665 let as_f64: crate::ndarray::ArrayD<f64> = mp
666 .at_precision()
667 .expect("f32 → f64 precision conversion should succeed");
668 assert!((as_f64[0] - 0.5_f64).abs() < 1e-12);
669 assert!((as_f64[1] - 1.25_f64).abs() < 1e-12);
670 assert!((as_f64[2] - (-2.0_f64)).abs() < 1e-12);
671 }
672
673 #[test]
675 fn test_at_precision_same_type_is_identity() {
676 use ::ndarray::array;
677 let arr = array![42.0_f64, -7.5_f64].into_dyn();
678 let mp = MixedPrecisionArray::new(arr.clone());
679 let result: crate::ndarray::ArrayD<f64> = mp
680 .at_precision()
681 .expect("f64 → f64 precision conversion should succeed");
682 for (a, b) in arr.iter().zip(result.iter()) {
683 assert_eq!(*a, *b, "Identity conversion must not change values");
684 }
685 }
686
687 #[test]
689 fn test_at_precision_preserves_shape() {
690 let arr = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
691 let mp = MixedPrecisionArray::new(arr);
692 let as_f32: crate::ndarray::Array<f32, crate::ndarray::Ix2> = mp
693 .at_precision()
694 .expect("2D f64 → f32 conversion should succeed");
695 assert_eq!(as_f32.shape(), &[2, 2]);
696 assert!((as_f32[[0, 0]] - 1.0_f32).abs() < 1e-6);
697 assert!((as_f32[[1, 1]] - 4.0_f32).abs() < 1e-6);
698 }
699}