Skip to main content

scirs2_special/gpu_kernels/
wgsl.rs

1//! WGSL compute-shader kernels for WebGPU-backed dispatch of batch special functions.
2//!
3//! Each constant holds the WGSL source for a `@compute` shader that evaluates a batch
4//! of special-function values.  The shaders operate on `array<f32>` — inputs must be
5//! cast from `f64` before upload and cast back after download.
6//!
7//! The host-side dispatch functions (`gamma_batch_wgpu`, `erf_batch_wgpu`,
8//! `bessel_j0_batch_wgpu`, `lgamma_batch_wgpu`) perform real GPU dispatch when the
9//! `wgpu_kernels` feature is enabled and a wgpu adapter is found.  When the feature
10//! is disabled, or no adapter is available at runtime, they return
11//! [`WgslDispatchError::GpuNotAvailable`] so the caller can fall back to CPU.
12//!
13//! # Feature gating
14//!
15//! The WGSL shader source constants are always compiled (useful for documentation and
16//! validation tooling).  The GPU dispatch paths are gated behind
17//! `#[cfg(feature = "wgpu_kernels")]`.
18
19// ---------------------------------------------------------------------------
20// WGSL shader sources
21// ---------------------------------------------------------------------------
22
23/// WGSL compute shader for batch Gamma evaluation (Lanczos g=7 approximation).
24///
25/// Workgroup size 64.  Each invocation reads one `f32` from `input` and
26/// writes the approximated `Γ(x)` into `output`.
27/// The reflection formula `Γ(x) = π / (sin(π x) · Γ(1-x))` is applied when
28/// `x < 0.5`.
29pub const GAMMA_WGSL: &str = r#"
30@group(0) @binding(0) var<storage, read> input: array<f32>;
31@group(0) @binding(1) var<storage, read_write> output: array<f32>;
32
33const PI: f32 = 3.14159265358979323846;
34
35// Lanczos g=7 coefficients (Spouge's form, 9 terms)
36fn lanczos_gamma(x_in: f32) -> f32 {
37    var x = x_in;
38    var sign = 1.0f;
39    if x < 0.5 {
40        sign = PI / (sin(PI * x));
41        x = 1.0 - x;
42    }
43    let g: f32 = 7.0;
44    x = x - 1.0;
45
46    let c0: f32 =  0.99999999999980993;
47    let c1: f32 =  676.5203681218851;
48    let c2: f32 = -1259.1392167224028;
49    let c3: f32 =  771.32342877765313;
50    let c4: f32 = -176.61502916214059;
51    let c5: f32 =  12.507343278686905;
52    let c6: f32 = -0.13857109526572012;
53    let c7: f32 =  9.9843695780195716e-6;
54    let c8: f32 =  1.5056327351493116e-7;
55
56    let s = c0
57        + c1 / (x + 1.0)
58        + c2 / (x + 2.0)
59        + c3 / (x + 3.0)
60        + c4 / (x + 4.0)
61        + c5 / (x + 5.0)
62        + c6 / (x + 6.0)
63        + c7 / (x + 7.0)
64        + c8 / (x + 8.0);
65
66    let t = x + g + 0.5;
67    let result = sqrt(2.0 * PI) * pow(t, x + 0.5) * exp(-t) * s;
68    if sign != 1.0 { return sign / result; }
69    return result;
70}
71
72@compute @workgroup_size(64)
73fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
74    let idx = gid.x;
75    if idx >= arrayLength(&input) { return; }
76    output[idx] = lanczos_gamma(input[idx]);
77}
78"#;
79
80/// WGSL compute shader for batch `erf` evaluation.
81///
82/// Uses the Abramowitz & Stegun 7.1.26 approximation (max error ≈ 1.5 × 10⁻⁷).
83pub const ERF_WGSL: &str = r#"
84@group(0) @binding(0) var<storage, read> input: array<f32>;
85@group(0) @binding(1) var<storage, read_write> output: array<f32>;
86
87fn approx_erf(x: f32) -> f32 {
88    let t = 1.0 / (1.0 + 0.3275911 * abs(x));
89    let y = 1.0 - (((((
90          1.061405429 * t
91        - 1.453152027) * t
92        + 1.421413741) * t
93        - 0.284496736) * t
94        + 0.254829592) * t * exp(-x * x));
95    return select(-y, y, x >= 0.0);
96}
97
98@compute @workgroup_size(64)
99fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
100    let idx = gid.x;
101    if idx >= arrayLength(&input) { return; }
102    output[idx] = approx_erf(input[idx]);
103}
104"#;
105
106/// WGSL compute shader for batch Bessel J₀ evaluation.
107///
108/// Uses the polynomial approximation from Abramowitz & Stegun §9.4 for
109/// |x| < 8 and the asymptotic expansion for |x| ≥ 8.
110pub const BESSEL_J0_WGSL: &str = r#"
111@group(0) @binding(0) var<storage, read> input: array<f32>;
112@group(0) @binding(1) var<storage, read_write> output: array<f32>;
113
114const PI: f32 = 3.14159265358979323846;
115
116fn bessel_j0(x_in: f32) -> f32 {
117    let x = abs(x_in);
118    if x < 8.0 {
119        let y = x * x;
120        let p1: f32 =  57568490574.0;
121        let p2: f32 = -13362590354.0;
122        let p3: f32 =  651619640.7;
123        let p4: f32 = -11214424.18;
124        let p5: f32 =  77392.33017;
125        let p6: f32 = -184.9052456;
126        let q1: f32 =  57568490411.0;
127        let q2: f32 =  1029532985.0;
128        let q3: f32 =  9494680.718;
129        let q4: f32 =  59272.64853;
130        let q5: f32 =  267.8532712;
131        let p = p1 + y * (p2 + y * (p3 + y * (p4 + y * (p5 + y * p6))));
132        let q = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y))));
133        return p / q;
134    } else {
135        let z = 8.0 / x;
136        let y = z * z;
137        let xx = x - 0.785398164;
138        let pv = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4
139                 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)));
140        let qv = -0.1562499995e-1 + y * (0.1430488765e-3
141                 + y * (-0.6911147651e-5 + y * (0.7621095161e-6
142                 - y * 0.934945152e-7)));
143        return sqrt(0.636619772 / x) * (cos(xx) * pv - z * sin(xx) * qv);
144    }
145}
146
147@compute @workgroup_size(64)
148fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
149    let idx = gid.x;
150    if idx >= arrayLength(&input) { return; }
151    output[idx] = bessel_j0(input[idx]);
152}
153"#;
154
155/// WGSL compute shader for batch `erfc` evaluation.
156///
157/// Computes `erfc(x) = 1 - erf(x)` using the same Abramowitz & Stegun 7.1.26
158/// rational approximation as [`ERF_WGSL`].  The subtraction `1 - approx_erf(x)`
159/// is used throughout; edge cases at |x| > 6 short-circuit to exact 0 or 2.
160pub const ERFC_WGSL: &str = r#"
161@group(0) @binding(0) var<storage, read> input: array<f32>;
162@group(0) @binding(1) var<storage, read_write> output: array<f32>;
163
164fn approx_erf_inner(x: f32) -> f32 {
165    let t = 1.0 / (1.0 + 0.3275911 * abs(x));
166    let y = 1.0 - (((((
167          1.061405429 * t
168        - 1.453152027) * t
169        + 1.421413741) * t
170        - 0.284496736) * t
171        + 0.254829592) * t * exp(-x * x));
172    return select(-y, y, x >= 0.0);
173}
174
175fn approx_erfc(x: f32) -> f32 {
176    // erfc saturates quickly: |erfc(x)| < f32_epsilon for |x| > ~6
177    if abs(x) > 6.0 {
178        return select(0.0, 2.0, x < 0.0);
179    }
180    return 1.0 - approx_erf_inner(x);
181}
182
183@compute @workgroup_size(64)
184fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
185    let idx = gid.x;
186    if idx >= arrayLength(&input) { return; }
187    output[idx] = approx_erfc(input[idx]);
188}
189"#;
190
191/// WGSL compute shader for batch inverse-erf evaluation.
192///
193/// Uses the Winitzki (2008) rational approximation to `erfinv`, which achieves
194/// a maximum absolute error of approximately 5 × 10⁻⁴ for |p| < 1 in f32.
195/// Inputs with |p| ≥ 1 return ±1 × 10¹⁰ (representable large f32 values).
196pub const ERFINV_WGSL: &str = r#"
197@group(0) @binding(0) var<storage, read> input: array<f32>;
198@group(0) @binding(1) var<storage, read_write> output: array<f32>;
199
200const PI_F: f32 = 3.14159265358979323846;
201const WINITZKI_A: f32 = 0.147;
202const INV_WINITZKI_A: f32 = 6.802721088;  // 1.0 / 0.147
203
204fn approx_erfinv(p: f32) -> f32 {
205    let ap = abs(p);
206    if ap >= 1.0 {
207        // Return signed large value for |p| = 1 boundary
208        return select(1e10, -1e10, p < 0.0);
209    }
210    if p == 0.0 {
211        return 0.0;
212    }
213
214    let sign_p = select(-1.0f, 1.0f, p >= 0.0);
215    // Winitzki (2008): erfinv(p) ≈ sign(p) * sqrt(sqrt(c^2 - ln(1-p^2)/a) - c)
216    // where c = 2/(π·a) + ln(1-p^2)/2
217    let ln_term = log(1.0 - p * p);
218    let two_over_pia = 2.0 / (PI_F * WINITZKI_A);
219    let c = two_over_pia + ln_term * 0.5;
220    let discriminant = c * c - ln_term * INV_WINITZKI_A;
221    // discriminant is always non-negative for |p| < 1
222    let inner = sqrt(max(discriminant, 0.0)) - c;
223    return sign_p * sqrt(max(inner, 0.0));
224}
225
226@compute @workgroup_size(64)
227fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
228    let idx = gid.x;
229    if idx >= arrayLength(&input) { return; }
230    output[idx] = approx_erfinv(input[idx]);
231}
232"#;
233
234/// WGSL compute shader for batch log-gamma evaluation.
235///
236/// Uses the same Lanczos g=7 coefficients as [`GAMMA_WGSL`], but computes
237/// `ln Γ(x)` directly to avoid overflow for large `x`.
238/// The reflection formula `ln Γ(x) = ln(π/|sin(πx)|) - ln Γ(1-x)` handles `x < 0.5`.
239pub const LGAMMA_WGSL: &str = r#"
240@group(0) @binding(0) var<storage, read> input: array<f32>;
241@group(0) @binding(1) var<storage, read_write> output: array<f32>;
242
243const PI: f32 = 3.14159265358979323846;
244
245fn lanczos_lgamma(x_in: f32) -> f32 {
246    var x = x_in;
247    var log_sign: f32 = 0.0;
248    if x < 0.5 {
249        log_sign = log(PI / abs(sin(PI * x)));
250        x = 1.0 - x;
251    }
252    let g: f32 = 7.0;
253    x = x - 1.0;
254    let c0: f32 =  0.99999999999980993;
255    let c1: f32 =  676.5203681218851;
256    let c2: f32 = -1259.1392167224028;
257    let c3: f32 =  771.32342877765313;
258    let c4: f32 = -176.61502916214059;
259    let c5: f32 =  12.507343278686905;
260    let c6: f32 = -0.13857109526572012;
261    let c7: f32 =  9.9843695780195716e-6;
262    let c8: f32 =  1.5056327351493116e-7;
263    let s = c0 + c1/(x+1.0) + c2/(x+2.0) + c3/(x+3.0) + c4/(x+4.0)
264              + c5/(x+5.0) + c6/(x+6.0) + c7/(x+7.0) + c8/(x+8.0);
265    let t = x + g + 0.5;
266    let lgamma = 0.5 * log(2.0 * PI) + (x + 0.5) * log(t) - t + log(s);
267    if log_sign != 0.0 {
268        return log_sign - lgamma;
269    }
270    return lgamma;
271}
272
273@compute @workgroup_size(64)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
275    let idx = gid.x;
276    if idx >= arrayLength(&input) { return; }
277    output[idx] = lanczos_lgamma(input[idx]);
278}
279"#;
280
281// ---------------------------------------------------------------------------
282// Dispatch error
283// ---------------------------------------------------------------------------
284
285/// Error type for WGSL/WebGPU dispatch.
286#[derive(Debug, Clone)]
287pub enum WgslDispatchError {
288    /// No wgpu device is available (headless or non-WASM build).
289    GpuNotAvailable,
290    /// The wgpu pipeline setup or execution failed.
291    RuntimeError(String),
292}
293
294impl std::fmt::Display for WgslDispatchError {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        match self {
297            WgslDispatchError::GpuNotAvailable => {
298                write!(f, "wgpu GPU device not available")
299            }
300            WgslDispatchError::RuntimeError(msg) => {
301                write!(f, "wgpu runtime error: {msg}")
302            }
303        }
304    }
305}
306
307// ---------------------------------------------------------------------------
308// Real GPU dispatch (wgpu_kernels feature)
309// ---------------------------------------------------------------------------
310
311/// Upload `xs_f32` to a read-only storage buffer, dispatch `shader_src` over it,
312/// and return the resulting `f32` values.
313///
314/// This is the shared implementation used by all four batch dispatch functions.
315/// The shader is expected to have exactly two bindings:
316///   - `@group(0) @binding(0)` — read-only input `array<f32>`
317///   - `@group(0) @binding(1)` — read-write output `array<f32>`
318#[cfg(feature = "wgpu_kernels")]
319fn dispatch_unary_f32(shader_src: &str, xs_f32: &[f32]) -> Result<Vec<f32>, WgslDispatchError> {
320    use wgpu::{
321        util::BufferInitDescriptor, util::DeviceExt as _, Backends, BindGroupDescriptor,
322        BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
323        BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
324        ComputePassDescriptor, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits,
325        MapMode, PowerPreference, RequestAdapterOptions, ShaderModuleDescriptor, ShaderSource,
326        ShaderStages,
327    };
328
329    let n = xs_f32.len();
330    if n == 0 {
331        return Ok(Vec::new());
332    }
333
334    // ── Adapter / device acquisition ──────────────────────────────────────────
335    let instance = Instance::new(InstanceDescriptor {
336        backends: Backends::all(),
337        flags: wgpu::InstanceFlags::default(),
338        memory_budget_thresholds: Default::default(),
339        backend_options: Default::default(),
340        display: None,
341    });
342
343    let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
344        power_preference: PowerPreference::HighPerformance,
345        compatible_surface: None,
346        force_fallback_adapter: false,
347    }))
348    .map_err(|_| WgslDispatchError::GpuNotAvailable)?;
349
350    let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
351        label: Some("scirs2-special"),
352        required_features: Features::empty(),
353        required_limits: Limits::default(),
354        ..Default::default()
355    }))
356    .map_err(|e| WgslDispatchError::RuntimeError(e.to_string()))?;
357
358    // ── Shader / pipeline ─────────────────────────────────────────────────────
359    let shader_module = device.create_shader_module(ShaderModuleDescriptor {
360        label: Some("scirs2-special-shader"),
361        source: ShaderSource::Wgsl(shader_src.into()),
362    });
363
364    let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
365        label: Some("scirs2-special-bgl"),
366        entries: &[
367            BindGroupLayoutEntry {
368                binding: 0,
369                visibility: ShaderStages::COMPUTE,
370                ty: BindingType::Buffer {
371                    ty: BufferBindingType::Storage { read_only: true },
372                    has_dynamic_offset: false,
373                    min_binding_size: None,
374                },
375                count: None,
376            },
377            BindGroupLayoutEntry {
378                binding: 1,
379                visibility: ShaderStages::COMPUTE,
380                ty: BindingType::Buffer {
381                    ty: BufferBindingType::Storage { read_only: false },
382                    has_dynamic_offset: false,
383                    min_binding_size: None,
384                },
385                count: None,
386            },
387        ],
388    });
389
390    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
391        label: Some("scirs2-special-layout"),
392        bind_group_layouts: &[Some(&bgl)],
393        ..Default::default()
394    });
395
396    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
397        label: Some("scirs2-special-pipeline"),
398        layout: Some(&pipeline_layout),
399        module: &shader_module,
400        entry_point: Some("main"),
401        compilation_options: Default::default(),
402        cache: None,
403    });
404
405    // ── Buffers ───────────────────────────────────────────────────────────────
406    // Encode f32 slice to raw bytes manually (avoids bytemuck dependency)
407    let input_bytes: Vec<u8> = xs_f32.iter().flat_map(|v| v.to_le_bytes()).collect();
408    let byte_len = (n * 4) as u64;
409
410    let buf_input = device.create_buffer_init(&BufferInitDescriptor {
411        label: Some("scirs2-special-input"),
412        contents: &input_bytes,
413        usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
414    });
415
416    let buf_output = device.create_buffer(&BufferDescriptor {
417        label: Some("scirs2-special-output"),
418        size: byte_len,
419        usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
420        mapped_at_creation: false,
421    });
422
423    let buf_staging = device.create_buffer(&BufferDescriptor {
424        label: Some("scirs2-special-staging"),
425        size: byte_len,
426        usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
427        mapped_at_creation: false,
428    });
429
430    // ── Bind group ────────────────────────────────────────────────────────────
431    let bind_group = device.create_bind_group(&BindGroupDescriptor {
432        label: Some("scirs2-special-bg"),
433        layout: &bgl,
434        entries: &[
435            BindGroupEntry {
436                binding: 0,
437                resource: buf_input.as_entire_binding(),
438            },
439            BindGroupEntry {
440                binding: 1,
441                resource: buf_output.as_entire_binding(),
442            },
443        ],
444    });
445
446    // ── Encode / dispatch ─────────────────────────────────────────────────────
447    let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
448        label: Some("scirs2-special-encoder"),
449    });
450    {
451        let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
452            label: Some("scirs2-special-pass"),
453            timestamp_writes: None,
454        });
455        cpass.set_pipeline(&pipeline);
456        cpass.set_bind_group(0, &bind_group, &[]);
457        let workgroups = (n as u32 + 63) / 64;
458        cpass.dispatch_workgroups(workgroups, 1, 1);
459    }
460    encoder.copy_buffer_to_buffer(&buf_output, 0, &buf_staging, 0, byte_len);
461    queue.submit(Some(encoder.finish()));
462
463    // ── Readback ──────────────────────────────────────────────────────────────
464    device
465        .poll(wgpu::PollType::wait_indefinitely())
466        .map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll error: {e:?}")))?;
467
468    let slice = buf_staging.slice(0..byte_len);
469    let (tx, rx) = std::sync::mpsc::channel();
470    slice.map_async(MapMode::Read, move |r| {
471        let _ = tx.send(r);
472    });
473
474    device
475        .poll(wgpu::PollType::wait_indefinitely())
476        .map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll during map: {e:?}")))?;
477
478    rx.recv()
479        .map_err(|_| WgslDispatchError::RuntimeError("channel closed in map_async".into()))?
480        .map_err(|e| WgslDispatchError::RuntimeError(format!("map_async failed: {e:?}")))?;
481
482    let mapped = slice.get_mapped_range();
483    let result: Vec<f32> = mapped
484        .chunks_exact(4)
485        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
486        .collect();
487    drop(mapped);
488    buf_staging.unmap();
489
490    Ok(result)
491}
492
493// ---------------------------------------------------------------------------
494// Host-side dispatch functions — real (wgpu_kernels) path
495// ---------------------------------------------------------------------------
496
497/// Attempt batch Gamma evaluation on a WebGPU device.
498///
499/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
500/// `f32`, dispatches the Lanczos WGSL shader, and returns `f64` results.
501/// Falls back to [`WgslDispatchError::GpuNotAvailable`] otherwise.
502#[cfg(feature = "wgpu_kernels")]
503pub fn gamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
504    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
505    let result_f32 = dispatch_unary_f32(GAMMA_WGSL, &xs_f32)?;
506    Ok(result_f32.iter().map(|&v| v as f64).collect())
507}
508
509/// Attempt batch `erf` evaluation on a WebGPU device.
510///
511/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
512/// `f32`, dispatches the A&S WGSL erf shader, and returns `f64` results.
513#[cfg(feature = "wgpu_kernels")]
514pub fn erf_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
515    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
516    let result_f32 = dispatch_unary_f32(ERF_WGSL, &xs_f32)?;
517    Ok(result_f32.iter().map(|&v| v as f64).collect())
518}
519
520/// Attempt batch Bessel J₀ evaluation on a WebGPU device.
521///
522/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
523/// `f32`, dispatches the A&S polynomial WGSL shader, and returns `f64` results.
524#[cfg(feature = "wgpu_kernels")]
525pub fn bessel_j0_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
526    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
527    let result_f32 = dispatch_unary_f32(BESSEL_J0_WGSL, &xs_f32)?;
528    Ok(result_f32.iter().map(|&v| v as f64).collect())
529}
530
531/// Attempt batch log-Gamma evaluation on a WebGPU device.
532///
533/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
534/// `f32`, dispatches the Lanczos log-gamma WGSL shader, and returns `f64` results.
535#[cfg(feature = "wgpu_kernels")]
536pub fn lgamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
537    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
538    let result_f32 = dispatch_unary_f32(LGAMMA_WGSL, &xs_f32)?;
539    Ok(result_f32.iter().map(|&v| v as f64).collect())
540}
541
542/// Attempt batch `erfc` evaluation on a WebGPU device.
543///
544/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
545/// `f32`, dispatches the A&S erfc WGSL shader, and returns `f64` results.
546/// Falls back to [`WgslDispatchError::GpuNotAvailable`] otherwise.
547#[cfg(feature = "wgpu_kernels")]
548pub fn erfc_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
549    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
550    let result_f32 = dispatch_unary_f32(ERFC_WGSL, &xs_f32)?;
551    Ok(result_f32.iter().map(|&v| v as f64).collect())
552}
553
554/// Attempt batch `erfinv` evaluation on a WebGPU device.
555///
556/// When `wgpu_kernels` is enabled and a wgpu adapter is found, uploads `xs` as
557/// `f32`, dispatches the Winitzki erfinv WGSL shader, and returns `f64` results.
558/// Falls back to [`WgslDispatchError::GpuNotAvailable`] otherwise.
559#[cfg(feature = "wgpu_kernels")]
560pub fn erfinv_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
561    let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
562    let result_f32 = dispatch_unary_f32(ERFINV_WGSL, &xs_f32)?;
563    Ok(result_f32.iter().map(|&v| v as f64).collect())
564}
565
566// ---------------------------------------------------------------------------
567// Host-side dispatch functions — stub (no wgpu_kernels) path
568// ---------------------------------------------------------------------------
569
570/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
571#[cfg(not(feature = "wgpu_kernels"))]
572pub fn gamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
573    Err(WgslDispatchError::GpuNotAvailable)
574}
575
576/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
577#[cfg(not(feature = "wgpu_kernels"))]
578pub fn erf_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
579    Err(WgslDispatchError::GpuNotAvailable)
580}
581
582/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
583#[cfg(not(feature = "wgpu_kernels"))]
584pub fn bessel_j0_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
585    Err(WgslDispatchError::GpuNotAvailable)
586}
587
588/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
589#[cfg(not(feature = "wgpu_kernels"))]
590pub fn lgamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
591    Err(WgslDispatchError::GpuNotAvailable)
592}
593
594/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
595#[cfg(not(feature = "wgpu_kernels"))]
596pub fn erfc_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
597    Err(WgslDispatchError::GpuNotAvailable)
598}
599
600/// Stub: returns [`WgslDispatchError::GpuNotAvailable`] when `wgpu_kernels` is off.
601#[cfg(not(feature = "wgpu_kernels"))]
602pub fn erfinv_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
603    Err(WgslDispatchError::GpuNotAvailable)
604}
605
606// ---------------------------------------------------------------------------
607// Tests
608// ---------------------------------------------------------------------------
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_gamma_wgsl_source_is_non_empty() {
616        assert!(!GAMMA_WGSL.is_empty());
617        assert!(GAMMA_WGSL.contains("@compute"));
618        assert!(GAMMA_WGSL.contains("workgroup_size"));
619        assert!(GAMMA_WGSL.contains("lanczos_gamma"));
620    }
621
622    #[test]
623    fn test_erf_wgsl_source_is_non_empty() {
624        assert!(!ERF_WGSL.is_empty());
625        assert!(ERF_WGSL.contains("@compute"));
626        assert!(ERF_WGSL.contains("approx_erf"));
627    }
628
629    #[test]
630    fn test_bessel_j0_wgsl_source_is_non_empty() {
631        assert!(!BESSEL_J0_WGSL.is_empty());
632        assert!(BESSEL_J0_WGSL.contains("@compute"));
633        assert!(BESSEL_J0_WGSL.contains("bessel_j0"));
634    }
635
636    #[test]
637    fn test_lgamma_wgsl_source_is_non_empty() {
638        assert!(!LGAMMA_WGSL.is_empty());
639        assert!(LGAMMA_WGSL.contains("@compute"));
640        assert!(LGAMMA_WGSL.contains("lanczos_lgamma"));
641    }
642
643    #[test]
644    fn test_erfc_wgsl_source_is_non_empty() {
645        assert!(!ERFC_WGSL.is_empty());
646        assert!(ERFC_WGSL.contains("@compute"));
647        assert!(ERFC_WGSL.contains("approx_erfc"));
648        assert!(ERFC_WGSL.contains("workgroup_size"));
649    }
650
651    #[test]
652    fn test_erfinv_wgsl_source_is_non_empty() {
653        assert!(!ERFINV_WGSL.is_empty());
654        assert!(ERFINV_WGSL.contains("@compute"));
655        assert!(ERFINV_WGSL.contains("approx_erfinv"));
656        assert!(ERFINV_WGSL.contains("workgroup_size"));
657    }
658
659    #[test]
660    fn test_gamma_batch_wgpu_returns_not_available() {
661        // Without wgpu_kernels feature, always returns GpuNotAvailable.
662        // With wgpu_kernels feature on headless CI, also returns GpuNotAvailable.
663        let xs = vec![1.0_f64, 2.0, 3.0];
664        let result = gamma_batch_wgpu(&xs);
665        // Either Ok (GPU present) or GpuNotAvailable (no GPU)
666        match result {
667            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
668            Err(e) => panic!("unexpected error: {e}"),
669        }
670    }
671
672    #[test]
673    fn test_erf_batch_wgpu_returns_not_available() {
674        let xs = vec![0.0_f64, 1.0];
675        let result = erf_batch_wgpu(&xs);
676        match result {
677            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
678            Err(e) => panic!("unexpected error: {e}"),
679        }
680    }
681
682    #[test]
683    fn test_bessel_j0_batch_wgpu_returns_not_available() {
684        let xs = vec![0.0_f64, 2.405];
685        let result = bessel_j0_batch_wgpu(&xs);
686        match result {
687            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
688            Err(e) => panic!("unexpected error: {e}"),
689        }
690    }
691
692    #[test]
693    fn test_lgamma_batch_wgpu_returns_not_available() {
694        let xs = vec![1.0_f64, 2.0, 3.0];
695        let result = lgamma_batch_wgpu(&xs);
696        match result {
697            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
698            Err(e) => panic!("unexpected error: {e}"),
699        }
700    }
701
702    #[test]
703    fn test_erfc_batch_wgpu_returns_not_available() {
704        let xs = vec![0.0_f64, 1.0, -1.0];
705        let result = erfc_batch_wgpu(&xs);
706        match result {
707            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
708            Err(e) => panic!("unexpected error: {e}"),
709        }
710    }
711
712    #[test]
713    fn test_erfinv_batch_wgpu_returns_not_available() {
714        let xs = vec![0.0_f64, 0.5, -0.5];
715        let result = erfinv_batch_wgpu(&xs);
716        match result {
717            Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
718            Err(e) => panic!("unexpected error: {e}"),
719        }
720    }
721
722    #[test]
723    fn test_wgsl_dispatch_error_display() {
724        let e = WgslDispatchError::GpuNotAvailable;
725        assert!(e.to_string().contains("not available"));
726        let e2 = WgslDispatchError::RuntimeError("buffer overflow".into());
727        assert!(e2.to_string().contains("buffer overflow"));
728    }
729}