Skip to main content

zer_compute/backend/
mod.rs

1//! Backend abstraction, [`DeviceBackend`] enum, [`BackendPreference`], and
2//! auto-detection logic.
3//!
4//! # Selecting a backend
5//!
6//! ```rust,ignore
7//! // Auto: CUDA → AVX2 → CPU scalar
8//! let backend = DeviceBackend::auto_detect();
9//!
10//! // Explicit with error if not compiled in or hardware missing
11//! let backend = DeviceBackend::from_preference(BackendPreference::Cuda)?;
12//! ```
13//!
14//! # Dispatching kernels
15//!
16//! ```rust,ignore
17//! use zer_compute::kernels::hello_backend::{HelloBackend, HelloBackendInput};
18//!
19//! let out = backend.run::<HelloBackend>(HelloBackendInput)?;
20//! ```
21//!
22//! `run<K>` is generic over any `K: Kernel` for which `DeviceBackend:
23//! KernelDispatch<K>` is implemented.  The `impl KernelDispatch<K> for
24//! DeviceBackend` blocks at the bottom of this file do the N-arm match that
25//! delegates to the active variant.
26
27pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "avx2")]
33pub mod avx2;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38use crate::{
39    backend::cpu::CpuDevice,
40    error::GpuError,
41    kernel::{Kernel, KernelDispatch},
42    kernels::{
43        em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
44        hello_backend::{HelloBackend, HelloBackendInput, HelloBackendOutput},
45    },
46};
47
48// ── DeviceBackend ─────────────────────────────────────────────────────────────
49
50/// Active backend selected at runtime.
51///
52/// Obtain via [`DeviceBackend::auto_detect`] for automatic selection, or
53/// [`DeviceBackend::from_preference`] to request a specific backend.
54///
55/// # Feature gating
56///
57/// - `Cuda` variant requires `--features cuda`.
58/// - `Avx2` variant requires `--features avx2` on an x86_64 host.
59/// - `Cpu` is always present and is the fallback of last resort.
60pub enum DeviceBackend {
61    /// Scalar CPU path, delegates to `zer-compare` (Rayon parallel).
62    Cpu,
63
64    /// NVIDIA CUDA path via `cudarc`, preferred when available.
65    #[cfg(feature = "cuda")]
66    Cuda(cuda::CudaDevice),
67
68    /// Vulkan 1.3 compute path, works on NVIDIA Maxwell+ and other Vulkan-capable GPUs.
69    #[cfg(feature = "vulkan")]
70    Vulkan(vulkan::VulkanDevice),
71
72    /// x86_64 AVX2 SIMD path, no external toolchain required.
73    #[cfg(feature = "avx2")]
74    Avx2,
75}
76
77/// Backward-compatibility alias.  New code should use [`DeviceBackend`].
78pub type GpuBackend = DeviceBackend;
79
80// ── BackendPreference ─────────────────────────────────────────────────────────
81
82/// Explicit backend preference passed to [`DeviceBackend::from_preference`].
83#[non_exhaustive]
84pub enum BackendPreference {
85    /// Try CUDA → Vulkan → AVX2 → CPU in order (same as `auto_detect`).
86    Auto,
87    /// Require CUDA; error if not compiled in or no CUDA GPU available.
88    Cuda,
89    /// Require Vulkan; error if not compiled in or no Vulkan GPU available.
90    Vulkan,
91    /// Require AVX2; error if not compiled in or CPU lacks AVX2 support.
92    Avx2,
93    /// Always use the scalar CPU path.
94    Cpu,
95}
96
97// ── DeviceBackend impl ────────────────────────────────────────────────────────
98
99impl DeviceBackend {
100    /// Auto-detect the best available backend: CUDA → AVX2 → CPU scalar.
101    ///
102    /// Never panics; always returns a usable backend.  Tracing output explains
103    /// which path was selected and why alternatives were skipped.
104    pub fn auto_detect() -> Self {
105        #[cfg(feature = "cuda")]
106        match cuda::CudaDevice::init() {
107            Ok(dev) => {
108                tracing::info!(
109                    device_name = %dev.name(),
110                    vram_bytes  = dev.total_vram_bytes(),
111                    "compute backend: CUDA selected"
112                );
113                return Self::Cuda(dev);
114            }
115            Err(e) => tracing::warn!(%e, "CUDA init failed, trying Vulkan"),
116        }
117
118        #[cfg(feature = "vulkan")]
119        match vulkan::VulkanDevice::init() {
120            Ok(dev) => {
121                tracing::info!(
122                    device_name = %dev.name(),
123                    vram_bytes  = dev.total_vram_bytes(),
124                    "compute backend: Vulkan selected"
125                );
126                return Self::Vulkan(dev);
127            }
128            Err(e) => tracing::warn!(%e, "Vulkan init failed, trying AVX2"),
129        }
130
131        #[cfg(feature = "avx2")]
132        if is_x86_feature_detected!("avx2") {
133            tracing::info!("compute backend: AVX2 selected");
134            return Self::Avx2;
135        }
136
137        tracing::warn!("compute backend: scalar CPU fallback");
138        Self::Cpu
139    }
140
141    /// Force the scalar CPU backend regardless of available hardware.
142    ///
143    /// Useful in tests where deterministic, non-SIMD behaviour is required.
144    pub fn cpu() -> Self {
145        Self::Cpu
146    }
147
148    /// Initialise the CUDA backend explicitly.
149    ///
150    /// Requires `--features cuda`; the method does not exist without it.
151    /// Returns `Err` when no CUDA-capable GPU is present or driver init fails.
152    #[cfg(feature = "cuda")]
153    pub fn cuda() -> Result<Self, GpuError> {
154        Ok(Self::Cuda(cuda::CudaDevice::init()?))
155    }
156
157    /// Initialise the Vulkan compute backend explicitly.
158    ///
159    /// Requires `--features vulkan`; the method does not exist without it.
160    /// Returns `Err` when no Vulkan-capable GPU is present or init fails.
161    #[cfg(feature = "vulkan")]
162    pub fn vulkan() -> Result<Self, GpuError> {
163        Ok(Self::Vulkan(vulkan::VulkanDevice::init()?))
164    }
165
166    /// Initialise the AVX2 SIMD backend explicitly.
167    ///
168    /// Requires `--features avx2`; the method does not exist without it.
169    /// Returns `Err` when the running CPU does not support AVX2.
170    #[cfg(feature = "avx2")]
171    pub fn avx2() -> Result<Self, GpuError> {
172        if is_x86_feature_detected!("avx2") {
173            Ok(Self::Avx2)
174        } else {
175            Err(GpuError::BackendUnavailable(
176                "AVX2 not supported by this CPU".into(),
177            ))
178        }
179    }
180
181    /// Request a specific backend.
182    ///
183    /// Returns `Err(BackendUnavailable)` when:
184    /// - The requested feature flag was not compiled in.
185    /// - The hardware initialisation fails (e.g. no CUDA GPU present).
186    /// - The requested ISA extension is absent at runtime (AVX2).
187    pub fn from_preference(pref: BackendPreference) -> Result<Self, GpuError> {
188        match pref {
189            BackendPreference::Auto => Ok(Self::auto_detect()),
190            BackendPreference::Cpu  => Ok(Self::Cpu),
191
192            BackendPreference::Cuda => {
193                #[cfg(feature = "cuda")]
194                return Ok(Self::Cuda(cuda::CudaDevice::init()?));
195                #[allow(unreachable_code)]
196                Err(GpuError::BackendUnavailable(
197                    "CUDA backend not compiled in; rebuild with --features cuda".into(),
198                ))
199            }
200
201            BackendPreference::Vulkan => {
202                #[cfg(feature = "vulkan")]
203                return Ok(Self::Vulkan(vulkan::VulkanDevice::init()?));
204                #[allow(unreachable_code)]
205                Err(GpuError::BackendUnavailable(
206                    "Vulkan backend not compiled in; rebuild with --features vulkan".into(),
207                ))
208            }
209
210            BackendPreference::Avx2 => {
211                #[cfg(feature = "avx2")]
212                {
213                    if is_x86_feature_detected!("avx2") {
214                        return Ok(Self::Avx2);
215                    }
216                    return Err(GpuError::BackendUnavailable(
217                        "AVX2 not supported by this CPU".into(),
218                    ));
219                }
220                #[allow(unreachable_code)]
221                Err(GpuError::BackendUnavailable(
222                    "AVX2 backend not compiled in; rebuild with --features avx2".into(),
223                ))
224            }
225        }
226    }
227
228    /// Dispatch kernel `K` on this backend.
229    pub fn run<K: Kernel>(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>
230    where
231        Self: KernelDispatch<K>,
232    {
233        self.dispatch(input)
234    }
235
236    /// Human-readable name of the active backend.
237    pub fn name(&self) -> &'static str {
238        match self {
239            Self::Cpu => "cpu",
240            #[cfg(feature = "cuda")]
241            Self::Cuda(_) => "cuda",
242            #[cfg(feature = "vulkan")]
243            Self::Vulkan(_) => "vulkan",
244            #[cfg(feature = "avx2")]
245            Self::Avx2 => "avx2",
246        }
247    }
248
249    /// `true` when a GPU backend (CUDA or Vulkan) is active.
250    pub fn is_gpu(&self) -> bool {
251        match self {
252            #[cfg(feature = "cuda")]
253            Self::Cuda(_) => true,
254            #[cfg(feature = "vulkan")]
255            Self::Vulkan(_) => true,
256            _ => false,
257        }
258    }
259
260    /// `true` when this backend provides any hardware acceleration
261    /// (GPU or SIMD), i.e. it is not the scalar CPU fallback.
262    pub fn is_accelerated(&self) -> bool {
263        !matches!(self, Self::Cpu)
264    }
265
266    /// Available VRAM in bytes, or `None` for CPU/AVX2 paths.
267    pub fn available_vram_bytes(&self) -> Option<u64> {
268        match self {
269            Self::Cpu => None,
270            #[cfg(feature = "cuda")]
271            Self::Cuda(dev) => dev.available_vram_bytes().ok(),
272            #[cfg(feature = "vulkan")]
273            Self::Vulkan(dev) => dev.available_vram_bytes(),
274            #[cfg(feature = "avx2")]
275            Self::Avx2 => None,
276        }
277    }
278
279    /// Total (installed) VRAM in bytes, or `None` for CPU/AVX2 paths.
280    pub fn total_vram_bytes(&self) -> Option<u64> {
281        match self {
282            Self::Cpu => None,
283            #[cfg(feature = "cuda")]
284            Self::Cuda(dev) => Some(dev.total_vram_bytes()),
285            #[cfg(feature = "vulkan")]
286            Self::Vulkan(dev) => Some(dev.total_vram_bytes()),
287            #[cfg(feature = "avx2")]
288            Self::Avx2 => None,
289        }
290    }
291}
292
293// ── Session-based full-GPU EM ─────────────────────────────────────────────────
294
295/// Unified EM session handle, wraps CUDA, Vulkan, or AVX2 session state.
296///
297/// Callers treat this as an opaque token.  The owning `DeviceBackend` is
298/// responsible for cleanup via [`DeviceBackend::em_drop_session`].
299#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
300pub(crate) enum EmSession {
301    #[cfg(feature = "cuda")]
302    Cuda(cuda::launch::em_reduce::CudaEmSession),
303    #[cfg(feature = "vulkan")]
304    Vulkan(vulkan::launch::em_reduce::VulkanEmSession),
305    #[cfg(feature = "avx2")]
306    Avx2(avx2::launch::em_reduce::Avx2EmSession),
307}
308
309#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
310impl DeviceBackend {
311    /// Allocate an EM session: upload `comparison_levels` once, pre-allocate
312    /// all backend-specific buffers.  Call this once before the EM loop.
313    pub(crate) fn em_init_session(
314        &self,
315        comparison_levels: &[u32],
316        n_pairs:  usize,
317        n_fields: usize,
318    ) -> Result<EmSession, GpuError> {
319        match self {
320            #[cfg(feature = "cuda")]
321            Self::Cuda(dev) => dev.em_init_session(comparison_levels, n_pairs, n_fields)
322                .map(EmSession::Cuda),
323            #[cfg(feature = "vulkan")]
324            Self::Vulkan(dev) => dev.em_init_session(comparison_levels, n_pairs, n_fields)
325                .map(EmSession::Vulkan),
326            #[cfg(feature = "avx2")]
327            Self::Avx2 => Ok(EmSession::Avx2(
328                avx2::device::Avx2Device::em_init_session(comparison_levels, n_pairs, n_fields),
329            )),
330            _ => Err(GpuError::BackendUnavailable(
331                "em_init_session requires an accelerated backend".into(),
332            )),
333        }
334    }
335
336    /// Run one full EM iteration (E-step + M-step) on the active backend.
337    ///
338    /// `weights` must be `ln(m[f][l] / u[f][l])`, `n_fields * 4` floats.
339    /// Returns raw M-step counts; the caller normalises them into `ModelParams`.
340    pub(crate) fn em_run_iteration(
341        &self,
342        session:        &mut EmSession,
343        weights:        &[f32],
344        log_prior_odds: f32,
345    ) -> Result<EmReduceOutput, GpuError> {
346        match (self, session) {
347            #[cfg(feature = "cuda")]
348            (Self::Cuda(dev), EmSession::Cuda(s)) =>
349                dev.em_run_iteration(s, weights, log_prior_odds),
350            #[cfg(feature = "vulkan")]
351            (Self::Vulkan(dev), EmSession::Vulkan(s)) =>
352                dev.em_run_iteration(s, weights, log_prior_odds),
353            #[cfg(feature = "avx2")]
354            (Self::Avx2, EmSession::Avx2(s)) =>
355                avx2::device::Avx2Device::em_run_iteration(s, weights, log_prior_odds),
356            _ => Err(GpuError::BackendUnavailable(
357                "em_run_iteration requires an accelerated backend".into(),
358            )),
359        }
360    }
361
362    /// Release all backend-side resources held by `session`.
363    ///
364    /// For CUDA/AVX2: fields auto-drop.
365    /// For Vulkan: explicit `VulkanEmSession::destroy` is required because
366    /// `VulkanBuffer` has no `Drop` impl.
367    pub(crate) fn em_drop_session(&self, session: EmSession) {
368        match (self, session) {
369            #[cfg(feature = "cuda")]
370            (Self::Cuda(_), EmSession::Cuda(_s)) => { /* CudaSlice fields auto-drop */ }
371            #[cfg(feature = "vulkan")]
372            (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
373                let mut alloc = dev.allocator.lock().unwrap();
374                s.destroy(&dev.device, &mut alloc);
375            }
376            #[cfg(feature = "avx2")]
377            (Self::Avx2, EmSession::Avx2(_s)) => { /* Vec fields auto-drop */ }
378            _ => {}
379        }
380    }
381}
382
383// ── KernelDispatch impls for DeviceBackend ────────────────────────────────────
384//
385// One impl block per kernel.  Each block delegates to the active variant's
386// per-backend KernelDispatch impl (defined in the respective launch/ modules).
387
388impl KernelDispatch<HelloBackend> for DeviceBackend {
389    fn dispatch(&self, input: HelloBackendInput) -> Result<HelloBackendOutput, GpuError> {
390        match self {
391            #[cfg(feature = "cuda")]
392            Self::Cuda(dev)   => <cuda::CudaDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input),
393            #[cfg(feature = "vulkan")]
394            Self::Vulkan(dev) => <vulkan::VulkanDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input),
395            #[cfg(feature = "avx2")]
396            Self::Avx2        => <avx2::Avx2Device as KernelDispatch<HelloBackend>>::dispatch(&avx2::Avx2Device, input),
397            Self::Cpu         => <CpuDevice as KernelDispatch<HelloBackend>>::dispatch(&CpuDevice, input),
398        }
399    }
400}
401
402impl KernelDispatch<EmReduce> for DeviceBackend {
403    fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
404        match self {
405            #[cfg(feature = "cuda")]
406            Self::Cuda(dev)   => <cuda::CudaDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
407            #[cfg(feature = "vulkan")]
408            Self::Vulkan(dev) => <vulkan::VulkanDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
409            #[cfg(feature = "avx2")]
410            Self::Avx2        => <avx2::Avx2Device as KernelDispatch<EmReduce>>::dispatch(&avx2::Avx2Device, input),
411            Self::Cpu         => <CpuDevice as KernelDispatch<EmReduce>>::dispatch(&CpuDevice, input),
412        }
413    }
414}
415
416// ── Unit tests ────────────────────────────────────────────────────────────────
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn auto_detect_does_not_panic() {
424        let backend = DeviceBackend::auto_detect();
425        let name = backend.name();
426        assert!(
427            matches!(name, "cpu" | "cuda" | "vulkan" | "avx2"),
428            "unexpected backend name: {name}"
429        );
430    }
431
432    #[test]
433    fn cpu_backend_has_no_vram() {
434        let b = DeviceBackend::cpu();
435        assert_eq!(b.available_vram_bytes(), None);
436        assert_eq!(b.total_vram_bytes(), None);
437        assert!(!b.is_gpu());
438        assert!(!b.is_accelerated());
439    }
440
441    #[test]
442    fn cpu_backend_name() {
443        assert_eq!(DeviceBackend::cpu().name(), "cpu");
444    }
445
446    #[test]
447    fn cpu_preference_always_succeeds() {
448        assert!(DeviceBackend::from_preference(BackendPreference::Cpu).is_ok());
449    }
450
451    #[cfg(feature = "avx2")]
452    #[test]
453    fn avx2_backend_is_accelerated_not_gpu() {
454        let b = DeviceBackend::Avx2;
455        assert!(b.is_accelerated());
456        assert!(!b.is_gpu());
457        assert_eq!(b.name(), "avx2");
458        assert_eq!(b.available_vram_bytes(), None);
459    }
460}