1pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "avx2")]
33pub mod avx2;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38use crate::{
39 backend::cpu::CpuDevice,
40 error::GpuError,
41 kernel::{Kernel, KernelDispatch},
42 kernels::{
43 em_reduce::{EmReduce, EmReduceInput, EmReduceOutput},
44 hello_backend::{HelloBackend, HelloBackendInput, HelloBackendOutput},
45 },
46};
47
48pub enum DeviceBackend {
61 Cpu,
63
64 #[cfg(feature = "cuda")]
66 Cuda(cuda::CudaDevice),
67
68 #[cfg(feature = "vulkan")]
70 Vulkan(vulkan::VulkanDevice),
71
72 #[cfg(feature = "avx2")]
74 Avx2,
75}
76
77pub type GpuBackend = DeviceBackend;
79
80#[non_exhaustive]
84pub enum BackendPreference {
85 Auto,
87 Cuda,
89 Vulkan,
91 Avx2,
93 Cpu,
95}
96
97impl DeviceBackend {
100 pub fn auto_detect() -> Self {
105 #[cfg(feature = "cuda")]
106 match cuda::CudaDevice::init() {
107 Ok(dev) => {
108 tracing::info!(
109 device_name = %dev.name(),
110 vram_bytes = dev.total_vram_bytes(),
111 "compute backend: CUDA selected"
112 );
113 return Self::Cuda(dev);
114 }
115 Err(e) => tracing::warn!(%e, "CUDA init failed, trying Vulkan"),
116 }
117
118 #[cfg(feature = "vulkan")]
119 match vulkan::VulkanDevice::init() {
120 Ok(dev) => {
121 tracing::info!(
122 device_name = %dev.name(),
123 vram_bytes = dev.total_vram_bytes(),
124 "compute backend: Vulkan selected"
125 );
126 return Self::Vulkan(dev);
127 }
128 Err(e) => tracing::warn!(%e, "Vulkan init failed, trying AVX2"),
129 }
130
131 #[cfg(feature = "avx2")]
132 if is_x86_feature_detected!("avx2") {
133 tracing::info!("compute backend: AVX2 selected");
134 return Self::Avx2;
135 }
136
137 tracing::warn!("compute backend: scalar CPU fallback");
138 Self::Cpu
139 }
140
141 pub fn cpu() -> Self {
145 Self::Cpu
146 }
147
148 #[cfg(feature = "cuda")]
153 pub fn cuda() -> Result<Self, GpuError> {
154 Ok(Self::Cuda(cuda::CudaDevice::init()?))
155 }
156
157 #[cfg(feature = "vulkan")]
162 pub fn vulkan() -> Result<Self, GpuError> {
163 Ok(Self::Vulkan(vulkan::VulkanDevice::init()?))
164 }
165
166 #[cfg(feature = "avx2")]
171 pub fn avx2() -> Result<Self, GpuError> {
172 if is_x86_feature_detected!("avx2") {
173 Ok(Self::Avx2)
174 } else {
175 Err(GpuError::BackendUnavailable(
176 "AVX2 not supported by this CPU".into(),
177 ))
178 }
179 }
180
181 pub fn from_preference(pref: BackendPreference) -> Result<Self, GpuError> {
188 match pref {
189 BackendPreference::Auto => Ok(Self::auto_detect()),
190 BackendPreference::Cpu => Ok(Self::Cpu),
191
192 BackendPreference::Cuda => {
193 #[cfg(feature = "cuda")]
194 return Ok(Self::Cuda(cuda::CudaDevice::init()?));
195 #[allow(unreachable_code)]
196 Err(GpuError::BackendUnavailable(
197 "CUDA backend not compiled in; rebuild with --features cuda".into(),
198 ))
199 }
200
201 BackendPreference::Vulkan => {
202 #[cfg(feature = "vulkan")]
203 return Ok(Self::Vulkan(vulkan::VulkanDevice::init()?));
204 #[allow(unreachable_code)]
205 Err(GpuError::BackendUnavailable(
206 "Vulkan backend not compiled in; rebuild with --features vulkan".into(),
207 ))
208 }
209
210 BackendPreference::Avx2 => {
211 #[cfg(feature = "avx2")]
212 {
213 if is_x86_feature_detected!("avx2") {
214 return Ok(Self::Avx2);
215 }
216 return Err(GpuError::BackendUnavailable(
217 "AVX2 not supported by this CPU".into(),
218 ));
219 }
220 #[allow(unreachable_code)]
221 Err(GpuError::BackendUnavailable(
222 "AVX2 backend not compiled in; rebuild with --features avx2".into(),
223 ))
224 }
225 }
226 }
227
228 pub fn run<K: Kernel>(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>
230 where
231 Self: KernelDispatch<K>,
232 {
233 self.dispatch(input)
234 }
235
236 pub fn name(&self) -> &'static str {
238 match self {
239 Self::Cpu => "cpu",
240 #[cfg(feature = "cuda")]
241 Self::Cuda(_) => "cuda",
242 #[cfg(feature = "vulkan")]
243 Self::Vulkan(_) => "vulkan",
244 #[cfg(feature = "avx2")]
245 Self::Avx2 => "avx2",
246 }
247 }
248
249 pub fn is_gpu(&self) -> bool {
251 match self {
252 #[cfg(feature = "cuda")]
253 Self::Cuda(_) => true,
254 #[cfg(feature = "vulkan")]
255 Self::Vulkan(_) => true,
256 _ => false,
257 }
258 }
259
260 pub fn is_accelerated(&self) -> bool {
263 !matches!(self, Self::Cpu)
264 }
265
266 pub fn available_vram_bytes(&self) -> Option<u64> {
268 match self {
269 Self::Cpu => None,
270 #[cfg(feature = "cuda")]
271 Self::Cuda(dev) => dev.available_vram_bytes().ok(),
272 #[cfg(feature = "vulkan")]
273 Self::Vulkan(dev) => dev.available_vram_bytes(),
274 #[cfg(feature = "avx2")]
275 Self::Avx2 => None,
276 }
277 }
278
279 pub fn total_vram_bytes(&self) -> Option<u64> {
281 match self {
282 Self::Cpu => None,
283 #[cfg(feature = "cuda")]
284 Self::Cuda(dev) => Some(dev.total_vram_bytes()),
285 #[cfg(feature = "vulkan")]
286 Self::Vulkan(dev) => Some(dev.total_vram_bytes()),
287 #[cfg(feature = "avx2")]
288 Self::Avx2 => None,
289 }
290 }
291}
292
293#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
300pub(crate) enum EmSession {
301 #[cfg(feature = "cuda")]
302 Cuda(cuda::launch::em_reduce::CudaEmSession),
303 #[cfg(feature = "vulkan")]
304 Vulkan(vulkan::launch::em_reduce::VulkanEmSession),
305 #[cfg(feature = "avx2")]
306 Avx2(avx2::launch::em_reduce::Avx2EmSession),
307}
308
309#[cfg(any(feature = "cuda", feature = "vulkan", feature = "avx2"))]
310impl DeviceBackend {
311 pub(crate) fn em_init_session(
314 &self,
315 comparison_levels: &[u32],
316 n_pairs: usize,
317 n_fields: usize,
318 ) -> Result<EmSession, GpuError> {
319 match self {
320 #[cfg(feature = "cuda")]
321 Self::Cuda(dev) => dev.em_init_session(comparison_levels, n_pairs, n_fields)
322 .map(EmSession::Cuda),
323 #[cfg(feature = "vulkan")]
324 Self::Vulkan(dev) => dev.em_init_session(comparison_levels, n_pairs, n_fields)
325 .map(EmSession::Vulkan),
326 #[cfg(feature = "avx2")]
327 Self::Avx2 => Ok(EmSession::Avx2(
328 avx2::device::Avx2Device::em_init_session(comparison_levels, n_pairs, n_fields),
329 )),
330 _ => Err(GpuError::BackendUnavailable(
331 "em_init_session requires an accelerated backend".into(),
332 )),
333 }
334 }
335
336 pub(crate) fn em_run_iteration(
341 &self,
342 session: &mut EmSession,
343 weights: &[f32],
344 log_prior_odds: f32,
345 ) -> Result<EmReduceOutput, GpuError> {
346 match (self, session) {
347 #[cfg(feature = "cuda")]
348 (Self::Cuda(dev), EmSession::Cuda(s)) =>
349 dev.em_run_iteration(s, weights, log_prior_odds),
350 #[cfg(feature = "vulkan")]
351 (Self::Vulkan(dev), EmSession::Vulkan(s)) =>
352 dev.em_run_iteration(s, weights, log_prior_odds),
353 #[cfg(feature = "avx2")]
354 (Self::Avx2, EmSession::Avx2(s)) =>
355 avx2::device::Avx2Device::em_run_iteration(s, weights, log_prior_odds),
356 _ => Err(GpuError::BackendUnavailable(
357 "em_run_iteration requires an accelerated backend".into(),
358 )),
359 }
360 }
361
362 pub(crate) fn em_drop_session(&self, session: EmSession) {
368 match (self, session) {
369 #[cfg(feature = "cuda")]
370 (Self::Cuda(_), EmSession::Cuda(_s)) => { }
371 #[cfg(feature = "vulkan")]
372 (Self::Vulkan(dev), EmSession::Vulkan(s)) => {
373 let mut alloc = dev.allocator.lock().unwrap();
374 s.destroy(&dev.device, &mut alloc);
375 }
376 #[cfg(feature = "avx2")]
377 (Self::Avx2, EmSession::Avx2(_s)) => { }
378 _ => {}
379 }
380 }
381}
382
383impl KernelDispatch<HelloBackend> for DeviceBackend {
389 fn dispatch(&self, input: HelloBackendInput) -> Result<HelloBackendOutput, GpuError> {
390 match self {
391 #[cfg(feature = "cuda")]
392 Self::Cuda(dev) => <cuda::CudaDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input),
393 #[cfg(feature = "vulkan")]
394 Self::Vulkan(dev) => <vulkan::VulkanDevice as KernelDispatch<HelloBackend>>::dispatch(dev, input),
395 #[cfg(feature = "avx2")]
396 Self::Avx2 => <avx2::Avx2Device as KernelDispatch<HelloBackend>>::dispatch(&avx2::Avx2Device, input),
397 Self::Cpu => <CpuDevice as KernelDispatch<HelloBackend>>::dispatch(&CpuDevice, input),
398 }
399 }
400}
401
402impl KernelDispatch<EmReduce> for DeviceBackend {
403 fn dispatch(&self, input: EmReduceInput<'_>) -> Result<EmReduceOutput, GpuError> {
404 match self {
405 #[cfg(feature = "cuda")]
406 Self::Cuda(dev) => <cuda::CudaDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
407 #[cfg(feature = "vulkan")]
408 Self::Vulkan(dev) => <vulkan::VulkanDevice as KernelDispatch<EmReduce>>::dispatch(dev, input),
409 #[cfg(feature = "avx2")]
410 Self::Avx2 => <avx2::Avx2Device as KernelDispatch<EmReduce>>::dispatch(&avx2::Avx2Device, input),
411 Self::Cpu => <CpuDevice as KernelDispatch<EmReduce>>::dispatch(&CpuDevice, input),
412 }
413 }
414}
415
416#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn auto_detect_does_not_panic() {
424 let backend = DeviceBackend::auto_detect();
425 let name = backend.name();
426 assert!(
427 matches!(name, "cpu" | "cuda" | "vulkan" | "avx2"),
428 "unexpected backend name: {name}"
429 );
430 }
431
432 #[test]
433 fn cpu_backend_has_no_vram() {
434 let b = DeviceBackend::cpu();
435 assert_eq!(b.available_vram_bytes(), None);
436 assert_eq!(b.total_vram_bytes(), None);
437 assert!(!b.is_gpu());
438 assert!(!b.is_accelerated());
439 }
440
441 #[test]
442 fn cpu_backend_name() {
443 assert_eq!(DeviceBackend::cpu().name(), "cpu");
444 }
445
446 #[test]
447 fn cpu_preference_always_succeeds() {
448 assert!(DeviceBackend::from_preference(BackendPreference::Cpu).is_ok());
449 }
450
451 #[cfg(feature = "avx2")]
452 #[test]
453 fn avx2_backend_is_accelerated_not_gpu() {
454 let b = DeviceBackend::Avx2;
455 assert!(b.is_accelerated());
456 assert!(!b.is_gpu());
457 assert_eq!(b.name(), "avx2");
458 assert_eq!(b.available_vram_bytes(), None);
459 }
460}