1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod device_info;
16pub mod heterogeneous;
17pub mod kernels;
18pub mod memory_management;
19pub mod stream_allocator;
20pub mod tensor_cores;
21
22pub use async_transfer::{
23 AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
24};
25pub use device_info::GpuDeviceInfo;
26pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
27pub use stream_allocator::{StreamAllocator, StreamId};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum GpuBackend {
32 Cuda,
34 Rocm,
36 Wgpu,
38 Metal,
40 OpenCL,
42 Cpu,
44}
45
46impl Default for GpuBackend {
47 fn default() -> Self {
48 Self::preferred()
49 }
50}
51
52impl GpuBackend {
53 pub fn preferred() -> Self {
55 match backends::initialize_optimal_backend() {
58 Ok(backend) => {
59 if backend != GpuBackend::Cpu {
61 #[cfg(not(test))]
64 {
65 return GpuBackend::Cpu;
67 }
68 #[cfg(test)]
69 {
70 return backend;
72 }
73 }
74 backend
75 }
76 Err(_) => {
77 GpuBackend::Cpu
79 }
80 }
81 }
82
83 pub fn is_available(&self) -> bool {
85 match self {
86 GpuBackend::Cuda => false, GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
90 #[cfg(feature = "wgpu")]
91 {
92 use crate::gpu::backends::wgpu::WebGPUContext;
93 WebGPUContext::is_available()
94 }
95 #[cfg(not(feature = "wgpu"))]
96 {
97 false
98 }
99 }
100 GpuBackend::Metal => {
101 #[cfg(all(feature = "metal", target_os = "macos"))]
102 {
103 true
105 }
106 #[cfg(not(all(feature = "metal", target_os = "macos")))]
107 {
108 false
109 }
110 }
111 GpuBackend::OpenCL => {
112 #[cfg(feature = "opencl")]
113 {
114 use crate::gpu::backends::opencl::OpenCLContext;
115 OpenCLContext::is_available()
116 }
117 #[cfg(not(feature = "opencl"))]
118 {
119 false
120 }
121 }
122 GpuBackend::Cpu => true,
123 }
124 }
125}
126
127impl fmt::Display for GpuBackend {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 match self {
130 GpuBackend::Cuda => write!(f, "CUDA"),
131 GpuBackend::Rocm => write!(f, "ROCm"),
132 GpuBackend::Wgpu => write!(f, "WebGPU"),
133 GpuBackend::Metal => write!(f, "Metal"),
134 GpuBackend::OpenCL => write!(f, "OpenCL"),
135 GpuBackend::Cpu => write!(f, "CPU"),
136 }
137 }
138}
139
140use crate::error::{CoreError, ErrorContext, ErrorLocation};
141
142#[derive(Debug, thiserror::Error)]
144pub enum GpuError {
145 #[error("GPU backend {0} is not available")]
147 BackendNotAvailable(String),
148
149 #[error("GPU backend {0} is not supported")]
151 UnsupportedBackend(GpuBackend),
152
153 #[error("GPU backend {0:?} is not supported for this kernel")]
155 BackendNotSupported(GpuBackend),
156
157 #[error("GPU backend {0} is not implemented yet")]
159 BackendNotImplemented(GpuBackend),
160
161 #[error("GPU out of memory: {0}")]
163 OutOfMemory(String),
164
165 #[error("Kernel compilation error: {0}")]
167 KernelCompilationError(String),
168
169 #[error("Kernel execution error: {0}")]
171 KernelExecutionError(String),
172
173 #[error("Invalid parameter: {0}")]
175 InvalidParameter(String),
176
177 #[error("Kernel not found: {0}")]
179 KernelNotFound(String),
180
181 #[error("Kernel specialization not supported")]
183 SpecializationNotSupported,
184
185 #[error("Unsupported data type: {0:?}")]
187 UnsupportedDataType(kernels::DataType),
188
189 #[error("{0}")]
191 Other(String),
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub struct GpuDevice {
197 backend: GpuBackend,
198 device_id: usize,
199}
200
201impl GpuDevice {
202 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
204 Self { backend, device_id }
205 }
206
207 pub fn backend(&self) -> GpuBackend {
209 self.backend
210 }
211
212 pub fn device_id(&self) -> usize {
214 self.device_id
215 }
216
217 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
219 Ok(GpuKernel {
221 backend: self.backend,
222 entry_point: entrypoint.to_string(),
223 })
224 }
225
226 pub fn get_info(&self) -> Result<GpuDeviceInfo, GpuError> {
243 Ok(GpuDeviceInfo::for_backend(self.backend))
244 }
245}
246
247pub struct GpuKernel {
249 backend: GpuBackend,
250 entry_point: String,
251}
252
253impl GpuKernel {
254 pub fn backend(&self) -> GpuBackend {
256 self.backend
257 }
258
259 pub fn entry_point(&self) -> &str {
261 &self.entry_point
262 }
263}
264
265impl From<GpuError> for CoreError {
267 fn from(err: GpuError) -> Self {
268 match err {
269 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
270 ErrorContext::new(format!("GPU backend {backend} is not available"))
271 .with_location(ErrorLocation::new(file!(), line!())),
272 ),
273 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
274 ErrorContext::new(format!("GPU backend {backend} is not supported"))
275 .with_location(ErrorLocation::new(file!(), line!())),
276 ),
277 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
278 ErrorContext::new(format!(
279 "GPU backend {backend:?} is not supported for this kernel"
280 ))
281 .with_location(ErrorLocation::new(file!(), line!())),
282 ),
283 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
284 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
285 .with_location(ErrorLocation::new(file!(), line!())),
286 ),
287 GpuError::OutOfMemory(details) => CoreError::MemoryError(
288 ErrorContext::new(details.to_string())
289 .with_location(ErrorLocation::new(file!(), line!())),
290 ),
291 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
292 ErrorContext::new(msg.to_string())
293 .with_location(ErrorLocation::new(file!(), line!())),
294 ),
295 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
296 ErrorContext::new(msg.to_string())
297 .with_location(ErrorLocation::new(file!(), line!())),
298 ),
299 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
300 ErrorContext::new(msg.to_string())
301 .with_location(ErrorLocation::new(file!(), line!())),
302 ),
303 GpuError::KernelNotFound(name) => CoreError::ComputationError(
304 ErrorContext::new(name.to_string())
305 .with_location(ErrorLocation::new(file!(), line!())),
306 ),
307 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
308 ErrorContext::new("Kernel specialization not supported".to_string())
309 .with_location(ErrorLocation::new(file!(), line!())),
310 ),
311 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
312 ErrorContext::new(format!("{dtype:?}"))
313 .with_location(ErrorLocation::new(file!(), line!())),
314 ),
315 GpuError::Other(msg) => CoreError::ComputationError(
316 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
317 ),
318 }
319 }
320}
321
322pub trait GpuDataType: Copy + Send + Sync + 'static {}
324
325#[derive(Debug)]
327pub struct GpuPtr<T: GpuDataType> {
328 ptr: u64,
329 size: usize,
330 phantom: PhantomData<T>,
331}
332
333impl<T: GpuDataType> GpuPtr<T> {
334 pub fn allocate(size: usize) -> Result<Self, GpuError> {
336 Ok(GpuPtr {
337 ptr: 0x1000_0000, size,
339 phantom: PhantomData,
340 })
341 }
342
343 pub fn as_ptr(&self) -> u64 {
345 self.ptr
346 }
347
348 pub fn len(&self) -> usize {
350 self.size
351 }
352
353 pub fn is_empty(&self) -> bool {
355 self.size == 0
356 }
357}
358
359#[derive(Debug, Clone)]
361pub enum KernelArg<'a, T: GpuDataType> {
362 Buffer(&'a GpuPtr<T>),
364 Scalar(T),
366}
367
368#[derive(Debug, Clone)]
370pub enum DynamicKernelArg {
371 Buffer(u64), F32(f32),
375 F64(f64),
377 I32(i32),
379 U32(u32),
381 Usize(usize),
383}
384
385pub struct GpuChannel {
387 #[allow(dead_code)]
388 source_device: usize,
389 #[allow(dead_code)]
390 target_device: usize,
391 #[allow(dead_code)]
392 bandwidth: f64, }
394
395impl GpuDataType for f32 {}
397impl GpuDataType for f64 {}
398impl GpuDataType for i32 {}
399impl GpuDataType for u32 {}
400impl GpuDataType for u8 {}
401impl GpuDataType for i8 {}
402impl GpuDataType for u16 {}
403impl GpuDataType for i16 {}
404impl GpuDataType for u64 {}
405impl GpuDataType for i64 {}
406impl GpuDataType for usize {}
407impl GpuDataType for isize {}
408
409pub struct GpuBuffer<T: GpuDataType> {
411 inner: Arc<dyn GpuBufferImpl>,
412 size: usize,
413 phantom: PhantomData<T>,
414}
415
416impl<T: GpuDataType> GpuBuffer<T> {
417 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
419 Self {
420 inner,
421 size,
422 phantom: PhantomData,
423 }
424 }
425
426 pub fn len(&self) -> usize {
428 self.size
429 }
430
431 pub fn is_empty(&self) -> bool {
433 self.size == 0
434 }
435
436 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
438 if data.len() > self.size {
439 return Err(GpuError::InvalidParameter(
440 "Data size exceeds buffer size".to_string(),
441 ));
442 }
443 unsafe {
444 self.inner
445 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
446 }
447 Ok(())
448 }
449
450 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
452 if data.len() > self.size {
453 return Err(GpuError::InvalidParameter(
454 "Data size exceeds buffer size".to_string(),
455 ));
456 }
457 unsafe {
458 self.inner
459 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
460 }
461 Ok(())
462 }
463
464 pub fn to_vec(&self) -> Vec<T> {
466 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
467 let _ = self.copy_to_host(&mut result);
468 result
469 }
470}
471
472impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 f.debug_struct("GpuBuffer")
475 .field("size", &self.size)
476 .finish()
477 }
478}
479
480impl<T: GpuDataType> Clone for GpuBuffer<T> {
481 fn clone(&self) -> Self {
482 Self {
483 inner: Arc::clone(&self.inner),
484 size: self.size,
485 phantom: PhantomData,
486 }
487 }
488}
489
490#[derive(Clone)]
492pub struct GpuKernelHandle {
493 inner: Arc<dyn GpuKernelImpl>,
494}
495
496impl GpuKernelHandle {
497 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
499 Self { inner }
500 }
501
502 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
504 self.inner.set_buffer(name, &buffer.inner);
505 }
506
507 pub fn set_u32(&self, name: &str, value: u32) {
509 self.inner.set_u32(name, value);
510 }
511
512 pub fn set_i32(&self, name: &str, value: i32) {
514 self.inner.set_i32(name, value);
515 }
516
517 pub fn set_f32(&self, name: &str, value: f32) {
519 self.inner.set_f32(name, value);
520 }
521
522 pub fn set_f64(&self, name: &str, value: f64) {
524 self.inner.set_f64(name, value);
525 }
526
527 pub fn dispatch(&self, workgroups: [u32; 3]) {
533 if !self.inner.try_batch_dispatch(workgroups) {
534 self.inner.dispatch(workgroups);
535 }
536 }
537
538 pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
543 self.inner.dispatch_no_wait(workgroups);
544 }
545}
546
547pub struct GpuCompiler {
549 inner: Arc<dyn GpuCompilerImpl>,
550}
551
552impl GpuCompiler {
553 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
555 Self { inner }
556 }
557
558 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
560 let kernel = self.inner.compile(source)?;
561 Ok(GpuKernelHandle::new(kernel))
562 }
563
564 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
566 let kernel = self.inner.compile_typed(
567 name,
568 std::any::TypeId::of::<I>(),
569 std::any::TypeId::of::<O>(),
570 );
571 GpuKernelHandle::new(kernel)
572 }
573}
574
575pub struct GpuContext {
577 inner: Arc<dyn GpuContextImpl>,
578 backend: GpuBackend,
579 kernel_registry: kernels::KernelRegistry,
580}
581
582impl GpuContext {
583 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
585 if backend == GpuBackend::Cuda {
589 return Err(GpuError::BackendNotAvailable(
590 "CUDA backend was removed from scirs2-core in 0.6.x; use the oxicuda-* crates directly (per-crate `cuda` feature). See SCIRS2_POLICY.md".to_string(),
591 ));
592 }
593
594 if !backend.is_available() {
596 return Err(GpuError::BackendNotAvailable(backend.to_string()));
597 }
598
599 if backend != GpuBackend::Cpu {
601 let detection_result = backends::detect_gpu_backends();
602 let backend_available = detection_result
603 .devices
604 .iter()
605 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
606
607 if !backend_available {
608 return Err(GpuError::BackendNotAvailable(format!(
609 "{backend} (no devices detected at runtime)"
610 )));
611 }
612 }
613
614 let inner = match backend {
615 GpuBackend::Cuda => {
616 return Err(GpuError::BackendNotAvailable(
617 "CUDA backend was removed from scirs2-core in 0.6.x; use the oxicuda-* crates directly (per-crate `cuda` feature). See SCIRS2_POLICY.md".to_string(),
618 ));
619 }
620 GpuBackend::Rocm => {
621 #[cfg(feature = "rocm")]
622 {
623 #[cfg(test)]
626 {
627 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
629 }
630 #[cfg(not(test))]
631 {
632 return Err(GpuError::BackendNotImplemented(backend));
633 }
634 }
635 #[cfg(not(feature = "rocm"))]
636 {
637 return Err(GpuError::UnsupportedBackend(backend));
638 }
639 }
640 GpuBackend::Wgpu => {
641 #[cfg(feature = "wgpu")]
642 {
643 use crate::gpu::backends::wgpu::WebGPUContext;
644 match WebGPUContext::new() {
645 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
646 Err(e) => return Err(e),
647 }
648 }
649 #[cfg(not(feature = "wgpu"))]
650 {
651 return Err(GpuError::UnsupportedBackend(backend));
652 }
653 }
654 GpuBackend::Metal => {
655 #[cfg(all(feature = "metal", target_os = "macos"))]
656 {
657 use crate::gpu::backends::metal::MetalContext;
658 match MetalContext::new() {
659 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
660 Err(e) => return Err(e),
661 }
662 }
663 #[cfg(not(all(feature = "metal", target_os = "macos")))]
664 {
665 return Err(GpuError::UnsupportedBackend(backend));
666 }
667 }
668 GpuBackend::OpenCL => {
669 #[cfg(feature = "opencl")]
670 {
671 use crate::gpu::backends::opencl::OpenCLContext;
672 match OpenCLContext::new() {
673 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
674 Err(e) => return Err(e),
675 }
676 }
677 #[cfg(not(feature = "opencl"))]
678 {
679 return Err(GpuError::UnsupportedBackend(backend));
680 }
681 }
682 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
683 };
684
685 Ok(Self {
686 inner,
687 backend,
688 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
689 })
690 }
691
692 pub fn backend(&self) -> GpuBackend {
694 self.backend
695 }
696
697 pub fn backend_name(&self) -> &str {
699 match self.backend {
700 GpuBackend::Cuda => "CUDA",
701 GpuBackend::Rocm => "ROCm",
702 GpuBackend::Wgpu => "WebGPU",
703 GpuBackend::Metal => "Metal",
704 GpuBackend::OpenCL => "OpenCL",
705 GpuBackend::Cpu => "CPU",
706 }
707 }
708
709 pub fn gpu_sync(&self) -> Result<(), GpuError> {
714 self.inner.gpu_sync()
715 }
716
717 pub fn begin_batch(&self) -> Result<(), GpuError> {
723 self.inner.begin_batch()
724 }
725
726 pub fn end_batch(&self) -> Result<(), GpuError> {
732 self.inner.end_batch()
733 }
734
735 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
736 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
737 let inner = self.inner.create_buffer(byte_size);
738 GpuBuffer::new(inner, size)
739 }
740
741 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
743 let buffer = self.create_buffer::<T>(data.len());
744 let _ = buffer.copy_from_host(data);
745 buffer
746 }
747
748 pub fn execute<F, R>(&self, f: F) -> R
750 where
751 F: FnOnce(&GpuCompiler) -> R,
752 {
753 let compiler = GpuCompiler::new(self.inner.create_compiler());
754 f(&compiler)
755 }
756
757 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
759 let kernel = self
760 .kernel_registry
761 .get(name)
762 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
763
764 let kernel_source = kernel.source_for_backend(self.backend)?;
765 let metadata = kernel.metadata();
766
767 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
768 Ok(handle)
769 }
770
771 pub fn get_specialized_kernel(
773 &self,
774 name: &str,
775 params: &kernels::KernelParams,
776 ) -> Result<GpuKernelHandle, GpuError> {
777 let specialized = self.kernel_registry.get_specialized(name, params)?;
778 let kernel_source = specialized.source_for_backend(self.backend)?;
779 let metadata = specialized.metadata();
780
781 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
782 Ok(handle)
783 }
784
785 fn compile_kernel_with_metadata(
787 &self,
788 source: &str,
789 _metadata: &kernels::KernelMetadata,
790 ) -> Result<GpuKernelHandle, GpuError> {
791 self.execute(|compiler| compiler.compile(source))
792 }
793
794 pub fn get_available_memory(&self) -> Option<usize> {
796 Some(1024 * 1024 * 1024) }
800
801 pub fn get_total_memory(&self) -> Option<usize> {
803 #[cfg(target_arch = "wasm32")]
806 return Some(512 * 1024 * 1024); #[cfg(not(target_arch = "wasm32"))]
809 Some((4u64 * 1024 * 1024 * 1024) as usize) }
811
812 pub fn launch_kernel(
814 &self,
815 kernel_name: &str,
816 grid_size: (usize, usize, usize),
817 block_size: (usize, usize, usize),
818 args: &[DynamicKernelArg],
819 ) -> Result<(), GpuError> {
820 let _ = (kernel_name, grid_size, block_size, args);
822 Ok(())
823 }
824
825 pub fn transfer_async_host_to_device<T: GpuDataType>(
827 &self,
828 ptr: &GpuPtr<T>,
829 data: &[T],
830 ) -> Result<(), GpuError> {
831 let _ = (ptr, data);
833 Ok(())
834 }
835
836 pub fn transfer_host_to_device<T: GpuDataType>(
838 &self,
839 ptr: &GpuPtr<T>,
840 data: &[T],
841 ) -> Result<(), GpuError> {
842 let _ = (ptr, data);
844 Ok(())
845 }
846
847 pub fn transfer_async_device_to_host<T: GpuDataType>(
849 &self,
850 ptr: &GpuPtr<T>,
851 data: &mut [T],
852 ) -> Result<(), GpuError> {
853 let _ = (ptr, data);
855 Ok(())
856 }
857
858 pub fn transfer_device_to_host<T: GpuDataType>(
860 &self,
861 ptr: &GpuPtr<T>,
862 data: &mut [T],
863 ) -> Result<(), GpuError> {
864 let _ = (ptr, data);
866 Ok(())
867 }
868
869 pub fn execute_kernel(
872 &self,
873 source: &str,
874 buffers: &[GpuBuffer<f32>],
875 work_groups: (u32, u32, u32),
876 int_params: &[u32],
877 float_params: &[f32],
878 ) -> Result<(), GpuError> {
879 eprintln!(
882 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
883 source.len(),
884 buffers.len(),
885 work_groups
886 );
887 eprintln!("Int params: {int_params:?}");
888 eprintln!("Float params: {float_params:?}");
889 Ok(())
890 }
891
892 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
895 Ok(buffer.to_vec())
896 }
897
898 pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
900 self.sum_all_cpu_fallback(buffer)
901 }
902
903 pub fn mean_all<T: GpuDataType>(
905 &self,
906 buffer: &GpuBuffer<T>,
907 ) -> Result<GpuBuffer<T>, GpuError> {
908 self.mean_all_cpu_fallback(buffer)
909 }
910
911 pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
913 self.max_all_cpu_fallback(buffer)
914 }
915
916 pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
918 self.min_all_cpu_fallback(buffer)
919 }
920
921 pub fn sum_axis<T: GpuDataType>(
923 &self,
924 buffer: &GpuBuffer<T>,
925 shape: &[usize],
926 axis: usize,
927 ) -> Result<GpuBuffer<T>, GpuError> {
928 self.sum_axis_cpu_fallback(buffer, shape, axis)
929 }
930
931 pub fn mean_axis<T: GpuDataType>(
933 &self,
934 buffer: &GpuBuffer<T>,
935 shape: &[usize],
936 axis: usize,
937 ) -> Result<GpuBuffer<T>, GpuError> {
938 self.mean_axis_cpu_fallback(buffer, shape, axis)
939 }
940
941 pub fn max_axis<T: GpuDataType>(
943 &self,
944 buffer: &GpuBuffer<T>,
945 shape: &[usize],
946 axis: usize,
947 ) -> Result<GpuBuffer<T>, GpuError> {
948 self.max_axis_cpu_fallback(buffer, shape, axis)
949 }
950
951 pub fn min_axis<T: GpuDataType>(
953 &self,
954 buffer: &GpuBuffer<T>,
955 shape: &[usize],
956 axis: usize,
957 ) -> Result<GpuBuffer<T>, GpuError> {
958 self.min_axis_cpu_fallback(buffer, shape, axis)
959 }
960
961 pub fn broadcast<T: GpuDataType>(
963 &self,
964 buffer: &GpuBuffer<T>,
965 from_shape: &[usize],
966 to_shape: &[usize],
967 ) -> Result<GpuBuffer<T>, GpuError> {
968 self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
969 }
970
971 pub fn scale<T: GpuDataType>(
973 &self,
974 buffer: &GpuBuffer<T>,
975 scalar: T,
976 ) -> Result<GpuBuffer<T>, GpuError> {
977 self.scale_cpu_fallback(buffer, scalar)
978 }
979
980 pub fn gemm<T: GpuDataType>(
982 &self,
983 a: &GpuBuffer<T>,
984 b: &GpuBuffer<T>,
985 m: usize,
986 k: usize,
987 n: usize,
988 ) -> Result<GpuBuffer<T>, GpuError> {
989 self.gemm_cpu_fallback(a, b, m, k, n)
990 }
991
992 pub fn gemm_transpose_b<T: GpuDataType>(
994 &self,
995 a: &GpuBuffer<T>,
996 b: &GpuBuffer<T>,
997 m: usize,
998 k: usize,
999 n: usize,
1000 ) -> Result<GpuBuffer<T>, GpuError> {
1001 self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
1002 }
1003
1004 pub fn gemm_transpose_a<T: GpuDataType>(
1006 &self,
1007 a: &GpuBuffer<T>,
1008 b: &GpuBuffer<T>,
1009 m: usize,
1010 k: usize,
1011 n: usize,
1012 ) -> Result<GpuBuffer<T>, GpuError> {
1013 self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1014 }
1015
1016 pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1018 self.relu_cpu_fallback(input)
1019 }
1020
1021 pub fn relu_backward<T: GpuDataType>(
1023 &self,
1024 grad_output: &GpuBuffer<T>,
1025 input: &GpuBuffer<T>,
1026 ) -> Result<GpuBuffer<T>, GpuError> {
1027 self.relu_backward_cpu_fallback(grad_output, input)
1028 }
1029
1030 pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1032 self.sigmoid_cpu_fallback(input)
1033 }
1034
1035 pub fn sigmoid_backward<T: GpuDataType>(
1037 &self,
1038 grad_output: &GpuBuffer<T>,
1039 input: &GpuBuffer<T>,
1040 ) -> Result<GpuBuffer<T>, GpuError> {
1041 self.sigmoid_backward_cpu_fallback(grad_output, input)
1042 }
1043
1044 pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1046 self.tanh_cpu_fallback(input)
1047 }
1048
1049 pub fn tanh_backward<T: GpuDataType>(
1051 &self,
1052 grad_output: &GpuBuffer<T>,
1053 input: &GpuBuffer<T>,
1054 ) -> Result<GpuBuffer<T>, GpuError> {
1055 self.tanh_backward_cpu_fallback(grad_output, input)
1056 }
1057
1058 pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1060 self.gelu_cpu_fallback(input)
1061 }
1062
1063 pub fn gelu_backward<T: GpuDataType>(
1065 &self,
1066 grad_output: &GpuBuffer<T>,
1067 input: &GpuBuffer<T>,
1068 ) -> Result<GpuBuffer<T>, GpuError> {
1069 self.gelu_backward_cpu_fallback(grad_output, input)
1070 }
1071}
1072
1073impl fmt::Debug for GpuContext {
1074 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1075 f.debug_struct("GpuContext")
1076 .field("backend", &self.backend)
1077 .finish()
1078 }
1079}
1080
1081pub(crate) trait GpuBufferImpl: Send + Sync {
1086 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1088
1089 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1091
1092 #[allow(dead_code)]
1094 fn as_any(&self) -> &dyn std::any::Any;
1095
1096 #[allow(dead_code)]
1098 fn size(&self) -> usize {
1099 0 }
1101
1102 #[allow(dead_code)]
1104 fn device_ptr(&self) -> u64 {
1105 0 }
1107}
1108
1109pub(crate) trait GpuKernelImpl: Send + Sync {
1111 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1113
1114 fn set_u32(&self, name: &str, value: u32);
1116
1117 fn set_i32(&self, name: &str, value: i32);
1119
1120 fn set_f32(&self, name: &str, value: f32);
1122
1123 fn set_f64(&self, name: &str, value: f64);
1125
1126 fn dispatch(&self, workgroups: [u32; 3]);
1128
1129 fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1133 self.dispatch(workgroups);
1135 }
1136
1137 fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1142 false
1143 }
1144}
1145
1146pub(crate) trait GpuCompilerImpl: Send + Sync {
1148 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1150
1151 fn compile_typed(
1153 &self,
1154 name: &str,
1155 input_type: std::any::TypeId,
1156 output_type: std::any::TypeId,
1157 ) -> Arc<dyn GpuKernelImpl>;
1158}
1159
1160pub(crate) trait GpuContextImpl: Send + Sync {
1162 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1164
1165 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1167
1168 fn gpu_sync(&self) -> Result<(), GpuError> {
1170 Ok(()) }
1172
1173 fn begin_batch(&self) -> Result<(), GpuError> {
1178 Ok(()) }
1180
1181 fn end_batch(&self) -> Result<(), GpuError> {
1186 Ok(()) }
1188
1189 fn as_any(&self) -> &dyn std::any::Any
1191 where
1192 Self: 'static + Sized,
1193 {
1194 self
1195 }
1196}
1197
1198struct CpuContext;
1202
1203impl CpuContext {
1204 fn new() -> Self {
1206 Self
1207 }
1208}
1209
1210impl GpuContextImpl for CpuContext {
1211 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1212 Arc::new(CpuBuffer::new(size))
1213 }
1214
1215 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1216 Arc::new(CpuCompiler)
1217 }
1218}
1219
1220struct CpuBuffer {
1222 data: Vec<u8>,
1223}
1224
1225impl CpuBuffer {
1226 fn new(size: usize) -> Self {
1228 Self {
1229 data: vec![0; size],
1230 }
1231 }
1232}
1233
1234impl GpuBufferImpl for CpuBuffer {
1235 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1236 let mut_self = self as *const Self as *mut Self;
1237 let data_ptr = (*mut_self).data.as_mut_ptr();
1238 std::ptr::copy_nonoverlapping(data, data_ptr, size);
1239 }
1240
1241 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1242 let data_ptr = self.data.as_ptr();
1243 std::ptr::copy_nonoverlapping(data_ptr, data, size);
1244 }
1245
1246 fn as_any(&self) -> &dyn std::any::Any {
1247 self
1248 }
1249
1250 fn size(&self) -> usize {
1251 self.data.len()
1252 }
1253
1254 fn device_ptr(&self) -> u64 {
1255 self.data.as_ptr() as u64
1256 }
1257}
1258
1259struct CpuCompiler;
1261
1262impl GpuCompilerImpl for CpuCompiler {
1263 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1264 Ok(Arc::new(CpuKernel))
1267 }
1268
1269 fn compile_typed(
1270 &self,
1271 _name: &str,
1272 _input_type: std::any::TypeId,
1273 _output_type: std::any::TypeId,
1274 ) -> Arc<dyn GpuKernelImpl> {
1275 Arc::new(CpuKernel)
1278 }
1279}
1280
1281struct CpuKernel;
1283
1284impl GpuKernelImpl for CpuKernel {
1285 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1286 }
1288
1289 fn set_u32(&self, _name: &str, value: u32) {
1290 }
1292
1293 fn set_i32(&self, _name: &str, value: i32) {
1294 }
1296
1297 fn set_f32(&self, _name: &str, value: f32) {
1298 }
1300
1301 fn set_f64(&self, _name: &str, value: f64) {
1302 }
1304
1305 fn dispatch(&self, workgroups: [u32; 3]) {
1306 }
1308}
1309
1310#[cfg(test)]
1314mod tests {
1315 use super::*;
1316
1317 #[test]
1318 fn test_gpu_backend_preferred() {
1319 let backend = GpuBackend::preferred();
1320 match backend {
1322 GpuBackend::Cuda
1323 | GpuBackend::Rocm
1324 | GpuBackend::Wgpu
1325 | GpuBackend::Metal
1326 | GpuBackend::OpenCL
1327 | GpuBackend::Cpu => {}
1328 }
1329 }
1330
1331 #[test]
1332 fn test_gpu_backend_default() {
1333 let backend = GpuBackend::default();
1334 assert_eq!(backend, GpuBackend::preferred());
1335 }
1336
1337 #[test]
1338 fn test_gpu_backend_is_available() {
1339 let backend = GpuBackend::Cpu;
1340 assert!(backend.is_available());
1341
1342 assert!(!GpuBackend::Cuda.is_available());
1345
1346 #[cfg(feature = "rocm")]
1347 {
1348 let _ = GpuBackend::Rocm.is_available(); }
1351 #[cfg(not(feature = "rocm"))]
1352 assert!(!GpuBackend::Rocm.is_available());
1353
1354 #[cfg(all(feature = "metal", target_os = "macos"))]
1355 assert!(GpuBackend::Metal.is_available());
1356 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1357 assert!(!GpuBackend::Metal.is_available());
1358 }
1359
1360 #[test]
1361 fn test_gpu_backend_display() {
1362 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1363 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1364 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1365 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1366 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1367 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1368 }
1369
1370 #[test]
1371 fn test_gpuerror_from_conversion() {
1372 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1373 let coreerror: CoreError = gpuerror.into();
1374 match coreerror {
1375 CoreError::ComputationError(_) => {}
1376 _ => panic!("Expected ComputationError"),
1377 }
1378
1379 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1380 let coreerror: CoreError = gpuerror.into();
1381 match coreerror {
1382 CoreError::MemoryError(_) => {}
1383 _ => panic!("Expected MemoryError"),
1384 }
1385
1386 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1387 let coreerror: CoreError = gpuerror.into();
1388 match coreerror {
1389 CoreError::InvalidArgument(_) => {}
1390 _ => panic!("Expected InvalidArgument"),
1391 }
1392
1393 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1394 let coreerror: CoreError = gpuerror.into();
1395 match coreerror {
1396 CoreError::TypeError(_) => {}
1397 _ => panic!("Expected TypeError"),
1398 }
1399 }
1400
1401 #[test]
1402 fn test_gpu_datatype_trait() {
1403 fn assert_gpu_datatype<T: GpuDataType>() {}
1405
1406 assert_gpu_datatype::<f32>();
1407 assert_gpu_datatype::<f64>();
1408 assert_gpu_datatype::<i32>();
1409 assert_gpu_datatype::<u32>();
1410 assert_gpu_datatype::<u8>();
1411 assert_gpu_datatype::<i8>();
1412 assert_gpu_datatype::<u16>();
1413 assert_gpu_datatype::<i16>();
1414 assert_gpu_datatype::<u64>();
1415 assert_gpu_datatype::<i64>();
1416 }
1417
1418 #[test]
1419 fn test_gpu_buffer_creation() {
1420 let inner = Arc::new(CpuBuffer::new(100));
1421 let buffer = GpuBuffer::<f32>::new(inner, 25);
1422
1423 assert_eq!(buffer.len(), 25);
1424 assert!(!buffer.is_empty());
1425 }
1426
1427 #[test]
1428 fn test_gpu_buffer_empty() {
1429 let inner = Arc::new(CpuBuffer::new(0));
1430 let buffer = GpuBuffer::<f32>::new(inner, 0);
1431
1432 assert_eq!(buffer.len(), 0);
1433 assert!(buffer.is_empty());
1434 }
1435
1436 #[test]
1437 fn test_gpu_buffer_copy_operations() {
1438 let inner = Arc::new(CpuBuffer::new(16));
1439 let buffer = GpuBuffer::<f32>::new(inner, 4);
1440
1441 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1442 let _ = buffer.copy_from_host(&data);
1443
1444 let mut result = vec![0.0f32; 4];
1445 let _ = buffer.copy_to_host(&mut result);
1446
1447 assert_eq!(result, data);
1448 }
1449
1450 #[test]
1451 fn test_gpu_buffer_to_vec() {
1452 let inner = Arc::new(CpuBuffer::new(12));
1453 let buffer = GpuBuffer::<f32>::new(inner, 3);
1454
1455 let data = vec![5.0f32, 6.0, 7.0];
1456 let _ = buffer.copy_from_host(&data);
1457
1458 let result = buffer.to_vec();
1459 assert_eq!(result, data);
1460 }
1461
1462 #[test]
1463 #[should_panic(expected = "Data size exceeds buffer size")]
1464 fn test_gpu_buffer_copy_from_host_overflow() {
1465 let inner = Arc::new(CpuBuffer::new(8));
1466 let buffer = GpuBuffer::<f32>::new(inner, 2);
1467
1468 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1470 }
1471
1472 #[test]
1473 #[should_panic(expected = "Data size exceeds buffer size")]
1474 fn test_gpu_buffer_copy_to_host_overflow() {
1475 let inner = Arc::new(CpuBuffer::new(8));
1476 let buffer = GpuBuffer::<f32>::new(inner, 2);
1477
1478 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1480 }
1481
1482 #[test]
1483 fn test_gpu_kernel_handle() {
1484 let kernel = Arc::new(CpuKernel);
1485 let handle = GpuKernelHandle::new(kernel);
1486
1487 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1489 handle.set_buffer("input", &buffer);
1490 handle.set_u32("size", 100);
1491 handle.set_i32("offset", -5);
1492 handle.set_f32("scale", 2.5);
1493 handle.set_f64("precision", 0.0001);
1494
1495 handle.dispatch([16, 8, 1]);
1497 }
1498
1499 #[test]
1500 fn test_gpu_context_cpu_backend() {
1501 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1502 assert_eq!(context.backend(), GpuBackend::Cpu);
1503 assert_eq!(context.backend_name(), "CPU");
1504
1505 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1507 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1508 }
1509
1510 #[test]
1511 fn test_gpu_context_buffer_creation() {
1512 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1513
1514 let buffer = context.create_buffer::<f32>(100);
1515 assert_eq!(buffer.len(), 100);
1516
1517 let data = vec![1.0f32; 50];
1518 let buffer_from_slice = context.create_buffer_from_slice(&data);
1519 assert_eq!(buffer_from_slice.len(), 50);
1520
1521 let result = buffer_from_slice.to_vec();
1522 assert_eq!(result, data);
1523 }
1524
1525 #[test]
1526 fn test_gpu_context_unsupported_backend() {
1527 let result = GpuContext::new(GpuBackend::Cuda);
1530 match result {
1531 Err(GpuError::BackendNotAvailable(msg)) => {
1532 assert!(
1533 msg.contains("oxicuda"),
1534 "expected oxicuda guidance, got: {msg}"
1535 );
1536 }
1537 other => {
1538 panic!("expected BackendNotAvailable for retired CUDA backend, got: {other:?}")
1539 }
1540 }
1541 }
1542
1543 #[test]
1544 fn test_gpu_compiler() {
1545 let compiler_impl = Arc::new(CpuCompiler);
1546 let compiler = GpuCompiler::new(compiler_impl);
1547
1548 let kernel = compiler
1550 .compile("dummy kernel source")
1551 .expect("Operation failed");
1552 kernel.dispatch([1, 1, 1]);
1553
1554 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1556 typed_kernel.dispatch([32, 1, 1]);
1557 }
1558
1559 #[test]
1560 fn test_gpu_context_execute() {
1561 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1562
1563 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1564
1565 assert!(result);
1566 }
1567
1568 #[test]
1569 fn test_gpu_context_kernel_registry() {
1570 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1571
1572 let result = context.get_kernel("non_existent_kernel");
1574 assert!(result.is_err());
1575 match result {
1576 Err(GpuError::KernelNotFound(_)) => {}
1577 _ => panic!("Expected KernelNotFound error"),
1578 }
1579 }
1580
1581 #[test]
1582 fn test_cpu_buffer_implementation() {
1583 let buffer = CpuBuffer::new(256);
1584 assert_eq!(buffer.data.len(), 256);
1585
1586 assert!(buffer.data.iter().all(|&b| b == 0));
1588 }
1589
1590 #[test]
1591 fn test_gpuerror_display() {
1592 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1593 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1594
1595 let error = GpuError::OutOfMemory("allocation failed".to_string());
1596 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1597
1598 let error = GpuError::KernelCompilationError("syntax error".to_string());
1599 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1600
1601 let error = GpuError::KernelNotFound("gemm".to_string());
1602 assert_eq!(error.to_string(), "Kernel not found: gemm");
1603 }
1604
1605 #[test]
1606 fn test_backend_equality() {
1607 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1608 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1609
1610 let backend = GpuBackend::Metal;
1612 let cloned = backend;
1613 let copied = backend;
1614 assert_eq!(backend, cloned);
1615 assert_eq!(backend, copied);
1616 }
1617
1618 #[test]
1619 fn test_backend_hash() {
1620 use std::collections::HashSet;
1621
1622 let mut set = HashSet::new();
1623 set.insert(GpuBackend::Cuda);
1624 set.insert(GpuBackend::Rocm);
1625 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1629 assert!(set.contains(&GpuBackend::Rocm));
1630 }
1631
1632 #[test]
1633 fn test_gpu_buffer_debug_clone() {
1634 let inner = Arc::new(CpuBuffer::new(16));
1635 let buffer = GpuBuffer::<f32>::new(inner, 4);
1636
1637 let debug_str = format!("{:?}", buffer);
1639 assert!(debug_str.contains("GpuBuffer"));
1640 assert!(debug_str.contains("size"));
1641
1642 let cloned = buffer.clone();
1644 assert_eq!(cloned.len(), buffer.len());
1645 assert_eq!(cloned.len(), 4);
1646
1647 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1649 let _ = buffer.copy_from_host(&data);
1650
1651 let mut result = vec![0.0f32; 4];
1652 let _ = cloned.copy_to_host(&mut result);
1653 assert_eq!(result, data);
1654 }
1655
1656 #[test]
1657 fn test_gpu_context_debug() {
1658 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1659
1660 let debug_str = format!("{:?}", context);
1662 assert!(debug_str.contains("GpuContext"));
1663 assert!(debug_str.contains("backend"));
1664 assert!(debug_str.contains("Cpu"));
1665 }
1666
1667 #[test]
1668 fn test_gpu_context_batch_dispatch() {
1669 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1671
1672 let begin_result = context.begin_batch();
1674 assert!(
1675 begin_result.is_ok(),
1676 "begin_batch should succeed on CPU backend"
1677 );
1678
1679 let dispatch_result = context.execute(|compiler| {
1681 compiler.compile("dummy kernel source").map(|kernel| {
1682 kernel.dispatch([4, 1, 1]);
1683 })
1684 });
1685 assert!(
1686 dispatch_result.is_ok(),
1687 "kernel dispatch inside batch should succeed"
1688 );
1689
1690 let end_result = context.end_batch();
1692 assert!(
1693 end_result.is_ok(),
1694 "end_batch should succeed on CPU backend"
1695 );
1696 }
1697
1698 #[test]
1699 fn test_gpu_context_gpu_sync() {
1700 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1702
1703 let result = context.gpu_sync();
1705 assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1706 }
1707
1708 #[test]
1709 fn test_gpu_kernel_dispatch_no_wait() {
1710 let kernel = Arc::new(CpuKernel);
1712 let handle = GpuKernelHandle::new(kernel);
1713
1714 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1716 handle.set_buffer("input", &buffer);
1717 handle.set_u32("size", 4);
1718
1719 handle.dispatch_no_wait([4, 1, 1]);
1721 }
1722
1723 #[test]
1724 fn test_gpu_device_get_info_cpu() {
1725 let device = GpuDevice::new(GpuBackend::Cpu, 0);
1726 let info = device
1727 .get_info()
1728 .expect("get_info should succeed for the CPU backend");
1729
1730 assert_eq!(info.backend, GpuBackend::Cpu);
1732 assert_eq!(info.device_name, "CPU");
1733 assert_eq!(info.device_type, "CPU");
1734
1735 assert!(!info.device_name.is_empty());
1738 assert!(!info.compute_capability.is_empty());
1739 assert!(info.max_work_group_size >= 1);
1740 assert!(info.supports_fp64);
1741 assert!(info.supports_fp16);
1742 }
1743
1744 #[test]
1745 fn test_gpu_device_info_is_deterministic_per_backend() {
1746 for backend in [
1748 GpuBackend::Cpu,
1749 GpuBackend::Cuda,
1750 GpuBackend::Rocm,
1751 GpuBackend::Wgpu,
1752 GpuBackend::Metal,
1753 GpuBackend::OpenCL,
1754 ] {
1755 let device = GpuDevice::new(backend, 0);
1756 let info = device.get_info().expect("get_info should not fail");
1757 let info_again = device.get_info().expect("get_info should not fail");
1758
1759 assert_eq!(info.backend, backend);
1760 assert!(!info.device_name.is_empty());
1761 assert!(!info.device_type.is_empty());
1762 assert!(!info.compute_capability.is_empty());
1763 assert!(info.max_work_group_size >= 1);
1764
1765 assert_eq!(info.device_name, info_again.device_name);
1767 assert_eq!(info.max_work_group_size, info_again.max_work_group_size);
1768 assert_eq!(info.supports_fp64, info_again.supports_fp64);
1769 }
1770 }
1771}