1use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::{cast as num_cast, Float};
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23 ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30 Half,
32
33 Single,
35
36 Double,
38
39 Mixed,
41}
42
43impl fmt::Display for Precision {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Precision::Half => write!(f, "half"),
47 Precision::Single => write!(f, "single"),
48 Precision::Double => write!(f, "double"),
49 Precision::Mixed => write!(f, "mixed"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57 pub storage_precision: Precision,
59
60 pub computeprecision: Precision,
62
63 pub auto_precision: bool,
65
66 pub downcast_threshold: usize,
68
69 pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74 fn default() -> Self {
75 Self {
76 storage_precision: Precision::Single,
77 computeprecision: Precision::Double,
78 auto_precision: true,
79 downcast_threshold: 10_000_000, double_precision_accumulation: true,
81 }
82 }
83}
84
85pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87 RwLock::new(MixedPrecisionConfig {
88 storage_precision: Precision::Single,
89 computeprecision: Precision::Double,
90 auto_precision: true,
91 downcast_threshold: 10_000_000, double_precision_accumulation: true,
93 })
94});
95
96#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99 if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100 *global_config = config;
101 }
102}
103
104#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107 MIXED_PRECISION_CONFIG
108 .read()
109 .map(|c| c.clone())
110 .unwrap_or_default()
111}
112
113#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117 T: Clone + 'static,
118 D: Dimension,
119{
120 let config = get_mixed_precision_config();
121 let size = array.len();
122
123 if config.auto_precision {
124 if size >= config.downcast_threshold {
125 Precision::Single
126 } else {
127 Precision::Double
128 }
129 } else {
130 config.storage_precision
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141 T: Clone + 'static,
142 D: Dimension,
143{
144 array: Array<T, D>,
146
147 storage_precision: Precision,
149
150 computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156 T: Clone + Float + 'static,
157 D: Dimension,
158{
159 pub fn new(array: Array<T, D>) -> Self {
161 let precision = match std::mem::size_of::<T>() {
162 2 => Precision::Half,
163 4 => Precision::Single,
164 8 => Precision::Double,
165 _ => Precision::Mixed,
166 };
167
168 Self {
169 array,
170 storage_precision: precision,
171 computeprecision: precision,
172 }
173 }
174
175 pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177 let storage_precision = match std::mem::size_of::<T>() {
178 2 => Precision::Half,
179 4 => Precision::Single,
180 8 => Precision::Double,
181 _ => Precision::Mixed,
182 };
183
184 Self {
185 array: data,
186 storage_precision,
187 computeprecision,
188 }
189 }
190
191 pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
210 where
211 U: Clone + Float + 'static,
212 {
213 let mut converted: Vec<U> = Vec::with_capacity(self.array.len());
215 for x in self.array.iter() {
216 match num_cast::<T, U>(*x) {
217 Some(v) => converted.push(v),
218 None => {
219 return Err(CoreError::ComputationError(ErrorContext::new(format!(
220 "at_precision: failed to cast element to target precision (source size \
221 {} bytes, target size {} bytes)",
222 std::mem::size_of::<T>(),
223 std::mem::size_of::<U>(),
224 ))))
225 }
226 }
227 }
228
229 Array::from_shape_vec(self.array.raw_dim(), converted).map_err(|e| {
231 CoreError::ShapeError(ErrorContext::new(format!(
232 "at_precision: failed to reconstruct array from converted elements: {e}"
233 )))
234 })
235 }
236
237 pub fn storage_precision(&self) -> Precision {
239 self.storage_precision
240 }
241
242 pub const fn array(&self) -> &Array<T, D> {
244 &self.array
245 }
246}
247
248pub trait MixedPrecisionSupport: ArrayProtocol {
250 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
252
253 fn precision(&self) -> Precision;
255
256 fn supports_precision(&self, precision: Precision) -> bool;
258
259 fn as_array_protocol(&self) -> &dyn ArrayProtocol;
271}
272
273fn extract_inner_ndarray<T, D>(arg: &dyn Any) -> Option<Array<T, D>>
283where
284 T: Clone + Float + Send + Sync + 'static,
285 D: Dimension + Send + Sync + 'static,
286{
287 if let Some(ap) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
290 let inner: &dyn ArrayProtocol = &**ap;
291 if let Some(mp) = inner.as_any().downcast_ref::<MixedPrecisionArray<T, D>>() {
292 return Some(mp.array.clone());
293 }
294 if let Some(nd) = inner.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
295 return Some(nd.as_array().clone());
296 }
297 return None;
298 }
299
300 if let Some(mp) = arg.downcast_ref::<MixedPrecisionArray<T, D>>() {
302 return Some(mp.array.clone());
303 }
304 if let Some(nd) = arg.downcast_ref::<NdarrayWrapper<T, D>>() {
305 return Some(nd.as_array().clone());
306 }
307
308 None
309}
310
311fn rewrap_result_as_array_protocol<T>(result: Box<dyn Any>) -> Box<dyn Any>
323where
324 T: Clone + Float + Send + Sync + 'static,
325{
326 use crate::ndarray::{Ix1, Ix2, IxDyn};
327
328 if result.is::<Box<dyn ArrayProtocol>>() {
330 return result;
331 }
332
333 let result = match result.downcast::<NdarrayWrapper<T, Ix2>>() {
335 Ok(wrapper) => {
336 let boxed: Box<dyn ArrayProtocol> = wrapper;
337 return Box::new(boxed);
338 }
339 Err(other) => other,
340 };
341
342 let result = match result.downcast::<NdarrayWrapper<T, Ix1>>() {
344 Ok(wrapper) => {
345 let boxed: Box<dyn ArrayProtocol> = wrapper;
346 return Box::new(boxed);
347 }
348 Err(other) => other,
349 };
350
351 match result.downcast::<NdarrayWrapper<T, IxDyn>>() {
353 Ok(wrapper) => {
354 let boxed: Box<dyn ArrayProtocol> = wrapper;
355 Box::new(boxed)
356 }
357 Err(other) => other,
359 }
360}
361
362impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
364where
365 T: Clone + Float + Send + Sync + 'static,
366 D: Dimension + Send + Sync + 'static,
367{
368 fn array_function(
369 &self,
370 func: &ArrayFunction,
371 types: &[TypeId],
372 args: &[Box<dyn Any>],
373 kwargs: &HashMap<String, Box<dyn Any>>,
374 ) -> Result<Box<dyn Any>, NotImplemented> {
375 let wrapped_self = NdarrayWrapper::new(self.array.clone());
379
380 let precision = kwargs
384 .get("precision")
385 .and_then(|p| p.downcast_ref::<Precision>())
386 .cloned()
387 .unwrap_or(self.computeprecision);
388
389 match func.name {
390 "scirs2::array_protocol::operations::matmul"
391 | "scirs2::array_protocol::operations::add"
392 | "scirs2::array_protocol::operations::subtract"
393 | "scirs2::array_protocol::operations::multiply" => {
394 if args.len() < 2 {
400 return Err(NotImplemented);
401 }
402
403 let Some(other_array) = extract_inner_ndarray::<T, D>(args[1].as_ref()) else {
404 return Err(NotImplemented);
405 };
406 let wrapped_other = NdarrayWrapper::new(other_array);
407
408 if matches!(precision, Precision::Half) {
411 return Err(NotImplemented);
412 }
413
414 let new_args: Vec<Box<dyn Any>> =
415 vec![Box::new(wrapped_self.clone()), Box::new(wrapped_other)];
416 wrapped_self
417 .array_function(func, types, &new_args, kwargs)
418 .map(rewrap_result_as_array_protocol::<T>)
419 }
420 "scirs2::array_protocol::operations::transpose"
421 | "scirs2::array_protocol::operations::reshape"
422 | "scirs2::array_protocol::operations::sum" => {
423 let new_args: Vec<Box<dyn Any>> = vec![Box::new(wrapped_self.clone())];
427 wrapped_self
428 .array_function(func, types, &new_args, kwargs)
429 .map(rewrap_result_as_array_protocol::<T>)
430 }
431 _ => {
432 wrapped_self.array_function(func, types, args, kwargs)
435 }
436 }
437 }
438
439 fn as_any(&self) -> &dyn Any {
440 self
441 }
442
443 fn shape(&self) -> &[usize] {
444 self.array.shape()
445 }
446
447 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
448 Box::new(Self {
449 array: self.array.clone(),
450 storage_precision: self.storage_precision,
451 computeprecision: self.computeprecision,
452 })
453 }
454}
455
456impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
458where
459 T: Clone + Float + Send + Sync + 'static,
460 D: Dimension + Send + Sync + 'static,
461{
462 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
463 match precision {
464 Precision::Single => {
465 let current_precision = self.precision();
469 if current_precision == Precision::Single {
470 return Ok(Box::new(self.clone()));
472 }
473
474 let array_single = self.array.clone();
477 let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
478 Ok(Box::new(newarray))
479 }
480 Precision::Double => {
481 let current_precision = self.precision();
484 if current_precision == Precision::Double {
485 return Ok(Box::new(self.clone()));
487 }
488
489 let array_double = self.array.clone();
492 let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
493 Ok(Box::new(newarray))
494 }
495 Precision::Mixed => {
496 let array_mixed = self.array.clone();
498 let newarray =
499 MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
500 Ok(Box::new(newarray))
501 }
502 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
503 "Conversion to {precision} precision not implemented"
504 )))),
505 }
506 }
507
508 fn precision(&self) -> Precision {
509 if self.storage_precision != self.computeprecision {
511 Precision::Mixed
512 } else {
513 self.storage_precision
514 }
515 }
516
517 fn supports_precision(&self, precision: Precision) -> bool {
518 matches!(precision, Precision::Single | Precision::Double)
519 }
520
521 fn as_array_protocol(&self) -> &dyn ArrayProtocol {
522 self
523 }
524}
525
526impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
528where
529 T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
530 D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
531{
532 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
533 let mut config = self.config().clone();
535 config.mixed_precision = precision == Precision::Mixed;
536
537 if let Ok(cpu_array) = self.to_cpu() {
538 if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
540 let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
541 return Ok(Box::new(new_gpu_array));
542 }
543 }
544
545 Err(CoreError::NotImplementedError(ErrorContext::new(format!(
546 "Conversion to {precision} precision not implemented for GPU arrays"
547 ))))
548 }
549
550 fn precision(&self) -> Precision {
551 if self.config().mixed_precision {
552 Precision::Mixed
553 } else {
554 match std::mem::size_of::<T>() {
555 4 => Precision::Single,
556 8 => Precision::Double,
557 _ => Precision::Mixed,
558 }
559 }
560 }
561
562 fn supports_precision(&self, precision: Precision) -> bool {
563 true
565 }
566
567 fn as_array_protocol(&self) -> &dyn ArrayProtocol {
568 self
569 }
570}
571
572#[allow(dead_code)]
577pub fn execute_with_precision<F, R>(
578 arrays: &[&dyn MixedPrecisionSupport],
579 precision: Precision,
580 executor: F,
581) -> CoreResult<R>
582where
583 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
584 R: 'static,
585{
586 for array in arrays {
588 if !array.supports_precision(precision) {
589 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
590 "One or more arrays do not support {precision} precision"
591 ))));
592 }
593 }
594
595 let mut converted_arrays: Vec<Box<dyn MixedPrecisionSupport>> =
598 Vec::with_capacity(arrays.len());
599
600 for &array in arrays {
601 let converted = array.to_precision(precision)?;
602 converted_arrays.push(converted);
603 }
604
605 let protocol_refs: Vec<&dyn ArrayProtocol> = converted_arrays
612 .iter()
613 .map(|array| array.as_array_protocol())
614 .collect();
615
616 executor(&protocol_refs)
618}
619
620pub mod ops {
622 use super::*;
623 use crate::array_protocol::operations as array_ops;
624
625 pub fn matmul(
627 a: &dyn MixedPrecisionSupport,
628 b: &dyn MixedPrecisionSupport,
629 precision: Precision,
630 ) -> CoreResult<Box<dyn ArrayProtocol>> {
631 execute_with_precision(&[a, b], precision, |arrays| {
632 match array_ops::matmul(arrays[0], arrays[1]) {
634 Ok(result) => Ok(result),
635 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
636 e.to_string(),
637 ))),
638 }
639 })
640 }
641
642 pub fn add(
644 a: &dyn MixedPrecisionSupport,
645 b: &dyn MixedPrecisionSupport,
646 precision: Precision,
647 ) -> CoreResult<Box<dyn ArrayProtocol>> {
648 execute_with_precision(&[a, b], precision, |arrays| {
649 match array_ops::add(arrays[0], arrays[1]) {
651 Ok(result) => Ok(result),
652 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
653 e.to_string(),
654 ))),
655 }
656 })
657 }
658
659 pub fn multiply(
661 a: &dyn MixedPrecisionSupport,
662 b: &dyn MixedPrecisionSupport,
663 precision: Precision,
664 ) -> CoreResult<Box<dyn ArrayProtocol>> {
665 execute_with_precision(&[a, b], precision, |arrays| {
666 match array_ops::multiply(arrays[0], arrays[1]) {
668 Ok(result) => Ok(result),
669 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
670 e.to_string(),
671 ))),
672 }
673 })
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680 use ::ndarray::arr2;
681
682 #[test]
683 fn test_mixed_precision_array() {
684 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
686 let mixed_array = MixedPrecisionArray::new(array.clone());
687
688 assert_eq!(mixed_array.storage_precision(), Precision::Double);
690
691 let array_protocol: &dyn ArrayProtocol = &mixed_array;
693 assert!(array_protocol
695 .as_any()
696 .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
697 }
698
699 #[test]
700 fn test_mixed_precision_support() {
701 crate::array_protocol::init();
703
704 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
706 let mixed_array = MixedPrecisionArray::new(array.clone());
707
708 let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
710 assert_eq!(mixed_support.precision(), Precision::Double);
711 assert!(mixed_support.supports_precision(Precision::Single));
712 assert!(mixed_support.supports_precision(Precision::Double));
713 }
714
715 #[test]
719 fn test_at_precision_f64_to_f32() {
720 use ::ndarray::array;
721 let arr = array![1.0_f64, 2.5_f64, -1.75_f64].into_dyn();
723 let mp = MixedPrecisionArray::new(arr);
724 let as_f32: crate::ndarray::ArrayD<f32> = mp
725 .at_precision()
726 .expect("f64 → f32 precision conversion should succeed");
727 assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
728 assert!((as_f32[1] - 2.5_f32).abs() < 1e-6);
729 assert!((as_f32[2] - (-1.75_f32)).abs() < 1e-6);
730 }
731
732 #[test]
734 fn test_at_precision_f32_to_f64() {
735 use ::ndarray::array;
736 let arr = array![0.5_f32, 1.25_f32, -2.0_f32].into_dyn();
737 let mp = MixedPrecisionArray::new(arr);
738 let as_f64: crate::ndarray::ArrayD<f64> = mp
739 .at_precision()
740 .expect("f32 → f64 precision conversion should succeed");
741 assert!((as_f64[0] - 0.5_f64).abs() < 1e-12);
742 assert!((as_f64[1] - 1.25_f64).abs() < 1e-12);
743 assert!((as_f64[2] - (-2.0_f64)).abs() < 1e-12);
744 }
745
746 #[test]
748 fn test_at_precision_same_type_is_identity() {
749 use ::ndarray::array;
750 let arr = array![42.0_f64, -7.5_f64].into_dyn();
751 let mp = MixedPrecisionArray::new(arr.clone());
752 let result: crate::ndarray::ArrayD<f64> = mp
753 .at_precision()
754 .expect("f64 → f64 precision conversion should succeed");
755 for (a, b) in arr.iter().zip(result.iter()) {
756 assert_eq!(*a, *b, "Identity conversion must not change values");
757 }
758 }
759
760 #[test]
762 fn test_at_precision_preserves_shape() {
763 let arr = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
764 let mp = MixedPrecisionArray::new(arr);
765 let as_f32: crate::ndarray::Array<f32, crate::ndarray::Ix2> = mp
766 .at_precision()
767 .expect("2D f64 → f32 conversion should succeed");
768 assert_eq!(as_f32.shape(), &[2, 2]);
769 assert!((as_f32[[0, 0]] - 1.0_f32).abs() < 1e-6);
770 assert!((as_f32[[1, 1]] - 4.0_f32).abs() < 1e-6);
771 }
772
773 #[test]
778 fn test_execute_with_precision_matmul_single() {
779 crate::array_protocol::init();
780
781 let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
783 let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
784
785 let result = ops::matmul(&a, &b, Precision::Single)
786 .expect("mixed-precision matmul should succeed on stable Rust");
787
788 let wrapper = result
789 .as_any()
790 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
791 .expect("matmul result should be an NdarrayWrapper<f64, Ix2>");
792 let out = wrapper.as_array();
793
794 assert_eq!(out.shape(), &[2, 2]);
795 assert!((out[[0, 0]] - 19.0).abs() < 1e-9);
796 assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
797 assert!((out[[1, 0]] - 43.0).abs() < 1e-9);
798 assert!((out[[1, 1]] - 50.0).abs() < 1e-9);
799 }
800
801 #[test]
804 fn test_execute_with_precision_add_single() {
805 crate::array_protocol::init();
806
807 let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
808 let b = MixedPrecisionArray::new(arr2(&[[10.0_f64, 20.0], [30.0, 40.0]]));
809
810 let result = ops::add(&a, &b, Precision::Single)
811 .expect("mixed-precision add should succeed on stable Rust");
812
813 let wrapper = result
814 .as_any()
815 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
816 .expect("add result should be an NdarrayWrapper<f64, Ix2>");
817 let out = wrapper.as_array();
818
819 assert_eq!(out.shape(), &[2, 2]);
820 assert!((out[[0, 0]] - 11.0).abs() < 1e-9);
821 assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
822 assert!((out[[1, 0]] - 33.0).abs() < 1e-9);
823 assert!((out[[1, 1]] - 44.0).abs() < 1e-9);
824 }
825
826 #[test]
830 fn test_execute_with_precision_half_is_rejected() {
831 crate::array_protocol::init();
832
833 let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
834 let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
835
836 let result = ops::matmul(&a, &b, Precision::Half);
838 assert!(
839 result.is_err(),
840 "Half precision matmul must return an error"
841 );
842 }
843}