1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod heterogeneous;
16pub mod kernels;
17pub mod memory_management;
18pub mod stream_allocator;
19pub mod tensor_cores;
20
21pub use async_transfer::{
22 AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
23};
24pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
25pub use stream_allocator::{StreamAllocator, StreamId};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum GpuBackend {
30 Cuda,
32 Rocm,
34 Wgpu,
36 Metal,
38 OpenCL,
40 Cpu,
42}
43
44impl Default for GpuBackend {
45 fn default() -> Self {
46 Self::preferred()
47 }
48}
49
50impl GpuBackend {
51 pub fn preferred() -> Self {
53 match backends::initialize_optimal_backend() {
56 Ok(backend) => {
57 if backend != GpuBackend::Cpu {
59 #[cfg(not(test))]
62 {
63 return GpuBackend::Cpu;
65 }
66 #[cfg(test)]
67 {
68 return backend;
70 }
71 }
72 backend
73 }
74 Err(_) => {
75 GpuBackend::Cpu
77 }
78 }
79 }
80
81 pub fn is_available(&self) -> bool {
83 match self {
84 GpuBackend::Cuda => {
86 #[cfg(feature = "cuda")]
87 {
88 use crate::gpu::backends::cuda::CudaContext;
89 CudaContext::is_available()
90 }
91 #[cfg(not(feature = "cuda"))]
92 {
93 false
94 }
95 }
96 GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
98 #[cfg(feature = "wgpu_backend")]
99 {
100 use crate::gpu::backends::wgpu::WebGPUContext;
101 WebGPUContext::is_available()
102 }
103 #[cfg(not(feature = "wgpu_backend"))]
104 {
105 false
106 }
107 }
108 GpuBackend::Metal => {
109 #[cfg(all(feature = "metal", target_os = "macos"))]
110 {
111 true
113 }
114 #[cfg(not(all(feature = "metal", target_os = "macos")))]
115 {
116 false
117 }
118 }
119 GpuBackend::OpenCL => {
120 #[cfg(feature = "opencl")]
121 {
122 use crate::gpu::backends::opencl::OpenCLContext;
123 OpenCLContext::is_available()
124 }
125 #[cfg(not(feature = "opencl"))]
126 {
127 false
128 }
129 }
130 GpuBackend::Cpu => true,
131 }
132 }
133}
134
135impl fmt::Display for GpuBackend {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 GpuBackend::Cuda => write!(f, "CUDA"),
139 GpuBackend::Rocm => write!(f, "ROCm"),
140 GpuBackend::Wgpu => write!(f, "WebGPU"),
141 GpuBackend::Metal => write!(f, "Metal"),
142 GpuBackend::OpenCL => write!(f, "OpenCL"),
143 GpuBackend::Cpu => write!(f, "CPU"),
144 }
145 }
146}
147
148use crate::error::{CoreError, ErrorContext, ErrorLocation};
149
150#[derive(Debug, thiserror::Error)]
152pub enum GpuError {
153 #[error("GPU backend {0} is not available")]
155 BackendNotAvailable(String),
156
157 #[error("GPU backend {0} is not supported")]
159 UnsupportedBackend(GpuBackend),
160
161 #[error("GPU backend {0:?} is not supported for this kernel")]
163 BackendNotSupported(GpuBackend),
164
165 #[error("GPU backend {0} is not implemented yet")]
167 BackendNotImplemented(GpuBackend),
168
169 #[error("GPU out of memory: {0}")]
171 OutOfMemory(String),
172
173 #[error("Kernel compilation error: {0}")]
175 KernelCompilationError(String),
176
177 #[error("Kernel execution error: {0}")]
179 KernelExecutionError(String),
180
181 #[error("Invalid parameter: {0}")]
183 InvalidParameter(String),
184
185 #[error("Kernel not found: {0}")]
187 KernelNotFound(String),
188
189 #[error("Kernel specialization not supported")]
191 SpecializationNotSupported,
192
193 #[error("Unsupported data type: {0:?}")]
195 UnsupportedDataType(kernels::DataType),
196
197 #[error("{0}")]
199 Other(String),
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub struct GpuDevice {
205 backend: GpuBackend,
206 device_id: usize,
207}
208
209impl GpuDevice {
210 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
212 Self { backend, device_id }
213 }
214
215 pub fn backend(&self) -> GpuBackend {
217 self.backend
218 }
219
220 pub fn device_id(&self) -> usize {
222 self.device_id
223 }
224
225 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
227 Ok(GpuKernel {
229 backend: self.backend,
230 entry_point: entrypoint.to_string(),
231 })
232 }
233}
234
235pub struct GpuKernel {
237 backend: GpuBackend,
238 entry_point: String,
239}
240
241impl GpuKernel {
242 pub fn backend(&self) -> GpuBackend {
244 self.backend
245 }
246
247 pub fn entry_point(&self) -> &str {
249 &self.entry_point
250 }
251}
252
253impl From<GpuError> for CoreError {
255 fn from(err: GpuError) -> Self {
256 match err {
257 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
258 ErrorContext::new(format!("GPU backend {backend} is not available"))
259 .with_location(ErrorLocation::new(file!(), line!())),
260 ),
261 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
262 ErrorContext::new(format!("GPU backend {backend} is not supported"))
263 .with_location(ErrorLocation::new(file!(), line!())),
264 ),
265 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
266 ErrorContext::new(format!(
267 "GPU backend {backend:?} is not supported for this kernel"
268 ))
269 .with_location(ErrorLocation::new(file!(), line!())),
270 ),
271 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
272 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
273 .with_location(ErrorLocation::new(file!(), line!())),
274 ),
275 GpuError::OutOfMemory(details) => CoreError::MemoryError(
276 ErrorContext::new(details.to_string())
277 .with_location(ErrorLocation::new(file!(), line!())),
278 ),
279 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
280 ErrorContext::new(msg.to_string())
281 .with_location(ErrorLocation::new(file!(), line!())),
282 ),
283 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
284 ErrorContext::new(msg.to_string())
285 .with_location(ErrorLocation::new(file!(), line!())),
286 ),
287 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
288 ErrorContext::new(msg.to_string())
289 .with_location(ErrorLocation::new(file!(), line!())),
290 ),
291 GpuError::KernelNotFound(name) => CoreError::ComputationError(
292 ErrorContext::new(name.to_string())
293 .with_location(ErrorLocation::new(file!(), line!())),
294 ),
295 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
296 ErrorContext::new("Kernel specialization not supported".to_string())
297 .with_location(ErrorLocation::new(file!(), line!())),
298 ),
299 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
300 ErrorContext::new(format!("{dtype:?}"))
301 .with_location(ErrorLocation::new(file!(), line!())),
302 ),
303 GpuError::Other(msg) => CoreError::ComputationError(
304 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
305 ),
306 }
307 }
308}
309
310pub trait GpuDataType: Copy + Send + Sync + 'static {}
312
313#[derive(Debug)]
315pub struct GpuPtr<T: GpuDataType> {
316 ptr: u64,
317 size: usize,
318 phantom: PhantomData<T>,
319}
320
321impl<T: GpuDataType> GpuPtr<T> {
322 pub fn allocate(size: usize) -> Result<Self, GpuError> {
324 Ok(GpuPtr {
325 ptr: 0x1000_0000, size,
327 phantom: PhantomData,
328 })
329 }
330
331 pub fn as_ptr(&self) -> u64 {
333 self.ptr
334 }
335
336 pub fn len(&self) -> usize {
338 self.size
339 }
340
341 pub fn is_empty(&self) -> bool {
343 self.size == 0
344 }
345}
346
347#[derive(Debug, Clone)]
349pub enum KernelArg<'a, T: GpuDataType> {
350 Buffer(&'a GpuPtr<T>),
352 Scalar(T),
354}
355
356#[derive(Debug, Clone)]
358pub enum DynamicKernelArg {
359 Buffer(u64), F32(f32),
363 F64(f64),
365 I32(i32),
367 U32(u32),
369 Usize(usize),
371}
372
373pub struct GpuChannel {
375 #[allow(dead_code)]
376 source_device: usize,
377 #[allow(dead_code)]
378 target_device: usize,
379 #[allow(dead_code)]
380 bandwidth: f64, }
382
383impl GpuDataType for f32 {}
385impl GpuDataType for f64 {}
386impl GpuDataType for i32 {}
387impl GpuDataType for u32 {}
388impl GpuDataType for u8 {}
389impl GpuDataType for i8 {}
390impl GpuDataType for u16 {}
391impl GpuDataType for i16 {}
392impl GpuDataType for u64 {}
393impl GpuDataType for i64 {}
394impl GpuDataType for usize {}
395impl GpuDataType for isize {}
396
397pub struct GpuBuffer<T: GpuDataType> {
399 inner: Arc<dyn GpuBufferImpl>,
400 size: usize,
401 phantom: PhantomData<T>,
402}
403
404impl<T: GpuDataType> GpuBuffer<T> {
405 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
407 Self {
408 inner,
409 size,
410 phantom: PhantomData,
411 }
412 }
413
414 pub fn len(&self) -> usize {
416 self.size
417 }
418
419 pub fn is_empty(&self) -> bool {
421 self.size == 0
422 }
423
424 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
426 if data.len() > self.size {
427 return Err(GpuError::InvalidParameter(
428 "Data size exceeds buffer size".to_string(),
429 ));
430 }
431 unsafe {
432 self.inner
433 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
434 }
435 Ok(())
436 }
437
438 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
440 if data.len() > self.size {
441 return Err(GpuError::InvalidParameter(
442 "Data size exceeds buffer size".to_string(),
443 ));
444 }
445 unsafe {
446 self.inner
447 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
448 }
449 Ok(())
450 }
451
452 pub fn to_vec(&self) -> Vec<T> {
454 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
455 let _ = self.copy_to_host(&mut result);
456 result
457 }
458}
459
460impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 f.debug_struct("GpuBuffer")
463 .field("size", &self.size)
464 .finish()
465 }
466}
467
468impl<T: GpuDataType> Clone for GpuBuffer<T> {
469 fn clone(&self) -> Self {
470 Self {
471 inner: Arc::clone(&self.inner),
472 size: self.size,
473 phantom: PhantomData,
474 }
475 }
476}
477
478#[derive(Clone)]
480pub struct GpuKernelHandle {
481 inner: Arc<dyn GpuKernelImpl>,
482}
483
484impl GpuKernelHandle {
485 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
487 Self { inner }
488 }
489
490 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
492 self.inner.set_buffer(name, &buffer.inner);
493 }
494
495 pub fn set_u32(&self, name: &str, value: u32) {
497 self.inner.set_u32(name, value);
498 }
499
500 pub fn set_i32(&self, name: &str, value: i32) {
502 self.inner.set_i32(name, value);
503 }
504
505 pub fn set_f32(&self, name: &str, value: f32) {
507 self.inner.set_f32(name, value);
508 }
509
510 pub fn set_f64(&self, name: &str, value: f64) {
512 self.inner.set_f64(name, value);
513 }
514
515 pub fn dispatch(&self, workgroups: [u32; 3]) {
521 if !self.inner.try_batch_dispatch(workgroups) {
522 self.inner.dispatch(workgroups);
523 }
524 }
525
526 pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
531 self.inner.dispatch_no_wait(workgroups);
532 }
533}
534
535pub struct GpuCompiler {
537 inner: Arc<dyn GpuCompilerImpl>,
538}
539
540impl GpuCompiler {
541 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
543 Self { inner }
544 }
545
546 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
548 let kernel = self.inner.compile(source)?;
549 Ok(GpuKernelHandle::new(kernel))
550 }
551
552 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
554 let kernel = self.inner.compile_typed(
555 name,
556 std::any::TypeId::of::<I>(),
557 std::any::TypeId::of::<O>(),
558 );
559 GpuKernelHandle::new(kernel)
560 }
561}
562
563pub struct GpuContext {
565 inner: Arc<dyn GpuContextImpl>,
566 backend: GpuBackend,
567 kernel_registry: kernels::KernelRegistry,
568}
569
570impl GpuContext {
571 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
573 if !backend.is_available() {
575 return Err(GpuError::BackendNotAvailable(backend.to_string()));
576 }
577
578 if backend != GpuBackend::Cpu {
580 let detection_result = backends::detect_gpu_backends();
581 let backend_available = detection_result
582 .devices
583 .iter()
584 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
585
586 if !backend_available {
587 return Err(GpuError::BackendNotAvailable(format!(
588 "{backend} (no devices detected at runtime)"
589 )));
590 }
591 }
592
593 let inner = match backend {
594 GpuBackend::Cuda => {
595 #[cfg(feature = "cuda")]
596 {
597 use crate::gpu::backends::cuda::CudaContext;
598 match CudaContext::new() {
599 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
600 Err(e) => return Err(e),
601 }
602 }
603 #[cfg(not(feature = "cuda"))]
604 {
605 return Err(GpuError::UnsupportedBackend(backend));
606 }
607 }
608 GpuBackend::Rocm => {
609 #[cfg(feature = "rocm")]
610 {
611 #[cfg(test)]
614 {
615 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
617 }
618 #[cfg(not(test))]
619 {
620 return Err(GpuError::BackendNotImplemented(backend));
621 }
622 }
623 #[cfg(not(feature = "rocm"))]
624 {
625 return Err(GpuError::UnsupportedBackend(backend));
626 }
627 }
628 GpuBackend::Wgpu => {
629 #[cfg(feature = "wgpu_backend")]
630 {
631 use crate::gpu::backends::wgpu::WebGPUContext;
632 match WebGPUContext::new() {
633 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
634 Err(e) => return Err(e),
635 }
636 }
637 #[cfg(not(feature = "wgpu_backend"))]
638 {
639 return Err(GpuError::UnsupportedBackend(backend));
640 }
641 }
642 GpuBackend::Metal => {
643 #[cfg(all(feature = "metal", target_os = "macos"))]
644 {
645 use crate::gpu::backends::metal::MetalContext;
646 match MetalContext::new() {
647 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
648 Err(e) => return Err(e),
649 }
650 }
651 #[cfg(not(all(feature = "metal", target_os = "macos")))]
652 {
653 return Err(GpuError::UnsupportedBackend(backend));
654 }
655 }
656 GpuBackend::OpenCL => {
657 #[cfg(feature = "opencl")]
658 {
659 use crate::gpu::backends::opencl::OpenCLContext;
660 match OpenCLContext::new() {
661 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
662 Err(e) => return Err(e),
663 }
664 }
665 #[cfg(not(feature = "opencl"))]
666 {
667 return Err(GpuError::UnsupportedBackend(backend));
668 }
669 }
670 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
671 };
672
673 Ok(Self {
674 inner,
675 backend,
676 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
677 })
678 }
679
680 pub fn backend(&self) -> GpuBackend {
682 self.backend
683 }
684
685 pub fn backend_name(&self) -> &str {
687 match self.backend {
688 GpuBackend::Cuda => "CUDA",
689 GpuBackend::Rocm => "ROCm",
690 GpuBackend::Wgpu => "WebGPU",
691 GpuBackend::Metal => "Metal",
692 GpuBackend::OpenCL => "OpenCL",
693 GpuBackend::Cpu => "CPU",
694 }
695 }
696
697 pub fn gpu_sync(&self) -> Result<(), GpuError> {
702 self.inner.gpu_sync()
703 }
704
705 pub fn begin_batch(&self) -> Result<(), GpuError> {
711 self.inner.begin_batch()
712 }
713
714 pub fn end_batch(&self) -> Result<(), GpuError> {
720 self.inner.end_batch()
721 }
722
723 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
724 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
725 let inner = self.inner.create_buffer(byte_size);
726 GpuBuffer::new(inner, size)
727 }
728
729 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
731 let buffer = self.create_buffer::<T>(data.len());
732 let _ = buffer.copy_from_host(data);
733 buffer
734 }
735
736 pub fn execute<F, R>(&self, f: F) -> R
738 where
739 F: FnOnce(&GpuCompiler) -> R,
740 {
741 let compiler = GpuCompiler::new(self.inner.create_compiler());
742 f(&compiler)
743 }
744
745 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
747 let kernel = self
748 .kernel_registry
749 .get(name)
750 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
751
752 let kernel_source = kernel.source_for_backend(self.backend)?;
753 let metadata = kernel.metadata();
754
755 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
756 Ok(handle)
757 }
758
759 pub fn get_specialized_kernel(
761 &self,
762 name: &str,
763 params: &kernels::KernelParams,
764 ) -> Result<GpuKernelHandle, GpuError> {
765 let specialized = self.kernel_registry.get_specialized(name, params)?;
766 let kernel_source = specialized.source_for_backend(self.backend)?;
767 let metadata = specialized.metadata();
768
769 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
770 Ok(handle)
771 }
772
773 fn compile_kernel_with_metadata(
775 &self,
776 source: &str,
777 _metadata: &kernels::KernelMetadata,
778 ) -> Result<GpuKernelHandle, GpuError> {
779 self.execute(|compiler| compiler.compile(source))
780 }
781
782 pub fn get_available_memory(&self) -> Option<usize> {
784 Some(1024 * 1024 * 1024) }
788
789 pub fn get_total_memory(&self) -> Option<usize> {
791 #[cfg(target_arch = "wasm32")]
794 return Some(512 * 1024 * 1024); #[cfg(not(target_arch = "wasm32"))]
797 Some((4u64 * 1024 * 1024 * 1024) as usize) }
799
800 pub fn launch_kernel(
802 &self,
803 kernel_name: &str,
804 grid_size: (usize, usize, usize),
805 block_size: (usize, usize, usize),
806 args: &[DynamicKernelArg],
807 ) -> Result<(), GpuError> {
808 let _ = (kernel_name, grid_size, block_size, args);
810 Ok(())
811 }
812
813 pub fn transfer_async_host_to_device<T: GpuDataType>(
815 &self,
816 ptr: &GpuPtr<T>,
817 data: &[T],
818 ) -> Result<(), GpuError> {
819 let _ = (ptr, data);
821 Ok(())
822 }
823
824 pub fn transfer_host_to_device<T: GpuDataType>(
826 &self,
827 ptr: &GpuPtr<T>,
828 data: &[T],
829 ) -> Result<(), GpuError> {
830 let _ = (ptr, data);
832 Ok(())
833 }
834
835 pub fn transfer_async_device_to_host<T: GpuDataType>(
837 &self,
838 ptr: &GpuPtr<T>,
839 data: &mut [T],
840 ) -> Result<(), GpuError> {
841 let _ = (ptr, data);
843 Ok(())
844 }
845
846 pub fn transfer_device_to_host<T: GpuDataType>(
848 &self,
849 ptr: &GpuPtr<T>,
850 data: &mut [T],
851 ) -> Result<(), GpuError> {
852 let _ = (ptr, data);
854 Ok(())
855 }
856
857 pub fn execute_kernel(
860 &self,
861 source: &str,
862 buffers: &[GpuBuffer<f32>],
863 work_groups: (u32, u32, u32),
864 int_params: &[u32],
865 float_params: &[f32],
866 ) -> Result<(), GpuError> {
867 eprintln!(
870 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
871 source.len(),
872 buffers.len(),
873 work_groups
874 );
875 eprintln!("Int params: {int_params:?}");
876 eprintln!("Float params: {float_params:?}");
877 Ok(())
878 }
879
880 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
883 Ok(buffer.to_vec())
884 }
885
886 pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
888 self.sum_all_cpu_fallback(buffer)
889 }
890
891 pub fn mean_all<T: GpuDataType>(
893 &self,
894 buffer: &GpuBuffer<T>,
895 ) -> Result<GpuBuffer<T>, GpuError> {
896 self.mean_all_cpu_fallback(buffer)
897 }
898
899 pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
901 self.max_all_cpu_fallback(buffer)
902 }
903
904 pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
906 self.min_all_cpu_fallback(buffer)
907 }
908
909 pub fn sum_axis<T: GpuDataType>(
911 &self,
912 buffer: &GpuBuffer<T>,
913 shape: &[usize],
914 axis: usize,
915 ) -> Result<GpuBuffer<T>, GpuError> {
916 self.sum_axis_cpu_fallback(buffer, shape, axis)
917 }
918
919 pub fn mean_axis<T: GpuDataType>(
921 &self,
922 buffer: &GpuBuffer<T>,
923 shape: &[usize],
924 axis: usize,
925 ) -> Result<GpuBuffer<T>, GpuError> {
926 self.mean_axis_cpu_fallback(buffer, shape, axis)
927 }
928
929 pub fn max_axis<T: GpuDataType>(
931 &self,
932 buffer: &GpuBuffer<T>,
933 shape: &[usize],
934 axis: usize,
935 ) -> Result<GpuBuffer<T>, GpuError> {
936 self.max_axis_cpu_fallback(buffer, shape, axis)
937 }
938
939 pub fn min_axis<T: GpuDataType>(
941 &self,
942 buffer: &GpuBuffer<T>,
943 shape: &[usize],
944 axis: usize,
945 ) -> Result<GpuBuffer<T>, GpuError> {
946 self.min_axis_cpu_fallback(buffer, shape, axis)
947 }
948
949 pub fn broadcast<T: GpuDataType>(
951 &self,
952 buffer: &GpuBuffer<T>,
953 from_shape: &[usize],
954 to_shape: &[usize],
955 ) -> Result<GpuBuffer<T>, GpuError> {
956 self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
957 }
958
959 pub fn scale<T: GpuDataType>(
961 &self,
962 buffer: &GpuBuffer<T>,
963 scalar: T,
964 ) -> Result<GpuBuffer<T>, GpuError> {
965 self.scale_cpu_fallback(buffer, scalar)
966 }
967
968 pub fn gemm<T: GpuDataType>(
970 &self,
971 a: &GpuBuffer<T>,
972 b: &GpuBuffer<T>,
973 m: usize,
974 k: usize,
975 n: usize,
976 ) -> Result<GpuBuffer<T>, GpuError> {
977 self.gemm_cpu_fallback(a, b, m, k, n)
978 }
979
980 pub fn gemm_transpose_b<T: GpuDataType>(
982 &self,
983 a: &GpuBuffer<T>,
984 b: &GpuBuffer<T>,
985 m: usize,
986 k: usize,
987 n: usize,
988 ) -> Result<GpuBuffer<T>, GpuError> {
989 self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
990 }
991
992 pub fn gemm_transpose_a<T: GpuDataType>(
994 &self,
995 a: &GpuBuffer<T>,
996 b: &GpuBuffer<T>,
997 m: usize,
998 k: usize,
999 n: usize,
1000 ) -> Result<GpuBuffer<T>, GpuError> {
1001 self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1002 }
1003
1004 pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1006 self.relu_cpu_fallback(input)
1007 }
1008
1009 pub fn relu_backward<T: GpuDataType>(
1011 &self,
1012 grad_output: &GpuBuffer<T>,
1013 input: &GpuBuffer<T>,
1014 ) -> Result<GpuBuffer<T>, GpuError> {
1015 self.relu_backward_cpu_fallback(grad_output, input)
1016 }
1017
1018 pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1020 self.sigmoid_cpu_fallback(input)
1021 }
1022
1023 pub fn sigmoid_backward<T: GpuDataType>(
1025 &self,
1026 grad_output: &GpuBuffer<T>,
1027 input: &GpuBuffer<T>,
1028 ) -> Result<GpuBuffer<T>, GpuError> {
1029 self.sigmoid_backward_cpu_fallback(grad_output, input)
1030 }
1031
1032 pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1034 self.tanh_cpu_fallback(input)
1035 }
1036
1037 pub fn tanh_backward<T: GpuDataType>(
1039 &self,
1040 grad_output: &GpuBuffer<T>,
1041 input: &GpuBuffer<T>,
1042 ) -> Result<GpuBuffer<T>, GpuError> {
1043 self.tanh_backward_cpu_fallback(grad_output, input)
1044 }
1045
1046 pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1048 self.gelu_cpu_fallback(input)
1049 }
1050
1051 pub fn gelu_backward<T: GpuDataType>(
1053 &self,
1054 grad_output: &GpuBuffer<T>,
1055 input: &GpuBuffer<T>,
1056 ) -> Result<GpuBuffer<T>, GpuError> {
1057 self.gelu_backward_cpu_fallback(grad_output, input)
1058 }
1059}
1060
1061impl fmt::Debug for GpuContext {
1062 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1063 f.debug_struct("GpuContext")
1064 .field("backend", &self.backend)
1065 .finish()
1066 }
1067}
1068
1069pub(crate) trait GpuBufferImpl: Send + Sync {
1074 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1076
1077 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1079
1080 #[allow(dead_code)]
1082 fn as_any(&self) -> &dyn std::any::Any;
1083
1084 #[allow(dead_code)]
1086 fn size(&self) -> usize {
1087 0 }
1089
1090 #[allow(dead_code)]
1092 fn device_ptr(&self) -> u64 {
1093 0 }
1095}
1096
1097pub(crate) trait GpuKernelImpl: Send + Sync {
1099 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1101
1102 fn set_u32(&self, name: &str, value: u32);
1104
1105 fn set_i32(&self, name: &str, value: i32);
1107
1108 fn set_f32(&self, name: &str, value: f32);
1110
1111 fn set_f64(&self, name: &str, value: f64);
1113
1114 fn dispatch(&self, workgroups: [u32; 3]);
1116
1117 fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1121 self.dispatch(workgroups);
1123 }
1124
1125 fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1130 false
1131 }
1132}
1133
1134pub(crate) trait GpuCompilerImpl: Send + Sync {
1136 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1138
1139 fn compile_typed(
1141 &self,
1142 name: &str,
1143 input_type: std::any::TypeId,
1144 output_type: std::any::TypeId,
1145 ) -> Arc<dyn GpuKernelImpl>;
1146}
1147
1148pub(crate) trait GpuContextImpl: Send + Sync {
1150 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1152
1153 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1155
1156 fn gpu_sync(&self) -> Result<(), GpuError> {
1158 Ok(()) }
1160
1161 fn begin_batch(&self) -> Result<(), GpuError> {
1166 Ok(()) }
1168
1169 fn end_batch(&self) -> Result<(), GpuError> {
1174 Ok(()) }
1176
1177 fn as_any(&self) -> &dyn std::any::Any
1179 where
1180 Self: 'static + Sized,
1181 {
1182 self
1183 }
1184}
1185
1186struct CpuContext;
1190
1191impl CpuContext {
1192 fn new() -> Self {
1194 Self
1195 }
1196}
1197
1198impl GpuContextImpl for CpuContext {
1199 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1200 Arc::new(CpuBuffer::new(size))
1201 }
1202
1203 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1204 Arc::new(CpuCompiler)
1205 }
1206}
1207
1208struct CpuBuffer {
1210 data: Vec<u8>,
1211}
1212
1213impl CpuBuffer {
1214 fn new(size: usize) -> Self {
1216 Self {
1217 data: vec![0; size],
1218 }
1219 }
1220}
1221
1222impl GpuBufferImpl for CpuBuffer {
1223 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1224 let mut_self = self as *const Self as *mut Self;
1225 let data_ptr = (*mut_self).data.as_mut_ptr();
1226 std::ptr::copy_nonoverlapping(data, data_ptr, size);
1227 }
1228
1229 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1230 let data_ptr = self.data.as_ptr();
1231 std::ptr::copy_nonoverlapping(data_ptr, data, size);
1232 }
1233
1234 fn as_any(&self) -> &dyn std::any::Any {
1235 self
1236 }
1237
1238 fn size(&self) -> usize {
1239 self.data.len()
1240 }
1241
1242 fn device_ptr(&self) -> u64 {
1243 self.data.as_ptr() as u64
1244 }
1245}
1246
1247struct CpuCompiler;
1249
1250impl GpuCompilerImpl for CpuCompiler {
1251 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1252 Ok(Arc::new(CpuKernel))
1255 }
1256
1257 fn compile_typed(
1258 &self,
1259 _name: &str,
1260 _input_type: std::any::TypeId,
1261 _output_type: std::any::TypeId,
1262 ) -> Arc<dyn GpuKernelImpl> {
1263 Arc::new(CpuKernel)
1266 }
1267}
1268
1269struct CpuKernel;
1271
1272impl GpuKernelImpl for CpuKernel {
1273 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1274 }
1276
1277 fn set_u32(&self, _name: &str, value: u32) {
1278 }
1280
1281 fn set_i32(&self, _name: &str, value: i32) {
1282 }
1284
1285 fn set_f32(&self, _name: &str, value: f32) {
1286 }
1288
1289 fn set_f64(&self, _name: &str, value: f64) {
1290 }
1292
1293 fn dispatch(&self, workgroups: [u32; 3]) {
1294 }
1296}
1297
1298#[cfg(test)]
1302mod tests {
1303 use super::*;
1304
1305 #[test]
1306 fn test_gpu_backend_preferred() {
1307 let backend = GpuBackend::preferred();
1308 match backend {
1310 GpuBackend::Cuda
1311 | GpuBackend::Rocm
1312 | GpuBackend::Wgpu
1313 | GpuBackend::Metal
1314 | GpuBackend::OpenCL
1315 | GpuBackend::Cpu => {}
1316 }
1317 }
1318
1319 #[test]
1320 fn test_gpu_backend_default() {
1321 let backend = GpuBackend::default();
1322 assert_eq!(backend, GpuBackend::preferred());
1323 }
1324
1325 #[test]
1326 fn test_gpu_backend_is_available() {
1327 let backend = GpuBackend::Cpu;
1328 assert!(backend.is_available());
1329
1330 #[cfg(feature = "cuda")]
1332 {
1333 let _ = GpuBackend::Cuda.is_available(); }
1336 #[cfg(not(feature = "cuda"))]
1337 assert!(!GpuBackend::Cuda.is_available());
1338
1339 #[cfg(feature = "rocm")]
1340 {
1341 let _ = GpuBackend::Rocm.is_available(); }
1344 #[cfg(not(feature = "rocm"))]
1345 assert!(!GpuBackend::Rocm.is_available());
1346
1347 #[cfg(all(feature = "metal", target_os = "macos"))]
1348 assert!(GpuBackend::Metal.is_available());
1349 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1350 assert!(!GpuBackend::Metal.is_available());
1351 }
1352
1353 #[test]
1354 fn test_gpu_backend_display() {
1355 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1356 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1357 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1358 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1359 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1360 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1361 }
1362
1363 #[test]
1364 fn test_gpuerror_from_conversion() {
1365 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1366 let coreerror: CoreError = gpuerror.into();
1367 match coreerror {
1368 CoreError::ComputationError(_) => {}
1369 _ => panic!("Expected ComputationError"),
1370 }
1371
1372 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1373 let coreerror: CoreError = gpuerror.into();
1374 match coreerror {
1375 CoreError::MemoryError(_) => {}
1376 _ => panic!("Expected MemoryError"),
1377 }
1378
1379 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1380 let coreerror: CoreError = gpuerror.into();
1381 match coreerror {
1382 CoreError::InvalidArgument(_) => {}
1383 _ => panic!("Expected InvalidArgument"),
1384 }
1385
1386 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1387 let coreerror: CoreError = gpuerror.into();
1388 match coreerror {
1389 CoreError::TypeError(_) => {}
1390 _ => panic!("Expected TypeError"),
1391 }
1392 }
1393
1394 #[test]
1395 fn test_gpu_datatype_trait() {
1396 fn assert_gpu_datatype<T: GpuDataType>() {}
1398
1399 assert_gpu_datatype::<f32>();
1400 assert_gpu_datatype::<f64>();
1401 assert_gpu_datatype::<i32>();
1402 assert_gpu_datatype::<u32>();
1403 assert_gpu_datatype::<u8>();
1404 assert_gpu_datatype::<i8>();
1405 assert_gpu_datatype::<u16>();
1406 assert_gpu_datatype::<i16>();
1407 assert_gpu_datatype::<u64>();
1408 assert_gpu_datatype::<i64>();
1409 }
1410
1411 #[test]
1412 fn test_gpu_buffer_creation() {
1413 let inner = Arc::new(CpuBuffer::new(100));
1414 let buffer = GpuBuffer::<f32>::new(inner, 25);
1415
1416 assert_eq!(buffer.len(), 25);
1417 assert!(!buffer.is_empty());
1418 }
1419
1420 #[test]
1421 fn test_gpu_buffer_empty() {
1422 let inner = Arc::new(CpuBuffer::new(0));
1423 let buffer = GpuBuffer::<f32>::new(inner, 0);
1424
1425 assert_eq!(buffer.len(), 0);
1426 assert!(buffer.is_empty());
1427 }
1428
1429 #[test]
1430 fn test_gpu_buffer_copy_operations() {
1431 let inner = Arc::new(CpuBuffer::new(16));
1432 let buffer = GpuBuffer::<f32>::new(inner, 4);
1433
1434 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1435 let _ = buffer.copy_from_host(&data);
1436
1437 let mut result = vec![0.0f32; 4];
1438 let _ = buffer.copy_to_host(&mut result);
1439
1440 assert_eq!(result, data);
1441 }
1442
1443 #[test]
1444 fn test_gpu_buffer_to_vec() {
1445 let inner = Arc::new(CpuBuffer::new(12));
1446 let buffer = GpuBuffer::<f32>::new(inner, 3);
1447
1448 let data = vec![5.0f32, 6.0, 7.0];
1449 let _ = buffer.copy_from_host(&data);
1450
1451 let result = buffer.to_vec();
1452 assert_eq!(result, data);
1453 }
1454
1455 #[test]
1456 #[should_panic(expected = "Data size exceeds buffer size")]
1457 fn test_gpu_buffer_copy_from_host_overflow() {
1458 let inner = Arc::new(CpuBuffer::new(8));
1459 let buffer = GpuBuffer::<f32>::new(inner, 2);
1460
1461 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1463 }
1464
1465 #[test]
1466 #[should_panic(expected = "Data size exceeds buffer size")]
1467 fn test_gpu_buffer_copy_to_host_overflow() {
1468 let inner = Arc::new(CpuBuffer::new(8));
1469 let buffer = GpuBuffer::<f32>::new(inner, 2);
1470
1471 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1473 }
1474
1475 #[test]
1476 fn test_gpu_kernel_handle() {
1477 let kernel = Arc::new(CpuKernel);
1478 let handle = GpuKernelHandle::new(kernel);
1479
1480 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1482 handle.set_buffer("input", &buffer);
1483 handle.set_u32("size", 100);
1484 handle.set_i32("offset", -5);
1485 handle.set_f32("scale", 2.5);
1486 handle.set_f64("precision", 0.0001);
1487
1488 handle.dispatch([16, 8, 1]);
1490 }
1491
1492 #[test]
1493 fn test_gpu_context_cpu_backend() {
1494 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1495 assert_eq!(context.backend(), GpuBackend::Cpu);
1496 assert_eq!(context.backend_name(), "CPU");
1497
1498 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1500 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1501 }
1502
1503 #[test]
1504 fn test_gpu_context_buffer_creation() {
1505 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1506
1507 let buffer = context.create_buffer::<f32>(100);
1508 assert_eq!(buffer.len(), 100);
1509
1510 let data = vec![1.0f32; 50];
1511 let buffer_from_slice = context.create_buffer_from_slice(&data);
1512 assert_eq!(buffer_from_slice.len(), 50);
1513
1514 let result = buffer_from_slice.to_vec();
1515 assert_eq!(result, data);
1516 }
1517
1518 #[test]
1519 fn test_gpu_context_unsupported_backend() {
1520 #[cfg(not(feature = "cuda"))]
1522 {
1523 let result = GpuContext::new(GpuBackend::Cuda);
1524 assert!(result.is_err());
1525 match result {
1526 Err(GpuError::UnsupportedBackend(_)) => {}
1527 Err(GpuError::BackendNotAvailable(_)) => {} Err(e) => panic!(
1529 "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1530 e
1531 ),
1532 Ok(_) => panic!("Expected error, got Ok"),
1533 }
1534 }
1535 }
1536
1537 #[test]
1538 fn test_gpu_compiler() {
1539 let compiler_impl = Arc::new(CpuCompiler);
1540 let compiler = GpuCompiler::new(compiler_impl);
1541
1542 let kernel = compiler
1544 .compile("dummy kernel source")
1545 .expect("Operation failed");
1546 kernel.dispatch([1, 1, 1]);
1547
1548 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1550 typed_kernel.dispatch([32, 1, 1]);
1551 }
1552
1553 #[test]
1554 fn test_gpu_context_execute() {
1555 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1556
1557 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1558
1559 assert!(result);
1560 }
1561
1562 #[test]
1563 fn test_gpu_context_kernel_registry() {
1564 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1565
1566 let result = context.get_kernel("non_existent_kernel");
1568 assert!(result.is_err());
1569 match result {
1570 Err(GpuError::KernelNotFound(_)) => {}
1571 _ => panic!("Expected KernelNotFound error"),
1572 }
1573 }
1574
1575 #[test]
1576 fn test_cpu_buffer_implementation() {
1577 let buffer = CpuBuffer::new(256);
1578 assert_eq!(buffer.data.len(), 256);
1579
1580 assert!(buffer.data.iter().all(|&b| b == 0));
1582 }
1583
1584 #[test]
1585 fn test_gpuerror_display() {
1586 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1587 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1588
1589 let error = GpuError::OutOfMemory("allocation failed".to_string());
1590 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1591
1592 let error = GpuError::KernelCompilationError("syntax error".to_string());
1593 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1594
1595 let error = GpuError::KernelNotFound("gemm".to_string());
1596 assert_eq!(error.to_string(), "Kernel not found: gemm");
1597 }
1598
1599 #[test]
1600 fn test_backend_equality() {
1601 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1602 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1603
1604 let backend = GpuBackend::Metal;
1606 let cloned = backend;
1607 let copied = backend;
1608 assert_eq!(backend, cloned);
1609 assert_eq!(backend, copied);
1610 }
1611
1612 #[test]
1613 fn test_backend_hash() {
1614 use std::collections::HashSet;
1615
1616 let mut set = HashSet::new();
1617 set.insert(GpuBackend::Cuda);
1618 set.insert(GpuBackend::Rocm);
1619 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1623 assert!(set.contains(&GpuBackend::Rocm));
1624 }
1625
1626 #[test]
1627 fn test_gpu_buffer_debug_clone() {
1628 let inner = Arc::new(CpuBuffer::new(16));
1629 let buffer = GpuBuffer::<f32>::new(inner, 4);
1630
1631 let debug_str = format!("{:?}", buffer);
1633 assert!(debug_str.contains("GpuBuffer"));
1634 assert!(debug_str.contains("size"));
1635
1636 let cloned = buffer.clone();
1638 assert_eq!(cloned.len(), buffer.len());
1639 assert_eq!(cloned.len(), 4);
1640
1641 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1643 let _ = buffer.copy_from_host(&data);
1644
1645 let mut result = vec![0.0f32; 4];
1646 let _ = cloned.copy_to_host(&mut result);
1647 assert_eq!(result, data);
1648 }
1649
1650 #[test]
1651 fn test_gpu_context_debug() {
1652 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1653
1654 let debug_str = format!("{:?}", context);
1656 assert!(debug_str.contains("GpuContext"));
1657 assert!(debug_str.contains("backend"));
1658 assert!(debug_str.contains("Cpu"));
1659 }
1660
1661 #[test]
1662 fn test_gpu_context_batch_dispatch() {
1663 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1665
1666 let begin_result = context.begin_batch();
1668 assert!(
1669 begin_result.is_ok(),
1670 "begin_batch should succeed on CPU backend"
1671 );
1672
1673 let dispatch_result = context.execute(|compiler| {
1675 compiler.compile("dummy kernel source").map(|kernel| {
1676 kernel.dispatch([4, 1, 1]);
1677 })
1678 });
1679 assert!(
1680 dispatch_result.is_ok(),
1681 "kernel dispatch inside batch should succeed"
1682 );
1683
1684 let end_result = context.end_batch();
1686 assert!(
1687 end_result.is_ok(),
1688 "end_batch should succeed on CPU backend"
1689 );
1690 }
1691
1692 #[test]
1693 fn test_gpu_context_gpu_sync() {
1694 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1696
1697 let result = context.gpu_sync();
1699 assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1700 }
1701
1702 #[test]
1703 fn test_gpu_kernel_dispatch_no_wait() {
1704 let kernel = Arc::new(CpuKernel);
1706 let handle = GpuKernelHandle::new(kernel);
1707
1708 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1710 handle.set_buffer("input", &buffer);
1711 handle.set_u32("size", 4);
1712
1713 handle.dispatch_no_wait([4, 1, 1]);
1715 }
1716}