sklears_svm/
gpu_kernels.rs

1//! GPU-accelerated kernel computations using WGPU
2//!
3//! This module provides GPU-accelerated implementations of common SVM kernels
4//! including RBF, Polynomial, and Linear kernels. It uses WGPU for cross-platform
5//! GPU acceleration and compute shaders for high-performance kernel matrix computation.
6
7use crate::kernels::KernelType;
8use scirs2_core::ndarray::{s, Array2};
9use thiserror::Error;
10
11#[cfg(feature = "gpu")]
12use std::collections::HashMap;
13#[cfg(feature = "gpu")]
14use wgpu::{
15    util::DeviceExt, Adapter, BufferDescriptor, BufferUsages, ComputePassDescriptor,
16    ComputePipeline, Device, DeviceDescriptor, Features, Instance, Limits, PowerPreference, Queue,
17    RequestAdapterOptions, ShaderModule, ShaderModuleDescriptor, ShaderSource,
18};
19
20/// Errors that can occur during GPU kernel computation
21#[derive(Error, Debug)]
22pub enum GpuKernelError {
23    #[error("GPU device not available")]
24    DeviceNotAvailable,
25    #[error("Insufficient GPU memory")]
26    InsufficientMemory,
27    #[error("GPU computation failed: {0}")]
28    ComputationFailed(String),
29    #[error("Shader compilation failed: {0}")]
30    ShaderCompilationFailed(String),
31    #[error("Buffer creation failed")]
32    BufferCreationFailed,
33    #[error("GPU feature not supported: {0}")]
34    FeatureNotSupported(String),
35    #[error("Kernel matrix dimensions mismatch")]
36    DimensionMismatch,
37}
38
39/// Result type for GPU kernel operations
40pub type GpuKernelResult<T> = Result<T, GpuKernelError>;
41
42/// GPU-accelerated kernel matrix computation
43#[cfg(feature = "gpu")]
44pub struct GpuKernelComputer {
45    device: Device,
46    queue: Queue,
47    adapter: Adapter,
48    pipelines: HashMap<String, ComputePipeline>,
49    shader_modules: HashMap<String, ShaderModule>,
50}
51
52#[cfg(feature = "gpu")]
53impl GpuKernelComputer {
54    /// Create a new GPU kernel computer
55    pub async fn new() -> GpuKernelResult<Self> {
56        let instance = Instance::new(&wgpu::InstanceDescriptor {
57            backends: wgpu::Backends::all(),
58            flags: wgpu::InstanceFlags::default(),
59            ..Default::default()
60        });
61
62        let adapter = instance
63            .request_adapter(&RequestAdapterOptions {
64                power_preference: PowerPreference::HighPerformance,
65                compatible_surface: None,
66                force_fallback_adapter: false,
67            })
68            .await
69            .ok_or(GpuKernelError::DeviceNotAvailable)?;
70
71        let (device, queue) = adapter
72            .request_device(
73                &DeviceDescriptor {
74                    label: None,
75                    required_features: Features::empty(),
76                    required_limits: Limits::default(),
77                    memory_hints: wgpu::MemoryHints::Performance,
78                },
79                None,
80            )
81            .await
82            .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
83
84        let mut computer = Self {
85            device,
86            queue,
87            adapter,
88            pipelines: HashMap::new(),
89            shader_modules: HashMap::new(),
90        };
91
92        // Initialize common shaders
93        computer.init_rbf_shader()?;
94        computer.init_polynomial_shader()?;
95        computer.init_linear_shader()?;
96        computer.init_sigmoid_shader()?;
97
98        Ok(computer)
99    }
100
101    /// Initialize RBF kernel shader
102    fn init_rbf_shader(&mut self) -> GpuKernelResult<()> {
103        let shader_source = r#"
104            @group(0) @binding(0) var<storage, read> X: array<f32>;
105            @group(0) @binding(1) var<storage, read> Y: array<f32>;
106            @group(0) @binding(2) var<storage, read_write> result: array<f32>;
107            @group(0) @binding(3) var<storage, read> params: array<f32>;
108
109            @compute @workgroup_size(16, 16)
110            fn rbf_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
111                let n_x = u32(params[0]);
112                let n_y = u32(params[1]);
113                let n_features = u32(params[2]);
114                let gamma = params[3];
115
116                let i = global_id.x;
117                let j = global_id.y;
118
119                if (i >= n_x || j >= n_y) {
120                    return;
121                }
122
123                var sum_sq_diff = 0.0;
124                for (var k = 0u; k < n_features; k++) {
125                    let diff = X[i * n_features + k] - Y[j * n_features + k];
126                    sum_sq_diff += diff * diff;
127                }
128
129                result[i * n_y + j] = exp(-gamma * sum_sq_diff);
130            }
131        "#;
132
133        let shader = self.device.create_shader_module(ShaderModuleDescriptor {
134            label: Some("RBF Kernel Shader"),
135            source: ShaderSource::Wgsl(shader_source.into()),
136        });
137
138        let pipeline = self
139            .device
140            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
141                label: Some("RBF Kernel Pipeline"),
142                layout: None,
143                module: &shader,
144                entry_point: Some("rbf_kernel"),
145                compilation_options: Default::default(),
146                cache: None,
147            });
148
149        self.shader_modules.insert("rbf".to_string(), shader);
150        self.pipelines.insert("rbf".to_string(), pipeline);
151
152        Ok(())
153    }
154
155    /// Initialize polynomial kernel shader
156    fn init_polynomial_shader(&mut self) -> GpuKernelResult<()> {
157        let shader_source = r#"
158            @group(0) @binding(0) var<storage, read> X: array<f32>;
159            @group(0) @binding(1) var<storage, read> Y: array<f32>;
160            @group(0) @binding(2) var<storage, read_write> result: array<f32>;
161            @group(0) @binding(3) var<storage, read> params: array<f32>;
162
163            @compute @workgroup_size(16, 16)
164            fn polynomial_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
165                let n_x = u32(params[0]);
166                let n_y = u32(params[1]);
167                let n_features = u32(params[2]);
168                let gamma = params[3];
169                let coef0 = params[4];
170                let degree = params[5];
171
172                let i = global_id.x;
173                let j = global_id.y;
174
175                if (i >= n_x || j >= n_y) {
176                    return;
177                }
178
179                var dot_product = 0.0;
180                for (var k = 0u; k < n_features; k++) {
181                    dot_product += X[i * n_features + k] * Y[j * n_features + k];
182                }
183
184                result[i * n_y + j] = pow(gamma * dot_product + coef0, degree);
185            }
186        "#;
187
188        let shader = self.device.create_shader_module(ShaderModuleDescriptor {
189            label: Some("Polynomial Kernel Shader"),
190            source: ShaderSource::Wgsl(shader_source.into()),
191        });
192
193        let pipeline = self
194            .device
195            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
196                label: Some("Polynomial Kernel Pipeline"),
197                layout: None,
198                module: &shader,
199                entry_point: Some("polynomial_kernel"),
200                compilation_options: Default::default(),
201                cache: None,
202            });
203
204        self.shader_modules.insert("polynomial".to_string(), shader);
205        self.pipelines.insert("polynomial".to_string(), pipeline);
206
207        Ok(())
208    }
209
210    /// Initialize linear kernel shader
211    fn init_linear_shader(&mut self) -> GpuKernelResult<()> {
212        let shader_source = r#"
213            @group(0) @binding(0) var<storage, read> X: array<f32>;
214            @group(0) @binding(1) var<storage, read> Y: array<f32>;
215            @group(0) @binding(2) var<storage, read_write> result: array<f32>;
216            @group(0) @binding(3) var<storage, read> params: array<f32>;
217
218            @compute @workgroup_size(16, 16)
219            fn linear_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
220                let n_x = u32(params[0]);
221                let n_y = u32(params[1]);
222                let n_features = u32(params[2]);
223
224                let i = global_id.x;
225                let j = global_id.y;
226
227                if (i >= n_x || j >= n_y) {
228                    return;
229                }
230
231                var dot_product = 0.0;
232                for (var k = 0u; k < n_features; k++) {
233                    dot_product += X[i * n_features + k] * Y[j * n_features + k];
234                }
235
236                result[i * n_y + j] = dot_product;
237            }
238        "#;
239
240        let shader = self.device.create_shader_module(ShaderModuleDescriptor {
241            label: Some("Linear Kernel Shader"),
242            source: ShaderSource::Wgsl(shader_source.into()),
243        });
244
245        let pipeline = self
246            .device
247            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
248                label: Some("Linear Kernel Pipeline"),
249                layout: None,
250                module: &shader,
251                entry_point: Some("linear_kernel"),
252                compilation_options: Default::default(),
253                cache: None,
254            });
255
256        self.shader_modules.insert("linear".to_string(), shader);
257        self.pipelines.insert("linear".to_string(), pipeline);
258
259        Ok(())
260    }
261
262    /// Initialize sigmoid kernel shader
263    fn init_sigmoid_shader(&mut self) -> GpuKernelResult<()> {
264        let shader_source = r#"
265            @group(0) @binding(0) var<storage, read> X: array<f32>;
266            @group(0) @binding(1) var<storage, read> Y: array<f32>;
267            @group(0) @binding(2) var<storage, read_write> result: array<f32>;
268            @group(0) @binding(3) var<storage, read> params: array<f32>;
269
270            @compute @workgroup_size(16, 16)
271            fn sigmoid_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
272                let n_x = u32(params[0]);
273                let n_y = u32(params[1]);
274                let n_features = u32(params[2]);
275                let gamma = params[3];
276                let coef0 = params[4];
277
278                let i = global_id.x;
279                let j = global_id.y;
280
281                if (i >= n_x || j >= n_y) {
282                    return;
283                }
284
285                var dot_product = 0.0;
286                for (var k = 0u; k < n_features; k++) {
287                    dot_product += X[i * n_features + k] * Y[j * n_features + k];
288                }
289
290                result[i * n_y + j] = tanh(gamma * dot_product + coef0);
291            }
292        "#;
293
294        let shader = self.device.create_shader_module(ShaderModuleDescriptor {
295            label: Some("Sigmoid Kernel Shader"),
296            source: ShaderSource::Wgsl(shader_source.into()),
297        });
298
299        let pipeline = self
300            .device
301            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
302                label: Some("Sigmoid Kernel Pipeline"),
303                layout: None,
304                module: &shader,
305                entry_point: Some("sigmoid_kernel"),
306                compilation_options: Default::default(),
307                cache: None,
308            });
309
310        self.shader_modules.insert("sigmoid".to_string(), shader);
311        self.pipelines.insert("sigmoid".to_string(), pipeline);
312
313        Ok(())
314    }
315
316    /// Compute kernel matrix on GPU
317    pub async fn compute_kernel_matrix(
318        &self,
319        X: &Array2<f32>,
320        Y: &Array2<f32>,
321        kernel_type: &KernelType,
322    ) -> GpuKernelResult<Array2<f32>> {
323        let (n_x, n_features_x) = X.dim();
324        let (n_y, n_features_y) = Y.dim();
325
326        if n_features_x != n_features_y {
327            return Err(GpuKernelError::DimensionMismatch);
328        }
329
330        let (pipeline_name, params) = match kernel_type {
331            KernelType::Rbf { gamma } => (
332                "rbf",
333                vec![n_x as f32, n_y as f32, n_features_x as f32, *gamma as f32],
334            ),
335            KernelType::Polynomial {
336                gamma,
337                coef0,
338                degree,
339            } => (
340                "polynomial",
341                vec![
342                    n_x as f32,
343                    n_y as f32,
344                    n_features_x as f32,
345                    *gamma as f32,
346                    *coef0 as f32,
347                    *degree as f32,
348                ],
349            ),
350            KernelType::Linear => ("linear", vec![n_x as f32, n_y as f32, n_features_x as f32]),
351            KernelType::Sigmoid { gamma, coef0 } => (
352                "sigmoid",
353                vec![
354                    n_x as f32,
355                    n_y as f32,
356                    n_features_x as f32,
357                    *gamma as f32,
358                    *coef0 as f32,
359                ],
360            ),
361            _ => {
362                return Err(GpuKernelError::FeatureNotSupported(
363                    "Kernel type not supported on GPU".to_string(),
364                ))
365            }
366        };
367
368        let pipeline = self.pipelines.get(pipeline_name).ok_or_else(|| {
369            GpuKernelError::FeatureNotSupported(format!("Pipeline {pipeline_name} not found"))
370        })?;
371
372        // Create buffers
373        let x_buffer = self
374            .device
375            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
376                label: Some("X Buffer"),
377                contents: bytemuck::cast_slice(X.as_slice().unwrap()),
378                usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
379            });
380
381        let y_buffer = self
382            .device
383            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
384                label: Some("Y Buffer"),
385                contents: bytemuck::cast_slice(Y.as_slice().unwrap()),
386                usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
387            });
388
389        let result_buffer = self.device.create_buffer(&BufferDescriptor {
390            label: Some("Result Buffer"),
391            size: (n_x * n_y * std::mem::size_of::<f32>()) as u64,
392            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
393            mapped_at_creation: false,
394        });
395
396        let params_buffer = self
397            .device
398            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
399                label: Some("Params Buffer"),
400                contents: bytemuck::cast_slice(&params),
401                usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
402            });
403
404        // Create bind group
405        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
406            label: Some("Kernel Bind Group"),
407            layout: &pipeline.get_bind_group_layout(0),
408            entries: &[
409                wgpu::BindGroupEntry {
410                    binding: 0,
411                    resource: x_buffer.as_entire_binding(),
412                },
413                wgpu::BindGroupEntry {
414                    binding: 1,
415                    resource: y_buffer.as_entire_binding(),
416                },
417                wgpu::BindGroupEntry {
418                    binding: 2,
419                    resource: result_buffer.as_entire_binding(),
420                },
421                wgpu::BindGroupEntry {
422                    binding: 3,
423                    resource: params_buffer.as_entire_binding(),
424                },
425            ],
426        });
427
428        // Execute compute pass
429        let mut encoder = self
430            .device
431            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
432                label: Some("Kernel Compute Encoder"),
433            });
434
435        {
436            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
437                label: Some("Kernel Compute Pass"),
438                timestamp_writes: None,
439            });
440
441            compute_pass.set_pipeline(pipeline);
442            compute_pass.set_bind_group(0, &bind_group, &[]);
443
444            let workgroup_size = 16;
445            let num_workgroups_x = (n_x + workgroup_size - 1) / workgroup_size;
446            let num_workgroups_y = (n_y + workgroup_size - 1) / workgroup_size;
447
448            compute_pass.dispatch_workgroups(num_workgroups_x as u32, num_workgroups_y as u32, 1);
449        }
450
451        // Create staging buffer for reading results
452        let staging_buffer = self.device.create_buffer(&BufferDescriptor {
453            label: Some("Staging Buffer"),
454            size: (n_x * n_y * std::mem::size_of::<f32>()) as u64,
455            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
456            mapped_at_creation: false,
457        });
458
459        encoder.copy_buffer_to_buffer(
460            &result_buffer,
461            0,
462            &staging_buffer,
463            0,
464            (n_x * n_y * std::mem::size_of::<f32>()) as u64,
465        );
466
467        self.queue.submit(Some(encoder.finish()));
468
469        // Read results
470        let buffer_slice = staging_buffer.slice(..);
471        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
472        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
473            tx.send(result).unwrap();
474        });
475
476        self.device.poll(wgpu::Maintain::Wait);
477        rx.receive()
478            .await
479            .unwrap()
480            .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
481
482        let data = buffer_slice.get_mapped_range();
483        let result_data: &[f32] = bytemuck::cast_slice(&data);
484
485        let result_matrix = Array2::from_shape_vec((n_x, n_y), result_data.to_vec())
486            .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
487
488        Ok(result_matrix)
489    }
490
491    /// Get device information
492    pub fn device_info(&self) -> String {
493        format!("GPU: {}", self.adapter.get_info().name)
494    }
495
496    /// Check if GPU supports required features
497    pub fn supports_compute(&self) -> bool {
498        self.adapter.features().contains(Features::empty())
499    }
500
501    /// Get GPU memory info
502    pub fn memory_info(&self) -> wgpu::AdapterInfo {
503        self.adapter.get_info()
504    }
505}
506
507#[cfg(not(feature = "gpu"))]
508pub struct GpuKernelComputer;
509
510#[cfg(not(feature = "gpu"))]
511impl GpuKernelComputer {
512    pub async fn new() -> GpuKernelResult<Self> {
513        Err(GpuKernelError::FeatureNotSupported(
514            "GPU support not enabled".to_string(),
515        ))
516    }
517
518    pub async fn compute_kernel_matrix(
519        &self,
520        _x: &Array2<f32>,
521        _y: &Array2<f32>,
522        _kernel_type: &KernelType,
523    ) -> GpuKernelResult<Array2<f32>> {
524        Err(GpuKernelError::FeatureNotSupported(
525            "GPU support not enabled".to_string(),
526        ))
527    }
528}
529
530/// GPU-accelerated kernel function wrapper
531pub struct GpuKernel {
532    #[cfg(feature = "gpu")]
533    computer: Option<GpuKernelComputer>,
534    kernel_type: KernelType,
535    use_gpu: bool,
536}
537
538impl GpuKernel {
539    /// Create a new GPU-accelerated kernel
540    pub fn new(kernel_type: KernelType, use_gpu: bool) -> Self {
541        Self {
542            #[cfg(feature = "gpu")]
543            computer: None,
544            kernel_type,
545            use_gpu,
546        }
547    }
548
549    /// Initialize GPU acceleration
550    #[cfg(feature = "gpu")]
551    pub async fn init_gpu(&mut self) -> GpuKernelResult<()> {
552        if self.use_gpu {
553            self.computer = Some(GpuKernelComputer::new().await?);
554        }
555        Ok(())
556    }
557
558    #[cfg(not(feature = "gpu"))]
559    pub async fn init_gpu(&mut self) -> GpuKernelResult<()> {
560        if self.use_gpu {
561            return Err(GpuKernelError::FeatureNotSupported(
562                "GPU support not enabled".to_string(),
563            ));
564        }
565        Ok(())
566    }
567
568    /// Compute kernel matrix with GPU acceleration if available
569    pub async fn compute_matrix(&self, x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
570        #[cfg(feature = "gpu")]
571        if let Some(computer) = &self.computer {
572            if let Ok(result) = computer
573                .compute_kernel_matrix(x, y, &self.kernel_type)
574                .await
575            {
576                return result;
577            }
578        }
579
580        // Fallback to CPU computation
581        self.compute_cpu_kernel_matrix(x, y)
582    }
583
584    /// Compute kernel matrix on CPU
585    pub fn compute_cpu_kernel_matrix(&self, x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
586        let (n_x, _n_features) = x.dim();
587        let (n_y, _) = y.dim();
588        let mut result = Array2::zeros((n_x, n_y));
589
590        for i in 0..n_x {
591            for j in 0..n_y {
592                let x_i = x.row(i);
593                let y_j = y.row(j);
594
595                let kernel_value = match &self.kernel_type {
596                    KernelType::Linear => x_i.dot(&y_j) as f64,
597                    KernelType::Rbf { gamma } => {
598                        let diff = &x_i - &y_j;
599                        let squared_distance = diff.dot(&diff) as f64;
600                        (-gamma * squared_distance).exp()
601                    }
602                    KernelType::Polynomial {
603                        gamma,
604                        coef0,
605                        degree,
606                    } => {
607                        let dot_product = x_i.dot(&y_j) as f64;
608                        (gamma * dot_product + coef0).powf(*degree)
609                    }
610                    KernelType::Sigmoid { gamma, coef0 } => {
611                        let dot_product = x_i.dot(&y_j) as f64;
612                        (gamma * dot_product + coef0).tanh()
613                    }
614                    _ => 0.0, // Unsupported kernel types default to 0
615                };
616
617                result[(i, j)] = kernel_value as f32;
618            }
619        }
620
621        result
622    }
623
624    /// Check if GPU is available and initialized
625    pub fn is_gpu_available(&self) -> bool {
626        #[cfg(feature = "gpu")]
627        return self.computer.is_some();
628        #[cfg(not(feature = "gpu"))]
629        false
630    }
631
632    /// Get device information
633    pub fn device_info(&self) -> String {
634        #[cfg(feature = "gpu")]
635        if let Some(computer) = &self.computer {
636            return computer.device_info();
637        }
638        "CPU".to_string()
639    }
640}
641
642/// Benchmark GPU vs CPU kernel computation
643pub struct GpuKernelBenchmark {
644    pub gpu_time: Option<std::time::Duration>,
645    pub cpu_time: std::time::Duration,
646    pub speedup: Option<f64>,
647    pub accuracy: f64,
648}
649
650impl GpuKernelBenchmark {
651    /// Run benchmark comparing GPU and CPU kernel computation
652    pub async fn run(
653        x: &Array2<f32>,
654        y: &Array2<f32>,
655        kernel_type: KernelType,
656    ) -> GpuKernelResult<Self> {
657        // CPU benchmark
658        let cpu_start = std::time::Instant::now();
659        let cpu_kernel = GpuKernel::new(kernel_type.clone(), false);
660        #[cfg_attr(not(feature = "gpu"), allow(unused_variables))]
661        let cpu_result = cpu_kernel.compute_cpu_kernel_matrix(x, y);
662        let cpu_time = cpu_start.elapsed();
663
664        // GPU benchmark
665        #[cfg(feature = "gpu")]
666        let (gpu_time, speedup, accuracy) = {
667            if let Ok(computer) = GpuKernelComputer::new().await {
668                let gpu_start = std::time::Instant::now();
669                let gpu_result = computer.compute_kernel_matrix(x, y, &kernel_type).await?;
670                let gpu_time = gpu_start.elapsed();
671
672                // Compute accuracy
673                let diff = &cpu_result - &gpu_result;
674                let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
675                let accuracy = 1.0 - (mse as f64).sqrt();
676
677                let speedup = cpu_time.as_secs_f64() / gpu_time.as_secs_f64();
678
679                (Some(gpu_time), Some(speedup), accuracy)
680            } else {
681                (None, None, 0.0)
682            }
683        };
684
685        #[cfg(not(feature = "gpu"))]
686        let (gpu_time, speedup, accuracy) = (None, None, 0.0);
687
688        Ok(GpuKernelBenchmark {
689            gpu_time,
690            cpu_time,
691            speedup,
692            accuracy,
693        })
694    }
695}
696
697/// Utilities for GPU kernel optimization
698pub mod gpu_utils {
699    use super::*;
700
701    /// Optimal batch size for GPU computation
702    pub fn optimal_batch_size(n_samples: usize, n_features: usize) -> usize {
703        // Heuristic: balance memory usage and parallelism
704        let memory_limit = 1024 * 1024 * 1024; // 1GB limit
705        let sample_size = n_features * std::mem::size_of::<f32>();
706        let max_batch = memory_limit / sample_size;
707
708        (max_batch.min(n_samples)).max(1)
709    }
710
711    /// Check if GPU acceleration is beneficial
712    pub fn should_use_gpu(n_samples: usize, n_features: usize) -> bool {
713        // Use GPU for large datasets where parallel computation is beneficial
714        let computation_size = n_samples * n_samples * n_features;
715        computation_size > 1_000_000 // Threshold for GPU acceleration
716    }
717
718    /// Batch kernel matrix computation for large datasets
719    pub async fn compute_kernel_matrix_batched(
720        computer: &GpuKernelComputer,
721        x: &Array2<f32>,
722        y: &Array2<f32>,
723        kernel_type: &KernelType,
724        batch_size: usize,
725    ) -> GpuKernelResult<Array2<f32>> {
726        let (n_x, _n_features) = x.dim();
727        let (n_y, _) = y.dim();
728
729        let mut result = Array2::zeros((n_x, n_y));
730
731        for i in (0..n_x).step_by(batch_size) {
732            let end_i = (i + batch_size).min(n_x);
733            let x_batch = x.slice(s![i..end_i, ..]);
734
735            for j in (0..n_y).step_by(batch_size) {
736                let end_j = (j + batch_size).min(n_y);
737                let y_batch = y.slice(s![j..end_j, ..]);
738
739                let batch_result = computer
740                    .compute_kernel_matrix(&x_batch.to_owned(), &y_batch.to_owned(), kernel_type)
741                    .await?;
742
743                result
744                    .slice_mut(s![i..end_i, j..end_j])
745                    .assign(&batch_result);
746            }
747        }
748
749        Ok(result)
750    }
751}
752
753#[allow(non_snake_case)]
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn test_gpu_kernel_creation() {
760        let kernel = GpuKernel::new(KernelType::Rbf { gamma: 1.0 }, true);
761        assert!(!kernel.is_gpu_available()); // Not initialized yet
762    }
763
764    #[test]
765    fn test_gpu_kernel_sync() {
766        let kernel = GpuKernel::new(KernelType::Linear, true);
767        // Test synchronous operations only for now
768        assert_eq!(kernel.device_info(), "CPU");
769    }
770
771    #[test]
772    fn test_gpu_utils() {
773        let batch_size = gpu_utils::optimal_batch_size(1000, 100);
774        assert!(batch_size > 0);
775
776        let should_use = gpu_utils::should_use_gpu(1000, 1000);
777        assert!(should_use);
778
779        let should_not_use = gpu_utils::should_use_gpu(10, 10);
780        assert!(!should_not_use);
781    }
782
783    #[test]
784    #[allow(non_snake_case)]
785    fn test_benchmark_sync() {
786        let X_var = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f32).collect()).unwrap();
787        let Y_var = Array2::from_shape_vec((8, 5), (0..40).map(|x| x as f32).collect()).unwrap();
788
789        // Test that we can create the benchmark struct
790        assert_eq!(X_var.dim(), (10, 5));
791        assert_eq!(Y_var.dim(), (8, 5));
792    }
793}