1use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23use ::ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
24use num_traits;
25
26use crate::array_protocol::{
27 ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
28 DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
29};
30use crate::error::CoreResult;
31
32#[derive(Debug, Clone)]
34pub struct AutoDeviceConfig {
35 pub gpu_threshold: usize,
37
38 pub distributed_threshold: usize,
40
41 pub enable_mixed_precision: bool,
43
44 pub prefer_memory_efficiency: bool,
46
47 pub auto_transfer: bool,
49
50 pub prefer_data_locality: bool,
52
53 pub preferred_gpu_backend: GPUBackend,
55
56 pub fallback_to_cpu: bool,
58}
59
60impl Default for AutoDeviceConfig {
61 fn default() -> Self {
62 Self {
63 gpu_threshold: 1_000_000, distributed_threshold: 100_000_000, enable_mixed_precision: false,
66 prefer_memory_efficiency: false,
67 auto_transfer: true,
68 prefer_data_locality: true,
69 preferred_gpu_backend: GPUBackend::CUDA,
70 fallback_to_cpu: true,
71 }
72 }
73}
74
75pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
77 gpu_threshold: 1_000_000,
78 distributed_threshold: 100_000_000,
79 enable_mixed_precision: false,
80 prefer_memory_efficiency: false,
81 auto_transfer: true,
82 prefer_data_locality: true,
83 preferred_gpu_backend: GPUBackend::CUDA,
84 fallback_to_cpu: true,
85});
86
87#[allow(dead_code)]
89pub fn set_auto_device_config(config: AutoDeviceConfig) {
90 if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
91 *global_config = config;
92 }
93}
94
95#[allow(dead_code)]
97pub fn get_auto_device_config() -> AutoDeviceConfig {
98 AUTO_DEVICE_CONFIG
99 .read()
100 .map(|c| c.clone())
101 .unwrap_or_default()
102}
103
104#[allow(dead_code)]
109pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
110where
111 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
112 D: Dimension + crate::ndarray::RemoveAxis,
113{
114 let config = get_auto_device_config();
115 let size = array.len();
116
117 if size >= config.distributed_threshold {
118 DeviceType::Distributed
119 } else if size >= config.gpu_threshold {
120 DeviceType::GPU
121 } else {
122 DeviceType::CPU
123 }
124}
125
126#[allow(dead_code)]
131pub fn determine_best_device_for_operation<T, D>(
132 arrays: &[&Array<T, D>],
133 operation: &str,
134) -> DeviceType
135where
136 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
137 D: Dimension + crate::ndarray::RemoveAxis,
138{
139 let config = get_auto_device_config();
140
141 let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
143
144 let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
146
147 let gpu_threshold = if is_complex_operation {
149 config.gpu_threshold / 10 } else {
151 config.gpu_threshold
152 };
153
154 let distributed_threshold = if is_complex_operation {
155 config.distributed_threshold / 2 } else {
157 config.distributed_threshold
158 };
159
160 if total_size >= distributed_threshold {
161 DeviceType::Distributed
162 } else if total_size >= gpu_threshold {
163 DeviceType::GPU
164 } else {
165 DeviceType::CPU
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum DeviceType {
172 CPU,
174
175 GPU,
177
178 Distributed,
180}
181
182#[allow(dead_code)]
187pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
188where
189 T: Clone
190 + Send
191 + Sync
192 + 'static
193 + num_traits::Zero
194 + std::ops::Div<f64, Output = T>
195 + Default
196 + std::ops::Mul<Output = T>
197 + std::ops::Add<Output = T>,
198 D: Dimension + crate::ndarray::RemoveAxis + 'static,
199 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
200 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
201{
202 match device {
203 DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
204 DeviceType::GPU => {
205 let config = get_auto_device_config();
206 let gpu_config = GPUConfig {
207 backend: config.preferred_gpu_backend,
208 device_id: 0,
209 async_ops: true,
210 mixed_precision: config.enable_mixed_precision,
211 memory_fraction: 0.9,
212 };
213
214 Box::new(GPUNdarray::new(array.clone(), gpu_config))
215 }
216 DeviceType::Distributed => {
217 let dist_config = DistributedConfig {
218 chunks: 2, balance: true,
220 strategy: DistributionStrategy::RowWise,
221 backend: DistributedBackend::Threaded,
222 };
223
224 Box::new(DistributedNdarray::from_array(&array, dist_config))
225 }
226 }
227}
228
229pub struct AutoDevice<T, D>
234where
235 T: Clone
236 + Send
237 + Sync
238 + 'static
239 + num_traits::Zero
240 + std::ops::Div<f64, Output = T>
241 + Default
242 + std::ops::Mul<Output = T>
243 + std::ops::Add<Output = T>,
244 D: Dimension + crate::ndarray::RemoveAxis,
245 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
246 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
247{
248 array: Array<T, D>,
250
251 device: DeviceType,
253
254 device_array: Option<Box<dyn ArrayProtocol>>,
256}
257
258impl<T, D> std::fmt::Debug for AutoDevice<T, D>
260where
261 T: Clone
262 + Send
263 + Sync
264 + std::fmt::Debug
265 + 'static
266 + num_traits::Zero
267 + std::ops::Div<f64, Output = T>
268 + Default
269 + std::ops::Mul<Output = T>
270 + std::ops::Add<Output = T>,
271 D: Dimension + crate::ndarray::RemoveAxis + std::fmt::Debug + 'static,
272 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
273 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
274{
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("AutoDevice")
277 .field("array", &self.array)
278 .field("device", &self.device)
279 .field("device_array", &self.device_array.is_some())
280 .finish()
281 }
282}
283
284impl<T, D> AutoDevice<T, D>
285where
286 T: Clone
287 + Send
288 + Sync
289 + 'static
290 + num_traits::Zero
291 + std::ops::Div<f64, Output = T>
292 + Default
293 + std::ops::Mul<Output = T>
294 + std::ops::Add<Output = T>,
295 D: Dimension + crate::ndarray::RemoveAxis + 'static,
296 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
297 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
298{
299 pub fn new(array: Array<T, D>) -> Self {
301 let device = determine_best_device(&array);
302 let device_array = None; Self {
305 array,
306 device,
307 device_array,
308 }
309 }
310
311 pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
313 if self.device != device || self.device_array.is_none() {
314 self.device = device;
316 self.device_array = Some(convert_to_device(self.array.clone(), device));
317 }
318
319 self.device_array
320 .as_ref()
321 .expect("Operation failed")
322 .as_ref()
323 }
324
325 pub fn device(&self) -> DeviceType {
327 self.device
328 }
329
330 pub const fn array(&self) -> &Array<T, D> {
332 &self.array
333 }
334}
335
336impl<T, D> Clone for AutoDevice<T, D>
337where
338 T: Clone
339 + Send
340 + Sync
341 + 'static
342 + num_traits::Zero
343 + std::ops::Div<f64, Output = T>
344 + Default
345 + std::ops::Mul<Output = T>
346 + std::ops::Add<Output = T>,
347 D: Dimension + crate::ndarray::RemoveAxis + 'static,
348 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
349 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
350{
351 fn clone(&self) -> Self {
352 Self {
353 array: self.array.clone(),
354 device: self.device,
355 device_array: self.device_array.clone(),
356 }
357 }
358}
359
360impl<T, D> ArrayProtocol for AutoDevice<T, D>
361where
362 T: Clone
363 + Send
364 + Sync
365 + 'static
366 + num_traits::Zero
367 + std::ops::Div<f64, Output = T>
368 + Default
369 + std::ops::Mul<Output = T>
370 + std::ops::Add<Output = T>,
371 D: Dimension + crate::ndarray::RemoveAxis + 'static,
372 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
373 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
374{
375 fn array_function(
376 &self,
377 func: &ArrayFunction,
378 types: &[TypeId],
379 args: &[Box<dyn Any>],
380 kwargs: &HashMap<String, Box<dyn Any>>,
381 ) -> Result<Box<dyn Any>, NotImplemented> {
382 if let Some(device_array) = &self.device_array {
384 device_array.array_function(func, types, args, kwargs)
385 } else {
386 let device = determine_best_device(&self.array);
388 let temp_array = convert_to_device(self.array.clone(), device);
389 temp_array.array_function(func, types, args, kwargs)
390 }
391 }
392
393 fn as_any(&self) -> &dyn Any {
394 self
395 }
396
397 fn shape(&self) -> &[usize] {
398 self.array.shape()
399 }
400
401 fn dtype(&self) -> TypeId {
402 TypeId::of::<T>()
403 }
404
405 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
406 Box::new(self.clone())
407 }
408}
409
410#[allow(dead_code)]
415pub fn auto_execute<T, D, F, R>(
416 arrays: &mut [&mut AutoDevice<T, D>],
417 operation: &str,
418 executor: F,
419) -> CoreResult<R>
420where
421 T: Clone
422 + Send
423 + Sync
424 + 'static
425 + num_traits::Zero
426 + std::ops::Div<f64, Output = T>
427 + Default
428 + std::ops::Mul<Output = T>
429 + std::ops::Add<Output = T>,
430 D: Dimension + crate::ndarray::RemoveAxis + 'static,
431 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
432 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
433 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
434 R: 'static,
435{
436 let best_device = determine_best_device_for_operation(
438 &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
439 operation,
440 );
441
442 let device_arrays: Vec<&dyn ArrayProtocol> = arrays
444 .iter_mut()
445 .map(|a| a.on_device(best_device))
446 .collect();
447
448 executor(&device_arrays)
450}
451
452pub mod ops {
454 use super::*;
455 use crate::array_protocol::operations as ap_ops;
456 use crate::error::{CoreError, ErrorContext};
457
458 pub fn matmul<T, D>(
460 a: &mut AutoDevice<T, D>,
461 b: &mut AutoDevice<T, D>,
462 ) -> CoreResult<Box<dyn ArrayProtocol>>
463 where
464 T: Clone
465 + Send
466 + Sync
467 + 'static
468 + num_traits::Zero
469 + std::ops::Div<f64, Output = T>
470 + Default
471 + std::ops::Mul<Output = T>
472 + std::ops::Add<Output = T>,
473 D: Dimension + crate::ndarray::RemoveAxis + 'static,
474 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
475 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
476 {
477 auto_execute(&mut [a, b], "matmul", |arrays| {
478 match ap_ops::matmul(arrays[0], arrays[1]) {
480 Ok(result) => Ok(result),
481 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
482 e.to_string(),
483 ))),
484 }
485 })
486 }
487
488 pub fn add<T, D>(
490 a: &mut AutoDevice<T, D>,
491 b: &mut AutoDevice<T, D>,
492 ) -> CoreResult<Box<dyn ArrayProtocol>>
493 where
494 T: Clone
495 + Send
496 + Sync
497 + 'static
498 + num_traits::Zero
499 + std::ops::Div<f64, Output = T>
500 + Default
501 + std::ops::Mul<Output = T>
502 + std::ops::Add<Output = T>,
503 D: Dimension + crate::ndarray::RemoveAxis + 'static,
504 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
505 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
506 {
507 auto_execute(&mut [a, b], "add", |arrays| {
508 match ap_ops::add(arrays[0], arrays[1]) {
510 Ok(result) => Ok(result),
511 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
512 e.to_string(),
513 ))),
514 }
515 })
516 }
517
518 pub fn multiply<T, D>(
520 a: &mut AutoDevice<T, D>,
521 b: &mut AutoDevice<T, D>,
522 ) -> CoreResult<Box<dyn ArrayProtocol>>
523 where
524 T: Clone
525 + Send
526 + Sync
527 + 'static
528 + num_traits::Zero
529 + std::ops::Div<f64, Output = T>
530 + Default
531 + std::ops::Mul<Output = T>
532 + std::ops::Add<Output = T>,
533 D: Dimension + crate::ndarray::RemoveAxis + 'static,
534 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
535 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
536 {
537 auto_execute(&mut [a, b], "multiply", |arrays| {
538 match ap_ops::multiply(arrays[0], arrays[1]) {
540 Ok(result) => Ok(result),
541 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
542 e.to_string(),
543 ))),
544 }
545 })
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use ::ndarray::{arr2, Array2};
553
554 #[test]
555 fn test_auto_device_selection() {
556 crate::array_protocol::init();
558
559 let small_array = Array2::<f64>::ones((10, 10));
561 let device = determine_best_device(&small_array);
562 assert_eq!(device, DeviceType::CPU);
563
564 let mut config = get_auto_device_config();
566 config.gpu_threshold = 50; set_auto_device_config(config);
568
569 let device = determine_best_device(&small_array);
571 assert_eq!(device, DeviceType::GPU);
572
573 set_auto_device_config(AutoDeviceConfig::default());
575 }
576
577 #[test]
578 fn test_auto_device_wrapper() {
579 crate::array_protocol::init();
581
582 let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
584 let array = array_2d.into_dyn();
585 let mut auto_array = AutoDevice::new(array.clone());
586
587 assert_eq!(auto_array.device(), DeviceType::CPU);
589
590 let gpu_array = auto_array.on_device(DeviceType::GPU);
592 assert!(gpu_array
593 .as_any()
594 .downcast_ref::<GPUNdarray<f64, crate::ndarray::IxDyn>>()
595 .is_some());
596
597 assert_eq!(auto_array.device(), DeviceType::GPU);
599 }
600}