1use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::fmt;
22use std::sync::{LazyLock, RwLock};
23
24use ndarray::{Array, Dimension};
25use num_traits::Float;
26
27use crate::array_protocol::gpu_impl::GPUNdarray;
28use crate::array_protocol::{
29 ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
30};
31use crate::error::{CoreError, CoreResult, ErrorContext};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum Precision {
36 Half,
38
39 Single,
41
42 Double,
44
45 Mixed,
47}
48
49impl fmt::Display for Precision {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 Precision::Half => write!(f, "half"),
53 Precision::Single => write!(f, "single"),
54 Precision::Double => write!(f, "double"),
55 Precision::Mixed => write!(f, "mixed"),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct MixedPrecisionConfig {
63 pub storage_precision: Precision,
65
66 pub computeprecision: Precision,
68
69 pub auto_precision: bool,
71
72 pub downcast_threshold: usize,
74
75 pub double_precision_accumulation: bool,
77}
78
79impl Default for MixedPrecisionConfig {
80 fn default() -> Self {
81 Self {
82 storage_precision: Precision::Single,
83 computeprecision: Precision::Double,
84 auto_precision: true,
85 downcast_threshold: 10_000_000, double_precision_accumulation: true,
87 }
88 }
89}
90
91pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
93 RwLock::new(MixedPrecisionConfig {
94 storage_precision: Precision::Single,
95 computeprecision: Precision::Double,
96 auto_precision: true,
97 downcast_threshold: 10_000_000, double_precision_accumulation: true,
99 })
100});
101
102#[allow(dead_code)]
104pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
105 if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
106 *global_config = config;
107 }
108}
109
110#[allow(dead_code)]
112pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
113 MIXED_PRECISION_CONFIG
114 .read()
115 .map(|c| c.clone())
116 .unwrap_or_default()
117}
118
119#[allow(dead_code)]
121pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
122where
123 T: Clone + 'static,
124 D: Dimension,
125{
126 let config = get_mixed_precision_config();
127 let size = array.len();
128
129 if config.auto_precision {
130 if size >= config.downcast_threshold {
131 Precision::Single
132 } else {
133 Precision::Double
134 }
135 } else {
136 config.storage_precision
137 }
138}
139
140#[derive(Debug, Clone)]
145pub struct MixedPrecisionArray<T, D>
146where
147 T: Clone + 'static,
148 D: Dimension,
149{
150 array: Array<T, D>,
152
153 storage_precision: Precision,
155
156 computeprecision: Precision,
158}
159
160impl<T, D> MixedPrecisionArray<T, D>
161where
162 T: Clone + Float + 'static,
163 D: Dimension,
164{
165 pub fn new(array: Array<T, D>) -> Self {
167 let precision = match std::mem::size_of::<T>() {
168 2 => Precision::Half,
169 4 => Precision::Single,
170 8 => Precision::Double,
171 _ => Precision::Mixed,
172 };
173
174 Self {
175 array,
176 storage_precision: precision,
177 computeprecision: precision,
178 }
179 }
180
181 pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
183 let storage_precision = match std::mem::size_of::<T>() {
184 2 => Precision::Half,
185 4 => Precision::Single,
186 8 => Precision::Double,
187 _ => Precision::Mixed,
188 };
189
190 Self {
191 array: data,
192 storage_precision,
193 computeprecision,
194 }
195 }
196
197 pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
202 where
203 U: Clone + Float + 'static,
204 {
205 Err(CoreError::NotImplementedError(ErrorContext::new(
208 "Precision conversion not fully implemented yet",
209 )))
210 }
211
212 pub fn storage_precision(&self) -> Precision {
214 self.storage_precision
215 }
216
217 pub const fn array(&self) -> &Array<T, D> {
219 &self.array
220 }
221}
222
223pub trait MixedPrecisionSupport: ArrayProtocol {
225 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
227
228 fn precision(&self) -> Precision;
230
231 fn supports_precision(&self, precision: Precision) -> bool;
233}
234
235impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
237where
238 T: Clone + Float + Send + Sync + 'static,
239 D: Dimension + Send + Sync + 'static,
240{
241 fn array_function(
242 &self,
243 func: &ArrayFunction,
244 types: &[TypeId],
245 args: &[Box<dyn Any>],
246 kwargs: &HashMap<String, Box<dyn Any>>,
247 ) -> Result<Box<dyn Any>, NotImplemented> {
248 let precision = kwargs
250 .get("precision")
251 .and_then(|p| p.downcast_ref::<Precision>())
252 .cloned()
253 .unwrap_or(self.computeprecision);
254
255 match func.name {
257 "scirs2::array_protocol::operations::matmul" => {
258 if args.len() >= 2 {
260 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
262 let other_precision = other.computeprecision;
263 let _precision_to_use = match (precision, other_precision) {
264 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
265 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
266 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
267 (Precision::Half, Precision::Half) => Precision::Half,
268 };
269
270 let wrapped_self = NdarrayWrapper::new(self.array.clone());
273
274 return wrapped_self.array_function(func, types, args, kwargs);
276 }
277 }
278
279 match precision {
281 Precision::Single | Precision::Double => {
282 let wrapped = NdarrayWrapper::new(self.array.clone());
284
285 let mut new_args = Vec::with_capacity(args.len());
287 new_args.push(Box::new(wrapped.clone()));
288
289 wrapped.array_function(func, types, args, kwargs)
293 }
294 Precision::Mixed => {
295 let wrapped = NdarrayWrapper::new(self.array.clone());
297
298 let mut new_args = Vec::with_capacity(args.len());
300 new_args.push(Box::new(wrapped.clone()));
301
302 wrapped.array_function(func, types, args, kwargs)
305 }
306 _ => Err(NotImplemented),
307 }
308 }
309 "scirs2::array_protocol::operations::add"
310 | "scirs2::array_protocol::operations::subtract"
311 | "scirs2::array_protocol::operations::multiply" => {
312 if args.len() >= 2 {
315 if let Some(other) = args[1].downcast_ref::<MixedPrecisionArray<T, D>>() {
316 let other_precision = other.computeprecision;
318 let _precision_to_use = match (precision, other_precision) {
319 (Precision::Double, _) | (_, Precision::Double) => Precision::Double,
320 (Precision::Mixed, _) | (_, Precision::Mixed) => Precision::Mixed,
321 (Precision::Single, _) | (_, Precision::Single) => Precision::Single,
322 (Precision::Half, Precision::Half) => Precision::Half,
323 };
324
325 let wrapped_self = NdarrayWrapper::new(self.array.clone());
328
329 return wrapped_self.array_function(func, types, args, kwargs);
331 }
332 }
333
334 let wrapped = NdarrayWrapper::new(self.array.clone());
336
337 wrapped.array_function(func, types, args, kwargs)
339 }
340 "scirs2::array_protocol::operations::transpose"
341 | "scirs2::array_protocol::operations::reshape"
342 | "scirs2::array_protocol::operations::sum" => {
343 let wrapped = NdarrayWrapper::new(self.array.clone());
346
347 wrapped.array_function(func, types, args, kwargs)
349 }
350 _ => {
351 let wrapped = NdarrayWrapper::new(self.array.clone());
353 wrapped.array_function(func, types, args, kwargs)
354 }
355 }
356 }
357
358 fn as_any(&self) -> &dyn Any {
359 self
360 }
361
362 fn shape(&self) -> &[usize] {
363 self.array.shape()
364 }
365
366 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
367 Box::new(Self {
368 array: self.array.clone(),
369 storage_precision: self.storage_precision,
370 computeprecision: self.computeprecision,
371 })
372 }
373}
374
375impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
377where
378 T: Clone + Float + Send + Sync + 'static,
379 D: Dimension + Send + Sync + 'static,
380{
381 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
382 match precision {
383 Precision::Single => {
384 let current_precision = self.precision();
388 if current_precision == Precision::Single {
389 return Ok(Box::new(self.clone()));
391 }
392
393 let array_single = self.array.clone();
396 let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
397 Ok(Box::new(newarray))
398 }
399 Precision::Double => {
400 let current_precision = self.precision();
403 if current_precision == Precision::Double {
404 return Ok(Box::new(self.clone()));
406 }
407
408 let array_double = self.array.clone();
411 let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
412 Ok(Box::new(newarray))
413 }
414 Precision::Mixed => {
415 let array_mixed = self.array.clone();
417 let newarray =
418 MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
419 Ok(Box::new(newarray))
420 }
421 _ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
422 "Conversion to {precision} precision not implemented"
423 )))),
424 }
425 }
426
427 fn precision(&self) -> Precision {
428 if self.storage_precision != self.computeprecision {
430 Precision::Mixed
431 } else {
432 self.storage_precision
433 }
434 }
435
436 fn supports_precision(&self, precision: Precision) -> bool {
437 matches!(precision, Precision::Single | Precision::Double)
438 }
439}
440
441impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
443where
444 T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
445 D: Dimension + Send + Sync + 'static + ndarray::RemoveAxis,
446{
447 fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
448 let mut config = self.config().clone();
450 config.mixed_precision = precision == Precision::Mixed;
451
452 if let Ok(cpu_array) = self.to_cpu() {
453 if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
455 let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
456 return Ok(Box::new(new_gpu_array));
457 }
458 }
459
460 Err(CoreError::NotImplementedError(ErrorContext::new(format!(
461 "Conversion to {precision} precision not implemented for GPU arrays"
462 ))))
463 }
464
465 fn precision(&self) -> Precision {
466 if self.config().mixed_precision {
467 Precision::Mixed
468 } else {
469 match std::mem::size_of::<T>() {
470 4 => Precision::Single,
471 8 => Precision::Double,
472 _ => Precision::Mixed,
473 }
474 }
475 }
476
477 fn supports_precision(&self, precision: Precision) -> bool {
478 true
480 }
481}
482
483#[allow(dead_code)]
488pub fn execute_with_precision<F, R>(
489 arrays: &[&dyn MixedPrecisionSupport],
490 precision: Precision,
491 executor: F,
492) -> CoreResult<R>
493where
494 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
495 R: 'static,
496{
497 for array in arrays {
499 if !array.supports_precision(precision) {
500 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
501 "One or more arrays do not support {precision} precision"
502 ))));
503 }
504 }
505
506 let mut converted_arrays = Vec::with_capacity(arrays.len());
508
509 for &array in arrays {
510 let converted = array.to_precision(precision)?;
511 converted_arrays.push(converted);
512 }
513
514 let array_refs: Vec<&dyn ArrayProtocol> = converted_arrays
516 .iter()
517 .map(|arr| arr.as_ref() as &dyn ArrayProtocol)
518 .collect();
519
520 executor(&array_refs)
522}
523
524pub mod ops {
526 use super::*;
527 use crate::array_protocol::operations as array_ops;
528
529 pub fn matmul(
531 a: &dyn MixedPrecisionSupport,
532 b: &dyn MixedPrecisionSupport,
533 precision: Precision,
534 ) -> CoreResult<Box<dyn ArrayProtocol>> {
535 execute_with_precision(&[a, b], precision, |arrays| {
536 match array_ops::matmul(arrays[0], arrays[1]) {
538 Ok(result) => Ok(result),
539 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
540 e.to_string(),
541 ))),
542 }
543 })
544 }
545
546 pub fn add(
548 a: &dyn MixedPrecisionSupport,
549 b: &dyn MixedPrecisionSupport,
550 precision: Precision,
551 ) -> CoreResult<Box<dyn ArrayProtocol>> {
552 execute_with_precision(&[a, b], precision, |arrays| {
553 match array_ops::add(arrays[0], arrays[1]) {
555 Ok(result) => Ok(result),
556 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
557 e.to_string(),
558 ))),
559 }
560 })
561 }
562
563 pub fn multiply(
565 a: &dyn MixedPrecisionSupport,
566 b: &dyn MixedPrecisionSupport,
567 precision: Precision,
568 ) -> CoreResult<Box<dyn ArrayProtocol>> {
569 execute_with_precision(&[a, b], precision, |arrays| {
570 match array_ops::multiply(arrays[0], arrays[1]) {
572 Ok(result) => Ok(result),
573 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
574 e.to_string(),
575 ))),
576 }
577 })
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use ndarray::arr2;
585
586 #[test]
587 fn test_mixed_precision_array() {
588 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
590 let mixed_array = MixedPrecisionArray::new(array.clone());
591
592 assert_eq!(mixed_array.storage_precision(), Precision::Double);
594
595 let array_protocol: &dyn ArrayProtocol = &mixed_array;
597 assert!(array_protocol
599 .as_any()
600 .is::<MixedPrecisionArray<f64, ndarray::Ix2>>());
601 }
602
603 #[test]
604 fn test_mixed_precision_support() {
605 crate::array_protocol::init();
607
608 let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
610 let mixed_array = MixedPrecisionArray::new(array.clone());
611
612 let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
614 assert_eq!(mixed_support.precision(), Precision::Double);
615 assert!(mixed_support.supports_precision(Precision::Single));
616 assert!(mixed_support.supports_precision(Precision::Double));
617 }
618}