1#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
7use std::borrow::Cow;
8
9#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
10use std::sync::mpsc;
11
12#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
13use wgpu::util::DeviceExt;
14
15use super::backend_report::{GpuActivationReport, GpuActivationState};
16
17#[derive(Debug, Clone)]
19pub struct ComputeCapabilities {
20 pub backend: String,
21 pub max_workgroup_size_x: u32,
22 pub max_workgroup_size_y: u32,
23 pub max_workgroup_invocations: u32,
24 pub max_storage_buffer_size: u64,
25 pub gpu_available: bool,
26}
27
28impl Default for ComputeCapabilities {
29 fn default() -> Self {
30 Self {
31 backend: "CPU-fallback".to_string(),
32 max_workgroup_size_x: 256,
33 max_workgroup_size_y: 256,
34 max_workgroup_invocations: 256,
35 max_storage_buffer_size: u64::MAX,
36 gpu_available: false,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ComputeBindingKind {
44 Uniform,
45 StorageReadOnly,
46 StorageReadWrite,
47}
48
49#[derive(Debug, Clone)]
51pub struct ComputeBindingDescriptor {
52 pub label: String,
53 pub kind: ComputeBindingKind,
54 pub bytes: Vec<u8>,
55}
56
57#[derive(Debug, Clone)]
59pub struct ComputeDispatchDescriptor {
60 pub label: String,
61 pub shader_source: String,
62 pub entry_point: String,
63 pub workgroup_count: [u32; 3],
64 pub bindings: Vec<ComputeBindingDescriptor>,
65}
66
67#[derive(Debug, Clone)]
69pub struct ComputeDispatchResult {
70 pub backend: String,
71 pub outputs: Vec<Vec<u8>>,
72}
73
74pub struct GpuContext {
76 pub capabilities: ComputeCapabilities,
77 runtime_error: Option<String>,
78 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
79 runtime: Option<NativeGpuRuntime>,
80}
81
82#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
83struct NativeGpuRuntime {
84 _instance: wgpu::Instance,
85 _adapter: wgpu::Adapter,
86 device: wgpu::Device,
87 queue: wgpu::Queue,
88}
89
90impl std::fmt::Debug for GpuContext {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("GpuContext")
93 .field("capabilities", &self.capabilities)
94 .field("runtime_error", &self.runtime_error)
95 .finish()
96 }
97}
98
99impl GpuContext {
100 pub fn cpu_fallback() -> Self {
102 Self {
103 capabilities: ComputeCapabilities::default(),
104 runtime_error: None,
105 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
106 runtime: None,
107 }
108 }
109
110 pub fn try_create() -> Result<Self, String> {
112 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
113 {
114 let instance = wgpu::Instance::default();
115 let adapter =
116 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
117 power_preference: wgpu::PowerPreference::HighPerformance,
118 compatible_surface: None,
119 force_fallback_adapter: false,
120 }))
121 .ok_or_else(|| "No GPU adapter found".to_string())?;
122
123 let adapter_info = adapter.get_info();
124 let limits = adapter.limits();
125 let required_limits = wgpu::Limits::default().using_resolution(limits.clone());
126
127 let (device, queue) = pollster::block_on(adapter.request_device(
128 &wgpu::DeviceDescriptor {
129 label: Some("sci-form gpu"),
130 required_features: wgpu::Features::empty(),
131 required_limits,
132 },
133 None,
134 ))
135 .map_err(|err| format!("Failed to create wgpu device: {err}"))?;
136
137 Ok(Self {
138 capabilities: ComputeCapabilities {
139 backend: format!("{:?}", adapter_info.backend),
140 max_workgroup_size_x: limits.max_compute_workgroup_size_x,
141 max_workgroup_size_y: limits.max_compute_workgroup_size_y,
142 max_workgroup_invocations: limits.max_compute_invocations_per_workgroup,
143 max_storage_buffer_size: limits.max_storage_buffer_binding_size as u64,
144 gpu_available: true,
145 },
146 runtime_error: None,
147 runtime: Some(NativeGpuRuntime {
148 _instance: instance,
149 _adapter: adapter,
150 device,
151 queue,
152 }),
153 })
154 }
155
156 #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
157 {
158 Err("experimental-gpu feature not enabled".to_string())
159 }
160 }
161
162 pub fn best_available() -> Self {
164 match Self::try_create() {
165 Ok(ctx) => ctx,
166 Err(reason) => {
167 let mut ctx = Self::cpu_fallback();
168 ctx.runtime_error = Some(reason);
169 ctx
170 }
171 }
172 }
173
174 pub fn activation_report(&self) -> GpuActivationReport {
176 if self.capabilities.gpu_available {
177 GpuActivationReport {
178 backend: self.capabilities.backend.clone(),
179 feature_enabled: true,
180 gpu_available: true,
181 runtime_ready: true,
182 state: GpuActivationState::Ready,
183 reason: "GPU runtime available".to_string(),
184 }
185 } else if cfg!(feature = "experimental-gpu") {
186 GpuActivationReport {
187 backend: self.capabilities.backend.clone(),
188 feature_enabled: true,
189 gpu_available: false,
190 runtime_ready: false,
191 state: GpuActivationState::NoAdapter,
192 reason: self
193 .runtime_error
194 .clone()
195 .unwrap_or_else(|| "experimental-gpu enabled but no adapter found".to_string()),
196 }
197 } else {
198 GpuActivationReport {
199 backend: "CPU-fallback".to_string(),
200 feature_enabled: false,
201 gpu_available: false,
202 runtime_ready: false,
203 state: GpuActivationState::FeatureDisabled,
204 reason: "experimental-gpu feature not enabled".to_string(),
205 }
206 }
207 }
208
209 pub fn is_gpu_available(&self) -> bool {
211 self.capabilities.gpu_available
212 }
213
214 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
215 fn runtime(&self) -> Result<&NativeGpuRuntime, String> {
216 self.runtime.as_ref().ok_or_else(|| {
217 self.runtime_error
218 .clone()
219 .unwrap_or_else(|| "GPU runtime not initialized".to_string())
220 })
221 }
222
223 pub fn run_compute(
225 &self,
226 descriptor: &ComputeDispatchDescriptor,
227 ) -> Result<ComputeDispatchResult, String> {
228 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
229 {
230 let runtime = self.runtime()?;
231
232 let shader = runtime
233 .device
234 .create_shader_module(wgpu::ShaderModuleDescriptor {
235 label: Some(&descriptor.label),
236 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&descriptor.shader_source)),
237 });
238
239 let mut layout_entries = Vec::with_capacity(descriptor.bindings.len());
240 let mut buffers = Vec::with_capacity(descriptor.bindings.len());
241 let mut readbacks = Vec::new();
242
243 for (index, binding) in descriptor.bindings.iter().enumerate() {
244 let usage = match binding.kind {
245 ComputeBindingKind::Uniform => {
246 wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST
247 }
248 ComputeBindingKind::StorageReadOnly => {
249 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST
250 }
251 ComputeBindingKind::StorageReadWrite => {
252 wgpu::BufferUsages::STORAGE
253 | wgpu::BufferUsages::COPY_DST
254 | wgpu::BufferUsages::COPY_SRC
255 }
256 };
257
258 let buffer = runtime
259 .device
260 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
261 label: Some(&binding.label),
262 contents: &binding.bytes,
263 usage,
264 });
265
266 layout_entries.push(wgpu::BindGroupLayoutEntry {
267 binding: index as u32,
268 visibility: wgpu::ShaderStages::COMPUTE,
269 ty: wgpu::BindingType::Buffer {
270 ty: match binding.kind {
271 ComputeBindingKind::Uniform => wgpu::BufferBindingType::Uniform,
272 ComputeBindingKind::StorageReadOnly => {
273 wgpu::BufferBindingType::Storage { read_only: true }
274 }
275 ComputeBindingKind::StorageReadWrite => {
276 wgpu::BufferBindingType::Storage { read_only: false }
277 }
278 },
279 has_dynamic_offset: false,
280 min_binding_size: None,
281 },
282 count: None,
283 });
284 if matches!(binding.kind, ComputeBindingKind::StorageReadWrite) {
285 readbacks.push((buffers.len(), binding.bytes.len()));
286 }
287 buffers.push(buffer);
288 }
289
290 let bind_group_entries: Vec<_> = buffers
291 .iter()
292 .enumerate()
293 .map(|(index, buffer)| wgpu::BindGroupEntry {
294 binding: index as u32,
295 resource: buffer.as_entire_binding(),
296 })
297 .collect();
298
299 let bind_group_layout =
300 runtime
301 .device
302 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
303 label: Some(&format!("{} layout", descriptor.label)),
304 entries: &layout_entries,
305 });
306 let pipeline_layout =
307 runtime
308 .device
309 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
310 label: Some(&format!("{} pipeline", descriptor.label)),
311 bind_group_layouts: &[&bind_group_layout],
312 push_constant_ranges: &[],
313 });
314 let pipeline =
315 runtime
316 .device
317 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
318 label: Some(&descriptor.label),
319 layout: Some(&pipeline_layout),
320 module: &shader,
321 entry_point: &descriptor.entry_point,
322 });
323 let bind_group = runtime
324 .device
325 .create_bind_group(&wgpu::BindGroupDescriptor {
326 label: Some(&format!("{} bind group", descriptor.label)),
327 layout: &bind_group_layout,
328 entries: &bind_group_entries,
329 });
330
331 let mut encoder =
332 runtime
333 .device
334 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
335 label: Some(&format!("{} encoder", descriptor.label)),
336 });
337 {
338 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
339 label: Some(&format!("{} pass", descriptor.label)),
340 timestamp_writes: None,
341 });
342 pass.set_pipeline(&pipeline);
343 pass.set_bind_group(0, &bind_group, &[]);
344 pass.dispatch_workgroups(
345 descriptor.workgroup_count[0],
346 descriptor.workgroup_count[1],
347 descriptor.workgroup_count[2],
348 );
349 }
350
351 let mut staging_buffers = Vec::with_capacity(readbacks.len());
352 for (buffer_index, size_bytes) in &readbacks {
353 let staging = runtime.device.create_buffer(&wgpu::BufferDescriptor {
354 label: Some("readback staging"),
355 size: *size_bytes as u64,
356 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
357 mapped_at_creation: false,
358 });
359 encoder.copy_buffer_to_buffer(
360 &buffers[*buffer_index],
361 0,
362 &staging,
363 0,
364 *size_bytes as u64,
365 );
366 staging_buffers.push(staging);
367 }
368
369 runtime.queue.submit(Some(encoder.finish()));
370 runtime.device.poll(wgpu::Maintain::Wait);
371
372 let mut outputs = Vec::with_capacity(staging_buffers.len());
373 for staging in staging_buffers {
374 let slice = staging.slice(..);
375 let (sender, receiver) = mpsc::channel();
376 slice.map_async(wgpu::MapMode::Read, move |result| {
377 let _ = sender.send(result);
378 });
379 runtime.device.poll(wgpu::Maintain::Wait);
380 receiver
381 .recv()
382 .map_err(|_| "GPU readback channel error".to_string())?
383 .map_err(|err| format!("GPU buffer map failed: {err}"))?;
384
385 let bytes = slice.get_mapped_range().to_vec();
386 staging.unmap();
387 outputs.push(bytes);
388 }
389
390 Ok(ComputeDispatchResult {
391 backend: self.capabilities.backend.clone(),
392 outputs,
393 })
394 }
395
396 #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
397 {
398 let _ = descriptor;
399 Err("experimental-gpu feature not enabled".to_string())
400 }
401 }
402
403 pub fn validate_shader(&self, label: &str, source: &str) -> Result<String, String> {
405 #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
406 {
407 let runtime = self.runtime()?;
408 let _module = runtime
410 .device
411 .create_shader_module(wgpu::ShaderModuleDescriptor {
412 label: Some(label),
413 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
414 });
415 Ok(format!(
416 "Shader '{}' compiled on {}",
417 label, self.capabilities.backend
418 ))
419 }
420
421 #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
422 {
423 let _ = (label, source);
424 Err("experimental-gpu feature not enabled".to_string())
425 }
426 }
427
428 pub fn vector_add_f32(&self, lhs: &[f32], rhs: &[f32]) -> Result<Vec<f32>, String> {
430 if lhs.len() != rhs.len() {
431 return Err("Vectors must have the same length".to_string());
432 }
433
434 let params = VectorAddParams {
435 len: lhs.len() as u32,
436 _pad: [0; 3],
437 };
438 let dispatch = (lhs.len() as u32).div_ceil(64);
439 let output_seed = vec![0.0f32; lhs.len()];
440
441 let descriptor = ComputeDispatchDescriptor {
442 label: "vector add".to_string(),
443 shader_source: VECTOR_ADD_SHADER.to_string(),
444 entry_point: "main".to_string(),
445 workgroup_count: [dispatch.max(1), 1, 1],
446 bindings: vec![
447 ComputeBindingDescriptor {
448 label: "lhs".to_string(),
449 kind: ComputeBindingKind::StorageReadOnly,
450 bytes: f32_slice_to_bytes(lhs),
451 },
452 ComputeBindingDescriptor {
453 label: "rhs".to_string(),
454 kind: ComputeBindingKind::StorageReadOnly,
455 bytes: f32_slice_to_bytes(rhs),
456 },
457 ComputeBindingDescriptor {
458 label: "params".to_string(),
459 kind: ComputeBindingKind::Uniform,
460 bytes: vector_add_params_to_bytes(¶ms),
461 },
462 ComputeBindingDescriptor {
463 label: "output".to_string(),
464 kind: ComputeBindingKind::StorageReadWrite,
465 bytes: f32_slice_to_bytes(&output_seed),
466 },
467 ],
468 };
469
470 let mut result = self.run_compute(&descriptor)?.outputs;
471 let bytes = result.pop().ok_or("No output from GPU kernel")?;
472 Ok(bytes_to_f32_vec(&bytes))
473 }
474}
475
476#[repr(C)]
479#[derive(Debug, Clone, Copy)]
480struct VectorAddParams {
481 len: u32,
482 _pad: [u32; 3],
483}
484
485pub fn f32_slice_to_bytes(values: &[f32]) -> Vec<u8> {
486 let mut bytes = Vec::with_capacity(values.len() * 4);
487 for v in values {
488 bytes.extend_from_slice(&v.to_ne_bytes());
489 }
490 bytes
491}
492
493pub fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
494 bytes
495 .chunks_exact(4)
496 .map(|c| f32::from_ne_bytes(c.try_into().expect("4 bytes")))
497 .collect()
498}
499
500#[derive(Debug, Clone, Copy)]
501pub enum UniformValue {
502 U32(u32),
503 F32(f32),
504}
505
506pub fn pack_uniform_values(values: &[UniformValue]) -> Vec<u8> {
507 let mut bytes = Vec::with_capacity(values.len() * 4);
508 for value in values {
509 match value {
510 UniformValue::U32(word) => bytes.extend_from_slice(&word.to_ne_bytes()),
511 UniformValue::F32(word) => bytes.extend_from_slice(&word.to_ne_bytes()),
512 }
513 }
514 bytes
515}
516
517pub fn pack_vec3_positions_f32(positions: &[[f64; 3]]) -> Vec<u8> {
518 let mut bytes = Vec::with_capacity(positions.len() * 16);
519 for position in positions {
520 bytes.extend_from_slice(&(position[0] as f32).to_ne_bytes());
521 bytes.extend_from_slice(&(position[1] as f32).to_ne_bytes());
522 bytes.extend_from_slice(&(position[2] as f32).to_ne_bytes());
523 bytes.extend_from_slice(&0.0f32.to_ne_bytes());
524 }
525 bytes
526}
527
528pub fn bytes_to_f64_vec_from_f32(bytes: &[u8]) -> Vec<f64> {
529 bytes_to_f32_vec(bytes)
530 .into_iter()
531 .map(|value| value as f64)
532 .collect()
533}
534
535pub fn ceil_div_u32(value: usize, chunk: u32) -> u32 {
536 (value as u32).div_ceil(chunk)
537}
538
539fn vector_add_params_to_bytes(params: &VectorAddParams) -> Vec<u8> {
540 let mut bytes = Vec::with_capacity(16);
541 bytes.extend_from_slice(¶ms.len.to_ne_bytes());
542 for v in params._pad {
543 bytes.extend_from_slice(&v.to_ne_bytes());
544 }
545 bytes
546}
547
548const VECTOR_ADD_SHADER: &str = r#"
549struct Params {
550 len: u32, _pad0: u32, _pad1: u32, _pad2: u32,
551};
552
553@group(0) @binding(0) var<storage, read> lhs: array<f32>;
554@group(0) @binding(1) var<storage, read> rhs: array<f32>;
555@group(0) @binding(2) var<uniform> params: Params;
556@group(0) @binding(3) var<storage, read_write> out: array<f32>;
557
558@compute @workgroup_size(64, 1, 1)
559fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
560 let idx = gid.x;
561 if (idx >= params.len) { return; }
562 out[idx] = lhs[idx] + rhs[idx];
563}
564"#;
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_cpu_fallback_creation() {
572 let ctx = GpuContext::cpu_fallback();
573 assert!(!ctx.is_gpu_available());
574 assert_eq!(ctx.capabilities.backend, "CPU-fallback");
575 }
576
577 #[test]
578 fn test_best_available_never_panics() {
579 let ctx = GpuContext::best_available();
580 let report = ctx.activation_report();
581 assert!(!report.backend.is_empty());
582 }
583
584 #[test]
585 fn test_activation_report_feature_disabled() {
586 let ctx = GpuContext::cpu_fallback();
587 let report = ctx.activation_report();
588 if !cfg!(feature = "experimental-gpu") {
589 assert_eq!(report.state, GpuActivationState::FeatureDisabled);
590 assert!(!report.feature_enabled);
591 }
592 }
593
594 #[test]
595 fn test_compute_capabilities_default() {
596 let caps = ComputeCapabilities::default();
597 assert!(!caps.gpu_available);
598 assert!(caps.max_workgroup_size_x > 0);
599 }
600
601 #[test]
602 fn test_f32_roundtrip() {
603 let values = vec![1.0f32, 2.5, -std::f32::consts::PI, 0.0];
604 let bytes = f32_slice_to_bytes(&values);
605 let result = bytes_to_f32_vec(&bytes);
606 assert_eq!(values, result);
607 }
608
609 #[test]
610 fn test_uniform_word_packing() {
611 let bytes = pack_uniform_values(&[
612 UniformValue::U32(7),
613 UniformValue::F32(1.5),
614 UniformValue::U32(9),
615 UniformValue::F32(-2.0),
616 ]);
617 assert_eq!(bytes.len(), 16);
618 assert_eq!(u32::from_ne_bytes(bytes[0..4].try_into().unwrap()), 7);
619 assert!((f32::from_ne_bytes(bytes[4..8].try_into().unwrap()) - 1.5).abs() < 1e-6);
620 assert_eq!(u32::from_ne_bytes(bytes[8..12].try_into().unwrap()), 9);
621 assert!((f32::from_ne_bytes(bytes[12..16].try_into().unwrap()) + 2.0).abs() < 1e-6);
622 }
623
624 #[test]
625 fn test_pack_vec3_positions_f32_layout() {
626 let bytes = pack_vec3_positions_f32(&[[1.0, -2.0, 3.5]]);
627 assert_eq!(bytes.len(), 16);
628 assert!((f32::from_ne_bytes(bytes[0..4].try_into().unwrap()) - 1.0).abs() < 1e-6);
629 assert!((f32::from_ne_bytes(bytes[4..8].try_into().unwrap()) + 2.0).abs() < 1e-6);
630 assert!((f32::from_ne_bytes(bytes[8..12].try_into().unwrap()) - 3.5).abs() < 1e-6);
631 assert_eq!(f32::from_ne_bytes(bytes[12..16].try_into().unwrap()), 0.0);
632 }
633}