scirs2_core/array_protocol/
auto_device.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Automatic device placement for array operations.
14//!
15//! This module provides functionality for automatically determining the best
16//! device (CPU, GPU, distributed) for array operations based on array size,
17//! available hardware, and operation complexity.
18
19use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23use ::ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
24use num_traits;
25
26use crate::array_protocol::{
27    ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
28    DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
29};
30use crate::error::CoreResult;
31
32/// Configuration for automatic device placement.
33#[derive(Debug, Clone)]
34pub struct AutoDeviceConfig {
35    /// Minimum array size (total elements) to consider GPU placement.
36    pub gpu_threshold: usize,
37
38    /// Minimum array size to consider distributed placement.
39    pub distributed_threshold: usize,
40
41    /// Enable mixed-precision operations.
42    pub enable_mixed_precision: bool,
43
44    /// Prefer memory efficiency over speed.
45    pub prefer_memory_efficiency: bool,
46
47    /// Automatically transfer arrays between devices when needed.
48    pub auto_transfer: bool,
49
50    /// Prefer device data locality (avoid transfers).
51    pub prefer_data_locality: bool,
52
53    /// Preferred GPU backend.
54    pub preferred_gpu_backend: GPUBackend,
55
56    /// Fallback to CPU if GPU is not available.
57    pub fallback_to_cpu: bool,
58}
59
60impl Default for AutoDeviceConfig {
61    fn default() -> Self {
62        Self {
63            gpu_threshold: 1_000_000,           // 1M elements
64            distributed_threshold: 100_000_000, // 100M elements
65            enable_mixed_precision: false,
66            prefer_memory_efficiency: false,
67            auto_transfer: true,
68            prefer_data_locality: true,
69            preferred_gpu_backend: GPUBackend::CUDA,
70            fallback_to_cpu: true,
71        }
72    }
73}
74
75/// Global auto device configuration.
76pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
77    gpu_threshold: 1_000_000,
78    distributed_threshold: 100_000_000,
79    enable_mixed_precision: false,
80    prefer_memory_efficiency: false,
81    auto_transfer: true,
82    prefer_data_locality: true,
83    preferred_gpu_backend: GPUBackend::CUDA,
84    fallback_to_cpu: true,
85});
86
87/// Set the global auto device configuration.
88#[allow(dead_code)]
89pub fn set_auto_device_config(config: AutoDeviceConfig) {
90    if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
91        *global_config = config;
92    }
93}
94
95/// Get the current auto device configuration.
96#[allow(dead_code)]
97pub fn get_auto_device_config() -> AutoDeviceConfig {
98    AUTO_DEVICE_CONFIG
99        .read()
100        .map(|c| c.clone())
101        .unwrap_or_default()
102}
103
104/// Determine the best device for an array.
105///
106/// This function determines the best device (CPU, GPU, distributed) for an array
107/// based on its size, the operation being performed, and the current configuration.
108#[allow(dead_code)]
109pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
110where
111    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
112    D: Dimension + crate::ndarray::RemoveAxis,
113{
114    let config = get_auto_device_config();
115    let size = array.len();
116
117    if size >= config.distributed_threshold {
118        DeviceType::Distributed
119    } else if size >= config.gpu_threshold {
120        DeviceType::GPU
121    } else {
122        DeviceType::CPU
123    }
124}
125
126/// Determine the best device for an operation with multiple arrays.
127///
128/// This function determines the best device for an operation based on
129/// the arrays involved and the operation being performed.
130#[allow(dead_code)]
131pub fn determine_best_device_for_operation<T, D>(
132    arrays: &[&Array<T, D>],
133    operation: &str,
134) -> DeviceType
135where
136    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
137    D: Dimension + crate::ndarray::RemoveAxis,
138{
139    let config = get_auto_device_config();
140
141    // Complex operations (matrix multiplication, SVD, etc.) benefit more from GPU
142    let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
143
144    // Compute total size of all arrays
145    let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
146
147    // Adjust thresholds based on operation complexity
148    let gpu_threshold = if is_complex_operation {
149        config.gpu_threshold / 10 // Lower threshold for complex operations
150    } else {
151        config.gpu_threshold
152    };
153
154    let distributed_threshold = if is_complex_operation {
155        config.distributed_threshold / 2 // Lower threshold for complex operations
156    } else {
157        config.distributed_threshold
158    };
159
160    if total_size >= distributed_threshold {
161        DeviceType::Distributed
162    } else if total_size >= gpu_threshold {
163        DeviceType::GPU
164    } else {
165        DeviceType::CPU
166    }
167}
168
169/// Available device types for array operations.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum DeviceType {
172    /// CPU-based computation.
173    CPU,
174
175    /// GPU-accelerated computation.
176    GPU,
177
178    /// Distributed computation across multiple machines/processes.
179    Distributed,
180}
181
182/// Convert an array to the specified device type.
183///
184/// This function converts an array to the specified device type, creating
185/// the appropriate array wrapper for the target device.
186#[allow(dead_code)]
187pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
188where
189    T: Clone
190        + Send
191        + Sync
192        + 'static
193        + num_traits::Zero
194        + std::ops::Div<f64, Output = T>
195        + Default
196        + std::ops::Mul<Output = T>
197        + std::ops::Add<Output = T>,
198    D: Dimension + crate::ndarray::RemoveAxis + 'static,
199    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
200    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
201{
202    match device {
203        DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
204        DeviceType::GPU => {
205            let config = get_auto_device_config();
206            let gpu_config = GPUConfig {
207                backend: config.preferred_gpu_backend,
208                device_id: 0,
209                async_ops: true,
210                mixed_precision: config.enable_mixed_precision,
211                memory_fraction: 0.9,
212            };
213
214            Box::new(GPUNdarray::new(array.clone(), gpu_config))
215        }
216        DeviceType::Distributed => {
217            let dist_config = DistributedConfig {
218                chunks: 2, // Using 2 chunks as a default instead of num_cpus / 2
219                balance: true,
220                strategy: DistributionStrategy::RowWise,
221                backend: DistributedBackend::Threaded,
222            };
223
224            Box::new(DistributedNdarray::from_array(&array, dist_config))
225        }
226    }
227}
228
229/// A wrapper for arrays that automatically chooses the best device.
230///
231/// This wrapper automatically places arrays on the most appropriate device
232/// based on their size and the operations being performed.
233pub struct AutoDevice<T, D>
234where
235    T: Clone
236        + Send
237        + Sync
238        + 'static
239        + num_traits::Zero
240        + std::ops::Div<f64, Output = T>
241        + Default
242        + std::ops::Mul<Output = T>
243        + std::ops::Add<Output = T>,
244    D: Dimension + crate::ndarray::RemoveAxis,
245    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
246    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
247{
248    /// The underlying array.
249    array: Array<T, D>,
250
251    /// The current device the array is on.
252    device: DeviceType,
253
254    /// The array on the current device.
255    device_array: Option<Box<dyn ArrayProtocol>>,
256}
257
258// Manually implement Debug for AutoDevice since Box<dyn ArrayProtocol> doesn't implement Debug
259impl<T, D> std::fmt::Debug for AutoDevice<T, D>
260where
261    T: Clone
262        + Send
263        + Sync
264        + std::fmt::Debug
265        + 'static
266        + num_traits::Zero
267        + std::ops::Div<f64, Output = T>
268        + Default
269        + std::ops::Mul<Output = T>
270        + std::ops::Add<Output = T>,
271    D: Dimension + crate::ndarray::RemoveAxis + std::fmt::Debug + 'static,
272    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
273    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
274{
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("AutoDevice")
277            .field("array", &self.array)
278            .field("device", &self.device)
279            .field("device_array", &self.device_array.is_some())
280            .finish()
281    }
282}
283
284impl<T, D> AutoDevice<T, D>
285where
286    T: Clone
287        + Send
288        + Sync
289        + 'static
290        + num_traits::Zero
291        + std::ops::Div<f64, Output = T>
292        + Default
293        + std::ops::Mul<Output = T>
294        + std::ops::Add<Output = T>,
295    D: Dimension + crate::ndarray::RemoveAxis + 'static,
296    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
297    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
298{
299    /// Create a new auto-device array.
300    pub fn new(array: Array<T, D>) -> Self {
301        let device = determine_best_device(&array);
302        let device_array = None; // Lazily initialized
303
304        Self {
305            array,
306            device,
307            device_array,
308        }
309    }
310
311    /// Get the array on the specified device.
312    pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
313        if self.device != device || self.device_array.is_none() {
314            // Convert to the requested device
315            self.device = device;
316            self.device_array = Some(convert_to_device(self.array.clone(), device));
317        }
318
319        self.device_array
320            .as_ref()
321            .expect("Operation failed")
322            .as_ref()
323    }
324
325    /// Get the current device.
326    pub fn device(&self) -> DeviceType {
327        self.device
328    }
329
330    /// Get the underlying array.
331    pub const fn array(&self) -> &Array<T, D> {
332        &self.array
333    }
334}
335
336impl<T, D> Clone for AutoDevice<T, D>
337where
338    T: Clone
339        + Send
340        + Sync
341        + 'static
342        + num_traits::Zero
343        + std::ops::Div<f64, Output = T>
344        + Default
345        + std::ops::Mul<Output = T>
346        + std::ops::Add<Output = T>,
347    D: Dimension + crate::ndarray::RemoveAxis + 'static,
348    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
349    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
350{
351    fn clone(&self) -> Self {
352        Self {
353            array: self.array.clone(),
354            device: self.device,
355            device_array: self.device_array.clone(),
356        }
357    }
358}
359
360impl<T, D> ArrayProtocol for AutoDevice<T, D>
361where
362    T: Clone
363        + Send
364        + Sync
365        + 'static
366        + num_traits::Zero
367        + std::ops::Div<f64, Output = T>
368        + Default
369        + std::ops::Mul<Output = T>
370        + std::ops::Add<Output = T>,
371    D: Dimension + crate::ndarray::RemoveAxis + 'static,
372    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
373    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
374{
375    fn array_function(
376        &self,
377        func: &ArrayFunction,
378        types: &[TypeId],
379        args: &[Box<dyn Any>],
380        kwargs: &HashMap<String, Box<dyn Any>>,
381    ) -> Result<Box<dyn Any>, NotImplemented> {
382        // If we already have a device array, delegate to it
383        if let Some(device_array) = &self.device_array {
384            device_array.array_function(func, types, args, kwargs)
385        } else {
386            // Otherwise, create a temporary array on the appropriate device
387            let device = determine_best_device(&self.array);
388            let temp_array = convert_to_device(self.array.clone(), device);
389            temp_array.array_function(func, types, args, kwargs)
390        }
391    }
392
393    fn as_any(&self) -> &dyn Any {
394        self
395    }
396
397    fn shape(&self) -> &[usize] {
398        self.array.shape()
399    }
400
401    fn dtype(&self) -> TypeId {
402        TypeId::of::<T>()
403    }
404
405    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
406        Box::new(self.clone())
407    }
408}
409
410/// Execute an operation with automatic device selection.
411///
412/// This function automatically selects the best device for the operation
413/// based on the arrays involved and the operation being performed.
414#[allow(dead_code)]
415pub fn auto_execute<T, D, F, R>(
416    arrays: &mut [&mut AutoDevice<T, D>],
417    operation: &str,
418    executor: F,
419) -> CoreResult<R>
420where
421    T: Clone
422        + Send
423        + Sync
424        + 'static
425        + num_traits::Zero
426        + std::ops::Div<f64, Output = T>
427        + Default
428        + std::ops::Mul<Output = T>
429        + std::ops::Add<Output = T>,
430    D: Dimension + crate::ndarray::RemoveAxis + 'static,
431    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
432    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
433    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
434    R: 'static,
435{
436    // Determine the best device for this operation
437    let best_device = determine_best_device_for_operation(
438        &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
439        operation,
440    );
441
442    // Get arrays on the selected device
443    let device_arrays: Vec<&dyn ArrayProtocol> = arrays
444        .iter_mut()
445        .map(|a| a.on_device(best_device))
446        .collect();
447
448    // Execute the operation
449    executor(&device_arrays)
450}
451
452/// Implementation of common array operations with automatic device selection.
453pub mod ops {
454    use super::*;
455    use crate::array_protocol::operations as ap_ops;
456    use crate::error::{CoreError, ErrorContext};
457
458    /// Matrix multiplication with automatic device selection.
459    pub fn matmul<T, D>(
460        a: &mut AutoDevice<T, D>,
461        b: &mut AutoDevice<T, D>,
462    ) -> CoreResult<Box<dyn ArrayProtocol>>
463    where
464        T: Clone
465            + Send
466            + Sync
467            + 'static
468            + num_traits::Zero
469            + std::ops::Div<f64, Output = T>
470            + Default
471            + std::ops::Mul<Output = T>
472            + std::ops::Add<Output = T>,
473        D: Dimension + crate::ndarray::RemoveAxis + 'static,
474        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
475        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
476    {
477        auto_execute(&mut [a, b], "matmul", |arrays| {
478            // Convert OperationError to CoreError
479            match ap_ops::matmul(arrays[0], arrays[1]) {
480                Ok(result) => Ok(result),
481                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
482                    e.to_string(),
483                ))),
484            }
485        })
486    }
487
488    /// Element-wise addition with automatic device selection.
489    pub fn add<T, D>(
490        a: &mut AutoDevice<T, D>,
491        b: &mut AutoDevice<T, D>,
492    ) -> CoreResult<Box<dyn ArrayProtocol>>
493    where
494        T: Clone
495            + Send
496            + Sync
497            + 'static
498            + num_traits::Zero
499            + std::ops::Div<f64, Output = T>
500            + Default
501            + std::ops::Mul<Output = T>
502            + std::ops::Add<Output = T>,
503        D: Dimension + crate::ndarray::RemoveAxis + 'static,
504        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
505        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
506    {
507        auto_execute(&mut [a, b], "add", |arrays| {
508            // Convert OperationError to CoreError
509            match ap_ops::add(arrays[0], arrays[1]) {
510                Ok(result) => Ok(result),
511                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
512                    e.to_string(),
513                ))),
514            }
515        })
516    }
517
518    /// Element-wise multiplication with automatic device selection.
519    pub fn multiply<T, D>(
520        a: &mut AutoDevice<T, D>,
521        b: &mut AutoDevice<T, D>,
522    ) -> CoreResult<Box<dyn ArrayProtocol>>
523    where
524        T: Clone
525            + Send
526            + Sync
527            + 'static
528            + num_traits::Zero
529            + std::ops::Div<f64, Output = T>
530            + Default
531            + std::ops::Mul<Output = T>
532            + std::ops::Add<Output = T>,
533        D: Dimension + crate::ndarray::RemoveAxis + 'static,
534        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
535        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
536    {
537        auto_execute(&mut [a, b], "multiply", |arrays| {
538            // Convert OperationError to CoreError
539            match ap_ops::multiply(arrays[0], arrays[1]) {
540                Ok(result) => Ok(result),
541                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
542                    e.to_string(),
543                ))),
544            }
545        })
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use ::ndarray::{arr2, Array2};
553
554    #[test]
555    fn test_auto_device_selection() {
556        // Initialize the array protocol
557        crate::array_protocol::init();
558
559        // Create a small array (should be placed on CPU)
560        let small_array = Array2::<f64>::ones((10, 10));
561        let device = determine_best_device(&small_array);
562        assert_eq!(device, DeviceType::CPU);
563
564        // Modify config to place smaller arrays on GPU
565        let mut config = get_auto_device_config();
566        config.gpu_threshold = 50; // 50 elements
567        set_auto_device_config(config);
568
569        // Check device selection with new config
570        let device = determine_best_device(&small_array);
571        assert_eq!(device, DeviceType::GPU);
572
573        // Reset config
574        set_auto_device_config(AutoDeviceConfig::default());
575    }
576
577    #[test]
578    fn test_auto_device_wrapper() {
579        // Initialize the array protocol
580        crate::array_protocol::init();
581
582        // Create a small array - using IxDyn to match the trait bounds
583        let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
584        let array = array_2d.into_dyn();
585        let mut auto_array = AutoDevice::new(array.clone());
586
587        // Check initial device (should be CPU for small array)
588        assert_eq!(auto_array.device(), DeviceType::CPU);
589
590        // Get array on GPU
591        let gpu_array = auto_array.on_device(DeviceType::GPU);
592        assert!(gpu_array
593            .as_any()
594            .downcast_ref::<GPUNdarray<f64, crate::ndarray::IxDyn>>()
595            .is_some());
596
597        // Check that device was updated
598        assert_eq!(auto_array.device(), DeviceType::GPU);
599    }
600}