scirs2_core/array_protocol/
auto_device.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Automatic device placement for array operations.
14//!
15//! This module provides functionality for automatically determining the best
16//! device (CPU, GPU, distributed) for array operations based on array size,
17//! available hardware, and operation complexity.
18
19use std::any::{Any, TypeId};
20use std::collections::HashMap;
21use std::sync::RwLock;
22
23use ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
24use num_traits;
25
26use crate::array_protocol::{
27    ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
28    DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
29};
30use crate::error::CoreResult;
31
32/// Configuration for automatic device placement.
33#[derive(Debug, Clone)]
34pub struct AutoDeviceConfig {
35    /// Minimum array size (total elements) to consider GPU placement.
36    pub gpu_threshold: usize,
37
38    /// Minimum array size to consider distributed placement.
39    pub distributed_threshold: usize,
40
41    /// Enable mixed-precision operations.
42    pub enable_mixed_precision: bool,
43
44    /// Prefer memory efficiency over speed.
45    pub prefer_memory_efficiency: bool,
46
47    /// Automatically transfer arrays between devices when needed.
48    pub auto_transfer: bool,
49
50    /// Prefer device data locality (avoid transfers).
51    pub prefer_data_locality: bool,
52
53    /// Preferred GPU backend.
54    pub preferred_gpu_backend: GPUBackend,
55
56    /// Fallback to CPU if GPU is not available.
57    pub fallback_to_cpu: bool,
58}
59
60impl Default for AutoDeviceConfig {
61    fn default() -> Self {
62        Self {
63            gpu_threshold: 1_000_000,           // 1M elements
64            distributed_threshold: 100_000_000, // 100M elements
65            enable_mixed_precision: false,
66            prefer_memory_efficiency: false,
67            auto_transfer: true,
68            prefer_data_locality: true,
69            preferred_gpu_backend: GPUBackend::CUDA,
70            fallback_to_cpu: true,
71        }
72    }
73}
74
75/// Global auto device configuration.
76pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
77    gpu_threshold: 1_000_000,
78    distributed_threshold: 100_000_000,
79    enable_mixed_precision: false,
80    prefer_memory_efficiency: false,
81    auto_transfer: true,
82    prefer_data_locality: true,
83    preferred_gpu_backend: GPUBackend::CUDA,
84    fallback_to_cpu: true,
85});
86
87/// Set the global auto device configuration.
88#[allow(dead_code)]
89pub fn set_auto_device_config(config: AutoDeviceConfig) {
90    if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
91        *global_config = config;
92    }
93}
94
95/// Get the current auto device configuration.
96#[allow(dead_code)]
97pub fn get_auto_device_config() -> AutoDeviceConfig {
98    AUTO_DEVICE_CONFIG
99        .read()
100        .map(|c| c.clone())
101        .unwrap_or_default()
102}
103
104/// Determine the best device for an array.
105///
106/// This function determines the best device (CPU, GPU, distributed) for an array
107/// based on its size, the operation being performed, and the current configuration.
108#[allow(dead_code)]
109pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
110where
111    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
112    D: Dimension + ndarray::RemoveAxis,
113{
114    let config = get_auto_device_config();
115    let size = array.len();
116
117    if size >= config.distributed_threshold {
118        DeviceType::Distributed
119    } else if size >= config.gpu_threshold {
120        DeviceType::GPU
121    } else {
122        DeviceType::CPU
123    }
124}
125
126/// Determine the best device for an operation with multiple arrays.
127///
128/// This function determines the best device for an operation based on
129/// the arrays involved and the operation being performed.
130#[allow(dead_code)]
131pub fn determine_best_device_for_operation<T, D>(
132    arrays: &[&Array<T, D>],
133    operation: &str,
134) -> DeviceType
135where
136    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
137    D: Dimension + ndarray::RemoveAxis,
138{
139    let config = get_auto_device_config();
140
141    // Complex operations (matrix multiplication, SVD, etc.) benefit more from GPU
142    let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
143
144    // Compute total size of all arrays
145    let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
146
147    // Adjust thresholds based on operation complexity
148    let gpu_threshold = if is_complex_operation {
149        config.gpu_threshold / 10 // Lower threshold for complex operations
150    } else {
151        config.gpu_threshold
152    };
153
154    let distributed_threshold = if is_complex_operation {
155        config.distributed_threshold / 2 // Lower threshold for complex operations
156    } else {
157        config.distributed_threshold
158    };
159
160    if total_size >= distributed_threshold {
161        DeviceType::Distributed
162    } else if total_size >= gpu_threshold {
163        DeviceType::GPU
164    } else {
165        DeviceType::CPU
166    }
167}
168
169/// Available device types for array operations.
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum DeviceType {
172    /// CPU-based computation.
173    CPU,
174
175    /// GPU-accelerated computation.
176    GPU,
177
178    /// Distributed computation across multiple machines/processes.
179    Distributed,
180}
181
182/// Convert an array to the specified device type.
183///
184/// This function converts an array to the specified device type, creating
185/// the appropriate array wrapper for the target device.
186#[allow(dead_code)]
187pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
188where
189    T: Clone
190        + Send
191        + Sync
192        + 'static
193        + num_traits::Zero
194        + std::ops::Div<f64, Output = T>
195        + Default
196        + std::ops::Mul<Output = T>
197        + std::ops::Add<Output = T>,
198    D: Dimension + ndarray::RemoveAxis + 'static,
199    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
200    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
201{
202    match device {
203        DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
204        DeviceType::GPU => {
205            let config = get_auto_device_config();
206            let gpu_config = GPUConfig {
207                backend: config.preferred_gpu_backend,
208                device_id: 0,
209                async_ops: true,
210                mixed_precision: config.enable_mixed_precision,
211                memory_fraction: 0.9,
212            };
213
214            Box::new(GPUNdarray::new(array.clone(), gpu_config))
215        }
216        DeviceType::Distributed => {
217            let dist_config = DistributedConfig {
218                chunks: 2, // Using 2 chunks as a default instead of num_cpus / 2
219                balance: true,
220                strategy: DistributionStrategy::RowWise,
221                backend: DistributedBackend::Threaded,
222            };
223
224            Box::new(DistributedNdarray::from_array(&array, dist_config))
225        }
226    }
227}
228
229/// A wrapper for arrays that automatically chooses the best device.
230///
231/// This wrapper automatically places arrays on the most appropriate device
232/// based on their size and the operations being performed.
233pub struct AutoDevice<T, D>
234where
235    T: Clone
236        + Send
237        + Sync
238        + 'static
239        + num_traits::Zero
240        + std::ops::Div<f64, Output = T>
241        + Default
242        + std::ops::Mul<Output = T>
243        + std::ops::Add<Output = T>,
244    D: Dimension + ndarray::RemoveAxis,
245    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
246    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
247{
248    /// The underlying array.
249    array: Array<T, D>,
250
251    /// The current device the array is on.
252    device: DeviceType,
253
254    /// The array on the current device.
255    device_array: Option<Box<dyn ArrayProtocol>>,
256}
257
258// Manually implement Debug for AutoDevice since Box<dyn ArrayProtocol> doesn't implement Debug
259impl<T, D> std::fmt::Debug for AutoDevice<T, D>
260where
261    T: Clone
262        + Send
263        + Sync
264        + std::fmt::Debug
265        + 'static
266        + num_traits::Zero
267        + std::ops::Div<f64, Output = T>
268        + Default
269        + std::ops::Mul<Output = T>
270        + std::ops::Add<Output = T>,
271    D: Dimension + ndarray::RemoveAxis + std::fmt::Debug + 'static,
272    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
273    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
274{
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("AutoDevice")
277            .field("array", &self.array)
278            .field("device", &self.device)
279            .field("device_array", &self.device_array.is_some())
280            .finish()
281    }
282}
283
284impl<T, D> AutoDevice<T, D>
285where
286    T: Clone
287        + Send
288        + Sync
289        + 'static
290        + num_traits::Zero
291        + std::ops::Div<f64, Output = T>
292        + Default
293        + std::ops::Mul<Output = T>
294        + std::ops::Add<Output = T>,
295    D: Dimension + ndarray::RemoveAxis + 'static,
296    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
297    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
298{
299    /// Create a new auto-device array.
300    pub fn new(array: Array<T, D>) -> Self {
301        let device = determine_best_device(&array);
302        let device_array = None; // Lazily initialized
303
304        Self {
305            array,
306            device,
307            device_array,
308        }
309    }
310
311    /// Get the array on the specified device.
312    pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
313        if self.device != device || self.device_array.is_none() {
314            // Convert to the requested device
315            self.device = device;
316            self.device_array = Some(convert_to_device(self.array.clone(), device));
317        }
318
319        self.device_array.as_ref().unwrap().as_ref()
320    }
321
322    /// Get the current device.
323    pub fn device(&self) -> DeviceType {
324        self.device
325    }
326
327    /// Get the underlying array.
328    pub const fn array(&self) -> &Array<T, D> {
329        &self.array
330    }
331}
332
333impl<T, D> Clone for AutoDevice<T, D>
334where
335    T: Clone
336        + Send
337        + Sync
338        + 'static
339        + num_traits::Zero
340        + std::ops::Div<f64, Output = T>
341        + Default
342        + std::ops::Mul<Output = T>
343        + std::ops::Add<Output = T>,
344    D: Dimension + ndarray::RemoveAxis + 'static,
345    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
346    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
347{
348    fn clone(&self) -> Self {
349        Self {
350            array: self.array.clone(),
351            device: self.device,
352            device_array: self.device_array.clone(),
353        }
354    }
355}
356
357impl<T, D> ArrayProtocol for AutoDevice<T, D>
358where
359    T: Clone
360        + Send
361        + Sync
362        + 'static
363        + num_traits::Zero
364        + std::ops::Div<f64, Output = T>
365        + Default
366        + std::ops::Mul<Output = T>
367        + std::ops::Add<Output = T>,
368    D: Dimension + ndarray::RemoveAxis + 'static,
369    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
370    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
371{
372    fn array_function(
373        &self,
374        func: &ArrayFunction,
375        types: &[TypeId],
376        args: &[Box<dyn Any>],
377        kwargs: &HashMap<String, Box<dyn Any>>,
378    ) -> Result<Box<dyn Any>, NotImplemented> {
379        // If we already have a device array, delegate to it
380        if let Some(device_array) = &self.device_array {
381            device_array.array_function(func, types, args, kwargs)
382        } else {
383            // Otherwise, create a temporary array on the appropriate device
384            let device = determine_best_device(&self.array);
385            let temp_array = convert_to_device(self.array.clone(), device);
386            temp_array.array_function(func, types, args, kwargs)
387        }
388    }
389
390    fn as_any(&self) -> &dyn Any {
391        self
392    }
393
394    fn shape(&self) -> &[usize] {
395        self.array.shape()
396    }
397
398    fn dtype(&self) -> TypeId {
399        TypeId::of::<T>()
400    }
401
402    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
403        Box::new(self.clone())
404    }
405}
406
407/// Execute an operation with automatic device selection.
408///
409/// This function automatically selects the best device for the operation
410/// based on the arrays involved and the operation being performed.
411#[allow(dead_code)]
412pub fn auto_execute<T, D, F, R>(
413    arrays: &mut [&mut AutoDevice<T, D>],
414    operation: &str,
415    executor: F,
416) -> CoreResult<R>
417where
418    T: Clone
419        + Send
420        + Sync
421        + 'static
422        + num_traits::Zero
423        + std::ops::Div<f64, Output = T>
424        + Default
425        + std::ops::Mul<Output = T>
426        + std::ops::Add<Output = T>,
427    D: Dimension + ndarray::RemoveAxis + 'static,
428    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
429    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
430    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
431    R: 'static,
432{
433    // Determine the best device for this operation
434    let best_device = determine_best_device_for_operation(
435        &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
436        operation,
437    );
438
439    // Get arrays on the selected device
440    let device_arrays: Vec<&dyn ArrayProtocol> = arrays
441        .iter_mut()
442        .map(|a| a.on_device(best_device))
443        .collect();
444
445    // Execute the operation
446    executor(&device_arrays)
447}
448
449/// Implementation of common array operations with automatic device selection.
450pub mod ops {
451    use super::*;
452    use crate::array_protocol::operations as ap_ops;
453    use crate::error::{CoreError, ErrorContext};
454
455    /// Matrix multiplication with automatic device selection.
456    pub fn matmul<T, D>(
457        a: &mut AutoDevice<T, D>,
458        b: &mut AutoDevice<T, D>,
459    ) -> CoreResult<Box<dyn ArrayProtocol>>
460    where
461        T: Clone
462            + Send
463            + Sync
464            + 'static
465            + num_traits::Zero
466            + std::ops::Div<f64, Output = T>
467            + Default
468            + std::ops::Mul<Output = T>
469            + std::ops::Add<Output = T>,
470        D: Dimension + ndarray::RemoveAxis + 'static,
471        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
472        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
473    {
474        auto_execute(&mut [a, b], "matmul", |arrays| {
475            // Convert OperationError to CoreError
476            match ap_ops::matmul(arrays[0], arrays[1]) {
477                Ok(result) => Ok(result),
478                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
479                    e.to_string(),
480                ))),
481            }
482        })
483    }
484
485    /// Element-wise addition with automatic device selection.
486    pub fn add<T, D>(
487        a: &mut AutoDevice<T, D>,
488        b: &mut AutoDevice<T, D>,
489    ) -> CoreResult<Box<dyn ArrayProtocol>>
490    where
491        T: Clone
492            + Send
493            + Sync
494            + 'static
495            + num_traits::Zero
496            + std::ops::Div<f64, Output = T>
497            + Default
498            + std::ops::Mul<Output = T>
499            + std::ops::Add<Output = T>,
500        D: Dimension + ndarray::RemoveAxis + 'static,
501        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
502        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
503    {
504        auto_execute(&mut [a, b], "add", |arrays| {
505            // Convert OperationError to CoreError
506            match ap_ops::add(arrays[0], arrays[1]) {
507                Ok(result) => Ok(result),
508                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
509                    e.to_string(),
510                ))),
511            }
512        })
513    }
514
515    /// Element-wise multiplication with automatic device selection.
516    pub fn multiply<T, D>(
517        a: &mut AutoDevice<T, D>,
518        b: &mut AutoDevice<T, D>,
519    ) -> CoreResult<Box<dyn ArrayProtocol>>
520    where
521        T: Clone
522            + Send
523            + Sync
524            + 'static
525            + num_traits::Zero
526            + std::ops::Div<f64, Output = T>
527            + Default
528            + std::ops::Mul<Output = T>
529            + std::ops::Add<Output = T>,
530        D: Dimension + ndarray::RemoveAxis + 'static,
531        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
532        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
533    {
534        auto_execute(&mut [a, b], "multiply", |arrays| {
535            // Convert OperationError to CoreError
536            match ap_ops::multiply(arrays[0], arrays[1]) {
537                Ok(result) => Ok(result),
538                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
539                    e.to_string(),
540                ))),
541            }
542        })
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use ndarray::{arr2, Array2};
550
551    #[test]
552    fn test_auto_device_selection() {
553        // Initialize the array protocol
554        crate::array_protocol::init();
555
556        // Create a small array (should be placed on CPU)
557        let small_array = Array2::<f64>::ones((10, 10));
558        let device = determine_best_device(&small_array);
559        assert_eq!(device, DeviceType::CPU);
560
561        // Modify config to place smaller arrays on GPU
562        let mut config = get_auto_device_config();
563        config.gpu_threshold = 50; // 50 elements
564        set_auto_device_config(config);
565
566        // Check device selection with new config
567        let device = determine_best_device(&small_array);
568        assert_eq!(device, DeviceType::GPU);
569
570        // Reset config
571        set_auto_device_config(AutoDeviceConfig::default());
572    }
573
574    #[test]
575    fn test_auto_device_wrapper() {
576        // Initialize the array protocol
577        crate::array_protocol::init();
578
579        // Create a small array - using IxDyn to match the trait bounds
580        let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
581        let array = array_2d.into_dyn();
582        let mut auto_array = AutoDevice::new(array.clone());
583
584        // Check initial device (should be CPU for small array)
585        assert_eq!(auto_array.device(), DeviceType::CPU);
586
587        // Get array on GPU
588        let gpu_array = auto_array.on_device(DeviceType::GPU);
589        assert!(gpu_array
590            .as_any()
591            .downcast_ref::<GPUNdarray<f64, ndarray::IxDyn>>()
592            .is_some());
593
594        // Check that device was updated
595        assert_eq!(auto_array.device(), DeviceType::GPU);
596    }
597}