Skip to main content

scirs2_core/gpu/
mod.rs

1//! GPU acceleration module for scirs2-core
2//!
3//! This module provides hardware acceleration support across different GPU backends.
4
5use std::fmt;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub mod async_execution;
10pub mod async_transfer;
11pub mod auto_tuning;
12pub mod backends;
13pub mod benchmarks;
14mod cpu_ops;
15pub mod device_info;
16pub mod heterogeneous;
17pub mod kernels;
18pub mod memory_management;
19pub mod stream_allocator;
20pub mod tensor_cores;
21
22pub use async_transfer::{
23    AsyncTransferError, AsyncTransferPipeline, TransferDirection, TransferHandle,
24};
25pub use device_info::GpuDeviceInfo;
26pub use memory_management::unified_memory::{SyncState, UnifiedAllocator, UnifiedBuffer};
27pub use stream_allocator::{StreamAllocator, StreamId};
28
29/// GPU backend type
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum GpuBackend {
32    /// NVIDIA CUDA backend
33    Cuda,
34    /// AMD ROCm backend
35    Rocm,
36    /// WebGPU backend
37    Wgpu,
38    /// Apple Metal backend
39    Metal,
40    /// OpenCL backend
41    OpenCL,
42    /// CPU fallback
43    Cpu,
44}
45
46impl Default for GpuBackend {
47    fn default() -> Self {
48        Self::preferred()
49    }
50}
51
52impl GpuBackend {
53    /// Get the preferred GPU backend for the current system
54    pub fn preferred() -> Self {
55        // Use the backend detection system to find the optimal backend
56        // This will properly detect available GPUs and fall back to CPU if needed
57        match backends::initialize_optimal_backend() {
58            Ok(backend) => {
59                // If we get a non-CPU backend, verify it's actually usable
60                if backend != GpuBackend::Cpu {
61                    // Check if we can actually create a context with this backend
62                    // For now, since implementations are stubs, fall back to CPU
63                    #[cfg(not(test))]
64                    {
65                        // In non-test environments, we don't have real GPU implementations yet
66                        return GpuBackend::Cpu;
67                    }
68                    #[cfg(test)]
69                    {
70                        // In tests, we can pretend the backend works
71                        return backend;
72                    }
73                }
74                backend
75            }
76            Err(_) => {
77                // If detection fails entirely, use CPU
78                GpuBackend::Cpu
79            }
80        }
81    }
82
83    /// Check if this backend is available on the current system
84    pub fn is_available(&self) -> bool {
85        match self {
86            // Check runtime availability for GPU backends
87            GpuBackend::Cuda => false, // CUDA backend retired from scirs2-core in 0.6.x — use oxicuda-* crates
88            GpuBackend::Rocm => cfg!(feature = "rocm"), // Would use ROCm runtime check
89            GpuBackend::Wgpu => {
90                #[cfg(feature = "wgpu")]
91                {
92                    use crate::gpu::backends::wgpu::WebGPUContext;
93                    WebGPUContext::is_available()
94                }
95                #[cfg(not(feature = "wgpu"))]
96                {
97                    false
98                }
99            }
100            GpuBackend::Metal => {
101                #[cfg(all(feature = "metal", target_os = "macos"))]
102                {
103                    // Metal is always available on macOS if the feature is enabled
104                    true
105                }
106                #[cfg(not(all(feature = "metal", target_os = "macos")))]
107                {
108                    false
109                }
110            }
111            GpuBackend::OpenCL => {
112                #[cfg(feature = "opencl")]
113                {
114                    use crate::gpu::backends::opencl::OpenCLContext;
115                    OpenCLContext::is_available()
116                }
117                #[cfg(not(feature = "opencl"))]
118                {
119                    false
120                }
121            }
122            GpuBackend::Cpu => true,
123        }
124    }
125}
126
127impl fmt::Display for GpuBackend {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        match self {
130            GpuBackend::Cuda => write!(f, "CUDA"),
131            GpuBackend::Rocm => write!(f, "ROCm"),
132            GpuBackend::Wgpu => write!(f, "WebGPU"),
133            GpuBackend::Metal => write!(f, "Metal"),
134            GpuBackend::OpenCL => write!(f, "OpenCL"),
135            GpuBackend::Cpu => write!(f, "CPU"),
136        }
137    }
138}
139
140use crate::error::{CoreError, ErrorContext, ErrorLocation};
141
142/// Error type for GPU operations
143#[derive(Debug, thiserror::Error)]
144pub enum GpuError {
145    /// Backend is not available
146    #[error("GPU backend {0} is not available")]
147    BackendNotAvailable(String),
148
149    /// Backend is not supported
150    #[error("GPU backend {0} is not supported")]
151    UnsupportedBackend(GpuBackend),
152
153    /// Backend is not supported for a kernel
154    #[error("GPU backend {0:?} is not supported for this kernel")]
155    BackendNotSupported(GpuBackend),
156
157    /// Backend is not implemented yet
158    #[error("GPU backend {0} is not implemented yet")]
159    BackendNotImplemented(GpuBackend),
160
161    /// Out of memory
162    #[error("GPU out of memory: {0}")]
163    OutOfMemory(String),
164
165    /// Kernel compilation error
166    #[error("Kernel compilation error: {0}")]
167    KernelCompilationError(String),
168
169    /// Kernel execution error
170    #[error("Kernel execution error: {0}")]
171    KernelExecutionError(String),
172
173    /// Invalid parameter
174    #[error("Invalid parameter: {0}")]
175    InvalidParameter(String),
176
177    /// Kernel not found
178    #[error("Kernel not found: {0}")]
179    KernelNotFound(String),
180
181    /// Specialization not supported
182    #[error("Kernel specialization not supported")]
183    SpecializationNotSupported,
184
185    /// Unsupported data type
186    #[error("Unsupported data type: {0:?}")]
187    UnsupportedDataType(kernels::DataType),
188
189    /// Other error
190    #[error("{0}")]
191    Other(String),
192}
193
194/// GPU device abstraction
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub struct GpuDevice {
197    backend: GpuBackend,
198    device_id: usize,
199}
200
201impl GpuDevice {
202    /// Create a new GPU device
203    pub fn new(backend: GpuBackend, device_id: usize) -> Self {
204        Self { backend, device_id }
205    }
206
207    /// Get the backend type
208    pub fn backend(&self) -> GpuBackend {
209        self.backend
210    }
211
212    /// Get the device ID
213    pub fn device_id(&self) -> usize {
214        self.device_id
215    }
216
217    /// Compile a kernel from source
218    pub fn compile_kernel(&self, _source: &str, entrypoint: &str) -> Result<GpuKernel, GpuError> {
219        // Placeholder implementation
220        Ok(GpuKernel {
221            backend: self.backend,
222            entry_point: entrypoint.to_string(),
223        })
224    }
225
226    /// Query the capabilities of this device.
227    ///
228    /// Returns a [`GpuDeviceInfo`] describing the device name, type, memory,
229    /// work-group limits and supported floating-point precisions. The values are
230    /// derived deterministically from this device's [`backend`](Self::backend).
231    ///
232    /// In the default Pure-Rust build no GPU SDK is linked, so the values for
233    /// hardware backends are conservative placeholders (see
234    /// [`GpuDeviceInfo`]); the method itself never fails in that configuration.
235    /// The `Result` return type reserves the ability to surface real adapter
236    /// query failures once live introspection is wired in.
237    ///
238    /// # Errors
239    ///
240    /// Currently always returns `Ok`. Reserved for future backends that perform
241    /// a fallible live adapter query.
242    pub fn get_info(&self) -> Result<GpuDeviceInfo, GpuError> {
243        Ok(GpuDeviceInfo::for_backend(self.backend))
244    }
245}
246
247/// GPU kernel abstraction
248pub struct GpuKernel {
249    backend: GpuBackend,
250    entry_point: String,
251}
252
253impl GpuKernel {
254    /// Get the backend type
255    pub fn backend(&self) -> GpuBackend {
256        self.backend
257    }
258
259    /// Get the entry point name
260    pub fn entry_point(&self) -> &str {
261        &self.entry_point
262    }
263}
264
265/// Convert GPU errors to core errors with semantic preservation
266impl From<GpuError> for CoreError {
267    fn from(err: GpuError) -> Self {
268        match err {
269            GpuError::BackendNotAvailable(backend) => CoreError::ComputationError(
270                ErrorContext::new(format!("GPU backend {backend} is not available"))
271                    .with_location(ErrorLocation::new(file!(), line!())),
272            ),
273            GpuError::UnsupportedBackend(backend) => CoreError::NotImplementedError(
274                ErrorContext::new(format!("GPU backend {backend} is not supported"))
275                    .with_location(ErrorLocation::new(file!(), line!())),
276            ),
277            GpuError::BackendNotSupported(backend) => CoreError::NotImplementedError(
278                ErrorContext::new(format!(
279                    "GPU backend {backend:?} is not supported for this kernel"
280                ))
281                .with_location(ErrorLocation::new(file!(), line!())),
282            ),
283            GpuError::BackendNotImplemented(backend) => CoreError::NotImplementedError(
284                ErrorContext::new(format!("GPU backend {backend} is not implemented yet"))
285                    .with_location(ErrorLocation::new(file!(), line!())),
286            ),
287            GpuError::OutOfMemory(details) => CoreError::MemoryError(
288                ErrorContext::new(details.to_string())
289                    .with_location(ErrorLocation::new(file!(), line!())),
290            ),
291            GpuError::KernelCompilationError(msg) => CoreError::ComputationError(
292                ErrorContext::new(msg.to_string())
293                    .with_location(ErrorLocation::new(file!(), line!())),
294            ),
295            GpuError::KernelExecutionError(msg) => CoreError::ComputationError(
296                ErrorContext::new(msg.to_string())
297                    .with_location(ErrorLocation::new(file!(), line!())),
298            ),
299            GpuError::InvalidParameter(msg) => CoreError::InvalidArgument(
300                ErrorContext::new(msg.to_string())
301                    .with_location(ErrorLocation::new(file!(), line!())),
302            ),
303            GpuError::KernelNotFound(name) => CoreError::ComputationError(
304                ErrorContext::new(name.to_string())
305                    .with_location(ErrorLocation::new(file!(), line!())),
306            ),
307            GpuError::SpecializationNotSupported => CoreError::NotImplementedError(
308                ErrorContext::new("Kernel specialization not supported".to_string())
309                    .with_location(ErrorLocation::new(file!(), line!())),
310            ),
311            GpuError::UnsupportedDataType(dtype) => CoreError::TypeError(
312                ErrorContext::new(format!("{dtype:?}"))
313                    .with_location(ErrorLocation::new(file!(), line!())),
314            ),
315            GpuError::Other(msg) => CoreError::ComputationError(
316                ErrorContext::new(msg).with_location(ErrorLocation::new(file!(), line!())),
317            ),
318        }
319    }
320}
321
322/// Trait for types that can be used with GPU operations
323pub trait GpuDataType: Copy + Send + Sync + 'static {}
324
325/// GPU memory pointer abstraction
326#[derive(Debug)]
327pub struct GpuPtr<T: GpuDataType> {
328    ptr: u64,
329    size: usize,
330    phantom: PhantomData<T>,
331}
332
333impl<T: GpuDataType> GpuPtr<T> {
334    /// Allocate GPU memory
335    pub fn allocate(size: usize) -> Result<Self, GpuError> {
336        Ok(GpuPtr {
337            ptr: 0x1000_0000, // Placeholder address
338            size,
339            phantom: PhantomData,
340        })
341    }
342
343    /// Get the raw pointer value
344    pub fn as_ptr(&self) -> u64 {
345        self.ptr
346    }
347
348    /// Get the size in elements
349    pub fn len(&self) -> usize {
350        self.size
351    }
352
353    /// Check if the pointer is empty (size is 0)
354    pub fn is_empty(&self) -> bool {
355        self.size == 0
356    }
357}
358
359/// Kernel argument types for GPU kernel execution
360#[derive(Debug, Clone)]
361pub enum KernelArg<'a, T: GpuDataType> {
362    /// Buffer argument
363    Buffer(&'a GpuPtr<T>),
364    /// Scalar argument
365    Scalar(T),
366}
367
368/// Non-generic kernel argument for mixed-type kernel calls
369#[derive(Debug, Clone)]
370pub enum DynamicKernelArg {
371    /// Buffer argument (type-erased)
372    Buffer(u64), // Raw pointer
373    /// f32 scalar
374    F32(f32),
375    /// f64 scalar
376    F64(f64),
377    /// i32 scalar
378    I32(i32),
379    /// u32 scalar
380    U32(u32),
381    /// usize scalar
382    Usize(usize),
383}
384
385/// GPU communication channel for multi-GPU operations
386pub struct GpuChannel {
387    #[allow(dead_code)]
388    source_device: usize,
389    #[allow(dead_code)]
390    target_device: usize,
391    #[allow(dead_code)]
392    bandwidth: f64, // GB/s
393}
394
395// Implement for common types
396impl GpuDataType for f32 {}
397impl GpuDataType for f64 {}
398impl GpuDataType for i32 {}
399impl GpuDataType for u32 {}
400impl GpuDataType for u8 {}
401impl GpuDataType for i8 {}
402impl GpuDataType for u16 {}
403impl GpuDataType for i16 {}
404impl GpuDataType for u64 {}
405impl GpuDataType for i64 {}
406impl GpuDataType for usize {}
407impl GpuDataType for isize {}
408
409/// GPU buffer
410pub struct GpuBuffer<T: GpuDataType> {
411    inner: Arc<dyn GpuBufferImpl>,
412    size: usize,
413    phantom: PhantomData<T>,
414}
415
416impl<T: GpuDataType> GpuBuffer<T> {
417    /// Create a new buffer with the given size
418    pub(crate) fn new(inner: Arc<dyn GpuBufferImpl>, size: usize) -> Self {
419        Self {
420            inner,
421            size,
422            phantom: PhantomData,
423        }
424    }
425
426    /// Get the size of the buffer in elements
427    pub fn len(&self) -> usize {
428        self.size
429    }
430
431    /// Check if the buffer is empty
432    pub fn is_empty(&self) -> bool {
433        self.size == 0
434    }
435
436    /// Copy data from the host to the device
437    pub fn copy_from_host(&self, data: &[T]) -> Result<(), GpuError> {
438        if data.len() > self.size {
439            return Err(GpuError::InvalidParameter(
440                "Data size exceeds buffer size".to_string(),
441            ));
442        }
443        unsafe {
444            self.inner
445                .copy_from_host(data.as_ptr() as *const u8, std::mem::size_of_val(data));
446        }
447        Ok(())
448    }
449
450    /// Copy data from the device to the host
451    pub fn copy_to_host(&self, data: &mut [T]) -> Result<(), GpuError> {
452        if data.len() > self.size {
453            return Err(GpuError::InvalidParameter(
454                "Data size exceeds buffer size".to_string(),
455            ));
456        }
457        unsafe {
458            self.inner
459                .copy_to_host(data.as_mut_ptr() as *mut u8, std::mem::size_of_val(data));
460        }
461        Ok(())
462    }
463
464    /// Convert the buffer contents to a vector
465    pub fn to_vec(&self) -> Vec<T> {
466        let mut result = vec![unsafe { std::mem::zeroed() }; self.size];
467        let _ = self.copy_to_host(&mut result);
468        result
469    }
470}
471
472impl<T: GpuDataType> fmt::Debug for GpuBuffer<T> {
473    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474        f.debug_struct("GpuBuffer")
475            .field("size", &self.size)
476            .finish()
477    }
478}
479
480impl<T: GpuDataType> Clone for GpuBuffer<T> {
481    fn clone(&self) -> Self {
482        Self {
483            inner: Arc::clone(&self.inner),
484            size: self.size,
485            phantom: PhantomData,
486        }
487    }
488}
489
490/// GPU kernel handle
491#[derive(Clone)]
492pub struct GpuKernelHandle {
493    inner: Arc<dyn GpuKernelImpl>,
494}
495
496impl GpuKernelHandle {
497    /// Create a new kernel handle
498    pub(crate) fn new(inner: Arc<dyn GpuKernelImpl>) -> Self {
499        Self { inner }
500    }
501
502    /// Set a buffer parameter
503    pub fn set_buffer<T: GpuDataType>(&self, name: &str, buffer: &GpuBuffer<T>) {
504        self.inner.set_buffer(name, &buffer.inner);
505    }
506
507    /// Set a u32 parameter
508    pub fn set_u32(&self, name: &str, value: u32) {
509        self.inner.set_u32(name, value);
510    }
511
512    /// Set an i32 parameter
513    pub fn set_i32(&self, name: &str, value: i32) {
514        self.inner.set_i32(name, value);
515    }
516
517    /// Set an f32 parameter
518    pub fn set_f32(&self, name: &str, value: f32) {
519        self.inner.set_f32(name, value);
520    }
521
522    /// Set an f64 parameter
523    pub fn set_f64(&self, name: &str, value: f64) {
524        self.inner.set_f64(name, value);
525    }
526
527    /// Dispatch the kernel with the given work group counts.
528    ///
529    /// If batch mode is active (see [`GpuContext::begin_batch`]), the
530    /// dispatch is encoded into the shared command buffer instead of
531    /// submitting immediately.
532    pub fn dispatch(&self, workgroups: [u32; 3]) {
533        if !self.inner.try_batch_dispatch(workgroups) {
534            self.inner.dispatch(workgroups);
535        }
536    }
537
538    /// Dispatch the kernel without waiting for GPU completion.
539    ///
540    /// This submits work to the GPU command queue but returns immediately.
541    /// Use [`GpuContext::gpu_sync()`] to wait for all pending dispatches.
542    pub fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
543        self.inner.dispatch_no_wait(workgroups);
544    }
545}
546
547/// GPU compiler for dynamic kernels
548pub struct GpuCompiler {
549    inner: Arc<dyn GpuCompilerImpl>,
550}
551
552impl GpuCompiler {
553    /// Create a new compiler
554    pub(crate) fn new(inner: Arc<dyn GpuCompilerImpl>) -> Self {
555        Self { inner }
556    }
557
558    /// Compile a kernel from source
559    pub fn compile(&self, source: &str) -> Result<GpuKernelHandle, GpuError> {
560        let kernel = self.inner.compile(source)?;
561        Ok(GpuKernelHandle::new(kernel))
562    }
563
564    /// Compile a kernel for the specified input and output types
565    pub fn compile_kernel<I: GpuDataType, O: GpuDataType>(&self, name: &str) -> GpuKernelHandle {
566        let kernel = self.inner.compile_typed(
567            name,
568            std::any::TypeId::of::<I>(),
569            std::any::TypeId::of::<O>(),
570        );
571        GpuKernelHandle::new(kernel)
572    }
573}
574
575/// GPU context for managing GPU resources and operations
576pub struct GpuContext {
577    inner: Arc<dyn GpuContextImpl>,
578    backend: GpuBackend,
579    kernel_registry: kernels::KernelRegistry,
580}
581
582impl GpuContext {
583    /// Create a new GPU context with the specified backend
584    pub fn new(backend: GpuBackend) -> Result<Self, GpuError> {
585        // CUDA was retired from scirs2-core in 0.6.x (moved to the per-crate oxicuda-*
586        // backends). Intercept it before the generic availability checks so callers
587        // receive explicit migration guidance instead of a bare "not available".
588        if backend == GpuBackend::Cuda {
589            return Err(GpuError::BackendNotAvailable(
590                "CUDA backend was removed from scirs2-core in 0.6.x; use the oxicuda-* crates directly (per-crate `cuda` feature). See SCIRS2_POLICY.md".to_string(),
591            ));
592        }
593
594        // First check if the backend is available at compile time
595        if !backend.is_available() {
596            return Err(GpuError::BackendNotAvailable(backend.to_string()));
597        }
598
599        // For non-CPU backends, also check runtime availability
600        if backend != GpuBackend::Cpu {
601            let detection_result = backends::detect_gpu_backends();
602            let backend_available = detection_result
603                .devices
604                .iter()
605                .any(|d| d.backend == backend && d.backend != GpuBackend::Cpu);
606
607            if !backend_available {
608                return Err(GpuError::BackendNotAvailable(format!(
609                    "{backend} (no devices detected at runtime)"
610                )));
611            }
612        }
613
614        let inner = match backend {
615            GpuBackend::Cuda => {
616                return Err(GpuError::BackendNotAvailable(
617                    "CUDA backend was removed from scirs2-core in 0.6.x; use the oxicuda-* crates directly (per-crate `cuda` feature). See SCIRS2_POLICY.md".to_string(),
618                ));
619            }
620            GpuBackend::Rocm => {
621                #[cfg(feature = "rocm")]
622                {
623                    // This is just a stub - in a real implementation, we would use the hip-sys crate
624                    // to create a ROCm context and return it
625                    #[cfg(test)]
626                    {
627                        // For testing, we can use a mock implementation
628                        Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>
629                    }
630                    #[cfg(not(test))]
631                    {
632                        return Err(GpuError::BackendNotImplemented(backend));
633                    }
634                }
635                #[cfg(not(feature = "rocm"))]
636                {
637                    return Err(GpuError::UnsupportedBackend(backend));
638                }
639            }
640            GpuBackend::Wgpu => {
641                #[cfg(feature = "wgpu")]
642                {
643                    use crate::gpu::backends::wgpu::WebGPUContext;
644                    match WebGPUContext::new() {
645                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
646                        Err(e) => return Err(e),
647                    }
648                }
649                #[cfg(not(feature = "wgpu"))]
650                {
651                    return Err(GpuError::UnsupportedBackend(backend));
652                }
653            }
654            GpuBackend::Metal => {
655                #[cfg(all(feature = "metal", target_os = "macos"))]
656                {
657                    use crate::gpu::backends::metal::MetalContext;
658                    match MetalContext::new() {
659                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
660                        Err(e) => return Err(e),
661                    }
662                }
663                #[cfg(not(all(feature = "metal", target_os = "macos")))]
664                {
665                    return Err(GpuError::UnsupportedBackend(backend));
666                }
667            }
668            GpuBackend::OpenCL => {
669                #[cfg(feature = "opencl")]
670                {
671                    use crate::gpu::backends::opencl::OpenCLContext;
672                    match OpenCLContext::new() {
673                        Ok(ctx) => Arc::new(ctx) as Arc<dyn GpuContextImpl>,
674                        Err(e) => return Err(e),
675                    }
676                }
677                #[cfg(not(feature = "opencl"))]
678                {
679                    return Err(GpuError::UnsupportedBackend(backend));
680                }
681            }
682            GpuBackend::Cpu => Arc::new(CpuContext::new()) as Arc<dyn GpuContextImpl>,
683        };
684
685        Ok(Self {
686            inner,
687            backend,
688            kernel_registry: kernels::KernelRegistry::with_default_kernels(),
689        })
690    }
691
692    /// Get the backend type
693    pub fn backend(&self) -> GpuBackend {
694        self.backend
695    }
696
697    /// Get the backend name
698    pub fn backend_name(&self) -> &str {
699        match self.backend {
700            GpuBackend::Cuda => "CUDA",
701            GpuBackend::Rocm => "ROCm",
702            GpuBackend::Wgpu => "WebGPU",
703            GpuBackend::Metal => "Metal",
704            GpuBackend::OpenCL => "OpenCL",
705            GpuBackend::Cpu => "CPU",
706        }
707    }
708
709    /// Wait for all pending GPU dispatches to complete.
710    ///
711    /// After calling [`GpuKernelHandle::dispatch_no_wait()`], results are
712    /// not guaranteed to be available until this method returns.
713    pub fn gpu_sync(&self) -> Result<(), GpuError> {
714        self.inner.gpu_sync()
715    }
716
717    /// Begin batch dispatch mode.
718    ///
719    /// While active, kernel dispatches are encoded into a shared command
720    /// buffer instead of creating individual ones.  Call [`end_batch`](Self::end_batch)
721    /// to submit all encoded work and wait for completion.
722    pub fn begin_batch(&self) -> Result<(), GpuError> {
723        self.inner.begin_batch()
724    }
725
726    /// End batch dispatch mode.
727    ///
728    /// Submits the shared command buffer containing all dispatches that
729    /// were encoded since [`begin_batch`](Self::begin_batch), then waits
730    /// for the GPU to finish.
731    pub fn end_batch(&self) -> Result<(), GpuError> {
732        self.inner.end_batch()
733    }
734
735    pub fn create_buffer<T: GpuDataType>(&self, size: usize) -> GpuBuffer<T> {
736        let byte_size = size.saturating_mul(std::mem::size_of::<T>());
737        let inner = self.inner.create_buffer(byte_size);
738        GpuBuffer::new(inner, size)
739    }
740
741    /// Create a buffer from a slice
742    pub fn create_buffer_from_slice<T: GpuDataType>(&self, data: &[T]) -> GpuBuffer<T> {
743        let buffer = self.create_buffer::<T>(data.len());
744        let _ = buffer.copy_from_host(data);
745        buffer
746    }
747
748    /// Execute a function with a compiler
749    pub fn execute<F, R>(&self, f: F) -> R
750    where
751        F: FnOnce(&GpuCompiler) -> R,
752    {
753        let compiler = GpuCompiler::new(self.inner.create_compiler());
754        f(&compiler)
755    }
756
757    /// Get a kernel from the registry
758    pub fn get_kernel(&self, name: &str) -> Result<GpuKernelHandle, GpuError> {
759        let kernel = self
760            .kernel_registry
761            .get(name)
762            .ok_or_else(|| GpuError::KernelNotFound(name.to_string()))?;
763
764        let kernel_source = kernel.source_for_backend(self.backend)?;
765        let metadata = kernel.metadata();
766
767        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
768        Ok(handle)
769    }
770
771    /// Get a specialized kernel from the registry
772    pub fn get_specialized_kernel(
773        &self,
774        name: &str,
775        params: &kernels::KernelParams,
776    ) -> Result<GpuKernelHandle, GpuError> {
777        let specialized = self.kernel_registry.get_specialized(name, params)?;
778        let kernel_source = specialized.source_for_backend(self.backend)?;
779        let metadata = specialized.metadata();
780
781        let handle = self.compile_kernel_with_metadata(&kernel_source, &metadata)?;
782        Ok(handle)
783    }
784
785    /// Compile a kernel with metadata
786    fn compile_kernel_with_metadata(
787        &self,
788        source: &str,
789        _metadata: &kernels::KernelMetadata,
790    ) -> Result<GpuKernelHandle, GpuError> {
791        self.execute(|compiler| compiler.compile(source))
792    }
793
794    /// Get available memory on the device
795    pub fn get_available_memory(&self) -> Option<usize> {
796        // In a real implementation, this would query the device
797        // For now, return a placeholder value
798        Some(1024 * 1024 * 1024) // 1GB
799    }
800
801    /// Get total memory on the device
802    pub fn get_total_memory(&self) -> Option<usize> {
803        // In a real implementation, this would query the device
804        // For now, return a placeholder value
805        #[cfg(target_arch = "wasm32")]
806        return Some(512 * 1024 * 1024); // 512MB for WASM32
807
808        #[cfg(not(target_arch = "wasm32"))]
809        Some((4u64 * 1024 * 1024 * 1024) as usize) // 4GB for native
810    }
811
812    /// Launch a kernel with the given parameters
813    pub fn launch_kernel(
814        &self,
815        kernel_name: &str,
816        grid_size: (usize, usize, usize),
817        block_size: (usize, usize, usize),
818        args: &[DynamicKernelArg],
819    ) -> Result<(), GpuError> {
820        // Placeholder implementation
821        let _ = (kernel_name, grid_size, block_size, args);
822        Ok(())
823    }
824
825    /// Transfer data from host to device asynchronously
826    pub fn transfer_async_host_to_device<T: GpuDataType>(
827        &self,
828        ptr: &GpuPtr<T>,
829        data: &[T],
830    ) -> Result<(), GpuError> {
831        // Placeholder implementation
832        let _ = (ptr, data);
833        Ok(())
834    }
835
836    /// Transfer data from host to device synchronously
837    pub fn transfer_host_to_device<T: GpuDataType>(
838        &self,
839        ptr: &GpuPtr<T>,
840        data: &[T],
841    ) -> Result<(), GpuError> {
842        // Placeholder implementation
843        let _ = (ptr, data);
844        Ok(())
845    }
846
847    /// Transfer data from device to host asynchronously
848    pub fn transfer_async_device_to_host<T: GpuDataType>(
849        &self,
850        ptr: &GpuPtr<T>,
851        data: &mut [T],
852    ) -> Result<(), GpuError> {
853        // Placeholder implementation
854        let _ = (ptr, data);
855        Ok(())
856    }
857
858    /// Transfer data from device to host synchronously
859    pub fn transfer_device_to_host<T: GpuDataType>(
860        &self,
861        ptr: &GpuPtr<T>,
862        data: &mut [T],
863    ) -> Result<(), GpuError> {
864        // Placeholder implementation
865        let _ = (ptr, data);
866        Ok(())
867    }
868
869    /// Execute a kernel with dynamic compilation and parameter passing
870    /// This method is expected by scirs2-vision for GPU operations
871    pub fn execute_kernel(
872        &self,
873        source: &str,
874        buffers: &[GpuBuffer<f32>],
875        work_groups: (u32, u32, u32),
876        int_params: &[u32],
877        float_params: &[f32],
878    ) -> Result<(), GpuError> {
879        // For now, provide a basic implementation that logs the execution
880        // In a real implementation, this would compile and execute the kernel
881        eprintln!(
882            "GPU kernel execution (source length: {}, buffers: {}, workgroups: {:?})",
883            source.len(),
884            buffers.len(),
885            work_groups
886        );
887        eprintln!("Int params: {int_params:?}");
888        eprintln!("Float params: {float_params:?}");
889        Ok(())
890    }
891
892    /// Read data from a GPU buffer
893    /// This method is expected by scirs2-vision for reading GPU results
894    pub fn read_buffer<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<Vec<T>, GpuError> {
895        Ok(buffer.to_vec())
896    }
897
898    /// Global sum reduction
899    pub fn sum_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
900        self.sum_all_cpu_fallback(buffer)
901    }
902
903    /// Global mean reduction
904    pub fn mean_all<T: GpuDataType>(
905        &self,
906        buffer: &GpuBuffer<T>,
907    ) -> Result<GpuBuffer<T>, GpuError> {
908        self.mean_all_cpu_fallback(buffer)
909    }
910
911    /// Global max reduction
912    pub fn max_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
913        self.max_all_cpu_fallback(buffer)
914    }
915
916    /// Global min reduction
917    pub fn min_all<T: GpuDataType>(&self, buffer: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
918        self.min_all_cpu_fallback(buffer)
919    }
920
921    /// Sum reduction along an axis
922    pub fn sum_axis<T: GpuDataType>(
923        &self,
924        buffer: &GpuBuffer<T>,
925        shape: &[usize],
926        axis: usize,
927    ) -> Result<GpuBuffer<T>, GpuError> {
928        self.sum_axis_cpu_fallback(buffer, shape, axis)
929    }
930
931    /// Mean reduction along an axis
932    pub fn mean_axis<T: GpuDataType>(
933        &self,
934        buffer: &GpuBuffer<T>,
935        shape: &[usize],
936        axis: usize,
937    ) -> Result<GpuBuffer<T>, GpuError> {
938        self.mean_axis_cpu_fallback(buffer, shape, axis)
939    }
940
941    /// Max reduction along an axis
942    pub fn max_axis<T: GpuDataType>(
943        &self,
944        buffer: &GpuBuffer<T>,
945        shape: &[usize],
946        axis: usize,
947    ) -> Result<GpuBuffer<T>, GpuError> {
948        self.max_axis_cpu_fallback(buffer, shape, axis)
949    }
950
951    /// Min reduction along an axis
952    pub fn min_axis<T: GpuDataType>(
953        &self,
954        buffer: &GpuBuffer<T>,
955        shape: &[usize],
956        axis: usize,
957    ) -> Result<GpuBuffer<T>, GpuError> {
958        self.min_axis_cpu_fallback(buffer, shape, axis)
959    }
960
961    /// Broadcast a buffer to a different shape
962    pub fn broadcast<T: GpuDataType>(
963        &self,
964        buffer: &GpuBuffer<T>,
965        from_shape: &[usize],
966        to_shape: &[usize],
967    ) -> Result<GpuBuffer<T>, GpuError> {
968        self.broadcast_cpu_fallback(buffer, from_shape, to_shape)
969    }
970
971    /// Scale a buffer by a scalar value
972    pub fn scale<T: GpuDataType>(
973        &self,
974        buffer: &GpuBuffer<T>,
975        scalar: T,
976    ) -> Result<GpuBuffer<T>, GpuError> {
977        self.scale_cpu_fallback(buffer, scalar)
978    }
979
980    /// General matrix multiplication: C = A @ B
981    pub fn gemm<T: GpuDataType>(
982        &self,
983        a: &GpuBuffer<T>,
984        b: &GpuBuffer<T>,
985        m: usize,
986        k: usize,
987        n: usize,
988    ) -> Result<GpuBuffer<T>, GpuError> {
989        self.gemm_cpu_fallback(a, b, m, k, n)
990    }
991
992    /// GEMM with transposed B: C = A @ B^T
993    pub fn gemm_transpose_b<T: GpuDataType>(
994        &self,
995        a: &GpuBuffer<T>,
996        b: &GpuBuffer<T>,
997        m: usize,
998        k: usize,
999        n: usize,
1000    ) -> Result<GpuBuffer<T>, GpuError> {
1001        self.gemm_transpose_b_cpu_fallback(a, b, m, k, n)
1002    }
1003
1004    /// GEMM with transposed A: C = A^T @ B
1005    pub fn gemm_transpose_a<T: GpuDataType>(
1006        &self,
1007        a: &GpuBuffer<T>,
1008        b: &GpuBuffer<T>,
1009        m: usize,
1010        k: usize,
1011        n: usize,
1012    ) -> Result<GpuBuffer<T>, GpuError> {
1013        self.gemm_transpose_a_cpu_fallback(a, b, m, k, n)
1014    }
1015
1016    /// ReLU activation forward pass
1017    pub fn relu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1018        self.relu_cpu_fallback(input)
1019    }
1020
1021    /// ReLU backward pass
1022    pub fn relu_backward<T: GpuDataType>(
1023        &self,
1024        grad_output: &GpuBuffer<T>,
1025        input: &GpuBuffer<T>,
1026    ) -> Result<GpuBuffer<T>, GpuError> {
1027        self.relu_backward_cpu_fallback(grad_output, input)
1028    }
1029
1030    /// Sigmoid activation forward pass
1031    pub fn sigmoid<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1032        self.sigmoid_cpu_fallback(input)
1033    }
1034
1035    /// Sigmoid backward pass
1036    pub fn sigmoid_backward<T: GpuDataType>(
1037        &self,
1038        grad_output: &GpuBuffer<T>,
1039        input: &GpuBuffer<T>,
1040    ) -> Result<GpuBuffer<T>, GpuError> {
1041        self.sigmoid_backward_cpu_fallback(grad_output, input)
1042    }
1043
1044    /// Tanh activation forward pass
1045    pub fn tanh<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1046        self.tanh_cpu_fallback(input)
1047    }
1048
1049    /// Tanh backward pass
1050    pub fn tanh_backward<T: GpuDataType>(
1051        &self,
1052        grad_output: &GpuBuffer<T>,
1053        input: &GpuBuffer<T>,
1054    ) -> Result<GpuBuffer<T>, GpuError> {
1055        self.tanh_backward_cpu_fallback(grad_output, input)
1056    }
1057
1058    /// GELU activation forward pass
1059    pub fn gelu<T: GpuDataType>(&self, input: &GpuBuffer<T>) -> Result<GpuBuffer<T>, GpuError> {
1060        self.gelu_cpu_fallback(input)
1061    }
1062
1063    /// GELU backward pass
1064    pub fn gelu_backward<T: GpuDataType>(
1065        &self,
1066        grad_output: &GpuBuffer<T>,
1067        input: &GpuBuffer<T>,
1068    ) -> Result<GpuBuffer<T>, GpuError> {
1069        self.gelu_backward_cpu_fallback(grad_output, input)
1070    }
1071}
1072
1073impl fmt::Debug for GpuContext {
1074    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1075        f.debug_struct("GpuContext")
1076            .field("backend", &self.backend)
1077            .finish()
1078    }
1079}
1080
1081// The following trait definitions would be implemented by backend-specific
1082// code in a real implementation
1083
1084/// GPU buffer implementation trait
1085pub(crate) trait GpuBufferImpl: Send + Sync {
1086    /// Copy data from host to device
1087    unsafe fn copy_from_host(&self, data: *const u8, size: usize);
1088
1089    /// Copy data from device to host
1090    unsafe fn copy_to_host(&self, data: *mut u8, size: usize);
1091
1092    /// Get a reference to self as Any for downcasting
1093    #[allow(dead_code)]
1094    fn as_any(&self) -> &dyn std::any::Any;
1095
1096    /// Get the size of the buffer in bytes
1097    #[allow(dead_code)]
1098    fn size(&self) -> usize {
1099        0 // Default implementation for backward compatibility
1100    }
1101
1102    /// Get the device pointer (for backends that use device pointers)
1103    #[allow(dead_code)]
1104    fn device_ptr(&self) -> u64 {
1105        0 // Default implementation for backward compatibility
1106    }
1107}
1108
1109/// GPU kernel implementation trait
1110pub(crate) trait GpuKernelImpl: Send + Sync {
1111    /// Set a buffer parameter
1112    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>);
1113
1114    /// Set a u32 parameter
1115    fn set_u32(&self, name: &str, value: u32);
1116
1117    /// Set an i32 parameter
1118    fn set_i32(&self, name: &str, value: i32);
1119
1120    /// Set an f32 parameter
1121    fn set_f32(&self, name: &str, value: f32);
1122
1123    /// Set an f64 parameter
1124    fn set_f64(&self, name: &str, value: f64);
1125
1126    /// Dispatch the kernel
1127    fn dispatch(&self, workgroups: [u32; 3]);
1128
1129    /// Dispatch the kernel without waiting for completion.
1130    /// The GPU work is submitted and will execute asynchronously.
1131    /// Call `gpu_sync()` on the context to wait for all pending work.
1132    fn dispatch_no_wait(&self, workgroups: [u32; 3]) {
1133        // Default: fall back to synchronous dispatch
1134        self.dispatch(workgroups);
1135    }
1136
1137    /// Try to encode a dispatch into the active batch (if any).
1138    ///
1139    /// Returns `true` if the dispatch was batched, `false` if no batch is
1140    /// active and the caller should use the normal `dispatch()` path.
1141    fn try_batch_dispatch(&self, _workgroups: [u32; 3]) -> bool {
1142        false
1143    }
1144}
1145
1146/// GPU compiler implementation trait
1147pub(crate) trait GpuCompilerImpl: Send + Sync {
1148    /// Compile a kernel from source
1149    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError>;
1150
1151    /// Compile a typed kernel
1152    fn compile_typed(
1153        &self,
1154        name: &str,
1155        input_type: std::any::TypeId,
1156        output_type: std::any::TypeId,
1157    ) -> Arc<dyn GpuKernelImpl>;
1158}
1159
1160/// GPU context implementation trait
1161pub(crate) trait GpuContextImpl: Send + Sync {
1162    /// Create a buffer
1163    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl>;
1164
1165    /// Create a compiler
1166    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl>;
1167
1168    /// Wait for all pending GPU operations to complete.
1169    fn gpu_sync(&self) -> Result<(), GpuError> {
1170        Ok(()) // Default: no-op (synchronous backends already block)
1171    }
1172
1173    /// Begin batch dispatch mode.
1174    ///
1175    /// While active, kernel dispatches are encoded into a shared command
1176    /// buffer instead of creating individual ones.
1177    fn begin_batch(&self) -> Result<(), GpuError> {
1178        Ok(()) // Default: no-op (backends without batch support)
1179    }
1180
1181    /// End batch dispatch mode.
1182    ///
1183    /// Submits the shared command buffer and waits for all encoded
1184    /// dispatches to complete on the GPU.
1185    fn end_batch(&self) -> Result<(), GpuError> {
1186        Ok(()) // Default: no-op
1187    }
1188
1189    /// Support dynamic downcasting of concrete context implementations
1190    fn as_any(&self) -> &dyn std::any::Any
1191    where
1192        Self: 'static + Sized,
1193    {
1194        self
1195    }
1196}
1197
1198// CPU fallback implementation
1199
1200/// CPU context implementation
1201struct CpuContext;
1202
1203impl CpuContext {
1204    /// Create a new CPU context
1205    fn new() -> Self {
1206        Self
1207    }
1208}
1209
1210impl GpuContextImpl for CpuContext {
1211    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
1212        Arc::new(CpuBuffer::new(size))
1213    }
1214
1215    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
1216        Arc::new(CpuCompiler)
1217    }
1218}
1219
1220/// CPU buffer implementation
1221struct CpuBuffer {
1222    data: Vec<u8>,
1223}
1224
1225impl CpuBuffer {
1226    /// Create a new CPU buffer
1227    fn new(size: usize) -> Self {
1228        Self {
1229            data: vec![0; size],
1230        }
1231    }
1232}
1233
1234impl GpuBufferImpl for CpuBuffer {
1235    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1236        let mut_self = self as *const Self as *mut Self;
1237        let data_ptr = (*mut_self).data.as_mut_ptr();
1238        std::ptr::copy_nonoverlapping(data, data_ptr, size);
1239    }
1240
1241    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1242        let data_ptr = self.data.as_ptr();
1243        std::ptr::copy_nonoverlapping(data_ptr, data, size);
1244    }
1245
1246    fn as_any(&self) -> &dyn std::any::Any {
1247        self
1248    }
1249
1250    fn size(&self) -> usize {
1251        self.data.len()
1252    }
1253
1254    fn device_ptr(&self) -> u64 {
1255        self.data.as_ptr() as u64
1256    }
1257}
1258
1259/// CPU compiler implementation
1260struct CpuCompiler;
1261
1262impl GpuCompilerImpl for CpuCompiler {
1263    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
1264        // In a real implementation, we would parse and execute the kernel
1265        // For now, just return a dummy implementation
1266        Ok(Arc::new(CpuKernel))
1267    }
1268
1269    fn compile_typed(
1270        &self,
1271        _name: &str,
1272        _input_type: std::any::TypeId,
1273        _output_type: std::any::TypeId,
1274    ) -> Arc<dyn GpuKernelImpl> {
1275        // In a real implementation, we would select an appropriate implementation
1276        // For now, just return a dummy implementation
1277        Arc::new(CpuKernel)
1278    }
1279}
1280
1281/// CPU kernel implementation
1282struct CpuKernel;
1283
1284impl GpuKernelImpl for CpuKernel {
1285    fn set_buffer(&self, _name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
1286        // In a real implementation, we would store the buffer
1287    }
1288
1289    fn set_u32(&self, _name: &str, value: u32) {
1290        // In a real implementation, we would store the value
1291    }
1292
1293    fn set_i32(&self, _name: &str, value: i32) {
1294        // In a real implementation, we would store the value
1295    }
1296
1297    fn set_f32(&self, _name: &str, value: f32) {
1298        // In a real implementation, we would store the value
1299    }
1300
1301    fn set_f64(&self, _name: &str, value: f64) {
1302        // In a real implementation, we would store the value
1303    }
1304
1305    fn dispatch(&self, workgroups: [u32; 3]) {
1306        // In a real implementation, we would execute the kernel
1307    }
1308}
1309
1310// In a real implementation, we would have implementations for other backends
1311// such as CUDA, WebGPU, Metal, and OpenCL.
1312
1313#[cfg(test)]
1314mod tests {
1315    use super::*;
1316
1317    #[test]
1318    fn test_gpu_backend_preferred() {
1319        let backend = GpuBackend::preferred();
1320        // Should return a valid backend
1321        match backend {
1322            GpuBackend::Cuda
1323            | GpuBackend::Rocm
1324            | GpuBackend::Wgpu
1325            | GpuBackend::Metal
1326            | GpuBackend::OpenCL
1327            | GpuBackend::Cpu => {}
1328        }
1329    }
1330
1331    #[test]
1332    fn test_gpu_backend_default() {
1333        let backend = GpuBackend::default();
1334        assert_eq!(backend, GpuBackend::preferred());
1335    }
1336
1337    #[test]
1338    fn test_gpu_backend_is_available() {
1339        let backend = GpuBackend::Cpu;
1340        assert!(backend.is_available());
1341
1342        // Test other backends - availability depends on runtime, not just feature flags
1343        // CUDA backend was retired from scirs2-core in 0.6.x — always unavailable now.
1344        assert!(!GpuBackend::Cuda.is_available());
1345
1346        #[cfg(feature = "rocm")]
1347        {
1348            // ROCm feature enabled doesn't guarantee runtime availability
1349            let _ = GpuBackend::Rocm.is_available(); // Just check without asserting
1350        }
1351        #[cfg(not(feature = "rocm"))]
1352        assert!(!GpuBackend::Rocm.is_available());
1353
1354        #[cfg(all(feature = "metal", target_os = "macos"))]
1355        assert!(GpuBackend::Metal.is_available());
1356        #[cfg(not(all(feature = "metal", target_os = "macos")))]
1357        assert!(!GpuBackend::Metal.is_available());
1358    }
1359
1360    #[test]
1361    fn test_gpu_backend_display() {
1362        assert_eq!(GpuBackend::Cuda.to_string(), "CUDA");
1363        assert_eq!(GpuBackend::Rocm.to_string(), "ROCm");
1364        assert_eq!(GpuBackend::Wgpu.to_string(), "WebGPU");
1365        assert_eq!(GpuBackend::Metal.to_string(), "Metal");
1366        assert_eq!(GpuBackend::OpenCL.to_string(), "OpenCL");
1367        assert_eq!(GpuBackend::Cpu.to_string(), "CPU");
1368    }
1369
1370    #[test]
1371    fn test_gpuerror_from_conversion() {
1372        let gpuerror = GpuError::BackendNotAvailable("CUDA".to_string());
1373        let coreerror: CoreError = gpuerror.into();
1374        match coreerror {
1375            CoreError::ComputationError(_) => {}
1376            _ => panic!("Expected ComputationError"),
1377        }
1378
1379        let gpuerror = GpuError::OutOfMemory("8GB required".to_string());
1380        let coreerror: CoreError = gpuerror.into();
1381        match coreerror {
1382            CoreError::MemoryError(_) => {}
1383            _ => panic!("Expected MemoryError"),
1384        }
1385
1386        let gpuerror = GpuError::InvalidParameter("batch_size must be > 0".to_string());
1387        let coreerror: CoreError = gpuerror.into();
1388        match coreerror {
1389            CoreError::InvalidArgument(_) => {}
1390            _ => panic!("Expected InvalidArgument"),
1391        }
1392
1393        let gpuerror = GpuError::UnsupportedDataType(kernels::DataType::Float16);
1394        let coreerror: CoreError = gpuerror.into();
1395        match coreerror {
1396            CoreError::TypeError(_) => {}
1397            _ => panic!("Expected TypeError"),
1398        }
1399    }
1400
1401    #[test]
1402    fn test_gpu_datatype_trait() {
1403        // Test that various types implement GpuDataType
1404        fn assert_gpu_datatype<T: GpuDataType>() {}
1405
1406        assert_gpu_datatype::<f32>();
1407        assert_gpu_datatype::<f64>();
1408        assert_gpu_datatype::<i32>();
1409        assert_gpu_datatype::<u32>();
1410        assert_gpu_datatype::<u8>();
1411        assert_gpu_datatype::<i8>();
1412        assert_gpu_datatype::<u16>();
1413        assert_gpu_datatype::<i16>();
1414        assert_gpu_datatype::<u64>();
1415        assert_gpu_datatype::<i64>();
1416    }
1417
1418    #[test]
1419    fn test_gpu_buffer_creation() {
1420        let inner = Arc::new(CpuBuffer::new(100));
1421        let buffer = GpuBuffer::<f32>::new(inner, 25);
1422
1423        assert_eq!(buffer.len(), 25);
1424        assert!(!buffer.is_empty());
1425    }
1426
1427    #[test]
1428    fn test_gpu_buffer_empty() {
1429        let inner = Arc::new(CpuBuffer::new(0));
1430        let buffer = GpuBuffer::<f32>::new(inner, 0);
1431
1432        assert_eq!(buffer.len(), 0);
1433        assert!(buffer.is_empty());
1434    }
1435
1436    #[test]
1437    fn test_gpu_buffer_copy_operations() {
1438        let inner = Arc::new(CpuBuffer::new(16));
1439        let buffer = GpuBuffer::<f32>::new(inner, 4);
1440
1441        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1442        let _ = buffer.copy_from_host(&data);
1443
1444        let mut result = vec![0.0f32; 4];
1445        let _ = buffer.copy_to_host(&mut result);
1446
1447        assert_eq!(result, data);
1448    }
1449
1450    #[test]
1451    fn test_gpu_buffer_to_vec() {
1452        let inner = Arc::new(CpuBuffer::new(12));
1453        let buffer = GpuBuffer::<f32>::new(inner, 3);
1454
1455        let data = vec![5.0f32, 6.0, 7.0];
1456        let _ = buffer.copy_from_host(&data);
1457
1458        let result = buffer.to_vec();
1459        assert_eq!(result, data);
1460    }
1461
1462    #[test]
1463    #[should_panic(expected = "Data size exceeds buffer size")]
1464    fn test_gpu_buffer_copy_from_host_overflow() {
1465        let inner = Arc::new(CpuBuffer::new(8));
1466        let buffer = GpuBuffer::<f32>::new(inner, 2);
1467
1468        let data = vec![1.0f32, 2.0, 3.0]; // 3 elements > 2 buffer size
1469        buffer.copy_from_host(&data).expect("Operation failed");
1470    }
1471
1472    #[test]
1473    #[should_panic(expected = "Data size exceeds buffer size")]
1474    fn test_gpu_buffer_copy_to_host_overflow() {
1475        let inner = Arc::new(CpuBuffer::new(8));
1476        let buffer = GpuBuffer::<f32>::new(inner, 2);
1477
1478        let mut data = vec![0.0f32; 3]; // 3 elements > 2 buffer size
1479        buffer.copy_to_host(&mut data).expect("Operation failed");
1480    }
1481
1482    #[test]
1483    fn test_gpu_kernel_handle() {
1484        let kernel = Arc::new(CpuKernel);
1485        let handle = GpuKernelHandle::new(kernel);
1486
1487        // Test setting various parameter types
1488        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1489        handle.set_buffer("input", &buffer);
1490        handle.set_u32("size", 100);
1491        handle.set_i32("offset", -5);
1492        handle.set_f32("scale", 2.5);
1493        handle.set_f64("precision", 0.0001);
1494
1495        // Test dispatch
1496        handle.dispatch([16, 8, 1]);
1497    }
1498
1499    #[test]
1500    fn test_gpu_context_cpu_backend() {
1501        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1502        assert_eq!(context.backend(), GpuBackend::Cpu);
1503        assert_eq!(context.backend_name(), "CPU");
1504
1505        // Test memory query methods
1506        assert_eq!(context.get_available_memory(), Some(1024 * 1024 * 1024));
1507        assert_eq!(context.get_total_memory(), Some(4 * 1024 * 1024 * 1024));
1508    }
1509
1510    #[test]
1511    fn test_gpu_context_buffer_creation() {
1512        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1513
1514        let buffer = context.create_buffer::<f32>(100);
1515        assert_eq!(buffer.len(), 100);
1516
1517        let data = vec![1.0f32; 50];
1518        let buffer_from_slice = context.create_buffer_from_slice(&data);
1519        assert_eq!(buffer_from_slice.len(), 50);
1520
1521        let result = buffer_from_slice.to_vec();
1522        assert_eq!(result, data);
1523    }
1524
1525    #[test]
1526    fn test_gpu_context_unsupported_backend() {
1527        // CUDA was retired from scirs2-core in 0.6.x; constructing a context for it
1528        // must fail with explicit oxicuda migration guidance.
1529        let result = GpuContext::new(GpuBackend::Cuda);
1530        match result {
1531            Err(GpuError::BackendNotAvailable(msg)) => {
1532                assert!(
1533                    msg.contains("oxicuda"),
1534                    "expected oxicuda guidance, got: {msg}"
1535                );
1536            }
1537            other => {
1538                panic!("expected BackendNotAvailable for retired CUDA backend, got: {other:?}")
1539            }
1540        }
1541    }
1542
1543    #[test]
1544    fn test_gpu_compiler() {
1545        let compiler_impl = Arc::new(CpuCompiler);
1546        let compiler = GpuCompiler::new(compiler_impl);
1547
1548        // Test compiling from source
1549        let kernel = compiler
1550            .compile("dummy kernel source")
1551            .expect("Operation failed");
1552        kernel.dispatch([1, 1, 1]);
1553
1554        // Test typed compilation
1555        let typed_kernel = compiler.compile_kernel::<f32, f32>("vector_add");
1556        typed_kernel.dispatch([32, 1, 1]);
1557    }
1558
1559    #[test]
1560    fn test_gpu_context_execute() {
1561        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1562
1563        let result = context.execute(|compiler| compiler.compile("test kernel").is_ok());
1564
1565        assert!(result);
1566    }
1567
1568    #[test]
1569    fn test_gpu_context_kernel_registry() {
1570        let context = GpuContext::new(GpuBackend::Cpu).expect("Operation failed");
1571
1572        // Test getting a non-existent kernel
1573        let result = context.get_kernel("non_existent_kernel");
1574        assert!(result.is_err());
1575        match result {
1576            Err(GpuError::KernelNotFound(_)) => {}
1577            _ => panic!("Expected KernelNotFound error"),
1578        }
1579    }
1580
1581    #[test]
1582    fn test_cpu_buffer_implementation() {
1583        let buffer = CpuBuffer::new(256);
1584        assert_eq!(buffer.data.len(), 256);
1585
1586        // Test that initial data is zeroed
1587        assert!(buffer.data.iter().all(|&b| b == 0));
1588    }
1589
1590    #[test]
1591    fn test_gpuerror_display() {
1592        let error = GpuError::BackendNotAvailable("CUDA".to_string());
1593        assert_eq!(error.to_string(), "GPU backend CUDA is not available");
1594
1595        let error = GpuError::OutOfMemory("allocation failed".to_string());
1596        assert_eq!(error.to_string(), "GPU out of memory: allocation failed");
1597
1598        let error = GpuError::KernelCompilationError("syntax error".to_string());
1599        assert_eq!(error.to_string(), "Kernel compilation error: syntax error");
1600
1601        let error = GpuError::KernelNotFound("gemm".to_string());
1602        assert_eq!(error.to_string(), "Kernel not found: gemm");
1603    }
1604
1605    #[test]
1606    fn test_backend_equality() {
1607        assert_eq!(GpuBackend::Cuda, GpuBackend::Cuda);
1608        assert_ne!(GpuBackend::Cuda, GpuBackend::Rocm);
1609
1610        // Test Clone and Copy
1611        let backend = GpuBackend::Metal;
1612        let cloned = backend;
1613        let copied = backend;
1614        assert_eq!(backend, cloned);
1615        assert_eq!(backend, copied);
1616    }
1617
1618    #[test]
1619    fn test_backend_hash() {
1620        use std::collections::HashSet;
1621
1622        let mut set = HashSet::new();
1623        set.insert(GpuBackend::Cuda);
1624        set.insert(GpuBackend::Rocm);
1625        set.insert(GpuBackend::Cuda); // Duplicate
1626
1627        assert_eq!(set.len(), 2); // Should only have 2 unique entries
1628        assert!(set.contains(&GpuBackend::Cuda));
1629        assert!(set.contains(&GpuBackend::Rocm));
1630    }
1631
1632    #[test]
1633    fn test_gpu_buffer_debug_clone() {
1634        let inner = Arc::new(CpuBuffer::new(16));
1635        let buffer = GpuBuffer::<f32>::new(inner, 4);
1636
1637        // Test Debug implementation
1638        let debug_str = format!("{:?}", buffer);
1639        assert!(debug_str.contains("GpuBuffer"));
1640        assert!(debug_str.contains("size"));
1641
1642        // Test Clone implementation
1643        let cloned = buffer.clone();
1644        assert_eq!(cloned.len(), buffer.len());
1645        assert_eq!(cloned.len(), 4);
1646
1647        // Verify the clone is independent (shares the same Arc)
1648        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1649        let _ = buffer.copy_from_host(&data);
1650
1651        let mut result = vec![0.0f32; 4];
1652        let _ = cloned.copy_to_host(&mut result);
1653        assert_eq!(result, data);
1654    }
1655
1656    #[test]
1657    fn test_gpu_context_debug() {
1658        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
1659
1660        // Test Debug implementation
1661        let debug_str = format!("{:?}", context);
1662        assert!(debug_str.contains("GpuContext"));
1663        assert!(debug_str.contains("backend"));
1664        assert!(debug_str.contains("Cpu"));
1665    }
1666
1667    #[test]
1668    fn test_gpu_context_batch_dispatch() {
1669        // Create context with CPU backend (always available)
1670        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1671
1672        // Begin batch mode
1673        let begin_result = context.begin_batch();
1674        assert!(
1675            begin_result.is_ok(),
1676            "begin_batch should succeed on CPU backend"
1677        );
1678
1679        // Compile a kernel and dispatch it inside the batch
1680        let dispatch_result = context.execute(|compiler| {
1681            compiler.compile("dummy kernel source").map(|kernel| {
1682                kernel.dispatch([4, 1, 1]);
1683            })
1684        });
1685        assert!(
1686            dispatch_result.is_ok(),
1687            "kernel dispatch inside batch should succeed"
1688        );
1689
1690        // End batch — submits and waits
1691        let end_result = context.end_batch();
1692        assert!(
1693            end_result.is_ok(),
1694            "end_batch should succeed on CPU backend"
1695        );
1696    }
1697
1698    #[test]
1699    fn test_gpu_context_gpu_sync() {
1700        // Create context with CPU backend
1701        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create CPU context");
1702
1703        // gpu_sync should complete without error on CPU backend
1704        let result = context.gpu_sync();
1705        assert!(result.is_ok(), "gpu_sync should return Ok on CPU backend");
1706    }
1707
1708    #[test]
1709    fn test_gpu_kernel_dispatch_no_wait() {
1710        // Create a kernel handle backed by CpuKernel
1711        let kernel = Arc::new(CpuKernel);
1712        let handle = GpuKernelHandle::new(kernel);
1713
1714        // Bind a buffer and scalar so the kernel has state
1715        let buffer = GpuBuffer::<f32>::new(Arc::new(CpuBuffer::new(16)), 4);
1716        handle.set_buffer("input", &buffer);
1717        handle.set_u32("size", 4);
1718
1719        // dispatch_no_wait returns () — just verify no panic
1720        handle.dispatch_no_wait([4, 1, 1]);
1721    }
1722
1723    #[test]
1724    fn test_gpu_device_get_info_cpu() {
1725        let device = GpuDevice::new(GpuBackend::Cpu, 0);
1726        let info = device
1727            .get_info()
1728            .expect("get_info should succeed for the CPU backend");
1729
1730        // Backend and basic identity must round-trip.
1731        assert_eq!(info.backend, GpuBackend::Cpu);
1732        assert_eq!(info.device_name, "CPU");
1733        assert_eq!(info.device_type, "CPU");
1734
1735        // The struct must be well-formed: non-empty descriptors and a valid
1736        // work-group size. The CPU fallback emulates both precisions.
1737        assert!(!info.device_name.is_empty());
1738        assert!(!info.compute_capability.is_empty());
1739        assert!(info.max_work_group_size >= 1);
1740        assert!(info.supports_fp64);
1741        assert!(info.supports_fp16);
1742    }
1743
1744    #[test]
1745    fn test_gpu_device_info_is_deterministic_per_backend() {
1746        // Every backend must yield a well-formed, non-empty, deterministic info.
1747        for backend in [
1748            GpuBackend::Cpu,
1749            GpuBackend::Cuda,
1750            GpuBackend::Rocm,
1751            GpuBackend::Wgpu,
1752            GpuBackend::Metal,
1753            GpuBackend::OpenCL,
1754        ] {
1755            let device = GpuDevice::new(backend, 0);
1756            let info = device.get_info().expect("get_info should not fail");
1757            let info_again = device.get_info().expect("get_info should not fail");
1758
1759            assert_eq!(info.backend, backend);
1760            assert!(!info.device_name.is_empty());
1761            assert!(!info.device_type.is_empty());
1762            assert!(!info.compute_capability.is_empty());
1763            assert!(info.max_work_group_size >= 1);
1764
1765            // Deterministic: two queries produce identical descriptors.
1766            assert_eq!(info.device_name, info_again.device_name);
1767            assert_eq!(info.max_work_group_size, info_again.max_work_group_size);
1768            assert_eq!(info.supports_fp64, info_again.supports_fp64);
1769        }
1770    }
1771}