Skip to main content

scirs2_core/array_protocol/
auto_device.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Automatic device placement for array operations.
8//!
9//! This module provides functionality for automatically determining the best
10//! device (CPU, GPU, distributed) for array operations based on array size,
11//! available hardware, and operation complexity.
12
13use std::any::{Any, TypeId};
14use std::collections::HashMap;
15use std::sync::RwLock;
16
17use ::ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
18use num_traits;
19
20use crate::array_protocol::{
21    ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
22    DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
23};
24use crate::error::CoreResult;
25
26/// Configuration for automatic device placement.
27#[derive(Debug, Clone)]
28pub struct AutoDeviceConfig {
29    /// Minimum array size (total elements) to consider GPU placement.
30    pub gpu_threshold: usize,
31
32    /// Minimum array size to consider distributed placement.
33    pub distributed_threshold: usize,
34
35    /// Enable mixed-precision operations.
36    pub enable_mixed_precision: bool,
37
38    /// Prefer memory efficiency over speed.
39    pub prefer_memory_efficiency: bool,
40
41    /// Automatically transfer arrays between devices when needed.
42    pub auto_transfer: bool,
43
44    /// Prefer device data locality (avoid transfers).
45    pub prefer_data_locality: bool,
46
47    /// Preferred GPU backend.
48    pub preferred_gpu_backend: GPUBackend,
49
50    /// Fallback to CPU if GPU is not available.
51    pub fallback_to_cpu: bool,
52}
53
54impl Default for AutoDeviceConfig {
55    fn default() -> Self {
56        Self {
57            gpu_threshold: 1_000_000,           // 1M elements
58            distributed_threshold: 100_000_000, // 100M elements
59            enable_mixed_precision: false,
60            prefer_memory_efficiency: false,
61            auto_transfer: true,
62            prefer_data_locality: true,
63            preferred_gpu_backend: GPUBackend::CUDA,
64            fallback_to_cpu: true,
65        }
66    }
67}
68
69/// Global auto device configuration.
70pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
71    gpu_threshold: 1_000_000,
72    distributed_threshold: 100_000_000,
73    enable_mixed_precision: false,
74    prefer_memory_efficiency: false,
75    auto_transfer: true,
76    prefer_data_locality: true,
77    preferred_gpu_backend: GPUBackend::CUDA,
78    fallback_to_cpu: true,
79});
80
81/// Set the global auto device configuration.
82#[allow(dead_code)]
83pub fn set_auto_device_config(config: AutoDeviceConfig) {
84    if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
85        *global_config = config;
86    }
87}
88
89/// Get the current auto device configuration.
90#[allow(dead_code)]
91pub fn get_auto_device_config() -> AutoDeviceConfig {
92    AUTO_DEVICE_CONFIG
93        .read()
94        .map(|c| c.clone())
95        .unwrap_or_default()
96}
97
98/// Determine the best device for an array.
99///
100/// This function determines the best device (CPU, GPU, distributed) for an array
101/// based on its size, the operation being performed, and the current configuration.
102#[allow(dead_code)]
103pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
104where
105    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
106    D: Dimension + crate::ndarray::RemoveAxis,
107{
108    let config = get_auto_device_config();
109    let size = array.len();
110
111    if size >= config.distributed_threshold {
112        DeviceType::Distributed
113    } else if size >= config.gpu_threshold {
114        DeviceType::GPU
115    } else {
116        DeviceType::CPU
117    }
118}
119
120/// Determine the best device for an operation with multiple arrays.
121///
122/// This function determines the best device for an operation based on
123/// the arrays involved and the operation being performed.
124#[allow(dead_code)]
125pub fn determine_best_device_for_operation<T, D>(
126    arrays: &[&Array<T, D>],
127    operation: &str,
128) -> DeviceType
129where
130    T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
131    D: Dimension + crate::ndarray::RemoveAxis,
132{
133    let config = get_auto_device_config();
134
135    // Complex operations (matrix multiplication, SVD, etc.) benefit more from GPU
136    let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
137
138    // Compute total size of all arrays
139    let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
140
141    // Adjust thresholds based on operation complexity
142    let gpu_threshold = if is_complex_operation {
143        config.gpu_threshold / 10 // Lower threshold for complex operations
144    } else {
145        config.gpu_threshold
146    };
147
148    let distributed_threshold = if is_complex_operation {
149        config.distributed_threshold / 2 // Lower threshold for complex operations
150    } else {
151        config.distributed_threshold
152    };
153
154    if total_size >= distributed_threshold {
155        DeviceType::Distributed
156    } else if total_size >= gpu_threshold {
157        DeviceType::GPU
158    } else {
159        DeviceType::CPU
160    }
161}
162
163/// Available device types for array operations.
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum DeviceType {
166    /// CPU-based computation.
167    CPU,
168
169    /// GPU-accelerated computation.
170    GPU,
171
172    /// Distributed computation across multiple machines/processes.
173    Distributed,
174}
175
176/// Convert an array to the specified device type.
177///
178/// This function converts an array to the specified device type, creating
179/// the appropriate array wrapper for the target device.
180#[allow(dead_code)]
181pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
182where
183    T: Clone
184        + Send
185        + Sync
186        + 'static
187        + num_traits::Zero
188        + std::ops::Div<f64, Output = T>
189        + Default
190        + std::ops::Mul<Output = T>
191        + std::ops::Add<Output = T>,
192    D: Dimension + crate::ndarray::RemoveAxis + 'static,
193    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
194    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
195{
196    match device {
197        DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
198        DeviceType::GPU => {
199            let config = get_auto_device_config();
200            let gpu_config = GPUConfig {
201                backend: config.preferred_gpu_backend,
202                device_id: 0,
203                async_ops: true,
204                mixed_precision: config.enable_mixed_precision,
205                memory_fraction: 0.9,
206            };
207
208            Box::new(GPUNdarray::new(array.clone(), gpu_config))
209        }
210        DeviceType::Distributed => {
211            let dist_config = DistributedConfig {
212                chunks: 2, // Using 2 chunks as a default instead of num_cpus / 2
213                balance: true,
214                strategy: DistributionStrategy::RowWise,
215                backend: DistributedBackend::Threaded,
216            };
217
218            Box::new(DistributedNdarray::from_array(&array, dist_config))
219        }
220    }
221}
222
223/// A wrapper for arrays that automatically chooses the best device.
224///
225/// This wrapper automatically places arrays on the most appropriate device
226/// based on their size and the operations being performed.
227pub struct AutoDevice<T, D>
228where
229    T: Clone
230        + Send
231        + Sync
232        + 'static
233        + num_traits::Zero
234        + std::ops::Div<f64, Output = T>
235        + Default
236        + std::ops::Mul<Output = T>
237        + std::ops::Add<Output = T>,
238    D: Dimension + crate::ndarray::RemoveAxis,
239    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
240    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
241{
242    /// The underlying array.
243    array: Array<T, D>,
244
245    /// The current device the array is on.
246    device: DeviceType,
247
248    /// The array on the current device.
249    device_array: Option<Box<dyn ArrayProtocol>>,
250}
251
252// Manually implement Debug for AutoDevice since Box<dyn ArrayProtocol> doesn't implement Debug
253impl<T, D> std::fmt::Debug for AutoDevice<T, D>
254where
255    T: Clone
256        + Send
257        + Sync
258        + std::fmt::Debug
259        + 'static
260        + num_traits::Zero
261        + std::ops::Div<f64, Output = T>
262        + Default
263        + std::ops::Mul<Output = T>
264        + std::ops::Add<Output = T>,
265    D: Dimension + crate::ndarray::RemoveAxis + std::fmt::Debug + 'static,
266    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
267    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
268{
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        f.debug_struct("AutoDevice")
271            .field("array", &self.array)
272            .field("device", &self.device)
273            .field("device_array", &self.device_array.is_some())
274            .finish()
275    }
276}
277
278impl<T, D> AutoDevice<T, D>
279where
280    T: Clone
281        + Send
282        + Sync
283        + 'static
284        + num_traits::Zero
285        + std::ops::Div<f64, Output = T>
286        + Default
287        + std::ops::Mul<Output = T>
288        + std::ops::Add<Output = T>,
289    D: Dimension + crate::ndarray::RemoveAxis + 'static,
290    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
291    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
292{
293    /// Create a new auto-device array.
294    pub fn new(array: Array<T, D>) -> Self {
295        let device = determine_best_device(&array);
296        let device_array = None; // Lazily initialized
297
298        Self {
299            array,
300            device,
301            device_array,
302        }
303    }
304
305    /// Get the array on the specified device.
306    pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
307        if self.device != device || self.device_array.is_none() {
308            // Convert to the requested device
309            self.device = device;
310            self.device_array = Some(convert_to_device(self.array.clone(), device));
311        }
312
313        self.device_array
314            .as_ref()
315            .expect("Operation failed")
316            .as_ref()
317    }
318
319    /// Get the current device.
320    pub fn device(&self) -> DeviceType {
321        self.device
322    }
323
324    /// Get the underlying array.
325    pub const fn array(&self) -> &Array<T, D> {
326        &self.array
327    }
328}
329
330impl<T, D> Clone for AutoDevice<T, D>
331where
332    T: Clone
333        + Send
334        + Sync
335        + 'static
336        + num_traits::Zero
337        + std::ops::Div<f64, Output = T>
338        + Default
339        + std::ops::Mul<Output = T>
340        + std::ops::Add<Output = T>,
341    D: Dimension + crate::ndarray::RemoveAxis + 'static,
342    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
343    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
344{
345    fn clone(&self) -> Self {
346        Self {
347            array: self.array.clone(),
348            device: self.device,
349            device_array: self.device_array.clone(),
350        }
351    }
352}
353
354impl<T, D> ArrayProtocol for AutoDevice<T, D>
355where
356    T: Clone
357        + Send
358        + Sync
359        + 'static
360        + num_traits::Zero
361        + std::ops::Div<f64, Output = T>
362        + Default
363        + std::ops::Mul<Output = T>
364        + std::ops::Add<Output = T>,
365    D: Dimension + crate::ndarray::RemoveAxis + 'static,
366    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
367    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
368{
369    fn array_function(
370        &self,
371        func: &ArrayFunction,
372        types: &[TypeId],
373        args: &[Box<dyn Any>],
374        kwargs: &HashMap<String, Box<dyn Any>>,
375    ) -> Result<Box<dyn Any>, NotImplemented> {
376        // If we already have a device array, delegate to it
377        if let Some(device_array) = &self.device_array {
378            device_array.array_function(func, types, args, kwargs)
379        } else {
380            // Otherwise, create a temporary array on the appropriate device
381            let device = determine_best_device(&self.array);
382            let temp_array = convert_to_device(self.array.clone(), device);
383            temp_array.array_function(func, types, args, kwargs)
384        }
385    }
386
387    fn as_any(&self) -> &dyn Any {
388        self
389    }
390
391    fn shape(&self) -> &[usize] {
392        self.array.shape()
393    }
394
395    fn dtype(&self) -> TypeId {
396        TypeId::of::<T>()
397    }
398
399    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
400        Box::new(self.clone())
401    }
402}
403
404/// Execute an operation with automatic device selection.
405///
406/// This function automatically selects the best device for the operation
407/// based on the arrays involved and the operation being performed.
408#[allow(dead_code)]
409pub fn auto_execute<T, D, F, R>(
410    arrays: &mut [&mut AutoDevice<T, D>],
411    operation: &str,
412    executor: F,
413) -> CoreResult<R>
414where
415    T: Clone
416        + Send
417        + Sync
418        + 'static
419        + num_traits::Zero
420        + std::ops::Div<f64, Output = T>
421        + Default
422        + std::ops::Mul<Output = T>
423        + std::ops::Add<Output = T>,
424    D: Dimension + crate::ndarray::RemoveAxis + 'static,
425    SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
426    SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
427    F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
428    R: 'static,
429{
430    // Determine the best device for this operation
431    let best_device = determine_best_device_for_operation(
432        &arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
433        operation,
434    );
435
436    // Get arrays on the selected device
437    let device_arrays: Vec<&dyn ArrayProtocol> = arrays
438        .iter_mut()
439        .map(|a| a.on_device(best_device))
440        .collect();
441
442    // Execute the operation
443    executor(&device_arrays)
444}
445
446/// Implementation of common array operations with automatic device selection.
447pub mod ops {
448    use super::*;
449    use crate::array_protocol::operations as ap_ops;
450    use crate::error::{CoreError, ErrorContext};
451
452    /// Matrix multiplication with automatic device selection.
453    pub fn matmul<T, D>(
454        a: &mut AutoDevice<T, D>,
455        b: &mut AutoDevice<T, D>,
456    ) -> CoreResult<Box<dyn ArrayProtocol>>
457    where
458        T: Clone
459            + Send
460            + Sync
461            + 'static
462            + num_traits::Zero
463            + std::ops::Div<f64, Output = T>
464            + Default
465            + std::ops::Mul<Output = T>
466            + std::ops::Add<Output = T>,
467        D: Dimension + crate::ndarray::RemoveAxis + 'static,
468        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
469        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
470    {
471        auto_execute(&mut [a, b], "matmul", |arrays| {
472            // Convert OperationError to CoreError
473            match ap_ops::matmul(arrays[0], arrays[1]) {
474                Ok(result) => Ok(result),
475                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
476                    e.to_string(),
477                ))),
478            }
479        })
480    }
481
482    /// Element-wise addition with automatic device selection.
483    pub fn add<T, D>(
484        a: &mut AutoDevice<T, D>,
485        b: &mut AutoDevice<T, D>,
486    ) -> CoreResult<Box<dyn ArrayProtocol>>
487    where
488        T: Clone
489            + Send
490            + Sync
491            + 'static
492            + num_traits::Zero
493            + std::ops::Div<f64, Output = T>
494            + Default
495            + std::ops::Mul<Output = T>
496            + std::ops::Add<Output = T>,
497        D: Dimension + crate::ndarray::RemoveAxis + 'static,
498        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
499        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
500    {
501        auto_execute(&mut [a, b], "add", |arrays| {
502            // Convert OperationError to CoreError
503            match ap_ops::add(arrays[0], arrays[1]) {
504                Ok(result) => Ok(result),
505                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
506                    e.to_string(),
507                ))),
508            }
509        })
510    }
511
512    /// Element-wise multiplication with automatic device selection.
513    pub fn multiply<T, D>(
514        a: &mut AutoDevice<T, D>,
515        b: &mut AutoDevice<T, D>,
516    ) -> CoreResult<Box<dyn ArrayProtocol>>
517    where
518        T: Clone
519            + Send
520            + Sync
521            + 'static
522            + num_traits::Zero
523            + std::ops::Div<f64, Output = T>
524            + Default
525            + std::ops::Mul<Output = T>
526            + std::ops::Add<Output = T>,
527        D: Dimension + crate::ndarray::RemoveAxis + 'static,
528        SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
529        SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
530    {
531        auto_execute(&mut [a, b], "multiply", |arrays| {
532            // Convert OperationError to CoreError
533            match ap_ops::multiply(arrays[0], arrays[1]) {
534                Ok(result) => Ok(result),
535                Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
536                    e.to_string(),
537                ))),
538            }
539        })
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use ::ndarray::{arr2, Array2};
547
548    #[test]
549    fn test_auto_device_selection() {
550        // Initialize the array protocol
551        crate::array_protocol::init();
552
553        // Create a small array (should be placed on CPU)
554        let small_array = Array2::<f64>::ones((10, 10));
555        let device = determine_best_device(&small_array);
556        assert_eq!(device, DeviceType::CPU);
557
558        // Modify config to place smaller arrays on GPU
559        let mut config = get_auto_device_config();
560        config.gpu_threshold = 50; // 50 elements
561        set_auto_device_config(config);
562
563        // Check device selection with new config
564        let device = determine_best_device(&small_array);
565        assert_eq!(device, DeviceType::GPU);
566
567        // Reset config
568        set_auto_device_config(AutoDeviceConfig::default());
569    }
570
571    #[test]
572    fn test_auto_device_wrapper() {
573        // Initialize the array protocol
574        crate::array_protocol::init();
575
576        // Create a small array - using IxDyn to match the trait bounds
577        let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
578        let array = array_2d.into_dyn();
579        let mut auto_array = AutoDevice::new(array.clone());
580
581        // Check initial device (should be CPU for small array)
582        assert_eq!(auto_array.device(), DeviceType::CPU);
583
584        // Get array on GPU
585        let gpu_array = auto_array.on_device(DeviceType::GPU);
586        assert!(gpu_array
587            .as_any()
588            .downcast_ref::<GPUNdarray<f64, crate::ndarray::IxDyn>>()
589            .is_some());
590
591        // Check that device was updated
592        assert_eq!(auto_array.device(), DeviceType::GPU);
593    }
594}