tensorlogic_oxicuda_rng/engine.rs
1//! Core RNG engine: a unified GPU + CPU dual-path random number generator.
2//!
3//! # Design
4//!
5//! [`RngEngine`] exposes a small, ergonomic surface (`uniform_f32`,
6//! `normal_f32`, `bernoulli`, `uniform_f64`, `normal_f64`, streaming variants)
7//! that dispatches at runtime to either:
8//!
9//! * The **CPU path** — a minimal PCG-XSH-RR 64-bit generator with Box-Muller
10//! transform, implemented entirely in pure Rust with zero external
11//! dependencies. This is always available when the `cpu` feature is enabled
12//! (the default).
13//!
14//! * The **GPU path** — `oxicuda-rand`'s `RngGenerator`, which compiles and
15//! launches PTX kernels on an NVIDIA CUDA device. Activated by the `gpu`
16//! feature at compile time and then conditionally at runtime via
17//! `gpu_available()`.
18//!
19//! # Thread safety
20//!
21//! On the **CPU path**, `RngEngine` is both [`Send`] and [`Sync`] — the state
22//! is plain integers with no shared mutable references. On the **GPU path**,
23//! `RngEngine` is `Send` but NOT `Sync` because a CUDA stream cannot be shared
24//! across threads.
25//!
26//! # Policy compliance
27//!
28//! This file does **not** import `rand`, `rand_distr`, or `ndarray`.
29//! The PCG generator and Box-Muller transform are implemented from scratch.
30
31use crate::error::RngError;
32
33// ---------------------------------------------------------------------------
34// Public kind enum
35// ---------------------------------------------------------------------------
36
37/// The RNG algorithm family to request.
38///
39/// On the CPU path all three variants share the same underlying PCG-XSH-RR
40/// 64-bit state machine — the distinction is preserved so that switching
41/// to the GPU path (where Philox, XORWOW, and MRG32k3a map to distinct
42/// cuRAND kernels) is a zero-cost refactor.
43///
44/// On the GPU path the variant selects the corresponding `oxicuda-rand`
45/// engine (`RngEngine::Philox` → Philox-4x32-10, etc.).
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum RngEngineKind {
48 /// Philox-4x32-10 counter-based PRNG (cuRAND default).
49 Philox,
50 /// XORWOW with Weyl sequence addition.
51 Xorwow,
52 /// MRG32k3a combined multiple recursive generator (highest statistical quality).
53 Mrg32k3a,
54}
55
56impl RngEngineKind {
57 /// Returns a stable string representation of the engine kind.
58 pub fn as_str(self) -> &'static str {
59 match self {
60 Self::Philox => "philox",
61 Self::Xorwow => "xorwow",
62 Self::Mrg32k3a => "mrg32k3a",
63 }
64 }
65}
66
67impl std::fmt::Display for RngEngineKind {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str(self.as_str())
70 }
71}
72
73// ---------------------------------------------------------------------------
74// GPU path helpers
75// ---------------------------------------------------------------------------
76
77/// Returns `true` when a CUDA device is accessible at runtime.
78///
79/// In Round 6 this will call `oxicuda_driver::init()` + `Device::count()`.
80/// For now it always returns `false` so the CPU path acts as the universal
81/// fallback.
82#[cfg(feature = "gpu")]
83fn gpu_available() -> bool {
84 false // Round 6 will replace this with a real driver probe
85}
86
87// ---------------------------------------------------------------------------
88// CPU path — PCG-XSH-RR 64-bit generator
89// ---------------------------------------------------------------------------
90
91/// Minimal PCG-XSH-RR 64-bit PRNG.
92///
93/// This is a verbatim implementation of the PCG family algorithm as described
94/// by M. E. O'Neill (2014). No external crate is used.
95///
96/// State advancement:
97/// ```text
98/// state' = state * PCG_MULT + inc (all mod 2^64)
99/// ```
100/// Output function (XSH-RR):
101/// ```text
102/// xorshifted = ((state >> 18) ^ state) >> 27 (32-bit result)
103/// rot = state >> 59
104/// out = rotate_right(xorshifted, rot)
105/// ```
106#[cfg(feature = "cpu")]
107struct CpuRngState {
108 /// LCG accumulator.
109 state: u64,
110 /// Stream selector — must be odd.
111 inc: u64,
112}
113
114#[cfg(feature = "cpu")]
115impl CpuRngState {
116 const PCG_MULT: u64 = 6_364_136_223_846_793_005_u64;
117
118 /// Constructs a seeded PCG generator.
119 ///
120 /// The stream discriminator is derived from the seed (with the odd-bit
121 /// forced), then the generator is "warmed up" with two advance steps so
122 /// the initial `state = 0` bias is eliminated.
123 fn new(seed: u64) -> Self {
124 let inc = seed.wrapping_shl(1) | 1; // ensure odd — stream != 0 required
125 let mut s = Self { state: 0, inc };
126 // Warm-up: absorb the seed into the state before first output.
127 let _ = s.next_u32(); // advance once from zero
128 s.state = s.state.wrapping_add(seed);
129 let _ = s.next_u32(); // second advance after seeding
130 s
131 }
132
133 /// Returns the next 32-bit pseudorandom output.
134 #[inline]
135 fn next_u32(&mut self) -> u32 {
136 let old = self.state;
137 // Linear congruential step.
138 self.state = old.wrapping_mul(Self::PCG_MULT).wrapping_add(self.inc);
139 // XSH-RR permutation.
140 let xorshifted = (((old >> 18) ^ old) >> 27) as u32;
141 let rot = (old >> 59) as u32;
142 xorshifted.rotate_right(rot)
143 }
144
145 /// Returns a uniform sample in `[0.0, 1.0)` by masking the 23 mantissa
146 /// bits of an f32.
147 #[inline]
148 fn next_f32(&mut self) -> f32 {
149 // Take the top 23 bits, set the exponent to 127 (= 1.0), subtract 1.
150 // This maps the 23-bit integer uniformly to [1.0, 2.0) → shift to [0.0, 1.0).
151 let bits = (self.next_u32() >> 9) | 0x3f80_0000_u32;
152 f32::from_bits(bits) - 1.0_f32
153 }
154
155 /// Returns a pair of independent standard normal samples using Box-Muller.
156 ///
157 /// Box-Muller transform:
158 /// ```text
159 /// r = sqrt(-2 * ln(u1))
160 /// theta = 2 * PI * u2
161 /// z0 = r * cos(theta)
162 /// z1 = r * sin(theta)
163 /// ```
164 /// where `u1`, `u2 ~ Uniform(0, 1)`. We guard against `ln(0)` by
165 /// clamping `u1` to `f32::EPSILON`.
166 #[inline]
167 fn next_normal_pair(&mut self) -> (f32, f32) {
168 // Guard u1 away from zero to avoid ln(0) = -inf.
169 let u1 = {
170 let raw = self.next_f32();
171 if raw < f32::EPSILON {
172 f32::EPSILON
173 } else {
174 raw
175 }
176 };
177 let u2 = self.next_f32();
178
179 let r = (-2.0_f32 * u1.ln()).sqrt();
180 let theta = std::f32::consts::TAU * u2; // TAU = 2π
181 (r * theta.cos(), r * theta.sin())
182 }
183
184 /// Returns the next 64-bit pseudorandom output by combining two 32-bit
185 /// PCG outputs into a single u64.
186 ///
187 /// The high 32 bits come from the first PCG step, the low 32 bits from
188 /// the second. This preserves the sequential structure of the stream so
189 /// that `next_u64` and `next_u32` interleave predictably.
190 #[inline]
191 fn next_u64(&mut self) -> u64 {
192 let hi = self.next_u32() as u64;
193 let lo = self.next_u32() as u64;
194 (hi << 32) | lo
195 }
196
197 /// Returns a uniform sample in `[0.0, 1.0)` with 52-bit mantissa
198 /// precision using the IEEE 754 exponent-field trick.
199 ///
200 /// Construction:
201 /// ```text
202 /// bits = next_u64() >> 12 (top 52 bits from the 64-bit PCG output)
203 /// x = f64::from_bits(0x3FF0…0 | bits) - 1.0
204 /// ```
205 /// The exponent `0x3FF` represents a biased value of 1023, placing the
206 /// result in `[1.0, 2.0)`. Subtracting 1.0 maps to `[0.0, 1.0)`.
207 #[inline]
208 fn next_f64(&mut self) -> f64 {
209 // 64-bit PCG output; keep top 52 bits for the f64 mantissa.
210 let bits = self.next_u64();
211 // IEEE 754 double: sign=0, exponent=1023 (0x3FF bias → [1.0, 2.0)).
212 f64::from_bits(0x3FF0_0000_0000_0000_u64 | (bits >> 12)) - 1.0_f64
213 }
214
215 /// Returns a pair of independent standard normal f64 samples via Box-Muller.
216 ///
217 /// We use `(1.0 - u1)` rather than `u1` directly to guarantee the argument
218 /// to `ln()` is strictly in `(0.0, 1.0]`, avoiding `ln(0)`.
219 #[inline]
220 fn next_normal_pair_f64(&mut self) -> (f64, f64) {
221 // u1 ∈ [0, 1) — we invert to (0, 1] before the logarithm.
222 let u1 = self.next_f64();
223 let u2 = self.next_f64();
224
225 // Use (1 - u1) to map [0,1) → (0,1] and guard against ln(0).
226 let safe_u1 = if u1 >= 1.0 { f64::EPSILON } else { 1.0 - u1 };
227 let r = (-2.0_f64 * safe_u1.ln()).sqrt();
228 let theta = std::f64::consts::TAU * u2; // TAU = 2π
229 (r * theta.cos(), r * theta.sin())
230 }
231}
232
233// ---------------------------------------------------------------------------
234// Inner state enum
235// ---------------------------------------------------------------------------
236
237#[cfg(feature = "gpu")]
238use oxicuda_rand::generator::{RngEngine as OxiRngEngine, RngGenerator};
239
240#[cfg(feature = "gpu")]
241use std::sync::Arc;
242
243/// Polymorphic inner state for the dual-path engine.
244enum RngEngineInner {
245 /// Pure-Rust PCG generator (always available when `cpu` feature is on).
246 #[cfg(feature = "cpu")]
247 Cpu(CpuRngState),
248
249 /// GPU-backed generator using `oxicuda-rand`.
250 #[cfg(feature = "gpu")]
251 Gpu(GpuRngState),
252}
253
254/// All GPU-related state bundled together.
255#[cfg(feature = "gpu")]
256struct GpuRngState {
257 generator: RngGenerator,
258}
259
260// ---------------------------------------------------------------------------
261// Public RngEngine
262// ---------------------------------------------------------------------------
263
264/// A seeded, dual-path random number generator.
265///
266/// Constructed via [`RngEngine::new`]. The `gpu` Cargo feature enables the
267/// GPU path; if CUDA is not available at runtime the constructor transparently
268/// falls back to the CPU path.
269///
270/// # Thread safety
271///
272/// On the **CPU path** (`default`), `RngEngine` is both [`Send`] and [`Sync`]
273/// — the state is a pair of `u64` integers and carries no shared references.
274///
275/// On the **GPU path** (`feature = "gpu"`), `RngEngine` is [`Send`] but NOT
276/// [`Sync`]. A CUDA stream cannot be shared across threads; the
277/// `PhantomData<*const ()>` field enforces that statically.
278pub struct RngEngine {
279 /// The engine kind (preserved for introspection and GPU dispatch).
280 kind: RngEngineKind,
281 /// The inner state — either CPU or GPU.
282 inner: RngEngineInner,
283 /// Makes `RngEngine` non-`Sync` on the GPU path only.
284 ///
285 /// On the CPU path this field is absent, allowing the compiler to
286 /// auto-derive `Sync` from the plain-integer fields.
287 #[cfg(feature = "gpu")]
288 _not_sync: std::marker::PhantomData<*const ()>,
289}
290
291// SAFETY: `RngEngine` owns its state exclusively (no shared references).
292// The CPU path is a plain `u64` pair — both `Send` and `Sync` are safe.
293// The GPU path holds a `RngGenerator` which owns a CUDA stream. Streams are
294// safe to *move* across threads (`Send`) but must not be shared (`!Sync`).
295// We provide an explicit `Send` impl because the PhantomData on the GPU path
296// would otherwise block the auto-derived `Send` as well.
297unsafe impl Send for RngEngine {}
298// `Sync` is intentionally NOT implemented on the GPU path.
299// `PhantomData<*const ()>` prevents the auto-derived impl there.
300// On the CPU path (no PhantomData) the compiler auto-derives `Sync` because
301// all fields are `u64` (which are `Send + Sync`).
302
303impl RngEngine {
304 /// Constructs a new RNG engine of the requested `kind` and `seed`.
305 ///
306 /// When the `gpu` feature is enabled **and** a CUDA device is reachable at
307 /// runtime, the GPU path is chosen; otherwise the CPU path is used.
308 ///
309 /// # Errors
310 ///
311 /// Currently infallible on the CPU path. Returns [`RngError::GpuError`]
312 /// if CUDA initialisation fails and there is no CPU fallback compiled in.
313 pub fn new(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
314 // ----------------------------------------------------------------
315 // GPU path: attempt to acquire a CUDA context and build a generator.
316 // ----------------------------------------------------------------
317 #[cfg(feature = "gpu")]
318 if gpu_available() {
319 return Self::new_gpu(kind, seed);
320 }
321
322 // ----------------------------------------------------------------
323 // CPU path.
324 // ----------------------------------------------------------------
325 #[cfg(feature = "cpu")]
326 {
327 Ok(Self {
328 kind,
329 inner: RngEngineInner::Cpu(CpuRngState::new(seed)),
330 #[cfg(feature = "gpu")]
331 _not_sync: std::marker::PhantomData,
332 })
333 }
334
335 // If neither feature is compiled in this is unreachable, but the
336 // compiler needs a return expression in all branches.
337 #[cfg(not(any(feature = "cpu", feature = "gpu")))]
338 Err(RngError::GpuError(
339 "no backend compiled: enable the `cpu` or `gpu` feature".to_string(),
340 ))
341 }
342
343 /// Constructs a GPU-backed generator.
344 #[cfg(feature = "gpu")]
345 fn new_gpu(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
346 use oxicuda_driver::{context::Context, Device};
347
348 oxicuda_driver::init().map_err(|e| RngError::GpuError(e.to_string()))?;
349 let device = Device::get(0).map_err(|e| RngError::GpuError(e.to_string()))?;
350 let ctx = Arc::new(Context::new(&device).map_err(|e| RngError::GpuError(e.to_string()))?);
351
352 let oxi_kind = match kind {
353 RngEngineKind::Philox => OxiRngEngine::Philox,
354 RngEngineKind::Xorwow => OxiRngEngine::Xorwow,
355 RngEngineKind::Mrg32k3a => OxiRngEngine::Mrg32k3a,
356 };
357
358 let generator = RngGenerator::new(oxi_kind, seed, &ctx)
359 .map_err(|e| RngError::GpuError(e.to_string()))?;
360
361 Ok(Self {
362 kind,
363 inner: RngEngineInner::Gpu(GpuRngState { generator }),
364 _not_sync: std::marker::PhantomData::<*const ()>,
365 })
366 }
367
368 /// Returns the engine kind that was requested at construction.
369 #[inline]
370 pub fn kind(&self) -> RngEngineKind {
371 self.kind
372 }
373
374 /// Returns `true` when the active path is the GPU.
375 pub fn is_gpu(&self) -> bool {
376 match &self.inner {
377 #[cfg(feature = "cpu")]
378 RngEngineInner::Cpu(_) => false,
379 #[cfg(feature = "gpu")]
380 RngEngineInner::Gpu(_) => true,
381 }
382 }
383
384 // -----------------------------------------------------------------------
385 // Uniform f32
386 // -----------------------------------------------------------------------
387
388 /// Fills `out` with independent uniform samples drawn from `[0.0, 1.0)`.
389 ///
390 /// # Errors
391 ///
392 /// * [`RngError::EmptyBuffer`] — `out` is empty.
393 /// * [`RngError::GpuError`] — CUDA operation failed (GPU path only).
394 pub fn uniform_f32(&mut self, out: &mut [f32]) -> Result<(), RngError> {
395 if out.is_empty() {
396 return Err(RngError::EmptyBuffer);
397 }
398 match &mut self.inner {
399 #[cfg(feature = "cpu")]
400 RngEngineInner::Cpu(state) => {
401 for slot in out.iter_mut() {
402 *slot = state.next_f32();
403 }
404 Ok(())
405 }
406 #[cfg(feature = "gpu")]
407 RngEngineInner::Gpu(gs) => {
408 use oxicuda_memory::DeviceBuffer;
409 let n = out.len();
410 let mut dev_buf =
411 DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
412 gs.generator
413 .generate_uniform_f32(&mut dev_buf)
414 .map_err(|e| RngError::GpuError(e.to_string()))?;
415 dev_buf
416 .copy_to_host(out)
417 .map_err(|e| RngError::GpuError(e.to_string()))?;
418 Ok(())
419 }
420 }
421 }
422
423 // -----------------------------------------------------------------------
424 // Normal f32
425 // -----------------------------------------------------------------------
426
427 /// Fills `out` with independent normal samples from `N(mean, std_dev²)`.
428 ///
429 /// Uses Box-Muller on the CPU path and the engine's native Gaussian kernel
430 /// on the GPU path.
431 ///
432 /// # Errors
433 ///
434 /// * [`RngError::EmptyBuffer`] — `out` is empty.
435 /// * [`RngError::InvalidParam`] — `std_dev < 0` or not finite.
436 /// * [`RngError::GpuError`] — CUDA failure (GPU path).
437 pub fn normal_f32(&mut self, out: &mut [f32], mean: f32, std_dev: f32) -> Result<(), RngError> {
438 if out.is_empty() {
439 return Err(RngError::EmptyBuffer);
440 }
441 if !std_dev.is_finite() || std_dev < 0.0 {
442 return Err(RngError::InvalidParam(format!(
443 "std_dev must be finite and >= 0, got {std_dev}"
444 )));
445 }
446 if !mean.is_finite() {
447 return Err(RngError::InvalidParam(format!(
448 "mean must be finite, got {mean}"
449 )));
450 }
451
452 match &mut self.inner {
453 #[cfg(feature = "cpu")]
454 RngEngineInner::Cpu(state) => {
455 let n = out.len();
456 let mut i = 0usize;
457 // Consume pairs from Box-Muller; handle the odd element.
458 while i + 1 < n {
459 let (z0, z1) = state.next_normal_pair();
460 out[i] = mean + std_dev * z0;
461 out[i + 1] = mean + std_dev * z1;
462 i += 2;
463 }
464 if i < n {
465 let (z0, _) = state.next_normal_pair();
466 out[i] = mean + std_dev * z0;
467 }
468 Ok(())
469 }
470 #[cfg(feature = "gpu")]
471 RngEngineInner::Gpu(gs) => {
472 use oxicuda_memory::DeviceBuffer;
473 let n = out.len();
474 let mut dev_buf =
475 DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
476 gs.generator
477 .generate_normal_f32(&mut dev_buf, mean, std_dev)
478 .map_err(|e| RngError::GpuError(e.to_string()))?;
479 dev_buf
480 .copy_to_host(out)
481 .map_err(|e| RngError::GpuError(e.to_string()))?;
482 Ok(())
483 }
484 }
485 }
486
487 // -----------------------------------------------------------------------
488 // Bernoulli
489 // -----------------------------------------------------------------------
490
491 /// Fills `out` with Bernoulli(p) samples: each element is `1u8` with
492 /// probability `p` and `0u8` otherwise.
493 ///
494 /// # Errors
495 ///
496 /// * [`RngError::EmptyBuffer`] — `out` is empty.
497 /// * [`RngError::InvalidParam`] — `p` is not in `[0.0, 1.0]`.
498 /// * [`RngError::GpuError`] — CUDA failure (GPU path).
499 pub fn bernoulli(&mut self, out: &mut [u8], p: f32) -> Result<(), RngError> {
500 if out.is_empty() {
501 return Err(RngError::EmptyBuffer);
502 }
503 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
504 return Err(RngError::InvalidParam(format!(
505 "p must be in [0.0, 1.0], got {p}"
506 )));
507 }
508
509 match &mut self.inner {
510 #[cfg(feature = "cpu")]
511 RngEngineInner::Cpu(state) => {
512 for slot in out.iter_mut() {
513 *slot = u8::from(state.next_f32() < p);
514 }
515 Ok(())
516 }
517 #[cfg(feature = "gpu")]
518 RngEngineInner::Gpu(gs) => {
519 // GPU path: generate uniform f32 on device, threshold on host.
520 // A future optimisation can do the threshold in a PTX kernel.
521 use oxicuda_memory::DeviceBuffer;
522 let n = out.len();
523 let mut dev_buf =
524 DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
525 gs.generator
526 .generate_uniform_f32(&mut dev_buf)
527 .map_err(|e| RngError::GpuError(e.to_string()))?;
528
529 let mut host_buf = vec![0f32; n];
530 dev_buf
531 .copy_to_host(&mut host_buf)
532 .map_err(|e| RngError::GpuError(e.to_string()))?;
533 for (slot, &u) in out.iter_mut().zip(host_buf.iter()) {
534 *slot = u8::from(u < p);
535 }
536 Ok(())
537 }
538 }
539 }
540
541 // -----------------------------------------------------------------------
542 // Uniform f64
543 // -----------------------------------------------------------------------
544
545 /// Fills `out` with independent uniform samples drawn from `[0.0, 1.0)`
546 /// with 52-bit mantissa precision.
547 ///
548 /// Each value is constructed from a 64-bit PCG output using the IEEE 754
549 /// exponent-field trick: the top 52 bits are inserted into the mantissa of
550 /// a double with exponent bias 1023 (∈ `[1.0, 2.0)`), then 1.0 is
551 /// subtracted to shift to `[0.0, 1.0)`.
552 ///
553 /// # Errors
554 ///
555 /// * [`RngError::EmptyBuffer`] — `out` is empty.
556 pub fn uniform_f64(&mut self, out: &mut [f64]) -> Result<(), RngError> {
557 if out.is_empty() {
558 return Err(RngError::EmptyBuffer);
559 }
560 match &mut self.inner {
561 #[cfg(feature = "cpu")]
562 RngEngineInner::Cpu(state) => {
563 for slot in out.iter_mut() {
564 *slot = state.next_f64();
565 }
566 Ok(())
567 }
568 #[cfg(feature = "gpu")]
569 RngEngineInner::Gpu(_gs) => {
570 // GPU path: no native f64 cuRAND kernel wired yet; use CPU
571 // emulation on the host side for correctness.
572 Err(RngError::GpuError(
573 "uniform_f64 on GPU path not yet implemented".to_string(),
574 ))
575 }
576 }
577 }
578
579 // -----------------------------------------------------------------------
580 // Normal f64
581 // -----------------------------------------------------------------------
582
583 /// Fills `out` with independent normal samples from `N(mean, std_dev²)`
584 /// with double precision.
585 ///
586 /// Uses Box-Muller on the CPU path. Each pair of output values consumes
587 /// two `uniform_f64` draws; an odd-length buffer consumes one additional
588 /// pair (discarding the second normal from the last Box-Muller step).
589 ///
590 /// # Errors
591 ///
592 /// * [`RngError::EmptyBuffer`] — `out` is empty.
593 /// * [`RngError::InvalidParam`] — `std_dev < 0` or not finite, or `mean`
594 /// is not finite.
595 pub fn normal_f64(&mut self, out: &mut [f64], mean: f64, std_dev: f64) -> Result<(), RngError> {
596 if out.is_empty() {
597 return Err(RngError::EmptyBuffer);
598 }
599 if !std_dev.is_finite() || std_dev < 0.0 {
600 return Err(RngError::InvalidParam(format!(
601 "std_dev must be finite and >= 0, got {std_dev}"
602 )));
603 }
604 if !mean.is_finite() {
605 return Err(RngError::InvalidParam(format!(
606 "mean must be finite, got {mean}"
607 )));
608 }
609
610 match &mut self.inner {
611 #[cfg(feature = "cpu")]
612 RngEngineInner::Cpu(state) => {
613 let n = out.len();
614 let mut i = 0usize;
615 // Consume pairs from Box-Muller; handle the odd trailing element.
616 while i + 1 < n {
617 let (z0, z1) = state.next_normal_pair_f64();
618 out[i] = mean + std_dev * z0;
619 out[i + 1] = mean + std_dev * z1;
620 i += 2;
621 }
622 if i < n {
623 let (z0, _) = state.next_normal_pair_f64();
624 out[i] = mean + std_dev * z0;
625 }
626 Ok(())
627 }
628 #[cfg(feature = "gpu")]
629 RngEngineInner::Gpu(_gs) => Err(RngError::GpuError(
630 "normal_f64 on GPU path not yet implemented".to_string(),
631 )),
632 }
633 }
634
635 // -----------------------------------------------------------------------
636 // Streaming API
637 // -----------------------------------------------------------------------
638
639 /// Generates `total` f32 uniform samples and delivers them in chunks of at
640 /// most `chunk_size` elements, calling `consumer` once per chunk.
641 ///
642 /// The final chunk may be smaller than `chunk_size` when `total` is not a
643 /// multiple of `chunk_size`.
644 ///
645 /// # Determinism
646 ///
647 /// Given the same seed and the same `total`, the complete sequence of
648 /// generated values is identical regardless of `chunk_size`. The chunk
649 /// size only affects how many values are presented per callback.
650 ///
651 /// # Errors
652 ///
653 /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
654 pub fn fill_uniform_chunked<F: FnMut(&[f32])>(
655 &mut self,
656 total: usize,
657 chunk_size: usize,
658 consumer: &mut F,
659 ) -> Result<(), RngError> {
660 if total == 0 || chunk_size == 0 {
661 return Err(RngError::EmptyBuffer);
662 }
663
664 let mut buf = vec![0f32; chunk_size];
665 let mut remaining = total;
666
667 while remaining > 0 {
668 let n = remaining.min(chunk_size);
669 self.uniform_f32(&mut buf[..n])?;
670 consumer(&buf[..n]);
671 remaining -= n;
672 }
673 Ok(())
674 }
675
676 /// Generates `total` f64 uniform samples and delivers them in chunks of at
677 /// most `chunk_size` elements, calling `consumer` once per chunk.
678 ///
679 /// The final chunk may be smaller than `chunk_size` when `total` is not a
680 /// multiple of `chunk_size`.
681 ///
682 /// # Determinism
683 ///
684 /// Given the same seed and the same `total`, the complete sequence of
685 /// generated values is identical regardless of `chunk_size`.
686 ///
687 /// # Errors
688 ///
689 /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
690 pub fn fill_uniform_chunked_f64<F: FnMut(&[f64])>(
691 &mut self,
692 total: usize,
693 chunk_size: usize,
694 consumer: &mut F,
695 ) -> Result<(), RngError> {
696 if total == 0 || chunk_size == 0 {
697 return Err(RngError::EmptyBuffer);
698 }
699
700 let mut buf = vec![0f64; chunk_size];
701 let mut remaining = total;
702
703 while remaining > 0 {
704 let n = remaining.min(chunk_size);
705 self.uniform_f64(&mut buf[..n])?;
706 consumer(&buf[..n]);
707 remaining -= n;
708 }
709 Ok(())
710 }
711
712 /// Generates `total` f32 normal samples from `N(mean, std_dev²)` and
713 /// delivers them in chunks of at most `chunk_size` elements, calling
714 /// `consumer` once per chunk.
715 ///
716 /// The final chunk may be smaller than `chunk_size`.
717 ///
718 /// # Determinism
719 ///
720 /// Given the same seed and the same `total`, the full sequence is identical
721 /// regardless of `chunk_size`. Note: because Box-Muller consumes values in
722 /// pairs, chunk boundaries that split a pair internally will advance the
723 /// stream by a full pair — the global sequence is determined by `total`, not
724 /// chunk boundaries.
725 ///
726 /// # Errors
727 ///
728 /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
729 /// * [`RngError::InvalidParam`] — `std_dev < 0` or not finite, or `mean`
730 /// is not finite.
731 pub fn fill_normal_chunked<F: FnMut(&[f32])>(
732 &mut self,
733 total: usize,
734 chunk_size: usize,
735 mean: f32,
736 std_dev: f32,
737 consumer: &mut F,
738 ) -> Result<(), RngError> {
739 if total == 0 || chunk_size == 0 {
740 return Err(RngError::EmptyBuffer);
741 }
742 if !std_dev.is_finite() || std_dev < 0.0 {
743 return Err(RngError::InvalidParam(format!(
744 "std_dev must be finite and >= 0, got {std_dev}"
745 )));
746 }
747 if !mean.is_finite() {
748 return Err(RngError::InvalidParam(format!(
749 "mean must be finite, got {mean}"
750 )));
751 }
752
753 let mut buf = vec![0f32; chunk_size];
754 let mut remaining = total;
755
756 while remaining > 0 {
757 let n = remaining.min(chunk_size);
758 self.normal_f32(&mut buf[..n], mean, std_dev)?;
759 consumer(&buf[..n]);
760 remaining -= n;
761 }
762 Ok(())
763 }
764}
765
766// ---------------------------------------------------------------------------
767// Unit tests (CPU-path only, no CUDA device required)
768// ---------------------------------------------------------------------------
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773
774 // -----------------------------------------------------------------------
775 // PCG internals
776 // -----------------------------------------------------------------------
777
778 #[test]
779 #[cfg(feature = "cpu")]
780 fn pcg_inc_is_odd() {
781 // The inc field must be odd for full-period PCG.
782 for seed in [0u64, 1, 42, u64::MAX, u64::MAX / 2] {
783 let state = CpuRngState::new(seed);
784 assert_eq!(state.inc & 1, 1, "inc must be odd for seed={seed}");
785 }
786 }
787
788 #[test]
789 #[cfg(feature = "cpu")]
790 fn pcg_uniform_in_range() {
791 let mut state = CpuRngState::new(12345);
792 for _ in 0..10_000 {
793 let v = state.next_f32();
794 assert!(
795 (0.0..1.0).contains(&v),
796 "uniform sample {v} not in [0.0, 1.0)"
797 );
798 }
799 }
800
801 #[test]
802 #[cfg(feature = "cpu")]
803 fn pcg_deterministic_replay() {
804 let mut a = CpuRngState::new(777);
805 let mut b = CpuRngState::new(777);
806 for _ in 0..1000 {
807 assert_eq!(a.next_u32(), b.next_u32());
808 }
809 }
810
811 #[test]
812 #[cfg(feature = "cpu")]
813 fn pcg_different_seeds_differ() {
814 let mut a = CpuRngState::new(0);
815 let mut b = CpuRngState::new(1);
816 // Extremely unlikely that 100 consecutive u32 outputs are identical.
817 let outputs_a: Vec<u32> = (0..100).map(|_| a.next_u32()).collect();
818 let outputs_b: Vec<u32> = (0..100).map(|_| b.next_u32()).collect();
819 assert_ne!(
820 outputs_a, outputs_b,
821 "different seeds should produce different sequences"
822 );
823 }
824
825 // -----------------------------------------------------------------------
826 // Box-Muller
827 // -----------------------------------------------------------------------
828
829 #[test]
830 #[cfg(feature = "cpu")]
831 fn box_muller_pair_is_finite() {
832 let mut state = CpuRngState::new(42);
833 for _ in 0..10_000 {
834 let (z0, z1) = state.next_normal_pair();
835 assert!(z0.is_finite(), "z0 is not finite: {z0}");
836 assert!(z1.is_finite(), "z1 is not finite: {z1}");
837 }
838 }
839
840 // -----------------------------------------------------------------------
841 // RngEngineKind
842 // -----------------------------------------------------------------------
843
844 #[test]
845 fn engine_kind_as_str() {
846 assert_eq!(RngEngineKind::Philox.as_str(), "philox");
847 assert_eq!(RngEngineKind::Xorwow.as_str(), "xorwow");
848 assert_eq!(RngEngineKind::Mrg32k3a.as_str(), "mrg32k3a");
849 }
850
851 #[test]
852 fn engine_kind_display() {
853 assert_eq!(format!("{}", RngEngineKind::Philox), "philox");
854 assert_eq!(format!("{}", RngEngineKind::Xorwow), "xorwow");
855 assert_eq!(format!("{}", RngEngineKind::Mrg32k3a), "mrg32k3a");
856 }
857
858 // -----------------------------------------------------------------------
859 // RngEngine construction & properties
860 // -----------------------------------------------------------------------
861
862 #[test]
863 fn engine_new_returns_ok() {
864 for kind in [
865 RngEngineKind::Philox,
866 RngEngineKind::Xorwow,
867 RngEngineKind::Mrg32k3a,
868 ] {
869 assert!(
870 RngEngine::new(kind, 0).is_ok(),
871 "construction failed for {kind}"
872 );
873 }
874 }
875
876 #[test]
877 fn engine_kind_accessor() {
878 let eng = RngEngine::new(RngEngineKind::Mrg32k3a, 1).unwrap();
879 assert_eq!(eng.kind(), RngEngineKind::Mrg32k3a);
880 }
881
882 #[test]
883 fn engine_is_not_gpu_in_ci() {
884 // GPU path always falls back to CPU in CI (no CUDA device).
885 let eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
886 assert!(!eng.is_gpu());
887 }
888
889 // -----------------------------------------------------------------------
890 // Error handling
891 // -----------------------------------------------------------------------
892
893 #[test]
894 fn uniform_empty_buffer_error() {
895 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
896 let mut out: Vec<f32> = vec![];
897 assert!(matches!(
898 eng.uniform_f32(&mut out),
899 Err(RngError::EmptyBuffer)
900 ));
901 }
902
903 #[test]
904 fn normal_empty_buffer_error() {
905 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
906 let mut out: Vec<f32> = vec![];
907 assert!(matches!(
908 eng.normal_f32(&mut out, 0.0, 1.0),
909 Err(RngError::EmptyBuffer)
910 ));
911 }
912
913 #[test]
914 fn bernoulli_empty_buffer_error() {
915 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
916 let mut out: Vec<u8> = vec![];
917 assert!(matches!(
918 eng.bernoulli(&mut out, 0.5),
919 Err(RngError::EmptyBuffer)
920 ));
921 }
922
923 #[test]
924 fn normal_negative_stddev_error() {
925 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
926 let mut out = vec![0f32; 10];
927 assert!(matches!(
928 eng.normal_f32(&mut out, 0.0, -1.0),
929 Err(RngError::InvalidParam(_))
930 ));
931 }
932
933 #[test]
934 fn normal_nan_mean_error() {
935 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
936 let mut out = vec![0f32; 10];
937 assert!(matches!(
938 eng.normal_f32(&mut out, f32::NAN, 1.0),
939 Err(RngError::InvalidParam(_))
940 ));
941 }
942
943 #[test]
944 fn bernoulli_invalid_p_error() {
945 let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
946 let mut out = vec![0u8; 10];
947 assert!(matches!(
948 eng.bernoulli(&mut out, -0.1),
949 Err(RngError::InvalidParam(_))
950 ));
951 assert!(matches!(
952 eng.bernoulli(&mut out, 1.1),
953 Err(RngError::InvalidParam(_))
954 ));
955 assert!(matches!(
956 eng.bernoulli(&mut out, f32::NAN),
957 Err(RngError::InvalidParam(_))
958 ));
959 }
960
961 // -----------------------------------------------------------------------
962 // Statistical sanity — quick checks (small N, loose tolerances)
963 // -----------------------------------------------------------------------
964
965 #[test]
966 fn uniform_in_range() {
967 let mut eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
968 let mut out = vec![0f32; 1_000];
969 eng.uniform_f32(&mut out).unwrap();
970 for &v in &out {
971 assert!((0.0..1.0).contains(&v), "uniform sample {v} out of [0,1)");
972 }
973 }
974
975 #[test]
976 fn normal_odd_length_fills_all_elements() {
977 // Exercises the trailing odd-element branch in normal_f32.
978 let mut eng = RngEngine::new(RngEngineKind::Xorwow, 99).unwrap();
979 let mut out = vec![f32::NAN; 7]; // odd length
980 eng.normal_f32(&mut out, 0.0, 1.0).unwrap();
981 for (i, &v) in out.iter().enumerate() {
982 assert!(v.is_finite(), "element {i} is not finite: {v}");
983 }
984 }
985
986 #[test]
987 fn bernoulli_outputs_only_zero_or_one() {
988 let mut eng = RngEngine::new(RngEngineKind::Mrg32k3a, 555).unwrap();
989 let mut out = vec![255u8; 1_000];
990 eng.bernoulli(&mut out, 0.5).unwrap();
991 for &b in &out {
992 assert!(b == 0 || b == 1, "bernoulli output {b} is not 0 or 1");
993 }
994 }
995
996 #[test]
997 fn bernoulli_p_zero_produces_all_zeros() {
998 let mut eng = RngEngine::new(RngEngineKind::Philox, 1).unwrap();
999 let mut out = vec![1u8; 500];
1000 eng.bernoulli(&mut out, 0.0).unwrap();
1001 assert!(out.iter().all(|&b| b == 0));
1002 }
1003
1004 #[test]
1005 fn bernoulli_p_one_produces_all_ones() {
1006 let mut eng = RngEngine::new(RngEngineKind::Philox, 2).unwrap();
1007 let mut out = vec![0u8; 500];
1008 eng.bernoulli(&mut out, 1.0).unwrap();
1009 assert!(out.iter().all(|&b| b == 1));
1010 }
1011}
1012
1013// ---------------------------------------------------------------------------
1014// Compile-time Send+Sync assertions for the CPU path
1015// ---------------------------------------------------------------------------
1016
1017/// Verifies at compile time that [`RngEngine`] is both [`Send`] and [`Sync`]
1018/// on the CPU path (no `gpu` feature).
1019///
1020/// If the type bounds fail, this module fails to compile — no runtime test
1021/// needed.
1022#[cfg(not(feature = "gpu"))]
1023mod send_sync_assertions {
1024 use super::RngEngine;
1025
1026 fn _assert_send<T: Send>() {}
1027 fn _assert_sync<T: Sync>() {}
1028
1029 fn _check_rng_engine_send_sync() {
1030 _assert_send::<RngEngine>();
1031 _assert_sync::<RngEngine>();
1032 }
1033}