1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod auto_tuning;
11pub mod backends;
12pub mod benchmarks;
13mod cpu_ops;
14pub mod heterogeneous;
15pub mod kernels;
16pub mod memory_management;
17pub mod tensor_cores;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum GpuBackend {
22 Cuda,
24 Rocm,
26 Wgpu,
28 Metal,
30 OpenCL,
32 Cpu,
34}
35
36impl Default for GpuBackend {
37 fn default() -> Self {
38 Self::preferred()
39 }
40}
41
42impl GpuBackend {
43 pub fn preferred() -> Self {
45 match backends::initialize_optimal_backend() {
48 Ok(backend) => {
49 if backend != GpuBackend::Cpu {
51 #[cfg(not(test))]
54 {
55 return GpuBackend::Cpu;
57 }
58 #[cfg(test)]
59 {
60 return backend;
62 }
63 }
64 backend
65 }
66 Err(_) => {
67 GpuBackend::Cpu
69 }
70 }
71 }
72
73 pub fn is_available(&self) -> bool {
75 match self {
76 GpuBackend::Cuda => {
78 #[cfg(feature = "cuda")]
79 {
80 use crate::gpu::backends::cuda::CudaContext;
81 CudaContext::is_available()
82 }
83 #[cfg(not(feature = "cuda"))]
84 {
85 false
86 }
87 }
88 GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
90 #[cfg(feature = "wgpu_backend")]
91 {
92 use crate::gpu::backends::wgpu::WebGPUContext;
93 WebGPUContext::is_available()
94 }
95 #[cfg(not(feature = "wgpu_backend"))]
96 {
97 false
98 }
99 }
100 GpuBackend::Metal => {
101 #[cfg(all(feature = "metal", target_os = "macos"))]
102 {
103 true
105 }
106 #[cfg(not(all(feature = "metal", target_os = "macos")))]
107 {
108 false
109 }
110 }
111 GpuBackend::OpenCL => {
112 #[cfg(feature = "opencl")]
113 {
114 use crate::gpu::backends::opencl::OpenCLContext;
115 OpenCLContext::is_available()
116 }
117 #[cfg(not(feature = "opencl"))]
118 {
119 false
120 }
121 }
122 GpuBackend::Cpu => true,
123 }
124 }
125}
126
127impl fmt::Display for GpuBackend {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 match self {
130 GpuBackend::Cuda => write!(f, "CUDA"),
131 GpuBackend::Rocm => write!(f, "ROCm"),
132 GpuBackend::Wgpu => write!(f, "WebGPU"),
133 GpuBackend::Metal => write!(f, "Metal"),
134 GpuBackend::OpenCL => write!(f, "OpenCL"),
135 GpuBackend::Cpu => write!(f, "CPU"),
136 }
137 }
138}
139
140use crate::error::{CoreError, ErrorContext, ErrorLocation};
141
142#[derive(Debug, thiserror::Error)]
144pub enum GpuError {
145 #[error("GPU backend {0} is not available")]
147 BackendNotAvailable(String),
148
149 #[error("GPU backend {0} is not supported")]
151 UnsupportedBackend(GpuBackend),
152
153 #[error("GPU backend {0:?} is not supported for this kernel")]
155 BackendNotSupported(GpuBackend),
156
157 #[error("GPU backend {0} is not implemented yet")]
159 BackendNotImplemented(GpuBackend),
160
161 #[error("GPU out of memory: {0}")]
163 OutOfMemory(String),
164
165 #[error("Kernel compilation error: {0}")]
167 KernelCompilationError(String),
168
169 #[error("Kernel execution error: {0}")]
171 KernelExecutionError(String),
172
173 #[error("Invalid parameter: {0}")]
175 InvalidParameter(String),
176
177 #[error("Kernel not found: {0}")]
179 KernelNotFound(String),
180
181 #[error("Kernel specialization not supported")]
183 SpecializationNotSupported,
184
185 #[error("Unsupported data type: {0:?}")]
187 UnsupportedDataType(kernels::DataType),
188
189 #[error("{0}")]
191 Other(String),
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub struct GpuDevice {
197 backend: GpuBackend,
198 device_id: usize,
199}
200
201impl GpuDevice {
202 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
204 Self { backend, device_id }
205 }
206
207 pub fn backend(&self) -> GpuBackend {
209 self.backend
210 }
211
212 pub fn device_id(&self) -> usize {
214 self.device_id
215 }
216
217 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
219 Ok(GpuKernel {
221 backend: self.backend,
222 entry_point: entrypoint.to_string(),
223 })
224 }
225}
226
227pub struct GpuKernel {
229 backend: GpuBackend,
230 entry_point: String,
231}
232
233impl GpuKernel {
234 pub fn backend(&self) -> GpuBackend {
236 self.backend
237 }
238
239 pub fn entry_point(&self) -> &str {
241 &self.entry_point
242 }
243}
244
245impl From<GpuError> for CoreError {
247 fn from(err: GpuError) -> Self {
248 match err {
249 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
250 ErrorContext::new(format!("GPU backend {backend} is not available"))
251 .with_location(ErrorLocation::new(file!(), line!())),
252 ),
253 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
254 ErrorContext::new(format!("GPU backend {backend} is not supported"))
255 .with_location(ErrorLocation::new(file!(), line!())),
256 ),
257 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
258 ErrorContext::new(format!(
259 "GPU backend {backend:?} is not supported for this kernel"
260 ))
261 .with_location(ErrorLocation::new(file!(), line!())),
262 ),
263 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
264 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
265 .with_location(ErrorLocation::new(file!(), line!())),
266 ),
267 GpuError::OutOfMemory(details) => CoreError::MemoryError(
268 ErrorContext::new(details.to_string())
269 .with_location(ErrorLocation::new(file!(), line!())),
270 ),
271 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
272 ErrorContext::new(msg.to_string())
273 .with_location(ErrorLocation::new(file!(), line!())),
274 ),
275 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
276 ErrorContext::new(msg.to_string())
277 .with_location(ErrorLocation::new(file!(), line!())),
278 ),
279 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
280 ErrorContext::new(msg.to_string())
281 .with_location(ErrorLocation::new(file!(), line!())),
282 ),
283 GpuError::KernelNotFound(name) => CoreError::ComputationError(
284 ErrorContext::new(name.to_string())
285 .with_location(ErrorLocation::new(file!(), line!())),
286 ),
287 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
288 ErrorContext::new("Kernel specialization not supported".to_string())
289 .with_location(ErrorLocation::new(file!(), line!())),
290 ),
291 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
292 ErrorContext::new(format!("{dtype:?}"))
293 .with_location(ErrorLocation::new(file!(), line!())),
294 ),
295 GpuError::Other(msg) => CoreError::ComputationError(
296 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
297 ),
298 }
299 }
300}
301
302pub trait GpuDataType: Copy + Send + Sync + 'static {}
304
305#[derive(Debug)]
307pub struct GpuPtr<T: GpuDataType> {
308 ptr: u64,
309 size: usize,
310 phantom: PhantomData<T>,
311}
312
313impl<T: GpuDataType> GpuPtr<T> {
314 pub fn allocate(size: usize) -> Result<Self, GpuError> {
316 Ok(GpuPtr {
317 ptr: 0x1000_0000, size,
319 phantom: PhantomData,
320 })
321 }
322
323 pub fn as_ptr(&self) -> u64 {
325 self.ptr
326 }
327
328 pub fn len(&self) -> usize {
330 self.size
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.size == 0
336 }
337}
338
339#[derive(Debug, Clone)]
341pub enum KernelArg<'a, T: GpuDataType> {
342 Buffer(&'a GpuPtr<T>),
344 Scalar(T),
346}
347
348#[derive(Debug, Clone)]
350pub enum DynamicKernelArg {
351 Buffer(u64), F32(f32),
355 F64(f64),
357 I32(i32),
359 U32(u32),
361 Usize(usize),
363}
364
365pub struct GpuChannel {
367 #[allow(dead_code)]
368 source_device: usize,
369 #[allow(dead_code)]
370 target_device: usize,
371 #[allow(dead_code)]
372 bandwidth: f64, }
374
375impl GpuDataType for f32 {}
377impl GpuDataType for f64 {}
378impl GpuDataType for i32 {}
379impl GpuDataType for u32 {}
380impl GpuDataType for u8 {}
381impl GpuDataType for i8 {}
382impl GpuDataType for u16 {}
383impl GpuDataType for i16 {}
384impl GpuDataType for u64 {}
385impl GpuDataType for i64 {}
386impl GpuDataType for usize {}
387impl GpuDataType for isize {}
388
389pub struct GpuBuffer<T: GpuDataType> {
391 inner: Arc<dyn GpuBufferImpl>,
392 size: usize,
393 phantom: PhantomData<T>,
394}
395
396impl<T: GpuDataType> GpuBuffer<T> {
397 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
399 Self {
400 inner,
401 size,
402 phantom: PhantomData,
403 }
404 }
405
406 pub fn len(&self) -> usize {
408 self.size
409 }
410
411 pub fn is_empty(&self) -> bool {
413 self.size == 0
414 }
415
416 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
418 if data.len() > self.size {
419 return Err(GpuError::InvalidParameter(
420 "Data size exceeds buffer size".to_string(),
421 ));
422 }
423 unsafe {
424 self.inner
425 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
426 }
427 Ok(())
428 }
429
430 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
432 if data.len() > self.size {
433 return Err(GpuError::InvalidParameter(
434 "Data size exceeds buffer size".to_string(),
435 ));
436 }
437 unsafe {
438 self.inner
439 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
440 }
441 Ok(())
442 }
443
444 pub fn to_vec(&self) -> Vec<T> {
446 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
447 let _ = self.copy_to_host(&mut result);
448 result
449 }
450}
451
452impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 f.debug_struct("GpuBuffer")
455 .field("size", &self.size)
456 .finish()
457 }
458}
459
460impl<T: GpuDataType> Clone for GpuBuffer<T> {
461 fn clone(&self) -> Self {
462 Self {
463 inner: Arc::clone(&self.inner),
464 size: self.size,
465 phantom: PhantomData,
466 }
467 }
468}
469
470#[derive(Clone)]
472pub struct GpuKernelHandle {
473 inner: Arc<dyn GpuKernelImpl>,
474}
475
476impl GpuKernelHandle {
477 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
479 Self { inner }
480 }
481
482 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
484 self.inner.set_buffer(name, &buffer.inner);
485 }
486
487 pub fn set_u32(&self, name: &str, value: u32) {
489 self.inner.set_u32(name, value);
490 }
491
492 pub fn set_i32(&self, name: &str, value: i32) {
494 self.inner.set_i32(name, value);
495 }
496
497 pub fn set_f32(&self, name: &str, value: f32) {
499 self.inner.set_f32(name, value);
500 }
501
502 pub fn set_f64(&self, name: &str, value: f64) {
504 self.inner.set_f64(name, value);
505 }
506
507 pub fn dispatch(&self, workgroups: [u32; 3]) {
509 self.inner.dispatch(workgroups);
510 }
511}
512
513pub struct GpuCompiler {
515 inner: Arc<dyn GpuCompilerImpl>,
516}
517
518impl GpuCompiler {
519 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
521 Self { inner }
522 }
523
524 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
526 let kernel = self.inner.compile(source)?;
527 Ok(GpuKernelHandle::new(kernel))
528 }
529
530 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
532 let kernel = self.inner.compile_typed(
533 name,
534 std::any::TypeId::of::<I>(),
535 std::any::TypeId::of::<O>(),
536 );
537 GpuKernelHandle::new(kernel)
538 }
539}
540
541pub struct GpuContext {
543 inner: Arc<dyn GpuContextImpl>,
544 backend: GpuBackend,
545 kernel_registry: kernels::KernelRegistry,
546}
547
548impl GpuContext {
549 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
551 if !backend.is_available() {
553 return Err(GpuError::BackendNotAvailable(backend.to_string()));
554 }
555
556 if backend != GpuBackend::Cpu {
558 let detection_result = backends::detect_gpu_backends();
559 let backend_available = detection_result
560 .devices
561 .iter()
562 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
563
564 if !backend_available {
565 return Err(GpuError::BackendNotAvailable(format!(
566 "{backend} (no devices detected at runtime)"
567 )));
568 }
569 }
570
571 let inner = match backend {
572 GpuBackend::Cuda => {
573 #[cfg(feature = "cuda")]
574 {
575 use crate::gpu::backends::cuda::CudaContext;
576 match CudaContext::new() {
577 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
578 Err(e) => return Err(e),
579 }
580 }
581 #[cfg(not(feature = "cuda"))]
582 {
583 return Err(GpuError::UnsupportedBackend(backend));
584 }
585 }
586 GpuBackend::Rocm => {
587 #[cfg(feature = "rocm")]
588 {
589 #[cfg(test)]
592 {
593 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
595 }
596 #[cfg(not(test))]
597 {
598 return Err(GpuError::BackendNotImplemented(backend));
599 }
600 }
601 #[cfg(not(feature = "rocm"))]
602 {
603 return Err(GpuError::UnsupportedBackend(backend));
604 }
605 }
606 GpuBackend::Wgpu => {
607 #[cfg(feature = "wgpu_backend")]
608 {
609 use crate::gpu::backends::wgpu::WebGPUContext;
610 match WebGPUContext::new() {
611 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
612 Err(e) => return Err(e),
613 }
614 }
615 #[cfg(not(feature = "wgpu_backend"))]
616 {
617 return Err(GpuError::UnsupportedBackend(backend));
618 }
619 }
620 GpuBackend::Metal => {
621 #[cfg(all(feature = "metal", target_os = "macos"))]
622 {
623 use crate::gpu::backends::metal::MetalContext;
624 match MetalContext::new() {
625 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
626 Err(e) => return Err(e),
627 }
628 }
629 #[cfg(not(all(feature = "metal", target_os = "macos")))]
630 {
631 return Err(GpuError::UnsupportedBackend(backend));
632 }
633 }
634 GpuBackend::OpenCL => {
635 #[cfg(feature = "opencl")]
636 {
637 use crate::gpu::backends::opencl::OpenCLContext;
638 match OpenCLContext::new() {
639 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
640 Err(e) => return Err(e),
641 }
642 }
643 #[cfg(not(feature = "opencl"))]
644 {
645 return Err(GpuError::UnsupportedBackend(backend));
646 }
647 }
648 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
649 };
650
651 Ok(Self {
652 inner,
653 backend,
654 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
655 })
656 }
657
658 pub fn backend(&self) -> GpuBackend {
660 self.backend
661 }
662
663 pub fn backend_name(&self) -> &str {
665 match self.backend {
666 GpuBackend::Cuda => "CUDA",
667 GpuBackend::Rocm => "ROCm",
668 GpuBackend::Wgpu => "WebGPU",
669 GpuBackend::Metal => "Metal",
670 GpuBackend::OpenCL => "OpenCL",
671 GpuBackend::Cpu => "CPU",
672 }
673 }
674
675 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
677 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
678 let inner = self.inner.create_buffer(byte_size);
679 GpuBuffer::new(inner, size)
680 }
681
682 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
684 let buffer = self.create_buffer::<T>(data.len());
685 let _ = buffer.copy_from_host(data);
686 buffer
687 }
688
689 pub fn execute<F, R>(&self, f: F) -> R
691 where
692 F: FnOnce(&GpuCompiler) -> R,
693 {
694 let compiler = GpuCompiler::new(self.inner.create_compiler());
695 f(&compiler)
696 }
697
698 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
700 let kernel = self
701 .kernel_registry
702 .get(name)
703 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
704
705 let kernel_source = kernel.source_for_backend(self.backend)?;
706 let metadata = kernel.metadata();
707
708 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
709 Ok(handle)
710 }
711
712 pub fn get_specialized_kernel(
714 &self,
715 name: &str,
716 params: &kernels::KernelParams,
717 ) -> Result<GpuKernelHandle, GpuError> {
718 let specialized = self.kernel_registry.get_specialized(name, params)?;
719 let kernel_source = specialized.source_for_backend(self.backend)?;
720 let metadata = specialized.metadata();
721
722 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
723 Ok(handle)
724 }
725
726 fn compile_kernel_with_metadata(
728 &self,
729 source: &str,
730 _metadata: &kernels::KernelMetadata,
731 ) -> Result<GpuKernelHandle, GpuError> {
732 self.execute(|compiler| compiler.compile(source))
733 }
734
735 pub fn get_available_memory(&self) -> Option<usize> {
737 Some(1024 * 1024 * 1024) }
741
742 pub fn get_total_memory(&self) -> Option<usize> {
744 #[cfg(target_arch = "wasm32")]
747 return Some(512 * 1024 * 1024); #[cfg(not(target_arch = "wasm32"))]
750 Some((4u64 * 1024 * 1024 * 1024) as usize) }
752
753 pub fn launch_kernel(
755 &self,
756 kernel_name: &str,
757 grid_size: (usize, usize, usize),
758 block_size: (usize, usize, usize),
759 args: &[DynamicKernelArg],
760 ) -> Result<(), GpuError> {
761 let _ = (kernel_name, grid_size, block_size, args);
763 Ok(())
764 }
765
766 pub fn transfer_async_host_to_device<T: GpuDataType>(
768 &self,
769 ptr: &GpuPtr<T>,
770 data: &[T],
771 ) -> Result<(), GpuError> {
772 let _ = (ptr, data);
774 Ok(())
775 }
776
777 pub fn transfer_host_to_device<T: GpuDataType>(
779 &self,
780 ptr: &GpuPtr<T>,
781 data: &[T],
782 ) -> Result<(), GpuError> {
783 let _ = (ptr, data);
785 Ok(())
786 }
787
788 pub fn transfer_async_device_to_host<T: GpuDataType>(
790 &self,
791 ptr: &GpuPtr<T>,
792 data: &mut [T],
793 ) -> Result<(), GpuError> {
794 let _ = (ptr, data);
796 Ok(())
797 }
798
799 pub fn transfer_device_to_host<T: GpuDataType>(
801 &self,
802 ptr: &GpuPtr<T>,
803 data: &mut [T],
804 ) -> Result<(), GpuError> {
805 let _ = (ptr, data);
807 Ok(())
808 }
809
810 pub fn execute_kernel(
813 &self,
814 source: &str,
815 buffers: &[GpuBuffer<f32>],
816 work_groups: (u32, u32, u32),
817 int_params: &[u32],
818 float_params: &[f32],
819 ) -> Result<(), GpuError> {
820 eprintln!(
823 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
824 source.len(),
825 buffers.len(),
826 work_groups
827 );
828 eprintln!("Int params: {int_params:?}");
829 eprintln!("Float params: {float_params:?}");
830 Ok(())
831 }
832
833 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
836 Ok(buffer.to_vec())
837 }
838
839 pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
841 self.sum_all_cpu_fallback(buffer)
842 }
843
844 pub fn mean_all<T: GpuDataType>(
846 &self,
847 buffer: &GpuBuffer<T>,
848 ) -> Result<GpuBuffer<T>, GpuError> {
849 self.mean_all_cpu_fallback(buffer)
850 }
851
852 pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
854 self.max_all_cpu_fallback(buffer)
855 }
856
857 pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
859 self.min_all_cpu_fallback(buffer)
860 }
861
862 pub fn sum_axis<T: GpuDataType>(
864 &self,
865 buffer: &GpuBuffer<T>,
866 shape: &[usize],
867 axis: usize,
868 ) -> Result<GpuBuffer<T>, GpuError> {
869 self.sum_axis_cpu_fallback(buffer, shape, axis)
870 }
871
872 pub fn mean_axis<T: GpuDataType>(
874 &self,
875 buffer: &GpuBuffer<T>,
876 shape: &[usize],
877 axis: usize,
878 ) -> Result<GpuBuffer<T>, GpuError> {
879 self.mean_axis_cpu_fallback(buffer, shape, axis)
880 }
881
882 pub fn max_axis<T: GpuDataType>(
884 &self,
885 buffer: &GpuBuffer<T>,
886 shape: &[usize],
887 axis: usize,
888 ) -> Result<GpuBuffer<T>, GpuError> {
889 self.max_axis_cpu_fallback(buffer, shape, axis)
890 }
891
892 pub fn min_axis<T: GpuDataType>(
894 &self,
895 buffer: &GpuBuffer<T>,
896 shape: &[usize],
897 axis: usize,
898 ) -> Result<GpuBuffer<T>, GpuError> {
899 self.min_axis_cpu_fallback(buffer, shape, axis)
900 }
901
902 pub fn broadcast<T: GpuDataType>(
904 &self,
905 buffer: &GpuBuffer<T>,
906 from_shape: &[usize],
907 to_shape: &[usize],
908 ) -> Result<GpuBuffer<T>, GpuError> {
909 self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
910 }
911
912 pub fn scale<T: GpuDataType>(
914 &self,
915 buffer: &GpuBuffer<T>,
916 scalar: T,
917 ) -> Result<GpuBuffer<T>, GpuError> {
918 self.scale_cpu_fallback(buffer, scalar)
919 }
920
921 pub fn gemm<T: GpuDataType>(
923 &self,
924 a: &GpuBuffer<T>,
925 b: &GpuBuffer<T>,
926 m: usize,
927 k: usize,
928 n: usize,
929 ) -> Result<GpuBuffer<T>, GpuError> {
930 self.gemm_cpu_fallback(a, b, m, k, n)
931 }
932
933 pub fn gemm_transpose_b<T: GpuDataType>(
935 &self,
936 a: &GpuBuffer<T>,
937 b: &GpuBuffer<T>,
938 m: usize,
939 k: usize,
940 n: usize,
941 ) -> Result<GpuBuffer<T>, GpuError> {
942 self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
943 }
944
945 pub fn gemm_transpose_a<T: GpuDataType>(
947 &self,
948 a: &GpuBuffer<T>,
949 b: &GpuBuffer<T>,
950 m: usize,
951 k: usize,
952 n: usize,
953 ) -> Result<GpuBuffer<T>, GpuError> {
954 self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
955 }
956
957 pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
959 self.relu_cpu_fallback(input)
960 }
961
962 pub fn relu_backward<T: GpuDataType>(
964 &self,
965 grad_output: &GpuBuffer<T>,
966 input: &GpuBuffer<T>,
967 ) -> Result<GpuBuffer<T>, GpuError> {
968 self.relu_backward_cpu_fallback(grad_output, input)
969 }
970
971 pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
973 self.sigmoid_cpu_fallback(input)
974 }
975
976 pub fn sigmoid_backward<T: GpuDataType>(
978 &self,
979 grad_output: &GpuBuffer<T>,
980 input: &GpuBuffer<T>,
981 ) -> Result<GpuBuffer<T>, GpuError> {
982 self.sigmoid_backward_cpu_fallback(grad_output, input)
983 }
984
985 pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
987 self.tanh_cpu_fallback(input)
988 }
989
990 pub fn tanh_backward<T: GpuDataType>(
992 &self,
993 grad_output: &GpuBuffer<T>,
994 input: &GpuBuffer<T>,
995 ) -> Result<GpuBuffer<T>, GpuError> {
996 self.tanh_backward_cpu_fallback(grad_output, input)
997 }
998
999 pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1001 self.gelu_cpu_fallback(input)
1002 }
1003
1004 pub fn gelu_backward<T: GpuDataType>(
1006 &self,
1007 grad_output: &GpuBuffer<T>,
1008 input: &GpuBuffer<T>,
1009 ) -> Result<GpuBuffer<T>, GpuError> {
1010 self.gelu_backward_cpu_fallback(grad_output, input)
1011 }
1012}
1013
1014impl fmt::Debug for GpuContext {
1015 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1016 f.debug_struct("GpuContext")
1017 .field("backend", &self.backend)
1018 .finish()
1019 }
1020}
1021
1022pub(crate) trait GpuBufferImpl: Send + Sync {
1027 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1029
1030 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1032
1033 #[allow(dead_code)]
1035 fn as_any(&self) -> &dyn std::any::Any;
1036
1037 #[allow(dead_code)]
1039 fn size(&self) -> usize {
1040 0 }
1042
1043 #[allow(dead_code)]
1045 fn device_ptr(&self) -> u64 {
1046 0 }
1048}
1049
1050pub(crate) trait GpuKernelImpl: Send + Sync {
1052 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1054
1055 fn set_u32(&self, name: &str, value: u32);
1057
1058 fn set_i32(&self, name: &str, value: i32);
1060
1061 fn set_f32(&self, name: &str, value: f32);
1063
1064 fn set_f64(&self, name: &str, value: f64);
1066
1067 fn dispatch(&self, workgroups: [u32; 3]);
1069}
1070
1071pub(crate) trait GpuCompilerImpl: Send + Sync {
1073 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1075
1076 fn compile_typed(
1078 &self,
1079 name: &str,
1080 input_type: std::any::TypeId,
1081 output_type: std::any::TypeId,
1082 ) -> Arc<dyn GpuKernelImpl>;
1083}
1084
1085pub(crate) trait GpuContextImpl: Send + Sync {
1087 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1089
1090 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1092
1093 fn as_any(&self) -> &dyn std::any::Any
1095 where
1096 Self: 'static + Sized,
1097 {
1098 self
1099 }
1100}
1101
1102struct CpuContext;
1106
1107impl CpuContext {
1108 fn new() -> Self {
1110 Self
1111 }
1112}
1113
1114impl GpuContextImpl for CpuContext {
1115 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1116 Arc::new(CpuBuffer::new(size))
1117 }
1118
1119 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1120 Arc::new(CpuCompiler)
1121 }
1122}
1123
1124struct CpuBuffer {
1126 data: Vec<u8>,
1127}
1128
1129impl CpuBuffer {
1130 fn new(size: usize) -> Self {
1132 Self {
1133 data: vec![0; size],
1134 }
1135 }
1136}
1137
1138impl GpuBufferImpl for CpuBuffer {
1139 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1140 let mut_self = self as *const Self as *mut Self;
1141 let data_ptr = (*mut_self).data.as_mut_ptr();
1142 std::ptr::copy_nonoverlapping(data, data_ptr, size);
1143 }
1144
1145 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1146 let data_ptr = self.data.as_ptr();
1147 std::ptr::copy_nonoverlapping(data_ptr, data, size);
1148 }
1149
1150 fn as_any(&self) -> &dyn std::any::Any {
1151 self
1152 }
1153
1154 fn size(&self) -> usize {
1155 self.data.len()
1156 }
1157
1158 fn device_ptr(&self) -> u64 {
1159 self.data.as_ptr() as u64
1160 }
1161}
1162
1163struct CpuCompiler;
1165
1166impl GpuCompilerImpl for CpuCompiler {
1167 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1168 Ok(Arc::new(CpuKernel))
1171 }
1172
1173 fn compile_typed(
1174 &self,
1175 _name: &str,
1176 _input_type: std::any::TypeId,
1177 _output_type: std::any::TypeId,
1178 ) -> Arc<dyn GpuKernelImpl> {
1179 Arc::new(CpuKernel)
1182 }
1183}
1184
1185struct CpuKernel;
1187
1188impl GpuKernelImpl for CpuKernel {
1189 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1190 }
1192
1193 fn set_u32(&self, _name: &str, value: u32) {
1194 }
1196
1197 fn set_i32(&self, _name: &str, value: i32) {
1198 }
1200
1201 fn set_f32(&self, _name: &str, value: f32) {
1202 }
1204
1205 fn set_f64(&self, _name: &str, value: f64) {
1206 }
1208
1209 fn dispatch(&self, workgroups: [u32; 3]) {
1210 }
1212}
1213
1214#[cfg(test)]
1218mod tests {
1219 use super::*;
1220
1221 #[test]
1222 fn test_gpu_backend_preferred() {
1223 let backend = GpuBackend::preferred();
1224 match backend {
1226 GpuBackend::Cuda
1227 | GpuBackend::Rocm
1228 | GpuBackend::Wgpu
1229 | GpuBackend::Metal
1230 | GpuBackend::OpenCL
1231 | GpuBackend::Cpu => {}
1232 }
1233 }
1234
1235 #[test]
1236 fn test_gpu_backend_default() {
1237 let backend = GpuBackend::default();
1238 assert_eq!(backend, GpuBackend::preferred());
1239 }
1240
1241 #[test]
1242 fn test_gpu_backend_is_available() {
1243 let backend = GpuBackend::Cpu;
1244 assert!(backend.is_available());
1245
1246 #[cfg(feature = "cuda")]
1248 {
1249 let _ = GpuBackend::Cuda.is_available(); }
1252 #[cfg(not(feature = "cuda"))]
1253 assert!(!GpuBackend::Cuda.is_available());
1254
1255 #[cfg(feature = "rocm")]
1256 {
1257 let _ = GpuBackend::Rocm.is_available(); }
1260 #[cfg(not(feature = "rocm"))]
1261 assert!(!GpuBackend::Rocm.is_available());
1262
1263 #[cfg(all(feature = "metal", target_os = "macos"))]
1264 assert!(GpuBackend::Metal.is_available());
1265 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1266 assert!(!GpuBackend::Metal.is_available());
1267 }
1268
1269 #[test]
1270 fn test_gpu_backend_display() {
1271 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1272 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1273 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1274 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1275 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1276 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1277 }
1278
1279 #[test]
1280 fn test_gpuerror_from_conversion() {
1281 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1282 let coreerror: CoreError = gpuerror.into();
1283 match coreerror {
1284 CoreError::ComputationError(_) => {}
1285 _ => panic!("Expected ComputationError"),
1286 }
1287
1288 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1289 let coreerror: CoreError = gpuerror.into();
1290 match coreerror {
1291 CoreError::MemoryError(_) => {}
1292 _ => panic!("Expected MemoryError"),
1293 }
1294
1295 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1296 let coreerror: CoreError = gpuerror.into();
1297 match coreerror {
1298 CoreError::InvalidArgument(_) => {}
1299 _ => panic!("Expected InvalidArgument"),
1300 }
1301
1302 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1303 let coreerror: CoreError = gpuerror.into();
1304 match coreerror {
1305 CoreError::TypeError(_) => {}
1306 _ => panic!("Expected TypeError"),
1307 }
1308 }
1309
1310 #[test]
1311 fn test_gpu_datatype_trait() {
1312 fn assert_gpu_datatype<T: GpuDataType>() {}
1314
1315 assert_gpu_datatype::<f32>();
1316 assert_gpu_datatype::<f64>();
1317 assert_gpu_datatype::<i32>();
1318 assert_gpu_datatype::<u32>();
1319 assert_gpu_datatype::<u8>();
1320 assert_gpu_datatype::<i8>();
1321 assert_gpu_datatype::<u16>();
1322 assert_gpu_datatype::<i16>();
1323 assert_gpu_datatype::<u64>();
1324 assert_gpu_datatype::<i64>();
1325 }
1326
1327 #[test]
1328 fn test_gpu_buffer_creation() {
1329 let inner = Arc::new(CpuBuffer::new(100));
1330 let buffer = GpuBuffer::<f32>::new(inner, 25);
1331
1332 assert_eq!(buffer.len(), 25);
1333 assert!(!buffer.is_empty());
1334 }
1335
1336 #[test]
1337 fn test_gpu_buffer_empty() {
1338 let inner = Arc::new(CpuBuffer::new(0));
1339 let buffer = GpuBuffer::<f32>::new(inner, 0);
1340
1341 assert_eq!(buffer.len(), 0);
1342 assert!(buffer.is_empty());
1343 }
1344
1345 #[test]
1346 fn test_gpu_buffer_copy_operations() {
1347 let inner = Arc::new(CpuBuffer::new(16));
1348 let buffer = GpuBuffer::<f32>::new(inner, 4);
1349
1350 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1351 let _ = buffer.copy_from_host(&data);
1352
1353 let mut result = vec![0.0f32; 4];
1354 let _ = buffer.copy_to_host(&mut result);
1355
1356 assert_eq!(result, data);
1357 }
1358
1359 #[test]
1360 fn test_gpu_buffer_to_vec() {
1361 let inner = Arc::new(CpuBuffer::new(12));
1362 let buffer = GpuBuffer::<f32>::new(inner, 3);
1363
1364 let data = vec![5.0f32, 6.0, 7.0];
1365 let _ = buffer.copy_from_host(&data);
1366
1367 let result = buffer.to_vec();
1368 assert_eq!(result, data);
1369 }
1370
1371 #[test]
1372 #[should_panic(expected = "Data size exceeds buffer size")]
1373 fn test_gpu_buffer_copy_from_host_overflow() {
1374 let inner = Arc::new(CpuBuffer::new(8));
1375 let buffer = GpuBuffer::<f32>::new(inner, 2);
1376
1377 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1379 }
1380
1381 #[test]
1382 #[should_panic(expected = "Data size exceeds buffer size")]
1383 fn test_gpu_buffer_copy_to_host_overflow() {
1384 let inner = Arc::new(CpuBuffer::new(8));
1385 let buffer = GpuBuffer::<f32>::new(inner, 2);
1386
1387 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1389 }
1390
1391 #[test]
1392 fn test_gpu_kernel_handle() {
1393 let kernel = Arc::new(CpuKernel);
1394 let handle = GpuKernelHandle::new(kernel);
1395
1396 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1398 handle.set_buffer("input", &buffer);
1399 handle.set_u32("size", 100);
1400 handle.set_i32("offset", -5);
1401 handle.set_f32("scale", 2.5);
1402 handle.set_f64("precision", 0.0001);
1403
1404 handle.dispatch([16, 8, 1]);
1406 }
1407
1408 #[test]
1409 fn test_gpu_context_cpu_backend() {
1410 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1411 assert_eq!(context.backend(), GpuBackend::Cpu);
1412 assert_eq!(context.backend_name(), "CPU");
1413
1414 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1416 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1417 }
1418
1419 #[test]
1420 fn test_gpu_context_buffer_creation() {
1421 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1422
1423 let buffer = context.create_buffer::<f32>(100);
1424 assert_eq!(buffer.len(), 100);
1425
1426 let data = vec![1.0f32; 50];
1427 let buffer_from_slice = context.create_buffer_from_slice(&data);
1428 assert_eq!(buffer_from_slice.len(), 50);
1429
1430 let result = buffer_from_slice.to_vec();
1431 assert_eq!(result, data);
1432 }
1433
1434 #[test]
1435 fn test_gpu_context_unsupported_backend() {
1436 #[cfg(not(feature = "cuda"))]
1438 {
1439 let result = GpuContext::new(GpuBackend::Cuda);
1440 assert!(result.is_err());
1441 match result {
1442 Err(GpuError::UnsupportedBackend(_)) => {}
1443 Err(GpuError::BackendNotAvailable(_)) => {} Err(e) => panic!(
1445 "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1446 e
1447 ),
1448 Ok(_) => panic!("Expected error, got Ok"),
1449 }
1450 }
1451 }
1452
1453 #[test]
1454 fn test_gpu_compiler() {
1455 let compiler_impl = Arc::new(CpuCompiler);
1456 let compiler = GpuCompiler::new(compiler_impl);
1457
1458 let kernel = compiler
1460 .compile("dummy kernel source")
1461 .expect("Operation failed");
1462 kernel.dispatch([1, 1, 1]);
1463
1464 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1466 typed_kernel.dispatch([32, 1, 1]);
1467 }
1468
1469 #[test]
1470 fn test_gpu_context_execute() {
1471 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1472
1473 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1474
1475 assert!(result);
1476 }
1477
1478 #[test]
1479 fn test_gpu_context_kernel_registry() {
1480 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1481
1482 let result = context.get_kernel("non_existent_kernel");
1484 assert!(result.is_err());
1485 match result {
1486 Err(GpuError::KernelNotFound(_)) => {}
1487 _ => panic!("Expected KernelNotFound error"),
1488 }
1489 }
1490
1491 #[test]
1492 fn test_cpu_buffer_implementation() {
1493 let buffer = CpuBuffer::new(256);
1494 assert_eq!(buffer.data.len(), 256);
1495
1496 assert!(buffer.data.iter().all(|&b| b == 0));
1498 }
1499
1500 #[test]
1501 fn test_gpuerror_display() {
1502 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1503 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1504
1505 let error = GpuError::OutOfMemory("allocation failed".to_string());
1506 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1507
1508 let error = GpuError::KernelCompilationError("syntax error".to_string());
1509 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1510
1511 let error = GpuError::KernelNotFound("gemm".to_string());
1512 assert_eq!(error.to_string(), "Kernel not found: gemm");
1513 }
1514
1515 #[test]
1516 fn test_backend_equality() {
1517 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1518 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1519
1520 let backend = GpuBackend::Metal;
1522 let cloned = backend;
1523 let copied = backend;
1524 assert_eq!(backend, cloned);
1525 assert_eq!(backend, copied);
1526 }
1527
1528 #[test]
1529 fn test_backend_hash() {
1530 use std::collections::HashSet;
1531
1532 let mut set = HashSet::new();
1533 set.insert(GpuBackend::Cuda);
1534 set.insert(GpuBackend::Rocm);
1535 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1539 assert!(set.contains(&GpuBackend::Rocm));
1540 }
1541
1542 #[test]
1543 fn test_gpu_buffer_debug_clone() {
1544 let inner = Arc::new(CpuBuffer::new(16));
1545 let buffer = GpuBuffer::<f32>::new(inner, 4);
1546
1547 let debug_str = format!("{:?}", buffer);
1549 assert!(debug_str.contains("GpuBuffer"));
1550 assert!(debug_str.contains("size"));
1551
1552 let cloned = buffer.clone();
1554 assert_eq!(cloned.len(), buffer.len());
1555 assert_eq!(cloned.len(), 4);
1556
1557 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1559 let _ = buffer.copy_from_host(&data);
1560
1561 let mut result = vec![0.0f32; 4];
1562 let _ = cloned.copy_to_host(&mut result);
1563 assert_eq!(result, data);
1564 }
1565
1566 #[test]
1567 fn test_gpu_context_debug() {
1568 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1569
1570 let debug_str = format!("{:?}", context);
1572 assert!(debug_str.contains("GpuContext"));
1573 assert!(debug_str.contains("backend"));
1574 assert!(debug_str.contains("Cpu"));
1575 }
1576}