Skip to main content

scirs2_core/gpu/
mod.rs

1//! GPU acceleration module for scirs2-core
2//!
3//! This module provides hardware acceleration support across different GPU backends.
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod auto_tuning;
11pub mod backends;
12pub mod benchmarks;
13mod cpu_ops;
14pub mod heterogeneous;
15pub mod kernels;
16pub mod memory_management;
17pub mod tensor_cores;
18
19/// GPU backend type
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum GpuBackend {
22    /// NVIDIA CUDA backend
23    Cuda,
24    /// AMD ROCm backend
25    Rocm,
26    /// WebGPU backend
27    Wgpu,
28    /// Apple Metal backend
29    Metal,
30    /// OpenCL backend
31    OpenCL,
32    /// CPU fallback
33    Cpu,
34}
35
36impl Default for GpuBackend {
37    fn default() -> Self {
38        Self::preferred()
39    }
40}
41
42impl GpuBackend {
43    /// Get the preferred GPU backend for the current system
44    pub fn preferred() -> Self {
45        // Use the backend detection system to find the optimal backend
46        // This will properly detect available GPUs and fall back to CPU if needed
47        match backends::initialize_optimal_backend() {
48            Ok(backend) => {
49                // If we get a non-CPU backend, verify it's actually usable
50                if backend != GpuBackend::Cpu {
51                    // Check if we can actually create a context with this backend
52                    // For now, since implementations are stubs, fall back to CPU
53                    #[cfg(not(test))]
54                    {
55                        // In non-test environments, we don't have real GPU implementations yet
56                        return GpuBackend::Cpu;
57                    }
58                    #[cfg(test)]
59                    {
60                        // In tests, we can pretend the backend works
61                        return backend;
62                    }
63                }
64                backend
65            }
66            Err(_) => {
67                // If detection fails entirely, use CPU
68                GpuBackend::Cpu
69            }
70        }
71    }
72
73    /// Check if this backend is available on the current system
74    pub fn is_available(&self) -> bool {
75        match self {
76            // Check runtime availability for GPU backends
77            GpuBackend::Cuda => {
78                #[cfg(feature = "cuda")]
79                {
80                    use crate::gpu::backends::cuda::CudaContext;
81                    CudaContext::is_available()
82                }
83                #[cfg(not(feature = "cuda"))]
84                {
85                    false
86                }
87            }
88            GpuBackend::Rocm => cfg!(feature = "rocm"), // Would use ROCm runtime check
89            GpuBackend::Wgpu => {
90                #[cfg(feature = "wgpu_backend")]
91                {
92                    use crate::gpu::backends::wgpu::WebGPUContext;
93                    WebGPUContext::is_available()
94                }
95                #[cfg(not(feature = "wgpu_backend"))]
96                {
97                    false
98                }
99            }
100            GpuBackend::Metal => {
101                #[cfg(all(feature = "metal", target_os = "macos"))]
102                {
103                    // Metal is always available on macOS if the feature is enabled
104                    true
105                }
106                #[cfg(not(all(feature = "metal", target_os = "macos")))]
107                {
108                    false
109                }
110            }
111            GpuBackend::OpenCL => {
112                #[cfg(feature = "opencl")]
113                {
114                    use crate::gpu::backends::opencl::OpenCLContext;
115                    OpenCLContext::is_available()
116                }
117                #[cfg(not(feature = "opencl"))]
118                {
119                    false
120                }
121            }
122            GpuBackend::Cpu => true,
123        }
124    }
125}
126
127impl fmt::Display for GpuBackend {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        match self {
130            GpuBackend::Cuda => write!(f, "CUDA"),
131            GpuBackend::Rocm => write!(f, "ROCm"),
132            GpuBackend::Wgpu => write!(f, "WebGPU"),
133            GpuBackend::Metal => write!(f, "Metal"),
134            GpuBackend::OpenCL => write!(f, "OpenCL"),
135            GpuBackend::Cpu => write!(f, "CPU"),
136        }
137    }
138}
139
140use crate::error::{CoreError, ErrorContext, ErrorLocation};
141
142/// Error type for GPU operations
143#[derive(Debug, thiserror::Error)]
144pub enum GpuError {
145    /// Backend is not available
146    #[error("GPU backend {0} is not available")]
147    BackendNotAvailable(String),
148
149    /// Backend is not supported
150    #[error("GPU backend {0} is not supported")]
151    UnsupportedBackend(GpuBackend),
152
153    /// Backend is not supported for a kernel
154    #[error("GPU backend {0:?} is not supported for this kernel")]
155    BackendNotSupported(GpuBackend),
156
157    /// Backend is not implemented yet
158    #[error("GPU backend {0} is not implemented yet")]
159    BackendNotImplemented(GpuBackend),
160
161    /// Out of memory
162    #[error("GPU out of memory: {0}")]
163    OutOfMemory(String),
164
165    /// Kernel compilation error
166    #[error("Kernel compilation error: {0}")]
167    KernelCompilationError(String),
168
169    /// Kernel execution error
170    #[error("Kernel execution error: {0}")]
171    KernelExecutionError(String),
172
173    /// Invalid parameter
174    #[error("Invalid parameter: {0}")]
175    InvalidParameter(String),
176
177    /// Kernel not found
178    #[error("Kernel not found: {0}")]
179    KernelNotFound(String),
180
181    /// Specialization not supported
182    #[error("Kernel specialization not supported")]
183    SpecializationNotSupported,
184
185    /// Unsupported data type
186    #[error("Unsupported data type: {0:?}")]
187    UnsupportedDataType(kernels::DataType),
188
189    /// Other error
190    #[error("{0}")]
191    Other(String),
192}
193
194/// GPU device abstraction
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub struct GpuDevice {
197    backend: GpuBackend,
198    device_id: usize,
199}
200
201impl GpuDevice {
202    /// Create a new GPU device
203    pub fn new(backend: GpuBackend, device_id: usize) -> Self {
204        Self { backend, device_id }
205    }
206
207    /// Get the backend type
208    pub fn backend(&self) -> GpuBackend {
209        self.backend
210    }
211
212    /// Get the device ID
213    pub fn device_id(&self) -> usize {
214        self.device_id
215    }
216
217    /// Compile a kernel from source
218    pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
219        // Placeholder implementation
220        Ok(GpuKernel {
221            backend: self.backend,
222            entry_point: entrypoint.to_string(),
223        })
224    }
225}
226
227/// GPU kernel abstraction
228pub struct GpuKernel {
229    backend: GpuBackend,
230    entry_point: String,
231}
232
233impl GpuKernel {
234    /// Get the backend type
235    pub fn backend(&self) -> GpuBackend {
236        self.backend
237    }
238
239    /// Get the entry point name
240    pub fn entry_point(&self) -> &str {
241        &self.entry_point
242    }
243}
244
245/// Convert GPU errors to core errors with semantic preservation
246impl From<GpuError> for CoreError {
247    fn from(err: GpuError) -> Self {
248        match err {
249            GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
250                ErrorContext::new(format!("GPU backend {backend} is not available"))
251                    .with_location(ErrorLocation::new(file!(), line!())),
252            ),
253            GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
254                ErrorContext::new(format!("GPU backend {backend} is not supported"))
255                    .with_location(ErrorLocation::new(file!(), line!())),
256            ),
257            GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
258                ErrorContext::new(format!(
259                    "GPU backend {backend:?} is not supported for this kernel"
260                ))
261                .with_location(ErrorLocation::new(file!(), line!())),
262            ),
263            GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
264                ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
265                    .with_location(ErrorLocation::new(file!(), line!())),
266            ),
267            GpuError::OutOfMemory(details) => CoreError::MemoryError(
268                ErrorContext::new(details.to_string())
269                    .with_location(ErrorLocation::new(file!(), line!())),
270            ),
271            GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
272                ErrorContext::new(msg.to_string())
273                    .with_location(ErrorLocation::new(file!(), line!())),
274            ),
275            GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
276                ErrorContext::new(msg.to_string())
277                    .with_location(ErrorLocation::new(file!(), line!())),
278            ),
279            GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
280                ErrorContext::new(msg.to_string())
281                    .with_location(ErrorLocation::new(file!(), line!())),
282            ),
283            GpuError::KernelNotFound(name) => CoreError::ComputationError(
284                ErrorContext::new(name.to_string())
285                    .with_location(ErrorLocation::new(file!(), line!())),
286            ),
287            GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
288                ErrorContext::new("Kernel specialization not supported".to_string())
289                    .with_location(ErrorLocation::new(file!(), line!())),
290            ),
291            GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
292                ErrorContext::new(format!("{dtype:?}"))
293                    .with_location(ErrorLocation::new(file!(), line!())),
294            ),
295            GpuError::Other(msg) => CoreError::ComputationError(
296                ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
297            ),
298        }
299    }
300}
301
302/// Trait for types that can be used with GPU operations
303pub trait GpuDataType: Copy + Send + Sync + 'static {}
304
305/// GPU memory pointer abstraction
306#[derive(Debug)]
307pub struct GpuPtr<T: GpuDataType> {
308    ptr: u64,
309    size: usize,
310    phantom: PhantomData<T>,
311}
312
313impl<T: GpuDataType> GpuPtr<T> {
314    /// Allocate GPU memory
315    pub fn allocate(size: usize) -> Result<Self, GpuError> {
316        Ok(GpuPtr {
317            ptr: 0x1000_0000, // Placeholder address
318            size,
319            phantom: PhantomData,
320        })
321    }
322
323    /// Get the raw pointer value
324    pub fn as_ptr(&self) -> u64 {
325        self.ptr
326    }
327
328    /// Get the size in elements
329    pub fn len(&self) -> usize {
330        self.size
331    }
332
333    /// Check if the pointer is empty (size is 0)
334    pub fn is_empty(&self) -> bool {
335        self.size == 0
336    }
337}
338
339/// Kernel argument types for GPU kernel execution
340#[derive(Debug, Clone)]
341pub enum KernelArg<'a, T: GpuDataType> {
342    /// Buffer argument
343    Buffer(&'a GpuPtr<T>),
344    /// Scalar argument
345    Scalar(T),
346}
347
348/// Non-generic kernel argument for mixed-type kernel calls
349#[derive(Debug, Clone)]
350pub enum DynamicKernelArg {
351    /// Buffer argument (type-erased)
352    Buffer(u64), // Raw pointer
353    /// f32 scalar
354    F32(f32),
355    /// f64 scalar
356    F64(f64),
357    /// i32 scalar
358    I32(i32),
359    /// u32 scalar
360    U32(u32),
361    /// usize scalar
362    Usize(usize),
363}
364
365/// GPU communication channel for multi-GPU operations
366pub struct GpuChannel {
367    #[allow(dead_code)]
368    source_device: usize,
369    #[allow(dead_code)]
370    target_device: usize,
371    #[allow(dead_code)]
372    bandwidth: f64, // GB/s
373}
374
375// Implement for common types
376impl GpuDataType for f32 {}
377impl GpuDataType for f64 {}
378impl GpuDataType for i32 {}
379impl GpuDataType for u32 {}
380impl GpuDataType for u8 {}
381impl GpuDataType for i8 {}
382impl GpuDataType for u16 {}
383impl GpuDataType for i16 {}
384impl GpuDataType for u64 {}
385impl GpuDataType for i64 {}
386impl GpuDataType for usize {}
387impl GpuDataType for isize {}
388
389/// GPU buffer
390pub struct GpuBuffer<T: GpuDataType> {
391    inner: Arc<dyn GpuBufferImpl>,
392    size: usize,
393    phantom: PhantomData<T>,
394}
395
396impl<T: GpuDataType> GpuBuffer<T> {
397    /// Create a new buffer with the given size
398    pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
399        Self {
400            inner,
401            size,
402            phantom: PhantomData,
403        }
404    }
405
406    /// Get the size of the buffer in elements
407    pub fn len(&self) -> usize {
408        self.size
409    }
410
411    /// Check if the buffer is empty
412    pub fn is_empty(&self) -> bool {
413        self.size == 0
414    }
415
416    /// Copy data from the host to the device
417    pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
418        if data.len() > self.size {
419            return Err(GpuError::InvalidParameter(
420                "Data size exceeds buffer size".to_string(),
421            ));
422        }
423        unsafe {
424            self.inner
425                .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
426        }
427        Ok(())
428    }
429
430    /// Copy data from the device to the host
431    pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
432        if data.len() > self.size {
433            return Err(GpuError::InvalidParameter(
434                "Data size exceeds buffer size".to_string(),
435            ));
436        }
437        unsafe {
438            self.inner
439                .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
440        }
441        Ok(())
442    }
443
444    /// Convert the buffer contents to a vector
445    pub fn to_vec(&self) -> Vec<T> {
446        let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
447        let _ = self.copy_to_host(&mut result);
448        result
449    }
450}
451
452impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        f.debug_struct("GpuBuffer")
455            .field("size", &self.size)
456            .finish()
457    }
458}
459
460impl<T: GpuDataType> Clone for GpuBuffer<T> {
461    fn clone(&self) -> Self {
462        Self {
463            inner: Arc::clone(&self.inner),
464            size: self.size,
465            phantom: PhantomData,
466        }
467    }
468}
469
470/// GPU kernel handle
471#[derive(Clone)]
472pub struct GpuKernelHandle {
473    inner: Arc<dyn GpuKernelImpl>,
474}
475
476impl GpuKernelHandle {
477    /// Create a new kernel handle
478    pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
479        Self { inner }
480    }
481
482    /// Set a buffer parameter
483    pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
484        self.inner.set_buffer(name, &buffer.inner);
485    }
486
487    /// Set a u32 parameter
488    pub fn set_u32(&self, name: &str, value: u32) {
489        self.inner.set_u32(name, value);
490    }
491
492    /// Set an i32 parameter
493    pub fn set_i32(&self, name: &str, value: i32) {
494        self.inner.set_i32(name, value);
495    }
496
497    /// Set an f32 parameter
498    pub fn set_f32(&self, name: &str, value: f32) {
499        self.inner.set_f32(name, value);
500    }
501
502    /// Set an f64 parameter
503    pub fn set_f64(&self, name: &str, value: f64) {
504        self.inner.set_f64(name, value);
505    }
506
507    /// Dispatch the kernel with the given work group counts
508    pub fn dispatch(&self, workgroups: [u32; 3]) {
509        self.inner.dispatch(workgroups);
510    }
511}
512
513/// GPU compiler for dynamic kernels
514pub struct GpuCompiler {
515    inner: Arc<dyn GpuCompilerImpl>,
516}
517
518impl GpuCompiler {
519    /// Create a new compiler
520    pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
521        Self { inner }
522    }
523
524    /// Compile a kernel from source
525    pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
526        let kernel = self.inner.compile(source)?;
527        Ok(GpuKernelHandle::new(kernel))
528    }
529
530    /// Compile a kernel for the specified input and output types
531    pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
532        let kernel = self.inner.compile_typed(
533            name,
534            std::any::TypeId::of::<I>(),
535            std::any::TypeId::of::<O>(),
536        );
537        GpuKernelHandle::new(kernel)
538    }
539}
540
541/// GPU context for managing GPU resources and operations
542pub struct GpuContext {
543    inner: Arc<dyn GpuContextImpl>,
544    backend: GpuBackend,
545    kernel_registry: kernels::KernelRegistry,
546}
547
548impl GpuContext {
549    /// Create a new GPU context with the specified backend
550    pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
551        // First check if the backend is available at compile time
552        if !backend.is_available() {
553            return Err(GpuError::BackendNotAvailable(backend.to_string()));
554        }
555
556        // For non-CPU backends, also check runtime availability
557        if backend != GpuBackend::Cpu {
558            let detection_result = backends::detect_gpu_backends();
559            let backend_available = detection_result
560                .devices
561                .iter()
562                .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
563
564            if !backend_available {
565                return Err(GpuError::BackendNotAvailable(format!(
566                    "{backend} (no devices detected at runtime)"
567                )));
568            }
569        }
570
571        let inner = match backend {
572            GpuBackend::Cuda => {
573                #[cfg(feature = "cuda")]
574                {
575                    use crate::gpu::backends::cuda::CudaContext;
576                    match CudaContext::new() {
577                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
578                        Err(e) => return Err(e),
579                    }
580                }
581                #[cfg(not(feature = "cuda"))]
582                {
583                    return Err(GpuError::UnsupportedBackend(backend));
584                }
585            }
586            GpuBackend::Rocm => {
587                #[cfg(feature = "rocm")]
588                {
589                    // This is just a stub - in a real implementation, we would use the hip-sys crate
590                    // to create a ROCm context and return it
591                    #[cfg(test)]
592                    {
593                        // For testing, we can use a mock implementation
594                        Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
595                    }
596                    #[cfg(not(test))]
597                    {
598                        return Err(GpuError::BackendNotImplemented(backend));
599                    }
600                }
601                #[cfg(not(feature = "rocm"))]
602                {
603                    return Err(GpuError::UnsupportedBackend(backend));
604                }
605            }
606            GpuBackend::Wgpu => {
607                #[cfg(feature = "wgpu_backend")]
608                {
609                    use crate::gpu::backends::wgpu::WebGPUContext;
610                    match WebGPUContext::new() {
611                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
612                        Err(e) => return Err(e),
613                    }
614                }
615                #[cfg(not(feature = "wgpu_backend"))]
616                {
617                    return Err(GpuError::UnsupportedBackend(backend));
618                }
619            }
620            GpuBackend::Metal => {
621                #[cfg(all(feature = "metal", target_os = "macos"))]
622                {
623                    use crate::gpu::backends::metal::MetalContext;
624                    match MetalContext::new() {
625                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
626                        Err(e) => return Err(e),
627                    }
628                }
629                #[cfg(not(all(feature = "metal", target_os = "macos")))]
630                {
631                    return Err(GpuError::UnsupportedBackend(backend));
632                }
633            }
634            GpuBackend::OpenCL => {
635                #[cfg(feature = "opencl")]
636                {
637                    use crate::gpu::backends::opencl::OpenCLContext;
638                    match OpenCLContext::new() {
639                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
640                        Err(e) => return Err(e),
641                    }
642                }
643                #[cfg(not(feature = "opencl"))]
644                {
645                    return Err(GpuError::UnsupportedBackend(backend));
646                }
647            }
648            GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
649        };
650
651        Ok(Self {
652            inner,
653            backend,
654            kernel_registry: kernels::KernelRegistry::with_default_kernels(),
655        })
656    }
657
658    /// Get the backend type
659    pub fn backend(&self) -> GpuBackend {
660        self.backend
661    }
662
663    /// Get the backend name
664    pub fn backend_name(&self) -> &str {
665        match self.backend {
666            GpuBackend::Cuda => "CUDA",
667            GpuBackend::Rocm => "ROCm",
668            GpuBackend::Wgpu => "WebGPU",
669            GpuBackend::Metal => "Metal",
670            GpuBackend::OpenCL => "OpenCL",
671            GpuBackend::Cpu => "CPU",
672        }
673    }
674
675    /// Create a buffer with the given size
676    pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
677        let byte_size = size.saturating_mul(std::mem::size_of::<T>());
678        let inner = self.inner.create_buffer(byte_size);
679        GpuBuffer::new(inner, size)
680    }
681
682    /// Create a buffer from a slice
683    pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
684        let buffer = self.create_buffer::<T>(data.len());
685        let _ = buffer.copy_from_host(data);
686        buffer
687    }
688
689    /// Execute a function with a compiler
690    pub fn execute<F, R>(&self, f: F) -> R
691    where
692        F: FnOnce(&GpuCompiler) -> R,
693    {
694        let compiler = GpuCompiler::new(self.inner.create_compiler());
695        f(&compiler)
696    }
697
698    /// Get a kernel from the registry
699    pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
700        let kernel = self
701            .kernel_registry
702            .get(name)
703            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
704
705        let kernel_source = kernel.source_for_backend(self.backend)?;
706        let metadata = kernel.metadata();
707
708        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
709        Ok(handle)
710    }
711
712    /// Get a specialized kernel from the registry
713    pub fn get_specialized_kernel(
714        &self,
715        name: &str,
716        params: &kernels::KernelParams,
717    ) -> Result<GpuKernelHandle, GpuError> {
718        let specialized = self.kernel_registry.get_specialized(name, params)?;
719        let kernel_source = specialized.source_for_backend(self.backend)?;
720        let metadata = specialized.metadata();
721
722        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
723        Ok(handle)
724    }
725
726    /// Compile a kernel with metadata
727    fn compile_kernel_with_metadata(
728        &self,
729        source: &str,
730        _metadata: &kernels::KernelMetadata,
731    ) -> Result<GpuKernelHandle, GpuError> {
732        self.execute(|compiler| compiler.compile(source))
733    }
734
735    /// Get available memory on the device
736    pub fn get_available_memory(&self) -> Option<usize> {
737        // In a real implementation, this would query the device
738        // For now, return a placeholder value
739        Some(1024 * 1024 * 1024) // 1GB
740    }
741
742    /// Get total memory on the device
743    pub fn get_total_memory(&self) -> Option<usize> {
744        // In a real implementation, this would query the device
745        // For now, return a placeholder value
746        #[cfg(target_arch = "wasm32")]
747        return Some(512 * 1024 * 1024); // 512MB for WASM32
748
749        #[cfg(not(target_arch = "wasm32"))]
750        Some((4u64 * 1024 * 1024 * 1024) as usize) // 4GB for native
751    }
752
753    /// Launch a kernel with the given parameters
754    pub fn launch_kernel(
755        &self,
756        kernel_name: &str,
757        grid_size: (usize, usize, usize),
758        block_size: (usize, usize, usize),
759        args: &[DynamicKernelArg],
760    ) -> Result<(), GpuError> {
761        // Placeholder implementation
762        let _ = (kernel_name, grid_size, block_size, args);
763        Ok(())
764    }
765
766    /// Transfer data from host to device asynchronously
767    pub fn transfer_async_host_to_device<T: GpuDataType>(
768        &self,
769        ptr: &GpuPtr<T>,
770        data: &[T],
771    ) -> Result<(), GpuError> {
772        // Placeholder implementation
773        let _ = (ptr, data);
774        Ok(())
775    }
776
777    /// Transfer data from host to device synchronously
778    pub fn transfer_host_to_device<T: GpuDataType>(
779        &self,
780        ptr: &GpuPtr<T>,
781        data: &[T],
782    ) -> Result<(), GpuError> {
783        // Placeholder implementation
784        let _ = (ptr, data);
785        Ok(())
786    }
787
788    /// Transfer data from device to host asynchronously
789    pub fn transfer_async_device_to_host<T: GpuDataType>(
790        &self,
791        ptr: &GpuPtr<T>,
792        data: &mut [T],
793    ) -> Result<(), GpuError> {
794        // Placeholder implementation
795        let _ = (ptr, data);
796        Ok(())
797    }
798
799    /// Transfer data from device to host synchronously
800    pub fn transfer_device_to_host<T: GpuDataType>(
801        &self,
802        ptr: &GpuPtr<T>,
803        data: &mut [T],
804    ) -> Result<(), GpuError> {
805        // Placeholder implementation
806        let _ = (ptr, data);
807        Ok(())
808    }
809
810    /// Execute a kernel with dynamic compilation and parameter passing
811    /// This method is expected by scirs2-vision for GPU operations
812    pub fn execute_kernel(
813        &self,
814        source: &str,
815        buffers: &[GpuBuffer<f32>],
816        work_groups: (u32, u32, u32),
817        int_params: &[u32],
818        float_params: &[f32],
819    ) -> Result<(), GpuError> {
820        // For now, provide a basic implementation that logs the execution
821        // In a real implementation, this would compile and execute the kernel
822        eprintln!(
823            "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
824            source.len(),
825            buffers.len(),
826            work_groups
827        );
828        eprintln!("Int params: {int_params:?}");
829        eprintln!("Float params: {float_params:?}");
830        Ok(())
831    }
832
833    /// Read data from a GPU buffer
834    /// This method is expected by scirs2-vision for reading GPU results
835    pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
836        Ok(buffer.to_vec())
837    }
838
839    /// Global sum reduction
840    pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
841        self.sum_all_cpu_fallback(buffer)
842    }
843
844    /// Global mean reduction
845    pub fn mean_all<T: GpuDataType>(
846        &self,
847        buffer: &GpuBuffer<T>,
848    ) -> Result<GpuBuffer<T>, GpuError> {
849        self.mean_all_cpu_fallback(buffer)
850    }
851
852    /// Global max reduction
853    pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
854        self.max_all_cpu_fallback(buffer)
855    }
856
857    /// Global min reduction
858    pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
859        self.min_all_cpu_fallback(buffer)
860    }
861
862    /// Sum reduction along an axis
863    pub fn sum_axis<T: GpuDataType>(
864        &self,
865        buffer: &GpuBuffer<T>,
866        shape: &[usize],
867        axis: usize,
868    ) -> Result<GpuBuffer<T>, GpuError> {
869        self.sum_axis_cpu_fallback(buffer, shape, axis)
870    }
871
872    /// Mean reduction along an axis
873    pub fn mean_axis<T: GpuDataType>(
874        &self,
875        buffer: &GpuBuffer<T>,
876        shape: &[usize],
877        axis: usize,
878    ) -> Result<GpuBuffer<T>, GpuError> {
879        self.mean_axis_cpu_fallback(buffer, shape, axis)
880    }
881
882    /// Max reduction along an axis
883    pub fn max_axis<T: GpuDataType>(
884        &self,
885        buffer: &GpuBuffer<T>,
886        shape: &[usize],
887        axis: usize,
888    ) -> Result<GpuBuffer<T>, GpuError> {
889        self.max_axis_cpu_fallback(buffer, shape, axis)
890    }
891
892    /// Min reduction along an axis
893    pub fn min_axis<T: GpuDataType>(
894        &self,
895        buffer: &GpuBuffer<T>,
896        shape: &[usize],
897        axis: usize,
898    ) -> Result<GpuBuffer<T>, GpuError> {
899        self.min_axis_cpu_fallback(buffer, shape, axis)
900    }
901
902    /// Broadcast a buffer to a different shape
903    pub fn broadcast<T: GpuDataType>(
904        &self,
905        buffer: &GpuBuffer<T>,
906        from_shape: &[usize],
907        to_shape: &[usize],
908    ) -> Result<GpuBuffer<T>, GpuError> {
909        self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
910    }
911
912    /// Scale a buffer by a scalar value
913    pub fn scale<T: GpuDataType>(
914        &self,
915        buffer: &GpuBuffer<T>,
916        scalar: T,
917    ) -> Result<GpuBuffer<T>, GpuError> {
918        self.scale_cpu_fallback(buffer, scalar)
919    }
920
921    /// General matrix multiplication: C = A @ B
922    pub fn gemm<T: GpuDataType>(
923        &self,
924        a: &GpuBuffer<T>,
925        b: &GpuBuffer<T>,
926        m: usize,
927        k: usize,
928        n: usize,
929    ) -> Result<GpuBuffer<T>, GpuError> {
930        self.gemm_cpu_fallback(a, b, m, k, n)
931    }
932
933    /// GEMM with transposed B: C = A @ B^T
934    pub fn gemm_transpose_b<T: GpuDataType>(
935        &self,
936        a: &GpuBuffer<T>,
937        b: &GpuBuffer<T>,
938        m: usize,
939        k: usize,
940        n: usize,
941    ) -> Result<GpuBuffer<T>, GpuError> {
942        self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
943    }
944
945    /// GEMM with transposed A: C = A^T @ B
946    pub fn gemm_transpose_a<T: GpuDataType>(
947        &self,
948        a: &GpuBuffer<T>,
949        b: &GpuBuffer<T>,
950        m: usize,
951        k: usize,
952        n: usize,
953    ) -> Result<GpuBuffer<T>, GpuError> {
954        self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
955    }
956
957    /// ReLU activation forward pass
958    pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
959        self.relu_cpu_fallback(input)
960    }
961
962    /// ReLU backward pass
963    pub fn relu_backward<T: GpuDataType>(
964        &self,
965        grad_output: &GpuBuffer<T>,
966        input: &GpuBuffer<T>,
967    ) -> Result<GpuBuffer<T>, GpuError> {
968        self.relu_backward_cpu_fallback(grad_output, input)
969    }
970
971    /// Sigmoid activation forward pass
972    pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
973        self.sigmoid_cpu_fallback(input)
974    }
975
976    /// Sigmoid backward pass
977    pub fn sigmoid_backward<T: GpuDataType>(
978        &self,
979        grad_output: &GpuBuffer<T>,
980        input: &GpuBuffer<T>,
981    ) -> Result<GpuBuffer<T>, GpuError> {
982        self.sigmoid_backward_cpu_fallback(grad_output, input)
983    }
984
985    /// Tanh activation forward pass
986    pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
987        self.tanh_cpu_fallback(input)
988    }
989
990    /// Tanh backward pass
991    pub fn tanh_backward<T: GpuDataType>(
992        &self,
993        grad_output: &GpuBuffer<T>,
994        input: &GpuBuffer<T>,
995    ) -> Result<GpuBuffer<T>, GpuError> {
996        self.tanh_backward_cpu_fallback(grad_output, input)
997    }
998
999    /// GELU activation forward pass
1000    pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1001        self.gelu_cpu_fallback(input)
1002    }
1003
1004    /// GELU backward pass
1005    pub fn gelu_backward<T: GpuDataType>(
1006        &self,
1007        grad_output: &GpuBuffer<T>,
1008        input: &GpuBuffer<T>,
1009    ) -> Result<GpuBuffer<T>, GpuError> {
1010        self.gelu_backward_cpu_fallback(grad_output, input)
1011    }
1012}
1013
1014impl fmt::Debug for GpuContext {
1015    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1016        f.debug_struct("GpuContext")
1017            .field("backend", &self.backend)
1018            .finish()
1019    }
1020}
1021
1022// The following trait definitions would be implemented by backend-specific
1023// code in a real implementation
1024
1025/// GPU buffer implementation trait
1026pub(crate) trait GpuBufferImpl: Send + Sync {
1027    /// Copy data from host to device
1028    unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1029
1030    /// Copy data from device to host
1031    unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1032
1033    /// Get a reference to self as Any for downcasting
1034    #[allow(dead_code)]
1035    fn as_any(&self) -> &dyn std::any::Any;
1036
1037    /// Get the size of the buffer in bytes
1038    #[allow(dead_code)]
1039    fn size(&self) -> usize {
1040        0 // Default implementation for backward compatibility
1041    }
1042
1043    /// Get the device pointer (for backends that use device pointers)
1044    #[allow(dead_code)]
1045    fn device_ptr(&self) -> u64 {
1046        0 // Default implementation for backward compatibility
1047    }
1048}
1049
1050/// GPU kernel implementation trait
1051pub(crate) trait GpuKernelImpl: Send + Sync {
1052    /// Set a buffer parameter
1053    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1054
1055    /// Set a u32 parameter
1056    fn set_u32(&self, name: &str, value: u32);
1057
1058    /// Set an i32 parameter
1059    fn set_i32(&self, name: &str, value: i32);
1060
1061    /// Set an f32 parameter
1062    fn set_f32(&self, name: &str, value: f32);
1063
1064    /// Set an f64 parameter
1065    fn set_f64(&self, name: &str, value: f64);
1066
1067    /// Dispatch the kernel
1068    fn dispatch(&self, workgroups: [u32; 3]);
1069}
1070
1071/// GPU compiler implementation trait
1072pub(crate) trait GpuCompilerImpl: Send + Sync {
1073    /// Compile a kernel from source
1074    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1075
1076    /// Compile a typed kernel
1077    fn compile_typed(
1078        &self,
1079        name: &str,
1080        input_type: std::any::TypeId,
1081        output_type: std::any::TypeId,
1082    ) -> Arc<dyn GpuKernelImpl>;
1083}
1084
1085/// GPU context implementation trait
1086pub(crate) trait GpuContextImpl: Send + Sync {
1087    /// Create a buffer
1088    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1089
1090    /// Create a compiler
1091    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1092
1093    /// Support dynamic downcasting of concrete context implementations
1094    fn as_any(&self) -> &dyn std::any::Any
1095    where
1096        Self: 'static + Sized,
1097    {
1098        self
1099    }
1100}
1101
1102// CPU fallback implementation
1103
1104/// CPU context implementation
1105struct CpuContext;
1106
1107impl CpuContext {
1108    /// Create a new CPU context
1109    fn new() -> Self {
1110        Self
1111    }
1112}
1113
1114impl GpuContextImpl for CpuContext {
1115    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1116        Arc::new(CpuBuffer::new(size))
1117    }
1118
1119    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1120        Arc::new(CpuCompiler)
1121    }
1122}
1123
1124/// CPU buffer implementation
1125struct CpuBuffer {
1126    data: Vec<u8>,
1127}
1128
1129impl CpuBuffer {
1130    /// Create a new CPU buffer
1131    fn new(size: usize) -> Self {
1132        Self {
1133            data: vec![0; size],
1134        }
1135    }
1136}
1137
1138impl GpuBufferImpl for CpuBuffer {
1139    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1140        let mut_self = self as *const Self as *mut Self;
1141        let data_ptr = (*mut_self).data.as_mut_ptr();
1142        std::ptr::copy_nonoverlapping(data, data_ptr, size);
1143    }
1144
1145    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1146        let data_ptr = self.data.as_ptr();
1147        std::ptr::copy_nonoverlapping(data_ptr, data, size);
1148    }
1149
1150    fn as_any(&self) -> &dyn std::any::Any {
1151        self
1152    }
1153
1154    fn size(&self) -> usize {
1155        self.data.len()
1156    }
1157
1158    fn device_ptr(&self) -> u64 {
1159        self.data.as_ptr() as u64
1160    }
1161}
1162
1163/// CPU compiler implementation
1164struct CpuCompiler;
1165
1166impl GpuCompilerImpl for CpuCompiler {
1167    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1168        // In a real implementation, we would parse and execute the kernel
1169        // For now, just return a dummy implementation
1170        Ok(Arc::new(CpuKernel))
1171    }
1172
1173    fn compile_typed(
1174        &self,
1175        _name: &str,
1176        _input_type: std::any::TypeId,
1177        _output_type: std::any::TypeId,
1178    ) -> Arc<dyn GpuKernelImpl> {
1179        // In a real implementation, we would select an appropriate implementation
1180        // For now, just return a dummy implementation
1181        Arc::new(CpuKernel)
1182    }
1183}
1184
1185/// CPU kernel implementation
1186struct CpuKernel;
1187
1188impl GpuKernelImpl for CpuKernel {
1189    fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1190        // In a real implementation, we would store the buffer
1191    }
1192
1193    fn set_u32(&self, _name: &str, value: u32) {
1194        // In a real implementation, we would store the value
1195    }
1196
1197    fn set_i32(&self, _name: &str, value: i32) {
1198        // In a real implementation, we would store the value
1199    }
1200
1201    fn set_f32(&self, _name: &str, value: f32) {
1202        // In a real implementation, we would store the value
1203    }
1204
1205    fn set_f64(&self, _name: &str, value: f64) {
1206        // In a real implementation, we would store the value
1207    }
1208
1209    fn dispatch(&self, workgroups: [u32; 3]) {
1210        // In a real implementation, we would execute the kernel
1211    }
1212}
1213
1214// In a real implementation, we would have implementations for other backends
1215// such as CUDA, WebGPU, Metal, and OpenCL.
1216
1217#[cfg(test)]
1218mod tests {
1219    use super::*;
1220
1221    #[test]
1222    fn test_gpu_backend_preferred() {
1223        let backend = GpuBackend::preferred();
1224        // Should return a valid backend
1225        match backend {
1226            GpuBackend::Cuda
1227            | GpuBackend::Rocm
1228            | GpuBackend::Wgpu
1229            | GpuBackend::Metal
1230            | GpuBackend::OpenCL
1231            | GpuBackend::Cpu => {}
1232        }
1233    }
1234
1235    #[test]
1236    fn test_gpu_backend_default() {
1237        let backend = GpuBackend::default();
1238        assert_eq!(backend, GpuBackend::preferred());
1239    }
1240
1241    #[test]
1242    fn test_gpu_backend_is_available() {
1243        let backend = GpuBackend::Cpu;
1244        assert!(backend.is_available());
1245
1246        // Test other backends - availability depends on runtime, not just feature flags
1247        #[cfg(feature = "cuda")]
1248        {
1249            // CUDA feature enabled doesn't guarantee runtime availability
1250            let _ = GpuBackend::Cuda.is_available(); // Just check without asserting
1251        }
1252        #[cfg(not(feature = "cuda"))]
1253        assert!(!GpuBackend::Cuda.is_available());
1254
1255        #[cfg(feature = "rocm")]
1256        {
1257            // ROCm feature enabled doesn't guarantee runtime availability
1258            let _ = GpuBackend::Rocm.is_available(); // Just check without asserting
1259        }
1260        #[cfg(not(feature = "rocm"))]
1261        assert!(!GpuBackend::Rocm.is_available());
1262
1263        #[cfg(all(feature = "metal", target_os = "macos"))]
1264        assert!(GpuBackend::Metal.is_available());
1265        #[cfg(not(all(feature = "metal", target_os = "macos")))]
1266        assert!(!GpuBackend::Metal.is_available());
1267    }
1268
1269    #[test]
1270    fn test_gpu_backend_display() {
1271        assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1272        assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1273        assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1274        assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1275        assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1276        assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1277    }
1278
1279    #[test]
1280    fn test_gpuerror_from_conversion() {
1281        let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1282        let coreerror: CoreError = gpuerror.into();
1283        match coreerror {
1284            CoreError::ComputationError(_) => {}
1285            _ => panic!("Expected ComputationError"),
1286        }
1287
1288        let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1289        let coreerror: CoreError = gpuerror.into();
1290        match coreerror {
1291            CoreError::MemoryError(_) => {}
1292            _ => panic!("Expected MemoryError"),
1293        }
1294
1295        let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1296        let coreerror: CoreError = gpuerror.into();
1297        match coreerror {
1298            CoreError::InvalidArgument(_) => {}
1299            _ => panic!("Expected InvalidArgument"),
1300        }
1301
1302        let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1303        let coreerror: CoreError = gpuerror.into();
1304        match coreerror {
1305            CoreError::TypeError(_) => {}
1306            _ => panic!("Expected TypeError"),
1307        }
1308    }
1309
1310    #[test]
1311    fn test_gpu_datatype_trait() {
1312        // Test that various types implement GpuDataType
1313        fn assert_gpu_datatype<T: GpuDataType>() {}
1314
1315        assert_gpu_datatype::<f32>();
1316        assert_gpu_datatype::<f64>();
1317        assert_gpu_datatype::<i32>();
1318        assert_gpu_datatype::<u32>();
1319        assert_gpu_datatype::<u8>();
1320        assert_gpu_datatype::<i8>();
1321        assert_gpu_datatype::<u16>();
1322        assert_gpu_datatype::<i16>();
1323        assert_gpu_datatype::<u64>();
1324        assert_gpu_datatype::<i64>();
1325    }
1326
1327    #[test]
1328    fn test_gpu_buffer_creation() {
1329        let inner = Arc::new(CpuBuffer::new(100));
1330        let buffer = GpuBuffer::<f32>::new(inner, 25);
1331
1332        assert_eq!(buffer.len(), 25);
1333        assert!(!buffer.is_empty());
1334    }
1335
1336    #[test]
1337    fn test_gpu_buffer_empty() {
1338        let inner = Arc::new(CpuBuffer::new(0));
1339        let buffer = GpuBuffer::<f32>::new(inner, 0);
1340
1341        assert_eq!(buffer.len(), 0);
1342        assert!(buffer.is_empty());
1343    }
1344
1345    #[test]
1346    fn test_gpu_buffer_copy_operations() {
1347        let inner = Arc::new(CpuBuffer::new(16));
1348        let buffer = GpuBuffer::<f32>::new(inner, 4);
1349
1350        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1351        let _ = buffer.copy_from_host(&data);
1352
1353        let mut result = vec![0.0f32; 4];
1354        let _ = buffer.copy_to_host(&mut result);
1355
1356        assert_eq!(result, data);
1357    }
1358
1359    #[test]
1360    fn test_gpu_buffer_to_vec() {
1361        let inner = Arc::new(CpuBuffer::new(12));
1362        let buffer = GpuBuffer::<f32>::new(inner, 3);
1363
1364        let data = vec![5.0f32, 6.0, 7.0];
1365        let _ = buffer.copy_from_host(&data);
1366
1367        let result = buffer.to_vec();
1368        assert_eq!(result, data);
1369    }
1370
1371    #[test]
1372    #[should_panic(expected = "Data size exceeds buffer size")]
1373    fn test_gpu_buffer_copy_from_host_overflow() {
1374        let inner = Arc::new(CpuBuffer::new(8));
1375        let buffer = GpuBuffer::<f32>::new(inner, 2);
1376
1377        let data = vec![1.0f32, 2.0, 3.0]; // 3 elements > 2 buffer size
1378        buffer.copy_from_host(&data).expect("Operation failed");
1379    }
1380
1381    #[test]
1382    #[should_panic(expected = "Data size exceeds buffer size")]
1383    fn test_gpu_buffer_copy_to_host_overflow() {
1384        let inner = Arc::new(CpuBuffer::new(8));
1385        let buffer = GpuBuffer::<f32>::new(inner, 2);
1386
1387        let mut data = vec![0.0f32; 3]; // 3 elements > 2 buffer size
1388        buffer.copy_to_host(&mut data).expect("Operation failed");
1389    }
1390
1391    #[test]
1392    fn test_gpu_kernel_handle() {
1393        let kernel = Arc::new(CpuKernel);
1394        let handle = GpuKernelHandle::new(kernel);
1395
1396        // Test setting various parameter types
1397        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1398        handle.set_buffer("input", &buffer);
1399        handle.set_u32("size", 100);
1400        handle.set_i32("offset", -5);
1401        handle.set_f32("scale", 2.5);
1402        handle.set_f64("precision", 0.0001);
1403
1404        // Test dispatch
1405        handle.dispatch([16, 8, 1]);
1406    }
1407
1408    #[test]
1409    fn test_gpu_context_cpu_backend() {
1410        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1411        assert_eq!(context.backend(), GpuBackend::Cpu);
1412        assert_eq!(context.backend_name(), "CPU");
1413
1414        // Test memory query methods
1415        assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1416        assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1417    }
1418
1419    #[test]
1420    fn test_gpu_context_buffer_creation() {
1421        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1422
1423        let buffer = context.create_buffer::<f32>(100);
1424        assert_eq!(buffer.len(), 100);
1425
1426        let data = vec![1.0f32; 50];
1427        let buffer_from_slice = context.create_buffer_from_slice(&data);
1428        assert_eq!(buffer_from_slice.len(), 50);
1429
1430        let result = buffer_from_slice.to_vec();
1431        assert_eq!(result, data);
1432    }
1433
1434    #[test]
1435    fn test_gpu_context_unsupported_backend() {
1436        // Test a backend that's not available
1437        #[cfg(not(feature = "cuda"))]
1438        {
1439            let result = GpuContext::new(GpuBackend::Cuda);
1440            assert!(result.is_err());
1441            match result {
1442                Err(GpuError::UnsupportedBackend(_)) => {}
1443                Err(GpuError::BackendNotAvailable(_)) => {} // Also accept this error
1444                Err(e) => panic!(
1445                    "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1446                    e
1447                ),
1448                Ok(_) => panic!("Expected error, got Ok"),
1449            }
1450        }
1451    }
1452
1453    #[test]
1454    fn test_gpu_compiler() {
1455        let compiler_impl = Arc::new(CpuCompiler);
1456        let compiler = GpuCompiler::new(compiler_impl);
1457
1458        // Test compiling from source
1459        let kernel = compiler
1460            .compile("dummy kernel source")
1461            .expect("Operation failed");
1462        kernel.dispatch([1, 1, 1]);
1463
1464        // Test typed compilation
1465        let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1466        typed_kernel.dispatch([32, 1, 1]);
1467    }
1468
1469    #[test]
1470    fn test_gpu_context_execute() {
1471        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1472
1473        let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1474
1475        assert!(result);
1476    }
1477
1478    #[test]
1479    fn test_gpu_context_kernel_registry() {
1480        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1481
1482        // Test getting a non-existent kernel
1483        let result = context.get_kernel("non_existent_kernel");
1484        assert!(result.is_err());
1485        match result {
1486            Err(GpuError::KernelNotFound(_)) => {}
1487            _ => panic!("Expected KernelNotFound error"),
1488        }
1489    }
1490
1491    #[test]
1492    fn test_cpu_buffer_implementation() {
1493        let buffer = CpuBuffer::new(256);
1494        assert_eq!(buffer.data.len(), 256);
1495
1496        // Test that initial data is zeroed
1497        assert!(buffer.data.iter().all(|&b| b == 0));
1498    }
1499
1500    #[test]
1501    fn test_gpuerror_display() {
1502        let error = GpuError::BackendNotAvailable("CUDA".to_string());
1503        assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1504
1505        let error = GpuError::OutOfMemory("allocation failed".to_string());
1506        assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1507
1508        let error = GpuError::KernelCompilationError("syntax error".to_string());
1509        assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1510
1511        let error = GpuError::KernelNotFound("gemm".to_string());
1512        assert_eq!(error.to_string(), "Kernel not found: gemm");
1513    }
1514
1515    #[test]
1516    fn test_backend_equality() {
1517        assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1518        assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1519
1520        // Test Clone and Copy
1521        let backend = GpuBackend::Metal;
1522        let cloned = backend;
1523        let copied = backend;
1524        assert_eq!(backend, cloned);
1525        assert_eq!(backend, copied);
1526    }
1527
1528    #[test]
1529    fn test_backend_hash() {
1530        use std::collections::HashSet;
1531
1532        let mut set = HashSet::new();
1533        set.insert(GpuBackend::Cuda);
1534        set.insert(GpuBackend::Rocm);
1535        set.insert(GpuBackend::Cuda); // Duplicate
1536
1537        assert_eq!(set.len(), 2); // Should only have 2 unique entries
1538        assert!(set.contains(&GpuBackend::Cuda));
1539        assert!(set.contains(&GpuBackend::Rocm));
1540    }
1541
1542    #[test]
1543    fn test_gpu_buffer_debug_clone() {
1544        let inner = Arc::new(CpuBuffer::new(16));
1545        let buffer = GpuBuffer::<f32>::new(inner, 4);
1546
1547        // Test Debug implementation
1548        let debug_str = format!("{:?}", buffer);
1549        assert!(debug_str.contains("GpuBuffer"));
1550        assert!(debug_str.contains("size"));
1551
1552        // Test Clone implementation
1553        let cloned = buffer.clone();
1554        assert_eq!(cloned.len(), buffer.len());
1555        assert_eq!(cloned.len(), 4);
1556
1557        // Verify the clone is independent (shares the same Arc)
1558        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1559        let _ = buffer.copy_from_host(&data);
1560
1561        let mut result = vec![0.0f32; 4];
1562        let _ = cloned.copy_to_host(&mut result);
1563        assert_eq!(result, data);
1564    }
1565
1566    #[test]
1567    fn test_gpu_context_debug() {
1568        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1569
1570        // Test Debug implementation
1571        let debug_str = format!("{:?}", context);
1572        assert!(debug_str.contains("GpuContext"));
1573        assert!(debug_str.contains("backend"));
1574        assert!(debug_str.contains("Cpu"));
1575    }
1576}