1use std::collections::HashMap;
6#[cfg(feature = "wgpu_backend")]
7use std::sync::{Arc, Mutex};
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "wgpu_backend")]
13#[allow(unused_imports)]
14use wgpu::{
15 util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
16 BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
17 BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
18 Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
19 ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
20 TextureSampleType, TextureViewDimension,
21};
22
23#[cfg(not(feature = "wgpu_backend"))]
25type WgpuDevice = *mut std::ffi::c_void;
26#[cfg(not(feature = "wgpu_backend"))]
27type WgpuQueue = *mut std::ffi::c_void;
28#[cfg(not(feature = "wgpu_backend"))]
29type WgpuBuffer = *mut std::ffi::c_void;
30#[cfg(not(feature = "wgpu_backend"))]
31type WgpuComputePipeline = *mut std::ffi::c_void;
32
33#[cfg(feature = "wgpu_backend")]
38pub struct WgpuComputePipeline {
39 pub pipeline: ComputePipeline,
41 pub bind_group_layout: BindGroupLayout,
43 pub workgroup_size: [u32; 3],
45}
46
47#[cfg(feature = "wgpu_backend")]
48unsafe impl Send for WgpuComputePipeline {}
50#[cfg(feature = "wgpu_backend")]
51unsafe impl Sync for WgpuComputePipeline {}
52
53#[cfg(feature = "wgpu_backend")]
74pub fn try_compile_wgsl(source: &str) -> Result<WgpuComputePipeline, GpuError> {
75 let ctx = WebGPUContext::new()?;
76 ctx.compile_to_pipeline(source)
77}
78
79#[cfg(feature = "wgpu_backend")]
84pub fn run_vector_add_wgsl(a: &[f32], b: &[f32]) -> Result<Vec<f32>, GpuError> {
85 let ctx = WebGPUContext::new()?;
86 ctx.run_vector_add(a, b)
87}
88
89pub const GEMM_SHADER_WGSL: &str = r#"
99@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
100@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
101@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
102
103struct GemmUniforms {
104 M: u32,
105 N: u32,
106 K: u32,
107 alpha: f32,
108 beta: f32,
109};
110
111@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
112
113@compute @workgroup_size(8, 8)
114fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
115 let row = global_id.x;
116 let col = global_id.y;
117
118 if row >= uniforms.M || col >= uniforms.N { return; }
119
120 var sum = 0.0f;
121 for (var k = 0u; k < uniforms.K; k++) {
122 sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
123 }
124
125 let idx = row * uniforms.N + col;
126 matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
127}
128"#;
129
130pub struct WebGPUContext {
132 #[cfg(feature = "wgpu_backend")]
133 device: Arc<Device>,
134 #[cfg(feature = "wgpu_backend")]
135 queue: Arc<Queue>,
136 #[cfg(not(feature = "wgpu_backend"))]
137 device: Arc<WgpuDevice>,
138 #[cfg(not(feature = "wgpu_backend"))]
139 queue: Arc<WgpuQueue>,
140 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
141 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
142}
143
144unsafe impl Send for WebGPUContext {}
146unsafe impl Sync for WebGPUContext {}
147
148impl WebGPUContext {
149 pub fn new() -> Result<Self, GpuError> {
151 #[cfg(feature = "wgpu_backend")]
152 {
153 let instance_desc = InstanceDescriptor {
155 backends: Backends::all(),
156 flags: wgpu::InstanceFlags::default(),
157 memory_budget_thresholds: Default::default(),
158 backend_options: Default::default(),
159 display: None,
160 };
161 let instance = Instance::new(instance_desc);
162
163 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
164 power_preference: PowerPreference::HighPerformance,
165 compatible_surface: None,
166 force_fallback_adapter: false,
167 }))
168 .map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
169
170 let device_descriptor = DeviceDescriptor {
171 label: Some("SciRS2 WebGPU Device"),
172 required_features: Features::empty(),
173 required_limits: Limits::default(),
174 ..Default::default()
176 };
177
178 let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
179 .map_err(|e| GpuError::Other(format!("{e}")))?;
180
181 Ok(Self {
182 device: Arc::new(device),
183 queue: Arc::new(queue),
184 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
185 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
187 }
188 #[cfg(not(feature = "wgpu_backend"))]
189 {
190 let device = Self::initialize_webgpu()?;
192 let queue = Self::create_queue(device)?;
193
194 Ok(Self {
195 device,
196 queue,
197 compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
198 memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
200 }
201 }
202
203 pub fn is_available() -> bool {
205 #[cfg(feature = "wgpu_backend")]
206 {
207 let instance_desc = InstanceDescriptor {
209 backends: Backends::all(),
210 flags: wgpu::InstanceFlags::default(),
211 memory_budget_thresholds: Default::default(),
212 backend_options: Default::default(),
213 display: None,
214 };
215 let instance = Instance::new(instance_desc);
216
217 pollster::block_on(async {
219 instance
220 .request_adapter(&RequestAdapterOptions {
221 power_preference: PowerPreference::default(),
222 compatible_surface: None,
223 force_fallback_adapter: false,
224 })
225 .await
226 .is_ok()
227 })
228 }
229 #[cfg(not(feature = "wgpu_backend"))]
230 {
231 false
233 }
234 }
235
236 fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
238 #[cfg(feature = "wgpu_backend")]
239 {
240 let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
242 label: Some(name),
243 source: ShaderSource::Wgsl(source.into()),
244 });
245
246 let entry_point = Self::extract_entry_point(source).unwrap_or("main");
248
249 let (bind_group_layout, binding_infos) =
251 self.create_bind_group_layout_from_source(source, name)?;
252
253 let pipeline_layout =
255 self.device
256 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
257 label: Some(&format!("{}_layout", name)),
258 bind_group_layouts: &[Some(&bind_group_layout)],
259 ..Default::default()
261 });
262
263 let compute_pipeline =
264 self.device
265 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
266 label: Some(&format!("{}_pipeline", name)),
267 layout: Some(&pipeline_layout),
268 module: &shader_module,
269 entry_point: Some(entry_point),
270 compilation_options: Default::default(),
271 cache: None,
272 });
273
274 let workgroup_size = extract_workgroup_size(source);
275
276 Ok(WebGPUShader {
277 pipeline: compute_pipeline,
278 bind_group_layout,
279 name: name.to_string(),
280 binding_infos,
281 workgroup_size,
282 })
283 }
284 #[cfg(not(feature = "wgpu_backend"))]
285 {
286 let pipeline = Self::compile_wgsl_source(source, name)?;
288
289 Ok(WebGPUShader {
290 pipeline,
291 bind_group_layout: std::ptr::null_mut(),
292 name: name.to_string(),
293 binding_infos: Vec::new(),
294 workgroup_size: [64, 1, 1],
295 })
296 }
297 }
298
299 #[cfg(feature = "wgpu_backend")]
301 fn create_bind_group_layout_from_source(
302 &self,
303 source: &str,
304 name: &str,
305 ) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
306 #[derive(Default)]
307 struct PendingAttr {
308 group: Option<u32>,
309 binding: Option<u32>,
310 }
311 let mut pending = PendingAttr::default();
312 let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
313 let mut infos: Vec<BindingInfo> = Vec::new();
314
315 fn strip_comment(line: &str) -> &str {
316 line.split_once("//").map(|(a, _)| a).unwrap_or(line)
317 }
318
319 for raw_line in source.lines() {
320 let line = strip_comment(raw_line).trim();
321 if line.is_empty() {
322 continue;
323 }
324
325 if let Some(i) = line.find("@group(") {
326 if let Some(end) = line[i + 7..].find(')') {
327 if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
328 pending.group = Some(g);
329 }
330 }
331 }
332 if let Some(i) = line.find("@binding(") {
333 if let Some(end) = line[i + 9..].find(')') {
334 if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
335 pending.binding = Some(b);
336 }
337 }
338 }
339
340 if line.contains("var<") {
341 if pending.group.unwrap_or(0) == 0 {
343 let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
345 let name = extract_var_name(line).unwrap_or("");
346 let storage = line.contains("var<storage");
347 let uniform = line.contains("var<uniform");
348 let read_only = storage
349 && (line.contains(", read>")
350 || line.contains("var<storage, read>")
351 || line.contains("var<storage, read,"));
352 if storage {
353 entries.push(BindGroupLayoutEntry {
354 binding: binding_num,
355 visibility: ShaderStages::COMPUTE,
356 ty: BindingType::Buffer {
357 ty: BufferBindingType::Storage { read_only },
358 has_dynamic_offset: false,
359 min_binding_size: None,
360 },
361 count: None,
362 });
363 infos.push(BindingInfo {
364 binding: binding_num,
365 name: name.to_string(),
366 kind: if read_only {
367 BindingKind::StorageRead
368 } else {
369 BindingKind::StorageRw
370 },
371 });
372 } else if uniform {
373 entries.push(BindGroupLayoutEntry {
374 binding: binding_num,
375 visibility: ShaderStages::COMPUTE,
376 ty: BindingType::Buffer {
377 ty: BufferBindingType::Uniform,
378 has_dynamic_offset: false,
379 min_binding_size: None,
380 },
381 count: None,
382 });
383 infos.push(BindingInfo {
384 binding: binding_num,
385 name: name.to_string(),
386 kind: BindingKind::Uniform,
387 });
388 }
389 }
390 pending = PendingAttr::default();
391 }
392 }
393
394 if entries.is_empty() {
395 entries.push(BindGroupLayoutEntry {
396 binding: 0,
397 visibility: ShaderStages::COMPUTE,
398 ty: BindingType::Buffer {
399 ty: BufferBindingType::Storage { read_only: false },
400 has_dynamic_offset: false,
401 min_binding_size: None,
402 },
403 count: None,
404 });
405 infos.push(BindingInfo {
406 binding: 0,
407 name: "_unnamed".into(),
408 kind: BindingKind::StorageRw,
409 });
410 }
411
412 let mut seen = std::collections::HashSet::new();
414 let mut dedup_entries = Vec::new();
415 let mut dedup_infos = Vec::new();
416 for (e, info) in entries.into_iter().zip(infos) {
417 if seen.insert(e.binding) {
418 dedup_entries.push(e);
419 dedup_infos.push(info);
420 }
421 }
422
423 let bind_group_layout = self
424 .device
425 .create_bind_group_layout(&BindGroupLayoutDescriptor {
426 label: Some(&format!("{}_bind_group_layout", name)),
427 entries: &dedup_entries,
428 });
429 Ok((bind_group_layout, dedup_infos))
430 }
431
432 #[cfg(feature = "wgpu_backend")]
434 pub fn device(&self) -> &Device {
435 &self.device
436 }
437
438 #[cfg(feature = "wgpu_backend")]
440 pub fn queue(&self) -> &Queue {
441 &self.queue
442 }
443
444 #[cfg(feature = "wgpu_backend")]
446 pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
447 let buffer = self.device.create_buffer(&BufferDescriptor {
448 label: Some("SciRS2 Buffer"),
449 size: size as u64,
450 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
451 mapped_at_creation: false,
452 });
453
454 Ok(buffer)
455 }
456
457 #[cfg(not(feature = "wgpu_backend"))]
459 pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
460 Ok((0x1000 + size) as WgpuBuffer)
462 }
463
464 #[cfg(not(feature = "wgpu_backend"))]
466 fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
467 Ok(0x1 as WgpuDevice)
469 }
470
471 #[cfg(not(feature = "wgpu_backend"))]
472 fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
473 Ok(0x2 as WgpuQueue)
475 }
476
477 #[cfg(not(feature = "wgpu_backend"))]
478 fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
479 Ok(0x3 as WgpuComputePipeline)
481 }
482
483 #[cfg(feature = "wgpu_backend")]
489 pub fn compile_to_pipeline(&self, source: &str) -> Result<WgpuComputePipeline, GpuError> {
490 let shader = self.compile_shader_internal(source, "scirs2-pipeline")?;
491 Ok(WgpuComputePipeline {
492 pipeline: shader.pipeline,
493 bind_group_layout: shader.bind_group_layout,
494 workgroup_size: shader.workgroup_size,
495 })
496 }
497
498 #[cfg(feature = "wgpu_backend")]
500 pub fn run_vector_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, GpuError> {
501 use wgpu::{util::DeviceExt as _, BufferUsages};
502
503 let n = a.len();
504 if n != b.len() {
505 return Err(GpuError::InvalidParameter(
506 "vectors must have equal length".into(),
507 ));
508 }
509
510 const VECTOR_ADD_WGSL: &str = r#"
511@group(0) @binding(0) var<storage, read> a : array<f32>;
512@group(0) @binding(1) var<storage, read> b : array<f32>;
513@group(0) @binding(2) var<storage, read_write> result : array<f32>;
514
515@compute @workgroup_size(64)
516fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
517 let idx = global_id.x;
518 if idx < arrayLength(&result) {
519 result[idx] = a[idx] + b[idx];
520 }
521}
522"#;
523
524 let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
526 label: Some("vector-add"),
527 source: ShaderSource::Wgsl(VECTOR_ADD_WGSL.into()),
528 });
529
530 let bgl = self
532 .device
533 .create_bind_group_layout(&BindGroupLayoutDescriptor {
534 label: Some("vector-add-bgl"),
535 entries: &[
536 BindGroupLayoutEntry {
538 binding: 0,
539 visibility: ShaderStages::COMPUTE,
540 ty: BindingType::Buffer {
541 ty: BufferBindingType::Storage { read_only: true },
542 has_dynamic_offset: false,
543 min_binding_size: None,
544 },
545 count: None,
546 },
547 BindGroupLayoutEntry {
549 binding: 1,
550 visibility: ShaderStages::COMPUTE,
551 ty: BindingType::Buffer {
552 ty: BufferBindingType::Storage { read_only: true },
553 has_dynamic_offset: false,
554 min_binding_size: None,
555 },
556 count: None,
557 },
558 BindGroupLayoutEntry {
560 binding: 2,
561 visibility: ShaderStages::COMPUTE,
562 ty: BindingType::Buffer {
563 ty: BufferBindingType::Storage { read_only: false },
564 has_dynamic_offset: false,
565 min_binding_size: None,
566 },
567 count: None,
568 },
569 ],
570 });
571
572 let pipeline_layout = self
573 .device
574 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
575 label: Some("vector-add-layout"),
576 bind_group_layouts: &[Some(&bgl)],
577 ..Default::default()
578 });
579
580 let pipeline = self
581 .device
582 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
583 label: Some("vector-add-pipeline"),
584 layout: Some(&pipeline_layout),
585 module: &shader_module,
586 entry_point: Some("main"),
587 compilation_options: Default::default(),
588 cache: None,
589 });
590
591 let a_bytes: Vec<u8> = a.iter().flat_map(|f| f.to_le_bytes()).collect();
593 let b_bytes: Vec<u8> = b.iter().flat_map(|f| f.to_le_bytes()).collect();
594 let result_size = std::mem::size_of_val(a) as u64;
595
596 let buf_a = self
597 .device
598 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
599 label: Some("vector-add-a"),
600 contents: &a_bytes,
601 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
602 });
603 let buf_b = self
604 .device
605 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
606 label: Some("vector-add-b"),
607 contents: &b_bytes,
608 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
609 });
610 let buf_result = self.device.create_buffer(&BufferDescriptor {
611 label: Some("vector-add-result"),
612 size: result_size,
613 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
614 mapped_at_creation: false,
615 });
616
617 let bind_group = self.device.create_bind_group(&BindGroupDescriptor {
619 label: Some("vector-add-bg"),
620 layout: &bgl,
621 entries: &[
622 BindGroupEntry {
623 binding: 0,
624 resource: buf_a.as_entire_binding(),
625 },
626 BindGroupEntry {
627 binding: 1,
628 resource: buf_b.as_entire_binding(),
629 },
630 BindGroupEntry {
631 binding: 2,
632 resource: buf_result.as_entire_binding(),
633 },
634 ],
635 });
636
637 let mut encoder = self
639 .device
640 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
641 label: Some("vector-add-encoder"),
642 });
643 {
644 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
645 label: Some("vector-add-pass"),
646 timestamp_writes: None,
647 });
648 cpass.set_pipeline(&pipeline);
649 cpass.set_bind_group(0, &bind_group, &[]);
650 let workgroups = (n as u32 + 63) / 64;
651 cpass.dispatch_workgroups(workgroups, 1, 1);
652 }
653
654 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
656 label: Some("vector-add-staging"),
657 size: result_size,
658 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
659 mapped_at_creation: false,
660 });
661 encoder.copy_buffer_to_buffer(&buf_result, 0, &staging, 0, result_size);
662 self.queue.submit(Some(encoder.finish()));
663
664 self.device
666 .poll(wgpu::PollType::wait_indefinitely())
667 .map_err(|e| GpuError::Other(format!("GPU poll error: {e:?}")))?;
668
669 let slice = staging.slice(0..result_size);
670 let (tx, rx) = std::sync::mpsc::channel();
671 slice.map_async(wgpu::MapMode::Read, move |r| {
672 let _ = tx.send(r);
673 });
674
675 self.device
677 .poll(wgpu::PollType::wait_indefinitely())
678 .map_err(|e| GpuError::Other(format!("GPU poll error during map: {e:?}")))?;
679
680 rx.recv()
681 .map_err(|_| GpuError::Other("Channel closed during map_async".into()))?
682 .map_err(|e| GpuError::Other(format!("map_async failed: {e:?}")))?;
683
684 let mapped = slice.get_mapped_range();
685 let result: Vec<f32> = mapped
686 .chunks_exact(4)
687 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
688 .collect();
689 drop(mapped);
690 staging.unmap();
691
692 Ok(result)
693 }
694
695 fn extract_entry_point(source: &str) -> Option<&str> {
697 let lines: Vec<&str> = source.lines().collect();
698
699 for (i, line) in lines.iter().enumerate() {
700 let trimmed = line.trim();
701
702 if trimmed.contains("@compute") {
704 let mut search_line = trimmed;
706 let mut search_idx = 0;
707
708 if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
710 search_idx += 1;
711 search_line = lines[search_idx].trim();
712 }
713
714 if let Some(start) = search_line.find("fn ") {
716 let remaining = &search_line[start + 3..];
717 if let Some(end) = remaining.find('(') {
718 let funcname = remaining[..end].trim();
719 return Some(funcname);
720 }
721 }
722 }
723 }
724
725 None
726 }
727}
728
729impl GpuContextImpl for WebGPUContext {
730 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
731 if let Ok(mut pool) = self.memory_pool.lock() {
733 if let Some(device_buffer) = pool.allocate(size) {
734 return Arc::new(WebGPUBuffer {
735 device_buffer: Some(device_buffer),
736 #[cfg(feature = "wgpu_backend")]
737 queue: Arc::clone(&self.queue),
738 #[cfg(feature = "wgpu_backend")]
739 device: Arc::clone(&self.device),
740 #[cfg(not(feature = "wgpu_backend"))]
741 queue: self.queue,
742 size,
743 memory_pool: Arc::clone(&self.memory_pool),
744 });
745 }
746 }
747
748 let device_buffer = match self.allocate_device_memory(size) {
750 Ok(buffer) => buffer,
751 Err(e) => {
752 eprintln!(
754 "Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
755 e
756 );
757
758 #[cfg(feature = "wgpu_backend")]
759 {
760 return Arc::new(WebGPUCpuFallbackBuffer {
763 data: vec![0u8; size],
764 size,
765 memory_pool: Arc::clone(&self.memory_pool),
766 });
767 }
768 #[cfg(not(feature = "wgpu_backend"))]
769 {
770 (0x2000 + size) as WgpuBuffer
771 }
772 }
773 };
774
775 Arc::new(WebGPUBuffer {
776 device_buffer: Some(device_buffer),
777 #[cfg(feature = "wgpu_backend")]
778 queue: Arc::clone(&self.queue),
779 #[cfg(feature = "wgpu_backend")]
780 device: Arc::clone(&self.device),
781 #[cfg(not(feature = "wgpu_backend"))]
782 queue: self.queue,
783 size,
784 memory_pool: Arc::clone(&self.memory_pool),
785 })
786 }
787
788 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
789 Arc::new(WebGPUCompiler {
790 context: Arc::new(WebGPUContext {
791 memory_pool: Arc::clone(&self.memory_pool),
792 compiled_shaders: Arc::clone(&self.compiled_shaders),
793 #[cfg(feature = "wgpu_backend")]
794 device: Arc::clone(&self.device),
795 #[cfg(feature = "wgpu_backend")]
796 queue: Arc::clone(&self.queue),
797 #[cfg(not(feature = "wgpu_backend"))]
798 device: Arc::clone(&self.device),
799 #[cfg(not(feature = "wgpu_backend"))]
800 queue: Arc::clone(&self.queue),
801 }),
802 })
803 }
804
805 fn as_any(&self) -> &dyn std::any::Any {
806 self
807 }
808}
809
810struct WebGPUShader {
812 #[cfg(feature = "wgpu_backend")]
813 pipeline: ComputePipeline,
814 #[cfg(not(feature = "wgpu_backend"))]
815 pipeline: WgpuComputePipeline,
816 #[cfg(feature = "wgpu_backend")]
817 #[allow(dead_code)]
818 bind_group_layout: BindGroupLayout,
819 #[cfg(not(feature = "wgpu_backend"))]
820 #[allow(dead_code)]
821 bind_group_layout: *mut std::ffi::c_void,
822 #[allow(dead_code)]
823 name: String,
824 #[allow(dead_code)]
825 binding_infos: Vec<BindingInfo>, #[allow(dead_code)]
827 workgroup_size: [u32; 3],
828}
829
830unsafe impl Send for WebGPUShader {}
832unsafe impl Sync for WebGPUShader {}
833
834struct WebGPUCompiler {
836 context: Arc<WebGPUContext>,
837}
838
839impl GpuCompilerImpl for WebGPUCompiler {
840 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
841 let shader = self.context.compile_shader_internal(source, "shader")?;
842 Ok(Arc::new(WebGPUKernelHandle {
843 shader_name: shader.name.clone(),
844 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
845 params: Arc::new(Mutex::new(HashMap::new())),
846 #[cfg(feature = "wgpu_backend")]
847 device: Arc::clone(&self.context.device),
848 #[cfg(feature = "wgpu_backend")]
849 queue: Arc::clone(&self.context.queue),
850 #[cfg(feature = "wgpu_backend")]
851 ephemeral_uniforms: Mutex::new(Vec::new()),
852 #[cfg(not(feature = "wgpu_backend"))]
853 device: self.context.device,
854 #[cfg(not(feature = "wgpu_backend"))]
855 queue: self.context.queue,
856 }))
857 }
858
859 fn compile_typed(
860 &self,
861 name: &str,
862 _input_type: std::any::TypeId,
863 _output_type: std::any::TypeId,
864 ) -> Arc<dyn GpuKernelImpl> {
865 Arc::new(WebGPUKernelHandle {
866 shader_name: name.to_string(),
867 compiled_shaders: Arc::clone(&self.context.compiled_shaders),
868 params: Arc::new(Mutex::new(HashMap::new())),
869 #[cfg(feature = "wgpu_backend")]
870 device: Arc::clone(&self.context.device),
871 #[cfg(feature = "wgpu_backend")]
872 queue: Arc::clone(&self.context.queue),
873 #[cfg(feature = "wgpu_backend")]
874 ephemeral_uniforms: Mutex::new(Vec::new()),
875 #[cfg(not(feature = "wgpu_backend"))]
876 device: self.context.device,
877 #[cfg(not(feature = "wgpu_backend"))]
878 queue: self.context.queue,
879 })
880 }
881}
882
883struct WebGPUKernelHandle {
885 shader_name: String,
886 compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
887 params: Arc<Mutex<HashMap<String, KernelParam>>>,
888 #[cfg(feature = "wgpu_backend")]
889 device: Arc<Device>,
890 #[cfg(feature = "wgpu_backend")]
891 queue: Arc<Queue>,
892 #[cfg(feature = "wgpu_backend")]
893 ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
894 #[cfg(not(feature = "wgpu_backend"))]
895 device: WgpuDevice,
896 #[cfg(not(feature = "wgpu_backend"))]
897 queue: WgpuQueue,
898}
899
900enum KernelParam {
901 #[allow(dead_code)]
902 Buffer(Arc<dyn GpuBufferImpl>),
903 #[allow(dead_code)]
904 U32(u32),
905 #[allow(dead_code)]
906 I32(i32),
907 #[allow(dead_code)]
908 F32(f32),
909 #[allow(dead_code)]
910 F64(f64),
911 Bytes(Vec<u8>),
912}
913
914#[derive(Clone, Debug)]
915enum BindingKind {
916 StorageRw,
917 StorageRead,
918 Uniform,
919}
920
921#[derive(Clone, Debug)]
922struct BindingInfo {
923 binding: u32,
924 name: String,
925 kind: BindingKind,
926}
927
928fn extract_workgroup_size(source: &str) -> [u32; 3] {
931 for line in source.lines() {
932 let trimmed = line.trim();
933 if let Some(start) = trimmed.find("@workgroup_size(") {
934 let after = &trimmed[start + "@workgroup_size(".len()..];
935 if let Some(end) = after.find(')') {
936 let inner = &after[..end];
937 let parts: Vec<u32> = inner
938 .split(',')
939 .filter_map(|s| s.trim().parse::<u32>().ok())
940 .collect();
941 return match parts.as_slice() {
942 [x] => [*x, 1, 1],
943 [x, y] => [*x, *y, 1],
944 [x, y, z, ..] => [*x, *y, *z],
945 _ => [64, 1, 1],
946 };
947 }
948 }
949 }
950 [64, 1, 1]
951}
952
953fn extract_var_name(line: &str) -> Option<&str> {
954 if let Some(var_start) = line.find("var<") {
955 let after_var = &line[var_start..];
956 if let Some(close) = after_var.find('>') {
957 let after = &after_var[close + 1..];
958 let after = after.trim_start();
959 if let Some(colon) = after.find(':') {
960 let name_part = after[..colon].trim();
961 if !name_part.is_empty() {
962 return Some(name_part);
963 }
964 }
965 }
966 }
967 None
968}
969
970impl GpuKernelImpl for WebGPUKernelHandle {
971 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
972 if let Ok(mut params) = self.params.lock() {
973 params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
974 }
975 }
976
977 fn set_u32(&self, name: &str, value: u32) {
978 if let Ok(mut params) = self.params.lock() {
979 params.insert(name.to_string(), KernelParam::U32(value));
980 }
981 }
982
983 fn set_i32(&self, name: &str, value: i32) {
984 if let Ok(mut params) = self.params.lock() {
985 params.insert(name.to_string(), KernelParam::I32(value));
986 }
987 }
988
989 fn set_f32(&self, name: &str, value: f32) {
990 if let Ok(mut params) = self.params.lock() {
991 params.insert(name.to_string(), KernelParam::F32(value));
992 }
993 }
994
995 fn set_f64(&self, name: &str, value: f64) {
996 if let Ok(mut params) = self.params.lock() {
997 params.insert(name.to_string(), KernelParam::F64(value));
998 }
999 }
1000
1001 #[allow(dead_code)]
1002 fn dispatch(&self, workgroups: [u32; 3]) {
1005 #[cfg(feature = "wgpu_backend")]
1006 {
1007 let shaders = match self.compiled_shaders.lock() {
1009 Ok(g) => g,
1010 Err(_) => return,
1011 };
1012 if let Some(shader) = shaders.get(&self.shader_name) {
1013 let params = match self.params.lock() {
1014 Ok(g) => g,
1015 Err(_) => return,
1016 };
1017
1018 let mut encoder =
1020 self.device
1021 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1022 label: Some("Compute Command Encoder"),
1023 });
1024
1025 {
1027 let mut compute_pass =
1028 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1029 label: Some("Compute Pass"),
1030 timestamp_writes: None,
1031 });
1032
1033 compute_pass.set_pipeline(&shader.pipeline);
1035
1036 if let Ok(bind_group) = self.create_bind_group_from_params(shader, ¶ms) {
1037 compute_pass.set_bind_group(0, &bind_group, &[]);
1038 }
1039 compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
1043 }
1044
1045 let command_buffer = encoder.finish();
1047 self.queue.submit(std::iter::once(command_buffer));
1048 }
1049 }
1050 #[cfg(not(feature = "wgpu_backend"))]
1051 {
1052 let _ = workgroups;
1054 let _ = &self.shader_name;
1055 }
1056 }
1057}
1058
1059struct WebGPUBuffer {
1061 #[cfg(feature = "wgpu_backend")]
1062 device_buffer: Option<Buffer>,
1063 #[cfg(feature = "wgpu_backend")]
1064 queue: Arc<Queue>,
1065 #[cfg(feature = "wgpu_backend")]
1066 device: Arc<Device>,
1067 #[cfg(not(feature = "wgpu_backend"))]
1068 device_buffer: Option<WgpuBuffer>,
1069 #[cfg(not(feature = "wgpu_backend"))]
1070 queue: WgpuQueue,
1071 size: usize,
1072 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1073}
1074
1075unsafe impl Send for WebGPUBuffer {}
1079unsafe impl Sync for WebGPUBuffer {}
1080
1081impl GpuBufferImpl for WebGPUBuffer {
1082 fn size(&self) -> usize {
1083 self.size
1084 }
1085
1086 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1087 #[cfg(feature = "wgpu_backend")]
1088 {
1089 if size > self.size {
1091 eprintln!(
1093 "Warning: Data size {} exceeds buffer size {}",
1094 size, self.size
1095 );
1096 return;
1097 }
1098
1099 let data_slice = std::slice::from_raw_parts(data, size);
1101
1102 if let Some(ref buffer) = self.device_buffer {
1104 self.queue.write_buffer(buffer, 0, data_slice);
1105 }
1106 }
1107 #[cfg(not(feature = "wgpu_backend"))]
1108 {
1109 if size > self.size {
1111 eprintln!(
1112 "Warning: Data size {} exceeds buffer size {}",
1113 size, self.size
1114 );
1115 }
1116 }
1118 }
1119
1120 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1121 #[cfg(feature = "wgpu_backend")]
1122 {
1123 if size > self.size {
1125 eprintln!(
1126 "Warning: Data size {} exceeds buffer size {}",
1127 size, self.size
1128 );
1129 return;
1130 }
1131
1132 if let Some(ref buffer) = self.device_buffer {
1133 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
1134 label: Some("scirs2-readback"),
1135 size: size as u64,
1136 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1137 mapped_at_creation: false,
1138 });
1139 let mut encoder =
1140 self.device
1141 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1142 label: Some("scirs2-readback-enc"),
1143 });
1144 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
1145 self.queue.submit(Some(encoder.finish()));
1146
1147 let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
1151
1152 let slice = staging.slice(0..size as u64);
1153 let (tx, rx) = std::sync::mpsc::channel();
1154 slice.map_async(wgpu::MapMode::Read, move |r| {
1155 let _ = tx.send(r);
1156 });
1157 let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
1159 if let Ok(Ok(())) = rx.recv() {
1160 let mapped = slice.get_mapped_range();
1161 let dst = std::slice::from_raw_parts_mut(data, size);
1162 dst.copy_from_slice(&mapped);
1163 drop(mapped);
1164 staging.unmap();
1165 } else {
1166 eprintln!("Warning: map_async failed for readback");
1167 }
1168 }
1169 }
1170 #[cfg(not(feature = "wgpu_backend"))]
1171 {
1172 if size > self.size {
1174 eprintln!(
1175 "Warning: Data size {} exceeds buffer size {}",
1176 size, self.size
1177 );
1178 }
1179
1180 let data_slice = std::slice::from_raw_parts_mut(data, size);
1182 data_slice.fill(0);
1183 }
1184 }
1185
1186 fn device_ptr(&self) -> u64 {
1187 #[cfg(feature = "wgpu_backend")]
1188 {
1189 &self.device_buffer as *const _ as u64
1192 }
1193 #[cfg(not(feature = "wgpu_backend"))]
1194 {
1195 self.device_buffer as u64
1196 }
1197 }
1198
1199 fn as_any(&self) -> &dyn std::any::Any {
1200 self
1201 }
1202}
1203
1204#[cfg(feature = "wgpu_backend")]
1205impl WebGPUKernelHandle {
1206 fn create_bind_group_from_params(
1207 &self,
1208 shader: &WebGPUShader,
1209 params: &HashMap<String, KernelParam>,
1210 ) -> Result<wgpu::BindGroup, GpuError> {
1211 let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
1212 let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
1214 let mut uniform_bytes: Vec<u8> = Vec::new();
1215 for info in &shader.binding_infos {
1216 match info.kind {
1217 BindingKind::StorageRw | BindingKind::StorageRead => {
1218 if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
1219 if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
1220 if let Some(ref inner) = wbuf.device_buffer {
1221 entries.push(wgpu::BindGroupEntry {
1222 binding: info.binding,
1223 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1224 buffer: inner,
1225 offset: 0,
1226 size: None,
1227 }),
1228 });
1229 }
1230 }
1231 } else {
1232 return Err(GpuError::InvalidParameter(format!(
1233 "Missing buffer param '{}'",
1234 info.name
1235 )));
1236 }
1237 }
1238 BindingKind::Uniform => {
1239 for (k, v) in params.iter() {
1241 if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
1242 match v {
1243 KernelParam::U32(u) => {
1244 uniform_bytes.extend_from_slice(&u.to_le_bytes())
1245 }
1246 KernelParam::I32(i) => {
1247 uniform_bytes.extend_from_slice(&i.to_le_bytes())
1248 }
1249 KernelParam::F32(f) => {
1250 uniform_bytes.extend_from_slice(&f.to_le_bytes())
1251 }
1252 KernelParam::F64(f) => {
1253 uniform_bytes.extend_from_slice(&f.to_le_bytes())
1254 }
1255 KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
1256 KernelParam::Buffer(_) => {}
1257 }
1258 }
1259 }
1260 }
1261 }
1262 }
1263 if !uniform_bytes.is_empty() {
1264 while uniform_bytes.len() % 16 != 0 {
1265 uniform_bytes.push(0);
1266 }
1267 if let Some(uinfo) = shader
1268 .binding_infos
1269 .iter()
1270 .find(|b| matches!(b.kind, BindingKind::Uniform))
1271 {
1272 if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1273 list.clear();
1274 let ubuf = self
1275 .device
1276 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1277 label: Some("scirs2-uniforms"),
1278 contents: &uniform_bytes,
1279 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1280 });
1281 list.push(ubuf.clone());
1282 owned_uniform_buffers.push(ubuf.clone());
1283 let idx = owned_uniform_buffers.len() - 1;
1284 let buf_ref = &owned_uniform_buffers[idx];
1285 entries.push(wgpu::BindGroupEntry {
1286 binding: uinfo.binding,
1287 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
1288 buffer: buf_ref,
1289 offset: 0,
1290 size: None,
1291 }),
1292 });
1293 }
1294 }
1295 } else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
1296 list.clear();
1297 }
1298 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1299 label: Some("scirs2-bind-group"),
1300 layout: &shader.bind_group_layout,
1301 entries: &entries,
1302 });
1303 Ok(bind_group)
1304 }
1305}
1306
1307impl Drop for WebGPUBuffer {
1308 fn drop(&mut self) {
1309 if let Ok(mut pool) = self.memory_pool.lock() {
1311 #[cfg(feature = "wgpu_backend")]
1312 {
1313 if let Some(buffer) = self.device_buffer.take() {
1315 pool.deallocate(buffer);
1316 }
1317 }
1318 #[cfg(not(feature = "wgpu_backend"))]
1319 {
1320 if let Some(buffer) = self.device_buffer.take() {
1321 pool.deallocate(buffer);
1322 }
1323 }
1324 }
1325 }
1326}
1327
1328struct WebGPUCpuFallbackBuffer {
1331 data: Vec<u8>,
1332 size: usize,
1333 #[allow(dead_code)]
1334 memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
1335}
1336
1337impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
1338 fn size(&self) -> usize {
1339 self.size
1340 }
1341
1342 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
1343 if size > self.size {
1344 eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
1345 return;
1346 }
1347
1348 let data_slice = std::slice::from_raw_parts(data, size);
1350 eprintln!(
1353 "Warning: CPU fallback buffer copy_from_host called (size: {})",
1354 size
1355 );
1356 }
1357
1358 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
1359 if size > self.size {
1360 eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
1361 return;
1362 }
1363
1364 let data_slice = std::slice::from_raw_parts_mut(data, size);
1366 let copy_size = size.min(self.data.len());
1367 data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
1368
1369 eprintln!(
1370 "Warning: CPU fallback buffer copy_to_host called (size: {})",
1371 size
1372 );
1373 }
1374
1375 fn device_ptr(&self) -> u64 {
1376 self.data.as_ptr() as u64
1377 }
1378
1379 fn as_any(&self) -> &dyn std::any::Any {
1380 self
1381 }
1382}
1383
1384unsafe impl Send for WebGPUCpuFallbackBuffer {}
1386unsafe impl Sync for WebGPUCpuFallbackBuffer {}
1387
1388struct WebGPUMemoryPool {
1390 #[cfg(feature = "wgpu_backend")]
1391 available_buffers: HashMap<usize, Vec<Buffer>>,
1392 #[cfg(not(feature = "wgpu_backend"))]
1393 available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
1394 #[allow(dead_code)]
1395 total_size: usize,
1396 used_size: usize,
1397}
1398
1399impl WebGPUMemoryPool {
1400 fn new(totalsize: usize) -> Self {
1401 Self {
1402 available_buffers: HashMap::new(),
1403 total_size: totalsize,
1404 used_size: 0,
1405 }
1406 }
1407
1408 #[cfg(feature = "wgpu_backend")]
1409 fn allocate(&mut self, size: usize) -> Option<Buffer> {
1410 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1412 if let Some(buffer) = buffers.pop() {
1413 self.used_size += size;
1414 return Some(buffer);
1415 }
1416 }
1417 None
1418 }
1419
1420 #[cfg(not(feature = "wgpu_backend"))]
1421 fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
1422 if let Some(buffers) = self.available_buffers.get_mut(&size) {
1424 if let Some(buffer) = buffers.pop() {
1425 self.used_size += size;
1426 return Some(buffer);
1427 }
1428 }
1429 None
1430 }
1431
1432 #[cfg(feature = "wgpu_backend")]
1433 fn deallocate(&mut self, buffer: Buffer) {
1434 let size = buffer.size() as usize;
1436 self.available_buffers
1437 .entry(size)
1438 .or_insert_with(Vec::new)
1439 .push(buffer);
1440 self.used_size = self.used_size.saturating_sub(size);
1441 }
1442
1443 #[cfg(not(feature = "wgpu_backend"))]
1444 fn deallocate(&mut self, buffer: WgpuBuffer) {
1445 let size = 1024; self.available_buffers
1448 .entry(size)
1449 .or_insert_with(Vec::new)
1450 .push(buffer);
1451 self.used_size = self.used_size.saturating_sub(size);
1452 }
1453
1454 #[allow(dead_code)]
1455 fn get_memory_usage(&self) -> (usize, usize) {
1456 (self.used_size, self.total_size)
1457 }
1458}