Skip to main content

scirs2_core/gpu/backends/
wgpu.rs

1//! WebGPU backend implementation for GPU operations
2//!
3//! This module provides WebGPU-specific implementations for cross-platform GPU operations.
4
5use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7// wgpu 26 removed earlier Poll enum; Device::poll exists but Maintain enum not re-exported here; we avoid explicit polling for now.
8use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15    util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17    BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18    Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19    ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20    TextureSampleType, TextureViewDimension,
21};
22
23// Fallback types for when WebGPU is not available
24#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33// WebGPU shader source templates
34#[allow(dead_code)]
35const ADAM_SHADER_WGSL: &str = r#"
36@group(0) @binding(0) var<storage, read_write> params: array<f32>;
37@group(0) @binding(1) var<storage, read> grads: array<f32>;
38@group(0) @binding(2) var<storage, read_write> m: array<f32>;
39@group(0) @binding(3) var<storage, read_write> v: array<f32>;
40
41struct AdamUniforms {
42    lr: f32,
43    beta1: f32,
44    beta2: f32,
45    eps: f32,
46    weight_decay: f32,
47    bias_correction1: f32,
48    bias_correction2: f32,
49    n: u32,
50};
51
52@group(0) @binding(4) var<uniform> uniforms: AdamUniforms;
53
54@compute @workgroup_size(64)
55#[allow(dead_code)]
56fn adam_update(@builtin(global_invocation_id) global_id: vec3<u32>) {
57    let idx = global_id.x;
58    
59    if (idx >= uniforms.n) {
60        return;
61    }
62    
63    var grad = grads[idx];
64    
65    // Apply weight decay
66    if (uniforms.weight_decay > 0.0) {
67        grad += uniforms.weight_decay * params[idx];
68    }
69    
70    // Update biased first moment estimate
71    m[idx] = uniforms.beta1 * m[idx] + (1.0 - uniforms.beta1) * grad;
72    
73    // Update biased second raw moment estimate
74    v[idx] = uniforms.beta2 * v[idx] + (1.0 - uniforms.beta2) * grad * grad;
75    
76    // Compute bias-corrected moment estimates
77    let m_hat = m[idx] / uniforms.bias_correction1;
78    let v_hat = v[idx] / uniforms.bias_correction2;
79    
80    // Update parameters
81    params[idx] -= uniforms.lr * m_hat / (sqrt(v_hat) + uniforms.eps);
82}
83"#;
84
85#[allow(dead_code)]
86const GEMM_SHADER_WGSL: &str = r#"
87@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
88@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
89@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
90
91struct GemmUniforms {
92    M: u32,
93    N: u32,
94    K: u32,
95    alpha: f32,
96    beta: f32,
97};
98
99@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
100
101@compute @workgroup_size(8, 8)
102#[allow(dead_code)]
103fn gemm(@builtin(global_invocation_id) global_id: vec3<u32>) {
104    let row = global_id.x;
105    let col = global_id.y;
106    
107    if (row >= uniforms.M || col >= uniforms.N) {
108        return;
109    }
110    
111    var sum = 0.0;
112    for (var k = 0u; k < uniforms.K; k++) {
113        sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
114    }
115    
116    let idx = row * uniforms.N + col;
117    matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
118}
119"#;
120
121/// WebGPU context wrapper
122pub struct WebGPUContext {
123    #[cfg(feature = "wgpu_backend")]
124    device: Arc<Device>,
125    #[cfg(feature = "wgpu_backend")]
126    queue: Arc<Queue>,
127    #[cfg(not(feature = "wgpu_backend"))]
128    device: Arc<WgpuDevice>,
129    #[cfg(not(feature = "wgpu_backend"))]
130    queue: Arc<WgpuQueue>,
131    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
132    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
133}
134
135// WebGPU handles are safe to send between threads when properly synchronized
136unsafe impl Send for WebGPUContext {}
137unsafe impl Sync for WebGPUContext {}
138
139impl WebGPUContext {
140    /// Create a new WebGPU context
141    pub fn new() -> Result<Self, GpuError> {
142        #[cfg(feature = "wgpu_backend")]
143        {
144            // Real WebGPU implementation
145            let instance_desc = InstanceDescriptor {
146                backends: Backends::all(),
147                ..Default::default()
148            };
149            let instance = Instance::new(&instance_desc);
150
151            let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
152                power_preference: PowerPreference::HighPerformance,
153                compatible_surface: None,
154                force_fallback_adapter: false,
155            }))
156            .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
157
158            let device_descriptor = DeviceDescriptor {
159                label: Some("SciRS2 WebGPU Device"),
160                required_features: Features::empty(),
161                required_limits: Limits::default(),
162                // Newer wgpu versions removed/changed some fields (e.g. trace Option). Use defaults for the rest.
163                ..Default::default()
164            };
165
166            let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
167                .map_err(|e| GpuError::Other(format!("{e}")))?;
168
169            Ok(Self {
170                device: Arc::new(device),
171                queue: Arc::new(queue),
172                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
173                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
174            })
175        }
176        #[cfg(not(feature = "wgpu_backend"))]
177        {
178            // Fallback implementation
179            let device = Self::initialize_webgpu()?;
180            let queue = Self::create_queue(device)?;
181
182            Ok(Self {
183                device,
184                queue,
185                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
186                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
187            })
188        }
189    }
190
191    /// Check if WebGPU is available and working
192    pub fn is_available() -> bool {
193        #[cfg(feature = "wgpu_backend")]
194        {
195            // Real WebGPU implementation - try to create an instance and adapter
196            let instance_desc = InstanceDescriptor {
197                backends: Backends::all(),
198                ..Default::default()
199            };
200            let instance = Instance::new(&instance_desc);
201
202            // Try to get an adapter (this is async, so we use a simple runtime check)
203            pollster::block_on(async {
204                instance
205                    .request_adapter(&RequestAdapterOptions {
206                        power_preference: PowerPreference::default(),
207                        compatible_surface: None,
208                        force_fallback_adapter: false,
209                    })
210                    .await
211                    .is_ok()
212            })
213        }
214        #[cfg(not(feature = "wgpu_backend"))]
215        {
216            // Fallback: return false since we don't have real WebGPU
217            false
218        }
219    }
220
221    /// Compile a shader from WGSL source
222    fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
223        #[cfg(feature = "wgpu_backend")]
224        {
225            // Real WebGPU implementation
226            let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
227                label: Some(name),
228                source: ShaderSource::Wgsl(source.into()),
229            });
230
231            // Extract entry point from source or use default
232            let entry_point = Self::extract_entry_point(source).unwrap_or("main");
233
234            // Create bind group layout + reflection infos
235            let (bind_group_layout, binding_infos) =
236                self.create_bind_group_layout_from_source(source, name)?;
237
238            // Create pipeline layout with our bind group layout
239            let pipeline_layout =
240                self.device
241                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
242                        label: Some(&format!("{}_layout", name)),
243                        bind_group_layouts: &[&bind_group_layout],
244                        // wgpu 28+: immediate_size replaces push_constant_ranges
245                        ..Default::default()
246                    });
247
248            let compute_pipeline =
249                self.device
250                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
251                        label: Some(&format!("{}_pipeline", name)),
252                        layout: Some(&pipeline_layout),
253                        module: &shader_module,
254                        entry_point: Some(entry_point),
255                        compilation_options: Default::default(),
256                        cache: None,
257                    });
258
259            Ok(WebGPUShader {
260                pipeline: compute_pipeline,
261                bind_group_layout,
262                name: name.to_string(),
263                binding_infos,
264            })
265        }
266        #[cfg(not(feature = "wgpu_backend"))]
267        {
268            // Fallback implementation
269            let pipeline = Self::compile_wgsl_source(source, name)?;
270
271            Ok(WebGPUShader {
272                pipeline,
273                bind_group_layout: std::ptr::null_mut(),
274                name: name.to_string(),
275                binding_infos: Vec::new(),
276            })
277        }
278    }
279
280    /// Create bind group layout from WGSL source analysis
281    #[cfg(feature = "wgpu_backend")]
282    fn create_bind_group_layout_from_source(
283        &self,
284        source: &str,
285        name: &str,
286    ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
287        #[derive(Default)]
288        struct PendingAttr {
289            group: Option<u32>,
290            binding: Option<u32>,
291        }
292        let mut pending = PendingAttr::default();
293        let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
294        let mut infos: Vec<BindingInfo> = Vec::new();
295
296        fn strip_comment(line: &str) -> &str {
297            line.split_once("//").map(|(a, _)| a).unwrap_or(line)
298        }
299
300        for raw_line in source.lines() {
301            let line = strip_comment(raw_line).trim();
302            if line.is_empty() {
303                continue;
304            }
305
306            if let Some(i) = line.find("@group(") {
307                if let Some(end) = line[i + 7..].find(')') {
308                    if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
309                        pending.group = Some(g);
310                    }
311                }
312            }
313            if let Some(i) = line.find("@binding(") {
314                if let Some(end) = line[i + 9..].find(')') {
315                    if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
316                        pending.binding = Some(b);
317                    }
318                }
319            }
320
321            if line.contains("var<") {
322                // variable declaration
323                if pending.group.unwrap_or(0) == 0 {
324                    // only group 0 for now
325                    let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
326                    let name = extract_var_name(line).unwrap_or("");
327                    let storage = line.contains("var<storage");
328                    let uniform = line.contains("var<uniform");
329                    let read_only = storage
330                        && (line.contains(", read>")
331                            || line.contains("var<storage, read>")
332                            || line.contains("var<storage, read,"));
333                    if storage {
334                        entries.push(BindGroupLayoutEntry {
335                            binding: binding_num,
336                            visibility: ShaderStages::COMPUTE,
337                            ty: BindingType::Buffer {
338                                ty: BufferBindingType::Storage { read_only },
339                                has_dynamic_offset: false,
340                                min_binding_size: None,
341                            },
342                            count: None,
343                        });
344                        infos.push(BindingInfo {
345                            binding: binding_num,
346                            name: name.to_string(),
347                            kind: if read_only {
348                                BindingKind::StorageRead
349                            } else {
350                                BindingKind::StorageRw
351                            },
352                        });
353                    } else if uniform {
354                        entries.push(BindGroupLayoutEntry {
355                            binding: binding_num,
356                            visibility: ShaderStages::COMPUTE,
357                            ty: BindingType::Buffer {
358                                ty: BufferBindingType::Uniform,
359                                has_dynamic_offset: false,
360                                min_binding_size: None,
361                            },
362                            count: None,
363                        });
364                        infos.push(BindingInfo {
365                            binding: binding_num,
366                            name: name.to_string(),
367                            kind: BindingKind::Uniform,
368                        });
369                    }
370                }
371                pending = PendingAttr::default();
372            }
373        }
374
375        if entries.is_empty() {
376            entries.push(BindGroupLayoutEntry {
377                binding: 0,
378                visibility: ShaderStages::COMPUTE,
379                ty: BindingType::Buffer {
380                    ty: BufferBindingType::Storage { read_only: false },
381                    has_dynamic_offset: false,
382                    min_binding_size: None,
383                },
384                count: None,
385            });
386            infos.push(BindingInfo {
387                binding: 0,
388                name: "_unnamed".into(),
389                kind: BindingKind::StorageRw,
390            });
391        }
392
393        // Deduplicate by binding number
394        let mut seen = std::collections::HashSet::new();
395        let mut dedup_entries = Vec::new();
396        let mut dedup_infos = Vec::new();
397        for (e, info) in entries.into_iter().zip(infos) {
398            if seen.insert(e.binding) {
399                dedup_entries.push(e);
400                dedup_infos.push(info);
401            }
402        }
403
404        let bind_group_layout = self
405            .device
406            .create_bind_group_layout(&BindGroupLayoutDescriptor {
407                label: Some(&format!("{}_bind_group_layout", name)),
408                entries: &dedup_entries,
409            });
410        Ok((bind_group_layout, dedup_infos))
411    }
412
413    /// Allocate device memory
414    #[cfg(feature = "wgpu_backend")]
415    pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
416        let buffer = self.device.create_buffer(&BufferDescriptor {
417            label: Some("SciRS2 Buffer"),
418            size: size as u64,
419            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
420            mapped_at_creation: false,
421        });
422
423        Ok(buffer)
424    }
425
426    /// Allocate device memory (fallback)
427    #[cfg(not(feature = "wgpu_backend"))]
428    pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
429        // Fallback implementation: return a simulated buffer handle
430        Ok((0x1000 + size) as WgpuBuffer)
431    }
432
433    // Fallback methods for when WebGPU is not available
434    #[cfg(not(feature = "wgpu_backend"))]
435    fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
436        // Stub implementation
437        Ok(0x1 as WgpuDevice)
438    }
439
440    #[cfg(not(feature = "wgpu_backend"))]
441    fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
442        // Stub implementation
443        Ok(0x2 as WgpuQueue)
444    }
445
446    #[cfg(not(feature = "wgpu_backend"))]
447    fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
448        // Stub implementation
449        Ok(0x3 as WgpuComputePipeline)
450    }
451
452    /// Extract the entry point function name from WGSL source code
453    fn extract_entry_point(source: &str) -> Option<&str> {
454        let lines: Vec<&str> = source.lines().collect();
455
456        for (i, line) in lines.iter().enumerate() {
457            let trimmed = line.trim();
458
459            // Check if this line contains @compute
460            if trimmed.contains("@compute") {
461                // The function might be on the same line or the next line
462                let mut search_line = trimmed;
463                let mut search_idx = 0;
464
465                // If @compute and function are not on the same line, check next line
466                if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
467                    search_idx += 1;
468                    search_line = lines[search_idx].trim();
469                }
470
471                // Extract function name
472                if let Some(start) = search_line.find("fn ") {
473                    let remaining = &search_line[start + 3..];
474                    if let Some(end) = remaining.find('(') {
475                        let funcname = remaining[..end].trim();
476                        return Some(funcname);
477                    }
478                }
479            }
480        }
481
482        None
483    }
484}
485
486impl GpuContextImpl for WebGPUContext {
487    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
488        // Try to allocate from memory pool first
489        if let Ok(mut pool) = self.memory_pool.lock() {
490            if let Some(device_buffer) = pool.allocate(size) {
491                return Arc::new(WebGPUBuffer {
492                    device_buffer: Some(device_buffer),
493                    #[cfg(feature = "wgpu_backend")]
494                    queue: Arc::clone(&self.queue),
495                    #[cfg(feature = "wgpu_backend")]
496                    device: Arc::clone(&self.device),
497                    #[cfg(not(feature = "wgpu_backend"))]
498                    queue: self.queue,
499                    size,
500                    memory_pool: Arc::clone(&self.memory_pool),
501                });
502            }
503        }
504
505        // Fallback to direct allocation
506        let device_buffer = match self.allocate_device_memory(size) {
507            Ok(buffer) => buffer,
508            Err(e) => {
509                // Log the WebGPU allocation failure and create a CPU fallback
510                eprintln!(
511                    "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
512                    e
513                );
514
515                #[cfg(feature = "wgpu_backend")]
516                {
517                    // Create a CPU fallback buffer with minimal size for WebGPU compatibility
518                    // This is a last resort when GPU memory is exhausted
519                    return Arc::new(WebGPUCpuFallbackBuffer {
520                        data: vec![0u8; size],
521                        size,
522                        memory_pool: Arc::clone(&self.memory_pool),
523                    });
524                }
525                #[cfg(not(feature = "wgpu_backend"))]
526                {
527                    (0x2000 + size) as WgpuBuffer
528                }
529            }
530        };
531
532        Arc::new(WebGPUBuffer {
533            device_buffer: Some(device_buffer),
534            #[cfg(feature = "wgpu_backend")]
535            queue: Arc::clone(&self.queue),
536            #[cfg(feature = "wgpu_backend")]
537            device: Arc::clone(&self.device),
538            #[cfg(not(feature = "wgpu_backend"))]
539            queue: self.queue,
540            size,
541            memory_pool: Arc::clone(&self.memory_pool),
542        })
543    }
544
545    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
546        Arc::new(WebGPUCompiler {
547            context: Arc::new(WebGPUContext {
548                memory_pool: Arc::clone(&self.memory_pool),
549                compiled_shaders: Arc::clone(&self.compiled_shaders),
550                #[cfg(feature = "wgpu_backend")]
551                device: Arc::clone(&self.device),
552                #[cfg(feature = "wgpu_backend")]
553                queue: Arc::clone(&self.queue),
554                #[cfg(not(feature = "wgpu_backend"))]
555                device: Arc::clone(&self.device),
556                #[cfg(not(feature = "wgpu_backend"))]
557                queue: Arc::clone(&self.queue),
558            }),
559        })
560    }
561
562    fn as_any(&self) -> &dyn std::any::Any {
563        self
564    }
565}
566
567/// WebGPU shader wrapper (augmented with basic reflection info)
568struct WebGPUShader {
569    #[cfg(feature = "wgpu_backend")]
570    pipeline: ComputePipeline,
571    #[cfg(not(feature = "wgpu_backend"))]
572    pipeline: WgpuComputePipeline,
573    #[cfg(feature = "wgpu_backend")]
574    #[allow(dead_code)]
575    bind_group_layout: BindGroupLayout,
576    #[cfg(not(feature = "wgpu_backend"))]
577    #[allow(dead_code)]
578    bind_group_layout: *mut std::ffi::c_void,
579    #[allow(dead_code)]
580    name: String,
581    #[allow(dead_code)]
582    binding_infos: Vec<BindingInfo>, // basic reflection info (names may be synthetic when parser can't extract)
583}
584
585// WebGPU shader handles are safe to send between threads when properly synchronized
586unsafe impl Send for WebGPUShader {}
587unsafe impl Sync for WebGPUShader {}
588
589/// WebGPU compiler implementation
590struct WebGPUCompiler {
591    context: Arc<WebGPUContext>,
592}
593
594impl GpuCompilerImpl for WebGPUCompiler {
595    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
596        let shader = self.context.compile_shader_internal(source, "shader")?;
597        Ok(Arc::new(WebGPUKernelHandle {
598            shader_name: shader.name.clone(),
599            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
600            params: Arc::new(Mutex::new(HashMap::new())),
601            #[cfg(feature = "wgpu_backend")]
602            device: Arc::clone(&self.context.device),
603            #[cfg(feature = "wgpu_backend")]
604            queue: Arc::clone(&self.context.queue),
605            #[cfg(feature = "wgpu_backend")]
606            ephemeral_uniforms: Mutex::new(Vec::new()),
607            #[cfg(not(feature = "wgpu_backend"))]
608            device: self.context.device,
609            #[cfg(not(feature = "wgpu_backend"))]
610            queue: self.context.queue,
611        }))
612    }
613
614    fn compile_typed(
615        &self,
616        name: &str,
617        _input_type: std::any::TypeId,
618        _output_type: std::any::TypeId,
619    ) -> Arc<dyn GpuKernelImpl> {
620        Arc::new(WebGPUKernelHandle {
621            shader_name: name.to_string(),
622            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
623            params: Arc::new(Mutex::new(HashMap::new())),
624            #[cfg(feature = "wgpu_backend")]
625            device: Arc::clone(&self.context.device),
626            #[cfg(feature = "wgpu_backend")]
627            queue: Arc::clone(&self.context.queue),
628            #[cfg(feature = "wgpu_backend")]
629            ephemeral_uniforms: Mutex::new(Vec::new()),
630            #[cfg(not(feature = "wgpu_backend"))]
631            device: self.context.device,
632            #[cfg(not(feature = "wgpu_backend"))]
633            queue: self.context.queue,
634        })
635    }
636}
637
638/// WebGPU kernel handle for execution
639struct WebGPUKernelHandle {
640    shader_name: String,
641    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
642    params: Arc<Mutex<HashMap<String, KernelParam>>>,
643    #[cfg(feature = "wgpu_backend")]
644    device: Arc<Device>,
645    #[cfg(feature = "wgpu_backend")]
646    queue: Arc<Queue>,
647    #[cfg(feature = "wgpu_backend")]
648    ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
649    #[cfg(not(feature = "wgpu_backend"))]
650    device: WgpuDevice,
651    #[cfg(not(feature = "wgpu_backend"))]
652    queue: WgpuQueue,
653}
654
655enum KernelParam {
656    #[allow(dead_code)]
657    Buffer(Arc<dyn GpuBufferImpl>),
658    #[allow(dead_code)]
659    U32(u32),
660    #[allow(dead_code)]
661    I32(i32),
662    #[allow(dead_code)]
663    F32(f32),
664    #[allow(dead_code)]
665    F64(f64),
666    Bytes(Vec<u8>),
667}
668
669#[derive(Clone, Debug)]
670enum BindingKind {
671    StorageRw,
672    StorageRead,
673    Uniform,
674}
675
676#[derive(Clone, Debug)]
677struct BindingInfo {
678    binding: u32,
679    name: String,
680    kind: BindingKind,
681}
682
683fn extract_var_name(line: &str) -> Option<&str> {
684    if let Some(var_start) = line.find("var<") {
685        let after_var = &line[var_start..];
686        if let Some(close) = after_var.find('>') {
687            let after = &after_var[close + 1..];
688            let after = after.trim_start();
689            if let Some(colon) = after.find(':') {
690                let name_part = after[..colon].trim();
691                if !name_part.is_empty() {
692                    return Some(name_part);
693                }
694            }
695        }
696    }
697    None
698}
699
700impl GpuKernelImpl for WebGPUKernelHandle {
701    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
702        let mut params = self.params.lock().expect("Operation failed");
703        params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
704    }
705
706    fn set_u32(&self, name: &str, value: u32) {
707        let mut params = self.params.lock().expect("Operation failed");
708        params.insert(name.to_string(), KernelParam::U32(value));
709    }
710
711    fn set_i32(&self, name: &str, value: i32) {
712        let mut params = self.params.lock().expect("Operation failed");
713        params.insert(name.to_string(), KernelParam::I32(value));
714    }
715
716    fn set_f32(&self, name: &str, value: f32) {
717        let mut params = self.params.lock().expect("Operation failed");
718        params.insert(name.to_string(), KernelParam::F32(value));
719    }
720
721    fn set_f64(&self, name: &str, value: f64) {
722        let mut params = self.params.lock().expect("Operation failed");
723        params.insert(name.to_string(), KernelParam::F64(value));
724    }
725
726    #[allow(dead_code)]
727    // raw bytes helper removed from trait; use internal helper if needed
728
729    fn dispatch(&self, workgroups: [u32; 3]) {
730        #[cfg(feature = "wgpu_backend")]
731        {
732            // Real WebGPU compute dispatch
733            let shaders = self.compiled_shaders.lock().expect("Operation failed");
734            if let Some(shader) = shaders.get(&self.shader_name) {
735                let params = self.params.lock().expect("Operation failed");
736
737                // Create command encoder
738                let mut encoder =
739                    self.device
740                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
741                            label: Some("Compute Command Encoder"),
742                        });
743
744                // Begin compute pass
745                {
746                    let mut compute_pass =
747                        encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
748                            label: Some("Compute Pass"),
749                            timestamp_writes: None,
750                        });
751
752                    // Set the compute pipeline
753                    compute_pass.set_pipeline(&shader.pipeline);
754
755                    if let Ok(bind_group) = self.create_bind_group_from_params(shader, &params) {
756                        compute_pass.set_bind_group(0, &bind_group, &[]);
757                    } else {
758                        eprintln!(
759                            "Warning: Failed to create bind group for shader {}",
760                            self.shader_name
761                        );
762                    }
763
764                    // Dispatch the compute shader
765                    compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
766                }
767
768                // Submit the command buffer
769                let command_buffer = encoder.finish();
770                self.queue.submit(std::iter::once(command_buffer));
771
772                eprintln!(
773                    "WebGPU compute shader {} dispatched with workgroups: {:?}",
774                    self.shader_name, workgroups
775                );
776            }
777        }
778        #[cfg(not(feature = "wgpu_backend"))]
779        {
780            // Fallback implementation - just log the execution
781            eprintln!("Executing WebGPU shader {} (simulated)", self.shader_name);
782            eprintln!("Work groups: {:?}", workgroups);
783        }
784    }
785}
786
787/// WebGPU buffer implementation
788struct WebGPUBuffer {
789    #[cfg(feature = "wgpu_backend")]
790    device_buffer: Option<Buffer>,
791    #[cfg(feature = "wgpu_backend")]
792    queue: Arc<Queue>,
793    #[cfg(feature = "wgpu_backend")]
794    device: Arc<Device>,
795    #[cfg(not(feature = "wgpu_backend"))]
796    device_buffer: Option<WgpuBuffer>,
797    #[cfg(not(feature = "wgpu_backend"))]
798    queue: WgpuQueue,
799    size: usize,
800    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
801}
802
803// WebGPU buffer handles are safe to send between threads when properly synchronized
804// The real wgpu types (Buffer, Queue) are Send + Sync
805// For fallback types (raw pointers), we assume proper synchronization is handled externally
806unsafe impl Send for WebGPUBuffer {}
807unsafe impl Sync for WebGPUBuffer {}
808
809impl GpuBufferImpl for WebGPUBuffer {
810    fn size(&self) -> usize {
811        self.size
812    }
813
814    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
815        #[cfg(feature = "wgpu_backend")]
816        {
817            // Validate data size
818            if size > self.size {
819                // In unsafe context, we can't return an error, so we'll just log and return
820                eprintln!(
821                    "Warning: Data size {} exceeds buffer size {}",
822                    size, self.size
823                );
824                return;
825            }
826
827            // Convert raw pointer to slice for WebGPU API
828            let data_slice = std::slice::from_raw_parts(data, size);
829
830            // Real WebGPU implementation - write data to buffer
831            if let Some(ref buffer) = self.device_buffer {
832                self.queue.write_buffer(buffer, 0, data_slice);
833            }
834        }
835        #[cfg(not(feature = "wgpu_backend"))]
836        {
837            // Fallback implementation - just validate
838            if size > self.size {
839                eprintln!(
840                    "Warning: Data size {} exceeds buffer size {}",
841                    size, self.size
842                );
843            }
844            // In fallback mode, we just simulate the operation
845        }
846    }
847
848    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
849        #[cfg(feature = "wgpu_backend")]
850        {
851            // Validate data size
852            if size > self.size {
853                eprintln!(
854                    "Warning: Data size {} exceeds buffer size {}",
855                    size, self.size
856                );
857                return;
858            }
859
860            if let Some(ref buffer) = self.device_buffer {
861                let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
862                    label: Some("scirs2-readback"),
863                    size: size as u64,
864                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
865                    mapped_at_creation: false,
866                });
867                let mut encoder =
868                    self.device
869                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
870                            label: Some("scirs2-readback-enc"),
871                        });
872                encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
873                self.queue.submit(Some(encoder.finish()));
874                let slice = staging.slice(0..size as u64);
875                let (tx, rx) = std::sync::mpsc::channel();
876                slice.map_async(wgpu::MapMode::Read, move |r| {
877                    let _ = tx.send(r);
878                });
879                // TODO: explicit device.poll if necessary for certain platforms
880                if let Ok(Ok(())) = rx.recv() {
881                    let mapped = slice.get_mapped_range();
882                    let dst = std::slice::from_raw_parts_mut(data, size);
883                    dst.copy_from_slice(&mapped);
884                    drop(mapped);
885                    staging.unmap();
886                } else {
887                    eprintln!("Warning: map_async failed for readback");
888                }
889            }
890        }
891        #[cfg(not(feature = "wgpu_backend"))]
892        {
893            // Fallback implementation - just validate and zero out
894            if size > self.size {
895                eprintln!(
896                    "Warning: Data size {} exceeds buffer size {}",
897                    size, self.size
898                );
899            }
900
901            // Zero out the data as a placeholder
902            let data_slice = std::slice::from_raw_parts_mut(data, size);
903            data_slice.fill(0);
904        }
905    }
906
907    fn device_ptr(&self) -> u64 {
908        #[cfg(feature = "wgpu_backend")]
909        {
910            // WebGPU doesn't expose raw device pointers, so we return a placeholder
911            // In a real implementation, this might return a handle or ID
912            &self.device_buffer as *const _ as u64
913        }
914        #[cfg(not(feature = "wgpu_backend"))]
915        {
916            self.device_buffer as u64
917        }
918    }
919
920    fn as_any(&self) -> &dyn std::any::Any {
921        self
922    }
923}
924
925#[cfg(feature = "wgpu_backend")]
926impl WebGPUKernelHandle {
927    fn create_bind_group_from_params(
928        &self,
929        shader: &WebGPUShader,
930        params: &HashMap<String, KernelParam>,
931    ) -> Result<wgpu::BindGroup, GpuError> {
932        let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
933        // Hold uniform buffers so their lifetime extends until after bind_group creation
934        let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
935        let mut uniform_bytes: Vec<u8> = Vec::new();
936        for info in &shader.binding_infos {
937            match info.kind {
938                BindingKind::StorageRw | BindingKind::StorageRead => {
939                    if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
940                        if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
941                            if let Some(ref inner) = wbuf.device_buffer {
942                                entries.push(wgpu::BindGroupEntry {
943                                    binding: info.binding,
944                                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
945                                        buffer: inner,
946                                        offset: 0,
947                                        size: None,
948                                    }),
949                                });
950                            }
951                        }
952                    } else {
953                        return Err(GpuError::InvalidParameter(format!(
954                            "Missing buffer param '{}'",
955                            info.name
956                        )));
957                    }
958                }
959                BindingKind::Uniform => {
960                    // Collect all scalars/bytes with key prefix or exact match
961                    for (k, v) in params.iter() {
962                        if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
963                            match v {
964                                KernelParam::U32(u) => {
965                                    uniform_bytes.extend_from_slice(&u.to_le_bytes())
966                                }
967                                KernelParam::I32(i) => {
968                                    uniform_bytes.extend_from_slice(&i.to_le_bytes())
969                                }
970                                KernelParam::F32(f) => {
971                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
972                                }
973                                KernelParam::F64(f) => {
974                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
975                                }
976                                KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
977                                KernelParam::Buffer(_) => {}
978                            }
979                        }
980                    }
981                }
982            }
983        }
984        if !uniform_bytes.is_empty() {
985            while uniform_bytes.len() % 16 != 0 {
986                uniform_bytes.push(0);
987            }
988            if let Some(uinfo) = shader
989                .binding_infos
990                .iter()
991                .find(|b| matches!(b.kind, BindingKind::Uniform))
992            {
993                if let Ok(mut list) = self.ephemeral_uniforms.lock() {
994                    list.clear();
995                    let ubuf = self
996                        .device
997                        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
998                            label: Some("scirs2-uniforms"),
999                            contents: &uniform_bytes,
1000                            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1001                        });
1002                    list.push(ubuf.clone());
1003                    owned_uniform_buffers.push(ubuf.clone());
1004                    let idx = owned_uniform_buffers.len() - 1;
1005                    let buf_ref = &owned_uniform_buffers[idx];
1006                    entries.push(wgpu::BindGroupEntry {
1007                        binding: uinfo.binding,
1008                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1009                            buffer: buf_ref,
1010                            offset: 0,
1011                            size: None,
1012                        }),
1013                    });
1014                }
1015            }
1016        } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1017            list.clear();
1018        }
1019        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1020            label: Some("scirs2-bind-group"),
1021            layout: &shader.bind_group_layout,
1022            entries: &entries,
1023        });
1024        Ok(bind_group)
1025    }
1026}
1027
1028impl Drop for WebGPUBuffer {
1029    fn drop(&mut self) {
1030        // Return buffer to memory pool if possible
1031        if let Ok(mut pool) = self.memory_pool.lock() {
1032            #[cfg(feature = "wgpu_backend")]
1033            {
1034                // In real implementation, would return buffer to pool
1035                if let Some(buffer) = self.device_buffer.take() {
1036                    pool.deallocate(buffer);
1037                }
1038            }
1039            #[cfg(not(feature = "wgpu_backend"))]
1040            {
1041                if let Some(buffer) = self.device_buffer.take() {
1042                    pool.deallocate(buffer);
1043                }
1044            }
1045        }
1046    }
1047}
1048
1049/// CPU fallback buffer for when WebGPU buffer allocation fails
1050/// This provides a graceful degradation when GPU memory is exhausted
1051struct WebGPUCpuFallbackBuffer {
1052    data: Vec<u8>,
1053    size: usize,
1054    #[allow(dead_code)]
1055    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1056}
1057
1058impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1059    fn size(&self) -> usize {
1060        self.size
1061    }
1062
1063    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1064        if size > self.size {
1065            eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1066            return;
1067        }
1068
1069        // Since this is a CPU fallback, we can use safe Rust internally
1070        let data_slice = std::slice::from_raw_parts(data, size);
1071        // We can't mutate self.data directly since &self is immutable
1072        // In a real implementation, this would require interior mutability
1073        eprintln!(
1074            "Warning: CPU fallback buffer copy_from_host called (size: {})",
1075            size
1076        );
1077    }
1078
1079    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1080        if size > self.size {
1081            eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1082            return;
1083        }
1084
1085        // Copy from CPU buffer to host
1086        let data_slice = std::slice::from_raw_parts_mut(data, size);
1087        let copy_size = size.min(self.data.len());
1088        data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1089
1090        eprintln!(
1091            "Warning: CPU fallback buffer copy_to_host called (size: {})",
1092            size
1093        );
1094    }
1095
1096    fn device_ptr(&self) -> u64 {
1097        self.data.as_ptr() as u64
1098    }
1099
1100    fn as_any(&self) -> &dyn std::any::Any {
1101        self
1102    }
1103}
1104
1105// Safety: WebGPUCpuFallbackBuffer is thread-safe since it only contains owned data
1106unsafe impl Send for WebGPUCpuFallbackBuffer {}
1107unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1108
1109/// WebGPU memory pool for efficient buffer management
1110struct WebGPUMemoryPool {
1111    #[cfg(feature = "wgpu_backend")]
1112    available_buffers: HashMap<usize, Vec<Buffer>>,
1113    #[cfg(not(feature = "wgpu_backend"))]
1114    available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1115    #[allow(dead_code)]
1116    total_size: usize,
1117    used_size: usize,
1118}
1119
1120impl WebGPUMemoryPool {
1121    fn new(totalsize: usize) -> Self {
1122        Self {
1123            available_buffers: HashMap::new(),
1124            total_size: totalsize,
1125            used_size: 0,
1126        }
1127    }
1128
1129    #[cfg(feature = "wgpu_backend")]
1130    fn allocate(&mut self, size: usize) -> Option<Buffer> {
1131        // Try to find a suitable buffer in the pool
1132        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1133            if let Some(buffer) = buffers.pop() {
1134                self.used_size += size;
1135                return Some(buffer);
1136            }
1137        }
1138        None
1139    }
1140
1141    #[cfg(not(feature = "wgpu_backend"))]
1142    fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1143        // Try to find a suitable buffer in the pool
1144        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1145            if let Some(buffer) = buffers.pop() {
1146                self.used_size += size;
1147                return Some(buffer);
1148            }
1149        }
1150        None
1151    }
1152
1153    #[cfg(feature = "wgpu_backend")]
1154    fn deallocate(&mut self, buffer: Buffer) {
1155        // Return buffer to pool
1156        let size = buffer.size() as usize;
1157        self.available_buffers
1158            .entry(size)
1159            .or_insert_with(Vec::new)
1160            .push(buffer);
1161        self.used_size = self.used_size.saturating_sub(size);
1162    }
1163
1164    #[cfg(not(feature = "wgpu_backend"))]
1165    fn deallocate(&mut self, buffer: WgpuBuffer) {
1166        // Fallback implementation - track the buffer
1167        let size = 1024; // Placeholder size
1168        self.available_buffers
1169            .entry(size)
1170            .or_insert_with(Vec::new)
1171            .push(buffer);
1172        self.used_size = self.used_size.saturating_sub(size);
1173    }
1174
1175    #[allow(dead_code)]
1176    fn get_memory_usage(&self) -> (usize, usize) {
1177        (self.used_size, self.total_size)
1178    }
1179}