1use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15 util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17 BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18 Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19 ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20 TextureSampleType, TextureViewDimension,
21};
22
23#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33#[allow(dead_code)]
35const ADAM_SHADER_WGSL: &str = r#"
36@group(0) @binding(0) var<storage, read_write> params: array<f32>;
37@group(0) @binding(1) var<storage, read> grads: array<f32>;
38@group(0) @binding(2) var<storage, read_write> m: array<f32>;
39@group(0) @binding(3) var<storage, read_write> v: array<f32>;
40
41struct AdamUniforms {
42 lr: f32,
43 beta1: f32,
44 beta2: f32,
45 eps: f32,
46 weight_decay: f32,
47 bias_correction1: f32,
48 bias_correction2: f32,
49 n: u32,
50};
51
52@group(0) @binding(4) var<uniform> uniforms: AdamUniforms;
53
54@compute @workgroup_size(64)
55#[allow(dead_code)]
56fn adam_update(@builtin(global_invocation_id) global_id: vec3<u32>) {
57 let idx = global_id.x;
58
59 if (idx >= uniforms.n) {
60 return;
61 }
62
63 var grad = grads[idx];
64
65 // Apply weight decay
66 if (uniforms.weight_decay > 0.0) {
67 grad += uniforms.weight_decay * params[idx];
68 }
69
70 // Update biased first moment estimate
71 m[idx] = uniforms.beta1 * m[idx] + (1.0 - uniforms.beta1) * grad;
72
73 // Update biased second raw moment estimate
74 v[idx] = uniforms.beta2 * v[idx] + (1.0 - uniforms.beta2) * grad * grad;
75
76 // Compute bias-corrected moment estimates
77 let m_hat = m[idx] / uniforms.bias_correction1;
78 let v_hat = v[idx] / uniforms.bias_correction2;
79
80 // Update parameters
81 params[idx] -= uniforms.lr * m_hat / (sqrt(v_hat) + uniforms.eps);
82}
83"#;
84
85#[allow(dead_code)]
86const GEMM_SHADER_WGSL: &str = r#"
87@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
88@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
89@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
90
91struct GemmUniforms {
92 M: u32,
93 N: u32,
94 K: u32,
95 alpha: f32,
96 beta: f32,
97};
98
99@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
100
101@compute @workgroup_size(8, 8)
102#[allow(dead_code)]
103fn gemm(@builtin(global_invocation_id) global_id: vec3<u32>) {
104 let row = global_id.x;
105 let col = global_id.y;
106
107 if (row >= uniforms.M || col >= uniforms.N) {
108 return;
109 }
110
111 var sum = 0.0;
112 for (var k = 0u; k < uniforms.K; k++) {
113 sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
114 }
115
116 let idx = row * uniforms.N + col;
117 matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
118}
119"#;
120
121pub struct WebGPUContext {
123 #[cfg(feature = "wgpu_backend")]
124 device: Arc<Device>,
125 #[cfg(feature = "wgpu_backend")]
126 queue: Arc<Queue>,
127 #[cfg(not(feature = "wgpu_backend"))]
128 device: Arc<WgpuDevice>,
129 #[cfg(not(feature = "wgpu_backend"))]
130 queue: Arc<WgpuQueue>,
131 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
132 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
133}
134
135unsafe impl Send for WebGPUContext {}
137unsafe impl Sync for WebGPUContext {}
138
139impl WebGPUContext {
140 pub fn new() -> Result<Self, GpuError> {
142 #[cfg(feature = "wgpu_backend")]
143 {
144 let instance_desc = InstanceDescriptor {
146 backends: Backends::all(),
147 ..Default::default()
148 };
149 let instance = Instance::new(&instance_desc);
150
151 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
152 power_preference: PowerPreference::HighPerformance,
153 compatible_surface: None,
154 force_fallback_adapter: false,
155 }))
156 .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
157
158 let device_descriptor = DeviceDescriptor {
159 label: Some("SciRS2 WebGPU Device"),
160 required_features: Features::empty(),
161 required_limits: Limits::default(),
162 ..Default::default()
164 };
165
166 let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
167 .map_err(|e| GpuError::Other(format!("{e}")))?;
168
169 Ok(Self {
170 device: Arc::new(device),
171 queue: Arc::new(queue),
172 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
173 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
175 }
176 #[cfg(not(feature = "wgpu_backend"))]
177 {
178 let device = Self::initialize_webgpu()?;
180 let queue = Self::create_queue(device)?;
181
182 Ok(Self {
183 device,
184 queue,
185 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
186 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
188 }
189 }
190
191 pub fn is_available() -> bool {
193 #[cfg(feature = "wgpu_backend")]
194 {
195 let instance_desc = InstanceDescriptor {
197 backends: Backends::all(),
198 ..Default::default()
199 };
200 let instance = Instance::new(&instance_desc);
201
202 pollster::block_on(async {
204 instance
205 .request_adapter(&RequestAdapterOptions {
206 power_preference: PowerPreference::default(),
207 compatible_surface: None,
208 force_fallback_adapter: false,
209 })
210 .await
211 .is_ok()
212 })
213 }
214 #[cfg(not(feature = "wgpu_backend"))]
215 {
216 false
218 }
219 }
220
221 fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
223 #[cfg(feature = "wgpu_backend")]
224 {
225 let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
227 label: Some(name),
228 source: ShaderSource::Wgsl(source.into()),
229 });
230
231 let entry_point = Self::extract_entry_point(source).unwrap_or("main");
233
234 let (bind_group_layout, binding_infos) =
236 self.create_bind_group_layout_from_source(source, name)?;
237
238 let pipeline_layout =
240 self.device
241 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
242 label: Some(&format!("{}_layout", name)),
243 bind_group_layouts: &[&bind_group_layout],
244 ..Default::default()
246 });
247
248 let compute_pipeline =
249 self.device
250 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
251 label: Some(&format!("{}_pipeline", name)),
252 layout: Some(&pipeline_layout),
253 module: &shader_module,
254 entry_point: Some(entry_point),
255 compilation_options: Default::default(),
256 cache: None,
257 });
258
259 Ok(WebGPUShader {
260 pipeline: compute_pipeline,
261 bind_group_layout,
262 name: name.to_string(),
263 binding_infos,
264 })
265 }
266 #[cfg(not(feature = "wgpu_backend"))]
267 {
268 let pipeline = Self::compile_wgsl_source(source, name)?;
270
271 Ok(WebGPUShader {
272 pipeline,
273 bind_group_layout: std::ptr::null_mut(),
274 name: name.to_string(),
275 binding_infos: Vec::new(),
276 })
277 }
278 }
279
280 #[cfg(feature = "wgpu_backend")]
282 fn create_bind_group_layout_from_source(
283 &self,
284 source: &str,
285 name: &str,
286 ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
287 #[derive(Default)]
288 struct PendingAttr {
289 group: Option<u32>,
290 binding: Option<u32>,
291 }
292 let mut pending = PendingAttr::default();
293 let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
294 let mut infos: Vec<BindingInfo> = Vec::new();
295
296 fn strip_comment(line: &str) -> &str {
297 line.split_once("//").map(|(a, _)| a).unwrap_or(line)
298 }
299
300 for raw_line in source.lines() {
301 let line = strip_comment(raw_line).trim();
302 if line.is_empty() {
303 continue;
304 }
305
306 if let Some(i) = line.find("@group(") {
307 if let Some(end) = line[i + 7..].find(')') {
308 if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
309 pending.group = Some(g);
310 }
311 }
312 }
313 if let Some(i) = line.find("@binding(") {
314 if let Some(end) = line[i + 9..].find(')') {
315 if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
316 pending.binding = Some(b);
317 }
318 }
319 }
320
321 if line.contains("var<") {
322 if pending.group.unwrap_or(0) == 0 {
324 let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
326 let name = extract_var_name(line).unwrap_or("");
327 let storage = line.contains("var<storage");
328 let uniform = line.contains("var<uniform");
329 let read_only = storage
330 && (line.contains(", read>")
331 || line.contains("var<storage, read>")
332 || line.contains("var<storage, read,"));
333 if storage {
334 entries.push(BindGroupLayoutEntry {
335 binding: binding_num,
336 visibility: ShaderStages::COMPUTE,
337 ty: BindingType::Buffer {
338 ty: BufferBindingType::Storage { read_only },
339 has_dynamic_offset: false,
340 min_binding_size: None,
341 },
342 count: None,
343 });
344 infos.push(BindingInfo {
345 binding: binding_num,
346 name: name.to_string(),
347 kind: if read_only {
348 BindingKind::StorageRead
349 } else {
350 BindingKind::StorageRw
351 },
352 });
353 } else if uniform {
354 entries.push(BindGroupLayoutEntry {
355 binding: binding_num,
356 visibility: ShaderStages::COMPUTE,
357 ty: BindingType::Buffer {
358 ty: BufferBindingType::Uniform,
359 has_dynamic_offset: false,
360 min_binding_size: None,
361 },
362 count: None,
363 });
364 infos.push(BindingInfo {
365 binding: binding_num,
366 name: name.to_string(),
367 kind: BindingKind::Uniform,
368 });
369 }
370 }
371 pending = PendingAttr::default();
372 }
373 }
374
375 if entries.is_empty() {
376 entries.push(BindGroupLayoutEntry {
377 binding: 0,
378 visibility: ShaderStages::COMPUTE,
379 ty: BindingType::Buffer {
380 ty: BufferBindingType::Storage { read_only: false },
381 has_dynamic_offset: false,
382 min_binding_size: None,
383 },
384 count: None,
385 });
386 infos.push(BindingInfo {
387 binding: 0,
388 name: "_unnamed".into(),
389 kind: BindingKind::StorageRw,
390 });
391 }
392
393 let mut seen = std::collections::HashSet::new();
395 let mut dedup_entries = Vec::new();
396 let mut dedup_infos = Vec::new();
397 for (e, info) in entries.into_iter().zip(infos) {
398 if seen.insert(e.binding) {
399 dedup_entries.push(e);
400 dedup_infos.push(info);
401 }
402 }
403
404 let bind_group_layout = self
405 .device
406 .create_bind_group_layout(&BindGroupLayoutDescriptor {
407 label: Some(&format!("{}_bind_group_layout", name)),
408 entries: &dedup_entries,
409 });
410 Ok((bind_group_layout, dedup_infos))
411 }
412
413 #[cfg(feature = "wgpu_backend")]
415 pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
416 let buffer = self.device.create_buffer(&BufferDescriptor {
417 label: Some("SciRS2 Buffer"),
418 size: size as u64,
419 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
420 mapped_at_creation: false,
421 });
422
423 Ok(buffer)
424 }
425
426 #[cfg(not(feature = "wgpu_backend"))]
428 pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
429 Ok((0x1000 + size) as WgpuBuffer)
431 }
432
433 #[cfg(not(feature = "wgpu_backend"))]
435 fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
436 Ok(0x1 as WgpuDevice)
438 }
439
440 #[cfg(not(feature = "wgpu_backend"))]
441 fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
442 Ok(0x2 as WgpuQueue)
444 }
445
446 #[cfg(not(feature = "wgpu_backend"))]
447 fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
448 Ok(0x3 as WgpuComputePipeline)
450 }
451
452 fn extract_entry_point(source: &str) -> Option<&str> {
454 let lines: Vec<&str> = source.lines().collect();
455
456 for (i, line) in lines.iter().enumerate() {
457 let trimmed = line.trim();
458
459 if trimmed.contains("@compute") {
461 let mut search_line = trimmed;
463 let mut search_idx = 0;
464
465 if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
467 search_idx += 1;
468 search_line = lines[search_idx].trim();
469 }
470
471 if let Some(start) = search_line.find("fn ") {
473 let remaining = &search_line[start + 3..];
474 if let Some(end) = remaining.find('(') {
475 let funcname = remaining[..end].trim();
476 return Some(funcname);
477 }
478 }
479 }
480 }
481
482 None
483 }
484}
485
486impl GpuContextImpl for WebGPUContext {
487 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
488 if let Ok(mut pool) = self.memory_pool.lock() {
490 if let Some(device_buffer) = pool.allocate(size) {
491 return Arc::new(WebGPUBuffer {
492 device_buffer: Some(device_buffer),
493 #[cfg(feature = "wgpu_backend")]
494 queue: Arc::clone(&self.queue),
495 #[cfg(feature = "wgpu_backend")]
496 device: Arc::clone(&self.device),
497 #[cfg(not(feature = "wgpu_backend"))]
498 queue: self.queue,
499 size,
500 memory_pool: Arc::clone(&self.memory_pool),
501 });
502 }
503 }
504
505 let device_buffer = match self.allocate_device_memory(size) {
507 Ok(buffer) => buffer,
508 Err(e) => {
509 eprintln!(
511 "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
512 e
513 );
514
515 #[cfg(feature = "wgpu_backend")]
516 {
517 return Arc::new(WebGPUCpuFallbackBuffer {
520 data: vec![0u8; size],
521 size,
522 memory_pool: Arc::clone(&self.memory_pool),
523 });
524 }
525 #[cfg(not(feature = "wgpu_backend"))]
526 {
527 (0x2000 + size) as WgpuBuffer
528 }
529 }
530 };
531
532 Arc::new(WebGPUBuffer {
533 device_buffer: Some(device_buffer),
534 #[cfg(feature = "wgpu_backend")]
535 queue: Arc::clone(&self.queue),
536 #[cfg(feature = "wgpu_backend")]
537 device: Arc::clone(&self.device),
538 #[cfg(not(feature = "wgpu_backend"))]
539 queue: self.queue,
540 size,
541 memory_pool: Arc::clone(&self.memory_pool),
542 })
543 }
544
545 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
546 Arc::new(WebGPUCompiler {
547 context: Arc::new(WebGPUContext {
548 memory_pool: Arc::clone(&self.memory_pool),
549 compiled_shaders: Arc::clone(&self.compiled_shaders),
550 #[cfg(feature = "wgpu_backend")]
551 device: Arc::clone(&self.device),
552 #[cfg(feature = "wgpu_backend")]
553 queue: Arc::clone(&self.queue),
554 #[cfg(not(feature = "wgpu_backend"))]
555 device: Arc::clone(&self.device),
556 #[cfg(not(feature = "wgpu_backend"))]
557 queue: Arc::clone(&self.queue),
558 }),
559 })
560 }
561
562 fn as_any(&self) -> &dyn std::any::Any {
563 self
564 }
565}
566
567struct WebGPUShader {
569 #[cfg(feature = "wgpu_backend")]
570 pipeline: ComputePipeline,
571 #[cfg(not(feature = "wgpu_backend"))]
572 pipeline: WgpuComputePipeline,
573 #[cfg(feature = "wgpu_backend")]
574 #[allow(dead_code)]
575 bind_group_layout: BindGroupLayout,
576 #[cfg(not(feature = "wgpu_backend"))]
577 #[allow(dead_code)]
578 bind_group_layout: *mut std::ffi::c_void,
579 #[allow(dead_code)]
580 name: String,
581 #[allow(dead_code)]
582 binding_infos: Vec<BindingInfo>, }
584
585unsafe impl Send for WebGPUShader {}
587unsafe impl Sync for WebGPUShader {}
588
589struct WebGPUCompiler {
591 context: Arc<WebGPUContext>,
592}
593
594impl GpuCompilerImpl for WebGPUCompiler {
595 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
596 let shader = self.context.compile_shader_internal(source, "shader")?;
597 Ok(Arc::new(WebGPUKernelHandle {
598 shader_name: shader.name.clone(),
599 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
600 params: Arc::new(Mutex::new(HashMap::new())),
601 #[cfg(feature = "wgpu_backend")]
602 device: Arc::clone(&self.context.device),
603 #[cfg(feature = "wgpu_backend")]
604 queue: Arc::clone(&self.context.queue),
605 #[cfg(feature = "wgpu_backend")]
606 ephemeral_uniforms: Mutex::new(Vec::new()),
607 #[cfg(not(feature = "wgpu_backend"))]
608 device: self.context.device,
609 #[cfg(not(feature = "wgpu_backend"))]
610 queue: self.context.queue,
611 }))
612 }
613
614 fn compile_typed(
615 &self,
616 name: &str,
617 _input_type: std::any::TypeId,
618 _output_type: std::any::TypeId,
619 ) -> Arc<dyn GpuKernelImpl> {
620 Arc::new(WebGPUKernelHandle {
621 shader_name: name.to_string(),
622 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
623 params: Arc::new(Mutex::new(HashMap::new())),
624 #[cfg(feature = "wgpu_backend")]
625 device: Arc::clone(&self.context.device),
626 #[cfg(feature = "wgpu_backend")]
627 queue: Arc::clone(&self.context.queue),
628 #[cfg(feature = "wgpu_backend")]
629 ephemeral_uniforms: Mutex::new(Vec::new()),
630 #[cfg(not(feature = "wgpu_backend"))]
631 device: self.context.device,
632 #[cfg(not(feature = "wgpu_backend"))]
633 queue: self.context.queue,
634 })
635 }
636}
637
638struct WebGPUKernelHandle {
640 shader_name: String,
641 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
642 params: Arc<Mutex<HashMap<String, KernelParam>>>,
643 #[cfg(feature = "wgpu_backend")]
644 device: Arc<Device>,
645 #[cfg(feature = "wgpu_backend")]
646 queue: Arc<Queue>,
647 #[cfg(feature = "wgpu_backend")]
648 ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
649 #[cfg(not(feature = "wgpu_backend"))]
650 device: WgpuDevice,
651 #[cfg(not(feature = "wgpu_backend"))]
652 queue: WgpuQueue,
653}
654
655enum KernelParam {
656 #[allow(dead_code)]
657 Buffer(Arc<dyn GpuBufferImpl>),
658 #[allow(dead_code)]
659 U32(u32),
660 #[allow(dead_code)]
661 I32(i32),
662 #[allow(dead_code)]
663 F32(f32),
664 #[allow(dead_code)]
665 F64(f64),
666 Bytes(Vec<u8>),
667}
668
669#[derive(Clone, Debug)]
670enum BindingKind {
671 StorageRw,
672 StorageRead,
673 Uniform,
674}
675
676#[derive(Clone, Debug)]
677struct BindingInfo {
678 binding: u32,
679 name: String,
680 kind: BindingKind,
681}
682
683fn extract_var_name(line: &str) -> Option<&str> {
684 if let Some(var_start) = line.find("var<") {
685 let after_var = &line[var_start..];
686 if let Some(close) = after_var.find('>') {
687 let after = &after_var[close + 1..];
688 let after = after.trim_start();
689 if let Some(colon) = after.find(':') {
690 let name_part = after[..colon].trim();
691 if !name_part.is_empty() {
692 return Some(name_part);
693 }
694 }
695 }
696 }
697 None
698}
699
700impl GpuKernelImpl for WebGPUKernelHandle {
701 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
702 let mut params = self.params.lock().expect("Operation failed");
703 params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
704 }
705
706 fn set_u32(&self, name: &str, value: u32) {
707 let mut params = self.params.lock().expect("Operation failed");
708 params.insert(name.to_string(), KernelParam::U32(value));
709 }
710
711 fn set_i32(&self, name: &str, value: i32) {
712 let mut params = self.params.lock().expect("Operation failed");
713 params.insert(name.to_string(), KernelParam::I32(value));
714 }
715
716 fn set_f32(&self, name: &str, value: f32) {
717 let mut params = self.params.lock().expect("Operation failed");
718 params.insert(name.to_string(), KernelParam::F32(value));
719 }
720
721 fn set_f64(&self, name: &str, value: f64) {
722 let mut params = self.params.lock().expect("Operation failed");
723 params.insert(name.to_string(), KernelParam::F64(value));
724 }
725
726 #[allow(dead_code)]
727 fn dispatch(&self, workgroups: [u32; 3]) {
730 #[cfg(feature = "wgpu_backend")]
731 {
732 let shaders = self.compiled_shaders.lock().expect("Operation failed");
734 if let Some(shader) = shaders.get(&self.shader_name) {
735 let params = self.params.lock().expect("Operation failed");
736
737 let mut encoder =
739 self.device
740 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
741 label: Some("Compute Command Encoder"),
742 });
743
744 {
746 let mut compute_pass =
747 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
748 label: Some("Compute Pass"),
749 timestamp_writes: None,
750 });
751
752 compute_pass.set_pipeline(&shader.pipeline);
754
755 if let Ok(bind_group) = self.create_bind_group_from_params(shader, ¶ms) {
756 compute_pass.set_bind_group(0, &bind_group, &[]);
757 } else {
758 eprintln!(
759 "Warning: Failed to create bind group for shader {}",
760 self.shader_name
761 );
762 }
763
764 compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
766 }
767
768 let command_buffer = encoder.finish();
770 self.queue.submit(std::iter::once(command_buffer));
771
772 eprintln!(
773 "WebGPU compute shader {} dispatched with workgroups: {:?}",
774 self.shader_name, workgroups
775 );
776 }
777 }
778 #[cfg(not(feature = "wgpu_backend"))]
779 {
780 eprintln!("Executing WebGPU shader {} (simulated)", self.shader_name);
782 eprintln!("Work groups: {:?}", workgroups);
783 }
784 }
785}
786
787struct WebGPUBuffer {
789 #[cfg(feature = "wgpu_backend")]
790 device_buffer: Option<Buffer>,
791 #[cfg(feature = "wgpu_backend")]
792 queue: Arc<Queue>,
793 #[cfg(feature = "wgpu_backend")]
794 device: Arc<Device>,
795 #[cfg(not(feature = "wgpu_backend"))]
796 device_buffer: Option<WgpuBuffer>,
797 #[cfg(not(feature = "wgpu_backend"))]
798 queue: WgpuQueue,
799 size: usize,
800 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
801}
802
803unsafe impl Send for WebGPUBuffer {}
807unsafe impl Sync for WebGPUBuffer {}
808
809impl GpuBufferImpl for WebGPUBuffer {
810 fn size(&self) -> usize {
811 self.size
812 }
813
814 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
815 #[cfg(feature = "wgpu_backend")]
816 {
817 if size > self.size {
819 eprintln!(
821 "Warning: Data size {} exceeds buffer size {}",
822 size, self.size
823 );
824 return;
825 }
826
827 let data_slice = std::slice::from_raw_parts(data, size);
829
830 if let Some(ref buffer) = self.device_buffer {
832 self.queue.write_buffer(buffer, 0, data_slice);
833 }
834 }
835 #[cfg(not(feature = "wgpu_backend"))]
836 {
837 if size > self.size {
839 eprintln!(
840 "Warning: Data size {} exceeds buffer size {}",
841 size, self.size
842 );
843 }
844 }
846 }
847
848 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
849 #[cfg(feature = "wgpu_backend")]
850 {
851 if size > self.size {
853 eprintln!(
854 "Warning: Data size {} exceeds buffer size {}",
855 size, self.size
856 );
857 return;
858 }
859
860 if let Some(ref buffer) = self.device_buffer {
861 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
862 label: Some("scirs2-readback"),
863 size: size as u64,
864 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
865 mapped_at_creation: false,
866 });
867 let mut encoder =
868 self.device
869 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
870 label: Some("scirs2-readback-enc"),
871 });
872 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
873 self.queue.submit(Some(encoder.finish()));
874 let slice = staging.slice(0..size as u64);
875 let (tx, rx) = std::sync::mpsc::channel();
876 slice.map_async(wgpu::MapMode::Read, move |r| {
877 let _ = tx.send(r);
878 });
879 if let Ok(Ok(())) = rx.recv() {
881 let mapped = slice.get_mapped_range();
882 let dst = std::slice::from_raw_parts_mut(data, size);
883 dst.copy_from_slice(&mapped);
884 drop(mapped);
885 staging.unmap();
886 } else {
887 eprintln!("Warning: map_async failed for readback");
888 }
889 }
890 }
891 #[cfg(not(feature = "wgpu_backend"))]
892 {
893 if size > self.size {
895 eprintln!(
896 "Warning: Data size {} exceeds buffer size {}",
897 size, self.size
898 );
899 }
900
901 let data_slice = std::slice::from_raw_parts_mut(data, size);
903 data_slice.fill(0);
904 }
905 }
906
907 fn device_ptr(&self) -> u64 {
908 #[cfg(feature = "wgpu_backend")]
909 {
910 &self.device_buffer as *const _ as u64
913 }
914 #[cfg(not(feature = "wgpu_backend"))]
915 {
916 self.device_buffer as u64
917 }
918 }
919
920 fn as_any(&self) -> &dyn std::any::Any {
921 self
922 }
923}
924
925#[cfg(feature = "wgpu_backend")]
926impl WebGPUKernelHandle {
927 fn create_bind_group_from_params(
928 &self,
929 shader: &WebGPUShader,
930 params: &HashMap<String, KernelParam>,
931 ) -> Result<wgpu::BindGroup, GpuError> {
932 let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
933 let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
935 let mut uniform_bytes: Vec<u8> = Vec::new();
936 for info in &shader.binding_infos {
937 match info.kind {
938 BindingKind::StorageRw | BindingKind::StorageRead => {
939 if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
940 if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
941 if let Some(ref inner) = wbuf.device_buffer {
942 entries.push(wgpu::BindGroupEntry {
943 binding: info.binding,
944 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
945 buffer: inner,
946 offset: 0,
947 size: None,
948 }),
949 });
950 }
951 }
952 } else {
953 return Err(GpuError::InvalidParameter(format!(
954 "Missing buffer param '{}'",
955 info.name
956 )));
957 }
958 }
959 BindingKind::Uniform => {
960 for (k, v) in params.iter() {
962 if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
963 match v {
964 KernelParam::U32(u) => {
965 uniform_bytes.extend_from_slice(&u.to_le_bytes())
966 }
967 KernelParam::I32(i) => {
968 uniform_bytes.extend_from_slice(&i.to_le_bytes())
969 }
970 KernelParam::F32(f) => {
971 uniform_bytes.extend_from_slice(&f.to_le_bytes())
972 }
973 KernelParam::F64(f) => {
974 uniform_bytes.extend_from_slice(&f.to_le_bytes())
975 }
976 KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
977 KernelParam::Buffer(_) => {}
978 }
979 }
980 }
981 }
982 }
983 }
984 if !uniform_bytes.is_empty() {
985 while uniform_bytes.len() % 16 != 0 {
986 uniform_bytes.push(0);
987 }
988 if let Some(uinfo) = shader
989 .binding_infos
990 .iter()
991 .find(|b| matches!(b.kind, BindingKind::Uniform))
992 {
993 if let Ok(mut list) = self.ephemeral_uniforms.lock() {
994 list.clear();
995 let ubuf = self
996 .device
997 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
998 label: Some("scirs2-uniforms"),
999 contents: &uniform_bytes,
1000 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1001 });
1002 list.push(ubuf.clone());
1003 owned_uniform_buffers.push(ubuf.clone());
1004 let idx = owned_uniform_buffers.len() - 1;
1005 let buf_ref = &owned_uniform_buffers[idx];
1006 entries.push(wgpu::BindGroupEntry {
1007 binding: uinfo.binding,
1008 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1009 buffer: buf_ref,
1010 offset: 0,
1011 size: None,
1012 }),
1013 });
1014 }
1015 }
1016 } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1017 list.clear();
1018 }
1019 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1020 label: Some("scirs2-bind-group"),
1021 layout: &shader.bind_group_layout,
1022 entries: &entries,
1023 });
1024 Ok(bind_group)
1025 }
1026}
1027
1028impl Drop for WebGPUBuffer {
1029 fn drop(&mut self) {
1030 if let Ok(mut pool) = self.memory_pool.lock() {
1032 #[cfg(feature = "wgpu_backend")]
1033 {
1034 if let Some(buffer) = self.device_buffer.take() {
1036 pool.deallocate(buffer);
1037 }
1038 }
1039 #[cfg(not(feature = "wgpu_backend"))]
1040 {
1041 if let Some(buffer) = self.device_buffer.take() {
1042 pool.deallocate(buffer);
1043 }
1044 }
1045 }
1046 }
1047}
1048
1049struct WebGPUCpuFallbackBuffer {
1052 data: Vec<u8>,
1053 size: usize,
1054 #[allow(dead_code)]
1055 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1056}
1057
1058impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1059 fn size(&self) -> usize {
1060 self.size
1061 }
1062
1063 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1064 if size > self.size {
1065 eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1066 return;
1067 }
1068
1069 let data_slice = std::slice::from_raw_parts(data, size);
1071 eprintln!(
1074 "Warning: CPU fallback buffer copy_from_host called (size: {})",
1075 size
1076 );
1077 }
1078
1079 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1080 if size > self.size {
1081 eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1082 return;
1083 }
1084
1085 let data_slice = std::slice::from_raw_parts_mut(data, size);
1087 let copy_size = size.min(self.data.len());
1088 data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1089
1090 eprintln!(
1091 "Warning: CPU fallback buffer copy_to_host called (size: {})",
1092 size
1093 );
1094 }
1095
1096 fn device_ptr(&self) -> u64 {
1097 self.data.as_ptr() as u64
1098 }
1099
1100 fn as_any(&self) -> &dyn std::any::Any {
1101 self
1102 }
1103}
1104
1105unsafe impl Send for WebGPUCpuFallbackBuffer {}
1107unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1108
1109struct WebGPUMemoryPool {
1111 #[cfg(feature = "wgpu_backend")]
1112 available_buffers: HashMap<usize, Vec<Buffer>>,
1113 #[cfg(not(feature = "wgpu_backend"))]
1114 available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1115 #[allow(dead_code)]
1116 total_size: usize,
1117 used_size: usize,
1118}
1119
1120impl WebGPUMemoryPool {
1121 fn new(totalsize: usize) -> Self {
1122 Self {
1123 available_buffers: HashMap::new(),
1124 total_size: totalsize,
1125 used_size: 0,
1126 }
1127 }
1128
1129 #[cfg(feature = "wgpu_backend")]
1130 fn allocate(&mut self, size: usize) -> Option<Buffer> {
1131 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1133 if let Some(buffer) = buffers.pop() {
1134 self.used_size += size;
1135 return Some(buffer);
1136 }
1137 }
1138 None
1139 }
1140
1141 #[cfg(not(feature = "wgpu_backend"))]
1142 fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1143 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1145 if let Some(buffer) = buffers.pop() {
1146 self.used_size += size;
1147 return Some(buffer);
1148 }
1149 }
1150 None
1151 }
1152
1153 #[cfg(feature = "wgpu_backend")]
1154 fn deallocate(&mut self, buffer: Buffer) {
1155 let size = buffer.size() as usize;
1157 self.available_buffers
1158 .entry(size)
1159 .or_insert_with(Vec::new)
1160 .push(buffer);
1161 self.used_size = self.used_size.saturating_sub(size);
1162 }
1163
1164 #[cfg(not(feature = "wgpu_backend"))]
1165 fn deallocate(&mut self, buffer: WgpuBuffer) {
1166 let size = 1024; self.available_buffers
1169 .entry(size)
1170 .or_insert_with(Vec::new)
1171 .push(buffer);
1172 self.used_size = self.used_size.saturating_sub(size);
1173 }
1174
1175 #[allow(dead_code)]
1176 fn get_memory_usage(&self) -> (usize, usize) {
1177 (self.used_size, self.total_size)
1178 }
1179}