sklears_ensemble/
gpu_acceleration.rs

1//! GPU acceleration for ensemble methods
2//!
3//! This module provides GPU acceleration capabilities for ensemble training and inference,
4//! with support for multiple GPU backends and fallback to CPU when GPU is unavailable.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::types::{Float, Int};
9use std::sync::Arc;
10
11/// GPU backend enumeration
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum GpuBackend {
14    /// CUDA backend for NVIDIA GPUs
15    Cuda,
16    /// OpenCL backend for cross-platform GPU support
17    OpenCL,
18    /// Metal backend for Apple GPUs
19    Metal,
20    /// Vulkan backend for modern graphics APIs
21    Vulkan,
22    /// CPU fallback when no GPU is available
23    CpuFallback,
24}
25
26/// GPU configuration
27#[derive(Debug, Clone)]
28pub struct GpuConfig {
29    /// GPU backend to use
30    pub backend: GpuBackend,
31    /// Device ID (for systems with multiple GPUs)
32    pub device_id: usize,
33    /// Memory limit for GPU usage (in MB)
34    pub memory_limit_mb: Option<usize>,
35    /// Batch size for GPU operations
36    pub batch_size: usize,
37    /// Number of GPU streams for parallel execution
38    pub n_streams: usize,
39    /// Enable mixed precision (FP16/FP32)
40    pub mixed_precision: bool,
41    /// Enable tensor cores (for supported hardware)
42    pub tensor_cores: bool,
43    /// Memory pool size for efficient allocation
44    pub memory_pool_size_mb: usize,
45    /// Enable profiling for performance analysis
46    pub enable_profiling: bool,
47}
48
49impl Default for GpuConfig {
50    fn default() -> Self {
51        Self {
52            backend: GpuBackend::CpuFallback,
53            device_id: 0,
54            memory_limit_mb: None,
55            batch_size: 1024,
56            n_streams: 4,
57            mixed_precision: false,
58            tensor_cores: false,
59            memory_pool_size_mb: 1024,
60            enable_profiling: false,
61        }
62    }
63}
64
65/// GPU acceleration context
66pub struct GpuContext {
67    config: GpuConfig,
68    device_info: GpuDeviceInfo,
69    memory_manager: GpuMemoryManager,
70    profiler: Option<GpuProfiler>,
71}
72
73/// GPU device information
74#[derive(Debug, Clone)]
75pub struct GpuDeviceInfo {
76    /// Device name
77    pub name: String,
78    /// Total memory in MB
79    pub total_memory_mb: usize,
80    /// Available memory in MB
81    pub available_memory_mb: usize,
82    /// Number of compute units/cores
83    pub compute_units: usize,
84    /// Maximum work group size
85    pub max_work_group_size: usize,
86    /// Supports mixed precision
87    pub supports_mixed_precision: bool,
88    /// Supports tensor cores
89    pub supports_tensor_cores: bool,
90}
91
92/// GPU memory manager
93pub struct GpuMemoryManager {
94    allocated_bytes: usize,
95    peak_allocated_bytes: usize,
96    pool_size_bytes: usize,
97    free_blocks: Vec<GpuMemoryBlock>,
98}
99
100/// GPU memory block
101#[derive(Debug)]
102pub struct GpuMemoryBlock {
103    pub ptr: usize, // GPU memory pointer (abstracted)
104    pub size_bytes: usize,
105    pub in_use: bool,
106}
107
108/// GPU profiler for performance analysis
109pub struct GpuProfiler {
110    enabled: bool,
111    kernel_times: Vec<(String, f64)>,
112    memory_transfers: Vec<(String, usize, f64)>,
113}
114
115/// GPU kernel operations
116pub trait GpuKernel {
117    /// Execute kernel on GPU
118    fn execute(&self, context: &GpuContext) -> Result<()>;
119
120    /// Get estimated execution time
121    fn estimated_time_ms(&self) -> f64;
122
123    /// Get memory requirements
124    fn memory_requirements_mb(&self) -> usize;
125}
126
127/// GPU-accelerated gradient boosting kernels
128pub struct GradientBoostingKernels {
129    /// Histogram computation kernel
130    pub histogram_kernel: HistogramKernel,
131    /// Split finding kernel
132    pub split_kernel: SplitFindingKernel,
133    /// Tree update kernel
134    pub tree_update_kernel: TreeUpdateKernel,
135    /// Prediction kernel
136    pub prediction_kernel: PredictionKernel,
137}
138
139/// Histogram computation kernel
140#[derive(Debug)]
141pub struct HistogramKernel {
142    pub n_features: usize,
143    pub n_bins: usize,
144    pub n_samples: usize,
145}
146
147/// Split finding kernel
148#[derive(Debug)]
149pub struct SplitFindingKernel {
150    pub n_features: usize,
151    pub n_bins: usize,
152    pub regularization: Float,
153}
154
155/// Tree update kernel
156#[derive(Debug)]
157pub struct TreeUpdateKernel {
158    pub max_depth: usize,
159    pub learning_rate: Float,
160}
161
162/// Prediction kernel
163#[derive(Debug)]
164pub struct PredictionKernel {
165    pub n_trees: usize,
166    pub n_classes: usize,
167}
168
169/// GPU tensor operations
170pub struct GpuTensorOps {
171    context: Arc<GpuContext>,
172}
173
174/// GPU-accelerated ensemble trainer
175pub struct GpuEnsembleTrainer {
176    context: Arc<GpuContext>,
177    kernels: GradientBoostingKernels,
178    tensor_ops: GpuTensorOps,
179}
180
181impl GpuContext {
182    /// Create new GPU context
183    pub fn new(config: GpuConfig) -> Result<Self> {
184        let device_info = Self::detect_device(&config)?;
185        let memory_manager = GpuMemoryManager::new(config.memory_pool_size_mb * 1024 * 1024);
186        let profiler = if config.enable_profiling {
187            Some(GpuProfiler::new())
188        } else {
189            None
190        };
191
192        Ok(Self {
193            config,
194            device_info,
195            memory_manager,
196            profiler,
197        })
198    }
199
200    /// Detect and initialize GPU device
201    fn detect_device(config: &GpuConfig) -> Result<GpuDeviceInfo> {
202        match config.backend {
203            GpuBackend::Cuda => Self::detect_cuda_device(config.device_id),
204            GpuBackend::OpenCL => Self::detect_opencl_device(config.device_id),
205            GpuBackend::Metal => Self::detect_metal_device(config.device_id),
206            GpuBackend::Vulkan => Self::detect_vulkan_device(config.device_id),
207            GpuBackend::CpuFallback => Ok(Self::create_cpu_fallback_info()),
208        }
209    }
210
211    /// Detect CUDA device
212    fn detect_cuda_device(device_id: usize) -> Result<GpuDeviceInfo> {
213        // In a real implementation, this would use CUDA APIs
214        // For now, return a mock device info
215        Ok(GpuDeviceInfo {
216            name: format!("CUDA Device {}", device_id),
217            total_memory_mb: 8192,
218            available_memory_mb: 7168,
219            compute_units: 80,
220            max_work_group_size: 1024,
221            supports_mixed_precision: true,
222            supports_tensor_cores: true,
223        })
224    }
225
226    /// Detect OpenCL device
227    fn detect_opencl_device(device_id: usize) -> Result<GpuDeviceInfo> {
228        // In a real implementation, this would use OpenCL APIs
229        Ok(GpuDeviceInfo {
230            name: format!("OpenCL Device {}", device_id),
231            total_memory_mb: 4096,
232            available_memory_mb: 3584,
233            compute_units: 64,
234            max_work_group_size: 256,
235            supports_mixed_precision: false,
236            supports_tensor_cores: false,
237        })
238    }
239
240    /// Detect Metal device
241    fn detect_metal_device(device_id: usize) -> Result<GpuDeviceInfo> {
242        // In a real implementation, this would use Metal APIs
243        Ok(GpuDeviceInfo {
244            name: format!("Metal Device {}", device_id),
245            total_memory_mb: 16384, // Unified memory on Apple Silicon
246            available_memory_mb: 14336,
247            compute_units: 32,
248            max_work_group_size: 1024,
249            supports_mixed_precision: true,
250            supports_tensor_cores: false,
251        })
252    }
253
254    /// Detect Vulkan device
255    fn detect_vulkan_device(device_id: usize) -> Result<GpuDeviceInfo> {
256        // In a real implementation, this would use Vulkan APIs
257        Ok(GpuDeviceInfo {
258            name: format!("Vulkan Device {}", device_id),
259            total_memory_mb: 6144,
260            available_memory_mb: 5376,
261            compute_units: 56,
262            max_work_group_size: 512,
263            supports_mixed_precision: true,
264            supports_tensor_cores: false,
265        })
266    }
267
268    /// Create CPU fallback device info
269    fn create_cpu_fallback_info() -> GpuDeviceInfo {
270        GpuDeviceInfo {
271            name: "CPU Fallback".to_string(),
272            total_memory_mb: 8192, // Assume 8GB system RAM
273            available_memory_mb: 6144,
274            compute_units: std::thread::available_parallelism()
275                .map(|n| n.get())
276                .unwrap_or(4),
277            max_work_group_size: 1,
278            supports_mixed_precision: false,
279            supports_tensor_cores: false,
280        }
281    }
282
283    /// Check if GPU is available
284    pub fn is_gpu_available(&self) -> bool {
285        self.config.backend != GpuBackend::CpuFallback
286    }
287
288    /// Get device information
289    pub fn device_info(&self) -> &GpuDeviceInfo {
290        &self.device_info
291    }
292
293    /// Allocate GPU memory
294    pub fn allocate_memory(&mut self, size_bytes: usize) -> Result<usize> {
295        self.memory_manager.allocate(size_bytes)
296    }
297
298    /// Free GPU memory
299    pub fn free_memory(&mut self, ptr: usize) -> Result<()> {
300        self.memory_manager.free(ptr)
301    }
302
303    /// Start profiling
304    pub fn start_profiling(&mut self) {
305        if let Some(ref mut profiler) = self.profiler {
306            profiler.start();
307        }
308    }
309
310    /// Stop profiling and get results
311    pub fn stop_profiling(&mut self) -> Option<ProfilingResults> {
312        self.profiler.as_mut().map(|p| p.stop())
313    }
314}
315
316impl GpuMemoryManager {
317    /// Create new memory manager
318    pub fn new(pool_size_bytes: usize) -> Self {
319        Self {
320            allocated_bytes: 0,
321            peak_allocated_bytes: 0,
322            pool_size_bytes,
323            free_blocks: Vec::new(),
324        }
325    }
326
327    /// Allocate memory block
328    pub fn allocate(&mut self, size_bytes: usize) -> Result<usize> {
329        if self.allocated_bytes + size_bytes > self.pool_size_bytes {
330            return Err(SklearsError::InvalidInput(
331                "GPU memory allocation failed: out of memory".to_string(),
332            ));
333        }
334
335        // Find suitable free block or allocate new one
336        for block in &mut self.free_blocks {
337            if !block.in_use && block.size_bytes >= size_bytes {
338                block.in_use = true;
339                self.allocated_bytes += size_bytes;
340                self.peak_allocated_bytes = self.peak_allocated_bytes.max(self.allocated_bytes);
341                return Ok(block.ptr);
342            }
343        }
344
345        // Allocate new block
346        let ptr = self.free_blocks.len(); // Simple pointer simulation
347        self.free_blocks.push(GpuMemoryBlock {
348            ptr,
349            size_bytes,
350            in_use: true,
351        });
352
353        self.allocated_bytes += size_bytes;
354        self.peak_allocated_bytes = self.peak_allocated_bytes.max(self.allocated_bytes);
355
356        Ok(ptr)
357    }
358
359    /// Free memory block
360    pub fn free(&mut self, ptr: usize) -> Result<()> {
361        if let Some(block) = self.free_blocks.get_mut(ptr) {
362            if block.in_use {
363                block.in_use = false;
364                self.allocated_bytes = self.allocated_bytes.saturating_sub(block.size_bytes);
365                Ok(())
366            } else {
367                Err(SklearsError::InvalidInput(
368                    "Attempted to free already freed memory".to_string(),
369                ))
370            }
371        } else {
372            Err(SklearsError::InvalidInput(
373                "Invalid memory pointer".to_string(),
374            ))
375        }
376    }
377
378    /// Get memory usage statistics
379    pub fn memory_stats(&self) -> (usize, usize, usize) {
380        (
381            self.allocated_bytes,
382            self.peak_allocated_bytes,
383            self.pool_size_bytes,
384        )
385    }
386}
387
388impl Default for GpuProfiler {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394impl GpuProfiler {
395    /// Create new profiler
396    pub fn new() -> Self {
397        Self {
398            enabled: false,
399            kernel_times: Vec::new(),
400            memory_transfers: Vec::new(),
401        }
402    }
403
404    /// Start profiling
405    pub fn start(&mut self) {
406        self.enabled = true;
407        self.kernel_times.clear();
408        self.memory_transfers.clear();
409    }
410
411    /// Stop profiling and return results
412    pub fn stop(&mut self) -> ProfilingResults {
413        self.enabled = false;
414        ProfilingResults {
415            kernel_times: self.kernel_times.clone(),
416            memory_transfers: self.memory_transfers.clone(),
417            total_kernel_time: self.kernel_times.iter().map(|(_, t)| t).sum(),
418            total_memory_transfer_time: self.memory_transfers.iter().map(|(_, _, t)| t).sum(),
419        }
420    }
421
422    /// Record kernel execution time
423    pub fn record_kernel(&mut self, name: String, time_ms: f64) {
424        if self.enabled {
425            self.kernel_times.push((name, time_ms));
426        }
427    }
428
429    /// Record memory transfer
430    pub fn record_memory_transfer(&mut self, name: String, bytes: usize, time_ms: f64) {
431        if self.enabled {
432            self.memory_transfers.push((name, bytes, time_ms));
433        }
434    }
435}
436
437/// Profiling results
438#[derive(Debug, Clone)]
439pub struct ProfilingResults {
440    pub kernel_times: Vec<(String, f64)>,
441    pub memory_transfers: Vec<(String, usize, f64)>,
442    pub total_kernel_time: f64,
443    pub total_memory_transfer_time: f64,
444}
445
446impl GpuKernel for HistogramKernel {
447    fn execute(&self, _context: &GpuContext) -> Result<()> {
448        // In a real implementation, this would execute GPU kernel code
449        // For now, simulate execution
450        std::thread::sleep(std::time::Duration::from_millis(
451            (self.n_features * self.n_bins / 1000) as u64,
452        ));
453        Ok(())
454    }
455
456    fn estimated_time_ms(&self) -> f64 {
457        (self.n_features * self.n_bins * self.n_samples) as f64 / 1_000_000.0
458    }
459
460    fn memory_requirements_mb(&self) -> usize {
461        (self.n_features * self.n_bins * 8) / (1024 * 1024) // 8 bytes per histogram entry
462    }
463}
464
465impl GpuKernel for SplitFindingKernel {
466    fn execute(&self, _context: &GpuContext) -> Result<()> {
467        // Simulate split finding computation
468        std::thread::sleep(std::time::Duration::from_millis(
469            (self.n_features * self.n_bins / 100) as u64,
470        ));
471        Ok(())
472    }
473
474    fn estimated_time_ms(&self) -> f64 {
475        (self.n_features * self.n_bins) as f64 / 100_000.0
476    }
477
478    fn memory_requirements_mb(&self) -> usize {
479        (self.n_features * self.n_bins * 4) / (1024 * 1024) // 4 bytes per split candidate
480    }
481}
482
483impl GpuKernel for TreeUpdateKernel {
484    fn execute(&self, _context: &GpuContext) -> Result<()> {
485        // Simulate tree update computation
486        std::thread::sleep(std::time::Duration::from_millis(1));
487        Ok(())
488    }
489
490    fn estimated_time_ms(&self) -> f64 {
491        self.max_depth as f64 * 0.1
492    }
493
494    fn memory_requirements_mb(&self) -> usize {
495        (2_usize.pow(self.max_depth as u32) * 64) / (1024 * 1024) // Tree node storage
496    }
497}
498
499impl GpuKernel for PredictionKernel {
500    fn execute(&self, _context: &GpuContext) -> Result<()> {
501        // Simulate prediction computation
502        std::thread::sleep(std::time::Duration::from_millis((self.n_trees / 10) as u64));
503        Ok(())
504    }
505
506    fn estimated_time_ms(&self) -> f64 {
507        self.n_trees as f64 * 0.01
508    }
509
510    fn memory_requirements_mb(&self) -> usize {
511        (self.n_trees * self.n_classes * 4) / (1024 * 1024) // Prediction storage
512    }
513}
514
515impl GpuTensorOps {
516    /// Create new GPU tensor operations
517    pub fn new(context: Arc<GpuContext>) -> Self {
518        Self { context }
519    }
520
521    /// Matrix multiplication on GPU
522    pub fn matmul(&self, a: &Array2<Float>, b: &Array2<Float>) -> Result<Array2<Float>> {
523        // In a real implementation, this would use GPU BLAS libraries
524        // For now, fallback to CPU computation
525        Ok(a.dot(b))
526    }
527
528    /// Element-wise operations on GPU
529    pub fn elementwise_add(&self, a: &Array2<Float>, b: &Array2<Float>) -> Result<Array2<Float>> {
530        // GPU element-wise addition
531        Ok(a + b)
532    }
533
534    /// Reduction operations on GPU
535    pub fn reduce_sum(&self, array: &Array2<Float>, axis: Option<usize>) -> Result<Array1<Float>> {
536        match axis {
537            Some(ax) => Ok(array.sum_axis(scirs2_core::ndarray::Axis(ax))),
538            None => Ok(Array1::from_elem(1, array.sum())),
539        }
540    }
541
542    /// Softmax on GPU
543    pub fn softmax(&self, array: &Array2<Float>) -> Result<Array2<Float>> {
544        let mut result = array.clone();
545
546        for mut row in result.rows_mut() {
547            let max_val = row.fold(Float::NEG_INFINITY, |a, &b| a.max(b));
548            row.mapv_inplace(|x| (x - max_val).exp());
549            let sum = row.sum();
550            row /= sum;
551        }
552
553        Ok(result)
554    }
555}
556
557impl GpuEnsembleTrainer {
558    /// Create new GPU ensemble trainer
559    pub fn new(config: GpuConfig) -> Result<Self> {
560        let context = Arc::new(GpuContext::new(config)?);
561        let kernels = GradientBoostingKernels {
562            histogram_kernel: HistogramKernel {
563                n_features: 100,
564                n_bins: 256,
565                n_samples: 10000,
566            },
567            split_kernel: SplitFindingKernel {
568                n_features: 100,
569                n_bins: 256,
570                regularization: 0.01,
571            },
572            tree_update_kernel: TreeUpdateKernel {
573                max_depth: 6,
574                learning_rate: 0.1,
575            },
576            prediction_kernel: PredictionKernel {
577                n_trees: 100,
578                n_classes: 2,
579            },
580        };
581        let tensor_ops = GpuTensorOps::new(context.clone());
582
583        Ok(Self {
584            context,
585            kernels,
586            tensor_ops,
587        })
588    }
589
590    /// Train gradient boosting on GPU
591    pub fn train_gradient_boosting(
592        &self,
593        x: &Array2<Float>,
594        y: &Array1<Int>,
595        n_estimators: usize,
596    ) -> Result<Vec<Array1<Float>>> {
597        // Simplified return type
598        let mut models = Vec::new();
599
600        for i in 0..n_estimators {
601            // Compute histograms on GPU
602            self.kernels.histogram_kernel.execute(&self.context)?;
603
604            // Find best splits on GPU
605            self.kernels.split_kernel.execute(&self.context)?;
606
607            // Update tree on GPU
608            self.kernels.tree_update_kernel.execute(&self.context)?;
609
610            // Create a simple model representation
611            models.push(Array1::zeros(x.ncols()));
612        }
613
614        Ok(models)
615    }
616
617    /// Predict using GPU-accelerated ensemble
618    pub fn predict_ensemble(
619        &self,
620        models: &[Array1<Float>],
621        x: &Array2<Float>,
622    ) -> Result<Array1<Int>> {
623        // Execute prediction kernel
624        self.kernels.prediction_kernel.execute(&self.context)?;
625
626        // Simplified prediction logic
627        let mut predictions = Array1::zeros(x.nrows());
628
629        for (i, row) in x.rows().into_iter().enumerate() {
630            let mut sum = 0.0;
631            for model in models {
632                sum += row.dot(model);
633            }
634            predictions[i] = if sum > 0.0 { 1 } else { 0 };
635        }
636
637        Ok(predictions)
638    }
639
640    /// Get GPU context
641    pub fn context(&self) -> &GpuContext {
642        &self.context
643    }
644
645    /// Check GPU availability
646    pub fn is_gpu_available(&self) -> bool {
647        self.context.is_gpu_available()
648    }
649}
650
651/// GPU backend detection
652pub fn detect_available_backends() -> Vec<GpuBackend> {
653    let mut backends = Vec::new();
654
655    // In a real implementation, these would check for actual GPU support
656
657    // Check for CUDA
658    if is_cuda_available() {
659        backends.push(GpuBackend::Cuda);
660    }
661
662    // Check for OpenCL
663    if is_opencl_available() {
664        backends.push(GpuBackend::OpenCL);
665    }
666
667    // Check for Metal (macOS)
668    #[cfg(target_os = "macos")]
669    if is_metal_available() {
670        backends.push(GpuBackend::Metal);
671    }
672
673    // Check for Vulkan
674    if is_vulkan_available() {
675        backends.push(GpuBackend::Vulkan);
676    }
677
678    // Always have CPU fallback
679    backends.push(GpuBackend::CpuFallback);
680
681    backends
682}
683
684/// Check CUDA availability
685fn is_cuda_available() -> bool {
686    // In a real implementation, this would check for CUDA runtime
687    false // Placeholder
688}
689
690/// Check OpenCL availability
691fn is_opencl_available() -> bool {
692    // In a real implementation, this would check for OpenCL drivers
693    false // Placeholder
694}
695
696/// Check Metal availability
697#[cfg(target_os = "macos")]
698fn is_metal_available() -> bool {
699    // In a real implementation, this would check for Metal framework
700    true // Assume available on macOS
701}
702
703#[cfg(not(target_os = "macos"))]
704fn is_metal_available() -> bool {
705    false
706}
707
708/// Check Vulkan availability
709fn is_vulkan_available() -> bool {
710    // In a real implementation, this would check for Vulkan drivers
711    false // Placeholder
712}
713
714#[allow(non_snake_case)]
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use scirs2_core::ndarray::array;
719
720    #[test]
721    fn test_gpu_config_default() {
722        let config = GpuConfig::default();
723        assert_eq!(config.backend, GpuBackend::CpuFallback);
724        assert_eq!(config.device_id, 0);
725        assert_eq!(config.batch_size, 1024);
726    }
727
728    #[test]
729    fn test_gpu_context_creation() {
730        let config = GpuConfig::default();
731        let context = GpuContext::new(config).unwrap();
732        assert!(!context.is_gpu_available()); // Should be CPU fallback
733    }
734
735    #[test]
736    fn test_memory_manager() {
737        let mut manager = GpuMemoryManager::new(1024 * 1024); // 1MB
738
739        let ptr1 = manager.allocate(1024).unwrap();
740        let ptr2 = manager.allocate(2048).unwrap();
741
742        assert_ne!(ptr1, ptr2);
743
744        manager.free(ptr1).unwrap();
745        manager.free(ptr2).unwrap();
746
747        let (allocated, _, total) = manager.memory_stats();
748        assert_eq!(allocated, 0);
749        assert_eq!(total, 1024 * 1024);
750    }
751
752    #[test]
753    fn test_gpu_profiler() {
754        let mut profiler = GpuProfiler::new();
755        profiler.start();
756
757        profiler.record_kernel("test_kernel".to_string(), 1.5);
758        profiler.record_memory_transfer("test_transfer".to_string(), 1024, 0.5);
759
760        let results = profiler.stop();
761        assert_eq!(results.kernel_times.len(), 1);
762        assert_eq!(results.memory_transfers.len(), 1);
763        assert_eq!(results.total_kernel_time, 1.5);
764    }
765
766    #[test]
767    fn test_histogram_kernel() {
768        let kernel = HistogramKernel {
769            n_features: 10,
770            n_bins: 32,
771            n_samples: 1000,
772        };
773
774        assert!(kernel.estimated_time_ms() > 0.0);
775        assert!(kernel.memory_requirements_mb() >= 0);
776    }
777
778    #[test]
779    fn test_gpu_tensor_ops() {
780        let config = GpuConfig::default();
781        let context = Arc::new(GpuContext::new(config).unwrap());
782        let ops = GpuTensorOps::new(context);
783
784        let a = array![[1.0, 2.0], [3.0, 4.0]];
785        let b = array![[5.0, 6.0], [7.0, 8.0]];
786
787        let result = ops.elementwise_add(&a, &b).unwrap();
788        let expected = array![[6.0, 8.0], [10.0, 12.0]];
789
790        assert_eq!(result, expected);
791    }
792
793    #[test]
794    fn test_available_backends() {
795        let backends = detect_available_backends();
796        assert!(!backends.is_empty());
797        assert!(backends.contains(&GpuBackend::CpuFallback));
798    }
799
800    #[test]
801    fn test_gpu_ensemble_trainer() {
802        let config = GpuConfig::default();
803        let trainer = GpuEnsembleTrainer::new(config).unwrap();
804
805        assert!(!trainer.is_gpu_available()); // Should be CPU fallback
806        assert_eq!(trainer.context().device_info().name, "CPU Fallback");
807    }
808}