1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod device_info;
16pub mod heterogeneous;
17pub mod kernels;
18pub mod memory_management;
19pub mod stream_allocator;
20pub mod tensor_cores;
21
22pub use async_transfer::{
23 AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
24};
25pub use device_info::GpuDeviceInfo;
26pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
27pub use stream_allocator::{StreamAllocator, StreamId};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum GpuBackend {
32 Cuda,
34 Rocm,
36 Wgpu,
38 Metal,
40 OpenCL,
42 Cpu,
44}
45
46impl Default for GpuBackend {
47 fn default() -> Self {
48 Self::preferred()
49 }
50}
51
52impl GpuBackend {
53 pub fn preferred() -> Self {
55 match backends::initialize_optimal_backend() {
58 Ok(backend) => {
59 if backend != GpuBackend::Cpu {
61 #[cfg(not(test))]
64 {
65 return GpuBackend::Cpu;
67 }
68 #[cfg(test)]
69 {
70 return backend;
72 }
73 }
74 backend
75 }
76 Err(_) => {
77 GpuBackend::Cpu
79 }
80 }
81 }
82
83 pub fn is_available(&self) -> bool {
85 match self {
86 GpuBackend::Cuda => {
88 #[cfg(feature = "cuda")]
89 {
90 use crate::gpu::backends::cuda::CudaContext;
91 CudaContext::is_available()
92 }
93 #[cfg(not(feature = "cuda"))]
94 {
95 false
96 }
97 }
98 GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
100 #[cfg(feature = "wgpu_backend")]
101 {
102 use crate::gpu::backends::wgpu::WebGPUContext;
103 WebGPUContext::is_available()
104 }
105 #[cfg(not(feature = "wgpu_backend"))]
106 {
107 false
108 }
109 }
110 GpuBackend::Metal => {
111 #[cfg(all(feature = "metal", target_os = "macos"))]
112 {
113 true
115 }
116 #[cfg(not(all(feature = "metal", target_os = "macos")))]
117 {
118 false
119 }
120 }
121 GpuBackend::OpenCL => {
122 #[cfg(feature = "opencl")]
123 {
124 use crate::gpu::backends::opencl::OpenCLContext;
125 OpenCLContext::is_available()
126 }
127 #[cfg(not(feature = "opencl"))]
128 {
129 false
130 }
131 }
132 GpuBackend::Cpu => true,
133 }
134 }
135}
136
137impl fmt::Display for GpuBackend {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 match self {
140 GpuBackend::Cuda => write!(f, "CUDA"),
141 GpuBackend::Rocm => write!(f, "ROCm"),
142 GpuBackend::Wgpu => write!(f, "WebGPU"),
143 GpuBackend::Metal => write!(f, "Metal"),
144 GpuBackend::OpenCL => write!(f, "OpenCL"),
145 GpuBackend::Cpu => write!(f, "CPU"),
146 }
147 }
148}
149
150use crate::error::{CoreError, ErrorContext, ErrorLocation};
151
152#[derive(Debug, thiserror::Error)]
154pub enum GpuError {
155 #[error("GPU backend {0} is not available")]
157 BackendNotAvailable(String),
158
159 #[error("GPU backend {0} is not supported")]
161 UnsupportedBackend(GpuBackend),
162
163 #[error("GPU backend {0:?} is not supported for this kernel")]
165 BackendNotSupported(GpuBackend),
166
167 #[error("GPU backend {0} is not implemented yet")]
169 BackendNotImplemented(GpuBackend),
170
171 #[error("GPU out of memory: {0}")]
173 OutOfMemory(String),
174
175 #[error("Kernel compilation error: {0}")]
177 KernelCompilationError(String),
178
179 #[error("Kernel execution error: {0}")]
181 KernelExecutionError(String),
182
183 #[error("Invalid parameter: {0}")]
185 InvalidParameter(String),
186
187 #[error("Kernel not found: {0}")]
189 KernelNotFound(String),
190
191 #[error("Kernel specialization not supported")]
193 SpecializationNotSupported,
194
195 #[error("Unsupported data type: {0:?}")]
197 UnsupportedDataType(kernels::DataType),
198
199 #[error("{0}")]
201 Other(String),
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub struct GpuDevice {
207 backend: GpuBackend,
208 device_id: usize,
209}
210
211impl GpuDevice {
212 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
214 Self { backend, device_id }
215 }
216
217 pub fn backend(&self) -> GpuBackend {
219 self.backend
220 }
221
222 pub fn device_id(&self) -> usize {
224 self.device_id
225 }
226
227 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
229 Ok(GpuKernel {
231 backend: self.backend,
232 entry_point: entrypoint.to_string(),
233 })
234 }
235
236 pub fn get_info(&self) -> Result<GpuDeviceInfo, GpuError> {
253 Ok(GpuDeviceInfo::for_backend(self.backend))
254 }
255}
256
257pub struct GpuKernel {
259 backend: GpuBackend,
260 entry_point: String,
261}
262
263impl GpuKernel {
264 pub fn backend(&self) -> GpuBackend {
266 self.backend
267 }
268
269 pub fn entry_point(&self) -> &str {
271 &self.entry_point
272 }
273}
274
275impl From<GpuError> for CoreError {
277 fn from(err: GpuError) -> Self {
278 match err {
279 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
280 ErrorContext::new(format!("GPU backend {backend} is not available"))
281 .with_location(ErrorLocation::new(file!(), line!())),
282 ),
283 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
284 ErrorContext::new(format!("GPU backend {backend} is not supported"))
285 .with_location(ErrorLocation::new(file!(), line!())),
286 ),
287 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
288 ErrorContext::new(format!(
289 "GPU backend {backend:?} is not supported for this kernel"
290 ))
291 .with_location(ErrorLocation::new(file!(), line!())),
292 ),
293 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
294 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
295 .with_location(ErrorLocation::new(file!(), line!())),
296 ),
297 GpuError::OutOfMemory(details) => CoreError::MemoryError(
298 ErrorContext::new(details.to_string())
299 .with_location(ErrorLocation::new(file!(), line!())),
300 ),
301 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
302 ErrorContext::new(msg.to_string())
303 .with_location(ErrorLocation::new(file!(), line!())),
304 ),
305 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
306 ErrorContext::new(msg.to_string())
307 .with_location(ErrorLocation::new(file!(), line!())),
308 ),
309 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
310 ErrorContext::new(msg.to_string())
311 .with_location(ErrorLocation::new(file!(), line!())),
312 ),
313 GpuError::KernelNotFound(name) => CoreError::ComputationError(
314 ErrorContext::new(name.to_string())
315 .with_location(ErrorLocation::new(file!(), line!())),
316 ),
317 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
318 ErrorContext::new("Kernel specialization not supported".to_string())
319 .with_location(ErrorLocation::new(file!(), line!())),
320 ),
321 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
322 ErrorContext::new(format!("{dtype:?}"))
323 .with_location(ErrorLocation::new(file!(), line!())),
324 ),
325 GpuError::Other(msg) => CoreError::ComputationError(
326 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
327 ),
328 }
329 }
330}
331
332pub trait GpuDataType: Copy + Send + Sync + 'static {}
334
335#[derive(Debug)]
337pub struct GpuPtr<T: GpuDataType> {
338 ptr: u64,
339 size: usize,
340 phantom: PhantomData<T>,
341}
342
343impl<T: GpuDataType> GpuPtr<T> {
344 pub fn allocate(size: usize) -> Result<Self, GpuError> {
346 Ok(GpuPtr {
347 ptr: 0x1000_0000, size,
349 phantom: PhantomData,
350 })
351 }
352
353 pub fn as_ptr(&self) -> u64 {
355 self.ptr
356 }
357
358 pub fn len(&self) -> usize {
360 self.size
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.size == 0
366 }
367}
368
369#[derive(Debug, Clone)]
371pub enum KernelArg<'a, T: GpuDataType> {
372 Buffer(&'a GpuPtr<T>),
374 Scalar(T),
376}
377
378#[derive(Debug, Clone)]
380pub enum DynamicKernelArg {
381 Buffer(u64), F32(f32),
385 F64(f64),
387 I32(i32),
389 U32(u32),
391 Usize(usize),
393}
394
395pub struct GpuChannel {
397 #[allow(dead_code)]
398 source_device: usize,
399 #[allow(dead_code)]
400 target_device: usize,
401 #[allow(dead_code)]
402 bandwidth: f64, }
404
405impl GpuDataType for f32 {}
407impl GpuDataType for f64 {}
408impl GpuDataType for i32 {}
409impl GpuDataType for u32 {}
410impl GpuDataType for u8 {}
411impl GpuDataType for i8 {}
412impl GpuDataType for u16 {}
413impl GpuDataType for i16 {}
414impl GpuDataType for u64 {}
415impl GpuDataType for i64 {}
416impl GpuDataType for usize {}
417impl GpuDataType for isize {}
418
419pub struct GpuBuffer<T: GpuDataType> {
421 inner: Arc<dyn GpuBufferImpl>,
422 size: usize,
423 phantom: PhantomData<T>,
424}
425
426impl<T: GpuDataType> GpuBuffer<T> {
427 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
429 Self {
430 inner,
431 size,
432 phantom: PhantomData,
433 }
434 }
435
436 pub fn len(&self) -> usize {
438 self.size
439 }
440
441 pub fn is_empty(&self) -> bool {
443 self.size == 0
444 }
445
446 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
448 if data.len() > self.size {
449 return Err(GpuError::InvalidParameter(
450 "Data size exceeds buffer size".to_string(),
451 ));
452 }
453 unsafe {
454 self.inner
455 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
456 }
457 Ok(())
458 }
459
460 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
462 if data.len() > self.size {
463 return Err(GpuError::InvalidParameter(
464 "Data size exceeds buffer size".to_string(),
465 ));
466 }
467 unsafe {
468 self.inner
469 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
470 }
471 Ok(())
472 }
473
474 pub fn to_vec(&self) -> Vec<T> {
476 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
477 let _ = self.copy_to_host(&mut result);
478 result
479 }
480}
481
482impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
483 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484 f.debug_struct("GpuBuffer")
485 .field("size", &self.size)
486 .finish()
487 }
488}
489
490impl<T: GpuDataType> Clone for GpuBuffer<T> {
491 fn clone(&self) -> Self {
492 Self {
493 inner: Arc::clone(&self.inner),
494 size: self.size,
495 phantom: PhantomData,
496 }
497 }
498}
499
500#[derive(Clone)]
502pub struct GpuKernelHandle {
503 inner: Arc<dyn GpuKernelImpl>,
504}
505
506impl GpuKernelHandle {
507 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
509 Self { inner }
510 }
511
512 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
514 self.inner.set_buffer(name, &buffer.inner);
515 }
516
517 pub fn set_u32(&self, name: &str, value: u32) {
519 self.inner.set_u32(name, value);
520 }
521
522 pub fn set_i32(&self, name: &str, value: i32) {
524 self.inner.set_i32(name, value);
525 }
526
527 pub fn set_f32(&self, name: &str, value: f32) {
529 self.inner.set_f32(name, value);
530 }
531
532 pub fn set_f64(&self, name: &str, value: f64) {
534 self.inner.set_f64(name, value);
535 }
536
537 pub fn dispatch(&self, workgroups: [u32; 3]) {
543 if !self.inner.try_batch_dispatch(workgroups) {
544 self.inner.dispatch(workgroups);
545 }
546 }
547
548 pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
553 self.inner.dispatch_no_wait(workgroups);
554 }
555}
556
557pub struct GpuCompiler {
559 inner: Arc<dyn GpuCompilerImpl>,
560}
561
562impl GpuCompiler {
563 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
565 Self { inner }
566 }
567
568 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
570 let kernel = self.inner.compile(source)?;
571 Ok(GpuKernelHandle::new(kernel))
572 }
573
574 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
576 let kernel = self.inner.compile_typed(
577 name,
578 std::any::TypeId::of::<I>(),
579 std::any::TypeId::of::<O>(),
580 );
581 GpuKernelHandle::new(kernel)
582 }
583}
584
585pub struct GpuContext {
587 inner: Arc<dyn GpuContextImpl>,
588 backend: GpuBackend,
589 kernel_registry: kernels::KernelRegistry,
590}
591
592impl GpuContext {
593 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
595 if !backend.is_available() {
597 return Err(GpuError::BackendNotAvailable(backend.to_string()));
598 }
599
600 if backend != GpuBackend::Cpu {
602 let detection_result = backends::detect_gpu_backends();
603 let backend_available = detection_result
604 .devices
605 .iter()
606 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
607
608 if !backend_available {
609 return Err(GpuError::BackendNotAvailable(format!(
610 "{backend} (no devices detected at runtime)"
611 )));
612 }
613 }
614
615 let inner = match backend {
616 GpuBackend::Cuda => {
617 #[cfg(feature = "cuda")]
618 {
619 use crate::gpu::backends::cuda::CudaContext;
620 match CudaContext::new() {
621 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
622 Err(e) => return Err(e),
623 }
624 }
625 #[cfg(not(feature = "cuda"))]
626 {
627 return Err(GpuError::UnsupportedBackend(backend));
628 }
629 }
630 GpuBackend::Rocm => {
631 #[cfg(feature = "rocm")]
632 {
633 #[cfg(test)]
636 {
637 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
639 }
640 #[cfg(not(test))]
641 {
642 return Err(GpuError::BackendNotImplemented(backend));
643 }
644 }
645 #[cfg(not(feature = "rocm"))]
646 {
647 return Err(GpuError::UnsupportedBackend(backend));
648 }
649 }
650 GpuBackend::Wgpu => {
651 #[cfg(feature = "wgpu_backend")]
652 {
653 use crate::gpu::backends::wgpu::WebGPUContext;
654 match WebGPUContext::new() {
655 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
656 Err(e) => return Err(e),
657 }
658 }
659 #[cfg(not(feature = "wgpu_backend"))]
660 {
661 return Err(GpuError::UnsupportedBackend(backend));
662 }
663 }
664 GpuBackend::Metal => {
665 #[cfg(all(feature = "metal", target_os = "macos"))]
666 {
667 use crate::gpu::backends::metal::MetalContext;
668 match MetalContext::new() {
669 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
670 Err(e) => return Err(e),
671 }
672 }
673 #[cfg(not(all(feature = "metal", target_os = "macos")))]
674 {
675 return Err(GpuError::UnsupportedBackend(backend));
676 }
677 }
678 GpuBackend::OpenCL => {
679 #[cfg(feature = "opencl")]
680 {
681 use crate::gpu::backends::opencl::OpenCLContext;
682 match OpenCLContext::new() {
683 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
684 Err(e) => return Err(e),
685 }
686 }
687 #[cfg(not(feature = "opencl"))]
688 {
689 return Err(GpuError::UnsupportedBackend(backend));
690 }
691 }
692 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
693 };
694
695 Ok(Self {
696 inner,
697 backend,
698 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
699 })
700 }
701
702 pub fn backend(&self) -> GpuBackend {
704 self.backend
705 }
706
707 pub fn backend_name(&self) -> &str {
709 match self.backend {
710 GpuBackend::Cuda => "CUDA",
711 GpuBackend::Rocm => "ROCm",
712 GpuBackend::Wgpu => "WebGPU",
713 GpuBackend::Metal => "Metal",
714 GpuBackend::OpenCL => "OpenCL",
715 GpuBackend::Cpu => "CPU",
716 }
717 }
718
719 pub fn gpu_sync(&self) -> Result<(), GpuError> {
724 self.inner.gpu_sync()
725 }
726
727 pub fn begin_batch(&self) -> Result<(), GpuError> {
733 self.inner.begin_batch()
734 }
735
736 pub fn end_batch(&self) -> Result<(), GpuError> {
742 self.inner.end_batch()
743 }
744
745 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
746 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
747 let inner = self.inner.create_buffer(byte_size);
748 GpuBuffer::new(inner, size)
749 }
750
751 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
753 let buffer = self.create_buffer::<T>(data.len());
754 let _ = buffer.copy_from_host(data);
755 buffer
756 }
757
758 pub fn execute<F, R>(&self, f: F) -> R
760 where
761 F: FnOnce(&GpuCompiler) -> R,
762 {
763 let compiler = GpuCompiler::new(self.inner.create_compiler());
764 f(&compiler)
765 }
766
767 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
769 let kernel = self
770 .kernel_registry
771 .get(name)
772 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
773
774 let kernel_source = kernel.source_for_backend(self.backend)?;
775 let metadata = kernel.metadata();
776
777 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
778 Ok(handle)
779 }
780
781 pub fn get_specialized_kernel(
783 &self,
784 name: &str,
785 params: &kernels::KernelParams,
786 ) -> Result<GpuKernelHandle, GpuError> {
787 let specialized = self.kernel_registry.get_specialized(name, params)?;
788 let kernel_source = specialized.source_for_backend(self.backend)?;
789 let metadata = specialized.metadata();
790
791 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
792 Ok(handle)
793 }
794
795 fn compile_kernel_with_metadata(
797 &self,
798 source: &str,
799 _metadata: &kernels::KernelMetadata,
800 ) -> Result<GpuKernelHandle, GpuError> {
801 self.execute(|compiler| compiler.compile(source))
802 }
803
804 pub fn get_available_memory(&self) -> Option<usize> {
806 Some(1024 * 1024 * 1024) }
810
811 pub fn get_total_memory(&self) -> Option<usize> {
813 #[cfg(target_arch = "wasm32")]
816 return Some(512 * 1024 * 1024); #[cfg(not(target_arch = "wasm32"))]
819 Some((4u64 * 1024 * 1024 * 1024) as usize) }
821
822 pub fn launch_kernel(
824 &self,
825 kernel_name: &str,
826 grid_size: (usize, usize, usize),
827 block_size: (usize, usize, usize),
828 args: &[DynamicKernelArg],
829 ) -> Result<(), GpuError> {
830 let _ = (kernel_name, grid_size, block_size, args);
832 Ok(())
833 }
834
835 pub fn transfer_async_host_to_device<T: GpuDataType>(
837 &self,
838 ptr: &GpuPtr<T>,
839 data: &[T],
840 ) -> Result<(), GpuError> {
841 let _ = (ptr, data);
843 Ok(())
844 }
845
846 pub fn transfer_host_to_device<T: GpuDataType>(
848 &self,
849 ptr: &GpuPtr<T>,
850 data: &[T],
851 ) -> Result<(), GpuError> {
852 let _ = (ptr, data);
854 Ok(())
855 }
856
857 pub fn transfer_async_device_to_host<T: GpuDataType>(
859 &self,
860 ptr: &GpuPtr<T>,
861 data: &mut [T],
862 ) -> Result<(), GpuError> {
863 let _ = (ptr, data);
865 Ok(())
866 }
867
868 pub fn transfer_device_to_host<T: GpuDataType>(
870 &self,
871 ptr: &GpuPtr<T>,
872 data: &mut [T],
873 ) -> Result<(), GpuError> {
874 let _ = (ptr, data);
876 Ok(())
877 }
878
879 pub fn execute_kernel(
882 &self,
883 source: &str,
884 buffers: &[GpuBuffer<f32>],
885 work_groups: (u32, u32, u32),
886 int_params: &[u32],
887 float_params: &[f32],
888 ) -> Result<(), GpuError> {
889 eprintln!(
892 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
893 source.len(),
894 buffers.len(),
895 work_groups
896 );
897 eprintln!("Int params: {int_params:?}");
898 eprintln!("Float params: {float_params:?}");
899 Ok(())
900 }
901
902 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
905 Ok(buffer.to_vec())
906 }
907
908 pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
910 self.sum_all_cpu_fallback(buffer)
911 }
912
913 pub fn mean_all<T: GpuDataType>(
915 &self,
916 buffer: &GpuBuffer<T>,
917 ) -> Result<GpuBuffer<T>, GpuError> {
918 self.mean_all_cpu_fallback(buffer)
919 }
920
921 pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
923 self.max_all_cpu_fallback(buffer)
924 }
925
926 pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
928 self.min_all_cpu_fallback(buffer)
929 }
930
931 pub fn sum_axis<T: GpuDataType>(
933 &self,
934 buffer: &GpuBuffer<T>,
935 shape: &[usize],
936 axis: usize,
937 ) -> Result<GpuBuffer<T>, GpuError> {
938 self.sum_axis_cpu_fallback(buffer, shape, axis)
939 }
940
941 pub fn mean_axis<T: GpuDataType>(
943 &self,
944 buffer: &GpuBuffer<T>,
945 shape: &[usize],
946 axis: usize,
947 ) -> Result<GpuBuffer<T>, GpuError> {
948 self.mean_axis_cpu_fallback(buffer, shape, axis)
949 }
950
951 pub fn max_axis<T: GpuDataType>(
953 &self,
954 buffer: &GpuBuffer<T>,
955 shape: &[usize],
956 axis: usize,
957 ) -> Result<GpuBuffer<T>, GpuError> {
958 self.max_axis_cpu_fallback(buffer, shape, axis)
959 }
960
961 pub fn min_axis<T: GpuDataType>(
963 &self,
964 buffer: &GpuBuffer<T>,
965 shape: &[usize],
966 axis: usize,
967 ) -> Result<GpuBuffer<T>, GpuError> {
968 self.min_axis_cpu_fallback(buffer, shape, axis)
969 }
970
971 pub fn broadcast<T: GpuDataType>(
973 &self,
974 buffer: &GpuBuffer<T>,
975 from_shape: &[usize],
976 to_shape: &[usize],
977 ) -> Result<GpuBuffer<T>, GpuError> {
978 self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
979 }
980
981 pub fn scale<T: GpuDataType>(
983 &self,
984 buffer: &GpuBuffer<T>,
985 scalar: T,
986 ) -> Result<GpuBuffer<T>, GpuError> {
987 self.scale_cpu_fallback(buffer, scalar)
988 }
989
990 pub fn gemm<T: GpuDataType>(
992 &self,
993 a: &GpuBuffer<T>,
994 b: &GpuBuffer<T>,
995 m: usize,
996 k: usize,
997 n: usize,
998 ) -> Result<GpuBuffer<T>, GpuError> {
999 self.gemm_cpu_fallback(a, b, m, k, n)
1000 }
1001
1002 pub fn gemm_transpose_b<T: GpuDataType>(
1004 &self,
1005 a: &GpuBuffer<T>,
1006 b: &GpuBuffer<T>,
1007 m: usize,
1008 k: usize,
1009 n: usize,
1010 ) -> Result<GpuBuffer<T>, GpuError> {
1011 self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
1012 }
1013
1014 pub fn gemm_transpose_a<T: GpuDataType>(
1016 &self,
1017 a: &GpuBuffer<T>,
1018 b: &GpuBuffer<T>,
1019 m: usize,
1020 k: usize,
1021 n: usize,
1022 ) -> Result<GpuBuffer<T>, GpuError> {
1023 self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1024 }
1025
1026 pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1028 self.relu_cpu_fallback(input)
1029 }
1030
1031 pub fn relu_backward<T: GpuDataType>(
1033 &self,
1034 grad_output: &GpuBuffer<T>,
1035 input: &GpuBuffer<T>,
1036 ) -> Result<GpuBuffer<T>, GpuError> {
1037 self.relu_backward_cpu_fallback(grad_output, input)
1038 }
1039
1040 pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1042 self.sigmoid_cpu_fallback(input)
1043 }
1044
1045 pub fn sigmoid_backward<T: GpuDataType>(
1047 &self,
1048 grad_output: &GpuBuffer<T>,
1049 input: &GpuBuffer<T>,
1050 ) -> Result<GpuBuffer<T>, GpuError> {
1051 self.sigmoid_backward_cpu_fallback(grad_output, input)
1052 }
1053
1054 pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1056 self.tanh_cpu_fallback(input)
1057 }
1058
1059 pub fn tanh_backward<T: GpuDataType>(
1061 &self,
1062 grad_output: &GpuBuffer<T>,
1063 input: &GpuBuffer<T>,
1064 ) -> Result<GpuBuffer<T>, GpuError> {
1065 self.tanh_backward_cpu_fallback(grad_output, input)
1066 }
1067
1068 pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1070 self.gelu_cpu_fallback(input)
1071 }
1072
1073 pub fn gelu_backward<T: GpuDataType>(
1075 &self,
1076 grad_output: &GpuBuffer<T>,
1077 input: &GpuBuffer<T>,
1078 ) -> Result<GpuBuffer<T>, GpuError> {
1079 self.gelu_backward_cpu_fallback(grad_output, input)
1080 }
1081}
1082
1083impl fmt::Debug for GpuContext {
1084 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1085 f.debug_struct("GpuContext")
1086 .field("backend", &self.backend)
1087 .finish()
1088 }
1089}
1090
1091pub(crate) trait GpuBufferImpl: Send + Sync {
1096 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1098
1099 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1101
1102 #[allow(dead_code)]
1104 fn as_any(&self) -> &dyn std::any::Any;
1105
1106 #[allow(dead_code)]
1108 fn size(&self) -> usize {
1109 0 }
1111
1112 #[allow(dead_code)]
1114 fn device_ptr(&self) -> u64 {
1115 0 }
1117}
1118
1119pub(crate) trait GpuKernelImpl: Send + Sync {
1121 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1123
1124 fn set_u32(&self, name: &str, value: u32);
1126
1127 fn set_i32(&self, name: &str, value: i32);
1129
1130 fn set_f32(&self, name: &str, value: f32);
1132
1133 fn set_f64(&self, name: &str, value: f64);
1135
1136 fn dispatch(&self, workgroups: [u32; 3]);
1138
1139 fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1143 self.dispatch(workgroups);
1145 }
1146
1147 fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1152 false
1153 }
1154}
1155
1156pub(crate) trait GpuCompilerImpl: Send + Sync {
1158 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1160
1161 fn compile_typed(
1163 &self,
1164 name: &str,
1165 input_type: std::any::TypeId,
1166 output_type: std::any::TypeId,
1167 ) -> Arc<dyn GpuKernelImpl>;
1168}
1169
1170pub(crate) trait GpuContextImpl: Send + Sync {
1172 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1174
1175 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1177
1178 fn gpu_sync(&self) -> Result<(), GpuError> {
1180 Ok(()) }
1182
1183 fn begin_batch(&self) -> Result<(), GpuError> {
1188 Ok(()) }
1190
1191 fn end_batch(&self) -> Result<(), GpuError> {
1196 Ok(()) }
1198
1199 fn as_any(&self) -> &dyn std::any::Any
1201 where
1202 Self: 'static + Sized,
1203 {
1204 self
1205 }
1206}
1207
1208struct CpuContext;
1212
1213impl CpuContext {
1214 fn new() -> Self {
1216 Self
1217 }
1218}
1219
1220impl GpuContextImpl for CpuContext {
1221 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1222 Arc::new(CpuBuffer::new(size))
1223 }
1224
1225 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1226 Arc::new(CpuCompiler)
1227 }
1228}
1229
1230struct CpuBuffer {
1232 data: Vec<u8>,
1233}
1234
1235impl CpuBuffer {
1236 fn new(size: usize) -> Self {
1238 Self {
1239 data: vec![0; size],
1240 }
1241 }
1242}
1243
1244impl GpuBufferImpl for CpuBuffer {
1245 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1246 let mut_self = self as *const Self as *mut Self;
1247 let data_ptr = (*mut_self).data.as_mut_ptr();
1248 std::ptr::copy_nonoverlapping(data, data_ptr, size);
1249 }
1250
1251 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1252 let data_ptr = self.data.as_ptr();
1253 std::ptr::copy_nonoverlapping(data_ptr, data, size);
1254 }
1255
1256 fn as_any(&self) -> &dyn std::any::Any {
1257 self
1258 }
1259
1260 fn size(&self) -> usize {
1261 self.data.len()
1262 }
1263
1264 fn device_ptr(&self) -> u64 {
1265 self.data.as_ptr() as u64
1266 }
1267}
1268
1269struct CpuCompiler;
1271
1272impl GpuCompilerImpl for CpuCompiler {
1273 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1274 Ok(Arc::new(CpuKernel))
1277 }
1278
1279 fn compile_typed(
1280 &self,
1281 _name: &str,
1282 _input_type: std::any::TypeId,
1283 _output_type: std::any::TypeId,
1284 ) -> Arc<dyn GpuKernelImpl> {
1285 Arc::new(CpuKernel)
1288 }
1289}
1290
1291struct CpuKernel;
1293
1294impl GpuKernelImpl for CpuKernel {
1295 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1296 }
1298
1299 fn set_u32(&self, _name: &str, value: u32) {
1300 }
1302
1303 fn set_i32(&self, _name: &str, value: i32) {
1304 }
1306
1307 fn set_f32(&self, _name: &str, value: f32) {
1308 }
1310
1311 fn set_f64(&self, _name: &str, value: f64) {
1312 }
1314
1315 fn dispatch(&self, workgroups: [u32; 3]) {
1316 }
1318}
1319
1320#[cfg(test)]
1324mod tests {
1325 use super::*;
1326
1327 #[test]
1328 fn test_gpu_backend_preferred() {
1329 let backend = GpuBackend::preferred();
1330 match backend {
1332 GpuBackend::Cuda
1333 | GpuBackend::Rocm
1334 | GpuBackend::Wgpu
1335 | GpuBackend::Metal
1336 | GpuBackend::OpenCL
1337 | GpuBackend::Cpu => {}
1338 }
1339 }
1340
1341 #[test]
1342 fn test_gpu_backend_default() {
1343 let backend = GpuBackend::default();
1344 assert_eq!(backend, GpuBackend::preferred());
1345 }
1346
1347 #[test]
1348 fn test_gpu_backend_is_available() {
1349 let backend = GpuBackend::Cpu;
1350 assert!(backend.is_available());
1351
1352 #[cfg(feature = "cuda")]
1354 {
1355 let _ = GpuBackend::Cuda.is_available(); }
1358 #[cfg(not(feature = "cuda"))]
1359 assert!(!GpuBackend::Cuda.is_available());
1360
1361 #[cfg(feature = "rocm")]
1362 {
1363 let _ = GpuBackend::Rocm.is_available(); }
1366 #[cfg(not(feature = "rocm"))]
1367 assert!(!GpuBackend::Rocm.is_available());
1368
1369 #[cfg(all(feature = "metal", target_os = "macos"))]
1370 assert!(GpuBackend::Metal.is_available());
1371 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1372 assert!(!GpuBackend::Metal.is_available());
1373 }
1374
1375 #[test]
1376 fn test_gpu_backend_display() {
1377 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1378 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1379 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1380 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1381 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1382 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1383 }
1384
1385 #[test]
1386 fn test_gpuerror_from_conversion() {
1387 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1388 let coreerror: CoreError = gpuerror.into();
1389 match coreerror {
1390 CoreError::ComputationError(_) => {}
1391 _ => panic!("Expected ComputationError"),
1392 }
1393
1394 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1395 let coreerror: CoreError = gpuerror.into();
1396 match coreerror {
1397 CoreError::MemoryError(_) => {}
1398 _ => panic!("Expected MemoryError"),
1399 }
1400
1401 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1402 let coreerror: CoreError = gpuerror.into();
1403 match coreerror {
1404 CoreError::InvalidArgument(_) => {}
1405 _ => panic!("Expected InvalidArgument"),
1406 }
1407
1408 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1409 let coreerror: CoreError = gpuerror.into();
1410 match coreerror {
1411 CoreError::TypeError(_) => {}
1412 _ => panic!("Expected TypeError"),
1413 }
1414 }
1415
1416 #[test]
1417 fn test_gpu_datatype_trait() {
1418 fn assert_gpu_datatype<T: GpuDataType>() {}
1420
1421 assert_gpu_datatype::<f32>();
1422 assert_gpu_datatype::<f64>();
1423 assert_gpu_datatype::<i32>();
1424 assert_gpu_datatype::<u32>();
1425 assert_gpu_datatype::<u8>();
1426 assert_gpu_datatype::<i8>();
1427 assert_gpu_datatype::<u16>();
1428 assert_gpu_datatype::<i16>();
1429 assert_gpu_datatype::<u64>();
1430 assert_gpu_datatype::<i64>();
1431 }
1432
1433 #[test]
1434 fn test_gpu_buffer_creation() {
1435 let inner = Arc::new(CpuBuffer::new(100));
1436 let buffer = GpuBuffer::<f32>::new(inner, 25);
1437
1438 assert_eq!(buffer.len(), 25);
1439 assert!(!buffer.is_empty());
1440 }
1441
1442 #[test]
1443 fn test_gpu_buffer_empty() {
1444 let inner = Arc::new(CpuBuffer::new(0));
1445 let buffer = GpuBuffer::<f32>::new(inner, 0);
1446
1447 assert_eq!(buffer.len(), 0);
1448 assert!(buffer.is_empty());
1449 }
1450
1451 #[test]
1452 fn test_gpu_buffer_copy_operations() {
1453 let inner = Arc::new(CpuBuffer::new(16));
1454 let buffer = GpuBuffer::<f32>::new(inner, 4);
1455
1456 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1457 let _ = buffer.copy_from_host(&data);
1458
1459 let mut result = vec![0.0f32; 4];
1460 let _ = buffer.copy_to_host(&mut result);
1461
1462 assert_eq!(result, data);
1463 }
1464
1465 #[test]
1466 fn test_gpu_buffer_to_vec() {
1467 let inner = Arc::new(CpuBuffer::new(12));
1468 let buffer = GpuBuffer::<f32>::new(inner, 3);
1469
1470 let data = vec![5.0f32, 6.0, 7.0];
1471 let _ = buffer.copy_from_host(&data);
1472
1473 let result = buffer.to_vec();
1474 assert_eq!(result, data);
1475 }
1476
1477 #[test]
1478 #[should_panic(expected = "Data size exceeds buffer size")]
1479 fn test_gpu_buffer_copy_from_host_overflow() {
1480 let inner = Arc::new(CpuBuffer::new(8));
1481 let buffer = GpuBuffer::<f32>::new(inner, 2);
1482
1483 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1485 }
1486
1487 #[test]
1488 #[should_panic(expected = "Data size exceeds buffer size")]
1489 fn test_gpu_buffer_copy_to_host_overflow() {
1490 let inner = Arc::new(CpuBuffer::new(8));
1491 let buffer = GpuBuffer::<f32>::new(inner, 2);
1492
1493 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1495 }
1496
1497 #[test]
1498 fn test_gpu_kernel_handle() {
1499 let kernel = Arc::new(CpuKernel);
1500 let handle = GpuKernelHandle::new(kernel);
1501
1502 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1504 handle.set_buffer("input", &buffer);
1505 handle.set_u32("size", 100);
1506 handle.set_i32("offset", -5);
1507 handle.set_f32("scale", 2.5);
1508 handle.set_f64("precision", 0.0001);
1509
1510 handle.dispatch([16, 8, 1]);
1512 }
1513
1514 #[test]
1515 fn test_gpu_context_cpu_backend() {
1516 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1517 assert_eq!(context.backend(), GpuBackend::Cpu);
1518 assert_eq!(context.backend_name(), "CPU");
1519
1520 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1522 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1523 }
1524
1525 #[test]
1526 fn test_gpu_context_buffer_creation() {
1527 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1528
1529 let buffer = context.create_buffer::<f32>(100);
1530 assert_eq!(buffer.len(), 100);
1531
1532 let data = vec![1.0f32; 50];
1533 let buffer_from_slice = context.create_buffer_from_slice(&data);
1534 assert_eq!(buffer_from_slice.len(), 50);
1535
1536 let result = buffer_from_slice.to_vec();
1537 assert_eq!(result, data);
1538 }
1539
1540 #[test]
1541 fn test_gpu_context_unsupported_backend() {
1542 #[cfg(not(feature = "cuda"))]
1544 {
1545 let result = GpuContext::new(GpuBackend::Cuda);
1546 assert!(result.is_err());
1547 match result {
1548 Err(GpuError::UnsupportedBackend(_)) => {}
1549 Err(GpuError::BackendNotAvailable(_)) => {} Err(e) => panic!(
1551 "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1552 e
1553 ),
1554 Ok(_) => panic!("Expected error, got Ok"),
1555 }
1556 }
1557 }
1558
1559 #[test]
1560 fn test_gpu_compiler() {
1561 let compiler_impl = Arc::new(CpuCompiler);
1562 let compiler = GpuCompiler::new(compiler_impl);
1563
1564 let kernel = compiler
1566 .compile("dummy kernel source")
1567 .expect("Operation failed");
1568 kernel.dispatch([1, 1, 1]);
1569
1570 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1572 typed_kernel.dispatch([32, 1, 1]);
1573 }
1574
1575 #[test]
1576 fn test_gpu_context_execute() {
1577 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1578
1579 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1580
1581 assert!(result);
1582 }
1583
1584 #[test]
1585 fn test_gpu_context_kernel_registry() {
1586 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1587
1588 let result = context.get_kernel("non_existent_kernel");
1590 assert!(result.is_err());
1591 match result {
1592 Err(GpuError::KernelNotFound(_)) => {}
1593 _ => panic!("Expected KernelNotFound error"),
1594 }
1595 }
1596
1597 #[test]
1598 fn test_cpu_buffer_implementation() {
1599 let buffer = CpuBuffer::new(256);
1600 assert_eq!(buffer.data.len(), 256);
1601
1602 assert!(buffer.data.iter().all(|&b| b == 0));
1604 }
1605
1606 #[test]
1607 fn test_gpuerror_display() {
1608 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1609 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1610
1611 let error = GpuError::OutOfMemory("allocation failed".to_string());
1612 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1613
1614 let error = GpuError::KernelCompilationError("syntax error".to_string());
1615 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1616
1617 let error = GpuError::KernelNotFound("gemm".to_string());
1618 assert_eq!(error.to_string(), "Kernel not found: gemm");
1619 }
1620
1621 #[test]
1622 fn test_backend_equality() {
1623 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1624 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1625
1626 let backend = GpuBackend::Metal;
1628 let cloned = backend;
1629 let copied = backend;
1630 assert_eq!(backend, cloned);
1631 assert_eq!(backend, copied);
1632 }
1633
1634 #[test]
1635 fn test_backend_hash() {
1636 use std::collections::HashSet;
1637
1638 let mut set = HashSet::new();
1639 set.insert(GpuBackend::Cuda);
1640 set.insert(GpuBackend::Rocm);
1641 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1645 assert!(set.contains(&GpuBackend::Rocm));
1646 }
1647
1648 #[test]
1649 fn test_gpu_buffer_debug_clone() {
1650 let inner = Arc::new(CpuBuffer::new(16));
1651 let buffer = GpuBuffer::<f32>::new(inner, 4);
1652
1653 let debug_str = format!("{:?}", buffer);
1655 assert!(debug_str.contains("GpuBuffer"));
1656 assert!(debug_str.contains("size"));
1657
1658 let cloned = buffer.clone();
1660 assert_eq!(cloned.len(), buffer.len());
1661 assert_eq!(cloned.len(), 4);
1662
1663 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1665 let _ = buffer.copy_from_host(&data);
1666
1667 let mut result = vec![0.0f32; 4];
1668 let _ = cloned.copy_to_host(&mut result);
1669 assert_eq!(result, data);
1670 }
1671
1672 #[test]
1673 fn test_gpu_context_debug() {
1674 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1675
1676 let debug_str = format!("{:?}", context);
1678 assert!(debug_str.contains("GpuContext"));
1679 assert!(debug_str.contains("backend"));
1680 assert!(debug_str.contains("Cpu"));
1681 }
1682
1683 #[test]
1684 fn test_gpu_context_batch_dispatch() {
1685 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1687
1688 let begin_result = context.begin_batch();
1690 assert!(
1691 begin_result.is_ok(),
1692 "begin_batch should succeed on CPU backend"
1693 );
1694
1695 let dispatch_result = context.execute(|compiler| {
1697 compiler.compile("dummy kernel source").map(|kernel| {
1698 kernel.dispatch([4, 1, 1]);
1699 })
1700 });
1701 assert!(
1702 dispatch_result.is_ok(),
1703 "kernel dispatch inside batch should succeed"
1704 );
1705
1706 let end_result = context.end_batch();
1708 assert!(
1709 end_result.is_ok(),
1710 "end_batch should succeed on CPU backend"
1711 );
1712 }
1713
1714 #[test]
1715 fn test_gpu_context_gpu_sync() {
1716 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1718
1719 let result = context.gpu_sync();
1721 assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1722 }
1723
1724 #[test]
1725 fn test_gpu_kernel_dispatch_no_wait() {
1726 let kernel = Arc::new(CpuKernel);
1728 let handle = GpuKernelHandle::new(kernel);
1729
1730 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1732 handle.set_buffer("input", &buffer);
1733 handle.set_u32("size", 4);
1734
1735 handle.dispatch_no_wait([4, 1, 1]);
1737 }
1738
1739 #[test]
1740 fn test_gpu_device_get_info_cpu() {
1741 let device = GpuDevice::new(GpuBackend::Cpu, 0);
1742 let info = device
1743 .get_info()
1744 .expect("get_info should succeed for the CPU backend");
1745
1746 assert_eq!(info.backend, GpuBackend::Cpu);
1748 assert_eq!(info.device_name, "CPU");
1749 assert_eq!(info.device_type, "CPU");
1750
1751 assert!(!info.device_name.is_empty());
1754 assert!(!info.compute_capability.is_empty());
1755 assert!(info.max_work_group_size >= 1);
1756 assert!(info.supports_fp64);
1757 assert!(info.supports_fp16);
1758 }
1759
1760 #[test]
1761 fn test_gpu_device_info_is_deterministic_per_backend() {
1762 for backend in [
1764 GpuBackend::Cpu,
1765 GpuBackend::Cuda,
1766 GpuBackend::Rocm,
1767 GpuBackend::Wgpu,
1768 GpuBackend::Metal,
1769 GpuBackend::OpenCL,
1770 ] {
1771 let device = GpuDevice::new(backend, 0);
1772 let info = device.get_info().expect("get_info should not fail");
1773 let info_again = device.get_info().expect("get_info should not fail");
1774
1775 assert_eq!(info.backend, backend);
1776 assert!(!info.device_name.is_empty());
1777 assert!(!info.device_type.is_empty());
1778 assert!(!info.compute_capability.is_empty());
1779 assert!(info.max_work_group_size >= 1);
1780
1781 assert_eq!(info.device_name, info_again.device_name);
1783 assert_eq!(info.max_work_group_size, info_again.max_work_group_size);
1784 assert_eq!(info.supports_fp64, info_again.supports_fp64);
1785 }
1786 }
1787}