1pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "avx2")]
33pub mod avx2;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38use crate::{
39 backend::cpu::CpuDevice,
40 error::GpuError,
41 kernel::{Kernel, KernelDispatch},
42 kernels::{
43 em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
44 hello_backend::{HelloBackend, HelloBackendInput, HelloBackendOutput},
45 },
46};
47
48pub enum DeviceBackend {
61 Cpu,
63
64 #[cfg(feature = "cuda")]
66 Cuda(cuda::CudaDevice),
67
68 #[cfg(feature = "vulkan")]
70 Vulkan(vulkan::VulkanDevice),
71
72 #[cfg(feature = "avx2")]
74 Avx2,
75}
76
77pub type GpuBackend = DeviceBackend;
79
80#[non_exhaustive]
84pub enum BackendPreference {
85 Auto,
87 Cuda,
89 Vulkan,
91 Avx2,
93 Cpu,
95}
96
97impl DeviceBackend {
100 pub fn auto_detect() -> Self {
105 #[cfg(feature = "cuda")]
106 match cuda::CudaDevice::init() {
107 Ok(dev) => {
108 tracing::info!(
109 device_name = %dev.name(),
110 vram_bytes = dev.total_vram_bytes(),
111 "compute backend: CUDA selected"
112 );
113 return Self::Cuda(dev);
114 }
115 Err(e) => tracing::warn!(%e, "CUDA init failed, trying Vulkan"),
116 }
117
118 #[cfg(feature = "vulkan")]
119 match vulkan::VulkanDevice::init() {
120 Ok(dev) => {
121 tracing::info!(
122 device_name = %dev.name(),
123 vram_bytes = dev.total_vram_bytes(),
124 "compute backend: Vulkan selected"
125 );
126 return Self::Vulkan(dev);
127 }
128 Err(e) => tracing::warn!(%e, "Vulkan init failed, trying AVX2"),
129 }
130
131 #[cfg(feature = "avx2")]
132 if is_x86_feature_detected!("avx2") {
133 tracing::info!("compute backend: AVX2 selected");
134 return Self::Avx2;
135 }
136
137 tracing::warn!("compute backend: scalar CPU fallback");
138 Self::Cpu
139 }
140
141 pub fn cpu() -> Self {
145 Self::Cpu
146 }
147
148 #[cfg(feature = "cuda")]
153 pub fn cuda() -> Result<Self, GpuError> {
154 Ok(Self::Cuda(cuda::CudaDevice::init()?))
155 }
156
157 #[cfg(feature = "vulkan")]
162 pub fn vulkan() -> Result<Self, GpuError> {
163 Ok(Self::Vulkan(vulkan::VulkanDevice::init()?))
164 }
165
166 #[cfg(feature = "avx2")]
171 pub fn avx2() -> Result<Self, GpuError> {
172 if is_x86_feature_detected!("avx2") {
173 Ok(Self::Avx2)
174 } else {
175 Err(GpuError::BackendUnavailable(
176 "AVX2 not supported by this CPU".into(),
177 ))
178 }
179 }
180
181 pub fn from_preference(pref: BackendPreference) -> Result<Self, GpuError> {
188 match pref {
189 BackendPreference::Auto => Ok(Self::auto_detect()),
190 BackendPreference::Cpu => Ok(Self::Cpu),
191
192 BackendPreference::Cuda => {
193 #[cfg(feature = "cuda")]
194 return Ok(Self::Cuda(cuda::CudaDevice::init()?));
195 #[allow(unreachable_code)]
196 Err(GpuError::BackendUnavailable(
197 "CUDA backend not compiled in; rebuild with --features cuda".into(),
198 ))
199 }
200
201 BackendPreference::Vulkan => {
202 #[cfg(feature = "vulkan")]
203 return Ok(Self::Vulkan(vulkan::VulkanDevice::init()?));
204 #[allow(unreachable_code)]
205 Err(GpuError::BackendUnavailable(
206 "Vulkan backend not compiled in; rebuild with --features vulkan".into(),
207 ))
208 }
209
210 BackendPreference::Avx2 => {
211 #[cfg(feature = "avx2")]
212 {
213 if is_x86_feature_detected!("avx2") {
214 return Ok(Self::Avx2);
215 }
216 return Err(GpuError::BackendUnavailable(
217 "AVX2 not supported by this CPU".into(),
218 ));
219 }
220 #[allow(unreachable_code)]
221 Err(GpuError::BackendUnavailable(
222 "AVX2 backend not compiled in; rebuild with --features avx2".into(),
223 ))
224 }
225 }
226 }
227
228 pub fn run<K: Kernel>(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>
230 where
231 Self: KernelDispatch<K>,
232 {
233 self.dispatch(input)
234 }
235
236 pub fn name(&self) -> &'static str {
238 match self {
239 Self::Cpu => "cpu",
240 #[cfg(feature = "cuda")]
241 Self::Cuda(_) => "cuda",
242 #[cfg(feature = "vulkan")]
243 Self::Vulkan(_) => "vulkan",
244 #[cfg(feature = "avx2")]
245 Self::Avx2 => "avx2",
246 }
247 }
248
249 pub fn is_gpu(&self) -> bool {
251 match self {
252 #[cfg(feature = "cuda")]
253 Self::Cuda(_) => true,
254 #[cfg(feature = "vulkan")]
255 Self::Vulkan(_) => true,
256 _ => false,
257 }
258 }
259
260 pub fn is_accelerated(&self) -> bool {
263 !matches!(self, Self::Cpu)
264 }
265
266 pub fn available_vram_bytes(&self) -> Option<u64> {
268 match self {
269 Self::Cpu => None,
270 #[cfg(feature = "cuda")]
271 Self::Cuda(dev) => dev.available_vram_bytes().ok(),
272 #[cfg(feature = "vulkan")]
273 Self::Vulkan(dev) => dev.available_vram_bytes(),
274 #[cfg(feature = "avx2")]
275 Self::Avx2 => None,
276 }
277 }
278
279 pub fn total_vram_bytes(&self) -> Option<u64> {
281 match self {
282 Self::Cpu => None,
283 #[cfg(feature = "cuda")]
284 Self::Cuda(dev) => Some(dev.total_vram_bytes()),
285 #[cfg(feature = "vulkan")]
286 Self::Vulkan(dev) => Some(dev.total_vram_bytes()),
287 #[cfg(feature = "avx2")]
288 Self::Avx2 => None,
289 }
290 }
291}
292
293#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
300pub(crate) enum EmSession {
301 #[cfg(feature = "cuda")]
302 Cuda(cuda::launch::em_reduce::CudaEmSession),
303 #[cfg(feature = "vulkan")]
304 Vulkan(vulkan::launch::em_reduce::VulkanEmSession),
305 #[cfg(feature = "avx2")]
306 Avx2(avx2::launch::em_reduce::Avx2EmSession),
307}
308
309#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
310impl DeviceBackend {
311 pub(crate) fn em_init_session(
314 &self,
315 comparison_levels: &[u32],
316 n_pairs: usize,
317 n_fields: usize,
318 ) -> Result<EmSession, GpuError> {
319 match self {
320 #[cfg(feature = "cuda")]
321 Self::Cuda(dev) => dev
322 .em_init_session(comparison_levels, n_pairs, n_fields)
323 .map(EmSession::Cuda),
324 #[cfg(feature = "vulkan")]
325 Self::Vulkan(dev) => dev
326 .em_init_session(comparison_levels, n_pairs, n_fields)
327 .map(EmSession::Vulkan),
328 #[cfg(feature = "avx2")]
329 Self::Avx2 => Ok(EmSession::Avx2(avx2::device::Avx2Device::em_init_session(
330 comparison_levels,
331 n_pairs,
332 n_fields,
333 ))),
334 _ => Err(GpuError::BackendUnavailable(
335 "em_init_session requires an accelerated backend".into(),
336 )),
337 }
338 }
339
340 pub(crate) fn em_run_iteration(
345 &self,
346 session: &mut EmSession,
347 weights: &[f32],
348 log_prior_odds: f32,
349 ) -> Result<EmReduceOutput, GpuError> {
350 match (self, session) {
351 #[cfg(feature = "cuda")]
352 (Self::Cuda(dev), EmSession::Cuda(s)) => {
353 dev.em_run_iteration(s, weights, log_prior_odds)
354 }
355 #[cfg(feature = "vulkan")]
356 (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
357 dev.em_run_iteration(s, weights, log_prior_odds)
358 }
359 #[cfg(feature = "avx2")]
360 (Self::Avx2, EmSession::Avx2(s)) => {
361 avx2::device::Avx2Device::em_run_iteration(s, weights, log_prior_odds)
362 }
363 _ => Err(GpuError::BackendUnavailable(
364 "em_run_iteration requires an accelerated backend".into(),
365 )),
366 }
367 }
368
369 pub(crate) fn em_drop_session(&self, session: EmSession) {
375 match (self, session) {
376 #[cfg(feature = "cuda")]
377 (Self::Cuda(_), EmSession::Cuda(_s)) => { }
378 #[cfg(feature = "vulkan")]
379 (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
380 let mut alloc = dev.allocator.lock().unwrap();
381 s.destroy(&dev.device, &mut alloc);
382 }
383 #[cfg(feature = "avx2")]
384 (Self::Avx2, EmSession::Avx2(_s)) => { }
385 _ => {}
386 }
387 }
388}
389
390impl KernelDispatch<HelloBackend> for DeviceBackend {
396 fn dispatch(&self, input: HelloBackendInput) -> Result<HelloBackendOutput, GpuError> {
397 match self {
398 #[cfg(feature = "cuda")]
399 Self::Cuda(dev) => {
400 <cuda::CudaDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
401 }
402 #[cfg(feature = "vulkan")]
403 Self::Vulkan(dev) => {
404 <vulkan::VulkanDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input)
405 }
406 #[cfg(feature = "avx2")]
407 Self::Avx2 => <avx2::Avx2Device as KernelDispatch<HelloBackend>>::dispatch(
408 &avx2::Avx2Device,
409 input,
410 ),
411 Self::Cpu => <CpuDevice as KernelDispatch<HelloBackend>>::dispatch(&CpuDevice, input),
412 }
413 }
414}
415
416impl KernelDispatch<EmReduce> for DeviceBackend {
417 fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
418 match self {
419 #[cfg(feature = "cuda")]
420 Self::Cuda(dev) => <cuda::CudaDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
421 #[cfg(feature = "vulkan")]
422 Self::Vulkan(dev) => {
423 <vulkan::VulkanDevice as KernelDispatch<EmReduce>>::dispatch(dev, input)
424 }
425 #[cfg(feature = "avx2")]
426 Self::Avx2 => {
427 <avx2::Avx2Device as KernelDispatch<EmReduce>>::dispatch(&avx2::Avx2Device, input)
428 }
429 Self::Cpu => <CpuDevice as KernelDispatch<EmReduce>>::dispatch(&CpuDevice, input),
430 }
431 }
432}
433
434#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn auto_detect_does_not_panic() {
442 let backend = DeviceBackend::auto_detect();
443 let name = backend.name();
444 assert!(
445 matches!(name, "cpu" | "cuda" | "vulkan" | "avx2"),
446 "unexpected backend name: {name}"
447 );
448 }
449
450 #[test]
451 fn cpu_backend_has_no_vram() {
452 let b = DeviceBackend::cpu();
453 assert_eq!(b.available_vram_bytes(), None);
454 assert_eq!(b.total_vram_bytes(), None);
455 assert!(!b.is_gpu());
456 assert!(!b.is_accelerated());
457 }
458
459 #[test]
460 fn cpu_backend_name() {
461 assert_eq!(DeviceBackend::cpu().name(), "cpu");
462 }
463
464 #[test]
465 fn cpu_preference_always_succeeds() {
466 assert!(DeviceBackend::from_preference(BackendPreference::Cpu).is_ok());
467 }
468
469 #[cfg(feature = "avx2")]
470 #[test]
471 fn avx2_backend_is_accelerated_not_gpu() {
472 let b = DeviceBackend::Avx2;
473 assert!(b.is_accelerated());
474 assert!(!b.is_gpu());
475 assert_eq!(b.name(), "avx2");
476 assert_eq!(b.available_vram_bytes(), None);
477 }
478}