Skip to main content

tensorlogic_scirs_backend/
device_manager.rs

1//! Operation-level device selection and management.
2//!
3//! This module provides a pluggable device-selection framework that decides
4//! *per operation* whether to execute on CPU or GPU, based on tensor shape,
5//! operation kind, and hardware availability.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! DeviceManager          — owns a Box<dyn DeviceSelector>
11//!   └─ DeviceSelector    — trait: select(op, shape) → Device
12//!        └─ HeuristicSelector — GPU iff available ∧ large ∧ gpu-friendly op
13//! ```
14//!
15//! ## Quick start
16//!
17//! ```rust
18//! use tensorlogic_scirs_backend::device_manager::{
19//!     DeviceConfig, DeviceManager, OpDescriptor, OpKind,
20//! };
21//!
22//! let config = DeviceConfig::default().with_gpu_available(true).with_gpu_threshold(1_048_576);
23//! let mgr = DeviceManager::with_heuristic(config);
24//!
25//! let op = OpDescriptor { kind: OpKind::MatMul };
26//! let large_shape = [1024_usize, 1024];
27//! let device = mgr.select(&op, &large_shape);
28//! // → Gpu(0) when GPU is available and shape product ≥ threshold
29//! ```
30
31use crate::device::{Device, DeviceType};
32
33// ──────────────────────────────────────────────
34// OpKind
35// ──────────────────────────────────────────────
36
37/// Describes the kind of compute operation, used by the device selector
38/// heuristic to decide whether GPU execution is beneficial.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum OpKind {
41    /// Dense matrix multiplication / tensor contraction.
42    MatMul,
43
44    /// Element-wise operations (add, relu, sigmoid, …).
45    Elementwise,
46
47    /// Reduction operations (sum, max, mean, …).
48    Reduce,
49
50    /// Any other operation type.
51    Other,
52}
53
54impl OpKind {
55    /// Returns `true` when this kind of operation is well-suited for GPU
56    /// execution (high arithmetic intensity, large memory bandwidth demands).
57    ///
58    /// Currently `MatMul` and `Elementwise` are considered GPU-friendly.
59    pub fn is_gpu_friendly(self) -> bool {
60        matches!(self, OpKind::MatMul | OpKind::Elementwise)
61    }
62}
63
64// ──────────────────────────────────────────────
65// OpDescriptor
66// ──────────────────────────────────────────────
67
68/// Descriptor passed to the [`DeviceSelector`] for each scheduled operation.
69///
70/// Callers can extend this with additional fields in future without breaking
71/// implementations that only inspect `kind`.
72#[derive(Debug, Clone)]
73pub struct OpDescriptor {
74    /// The high-level kind of the operation.
75    pub kind: OpKind,
76}
77
78// ──────────────────────────────────────────────
79// DeviceSelector trait
80// ──────────────────────────────────────────────
81
82/// Trait for selecting a compute [`Device`] for a given operation.
83///
84/// Implementors decide, given an [`OpDescriptor`] and the tensor shape,
85/// which device should execute the operation.  The returned device must
86/// be valid for the current system; callers are free to treat an invalid
87/// device as an error.
88///
89/// # Thread safety
90///
91/// Implementations must be `Send + Sync` so that [`DeviceManager`] can be
92/// shared across threads.
93pub trait DeviceSelector: Send + Sync {
94    /// Select the best device for an operation described by `op` acting on a
95    /// tensor with the given `shape`.
96    fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device;
97}
98
99// ──────────────────────────────────────────────
100// DeviceConfig
101// ──────────────────────────────────────────────
102
103/// Configuration for the built-in [`HeuristicSelector`].
104///
105/// Use the builder methods to customise thresholds and forced-device overrides.
106///
107/// # Examples
108///
109/// ```rust
110/// use tensorlogic_scirs_backend::device_manager::DeviceConfig;
111///
112/// // Enable GPU when tensors have ≥ 4 M elements
113/// let cfg = DeviceConfig::default()
114///     .with_gpu_available(true)
115///     .with_gpu_threshold(4_194_304);
116/// ```
117#[derive(Debug, Clone)]
118pub struct DeviceConfig {
119    /// Minimum number of tensor elements required to consider GPU execution.
120    gpu_threshold_elems: usize,
121
122    /// Whether a GPU device is actually available on this machine.
123    gpu_available: bool,
124
125    /// Index of the GPU to target (used only when a GPU is selected).
126    gpu_index: u32,
127
128    /// When `Some`, always return this device regardless of other settings.
129    forced: Option<ForcedDevice>,
130}
131
132/// Internal forced-device discriminant to avoid storing a full `Device` clone
133/// (which is not `Copy`).
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135enum ForcedDevice {
136    Cpu,
137    Gpu(u32),
138}
139
140impl Default for DeviceConfig {
141    fn default() -> Self {
142        Self {
143            gpu_threshold_elems: 1_048_576, // 1 M elements
144            gpu_available: false,
145            gpu_index: 0,
146            forced: None,
147        }
148    }
149}
150
151impl DeviceConfig {
152    /// Set the element count threshold above which GPU execution is considered.
153    ///
154    /// Tensors with fewer than `n` elements will always run on CPU regardless
155    /// of GPU availability.
156    pub fn with_gpu_threshold(mut self, n: usize) -> Self {
157        self.gpu_threshold_elems = n;
158        self
159    }
160
161    /// Declare whether a GPU is available on the current system.
162    pub fn with_gpu_available(mut self, avail: bool) -> Self {
163        self.gpu_available = avail;
164        self
165    }
166
167    /// Force all operations to run on CPU, overriding every other setting.
168    pub fn force_cpu(mut self) -> Self {
169        self.forced = Some(ForcedDevice::Cpu);
170        self
171    }
172
173    /// Force all operations to run on the GPU with the given device index,
174    /// overriding every other setting.
175    pub fn force_gpu(mut self, idx: u32) -> Self {
176        self.forced = Some(ForcedDevice::Gpu(idx));
177        self
178    }
179
180    /// Set the GPU device index used when GPU execution is selected.
181    pub fn with_gpu_index(mut self, idx: u32) -> Self {
182        self.gpu_index = idx;
183        self
184    }
185}
186
187// ──────────────────────────────────────────────
188// HeuristicSelector
189// ──────────────────────────────────────────────
190
191/// A heuristic [`DeviceSelector`] that routes ops to GPU when three conditions
192/// are simultaneously satisfied:
193///
194/// 1. `config.gpu_available` is `true`.
195/// 2. The tensor element count (`shape.iter().product()`) is ≥
196///    `config.gpu_threshold_elems`.
197/// 3. `op.kind.is_gpu_friendly()` returns `true`.
198///
199/// Any `force_cpu` / `force_gpu` override in [`DeviceConfig`] takes
200/// precedence over all three conditions.
201pub struct HeuristicSelector {
202    config: DeviceConfig,
203}
204
205impl HeuristicSelector {
206    /// Create a new `HeuristicSelector` from the given configuration.
207    pub fn new(config: DeviceConfig) -> Self {
208        Self { config }
209    }
210}
211
212impl DeviceSelector for HeuristicSelector {
213    fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device {
214        // Forced override wins unconditionally.
215        if let Some(forced) = self.config.forced {
216            return match forced {
217                ForcedDevice::Cpu => Device::cpu(),
218                ForcedDevice::Gpu(idx) => Device {
219                    device_type: DeviceType::Cuda,
220                    index: idx as usize,
221                },
222            };
223        }
224
225        let n_elems: usize = shape.iter().product();
226
227        if self.config.gpu_available
228            && n_elems >= self.config.gpu_threshold_elems
229            && op.kind.is_gpu_friendly()
230        {
231            Device {
232                device_type: DeviceType::Cuda,
233                index: self.config.gpu_index as usize,
234            }
235        } else {
236            Device::cpu()
237        }
238    }
239}
240
241// ──────────────────────────────────────────────
242// DeviceManager
243// ──────────────────────────────────────────────
244
245/// Operation-level device manager wrapping a pluggable [`DeviceSelector`].
246///
247/// `DeviceManager` is the public entry point for the device-selection framework.
248/// Callers construct one with a selector of their choice (or use
249/// [`DeviceManager::with_heuristic`] for the built-in heuristic), then call
250/// [`DeviceManager::select`] once per scheduled operation.
251///
252/// # Examples
253///
254/// ```rust
255/// use tensorlogic_scirs_backend::device_manager::{
256///     DeviceConfig, DeviceManager, OpDescriptor, OpKind,
257/// };
258///
259/// let mgr = DeviceManager::with_heuristic(DeviceConfig::default());
260/// let op  = OpDescriptor { kind: OpKind::MatMul };
261/// let dev = mgr.select(&op, &[32, 32]);
262/// assert!(dev.is_cpu()); // GPU not available by default
263/// ```
264pub struct DeviceManager {
265    selector: Box<dyn DeviceSelector>,
266}
267
268impl DeviceManager {
269    /// Create a `DeviceManager` backed by any [`DeviceSelector`] implementation.
270    pub fn new(selector: impl DeviceSelector + 'static) -> Self {
271        Self {
272            selector: Box::new(selector),
273        }
274    }
275
276    /// Create a `DeviceManager` backed by the built-in [`HeuristicSelector`]
277    /// configured with `config`.
278    pub fn with_heuristic(config: DeviceConfig) -> Self {
279        Self::new(HeuristicSelector::new(config))
280    }
281
282    /// Select the compute device for an operation described by `op` acting on
283    /// a tensor with the given `shape`.
284    pub fn select(&self, op: &OpDescriptor, shape: &[usize]) -> Device {
285        self.selector.select(op, shape)
286    }
287}
288
289// ──────────────────────────────────────────────
290// Tests
291// ──────────────────────────────────────────────
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    // Helper: build a DeviceConfig with GPU available at default threshold (1 M).
298    fn gpu_config() -> DeviceConfig {
299        DeviceConfig::default().with_gpu_available(true)
300    }
301
302    // Helper: build a tiny shape (10 elements total).
303    fn tiny_shape() -> [usize; 2] {
304        [2, 5] // 10 elements
305    }
306
307    // Helper: build a large shape (2 M elements).
308    fn large_shape() -> [usize; 2] {
309        [1024, 2048] // 2 097 152 elements > 1 M threshold
310    }
311
312    // ── OpKind ──────────────────────────────────────────────────────────────
313
314    #[test]
315    fn test_op_kind_gpu_friendly() {
316        assert!(OpKind::MatMul.is_gpu_friendly());
317        assert!(OpKind::Elementwise.is_gpu_friendly());
318        assert!(!OpKind::Reduce.is_gpu_friendly());
319        assert!(!OpKind::Other.is_gpu_friendly());
320    }
321
322    // ── Heuristic: tiny tensor → CPU even when GPU available ─────────────
323
324    #[test]
325    fn test_tiny_tensor_routes_to_cpu() {
326        let mgr = DeviceManager::with_heuristic(gpu_config());
327        let op = OpDescriptor {
328            kind: OpKind::MatMul,
329        };
330        let dev = mgr.select(&op, &tiny_shape());
331        assert!(dev.is_cpu(), "tiny tensor should use CPU");
332    }
333
334    // ── Heuristic: large + gpu_available + MatMul → Gpu ─────────────────
335
336    #[test]
337    fn test_large_matmul_routes_to_gpu_when_available() {
338        let mgr = DeviceManager::with_heuristic(gpu_config());
339        let op = OpDescriptor {
340            kind: OpKind::MatMul,
341        };
342        let dev = mgr.select(&op, &large_shape());
343        assert!(
344            dev.is_gpu(),
345            "large MatMul with GPU available should use GPU"
346        );
347    }
348
349    // ── Heuristic: large + gpu_available=false → CPU ─────────────────────
350
351    #[test]
352    fn test_large_tensor_cpu_when_gpu_unavailable() {
353        let cfg = DeviceConfig::default().with_gpu_available(false);
354        let mgr = DeviceManager::with_heuristic(cfg);
355        let op = OpDescriptor {
356            kind: OpKind::MatMul,
357        };
358        let dev = mgr.select(&op, &large_shape());
359        assert!(dev.is_cpu(), "no GPU available → must stay on CPU");
360    }
361
362    // ── Heuristic: large + gpu_available + Other kind → CPU ──────────────
363
364    #[test]
365    fn test_large_non_gpu_friendly_op_routes_to_cpu() {
366        let mgr = DeviceManager::with_heuristic(gpu_config());
367
368        for kind in [OpKind::Reduce, OpKind::Other] {
369            let op = OpDescriptor { kind };
370            let dev = mgr.select(&op, &large_shape());
371            assert!(
372                dev.is_cpu(),
373                "{kind:?} is not GPU-friendly and should run on CPU"
374            );
375        }
376    }
377
378    // ── force_cpu overrides GPU-eligible combination ──────────────────────
379
380    #[test]
381    fn test_force_cpu_overrides_gpu_eligible() {
382        let cfg = gpu_config().force_cpu();
383        let mgr = DeviceManager::with_heuristic(cfg);
384        let op = OpDescriptor {
385            kind: OpKind::MatMul,
386        };
387        let dev = mgr.select(&op, &large_shape());
388        assert!(dev.is_cpu(), "force_cpu must override GPU eligibility");
389    }
390
391    // ── force_gpu overrides CPU-only config ──────────────────────────────
392
393    #[test]
394    fn test_force_gpu_overrides_cpu_config() {
395        // GPU is not "available" and tensor is tiny, but force_gpu wins.
396        let cfg = DeviceConfig::default()
397            .with_gpu_available(false)
398            .force_gpu(0);
399        let mgr = DeviceManager::with_heuristic(cfg);
400        let op = OpDescriptor {
401            kind: OpKind::Other,
402        };
403        let dev = mgr.select(&op, &tiny_shape());
404        assert!(dev.is_gpu(), "force_gpu must override all other conditions");
405    }
406
407    // ── Elementwise large tensor + GPU available ─────────────────────────
408
409    #[test]
410    fn test_large_elementwise_routes_to_gpu() {
411        let mgr = DeviceManager::with_heuristic(gpu_config());
412        let op = OpDescriptor {
413            kind: OpKind::Elementwise,
414        };
415        let dev = mgr.select(&op, &large_shape());
416        assert!(
417            dev.is_gpu(),
418            "large Elementwise with GPU available should use GPU"
419        );
420    }
421
422    // ── Custom selector injection ─────────────────────────────────────────
423
424    #[test]
425    fn test_custom_selector_always_cpu() {
426        struct AlwaysCpu;
427        impl DeviceSelector for AlwaysCpu {
428            fn select(&self, _op: &OpDescriptor, _shape: &[usize]) -> Device {
429                Device::cpu()
430            }
431        }
432
433        let mgr = DeviceManager::new(AlwaysCpu);
434        let op = OpDescriptor {
435            kind: OpKind::MatMul,
436        };
437        let dev = mgr.select(&op, &large_shape());
438        assert!(dev.is_cpu(), "custom selector should override heuristic");
439    }
440
441    // ── DeviceConfig builder API ──────────────────────────────────────────
442
443    #[test]
444    fn test_device_config_builder_threshold() {
445        // Tensor with 512 elements, threshold set to 256 → should go to GPU.
446        let cfg = DeviceConfig::default()
447            .with_gpu_available(true)
448            .with_gpu_threshold(256);
449        let mgr = DeviceManager::with_heuristic(cfg);
450        let op = OpDescriptor {
451            kind: OpKind::MatMul,
452        };
453        let shape = [16_usize, 32]; // 512 elements
454        let dev = mgr.select(&op, &shape);
455        assert!(dev.is_gpu(), "512 elems > 256 threshold should use GPU");
456    }
457
458    #[test]
459    fn test_device_config_default_no_gpu() {
460        // Default config: gpu_available = false.
461        let mgr = DeviceManager::with_heuristic(DeviceConfig::default());
462        let op = OpDescriptor {
463            kind: OpKind::MatMul,
464        };
465        let dev = mgr.select(&op, &large_shape());
466        assert!(dev.is_cpu(), "default config has no GPU available");
467    }
468}