Skip to main content

sci_form/gpu/
context.rs

1//! GPU compute context — wgpu device initialization with CPU fallback.
2//!
3//! Wraps `wgpu::Device` + `Queue` for native GPU compute, or falls back
4//! to CPU-only mode when no GPU is available or the feature is disabled.
5
6#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
7use std::borrow::Cow;
8
9#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
10use std::sync::mpsc;
11
12#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
13use wgpu::util::DeviceExt;
14
15use super::backend_report::{GpuActivationReport, GpuActivationState};
16
17/// Capabilities detected from the GPU adapter.
18#[derive(Debug, Clone)]
19pub struct ComputeCapabilities {
20    pub backend: String,
21    pub max_workgroup_size_x: u32,
22    pub max_workgroup_size_y: u32,
23    pub max_workgroup_invocations: u32,
24    pub max_storage_buffer_size: u64,
25    pub gpu_available: bool,
26}
27
28impl Default for ComputeCapabilities {
29    fn default() -> Self {
30        Self {
31            backend: "CPU-fallback".to_string(),
32            max_workgroup_size_x: 256,
33            max_workgroup_size_y: 256,
34            max_workgroup_invocations: 256,
35            max_storage_buffer_size: u64::MAX,
36            gpu_available: false,
37        }
38    }
39}
40
41/// Buffer access mode for compute bindings.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum ComputeBindingKind {
44    Uniform,
45    StorageReadOnly,
46    StorageReadWrite,
47}
48
49/// One logical binding passed to a compute kernel.
50#[derive(Debug, Clone)]
51pub struct ComputeBindingDescriptor {
52    pub label: String,
53    pub kind: ComputeBindingKind,
54    pub bytes: Vec<u8>,
55}
56
57/// Generic compute dispatch description.
58#[derive(Debug, Clone)]
59pub struct ComputeDispatchDescriptor {
60    pub label: String,
61    pub shader_source: String,
62    pub entry_point: String,
63    pub workgroup_count: [u32; 3],
64    pub bindings: Vec<ComputeBindingDescriptor>,
65}
66
67/// Output produced by a compute dispatch.
68#[derive(Debug, Clone)]
69pub struct ComputeDispatchResult {
70    pub backend: String,
71    pub outputs: Vec<Vec<u8>>,
72}
73
74/// GPU compute context handle.
75pub struct GpuContext {
76    pub capabilities: ComputeCapabilities,
77    runtime_error: Option<String>,
78    #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
79    runtime: Option<NativeGpuRuntime>,
80}
81
82#[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
83struct NativeGpuRuntime {
84    _instance: wgpu::Instance,
85    _adapter: wgpu::Adapter,
86    device: wgpu::Device,
87    queue: wgpu::Queue,
88}
89
90impl std::fmt::Debug for GpuContext {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("GpuContext")
93            .field("capabilities", &self.capabilities)
94            .field("runtime_error", &self.runtime_error)
95            .finish()
96    }
97}
98
99impl GpuContext {
100    /// Create a CPU-fallback context (always available).
101    pub fn cpu_fallback() -> Self {
102        Self {
103            capabilities: ComputeCapabilities::default(),
104            runtime_error: None,
105            #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
106            runtime: None,
107        }
108    }
109
110    /// Attempt to initialize a real wgpu compute device.
111    pub fn try_create() -> Result<Self, String> {
112        #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
113        {
114            let instance = wgpu::Instance::default();
115            let adapter =
116                pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
117                    power_preference: wgpu::PowerPreference::HighPerformance,
118                    compatible_surface: None,
119                    force_fallback_adapter: false,
120                }))
121                .ok_or_else(|| "No GPU adapter found".to_string())?;
122
123            let adapter_info = adapter.get_info();
124            let limits = adapter.limits();
125            let required_limits = wgpu::Limits::default().using_resolution(limits.clone());
126
127            let (device, queue) = pollster::block_on(adapter.request_device(
128                &wgpu::DeviceDescriptor {
129                    label: Some("sci-form gpu"),
130                    required_features: wgpu::Features::empty(),
131                    required_limits,
132                },
133                None,
134            ))
135            .map_err(|err| format!("Failed to create wgpu device: {err}"))?;
136
137            Ok(Self {
138                capabilities: ComputeCapabilities {
139                    backend: format!("{:?}", adapter_info.backend),
140                    max_workgroup_size_x: limits.max_compute_workgroup_size_x,
141                    max_workgroup_size_y: limits.max_compute_workgroup_size_y,
142                    max_workgroup_invocations: limits.max_compute_invocations_per_workgroup,
143                    max_storage_buffer_size: limits.max_storage_buffer_binding_size as u64,
144                    gpu_available: true,
145                },
146                runtime_error: None,
147                runtime: Some(NativeGpuRuntime {
148                    _instance: instance,
149                    _adapter: adapter,
150                    device,
151                    queue,
152                }),
153            })
154        }
155
156        #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
157        {
158            Err("experimental-gpu feature not enabled".to_string())
159        }
160    }
161
162    /// Best available backend: GPU if possible, CPU fallback otherwise.
163    pub fn best_available() -> Self {
164        match Self::try_create() {
165            Ok(ctx) => ctx,
166            Err(reason) => {
167                let mut ctx = Self::cpu_fallback();
168                ctx.runtime_error = Some(reason);
169                ctx
170            }
171        }
172    }
173
174    /// Build an activation report describing the current runtime state.
175    pub fn activation_report(&self) -> GpuActivationReport {
176        if self.capabilities.gpu_available {
177            GpuActivationReport {
178                backend: self.capabilities.backend.clone(),
179                feature_enabled: true,
180                gpu_available: true,
181                runtime_ready: true,
182                state: GpuActivationState::Ready,
183                reason: "GPU runtime available".to_string(),
184            }
185        } else if cfg!(feature = "experimental-gpu") {
186            GpuActivationReport {
187                backend: self.capabilities.backend.clone(),
188                feature_enabled: true,
189                gpu_available: false,
190                runtime_ready: false,
191                state: GpuActivationState::NoAdapter,
192                reason: self
193                    .runtime_error
194                    .clone()
195                    .unwrap_or_else(|| "experimental-gpu enabled but no adapter found".to_string()),
196            }
197        } else {
198            GpuActivationReport {
199                backend: "CPU-fallback".to_string(),
200                feature_enabled: false,
201                gpu_available: false,
202                runtime_ready: false,
203                state: GpuActivationState::FeatureDisabled,
204                reason: "experimental-gpu feature not enabled".to_string(),
205            }
206        }
207    }
208
209    /// Whether the GPU backend is available.
210    pub fn is_gpu_available(&self) -> bool {
211        self.capabilities.gpu_available
212    }
213
214    #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
215    fn runtime(&self) -> Result<&NativeGpuRuntime, String> {
216        self.runtime.as_ref().ok_or_else(|| {
217            self.runtime_error
218                .clone()
219                .unwrap_or_else(|| "GPU runtime not initialized".to_string())
220        })
221    }
222
223    /// Dispatch an arbitrary WGSL compute kernel.
224    pub fn run_compute(
225        &self,
226        descriptor: &ComputeDispatchDescriptor,
227    ) -> Result<ComputeDispatchResult, String> {
228        #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
229        {
230            let runtime = self.runtime()?;
231
232            let shader = runtime
233                .device
234                .create_shader_module(wgpu::ShaderModuleDescriptor {
235                    label: Some(&descriptor.label),
236                    source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&descriptor.shader_source)),
237                });
238
239            let mut layout_entries = Vec::with_capacity(descriptor.bindings.len());
240            let mut buffers = Vec::with_capacity(descriptor.bindings.len());
241            let mut readbacks = Vec::new();
242
243            for (index, binding) in descriptor.bindings.iter().enumerate() {
244                let usage = match binding.kind {
245                    ComputeBindingKind::Uniform => {
246                        wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST
247                    }
248                    ComputeBindingKind::StorageReadOnly => {
249                        wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST
250                    }
251                    ComputeBindingKind::StorageReadWrite => {
252                        wgpu::BufferUsages::STORAGE
253                            | wgpu::BufferUsages::COPY_DST
254                            | wgpu::BufferUsages::COPY_SRC
255                    }
256                };
257
258                let buffer = runtime
259                    .device
260                    .create_buffer_init(&wgpu::util::BufferInitDescriptor {
261                        label: Some(&binding.label),
262                        contents: &binding.bytes,
263                        usage,
264                    });
265
266                layout_entries.push(wgpu::BindGroupLayoutEntry {
267                    binding: index as u32,
268                    visibility: wgpu::ShaderStages::COMPUTE,
269                    ty: wgpu::BindingType::Buffer {
270                        ty: match binding.kind {
271                            ComputeBindingKind::Uniform => wgpu::BufferBindingType::Uniform,
272                            ComputeBindingKind::StorageReadOnly => {
273                                wgpu::BufferBindingType::Storage { read_only: true }
274                            }
275                            ComputeBindingKind::StorageReadWrite => {
276                                wgpu::BufferBindingType::Storage { read_only: false }
277                            }
278                        },
279                        has_dynamic_offset: false,
280                        min_binding_size: None,
281                    },
282                    count: None,
283                });
284                if matches!(binding.kind, ComputeBindingKind::StorageReadWrite) {
285                    readbacks.push((buffers.len(), binding.bytes.len()));
286                }
287                buffers.push(buffer);
288            }
289
290            let bind_group_entries: Vec<_> = buffers
291                .iter()
292                .enumerate()
293                .map(|(index, buffer)| wgpu::BindGroupEntry {
294                    binding: index as u32,
295                    resource: buffer.as_entire_binding(),
296                })
297                .collect();
298
299            let bind_group_layout =
300                runtime
301                    .device
302                    .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
303                        label: Some(&format!("{} layout", descriptor.label)),
304                        entries: &layout_entries,
305                    });
306            let pipeline_layout =
307                runtime
308                    .device
309                    .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
310                        label: Some(&format!("{} pipeline", descriptor.label)),
311                        bind_group_layouts: &[&bind_group_layout],
312                        push_constant_ranges: &[],
313                    });
314            let pipeline =
315                runtime
316                    .device
317                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
318                        label: Some(&descriptor.label),
319                        layout: Some(&pipeline_layout),
320                        module: &shader,
321                        entry_point: &descriptor.entry_point,
322                    });
323            let bind_group = runtime
324                .device
325                .create_bind_group(&wgpu::BindGroupDescriptor {
326                    label: Some(&format!("{} bind group", descriptor.label)),
327                    layout: &bind_group_layout,
328                    entries: &bind_group_entries,
329                });
330
331            let mut encoder =
332                runtime
333                    .device
334                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
335                        label: Some(&format!("{} encoder", descriptor.label)),
336                    });
337            {
338                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
339                    label: Some(&format!("{} pass", descriptor.label)),
340                    timestamp_writes: None,
341                });
342                pass.set_pipeline(&pipeline);
343                pass.set_bind_group(0, &bind_group, &[]);
344                pass.dispatch_workgroups(
345                    descriptor.workgroup_count[0],
346                    descriptor.workgroup_count[1],
347                    descriptor.workgroup_count[2],
348                );
349            }
350
351            let mut staging_buffers = Vec::with_capacity(readbacks.len());
352            for (buffer_index, size_bytes) in &readbacks {
353                let staging = runtime.device.create_buffer(&wgpu::BufferDescriptor {
354                    label: Some("readback staging"),
355                    size: *size_bytes as u64,
356                    usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
357                    mapped_at_creation: false,
358                });
359                encoder.copy_buffer_to_buffer(
360                    &buffers[*buffer_index],
361                    0,
362                    &staging,
363                    0,
364                    *size_bytes as u64,
365                );
366                staging_buffers.push(staging);
367            }
368
369            runtime.queue.submit(Some(encoder.finish()));
370            runtime.device.poll(wgpu::Maintain::Wait);
371
372            let mut outputs = Vec::with_capacity(staging_buffers.len());
373            for staging in staging_buffers {
374                let slice = staging.slice(..);
375                let (sender, receiver) = mpsc::channel();
376                slice.map_async(wgpu::MapMode::Read, move |result| {
377                    let _ = sender.send(result);
378                });
379                runtime.device.poll(wgpu::Maintain::Wait);
380                receiver
381                    .recv()
382                    .map_err(|_| "GPU readback channel error".to_string())?
383                    .map_err(|err| format!("GPU buffer map failed: {err}"))?;
384
385                let bytes = slice.get_mapped_range().to_vec();
386                staging.unmap();
387                outputs.push(bytes);
388            }
389
390            Ok(ComputeDispatchResult {
391                backend: self.capabilities.backend.clone(),
392                outputs,
393            })
394        }
395
396        #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
397        {
398            let _ = descriptor;
399            Err("experimental-gpu feature not enabled".to_string())
400        }
401    }
402
403    /// Validate that a WGSL shader compiles successfully on the current device.
404    pub fn validate_shader(&self, label: &str, source: &str) -> Result<String, String> {
405        #[cfg(all(feature = "experimental-gpu", not(target_arch = "wasm32")))]
406        {
407            let runtime = self.runtime()?;
408            // Shader creation will fail if WGSL is invalid
409            let _module = runtime
410                .device
411                .create_shader_module(wgpu::ShaderModuleDescriptor {
412                    label: Some(label),
413                    source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
414                });
415            Ok(format!(
416                "Shader '{}' compiled on {}",
417                label, self.capabilities.backend
418            ))
419        }
420
421        #[cfg(not(all(feature = "experimental-gpu", not(target_arch = "wasm32"))))]
422        {
423            let _ = (label, source);
424            Err("experimental-gpu feature not enabled".to_string())
425        }
426    }
427
428    /// Convenience: f32 vector addition for GPU validation.
429    pub fn vector_add_f32(&self, lhs: &[f32], rhs: &[f32]) -> Result<Vec<f32>, String> {
430        if lhs.len() != rhs.len() {
431            return Err("Vectors must have the same length".to_string());
432        }
433
434        let params = VectorAddParams {
435            len: lhs.len() as u32,
436            _pad: [0; 3],
437        };
438        let dispatch = (lhs.len() as u32).div_ceil(64);
439        let output_seed = vec![0.0f32; lhs.len()];
440
441        let descriptor = ComputeDispatchDescriptor {
442            label: "vector add".to_string(),
443            shader_source: VECTOR_ADD_SHADER.to_string(),
444            entry_point: "main".to_string(),
445            workgroup_count: [dispatch.max(1), 1, 1],
446            bindings: vec![
447                ComputeBindingDescriptor {
448                    label: "lhs".to_string(),
449                    kind: ComputeBindingKind::StorageReadOnly,
450                    bytes: f32_slice_to_bytes(lhs),
451                },
452                ComputeBindingDescriptor {
453                    label: "rhs".to_string(),
454                    kind: ComputeBindingKind::StorageReadOnly,
455                    bytes: f32_slice_to_bytes(rhs),
456                },
457                ComputeBindingDescriptor {
458                    label: "params".to_string(),
459                    kind: ComputeBindingKind::Uniform,
460                    bytes: vector_add_params_to_bytes(&params),
461                },
462                ComputeBindingDescriptor {
463                    label: "output".to_string(),
464                    kind: ComputeBindingKind::StorageReadWrite,
465                    bytes: f32_slice_to_bytes(&output_seed),
466                },
467            ],
468        };
469
470        let mut result = self.run_compute(&descriptor)?.outputs;
471        let bytes = result.pop().ok_or("No output from GPU kernel")?;
472        Ok(bytes_to_f32_vec(&bytes))
473    }
474}
475
476// ─── Helper types and functions ──────────────────────────────────────────────
477
478#[repr(C)]
479#[derive(Debug, Clone, Copy)]
480struct VectorAddParams {
481    len: u32,
482    _pad: [u32; 3],
483}
484
485pub fn f32_slice_to_bytes(values: &[f32]) -> Vec<u8> {
486    let mut bytes = Vec::with_capacity(values.len() * 4);
487    for v in values {
488        bytes.extend_from_slice(&v.to_ne_bytes());
489    }
490    bytes
491}
492
493pub fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
494    bytes
495        .chunks_exact(4)
496        .map(|c| f32::from_ne_bytes(c.try_into().expect("4 bytes")))
497        .collect()
498}
499
500#[derive(Debug, Clone, Copy)]
501pub enum UniformValue {
502    U32(u32),
503    F32(f32),
504}
505
506pub fn pack_uniform_values(values: &[UniformValue]) -> Vec<u8> {
507    let mut bytes = Vec::with_capacity(values.len() * 4);
508    for value in values {
509        match value {
510            UniformValue::U32(word) => bytes.extend_from_slice(&word.to_ne_bytes()),
511            UniformValue::F32(word) => bytes.extend_from_slice(&word.to_ne_bytes()),
512        }
513    }
514    bytes
515}
516
517pub fn pack_vec3_positions_f32(positions: &[[f64; 3]]) -> Vec<u8> {
518    let mut bytes = Vec::with_capacity(positions.len() * 16);
519    for position in positions {
520        bytes.extend_from_slice(&(position[0] as f32).to_ne_bytes());
521        bytes.extend_from_slice(&(position[1] as f32).to_ne_bytes());
522        bytes.extend_from_slice(&(position[2] as f32).to_ne_bytes());
523        bytes.extend_from_slice(&0.0f32.to_ne_bytes());
524    }
525    bytes
526}
527
528pub fn bytes_to_f64_vec_from_f32(bytes: &[u8]) -> Vec<f64> {
529    bytes_to_f32_vec(bytes)
530        .into_iter()
531        .map(|value| value as f64)
532        .collect()
533}
534
535pub fn ceil_div_u32(value: usize, chunk: u32) -> u32 {
536    (value as u32).div_ceil(chunk)
537}
538
539fn vector_add_params_to_bytes(params: &VectorAddParams) -> Vec<u8> {
540    let mut bytes = Vec::with_capacity(16);
541    bytes.extend_from_slice(&params.len.to_ne_bytes());
542    for v in params._pad {
543        bytes.extend_from_slice(&v.to_ne_bytes());
544    }
545    bytes
546}
547
548const VECTOR_ADD_SHADER: &str = r#"
549struct Params {
550    len: u32, _pad0: u32, _pad1: u32, _pad2: u32,
551};
552
553@group(0) @binding(0) var<storage, read> lhs: array<f32>;
554@group(0) @binding(1) var<storage, read> rhs: array<f32>;
555@group(0) @binding(2) var<uniform> params: Params;
556@group(0) @binding(3) var<storage, read_write> out: array<f32>;
557
558@compute @workgroup_size(64, 1, 1)
559fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
560    let idx = gid.x;
561    if (idx >= params.len) { return; }
562    out[idx] = lhs[idx] + rhs[idx];
563}
564"#;
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_cpu_fallback_creation() {
572        let ctx = GpuContext::cpu_fallback();
573        assert!(!ctx.is_gpu_available());
574        assert_eq!(ctx.capabilities.backend, "CPU-fallback");
575    }
576
577    #[test]
578    fn test_best_available_never_panics() {
579        let ctx = GpuContext::best_available();
580        let report = ctx.activation_report();
581        assert!(!report.backend.is_empty());
582    }
583
584    #[test]
585    fn test_activation_report_feature_disabled() {
586        let ctx = GpuContext::cpu_fallback();
587        let report = ctx.activation_report();
588        if !cfg!(feature = "experimental-gpu") {
589            assert_eq!(report.state, GpuActivationState::FeatureDisabled);
590            assert!(!report.feature_enabled);
591        }
592    }
593
594    #[test]
595    fn test_compute_capabilities_default() {
596        let caps = ComputeCapabilities::default();
597        assert!(!caps.gpu_available);
598        assert!(caps.max_workgroup_size_x > 0);
599    }
600
601    #[test]
602    fn test_f32_roundtrip() {
603        let values = vec![1.0f32, 2.5, -std::f32::consts::PI, 0.0];
604        let bytes = f32_slice_to_bytes(&values);
605        let result = bytes_to_f32_vec(&bytes);
606        assert_eq!(values, result);
607    }
608
609    #[test]
610    fn test_uniform_word_packing() {
611        let bytes = pack_uniform_values(&[
612            UniformValue::U32(7),
613            UniformValue::F32(1.5),
614            UniformValue::U32(9),
615            UniformValue::F32(-2.0),
616        ]);
617        assert_eq!(bytes.len(), 16);
618        assert_eq!(u32::from_ne_bytes(bytes[0..4].try_into().unwrap()), 7);
619        assert!((f32::from_ne_bytes(bytes[4..8].try_into().unwrap()) - 1.5).abs() < 1e-6);
620        assert_eq!(u32::from_ne_bytes(bytes[8..12].try_into().unwrap()), 9);
621        assert!((f32::from_ne_bytes(bytes[12..16].try_into().unwrap()) + 2.0).abs() < 1e-6);
622    }
623
624    #[test]
625    fn test_pack_vec3_positions_f32_layout() {
626        let bytes = pack_vec3_positions_f32(&[[1.0, -2.0, 3.5]]);
627        assert_eq!(bytes.len(), 16);
628        assert!((f32::from_ne_bytes(bytes[0..4].try_into().unwrap()) - 1.0).abs() < 1e-6);
629        assert!((f32::from_ne_bytes(bytes[4..8].try_into().unwrap()) + 2.0).abs() < 1e-6);
630        assert!((f32::from_ne_bytes(bytes[8..12].try_into().unwrap()) - 3.5).abs() < 1e-6);
631        assert_eq!(f32::from_ne_bytes(bytes[12..16].try_into().unwrap()), 0.0);
632    }
633}