1use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15 util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17 BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18 Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19 ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20 TextureSampleType, TextureViewDimension,
21};
22
23#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33#[allow(dead_code)]
35const ADAM_SHADER_WGSL: &str = r#"
36@group(0) @binding(0) var<storage, read_write> params: array<f32>;
37@group(0) @binding(1) var<storage, read> grads: array<f32>;
38@group(0) @binding(2) var<storage, read_write> m: array<f32>;
39@group(0) @binding(3) var<storage, read_write> v: array<f32>;
40
41struct AdamUniforms {
42 lr: f32,
43 beta1: f32,
44 beta2: f32,
45 eps: f32,
46 weight_decay: f32,
47 bias_correction1: f32,
48 bias_correction2: f32,
49 n: u32,
50};
51
52@group(0) @binding(4) var<uniform> uniforms: AdamUniforms;
53
54@compute @workgroup_size(64)
55#[allow(dead_code)]
56fn adam_update(@builtin(global_invocation_id) global_id: vec3<u32>) {
57 let idx = global_id.x;
58
59 if (idx >= uniforms.n) {
60 return;
61 }
62
63 var grad = grads[idx];
64
65 // Apply weight decay
66 if (uniforms.weight_decay > 0.0) {
67 grad += uniforms.weight_decay * params[idx];
68 }
69
70 // Update biased first moment estimate
71 m[idx] = uniforms.beta1 * m[idx] + (1.0 - uniforms.beta1) * grad;
72
73 // Update biased second raw moment estimate
74 v[idx] = uniforms.beta2 * v[idx] + (1.0 - uniforms.beta2) * grad * grad;
75
76 // Compute bias-corrected moment estimates
77 let m_hat = m[idx] / uniforms.bias_correction1;
78 let v_hat = v[idx] / uniforms.bias_correction2;
79
80 // Update parameters
81 params[idx] -= uniforms.lr * m_hat / (sqrt(v_hat) + uniforms.eps);
82}
83"#;
84
85#[allow(dead_code)]
86const GEMM_SHADER_WGSL: &str = r#"
87@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
88@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
89@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
90
91struct GemmUniforms {
92 M: u32,
93 N: u32,
94 K: u32,
95 alpha: f32,
96 beta: f32,
97};
98
99@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
100
101@compute @workgroup_size(8, 8)
102#[allow(dead_code)]
103fn gemm(@builtin(global_invocation_id) global_id: vec3<u32>) {
104 let row = global_id.x;
105 let col = global_id.y;
106
107 if (row >= uniforms.M || col >= uniforms.N) {
108 return;
109 }
110
111 var sum = 0.0;
112 for (var k = 0u; k < uniforms.K; k++) {
113 sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
114 }
115
116 let idx = row * uniforms.N + col;
117 matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
118}
119"#;
120
121pub struct WebGPUContext {
123 #[cfg(feature = "wgpu_backend")]
124 device: Arc<Device>,
125 #[cfg(feature = "wgpu_backend")]
126 queue: Arc<Queue>,
127 #[cfg(not(feature = "wgpu_backend"))]
128 device: Arc<WgpuDevice>,
129 #[cfg(not(feature = "wgpu_backend"))]
130 queue: Arc<WgpuQueue>,
131 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
132 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
133}
134
135unsafe impl Send for WebGPUContext {}
137unsafe impl Sync for WebGPUContext {}
138
139impl WebGPUContext {
140 pub fn new() -> Result<Self, GpuError> {
142 #[cfg(feature = "wgpu_backend")]
143 {
144 let instance_desc = InstanceDescriptor {
146 backends: Backends::all(),
147 ..Default::default()
148 };
149 let instance = Instance::new(&instance_desc);
150
151 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
152 power_preference: PowerPreference::HighPerformance,
153 compatible_surface: None,
154 force_fallback_adapter: false,
155 }))
156 .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
157
158 let device_descriptor = DeviceDescriptor {
159 label: Some("SciRS2 WebGPU Device"),
160 required_features: Features::empty(),
161 required_limits: Limits::default(),
162 ..Default::default()
164 };
165
166 let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
167 .map_err(|e| GpuError::Other(format!("{e}")))?;
168
169 Ok(Self {
170 device: Arc::new(device),
171 queue: Arc::new(queue),
172 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
173 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
175 }
176 #[cfg(not(feature = "wgpu_backend"))]
177 {
178 let device = Self::initialize_webgpu()?;
180 let queue = Self::create_queue(device)?;
181
182 Ok(Self {
183 device,
184 queue,
185 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
186 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
188 }
189 }
190
191 pub fn is_available() -> bool {
193 #[cfg(feature = "wgpu_backend")]
194 {
195 let instance_desc = InstanceDescriptor {
197 backends: Backends::all(),
198 ..Default::default()
199 };
200 let instance = Instance::new(&instance_desc);
201
202 pollster::block_on(async {
204 instance
205 .request_adapter(&RequestAdapterOptions {
206 power_preference: PowerPreference::default(),
207 compatible_surface: None,
208 force_fallback_adapter: false,
209 })
210 .await
211 .is_ok()
212 })
213 }
214 #[cfg(not(feature = "wgpu_backend"))]
215 {
216 false
218 }
219 }
220
221 fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
223 #[cfg(feature = "wgpu_backend")]
224 {
225 let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
227 label: Some(name),
228 source: ShaderSource::Wgsl(source.into()),
229 });
230
231 let entry_point = Self::extract_entry_point(source).unwrap_or("main");
233
234 let (bind_group_layout, binding_infos) =
236 self.create_bind_group_layout_from_source(source, name)?;
237
238 let pipeline_layout =
240 self.device
241 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
242 label: Some(&format!("{}_layout", name)),
243 bind_group_layouts: &[&bind_group_layout],
244 push_constant_ranges: &[],
245 });
246
247 let compute_pipeline =
248 self.device
249 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
250 label: Some(&format!("{}_pipeline", name)),
251 layout: Some(&pipeline_layout),
252 module: &shader_module,
253 entry_point: Some(entry_point),
254 compilation_options: Default::default(),
255 cache: None,
256 });
257
258 Ok(WebGPUShader {
259 pipeline: compute_pipeline,
260 bind_group_layout,
261 name: name.to_string(),
262 binding_infos,
263 })
264 }
265 #[cfg(not(feature = "wgpu_backend"))]
266 {
267 let pipeline = Self::compile_wgsl_source(source, name)?;
269
270 Ok(WebGPUShader {
271 pipeline,
272 bind_group_layout: std::ptr::null_mut(),
273 name: name.to_string(),
274 binding_infos: Vec::new(),
275 })
276 }
277 }
278
279 #[cfg(feature = "wgpu_backend")]
281 fn create_bind_group_layout_from_source(
282 &self,
283 source: &str,
284 name: &str,
285 ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
286 #[derive(Default)]
287 struct PendingAttr {
288 group: Option<u32>,
289 binding: Option<u32>,
290 }
291 let mut pending = PendingAttr::default();
292 let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
293 let mut infos: Vec<BindingInfo> = Vec::new();
294
295 fn strip_comment(line: &str) -> &str {
296 line.split_once("//").map(|(a, _)| a).unwrap_or(line)
297 }
298
299 for raw_line in source.lines() {
300 let line = strip_comment(raw_line).trim();
301 if line.is_empty() {
302 continue;
303 }
304
305 if let Some(i) = line.find("@group(") {
306 if let Some(end) = line[i + 7..].find(')') {
307 if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
308 pending.group = Some(g);
309 }
310 }
311 }
312 if let Some(i) = line.find("@binding(") {
313 if let Some(end) = line[i + 9..].find(')') {
314 if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
315 pending.binding = Some(b);
316 }
317 }
318 }
319
320 if line.contains("var<") {
321 if pending.group.unwrap_or(0) == 0 {
323 let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
325 let name = extract_var_name(line).unwrap_or("");
326 let storage = line.contains("var<storage");
327 let uniform = line.contains("var<uniform");
328 let read_only = storage
329 && (line.contains(", read>")
330 || line.contains("var<storage, read>")
331 || line.contains("var<storage, read,"));
332 if storage {
333 entries.push(BindGroupLayoutEntry {
334 binding: binding_num,
335 visibility: ShaderStages::COMPUTE,
336 ty: BindingType::Buffer {
337 ty: BufferBindingType::Storage { read_only },
338 has_dynamic_offset: false,
339 min_binding_size: None,
340 },
341 count: None,
342 });
343 infos.push(BindingInfo {
344 binding: binding_num,
345 name: name.to_string(),
346 kind: if read_only {
347 BindingKind::StorageRead
348 } else {
349 BindingKind::StorageRw
350 },
351 });
352 } else if uniform {
353 entries.push(BindGroupLayoutEntry {
354 binding: binding_num,
355 visibility: ShaderStages::COMPUTE,
356 ty: BindingType::Buffer {
357 ty: BufferBindingType::Uniform,
358 has_dynamic_offset: false,
359 min_binding_size: None,
360 },
361 count: None,
362 });
363 infos.push(BindingInfo {
364 binding: binding_num,
365 name: name.to_string(),
366 kind: BindingKind::Uniform,
367 });
368 }
369 }
370 pending = PendingAttr::default();
371 }
372 }
373
374 if entries.is_empty() {
375 entries.push(BindGroupLayoutEntry {
376 binding: 0,
377 visibility: ShaderStages::COMPUTE,
378 ty: BindingType::Buffer {
379 ty: BufferBindingType::Storage { read_only: false },
380 has_dynamic_offset: false,
381 min_binding_size: None,
382 },
383 count: None,
384 });
385 infos.push(BindingInfo {
386 binding: 0,
387 name: "_unnamed".into(),
388 kind: BindingKind::StorageRw,
389 });
390 }
391
392 let mut seen = std::collections::HashSet::new();
394 let mut dedup_entries = Vec::new();
395 let mut dedup_infos = Vec::new();
396 for (e, info) in entries.into_iter().zip(infos.into_iter()) {
397 if seen.insert(e.binding) {
398 dedup_entries.push(e);
399 dedup_infos.push(info);
400 }
401 }
402
403 let bind_group_layout = self
404 .device
405 .create_bind_group_layout(&BindGroupLayoutDescriptor {
406 label: Some(&format!("{}_bind_group_layout", name)),
407 entries: &dedup_entries,
408 });
409 Ok((bind_group_layout, dedup_infos))
410 }
411
412 #[cfg(feature = "wgpu_backend")]
414 pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
415 let buffer = self.device.create_buffer(&BufferDescriptor {
416 label: Some("SciRS2 Buffer"),
417 size: size as u64,
418 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
419 mapped_at_creation: false,
420 });
421
422 Ok(buffer)
423 }
424
425 #[cfg(not(feature = "wgpu_backend"))]
427 pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
428 Ok((0x1000 + size) as WgpuBuffer)
430 }
431
432 #[cfg(not(feature = "wgpu_backend"))]
434 fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
435 Ok(0x1 as WgpuDevice)
437 }
438
439 #[cfg(not(feature = "wgpu_backend"))]
440 fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
441 Ok(0x2 as WgpuQueue)
443 }
444
445 #[cfg(not(feature = "wgpu_backend"))]
446 fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
447 Ok(0x3 as WgpuComputePipeline)
449 }
450
451 fn extract_entry_point(source: &str) -> Option<&str> {
453 let lines: Vec<&str> = source.lines().collect();
454
455 for (i, line) in lines.iter().enumerate() {
456 let trimmed = line.trim();
457
458 if trimmed.contains("@compute") {
460 let mut search_line = trimmed;
462 let mut search_idx = 0;
463
464 if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
466 search_idx += 1;
467 search_line = lines[search_idx].trim();
468 }
469
470 if let Some(start) = search_line.find("fn ") {
472 let remaining = &search_line[start + 3..];
473 if let Some(end) = remaining.find('(') {
474 let funcname = remaining[..end].trim();
475 return Some(funcname);
476 }
477 }
478 }
479 }
480
481 None
482 }
483}
484
485impl GpuContextImpl for WebGPUContext {
486 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
487 if let Ok(mut pool) = self.memory_pool.lock() {
489 if let Some(device_buffer) = pool.allocate(size) {
490 return Arc::new(WebGPUBuffer {
491 device_buffer: Some(device_buffer),
492 #[cfg(feature = "wgpu_backend")]
493 queue: Arc::clone(&self.queue),
494 #[cfg(feature = "wgpu_backend")]
495 device: Arc::clone(&self.device),
496 #[cfg(not(feature = "wgpu_backend"))]
497 queue: self.queue,
498 size,
499 memory_pool: Arc::clone(&self.memory_pool),
500 });
501 }
502 }
503
504 let device_buffer = match self.allocate_device_memory(size) {
506 Ok(buffer) => buffer,
507 Err(e) => {
508 eprintln!(
510 "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
511 e
512 );
513
514 #[cfg(feature = "wgpu_backend")]
515 {
516 return Arc::new(WebGPUCpuFallbackBuffer {
519 data: vec![0u8; size],
520 size,
521 memory_pool: Arc::clone(&self.memory_pool),
522 });
523 }
524 #[cfg(not(feature = "wgpu_backend"))]
525 {
526 (0x2000 + size) as WgpuBuffer
527 }
528 }
529 };
530
531 Arc::new(WebGPUBuffer {
532 device_buffer: Some(device_buffer),
533 #[cfg(feature = "wgpu_backend")]
534 queue: Arc::clone(&self.queue),
535 #[cfg(feature = "wgpu_backend")]
536 device: Arc::clone(&self.device),
537 #[cfg(not(feature = "wgpu_backend"))]
538 queue: self.queue,
539 size,
540 memory_pool: Arc::clone(&self.memory_pool),
541 })
542 }
543
544 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
545 Arc::new(WebGPUCompiler {
546 context: Arc::new(WebGPUContext {
547 memory_pool: Arc::clone(&self.memory_pool),
548 compiled_shaders: Arc::clone(&self.compiled_shaders),
549 #[cfg(feature = "wgpu_backend")]
550 device: Arc::clone(&self.device),
551 #[cfg(feature = "wgpu_backend")]
552 queue: Arc::clone(&self.queue),
553 #[cfg(not(feature = "wgpu_backend"))]
554 device: Arc::clone(&self.device),
555 #[cfg(not(feature = "wgpu_backend"))]
556 queue: Arc::clone(&self.queue),
557 }),
558 })
559 }
560
561 fn as_any(&self) -> &dyn std::any::Any {
562 self
563 }
564}
565
566struct WebGPUShader {
568 #[cfg(feature = "wgpu_backend")]
569 pipeline: ComputePipeline,
570 #[cfg(not(feature = "wgpu_backend"))]
571 pipeline: WgpuComputePipeline,
572 #[cfg(feature = "wgpu_backend")]
573 #[allow(dead_code)]
574 bind_group_layout: BindGroupLayout,
575 #[cfg(not(feature = "wgpu_backend"))]
576 #[allow(dead_code)]
577 bind_group_layout: *mut std::ffi::c_void,
578 #[allow(dead_code)]
579 name: String,
580 #[allow(dead_code)]
581 binding_infos: Vec<BindingInfo>, }
583
584unsafe impl Send for WebGPUShader {}
586unsafe impl Sync for WebGPUShader {}
587
588struct WebGPUCompiler {
590 context: Arc<WebGPUContext>,
591}
592
593impl GpuCompilerImpl for WebGPUCompiler {
594 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
595 let shader = self.context.compile_shader_internal(source, "shader")?;
596 Ok(Arc::new(WebGPUKernelHandle {
597 shader_name: shader.name.clone(),
598 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
599 params: Arc::new(Mutex::new(HashMap::new())),
600 #[cfg(feature = "wgpu_backend")]
601 device: Arc::clone(&self.context.device),
602 #[cfg(feature = "wgpu_backend")]
603 queue: Arc::clone(&self.context.queue),
604 #[cfg(feature = "wgpu_backend")]
605 ephemeral_uniforms: Mutex::new(Vec::new()),
606 #[cfg(not(feature = "wgpu_backend"))]
607 device: self.context.device,
608 #[cfg(not(feature = "wgpu_backend"))]
609 queue: self.context.queue,
610 }))
611 }
612
613 fn compile_typed(
614 &self,
615 name: &str,
616 _input_type: std::any::TypeId,
617 _output_type: std::any::TypeId,
618 ) -> Arc<dyn GpuKernelImpl> {
619 Arc::new(WebGPUKernelHandle {
620 shader_name: name.to_string(),
621 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
622 params: Arc::new(Mutex::new(HashMap::new())),
623 #[cfg(feature = "wgpu_backend")]
624 device: Arc::clone(&self.context.device),
625 #[cfg(feature = "wgpu_backend")]
626 queue: Arc::clone(&self.context.queue),
627 #[cfg(feature = "wgpu_backend")]
628 ephemeral_uniforms: Mutex::new(Vec::new()),
629 #[cfg(not(feature = "wgpu_backend"))]
630 device: self.context.device,
631 #[cfg(not(feature = "wgpu_backend"))]
632 queue: self.context.queue,
633 })
634 }
635}
636
637struct WebGPUKernelHandle {
639 shader_name: String,
640 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
641 params: Arc<Mutex<HashMap<String, KernelParam>>>,
642 #[cfg(feature = "wgpu_backend")]
643 device: Arc<Device>,
644 #[cfg(feature = "wgpu_backend")]
645 queue: Arc<Queue>,
646 #[cfg(feature = "wgpu_backend")]
647 ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
648 #[cfg(not(feature = "wgpu_backend"))]
649 device: WgpuDevice,
650 #[cfg(not(feature = "wgpu_backend"))]
651 queue: WgpuQueue,
652}
653
654enum KernelParam {
655 #[allow(dead_code)]
656 Buffer(Arc<dyn GpuBufferImpl>),
657 #[allow(dead_code)]
658 U32(u32),
659 #[allow(dead_code)]
660 I32(i32),
661 #[allow(dead_code)]
662 F32(f32),
663 #[allow(dead_code)]
664 F64(f64),
665 Bytes(Vec<u8>),
666}
667
668#[derive(Clone, Debug)]
669enum BindingKind {
670 StorageRw,
671 StorageRead,
672 Uniform,
673}
674
675#[derive(Clone, Debug)]
676struct BindingInfo {
677 binding: u32,
678 name: String,
679 kind: BindingKind,
680}
681
682fn extract_var_name(line: &str) -> Option<&str> {
683 if let Some(var_start) = line.find("var<") {
684 let after_var = &line[var_start..];
685 if let Some(close) = after_var.find('>') {
686 let after = &after_var[close + 1..];
687 let after = after.trim_start();
688 if let Some(colon) = after.find(':') {
689 let name_part = after[..colon].trim();
690 if !name_part.is_empty() {
691 return Some(name_part);
692 }
693 }
694 }
695 }
696 None
697}
698
699impl GpuKernelImpl for WebGPUKernelHandle {
700 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
701 let mut params = self.params.lock().expect("Operation failed");
702 params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
703 }
704
705 fn set_u32(&self, name: &str, value: u32) {
706 let mut params = self.params.lock().expect("Operation failed");
707 params.insert(name.to_string(), KernelParam::U32(value));
708 }
709
710 fn set_i32(&self, name: &str, value: i32) {
711 let mut params = self.params.lock().expect("Operation failed");
712 params.insert(name.to_string(), KernelParam::I32(value));
713 }
714
715 fn set_f32(&self, name: &str, value: f32) {
716 let mut params = self.params.lock().expect("Operation failed");
717 params.insert(name.to_string(), KernelParam::F32(value));
718 }
719
720 fn set_f64(&self, name: &str, value: f64) {
721 let mut params = self.params.lock().expect("Operation failed");
722 params.insert(name.to_string(), KernelParam::F64(value));
723 }
724
725 #[allow(dead_code)]
726 fn dispatch(&self, workgroups: [u32; 3]) {
729 #[cfg(feature = "wgpu_backend")]
730 {
731 let shaders = self.compiled_shaders.lock().expect("Operation failed");
733 if let Some(shader) = shaders.get(&self.shader_name) {
734 let params = self.params.lock().expect("Operation failed");
735
736 let mut encoder =
738 self.device
739 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
740 label: Some("Compute Command Encoder"),
741 });
742
743 {
745 let mut compute_pass =
746 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
747 label: Some("Compute Pass"),
748 timestamp_writes: None,
749 });
750
751 compute_pass.set_pipeline(&shader.pipeline);
753
754 if let Ok(bind_group) = self.create_bind_group_from_params(shader, ¶ms) {
755 compute_pass.set_bind_group(0, &bind_group, &[]);
756 } else {
757 eprintln!(
758 "Warning: Failed to create bind group for shader {}",
759 self.shader_name
760 );
761 }
762
763 compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
765 }
766
767 let command_buffer = encoder.finish();
769 self.queue.submit(std::iter::once(command_buffer));
770
771 eprintln!(
772 "WebGPU compute shader {} dispatched with workgroups: {:?}",
773 self.shader_name, workgroups
774 );
775 }
776 }
777 #[cfg(not(feature = "wgpu_backend"))]
778 {
779 eprintln!("Executing WebGPU shader {} (simulated)", self.shader_name);
781 eprintln!("Work groups: {:?}", workgroups);
782 }
783 }
784}
785
786struct WebGPUBuffer {
788 #[cfg(feature = "wgpu_backend")]
789 device_buffer: Option<Buffer>,
790 #[cfg(feature = "wgpu_backend")]
791 queue: Arc<Queue>,
792 #[cfg(feature = "wgpu_backend")]
793 device: Arc<Device>,
794 #[cfg(not(feature = "wgpu_backend"))]
795 device_buffer: Option<WgpuBuffer>,
796 #[cfg(not(feature = "wgpu_backend"))]
797 queue: WgpuQueue,
798 size: usize,
799 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
800}
801
802unsafe impl Send for WebGPUBuffer {}
806unsafe impl Sync for WebGPUBuffer {}
807
808impl GpuBufferImpl for WebGPUBuffer {
809 fn size(&self) -> usize {
810 self.size
811 }
812
813 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
814 #[cfg(feature = "wgpu_backend")]
815 {
816 if size > self.size {
818 eprintln!(
820 "Warning: Data size {} exceeds buffer size {}",
821 size, self.size
822 );
823 return;
824 }
825
826 let data_slice = std::slice::from_raw_parts(data, size);
828
829 if let Some(ref buffer) = self.device_buffer {
831 self.queue.write_buffer(buffer, 0, data_slice);
832 }
833 }
834 #[cfg(not(feature = "wgpu_backend"))]
835 {
836 if size > self.size {
838 eprintln!(
839 "Warning: Data size {} exceeds buffer size {}",
840 size, self.size
841 );
842 }
843 }
845 }
846
847 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
848 #[cfg(feature = "wgpu_backend")]
849 {
850 if size > self.size {
852 eprintln!(
853 "Warning: Data size {} exceeds buffer size {}",
854 size, self.size
855 );
856 return;
857 }
858
859 if let Some(ref buffer) = self.device_buffer {
860 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
861 label: Some("scirs2-readback"),
862 size: size as u64,
863 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
864 mapped_at_creation: false,
865 });
866 let mut encoder =
867 self.device
868 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
869 label: Some("scirs2-readback-enc"),
870 });
871 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
872 self.queue.submit(Some(encoder.finish()));
873 let slice = staging.slice(0..size as u64);
874 let (tx, rx) = std::sync::mpsc::channel();
875 slice.map_async(wgpu::MapMode::Read, move |r| {
876 let _ = tx.send(r);
877 });
878 if let Ok(Ok(())) = rx.recv() {
880 let mapped = slice.get_mapped_range();
881 let dst = std::slice::from_raw_parts_mut(data, size);
882 dst.copy_from_slice(&mapped);
883 drop(mapped);
884 staging.unmap();
885 } else {
886 eprintln!("Warning: map_async failed for readback");
887 }
888 }
889 }
890 #[cfg(not(feature = "wgpu_backend"))]
891 {
892 if size > self.size {
894 eprintln!(
895 "Warning: Data size {} exceeds buffer size {}",
896 size, self.size
897 );
898 }
899
900 let data_slice = std::slice::from_raw_parts_mut(data, size);
902 data_slice.fill(0);
903 }
904 }
905
906 fn device_ptr(&self) -> u64 {
907 #[cfg(feature = "wgpu_backend")]
908 {
909 &self.device_buffer as *const _ as u64
912 }
913 #[cfg(not(feature = "wgpu_backend"))]
914 {
915 self.device_buffer as u64
916 }
917 }
918
919 fn as_any(&self) -> &dyn std::any::Any {
920 self
921 }
922}
923
924#[cfg(feature = "wgpu_backend")]
925impl WebGPUKernelHandle {
926 fn create_bind_group_from_params(
927 &self,
928 shader: &WebGPUShader,
929 params: &HashMap<String, KernelParam>,
930 ) -> Result<wgpu::BindGroup, GpuError> {
931 let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
932 let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
934 let mut uniform_bytes: Vec<u8> = Vec::new();
935 for info in &shader.binding_infos {
936 match info.kind {
937 BindingKind::StorageRw | BindingKind::StorageRead => {
938 if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
939 if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
940 if let Some(ref inner) = wbuf.device_buffer {
941 entries.push(wgpu::BindGroupEntry {
942 binding: info.binding,
943 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
944 buffer: inner,
945 offset: 0,
946 size: None,
947 }),
948 });
949 }
950 }
951 } else {
952 return Err(GpuError::InvalidParameter(format!(
953 "Missing buffer param '{}'",
954 info.name
955 )));
956 }
957 }
958 BindingKind::Uniform => {
959 for (k, v) in params.iter() {
961 if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
962 match v {
963 KernelParam::U32(u) => {
964 uniform_bytes.extend_from_slice(&u.to_le_bytes())
965 }
966 KernelParam::I32(i) => {
967 uniform_bytes.extend_from_slice(&i.to_le_bytes())
968 }
969 KernelParam::F32(f) => {
970 uniform_bytes.extend_from_slice(&f.to_le_bytes())
971 }
972 KernelParam::F64(f) => {
973 uniform_bytes.extend_from_slice(&f.to_le_bytes())
974 }
975 KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
976 KernelParam::Buffer(_) => {}
977 }
978 }
979 }
980 }
981 }
982 }
983 if !uniform_bytes.is_empty() {
984 while uniform_bytes.len() % 16 != 0 {
985 uniform_bytes.push(0);
986 }
987 if let Some(uinfo) = shader
988 .binding_infos
989 .iter()
990 .find(|b| matches!(b.kind, BindingKind::Uniform))
991 {
992 if let Ok(mut list) = self.ephemeral_uniforms.lock() {
993 list.clear();
994 let ubuf = self
995 .device
996 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
997 label: Some("scirs2-uniforms"),
998 contents: &uniform_bytes,
999 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1000 });
1001 list.push(ubuf.clone());
1002 owned_uniform_buffers.push(ubuf.clone());
1003 let idx = owned_uniform_buffers.len() - 1;
1004 let buf_ref = &owned_uniform_buffers[idx];
1005 entries.push(wgpu::BindGroupEntry {
1006 binding: uinfo.binding,
1007 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1008 buffer: buf_ref,
1009 offset: 0,
1010 size: None,
1011 }),
1012 });
1013 }
1014 }
1015 } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1016 list.clear();
1017 }
1018 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1019 label: Some("scirs2-bind-group"),
1020 layout: &shader.bind_group_layout,
1021 entries: &entries,
1022 });
1023 Ok(bind_group)
1024 }
1025}
1026
1027impl Drop for WebGPUBuffer {
1028 fn drop(&mut self) {
1029 if let Ok(mut pool) = self.memory_pool.lock() {
1031 #[cfg(feature = "wgpu_backend")]
1032 {
1033 if let Some(buffer) = self.device_buffer.take() {
1035 pool.deallocate(buffer);
1036 }
1037 }
1038 #[cfg(not(feature = "wgpu_backend"))]
1039 {
1040 if let Some(buffer) = self.device_buffer.take() {
1041 pool.deallocate(buffer);
1042 }
1043 }
1044 }
1045 }
1046}
1047
1048struct WebGPUCpuFallbackBuffer {
1051 data: Vec<u8>,
1052 size: usize,
1053 #[allow(dead_code)]
1054 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1055}
1056
1057impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1058 fn size(&self) -> usize {
1059 self.size
1060 }
1061
1062 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1063 if size > self.size {
1064 eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1065 return;
1066 }
1067
1068 let data_slice = std::slice::from_raw_parts(data, size);
1070 eprintln!(
1073 "Warning: CPU fallback buffer copy_from_host called (size: {})",
1074 size
1075 );
1076 }
1077
1078 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1079 if size > self.size {
1080 eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1081 return;
1082 }
1083
1084 let data_slice = std::slice::from_raw_parts_mut(data, size);
1086 let copy_size = size.min(self.data.len());
1087 data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1088
1089 eprintln!(
1090 "Warning: CPU fallback buffer copy_to_host called (size: {})",
1091 size
1092 );
1093 }
1094
1095 fn device_ptr(&self) -> u64 {
1096 self.data.as_ptr() as u64
1097 }
1098
1099 fn as_any(&self) -> &dyn std::any::Any {
1100 self
1101 }
1102}
1103
1104unsafe impl Send for WebGPUCpuFallbackBuffer {}
1106unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1107
1108struct WebGPUMemoryPool {
1110 #[cfg(feature = "wgpu_backend")]
1111 available_buffers: HashMap<usize, Vec<Buffer>>,
1112 #[cfg(not(feature = "wgpu_backend"))]
1113 available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1114 #[allow(dead_code)]
1115 total_size: usize,
1116 used_size: usize,
1117}
1118
1119impl WebGPUMemoryPool {
1120 fn new(totalsize: usize) -> Self {
1121 Self {
1122 available_buffers: HashMap::new(),
1123 total_size: totalsize,
1124 used_size: 0,
1125 }
1126 }
1127
1128 #[cfg(feature = "wgpu_backend")]
1129 fn allocate(&mut self, size: usize) -> Option<Buffer> {
1130 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1132 if let Some(buffer) = buffers.pop() {
1133 self.used_size += size;
1134 return Some(buffer);
1135 }
1136 }
1137 None
1138 }
1139
1140 #[cfg(not(feature = "wgpu_backend"))]
1141 fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1142 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1144 if let Some(buffer) = buffers.pop() {
1145 self.used_size += size;
1146 return Some(buffer);
1147 }
1148 }
1149 None
1150 }
1151
1152 #[cfg(feature = "wgpu_backend")]
1153 fn deallocate(&mut self, buffer: Buffer) {
1154 let size = buffer.size() as usize;
1156 self.available_buffers
1157 .entry(size)
1158 .or_insert_with(Vec::new)
1159 .push(buffer);
1160 self.used_size = self.used_size.saturating_sub(size);
1161 }
1162
1163 #[cfg(not(feature = "wgpu_backend"))]
1164 fn deallocate(&mut self, buffer: WgpuBuffer) {
1165 let size = 1024; self.available_buffers
1168 .entry(size)
1169 .or_insert_with(Vec::new)
1170 .push(buffer);
1171 self.used_size = self.used_size.saturating_sub(size);
1172 }
1173
1174 #[allow(dead_code)]
1175 fn get_memory_usage(&self) -> (usize, usize) {
1176 (self.used_size, self.total_size)
1177 }
1178}