1use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::{LazyLock, RwLock};
17
18use ::ndarray::{Array, Dimension};
19use num_traits::Float;
20
21use crate::array_protocol::gpu_impl::GPUNdarray;
22use crate::array_protocol::{
23 ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
24};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Precision {
30 Half,
32
33 Single,
35
36 Double,
38
39 Mixed,
41}
42
43impl fmt::Display for Precision {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Precision::Half => write!(f, "half"),
47 Precision::Single => write!(f, "single"),
48 Precision::Double => write!(f, "double"),
49 Precision::Mixed => write!(f, "mixed"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57 pub storage_precision: Precision,
59
60 pub computeprecision: Precision,
62
63 pub auto_precision: bool,
65
66 pub downcast_threshold: usize,
68
69 pub double_precision_accumulation: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74 fn default() -> Self {
75 Self {
76 storage_precision: Precision::Single,
77 computeprecision: Precision::Double,
78 auto_precision: true,
79 downcast_threshold: 10_000_000, double_precision_accumulation: true,
81 }
82 }
83}
84
85pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
87 RwLock::new(MixedPrecisionConfig {
88 storage_precision: Precision::Single,
89 computeprecision: Precision::Double,
90 auto_precision: true,
91 downcast_threshold: 10_000_000, double_precision_accumulation: true,
93 })
94});
95
96#[allow(dead_code)]
98pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
99 if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
100 *global_config = config;
101 }
102}
103
104#[allow(dead_code)]
106pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
107 MIXED_PRECISION_CONFIG
108 .read()
109 .map(|c| c.clone())
110 .unwrap_or_default()
111}
112
113#[allow(dead_code)]
115pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
116where
117 T: Clone + 'static,
118 D: Dimension,
119{
120 let config = get_mixed_precision_config();
121 let size = array.len();
122
123 if config.auto_precision {
124 if size >= config.downcast_threshold {
125 Precision::Single
126 } else {
127 Precision::Double
128 }
129 } else {
130 config.storage_precision
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct MixedPrecisionArray<T, D>
140where
141 T: Clone + 'static,
142 D: Dimension,
143{
144 array: Array<T, D>,
146
147 storage_precision: Precision,
149
150 computeprecision: Precision,
152}
153
154impl<T, D> MixedPrecisionArray<T, D>
155where
156 T: Clone + Float + 'static,
157 D: Dimension,
158{
159 pub fn new(array: Array<T, D>) -> Self {
161 let precision = match std::mem::size_of::<T>() {
162 2 => Precision::Half,
163 4 => Precision::Single,
164 8 => Precision::Double,
165 _ => Precision::Mixed,
166 };
167
168 Self {
169 array,
170 storage_precision: precision,
171 computeprecision: precision,
172 }
173 }
174
175 pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
177 let storage_precision = match std::mem::size_of::<T>() {
178 2 => Precision::Half,
179 4 => Precision::Single,
180 8 => Precision::Double,
181 _ => Precision::Mixed,
182 };
183
184 Self {
185 array: data,
186 storage_precision,
187 computeprecision,
188 }
189 }
190
191 pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
196 where
197 U: Clone + Float + 'static,
198 {
199 Err(CoreError::NotImplementedError(ErrorContext::new(
202 "Precision conversion not fully implemented yet",
203 )))
204 }
205
206 pub fn storage_precision(&self) -> Precision {
208 self.storage_precision
209 }
210
211 pub const fn array(&self) -> &Array<T, D> {
213 &self.array
214 }
215}
216
217pub trait MixedPrecisionSupport: ArrayProtocol {
219 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
221
222 fn precision(&self) -> Precision;
224
225 fn supports_precision(&self, precision: Precision) -> bool;
227}
228
229impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
231where
232 T: Clone + Float + Send + Sync + 'static,
233 D: Dimension + Send + Sync + 'static,
234{
235 fn array_function(
236 &self,
237 func: &ArrayFunction,
238 types: &[TypeId],
239 args: &[Box<dyn Any>],
240 kwargs: &HashMap<String, Box<dyn Any>>,
241 ) -> Result<Box<dyn Any>, NotImplemented> {
242 let precision = kwargs
244 .get("precision")
245 .and_then(|p| p.downcast_ref::<Precision>())
246 .cloned()
247 .unwrap_or(self.computeprecision);
248
249 match func.name {
251 "scirs2::array_protocol::operations::matmul" => {
252 if args.len() >= 2 {
254 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
256 let other_precision = other.computeprecision;
257 let _precision_to_use = match (precision, other_precision) {
258 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
259 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
260 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
261 (Precision::Half, Precision::Half) => Precision::Half,
262 };
263
264 let wrapped_self = NdarrayWrapper::new(self.array.clone());
267
268 return wrapped_self.array_function(func, types, args, kwargs);
270 }
271 }
272
273 match precision {
275 Precision::Single | Precision::Double => {
276 let wrapped = NdarrayWrapper::new(self.array.clone());
278
279 let mut new_args = Vec::with_capacity(args.len());
281 new_args.push(Box::new(wrapped.clone()));
282
283 wrapped.array_function(func, types, args, kwargs)
287 }
288 Precision::Mixed => {
289 let wrapped = NdarrayWrapper::new(self.array.clone());
291
292 let mut new_args = Vec::with_capacity(args.len());
294 new_args.push(Box::new(wrapped.clone()));
295
296 wrapped.array_function(func, types, args, kwargs)
299 }
300 _ => Err(NotImplemented),
301 }
302 }
303 "scirs2::array_protocol::operations::add"
304 | "scirs2::array_protocol::operations::subtract"
305 | "scirs2::array_protocol::operations::multiply" => {
306 if args.len() >= 2 {
309 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
310 let other_precision = other.computeprecision;
312 let _precision_to_use = match (precision, other_precision) {
313 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
314 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
315 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
316 (Precision::Half, Precision::Half) => Precision::Half,
317 };
318
319 let wrapped_self = NdarrayWrapper::new(self.array.clone());
322
323 return wrapped_self.array_function(func, types, args, kwargs);
325 }
326 }
327
328 let wrapped = NdarrayWrapper::new(self.array.clone());
330
331 wrapped.array_function(func, types, args, kwargs)
333 }
334 "scirs2::array_protocol::operations::transpose"
335 | "scirs2::array_protocol::operations::reshape"
336 | "scirs2::array_protocol::operations::sum" => {
337 let wrapped = NdarrayWrapper::new(self.array.clone());
340
341 wrapped.array_function(func, types, args, kwargs)
343 }
344 _ => {
345 let wrapped = NdarrayWrapper::new(self.array.clone());
347 wrapped.array_function(func, types, args, kwargs)
348 }
349 }
350 }
351
352 fn as_any(&self) -> &dyn Any {
353 self
354 }
355
356 fn shape(&self) -> &[usize] {
357 self.array.shape()
358 }
359
360 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
361 Box::new(Self {
362 array: self.array.clone(),
363 storage_precision: self.storage_precision,
364 computeprecision: self.computeprecision,
365 })
366 }
367}
368
369impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
371where
372 T: Clone + Float + Send + Sync + 'static,
373 D: Dimension + Send + Sync + 'static,
374{
375 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
376 match precision {
377 Precision::Single => {
378 let current_precision = self.precision();
382 if current_precision == Precision::Single {
383 return Ok(Box::new(self.clone()));
385 }
386
387 let array_single = self.array.clone();
390 let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
391 Ok(Box::new(newarray))
392 }
393 Precision::Double => {
394 let current_precision = self.precision();
397 if current_precision == Precision::Double {
398 return Ok(Box::new(self.clone()));
400 }
401
402 let array_double = self.array.clone();
405 let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
406 Ok(Box::new(newarray))
407 }
408 Precision::Mixed => {
409 let array_mixed = self.array.clone();
411 let newarray =
412 MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
413 Ok(Box::new(newarray))
414 }
415 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
416 "Conversion to {precision} precision not implemented"
417 )))),
418 }
419 }
420
421 fn precision(&self) -> Precision {
422 if self.storage_precision != self.computeprecision {
424 Precision::Mixed
425 } else {
426 self.storage_precision
427 }
428 }
429
430 fn supports_precision(&self, precision: Precision) -> bool {
431 matches!(precision, Precision::Single | Precision::Double)
432 }
433}
434
435impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
437where
438 T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
439 D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
440{
441 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
442 let mut config = self.config().clone();
444 config.mixed_precision = precision == Precision::Mixed;
445
446 if let Ok(cpu_array) = self.to_cpu() {
447 if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
449 let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
450 return Ok(Box::new(new_gpu_array));
451 }
452 }
453
454 Err(CoreError::NotImplementedError(ErrorContext::new(format!(
455 "Conversion to {precision} precision not implemented for GPU arrays"
456 ))))
457 }
458
459 fn precision(&self) -> Precision {
460 if self.config().mixed_precision {
461 Precision::Mixed
462 } else {
463 match std::mem::size_of::<T>() {
464 4 => Precision::Single,
465 8 => Precision::Double,
466 _ => Precision::Mixed,
467 }
468 }
469 }
470
471 fn supports_precision(&self, precision: Precision) -> bool {
472 true
474 }
475}
476
477#[allow(dead_code)]
482pub fn execute_with_precision<F, R>(
483 arrays: &[&dyn MixedPrecisionSupport],
484 precision: Precision,
485 executor: F,
486) -> CoreResult<R>
487where
488 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
489 R: 'static,
490{
491 for array in arrays {
493 if !array.supports_precision(precision) {
494 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
495 "One or more arrays do not support {precision} precision"
496 ))));
497 }
498 }
499
500 let mut converted_arrays = Vec::with_capacity(arrays.len());
502
503 for &array in arrays {
504 let converted = array.to_precision(precision)?;
505 converted_arrays.push(converted);
506 }
507
508 Err("Mixed precision batch execution not supported on stable Rust - requires trait_upcasting feature".to_string().into())
514}
515
516pub mod ops {
518 use super::*;
519 use crate::array_protocol::operations as array_ops;
520
521 pub fn matmul(
523 a: &dyn MixedPrecisionSupport,
524 b: &dyn MixedPrecisionSupport,
525 precision: Precision,
526 ) -> CoreResult<Box<dyn ArrayProtocol>> {
527 execute_with_precision(&[a, b], precision, |arrays| {
528 match array_ops::matmul(arrays[0], arrays[1]) {
530 Ok(result) => Ok(result),
531 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
532 e.to_string(),
533 ))),
534 }
535 })
536 }
537
538 pub fn add(
540 a: &dyn MixedPrecisionSupport,
541 b: &dyn MixedPrecisionSupport,
542 precision: Precision,
543 ) -> CoreResult<Box<dyn ArrayProtocol>> {
544 execute_with_precision(&[a, b], precision, |arrays| {
545 match array_ops::add(arrays[0], arrays[1]) {
547 Ok(result) => Ok(result),
548 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
549 e.to_string(),
550 ))),
551 }
552 })
553 }
554
555 pub fn multiply(
557 a: &dyn MixedPrecisionSupport,
558 b: &dyn MixedPrecisionSupport,
559 precision: Precision,
560 ) -> CoreResult<Box<dyn ArrayProtocol>> {
561 execute_with_precision(&[a, b], precision, |arrays| {
562 match array_ops::multiply(arrays[0], arrays[1]) {
564 Ok(result) => Ok(result),
565 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
566 e.to_string(),
567 ))),
568 }
569 })
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use ::ndarray::arr2;
577
578 #[test]
579 fn test_mixed_precision_array() {
580 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
582 let mixed_array = MixedPrecisionArray::new(array.clone());
583
584 assert_eq!(mixed_array.storage_precision(), Precision::Double);
586
587 let array_protocol: &dyn ArrayProtocol = &mixed_array;
589 assert!(array_protocol
591 .as_any()
592 .is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
593 }
594
595 #[test]
596 fn test_mixed_precision_support() {
597 crate::array_protocol::init();
599
600 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
602 let mixed_array = MixedPrecisionArray::new(array.clone());
603
604 let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
606 assert_eq!(mixed_support.precision(), Precision::Double);
607 assert!(mixed_support.supports_precision(Precision::Single));
608 assert!(mixed_support.supports_precision(Precision::Double));
609 }
610}