Skip to main content

scirs2_core/gpu/
mod.rs

1//! GPU acceleration module for scirs2-core
2//!
3//! This module provides hardware acceleration support across different GPU backends.
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod auto_tuning;
11pub mod backends;
12pub mod benchmarks;
13pub mod heterogeneous;
14pub mod kernels;
15pub mod tensor_cores;
16
17/// GPU backend type
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum GpuBackend {
20    /// NVIDIA CUDA backend
21    Cuda,
22    /// AMD ROCm backend
23    Rocm,
24    /// WebGPU backend
25    Wgpu,
26    /// Apple Metal backend
27    Metal,
28    /// OpenCL backend
29    OpenCL,
30    /// CPU fallback
31    Cpu,
32}
33
34impl Default for GpuBackend {
35    fn default() -> Self {
36        Self::preferred()
37    }
38}
39
40impl GpuBackend {
41    /// Get the preferred GPU backend for the current system
42    pub fn preferred() -> Self {
43        // Use the backend detection system to find the optimal backend
44        // This will properly detect available GPUs and fall back to CPU if needed
45        match backends::initialize_optimal_backend() {
46            Ok(backend) => {
47                // If we get a non-CPU backend, verify it's actually usable
48                if backend != GpuBackend::Cpu {
49                    // Check if we can actually create a context with this backend
50                    // For now, since implementations are stubs, fall back to CPU
51                    #[cfg(not(test))]
52                    {
53                        // In non-test environments, we don't have real GPU implementations yet
54                        return GpuBackend::Cpu;
55                    }
56                    #[cfg(test)]
57                    {
58                        // In tests, we can pretend the backend works
59                        return backend;
60                    }
61                }
62                backend
63            }
64            Err(_) => {
65                // If detection fails entirely, use CPU
66                GpuBackend::Cpu
67            }
68        }
69    }
70
71    /// Check if this backend is available on the current system
72    pub fn is_available(&self) -> bool {
73        match self {
74            // Check runtime availability for GPU backends
75            GpuBackend::Cuda => {
76                #[cfg(feature = "cuda")]
77                {
78                    use crate::gpu::backends::cuda::CudaContext;
79                    CudaContext::is_available()
80                }
81                #[cfg(not(feature = "cuda"))]
82                {
83                    false
84                }
85            }
86            GpuBackend::Rocm => cfg!(feature = "rocm"), // Would use ROCm runtime check
87            GpuBackend::Wgpu => {
88                #[cfg(feature = "wgpu_backend")]
89                {
90                    use crate::gpu::backends::wgpu::WebGPUContext;
91                    WebGPUContext::is_available()
92                }
93                #[cfg(not(feature = "wgpu_backend"))]
94                {
95                    false
96                }
97            }
98            GpuBackend::Metal => {
99                #[cfg(all(feature = "metal", target_os = "macos"))]
100                {
101                    // Metal is always available on macOS if the feature is enabled
102                    true
103                }
104                #[cfg(not(all(feature = "metal", target_os = "macos")))]
105                {
106                    false
107                }
108            }
109            GpuBackend::OpenCL => {
110                #[cfg(feature = "opencl")]
111                {
112                    use crate::gpu::backends::opencl::OpenCLContext;
113                    OpenCLContext::is_available()
114                }
115                #[cfg(not(feature = "opencl"))]
116                {
117                    false
118                }
119            }
120            GpuBackend::Cpu => true,
121        }
122    }
123}
124
125impl fmt::Display for GpuBackend {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            GpuBackend::Cuda => write!(f, "CUDA"),
129            GpuBackend::Rocm => write!(f, "ROCm"),
130            GpuBackend::Wgpu => write!(f, "WebGPU"),
131            GpuBackend::Metal => write!(f, "Metal"),
132            GpuBackend::OpenCL => write!(f, "OpenCL"),
133            GpuBackend::Cpu => write!(f, "CPU"),
134        }
135    }
136}
137
138use crate::error::{CoreError, ErrorContext, ErrorLocation};
139
140/// Error type for GPU operations
141#[derive(Debug, thiserror::Error)]
142pub enum GpuError {
143    /// Backend is not available
144    #[error("GPU backend {0} is not available")]
145    BackendNotAvailable(String),
146
147    /// Backend is not supported
148    #[error("GPU backend {0} is not supported")]
149    UnsupportedBackend(GpuBackend),
150
151    /// Backend is not supported for a kernel
152    #[error("GPU backend {0:?} is not supported for this kernel")]
153    BackendNotSupported(GpuBackend),
154
155    /// Backend is not implemented yet
156    #[error("GPU backend {0} is not implemented yet")]
157    BackendNotImplemented(GpuBackend),
158
159    /// Out of memory
160    #[error("GPU out of memory: {0}")]
161    OutOfMemory(String),
162
163    /// Kernel compilation error
164    #[error("Kernel compilation error: {0}")]
165    KernelCompilationError(String),
166
167    /// Kernel execution error
168    #[error("Kernel execution error: {0}")]
169    KernelExecutionError(String),
170
171    /// Invalid parameter
172    #[error("Invalid parameter: {0}")]
173    InvalidParameter(String),
174
175    /// Kernel not found
176    #[error("Kernel not found: {0}")]
177    KernelNotFound(String),
178
179    /// Specialization not supported
180    #[error("Kernel specialization not supported")]
181    SpecializationNotSupported,
182
183    /// Unsupported data type
184    #[error("Unsupported data type: {0:?}")]
185    UnsupportedDataType(kernels::DataType),
186
187    /// Other error
188    #[error("{0}")]
189    Other(String),
190}
191
192/// GPU device abstraction
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub struct GpuDevice {
195    backend: GpuBackend,
196    device_id: usize,
197}
198
199impl GpuDevice {
200    /// Create a new GPU device
201    pub fn new(backend: GpuBackend, device_id: usize) -> Self {
202        Self { backend, device_id }
203    }
204
205    /// Get the backend type
206    pub fn backend(&self) -> GpuBackend {
207        self.backend
208    }
209
210    /// Get the device ID
211    pub fn device_id(&self) -> usize {
212        self.device_id
213    }
214
215    /// Compile a kernel from source
216    pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
217        // Placeholder implementation
218        Ok(GpuKernel {
219            backend: self.backend,
220            entry_point: entrypoint.to_string(),
221        })
222    }
223}
224
225/// GPU kernel abstraction
226pub struct GpuKernel {
227    backend: GpuBackend,
228    entry_point: String,
229}
230
231impl GpuKernel {
232    /// Get the backend type
233    pub fn backend(&self) -> GpuBackend {
234        self.backend
235    }
236
237    /// Get the entry point name
238    pub fn entry_point(&self) -> &str {
239        &self.entry_point
240    }
241}
242
243/// Convert GPU errors to core errors with semantic preservation
244impl From<GpuError> for CoreError {
245    fn from(err: GpuError) -> Self {
246        match err {
247            GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
248                ErrorContext::new(format!("GPU backend {backend} is not available"))
249                    .with_location(ErrorLocation::new(file!(), line!())),
250            ),
251            GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
252                ErrorContext::new(format!("GPU backend {backend} is not supported"))
253                    .with_location(ErrorLocation::new(file!(), line!())),
254            ),
255            GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
256                ErrorContext::new(format!(
257                    "GPU backend {backend:?} is not supported for this kernel"
258                ))
259                .with_location(ErrorLocation::new(file!(), line!())),
260            ),
261            GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
262                ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
263                    .with_location(ErrorLocation::new(file!(), line!())),
264            ),
265            GpuError::OutOfMemory(details) => CoreError::MemoryError(
266                ErrorContext::new(details.to_string())
267                    .with_location(ErrorLocation::new(file!(), line!())),
268            ),
269            GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
270                ErrorContext::new(msg.to_string())
271                    .with_location(ErrorLocation::new(file!(), line!())),
272            ),
273            GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
274                ErrorContext::new(msg.to_string())
275                    .with_location(ErrorLocation::new(file!(), line!())),
276            ),
277            GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
278                ErrorContext::new(msg.to_string())
279                    .with_location(ErrorLocation::new(file!(), line!())),
280            ),
281            GpuError::KernelNotFound(name) => CoreError::ComputationError(
282                ErrorContext::new(name.to_string())
283                    .with_location(ErrorLocation::new(file!(), line!())),
284            ),
285            GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
286                ErrorContext::new("Kernel specialization not supported".to_string())
287                    .with_location(ErrorLocation::new(file!(), line!())),
288            ),
289            GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
290                ErrorContext::new(format!("{dtype:?}"))
291                    .with_location(ErrorLocation::new(file!(), line!())),
292            ),
293            GpuError::Other(msg) => CoreError::ComputationError(
294                ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
295            ),
296        }
297    }
298}
299
300/// Trait for types that can be used with GPU operations
301pub trait GpuDataType: Copy + Send + Sync + 'static {}
302
303/// GPU memory pointer abstraction
304#[derive(Debug)]
305pub struct GpuPtr<T: GpuDataType> {
306    ptr: u64,
307    size: usize,
308    phantom: PhantomData<T>,
309}
310
311impl<T: GpuDataType> GpuPtr<T> {
312    /// Allocate GPU memory
313    pub fn allocate(size: usize) -> Result<Self, GpuError> {
314        Ok(GpuPtr {
315            ptr: 0x1000_0000, // Placeholder address
316            size,
317            phantom: PhantomData,
318        })
319    }
320
321    /// Get the raw pointer value
322    pub fn as_ptr(&self) -> u64 {
323        self.ptr
324    }
325
326    /// Get the size in elements
327    pub fn len(&self) -> usize {
328        self.size
329    }
330
331    /// Check if the pointer is empty (size is 0)
332    pub fn is_empty(&self) -> bool {
333        self.size == 0
334    }
335}
336
337/// Kernel argument types for GPU kernel execution
338#[derive(Debug, Clone)]
339pub enum KernelArg<'a, T: GpuDataType> {
340    /// Buffer argument
341    Buffer(&'a GpuPtr<T>),
342    /// Scalar argument
343    Scalar(T),
344}
345
346/// Non-generic kernel argument for mixed-type kernel calls
347#[derive(Debug, Clone)]
348pub enum DynamicKernelArg {
349    /// Buffer argument (type-erased)
350    Buffer(u64), // Raw pointer
351    /// f32 scalar
352    F32(f32),
353    /// f64 scalar
354    F64(f64),
355    /// i32 scalar
356    I32(i32),
357    /// u32 scalar
358    U32(u32),
359    /// usize scalar
360    Usize(usize),
361}
362
363/// GPU communication channel for multi-GPU operations
364pub struct GpuChannel {
365    #[allow(dead_code)]
366    source_device: usize,
367    #[allow(dead_code)]
368    target_device: usize,
369    #[allow(dead_code)]
370    bandwidth: f64, // GB/s
371}
372
373// Implement for common types
374impl GpuDataType for f32 {}
375impl GpuDataType for f64 {}
376impl GpuDataType for i32 {}
377impl GpuDataType for u32 {}
378impl GpuDataType for u8 {}
379impl GpuDataType for i8 {}
380impl GpuDataType for u16 {}
381impl GpuDataType for i16 {}
382impl GpuDataType for u64 {}
383impl GpuDataType for i64 {}
384impl GpuDataType for usize {}
385impl GpuDataType for isize {}
386
387/// GPU buffer
388pub struct GpuBuffer<T: GpuDataType> {
389    inner: Arc<dyn GpuBufferImpl>,
390    size: usize,
391    phantom: PhantomData<T>,
392}
393
394impl<T: GpuDataType> GpuBuffer<T> {
395    /// Create a new buffer with the given size
396    pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
397        Self {
398            inner,
399            size,
400            phantom: PhantomData,
401        }
402    }
403
404    /// Get the size of the buffer in elements
405    pub fn len(&self) -> usize {
406        self.size
407    }
408
409    /// Check if the buffer is empty
410    pub fn is_empty(&self) -> bool {
411        self.size == 0
412    }
413
414    /// Copy data from the host to the device
415    pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
416        if data.len() > self.size {
417            return Err(GpuError::InvalidParameter(
418                "Data size exceeds buffer size".to_string(),
419            ));
420        }
421        unsafe {
422            self.inner
423                .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
424        }
425        Ok(())
426    }
427
428    /// Copy data from the device to the host
429    pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
430        if data.len() > self.size {
431            return Err(GpuError::InvalidParameter(
432                "Data size exceeds buffer size".to_string(),
433            ));
434        }
435        unsafe {
436            self.inner
437                .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
438        }
439        Ok(())
440    }
441
442    /// Convert the buffer contents to a vector
443    pub fn to_vec(&self) -> Vec<T> {
444        let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
445        let _ = self.copy_to_host(&mut result);
446        result
447    }
448}
449
450impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
451    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452        f.debug_struct("GpuBuffer")
453            .field("size", &self.size)
454            .finish()
455    }
456}
457
458impl<T: GpuDataType> Clone for GpuBuffer<T> {
459    fn clone(&self) -> Self {
460        Self {
461            inner: Arc::clone(&self.inner),
462            size: self.size,
463            phantom: PhantomData,
464        }
465    }
466}
467
468/// GPU kernel handle
469#[derive(Clone)]
470pub struct GpuKernelHandle {
471    inner: Arc<dyn GpuKernelImpl>,
472}
473
474impl GpuKernelHandle {
475    /// Create a new kernel handle
476    pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
477        Self { inner }
478    }
479
480    /// Set a buffer parameter
481    pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
482        self.inner.set_buffer(name, &buffer.inner);
483    }
484
485    /// Set a u32 parameter
486    pub fn set_u32(&self, name: &str, value: u32) {
487        self.inner.set_u32(name, value);
488    }
489
490    /// Set an i32 parameter
491    pub fn set_i32(&self, name: &str, value: i32) {
492        self.inner.set_i32(name, value);
493    }
494
495    /// Set an f32 parameter
496    pub fn set_f32(&self, name: &str, value: f32) {
497        self.inner.set_f32(name, value);
498    }
499
500    /// Set an f64 parameter
501    pub fn set_f64(&self, name: &str, value: f64) {
502        self.inner.set_f64(name, value);
503    }
504
505    /// Dispatch the kernel with the given work group counts
506    pub fn dispatch(&self, workgroups: [u32; 3]) {
507        self.inner.dispatch(workgroups);
508    }
509}
510
511/// GPU compiler for dynamic kernels
512pub struct GpuCompiler {
513    inner: Arc<dyn GpuCompilerImpl>,
514}
515
516impl GpuCompiler {
517    /// Create a new compiler
518    pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
519        Self { inner }
520    }
521
522    /// Compile a kernel from source
523    pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
524        let kernel = self.inner.compile(source)?;
525        Ok(GpuKernelHandle::new(kernel))
526    }
527
528    /// Compile a kernel for the specified input and output types
529    pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
530        let kernel = self.inner.compile_typed(
531            name,
532            std::any::TypeId::of::<I>(),
533            std::any::TypeId::of::<O>(),
534        );
535        GpuKernelHandle::new(kernel)
536    }
537}
538
539/// GPU context for managing GPU resources and operations
540pub struct GpuContext {
541    inner: Arc<dyn GpuContextImpl>,
542    backend: GpuBackend,
543    kernel_registry: kernels::KernelRegistry,
544}
545
546impl GpuContext {
547    /// Create a new GPU context with the specified backend
548    pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
549        // First check if the backend is available at compile time
550        if !backend.is_available() {
551            return Err(GpuError::BackendNotAvailable(backend.to_string()));
552        }
553
554        // For non-CPU backends, also check runtime availability
555        if backend != GpuBackend::Cpu {
556            let detection_result = backends::detect_gpu_backends();
557            let backend_available = detection_result
558                .devices
559                .iter()
560                .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
561
562            if !backend_available {
563                return Err(GpuError::BackendNotAvailable(format!(
564                    "{backend} (no devices detected at runtime)"
565                )));
566            }
567        }
568
569        let inner = match backend {
570            GpuBackend::Cuda => {
571                #[cfg(feature = "cuda")]
572                {
573                    use crate::gpu::backends::cuda::CudaContext;
574                    match CudaContext::new() {
575                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
576                        Err(e) => return Err(e),
577                    }
578                }
579                #[cfg(not(feature = "cuda"))]
580                {
581                    return Err(GpuError::UnsupportedBackend(backend));
582                }
583            }
584            GpuBackend::Rocm => {
585                #[cfg(feature = "rocm")]
586                {
587                    // This is just a stub - in a real implementation, we would use the hip-sys crate
588                    // to create a ROCm context and return it
589                    #[cfg(test)]
590                    {
591                        // For testing, we can use a mock implementation
592                        Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
593                    }
594                    #[cfg(not(test))]
595                    {
596                        return Err(GpuError::BackendNotImplemented(backend));
597                    }
598                }
599                #[cfg(not(feature = "rocm"))]
600                {
601                    return Err(GpuError::UnsupportedBackend(backend));
602                }
603            }
604            GpuBackend::Wgpu => {
605                #[cfg(feature = "wgpu_backend")]
606                {
607                    use crate::gpu::backends::wgpu::WebGPUContext;
608                    match WebGPUContext::new() {
609                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
610                        Err(e) => return Err(e),
611                    }
612                }
613                #[cfg(not(feature = "wgpu_backend"))]
614                {
615                    return Err(GpuError::UnsupportedBackend(backend));
616                }
617            }
618            GpuBackend::Metal => {
619                #[cfg(all(feature = "metal", target_os = "macos"))]
620                {
621                    use crate::gpu::backends::metal::MetalContext;
622                    match MetalContext::new() {
623                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
624                        Err(e) => return Err(e),
625                    }
626                }
627                #[cfg(not(all(feature = "metal", target_os = "macos")))]
628                {
629                    return Err(GpuError::UnsupportedBackend(backend));
630                }
631            }
632            GpuBackend::OpenCL => {
633                #[cfg(feature = "opencl")]
634                {
635                    use crate::gpu::backends::opencl::OpenCLContext;
636                    match OpenCLContext::new() {
637                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
638                        Err(e) => return Err(e),
639                    }
640                }
641                #[cfg(not(feature = "opencl"))]
642                {
643                    return Err(GpuError::UnsupportedBackend(backend));
644                }
645            }
646            GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
647        };
648
649        Ok(Self {
650            inner,
651            backend,
652            kernel_registry: kernels::KernelRegistry::with_default_kernels(),
653        })
654    }
655
656    /// Get the backend type
657    pub fn backend(&self) -> GpuBackend {
658        self.backend
659    }
660
661    /// Get the backend name
662    pub fn backend_name(&self) -> &str {
663        match self.backend {
664            GpuBackend::Cuda => "CUDA",
665            GpuBackend::Rocm => "ROCm",
666            GpuBackend::Wgpu => "WebGPU",
667            GpuBackend::Metal => "Metal",
668            GpuBackend::OpenCL => "OpenCL",
669            GpuBackend::Cpu => "CPU",
670        }
671    }
672
673    /// Create a buffer with the given size
674    pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
675        let byte_size = size.saturating_mul(std::mem::size_of::<T>());
676        let inner = self.inner.create_buffer(byte_size);
677        GpuBuffer::new(inner, size)
678    }
679
680    /// Create a buffer from a slice
681    pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
682        let buffer = self.create_buffer::<T>(data.len());
683        let _ = buffer.copy_from_host(data);
684        buffer
685    }
686
687    /// Execute a function with a compiler
688    pub fn execute<F, R>(&self, f: F) -> R
689    where
690        F: FnOnce(&GpuCompiler) -> R,
691    {
692        let compiler = GpuCompiler::new(self.inner.create_compiler());
693        f(&compiler)
694    }
695
696    /// Get a kernel from the registry
697    pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
698        let kernel = self
699            .kernel_registry
700            .get(name)
701            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
702
703        let kernel_source = kernel.source_for_backend(self.backend)?;
704        let metadata = kernel.metadata();
705
706        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
707        Ok(handle)
708    }
709
710    /// Get a specialized kernel from the registry
711    pub fn get_specialized_kernel(
712        &self,
713        name: &str,
714        params: &kernels::KernelParams,
715    ) -> Result<GpuKernelHandle, GpuError> {
716        let specialized = self.kernel_registry.get_specialized(name, params)?;
717        let kernel_source = specialized.source_for_backend(self.backend)?;
718        let metadata = specialized.metadata();
719
720        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
721        Ok(handle)
722    }
723
724    /// Compile a kernel with metadata
725    fn compile_kernel_with_metadata(
726        &self,
727        source: &str,
728        _metadata: &kernels::KernelMetadata,
729    ) -> Result<GpuKernelHandle, GpuError> {
730        self.execute(|compiler| compiler.compile(source))
731    }
732
733    /// Get available memory on the device
734    pub fn get_available_memory(&self) -> Option<usize> {
735        // In a real implementation, this would query the device
736        // For now, return a placeholder value
737        Some(1024 * 1024 * 1024) // 1GB
738    }
739
740    /// Get total memory on the device
741    pub fn get_total_memory(&self) -> Option<usize> {
742        // In a real implementation, this would query the device
743        // For now, return a placeholder value
744        #[cfg(target_arch = "wasm32")]
745        return Some(512 * 1024 * 1024); // 512MB for WASM32
746
747        #[cfg(not(target_arch = "wasm32"))]
748        Some((4u64 * 1024 * 1024 * 1024) as usize) // 4GB for native
749    }
750
751    /// Launch a kernel with the given parameters
752    pub fn launch_kernel(
753        &self,
754        kernel_name: &str,
755        grid_size: (usize, usize, usize),
756        block_size: (usize, usize, usize),
757        args: &[DynamicKernelArg],
758    ) -> Result<(), GpuError> {
759        // Placeholder implementation
760        let _ = (kernel_name, grid_size, block_size, args);
761        Ok(())
762    }
763
764    /// Transfer data from host to device asynchronously
765    pub fn transfer_async_host_to_device<T: GpuDataType>(
766        &self,
767        ptr: &GpuPtr<T>,
768        data: &[T],
769    ) -> Result<(), GpuError> {
770        // Placeholder implementation
771        let _ = (ptr, data);
772        Ok(())
773    }
774
775    /// Transfer data from host to device synchronously
776    pub fn transfer_host_to_device<T: GpuDataType>(
777        &self,
778        ptr: &GpuPtr<T>,
779        data: &[T],
780    ) -> Result<(), GpuError> {
781        // Placeholder implementation
782        let _ = (ptr, data);
783        Ok(())
784    }
785
786    /// Transfer data from device to host asynchronously
787    pub fn transfer_async_device_to_host<T: GpuDataType>(
788        &self,
789        ptr: &GpuPtr<T>,
790        data: &mut [T],
791    ) -> Result<(), GpuError> {
792        // Placeholder implementation
793        let _ = (ptr, data);
794        Ok(())
795    }
796
797    /// Transfer data from device to host synchronously
798    pub fn transfer_device_to_host<T: GpuDataType>(
799        &self,
800        ptr: &GpuPtr<T>,
801        data: &mut [T],
802    ) -> Result<(), GpuError> {
803        // Placeholder implementation
804        let _ = (ptr, data);
805        Ok(())
806    }
807
808    /// Execute a kernel with dynamic compilation and parameter passing
809    /// This method is expected by scirs2-vision for GPU operations
810    pub fn execute_kernel(
811        &self,
812        source: &str,
813        buffers: &[GpuBuffer<f32>],
814        work_groups: (u32, u32, u32),
815        int_params: &[u32],
816        float_params: &[f32],
817    ) -> Result<(), GpuError> {
818        // For now, provide a basic implementation that logs the execution
819        // In a real implementation, this would compile and execute the kernel
820        eprintln!(
821            "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
822            source.len(),
823            buffers.len(),
824            work_groups
825        );
826        eprintln!("Int params: {int_params:?}");
827        eprintln!("Float params: {float_params:?}");
828        Ok(())
829    }
830
831    /// Read data from a GPU buffer
832    /// This method is expected by scirs2-vision for reading GPU results
833    pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
834        Ok(buffer.to_vec())
835    }
836
837    /// Global sum reduction
838    pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
839        // Stub implementation using CPU fallback
840        let data = buffer.to_vec();
841        let sum: T = unsafe { std::mem::zeroed() }; // Placeholder
842        let result = self.create_buffer::<T>(1);
843        let _ = result.copy_from_host(&[sum]);
844        Ok(result)
845    }
846
847    /// Global mean reduction
848    pub fn mean_all<T: GpuDataType>(
849        &self,
850        buffer: &GpuBuffer<T>,
851    ) -> Result<GpuBuffer<T>, GpuError> {
852        // Stub implementation
853        self.sum_all(buffer)
854    }
855
856    /// Global max reduction
857    pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
858        // Stub implementation
859        let result = self.create_buffer::<T>(1);
860        Ok(result)
861    }
862
863    /// Global min reduction
864    pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
865        // Stub implementation
866        let result = self.create_buffer::<T>(1);
867        Ok(result)
868    }
869
870    /// Sum reduction along an axis
871    pub fn sum_axis<T: GpuDataType>(
872        &self,
873        buffer: &GpuBuffer<T>,
874        shape: &[usize],
875        axis: usize,
876    ) -> Result<GpuBuffer<T>, GpuError> {
877        // Stub implementation
878        let mut output_shape = shape.to_vec();
879        if axis >= output_shape.len() {
880            return Err(GpuError::InvalidParameter(format!(
881                "Axis {} out of bounds for shape {:?}",
882                axis, shape
883            )));
884        }
885        output_shape[axis] = 1;
886        let output_size: usize = output_shape.iter().product();
887        let result = self.create_buffer::<T>(output_size);
888        Ok(result)
889    }
890
891    /// Mean reduction along an axis
892    pub fn mean_axis<T: GpuDataType>(
893        &self,
894        buffer: &GpuBuffer<T>,
895        shape: &[usize],
896        axis: usize,
897    ) -> Result<GpuBuffer<T>, GpuError> {
898        // Stub implementation
899        self.sum_axis(buffer, shape, axis)
900    }
901
902    /// Max reduction along an axis
903    pub fn max_axis<T: GpuDataType>(
904        &self,
905        buffer: &GpuBuffer<T>,
906        shape: &[usize],
907        axis: usize,
908    ) -> Result<GpuBuffer<T>, GpuError> {
909        // Stub implementation
910        let mut output_shape = shape.to_vec();
911        if axis >= output_shape.len() {
912            return Err(GpuError::InvalidParameter(format!(
913                "Axis {} out of bounds for shape {:?}",
914                axis, shape
915            )));
916        }
917        output_shape[axis] = 1;
918        let output_size: usize = output_shape.iter().product();
919        let result = self.create_buffer::<T>(output_size);
920        Ok(result)
921    }
922
923    /// Min reduction along an axis
924    pub fn min_axis<T: GpuDataType>(
925        &self,
926        buffer: &GpuBuffer<T>,
927        shape: &[usize],
928        axis: usize,
929    ) -> Result<GpuBuffer<T>, GpuError> {
930        // Stub implementation
931        let mut output_shape = shape.to_vec();
932        if axis >= output_shape.len() {
933            return Err(GpuError::InvalidParameter(format!(
934                "Axis {} out of bounds for shape {:?}",
935                axis, shape
936            )));
937        }
938        output_shape[axis] = 1;
939        let output_size: usize = output_shape.iter().product();
940        let result = self.create_buffer::<T>(output_size);
941        Ok(result)
942    }
943
944    /// Broadcast a buffer to a different shape
945    pub fn broadcast<T: GpuDataType>(
946        &self,
947        buffer: &GpuBuffer<T>,
948        from_shape: &[usize],
949        to_shape: &[usize],
950    ) -> Result<GpuBuffer<T>, GpuError> {
951        // Stub implementation
952        let output_size: usize = to_shape.iter().product();
953        let result = self.create_buffer::<T>(output_size);
954
955        // For now, just copy the data if sizes match, otherwise create zeros
956        if buffer.len() == output_size {
957            let data = buffer.to_vec();
958            let _ = result.copy_from_host(&data);
959        }
960
961        Ok(result)
962    }
963
964    /// Scale a buffer by a scalar value
965    pub fn scale<T: GpuDataType>(
966        &self,
967        buffer: &GpuBuffer<T>,
968        scalar: T,
969    ) -> Result<GpuBuffer<T>, GpuError> {
970        // Stub implementation
971        let result = self.create_buffer::<T>(buffer.len());
972        let data = buffer.to_vec();
973        let _ = result.copy_from_host(&data);
974        Ok(result)
975    }
976
977    /// General matrix multiplication: C = A @ B
978    pub fn gemm<T: GpuDataType>(
979        &self,
980        a: &GpuBuffer<T>,
981        b: &GpuBuffer<T>,
982        m: usize,
983        k: usize,
984        n: usize,
985    ) -> Result<GpuBuffer<T>, GpuError> {
986        // Stub implementation
987        // Output size is m x n
988        let result = self.create_buffer::<T>(m * n);
989        Ok(result)
990    }
991
992    /// GEMM with transposed B: C = A @ B^T
993    pub fn gemm_transpose_b<T: GpuDataType>(
994        &self,
995        a: &GpuBuffer<T>,
996        b: &GpuBuffer<T>,
997        m: usize,
998        k: usize,
999        n: usize,
1000    ) -> Result<GpuBuffer<T>, GpuError> {
1001        // Stub implementation
1002        // Output size is m x n
1003        let result = self.create_buffer::<T>(m * n);
1004        Ok(result)
1005    }
1006
1007    /// GEMM with transposed A: C = A^T @ B
1008    pub fn gemm_transpose_a<T: GpuDataType>(
1009        &self,
1010        a: &GpuBuffer<T>,
1011        b: &GpuBuffer<T>,
1012        m: usize,
1013        k: usize,
1014        n: usize,
1015    ) -> Result<GpuBuffer<T>, GpuError> {
1016        // Stub implementation
1017        // Output size is m x n
1018        let result = self.create_buffer::<T>(m * n);
1019        Ok(result)
1020    }
1021
1022    /// ReLU activation forward pass
1023    pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1024        // Stub implementation
1025        let result = self.create_buffer::<T>(input.len());
1026        let data = input.to_vec();
1027        let _ = result.copy_from_host(&data);
1028        Ok(result)
1029    }
1030
1031    /// ReLU backward pass
1032    pub fn relu_backward<T: GpuDataType>(
1033        &self,
1034        grad_output: &GpuBuffer<T>,
1035        input: &GpuBuffer<T>,
1036    ) -> Result<GpuBuffer<T>, GpuError> {
1037        // Stub implementation
1038        let result = self.create_buffer::<T>(grad_output.len());
1039        let data = grad_output.to_vec();
1040        let _ = result.copy_from_host(&data);
1041        Ok(result)
1042    }
1043
1044    /// Sigmoid activation forward pass
1045    pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1046        // Stub implementation
1047        let result = self.create_buffer::<T>(input.len());
1048        let data = input.to_vec();
1049        let _ = result.copy_from_host(&data);
1050        Ok(result)
1051    }
1052
1053    /// Sigmoid backward pass
1054    pub fn sigmoid_backward<T: GpuDataType>(
1055        &self,
1056        grad_output: &GpuBuffer<T>,
1057        input: &GpuBuffer<T>,
1058    ) -> Result<GpuBuffer<T>, GpuError> {
1059        // Stub implementation
1060        let result = self.create_buffer::<T>(grad_output.len());
1061        let data = grad_output.to_vec();
1062        let _ = result.copy_from_host(&data);
1063        Ok(result)
1064    }
1065
1066    /// Tanh activation forward pass
1067    pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1068        // Stub implementation
1069        let result = self.create_buffer::<T>(input.len());
1070        let data = input.to_vec();
1071        let _ = result.copy_from_host(&data);
1072        Ok(result)
1073    }
1074
1075    /// Tanh backward pass
1076    pub fn tanh_backward<T: GpuDataType>(
1077        &self,
1078        grad_output: &GpuBuffer<T>,
1079        input: &GpuBuffer<T>,
1080    ) -> Result<GpuBuffer<T>, GpuError> {
1081        // Stub implementation
1082        let result = self.create_buffer::<T>(grad_output.len());
1083        let data = grad_output.to_vec();
1084        let _ = result.copy_from_host(&data);
1085        Ok(result)
1086    }
1087
1088    /// GELU activation forward pass
1089    pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1090        // Stub implementation
1091        let result = self.create_buffer::<T>(input.len());
1092        let data = input.to_vec();
1093        let _ = result.copy_from_host(&data);
1094        Ok(result)
1095    }
1096
1097    /// GELU backward pass
1098    pub fn gelu_backward<T: GpuDataType>(
1099        &self,
1100        grad_output: &GpuBuffer<T>,
1101        input: &GpuBuffer<T>,
1102    ) -> Result<GpuBuffer<T>, GpuError> {
1103        // Stub implementation
1104        let result = self.create_buffer::<T>(grad_output.len());
1105        let data = grad_output.to_vec();
1106        let _ = result.copy_from_host(&data);
1107        Ok(result)
1108    }
1109}
1110
1111impl fmt::Debug for GpuContext {
1112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1113        f.debug_struct("GpuContext")
1114            .field("backend", &self.backend)
1115            .finish()
1116    }
1117}
1118
1119// The following trait definitions would be implemented by backend-specific
1120// code in a real implementation
1121
1122/// GPU buffer implementation trait
1123pub(crate) trait GpuBufferImpl: Send + Sync {
1124    /// Copy data from host to device
1125    unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1126
1127    /// Copy data from device to host
1128    unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1129
1130    /// Get a reference to self as Any for downcasting
1131    #[allow(dead_code)]
1132    fn as_any(&self) -> &dyn std::any::Any;
1133
1134    /// Get the size of the buffer in bytes
1135    #[allow(dead_code)]
1136    fn size(&self) -> usize {
1137        0 // Default implementation for backward compatibility
1138    }
1139
1140    /// Get the device pointer (for backends that use device pointers)
1141    #[allow(dead_code)]
1142    fn device_ptr(&self) -> u64 {
1143        0 // Default implementation for backward compatibility
1144    }
1145}
1146
1147/// GPU kernel implementation trait
1148pub(crate) trait GpuKernelImpl: Send + Sync {
1149    /// Set a buffer parameter
1150    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1151
1152    /// Set a u32 parameter
1153    fn set_u32(&self, name: &str, value: u32);
1154
1155    /// Set an i32 parameter
1156    fn set_i32(&self, name: &str, value: i32);
1157
1158    /// Set an f32 parameter
1159    fn set_f32(&self, name: &str, value: f32);
1160
1161    /// Set an f64 parameter
1162    fn set_f64(&self, name: &str, value: f64);
1163
1164    /// Dispatch the kernel
1165    fn dispatch(&self, workgroups: [u32; 3]);
1166}
1167
1168/// GPU compiler implementation trait
1169pub(crate) trait GpuCompilerImpl: Send + Sync {
1170    /// Compile a kernel from source
1171    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1172
1173    /// Compile a typed kernel
1174    fn compile_typed(
1175        &self,
1176        name: &str,
1177        input_type: std::any::TypeId,
1178        output_type: std::any::TypeId,
1179    ) -> Arc<dyn GpuKernelImpl>;
1180}
1181
1182/// GPU context implementation trait
1183pub(crate) trait GpuContextImpl: Send + Sync {
1184    /// Create a buffer
1185    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1186
1187    /// Create a compiler
1188    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1189
1190    /// Support dynamic downcasting of concrete context implementations
1191    fn as_any(&self) -> &dyn std::any::Any
1192    where
1193        Self: 'static + Sized,
1194    {
1195        self
1196    }
1197}
1198
1199// CPU fallback implementation
1200
1201/// CPU context implementation
1202struct CpuContext;
1203
1204impl CpuContext {
1205    /// Create a new CPU context
1206    fn new() -> Self {
1207        Self
1208    }
1209}
1210
1211impl GpuContextImpl for CpuContext {
1212    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1213        Arc::new(CpuBuffer::new(size))
1214    }
1215
1216    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1217        Arc::new(CpuCompiler)
1218    }
1219}
1220
1221/// CPU buffer implementation
1222struct CpuBuffer {
1223    data: Vec<u8>,
1224}
1225
1226impl CpuBuffer {
1227    /// Create a new CPU buffer
1228    fn new(size: usize) -> Self {
1229        Self {
1230            data: vec![0; size],
1231        }
1232    }
1233}
1234
1235impl GpuBufferImpl for CpuBuffer {
1236    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1237        let mut_self = self as *const Self as *mut Self;
1238        let data_ptr = (*mut_self).data.as_mut_ptr();
1239        std::ptr::copy_nonoverlapping(data, data_ptr, size);
1240    }
1241
1242    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1243        let data_ptr = self.data.as_ptr();
1244        std::ptr::copy_nonoverlapping(data_ptr, data, size);
1245    }
1246
1247    fn as_any(&self) -> &dyn std::any::Any {
1248        self
1249    }
1250
1251    fn size(&self) -> usize {
1252        self.data.len()
1253    }
1254
1255    fn device_ptr(&self) -> u64 {
1256        self.data.as_ptr() as u64
1257    }
1258}
1259
1260/// CPU compiler implementation
1261struct CpuCompiler;
1262
1263impl GpuCompilerImpl for CpuCompiler {
1264    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1265        // In a real implementation, we would parse and execute the kernel
1266        // For now, just return a dummy implementation
1267        Ok(Arc::new(CpuKernel))
1268    }
1269
1270    fn compile_typed(
1271        &self,
1272        _name: &str,
1273        _input_type: std::any::TypeId,
1274        _output_type: std::any::TypeId,
1275    ) -> Arc<dyn GpuKernelImpl> {
1276        // In a real implementation, we would select an appropriate implementation
1277        // For now, just return a dummy implementation
1278        Arc::new(CpuKernel)
1279    }
1280}
1281
1282/// CPU kernel implementation
1283struct CpuKernel;
1284
1285impl GpuKernelImpl for CpuKernel {
1286    fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1287        // In a real implementation, we would store the buffer
1288    }
1289
1290    fn set_u32(&self, _name: &str, value: u32) {
1291        // In a real implementation, we would store the value
1292    }
1293
1294    fn set_i32(&self, _name: &str, value: i32) {
1295        // In a real implementation, we would store the value
1296    }
1297
1298    fn set_f32(&self, _name: &str, value: f32) {
1299        // In a real implementation, we would store the value
1300    }
1301
1302    fn set_f64(&self, _name: &str, value: f64) {
1303        // In a real implementation, we would store the value
1304    }
1305
1306    fn dispatch(&self, workgroups: [u32; 3]) {
1307        // In a real implementation, we would execute the kernel
1308    }
1309}
1310
1311// In a real implementation, we would have implementations for other backends
1312// such as CUDA, WebGPU, Metal, and OpenCL.
1313
1314#[cfg(test)]
1315mod tests {
1316    use super::*;
1317
1318    #[test]
1319    fn test_gpu_backend_preferred() {
1320        let backend = GpuBackend::preferred();
1321        // Should return a valid backend
1322        match backend {
1323            GpuBackend::Cuda
1324            | GpuBackend::Rocm
1325            | GpuBackend::Wgpu
1326            | GpuBackend::Metal
1327            | GpuBackend::OpenCL
1328            | GpuBackend::Cpu => {}
1329        }
1330    }
1331
1332    #[test]
1333    fn test_gpu_backend_default() {
1334        let backend = GpuBackend::default();
1335        assert_eq!(backend, GpuBackend::preferred());
1336    }
1337
1338    #[test]
1339    fn test_gpu_backend_is_available() {
1340        let backend = GpuBackend::Cpu;
1341        assert!(backend.is_available());
1342
1343        // Test other backends - availability depends on runtime, not just feature flags
1344        #[cfg(feature = "cuda")]
1345        {
1346            // CUDA feature enabled doesn't guarantee runtime availability
1347            let _ = GpuBackend::Cuda.is_available(); // Just check without asserting
1348        }
1349        #[cfg(not(feature = "cuda"))]
1350        assert!(!GpuBackend::Cuda.is_available());
1351
1352        #[cfg(feature = "rocm")]
1353        {
1354            // ROCm feature enabled doesn't guarantee runtime availability
1355            let _ = GpuBackend::Rocm.is_available(); // Just check without asserting
1356        }
1357        #[cfg(not(feature = "rocm"))]
1358        assert!(!GpuBackend::Rocm.is_available());
1359
1360        #[cfg(all(feature = "metal", target_os = "macos"))]
1361        assert!(GpuBackend::Metal.is_available());
1362        #[cfg(not(all(feature = "metal", target_os = "macos")))]
1363        assert!(!GpuBackend::Metal.is_available());
1364    }
1365
1366    #[test]
1367    fn test_gpu_backend_display() {
1368        assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1369        assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1370        assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1371        assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1372        assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1373        assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1374    }
1375
1376    #[test]
1377    fn test_gpuerror_from_conversion() {
1378        let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1379        let coreerror: CoreError = gpuerror.into();
1380        match coreerror {
1381            CoreError::ComputationError(_) => {}
1382            _ => panic!("Expected ComputationError"),
1383        }
1384
1385        let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1386        let coreerror: CoreError = gpuerror.into();
1387        match coreerror {
1388            CoreError::MemoryError(_) => {}
1389            _ => panic!("Expected MemoryError"),
1390        }
1391
1392        let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1393        let coreerror: CoreError = gpuerror.into();
1394        match coreerror {
1395            CoreError::InvalidArgument(_) => {}
1396            _ => panic!("Expected InvalidArgument"),
1397        }
1398
1399        let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1400        let coreerror: CoreError = gpuerror.into();
1401        match coreerror {
1402            CoreError::TypeError(_) => {}
1403            _ => panic!("Expected TypeError"),
1404        }
1405    }
1406
1407    #[test]
1408    fn test_gpu_datatype_trait() {
1409        // Test that various types implement GpuDataType
1410        fn assert_gpu_datatype<T: GpuDataType>() {}
1411
1412        assert_gpu_datatype::<f32>();
1413        assert_gpu_datatype::<f64>();
1414        assert_gpu_datatype::<i32>();
1415        assert_gpu_datatype::<u32>();
1416        assert_gpu_datatype::<u8>();
1417        assert_gpu_datatype::<i8>();
1418        assert_gpu_datatype::<u16>();
1419        assert_gpu_datatype::<i16>();
1420        assert_gpu_datatype::<u64>();
1421        assert_gpu_datatype::<i64>();
1422    }
1423
1424    #[test]
1425    fn test_gpu_buffer_creation() {
1426        let inner = Arc::new(CpuBuffer::new(100));
1427        let buffer = GpuBuffer::<f32>::new(inner, 25);
1428
1429        assert_eq!(buffer.len(), 25);
1430        assert!(!buffer.is_empty());
1431    }
1432
1433    #[test]
1434    fn test_gpu_buffer_empty() {
1435        let inner = Arc::new(CpuBuffer::new(0));
1436        let buffer = GpuBuffer::<f32>::new(inner, 0);
1437
1438        assert_eq!(buffer.len(), 0);
1439        assert!(buffer.is_empty());
1440    }
1441
1442    #[test]
1443    fn test_gpu_buffer_copy_operations() {
1444        let inner = Arc::new(CpuBuffer::new(16));
1445        let buffer = GpuBuffer::<f32>::new(inner, 4);
1446
1447        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1448        let _ = buffer.copy_from_host(&data);
1449
1450        let mut result = vec![0.0f32; 4];
1451        let _ = buffer.copy_to_host(&mut result);
1452
1453        assert_eq!(result, data);
1454    }
1455
1456    #[test]
1457    fn test_gpu_buffer_to_vec() {
1458        let inner = Arc::new(CpuBuffer::new(12));
1459        let buffer = GpuBuffer::<f32>::new(inner, 3);
1460
1461        let data = vec![5.0f32, 6.0, 7.0];
1462        let _ = buffer.copy_from_host(&data);
1463
1464        let result = buffer.to_vec();
1465        assert_eq!(result, data);
1466    }
1467
1468    #[test]
1469    #[should_panic(expected = "Data size exceeds buffer size")]
1470    fn test_gpu_buffer_copy_from_host_overflow() {
1471        let inner = Arc::new(CpuBuffer::new(8));
1472        let buffer = GpuBuffer::<f32>::new(inner, 2);
1473
1474        let data = vec![1.0f32, 2.0, 3.0]; // 3 elements > 2 buffer size
1475        buffer.copy_from_host(&data).expect("Operation failed");
1476    }
1477
1478    #[test]
1479    #[should_panic(expected = "Data size exceeds buffer size")]
1480    fn test_gpu_buffer_copy_to_host_overflow() {
1481        let inner = Arc::new(CpuBuffer::new(8));
1482        let buffer = GpuBuffer::<f32>::new(inner, 2);
1483
1484        let mut data = vec![0.0f32; 3]; // 3 elements > 2 buffer size
1485        buffer.copy_to_host(&mut data).expect("Operation failed");
1486    }
1487
1488    #[test]
1489    fn test_gpu_kernel_handle() {
1490        let kernel = Arc::new(CpuKernel);
1491        let handle = GpuKernelHandle::new(kernel);
1492
1493        // Test setting various parameter types
1494        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1495        handle.set_buffer("input", &buffer);
1496        handle.set_u32("size", 100);
1497        handle.set_i32("offset", -5);
1498        handle.set_f32("scale", 2.5);
1499        handle.set_f64("precision", 0.0001);
1500
1501        // Test dispatch
1502        handle.dispatch([16, 8, 1]);
1503    }
1504
1505    #[test]
1506    fn test_gpu_context_cpu_backend() {
1507        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1508        assert_eq!(context.backend(), GpuBackend::Cpu);
1509        assert_eq!(context.backend_name(), "CPU");
1510
1511        // Test memory query methods
1512        assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1513        assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1514    }
1515
1516    #[test]
1517    fn test_gpu_context_buffer_creation() {
1518        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1519
1520        let buffer = context.create_buffer::<f32>(100);
1521        assert_eq!(buffer.len(), 100);
1522
1523        let data = vec![1.0f32; 50];
1524        let buffer_from_slice = context.create_buffer_from_slice(&data);
1525        assert_eq!(buffer_from_slice.len(), 50);
1526
1527        let result = buffer_from_slice.to_vec();
1528        assert_eq!(result, data);
1529    }
1530
1531    #[test]
1532    fn test_gpu_context_unsupported_backend() {
1533        // Test a backend that's not available
1534        #[cfg(not(feature = "cuda"))]
1535        {
1536            let result = GpuContext::new(GpuBackend::Cuda);
1537            assert!(result.is_err());
1538            match result {
1539                Err(GpuError::UnsupportedBackend(_)) => {}
1540                Err(GpuError::BackendNotAvailable(_)) => {} // Also accept this error
1541                Err(e) => panic!(
1542                    "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1543                    e
1544                ),
1545                Ok(_) => panic!("Expected error, got Ok"),
1546            }
1547        }
1548    }
1549
1550    #[test]
1551    fn test_gpu_compiler() {
1552        let compiler_impl = Arc::new(CpuCompiler);
1553        let compiler = GpuCompiler::new(compiler_impl);
1554
1555        // Test compiling from source
1556        let kernel = compiler
1557            .compile("dummy kernel source")
1558            .expect("Operation failed");
1559        kernel.dispatch([1, 1, 1]);
1560
1561        // Test typed compilation
1562        let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1563        typed_kernel.dispatch([32, 1, 1]);
1564    }
1565
1566    #[test]
1567    fn test_gpu_context_execute() {
1568        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1569
1570        let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1571
1572        assert!(result);
1573    }
1574
1575    #[test]
1576    fn test_gpu_context_kernel_registry() {
1577        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1578
1579        // Test getting a non-existent kernel
1580        let result = context.get_kernel("non_existent_kernel");
1581        assert!(result.is_err());
1582        match result {
1583            Err(GpuError::KernelNotFound(_)) => {}
1584            _ => panic!("Expected KernelNotFound error"),
1585        }
1586    }
1587
1588    #[test]
1589    fn test_cpu_buffer_implementation() {
1590        let buffer = CpuBuffer::new(256);
1591        assert_eq!(buffer.data.len(), 256);
1592
1593        // Test that initial data is zeroed
1594        assert!(buffer.data.iter().all(|&b| b == 0));
1595    }
1596
1597    #[test]
1598    fn test_gpuerror_display() {
1599        let error = GpuError::BackendNotAvailable("CUDA".to_string());
1600        assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1601
1602        let error = GpuError::OutOfMemory("allocation failed".to_string());
1603        assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1604
1605        let error = GpuError::KernelCompilationError("syntax error".to_string());
1606        assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1607
1608        let error = GpuError::KernelNotFound("gemm".to_string());
1609        assert_eq!(error.to_string(), "Kernel not found: gemm");
1610    }
1611
1612    #[test]
1613    fn test_backend_equality() {
1614        assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1615        assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1616
1617        // Test Clone and Copy
1618        let backend = GpuBackend::Metal;
1619        let cloned = backend;
1620        let copied = backend;
1621        assert_eq!(backend, cloned);
1622        assert_eq!(backend, copied);
1623    }
1624
1625    #[test]
1626    fn test_backend_hash() {
1627        use std::collections::HashSet;
1628
1629        let mut set = HashSet::new();
1630        set.insert(GpuBackend::Cuda);
1631        set.insert(GpuBackend::Rocm);
1632        set.insert(GpuBackend::Cuda); // Duplicate
1633
1634        assert_eq!(set.len(), 2); // Should only have 2 unique entries
1635        assert!(set.contains(&GpuBackend::Cuda));
1636        assert!(set.contains(&GpuBackend::Rocm));
1637    }
1638
1639    #[test]
1640    fn test_gpu_buffer_debug_clone() {
1641        let inner = Arc::new(CpuBuffer::new(16));
1642        let buffer = GpuBuffer::<f32>::new(inner, 4);
1643
1644        // Test Debug implementation
1645        let debug_str = format!("{:?}", buffer);
1646        assert!(debug_str.contains("GpuBuffer"));
1647        assert!(debug_str.contains("size"));
1648
1649        // Test Clone implementation
1650        let cloned = buffer.clone();
1651        assert_eq!(cloned.len(), buffer.len());
1652        assert_eq!(cloned.len(), 4);
1653
1654        // Verify the clone is independent (shares the same Arc)
1655        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1656        let _ = buffer.copy_from_host(&data);
1657
1658        let mut result = vec![0.0f32; 4];
1659        let _ = cloned.copy_to_host(&mut result);
1660        assert_eq!(result, data);
1661    }
1662
1663    #[test]
1664    fn test_gpu_context_debug() {
1665        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1666
1667        // Test Debug implementation
1668        let debug_str = format!("{:?}", context);
1669        assert!(debug_str.contains("GpuContext"));
1670        assert!(debug_str.contains("backend"));
1671        assert!(debug_str.contains("Cpu"));
1672    }
1673}