scirs2_core/gpu/backends/
wgpu.rs

1//! WebGPU backend implementation for GPU operations
2//!
3//! This module provides WebGPU-specific implementations for cross-platform GPU operations.
4
5use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7// wgpu 26 removed earlier Poll enum; Device::poll exists but Maintain enum not re-exported here; we avoid explicit polling for now.
8use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15    util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17    BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18    Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19    ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20    TextureSampleType, TextureViewDimension,
21};
22
23// Fallback types for when WebGPU is not available
24#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33// WebGPU shader source templates
34#[allow(dead_code)]
35const ADAM_SHADER_WGSL: &str = r#"
36@group(0) @binding(0) var<storage, read_write> params: array<f32>;
37@group(0) @binding(1) var<storage, read> grads: array<f32>;
38@group(0) @binding(2) var<storage, read_write> m: array<f32>;
39@group(0) @binding(3) var<storage, read_write> v: array<f32>;
40
41struct AdamUniforms {
42    lr: f32,
43    beta1: f32,
44    beta2: f32,
45    eps: f32,
46    weight_decay: f32,
47    bias_correction1: f32,
48    bias_correction2: f32,
49    n: u32,
50};
51
52@group(0) @binding(4) var<uniform> uniforms: AdamUniforms;
53
54@compute @workgroup_size(64)
55#[allow(dead_code)]
56fn adam_update(@builtin(global_invocation_id) global_id: vec3<u32>) {
57    let idx = global_id.x;
58    
59    if (idx >= uniforms.n) {
60        return;
61    }
62    
63    var grad = grads[idx];
64    
65    // Apply weight decay
66    if (uniforms.weight_decay > 0.0) {
67        grad += uniforms.weight_decay * params[idx];
68    }
69    
70    // Update biased first moment estimate
71    m[idx] = uniforms.beta1 * m[idx] + (1.0 - uniforms.beta1) * grad;
72    
73    // Update biased second raw moment estimate
74    v[idx] = uniforms.beta2 * v[idx] + (1.0 - uniforms.beta2) * grad * grad;
75    
76    // Compute bias-corrected moment estimates
77    let m_hat = m[idx] / uniforms.bias_correction1;
78    let v_hat = v[idx] / uniforms.bias_correction2;
79    
80    // Update parameters
81    params[idx] -= uniforms.lr * m_hat / (sqrt(v_hat) + uniforms.eps);
82}
83"#;
84
85#[allow(dead_code)]
86const GEMM_SHADER_WGSL: &str = r#"
87@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
88@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
89@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
90
91struct GemmUniforms {
92    M: u32,
93    N: u32,
94    K: u32,
95    alpha: f32,
96    beta: f32,
97};
98
99@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
100
101@compute @workgroup_size(8, 8)
102#[allow(dead_code)]
103fn gemm(@builtin(global_invocation_id) global_id: vec3<u32>) {
104    let row = global_id.x;
105    let col = global_id.y;
106    
107    if (row >= uniforms.M || col >= uniforms.N) {
108        return;
109    }
110    
111    var sum = 0.0;
112    for (var k = 0u; k < uniforms.K; k++) {
113        sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
114    }
115    
116    let idx = row * uniforms.N + col;
117    matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
118}
119"#;
120
121/// WebGPU context wrapper
122pub struct WebGPUContext {
123    #[cfg(feature = "wgpu_backend")]
124    device: Arc<Device>,
125    #[cfg(feature = "wgpu_backend")]
126    queue: Arc<Queue>,
127    #[cfg(not(feature = "wgpu_backend"))]
128    device: Arc<WgpuDevice>,
129    #[cfg(not(feature = "wgpu_backend"))]
130    queue: Arc<WgpuQueue>,
131    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
132    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
133}
134
135// WebGPU handles are safe to send between threads when properly synchronized
136unsafe impl Send for WebGPUContext {}
137unsafe impl Sync for WebGPUContext {}
138
139impl WebGPUContext {
140    /// Create a new WebGPU context
141    pub fn new() -> Result<Self, GpuError> {
142        #[cfg(feature = "wgpu_backend")]
143        {
144            // Real WebGPU implementation
145            let instance_desc = InstanceDescriptor {
146                backends: Backends::all(),
147                ..Default::default()
148            };
149            let instance = Instance::new(&instance_desc);
150
151            let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
152                power_preference: PowerPreference::HighPerformance,
153                compatible_surface: None,
154                force_fallback_adapter: false,
155            }))
156            .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
157
158            let device_descriptor = DeviceDescriptor {
159                label: Some("SciRS2 WebGPU Device"),
160                required_features: Features::empty(),
161                required_limits: Limits::default(),
162                // Newer wgpu versions removed/changed some fields (e.g. trace Option). Use defaults for the rest.
163                ..Default::default()
164            };
165
166            let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
167                .map_err(|e| GpuError::Other(format!("{e}")))?;
168
169            Ok(Self {
170                device: Arc::new(device),
171                queue: Arc::new(queue),
172                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
173                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
174            })
175        }
176        #[cfg(not(feature = "wgpu_backend"))]
177        {
178            // Fallback implementation
179            let device = Self::initialize_webgpu()?;
180            let queue = Self::create_queue(device)?;
181
182            Ok(Self {
183                device,
184                queue,
185                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
186                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
187            })
188        }
189    }
190
191    /// Check if WebGPU is available and working
192    pub fn is_available() -> bool {
193        #[cfg(feature = "wgpu_backend")]
194        {
195            // Real WebGPU implementation - try to create an instance and adapter
196            let instance_desc = InstanceDescriptor {
197                backends: Backends::all(),
198                ..Default::default()
199            };
200            let instance = Instance::new(&instance_desc);
201
202            // Try to get an adapter (this is async, so we use a simple runtime check)
203            pollster::block_on(async {
204                instance
205                    .request_adapter(&RequestAdapterOptions {
206                        power_preference: PowerPreference::default(),
207                        compatible_surface: None,
208                        force_fallback_adapter: false,
209                    })
210                    .await
211                    .is_ok()
212            })
213        }
214        #[cfg(not(feature = "wgpu_backend"))]
215        {
216            // Fallback: return false since we don't have real WebGPU
217            false
218        }
219    }
220
221    /// Compile a shader from WGSL source
222    fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
223        #[cfg(feature = "wgpu_backend")]
224        {
225            // Real WebGPU implementation
226            let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
227                label: Some(name),
228                source: ShaderSource::Wgsl(source.into()),
229            });
230
231            // Extract entry point from source or use default
232            let entry_point = Self::extract_entry_point(source).unwrap_or("main");
233
234            // Create bind group layout + reflection infos
235            let (bind_group_layout, binding_infos) =
236                self.create_bind_group_layout_from_source(source, name)?;
237
238            // Create pipeline layout with our bind group layout
239            let pipeline_layout =
240                self.device
241                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
242                        label: Some(&format!("{}_layout", name)),
243                        bind_group_layouts: &[&bind_group_layout],
244                        push_constant_ranges: &[],
245                    });
246
247            let compute_pipeline =
248                self.device
249                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
250                        label: Some(&format!("{}_pipeline", name)),
251                        layout: Some(&pipeline_layout),
252                        module: &shader_module,
253                        entry_point: Some(entry_point),
254                        compilation_options: Default::default(),
255                        cache: None,
256                    });
257
258            Ok(WebGPUShader {
259                pipeline: compute_pipeline,
260                bind_group_layout,
261                name: name.to_string(),
262                binding_infos,
263            })
264        }
265        #[cfg(not(feature = "wgpu_backend"))]
266        {
267            // Fallback implementation
268            let pipeline = Self::compile_wgsl_source(source, name)?;
269
270            Ok(WebGPUShader {
271                pipeline,
272                bind_group_layout: std::ptr::null_mut(),
273                name: name.to_string(),
274                binding_infos: Vec::new(),
275            })
276        }
277    }
278
279    /// Create bind group layout from WGSL source analysis
280    #[cfg(feature = "wgpu_backend")]
281    fn create_bind_group_layout_from_source(
282        &self,
283        source: &str,
284        name: &str,
285    ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
286        #[derive(Default)]
287        struct PendingAttr {
288            group: Option<u32>,
289            binding: Option<u32>,
290        }
291        let mut pending = PendingAttr::default();
292        let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
293        let mut infos: Vec<BindingInfo> = Vec::new();
294
295        fn strip_comment(line: &str) -> &str {
296            line.split_once("//").map(|(a, _)| a).unwrap_or(line)
297        }
298
299        for raw_line in source.lines() {
300            let line = strip_comment(raw_line).trim();
301            if line.is_empty() {
302                continue;
303            }
304
305            if let Some(i) = line.find("@group(") {
306                if let Some(end) = line[i + 7..].find(')') {
307                    if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
308                        pending.group = Some(g);
309                    }
310                }
311            }
312            if let Some(i) = line.find("@binding(") {
313                if let Some(end) = line[i + 9..].find(')') {
314                    if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
315                        pending.binding = Some(b);
316                    }
317                }
318            }
319
320            if line.contains("var<") {
321                // variable declaration
322                if pending.group.unwrap_or(0) == 0 {
323                    // only group 0 for now
324                    let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
325                    let name = extract_var_name(line).unwrap_or("");
326                    let storage = line.contains("var<storage");
327                    let uniform = line.contains("var<uniform");
328                    let read_only = storage
329                        && (line.contains(", read>")
330                            || line.contains("var<storage, read>")
331                            || line.contains("var<storage, read,"));
332                    if storage {
333                        entries.push(BindGroupLayoutEntry {
334                            binding: binding_num,
335                            visibility: ShaderStages::COMPUTE,
336                            ty: BindingType::Buffer {
337                                ty: BufferBindingType::Storage { read_only },
338                                has_dynamic_offset: false,
339                                min_binding_size: None,
340                            },
341                            count: None,
342                        });
343                        infos.push(BindingInfo {
344                            binding: binding_num,
345                            name: name.to_string(),
346                            kind: if read_only {
347                                BindingKind::StorageRead
348                            } else {
349                                BindingKind::StorageRw
350                            },
351                        });
352                    } else if uniform {
353                        entries.push(BindGroupLayoutEntry {
354                            binding: binding_num,
355                            visibility: ShaderStages::COMPUTE,
356                            ty: BindingType::Buffer {
357                                ty: BufferBindingType::Uniform,
358                                has_dynamic_offset: false,
359                                min_binding_size: None,
360                            },
361                            count: None,
362                        });
363                        infos.push(BindingInfo {
364                            binding: binding_num,
365                            name: name.to_string(),
366                            kind: BindingKind::Uniform,
367                        });
368                    }
369                }
370                pending = PendingAttr::default();
371            }
372        }
373
374        if entries.is_empty() {
375            entries.push(BindGroupLayoutEntry {
376                binding: 0,
377                visibility: ShaderStages::COMPUTE,
378                ty: BindingType::Buffer {
379                    ty: BufferBindingType::Storage { read_only: false },
380                    has_dynamic_offset: false,
381                    min_binding_size: None,
382                },
383                count: None,
384            });
385            infos.push(BindingInfo {
386                binding: 0,
387                name: "_unnamed".into(),
388                kind: BindingKind::StorageRw,
389            });
390        }
391
392        // Deduplicate by binding number
393        let mut seen = std::collections::HashSet::new();
394        let mut dedup_entries = Vec::new();
395        let mut dedup_infos = Vec::new();
396        for (e, info) in entries.into_iter().zip(infos.into_iter()) {
397            if seen.insert(e.binding) {
398                dedup_entries.push(e);
399                dedup_infos.push(info);
400            }
401        }
402
403        let bind_group_layout = self
404            .device
405            .create_bind_group_layout(&BindGroupLayoutDescriptor {
406                label: Some(&format!("{}_bind_group_layout", name)),
407                entries: &dedup_entries,
408            });
409        Ok((bind_group_layout, dedup_infos))
410    }
411
412    /// Allocate device memory
413    #[cfg(feature = "wgpu_backend")]
414    pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
415        let buffer = self.device.create_buffer(&BufferDescriptor {
416            label: Some("SciRS2 Buffer"),
417            size: size as u64,
418            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
419            mapped_at_creation: false,
420        });
421
422        Ok(buffer)
423    }
424
425    /// Allocate device memory (fallback)
426    #[cfg(not(feature = "wgpu_backend"))]
427    pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
428        // Fallback implementation: return a simulated buffer handle
429        Ok((0x1000 + size) as WgpuBuffer)
430    }
431
432    // Fallback methods for when WebGPU is not available
433    #[cfg(not(feature = "wgpu_backend"))]
434    fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
435        // Stub implementation
436        Ok(0x1 as WgpuDevice)
437    }
438
439    #[cfg(not(feature = "wgpu_backend"))]
440    fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
441        // Stub implementation
442        Ok(0x2 as WgpuQueue)
443    }
444
445    #[cfg(not(feature = "wgpu_backend"))]
446    fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
447        // Stub implementation
448        Ok(0x3 as WgpuComputePipeline)
449    }
450
451    /// Extract the entry point function name from WGSL source code
452    fn extract_entry_point(source: &str) -> Option<&str> {
453        let lines: Vec<&str> = source.lines().collect();
454
455        for (i, line) in lines.iter().enumerate() {
456            let trimmed = line.trim();
457
458            // Check if this line contains @compute
459            if trimmed.contains("@compute") {
460                // The function might be on the same line or the next line
461                let mut search_line = trimmed;
462                let mut search_idx = 0;
463
464                // If @compute and function are not on the same line, check next line
465                if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
466                    search_idx += 1;
467                    search_line = lines[search_idx].trim();
468                }
469
470                // Extract function name
471                if let Some(start) = search_line.find("fn ") {
472                    let remaining = &search_line[start + 3..];
473                    if let Some(end) = remaining.find('(') {
474                        let funcname = remaining[..end].trim();
475                        return Some(funcname);
476                    }
477                }
478            }
479        }
480
481        None
482    }
483}
484
485impl GpuContextImpl for WebGPUContext {
486    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
487        // Try to allocate from memory pool first
488        if let Ok(mut pool) = self.memory_pool.lock() {
489            if let Some(device_buffer) = pool.allocate(size) {
490                return Arc::new(WebGPUBuffer {
491                    device_buffer: Some(device_buffer),
492                    #[cfg(feature = "wgpu_backend")]
493                    queue: Arc::clone(&self.queue),
494                    #[cfg(feature = "wgpu_backend")]
495                    device: Arc::clone(&self.device),
496                    #[cfg(not(feature = "wgpu_backend"))]
497                    queue: self.queue,
498                    size,
499                    memory_pool: Arc::clone(&self.memory_pool),
500                });
501            }
502        }
503
504        // Fallback to direct allocation
505        let device_buffer = match self.allocate_device_memory(size) {
506            Ok(buffer) => buffer,
507            Err(e) => {
508                // Log the WebGPU allocation failure and create a CPU fallback
509                eprintln!(
510                    "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
511                    e
512                );
513
514                #[cfg(feature = "wgpu_backend")]
515                {
516                    // Create a CPU fallback buffer with minimal size for WebGPU compatibility
517                    // This is a last resort when GPU memory is exhausted
518                    return Arc::new(WebGPUCpuFallbackBuffer {
519                        data: vec![0u8; size],
520                        size,
521                        memory_pool: Arc::clone(&self.memory_pool),
522                    });
523                }
524                #[cfg(not(feature = "wgpu_backend"))]
525                {
526                    (0x2000 + size) as WgpuBuffer
527                }
528            }
529        };
530
531        Arc::new(WebGPUBuffer {
532            device_buffer: Some(device_buffer),
533            #[cfg(feature = "wgpu_backend")]
534            queue: Arc::clone(&self.queue),
535            #[cfg(feature = "wgpu_backend")]
536            device: Arc::clone(&self.device),
537            #[cfg(not(feature = "wgpu_backend"))]
538            queue: self.queue,
539            size,
540            memory_pool: Arc::clone(&self.memory_pool),
541        })
542    }
543
544    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
545        Arc::new(WebGPUCompiler {
546            context: Arc::new(WebGPUContext {
547                memory_pool: Arc::clone(&self.memory_pool),
548                compiled_shaders: Arc::clone(&self.compiled_shaders),
549                #[cfg(feature = "wgpu_backend")]
550                device: Arc::clone(&self.device),
551                #[cfg(feature = "wgpu_backend")]
552                queue: Arc::clone(&self.queue),
553                #[cfg(not(feature = "wgpu_backend"))]
554                device: Arc::clone(&self.device),
555                #[cfg(not(feature = "wgpu_backend"))]
556                queue: Arc::clone(&self.queue),
557            }),
558        })
559    }
560
561    fn as_any(&self) -> &dyn std::any::Any {
562        self
563    }
564}
565
566/// WebGPU shader wrapper (augmented with basic reflection info)
567struct WebGPUShader {
568    #[cfg(feature = "wgpu_backend")]
569    pipeline: ComputePipeline,
570    #[cfg(not(feature = "wgpu_backend"))]
571    pipeline: WgpuComputePipeline,
572    #[cfg(feature = "wgpu_backend")]
573    #[allow(dead_code)]
574    bind_group_layout: BindGroupLayout,
575    #[cfg(not(feature = "wgpu_backend"))]
576    #[allow(dead_code)]
577    bind_group_layout: *mut std::ffi::c_void,
578    #[allow(dead_code)]
579    name: String,
580    #[allow(dead_code)]
581    binding_infos: Vec<BindingInfo>, // basic reflection info (names may be synthetic when parser can't extract)
582}
583
584// WebGPU shader handles are safe to send between threads when properly synchronized
585unsafe impl Send for WebGPUShader {}
586unsafe impl Sync for WebGPUShader {}
587
588/// WebGPU compiler implementation
589struct WebGPUCompiler {
590    context: Arc<WebGPUContext>,
591}
592
593impl GpuCompilerImpl for WebGPUCompiler {
594    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
595        let shader = self.context.compile_shader_internal(source, "shader")?;
596        Ok(Arc::new(WebGPUKernelHandle {
597            shader_name: shader.name.clone(),
598            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
599            params: Arc::new(Mutex::new(HashMap::new())),
600            #[cfg(feature = "wgpu_backend")]
601            device: Arc::clone(&self.context.device),
602            #[cfg(feature = "wgpu_backend")]
603            queue: Arc::clone(&self.context.queue),
604            #[cfg(feature = "wgpu_backend")]
605            ephemeral_uniforms: Mutex::new(Vec::new()),
606            #[cfg(not(feature = "wgpu_backend"))]
607            device: self.context.device,
608            #[cfg(not(feature = "wgpu_backend"))]
609            queue: self.context.queue,
610        }))
611    }
612
613    fn compile_typed(
614        &self,
615        name: &str,
616        _input_type: std::any::TypeId,
617        _output_type: std::any::TypeId,
618    ) -> Arc<dyn GpuKernelImpl> {
619        Arc::new(WebGPUKernelHandle {
620            shader_name: name.to_string(),
621            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
622            params: Arc::new(Mutex::new(HashMap::new())),
623            #[cfg(feature = "wgpu_backend")]
624            device: Arc::clone(&self.context.device),
625            #[cfg(feature = "wgpu_backend")]
626            queue: Arc::clone(&self.context.queue),
627            #[cfg(feature = "wgpu_backend")]
628            ephemeral_uniforms: Mutex::new(Vec::new()),
629            #[cfg(not(feature = "wgpu_backend"))]
630            device: self.context.device,
631            #[cfg(not(feature = "wgpu_backend"))]
632            queue: self.context.queue,
633        })
634    }
635}
636
637/// WebGPU kernel handle for execution
638struct WebGPUKernelHandle {
639    shader_name: String,
640    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
641    params: Arc<Mutex<HashMap<String, KernelParam>>>,
642    #[cfg(feature = "wgpu_backend")]
643    device: Arc<Device>,
644    #[cfg(feature = "wgpu_backend")]
645    queue: Arc<Queue>,
646    #[cfg(feature = "wgpu_backend")]
647    ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
648    #[cfg(not(feature = "wgpu_backend"))]
649    device: WgpuDevice,
650    #[cfg(not(feature = "wgpu_backend"))]
651    queue: WgpuQueue,
652}
653
654enum KernelParam {
655    #[allow(dead_code)]
656    Buffer(Arc<dyn GpuBufferImpl>),
657    #[allow(dead_code)]
658    U32(u32),
659    #[allow(dead_code)]
660    I32(i32),
661    #[allow(dead_code)]
662    F32(f32),
663    #[allow(dead_code)]
664    F64(f64),
665    Bytes(Vec<u8>),
666}
667
668#[derive(Clone, Debug)]
669enum BindingKind {
670    StorageRw,
671    StorageRead,
672    Uniform,
673}
674
675#[derive(Clone, Debug)]
676struct BindingInfo {
677    binding: u32,
678    name: String,
679    kind: BindingKind,
680}
681
682fn extract_var_name(line: &str) -> Option<&str> {
683    if let Some(var_start) = line.find("var<") {
684        let after_var = &line[var_start..];
685        if let Some(close) = after_var.find('>') {
686            let after = &after_var[close + 1..];
687            let after = after.trim_start();
688            if let Some(colon) = after.find(':') {
689                let name_part = after[..colon].trim();
690                if !name_part.is_empty() {
691                    return Some(name_part);
692                }
693            }
694        }
695    }
696    None
697}
698
699impl GpuKernelImpl for WebGPUKernelHandle {
700    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
701        let mut params = self.params.lock().expect("Operation failed");
702        params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
703    }
704
705    fn set_u32(&self, name: &str, value: u32) {
706        let mut params = self.params.lock().expect("Operation failed");
707        params.insert(name.to_string(), KernelParam::U32(value));
708    }
709
710    fn set_i32(&self, name: &str, value: i32) {
711        let mut params = self.params.lock().expect("Operation failed");
712        params.insert(name.to_string(), KernelParam::I32(value));
713    }
714
715    fn set_f32(&self, name: &str, value: f32) {
716        let mut params = self.params.lock().expect("Operation failed");
717        params.insert(name.to_string(), KernelParam::F32(value));
718    }
719
720    fn set_f64(&self, name: &str, value: f64) {
721        let mut params = self.params.lock().expect("Operation failed");
722        params.insert(name.to_string(), KernelParam::F64(value));
723    }
724
725    #[allow(dead_code)]
726    // raw bytes helper removed from trait; use internal helper if needed
727
728    fn dispatch(&self, workgroups: [u32; 3]) {
729        #[cfg(feature = "wgpu_backend")]
730        {
731            // Real WebGPU compute dispatch
732            let shaders = self.compiled_shaders.lock().expect("Operation failed");
733            if let Some(shader) = shaders.get(&self.shader_name) {
734                let params = self.params.lock().expect("Operation failed");
735
736                // Create command encoder
737                let mut encoder =
738                    self.device
739                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
740                            label: Some("Compute Command Encoder"),
741                        });
742
743                // Begin compute pass
744                {
745                    let mut compute_pass =
746                        encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
747                            label: Some("Compute Pass"),
748                            timestamp_writes: None,
749                        });
750
751                    // Set the compute pipeline
752                    compute_pass.set_pipeline(&shader.pipeline);
753
754                    if let Ok(bind_group) = self.create_bind_group_from_params(shader, &params) {
755                        compute_pass.set_bind_group(0, &bind_group, &[]);
756                    } else {
757                        eprintln!(
758                            "Warning: Failed to create bind group for shader {}",
759                            self.shader_name
760                        );
761                    }
762
763                    // Dispatch the compute shader
764                    compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
765                }
766
767                // Submit the command buffer
768                let command_buffer = encoder.finish();
769                self.queue.submit(std::iter::once(command_buffer));
770
771                eprintln!(
772                    "WebGPU compute shader {} dispatched with workgroups: {:?}",
773                    self.shader_name, workgroups
774                );
775            }
776        }
777        #[cfg(not(feature = "wgpu_backend"))]
778        {
779            // Fallback implementation - just log the execution
780            eprintln!("Executing WebGPU shader {} (simulated)", self.shader_name);
781            eprintln!("Work groups: {:?}", workgroups);
782        }
783    }
784}
785
786/// WebGPU buffer implementation
787struct WebGPUBuffer {
788    #[cfg(feature = "wgpu_backend")]
789    device_buffer: Option<Buffer>,
790    #[cfg(feature = "wgpu_backend")]
791    queue: Arc<Queue>,
792    #[cfg(feature = "wgpu_backend")]
793    device: Arc<Device>,
794    #[cfg(not(feature = "wgpu_backend"))]
795    device_buffer: Option<WgpuBuffer>,
796    #[cfg(not(feature = "wgpu_backend"))]
797    queue: WgpuQueue,
798    size: usize,
799    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
800}
801
802// WebGPU buffer handles are safe to send between threads when properly synchronized
803// The real wgpu types (Buffer, Queue) are Send + Sync
804// For fallback types (raw pointers), we assume proper synchronization is handled externally
805unsafe impl Send for WebGPUBuffer {}
806unsafe impl Sync for WebGPUBuffer {}
807
808impl GpuBufferImpl for WebGPUBuffer {
809    fn size(&self) -> usize {
810        self.size
811    }
812
813    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
814        #[cfg(feature = "wgpu_backend")]
815        {
816            // Validate data size
817            if size > self.size {
818                // In unsafe context, we can't return an error, so we'll just log and return
819                eprintln!(
820                    "Warning: Data size {} exceeds buffer size {}",
821                    size, self.size
822                );
823                return;
824            }
825
826            // Convert raw pointer to slice for WebGPU API
827            let data_slice = std::slice::from_raw_parts(data, size);
828
829            // Real WebGPU implementation - write data to buffer
830            if let Some(ref buffer) = self.device_buffer {
831                self.queue.write_buffer(buffer, 0, data_slice);
832            }
833        }
834        #[cfg(not(feature = "wgpu_backend"))]
835        {
836            // Fallback implementation - just validate
837            if size > self.size {
838                eprintln!(
839                    "Warning: Data size {} exceeds buffer size {}",
840                    size, self.size
841                );
842            }
843            // In fallback mode, we just simulate the operation
844        }
845    }
846
847    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
848        #[cfg(feature = "wgpu_backend")]
849        {
850            // Validate data size
851            if size > self.size {
852                eprintln!(
853                    "Warning: Data size {} exceeds buffer size {}",
854                    size, self.size
855                );
856                return;
857            }
858
859            if let Some(ref buffer) = self.device_buffer {
860                let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
861                    label: Some("scirs2-readback"),
862                    size: size as u64,
863                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
864                    mapped_at_creation: false,
865                });
866                let mut encoder =
867                    self.device
868                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
869                            label: Some("scirs2-readback-enc"),
870                        });
871                encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
872                self.queue.submit(Some(encoder.finish()));
873                let slice = staging.slice(0..size as u64);
874                let (tx, rx) = std::sync::mpsc::channel();
875                slice.map_async(wgpu::MapMode::Read, move |r| {
876                    let _ = tx.send(r);
877                });
878                // TODO: explicit device.poll if necessary for certain platforms
879                if let Ok(Ok(())) = rx.recv() {
880                    let mapped = slice.get_mapped_range();
881                    let dst = std::slice::from_raw_parts_mut(data, size);
882                    dst.copy_from_slice(&mapped);
883                    drop(mapped);
884                    staging.unmap();
885                } else {
886                    eprintln!("Warning: map_async failed for readback");
887                }
888            }
889        }
890        #[cfg(not(feature = "wgpu_backend"))]
891        {
892            // Fallback implementation - just validate and zero out
893            if size > self.size {
894                eprintln!(
895                    "Warning: Data size {} exceeds buffer size {}",
896                    size, self.size
897                );
898            }
899
900            // Zero out the data as a placeholder
901            let data_slice = std::slice::from_raw_parts_mut(data, size);
902            data_slice.fill(0);
903        }
904    }
905
906    fn device_ptr(&self) -> u64 {
907        #[cfg(feature = "wgpu_backend")]
908        {
909            // WebGPU doesn't expose raw device pointers, so we return a placeholder
910            // In a real implementation, this might return a handle or ID
911            &self.device_buffer as *const _ as u64
912        }
913        #[cfg(not(feature = "wgpu_backend"))]
914        {
915            self.device_buffer as u64
916        }
917    }
918
919    fn as_any(&self) -> &dyn std::any::Any {
920        self
921    }
922}
923
924#[cfg(feature = "wgpu_backend")]
925impl WebGPUKernelHandle {
926    fn create_bind_group_from_params(
927        &self,
928        shader: &WebGPUShader,
929        params: &HashMap<String, KernelParam>,
930    ) -> Result<wgpu::BindGroup, GpuError> {
931        let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
932        // Hold uniform buffers so their lifetime extends until after bind_group creation
933        let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
934        let mut uniform_bytes: Vec<u8> = Vec::new();
935        for info in &shader.binding_infos {
936            match info.kind {
937                BindingKind::StorageRw | BindingKind::StorageRead => {
938                    if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
939                        if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
940                            if let Some(ref inner) = wbuf.device_buffer {
941                                entries.push(wgpu::BindGroupEntry {
942                                    binding: info.binding,
943                                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
944                                        buffer: inner,
945                                        offset: 0,
946                                        size: None,
947                                    }),
948                                });
949                            }
950                        }
951                    } else {
952                        return Err(GpuError::InvalidParameter(format!(
953                            "Missing buffer param '{}'",
954                            info.name
955                        )));
956                    }
957                }
958                BindingKind::Uniform => {
959                    // Collect all scalars/bytes with key prefix or exact match
960                    for (k, v) in params.iter() {
961                        if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
962                            match v {
963                                KernelParam::U32(u) => {
964                                    uniform_bytes.extend_from_slice(&u.to_le_bytes())
965                                }
966                                KernelParam::I32(i) => {
967                                    uniform_bytes.extend_from_slice(&i.to_le_bytes())
968                                }
969                                KernelParam::F32(f) => {
970                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
971                                }
972                                KernelParam::F64(f) => {
973                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
974                                }
975                                KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
976                                KernelParam::Buffer(_) => {}
977                            }
978                        }
979                    }
980                }
981            }
982        }
983        if !uniform_bytes.is_empty() {
984            while uniform_bytes.len() % 16 != 0 {
985                uniform_bytes.push(0);
986            }
987            if let Some(uinfo) = shader
988                .binding_infos
989                .iter()
990                .find(|b| matches!(b.kind, BindingKind::Uniform))
991            {
992                if let Ok(mut list) = self.ephemeral_uniforms.lock() {
993                    list.clear();
994                    let ubuf = self
995                        .device
996                        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
997                            label: Some("scirs2-uniforms"),
998                            contents: &uniform_bytes,
999                            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1000                        });
1001                    list.push(ubuf.clone());
1002                    owned_uniform_buffers.push(ubuf.clone());
1003                    let idx = owned_uniform_buffers.len() - 1;
1004                    let buf_ref = &owned_uniform_buffers[idx];
1005                    entries.push(wgpu::BindGroupEntry {
1006                        binding: uinfo.binding,
1007                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1008                            buffer: buf_ref,
1009                            offset: 0,
1010                            size: None,
1011                        }),
1012                    });
1013                }
1014            }
1015        } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1016            list.clear();
1017        }
1018        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1019            label: Some("scirs2-bind-group"),
1020            layout: &shader.bind_group_layout,
1021            entries: &entries,
1022        });
1023        Ok(bind_group)
1024    }
1025}
1026
1027impl Drop for WebGPUBuffer {
1028    fn drop(&mut self) {
1029        // Return buffer to memory pool if possible
1030        if let Ok(mut pool) = self.memory_pool.lock() {
1031            #[cfg(feature = "wgpu_backend")]
1032            {
1033                // In real implementation, would return buffer to pool
1034                if let Some(buffer) = self.device_buffer.take() {
1035                    pool.deallocate(buffer);
1036                }
1037            }
1038            #[cfg(not(feature = "wgpu_backend"))]
1039            {
1040                if let Some(buffer) = self.device_buffer.take() {
1041                    pool.deallocate(buffer);
1042                }
1043            }
1044        }
1045    }
1046}
1047
1048/// CPU fallback buffer for when WebGPU buffer allocation fails
1049/// This provides a graceful degradation when GPU memory is exhausted
1050struct WebGPUCpuFallbackBuffer {
1051    data: Vec<u8>,
1052    size: usize,
1053    #[allow(dead_code)]
1054    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1055}
1056
1057impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1058    fn size(&self) -> usize {
1059        self.size
1060    }
1061
1062    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1063        if size > self.size {
1064            eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1065            return;
1066        }
1067
1068        // Since this is a CPU fallback, we can use safe Rust internally
1069        let data_slice = std::slice::from_raw_parts(data, size);
1070        // We can't mutate self.data directly since &self is immutable
1071        // In a real implementation, this would require interior mutability
1072        eprintln!(
1073            "Warning: CPU fallback buffer copy_from_host called (size: {})",
1074            size
1075        );
1076    }
1077
1078    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1079        if size > self.size {
1080            eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1081            return;
1082        }
1083
1084        // Copy from CPU buffer to host
1085        let data_slice = std::slice::from_raw_parts_mut(data, size);
1086        let copy_size = size.min(self.data.len());
1087        data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1088
1089        eprintln!(
1090            "Warning: CPU fallback buffer copy_to_host called (size: {})",
1091            size
1092        );
1093    }
1094
1095    fn device_ptr(&self) -> u64 {
1096        self.data.as_ptr() as u64
1097    }
1098
1099    fn as_any(&self) -> &dyn std::any::Any {
1100        self
1101    }
1102}
1103
1104// Safety: WebGPUCpuFallbackBuffer is thread-safe since it only contains owned data
1105unsafe impl Send for WebGPUCpuFallbackBuffer {}
1106unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1107
1108/// WebGPU memory pool for efficient buffer management
1109struct WebGPUMemoryPool {
1110    #[cfg(feature = "wgpu_backend")]
1111    available_buffers: HashMap<usize, Vec<Buffer>>,
1112    #[cfg(not(feature = "wgpu_backend"))]
1113    available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1114    #[allow(dead_code)]
1115    total_size: usize,
1116    used_size: usize,
1117}
1118
1119impl WebGPUMemoryPool {
1120    fn new(totalsize: usize) -> Self {
1121        Self {
1122            available_buffers: HashMap::new(),
1123            total_size: totalsize,
1124            used_size: 0,
1125        }
1126    }
1127
1128    #[cfg(feature = "wgpu_backend")]
1129    fn allocate(&mut self, size: usize) -> Option<Buffer> {
1130        // Try to find a suitable buffer in the pool
1131        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1132            if let Some(buffer) = buffers.pop() {
1133                self.used_size += size;
1134                return Some(buffer);
1135            }
1136        }
1137        None
1138    }
1139
1140    #[cfg(not(feature = "wgpu_backend"))]
1141    fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1142        // Try to find a suitable buffer in the pool
1143        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1144            if let Some(buffer) = buffers.pop() {
1145                self.used_size += size;
1146                return Some(buffer);
1147            }
1148        }
1149        None
1150    }
1151
1152    #[cfg(feature = "wgpu_backend")]
1153    fn deallocate(&mut self, buffer: Buffer) {
1154        // Return buffer to pool
1155        let size = buffer.size() as usize;
1156        self.available_buffers
1157            .entry(size)
1158            .or_insert_with(Vec::new)
1159            .push(buffer);
1160        self.used_size = self.used_size.saturating_sub(size);
1161    }
1162
1163    #[cfg(not(feature = "wgpu_backend"))]
1164    fn deallocate(&mut self, buffer: WgpuBuffer) {
1165        // Fallback implementation - track the buffer
1166        let size = 1024; // Placeholder size
1167        self.available_buffers
1168            .entry(size)
1169            .or_insert_with(Vec::new)
1170            .push(buffer);
1171        self.used_size = self.used_size.saturating_sub(size);
1172    }
1173
1174    #[allow(dead_code)]
1175    fn get_memory_usage(&self) -> (usize, usize) {
1176        (self.used_size, self.total_size)
1177    }
1178}