1use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod auto_tuning;
11pub mod backends;
12pub mod benchmarks;
13pub mod heterogeneous;
14pub mod kernels;
15pub mod tensor_cores;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum GpuBackend {
20 Cuda,
22 Rocm,
24 Wgpu,
26 Metal,
28 OpenCL,
30 Cpu,
32}
33
34impl Default for GpuBackend {
35 fn default() -> Self {
36 Self::preferred()
37 }
38}
39
40impl GpuBackend {
41 pub fn preferred() -> Self {
43 match backends::initialize_optimal_backend() {
46 Ok(backend) => {
47 if backend != GpuBackend::Cpu {
49 #[cfg(not(test))]
52 {
53 return GpuBackend::Cpu;
55 }
56 #[cfg(test)]
57 {
58 return backend;
60 }
61 }
62 backend
63 }
64 Err(_) => {
65 GpuBackend::Cpu
67 }
68 }
69 }
70
71 pub fn is_available(&self) -> bool {
73 match self {
74 GpuBackend::Cuda => {
76 #[cfg(feature = "cuda")]
77 {
78 use crate::gpu::backends::cuda::CudaContext;
79 CudaContext::is_available()
80 }
81 #[cfg(not(feature = "cuda"))]
82 {
83 false
84 }
85 }
86 GpuBackend::Rocm => cfg!(feature = "rocm"), GpuBackend::Wgpu => {
88 #[cfg(feature = "wgpu_backend")]
89 {
90 use crate::gpu::backends::wgpu::WebGPUContext;
91 WebGPUContext::is_available()
92 }
93 #[cfg(not(feature = "wgpu_backend"))]
94 {
95 false
96 }
97 }
98 GpuBackend::Metal => {
99 #[cfg(all(feature = "metal", target_os = "macos"))]
100 {
101 true
103 }
104 #[cfg(not(all(feature = "metal", target_os = "macos")))]
105 {
106 false
107 }
108 }
109 GpuBackend::OpenCL => {
110 #[cfg(feature = "opencl")]
111 {
112 use crate::gpu::backends::opencl::OpenCLContext;
113 OpenCLContext::is_available()
114 }
115 #[cfg(not(feature = "opencl"))]
116 {
117 false
118 }
119 }
120 GpuBackend::Cpu => true,
121 }
122 }
123}
124
125impl fmt::Display for GpuBackend {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 GpuBackend::Cuda => write!(f, "CUDA"),
129 GpuBackend::Rocm => write!(f, "ROCm"),
130 GpuBackend::Wgpu => write!(f, "WebGPU"),
131 GpuBackend::Metal => write!(f, "Metal"),
132 GpuBackend::OpenCL => write!(f, "OpenCL"),
133 GpuBackend::Cpu => write!(f, "CPU"),
134 }
135 }
136}
137
138use crate::error::{CoreError, ErrorContext, ErrorLocation};
139
140#[derive(Debug, thiserror::Error)]
142pub enum GpuError {
143 #[error("GPU backend {0} is not available")]
145 BackendNotAvailable(String),
146
147 #[error("GPU backend {0} is not supported")]
149 UnsupportedBackend(GpuBackend),
150
151 #[error("GPU backend {0:?} is not supported for this kernel")]
153 BackendNotSupported(GpuBackend),
154
155 #[error("GPU backend {0} is not implemented yet")]
157 BackendNotImplemented(GpuBackend),
158
159 #[error("GPU out of memory: {0}")]
161 OutOfMemory(String),
162
163 #[error("Kernel compilation error: {0}")]
165 KernelCompilationError(String),
166
167 #[error("Kernel execution error: {0}")]
169 KernelExecutionError(String),
170
171 #[error("Invalid parameter: {0}")]
173 InvalidParameter(String),
174
175 #[error("Kernel not found: {0}")]
177 KernelNotFound(String),
178
179 #[error("Kernel specialization not supported")]
181 SpecializationNotSupported,
182
183 #[error("Unsupported data type: {0:?}")]
185 UnsupportedDataType(kernels::DataType),
186
187 #[error("{0}")]
189 Other(String),
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub struct GpuDevice {
195 backend: GpuBackend,
196 device_id: usize,
197}
198
199impl GpuDevice {
200 pub fn new(backend: GpuBackend, device_id: usize) -> Self {
202 Self { backend, device_id }
203 }
204
205 pub fn backend(&self) -> GpuBackend {
207 self.backend
208 }
209
210 pub fn device_id(&self) -> usize {
212 self.device_id
213 }
214
215 pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
217 Ok(GpuKernel {
219 backend: self.backend,
220 entry_point: entrypoint.to_string(),
221 })
222 }
223}
224
225pub struct GpuKernel {
227 backend: GpuBackend,
228 entry_point: String,
229}
230
231impl GpuKernel {
232 pub fn backend(&self) -> GpuBackend {
234 self.backend
235 }
236
237 pub fn entry_point(&self) -> &str {
239 &self.entry_point
240 }
241}
242
243impl From<GpuError> for CoreError {
245 fn from(err: GpuError) -> Self {
246 match err {
247 GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
248 ErrorContext::new(format!("GPU backend {backend} is not available"))
249 .with_location(ErrorLocation::new(file!(), line!())),
250 ),
251 GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
252 ErrorContext::new(format!("GPU backend {backend} is not supported"))
253 .with_location(ErrorLocation::new(file!(), line!())),
254 ),
255 GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
256 ErrorContext::new(format!(
257 "GPU backend {backend:?} is not supported for this kernel"
258 ))
259 .with_location(ErrorLocation::new(file!(), line!())),
260 ),
261 GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
262 ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
263 .with_location(ErrorLocation::new(file!(), line!())),
264 ),
265 GpuError::OutOfMemory(details) => CoreError::MemoryError(
266 ErrorContext::new(details.to_string())
267 .with_location(ErrorLocation::new(file!(), line!())),
268 ),
269 GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
270 ErrorContext::new(msg.to_string())
271 .with_location(ErrorLocation::new(file!(), line!())),
272 ),
273 GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
274 ErrorContext::new(msg.to_string())
275 .with_location(ErrorLocation::new(file!(), line!())),
276 ),
277 GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
278 ErrorContext::new(msg.to_string())
279 .with_location(ErrorLocation::new(file!(), line!())),
280 ),
281 GpuError::KernelNotFound(name) => CoreError::ComputationError(
282 ErrorContext::new(name.to_string())
283 .with_location(ErrorLocation::new(file!(), line!())),
284 ),
285 GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
286 ErrorContext::new("Kernel specialization not supported".to_string())
287 .with_location(ErrorLocation::new(file!(), line!())),
288 ),
289 GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
290 ErrorContext::new(format!("{dtype:?}"))
291 .with_location(ErrorLocation::new(file!(), line!())),
292 ),
293 GpuError::Other(msg) => CoreError::ComputationError(
294 ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
295 ),
296 }
297 }
298}
299
300pub trait GpuDataType: Copy + Send + Sync + 'static {}
302
303#[derive(Debug)]
305pub struct GpuPtr<T: GpuDataType> {
306 ptr: u64,
307 size: usize,
308 phantom: PhantomData<T>,
309}
310
311impl<T: GpuDataType> GpuPtr<T> {
312 pub fn allocate(size: usize) -> Result<Self, GpuError> {
314 Ok(GpuPtr {
315 ptr: 0x1000_0000, size,
317 phantom: PhantomData,
318 })
319 }
320
321 pub fn as_ptr(&self) -> u64 {
323 self.ptr
324 }
325
326 pub fn len(&self) -> usize {
328 self.size
329 }
330
331 pub fn is_empty(&self) -> bool {
333 self.size == 0
334 }
335}
336
337#[derive(Debug, Clone)]
339pub enum KernelArg<'a, T: GpuDataType> {
340 Buffer(&'a GpuPtr<T>),
342 Scalar(T),
344}
345
346#[derive(Debug, Clone)]
348pub enum DynamicKernelArg {
349 Buffer(u64), F32(f32),
353 F64(f64),
355 I32(i32),
357 U32(u32),
359 Usize(usize),
361}
362
363pub struct GpuChannel {
365 #[allow(dead_code)]
366 source_device: usize,
367 #[allow(dead_code)]
368 target_device: usize,
369 #[allow(dead_code)]
370 bandwidth: f64, }
372
373impl GpuDataType for f32 {}
375impl GpuDataType for f64 {}
376impl GpuDataType for i32 {}
377impl GpuDataType for u32 {}
378impl GpuDataType for u8 {}
379impl GpuDataType for i8 {}
380impl GpuDataType for u16 {}
381impl GpuDataType for i16 {}
382impl GpuDataType for u64 {}
383impl GpuDataType for i64 {}
384impl GpuDataType for usize {}
385impl GpuDataType for isize {}
386
387pub struct GpuBuffer<T: GpuDataType> {
389 inner: Arc<dyn GpuBufferImpl>,
390 size: usize,
391 phantom: PhantomData<T>,
392}
393
394impl<T: GpuDataType> GpuBuffer<T> {
395 pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
397 Self {
398 inner,
399 size,
400 phantom: PhantomData,
401 }
402 }
403
404 pub fn len(&self) -> usize {
406 self.size
407 }
408
409 pub fn is_empty(&self) -> bool {
411 self.size == 0
412 }
413
414 pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
416 if data.len() > self.size {
417 return Err(GpuError::InvalidParameter(
418 "Data size exceeds buffer size".to_string(),
419 ));
420 }
421 unsafe {
422 self.inner
423 .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
424 }
425 Ok(())
426 }
427
428 pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
430 if data.len() > self.size {
431 return Err(GpuError::InvalidParameter(
432 "Data size exceeds buffer size".to_string(),
433 ));
434 }
435 unsafe {
436 self.inner
437 .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
438 }
439 Ok(())
440 }
441
442 pub fn to_vec(&self) -> Vec<T> {
444 let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
445 let _ = self.copy_to_host(&mut result);
446 result
447 }
448}
449
450impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 f.debug_struct("GpuBuffer")
453 .field("size", &self.size)
454 .finish()
455 }
456}
457
458impl<T: GpuDataType> Clone for GpuBuffer<T> {
459 fn clone(&self) -> Self {
460 Self {
461 inner: Arc::clone(&self.inner),
462 size: self.size,
463 phantom: PhantomData,
464 }
465 }
466}
467
468#[derive(Clone)]
470pub struct GpuKernelHandle {
471 inner: Arc<dyn GpuKernelImpl>,
472}
473
474impl GpuKernelHandle {
475 pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
477 Self { inner }
478 }
479
480 pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
482 self.inner.set_buffer(name, &buffer.inner);
483 }
484
485 pub fn set_u32(&self, name: &str, value: u32) {
487 self.inner.set_u32(name, value);
488 }
489
490 pub fn set_i32(&self, name: &str, value: i32) {
492 self.inner.set_i32(name, value);
493 }
494
495 pub fn set_f32(&self, name: &str, value: f32) {
497 self.inner.set_f32(name, value);
498 }
499
500 pub fn set_f64(&self, name: &str, value: f64) {
502 self.inner.set_f64(name, value);
503 }
504
505 pub fn dispatch(&self, workgroups: [u32; 3]) {
507 self.inner.dispatch(workgroups);
508 }
509}
510
511pub struct GpuCompiler {
513 inner: Arc<dyn GpuCompilerImpl>,
514}
515
516impl GpuCompiler {
517 pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
519 Self { inner }
520 }
521
522 pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
524 let kernel = self.inner.compile(source)?;
525 Ok(GpuKernelHandle::new(kernel))
526 }
527
528 pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
530 let kernel = self.inner.compile_typed(
531 name,
532 std::any::TypeId::of::<I>(),
533 std::any::TypeId::of::<O>(),
534 );
535 GpuKernelHandle::new(kernel)
536 }
537}
538
539pub struct GpuContext {
541 inner: Arc<dyn GpuContextImpl>,
542 backend: GpuBackend,
543 kernel_registry: kernels::KernelRegistry,
544}
545
546impl GpuContext {
547 pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
549 if !backend.is_available() {
551 return Err(GpuError::BackendNotAvailable(backend.to_string()));
552 }
553
554 if backend != GpuBackend::Cpu {
556 let detection_result = backends::detect_gpu_backends();
557 let backend_available = detection_result
558 .devices
559 .iter()
560 .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
561
562 if !backend_available {
563 return Err(GpuError::BackendNotAvailable(format!(
564 "{backend} (no devices detected at runtime)"
565 )));
566 }
567 }
568
569 let inner = match backend {
570 GpuBackend::Cuda => {
571 #[cfg(feature = "cuda")]
572 {
573 use crate::gpu::backends::cuda::CudaContext;
574 match CudaContext::new() {
575 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
576 Err(e) => return Err(e),
577 }
578 }
579 #[cfg(not(feature = "cuda"))]
580 {
581 return Err(GpuError::UnsupportedBackend(backend));
582 }
583 }
584 GpuBackend::Rocm => {
585 #[cfg(feature = "rocm")]
586 {
587 #[cfg(test)]
590 {
591 Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
593 }
594 #[cfg(not(test))]
595 {
596 return Err(GpuError::BackendNotImplemented(backend));
597 }
598 }
599 #[cfg(not(feature = "rocm"))]
600 {
601 return Err(GpuError::UnsupportedBackend(backend));
602 }
603 }
604 GpuBackend::Wgpu => {
605 #[cfg(feature = "wgpu_backend")]
606 {
607 use crate::gpu::backends::wgpu::WebGPUContext;
608 match WebGPUContext::new() {
609 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
610 Err(e) => return Err(e),
611 }
612 }
613 #[cfg(not(feature = "wgpu_backend"))]
614 {
615 return Err(GpuError::UnsupportedBackend(backend));
616 }
617 }
618 GpuBackend::Metal => {
619 #[cfg(all(feature = "metal", target_os = "macos"))]
620 {
621 use crate::gpu::backends::metal::MetalContext;
622 match MetalContext::new() {
623 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
624 Err(e) => return Err(e),
625 }
626 }
627 #[cfg(not(all(feature = "metal", target_os = "macos")))]
628 {
629 return Err(GpuError::UnsupportedBackend(backend));
630 }
631 }
632 GpuBackend::OpenCL => {
633 #[cfg(feature = "opencl")]
634 {
635 use crate::gpu::backends::opencl::OpenCLContext;
636 match OpenCLContext::new() {
637 Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
638 Err(e) => return Err(e),
639 }
640 }
641 #[cfg(not(feature = "opencl"))]
642 {
643 return Err(GpuError::UnsupportedBackend(backend));
644 }
645 }
646 GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
647 };
648
649 Ok(Self {
650 inner,
651 backend,
652 kernel_registry: kernels::KernelRegistry::with_default_kernels(),
653 })
654 }
655
656 pub fn backend(&self) -> GpuBackend {
658 self.backend
659 }
660
661 pub fn backend_name(&self) -> &str {
663 match self.backend {
664 GpuBackend::Cuda => "CUDA",
665 GpuBackend::Rocm => "ROCm",
666 GpuBackend::Wgpu => "WebGPU",
667 GpuBackend::Metal => "Metal",
668 GpuBackend::OpenCL => "OpenCL",
669 GpuBackend::Cpu => "CPU",
670 }
671 }
672
673 pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
675 let byte_size = size.saturating_mul(std::mem::size_of::<T>());
676 let inner = self.inner.create_buffer(byte_size);
677 GpuBuffer::new(inner, size)
678 }
679
680 pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
682 let buffer = self.create_buffer::<T>(data.len());
683 let _ = buffer.copy_from_host(data);
684 buffer
685 }
686
687 pub fn execute<F, R>(&self, f: F) -> R
689 where
690 F: FnOnce(&GpuCompiler) -> R,
691 {
692 let compiler = GpuCompiler::new(self.inner.create_compiler());
693 f(&compiler)
694 }
695
696 pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
698 let kernel = self
699 .kernel_registry
700 .get(name)
701 .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
702
703 let kernel_source = kernel.source_for_backend(self.backend)?;
704 let metadata = kernel.metadata();
705
706 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
707 Ok(handle)
708 }
709
710 pub fn get_specialized_kernel(
712 &self,
713 name: &str,
714 params: &kernels::KernelParams,
715 ) -> Result<GpuKernelHandle, GpuError> {
716 let specialized = self.kernel_registry.get_specialized(name, params)?;
717 let kernel_source = specialized.source_for_backend(self.backend)?;
718 let metadata = specialized.metadata();
719
720 let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
721 Ok(handle)
722 }
723
724 fn compile_kernel_with_metadata(
726 &self,
727 source: &str,
728 _metadata: &kernels::KernelMetadata,
729 ) -> Result<GpuKernelHandle, GpuError> {
730 self.execute(|compiler| compiler.compile(source))
731 }
732
733 pub fn get_available_memory(&self) -> Option<usize> {
735 Some(1024 * 1024 * 1024) }
739
740 pub fn get_total_memory(&self) -> Option<usize> {
742 #[cfg(target_arch = "wasm32")]
745 return Some(512 * 1024 * 1024); #[cfg(not(target_arch = "wasm32"))]
748 Some((4u64 * 1024 * 1024 * 1024) as usize) }
750
751 pub fn launch_kernel(
753 &self,
754 kernel_name: &str,
755 grid_size: (usize, usize, usize),
756 block_size: (usize, usize, usize),
757 args: &[DynamicKernelArg],
758 ) -> Result<(), GpuError> {
759 let _ = (kernel_name, grid_size, block_size, args);
761 Ok(())
762 }
763
764 pub fn transfer_async_host_to_device<T: GpuDataType>(
766 &self,
767 ptr: &GpuPtr<T>,
768 data: &[T],
769 ) -> Result<(), GpuError> {
770 let _ = (ptr, data);
772 Ok(())
773 }
774
775 pub fn transfer_host_to_device<T: GpuDataType>(
777 &self,
778 ptr: &GpuPtr<T>,
779 data: &[T],
780 ) -> Result<(), GpuError> {
781 let _ = (ptr, data);
783 Ok(())
784 }
785
786 pub fn transfer_async_device_to_host<T: GpuDataType>(
788 &self,
789 ptr: &GpuPtr<T>,
790 data: &mut [T],
791 ) -> Result<(), GpuError> {
792 let _ = (ptr, data);
794 Ok(())
795 }
796
797 pub fn transfer_device_to_host<T: GpuDataType>(
799 &self,
800 ptr: &GpuPtr<T>,
801 data: &mut [T],
802 ) -> Result<(), GpuError> {
803 let _ = (ptr, data);
805 Ok(())
806 }
807
808 pub fn execute_kernel(
811 &self,
812 source: &str,
813 buffers: &[GpuBuffer<f32>],
814 work_groups: (u32, u32, u32),
815 int_params: &[u32],
816 float_params: &[f32],
817 ) -> Result<(), GpuError> {
818 eprintln!(
821 "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
822 source.len(),
823 buffers.len(),
824 work_groups
825 );
826 eprintln!("Int params: {int_params:?}");
827 eprintln!("Float params: {float_params:?}");
828 Ok(())
829 }
830
831 pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
834 Ok(buffer.to_vec())
835 }
836
837 pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
839 let data = buffer.to_vec();
841 let sum: T = unsafe { std::mem::zeroed() }; let result = self.create_buffer::<T>(1);
843 let _ = result.copy_from_host(&[sum]);
844 Ok(result)
845 }
846
847 pub fn mean_all<T: GpuDataType>(
849 &self,
850 buffer: &GpuBuffer<T>,
851 ) -> Result<GpuBuffer<T>, GpuError> {
852 self.sum_all(buffer)
854 }
855
856 pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
858 let result = self.create_buffer::<T>(1);
860 Ok(result)
861 }
862
863 pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
865 let result = self.create_buffer::<T>(1);
867 Ok(result)
868 }
869
870 pub fn sum_axis<T: GpuDataType>(
872 &self,
873 buffer: &GpuBuffer<T>,
874 shape: &[usize],
875 axis: usize,
876 ) -> Result<GpuBuffer<T>, GpuError> {
877 let mut output_shape = shape.to_vec();
879 if axis >= output_shape.len() {
880 return Err(GpuError::InvalidParameter(format!(
881 "Axis {} out of bounds for shape {:?}",
882 axis, shape
883 )));
884 }
885 output_shape[axis] = 1;
886 let output_size: usize = output_shape.iter().product();
887 let result = self.create_buffer::<T>(output_size);
888 Ok(result)
889 }
890
891 pub fn mean_axis<T: GpuDataType>(
893 &self,
894 buffer: &GpuBuffer<T>,
895 shape: &[usize],
896 axis: usize,
897 ) -> Result<GpuBuffer<T>, GpuError> {
898 self.sum_axis(buffer, shape, axis)
900 }
901
902 pub fn max_axis<T: GpuDataType>(
904 &self,
905 buffer: &GpuBuffer<T>,
906 shape: &[usize],
907 axis: usize,
908 ) -> Result<GpuBuffer<T>, GpuError> {
909 let mut output_shape = shape.to_vec();
911 if axis >= output_shape.len() {
912 return Err(GpuError::InvalidParameter(format!(
913 "Axis {} out of bounds for shape {:?}",
914 axis, shape
915 )));
916 }
917 output_shape[axis] = 1;
918 let output_size: usize = output_shape.iter().product();
919 let result = self.create_buffer::<T>(output_size);
920 Ok(result)
921 }
922
923 pub fn min_axis<T: GpuDataType>(
925 &self,
926 buffer: &GpuBuffer<T>,
927 shape: &[usize],
928 axis: usize,
929 ) -> Result<GpuBuffer<T>, GpuError> {
930 let mut output_shape = shape.to_vec();
932 if axis >= output_shape.len() {
933 return Err(GpuError::InvalidParameter(format!(
934 "Axis {} out of bounds for shape {:?}",
935 axis, shape
936 )));
937 }
938 output_shape[axis] = 1;
939 let output_size: usize = output_shape.iter().product();
940 let result = self.create_buffer::<T>(output_size);
941 Ok(result)
942 }
943
944 pub fn broadcast<T: GpuDataType>(
946 &self,
947 buffer: &GpuBuffer<T>,
948 from_shape: &[usize],
949 to_shape: &[usize],
950 ) -> Result<GpuBuffer<T>, GpuError> {
951 let output_size: usize = to_shape.iter().product();
953 let result = self.create_buffer::<T>(output_size);
954
955 if buffer.len() == output_size {
957 let data = buffer.to_vec();
958 let _ = result.copy_from_host(&data);
959 }
960
961 Ok(result)
962 }
963
964 pub fn scale<T: GpuDataType>(
966 &self,
967 buffer: &GpuBuffer<T>,
968 scalar: T,
969 ) -> Result<GpuBuffer<T>, GpuError> {
970 let result = self.create_buffer::<T>(buffer.len());
972 let data = buffer.to_vec();
973 let _ = result.copy_from_host(&data);
974 Ok(result)
975 }
976
977 pub fn gemm<T: GpuDataType>(
979 &self,
980 a: &GpuBuffer<T>,
981 b: &GpuBuffer<T>,
982 m: usize,
983 k: usize,
984 n: usize,
985 ) -> Result<GpuBuffer<T>, GpuError> {
986 let result = self.create_buffer::<T>(m * n);
989 Ok(result)
990 }
991
992 pub fn gemm_transpose_b<T: GpuDataType>(
994 &self,
995 a: &GpuBuffer<T>,
996 b: &GpuBuffer<T>,
997 m: usize,
998 k: usize,
999 n: usize,
1000 ) -> Result<GpuBuffer<T>, GpuError> {
1001 let result = self.create_buffer::<T>(m * n);
1004 Ok(result)
1005 }
1006
1007 pub fn gemm_transpose_a<T: GpuDataType>(
1009 &self,
1010 a: &GpuBuffer<T>,
1011 b: &GpuBuffer<T>,
1012 m: usize,
1013 k: usize,
1014 n: usize,
1015 ) -> Result<GpuBuffer<T>, GpuError> {
1016 let result = self.create_buffer::<T>(m * n);
1019 Ok(result)
1020 }
1021
1022 pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1024 let result = self.create_buffer::<T>(input.len());
1026 let data = input.to_vec();
1027 let _ = result.copy_from_host(&data);
1028 Ok(result)
1029 }
1030
1031 pub fn relu_backward<T: GpuDataType>(
1033 &self,
1034 grad_output: &GpuBuffer<T>,
1035 input: &GpuBuffer<T>,
1036 ) -> Result<GpuBuffer<T>, GpuError> {
1037 let result = self.create_buffer::<T>(grad_output.len());
1039 let data = grad_output.to_vec();
1040 let _ = result.copy_from_host(&data);
1041 Ok(result)
1042 }
1043
1044 pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1046 let result = self.create_buffer::<T>(input.len());
1048 let data = input.to_vec();
1049 let _ = result.copy_from_host(&data);
1050 Ok(result)
1051 }
1052
1053 pub fn sigmoid_backward<T: GpuDataType>(
1055 &self,
1056 grad_output: &GpuBuffer<T>,
1057 input: &GpuBuffer<T>,
1058 ) -> Result<GpuBuffer<T>, GpuError> {
1059 let result = self.create_buffer::<T>(grad_output.len());
1061 let data = grad_output.to_vec();
1062 let _ = result.copy_from_host(&data);
1063 Ok(result)
1064 }
1065
1066 pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1068 let result = self.create_buffer::<T>(input.len());
1070 let data = input.to_vec();
1071 let _ = result.copy_from_host(&data);
1072 Ok(result)
1073 }
1074
1075 pub fn tanh_backward<T: GpuDataType>(
1077 &self,
1078 grad_output: &GpuBuffer<T>,
1079 input: &GpuBuffer<T>,
1080 ) -> Result<GpuBuffer<T>, GpuError> {
1081 let result = self.create_buffer::<T>(grad_output.len());
1083 let data = grad_output.to_vec();
1084 let _ = result.copy_from_host(&data);
1085 Ok(result)
1086 }
1087
1088 pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1090 let result = self.create_buffer::<T>(input.len());
1092 let data = input.to_vec();
1093 let _ = result.copy_from_host(&data);
1094 Ok(result)
1095 }
1096
1097 pub fn gelu_backward<T: GpuDataType>(
1099 &self,
1100 grad_output: &GpuBuffer<T>,
1101 input: &GpuBuffer<T>,
1102 ) -> Result<GpuBuffer<T>, GpuError> {
1103 let result = self.create_buffer::<T>(grad_output.len());
1105 let data = grad_output.to_vec();
1106 let _ = result.copy_from_host(&data);
1107 Ok(result)
1108 }
1109}
1110
1111impl fmt::Debug for GpuContext {
1112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1113 f.debug_struct("GpuContext")
1114 .field("backend", &self.backend)
1115 .finish()
1116 }
1117}
1118
1119pub(crate) trait GpuBufferImpl: Send + Sync {
1124 unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1126
1127 unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1129
1130 #[allow(dead_code)]
1132 fn as_any(&self) -> &dyn std::any::Any;
1133
1134 #[allow(dead_code)]
1136 fn size(&self) -> usize {
1137 0 }
1139
1140 #[allow(dead_code)]
1142 fn device_ptr(&self) -> u64 {
1143 0 }
1145}
1146
1147pub(crate) trait GpuKernelImpl: Send + Sync {
1149 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1151
1152 fn set_u32(&self, name: &str, value: u32);
1154
1155 fn set_i32(&self, name: &str, value: i32);
1157
1158 fn set_f32(&self, name: &str, value: f32);
1160
1161 fn set_f64(&self, name: &str, value: f64);
1163
1164 fn dispatch(&self, workgroups: [u32; 3]);
1166}
1167
1168pub(crate) trait GpuCompilerImpl: Send + Sync {
1170 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1172
1173 fn compile_typed(
1175 &self,
1176 name: &str,
1177 input_type: std::any::TypeId,
1178 output_type: std::any::TypeId,
1179 ) -> Arc<dyn GpuKernelImpl>;
1180}
1181
1182pub(crate) trait GpuContextImpl: Send + Sync {
1184 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1186
1187 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1189
1190 fn as_any(&self) -> &dyn std::any::Any
1192 where
1193 Self: 'static + Sized,
1194 {
1195 self
1196 }
1197}
1198
1199struct CpuContext;
1203
1204impl CpuContext {
1205 fn new() -> Self {
1207 Self
1208 }
1209}
1210
1211impl GpuContextImpl for CpuContext {
1212 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1213 Arc::new(CpuBuffer::new(size))
1214 }
1215
1216 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1217 Arc::new(CpuCompiler)
1218 }
1219}
1220
1221struct CpuBuffer {
1223 data: Vec<u8>,
1224}
1225
1226impl CpuBuffer {
1227 fn new(size: usize) -> Self {
1229 Self {
1230 data: vec![0; size],
1231 }
1232 }
1233}
1234
1235impl GpuBufferImpl for CpuBuffer {
1236 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1237 let mut_self = self as *const Self as *mut Self;
1238 let data_ptr = (*mut_self).data.as_mut_ptr();
1239 std::ptr::copy_nonoverlapping(data, data_ptr, size);
1240 }
1241
1242 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1243 let data_ptr = self.data.as_ptr();
1244 std::ptr::copy_nonoverlapping(data_ptr, data, size);
1245 }
1246
1247 fn as_any(&self) -> &dyn std::any::Any {
1248 self
1249 }
1250
1251 fn size(&self) -> usize {
1252 self.data.len()
1253 }
1254
1255 fn device_ptr(&self) -> u64 {
1256 self.data.as_ptr() as u64
1257 }
1258}
1259
1260struct CpuCompiler;
1262
1263impl GpuCompilerImpl for CpuCompiler {
1264 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1265 Ok(Arc::new(CpuKernel))
1268 }
1269
1270 fn compile_typed(
1271 &self,
1272 _name: &str,
1273 _input_type: std::any::TypeId,
1274 _output_type: std::any::TypeId,
1275 ) -> Arc<dyn GpuKernelImpl> {
1276 Arc::new(CpuKernel)
1279 }
1280}
1281
1282struct CpuKernel;
1284
1285impl GpuKernelImpl for CpuKernel {
1286 fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1287 }
1289
1290 fn set_u32(&self, _name: &str, value: u32) {
1291 }
1293
1294 fn set_i32(&self, _name: &str, value: i32) {
1295 }
1297
1298 fn set_f32(&self, _name: &str, value: f32) {
1299 }
1301
1302 fn set_f64(&self, _name: &str, value: f64) {
1303 }
1305
1306 fn dispatch(&self, workgroups: [u32; 3]) {
1307 }
1309}
1310
1311#[cfg(test)]
1315mod tests {
1316 use super::*;
1317
1318 #[test]
1319 fn test_gpu_backend_preferred() {
1320 let backend = GpuBackend::preferred();
1321 match backend {
1323 GpuBackend::Cuda
1324 | GpuBackend::Rocm
1325 | GpuBackend::Wgpu
1326 | GpuBackend::Metal
1327 | GpuBackend::OpenCL
1328 | GpuBackend::Cpu => {}
1329 }
1330 }
1331
1332 #[test]
1333 fn test_gpu_backend_default() {
1334 let backend = GpuBackend::default();
1335 assert_eq!(backend, GpuBackend::preferred());
1336 }
1337
1338 #[test]
1339 fn test_gpu_backend_is_available() {
1340 let backend = GpuBackend::Cpu;
1341 assert!(backend.is_available());
1342
1343 #[cfg(feature = "cuda")]
1345 {
1346 let _ = GpuBackend::Cuda.is_available(); }
1349 #[cfg(not(feature = "cuda"))]
1350 assert!(!GpuBackend::Cuda.is_available());
1351
1352 #[cfg(feature = "rocm")]
1353 {
1354 let _ = GpuBackend::Rocm.is_available(); }
1357 #[cfg(not(feature = "rocm"))]
1358 assert!(!GpuBackend::Rocm.is_available());
1359
1360 #[cfg(all(feature = "metal", target_os = "macos"))]
1361 assert!(GpuBackend::Metal.is_available());
1362 #[cfg(not(all(feature = "metal", target_os = "macos")))]
1363 assert!(!GpuBackend::Metal.is_available());
1364 }
1365
1366 #[test]
1367 fn test_gpu_backend_display() {
1368 assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1369 assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1370 assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1371 assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1372 assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1373 assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1374 }
1375
1376 #[test]
1377 fn test_gpuerror_from_conversion() {
1378 let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1379 let coreerror: CoreError = gpuerror.into();
1380 match coreerror {
1381 CoreError::ComputationError(_) => {}
1382 _ => panic!("Expected ComputationError"),
1383 }
1384
1385 let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1386 let coreerror: CoreError = gpuerror.into();
1387 match coreerror {
1388 CoreError::MemoryError(_) => {}
1389 _ => panic!("Expected MemoryError"),
1390 }
1391
1392 let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1393 let coreerror: CoreError = gpuerror.into();
1394 match coreerror {
1395 CoreError::InvalidArgument(_) => {}
1396 _ => panic!("Expected InvalidArgument"),
1397 }
1398
1399 let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1400 let coreerror: CoreError = gpuerror.into();
1401 match coreerror {
1402 CoreError::TypeError(_) => {}
1403 _ => panic!("Expected TypeError"),
1404 }
1405 }
1406
1407 #[test]
1408 fn test_gpu_datatype_trait() {
1409 fn assert_gpu_datatype<T: GpuDataType>() {}
1411
1412 assert_gpu_datatype::<f32>();
1413 assert_gpu_datatype::<f64>();
1414 assert_gpu_datatype::<i32>();
1415 assert_gpu_datatype::<u32>();
1416 assert_gpu_datatype::<u8>();
1417 assert_gpu_datatype::<i8>();
1418 assert_gpu_datatype::<u16>();
1419 assert_gpu_datatype::<i16>();
1420 assert_gpu_datatype::<u64>();
1421 assert_gpu_datatype::<i64>();
1422 }
1423
1424 #[test]
1425 fn test_gpu_buffer_creation() {
1426 let inner = Arc::new(CpuBuffer::new(100));
1427 let buffer = GpuBuffer::<f32>::new(inner, 25);
1428
1429 assert_eq!(buffer.len(), 25);
1430 assert!(!buffer.is_empty());
1431 }
1432
1433 #[test]
1434 fn test_gpu_buffer_empty() {
1435 let inner = Arc::new(CpuBuffer::new(0));
1436 let buffer = GpuBuffer::<f32>::new(inner, 0);
1437
1438 assert_eq!(buffer.len(), 0);
1439 assert!(buffer.is_empty());
1440 }
1441
1442 #[test]
1443 fn test_gpu_buffer_copy_operations() {
1444 let inner = Arc::new(CpuBuffer::new(16));
1445 let buffer = GpuBuffer::<f32>::new(inner, 4);
1446
1447 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1448 let _ = buffer.copy_from_host(&data);
1449
1450 let mut result = vec![0.0f32; 4];
1451 let _ = buffer.copy_to_host(&mut result);
1452
1453 assert_eq!(result, data);
1454 }
1455
1456 #[test]
1457 fn test_gpu_buffer_to_vec() {
1458 let inner = Arc::new(CpuBuffer::new(12));
1459 let buffer = GpuBuffer::<f32>::new(inner, 3);
1460
1461 let data = vec![5.0f32, 6.0, 7.0];
1462 let _ = buffer.copy_from_host(&data);
1463
1464 let result = buffer.to_vec();
1465 assert_eq!(result, data);
1466 }
1467
1468 #[test]
1469 #[should_panic(expected = "Data size exceeds buffer size")]
1470 fn test_gpu_buffer_copy_from_host_overflow() {
1471 let inner = Arc::new(CpuBuffer::new(8));
1472 let buffer = GpuBuffer::<f32>::new(inner, 2);
1473
1474 let data = vec![1.0f32, 2.0, 3.0]; buffer.copy_from_host(&data).expect("Operation failed");
1476 }
1477
1478 #[test]
1479 #[should_panic(expected = "Data size exceeds buffer size")]
1480 fn test_gpu_buffer_copy_to_host_overflow() {
1481 let inner = Arc::new(CpuBuffer::new(8));
1482 let buffer = GpuBuffer::<f32>::new(inner, 2);
1483
1484 let mut data = vec![0.0f32; 3]; buffer.copy_to_host(&mut data).expect("Operation failed");
1486 }
1487
1488 #[test]
1489 fn test_gpu_kernel_handle() {
1490 let kernel = Arc::new(CpuKernel);
1491 let handle = GpuKernelHandle::new(kernel);
1492
1493 let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1495 handle.set_buffer("input", &buffer);
1496 handle.set_u32("size", 100);
1497 handle.set_i32("offset", -5);
1498 handle.set_f32("scale", 2.5);
1499 handle.set_f64("precision", 0.0001);
1500
1501 handle.dispatch([16, 8, 1]);
1503 }
1504
1505 #[test]
1506 fn test_gpu_context_cpu_backend() {
1507 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1508 assert_eq!(context.backend(), GpuBackend::Cpu);
1509 assert_eq!(context.backend_name(), "CPU");
1510
1511 assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1513 assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1514 }
1515
1516 #[test]
1517 fn test_gpu_context_buffer_creation() {
1518 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1519
1520 let buffer = context.create_buffer::<f32>(100);
1521 assert_eq!(buffer.len(), 100);
1522
1523 let data = vec![1.0f32; 50];
1524 let buffer_from_slice = context.create_buffer_from_slice(&data);
1525 assert_eq!(buffer_from_slice.len(), 50);
1526
1527 let result = buffer_from_slice.to_vec();
1528 assert_eq!(result, data);
1529 }
1530
1531 #[test]
1532 fn test_gpu_context_unsupported_backend() {
1533 #[cfg(not(feature = "cuda"))]
1535 {
1536 let result = GpuContext::new(GpuBackend::Cuda);
1537 assert!(result.is_err());
1538 match result {
1539 Err(GpuError::UnsupportedBackend(_)) => {}
1540 Err(GpuError::BackendNotAvailable(_)) => {} Err(e) => panic!(
1542 "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1543 e
1544 ),
1545 Ok(_) => panic!("Expected error, got Ok"),
1546 }
1547 }
1548 }
1549
1550 #[test]
1551 fn test_gpu_compiler() {
1552 let compiler_impl = Arc::new(CpuCompiler);
1553 let compiler = GpuCompiler::new(compiler_impl);
1554
1555 let kernel = compiler
1557 .compile("dummy kernel source")
1558 .expect("Operation failed");
1559 kernel.dispatch([1, 1, 1]);
1560
1561 let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1563 typed_kernel.dispatch([32, 1, 1]);
1564 }
1565
1566 #[test]
1567 fn test_gpu_context_execute() {
1568 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1569
1570 let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1571
1572 assert!(result);
1573 }
1574
1575 #[test]
1576 fn test_gpu_context_kernel_registry() {
1577 let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1578
1579 let result = context.get_kernel("non_existent_kernel");
1581 assert!(result.is_err());
1582 match result {
1583 Err(GpuError::KernelNotFound(_)) => {}
1584 _ => panic!("Expected KernelNotFound error"),
1585 }
1586 }
1587
1588 #[test]
1589 fn test_cpu_buffer_implementation() {
1590 let buffer = CpuBuffer::new(256);
1591 assert_eq!(buffer.data.len(), 256);
1592
1593 assert!(buffer.data.iter().all(|&b| b == 0));
1595 }
1596
1597 #[test]
1598 fn test_gpuerror_display() {
1599 let error = GpuError::BackendNotAvailable("CUDA".to_string());
1600 assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1601
1602 let error = GpuError::OutOfMemory("allocation failed".to_string());
1603 assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1604
1605 let error = GpuError::KernelCompilationError("syntax error".to_string());
1606 assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1607
1608 let error = GpuError::KernelNotFound("gemm".to_string());
1609 assert_eq!(error.to_string(), "Kernel not found: gemm");
1610 }
1611
1612 #[test]
1613 fn test_backend_equality() {
1614 assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1615 assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1616
1617 let backend = GpuBackend::Metal;
1619 let cloned = backend;
1620 let copied = backend;
1621 assert_eq!(backend, cloned);
1622 assert_eq!(backend, copied);
1623 }
1624
1625 #[test]
1626 fn test_backend_hash() {
1627 use std::collections::HashSet;
1628
1629 let mut set = HashSet::new();
1630 set.insert(GpuBackend::Cuda);
1631 set.insert(GpuBackend::Rocm);
1632 set.insert(GpuBackend::Cuda); assert_eq!(set.len(), 2); assert!(set.contains(&GpuBackend::Cuda));
1636 assert!(set.contains(&GpuBackend::Rocm));
1637 }
1638
1639 #[test]
1640 fn test_gpu_buffer_debug_clone() {
1641 let inner = Arc::new(CpuBuffer::new(16));
1642 let buffer = GpuBuffer::<f32>::new(inner, 4);
1643
1644 let debug_str = format!("{:?}", buffer);
1646 assert!(debug_str.contains("GpuBuffer"));
1647 assert!(debug_str.contains("size"));
1648
1649 let cloned = buffer.clone();
1651 assert_eq!(cloned.len(), buffer.len());
1652 assert_eq!(cloned.len(), 4);
1653
1654 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1656 let _ = buffer.copy_from_host(&data);
1657
1658 let mut result = vec![0.0f32; 4];
1659 let _ = cloned.copy_to_host(&mut result);
1660 assert_eq!(result, data);
1661 }
1662
1663 #[test]
1664 fn test_gpu_context_debug() {
1665 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1666
1667 let debug_str = format!("{:?}", context);
1669 assert!(debug_str.contains("GpuContext"));
1670 assert!(debug_str.contains("backend"));
1671 assert!(debug_str.contains("Cpu"));
1672 }
1673}