1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod auto_tuning;
11pub mod backends;
12pub mod benchmarks;
13pub mod heterogeneous;
14pub mod kernels;
15pub mod tensor_cores;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum GpuBackend {
20 Cuda,
22 Rocm,
24 Wgpu,
26 Metal,
28 OpenCL,
30 Cpu,
32}
33
34impl Default for GpuBackend {
35 fn default() -> Self {
36 Self::preferred()
37 }
38}
39
40impl GpuBackend {
41 pub fn preferred() -> Self {
43 match backends::initialize_optimal_backend() {
46 Ok(backend) => {
47 if backend != GpuBackend::Cpu {
49 #[cfg(not(test))]
52 {
53 return GpuBackend::Cpu;
55 }
56 #[cfg(test)]
57 {
58 return backend;
60 }
61 }
62 backend
63 }
64 Err(_) => {
65 GpuBackend::Cpu
67 }
68 }
69 }
70
71 pub fn is_available(&self) -> bool {
73 match self {
74 GpuBackend::Cuda => {
76 #[cfg(feature = "cuda")]
77 {
78 use crate::gpu::backends::cuda::CudaContext;
79 CudaContext::is_available()
80 }
81 #[cfg(not(feature = "cuda"))]
82 {
83 false
84 }
85 }
86 GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
88 #[cfg(feature = "wgpu_backend")]
89 {
90 use crate::gpu::backends::wgpu::WebGPUContext;
91 WebGPUContext::is_available()
92 }
93 #[cfg(not(feature = "wgpu_backend"))]
94 {
95 false
96 }
97 }
98 GpuBackend::Metal => {
99 #[cfg(all(feature = "metal", target_os = "macos"))]
100 {
101 true
103 }
104 #[cfg(not(all(feature = "metal", target_os = "macos")))]
105 {
106 false
107 }
108 }
109 GpuBackend::OpenCL => {
110 #[cfg(feature = "opencl")]
111 {
112 use crate::gpu::backends::opencl::OpenCLContext;
113 OpenCLContext::is_available()
114 }
115 #[cfg(not(feature = "opencl"))]
116 {
117 false
118 }
119 }
120 GpuBackend::Cpu => true,
121 }
122 }
123}
124
125impl fmt::Display for GpuBackend {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 GpuBackend::Cuda => write!(f, "CUDA"),
129 GpuBackend::Rocm => write!(f, "ROCm"),
130 GpuBackend::Wgpu => write!(f, "WebGPU"),
131 GpuBackend::Metal => write!(f, "Metal"),
132 GpuBackend::OpenCL => write!(f, "OpenCL"),
133 GpuBackend::Cpu => write!(f, "CPU"),
134 }
135 }
136}
137
138use crate::error::{CoreError, ErrorContext, ErrorLocation};
139
140#[derive(Debug, thiserror::Error)]
142pub enum GpuError {
143 #[error("GPU backend {0} is not available")]
145 BackendNotAvailable(String),
146
147 #[error("GPU backend {0} is not supported")]
149 UnsupportedBackend(GpuBackend),
150
151 #[error("GPU backend {0:?} is not supported for this kernel")]
153 BackendNotSupported(GpuBackend),
154
155 #[error("GPU backend {0} is not implemented yet")]
157 BackendNotImplemented(GpuBackend),
158
159 #[error("GPU out of memory: {0}")]
161 OutOfMemory(String),
162
163 #[error("Kernel compilation error: {0}")]
165 KernelCompilationError(String),
166
167 #[error("Kernel execution error: {0}")]
169 KernelExecutionError(String),
170
171 #[error("Invalid parameter: {0}")]
173 InvalidParameter(String),
174
175 #[error("Kernel not found: {0}")]
177 KernelNotFound(String),
178
179 #[error("Kernel specialization not supported")]
181 SpecializationNotSupported,
182
183 #[error("Unsupported data type: {0:?}")]
185 UnsupportedDataType(kernels::DataType),
186
187 #[error("{0}")]
189 Other(String),
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub struct GpuDevice {
195 backend: GpuBackend,
196 device_id: usize,
197}
198
199impl GpuDevice {
200 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
202 Self { backend, device_id }
203 }
204
205 pub fn backend(&self) -> GpuBackend {
207 self.backend
208 }
209
210 pub fn device_id(&self) -> usize {
212 self.device_id
213 }
214
215 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
217 Ok(GpuKernel {
219 backend: self.backend,
220 entry_point: entrypoint.to_string(),
221 })
222 }
223}
224
225pub struct GpuKernel {
227 backend: GpuBackend,
228 entry_point: String,
229}
230
231impl GpuKernel {
232 pub fn backend(&self) -> GpuBackend {
234 self.backend
235 }
236
237 pub fn entry_point(&self) -> &str {
239 &self.entry_point
240 }
241}
242
243impl From<GpuError> for CoreError {
245 fn from(err: GpuError) -> Self {
246 match err {
247 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
248 ErrorContext::new(format!("GPU backend {backend} is not available"))
249 .with_location(ErrorLocation::new(file!(), line!())),
250 ),
251 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
252 ErrorContext::new(format!("GPU backend {backend} is not supported"))
253 .with_location(ErrorLocation::new(file!(), line!())),
254 ),
255 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
256 ErrorContext::new(format!(
257 "GPU backend {backend:?} is not supported for this kernel"
258 ))
259 .with_location(ErrorLocation::new(file!(), line!())),
260 ),
261 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
262 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
263 .with_location(ErrorLocation::new(file!(), line!())),
264 ),
265 GpuError::OutOfMemory(details) => CoreError::MemoryError(
266 ErrorContext::new(details.to_string())
267 .with_location(ErrorLocation::new(file!(), line!())),
268 ),
269 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
270 ErrorContext::new(msg.to_string())
271 .with_location(ErrorLocation::new(file!(), line!())),
272 ),
273 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
274 ErrorContext::new(msg.to_string())
275 .with_location(ErrorLocation::new(file!(), line!())),
276 ),
277 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
278 ErrorContext::new(msg.to_string())
279 .with_location(ErrorLocation::new(file!(), line!())),
280 ),
281 GpuError::KernelNotFound(name) => CoreError::ComputationError(
282 ErrorContext::new(name.to_string())
283 .with_location(ErrorLocation::new(file!(), line!())),
284 ),
285 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
286 ErrorContext::new("Kernel specialization not supported".to_string())
287 .with_location(ErrorLocation::new(file!(), line!())),
288 ),
289 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
290 ErrorContext::new(format!("{dtype:?}"))
291 .with_location(ErrorLocation::new(file!(), line!())),
292 ),
293 GpuError::Other(msg) => CoreError::ComputationError(
294 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
295 ),
296 }
297 }
298}
299
300pub trait GpuDataType: Copy + Send + Sync + 'static {}
302
303#[derive(Debug)]
305pub struct GpuPtr<T: GpuDataType> {
306 ptr: u64,
307 size: usize,
308 phantom: PhantomData<T>,
309}
310
311impl<T: GpuDataType> GpuPtr<T> {
312 pub fn allocate(size: usize) -> Result<Self, GpuError> {
314 Ok(GpuPtr {
315 ptr: 0x1000_0000, size,
317 phantom: PhantomData,
318 })
319 }
320
321 pub fn as_ptr(&self) -> u64 {
323 self.ptr
324 }
325
326 pub fn len(&self) -> usize {
328 self.size
329 }
330
331 pub fn is_empty(&self) -> bool {
333 self.size == 0
334 }
335}
336
337#[derive(Debug, Clone)]
339pub enum KernelArg<'a, T: GpuDataType> {
340 Buffer(&'a GpuPtr<T>),
342 Scalar(T),
344}
345
346#[derive(Debug, Clone)]
348pub enum DynamicKernelArg {
349 Buffer(u64), F32(f32),
353 F64(f64),
355 I32(i32),
357 U32(u32),
359 Usize(usize),
361}
362
363pub struct GpuChannel {
365 #[allow(dead_code)]
366 source_device: usize,
367 #[allow(dead_code)]
368 target_device: usize,
369 #[allow(dead_code)]
370 bandwidth: f64, }
372
373impl GpuDataType for f32 {}
375impl GpuDataType for f64 {}
376impl GpuDataType for i32 {}
377impl GpuDataType for u32 {}
378impl GpuDataType for u8 {}
379impl GpuDataType for i8 {}
380impl GpuDataType for u16 {}
381impl GpuDataType for i16 {}
382impl GpuDataType for u64 {}
383impl GpuDataType for i64 {}
384impl GpuDataType for usize {}
385impl GpuDataType for isize {}
386
387pub struct GpuBuffer<T: GpuDataType> {
389 inner: Arc<dyn GpuBufferImpl>,
390 size: usize,
391 phantom: PhantomData<T>,
392}
393
394impl<T: GpuDataType> GpuBuffer<T> {
395 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
397 Self {
398 inner,
399 size,
400 phantom: PhantomData,
401 }
402 }
403
404 pub fn len(&self) -> usize {
406 self.size
407 }
408
409 pub fn is_empty(&self) -> bool {
411 self.size == 0
412 }
413
414 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
416 if data.len() > self.size {
417 return Err(GpuError::InvalidParameter(
418 "Data size exceeds buffer size".to_string(),
419 ));
420 }
421 unsafe {
422 self.inner
423 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
424 }
425 Ok(())
426 }
427
428 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
430 if data.len() > self.size {
431 return Err(GpuError::InvalidParameter(
432 "Data size exceeds buffer size".to_string(),
433 ));
434 }
435 unsafe {
436 self.inner
437 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
438 }
439 Ok(())
440 }
441
442 pub fn to_vec(&self) -> Vec<T> {
444 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
445 let _ = self.copy_to_host(&mut result);
446 result
447 }
448}
449
450#[derive(Clone)]
452pub struct GpuKernelHandle {
453 inner: Arc<dyn GpuKernelImpl>,
454}
455
456impl GpuKernelHandle {
457 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
459 Self { inner }
460 }
461
462 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
464 self.inner.set_buffer(name, &buffer.inner);
465 }
466
467 pub fn set_u32(&self, name: &str, value: u32) {
469 self.inner.set_u32(name, value);
470 }
471
472 pub fn set_i32(&self, name: &str, value: i32) {
474 self.inner.set_i32(name, value);
475 }
476
477 pub fn set_f32(&self, name: &str, value: f32) {
479 self.inner.set_f32(name, value);
480 }
481
482 pub fn set_f64(&self, name: &str, value: f64) {
484 self.inner.set_f64(name, value);
485 }
486
487 pub fn dispatch(&self, workgroups: [u32; 3]) {
489 self.inner.dispatch(workgroups);
490 }
491}
492
493pub struct GpuCompiler {
495 inner: Arc<dyn GpuCompilerImpl>,
496}
497
498impl GpuCompiler {
499 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
501 Self { inner }
502 }
503
504 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
506 let kernel = self.inner.compile(source)?;
507 Ok(GpuKernelHandle::new(kernel))
508 }
509
510 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
512 let kernel = self.inner.compile_typed(
513 name,
514 std::any::TypeId::of::<I>(),
515 std::any::TypeId::of::<O>(),
516 );
517 GpuKernelHandle::new(kernel)
518 }
519}
520
521pub struct GpuContext {
523 inner: Arc<dyn GpuContextImpl>,
524 backend: GpuBackend,
525 kernel_registry: kernels::KernelRegistry,
526}
527
528impl GpuContext {
529 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
531 if !backend.is_available() {
533 return Err(GpuError::BackendNotAvailable(backend.to_string()));
534 }
535
536 if backend != GpuBackend::Cpu {
538 let detection_result = backends::detect_gpu_backends();
539 let backend_available = detection_result
540 .devices
541 .iter()
542 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
543
544 if !backend_available {
545 return Err(GpuError::BackendNotAvailable(format!(
546 "{backend} (no devices detected at runtime)"
547 )));
548 }
549 }
550
551 let inner = match backend {
552 GpuBackend::Cuda => {
553 #[cfg(feature = "cuda")]
554 {
555 use crate::gpu::backends::cuda::CudaContext;
556 match CudaContext::new() {
557 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
558 Err(e) => return Err(e),
559 }
560 }
561 #[cfg(not(feature = "cuda"))]
562 {
563 return Err(GpuError::UnsupportedBackend(backend));
564 }
565 }
566 GpuBackend::Rocm => {
567 #[cfg(feature = "rocm")]
568 {
569 #[cfg(test)]
572 {
573 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
575 }
576 #[cfg(not(test))]
577 {
578 return Err(GpuError::BackendNotImplemented(backend));
579 }
580 }
581 #[cfg(not(feature = "rocm"))]
582 {
583 return Err(GpuError::UnsupportedBackend(backend));
584 }
585 }
586 GpuBackend::Wgpu => {
587 #[cfg(feature = "wgpu_backend")]
588 {
589 use crate::gpu::backends::wgpu::WebGPUContext;
590 match WebGPUContext::new() {
591 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
592 Err(e) => return Err(e),
593 }
594 }
595 #[cfg(not(feature = "wgpu_backend"))]
596 {
597 return Err(GpuError::UnsupportedBackend(backend));
598 }
599 }
600 GpuBackend::Metal => {
601 #[cfg(all(feature = "metal", target_os = "macos"))]
602 {
603 use crate::gpu::backends::metal::MetalContext;
604 match MetalContext::new() {
605 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
606 Err(e) => return Err(e),
607 }
608 }
609 #[cfg(not(all(feature = "metal", target_os = "macos")))]
610 {
611 return Err(GpuError::UnsupportedBackend(backend));
612 }
613 }
614 GpuBackend::OpenCL => {
615 #[cfg(feature = "opencl")]
616 {
617 use crate::gpu::backends::opencl::OpenCLContext;
618 match OpenCLContext::new() {
619 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
620 Err(e) => return Err(e),
621 }
622 }
623 #[cfg(not(feature = "opencl"))]
624 {
625 return Err(GpuError::UnsupportedBackend(backend));
626 }
627 }
628 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
629 };
630
631 Ok(Self {
632 inner,
633 backend,
634 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
635 })
636 }
637
638 pub fn backend(&self) -> GpuBackend {
640 self.backend
641 }
642
643 pub fn backend_name(&self) -> &str {
645 match self.backend {
646 GpuBackend::Cuda => "CUDA",
647 GpuBackend::Rocm => "ROCm",
648 GpuBackend::Wgpu => "WebGPU",
649 GpuBackend::Metal => "Metal",
650 GpuBackend::OpenCL => "OpenCL",
651 GpuBackend::Cpu => "CPU",
652 }
653 }
654
655 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
657 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
658 let inner = self.inner.create_buffer(byte_size);
659 GpuBuffer::new(inner, size)
660 }
661
662 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
664 let buffer = self.create_buffer::<T>(data.len());
665 let _ = buffer.copy_from_host(data);
666 buffer
667 }
668
669 pub fn execute<F, R>(&self, f: F) -> R
671 where
672 F: FnOnce(&GpuCompiler) -> R,
673 {
674 let compiler = GpuCompiler::new(self.inner.create_compiler());
675 f(&compiler)
676 }
677
678 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
680 let kernel = self
681 .kernel_registry
682 .get(name)
683 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
684
685 let kernel_source = kernel.source_for_backend(self.backend)?;
686 let metadata = kernel.metadata();
687
688 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
689 Ok(handle)
690 }
691
692 pub fn get_specialized_kernel(
694 &self,
695 name: &str,
696 params: &kernels::KernelParams,
697 ) -> Result<GpuKernelHandle, GpuError> {
698 let specialized = self.kernel_registry.get_specialized(name, params)?;
699 let kernel_source = specialized.source_for_backend(self.backend)?;
700 let metadata = specialized.metadata();
701
702 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
703 Ok(handle)
704 }
705
706 fn compile_kernel_with_metadata(
708 &self,
709 source: &str,
710 _metadata: &kernels::KernelMetadata,
711 ) -> Result<GpuKernelHandle, GpuError> {
712 self.execute(|compiler| compiler.compile(source))
713 }
714
715 pub fn get_available_memory(&self) -> Option<usize> {
717 Some(1024 * 1024 * 1024) }
721
722 pub fn get_total_memory(&self) -> Option<usize> {
724 Some(4 * 1024 * 1024 * 1024) }
728
729 pub fn launch_kernel(
731 &self,
732 kernel_name: &str,
733 grid_size: (usize, usize, usize),
734 block_size: (usize, usize, usize),
735 args: &[DynamicKernelArg],
736 ) -> Result<(), GpuError> {
737 let _ = (kernel_name, grid_size, block_size, args);
739 Ok(())
740 }
741
742 pub fn transfer_async_host_to_device<T: GpuDataType>(
744 &self,
745 ptr: &GpuPtr<T>,
746 data: &[T],
747 ) -> Result<(), GpuError> {
748 let _ = (ptr, data);
750 Ok(())
751 }
752
753 pub fn transfer_host_to_device<T: GpuDataType>(
755 &self,
756 ptr: &GpuPtr<T>,
757 data: &[T],
758 ) -> Result<(), GpuError> {
759 let _ = (ptr, data);
761 Ok(())
762 }
763
764 pub fn transfer_async_device_to_host<T: GpuDataType>(
766 &self,
767 ptr: &GpuPtr<T>,
768 data: &mut [T],
769 ) -> Result<(), GpuError> {
770 let _ = (ptr, data);
772 Ok(())
773 }
774
775 pub fn transfer_device_to_host<T: GpuDataType>(
777 &self,
778 ptr: &GpuPtr<T>,
779 data: &mut [T],
780 ) -> Result<(), GpuError> {
781 let _ = (ptr, data);
783 Ok(())
784 }
785
786 pub fn execute_kernel(
789 &self,
790 source: &str,
791 buffers: &[GpuBuffer<f32>],
792 work_groups: (u32, u32, u32),
793 int_params: &[u32],
794 float_params: &[f32],
795 ) -> Result<(), GpuError> {
796 eprintln!(
799 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
800 source.len(),
801 buffers.len(),
802 work_groups
803 );
804 eprintln!("Int params: {int_params:?}");
805 eprintln!("Float params: {float_params:?}");
806 Ok(())
807 }
808
809 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
812 Ok(buffer.to_vec())
813 }
814}
815
816pub(crate) trait GpuBufferImpl: Send + Sync {
821 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
823
824 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
826
827 #[allow(dead_code)]
829 fn as_any(&self) -> &dyn std::any::Any;
830
831 #[allow(dead_code)]
833 fn size(&self) -> usize {
834 0 }
836
837 #[allow(dead_code)]
839 fn device_ptr(&self) -> u64 {
840 0 }
842}
843
844pub(crate) trait GpuKernelImpl: Send + Sync {
846 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
848
849 fn set_u32(&self, name: &str, value: u32);
851
852 fn set_i32(&self, name: &str, value: i32);
854
855 fn set_f32(&self, name: &str, value: f32);
857
858 fn set_f64(&self, name: &str, value: f64);
860
861 fn dispatch(&self, workgroups: [u32; 3]);
863}
864
865pub(crate) trait GpuCompilerImpl: Send + Sync {
867 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
869
870 fn compile_typed(
872 &self,
873 name: &str,
874 input_type: std::any::TypeId,
875 output_type: std::any::TypeId,
876 ) -> Arc<dyn GpuKernelImpl>;
877}
878
879pub(crate) trait GpuContextImpl: Send + Sync {
881 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
883
884 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
886
887 fn as_any(&self) -> &dyn std::any::Any
889 where
890 Self: 'static + Sized,
891 {
892 self
893 }
894}
895
896struct CpuContext;
900
901impl CpuContext {
902 fn new() -> Self {
904 Self
905 }
906}
907
908impl GpuContextImpl for CpuContext {
909 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
910 Arc::new(CpuBuffer::new(size))
911 }
912
913 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
914 Arc::new(CpuCompiler)
915 }
916}
917
918struct CpuBuffer {
920 data: Vec<u8>,
921}
922
923impl CpuBuffer {
924 fn new(size: usize) -> Self {
926 Self {
927 data: vec![0; size],
928 }
929 }
930}
931
932impl GpuBufferImpl for CpuBuffer {
933 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
934 let mut_self = self as *const Self as *mut Self;
935 let data_ptr = (*mut_self).data.as_mut_ptr();
936 std::ptr::copy_nonoverlapping(data, data_ptr, size);
937 }
938
939 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
940 let data_ptr = self.data.as_ptr();
941 std::ptr::copy_nonoverlapping(data_ptr, data, size);
942 }
943
944 fn as_any(&self) -> &dyn std::any::Any {
945 self
946 }
947
948 fn size(&self) -> usize {
949 self.data.len()
950 }
951
952 fn device_ptr(&self) -> u64 {
953 self.data.as_ptr() as u64
954 }
955}
956
957struct CpuCompiler;
959
960impl GpuCompilerImpl for CpuCompiler {
961 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
962 Ok(Arc::new(CpuKernel))
965 }
966
967 fn compile_typed(
968 &self,
969 _name: &str,
970 _input_type: std::any::TypeId,
971 _output_type: std::any::TypeId,
972 ) -> Arc<dyn GpuKernelImpl> {
973 Arc::new(CpuKernel)
976 }
977}
978
979struct CpuKernel;
981
982impl GpuKernelImpl for CpuKernel {
983 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
984 }
986
987 fn set_u32(&self, _name: &str, value: u32) {
988 }
990
991 fn set_i32(&self, _name: &str, value: i32) {
992 }
994
995 fn set_f32(&self, _name: &str, value: f32) {
996 }
998
999 fn set_f64(&self, _name: &str, value: f64) {
1000 }
1002
1003 fn dispatch(&self, workgroups: [u32; 3]) {
1004 }
1006}
1007
1008#[cfg(test)]
1012mod tests {
1013 use super::*;
1014
1015 #[test]
1016 fn test_gpu_backend_preferred() {
1017 let backend = GpuBackend::preferred();
1018 match backend {
1020 GpuBackend::Cuda
1021 | GpuBackend::Rocm
1022 | GpuBackend::Wgpu
1023 | GpuBackend::Metal
1024 | GpuBackend::OpenCL
1025 | GpuBackend::Cpu => {}
1026 }
1027 }
1028
1029 #[test]
1030 fn test_gpu_backend_default() {
1031 let backend = GpuBackend::default();
1032 assert_eq!(backend, GpuBackend::preferred());
1033 }
1034
1035 #[test]
1036 fn test_gpu_backend_is_available() {
1037 let backend = GpuBackend::Cpu;
1038 assert!(backend.is_available());
1039
1040 #[cfg(feature = "cuda")]
1042 {
1043 let _ = GpuBackend::Cuda.is_available(); }
1046 #[cfg(not(feature = "cuda"))]
1047 assert!(!GpuBackend::Cuda.is_available());
1048
1049 #[cfg(feature = "rocm")]
1050 {
1051 let _ = GpuBackend::Rocm.is_available(); }
1054 #[cfg(not(feature = "rocm"))]
1055 assert!(!GpuBackend::Rocm.is_available());
1056
1057 #[cfg(all(feature = "metal", target_os = "macos"))]
1058 assert!(GpuBackend::Metal.is_available());
1059 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1060 assert!(!GpuBackend::Metal.is_available());
1061 }
1062
1063 #[test]
1064 fn test_gpu_backend_display() {
1065 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1066 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1067 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1068 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1069 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1070 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1071 }
1072
1073 #[test]
1074 fn test_gpuerror_from_conversion() {
1075 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1076 let coreerror: CoreError = gpuerror.into();
1077 match coreerror {
1078 CoreError::ComputationError(_) => {}
1079 _ => panic!("Expected ComputationError"),
1080 }
1081
1082 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1083 let coreerror: CoreError = gpuerror.into();
1084 match coreerror {
1085 CoreError::MemoryError(_) => {}
1086 _ => panic!("Expected MemoryError"),
1087 }
1088
1089 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1090 let coreerror: CoreError = gpuerror.into();
1091 match coreerror {
1092 CoreError::InvalidArgument(_) => {}
1093 _ => panic!("Expected InvalidArgument"),
1094 }
1095
1096 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1097 let coreerror: CoreError = gpuerror.into();
1098 match coreerror {
1099 CoreError::TypeError(_) => {}
1100 _ => panic!("Expected TypeError"),
1101 }
1102 }
1103
1104 #[test]
1105 fn test_gpu_datatype_trait() {
1106 fn assert_gpu_datatype<T: GpuDataType>() {}
1108
1109 assert_gpu_datatype::<f32>();
1110 assert_gpu_datatype::<f64>();
1111 assert_gpu_datatype::<i32>();
1112 assert_gpu_datatype::<u32>();
1113 assert_gpu_datatype::<u8>();
1114 assert_gpu_datatype::<i8>();
1115 assert_gpu_datatype::<u16>();
1116 assert_gpu_datatype::<i16>();
1117 assert_gpu_datatype::<u64>();
1118 assert_gpu_datatype::<i64>();
1119 }
1120
1121 #[test]
1122 fn test_gpu_buffer_creation() {
1123 let inner = Arc::new(CpuBuffer::new(100));
1124 let buffer = GpuBuffer::<f32>::new(inner, 25);
1125
1126 assert_eq!(buffer.len(), 25);
1127 assert!(!buffer.is_empty());
1128 }
1129
1130 #[test]
1131 fn test_gpu_buffer_empty() {
1132 let inner = Arc::new(CpuBuffer::new(0));
1133 let buffer = GpuBuffer::<f32>::new(inner, 0);
1134
1135 assert_eq!(buffer.len(), 0);
1136 assert!(buffer.is_empty());
1137 }
1138
1139 #[test]
1140 fn test_gpu_buffer_copy_operations() {
1141 let inner = Arc::new(CpuBuffer::new(16));
1142 let buffer = GpuBuffer::<f32>::new(inner, 4);
1143
1144 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1145 let _ = buffer.copy_from_host(&data);
1146
1147 let mut result = vec![0.0f32; 4];
1148 let _ = buffer.copy_to_host(&mut result);
1149
1150 assert_eq!(result, data);
1151 }
1152
1153 #[test]
1154 fn test_gpu_buffer_to_vec() {
1155 let inner = Arc::new(CpuBuffer::new(12));
1156 let buffer = GpuBuffer::<f32>::new(inner, 3);
1157
1158 let data = vec![5.0f32, 6.0, 7.0];
1159 let _ = buffer.copy_from_host(&data);
1160
1161 let result = buffer.to_vec();
1162 assert_eq!(result, data);
1163 }
1164
1165 #[test]
1166 #[should_panic(expected = "Data size exceeds buffer size")]
1167 fn test_gpu_buffer_copy_from_host_overflow() {
1168 let inner = Arc::new(CpuBuffer::new(8));
1169 let buffer = GpuBuffer::<f32>::new(inner, 2);
1170
1171 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1173 }
1174
1175 #[test]
1176 #[should_panic(expected = "Data size exceeds buffer size")]
1177 fn test_gpu_buffer_copy_to_host_overflow() {
1178 let inner = Arc::new(CpuBuffer::new(8));
1179 let buffer = GpuBuffer::<f32>::new(inner, 2);
1180
1181 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1183 }
1184
1185 #[test]
1186 fn test_gpu_kernel_handle() {
1187 let kernel = Arc::new(CpuKernel);
1188 let handle = GpuKernelHandle::new(kernel);
1189
1190 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1192 handle.set_buffer("input", &buffer);
1193 handle.set_u32("size", 100);
1194 handle.set_i32("offset", -5);
1195 handle.set_f32("scale", 2.5);
1196 handle.set_f64("precision", 0.0001);
1197
1198 handle.dispatch([16, 8, 1]);
1200 }
1201
1202 #[test]
1203 fn test_gpu_context_cpu_backend() {
1204 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1205 assert_eq!(context.backend(), GpuBackend::Cpu);
1206 assert_eq!(context.backend_name(), "CPU");
1207
1208 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1210 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1211 }
1212
1213 #[test]
1214 fn test_gpu_context_buffer_creation() {
1215 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1216
1217 let buffer = context.create_buffer::<f32>(100);
1218 assert_eq!(buffer.len(), 100);
1219
1220 let data = vec![1.0f32; 50];
1221 let buffer_from_slice = context.create_buffer_from_slice(&data);
1222 assert_eq!(buffer_from_slice.len(), 50);
1223
1224 let result = buffer_from_slice.to_vec();
1225 assert_eq!(result, data);
1226 }
1227
1228 #[test]
1229 fn test_gpu_context_unsupported_backend() {
1230 #[cfg(not(feature = "cuda"))]
1232 {
1233 let result = GpuContext::new(GpuBackend::Cuda);
1234 assert!(result.is_err());
1235 match result {
1236 Err(GpuError::UnsupportedBackend(_)) => {}
1237 Err(GpuError::BackendNotAvailable(_)) => {} Err(e) => panic!(
1239 "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1240 e
1241 ),
1242 Ok(_) => panic!("Expected error, got Ok"),
1243 }
1244 }
1245 }
1246
1247 #[test]
1248 fn test_gpu_compiler() {
1249 let compiler_impl = Arc::new(CpuCompiler);
1250 let compiler = GpuCompiler::new(compiler_impl);
1251
1252 let kernel = compiler
1254 .compile("dummy kernel source")
1255 .expect("Operation failed");
1256 kernel.dispatch([1, 1, 1]);
1257
1258 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1260 typed_kernel.dispatch([32, 1, 1]);
1261 }
1262
1263 #[test]
1264 fn test_gpu_context_execute() {
1265 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1266
1267 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1268
1269 assert!(result);
1270 }
1271
1272 #[test]
1273 fn test_gpu_context_kernel_registry() {
1274 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1275
1276 let result = context.get_kernel("non_existent_kernel");
1278 assert!(result.is_err());
1279 match result {
1280 Err(GpuError::KernelNotFound(_)) => {}
1281 _ => panic!("Expected KernelNotFound error"),
1282 }
1283 }
1284
1285 #[test]
1286 fn test_cpu_buffer_implementation() {
1287 let buffer = CpuBuffer::new(256);
1288 assert_eq!(buffer.data.len(), 256);
1289
1290 assert!(buffer.data.iter().all(|&b| b == 0));
1292 }
1293
1294 #[test]
1295 fn test_gpuerror_display() {
1296 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1297 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1298
1299 let error = GpuError::OutOfMemory("allocation failed".to_string());
1300 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1301
1302 let error = GpuError::KernelCompilationError("syntax error".to_string());
1303 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1304
1305 let error = GpuError::KernelNotFound("gemm".to_string());
1306 assert_eq!(error.to_string(), "Kernel not found: gemm");
1307 }
1308
1309 #[test]
1310 fn test_backend_equality() {
1311 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1312 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1313
1314 let backend = GpuBackend::Metal;
1316 let cloned = backend;
1317 let copied = backend;
1318 assert_eq!(backend, cloned);
1319 assert_eq!(backend, copied);
1320 }
1321
1322 #[test]
1323 fn test_backend_hash() {
1324 use std::collections::HashSet;
1325
1326 let mut set = HashSet::new();
1327 set.insert(GpuBackend::Cuda);
1328 set.insert(GpuBackend::Rocm);
1329 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1333 assert!(set.contains(&GpuBackend::Rocm));
1334 }
1335}