Skip to main content

scirs2_core/gpu/backends/
wgpu.rs

1//! WebGPU backend implementation for GPU operations
2//!
3//! This module provides WebGPU-specific implementations for cross-platform GPU operations.
4
5use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7// wgpu 26 removed earlier Poll enum; Device::poll exists but Maintain enum not re-exported here; we avoid explicit polling for now.
8use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15    util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16    BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17    BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18    Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19    ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20    TextureSampleType, TextureViewDimension,
21};
22
23// Fallback types for when WebGPU is not available
24#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33/// A compiled WebGPU compute pipeline, containing all state needed to dispatch a compute shader.
34///
35/// Created by [`try_compile_wgsl`]. On hosts without a GPU adapter this is never constructed
36/// and that function returns an error instead.
37#[cfg(feature = "wgpu_backend")]
38pub struct WgpuComputePipeline {
39    /// The underlying wgpu compute pipeline.
40    pub pipeline: ComputePipeline,
41    /// The bind group layout derived from WGSL source inspection.
42    pub bind_group_layout: BindGroupLayout,
43    /// Workgroup size extracted from the `@workgroup_size(...)` attribute; defaults to `[64, 1, 1]`.
44    pub workgroup_size: [u32; 3],
45}
46
47#[cfg(feature = "wgpu_backend")]
48// SAFETY: wgpu's `ComputePipeline` and `BindGroupLayout` are `Send + Sync` on all native backends.
49unsafe impl Send for WgpuComputePipeline {}
50#[cfg(feature = "wgpu_backend")]
51unsafe impl Sync for WgpuComputePipeline {}
52
53/// Attempt to compile `source` as a WGSL compute shader and return a [`WgpuComputePipeline`].
54///
55/// # Errors
56/// - Returns an error if no wgpu adapter is available on the host (e.g. headless CI without GPU).
57/// - Returns an error if `source` contains invalid WGSL (wgpu panics on truly invalid WGSL;
58///   syntactically valid but semantically broken shaders will fail at pipeline creation).
59///
60/// # Example
61/// ```rust,no_run
62/// # #[cfg(feature = "wgpu_backend")]
63/// # {
64/// use scirs2_core::gpu::backends::try_compile_wgsl;
65/// let pipeline = try_compile_wgsl(r#"
66///     @group(0) @binding(0) var<storage, read_write> out: array<f32>;
67///     @compute @workgroup_size(64)
68///     fn main(@builtin(global_invocation_id) gid: vec3<u32>) { out[gid.x] = f32(gid.x); }
69/// "#).expect("shader compiled");
70/// let _ = pipeline;
71/// # }
72/// ```
73#[cfg(feature = "wgpu_backend")]
74pub fn try_compile_wgsl(source: &str) -> Result<WgpuComputePipeline, GpuError> {
75    let ctx = WebGPUContext::new()?;
76    ctx.compile_to_pipeline(source)
77}
78
79/// Run a vector-add compute shader end-to-end on the GPU.
80///
81/// Uploads `a` and `b` to device buffers, dispatches the WGSL kernel, then reads back the result.
82/// Returns `Ok(result_vec)` on success. Returns `Err` if no adapter is available.
83#[cfg(feature = "wgpu_backend")]
84pub fn run_vector_add_wgsl(a: &[f32], b: &[f32]) -> Result<Vec<f32>, GpuError> {
85    let ctx = WebGPUContext::new()?;
86    ctx.run_vector_add(a, b)
87}
88
89// WebGPU shader source templates — used by the kernel registry in kernels/mod.rs
90// for the GEMM BLAS kernel (exposed here as a named constant for external use).
91
92/// WGSL source for the GEMM kernel (tiled 8×8 matrix multiply).
93///
94/// Computes C = alpha * A * B + beta * C where A is M×K, B is K×N, C is M×N.
95///
96/// Buffers: 0 → matrix_a (read), 1 → matrix_b (read), 2 → matrix_c (read_write)
97/// Uniforms: M, N, K, alpha, beta
98pub const GEMM_SHADER_WGSL: &str = r#"
99@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
100@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
101@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
102
103struct GemmUniforms {
104    M: u32,
105    N: u32,
106    K: u32,
107    alpha: f32,
108    beta: f32,
109};
110
111@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
112
113@compute @workgroup_size(8, 8)
114fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
115    let row = global_id.x;
116    let col = global_id.y;
117
118    if row >= uniforms.M || col >= uniforms.N { return; }
119
120    var sum = 0.0f;
121    for (var k = 0u; k < uniforms.K; k++) {
122        sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
123    }
124
125    let idx = row * uniforms.N + col;
126    matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
127}
128"#;
129
130/// WebGPU context wrapper
131pub struct WebGPUContext {
132    #[cfg(feature = "wgpu_backend")]
133    device: Arc<Device>,
134    #[cfg(feature = "wgpu_backend")]
135    queue: Arc<Queue>,
136    #[cfg(not(feature = "wgpu_backend"))]
137    device: Arc<WgpuDevice>,
138    #[cfg(not(feature = "wgpu_backend"))]
139    queue: Arc<WgpuQueue>,
140    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
141    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
142}
143
144// WebGPU handles are safe to send between threads when properly synchronized
145unsafe impl Send for WebGPUContext {}
146unsafe impl Sync for WebGPUContext {}
147
148impl WebGPUContext {
149    /// Create a new WebGPU context
150    pub fn new() -> Result<Self, GpuError> {
151        #[cfg(feature = "wgpu_backend")]
152        {
153            // Real WebGPU implementation
154            let instance_desc = InstanceDescriptor {
155                backends: Backends::all(),
156                flags: wgpu::InstanceFlags::default(),
157                memory_budget_thresholds: Default::default(),
158                backend_options: Default::default(),
159                display: None,
160            };
161            let instance = Instance::new(instance_desc);
162
163            let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
164                power_preference: PowerPreference::HighPerformance,
165                compatible_surface: None,
166                force_fallback_adapter: false,
167            }))
168            .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
169
170            let device_descriptor = DeviceDescriptor {
171                label: Some("SciRS2 WebGPU Device"),
172                required_features: Features::empty(),
173                required_limits: Limits::default(),
174                // Newer wgpu versions removed/changed some fields (e.g. trace Option). Use defaults for the rest.
175                ..Default::default()
176            };
177
178            let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
179                .map_err(|e| GpuError::Other(format!("{e}")))?;
180
181            Ok(Self {
182                device: Arc::new(device),
183                queue: Arc::new(queue),
184                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
185                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
186            })
187        }
188        #[cfg(not(feature = "wgpu_backend"))]
189        {
190            // Fallback implementation
191            let device = Self::initialize_webgpu()?;
192            let queue = Self::create_queue(device)?;
193
194            Ok(Self {
195                device,
196                queue,
197                compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
198                memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), // 1GB pool
199            })
200        }
201    }
202
203    /// Check if WebGPU is available and working
204    pub fn is_available() -> bool {
205        #[cfg(feature = "wgpu_backend")]
206        {
207            // Real WebGPU implementation - try to create an instance and adapter
208            let instance_desc = InstanceDescriptor {
209                backends: Backends::all(),
210                flags: wgpu::InstanceFlags::default(),
211                memory_budget_thresholds: Default::default(),
212                backend_options: Default::default(),
213                display: None,
214            };
215            let instance = Instance::new(instance_desc);
216
217            // Try to get an adapter (this is async, so we use a simple runtime check)
218            pollster::block_on(async {
219                instance
220                    .request_adapter(&RequestAdapterOptions {
221                        power_preference: PowerPreference::default(),
222                        compatible_surface: None,
223                        force_fallback_adapter: false,
224                    })
225                    .await
226                    .is_ok()
227            })
228        }
229        #[cfg(not(feature = "wgpu_backend"))]
230        {
231            // Fallback: return false since we don't have real WebGPU
232            false
233        }
234    }
235
236    /// Compile a shader from WGSL source
237    fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
238        #[cfg(feature = "wgpu_backend")]
239        {
240            // Real WebGPU implementation
241            let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
242                label: Some(name),
243                source: ShaderSource::Wgsl(source.into()),
244            });
245
246            // Extract entry point from source or use default
247            let entry_point = Self::extract_entry_point(source).unwrap_or("main");
248
249            // Create bind group layout + reflection infos
250            let (bind_group_layout, binding_infos) =
251                self.create_bind_group_layout_from_source(source, name)?;
252
253            // Create pipeline layout with our bind group layout
254            let pipeline_layout =
255                self.device
256                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
257                        label: Some(&format!("{}_layout", name)),
258                        bind_group_layouts: &[Some(&bind_group_layout)],
259                        // wgpu 28+: immediate_size replaces push_constant_ranges
260                        ..Default::default()
261                    });
262
263            let compute_pipeline =
264                self.device
265                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
266                        label: Some(&format!("{}_pipeline", name)),
267                        layout: Some(&pipeline_layout),
268                        module: &shader_module,
269                        entry_point: Some(entry_point),
270                        compilation_options: Default::default(),
271                        cache: None,
272                    });
273
274            let workgroup_size = extract_workgroup_size(source);
275
276            Ok(WebGPUShader {
277                pipeline: compute_pipeline,
278                bind_group_layout,
279                name: name.to_string(),
280                binding_infos,
281                workgroup_size,
282            })
283        }
284        #[cfg(not(feature = "wgpu_backend"))]
285        {
286            // Fallback implementation
287            let pipeline = Self::compile_wgsl_source(source, name)?;
288
289            Ok(WebGPUShader {
290                pipeline,
291                bind_group_layout: std::ptr::null_mut(),
292                name: name.to_string(),
293                binding_infos: Vec::new(),
294                workgroup_size: [64, 1, 1],
295            })
296        }
297    }
298
299    /// Create bind group layout from WGSL source analysis
300    #[cfg(feature = "wgpu_backend")]
301    fn create_bind_group_layout_from_source(
302        &self,
303        source: &str,
304        name: &str,
305    ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
306        #[derive(Default)]
307        struct PendingAttr {
308            group: Option<u32>,
309            binding: Option<u32>,
310        }
311        let mut pending = PendingAttr::default();
312        let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
313        let mut infos: Vec<BindingInfo> = Vec::new();
314
315        fn strip_comment(line: &str) -> &str {
316            line.split_once("//").map(|(a, _)| a).unwrap_or(line)
317        }
318
319        for raw_line in source.lines() {
320            let line = strip_comment(raw_line).trim();
321            if line.is_empty() {
322                continue;
323            }
324
325            if let Some(i) = line.find("@group(") {
326                if let Some(end) = line[i + 7..].find(')') {
327                    if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
328                        pending.group = Some(g);
329                    }
330                }
331            }
332            if let Some(i) = line.find("@binding(") {
333                if let Some(end) = line[i + 9..].find(')') {
334                    if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
335                        pending.binding = Some(b);
336                    }
337                }
338            }
339
340            if line.contains("var<") {
341                // variable declaration
342                if pending.group.unwrap_or(0) == 0 {
343                    // only group 0 for now
344                    let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
345                    let name = extract_var_name(line).unwrap_or("");
346                    let storage = line.contains("var<storage");
347                    let uniform = line.contains("var<uniform");
348                    let read_only = storage
349                        && (line.contains(", read>")
350                            || line.contains("var<storage, read>")
351                            || line.contains("var<storage, read,"));
352                    if storage {
353                        entries.push(BindGroupLayoutEntry {
354                            binding: binding_num,
355                            visibility: ShaderStages::COMPUTE,
356                            ty: BindingType::Buffer {
357                                ty: BufferBindingType::Storage { read_only },
358                                has_dynamic_offset: false,
359                                min_binding_size: None,
360                            },
361                            count: None,
362                        });
363                        infos.push(BindingInfo {
364                            binding: binding_num,
365                            name: name.to_string(),
366                            kind: if read_only {
367                                BindingKind::StorageRead
368                            } else {
369                                BindingKind::StorageRw
370                            },
371                        });
372                    } else if uniform {
373                        entries.push(BindGroupLayoutEntry {
374                            binding: binding_num,
375                            visibility: ShaderStages::COMPUTE,
376                            ty: BindingType::Buffer {
377                                ty: BufferBindingType::Uniform,
378                                has_dynamic_offset: false,
379                                min_binding_size: None,
380                            },
381                            count: None,
382                        });
383                        infos.push(BindingInfo {
384                            binding: binding_num,
385                            name: name.to_string(),
386                            kind: BindingKind::Uniform,
387                        });
388                    }
389                }
390                pending = PendingAttr::default();
391            }
392        }
393
394        if entries.is_empty() {
395            entries.push(BindGroupLayoutEntry {
396                binding: 0,
397                visibility: ShaderStages::COMPUTE,
398                ty: BindingType::Buffer {
399                    ty: BufferBindingType::Storage { read_only: false },
400                    has_dynamic_offset: false,
401                    min_binding_size: None,
402                },
403                count: None,
404            });
405            infos.push(BindingInfo {
406                binding: 0,
407                name: "_unnamed".into(),
408                kind: BindingKind::StorageRw,
409            });
410        }
411
412        // Deduplicate by binding number
413        let mut seen = std::collections::HashSet::new();
414        let mut dedup_entries = Vec::new();
415        let mut dedup_infos = Vec::new();
416        for (e, info) in entries.into_iter().zip(infos) {
417            if seen.insert(e.binding) {
418                dedup_entries.push(e);
419                dedup_infos.push(info);
420            }
421        }
422
423        let bind_group_layout = self
424            .device
425            .create_bind_group_layout(&BindGroupLayoutDescriptor {
426                label: Some(&format!("{}_bind_group_layout", name)),
427                entries: &dedup_entries,
428            });
429        Ok((bind_group_layout, dedup_infos))
430    }
431
432    /// Return a reference to the underlying `wgpu::Device`.
433    #[cfg(feature = "wgpu_backend")]
434    pub fn device(&self) -> &Device {
435        &self.device
436    }
437
438    /// Return a reference to the underlying `wgpu::Queue`.
439    #[cfg(feature = "wgpu_backend")]
440    pub fn queue(&self) -> &Queue {
441        &self.queue
442    }
443
444    /// Allocate device memory
445    #[cfg(feature = "wgpu_backend")]
446    pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
447        let buffer = self.device.create_buffer(&BufferDescriptor {
448            label: Some("SciRS2 Buffer"),
449            size: size as u64,
450            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
451            mapped_at_creation: false,
452        });
453
454        Ok(buffer)
455    }
456
457    /// Allocate device memory (fallback)
458    #[cfg(not(feature = "wgpu_backend"))]
459    pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
460        // Fallback implementation: return a simulated buffer handle
461        Ok((0x1000 + size) as WgpuBuffer)
462    }
463
464    // Fallback methods for when WebGPU is not available
465    #[cfg(not(feature = "wgpu_backend"))]
466    fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
467        // Stub implementation
468        Ok(0x1 as WgpuDevice)
469    }
470
471    #[cfg(not(feature = "wgpu_backend"))]
472    fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
473        // Stub implementation
474        Ok(0x2 as WgpuQueue)
475    }
476
477    #[cfg(not(feature = "wgpu_backend"))]
478    fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
479        // Stub implementation
480        Ok(0x3 as WgpuComputePipeline)
481    }
482
483    /// Compile WGSL source into a [`WgpuComputePipeline`] (real-wgpu path only).
484    ///
485    /// This exposes the same compilation path as [`try_compile_wgsl`] but operates
486    /// on an already-created context so the adapter/device creation overhead is
487    /// incurred only once.
488    #[cfg(feature = "wgpu_backend")]
489    pub fn compile_to_pipeline(&self, source: &str) -> Result<WgpuComputePipeline, GpuError> {
490        let shader = self.compile_shader_internal(source, "scirs2-pipeline")?;
491        Ok(WgpuComputePipeline {
492            pipeline: shader.pipeline,
493            bind_group_layout: shader.bind_group_layout,
494            workgroup_size: shader.workgroup_size,
495        })
496    }
497
498    /// Run a vector-add end-to-end: upload `a` and `b`, dispatch, read back `result`.
499    #[cfg(feature = "wgpu_backend")]
500    pub fn run_vector_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, GpuError> {
501        use wgpu::{util::DeviceExt as _, BufferUsages};
502
503        let n = a.len();
504        if n != b.len() {
505            return Err(GpuError::InvalidParameter(
506                "vectors must have equal length".into(),
507            ));
508        }
509
510        const VECTOR_ADD_WGSL: &str = r#"
511@group(0) @binding(0) var<storage, read>       a      : array<f32>;
512@group(0) @binding(1) var<storage, read>       b      : array<f32>;
513@group(0) @binding(2) var<storage, read_write> result : array<f32>;
514
515@compute @workgroup_size(64)
516fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
517    let idx = global_id.x;
518    if idx < arrayLength(&result) {
519        result[idx] = a[idx] + b[idx];
520    }
521}
522"#;
523
524        // Compile shader
525        let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
526            label: Some("vector-add"),
527            source: ShaderSource::Wgsl(VECTOR_ADD_WGSL.into()),
528        });
529
530        // Build bind group layout explicitly (3 storage bindings)
531        let bgl = self
532            .device
533            .create_bind_group_layout(&BindGroupLayoutDescriptor {
534                label: Some("vector-add-bgl"),
535                entries: &[
536                    // binding 0: a (read-only storage)
537                    BindGroupLayoutEntry {
538                        binding: 0,
539                        visibility: ShaderStages::COMPUTE,
540                        ty: BindingType::Buffer {
541                            ty: BufferBindingType::Storage { read_only: true },
542                            has_dynamic_offset: false,
543                            min_binding_size: None,
544                        },
545                        count: None,
546                    },
547                    // binding 1: b (read-only storage)
548                    BindGroupLayoutEntry {
549                        binding: 1,
550                        visibility: ShaderStages::COMPUTE,
551                        ty: BindingType::Buffer {
552                            ty: BufferBindingType::Storage { read_only: true },
553                            has_dynamic_offset: false,
554                            min_binding_size: None,
555                        },
556                        count: None,
557                    },
558                    // binding 2: result (read-write storage)
559                    BindGroupLayoutEntry {
560                        binding: 2,
561                        visibility: ShaderStages::COMPUTE,
562                        ty: BindingType::Buffer {
563                            ty: BufferBindingType::Storage { read_only: false },
564                            has_dynamic_offset: false,
565                            min_binding_size: None,
566                        },
567                        count: None,
568                    },
569                ],
570            });
571
572        let pipeline_layout = self
573            .device
574            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
575                label: Some("vector-add-layout"),
576                bind_group_layouts: &[Some(&bgl)],
577                ..Default::default()
578            });
579
580        let pipeline = self
581            .device
582            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
583                label: Some("vector-add-pipeline"),
584                layout: Some(&pipeline_layout),
585                module: &shader_module,
586                entry_point: Some("main"),
587                compilation_options: Default::default(),
588                cache: None,
589            });
590
591        // Upload input buffers
592        let a_bytes: Vec<u8> = a.iter().flat_map(|f| f.to_le_bytes()).collect();
593        let b_bytes: Vec<u8> = b.iter().flat_map(|f| f.to_le_bytes()).collect();
594        let result_size = std::mem::size_of_val(a) as u64;
595
596        let buf_a = self
597            .device
598            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
599                label: Some("vector-add-a"),
600                contents: &a_bytes,
601                usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
602            });
603        let buf_b = self
604            .device
605            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
606                label: Some("vector-add-b"),
607                contents: &b_bytes,
608                usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
609            });
610        let buf_result = self.device.create_buffer(&BufferDescriptor {
611            label: Some("vector-add-result"),
612            size: result_size,
613            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
614            mapped_at_creation: false,
615        });
616
617        // Bind group
618        let bind_group = self.device.create_bind_group(&BindGroupDescriptor {
619            label: Some("vector-add-bg"),
620            layout: &bgl,
621            entries: &[
622                BindGroupEntry {
623                    binding: 0,
624                    resource: buf_a.as_entire_binding(),
625                },
626                BindGroupEntry {
627                    binding: 1,
628                    resource: buf_b.as_entire_binding(),
629                },
630                BindGroupEntry {
631                    binding: 2,
632                    resource: buf_result.as_entire_binding(),
633                },
634            ],
635        });
636
637        // Encode and dispatch
638        let mut encoder = self
639            .device
640            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
641                label: Some("vector-add-encoder"),
642            });
643        {
644            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
645                label: Some("vector-add-pass"),
646                timestamp_writes: None,
647            });
648            cpass.set_pipeline(&pipeline);
649            cpass.set_bind_group(0, &bind_group, &[]);
650            let workgroups = (n as u32 + 63) / 64;
651            cpass.dispatch_workgroups(workgroups, 1, 1);
652        }
653
654        // Readback via staging buffer
655        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
656            label: Some("vector-add-staging"),
657            size: result_size,
658            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
659            mapped_at_creation: false,
660        });
661        encoder.copy_buffer_to_buffer(&buf_result, 0, &staging, 0, result_size);
662        self.queue.submit(Some(encoder.finish()));
663
664        // Poll until GPU work completes (required on native backends before map_async fires)
665        self.device
666            .poll(wgpu::PollType::wait_indefinitely())
667            .map_err(|e| GpuError::Other(format!("GPU poll error: {e:?}")))?;
668
669        let slice = staging.slice(0..result_size);
670        let (tx, rx) = std::sync::mpsc::channel();
671        slice.map_async(wgpu::MapMode::Read, move |r| {
672            let _ = tx.send(r);
673        });
674
675        // Poll again to drive the map callback to completion
676        self.device
677            .poll(wgpu::PollType::wait_indefinitely())
678            .map_err(|e| GpuError::Other(format!("GPU poll error during map: {e:?}")))?;
679
680        rx.recv()
681            .map_err(|_| GpuError::Other("Channel closed during map_async".into()))?
682            .map_err(|e| GpuError::Other(format!("map_async failed: {e:?}")))?;
683
684        let mapped = slice.get_mapped_range();
685        let result: Vec<f32> = mapped
686            .chunks_exact(4)
687            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
688            .collect();
689        drop(mapped);
690        staging.unmap();
691
692        Ok(result)
693    }
694
695    /// Extract the entry point function name from WGSL source code
696    fn extract_entry_point(source: &str) -> Option<&str> {
697        let lines: Vec<&str> = source.lines().collect();
698
699        for (i, line) in lines.iter().enumerate() {
700            let trimmed = line.trim();
701
702            // Check if this line contains @compute
703            if trimmed.contains("@compute") {
704                // The function might be on the same line or the next line
705                let mut search_line = trimmed;
706                let mut search_idx = 0;
707
708                // If @compute and function are not on the same line, check next line
709                if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
710                    search_idx += 1;
711                    search_line = lines[search_idx].trim();
712                }
713
714                // Extract function name
715                if let Some(start) = search_line.find("fn ") {
716                    let remaining = &search_line[start + 3..];
717                    if let Some(end) = remaining.find('(') {
718                        let funcname = remaining[..end].trim();
719                        return Some(funcname);
720                    }
721                }
722            }
723        }
724
725        None
726    }
727}
728
729impl GpuContextImpl for WebGPUContext {
730    fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
731        // Try to allocate from memory pool first
732        if let Ok(mut pool) = self.memory_pool.lock() {
733            if let Some(device_buffer) = pool.allocate(size) {
734                return Arc::new(WebGPUBuffer {
735                    device_buffer: Some(device_buffer),
736                    #[cfg(feature = "wgpu_backend")]
737                    queue: Arc::clone(&self.queue),
738                    #[cfg(feature = "wgpu_backend")]
739                    device: Arc::clone(&self.device),
740                    #[cfg(not(feature = "wgpu_backend"))]
741                    queue: self.queue,
742                    size,
743                    memory_pool: Arc::clone(&self.memory_pool),
744                });
745            }
746        }
747
748        // Fallback to direct allocation
749        let device_buffer = match self.allocate_device_memory(size) {
750            Ok(buffer) => buffer,
751            Err(e) => {
752                // Log the WebGPU allocation failure and create a CPU fallback
753                eprintln!(
754                    "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
755                    e
756                );
757
758                #[cfg(feature = "wgpu_backend")]
759                {
760                    // Create a CPU fallback buffer with minimal size for WebGPU compatibility
761                    // This is a last resort when GPU memory is exhausted
762                    return Arc::new(WebGPUCpuFallbackBuffer {
763                        data: vec![0u8; size],
764                        size,
765                        memory_pool: Arc::clone(&self.memory_pool),
766                    });
767                }
768                #[cfg(not(feature = "wgpu_backend"))]
769                {
770                    (0x2000 + size) as WgpuBuffer
771                }
772            }
773        };
774
775        Arc::new(WebGPUBuffer {
776            device_buffer: Some(device_buffer),
777            #[cfg(feature = "wgpu_backend")]
778            queue: Arc::clone(&self.queue),
779            #[cfg(feature = "wgpu_backend")]
780            device: Arc::clone(&self.device),
781            #[cfg(not(feature = "wgpu_backend"))]
782            queue: self.queue,
783            size,
784            memory_pool: Arc::clone(&self.memory_pool),
785        })
786    }
787
788    fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
789        Arc::new(WebGPUCompiler {
790            context: Arc::new(WebGPUContext {
791                memory_pool: Arc::clone(&self.memory_pool),
792                compiled_shaders: Arc::clone(&self.compiled_shaders),
793                #[cfg(feature = "wgpu_backend")]
794                device: Arc::clone(&self.device),
795                #[cfg(feature = "wgpu_backend")]
796                queue: Arc::clone(&self.queue),
797                #[cfg(not(feature = "wgpu_backend"))]
798                device: Arc::clone(&self.device),
799                #[cfg(not(feature = "wgpu_backend"))]
800                queue: Arc::clone(&self.queue),
801            }),
802        })
803    }
804
805    fn as_any(&self) -> &dyn std::any::Any {
806        self
807    }
808}
809
810/// WebGPU shader wrapper (augmented with basic reflection info)
811struct WebGPUShader {
812    #[cfg(feature = "wgpu_backend")]
813    pipeline: ComputePipeline,
814    #[cfg(not(feature = "wgpu_backend"))]
815    pipeline: WgpuComputePipeline,
816    #[cfg(feature = "wgpu_backend")]
817    #[allow(dead_code)]
818    bind_group_layout: BindGroupLayout,
819    #[cfg(not(feature = "wgpu_backend"))]
820    #[allow(dead_code)]
821    bind_group_layout: *mut std::ffi::c_void,
822    #[allow(dead_code)]
823    name: String,
824    #[allow(dead_code)]
825    binding_infos: Vec<BindingInfo>, // basic reflection info (names may be synthetic when parser can't extract)
826    #[allow(dead_code)]
827    workgroup_size: [u32; 3],
828}
829
830// WebGPU shader handles are safe to send between threads when properly synchronized
831unsafe impl Send for WebGPUShader {}
832unsafe impl Sync for WebGPUShader {}
833
834/// WebGPU compiler implementation
835struct WebGPUCompiler {
836    context: Arc<WebGPUContext>,
837}
838
839impl GpuCompilerImpl for WebGPUCompiler {
840    fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
841        let shader = self.context.compile_shader_internal(source, "shader")?;
842        Ok(Arc::new(WebGPUKernelHandle {
843            shader_name: shader.name.clone(),
844            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
845            params: Arc::new(Mutex::new(HashMap::new())),
846            #[cfg(feature = "wgpu_backend")]
847            device: Arc::clone(&self.context.device),
848            #[cfg(feature = "wgpu_backend")]
849            queue: Arc::clone(&self.context.queue),
850            #[cfg(feature = "wgpu_backend")]
851            ephemeral_uniforms: Mutex::new(Vec::new()),
852            #[cfg(not(feature = "wgpu_backend"))]
853            device: self.context.device,
854            #[cfg(not(feature = "wgpu_backend"))]
855            queue: self.context.queue,
856        }))
857    }
858
859    fn compile_typed(
860        &self,
861        name: &str,
862        _input_type: std::any::TypeId,
863        _output_type: std::any::TypeId,
864    ) -> Arc<dyn GpuKernelImpl> {
865        Arc::new(WebGPUKernelHandle {
866            shader_name: name.to_string(),
867            compiled_shaders: Arc::clone(&self.context.compiled_shaders),
868            params: Arc::new(Mutex::new(HashMap::new())),
869            #[cfg(feature = "wgpu_backend")]
870            device: Arc::clone(&self.context.device),
871            #[cfg(feature = "wgpu_backend")]
872            queue: Arc::clone(&self.context.queue),
873            #[cfg(feature = "wgpu_backend")]
874            ephemeral_uniforms: Mutex::new(Vec::new()),
875            #[cfg(not(feature = "wgpu_backend"))]
876            device: self.context.device,
877            #[cfg(not(feature = "wgpu_backend"))]
878            queue: self.context.queue,
879        })
880    }
881}
882
883/// WebGPU kernel handle for execution
884struct WebGPUKernelHandle {
885    shader_name: String,
886    compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
887    params: Arc<Mutex<HashMap<String, KernelParam>>>,
888    #[cfg(feature = "wgpu_backend")]
889    device: Arc<Device>,
890    #[cfg(feature = "wgpu_backend")]
891    queue: Arc<Queue>,
892    #[cfg(feature = "wgpu_backend")]
893    ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
894    #[cfg(not(feature = "wgpu_backend"))]
895    device: WgpuDevice,
896    #[cfg(not(feature = "wgpu_backend"))]
897    queue: WgpuQueue,
898}
899
900enum KernelParam {
901    #[allow(dead_code)]
902    Buffer(Arc<dyn GpuBufferImpl>),
903    #[allow(dead_code)]
904    U32(u32),
905    #[allow(dead_code)]
906    I32(i32),
907    #[allow(dead_code)]
908    F32(f32),
909    #[allow(dead_code)]
910    F64(f64),
911    Bytes(Vec<u8>),
912}
913
914#[derive(Clone, Debug)]
915enum BindingKind {
916    StorageRw,
917    StorageRead,
918    Uniform,
919}
920
921#[derive(Clone, Debug)]
922struct BindingInfo {
923    binding: u32,
924    name: String,
925    kind: BindingKind,
926}
927
928/// Extract the `@workgroup_size(x [, y [, z]])` values from WGSL source.
929/// Returns `[64, 1, 1]` as a sensible default if the attribute is not present or unparseable.
930fn extract_workgroup_size(source: &str) -> [u32; 3] {
931    for line in source.lines() {
932        let trimmed = line.trim();
933        if let Some(start) = trimmed.find("@workgroup_size(") {
934            let after = &trimmed[start + "@workgroup_size(".len()..];
935            if let Some(end) = after.find(')') {
936                let inner = &after[..end];
937                let parts: Vec<u32> = inner
938                    .split(',')
939                    .filter_map(|s| s.trim().parse::<u32>().ok())
940                    .collect();
941                return match parts.as_slice() {
942                    [x] => [*x, 1, 1],
943                    [x, y] => [*x, *y, 1],
944                    [x, y, z, ..] => [*x, *y, *z],
945                    _ => [64, 1, 1],
946                };
947            }
948        }
949    }
950    [64, 1, 1]
951}
952
953fn extract_var_name(line: &str) -> Option<&str> {
954    if let Some(var_start) = line.find("var<") {
955        let after_var = &line[var_start..];
956        if let Some(close) = after_var.find('>') {
957            let after = &after_var[close + 1..];
958            let after = after.trim_start();
959            if let Some(colon) = after.find(':') {
960                let name_part = after[..colon].trim();
961                if !name_part.is_empty() {
962                    return Some(name_part);
963                }
964            }
965        }
966    }
967    None
968}
969
970impl GpuKernelImpl for WebGPUKernelHandle {
971    fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
972        if let Ok(mut params) = self.params.lock() {
973            params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
974        }
975    }
976
977    fn set_u32(&self, name: &str, value: u32) {
978        if let Ok(mut params) = self.params.lock() {
979            params.insert(name.to_string(), KernelParam::U32(value));
980        }
981    }
982
983    fn set_i32(&self, name: &str, value: i32) {
984        if let Ok(mut params) = self.params.lock() {
985            params.insert(name.to_string(), KernelParam::I32(value));
986        }
987    }
988
989    fn set_f32(&self, name: &str, value: f32) {
990        if let Ok(mut params) = self.params.lock() {
991            params.insert(name.to_string(), KernelParam::F32(value));
992        }
993    }
994
995    fn set_f64(&self, name: &str, value: f64) {
996        if let Ok(mut params) = self.params.lock() {
997            params.insert(name.to_string(), KernelParam::F64(value));
998        }
999    }
1000
1001    #[allow(dead_code)]
1002    // raw bytes helper removed from trait; use internal helper if needed
1003
1004    fn dispatch(&self, workgroups: [u32; 3]) {
1005        #[cfg(feature = "wgpu_backend")]
1006        {
1007            // Real WebGPU compute dispatch
1008            let shaders = match self.compiled_shaders.lock() {
1009                Ok(g) => g,
1010                Err(_) => return,
1011            };
1012            if let Some(shader) = shaders.get(&self.shader_name) {
1013                let params = match self.params.lock() {
1014                    Ok(g) => g,
1015                    Err(_) => return,
1016                };
1017
1018                // Create command encoder
1019                let mut encoder =
1020                    self.device
1021                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1022                            label: Some("Compute Command Encoder"),
1023                        });
1024
1025                // Begin compute pass
1026                {
1027                    let mut compute_pass =
1028                        encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1029                            label: Some("Compute Pass"),
1030                            timestamp_writes: None,
1031                        });
1032
1033                    // Set the compute pipeline
1034                    compute_pass.set_pipeline(&shader.pipeline);
1035
1036                    if let Ok(bind_group) = self.create_bind_group_from_params(shader, &params) {
1037                        compute_pass.set_bind_group(0, &bind_group, &[]);
1038                    }
1039                    // else: bind group creation failed; dispatch will proceed but produce undefined results
1040
1041                    // Dispatch the compute shader
1042                    compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
1043                }
1044
1045                // Submit the command buffer
1046                let command_buffer = encoder.finish();
1047                self.queue.submit(std::iter::once(command_buffer));
1048            }
1049        }
1050        #[cfg(not(feature = "wgpu_backend"))]
1051        {
1052            // Fallback: no GPU available; dispatch is a no-op
1053            let _ = workgroups;
1054            let _ = &self.shader_name;
1055        }
1056    }
1057}
1058
1059/// WebGPU buffer implementation
1060struct WebGPUBuffer {
1061    #[cfg(feature = "wgpu_backend")]
1062    device_buffer: Option<Buffer>,
1063    #[cfg(feature = "wgpu_backend")]
1064    queue: Arc<Queue>,
1065    #[cfg(feature = "wgpu_backend")]
1066    device: Arc<Device>,
1067    #[cfg(not(feature = "wgpu_backend"))]
1068    device_buffer: Option<WgpuBuffer>,
1069    #[cfg(not(feature = "wgpu_backend"))]
1070    queue: WgpuQueue,
1071    size: usize,
1072    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1073}
1074
1075// WebGPU buffer handles are safe to send between threads when properly synchronized
1076// The real wgpu types (Buffer, Queue) are Send + Sync
1077// For fallback types (raw pointers), we assume proper synchronization is handled externally
1078unsafe impl Send for WebGPUBuffer {}
1079unsafe impl Sync for WebGPUBuffer {}
1080
1081impl GpuBufferImpl for WebGPUBuffer {
1082    fn size(&self) -> usize {
1083        self.size
1084    }
1085
1086    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1087        #[cfg(feature = "wgpu_backend")]
1088        {
1089            // Validate data size
1090            if size > self.size {
1091                // In unsafe context, we can't return an error, so we'll just log and return
1092                eprintln!(
1093                    "Warning: Data size {} exceeds buffer size {}",
1094                    size, self.size
1095                );
1096                return;
1097            }
1098
1099            // Convert raw pointer to slice for WebGPU API
1100            let data_slice = std::slice::from_raw_parts(data, size);
1101
1102            // Real WebGPU implementation - write data to buffer
1103            if let Some(ref buffer) = self.device_buffer {
1104                self.queue.write_buffer(buffer, 0, data_slice);
1105            }
1106        }
1107        #[cfg(not(feature = "wgpu_backend"))]
1108        {
1109            // Fallback implementation - just validate
1110            if size > self.size {
1111                eprintln!(
1112                    "Warning: Data size {} exceeds buffer size {}",
1113                    size, self.size
1114                );
1115            }
1116            // In fallback mode, we just simulate the operation
1117        }
1118    }
1119
1120    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1121        #[cfg(feature = "wgpu_backend")]
1122        {
1123            // Validate data size
1124            if size > self.size {
1125                eprintln!(
1126                    "Warning: Data size {} exceeds buffer size {}",
1127                    size, self.size
1128                );
1129                return;
1130            }
1131
1132            if let Some(ref buffer) = self.device_buffer {
1133                let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
1134                    label: Some("scirs2-readback"),
1135                    size: size as u64,
1136                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1137                    mapped_at_creation: false,
1138                });
1139                let mut encoder =
1140                    self.device
1141                        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1142                            label: Some("scirs2-readback-enc"),
1143                        });
1144                encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
1145                self.queue.submit(Some(encoder.finish()));
1146
1147                // Poll the device until all submitted work completes before mapping.
1148                // This is required on all native wgpu backends (Vulkan, Metal, DX12)
1149                // to ensure the copy completes before the slice can be mapped.
1150                let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
1151
1152                let slice = staging.slice(0..size as u64);
1153                let (tx, rx) = std::sync::mpsc::channel();
1154                slice.map_async(wgpu::MapMode::Read, move |r| {
1155                    let _ = tx.send(r);
1156                });
1157                // Drive map callback to completion
1158                let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
1159                if let Ok(Ok(())) = rx.recv() {
1160                    let mapped = slice.get_mapped_range();
1161                    let dst = std::slice::from_raw_parts_mut(data, size);
1162                    dst.copy_from_slice(&mapped);
1163                    drop(mapped);
1164                    staging.unmap();
1165                } else {
1166                    eprintln!("Warning: map_async failed for readback");
1167                }
1168            }
1169        }
1170        #[cfg(not(feature = "wgpu_backend"))]
1171        {
1172            // Fallback implementation - just validate and zero out
1173            if size > self.size {
1174                eprintln!(
1175                    "Warning: Data size {} exceeds buffer size {}",
1176                    size, self.size
1177                );
1178            }
1179
1180            // Zero out the data as a placeholder
1181            let data_slice = std::slice::from_raw_parts_mut(data, size);
1182            data_slice.fill(0);
1183        }
1184    }
1185
1186    fn device_ptr(&self) -> u64 {
1187        #[cfg(feature = "wgpu_backend")]
1188        {
1189            // WebGPU doesn't expose raw device pointers, so we return a placeholder
1190            // In a real implementation, this might return a handle or ID
1191            &self.device_buffer as *const _ as u64
1192        }
1193        #[cfg(not(feature = "wgpu_backend"))]
1194        {
1195            self.device_buffer as u64
1196        }
1197    }
1198
1199    fn as_any(&self) -> &dyn std::any::Any {
1200        self
1201    }
1202}
1203
1204#[cfg(feature = "wgpu_backend")]
1205impl WebGPUKernelHandle {
1206    fn create_bind_group_from_params(
1207        &self,
1208        shader: &WebGPUShader,
1209        params: &HashMap<String, KernelParam>,
1210    ) -> Result<wgpu::BindGroup, GpuError> {
1211        let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
1212        // Hold uniform buffers so their lifetime extends until after bind_group creation
1213        let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
1214        let mut uniform_bytes: Vec<u8> = Vec::new();
1215        for info in &shader.binding_infos {
1216            match info.kind {
1217                BindingKind::StorageRw | BindingKind::StorageRead => {
1218                    if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
1219                        if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
1220                            if let Some(ref inner) = wbuf.device_buffer {
1221                                entries.push(wgpu::BindGroupEntry {
1222                                    binding: info.binding,
1223                                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1224                                        buffer: inner,
1225                                        offset: 0,
1226                                        size: None,
1227                                    }),
1228                                });
1229                            }
1230                        }
1231                    } else {
1232                        return Err(GpuError::InvalidParameter(format!(
1233                            "Missing buffer param '{}'",
1234                            info.name
1235                        )));
1236                    }
1237                }
1238                BindingKind::Uniform => {
1239                    // Collect all scalars/bytes with key prefix or exact match
1240                    for (k, v) in params.iter() {
1241                        if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
1242                            match v {
1243                                KernelParam::U32(u) => {
1244                                    uniform_bytes.extend_from_slice(&u.to_le_bytes())
1245                                }
1246                                KernelParam::I32(i) => {
1247                                    uniform_bytes.extend_from_slice(&i.to_le_bytes())
1248                                }
1249                                KernelParam::F32(f) => {
1250                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
1251                                }
1252                                KernelParam::F64(f) => {
1253                                    uniform_bytes.extend_from_slice(&f.to_le_bytes())
1254                                }
1255                                KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
1256                                KernelParam::Buffer(_) => {}
1257                            }
1258                        }
1259                    }
1260                }
1261            }
1262        }
1263        if !uniform_bytes.is_empty() {
1264            while uniform_bytes.len() % 16 != 0 {
1265                uniform_bytes.push(0);
1266            }
1267            if let Some(uinfo) = shader
1268                .binding_infos
1269                .iter()
1270                .find(|b| matches!(b.kind, BindingKind::Uniform))
1271            {
1272                if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1273                    list.clear();
1274                    let ubuf = self
1275                        .device
1276                        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1277                            label: Some("scirs2-uniforms"),
1278                            contents: &uniform_bytes,
1279                            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1280                        });
1281                    list.push(ubuf.clone());
1282                    owned_uniform_buffers.push(ubuf.clone());
1283                    let idx = owned_uniform_buffers.len() - 1;
1284                    let buf_ref = &owned_uniform_buffers[idx];
1285                    entries.push(wgpu::BindGroupEntry {
1286                        binding: uinfo.binding,
1287                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1288                            buffer: buf_ref,
1289                            offset: 0,
1290                            size: None,
1291                        }),
1292                    });
1293                }
1294            }
1295        } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1296            list.clear();
1297        }
1298        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1299            label: Some("scirs2-bind-group"),
1300            layout: &shader.bind_group_layout,
1301            entries: &entries,
1302        });
1303        Ok(bind_group)
1304    }
1305}
1306
1307impl Drop for WebGPUBuffer {
1308    fn drop(&mut self) {
1309        // Return buffer to memory pool if possible
1310        if let Ok(mut pool) = self.memory_pool.lock() {
1311            #[cfg(feature = "wgpu_backend")]
1312            {
1313                // In real implementation, would return buffer to pool
1314                if let Some(buffer) = self.device_buffer.take() {
1315                    pool.deallocate(buffer);
1316                }
1317            }
1318            #[cfg(not(feature = "wgpu_backend"))]
1319            {
1320                if let Some(buffer) = self.device_buffer.take() {
1321                    pool.deallocate(buffer);
1322                }
1323            }
1324        }
1325    }
1326}
1327
1328/// CPU fallback buffer for when WebGPU buffer allocation fails
1329/// This provides a graceful degradation when GPU memory is exhausted
1330struct WebGPUCpuFallbackBuffer {
1331    data: Vec<u8>,
1332    size: usize,
1333    #[allow(dead_code)]
1334    memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1335}
1336
1337impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1338    fn size(&self) -> usize {
1339        self.size
1340    }
1341
1342    unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1343        if size > self.size {
1344            eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1345            return;
1346        }
1347
1348        // Since this is a CPU fallback, we can use safe Rust internally
1349        let data_slice = std::slice::from_raw_parts(data, size);
1350        // We can't mutate self.data directly since &self is immutable
1351        // In a real implementation, this would require interior mutability
1352        eprintln!(
1353            "Warning: CPU fallback buffer copy_from_host called (size: {})",
1354            size
1355        );
1356    }
1357
1358    unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1359        if size > self.size {
1360            eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1361            return;
1362        }
1363
1364        // Copy from CPU buffer to host
1365        let data_slice = std::slice::from_raw_parts_mut(data, size);
1366        let copy_size = size.min(self.data.len());
1367        data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1368
1369        eprintln!(
1370            "Warning: CPU fallback buffer copy_to_host called (size: {})",
1371            size
1372        );
1373    }
1374
1375    fn device_ptr(&self) -> u64 {
1376        self.data.as_ptr() as u64
1377    }
1378
1379    fn as_any(&self) -> &dyn std::any::Any {
1380        self
1381    }
1382}
1383
1384// Safety: WebGPUCpuFallbackBuffer is thread-safe since it only contains owned data
1385unsafe impl Send for WebGPUCpuFallbackBuffer {}
1386unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1387
1388/// WebGPU memory pool for efficient buffer management
1389struct WebGPUMemoryPool {
1390    #[cfg(feature = "wgpu_backend")]
1391    available_buffers: HashMap<usize, Vec<Buffer>>,
1392    #[cfg(not(feature = "wgpu_backend"))]
1393    available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1394    #[allow(dead_code)]
1395    total_size: usize,
1396    used_size: usize,
1397}
1398
1399impl WebGPUMemoryPool {
1400    fn new(totalsize: usize) -> Self {
1401        Self {
1402            available_buffers: HashMap::new(),
1403            total_size: totalsize,
1404            used_size: 0,
1405        }
1406    }
1407
1408    #[cfg(feature = "wgpu_backend")]
1409    fn allocate(&mut self, size: usize) -> Option<Buffer> {
1410        // Try to find a suitable buffer in the pool
1411        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1412            if let Some(buffer) = buffers.pop() {
1413                self.used_size += size;
1414                return Some(buffer);
1415            }
1416        }
1417        None
1418    }
1419
1420    #[cfg(not(feature = "wgpu_backend"))]
1421    fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1422        // Try to find a suitable buffer in the pool
1423        if let Some(buffers) = self.available_buffers.get_mut(&size) {
1424            if let Some(buffer) = buffers.pop() {
1425                self.used_size += size;
1426                return Some(buffer);
1427            }
1428        }
1429        None
1430    }
1431
1432    #[cfg(feature = "wgpu_backend")]
1433    fn deallocate(&mut self, buffer: Buffer) {
1434        // Return buffer to pool
1435        let size = buffer.size() as usize;
1436        self.available_buffers
1437            .entry(size)
1438            .or_insert_with(Vec::new)
1439            .push(buffer);
1440        self.used_size = self.used_size.saturating_sub(size);
1441    }
1442
1443    #[cfg(not(feature = "wgpu_backend"))]
1444    fn deallocate(&mut self, buffer: WgpuBuffer) {
1445        // Fallback implementation - track the buffer
1446        let size = 1024; // Placeholder size
1447        self.available_buffers
1448            .entry(size)
1449            .or_insert_with(Vec::new)
1450            .push(buffer);
1451        self.used_size = self.used_size.saturating_sub(size);
1452    }
1453
1454    #[allow(dead_code)]
1455    fn get_memory_usage(&self) -> (usize, usize) {
1456        (self.used_size, self.total_size)
1457    }
1458}