1use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::sync::RwLock;
16
17use ::ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
18use num_traits;
19
20use crate::array_protocol::{
21 ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
22 DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
23};
24use crate::error::CoreResult;
25
26#[derive(Debug, Clone)]
28pub struct AutoDeviceConfig {
29 pub gpu_threshold: usize,
31
32 pub distributed_threshold: usize,
34
35 pub enable_mixed_precision: bool,
37
38 pub prefer_memory_efficiency: bool,
40
41 pub auto_transfer: bool,
43
44 pub prefer_data_locality: bool,
46
47 pub preferred_gpu_backend: GPUBackend,
49
50 pub fallback_to_cpu: bool,
52}
53
54impl Default for AutoDeviceConfig {
55 fn default() -> Self {
56 Self {
57 gpu_threshold: 1_000_000, distributed_threshold: 100_000_000, enable_mixed_precision: false,
60 prefer_memory_efficiency: false,
61 auto_transfer: true,
62 prefer_data_locality: true,
63 preferred_gpu_backend: GPUBackend::CUDA,
64 fallback_to_cpu: true,
65 }
66 }
67}
68
69pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
71 gpu_threshold: 1_000_000,
72 distributed_threshold: 100_000_000,
73 enable_mixed_precision: false,
74 prefer_memory_efficiency: false,
75 auto_transfer: true,
76 prefer_data_locality: true,
77 preferred_gpu_backend: GPUBackend::CUDA,
78 fallback_to_cpu: true,
79});
80
81#[allow(dead_code)]
83pub fn set_auto_device_config(config: AutoDeviceConfig) {
84 if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
85 *global_config = config;
86 }
87}
88
89#[allow(dead_code)]
91pub fn get_auto_device_config() -> AutoDeviceConfig {
92 AUTO_DEVICE_CONFIG
93 .read()
94 .map(|c| c.clone())
95 .unwrap_or_default()
96}
97
98#[allow(dead_code)]
103pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
104where
105 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
106 D: Dimension + crate::ndarray::RemoveAxis,
107{
108 let config = get_auto_device_config();
109 let size = array.len();
110
111 if size >= config.distributed_threshold {
112 DeviceType::Distributed
113 } else if size >= config.gpu_threshold {
114 DeviceType::GPU
115 } else {
116 DeviceType::CPU
117 }
118}
119
120#[allow(dead_code)]
125pub fn determine_best_device_for_operation<T, D>(
126 arrays: &[&Array<T, D>],
127 operation: &str,
128) -> DeviceType
129where
130 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
131 D: Dimension + crate::ndarray::RemoveAxis,
132{
133 let config = get_auto_device_config();
134
135 let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
137
138 let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
140
141 let gpu_threshold = if is_complex_operation {
143 config.gpu_threshold / 10 } else {
145 config.gpu_threshold
146 };
147
148 let distributed_threshold = if is_complex_operation {
149 config.distributed_threshold / 2 } else {
151 config.distributed_threshold
152 };
153
154 if total_size >= distributed_threshold {
155 DeviceType::Distributed
156 } else if total_size >= gpu_threshold {
157 DeviceType::GPU
158 } else {
159 DeviceType::CPU
160 }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum DeviceType {
166 CPU,
168
169 GPU,
171
172 Distributed,
174}
175
176#[allow(dead_code)]
181pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
182where
183 T: Clone
184 + Send
185 + Sync
186 + 'static
187 + num_traits::Zero
188 + std::ops::Div<f64, Output = T>
189 + Default
190 + std::ops::Mul<Output = T>
191 + std::ops::Add<Output = T>,
192 D: Dimension + crate::ndarray::RemoveAxis + 'static,
193 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
194 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
195{
196 match device {
197 DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
198 DeviceType::GPU => {
199 let config = get_auto_device_config();
200 let gpu_config = GPUConfig {
201 backend: config.preferred_gpu_backend,
202 device_id: 0,
203 async_ops: true,
204 mixed_precision: config.enable_mixed_precision,
205 memory_fraction: 0.9,
206 };
207
208 Box::new(GPUNdarray::new(array.clone(), gpu_config))
209 }
210 DeviceType::Distributed => {
211 let dist_config = DistributedConfig {
212 chunks: 2, balance: true,
214 strategy: DistributionStrategy::RowWise,
215 backend: DistributedBackend::Threaded,
216 };
217
218 Box::new(DistributedNdarray::from_array(&array, dist_config))
219 }
220 }
221}
222
223pub struct AutoDevice<T, D>
228where
229 T: Clone
230 + Send
231 + Sync
232 + 'static
233 + num_traits::Zero
234 + std::ops::Div<f64, Output = T>
235 + Default
236 + std::ops::Mul<Output = T>
237 + std::ops::Add<Output = T>,
238 D: Dimension + crate::ndarray::RemoveAxis,
239 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
240 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
241{
242 array: Array<T, D>,
244
245 device: DeviceType,
247
248 device_array: Option<Box<dyn ArrayProtocol>>,
250}
251
252impl<T, D> std::fmt::Debug for AutoDevice<T, D>
254where
255 T: Clone
256 + Send
257 + Sync
258 + std::fmt::Debug
259 + 'static
260 + num_traits::Zero
261 + std::ops::Div<f64, Output = T>
262 + Default
263 + std::ops::Mul<Output = T>
264 + std::ops::Add<Output = T>,
265 D: Dimension + crate::ndarray::RemoveAxis + std::fmt::Debug + 'static,
266 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
267 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
268{
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.debug_struct("AutoDevice")
271 .field("array", &self.array)
272 .field("device", &self.device)
273 .field("device_array", &self.device_array.is_some())
274 .finish()
275 }
276}
277
278impl<T, D> AutoDevice<T, D>
279where
280 T: Clone
281 + Send
282 + Sync
283 + 'static
284 + num_traits::Zero
285 + std::ops::Div<f64, Output = T>
286 + Default
287 + std::ops::Mul<Output = T>
288 + std::ops::Add<Output = T>,
289 D: Dimension + crate::ndarray::RemoveAxis + 'static,
290 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
291 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
292{
293 pub fn new(array: Array<T, D>) -> Self {
295 let device = determine_best_device(&array);
296 let device_array = None; Self {
299 array,
300 device,
301 device_array,
302 }
303 }
304
305 pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
307 if self.device != device || self.device_array.is_none() {
308 self.device = device;
310 self.device_array = Some(convert_to_device(self.array.clone(), device));
311 }
312
313 self.device_array
314 .as_ref()
315 .expect("Operation failed")
316 .as_ref()
317 }
318
319 pub fn device(&self) -> DeviceType {
321 self.device
322 }
323
324 pub const fn array(&self) -> &Array<T, D> {
326 &self.array
327 }
328}
329
330impl<T, D> Clone for AutoDevice<T, D>
331where
332 T: Clone
333 + Send
334 + Sync
335 + 'static
336 + num_traits::Zero
337 + std::ops::Div<f64, Output = T>
338 + Default
339 + std::ops::Mul<Output = T>
340 + std::ops::Add<Output = T>,
341 D: Dimension + crate::ndarray::RemoveAxis + 'static,
342 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
343 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
344{
345 fn clone(&self) -> Self {
346 Self {
347 array: self.array.clone(),
348 device: self.device,
349 device_array: self.device_array.clone(),
350 }
351 }
352}
353
354impl<T, D> ArrayProtocol for AutoDevice<T, D>
355where
356 T: Clone
357 + Send
358 + Sync
359 + 'static
360 + num_traits::Zero
361 + std::ops::Div<f64, Output = T>
362 + Default
363 + std::ops::Mul<Output = T>
364 + std::ops::Add<Output = T>,
365 D: Dimension + crate::ndarray::RemoveAxis + 'static,
366 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
367 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
368{
369 fn array_function(
370 &self,
371 func: &ArrayFunction,
372 types: &[TypeId],
373 args: &[Box<dyn Any>],
374 kwargs: &HashMap<String, Box<dyn Any>>,
375 ) -> Result<Box<dyn Any>, NotImplemented> {
376 if let Some(device_array) = &self.device_array {
378 device_array.array_function(func, types, args, kwargs)
379 } else {
380 let device = determine_best_device(&self.array);
382 let temp_array = convert_to_device(self.array.clone(), device);
383 temp_array.array_function(func, types, args, kwargs)
384 }
385 }
386
387 fn as_any(&self) -> &dyn Any {
388 self
389 }
390
391 fn shape(&self) -> &[usize] {
392 self.array.shape()
393 }
394
395 fn dtype(&self) -> TypeId {
396 TypeId::of::<T>()
397 }
398
399 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
400 Box::new(self.clone())
401 }
402}
403
404#[allow(dead_code)]
409pub fn auto_execute<T, D, F, R>(
410 arrays: &mut [&mut AutoDevice<T, D>],
411 operation: &str,
412 executor: F,
413) -> CoreResult<R>
414where
415 T: Clone
416 + Send
417 + Sync
418 + 'static
419 + num_traits::Zero
420 + std::ops::Div<f64, Output = T>
421 + Default
422 + std::ops::Mul<Output = T>
423 + std::ops::Add<Output = T>,
424 D: Dimension + crate::ndarray::RemoveAxis + 'static,
425 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
426 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
427 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
428 R: 'static,
429{
430 let best_device = determine_best_device_for_operation(
432 &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
433 operation,
434 );
435
436 let device_arrays: Vec<&dyn ArrayProtocol> = arrays
438 .iter_mut()
439 .map(|a| a.on_device(best_device))
440 .collect();
441
442 executor(&device_arrays)
444}
445
446pub mod ops {
448 use super::*;
449 use crate::array_protocol::operations as ap_ops;
450 use crate::error::{CoreError, ErrorContext};
451
452 pub fn matmul<T, D>(
454 a: &mut AutoDevice<T, D>,
455 b: &mut AutoDevice<T, D>,
456 ) -> CoreResult<Box<dyn ArrayProtocol>>
457 where
458 T: Clone
459 + Send
460 + Sync
461 + 'static
462 + num_traits::Zero
463 + std::ops::Div<f64, Output = T>
464 + Default
465 + std::ops::Mul<Output = T>
466 + std::ops::Add<Output = T>,
467 D: Dimension + crate::ndarray::RemoveAxis + 'static,
468 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
469 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
470 {
471 auto_execute(&mut [a, b], "matmul", |arrays| {
472 match ap_ops::matmul(arrays[0], arrays[1]) {
474 Ok(result) => Ok(result),
475 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
476 e.to_string(),
477 ))),
478 }
479 })
480 }
481
482 pub fn add<T, D>(
484 a: &mut AutoDevice<T, D>,
485 b: &mut AutoDevice<T, D>,
486 ) -> CoreResult<Box<dyn ArrayProtocol>>
487 where
488 T: Clone
489 + Send
490 + Sync
491 + 'static
492 + num_traits::Zero
493 + std::ops::Div<f64, Output = T>
494 + Default
495 + std::ops::Mul<Output = T>
496 + std::ops::Add<Output = T>,
497 D: Dimension + crate::ndarray::RemoveAxis + 'static,
498 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
499 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
500 {
501 auto_execute(&mut [a, b], "add", |arrays| {
502 match ap_ops::add(arrays[0], arrays[1]) {
504 Ok(result) => Ok(result),
505 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
506 e.to_string(),
507 ))),
508 }
509 })
510 }
511
512 pub fn multiply<T, D>(
514 a: &mut AutoDevice<T, D>,
515 b: &mut AutoDevice<T, D>,
516 ) -> CoreResult<Box<dyn ArrayProtocol>>
517 where
518 T: Clone
519 + Send
520 + Sync
521 + 'static
522 + num_traits::Zero
523 + std::ops::Div<f64, Output = T>
524 + Default
525 + std::ops::Mul<Output = T>
526 + std::ops::Add<Output = T>,
527 D: Dimension + crate::ndarray::RemoveAxis + 'static,
528 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
529 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
530 {
531 auto_execute(&mut [a, b], "multiply", |arrays| {
532 match ap_ops::multiply(arrays[0], arrays[1]) {
534 Ok(result) => Ok(result),
535 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
536 e.to_string(),
537 ))),
538 }
539 })
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use ::ndarray::{arr2, Array2};
547
548 #[test]
549 fn test_auto_device_selection() {
550 crate::array_protocol::init();
552
553 let small_array = Array2::<f64>::ones((10, 10));
555 let device = determine_best_device(&small_array);
556 assert_eq!(device, DeviceType::CPU);
557
558 let mut config = get_auto_device_config();
560 config.gpu_threshold = 50; set_auto_device_config(config);
562
563 let device = determine_best_device(&small_array);
565 assert_eq!(device, DeviceType::GPU);
566
567 set_auto_device_config(AutoDeviceConfig::default());
569 }
570
571 #[test]
572 fn test_auto_device_wrapper() {
573 crate::array_protocol::init();
575
576 let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
578 let array = array_2d.into_dyn();
579 let mut auto_array = AutoDevice::new(array.clone());
580
581 assert_eq!(auto_array.device(), DeviceType::CPU);
583
584 let gpu_array = auto_array.on_device(DeviceType::GPU);
586 assert!(gpu_array
587 .as_any()
588 .downcast_ref::<GPUNdarray<f64, crate::ndarray::IxDyn>>()
589 .is_some());
590
591 assert_eq!(auto_array.device(), DeviceType::GPU);
593 }
594}