Skip to main content

scirs2_core/gpu/
mod.rs

1//! GPU acceleration module for scirs2-core
2//!
3//! This module provides hardware acceleration support across different GPU backends.
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod heterogeneous;
16pub mod kernels;
17pub mod memory_management;
18pub mod stream_allocator;
19pub mod tensor_cores;
20
21pub use async_transfer::{
22    AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
23};
24pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
25pub use stream_allocator::{StreamAllocator, StreamId};
26
27/// GPU backend type
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum GpuBackend {
30    /// NVIDIA CUDA backend
31    Cuda,
32    /// AMD ROCm backend
33    Rocm,
34    /// WebGPU backend
35    Wgpu,
36    /// Apple Metal backend
37    Metal,
38    /// OpenCL backend
39    OpenCL,
40    /// CPU fallback
41    Cpu,
42}
43
44impl Default for GpuBackend {
45    fn default() -> Self {
46        Self::preferred()
47    }
48}
49
50impl GpuBackend {
51    /// Get the preferred GPU backend for the current system
52    pub fn preferred() -> Self {
53        // Use the backend detection system to find the optimal backend
54        // This will properly detect available GPUs and fall back to CPU if needed
55        match backends::initialize_optimal_backend() {
56            Ok(backend) => {
57                // If we get a non-CPU backend, verify it's actually usable
58                if backend != GpuBackend::Cpu {
59                    // Check if we can actually create a context with this backend
60                    // For now, since implementations are stubs, fall back to CPU
61                    #[cfg(not(test))]
62                    {
63                        // In non-test environments, we don't have real GPU implementations yet
64                        return GpuBackend::Cpu;
65                    }
66                    #[cfg(test)]
67                    {
68                        // In tests, we can pretend the backend works
69                        return backend;
70                    }
71                }
72                backend
73            }
74            Err(_) => {
75                // If detection fails entirely, use CPU
76                GpuBackend::Cpu
77            }
78        }
79    }
80
81    /// Check if this backend is available on the current system
82    pub fn is_available(&self) -> bool {
83        match self {
84            // Check runtime availability for GPU backends
85            GpuBackend::Cuda => {
86                #[cfg(feature = "cuda")]
87                {
88                    use crate::gpu::backends::cuda::CudaContext;
89                    CudaContext::is_available()
90                }
91                #[cfg(not(feature = "cuda"))]
92                {
93                    false
94                }
95            }
96            GpuBackend::Rocm => cfg!(feature = "rocm"), // Would use ROCm runtime check
97            GpuBackend::Wgpu => {
98                #[cfg(feature = "wgpu_backend")]
99                {
100                    use crate::gpu::backends::wgpu::WebGPUContext;
101                    WebGPUContext::is_available()
102                }
103                #[cfg(not(feature = "wgpu_backend"))]
104                {
105                    false
106                }
107            }
108            GpuBackend::Metal => {
109                #[cfg(all(feature = "metal", target_os = "macos"))]
110                {
111                    // Metal is always available on macOS if the feature is enabled
112                    true
113                }
114                #[cfg(not(all(feature = "metal", target_os = "macos")))]
115                {
116                    false
117                }
118            }
119            GpuBackend::OpenCL => {
120                #[cfg(feature = "opencl")]
121                {
122                    use crate::gpu::backends::opencl::OpenCLContext;
123                    OpenCLContext::is_available()
124                }
125                #[cfg(not(feature = "opencl"))]
126                {
127                    false
128                }
129            }
130            GpuBackend::Cpu => true,
131        }
132    }
133}
134
135impl fmt::Display for GpuBackend {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        match self {
138            GpuBackend::Cuda => write!(f, "CUDA"),
139            GpuBackend::Rocm => write!(f, "ROCm"),
140            GpuBackend::Wgpu => write!(f, "WebGPU"),
141            GpuBackend::Metal => write!(f, "Metal"),
142            GpuBackend::OpenCL => write!(f, "OpenCL"),
143            GpuBackend::Cpu => write!(f, "CPU"),
144        }
145    }
146}
147
148use crate::error::{CoreError, ErrorContext, ErrorLocation};
149
150/// Error type for GPU operations
151#[derive(Debug, thiserror::Error)]
152pub enum GpuError {
153    /// Backend is not available
154    #[error("GPU backend {0} is not available")]
155    BackendNotAvailable(String),
156
157    /// Backend is not supported
158    #[error("GPU backend {0} is not supported")]
159    UnsupportedBackend(GpuBackend),
160
161    /// Backend is not supported for a kernel
162    #[error("GPU backend {0:?} is not supported for this kernel")]
163    BackendNotSupported(GpuBackend),
164
165    /// Backend is not implemented yet
166    #[error("GPU backend {0} is not implemented yet")]
167    BackendNotImplemented(GpuBackend),
168
169    /// Out of memory
170    #[error("GPU out of memory: {0}")]
171    OutOfMemory(String),
172
173    /// Kernel compilation error
174    #[error("Kernel compilation error: {0}")]
175    KernelCompilationError(String),
176
177    /// Kernel execution error
178    #[error("Kernel execution error: {0}")]
179    KernelExecutionError(String),
180
181    /// Invalid parameter
182    #[error("Invalid parameter: {0}")]
183    InvalidParameter(String),
184
185    /// Kernel not found
186    #[error("Kernel not found: {0}")]
187    KernelNotFound(String),
188
189    /// Specialization not supported
190    #[error("Kernel specialization not supported")]
191    SpecializationNotSupported,
192
193    /// Unsupported data type
194    #[error("Unsupported data type: {0:?}")]
195    UnsupportedDataType(kernels::DataType),
196
197    /// Other error
198    #[error("{0}")]
199    Other(String),
200}
201
202/// GPU device abstraction
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub struct GpuDevice {
205    backend: GpuBackend,
206    device_id: usize,
207}
208
209impl GpuDevice {
210    /// Create a new GPU device
211    pub fn new(backend: GpuBackend, device_id: usize) -> Self {
212        Self { backend, device_id }
213    }
214
215    /// Get the backend type
216    pub fn backend(&self) -> GpuBackend {
217        self.backend
218    }
219
220    /// Get the device ID
221    pub fn device_id(&self) -> usize {
222        self.device_id
223    }
224
225    /// Compile a kernel from source
226    pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
227        // Placeholder implementation
228        Ok(GpuKernel {
229            backend: self.backend,
230            entry_point: entrypoint.to_string(),
231        })
232    }
233}
234
235/// GPU kernel abstraction
236pub struct GpuKernel {
237    backend: GpuBackend,
238    entry_point: String,
239}
240
241impl GpuKernel {
242    /// Get the backend type
243    pub fn backend(&self) -> GpuBackend {
244        self.backend
245    }
246
247    /// Get the entry point name
248    pub fn entry_point(&self) -> &str {
249        &self.entry_point
250    }
251}
252
253/// Convert GPU errors to core errors with semantic preservation
254impl From<GpuError> for CoreError {
255    fn from(err: GpuError) -> Self {
256        match err {
257            GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
258                ErrorContext::new(format!("GPU backend {backend} is not available"))
259                    .with_location(ErrorLocation::new(file!(), line!())),
260            ),
261            GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
262                ErrorContext::new(format!("GPU backend {backend} is not supported"))
263                    .with_location(ErrorLocation::new(file!(), line!())),
264            ),
265            GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
266                ErrorContext::new(format!(
267                    "GPU backend {backend:?} is not supported for this kernel"
268                ))
269                .with_location(ErrorLocation::new(file!(), line!())),
270            ),
271            GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
272                ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
273                    .with_location(ErrorLocation::new(file!(), line!())),
274            ),
275            GpuError::OutOfMemory(details) => CoreError::MemoryError(
276                ErrorContext::new(details.to_string())
277                    .with_location(ErrorLocation::new(file!(), line!())),
278            ),
279            GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
280                ErrorContext::new(msg.to_string())
281                    .with_location(ErrorLocation::new(file!(), line!())),
282            ),
283            GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
284                ErrorContext::new(msg.to_string())
285                    .with_location(ErrorLocation::new(file!(), line!())),
286            ),
287            GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
288                ErrorContext::new(msg.to_string())
289                    .with_location(ErrorLocation::new(file!(), line!())),
290            ),
291            GpuError::KernelNotFound(name) => CoreError::ComputationError(
292                ErrorContext::new(name.to_string())
293                    .with_location(ErrorLocation::new(file!(), line!())),
294            ),
295            GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
296                ErrorContext::new("Kernel specialization not supported".to_string())
297                    .with_location(ErrorLocation::new(file!(), line!())),
298            ),
299            GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
300                ErrorContext::new(format!("{dtype:?}"))
301                    .with_location(ErrorLocation::new(file!(), line!())),
302            ),
303            GpuError::Other(msg) => CoreError::ComputationError(
304                ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
305            ),
306        }
307    }
308}
309
310/// Trait for types that can be used with GPU operations
311pub trait GpuDataType: Copy + Send + Sync + 'static {}
312
313/// GPU memory pointer abstraction
314#[derive(Debug)]
315pub struct GpuPtr<T: GpuDataType> {
316    ptr: u64,
317    size: usize,
318    phantom: PhantomData<T>,
319}
320
321impl<T: GpuDataType> GpuPtr<T> {
322    /// Allocate GPU memory
323    pub fn allocate(size: usize) -> Result<Self, GpuError> {
324        Ok(GpuPtr {
325            ptr: 0x1000_0000, // Placeholder address
326            size,
327            phantom: PhantomData,
328        })
329    }
330
331    /// Get the raw pointer value
332    pub fn as_ptr(&self) -> u64 {
333        self.ptr
334    }
335
336    /// Get the size in elements
337    pub fn len(&self) -> usize {
338        self.size
339    }
340
341    /// Check if the pointer is empty (size is 0)
342    pub fn is_empty(&self) -> bool {
343        self.size == 0
344    }
345}
346
347/// Kernel argument types for GPU kernel execution
348#[derive(Debug, Clone)]
349pub enum KernelArg<'a, T: GpuDataType> {
350    /// Buffer argument
351    Buffer(&'a GpuPtr<T>),
352    /// Scalar argument
353    Scalar(T),
354}
355
356/// Non-generic kernel argument for mixed-type kernel calls
357#[derive(Debug, Clone)]
358pub enum DynamicKernelArg {
359    /// Buffer argument (type-erased)
360    Buffer(u64), // Raw pointer
361    /// f32 scalar
362    F32(f32),
363    /// f64 scalar
364    F64(f64),
365    /// i32 scalar
366    I32(i32),
367    /// u32 scalar
368    U32(u32),
369    /// usize scalar
370    Usize(usize),
371}
372
373/// GPU communication channel for multi-GPU operations
374pub struct GpuChannel {
375    #[allow(dead_code)]
376    source_device: usize,
377    #[allow(dead_code)]
378    target_device: usize,
379    #[allow(dead_code)]
380    bandwidth: f64, // GB/s
381}
382
383// Implement for common types
384impl GpuDataType for f32 {}
385impl GpuDataType for f64 {}
386impl GpuDataType for i32 {}
387impl GpuDataType for u32 {}
388impl GpuDataType for u8 {}
389impl GpuDataType for i8 {}
390impl GpuDataType for u16 {}
391impl GpuDataType for i16 {}
392impl GpuDataType for u64 {}
393impl GpuDataType for i64 {}
394impl GpuDataType for usize {}
395impl GpuDataType for isize {}
396
397/// GPU buffer
398pub struct GpuBuffer<T: GpuDataType> {
399    inner: Arc<dyn GpuBufferImpl>,
400    size: usize,
401    phantom: PhantomData<T>,
402}
403
404impl<T: GpuDataType> GpuBuffer<T> {
405    /// Create a new buffer with the given size
406    pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
407        Self {
408            inner,
409            size,
410            phantom: PhantomData,
411        }
412    }
413
414    /// Get the size of the buffer in elements
415    pub fn len(&self) -> usize {
416        self.size
417    }
418
419    /// Check if the buffer is empty
420    pub fn is_empty(&self) -> bool {
421        self.size == 0
422    }
423
424    /// Copy data from the host to the device
425    pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
426        if data.len() > self.size {
427            return Err(GpuError::InvalidParameter(
428                "Data size exceeds buffer size".to_string(),
429            ));
430        }
431        unsafe {
432            self.inner
433                .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
434        }
435        Ok(())
436    }
437
438    /// Copy data from the device to the host
439    pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
440        if data.len() > self.size {
441            return Err(GpuError::InvalidParameter(
442                "Data size exceeds buffer size".to_string(),
443            ));
444        }
445        unsafe {
446            self.inner
447                .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
448        }
449        Ok(())
450    }
451
452    /// Convert the buffer contents to a vector
453    pub fn to_vec(&self) -> Vec<T> {
454        let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
455        let _ = self.copy_to_host(&mut result);
456        result
457    }
458}
459
460impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
461    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462        f.debug_struct("GpuBuffer")
463            .field("size", &self.size)
464            .finish()
465    }
466}
467
468impl<T: GpuDataType> Clone for GpuBuffer<T> {
469    fn clone(&self) -> Self {
470        Self {
471            inner: Arc::clone(&self.inner),
472            size: self.size,
473            phantom: PhantomData,
474        }
475    }
476}
477
478/// GPU kernel handle
479#[derive(Clone)]
480pub struct GpuKernelHandle {
481    inner: Arc<dyn GpuKernelImpl>,
482}
483
484impl GpuKernelHandle {
485    /// Create a new kernel handle
486    pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
487        Self { inner }
488    }
489
490    /// Set a buffer parameter
491    pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
492        self.inner.set_buffer(name, &buffer.inner);
493    }
494
495    /// Set a u32 parameter
496    pub fn set_u32(&self, name: &str, value: u32) {
497        self.inner.set_u32(name, value);
498    }
499
500    /// Set an i32 parameter
501    pub fn set_i32(&self, name: &str, value: i32) {
502        self.inner.set_i32(name, value);
503    }
504
505    /// Set an f32 parameter
506    pub fn set_f32(&self, name: &str, value: f32) {
507        self.inner.set_f32(name, value);
508    }
509
510    /// Set an f64 parameter
511    pub fn set_f64(&self, name: &str, value: f64) {
512        self.inner.set_f64(name, value);
513    }
514
515    /// Dispatch the kernel with the given work group counts.
516    ///
517    /// If batch mode is active (see [`GpuContext::begin_batch`]), the
518    /// dispatch is encoded into the shared command buffer instead of
519    /// submitting immediately.
520    pub fn dispatch(&self, workgroups: [u32; 3]) {
521        if !self.inner.try_batch_dispatch(workgroups) {
522            self.inner.dispatch(workgroups);
523        }
524    }
525
526    /// Dispatch the kernel without waiting for GPU completion.
527    ///
528    /// This submits work to the GPU command queue but returns immediately.
529    /// Use [`GpuContext::gpu_sync()`] to wait for all pending dispatches.
530    pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
531        self.inner.dispatch_no_wait(workgroups);
532    }
533}
534
535/// GPU compiler for dynamic kernels
536pub struct GpuCompiler {
537    inner: Arc<dyn GpuCompilerImpl>,
538}
539
540impl GpuCompiler {
541    /// Create a new compiler
542    pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
543        Self { inner }
544    }
545
546    /// Compile a kernel from source
547    pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
548        let kernel = self.inner.compile(source)?;
549        Ok(GpuKernelHandle::new(kernel))
550    }
551
552    /// Compile a kernel for the specified input and output types
553    pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
554        let kernel = self.inner.compile_typed(
555            name,
556            std::any::TypeId::of::<I>(),
557            std::any::TypeId::of::<O>(),
558        );
559        GpuKernelHandle::new(kernel)
560    }
561}
562
563/// GPU context for managing GPU resources and operations
564pub struct GpuContext {
565    inner: Arc<dyn GpuContextImpl>,
566    backend: GpuBackend,
567    kernel_registry: kernels::KernelRegistry,
568}
569
570impl GpuContext {
571    /// Create a new GPU context with the specified backend
572    pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
573        // First check if the backend is available at compile time
574        if !backend.is_available() {
575            return Err(GpuError::BackendNotAvailable(backend.to_string()));
576        }
577
578        // For non-CPU backends, also check runtime availability
579        if backend != GpuBackend::Cpu {
580            let detection_result = backends::detect_gpu_backends();
581            let backend_available = detection_result
582                .devices
583                .iter()
584                .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
585
586            if !backend_available {
587                return Err(GpuError::BackendNotAvailable(format!(
588                    "{backend} (no devices detected at runtime)"
589                )));
590            }
591        }
592
593        let inner = match backend {
594            GpuBackend::Cuda => {
595                #[cfg(feature = "cuda")]
596                {
597                    use crate::gpu::backends::cuda::CudaContext;
598                    match CudaContext::new() {
599                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
600                        Err(e) => return Err(e),
601                    }
602                }
603                #[cfg(not(feature = "cuda"))]
604                {
605                    return Err(GpuError::UnsupportedBackend(backend));
606                }
607            }
608            GpuBackend::Rocm => {
609                #[cfg(feature = "rocm")]
610                {
611                    // This is just a stub - in a real implementation, we would use the hip-sys crate
612                    // to create a ROCm context and return it
613                    #[cfg(test)]
614                    {
615                        // For testing, we can use a mock implementation
616                        Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
617                    }
618                    #[cfg(not(test))]
619                    {
620                        return Err(GpuError::BackendNotImplemented(backend));
621                    }
622                }
623                #[cfg(not(feature = "rocm"))]
624                {
625                    return Err(GpuError::UnsupportedBackend(backend));
626                }
627            }
628            GpuBackend::Wgpu => {
629                #[cfg(feature = "wgpu_backend")]
630                {
631                    use crate::gpu::backends::wgpu::WebGPUContext;
632                    match WebGPUContext::new() {
633                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
634                        Err(e) => return Err(e),
635                    }
636                }
637                #[cfg(not(feature = "wgpu_backend"))]
638                {
639                    return Err(GpuError::UnsupportedBackend(backend));
640                }
641            }
642            GpuBackend::Metal => {
643                #[cfg(all(feature = "metal", target_os = "macos"))]
644                {
645                    use crate::gpu::backends::metal::MetalContext;
646                    match MetalContext::new() {
647                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
648                        Err(e) => return Err(e),
649                    }
650                }
651                #[cfg(not(all(feature = "metal", target_os = "macos")))]
652                {
653                    return Err(GpuError::UnsupportedBackend(backend));
654                }
655            }
656            GpuBackend::OpenCL => {
657                #[cfg(feature = "opencl")]
658                {
659                    use crate::gpu::backends::opencl::OpenCLContext;
660                    match OpenCLContext::new() {
661                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
662                        Err(e) => return Err(e),
663                    }
664                }
665                #[cfg(not(feature = "opencl"))]
666                {
667                    return Err(GpuError::UnsupportedBackend(backend));
668                }
669            }
670            GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
671        };
672
673        Ok(Self {
674            inner,
675            backend,
676            kernel_registry: kernels::KernelRegistry::with_default_kernels(),
677        })
678    }
679
680    /// Get the backend type
681    pub fn backend(&self) -> GpuBackend {
682        self.backend
683    }
684
685    /// Get the backend name
686    pub fn backend_name(&self) -> &str {
687        match self.backend {
688            GpuBackend::Cuda => "CUDA",
689            GpuBackend::Rocm => "ROCm",
690            GpuBackend::Wgpu => "WebGPU",
691            GpuBackend::Metal => "Metal",
692            GpuBackend::OpenCL => "OpenCL",
693            GpuBackend::Cpu => "CPU",
694        }
695    }
696
697    /// Wait for all pending GPU dispatches to complete.
698    ///
699    /// After calling [`GpuKernelHandle::dispatch_no_wait()`], results are
700    /// not guaranteed to be available until this method returns.
701    pub fn gpu_sync(&self) -> Result<(), GpuError> {
702        self.inner.gpu_sync()
703    }
704
705    /// Begin batch dispatch mode.
706    ///
707    /// While active, kernel dispatches are encoded into a shared command
708    /// buffer instead of creating individual ones.  Call [`end_batch`](Self::end_batch)
709    /// to submit all encoded work and wait for completion.
710    pub fn begin_batch(&self) -> Result<(), GpuError> {
711        self.inner.begin_batch()
712    }
713
714    /// End batch dispatch mode.
715    ///
716    /// Submits the shared command buffer containing all dispatches that
717    /// were encoded since [`begin_batch`](Self::begin_batch), then waits
718    /// for the GPU to finish.
719    pub fn end_batch(&self) -> Result<(), GpuError> {
720        self.inner.end_batch()
721    }
722
723    pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
724        let byte_size = size.saturating_mul(std::mem::size_of::<T>());
725        let inner = self.inner.create_buffer(byte_size);
726        GpuBuffer::new(inner, size)
727    }
728
729    /// Create a buffer from a slice
730    pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
731        let buffer = self.create_buffer::<T>(data.len());
732        let _ = buffer.copy_from_host(data);
733        buffer
734    }
735
736    /// Execute a function with a compiler
737    pub fn execute<F, R>(&self, f: F) -> R
738    where
739        F: FnOnce(&GpuCompiler) -> R,
740    {
741        let compiler = GpuCompiler::new(self.inner.create_compiler());
742        f(&compiler)
743    }
744
745    /// Get a kernel from the registry
746    pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
747        let kernel = self
748            .kernel_registry
749            .get(name)
750            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
751
752        let kernel_source = kernel.source_for_backend(self.backend)?;
753        let metadata = kernel.metadata();
754
755        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
756        Ok(handle)
757    }
758
759    /// Get a specialized kernel from the registry
760    pub fn get_specialized_kernel(
761        &self,
762        name: &str,
763        params: &kernels::KernelParams,
764    ) -> Result<GpuKernelHandle, GpuError> {
765        let specialized = self.kernel_registry.get_specialized(name, params)?;
766        let kernel_source = specialized.source_for_backend(self.backend)?;
767        let metadata = specialized.metadata();
768
769        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
770        Ok(handle)
771    }
772
773    /// Compile a kernel with metadata
774    fn compile_kernel_with_metadata(
775        &self,
776        source: &str,
777        _metadata: &kernels::KernelMetadata,
778    ) -> Result<GpuKernelHandle, GpuError> {
779        self.execute(|compiler| compiler.compile(source))
780    }
781
782    /// Get available memory on the device
783    pub fn get_available_memory(&self) -> Option<usize> {
784        // In a real implementation, this would query the device
785        // For now, return a placeholder value
786        Some(1024 * 1024 * 1024) // 1GB
787    }
788
789    /// Get total memory on the device
790    pub fn get_total_memory(&self) -> Option<usize> {
791        // In a real implementation, this would query the device
792        // For now, return a placeholder value
793        #[cfg(target_arch = "wasm32")]
794        return Some(512 * 1024 * 1024); // 512MB for WASM32
795
796        #[cfg(not(target_arch = "wasm32"))]
797        Some((4u64 * 1024 * 1024 * 1024) as usize) // 4GB for native
798    }
799
800    /// Launch a kernel with the given parameters
801    pub fn launch_kernel(
802        &self,
803        kernel_name: &str,
804        grid_size: (usize, usize, usize),
805        block_size: (usize, usize, usize),
806        args: &[DynamicKernelArg],
807    ) -> Result<(), GpuError> {
808        // Placeholder implementation
809        let _ = (kernel_name, grid_size, block_size, args);
810        Ok(())
811    }
812
813    /// Transfer data from host to device asynchronously
814    pub fn transfer_async_host_to_device<T: GpuDataType>(
815        &self,
816        ptr: &GpuPtr<T>,
817        data: &[T],
818    ) -> Result<(), GpuError> {
819        // Placeholder implementation
820        let _ = (ptr, data);
821        Ok(())
822    }
823
824    /// Transfer data from host to device synchronously
825    pub fn transfer_host_to_device<T: GpuDataType>(
826        &self,
827        ptr: &GpuPtr<T>,
828        data: &[T],
829    ) -> Result<(), GpuError> {
830        // Placeholder implementation
831        let _ = (ptr, data);
832        Ok(())
833    }
834
835    /// Transfer data from device to host asynchronously
836    pub fn transfer_async_device_to_host<T: GpuDataType>(
837        &self,
838        ptr: &GpuPtr<T>,
839        data: &mut [T],
840    ) -> Result<(), GpuError> {
841        // Placeholder implementation
842        let _ = (ptr, data);
843        Ok(())
844    }
845
846    /// Transfer data from device to host synchronously
847    pub fn transfer_device_to_host<T: GpuDataType>(
848        &self,
849        ptr: &GpuPtr<T>,
850        data: &mut [T],
851    ) -> Result<(), GpuError> {
852        // Placeholder implementation
853        let _ = (ptr, data);
854        Ok(())
855    }
856
857    /// Execute a kernel with dynamic compilation and parameter passing
858    /// This method is expected by scirs2-vision for GPU operations
859    pub fn execute_kernel(
860        &self,
861        source: &str,
862        buffers: &[GpuBuffer<f32>],
863        work_groups: (u32, u32, u32),
864        int_params: &[u32],
865        float_params: &[f32],
866    ) -> Result<(), GpuError> {
867        // For now, provide a basic implementation that logs the execution
868        // In a real implementation, this would compile and execute the kernel
869        eprintln!(
870            "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
871            source.len(),
872            buffers.len(),
873            work_groups
874        );
875        eprintln!("Int params: {int_params:?}");
876        eprintln!("Float params: {float_params:?}");
877        Ok(())
878    }
879
880    /// Read data from a GPU buffer
881    /// This method is expected by scirs2-vision for reading GPU results
882    pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
883        Ok(buffer.to_vec())
884    }
885
886    /// Global sum reduction
887    pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
888        self.sum_all_cpu_fallback(buffer)
889    }
890
891    /// Global mean reduction
892    pub fn mean_all<T: GpuDataType>(
893        &self,
894        buffer: &GpuBuffer<T>,
895    ) -> Result<GpuBuffer<T>, GpuError> {
896        self.mean_all_cpu_fallback(buffer)
897    }
898
899    /// Global max reduction
900    pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
901        self.max_all_cpu_fallback(buffer)
902    }
903
904    /// Global min reduction
905    pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
906        self.min_all_cpu_fallback(buffer)
907    }
908
909    /// Sum reduction along an axis
910    pub fn sum_axis<T: GpuDataType>(
911        &self,
912        buffer: &GpuBuffer<T>,
913        shape: &[usize],
914        axis: usize,
915    ) -> Result<GpuBuffer<T>, GpuError> {
916        self.sum_axis_cpu_fallback(buffer, shape, axis)
917    }
918
919    /// Mean reduction along an axis
920    pub fn mean_axis<T: GpuDataType>(
921        &self,
922        buffer: &GpuBuffer<T>,
923        shape: &[usize],
924        axis: usize,
925    ) -> Result<GpuBuffer<T>, GpuError> {
926        self.mean_axis_cpu_fallback(buffer, shape, axis)
927    }
928
929    /// Max reduction along an axis
930    pub fn max_axis<T: GpuDataType>(
931        &self,
932        buffer: &GpuBuffer<T>,
933        shape: &[usize],
934        axis: usize,
935    ) -> Result<GpuBuffer<T>, GpuError> {
936        self.max_axis_cpu_fallback(buffer, shape, axis)
937    }
938
939    /// Min reduction along an axis
940    pub fn min_axis<T: GpuDataType>(
941        &self,
942        buffer: &GpuBuffer<T>,
943        shape: &[usize],
944        axis: usize,
945    ) -> Result<GpuBuffer<T>, GpuError> {
946        self.min_axis_cpu_fallback(buffer, shape, axis)
947    }
948
949    /// Broadcast a buffer to a different shape
950    pub fn broadcast<T: GpuDataType>(
951        &self,
952        buffer: &GpuBuffer<T>,
953        from_shape: &[usize],
954        to_shape: &[usize],
955    ) -> Result<GpuBuffer<T>, GpuError> {
956        self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
957    }
958
959    /// Scale a buffer by a scalar value
960    pub fn scale<T: GpuDataType>(
961        &self,
962        buffer: &GpuBuffer<T>,
963        scalar: T,
964    ) -> Result<GpuBuffer<T>, GpuError> {
965        self.scale_cpu_fallback(buffer, scalar)
966    }
967
968    /// General matrix multiplication: C = A @ B
969    pub fn gemm<T: GpuDataType>(
970        &self,
971        a: &GpuBuffer<T>,
972        b: &GpuBuffer<T>,
973        m: usize,
974        k: usize,
975        n: usize,
976    ) -> Result<GpuBuffer<T>, GpuError> {
977        self.gemm_cpu_fallback(a, b, m, k, n)
978    }
979
980    /// GEMM with transposed B: C = A @ B^T
981    pub fn gemm_transpose_b<T: GpuDataType>(
982        &self,
983        a: &GpuBuffer<T>,
984        b: &GpuBuffer<T>,
985        m: usize,
986        k: usize,
987        n: usize,
988    ) -> Result<GpuBuffer<T>, GpuError> {
989        self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
990    }
991
992    /// GEMM with transposed A: C = A^T @ B
993    pub fn gemm_transpose_a<T: GpuDataType>(
994        &self,
995        a: &GpuBuffer<T>,
996        b: &GpuBuffer<T>,
997        m: usize,
998        k: usize,
999        n: usize,
1000    ) -> Result<GpuBuffer<T>, GpuError> {
1001        self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1002    }
1003
1004    /// ReLU activation forward pass
1005    pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1006        self.relu_cpu_fallback(input)
1007    }
1008
1009    /// ReLU backward pass
1010    pub fn relu_backward<T: GpuDataType>(
1011        &self,
1012        grad_output: &GpuBuffer<T>,
1013        input: &GpuBuffer<T>,
1014    ) -> Result<GpuBuffer<T>, GpuError> {
1015        self.relu_backward_cpu_fallback(grad_output, input)
1016    }
1017
1018    /// Sigmoid activation forward pass
1019    pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1020        self.sigmoid_cpu_fallback(input)
1021    }
1022
1023    /// Sigmoid backward pass
1024    pub fn sigmoid_backward<T: GpuDataType>(
1025        &self,
1026        grad_output: &GpuBuffer<T>,
1027        input: &GpuBuffer<T>,
1028    ) -> Result<GpuBuffer<T>, GpuError> {
1029        self.sigmoid_backward_cpu_fallback(grad_output, input)
1030    }
1031
1032    /// Tanh activation forward pass
1033    pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1034        self.tanh_cpu_fallback(input)
1035    }
1036
1037    /// Tanh backward pass
1038    pub fn tanh_backward<T: GpuDataType>(
1039        &self,
1040        grad_output: &GpuBuffer<T>,
1041        input: &GpuBuffer<T>,
1042    ) -> Result<GpuBuffer<T>, GpuError> {
1043        self.tanh_backward_cpu_fallback(grad_output, input)
1044    }
1045
1046    /// GELU activation forward pass
1047    pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1048        self.gelu_cpu_fallback(input)
1049    }
1050
1051    /// GELU backward pass
1052    pub fn gelu_backward<T: GpuDataType>(
1053        &self,
1054        grad_output: &GpuBuffer<T>,
1055        input: &GpuBuffer<T>,
1056    ) -> Result<GpuBuffer<T>, GpuError> {
1057        self.gelu_backward_cpu_fallback(grad_output, input)
1058    }
1059}
1060
1061impl fmt::Debug for GpuContext {
1062    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1063        f.debug_struct("GpuContext")
1064            .field("backend", &self.backend)
1065            .finish()
1066    }
1067}
1068
1069// The following trait definitions would be implemented by backend-specific
1070// code in a real implementation
1071
1072/// GPU buffer implementation trait
1073pub(crate) trait GpuBufferImpl: Send + Sync {
1074    /// Copy data from host to device
1075    unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1076
1077    /// Copy data from device to host
1078    unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1079
1080    /// Get a reference to self as Any for downcasting
1081    #[allow(dead_code)]
1082    fn as_any(&self) -> &dyn std::any::Any;
1083
1084    /// Get the size of the buffer in bytes
1085    #[allow(dead_code)]
1086    fn size(&self) -> usize {
1087        0 // Default implementation for backward compatibility
1088    }
1089
1090    /// Get the device pointer (for backends that use device pointers)
1091    #[allow(dead_code)]
1092    fn device_ptr(&self) -> u64 {
1093        0 // Default implementation for backward compatibility
1094    }
1095}
1096
1097/// GPU kernel implementation trait
1098pub(crate) trait GpuKernelImpl: Send + Sync {
1099    /// Set a buffer parameter
1100    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1101
1102    /// Set a u32 parameter
1103    fn set_u32(&self, name: &str, value: u32);
1104
1105    /// Set an i32 parameter
1106    fn set_i32(&self, name: &str, value: i32);
1107
1108    /// Set an f32 parameter
1109    fn set_f32(&self, name: &str, value: f32);
1110
1111    /// Set an f64 parameter
1112    fn set_f64(&self, name: &str, value: f64);
1113
1114    /// Dispatch the kernel
1115    fn dispatch(&self, workgroups: [u32; 3]);
1116
1117    /// Dispatch the kernel without waiting for completion.
1118    /// The GPU work is submitted and will execute asynchronously.
1119    /// Call `gpu_sync()` on the context to wait for all pending work.
1120    fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1121        // Default: fall back to synchronous dispatch
1122        self.dispatch(workgroups);
1123    }
1124
1125    /// Try to encode a dispatch into the active batch (if any).
1126    ///
1127    /// Returns `true` if the dispatch was batched, `false` if no batch is
1128    /// active and the caller should use the normal `dispatch()` path.
1129    fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1130        false
1131    }
1132}
1133
1134/// GPU compiler implementation trait
1135pub(crate) trait GpuCompilerImpl: Send + Sync {
1136    /// Compile a kernel from source
1137    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1138
1139    /// Compile a typed kernel
1140    fn compile_typed(
1141        &self,
1142        name: &str,
1143        input_type: std::any::TypeId,
1144        output_type: std::any::TypeId,
1145    ) -> Arc<dyn GpuKernelImpl>;
1146}
1147
1148/// GPU context implementation trait
1149pub(crate) trait GpuContextImpl: Send + Sync {
1150    /// Create a buffer
1151    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1152
1153    /// Create a compiler
1154    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1155
1156    /// Wait for all pending GPU operations to complete.
1157    fn gpu_sync(&self) -> Result<(), GpuError> {
1158        Ok(()) // Default: no-op (synchronous backends already block)
1159    }
1160
1161    /// Begin batch dispatch mode.
1162    ///
1163    /// While active, kernel dispatches are encoded into a shared command
1164    /// buffer instead of creating individual ones.
1165    fn begin_batch(&self) -> Result<(), GpuError> {
1166        Ok(()) // Default: no-op (backends without batch support)
1167    }
1168
1169    /// End batch dispatch mode.
1170    ///
1171    /// Submits the shared command buffer and waits for all encoded
1172    /// dispatches to complete on the GPU.
1173    fn end_batch(&self) -> Result<(), GpuError> {
1174        Ok(()) // Default: no-op
1175    }
1176
1177    /// Support dynamic downcasting of concrete context implementations
1178    fn as_any(&self) -> &dyn std::any::Any
1179    where
1180        Self: 'static + Sized,
1181    {
1182        self
1183    }
1184}
1185
1186// CPU fallback implementation
1187
1188/// CPU context implementation
1189struct CpuContext;
1190
1191impl CpuContext {
1192    /// Create a new CPU context
1193    fn new() -> Self {
1194        Self
1195    }
1196}
1197
1198impl GpuContextImpl for CpuContext {
1199    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1200        Arc::new(CpuBuffer::new(size))
1201    }
1202
1203    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1204        Arc::new(CpuCompiler)
1205    }
1206}
1207
1208/// CPU buffer implementation
1209struct CpuBuffer {
1210    data: Vec<u8>,
1211}
1212
1213impl CpuBuffer {
1214    /// Create a new CPU buffer
1215    fn new(size: usize) -> Self {
1216        Self {
1217            data: vec![0; size],
1218        }
1219    }
1220}
1221
1222impl GpuBufferImpl for CpuBuffer {
1223    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1224        let mut_self = self as *const Self as *mut Self;
1225        let data_ptr = (*mut_self).data.as_mut_ptr();
1226        std::ptr::copy_nonoverlapping(data, data_ptr, size);
1227    }
1228
1229    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1230        let data_ptr = self.data.as_ptr();
1231        std::ptr::copy_nonoverlapping(data_ptr, data, size);
1232    }
1233
1234    fn as_any(&self) -> &dyn std::any::Any {
1235        self
1236    }
1237
1238    fn size(&self) -> usize {
1239        self.data.len()
1240    }
1241
1242    fn device_ptr(&self) -> u64 {
1243        self.data.as_ptr() as u64
1244    }
1245}
1246
1247/// CPU compiler implementation
1248struct CpuCompiler;
1249
1250impl GpuCompilerImpl for CpuCompiler {
1251    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1252        // In a real implementation, we would parse and execute the kernel
1253        // For now, just return a dummy implementation
1254        Ok(Arc::new(CpuKernel))
1255    }
1256
1257    fn compile_typed(
1258        &self,
1259        _name: &str,
1260        _input_type: std::any::TypeId,
1261        _output_type: std::any::TypeId,
1262    ) -> Arc<dyn GpuKernelImpl> {
1263        // In a real implementation, we would select an appropriate implementation
1264        // For now, just return a dummy implementation
1265        Arc::new(CpuKernel)
1266    }
1267}
1268
1269/// CPU kernel implementation
1270struct CpuKernel;
1271
1272impl GpuKernelImpl for CpuKernel {
1273    fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1274        // In a real implementation, we would store the buffer
1275    }
1276
1277    fn set_u32(&self, _name: &str, value: u32) {
1278        // In a real implementation, we would store the value
1279    }
1280
1281    fn set_i32(&self, _name: &str, value: i32) {
1282        // In a real implementation, we would store the value
1283    }
1284
1285    fn set_f32(&self, _name: &str, value: f32) {
1286        // In a real implementation, we would store the value
1287    }
1288
1289    fn set_f64(&self, _name: &str, value: f64) {
1290        // In a real implementation, we would store the value
1291    }
1292
1293    fn dispatch(&self, workgroups: [u32; 3]) {
1294        // In a real implementation, we would execute the kernel
1295    }
1296}
1297
1298// In a real implementation, we would have implementations for other backends
1299// such as CUDA, WebGPU, Metal, and OpenCL.
1300
1301#[cfg(test)]
1302mod tests {
1303    use super::*;
1304
1305    #[test]
1306    fn test_gpu_backend_preferred() {
1307        let backend = GpuBackend::preferred();
1308        // Should return a valid backend
1309        match backend {
1310            GpuBackend::Cuda
1311            | GpuBackend::Rocm
1312            | GpuBackend::Wgpu
1313            | GpuBackend::Metal
1314            | GpuBackend::OpenCL
1315            | GpuBackend::Cpu => {}
1316        }
1317    }
1318
1319    #[test]
1320    fn test_gpu_backend_default() {
1321        let backend = GpuBackend::default();
1322        assert_eq!(backend, GpuBackend::preferred());
1323    }
1324
1325    #[test]
1326    fn test_gpu_backend_is_available() {
1327        let backend = GpuBackend::Cpu;
1328        assert!(backend.is_available());
1329
1330        // Test other backends - availability depends on runtime, not just feature flags
1331        #[cfg(feature = "cuda")]
1332        {
1333            // CUDA feature enabled doesn't guarantee runtime availability
1334            let _ = GpuBackend::Cuda.is_available(); // Just check without asserting
1335        }
1336        #[cfg(not(feature = "cuda"))]
1337        assert!(!GpuBackend::Cuda.is_available());
1338
1339        #[cfg(feature = "rocm")]
1340        {
1341            // ROCm feature enabled doesn't guarantee runtime availability
1342            let _ = GpuBackend::Rocm.is_available(); // Just check without asserting
1343        }
1344        #[cfg(not(feature = "rocm"))]
1345        assert!(!GpuBackend::Rocm.is_available());
1346
1347        #[cfg(all(feature = "metal", target_os = "macos"))]
1348        assert!(GpuBackend::Metal.is_available());
1349        #[cfg(not(all(feature = "metal", target_os = "macos")))]
1350        assert!(!GpuBackend::Metal.is_available());
1351    }
1352
1353    #[test]
1354    fn test_gpu_backend_display() {
1355        assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1356        assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1357        assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1358        assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1359        assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1360        assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1361    }
1362
1363    #[test]
1364    fn test_gpuerror_from_conversion() {
1365        let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1366        let coreerror: CoreError = gpuerror.into();
1367        match coreerror {
1368            CoreError::ComputationError(_) => {}
1369            _ => panic!("Expected ComputationError"),
1370        }
1371
1372        let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1373        let coreerror: CoreError = gpuerror.into();
1374        match coreerror {
1375            CoreError::MemoryError(_) => {}
1376            _ => panic!("Expected MemoryError"),
1377        }
1378
1379        let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1380        let coreerror: CoreError = gpuerror.into();
1381        match coreerror {
1382            CoreError::InvalidArgument(_) => {}
1383            _ => panic!("Expected InvalidArgument"),
1384        }
1385
1386        let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1387        let coreerror: CoreError = gpuerror.into();
1388        match coreerror {
1389            CoreError::TypeError(_) => {}
1390            _ => panic!("Expected TypeError"),
1391        }
1392    }
1393
1394    #[test]
1395    fn test_gpu_datatype_trait() {
1396        // Test that various types implement GpuDataType
1397        fn assert_gpu_datatype<T: GpuDataType>() {}
1398
1399        assert_gpu_datatype::<f32>();
1400        assert_gpu_datatype::<f64>();
1401        assert_gpu_datatype::<i32>();
1402        assert_gpu_datatype::<u32>();
1403        assert_gpu_datatype::<u8>();
1404        assert_gpu_datatype::<i8>();
1405        assert_gpu_datatype::<u16>();
1406        assert_gpu_datatype::<i16>();
1407        assert_gpu_datatype::<u64>();
1408        assert_gpu_datatype::<i64>();
1409    }
1410
1411    #[test]
1412    fn test_gpu_buffer_creation() {
1413        let inner = Arc::new(CpuBuffer::new(100));
1414        let buffer = GpuBuffer::<f32>::new(inner, 25);
1415
1416        assert_eq!(buffer.len(), 25);
1417        assert!(!buffer.is_empty());
1418    }
1419
1420    #[test]
1421    fn test_gpu_buffer_empty() {
1422        let inner = Arc::new(CpuBuffer::new(0));
1423        let buffer = GpuBuffer::<f32>::new(inner, 0);
1424
1425        assert_eq!(buffer.len(), 0);
1426        assert!(buffer.is_empty());
1427    }
1428
1429    #[test]
1430    fn test_gpu_buffer_copy_operations() {
1431        let inner = Arc::new(CpuBuffer::new(16));
1432        let buffer = GpuBuffer::<f32>::new(inner, 4);
1433
1434        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1435        let _ = buffer.copy_from_host(&data);
1436
1437        let mut result = vec![0.0f32; 4];
1438        let _ = buffer.copy_to_host(&mut result);
1439
1440        assert_eq!(result, data);
1441    }
1442
1443    #[test]
1444    fn test_gpu_buffer_to_vec() {
1445        let inner = Arc::new(CpuBuffer::new(12));
1446        let buffer = GpuBuffer::<f32>::new(inner, 3);
1447
1448        let data = vec![5.0f32, 6.0, 7.0];
1449        let _ = buffer.copy_from_host(&data);
1450
1451        let result = buffer.to_vec();
1452        assert_eq!(result, data);
1453    }
1454
1455    #[test]
1456    #[should_panic(expected = "Data size exceeds buffer size")]
1457    fn test_gpu_buffer_copy_from_host_overflow() {
1458        let inner = Arc::new(CpuBuffer::new(8));
1459        let buffer = GpuBuffer::<f32>::new(inner, 2);
1460
1461        let data = vec![1.0f32, 2.0, 3.0]; // 3 elements > 2 buffer size
1462        buffer.copy_from_host(&data).expect("Operation failed");
1463    }
1464
1465    #[test]
1466    #[should_panic(expected = "Data size exceeds buffer size")]
1467    fn test_gpu_buffer_copy_to_host_overflow() {
1468        let inner = Arc::new(CpuBuffer::new(8));
1469        let buffer = GpuBuffer::<f32>::new(inner, 2);
1470
1471        let mut data = vec![0.0f32; 3]; // 3 elements > 2 buffer size
1472        buffer.copy_to_host(&mut data).expect("Operation failed");
1473    }
1474
1475    #[test]
1476    fn test_gpu_kernel_handle() {
1477        let kernel = Arc::new(CpuKernel);
1478        let handle = GpuKernelHandle::new(kernel);
1479
1480        // Test setting various parameter types
1481        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1482        handle.set_buffer("input", &buffer);
1483        handle.set_u32("size", 100);
1484        handle.set_i32("offset", -5);
1485        handle.set_f32("scale", 2.5);
1486        handle.set_f64("precision", 0.0001);
1487
1488        // Test dispatch
1489        handle.dispatch([16, 8, 1]);
1490    }
1491
1492    #[test]
1493    fn test_gpu_context_cpu_backend() {
1494        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1495        assert_eq!(context.backend(), GpuBackend::Cpu);
1496        assert_eq!(context.backend_name(), "CPU");
1497
1498        // Test memory query methods
1499        assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1500        assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1501    }
1502
1503    #[test]
1504    fn test_gpu_context_buffer_creation() {
1505        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1506
1507        let buffer = context.create_buffer::<f32>(100);
1508        assert_eq!(buffer.len(), 100);
1509
1510        let data = vec![1.0f32; 50];
1511        let buffer_from_slice = context.create_buffer_from_slice(&data);
1512        assert_eq!(buffer_from_slice.len(), 50);
1513
1514        let result = buffer_from_slice.to_vec();
1515        assert_eq!(result, data);
1516    }
1517
1518    #[test]
1519    fn test_gpu_context_unsupported_backend() {
1520        // Test a backend that's not available
1521        #[cfg(not(feature = "cuda"))]
1522        {
1523            let result = GpuContext::new(GpuBackend::Cuda);
1524            assert!(result.is_err());
1525            match result {
1526                Err(GpuError::UnsupportedBackend(_)) => {}
1527                Err(GpuError::BackendNotAvailable(_)) => {} // Also accept this error
1528                Err(e) => panic!(
1529                    "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1530                    e
1531                ),
1532                Ok(_) => panic!("Expected error, got Ok"),
1533            }
1534        }
1535    }
1536
1537    #[test]
1538    fn test_gpu_compiler() {
1539        let compiler_impl = Arc::new(CpuCompiler);
1540        let compiler = GpuCompiler::new(compiler_impl);
1541
1542        // Test compiling from source
1543        let kernel = compiler
1544            .compile("dummy kernel source")
1545            .expect("Operation failed");
1546        kernel.dispatch([1, 1, 1]);
1547
1548        // Test typed compilation
1549        let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1550        typed_kernel.dispatch([32, 1, 1]);
1551    }
1552
1553    #[test]
1554    fn test_gpu_context_execute() {
1555        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1556
1557        let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1558
1559        assert!(result);
1560    }
1561
1562    #[test]
1563    fn test_gpu_context_kernel_registry() {
1564        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1565
1566        // Test getting a non-existent kernel
1567        let result = context.get_kernel("non_existent_kernel");
1568        assert!(result.is_err());
1569        match result {
1570            Err(GpuError::KernelNotFound(_)) => {}
1571            _ => panic!("Expected KernelNotFound error"),
1572        }
1573    }
1574
1575    #[test]
1576    fn test_cpu_buffer_implementation() {
1577        let buffer = CpuBuffer::new(256);
1578        assert_eq!(buffer.data.len(), 256);
1579
1580        // Test that initial data is zeroed
1581        assert!(buffer.data.iter().all(|&b| b == 0));
1582    }
1583
1584    #[test]
1585    fn test_gpuerror_display() {
1586        let error = GpuError::BackendNotAvailable("CUDA".to_string());
1587        assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1588
1589        let error = GpuError::OutOfMemory("allocation failed".to_string());
1590        assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1591
1592        let error = GpuError::KernelCompilationError("syntax error".to_string());
1593        assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1594
1595        let error = GpuError::KernelNotFound("gemm".to_string());
1596        assert_eq!(error.to_string(), "Kernel not found: gemm");
1597    }
1598
1599    #[test]
1600    fn test_backend_equality() {
1601        assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1602        assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1603
1604        // Test Clone and Copy
1605        let backend = GpuBackend::Metal;
1606        let cloned = backend;
1607        let copied = backend;
1608        assert_eq!(backend, cloned);
1609        assert_eq!(backend, copied);
1610    }
1611
1612    #[test]
1613    fn test_backend_hash() {
1614        use std::collections::HashSet;
1615
1616        let mut set = HashSet::new();
1617        set.insert(GpuBackend::Cuda);
1618        set.insert(GpuBackend::Rocm);
1619        set.insert(GpuBackend::Cuda); // Duplicate
1620
1621        assert_eq!(set.len(), 2); // Should only have 2 unique entries
1622        assert!(set.contains(&GpuBackend::Cuda));
1623        assert!(set.contains(&GpuBackend::Rocm));
1624    }
1625
1626    #[test]
1627    fn test_gpu_buffer_debug_clone() {
1628        let inner = Arc::new(CpuBuffer::new(16));
1629        let buffer = GpuBuffer::<f32>::new(inner, 4);
1630
1631        // Test Debug implementation
1632        let debug_str = format!("{:?}", buffer);
1633        assert!(debug_str.contains("GpuBuffer"));
1634        assert!(debug_str.contains("size"));
1635
1636        // Test Clone implementation
1637        let cloned = buffer.clone();
1638        assert_eq!(cloned.len(), buffer.len());
1639        assert_eq!(cloned.len(), 4);
1640
1641        // Verify the clone is independent (shares the same Arc)
1642        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1643        let _ = buffer.copy_from_host(&data);
1644
1645        let mut result = vec![0.0f32; 4];
1646        let _ = cloned.copy_to_host(&mut result);
1647        assert_eq!(result, data);
1648    }
1649
1650    #[test]
1651    fn test_gpu_context_debug() {
1652        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1653
1654        // Test Debug implementation
1655        let debug_str = format!("{:?}", context);
1656        assert!(debug_str.contains("GpuContext"));
1657        assert!(debug_str.contains("backend"));
1658        assert!(debug_str.contains("Cpu"));
1659    }
1660
1661    #[test]
1662    fn test_gpu_context_batch_dispatch() {
1663        // Create context with CPU backend (always available)
1664        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1665
1666        // Begin batch mode
1667        let begin_result = context.begin_batch();
1668        assert!(
1669            begin_result.is_ok(),
1670            "begin_batch should succeed on CPU backend"
1671        );
1672
1673        // Compile a kernel and dispatch it inside the batch
1674        let dispatch_result = context.execute(|compiler| {
1675            compiler.compile("dummy kernel source").map(|kernel| {
1676                kernel.dispatch([4, 1, 1]);
1677            })
1678        });
1679        assert!(
1680            dispatch_result.is_ok(),
1681            "kernel dispatch inside batch should succeed"
1682        );
1683
1684        // End batch — submits and waits
1685        let end_result = context.end_batch();
1686        assert!(
1687            end_result.is_ok(),
1688            "end_batch should succeed on CPU backend"
1689        );
1690    }
1691
1692    #[test]
1693    fn test_gpu_context_gpu_sync() {
1694        // Create context with CPU backend
1695        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1696
1697        // gpu_sync should complete without error on CPU backend
1698        let result = context.gpu_sync();
1699        assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1700    }
1701
1702    #[test]
1703    fn test_gpu_kernel_dispatch_no_wait() {
1704        // Create a kernel handle backed by CpuKernel
1705        let kernel = Arc::new(CpuKernel);
1706        let handle = GpuKernelHandle::new(kernel);
1707
1708        // Bind a buffer and scalar so the kernel has state
1709        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1710        handle.set_buffer("input", &buffer);
1711        handle.set_u32("size", 4);
1712
1713        // dispatch_no_wait returns () — just verify no panic
1714        handle.dispatch_no_wait([4, 1, 1]);
1715    }
1716}