1pub const GAMMA_WGSL: &str = r#"
30@group(0) @binding(0) var<storage, read> input: array<f32>;
31@group(0) @binding(1) var<storage, read_write> output: array<f32>;
32
33const PI: f32 = 3.14159265358979323846;
34
35// Lanczos g=7 coefficients (Spouge's form, 9 terms)
36fn lanczos_gamma(x_in: f32) -> f32 {
37 var x = x_in;
38 var sign = 1.0f;
39 if x < 0.5 {
40 sign = PI / (sin(PI * x));
41 x = 1.0 - x;
42 }
43 let g: f32 = 7.0;
44 x = x - 1.0;
45
46 let c0: f32 = 0.99999999999980993;
47 let c1: f32 = 676.5203681218851;
48 let c2: f32 = -1259.1392167224028;
49 let c3: f32 = 771.32342877765313;
50 let c4: f32 = -176.61502916214059;
51 let c5: f32 = 12.507343278686905;
52 let c6: f32 = -0.13857109526572012;
53 let c7: f32 = 9.9843695780195716e-6;
54 let c8: f32 = 1.5056327351493116e-7;
55
56 let s = c0
57 + c1 / (x + 1.0)
58 + c2 / (x + 2.0)
59 + c3 / (x + 3.0)
60 + c4 / (x + 4.0)
61 + c5 / (x + 5.0)
62 + c6 / (x + 6.0)
63 + c7 / (x + 7.0)
64 + c8 / (x + 8.0);
65
66 let t = x + g + 0.5;
67 let result = sqrt(2.0 * PI) * pow(t, x + 0.5) * exp(-t) * s;
68 if sign != 1.0 { return sign / result; }
69 return result;
70}
71
72@compute @workgroup_size(64)
73fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
74 let idx = gid.x;
75 if idx >= arrayLength(&input) { return; }
76 output[idx] = lanczos_gamma(input[idx]);
77}
78"#;
79
80pub const ERF_WGSL: &str = r#"
84@group(0) @binding(0) var<storage, read> input: array<f32>;
85@group(0) @binding(1) var<storage, read_write> output: array<f32>;
86
87fn approx_erf(x: f32) -> f32 {
88 let t = 1.0 / (1.0 + 0.3275911 * abs(x));
89 let y = 1.0 - (((((
90 1.061405429 * t
91 - 1.453152027) * t
92 + 1.421413741) * t
93 - 0.284496736) * t
94 + 0.254829592) * t * exp(-x * x));
95 return select(-y, y, x >= 0.0);
96}
97
98@compute @workgroup_size(64)
99fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
100 let idx = gid.x;
101 if idx >= arrayLength(&input) { return; }
102 output[idx] = approx_erf(input[idx]);
103}
104"#;
105
106pub const BESSEL_J0_WGSL: &str = r#"
111@group(0) @binding(0) var<storage, read> input: array<f32>;
112@group(0) @binding(1) var<storage, read_write> output: array<f32>;
113
114const PI: f32 = 3.14159265358979323846;
115
116fn bessel_j0(x_in: f32) -> f32 {
117 let x = abs(x_in);
118 if x < 8.0 {
119 let y = x * x;
120 let p1: f32 = 57568490574.0;
121 let p2: f32 = -13362590354.0;
122 let p3: f32 = 651619640.7;
123 let p4: f32 = -11214424.18;
124 let p5: f32 = 77392.33017;
125 let p6: f32 = -184.9052456;
126 let q1: f32 = 57568490411.0;
127 let q2: f32 = 1029532985.0;
128 let q3: f32 = 9494680.718;
129 let q4: f32 = 59272.64853;
130 let q5: f32 = 267.8532712;
131 let p = p1 + y * (p2 + y * (p3 + y * (p4 + y * (p5 + y * p6))));
132 let q = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y))));
133 return p / q;
134 } else {
135 let z = 8.0 / x;
136 let y = z * z;
137 let xx = x - 0.785398164;
138 let pv = 1.0 + y * (-0.1098628627e-2 + y * (0.2734510407e-4
139 + y * (-0.2073370639e-5 + y * 0.2093887211e-6)));
140 let qv = -0.1562499995e-1 + y * (0.1430488765e-3
141 + y * (-0.6911147651e-5 + y * (0.7621095161e-6
142 - y * 0.934945152e-7)));
143 return sqrt(0.636619772 / x) * (cos(xx) * pv - z * sin(xx) * qv);
144 }
145}
146
147@compute @workgroup_size(64)
148fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
149 let idx = gid.x;
150 if idx >= arrayLength(&input) { return; }
151 output[idx] = bessel_j0(input[idx]);
152}
153"#;
154
155pub const ERFC_WGSL: &str = r#"
161@group(0) @binding(0) var<storage, read> input: array<f32>;
162@group(0) @binding(1) var<storage, read_write> output: array<f32>;
163
164fn approx_erf_inner(x: f32) -> f32 {
165 let t = 1.0 / (1.0 + 0.3275911 * abs(x));
166 let y = 1.0 - (((((
167 1.061405429 * t
168 - 1.453152027) * t
169 + 1.421413741) * t
170 - 0.284496736) * t
171 + 0.254829592) * t * exp(-x * x));
172 return select(-y, y, x >= 0.0);
173}
174
175fn approx_erfc(x: f32) -> f32 {
176 // erfc saturates quickly: |erfc(x)| < f32_epsilon for |x| > ~6
177 if abs(x) > 6.0 {
178 return select(0.0, 2.0, x < 0.0);
179 }
180 return 1.0 - approx_erf_inner(x);
181}
182
183@compute @workgroup_size(64)
184fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
185 let idx = gid.x;
186 if idx >= arrayLength(&input) { return; }
187 output[idx] = approx_erfc(input[idx]);
188}
189"#;
190
191pub const ERFINV_WGSL: &str = r#"
197@group(0) @binding(0) var<storage, read> input: array<f32>;
198@group(0) @binding(1) var<storage, read_write> output: array<f32>;
199
200const PI_F: f32 = 3.14159265358979323846;
201const WINITZKI_A: f32 = 0.147;
202const INV_WINITZKI_A: f32 = 6.802721088; // 1.0 / 0.147
203
204fn approx_erfinv(p: f32) -> f32 {
205 let ap = abs(p);
206 if ap >= 1.0 {
207 // Return signed large value for |p| = 1 boundary
208 return select(1e10, -1e10, p < 0.0);
209 }
210 if p == 0.0 {
211 return 0.0;
212 }
213
214 let sign_p = select(-1.0f, 1.0f, p >= 0.0);
215 // Winitzki (2008): erfinv(p) ≈ sign(p) * sqrt(sqrt(c^2 - ln(1-p^2)/a) - c)
216 // where c = 2/(π·a) + ln(1-p^2)/2
217 let ln_term = log(1.0 - p * p);
218 let two_over_pia = 2.0 / (PI_F * WINITZKI_A);
219 let c = two_over_pia + ln_term * 0.5;
220 let discriminant = c * c - ln_term * INV_WINITZKI_A;
221 // discriminant is always non-negative for |p| < 1
222 let inner = sqrt(max(discriminant, 0.0)) - c;
223 return sign_p * sqrt(max(inner, 0.0));
224}
225
226@compute @workgroup_size(64)
227fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
228 let idx = gid.x;
229 if idx >= arrayLength(&input) { return; }
230 output[idx] = approx_erfinv(input[idx]);
231}
232"#;
233
234pub const LGAMMA_WGSL: &str = r#"
240@group(0) @binding(0) var<storage, read> input: array<f32>;
241@group(0) @binding(1) var<storage, read_write> output: array<f32>;
242
243const PI: f32 = 3.14159265358979323846;
244
245fn lanczos_lgamma(x_in: f32) -> f32 {
246 var x = x_in;
247 var log_sign: f32 = 0.0;
248 if x < 0.5 {
249 log_sign = log(PI / abs(sin(PI * x)));
250 x = 1.0 - x;
251 }
252 let g: f32 = 7.0;
253 x = x - 1.0;
254 let c0: f32 = 0.99999999999980993;
255 let c1: f32 = 676.5203681218851;
256 let c2: f32 = -1259.1392167224028;
257 let c3: f32 = 771.32342877765313;
258 let c4: f32 = -176.61502916214059;
259 let c5: f32 = 12.507343278686905;
260 let c6: f32 = -0.13857109526572012;
261 let c7: f32 = 9.9843695780195716e-6;
262 let c8: f32 = 1.5056327351493116e-7;
263 let s = c0 + c1/(x+1.0) + c2/(x+2.0) + c3/(x+3.0) + c4/(x+4.0)
264 + c5/(x+5.0) + c6/(x+6.0) + c7/(x+7.0) + c8/(x+8.0);
265 let t = x + g + 0.5;
266 let lgamma = 0.5 * log(2.0 * PI) + (x + 0.5) * log(t) - t + log(s);
267 if log_sign != 0.0 {
268 return log_sign - lgamma;
269 }
270 return lgamma;
271}
272
273@compute @workgroup_size(64)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
275 let idx = gid.x;
276 if idx >= arrayLength(&input) { return; }
277 output[idx] = lanczos_lgamma(input[idx]);
278}
279"#;
280
281#[derive(Debug, Clone)]
287pub enum WgslDispatchError {
288 GpuNotAvailable,
290 RuntimeError(String),
292}
293
294impl std::fmt::Display for WgslDispatchError {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 match self {
297 WgslDispatchError::GpuNotAvailable => {
298 write!(f, "wgpu GPU device not available")
299 }
300 WgslDispatchError::RuntimeError(msg) => {
301 write!(f, "wgpu runtime error: {msg}")
302 }
303 }
304 }
305}
306
307#[cfg(feature = "wgpu_kernels")]
319fn dispatch_unary_f32(shader_src: &str, xs_f32: &[f32]) -> Result<Vec<f32>, WgslDispatchError> {
320 use wgpu::{
321 util::BufferInitDescriptor, util::DeviceExt as _, Backends, BindGroupDescriptor,
322 BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
323 BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor,
324 ComputePassDescriptor, DeviceDescriptor, Features, Instance, InstanceDescriptor, Limits,
325 MapMode, PowerPreference, RequestAdapterOptions, ShaderModuleDescriptor, ShaderSource,
326 ShaderStages,
327 };
328
329 let n = xs_f32.len();
330 if n == 0 {
331 return Ok(Vec::new());
332 }
333
334 let instance = Instance::new(InstanceDescriptor {
336 backends: Backends::all(),
337 flags: wgpu::InstanceFlags::default(),
338 memory_budget_thresholds: Default::default(),
339 backend_options: Default::default(),
340 display: None,
341 });
342
343 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
344 power_preference: PowerPreference::HighPerformance,
345 compatible_surface: None,
346 force_fallback_adapter: false,
347 }))
348 .map_err(|_| WgslDispatchError::GpuNotAvailable)?;
349
350 let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
351 label: Some("scirs2-special"),
352 required_features: Features::empty(),
353 required_limits: Limits::default(),
354 ..Default::default()
355 }))
356 .map_err(|e| WgslDispatchError::RuntimeError(e.to_string()))?;
357
358 let shader_module = device.create_shader_module(ShaderModuleDescriptor {
360 label: Some("scirs2-special-shader"),
361 source: ShaderSource::Wgsl(shader_src.into()),
362 });
363
364 let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
365 label: Some("scirs2-special-bgl"),
366 entries: &[
367 BindGroupLayoutEntry {
368 binding: 0,
369 visibility: ShaderStages::COMPUTE,
370 ty: BindingType::Buffer {
371 ty: BufferBindingType::Storage { read_only: true },
372 has_dynamic_offset: false,
373 min_binding_size: None,
374 },
375 count: None,
376 },
377 BindGroupLayoutEntry {
378 binding: 1,
379 visibility: ShaderStages::COMPUTE,
380 ty: BindingType::Buffer {
381 ty: BufferBindingType::Storage { read_only: false },
382 has_dynamic_offset: false,
383 min_binding_size: None,
384 },
385 count: None,
386 },
387 ],
388 });
389
390 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
391 label: Some("scirs2-special-layout"),
392 bind_group_layouts: &[Some(&bgl)],
393 ..Default::default()
394 });
395
396 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
397 label: Some("scirs2-special-pipeline"),
398 layout: Some(&pipeline_layout),
399 module: &shader_module,
400 entry_point: Some("main"),
401 compilation_options: Default::default(),
402 cache: None,
403 });
404
405 let input_bytes: Vec<u8> = xs_f32.iter().flat_map(|v| v.to_le_bytes()).collect();
408 let byte_len = (n * 4) as u64;
409
410 let buf_input = device.create_buffer_init(&BufferInitDescriptor {
411 label: Some("scirs2-special-input"),
412 contents: &input_bytes,
413 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
414 });
415
416 let buf_output = device.create_buffer(&BufferDescriptor {
417 label: Some("scirs2-special-output"),
418 size: byte_len,
419 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
420 mapped_at_creation: false,
421 });
422
423 let buf_staging = device.create_buffer(&BufferDescriptor {
424 label: Some("scirs2-special-staging"),
425 size: byte_len,
426 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
427 mapped_at_creation: false,
428 });
429
430 let bind_group = device.create_bind_group(&BindGroupDescriptor {
432 label: Some("scirs2-special-bg"),
433 layout: &bgl,
434 entries: &[
435 BindGroupEntry {
436 binding: 0,
437 resource: buf_input.as_entire_binding(),
438 },
439 BindGroupEntry {
440 binding: 1,
441 resource: buf_output.as_entire_binding(),
442 },
443 ],
444 });
445
446 let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
448 label: Some("scirs2-special-encoder"),
449 });
450 {
451 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
452 label: Some("scirs2-special-pass"),
453 timestamp_writes: None,
454 });
455 cpass.set_pipeline(&pipeline);
456 cpass.set_bind_group(0, &bind_group, &[]);
457 let workgroups = (n as u32 + 63) / 64;
458 cpass.dispatch_workgroups(workgroups, 1, 1);
459 }
460 encoder.copy_buffer_to_buffer(&buf_output, 0, &buf_staging, 0, byte_len);
461 queue.submit(Some(encoder.finish()));
462
463 device
465 .poll(wgpu::PollType::wait_indefinitely())
466 .map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll error: {e:?}")))?;
467
468 let slice = buf_staging.slice(0..byte_len);
469 let (tx, rx) = std::sync::mpsc::channel();
470 slice.map_async(MapMode::Read, move |r| {
471 let _ = tx.send(r);
472 });
473
474 device
475 .poll(wgpu::PollType::wait_indefinitely())
476 .map_err(|e| WgslDispatchError::RuntimeError(format!("GPU poll during map: {e:?}")))?;
477
478 rx.recv()
479 .map_err(|_| WgslDispatchError::RuntimeError("channel closed in map_async".into()))?
480 .map_err(|e| WgslDispatchError::RuntimeError(format!("map_async failed: {e:?}")))?;
481
482 let mapped = slice.get_mapped_range();
483 let result: Vec<f32> = mapped
484 .chunks_exact(4)
485 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
486 .collect();
487 drop(mapped);
488 buf_staging.unmap();
489
490 Ok(result)
491}
492
493#[cfg(feature = "wgpu_kernels")]
503pub fn gamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
504 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
505 let result_f32 = dispatch_unary_f32(GAMMA_WGSL, &xs_f32)?;
506 Ok(result_f32.iter().map(|&v| v as f64).collect())
507}
508
509#[cfg(feature = "wgpu_kernels")]
514pub fn erf_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
515 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
516 let result_f32 = dispatch_unary_f32(ERF_WGSL, &xs_f32)?;
517 Ok(result_f32.iter().map(|&v| v as f64).collect())
518}
519
520#[cfg(feature = "wgpu_kernels")]
525pub fn bessel_j0_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
526 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
527 let result_f32 = dispatch_unary_f32(BESSEL_J0_WGSL, &xs_f32)?;
528 Ok(result_f32.iter().map(|&v| v as f64).collect())
529}
530
531#[cfg(feature = "wgpu_kernels")]
536pub fn lgamma_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
537 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
538 let result_f32 = dispatch_unary_f32(LGAMMA_WGSL, &xs_f32)?;
539 Ok(result_f32.iter().map(|&v| v as f64).collect())
540}
541
542#[cfg(feature = "wgpu_kernels")]
548pub fn erfc_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
549 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
550 let result_f32 = dispatch_unary_f32(ERFC_WGSL, &xs_f32)?;
551 Ok(result_f32.iter().map(|&v| v as f64).collect())
552}
553
554#[cfg(feature = "wgpu_kernels")]
560pub fn erfinv_batch_wgpu(xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
561 let xs_f32: Vec<f32> = xs.iter().map(|&x| x as f32).collect();
562 let result_f32 = dispatch_unary_f32(ERFINV_WGSL, &xs_f32)?;
563 Ok(result_f32.iter().map(|&v| v as f64).collect())
564}
565
566#[cfg(not(feature = "wgpu_kernels"))]
572pub fn gamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
573 Err(WgslDispatchError::GpuNotAvailable)
574}
575
576#[cfg(not(feature = "wgpu_kernels"))]
578pub fn erf_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
579 Err(WgslDispatchError::GpuNotAvailable)
580}
581
582#[cfg(not(feature = "wgpu_kernels"))]
584pub fn bessel_j0_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
585 Err(WgslDispatchError::GpuNotAvailable)
586}
587
588#[cfg(not(feature = "wgpu_kernels"))]
590pub fn lgamma_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
591 Err(WgslDispatchError::GpuNotAvailable)
592}
593
594#[cfg(not(feature = "wgpu_kernels"))]
596pub fn erfc_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
597 Err(WgslDispatchError::GpuNotAvailable)
598}
599
600#[cfg(not(feature = "wgpu_kernels"))]
602pub fn erfinv_batch_wgpu(_xs: &[f64]) -> Result<Vec<f64>, WgslDispatchError> {
603 Err(WgslDispatchError::GpuNotAvailable)
604}
605
606#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_gamma_wgsl_source_is_non_empty() {
616 assert!(!GAMMA_WGSL.is_empty());
617 assert!(GAMMA_WGSL.contains("@compute"));
618 assert!(GAMMA_WGSL.contains("workgroup_size"));
619 assert!(GAMMA_WGSL.contains("lanczos_gamma"));
620 }
621
622 #[test]
623 fn test_erf_wgsl_source_is_non_empty() {
624 assert!(!ERF_WGSL.is_empty());
625 assert!(ERF_WGSL.contains("@compute"));
626 assert!(ERF_WGSL.contains("approx_erf"));
627 }
628
629 #[test]
630 fn test_bessel_j0_wgsl_source_is_non_empty() {
631 assert!(!BESSEL_J0_WGSL.is_empty());
632 assert!(BESSEL_J0_WGSL.contains("@compute"));
633 assert!(BESSEL_J0_WGSL.contains("bessel_j0"));
634 }
635
636 #[test]
637 fn test_lgamma_wgsl_source_is_non_empty() {
638 assert!(!LGAMMA_WGSL.is_empty());
639 assert!(LGAMMA_WGSL.contains("@compute"));
640 assert!(LGAMMA_WGSL.contains("lanczos_lgamma"));
641 }
642
643 #[test]
644 fn test_erfc_wgsl_source_is_non_empty() {
645 assert!(!ERFC_WGSL.is_empty());
646 assert!(ERFC_WGSL.contains("@compute"));
647 assert!(ERFC_WGSL.contains("approx_erfc"));
648 assert!(ERFC_WGSL.contains("workgroup_size"));
649 }
650
651 #[test]
652 fn test_erfinv_wgsl_source_is_non_empty() {
653 assert!(!ERFINV_WGSL.is_empty());
654 assert!(ERFINV_WGSL.contains("@compute"));
655 assert!(ERFINV_WGSL.contains("approx_erfinv"));
656 assert!(ERFINV_WGSL.contains("workgroup_size"));
657 }
658
659 #[test]
660 fn test_gamma_batch_wgpu_returns_not_available() {
661 let xs = vec![1.0_f64, 2.0, 3.0];
664 let result = gamma_batch_wgpu(&xs);
665 match result {
667 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
668 Err(e) => panic!("unexpected error: {e}"),
669 }
670 }
671
672 #[test]
673 fn test_erf_batch_wgpu_returns_not_available() {
674 let xs = vec![0.0_f64, 1.0];
675 let result = erf_batch_wgpu(&xs);
676 match result {
677 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
678 Err(e) => panic!("unexpected error: {e}"),
679 }
680 }
681
682 #[test]
683 fn test_bessel_j0_batch_wgpu_returns_not_available() {
684 let xs = vec![0.0_f64, 2.405];
685 let result = bessel_j0_batch_wgpu(&xs);
686 match result {
687 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
688 Err(e) => panic!("unexpected error: {e}"),
689 }
690 }
691
692 #[test]
693 fn test_lgamma_batch_wgpu_returns_not_available() {
694 let xs = vec![1.0_f64, 2.0, 3.0];
695 let result = lgamma_batch_wgpu(&xs);
696 match result {
697 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
698 Err(e) => panic!("unexpected error: {e}"),
699 }
700 }
701
702 #[test]
703 fn test_erfc_batch_wgpu_returns_not_available() {
704 let xs = vec![0.0_f64, 1.0, -1.0];
705 let result = erfc_batch_wgpu(&xs);
706 match result {
707 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
708 Err(e) => panic!("unexpected error: {e}"),
709 }
710 }
711
712 #[test]
713 fn test_erfinv_batch_wgpu_returns_not_available() {
714 let xs = vec![0.0_f64, 0.5, -0.5];
715 let result = erfinv_batch_wgpu(&xs);
716 match result {
717 Ok(_) | Err(WgslDispatchError::GpuNotAvailable) => {}
718 Err(e) => panic!("unexpected error: {e}"),
719 }
720 }
721
722 #[test]
723 fn test_wgsl_dispatch_error_display() {
724 let e = WgslDispatchError::GpuNotAvailable;
725 assert!(e.to_string().contains("not available"));
726 let e2 = WgslDispatchError::RuntimeError("buffer overflow".into());
727 assert!(e2.to_string().contains("buffer overflow"));
728 }
729}