Skip to main content

scirs2_core/gpu/
mod.rs

1//! GPU acceleration module for scirs2-core
2//!
3//! This module provides hardware acceleration support across different GPU backends.
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod device_info;
16pub mod heterogeneous;
17pub mod kernels;
18pub mod memory_management;
19pub mod stream_allocator;
20pub mod tensor_cores;
21
22pub use async_transfer::{
23    AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
24};
25pub use device_info::GpuDeviceInfo;
26pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
27pub use stream_allocator::{StreamAllocator, StreamId};
28
29/// GPU backend type
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum GpuBackend {
32    /// NVIDIA CUDA backend
33    Cuda,
34    /// AMD ROCm backend
35    Rocm,
36    /// WebGPU backend
37    Wgpu,
38    /// Apple Metal backend
39    Metal,
40    /// OpenCL backend
41    OpenCL,
42    /// CPU fallback
43    Cpu,
44}
45
46impl Default for GpuBackend {
47    fn default() -> Self {
48        Self::preferred()
49    }
50}
51
52impl GpuBackend {
53    /// Get the preferred GPU backend for the current system
54    pub fn preferred() -> Self {
55        // Use the backend detection system to find the optimal backend
56        // This will properly detect available GPUs and fall back to CPU if needed
57        match backends::initialize_optimal_backend() {
58            Ok(backend) => {
59                // If we get a non-CPU backend, verify it's actually usable
60                if backend != GpuBackend::Cpu {
61                    // Check if we can actually create a context with this backend
62                    // For now, since implementations are stubs, fall back to CPU
63                    #[cfg(not(test))]
64                    {
65                        // In non-test environments, we don't have real GPU implementations yet
66                        return GpuBackend::Cpu;
67                    }
68                    #[cfg(test)]
69                    {
70                        // In tests, we can pretend the backend works
71                        return backend;
72                    }
73                }
74                backend
75            }
76            Err(_) => {
77                // If detection fails entirely, use CPU
78                GpuBackend::Cpu
79            }
80        }
81    }
82
83    /// Check if this backend is available on the current system
84    pub fn is_available(&self) -> bool {
85        match self {
86            // Check runtime availability for GPU backends
87            GpuBackend::Cuda => {
88                #[cfg(feature = "cuda")]
89                {
90                    use crate::gpu::backends::cuda::CudaContext;
91                    CudaContext::is_available()
92                }
93                #[cfg(not(feature = "cuda"))]
94                {
95                    false
96                }
97            }
98            GpuBackend::Rocm => cfg!(feature = "rocm"), // Would use ROCm runtime check
99            GpuBackend::Wgpu => {
100                #[cfg(feature = "wgpu_backend")]
101                {
102                    use crate::gpu::backends::wgpu::WebGPUContext;
103                    WebGPUContext::is_available()
104                }
105                #[cfg(not(feature = "wgpu_backend"))]
106                {
107                    false
108                }
109            }
110            GpuBackend::Metal => {
111                #[cfg(all(feature = "metal", target_os = "macos"))]
112                {
113                    // Metal is always available on macOS if the feature is enabled
114                    true
115                }
116                #[cfg(not(all(feature = "metal", target_os = "macos")))]
117                {
118                    false
119                }
120            }
121            GpuBackend::OpenCL => {
122                #[cfg(feature = "opencl")]
123                {
124                    use crate::gpu::backends::opencl::OpenCLContext;
125                    OpenCLContext::is_available()
126                }
127                #[cfg(not(feature = "opencl"))]
128                {
129                    false
130                }
131            }
132            GpuBackend::Cpu => true,
133        }
134    }
135}
136
137impl fmt::Display for GpuBackend {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        match self {
140            GpuBackend::Cuda => write!(f, "CUDA"),
141            GpuBackend::Rocm => write!(f, "ROCm"),
142            GpuBackend::Wgpu => write!(f, "WebGPU"),
143            GpuBackend::Metal => write!(f, "Metal"),
144            GpuBackend::OpenCL => write!(f, "OpenCL"),
145            GpuBackend::Cpu => write!(f, "CPU"),
146        }
147    }
148}
149
150use crate::error::{CoreError, ErrorContext, ErrorLocation};
151
152/// Error type for GPU operations
153#[derive(Debug, thiserror::Error)]
154pub enum GpuError {
155    /// Backend is not available
156    #[error("GPU backend {0} is not available")]
157    BackendNotAvailable(String),
158
159    /// Backend is not supported
160    #[error("GPU backend {0} is not supported")]
161    UnsupportedBackend(GpuBackend),
162
163    /// Backend is not supported for a kernel
164    #[error("GPU backend {0:?} is not supported for this kernel")]
165    BackendNotSupported(GpuBackend),
166
167    /// Backend is not implemented yet
168    #[error("GPU backend {0} is not implemented yet")]
169    BackendNotImplemented(GpuBackend),
170
171    /// Out of memory
172    #[error("GPU out of memory: {0}")]
173    OutOfMemory(String),
174
175    /// Kernel compilation error
176    #[error("Kernel compilation error: {0}")]
177    KernelCompilationError(String),
178
179    /// Kernel execution error
180    #[error("Kernel execution error: {0}")]
181    KernelExecutionError(String),
182
183    /// Invalid parameter
184    #[error("Invalid parameter: {0}")]
185    InvalidParameter(String),
186
187    /// Kernel not found
188    #[error("Kernel not found: {0}")]
189    KernelNotFound(String),
190
191    /// Specialization not supported
192    #[error("Kernel specialization not supported")]
193    SpecializationNotSupported,
194
195    /// Unsupported data type
196    #[error("Unsupported data type: {0:?}")]
197    UnsupportedDataType(kernels::DataType),
198
199    /// Other error
200    #[error("{0}")]
201    Other(String),
202}
203
204/// GPU device abstraction
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub struct GpuDevice {
207    backend: GpuBackend,
208    device_id: usize,
209}
210
211impl GpuDevice {
212    /// Create a new GPU device
213    pub fn new(backend: GpuBackend, device_id: usize) -> Self {
214        Self { backend, device_id }
215    }
216
217    /// Get the backend type
218    pub fn backend(&self) -> GpuBackend {
219        self.backend
220    }
221
222    /// Get the device ID
223    pub fn device_id(&self) -> usize {
224        self.device_id
225    }
226
227    /// Compile a kernel from source
228    pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
229        // Placeholder implementation
230        Ok(GpuKernel {
231            backend: self.backend,
232            entry_point: entrypoint.to_string(),
233        })
234    }
235
236    /// Query the capabilities of this device.
237    ///
238    /// Returns a [`GpuDeviceInfo`] describing the device name, type, memory,
239    /// work-group limits and supported floating-point precisions. The values are
240    /// derived deterministically from this device's [`backend`](Self::backend).
241    ///
242    /// In the default Pure-Rust build no GPU SDK is linked, so the values for
243    /// hardware backends are conservative placeholders (see
244    /// [`GpuDeviceInfo`]); the method itself never fails in that configuration.
245    /// The `Result` return type reserves the ability to surface real adapter
246    /// query failures once live introspection is wired in.
247    ///
248    /// # Errors
249    ///
250    /// Currently always returns `Ok`. Reserved for future backends that perform
251    /// a fallible live adapter query.
252    pub fn get_info(&self) -> Result<GpuDeviceInfo, GpuError> {
253        Ok(GpuDeviceInfo::for_backend(self.backend))
254    }
255}
256
257/// GPU kernel abstraction
258pub struct GpuKernel {
259    backend: GpuBackend,
260    entry_point: String,
261}
262
263impl GpuKernel {
264    /// Get the backend type
265    pub fn backend(&self) -> GpuBackend {
266        self.backend
267    }
268
269    /// Get the entry point name
270    pub fn entry_point(&self) -> &str {
271        &self.entry_point
272    }
273}
274
275/// Convert GPU errors to core errors with semantic preservation
276impl From<GpuError> for CoreError {
277    fn from(err: GpuError) -> Self {
278        match err {
279            GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
280                ErrorContext::new(format!("GPU backend {backend} is not available"))
281                    .with_location(ErrorLocation::new(file!(), line!())),
282            ),
283            GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
284                ErrorContext::new(format!("GPU backend {backend} is not supported"))
285                    .with_location(ErrorLocation::new(file!(), line!())),
286            ),
287            GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
288                ErrorContext::new(format!(
289                    "GPU backend {backend:?} is not supported for this kernel"
290                ))
291                .with_location(ErrorLocation::new(file!(), line!())),
292            ),
293            GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
294                ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
295                    .with_location(ErrorLocation::new(file!(), line!())),
296            ),
297            GpuError::OutOfMemory(details) => CoreError::MemoryError(
298                ErrorContext::new(details.to_string())
299                    .with_location(ErrorLocation::new(file!(), line!())),
300            ),
301            GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
302                ErrorContext::new(msg.to_string())
303                    .with_location(ErrorLocation::new(file!(), line!())),
304            ),
305            GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
306                ErrorContext::new(msg.to_string())
307                    .with_location(ErrorLocation::new(file!(), line!())),
308            ),
309            GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
310                ErrorContext::new(msg.to_string())
311                    .with_location(ErrorLocation::new(file!(), line!())),
312            ),
313            GpuError::KernelNotFound(name) => CoreError::ComputationError(
314                ErrorContext::new(name.to_string())
315                    .with_location(ErrorLocation::new(file!(), line!())),
316            ),
317            GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
318                ErrorContext::new("Kernel specialization not supported".to_string())
319                    .with_location(ErrorLocation::new(file!(), line!())),
320            ),
321            GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
322                ErrorContext::new(format!("{dtype:?}"))
323                    .with_location(ErrorLocation::new(file!(), line!())),
324            ),
325            GpuError::Other(msg) => CoreError::ComputationError(
326                ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
327            ),
328        }
329    }
330}
331
332/// Trait for types that can be used with GPU operations
333pub trait GpuDataType: Copy + Send + Sync + 'static {}
334
335/// GPU memory pointer abstraction
336#[derive(Debug)]
337pub struct GpuPtr<T: GpuDataType> {
338    ptr: u64,
339    size: usize,
340    phantom: PhantomData<T>,
341}
342
343impl<T: GpuDataType> GpuPtr<T> {
344    /// Allocate GPU memory
345    pub fn allocate(size: usize) -> Result<Self, GpuError> {
346        Ok(GpuPtr {
347            ptr: 0x1000_0000, // Placeholder address
348            size,
349            phantom: PhantomData,
350        })
351    }
352
353    /// Get the raw pointer value
354    pub fn as_ptr(&self) -> u64 {
355        self.ptr
356    }
357
358    /// Get the size in elements
359    pub fn len(&self) -> usize {
360        self.size
361    }
362
363    /// Check if the pointer is empty (size is 0)
364    pub fn is_empty(&self) -> bool {
365        self.size == 0
366    }
367}
368
369/// Kernel argument types for GPU kernel execution
370#[derive(Debug, Clone)]
371pub enum KernelArg<'a, T: GpuDataType> {
372    /// Buffer argument
373    Buffer(&'a GpuPtr<T>),
374    /// Scalar argument
375    Scalar(T),
376}
377
378/// Non-generic kernel argument for mixed-type kernel calls
379#[derive(Debug, Clone)]
380pub enum DynamicKernelArg {
381    /// Buffer argument (type-erased)
382    Buffer(u64), // Raw pointer
383    /// f32 scalar
384    F32(f32),
385    /// f64 scalar
386    F64(f64),
387    /// i32 scalar
388    I32(i32),
389    /// u32 scalar
390    U32(u32),
391    /// usize scalar
392    Usize(usize),
393}
394
395/// GPU communication channel for multi-GPU operations
396pub struct GpuChannel {
397    #[allow(dead_code)]
398    source_device: usize,
399    #[allow(dead_code)]
400    target_device: usize,
401    #[allow(dead_code)]
402    bandwidth: f64, // GB/s
403}
404
405// Implement for common types
406impl GpuDataType for f32 {}
407impl GpuDataType for f64 {}
408impl GpuDataType for i32 {}
409impl GpuDataType for u32 {}
410impl GpuDataType for u8 {}
411impl GpuDataType for i8 {}
412impl GpuDataType for u16 {}
413impl GpuDataType for i16 {}
414impl GpuDataType for u64 {}
415impl GpuDataType for i64 {}
416impl GpuDataType for usize {}
417impl GpuDataType for isize {}
418
419/// GPU buffer
420pub struct GpuBuffer<T: GpuDataType> {
421    inner: Arc<dyn GpuBufferImpl>,
422    size: usize,
423    phantom: PhantomData<T>,
424}
425
426impl<T: GpuDataType> GpuBuffer<T> {
427    /// Create a new buffer with the given size
428    pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
429        Self {
430            inner,
431            size,
432            phantom: PhantomData,
433        }
434    }
435
436    /// Get the size of the buffer in elements
437    pub fn len(&self) -> usize {
438        self.size
439    }
440
441    /// Check if the buffer is empty
442    pub fn is_empty(&self) -> bool {
443        self.size == 0
444    }
445
446    /// Copy data from the host to the device
447    pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
448        if data.len() > self.size {
449            return Err(GpuError::InvalidParameter(
450                "Data size exceeds buffer size".to_string(),
451            ));
452        }
453        unsafe {
454            self.inner
455                .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
456        }
457        Ok(())
458    }
459
460    /// Copy data from the device to the host
461    pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
462        if data.len() > self.size {
463            return Err(GpuError::InvalidParameter(
464                "Data size exceeds buffer size".to_string(),
465            ));
466        }
467        unsafe {
468            self.inner
469                .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
470        }
471        Ok(())
472    }
473
474    /// Convert the buffer contents to a vector
475    pub fn to_vec(&self) -> Vec<T> {
476        let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
477        let _ = self.copy_to_host(&mut result);
478        result
479    }
480}
481
482impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
483    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484        f.debug_struct("GpuBuffer")
485            .field("size", &self.size)
486            .finish()
487    }
488}
489
490impl<T: GpuDataType> Clone for GpuBuffer<T> {
491    fn clone(&self) -> Self {
492        Self {
493            inner: Arc::clone(&self.inner),
494            size: self.size,
495            phantom: PhantomData,
496        }
497    }
498}
499
500/// GPU kernel handle
501#[derive(Clone)]
502pub struct GpuKernelHandle {
503    inner: Arc<dyn GpuKernelImpl>,
504}
505
506impl GpuKernelHandle {
507    /// Create a new kernel handle
508    pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
509        Self { inner }
510    }
511
512    /// Set a buffer parameter
513    pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
514        self.inner.set_buffer(name, &buffer.inner);
515    }
516
517    /// Set a u32 parameter
518    pub fn set_u32(&self, name: &str, value: u32) {
519        self.inner.set_u32(name, value);
520    }
521
522    /// Set an i32 parameter
523    pub fn set_i32(&self, name: &str, value: i32) {
524        self.inner.set_i32(name, value);
525    }
526
527    /// Set an f32 parameter
528    pub fn set_f32(&self, name: &str, value: f32) {
529        self.inner.set_f32(name, value);
530    }
531
532    /// Set an f64 parameter
533    pub fn set_f64(&self, name: &str, value: f64) {
534        self.inner.set_f64(name, value);
535    }
536
537    /// Dispatch the kernel with the given work group counts.
538    ///
539    /// If batch mode is active (see [`GpuContext::begin_batch`]), the
540    /// dispatch is encoded into the shared command buffer instead of
541    /// submitting immediately.
542    pub fn dispatch(&self, workgroups: [u32; 3]) {
543        if !self.inner.try_batch_dispatch(workgroups) {
544            self.inner.dispatch(workgroups);
545        }
546    }
547
548    /// Dispatch the kernel without waiting for GPU completion.
549    ///
550    /// This submits work to the GPU command queue but returns immediately.
551    /// Use [`GpuContext::gpu_sync()`] to wait for all pending dispatches.
552    pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
553        self.inner.dispatch_no_wait(workgroups);
554    }
555}
556
557/// GPU compiler for dynamic kernels
558pub struct GpuCompiler {
559    inner: Arc<dyn GpuCompilerImpl>,
560}
561
562impl GpuCompiler {
563    /// Create a new compiler
564    pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
565        Self { inner }
566    }
567
568    /// Compile a kernel from source
569    pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
570        let kernel = self.inner.compile(source)?;
571        Ok(GpuKernelHandle::new(kernel))
572    }
573
574    /// Compile a kernel for the specified input and output types
575    pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
576        let kernel = self.inner.compile_typed(
577            name,
578            std::any::TypeId::of::<I>(),
579            std::any::TypeId::of::<O>(),
580        );
581        GpuKernelHandle::new(kernel)
582    }
583}
584
585/// GPU context for managing GPU resources and operations
586pub struct GpuContext {
587    inner: Arc<dyn GpuContextImpl>,
588    backend: GpuBackend,
589    kernel_registry: kernels::KernelRegistry,
590}
591
592impl GpuContext {
593    /// Create a new GPU context with the specified backend
594    pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
595        // First check if the backend is available at compile time
596        if !backend.is_available() {
597            return Err(GpuError::BackendNotAvailable(backend.to_string()));
598        }
599
600        // For non-CPU backends, also check runtime availability
601        if backend != GpuBackend::Cpu {
602            let detection_result = backends::detect_gpu_backends();
603            let backend_available = detection_result
604                .devices
605                .iter()
606                .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
607
608            if !backend_available {
609                return Err(GpuError::BackendNotAvailable(format!(
610                    "{backend} (no devices detected at runtime)"
611                )));
612            }
613        }
614
615        let inner = match backend {
616            GpuBackend::Cuda => {
617                #[cfg(feature = "cuda")]
618                {
619                    use crate::gpu::backends::cuda::CudaContext;
620                    match CudaContext::new() {
621                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
622                        Err(e) => return Err(e),
623                    }
624                }
625                #[cfg(not(feature = "cuda"))]
626                {
627                    return Err(GpuError::UnsupportedBackend(backend));
628                }
629            }
630            GpuBackend::Rocm => {
631                #[cfg(feature = "rocm")]
632                {
633                    // This is just a stub - in a real implementation, we would use the hip-sys crate
634                    // to create a ROCm context and return it
635                    #[cfg(test)]
636                    {
637                        // For testing, we can use a mock implementation
638                        Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
639                    }
640                    #[cfg(not(test))]
641                    {
642                        return Err(GpuError::BackendNotImplemented(backend));
643                    }
644                }
645                #[cfg(not(feature = "rocm"))]
646                {
647                    return Err(GpuError::UnsupportedBackend(backend));
648                }
649            }
650            GpuBackend::Wgpu => {
651                #[cfg(feature = "wgpu_backend")]
652                {
653                    use crate::gpu::backends::wgpu::WebGPUContext;
654                    match WebGPUContext::new() {
655                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
656                        Err(e) => return Err(e),
657                    }
658                }
659                #[cfg(not(feature = "wgpu_backend"))]
660                {
661                    return Err(GpuError::UnsupportedBackend(backend));
662                }
663            }
664            GpuBackend::Metal => {
665                #[cfg(all(feature = "metal", target_os = "macos"))]
666                {
667                    use crate::gpu::backends::metal::MetalContext;
668                    match MetalContext::new() {
669                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
670                        Err(e) => return Err(e),
671                    }
672                }
673                #[cfg(not(all(feature = "metal", target_os = "macos")))]
674                {
675                    return Err(GpuError::UnsupportedBackend(backend));
676                }
677            }
678            GpuBackend::OpenCL => {
679                #[cfg(feature = "opencl")]
680                {
681                    use crate::gpu::backends::opencl::OpenCLContext;
682                    match OpenCLContext::new() {
683                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
684                        Err(e) => return Err(e),
685                    }
686                }
687                #[cfg(not(feature = "opencl"))]
688                {
689                    return Err(GpuError::UnsupportedBackend(backend));
690                }
691            }
692            GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
693        };
694
695        Ok(Self {
696            inner,
697            backend,
698            kernel_registry: kernels::KernelRegistry::with_default_kernels(),
699        })
700    }
701
702    /// Get the backend type
703    pub fn backend(&self) -> GpuBackend {
704        self.backend
705    }
706
707    /// Get the backend name
708    pub fn backend_name(&self) -> &str {
709        match self.backend {
710            GpuBackend::Cuda => "CUDA",
711            GpuBackend::Rocm => "ROCm",
712            GpuBackend::Wgpu => "WebGPU",
713            GpuBackend::Metal => "Metal",
714            GpuBackend::OpenCL => "OpenCL",
715            GpuBackend::Cpu => "CPU",
716        }
717    }
718
719    /// Wait for all pending GPU dispatches to complete.
720    ///
721    /// After calling [`GpuKernelHandle::dispatch_no_wait()`], results are
722    /// not guaranteed to be available until this method returns.
723    pub fn gpu_sync(&self) -> Result<(), GpuError> {
724        self.inner.gpu_sync()
725    }
726
727    /// Begin batch dispatch mode.
728    ///
729    /// While active, kernel dispatches are encoded into a shared command
730    /// buffer instead of creating individual ones.  Call [`end_batch`](Self::end_batch)
731    /// to submit all encoded work and wait for completion.
732    pub fn begin_batch(&self) -> Result<(), GpuError> {
733        self.inner.begin_batch()
734    }
735
736    /// End batch dispatch mode.
737    ///
738    /// Submits the shared command buffer containing all dispatches that
739    /// were encoded since [`begin_batch`](Self::begin_batch), then waits
740    /// for the GPU to finish.
741    pub fn end_batch(&self) -> Result<(), GpuError> {
742        self.inner.end_batch()
743    }
744
745    pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
746        let byte_size = size.saturating_mul(std::mem::size_of::<T>());
747        let inner = self.inner.create_buffer(byte_size);
748        GpuBuffer::new(inner, size)
749    }
750
751    /// Create a buffer from a slice
752    pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
753        let buffer = self.create_buffer::<T>(data.len());
754        let _ = buffer.copy_from_host(data);
755        buffer
756    }
757
758    /// Execute a function with a compiler
759    pub fn execute<F, R>(&self, f: F) -> R
760    where
761        F: FnOnce(&GpuCompiler) -> R,
762    {
763        let compiler = GpuCompiler::new(self.inner.create_compiler());
764        f(&compiler)
765    }
766
767    /// Get a kernel from the registry
768    pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
769        let kernel = self
770            .kernel_registry
771            .get(name)
772            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
773
774        let kernel_source = kernel.source_for_backend(self.backend)?;
775        let metadata = kernel.metadata();
776
777        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
778        Ok(handle)
779    }
780
781    /// Get a specialized kernel from the registry
782    pub fn get_specialized_kernel(
783        &self,
784        name: &str,
785        params: &kernels::KernelParams,
786    ) -> Result<GpuKernelHandle, GpuError> {
787        let specialized = self.kernel_registry.get_specialized(name, params)?;
788        let kernel_source = specialized.source_for_backend(self.backend)?;
789        let metadata = specialized.metadata();
790
791        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
792        Ok(handle)
793    }
794
795    /// Compile a kernel with metadata
796    fn compile_kernel_with_metadata(
797        &self,
798        source: &str,
799        _metadata: &kernels::KernelMetadata,
800    ) -> Result<GpuKernelHandle, GpuError> {
801        self.execute(|compiler| compiler.compile(source))
802    }
803
804    /// Get available memory on the device
805    pub fn get_available_memory(&self) -> Option<usize> {
806        // In a real implementation, this would query the device
807        // For now, return a placeholder value
808        Some(1024 * 1024 * 1024) // 1GB
809    }
810
811    /// Get total memory on the device
812    pub fn get_total_memory(&self) -> Option<usize> {
813        // In a real implementation, this would query the device
814        // For now, return a placeholder value
815        #[cfg(target_arch = "wasm32")]
816        return Some(512 * 1024 * 1024); // 512MB for WASM32
817
818        #[cfg(not(target_arch = "wasm32"))]
819        Some((4u64 * 1024 * 1024 * 1024) as usize) // 4GB for native
820    }
821
822    /// Launch a kernel with the given parameters
823    pub fn launch_kernel(
824        &self,
825        kernel_name: &str,
826        grid_size: (usize, usize, usize),
827        block_size: (usize, usize, usize),
828        args: &[DynamicKernelArg],
829    ) -> Result<(), GpuError> {
830        // Placeholder implementation
831        let _ = (kernel_name, grid_size, block_size, args);
832        Ok(())
833    }
834
835    /// Transfer data from host to device asynchronously
836    pub fn transfer_async_host_to_device<T: GpuDataType>(
837        &self,
838        ptr: &GpuPtr<T>,
839        data: &[T],
840    ) -> Result<(), GpuError> {
841        // Placeholder implementation
842        let _ = (ptr, data);
843        Ok(())
844    }
845
846    /// Transfer data from host to device synchronously
847    pub fn transfer_host_to_device<T: GpuDataType>(
848        &self,
849        ptr: &GpuPtr<T>,
850        data: &[T],
851    ) -> Result<(), GpuError> {
852        // Placeholder implementation
853        let _ = (ptr, data);
854        Ok(())
855    }
856
857    /// Transfer data from device to host asynchronously
858    pub fn transfer_async_device_to_host<T: GpuDataType>(
859        &self,
860        ptr: &GpuPtr<T>,
861        data: &mut [T],
862    ) -> Result<(), GpuError> {
863        // Placeholder implementation
864        let _ = (ptr, data);
865        Ok(())
866    }
867
868    /// Transfer data from device to host synchronously
869    pub fn transfer_device_to_host<T: GpuDataType>(
870        &self,
871        ptr: &GpuPtr<T>,
872        data: &mut [T],
873    ) -> Result<(), GpuError> {
874        // Placeholder implementation
875        let _ = (ptr, data);
876        Ok(())
877    }
878
879    /// Execute a kernel with dynamic compilation and parameter passing
880    /// This method is expected by scirs2-vision for GPU operations
881    pub fn execute_kernel(
882        &self,
883        source: &str,
884        buffers: &[GpuBuffer<f32>],
885        work_groups: (u32, u32, u32),
886        int_params: &[u32],
887        float_params: &[f32],
888    ) -> Result<(), GpuError> {
889        // For now, provide a basic implementation that logs the execution
890        // In a real implementation, this would compile and execute the kernel
891        eprintln!(
892            "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
893            source.len(),
894            buffers.len(),
895            work_groups
896        );
897        eprintln!("Int params: {int_params:?}");
898        eprintln!("Float params: {float_params:?}");
899        Ok(())
900    }
901
902    /// Read data from a GPU buffer
903    /// This method is expected by scirs2-vision for reading GPU results
904    pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
905        Ok(buffer.to_vec())
906    }
907
908    /// Global sum reduction
909    pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
910        self.sum_all_cpu_fallback(buffer)
911    }
912
913    /// Global mean reduction
914    pub fn mean_all<T: GpuDataType>(
915        &self,
916        buffer: &GpuBuffer<T>,
917    ) -> Result<GpuBuffer<T>, GpuError> {
918        self.mean_all_cpu_fallback(buffer)
919    }
920
921    /// Global max reduction
922    pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
923        self.max_all_cpu_fallback(buffer)
924    }
925
926    /// Global min reduction
927    pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
928        self.min_all_cpu_fallback(buffer)
929    }
930
931    /// Sum reduction along an axis
932    pub fn sum_axis<T: GpuDataType>(
933        &self,
934        buffer: &GpuBuffer<T>,
935        shape: &[usize],
936        axis: usize,
937    ) -> Result<GpuBuffer<T>, GpuError> {
938        self.sum_axis_cpu_fallback(buffer, shape, axis)
939    }
940
941    /// Mean reduction along an axis
942    pub fn mean_axis<T: GpuDataType>(
943        &self,
944        buffer: &GpuBuffer<T>,
945        shape: &[usize],
946        axis: usize,
947    ) -> Result<GpuBuffer<T>, GpuError> {
948        self.mean_axis_cpu_fallback(buffer, shape, axis)
949    }
950
951    /// Max reduction along an axis
952    pub fn max_axis<T: GpuDataType>(
953        &self,
954        buffer: &GpuBuffer<T>,
955        shape: &[usize],
956        axis: usize,
957    ) -> Result<GpuBuffer<T>, GpuError> {
958        self.max_axis_cpu_fallback(buffer, shape, axis)
959    }
960
961    /// Min reduction along an axis
962    pub fn min_axis<T: GpuDataType>(
963        &self,
964        buffer: &GpuBuffer<T>,
965        shape: &[usize],
966        axis: usize,
967    ) -> Result<GpuBuffer<T>, GpuError> {
968        self.min_axis_cpu_fallback(buffer, shape, axis)
969    }
970
971    /// Broadcast a buffer to a different shape
972    pub fn broadcast<T: GpuDataType>(
973        &self,
974        buffer: &GpuBuffer<T>,
975        from_shape: &[usize],
976        to_shape: &[usize],
977    ) -> Result<GpuBuffer<T>, GpuError> {
978        self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
979    }
980
981    /// Scale a buffer by a scalar value
982    pub fn scale<T: GpuDataType>(
983        &self,
984        buffer: &GpuBuffer<T>,
985        scalar: T,
986    ) -> Result<GpuBuffer<T>, GpuError> {
987        self.scale_cpu_fallback(buffer, scalar)
988    }
989
990    /// General matrix multiplication: C = A @ B
991    pub fn gemm<T: GpuDataType>(
992        &self,
993        a: &GpuBuffer<T>,
994        b: &GpuBuffer<T>,
995        m: usize,
996        k: usize,
997        n: usize,
998    ) -> Result<GpuBuffer<T>, GpuError> {
999        self.gemm_cpu_fallback(a, b, m, k, n)
1000    }
1001
1002    /// GEMM with transposed B: C = A @ B^T
1003    pub fn gemm_transpose_b<T: GpuDataType>(
1004        &self,
1005        a: &GpuBuffer<T>,
1006        b: &GpuBuffer<T>,
1007        m: usize,
1008        k: usize,
1009        n: usize,
1010    ) -> Result<GpuBuffer<T>, GpuError> {
1011        self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
1012    }
1013
1014    /// GEMM with transposed A: C = A^T @ B
1015    pub fn gemm_transpose_a<T: GpuDataType>(
1016        &self,
1017        a: &GpuBuffer<T>,
1018        b: &GpuBuffer<T>,
1019        m: usize,
1020        k: usize,
1021        n: usize,
1022    ) -> Result<GpuBuffer<T>, GpuError> {
1023        self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1024    }
1025
1026    /// ReLU activation forward pass
1027    pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1028        self.relu_cpu_fallback(input)
1029    }
1030
1031    /// ReLU backward pass
1032    pub fn relu_backward<T: GpuDataType>(
1033        &self,
1034        grad_output: &GpuBuffer<T>,
1035        input: &GpuBuffer<T>,
1036    ) -> Result<GpuBuffer<T>, GpuError> {
1037        self.relu_backward_cpu_fallback(grad_output, input)
1038    }
1039
1040    /// Sigmoid activation forward pass
1041    pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1042        self.sigmoid_cpu_fallback(input)
1043    }
1044
1045    /// Sigmoid backward pass
1046    pub fn sigmoid_backward<T: GpuDataType>(
1047        &self,
1048        grad_output: &GpuBuffer<T>,
1049        input: &GpuBuffer<T>,
1050    ) -> Result<GpuBuffer<T>, GpuError> {
1051        self.sigmoid_backward_cpu_fallback(grad_output, input)
1052    }
1053
1054    /// Tanh activation forward pass
1055    pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1056        self.tanh_cpu_fallback(input)
1057    }
1058
1059    /// Tanh backward pass
1060    pub fn tanh_backward<T: GpuDataType>(
1061        &self,
1062        grad_output: &GpuBuffer<T>,
1063        input: &GpuBuffer<T>,
1064    ) -> Result<GpuBuffer<T>, GpuError> {
1065        self.tanh_backward_cpu_fallback(grad_output, input)
1066    }
1067
1068    /// GELU activation forward pass
1069    pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1070        self.gelu_cpu_fallback(input)
1071    }
1072
1073    /// GELU backward pass
1074    pub fn gelu_backward<T: GpuDataType>(
1075        &self,
1076        grad_output: &GpuBuffer<T>,
1077        input: &GpuBuffer<T>,
1078    ) -> Result<GpuBuffer<T>, GpuError> {
1079        self.gelu_backward_cpu_fallback(grad_output, input)
1080    }
1081}
1082
1083impl fmt::Debug for GpuContext {
1084    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1085        f.debug_struct("GpuContext")
1086            .field("backend", &self.backend)
1087            .finish()
1088    }
1089}
1090
1091// The following trait definitions would be implemented by backend-specific
1092// code in a real implementation
1093
1094/// GPU buffer implementation trait
1095pub(crate) trait GpuBufferImpl: Send + Sync {
1096    /// Copy data from host to device
1097    unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1098
1099    /// Copy data from device to host
1100    unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1101
1102    /// Get a reference to self as Any for downcasting
1103    #[allow(dead_code)]
1104    fn as_any(&self) -> &dyn std::any::Any;
1105
1106    /// Get the size of the buffer in bytes
1107    #[allow(dead_code)]
1108    fn size(&self) -> usize {
1109        0 // Default implementation for backward compatibility
1110    }
1111
1112    /// Get the device pointer (for backends that use device pointers)
1113    #[allow(dead_code)]
1114    fn device_ptr(&self) -> u64 {
1115        0 // Default implementation for backward compatibility
1116    }
1117}
1118
1119/// GPU kernel implementation trait
1120pub(crate) trait GpuKernelImpl: Send + Sync {
1121    /// Set a buffer parameter
1122    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1123
1124    /// Set a u32 parameter
1125    fn set_u32(&self, name: &str, value: u32);
1126
1127    /// Set an i32 parameter
1128    fn set_i32(&self, name: &str, value: i32);
1129
1130    /// Set an f32 parameter
1131    fn set_f32(&self, name: &str, value: f32);
1132
1133    /// Set an f64 parameter
1134    fn set_f64(&self, name: &str, value: f64);
1135
1136    /// Dispatch the kernel
1137    fn dispatch(&self, workgroups: [u32; 3]);
1138
1139    /// Dispatch the kernel without waiting for completion.
1140    /// The GPU work is submitted and will execute asynchronously.
1141    /// Call `gpu_sync()` on the context to wait for all pending work.
1142    fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1143        // Default: fall back to synchronous dispatch
1144        self.dispatch(workgroups);
1145    }
1146
1147    /// Try to encode a dispatch into the active batch (if any).
1148    ///
1149    /// Returns `true` if the dispatch was batched, `false` if no batch is
1150    /// active and the caller should use the normal `dispatch()` path.
1151    fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1152        false
1153    }
1154}
1155
1156/// GPU compiler implementation trait
1157pub(crate) trait GpuCompilerImpl: Send + Sync {
1158    /// Compile a kernel from source
1159    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1160
1161    /// Compile a typed kernel
1162    fn compile_typed(
1163        &self,
1164        name: &str,
1165        input_type: std::any::TypeId,
1166        output_type: std::any::TypeId,
1167    ) -> Arc<dyn GpuKernelImpl>;
1168}
1169
1170/// GPU context implementation trait
1171pub(crate) trait GpuContextImpl: Send + Sync {
1172    /// Create a buffer
1173    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1174
1175    /// Create a compiler
1176    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1177
1178    /// Wait for all pending GPU operations to complete.
1179    fn gpu_sync(&self) -> Result<(), GpuError> {
1180        Ok(()) // Default: no-op (synchronous backends already block)
1181    }
1182
1183    /// Begin batch dispatch mode.
1184    ///
1185    /// While active, kernel dispatches are encoded into a shared command
1186    /// buffer instead of creating individual ones.
1187    fn begin_batch(&self) -> Result<(), GpuError> {
1188        Ok(()) // Default: no-op (backends without batch support)
1189    }
1190
1191    /// End batch dispatch mode.
1192    ///
1193    /// Submits the shared command buffer and waits for all encoded
1194    /// dispatches to complete on the GPU.
1195    fn end_batch(&self) -> Result<(), GpuError> {
1196        Ok(()) // Default: no-op
1197    }
1198
1199    /// Support dynamic downcasting of concrete context implementations
1200    fn as_any(&self) -> &dyn std::any::Any
1201    where
1202        Self: 'static + Sized,
1203    {
1204        self
1205    }
1206}
1207
1208// CPU fallback implementation
1209
1210/// CPU context implementation
1211struct CpuContext;
1212
1213impl CpuContext {
1214    /// Create a new CPU context
1215    fn new() -> Self {
1216        Self
1217    }
1218}
1219
1220impl GpuContextImpl for CpuContext {
1221    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1222        Arc::new(CpuBuffer::new(size))
1223    }
1224
1225    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1226        Arc::new(CpuCompiler)
1227    }
1228}
1229
1230/// CPU buffer implementation
1231struct CpuBuffer {
1232    data: Vec<u8>,
1233}
1234
1235impl CpuBuffer {
1236    /// Create a new CPU buffer
1237    fn new(size: usize) -> Self {
1238        Self {
1239            data: vec![0; size],
1240        }
1241    }
1242}
1243
1244impl GpuBufferImpl for CpuBuffer {
1245    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1246        let mut_self = self as *const Self as *mut Self;
1247        let data_ptr = (*mut_self).data.as_mut_ptr();
1248        std::ptr::copy_nonoverlapping(data, data_ptr, size);
1249    }
1250
1251    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1252        let data_ptr = self.data.as_ptr();
1253        std::ptr::copy_nonoverlapping(data_ptr, data, size);
1254    }
1255
1256    fn as_any(&self) -> &dyn std::any::Any {
1257        self
1258    }
1259
1260    fn size(&self) -> usize {
1261        self.data.len()
1262    }
1263
1264    fn device_ptr(&self) -> u64 {
1265        self.data.as_ptr() as u64
1266    }
1267}
1268
1269/// CPU compiler implementation
1270struct CpuCompiler;
1271
1272impl GpuCompilerImpl for CpuCompiler {
1273    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1274        // In a real implementation, we would parse and execute the kernel
1275        // For now, just return a dummy implementation
1276        Ok(Arc::new(CpuKernel))
1277    }
1278
1279    fn compile_typed(
1280        &self,
1281        _name: &str,
1282        _input_type: std::any::TypeId,
1283        _output_type: std::any::TypeId,
1284    ) -> Arc<dyn GpuKernelImpl> {
1285        // In a real implementation, we would select an appropriate implementation
1286        // For now, just return a dummy implementation
1287        Arc::new(CpuKernel)
1288    }
1289}
1290
1291/// CPU kernel implementation
1292struct CpuKernel;
1293
1294impl GpuKernelImpl for CpuKernel {
1295    fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1296        // In a real implementation, we would store the buffer
1297    }
1298
1299    fn set_u32(&self, _name: &str, value: u32) {
1300        // In a real implementation, we would store the value
1301    }
1302
1303    fn set_i32(&self, _name: &str, value: i32) {
1304        // In a real implementation, we would store the value
1305    }
1306
1307    fn set_f32(&self, _name: &str, value: f32) {
1308        // In a real implementation, we would store the value
1309    }
1310
1311    fn set_f64(&self, _name: &str, value: f64) {
1312        // In a real implementation, we would store the value
1313    }
1314
1315    fn dispatch(&self, workgroups: [u32; 3]) {
1316        // In a real implementation, we would execute the kernel
1317    }
1318}
1319
1320// In a real implementation, we would have implementations for other backends
1321// such as CUDA, WebGPU, Metal, and OpenCL.
1322
1323#[cfg(test)]
1324mod tests {
1325    use super::*;
1326
1327    #[test]
1328    fn test_gpu_backend_preferred() {
1329        let backend = GpuBackend::preferred();
1330        // Should return a valid backend
1331        match backend {
1332            GpuBackend::Cuda
1333            | GpuBackend::Rocm
1334            | GpuBackend::Wgpu
1335            | GpuBackend::Metal
1336            | GpuBackend::OpenCL
1337            | GpuBackend::Cpu => {}
1338        }
1339    }
1340
1341    #[test]
1342    fn test_gpu_backend_default() {
1343        let backend = GpuBackend::default();
1344        assert_eq!(backend, GpuBackend::preferred());
1345    }
1346
1347    #[test]
1348    fn test_gpu_backend_is_available() {
1349        let backend = GpuBackend::Cpu;
1350        assert!(backend.is_available());
1351
1352        // Test other backends - availability depends on runtime, not just feature flags
1353        #[cfg(feature = "cuda")]
1354        {
1355            // CUDA feature enabled doesn't guarantee runtime availability
1356            let _ = GpuBackend::Cuda.is_available(); // Just check without asserting
1357        }
1358        #[cfg(not(feature = "cuda"))]
1359        assert!(!GpuBackend::Cuda.is_available());
1360
1361        #[cfg(feature = "rocm")]
1362        {
1363            // ROCm feature enabled doesn't guarantee runtime availability
1364            let _ = GpuBackend::Rocm.is_available(); // Just check without asserting
1365        }
1366        #[cfg(not(feature = "rocm"))]
1367        assert!(!GpuBackend::Rocm.is_available());
1368
1369        #[cfg(all(feature = "metal", target_os = "macos"))]
1370        assert!(GpuBackend::Metal.is_available());
1371        #[cfg(not(all(feature = "metal", target_os = "macos")))]
1372        assert!(!GpuBackend::Metal.is_available());
1373    }
1374
1375    #[test]
1376    fn test_gpu_backend_display() {
1377        assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1378        assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1379        assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1380        assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1381        assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1382        assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1383    }
1384
1385    #[test]
1386    fn test_gpuerror_from_conversion() {
1387        let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1388        let coreerror: CoreError = gpuerror.into();
1389        match coreerror {
1390            CoreError::ComputationError(_) => {}
1391            _ => panic!("Expected ComputationError"),
1392        }
1393
1394        let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1395        let coreerror: CoreError = gpuerror.into();
1396        match coreerror {
1397            CoreError::MemoryError(_) => {}
1398            _ => panic!("Expected MemoryError"),
1399        }
1400
1401        let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1402        let coreerror: CoreError = gpuerror.into();
1403        match coreerror {
1404            CoreError::InvalidArgument(_) => {}
1405            _ => panic!("Expected InvalidArgument"),
1406        }
1407
1408        let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1409        let coreerror: CoreError = gpuerror.into();
1410        match coreerror {
1411            CoreError::TypeError(_) => {}
1412            _ => panic!("Expected TypeError"),
1413        }
1414    }
1415
1416    #[test]
1417    fn test_gpu_datatype_trait() {
1418        // Test that various types implement GpuDataType
1419        fn assert_gpu_datatype<T: GpuDataType>() {}
1420
1421        assert_gpu_datatype::<f32>();
1422        assert_gpu_datatype::<f64>();
1423        assert_gpu_datatype::<i32>();
1424        assert_gpu_datatype::<u32>();
1425        assert_gpu_datatype::<u8>();
1426        assert_gpu_datatype::<i8>();
1427        assert_gpu_datatype::<u16>();
1428        assert_gpu_datatype::<i16>();
1429        assert_gpu_datatype::<u64>();
1430        assert_gpu_datatype::<i64>();
1431    }
1432
1433    #[test]
1434    fn test_gpu_buffer_creation() {
1435        let inner = Arc::new(CpuBuffer::new(100));
1436        let buffer = GpuBuffer::<f32>::new(inner, 25);
1437
1438        assert_eq!(buffer.len(), 25);
1439        assert!(!buffer.is_empty());
1440    }
1441
1442    #[test]
1443    fn test_gpu_buffer_empty() {
1444        let inner = Arc::new(CpuBuffer::new(0));
1445        let buffer = GpuBuffer::<f32>::new(inner, 0);
1446
1447        assert_eq!(buffer.len(), 0);
1448        assert!(buffer.is_empty());
1449    }
1450
1451    #[test]
1452    fn test_gpu_buffer_copy_operations() {
1453        let inner = Arc::new(CpuBuffer::new(16));
1454        let buffer = GpuBuffer::<f32>::new(inner, 4);
1455
1456        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1457        let _ = buffer.copy_from_host(&data);
1458
1459        let mut result = vec![0.0f32; 4];
1460        let _ = buffer.copy_to_host(&mut result);
1461
1462        assert_eq!(result, data);
1463    }
1464
1465    #[test]
1466    fn test_gpu_buffer_to_vec() {
1467        let inner = Arc::new(CpuBuffer::new(12));
1468        let buffer = GpuBuffer::<f32>::new(inner, 3);
1469
1470        let data = vec![5.0f32, 6.0, 7.0];
1471        let _ = buffer.copy_from_host(&data);
1472
1473        let result = buffer.to_vec();
1474        assert_eq!(result, data);
1475    }
1476
1477    #[test]
1478    #[should_panic(expected = "Data size exceeds buffer size")]
1479    fn test_gpu_buffer_copy_from_host_overflow() {
1480        let inner = Arc::new(CpuBuffer::new(8));
1481        let buffer = GpuBuffer::<f32>::new(inner, 2);
1482
1483        let data = vec![1.0f32, 2.0, 3.0]; // 3 elements > 2 buffer size
1484        buffer.copy_from_host(&data).expect("Operation failed");
1485    }
1486
1487    #[test]
1488    #[should_panic(expected = "Data size exceeds buffer size")]
1489    fn test_gpu_buffer_copy_to_host_overflow() {
1490        let inner = Arc::new(CpuBuffer::new(8));
1491        let buffer = GpuBuffer::<f32>::new(inner, 2);
1492
1493        let mut data = vec![0.0f32; 3]; // 3 elements > 2 buffer size
1494        buffer.copy_to_host(&mut data).expect("Operation failed");
1495    }
1496
1497    #[test]
1498    fn test_gpu_kernel_handle() {
1499        let kernel = Arc::new(CpuKernel);
1500        let handle = GpuKernelHandle::new(kernel);
1501
1502        // Test setting various parameter types
1503        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1504        handle.set_buffer("input", &buffer);
1505        handle.set_u32("size", 100);
1506        handle.set_i32("offset", -5);
1507        handle.set_f32("scale", 2.5);
1508        handle.set_f64("precision", 0.0001);
1509
1510        // Test dispatch
1511        handle.dispatch([16, 8, 1]);
1512    }
1513
1514    #[test]
1515    fn test_gpu_context_cpu_backend() {
1516        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1517        assert_eq!(context.backend(), GpuBackend::Cpu);
1518        assert_eq!(context.backend_name(), "CPU");
1519
1520        // Test memory query methods
1521        assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1522        assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1523    }
1524
1525    #[test]
1526    fn test_gpu_context_buffer_creation() {
1527        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1528
1529        let buffer = context.create_buffer::<f32>(100);
1530        assert_eq!(buffer.len(), 100);
1531
1532        let data = vec![1.0f32; 50];
1533        let buffer_from_slice = context.create_buffer_from_slice(&data);
1534        assert_eq!(buffer_from_slice.len(), 50);
1535
1536        let result = buffer_from_slice.to_vec();
1537        assert_eq!(result, data);
1538    }
1539
1540    #[test]
1541    fn test_gpu_context_unsupported_backend() {
1542        // Test a backend that's not available
1543        #[cfg(not(feature = "cuda"))]
1544        {
1545            let result = GpuContext::new(GpuBackend::Cuda);
1546            assert!(result.is_err());
1547            match result {
1548                Err(GpuError::UnsupportedBackend(_)) => {}
1549                Err(GpuError::BackendNotAvailable(_)) => {} // Also accept this error
1550                Err(e) => panic!(
1551                    "Expected UnsupportedBackend or BackendNotAvailable error, got: {:?}",
1552                    e
1553                ),
1554                Ok(_) => panic!("Expected error, got Ok"),
1555            }
1556        }
1557    }
1558
1559    #[test]
1560    fn test_gpu_compiler() {
1561        let compiler_impl = Arc::new(CpuCompiler);
1562        let compiler = GpuCompiler::new(compiler_impl);
1563
1564        // Test compiling from source
1565        let kernel = compiler
1566            .compile("dummy kernel source")
1567            .expect("Operation failed");
1568        kernel.dispatch([1, 1, 1]);
1569
1570        // Test typed compilation
1571        let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1572        typed_kernel.dispatch([32, 1, 1]);
1573    }
1574
1575    #[test]
1576    fn test_gpu_context_execute() {
1577        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1578
1579        let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1580
1581        assert!(result);
1582    }
1583
1584    #[test]
1585    fn test_gpu_context_kernel_registry() {
1586        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1587
1588        // Test getting a non-existent kernel
1589        let result = context.get_kernel("non_existent_kernel");
1590        assert!(result.is_err());
1591        match result {
1592            Err(GpuError::KernelNotFound(_)) => {}
1593            _ => panic!("Expected KernelNotFound error"),
1594        }
1595    }
1596
1597    #[test]
1598    fn test_cpu_buffer_implementation() {
1599        let buffer = CpuBuffer::new(256);
1600        assert_eq!(buffer.data.len(), 256);
1601
1602        // Test that initial data is zeroed
1603        assert!(buffer.data.iter().all(|&b| b == 0));
1604    }
1605
1606    #[test]
1607    fn test_gpuerror_display() {
1608        let error = GpuError::BackendNotAvailable("CUDA".to_string());
1609        assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1610
1611        let error = GpuError::OutOfMemory("allocation failed".to_string());
1612        assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1613
1614        let error = GpuError::KernelCompilationError("syntax error".to_string());
1615        assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1616
1617        let error = GpuError::KernelNotFound("gemm".to_string());
1618        assert_eq!(error.to_string(), "Kernel not found: gemm");
1619    }
1620
1621    #[test]
1622    fn test_backend_equality() {
1623        assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1624        assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1625
1626        // Test Clone and Copy
1627        let backend = GpuBackend::Metal;
1628        let cloned = backend;
1629        let copied = backend;
1630        assert_eq!(backend, cloned);
1631        assert_eq!(backend, copied);
1632    }
1633
1634    #[test]
1635    fn test_backend_hash() {
1636        use std::collections::HashSet;
1637
1638        let mut set = HashSet::new();
1639        set.insert(GpuBackend::Cuda);
1640        set.insert(GpuBackend::Rocm);
1641        set.insert(GpuBackend::Cuda); // Duplicate
1642
1643        assert_eq!(set.len(), 2); // Should only have 2 unique entries
1644        assert!(set.contains(&GpuBackend::Cuda));
1645        assert!(set.contains(&GpuBackend::Rocm));
1646    }
1647
1648    #[test]
1649    fn test_gpu_buffer_debug_clone() {
1650        let inner = Arc::new(CpuBuffer::new(16));
1651        let buffer = GpuBuffer::<f32>::new(inner, 4);
1652
1653        // Test Debug implementation
1654        let debug_str = format!("{:?}", buffer);
1655        assert!(debug_str.contains("GpuBuffer"));
1656        assert!(debug_str.contains("size"));
1657
1658        // Test Clone implementation
1659        let cloned = buffer.clone();
1660        assert_eq!(cloned.len(), buffer.len());
1661        assert_eq!(cloned.len(), 4);
1662
1663        // Verify the clone is independent (shares the same Arc)
1664        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1665        let _ = buffer.copy_from_host(&data);
1666
1667        let mut result = vec![0.0f32; 4];
1668        let _ = cloned.copy_to_host(&mut result);
1669        assert_eq!(result, data);
1670    }
1671
1672    #[test]
1673    fn test_gpu_context_debug() {
1674        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1675
1676        // Test Debug implementation
1677        let debug_str = format!("{:?}", context);
1678        assert!(debug_str.contains("GpuContext"));
1679        assert!(debug_str.contains("backend"));
1680        assert!(debug_str.contains("Cpu"));
1681    }
1682
1683    #[test]
1684    fn test_gpu_context_batch_dispatch() {
1685        // Create context with CPU backend (always available)
1686        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1687
1688        // Begin batch mode
1689        let begin_result = context.begin_batch();
1690        assert!(
1691            begin_result.is_ok(),
1692            "begin_batch should succeed on CPU backend"
1693        );
1694
1695        // Compile a kernel and dispatch it inside the batch
1696        let dispatch_result = context.execute(|compiler| {
1697            compiler.compile("dummy kernel source").map(|kernel| {
1698                kernel.dispatch([4, 1, 1]);
1699            })
1700        });
1701        assert!(
1702            dispatch_result.is_ok(),
1703            "kernel dispatch inside batch should succeed"
1704        );
1705
1706        // End batch — submits and waits
1707        let end_result = context.end_batch();
1708        assert!(
1709            end_result.is_ok(),
1710            "end_batch should succeed on CPU backend"
1711        );
1712    }
1713
1714    #[test]
1715    fn test_gpu_context_gpu_sync() {
1716        // Create context with CPU backend
1717        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1718
1719        // gpu_sync should complete without error on CPU backend
1720        let result = context.gpu_sync();
1721        assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1722    }
1723
1724    #[test]
1725    fn test_gpu_kernel_dispatch_no_wait() {
1726        // Create a kernel handle backed by CpuKernel
1727        let kernel = Arc::new(CpuKernel);
1728        let handle = GpuKernelHandle::new(kernel);
1729
1730        // Bind a buffer and scalar so the kernel has state
1731        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1732        handle.set_buffer("input", &buffer);
1733        handle.set_u32("size", 4);
1734
1735        // dispatch_no_wait returns () — just verify no panic
1736        handle.dispatch_no_wait([4, 1, 1]);
1737    }
1738
1739    #[test]
1740    fn test_gpu_device_get_info_cpu() {
1741        let device = GpuDevice::new(GpuBackend::Cpu, 0);
1742        let info = device
1743            .get_info()
1744            .expect("get_info should succeed for the CPU backend");
1745
1746        // Backend and basic identity must round-trip.
1747        assert_eq!(info.backend, GpuBackend::Cpu);
1748        assert_eq!(info.device_name, "CPU");
1749        assert_eq!(info.device_type, "CPU");
1750
1751        // The struct must be well-formed: non-empty descriptors and a valid
1752        // work-group size. The CPU fallback emulates both precisions.
1753        assert!(!info.device_name.is_empty());
1754        assert!(!info.compute_capability.is_empty());
1755        assert!(info.max_work_group_size >= 1);
1756        assert!(info.supports_fp64);
1757        assert!(info.supports_fp16);
1758    }
1759
1760    #[test]
1761    fn test_gpu_device_info_is_deterministic_per_backend() {
1762        // Every backend must yield a well-formed, non-empty, deterministic info.
1763        for backend in [
1764            GpuBackend::Cpu,
1765            GpuBackend::Cuda,
1766            GpuBackend::Rocm,
1767            GpuBackend::Wgpu,
1768            GpuBackend::Metal,
1769            GpuBackend::OpenCL,
1770        ] {
1771            let device = GpuDevice::new(backend, 0);
1772            let info = device.get_info().expect("get_info should not fail");
1773            let info_again = device.get_info().expect("get_info should not fail");
1774
1775            assert_eq!(info.backend, backend);
1776            assert!(!info.device_name.is_empty());
1777            assert!(!info.device_type.is_empty());
1778            assert!(!info.compute_capability.is_empty());
1779            assert!(info.max_work_group_size >= 1);
1780
1781            // Deterministic: two queries produce identical descriptors.
1782            assert_eq!(info.device_name, info_again.device_name);
1783            assert_eq!(info.max_work_group_size, info_again.max_work_group_size);
1784            assert_eq!(info.supports_fp64, info_again.supports_fp64);
1785        }
1786    }
1787}