Skip to main content

zer_compute/backend/
mod.rs

1//! Backend abstraction, [`DeviceBackend`] enum, [`BackendPreference`], and
2//! auto-detection logic.
3//!
4//! # Selecting a backend
5//!
6//! ```rust,ignore
7//! // Auto: CUDA → AVX2 → CPU scalar
8//! let backend = DeviceBackend::auto_detect();
9//!
10//! // Explicit with error if not compiled in or hardware missing
11//! let backend = DeviceBackend::from_preference(BackendPreference::Cuda)?;
12//! ```
13//!
14//! # Dispatching kernels
15//!
16//! ```rust,ignore
17//! use zer_compute::kernels::hello_backend::{HelloBackend, HelloBackendInput};
18//!
19//! let out = backend.run::<HelloBackend>(HelloBackendInput)?;
20//! ```
21//!
22//! `run<K>` is generic over any `K: Kernel` for which `DeviceBackend:
23//! KernelDispatch<K>` is implemented.  The `impl KernelDispatch<K> for
24//! DeviceBackend` blocks at the bottom of this file do the N-arm match that
25//! delegates to the active variant.
26
27pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "avx2")]
33pub mod avx2;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38use crate::{
39    backend::cpu::CpuDevice,
40    error::GpuError,
41    kernel::{Kernel, KernelDispatch},
42    kernels::{
43        em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
44        hello_backend::{HelloBackend, HelloBackendInput, HelloBackendOutput},
45    },
46};
47
48// ── DeviceBackend ─────────────────────────────────────────────────────────────
49
50/// Active backend selected at runtime.
51///
52/// Obtain via [`DeviceBackend::auto_detect`] for automatic selection, or
53/// [`DeviceBackend::from_preference`] to request a specific backend.
54///
55/// # Feature gating
56///
57/// - `Cuda` variant requires `--features cuda`.
58/// - `Avx2` variant requires `--features avx2` on an x86_64 host.
59/// - `Cpu` is always present and is the fallback of last resort.
60pub enum DeviceBackend {
61    /// Scalar CPU path, delegates to `zer-compare` (Rayon parallel).
62    Cpu,
63
64    /// NVIDIA CUDA path via `cudarc`, preferred when available.
65    #[cfg(feature = "cuda")]
66    Cuda(cuda::CudaDevice),
67
68    /// Vulkan 1.3 compute path, works on NVIDIA Maxwell+ and other Vulkan-capable GPUs.
69    #[cfg(feature = "vulkan")]
70    Vulkan(vulkan::VulkanDevice),
71
72    /// x86_64 AVX2 SIMD path, no external toolchain required.
73    #[cfg(feature = "avx2")]
74    Avx2,
75}
76
77/// Backward-compatibility alias.  New code should use [`DeviceBackend`].
78pub type GpuBackend = DeviceBackend;
79
80// ── BackendPreference ─────────────────────────────────────────────────────────
81
82/// Explicit backend preference passed to [`DeviceBackend::from_preference`].
83#[non_exhaustive]
84pub enum BackendPreference {
85    /// Try CUDA → Vulkan → AVX2 → CPU in order (same as `auto_detect`).
86    Auto,
87    /// Require CUDA; error if not compiled in or no CUDA GPU available.
88    Cuda,
89    /// Require Vulkan; error if not compiled in or no Vulkan GPU available.
90    Vulkan,
91    /// Require AVX2; error if not compiled in or CPU lacks AVX2 support.
92    Avx2,
93    /// Always use the scalar CPU path.
94    Cpu,
95}
96
97// ── DeviceBackend impl ────────────────────────────────────────────────────────
98
99impl DeviceBackend {
100    /// Auto-detect the best available backend: CUDA → AVX2 → CPU scalar.
101    ///
102    /// Never panics; always returns a usable backend.  Tracing output explains
103    /// which path was selected and why alternatives were skipped.
104    pub fn auto_detect() -> Self {
105        #[cfg(feature = "cuda")]
106        match cuda::CudaDevice::init() {
107            Ok(dev) => {
108                tracing::info!(
109                    device_name = %dev.name(),
110                    vram_bytes  = dev.total_vram_bytes(),
111                    "compute backend: CUDA selected"
112                );
113                return Self::Cuda(dev);
114            }
115            Err(e) => tracing::warn!(%e, "CUDA init failed, trying Vulkan"),
116        }
117
118        #[cfg(feature = "vulkan")]
119        match vulkan::VulkanDevice::init() {
120            Ok(dev) => {
121                tracing::info!(
122                    device_name = %dev.name(),
123                    vram_bytes  = dev.total_vram_bytes(),
124                    "compute backend: Vulkan selected"
125                );
126                return Self::Vulkan(dev);
127            }
128            Err(e) => tracing::warn!(%e, "Vulkan init failed, trying AVX2"),
129        }
130
131        #[cfg(feature = "avx2")]
132        if is_x86_feature_detected!("avx2") {
133            tracing::info!("compute backend: AVX2 selected");
134            return Self::Avx2;
135        }
136
137        tracing::warn!("compute backend: scalar CPU fallback");
138        Self::Cpu
139    }
140
141    /// Force the scalar CPU backend regardless of available hardware.
142    ///
143    /// Useful in tests where deterministic, non-SIMD behaviour is required.
144    pub fn cpu() -> Self {
145        Self::Cpu
146    }
147
148    /// Initialise the CUDA backend explicitly.
149    ///
150    /// Requires `--features cuda`; the method does not exist without it.
151    /// Returns `Err` when no CUDA-capable GPU is present or driver init fails.
152    #[cfg(feature = "cuda")]
153    pub fn cuda() -> Result<Self, GpuError> {
154        Ok(Self::Cuda(cuda::CudaDevice::init()?))
155    }
156
157    /// Initialise the Vulkan compute backend explicitly.
158    ///
159    /// Requires `--features vulkan`; the method does not exist without it.
160    /// Returns `Err` when no Vulkan-capable GPU is present or init fails.
161    #[cfg(feature = "vulkan")]
162    pub fn vulkan() -> Result<Self, GpuError> {
163        Ok(Self::Vulkan(vulkan::VulkanDevice::init()?))
164    }
165
166    /// Initialise the AVX2 SIMD backend explicitly.
167    ///
168    /// Requires `--features avx2`; the method does not exist without it.
169    /// Returns `Err` when the running CPU does not support AVX2.
170    #[cfg(feature = "avx2")]
171    pub fn avx2() -> Result<Self, GpuError> {
172        if is_x86_feature_detected!("avx2") {
173            Ok(Self::Avx2)
174        } else {
175            Err(GpuError::BackendUnavailable(
176                "AVX2 not supported by this CPU".into(),
177            ))
178        }
179    }
180
181    /// Request a specific backend.
182    ///
183    /// Returns `Err(BackendUnavailable)` when:
184    /// - The requested feature flag was not compiled in.
185    /// - The hardware initialisation fails (e.g. no CUDA GPU present).
186    /// - The requested ISA extension is absent at runtime (AVX2).
187    pub fn from_preference(pref: BackendPreference) -> Result<Self, GpuError> {
188        match pref {
189            BackendPreference::Auto => Ok(Self::auto_detect()),
190            BackendPreference::Cpu => Ok(Self::Cpu),
191
192            BackendPreference::Cuda => {
193                #[cfg(feature = "cuda")]
194                return Ok(Self::Cuda(cuda::CudaDevice::init()?));
195                #[allow(unreachable_code)]
196                Err(GpuError::BackendUnavailable(
197                    "CUDA backend not compiled in; rebuild with --features cuda".into(),
198                ))
199            }
200
201            BackendPreference::Vulkan => {
202                #[cfg(feature = "vulkan")]
203                return Ok(Self::Vulkan(vulkan::VulkanDevice::init()?));
204                #[allow(unreachable_code)]
205                Err(GpuError::BackendUnavailable(
206                    "Vulkan backend not compiled in; rebuild with --features vulkan".into(),
207                ))
208            }
209
210            BackendPreference::Avx2 => {
211                #[cfg(feature = "avx2")]
212                {
213                    if is_x86_feature_detected!("avx2") {
214                        return Ok(Self::Avx2);
215                    }
216                    return Err(GpuError::BackendUnavailable(
217                        "AVX2 not supported by this CPU".into(),
218                    ));
219                }
220                #[allow(unreachable_code)]
221                Err(GpuError::BackendUnavailable(
222                    "AVX2 backend not compiled in; rebuild with --features avx2".into(),
223                ))
224            }
225        }
226    }
227
228    /// Dispatch kernel `K` on this backend.
229    pub fn run<K: Kernel>(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>
230    where
231        Self: KernelDispatch<K>,
232    {
233        self.dispatch(input)
234    }
235
236    /// Human-readable name of the active backend.
237    pub fn name(&self) -> &'static str {
238        match self {
239            Self::Cpu => "cpu",
240            #[cfg(feature = "cuda")]
241            Self::Cuda(_) => "cuda",
242            #[cfg(feature = "vulkan")]
243            Self::Vulkan(_) => "vulkan",
244            #[cfg(feature = "avx2")]
245            Self::Avx2 => "avx2",
246        }
247    }
248
249    /// `true` when a GPU backend (CUDA or Vulkan) is active.
250    pub fn is_gpu(&self) -> bool {
251        match self {
252            #[cfg(feature = "cuda")]
253            Self::Cuda(_) => true,
254            #[cfg(feature = "vulkan")]
255            Self::Vulkan(_) => true,
256            _ => false,
257        }
258    }
259
260    /// `true` when this backend provides any hardware acceleration
261    /// (GPU or SIMD), i.e. it is not the scalar CPU fallback.
262    pub fn is_accelerated(&self) -> bool {
263        !matches!(self, Self::Cpu)
264    }
265
266    /// Available VRAM in bytes, or `None` for CPU/AVX2 paths.
267    pub fn available_vram_bytes(&self) -> Option<u64> {
268        match self {
269            Self::Cpu => None,
270            #[cfg(feature = "cuda")]
271            Self::Cuda(dev) => dev.available_vram_bytes().ok(),
272            #[cfg(feature = "vulkan")]
273            Self::Vulkan(dev) => dev.available_vram_bytes(),
274            #[cfg(feature = "avx2")]
275            Self::Avx2 => None,
276        }
277    }
278
279    /// Total (installed) VRAM in bytes, or `None` for CPU/AVX2 paths.
280    pub fn total_vram_bytes(&self) -> Option<u64> {
281        match self {
282            Self::Cpu => None,
283            #[cfg(feature = "cuda")]
284            Self::Cuda(dev) => Some(dev.total_vram_bytes()),
285            #[cfg(feature = "vulkan")]
286            Self::Vulkan(dev) => Some(dev.total_vram_bytes()),
287            #[cfg(feature = "avx2")]
288            Self::Avx2 => None,
289        }
290    }
291}
292
293// ── Session-based full-GPU EM ─────────────────────────────────────────────────
294
295/// Unified EM session handle, wraps CUDA, Vulkan, or AVX2 session state.
296///
297/// Callers treat this as an opaque token.  The owning `DeviceBackend` is
298/// responsible for cleanup via [`DeviceBackend::em_drop_session`].
299#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
300pub(crate) enum EmSession {
301    #[cfg(feature = "cuda")]
302    Cuda(cuda::launch::em_reduce::CudaEmSession),
303    #[cfg(feature = "vulkan")]
304    Vulkan(vulkan::launch::em_reduce::VulkanEmSession),
305    #[cfg(feature = "avx2")]
306    Avx2(avx2::launch::em_reduce::Avx2EmSession),
307}
308
309#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
310impl DeviceBackend {
311    /// Allocate an EM session: upload `comparison_levels` once, pre-allocate
312    /// all backend-specific buffers.  Call this once before the EM loop.
313    pub(crate) fn em_init_session(
314        &self,
315        comparison_levels: &[u32],
316        n_pairs: usize,
317        n_fields: usize,
318    ) -> Result<EmSession, GpuError> {
319        match self {
320            #[cfg(feature = "cuda")]
321            Self::Cuda(dev) => dev
322                .em_init_session(comparison_levels, n_pairs, n_fields)
323                .map(EmSession::Cuda),
324            #[cfg(feature = "vulkan")]
325            Self::Vulkan(dev) => dev
326                .em_init_session(comparison_levels, n_pairs, n_fields)
327                .map(EmSession::Vulkan),
328            #[cfg(feature = "avx2")]
329            Self::Avx2 => Ok(EmSession::Avx2(avx2::device::Avx2Device::em_init_session(
330                comparison_levels,
331                n_pairs,
332                n_fields,
333            ))),
334            _ => Err(GpuError::BackendUnavailable(
335                "em_init_session requires an accelerated backend".into(),
336            )),
337        }
338    }
339
340    /// Run one full EM iteration (E-step + M-step) on the active backend.
341    ///
342    /// `weights` must be `ln(m[f][l] / u[f][l])`, `n_fields * 4` floats.
343    /// Returns raw M-step counts; the caller normalises them into `ModelParams`.
344    pub(crate) fn em_run_iteration(
345        &self,
346        session: &mut EmSession,
347        weights: &[f32],
348        log_prior_odds: f32,
349    ) -> Result<EmReduceOutput, GpuError> {
350        match (self, session) {
351            #[cfg(feature = "cuda")]
352            (Self::Cuda(dev), EmSession::Cuda(s)) => {
353                dev.em_run_iteration(s, weights, log_prior_odds)
354            }
355            #[cfg(feature = "vulkan")]
356            (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
357                dev.em_run_iteration(s, weights, log_prior_odds)
358            }
359            #[cfg(feature = "avx2")]
360            (Self::Avx2, EmSession::Avx2(s)) => {
361                avx2::device::Avx2Device::em_run_iteration(s, weights, log_prior_odds)
362            }
363            _ => Err(GpuError::BackendUnavailable(
364                "em_run_iteration requires an accelerated backend".into(),
365            )),
366        }
367    }
368
369    /// Release all backend-side resources held by `session`.
370    ///
371    /// For CUDA/AVX2: fields auto-drop.
372    /// For Vulkan: explicit `VulkanEmSession::destroy` is required because
373    /// `VulkanBuffer` has no `Drop` impl.
374    pub(crate) fn em_drop_session(&self, session: EmSession) {
375        match (self, session) {
376            #[cfg(feature = "cuda")]
377            (Self::Cuda(_), EmSession::Cuda(_s)) => { /* CudaSlice fields auto-drop */ }
378            #[cfg(feature = "vulkan")]
379            (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
380                let mut alloc = dev.allocator.lock().unwrap();
381                s.destroy(&dev.device, &mut alloc);
382            }
383            #[cfg(feature = "avx2")]
384            (Self::Avx2, EmSession::Avx2(_s)) => { /* Vec fields auto-drop */ }
385            _ => {}
386        }
387    }
388}
389
390// ── KernelDispatch impls for DeviceBackend ────────────────────────────────────
391//
392// One impl block per kernel.  Each block delegates to the active variant's
393// per-backend KernelDispatch impl (defined in the respective launch/ modules).
394
395impl KernelDispatch<HelloBackend> for DeviceBackend {
396    fn dispatch(&self, input: HelloBackendInput) -> Result<HelloBackendOutput, GpuError> {
397        match self {
398            #[cfg(feature = "cuda")]
399            Self::Cuda(dev) => {
400                <cuda::CudaDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
401            }
402            #[cfg(feature = "vulkan")]
403            Self::Vulkan(dev) => {
404                <vulkan::VulkanDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
405            }
406            #[cfg(feature = "avx2")]
407            Self::Avx2 => <avx2::Avx2Device as KernelDispatch<HelloBackend>>::dispatch(
408                &avx2::Avx2Device,
409                input,
410            ),
411            Self::Cpu => <CpuDevice as KernelDispatch<HelloBackend>>::dispatch(&CpuDevice, input),
412        }
413    }
414}
415
416impl KernelDispatch<EmReduce> for DeviceBackend {
417    fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
418        match self {
419            #[cfg(feature = "cuda")]
420            Self::Cuda(dev) => <cuda::CudaDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
421            #[cfg(feature = "vulkan")]
422            Self::Vulkan(dev) => {
423                <vulkan::VulkanDevice as KernelDispatch<EmReduce>>::dispatch(dev, input)
424            }
425            #[cfg(feature = "avx2")]
426            Self::Avx2 => {
427                <avx2::Avx2Device as KernelDispatch<EmReduce>>::dispatch(&avx2::Avx2Device, input)
428            }
429            Self::Cpu => <CpuDevice as KernelDispatch<EmReduce>>::dispatch(&CpuDevice, input),
430        }
431    }
432}
433
434// ── Unit tests ────────────────────────────────────────────────────────────────
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn auto_detect_does_not_panic() {
442        let backend = DeviceBackend::auto_detect();
443        let name = backend.name();
444        assert!(
445            matches!(name, "cpu" | "cuda" | "vulkan" | "avx2"),
446            "unexpected backend name: {name}"
447        );
448    }
449
450    #[test]
451    fn cpu_backend_has_no_vram() {
452        let b = DeviceBackend::cpu();
453        assert_eq!(b.available_vram_bytes(), None);
454        assert_eq!(b.total_vram_bytes(), None);
455        assert!(!b.is_gpu());
456        assert!(!b.is_accelerated());
457    }
458
459    #[test]
460    fn cpu_backend_name() {
461        assert_eq!(DeviceBackend::cpu().name(), "cpu");
462    }
463
464    #[test]
465    fn cpu_preference_always_succeeds() {
466        assert!(DeviceBackend::from_preference(BackendPreference::Cpu).is_ok());
467    }
468
469    #[cfg(feature = "avx2")]
470    #[test]
471    fn avx2_backend_is_accelerated_not_gpu() {
472        let b = DeviceBackend::Avx2;
473        assert!(b.is_accelerated());
474        assert!(!b.is_gpu());
475        assert_eq!(b.name(), "avx2");
476        assert_eq!(b.available_vram_bytes(), None);
477    }
478}