1use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23use ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
24use num_traits;
25
26use crate::array_protocol::{
27 ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
28 DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
29};
30use crate::error::CoreResult;
31
32#[derive(Debug, Clone)]
34pub struct AutoDeviceConfig {
35 pub gpu_threshold: usize,
37
38 pub distributed_threshold: usize,
40
41 pub enable_mixed_precision: bool,
43
44 pub prefer_memory_efficiency: bool,
46
47 pub auto_transfer: bool,
49
50 pub prefer_data_locality: bool,
52
53 pub preferred_gpu_backend: GPUBackend,
55
56 pub fallback_to_cpu: bool,
58}
59
60impl Default for AutoDeviceConfig {
61 fn default() -> Self {
62 Self {
63 gpu_threshold: 1_000_000, distributed_threshold: 100_000_000, enable_mixed_precision: false,
66 prefer_memory_efficiency: false,
67 auto_transfer: true,
68 prefer_data_locality: true,
69 preferred_gpu_backend: GPUBackend::CUDA,
70 fallback_to_cpu: true,
71 }
72 }
73}
74
75pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
77 gpu_threshold: 1_000_000,
78 distributed_threshold: 100_000_000,
79 enable_mixed_precision: false,
80 prefer_memory_efficiency: false,
81 auto_transfer: true,
82 prefer_data_locality: true,
83 preferred_gpu_backend: GPUBackend::CUDA,
84 fallback_to_cpu: true,
85});
86
87#[allow(dead_code)]
89pub fn set_auto_device_config(config: AutoDeviceConfig) {
90 if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
91 *global_config = config;
92 }
93}
94
95#[allow(dead_code)]
97pub fn get_auto_device_config() -> AutoDeviceConfig {
98 AUTO_DEVICE_CONFIG
99 .read()
100 .map(|c| c.clone())
101 .unwrap_or_default()
102}
103
104#[allow(dead_code)]
109pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
110where
111 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
112 D: Dimension + ndarray::RemoveAxis,
113{
114 let config = get_auto_device_config();
115 let size = array.len();
116
117 if size >= config.distributed_threshold {
118 DeviceType::Distributed
119 } else if size >= config.gpu_threshold {
120 DeviceType::GPU
121 } else {
122 DeviceType::CPU
123 }
124}
125
126#[allow(dead_code)]
131pub fn determine_best_device_for_operation<T, D>(
132 arrays: &[&Array<T, D>],
133 operation: &str,
134) -> DeviceType
135where
136 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
137 D: Dimension + ndarray::RemoveAxis,
138{
139 let config = get_auto_device_config();
140
141 let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
143
144 let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
146
147 let gpu_threshold = if is_complex_operation {
149 config.gpu_threshold / 10 } else {
151 config.gpu_threshold
152 };
153
154 let distributed_threshold = if is_complex_operation {
155 config.distributed_threshold / 2 } else {
157 config.distributed_threshold
158 };
159
160 if total_size >= distributed_threshold {
161 DeviceType::Distributed
162 } else if total_size >= gpu_threshold {
163 DeviceType::GPU
164 } else {
165 DeviceType::CPU
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum DeviceType {
172 CPU,
174
175 GPU,
177
178 Distributed,
180}
181
182#[allow(dead_code)]
187pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
188where
189 T: Clone
190 + Send
191 + Sync
192 + 'static
193 + num_traits::Zero
194 + std::ops::Div<f64, Output = T>
195 + Default
196 + std::ops::Mul<Output = T>
197 + std::ops::Add<Output = T>,
198 D: Dimension + ndarray::RemoveAxis + 'static,
199 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
200 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
201{
202 match device {
203 DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
204 DeviceType::GPU => {
205 let config = get_auto_device_config();
206 let gpu_config = GPUConfig {
207 backend: config.preferred_gpu_backend,
208 device_id: 0,
209 async_ops: true,
210 mixed_precision: config.enable_mixed_precision,
211 memory_fraction: 0.9,
212 };
213
214 Box::new(GPUNdarray::new(array.clone(), gpu_config))
215 }
216 DeviceType::Distributed => {
217 let dist_config = DistributedConfig {
218 chunks: 2, balance: true,
220 strategy: DistributionStrategy::RowWise,
221 backend: DistributedBackend::Threaded,
222 };
223
224 Box::new(DistributedNdarray::from_array(&array, dist_config))
225 }
226 }
227}
228
229pub struct AutoDevice<T, D>
234where
235 T: Clone
236 + Send
237 + Sync
238 + 'static
239 + num_traits::Zero
240 + std::ops::Div<f64, Output = T>
241 + Default
242 + std::ops::Mul<Output = T>
243 + std::ops::Add<Output = T>,
244 D: Dimension + ndarray::RemoveAxis,
245 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
246 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
247{
248 array: Array<T, D>,
250
251 device: DeviceType,
253
254 device_array: Option<Box<dyn ArrayProtocol>>,
256}
257
258impl<T, D> std::fmt::Debug for AutoDevice<T, D>
260where
261 T: Clone
262 + Send
263 + Sync
264 + std::fmt::Debug
265 + 'static
266 + num_traits::Zero
267 + std::ops::Div<f64, Output = T>
268 + Default
269 + std::ops::Mul<Output = T>
270 + std::ops::Add<Output = T>,
271 D: Dimension + ndarray::RemoveAxis + std::fmt::Debug + 'static,
272 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
273 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
274{
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("AutoDevice")
277 .field("array", &self.array)
278 .field("device", &self.device)
279 .field("device_array", &self.device_array.is_some())
280 .finish()
281 }
282}
283
284impl<T, D> AutoDevice<T, D>
285where
286 T: Clone
287 + Send
288 + Sync
289 + 'static
290 + num_traits::Zero
291 + std::ops::Div<f64, Output = T>
292 + Default
293 + std::ops::Mul<Output = T>
294 + std::ops::Add<Output = T>,
295 D: Dimension + ndarray::RemoveAxis + 'static,
296 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
297 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
298{
299 pub fn new(array: Array<T, D>) -> Self {
301 let device = determine_best_device(&array);
302 let device_array = None; Self {
305 array,
306 device,
307 device_array,
308 }
309 }
310
311 pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
313 if self.device != device || self.device_array.is_none() {
314 self.device = device;
316 self.device_array = Some(convert_to_device(self.array.clone(), device));
317 }
318
319 self.device_array.as_ref().unwrap().as_ref()
320 }
321
322 pub fn device(&self) -> DeviceType {
324 self.device
325 }
326
327 pub const fn array(&self) -> &Array<T, D> {
329 &self.array
330 }
331}
332
333impl<T, D> Clone for AutoDevice<T, D>
334where
335 T: Clone
336 + Send
337 + Sync
338 + 'static
339 + num_traits::Zero
340 + std::ops::Div<f64, Output = T>
341 + Default
342 + std::ops::Mul<Output = T>
343 + std::ops::Add<Output = T>,
344 D: Dimension + ndarray::RemoveAxis + 'static,
345 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
346 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
347{
348 fn clone(&self) -> Self {
349 Self {
350 array: self.array.clone(),
351 device: self.device,
352 device_array: self.device_array.clone(),
353 }
354 }
355}
356
357impl<T, D> ArrayProtocol for AutoDevice<T, D>
358where
359 T: Clone
360 + Send
361 + Sync
362 + 'static
363 + num_traits::Zero
364 + std::ops::Div<f64, Output = T>
365 + Default
366 + std::ops::Mul<Output = T>
367 + std::ops::Add<Output = T>,
368 D: Dimension + ndarray::RemoveAxis + 'static,
369 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
370 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
371{
372 fn array_function(
373 &self,
374 func: &ArrayFunction,
375 types: &[TypeId],
376 args: &[Box<dyn Any>],
377 kwargs: &HashMap<String, Box<dyn Any>>,
378 ) -> Result<Box<dyn Any>, NotImplemented> {
379 if let Some(device_array) = &self.device_array {
381 device_array.array_function(func, types, args, kwargs)
382 } else {
383 let device = determine_best_device(&self.array);
385 let temp_array = convert_to_device(self.array.clone(), device);
386 temp_array.array_function(func, types, args, kwargs)
387 }
388 }
389
390 fn as_any(&self) -> &dyn Any {
391 self
392 }
393
394 fn shape(&self) -> &[usize] {
395 self.array.shape()
396 }
397
398 fn dtype(&self) -> TypeId {
399 TypeId::of::<T>()
400 }
401
402 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
403 Box::new(self.clone())
404 }
405}
406
407#[allow(dead_code)]
412pub fn auto_execute<T, D, F, R>(
413 arrays: &mut [&mut AutoDevice<T, D>],
414 operation: &str,
415 executor: F,
416) -> CoreResult<R>
417where
418 T: Clone
419 + Send
420 + Sync
421 + 'static
422 + num_traits::Zero
423 + std::ops::Div<f64, Output = T>
424 + Default
425 + std::ops::Mul<Output = T>
426 + std::ops::Add<Output = T>,
427 D: Dimension + ndarray::RemoveAxis + 'static,
428 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
429 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
430 F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
431 R: 'static,
432{
433 let best_device = determine_best_device_for_operation(
435 &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
436 operation,
437 );
438
439 let device_arrays: Vec<&dyn ArrayProtocol> = arrays
441 .iter_mut()
442 .map(|a| a.on_device(best_device))
443 .collect();
444
445 executor(&device_arrays)
447}
448
449pub mod ops {
451 use super::*;
452 use crate::array_protocol::operations as ap_ops;
453 use crate::error::{CoreError, ErrorContext};
454
455 pub fn matmul<T, D>(
457 a: &mut AutoDevice<T, D>,
458 b: &mut AutoDevice<T, D>,
459 ) -> CoreResult<Box<dyn ArrayProtocol>>
460 where
461 T: Clone
462 + Send
463 + Sync
464 + 'static
465 + num_traits::Zero
466 + std::ops::Div<f64, Output = T>
467 + Default
468 + std::ops::Mul<Output = T>
469 + std::ops::Add<Output = T>,
470 D: Dimension + ndarray::RemoveAxis + 'static,
471 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
472 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
473 {
474 auto_execute(&mut [a, b], "matmul", |arrays| {
475 match ap_ops::matmul(arrays[0], arrays[1]) {
477 Ok(result) => Ok(result),
478 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
479 e.to_string(),
480 ))),
481 }
482 })
483 }
484
485 pub fn add<T, D>(
487 a: &mut AutoDevice<T, D>,
488 b: &mut AutoDevice<T, D>,
489 ) -> CoreResult<Box<dyn ArrayProtocol>>
490 where
491 T: Clone
492 + Send
493 + Sync
494 + 'static
495 + num_traits::Zero
496 + std::ops::Div<f64, Output = T>
497 + Default
498 + std::ops::Mul<Output = T>
499 + std::ops::Add<Output = T>,
500 D: Dimension + ndarray::RemoveAxis + 'static,
501 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
502 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
503 {
504 auto_execute(&mut [a, b], "add", |arrays| {
505 match ap_ops::add(arrays[0], arrays[1]) {
507 Ok(result) => Ok(result),
508 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
509 e.to_string(),
510 ))),
511 }
512 })
513 }
514
515 pub fn multiply<T, D>(
517 a: &mut AutoDevice<T, D>,
518 b: &mut AutoDevice<T, D>,
519 ) -> CoreResult<Box<dyn ArrayProtocol>>
520 where
521 T: Clone
522 + Send
523 + Sync
524 + 'static
525 + num_traits::Zero
526 + std::ops::Div<f64, Output = T>
527 + Default
528 + std::ops::Mul<Output = T>
529 + std::ops::Add<Output = T>,
530 D: Dimension + ndarray::RemoveAxis + 'static,
531 SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
532 SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
533 {
534 auto_execute(&mut [a, b], "multiply", |arrays| {
535 match ap_ops::multiply(arrays[0], arrays[1]) {
537 Ok(result) => Ok(result),
538 Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
539 e.to_string(),
540 ))),
541 }
542 })
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use ndarray::{arr2, Array2};
550
551 #[test]
552 fn test_auto_device_selection() {
553 crate::array_protocol::init();
555
556 let small_array = Array2::<f64>::ones((10, 10));
558 let device = determine_best_device(&small_array);
559 assert_eq!(device, DeviceType::CPU);
560
561 let mut config = get_auto_device_config();
563 config.gpu_threshold = 50; set_auto_device_config(config);
565
566 let device = determine_best_device(&small_array);
568 assert_eq!(device, DeviceType::GPU);
569
570 set_auto_device_config(AutoDeviceConfig::default());
572 }
573
574 #[test]
575 fn test_auto_device_wrapper() {
576 crate::array_protocol::init();
578
579 let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
581 let array = array_2d.into_dyn();
582 let mut auto_array = AutoDevice::new(array.clone());
583
584 assert_eq!(auto_array.device(), DeviceType::CPU);
586
587 let gpu_array = auto_array.on_device(DeviceType::GPU);
589 assert!(gpu_array
590 .as_any()
591 .downcast_ref::<GPUNdarray<f64, ndarray::IxDyn>>()
592 .is_some());
593
594 assert_eq!(auto_array.device(), DeviceType::GPU);
596 }
597}