Skip to main content

tensorlogic_oxicuda_rng/
engine.rs

1//! Core RNG engine: a unified GPU + CPU dual-path random number generator.
2//!
3//! # Design
4//!
5//! [`RngEngine`] exposes a small, ergonomic surface (`uniform_f32`,
6//! `normal_f32`, `bernoulli`, `uniform_f64`, `normal_f64`, streaming variants)
7//! that dispatches at runtime to either:
8//!
9//! * The **CPU path** — a minimal PCG-XSH-RR 64-bit generator with Box-Muller
10//!   transform, implemented entirely in pure Rust with zero external
11//!   dependencies.  This is always available when the `cpu` feature is enabled
12//!   (the default).
13//!
14//! * The **GPU path** — `oxicuda-rand`'s `RngGenerator`, which compiles and
15//!   launches PTX kernels on an NVIDIA CUDA device.  Activated by the `gpu`
16//!   feature at compile time and then conditionally at runtime via
17//!   `gpu_available()`.
18//!
19//! # Thread safety
20//!
21//! On the **CPU path**, `RngEngine` is both [`Send`] and [`Sync`] — the state
22//! is plain integers with no shared mutable references.  On the **GPU path**,
23//! `RngEngine` is `Send` but NOT `Sync` because a CUDA stream cannot be shared
24//! across threads.
25//!
26//! # Policy compliance
27//!
28//! This file does **not** import `rand`, `rand_distr`, or `ndarray`.
29//! The PCG generator and Box-Muller transform are implemented from scratch.
30
31use crate::error::RngError;
32
33// ---------------------------------------------------------------------------
34// Public kind enum
35// ---------------------------------------------------------------------------
36
37/// The RNG algorithm family to request.
38///
39/// On the CPU path all three variants share the same underlying PCG-XSH-RR
40/// 64-bit state machine — the distinction is preserved so that switching
41/// to the GPU path (where Philox, XORWOW, and MRG32k3a map to distinct
42/// cuRAND kernels) is a zero-cost refactor.
43///
44/// On the GPU path the variant selects the corresponding `oxicuda-rand`
45/// engine (`RngEngine::Philox` → Philox-4x32-10, etc.).
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum RngEngineKind {
48    /// Philox-4x32-10 counter-based PRNG (cuRAND default).
49    Philox,
50    /// XORWOW with Weyl sequence addition.
51    Xorwow,
52    /// MRG32k3a combined multiple recursive generator (highest statistical quality).
53    Mrg32k3a,
54}
55
56impl RngEngineKind {
57    /// Returns a stable string representation of the engine kind.
58    pub fn as_str(self) -> &'static str {
59        match self {
60            Self::Philox => "philox",
61            Self::Xorwow => "xorwow",
62            Self::Mrg32k3a => "mrg32k3a",
63        }
64    }
65}
66
67impl std::fmt::Display for RngEngineKind {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str(self.as_str())
70    }
71}
72
73// ---------------------------------------------------------------------------
74// GPU path helpers
75// ---------------------------------------------------------------------------
76
77/// Returns `true` when a CUDA device is accessible at runtime.
78///
79/// In Round 6 this will call `oxicuda_driver::init()` + `Device::count()`.
80/// For now it always returns `false` so the CPU path acts as the universal
81/// fallback.
82#[cfg(feature = "gpu")]
83fn gpu_available() -> bool {
84    false // Round 6 will replace this with a real driver probe
85}
86
87// ---------------------------------------------------------------------------
88// CPU path — PCG-XSH-RR 64-bit generator
89// ---------------------------------------------------------------------------
90
91/// Minimal PCG-XSH-RR 64-bit PRNG.
92///
93/// This is a verbatim implementation of the PCG family algorithm as described
94/// by M. E. O'Neill (2014).  No external crate is used.
95///
96/// State advancement:
97/// ```text
98///   state' = state * PCG_MULT + inc        (all mod 2^64)
99/// ```
100/// Output function (XSH-RR):
101/// ```text
102///   xorshifted = ((state >> 18) ^ state) >> 27   (32-bit result)
103///   rot        = state >> 59
104///   out        = rotate_right(xorshifted, rot)
105/// ```
106#[cfg(feature = "cpu")]
107struct CpuRngState {
108    /// LCG accumulator.
109    state: u64,
110    /// Stream selector — must be odd.
111    inc: u64,
112}
113
114#[cfg(feature = "cpu")]
115impl CpuRngState {
116    const PCG_MULT: u64 = 6_364_136_223_846_793_005_u64;
117
118    /// Constructs a seeded PCG generator.
119    ///
120    /// The stream discriminator is derived from the seed (with the odd-bit
121    /// forced), then the generator is "warmed up" with two advance steps so
122    /// the initial `state = 0` bias is eliminated.
123    fn new(seed: u64) -> Self {
124        let inc = seed.wrapping_shl(1) | 1; // ensure odd — stream != 0 required
125        let mut s = Self { state: 0, inc };
126        // Warm-up: absorb the seed into the state before first output.
127        let _ = s.next_u32(); // advance once from zero
128        s.state = s.state.wrapping_add(seed);
129        let _ = s.next_u32(); // second advance after seeding
130        s
131    }
132
133    /// Returns the next 32-bit pseudorandom output.
134    #[inline]
135    fn next_u32(&mut self) -> u32 {
136        let old = self.state;
137        // Linear congruential step.
138        self.state = old.wrapping_mul(Self::PCG_MULT).wrapping_add(self.inc);
139        // XSH-RR permutation.
140        let xorshifted = (((old >> 18) ^ old) >> 27) as u32;
141        let rot = (old >> 59) as u32;
142        xorshifted.rotate_right(rot)
143    }
144
145    /// Returns a uniform sample in `[0.0, 1.0)` by masking the 23 mantissa
146    /// bits of an f32.
147    #[inline]
148    fn next_f32(&mut self) -> f32 {
149        // Take the top 23 bits, set the exponent to 127 (= 1.0), subtract 1.
150        // This maps the 23-bit integer uniformly to [1.0, 2.0) → shift to [0.0, 1.0).
151        let bits = (self.next_u32() >> 9) | 0x3f80_0000_u32;
152        f32::from_bits(bits) - 1.0_f32
153    }
154
155    /// Returns a pair of independent standard normal samples using Box-Muller.
156    ///
157    /// Box-Muller transform:
158    /// ```text
159    ///   r     = sqrt(-2 * ln(u1))
160    ///   theta = 2 * PI * u2
161    ///   z0    = r * cos(theta)
162    ///   z1    = r * sin(theta)
163    /// ```
164    /// where `u1`, `u2 ~ Uniform(0, 1)`.  We guard against `ln(0)` by
165    /// clamping `u1` to `f32::EPSILON`.
166    #[inline]
167    fn next_normal_pair(&mut self) -> (f32, f32) {
168        // Guard u1 away from zero to avoid ln(0) = -inf.
169        let u1 = {
170            let raw = self.next_f32();
171            if raw < f32::EPSILON {
172                f32::EPSILON
173            } else {
174                raw
175            }
176        };
177        let u2 = self.next_f32();
178
179        let r = (-2.0_f32 * u1.ln()).sqrt();
180        let theta = std::f32::consts::TAU * u2; // TAU = 2π
181        (r * theta.cos(), r * theta.sin())
182    }
183
184    /// Returns the next 64-bit pseudorandom output by combining two 32-bit
185    /// PCG outputs into a single u64.
186    ///
187    /// The high 32 bits come from the first PCG step, the low 32 bits from
188    /// the second.  This preserves the sequential structure of the stream so
189    /// that `next_u64` and `next_u32` interleave predictably.
190    #[inline]
191    fn next_u64(&mut self) -> u64 {
192        let hi = self.next_u32() as u64;
193        let lo = self.next_u32() as u64;
194        (hi << 32) | lo
195    }
196
197    /// Returns a uniform sample in `[0.0, 1.0)` with 52-bit mantissa
198    /// precision using the IEEE 754 exponent-field trick.
199    ///
200    /// Construction:
201    /// ```text
202    ///   bits = next_u64() >> 12               (top 52 bits from the 64-bit PCG output)
203    ///   x    = f64::from_bits(0x3FF0…0 | bits) - 1.0
204    /// ```
205    /// The exponent `0x3FF` represents a biased value of 1023, placing the
206    /// result in `[1.0, 2.0)`.  Subtracting 1.0 maps to `[0.0, 1.0)`.
207    #[inline]
208    fn next_f64(&mut self) -> f64 {
209        // 64-bit PCG output; keep top 52 bits for the f64 mantissa.
210        let bits = self.next_u64();
211        // IEEE 754 double: sign=0, exponent=1023 (0x3FF bias → [1.0, 2.0)).
212        f64::from_bits(0x3FF0_0000_0000_0000_u64 | (bits >> 12)) - 1.0_f64
213    }
214
215    /// Returns a pair of independent standard normal f64 samples via Box-Muller.
216    ///
217    /// We use `(1.0 - u1)` rather than `u1` directly to guarantee the argument
218    /// to `ln()` is strictly in `(0.0, 1.0]`, avoiding `ln(0)`.
219    #[inline]
220    fn next_normal_pair_f64(&mut self) -> (f64, f64) {
221        // u1 ∈ [0, 1) — we invert to (0, 1] before the logarithm.
222        let u1 = self.next_f64();
223        let u2 = self.next_f64();
224
225        // Use (1 - u1) to map [0,1) → (0,1] and guard against ln(0).
226        let safe_u1 = if u1 >= 1.0 { f64::EPSILON } else { 1.0 - u1 };
227        let r = (-2.0_f64 * safe_u1.ln()).sqrt();
228        let theta = std::f64::consts::TAU * u2; // TAU = 2π
229        (r * theta.cos(), r * theta.sin())
230    }
231}
232
233// ---------------------------------------------------------------------------
234// Inner state enum
235// ---------------------------------------------------------------------------
236
237#[cfg(feature = "gpu")]
238use oxicuda_rand::generator::{RngEngine as OxiRngEngine, RngGenerator};
239
240#[cfg(feature = "gpu")]
241use std::sync::Arc;
242
243/// Polymorphic inner state for the dual-path engine.
244enum RngEngineInner {
245    /// Pure-Rust PCG generator (always available when `cpu` feature is on).
246    #[cfg(feature = "cpu")]
247    Cpu(CpuRngState),
248
249    /// GPU-backed generator using `oxicuda-rand`.
250    #[cfg(feature = "gpu")]
251    Gpu(GpuRngState),
252}
253
254/// All GPU-related state bundled together.
255#[cfg(feature = "gpu")]
256struct GpuRngState {
257    generator: RngGenerator,
258}
259
260// ---------------------------------------------------------------------------
261// Public RngEngine
262// ---------------------------------------------------------------------------
263
264/// A seeded, dual-path random number generator.
265///
266/// Constructed via [`RngEngine::new`].  The `gpu` Cargo feature enables the
267/// GPU path; if CUDA is not available at runtime the constructor transparently
268/// falls back to the CPU path.
269///
270/// # Thread safety
271///
272/// On the **CPU path** (`default`), `RngEngine` is both [`Send`] and [`Sync`]
273/// — the state is a pair of `u64` integers and carries no shared references.
274///
275/// On the **GPU path** (`feature = "gpu"`), `RngEngine` is [`Send`] but NOT
276/// [`Sync`].  A CUDA stream cannot be shared across threads; the
277/// `PhantomData<*const ()>` field enforces that statically.
278pub struct RngEngine {
279    /// The engine kind (preserved for introspection and GPU dispatch).
280    kind: RngEngineKind,
281    /// The inner state — either CPU or GPU.
282    inner: RngEngineInner,
283    /// Makes `RngEngine` non-`Sync` on the GPU path only.
284    ///
285    /// On the CPU path this field is absent, allowing the compiler to
286    /// auto-derive `Sync` from the plain-integer fields.
287    #[cfg(feature = "gpu")]
288    _not_sync: std::marker::PhantomData<*const ()>,
289}
290
291// SAFETY: `RngEngine` owns its state exclusively (no shared references).
292// The CPU path is a plain `u64` pair — both `Send` and `Sync` are safe.
293// The GPU path holds a `RngGenerator` which owns a CUDA stream.  Streams are
294// safe to *move* across threads (`Send`) but must not be shared (`!Sync`).
295// We provide an explicit `Send` impl because the PhantomData on the GPU path
296// would otherwise block the auto-derived `Send` as well.
297unsafe impl Send for RngEngine {}
298// `Sync` is intentionally NOT implemented on the GPU path.
299// `PhantomData<*const ()>` prevents the auto-derived impl there.
300// On the CPU path (no PhantomData) the compiler auto-derives `Sync` because
301// all fields are `u64` (which are `Send + Sync`).
302
303impl RngEngine {
304    /// Constructs a new RNG engine of the requested `kind` and `seed`.
305    ///
306    /// When the `gpu` feature is enabled **and** a CUDA device is reachable at
307    /// runtime, the GPU path is chosen; otherwise the CPU path is used.
308    ///
309    /// # Errors
310    ///
311    /// Currently infallible on the CPU path.  Returns [`RngError::GpuError`]
312    /// if CUDA initialisation fails and there is no CPU fallback compiled in.
313    pub fn new(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
314        // ----------------------------------------------------------------
315        // GPU path: attempt to acquire a CUDA context and build a generator.
316        // ----------------------------------------------------------------
317        #[cfg(feature = "gpu")]
318        if gpu_available() {
319            return Self::new_gpu(kind, seed);
320        }
321
322        // ----------------------------------------------------------------
323        // CPU path.
324        // ----------------------------------------------------------------
325        #[cfg(feature = "cpu")]
326        {
327            Ok(Self {
328                kind,
329                inner: RngEngineInner::Cpu(CpuRngState::new(seed)),
330                #[cfg(feature = "gpu")]
331                _not_sync: std::marker::PhantomData,
332            })
333        }
334
335        // If neither feature is compiled in this is unreachable, but the
336        // compiler needs a return expression in all branches.
337        #[cfg(not(any(feature = "cpu", feature = "gpu")))]
338        Err(RngError::GpuError(
339            "no backend compiled: enable the `cpu` or `gpu` feature".to_string(),
340        ))
341    }
342
343    /// Constructs a GPU-backed generator.
344    #[cfg(feature = "gpu")]
345    fn new_gpu(kind: RngEngineKind, seed: u64) -> Result<Self, RngError> {
346        use oxicuda_driver::{context::Context, Device};
347
348        oxicuda_driver::init().map_err(|e| RngError::GpuError(e.to_string()))?;
349        let device = Device::get(0).map_err(|e| RngError::GpuError(e.to_string()))?;
350        let ctx = Arc::new(Context::new(&device).map_err(|e| RngError::GpuError(e.to_string()))?);
351
352        let oxi_kind = match kind {
353            RngEngineKind::Philox => OxiRngEngine::Philox,
354            RngEngineKind::Xorwow => OxiRngEngine::Xorwow,
355            RngEngineKind::Mrg32k3a => OxiRngEngine::Mrg32k3a,
356        };
357
358        let generator = RngGenerator::new(oxi_kind, seed, &ctx)
359            .map_err(|e| RngError::GpuError(e.to_string()))?;
360
361        Ok(Self {
362            kind,
363            inner: RngEngineInner::Gpu(GpuRngState { generator }),
364            _not_sync: std::marker::PhantomData::<*const ()>,
365        })
366    }
367
368    /// Returns the engine kind that was requested at construction.
369    #[inline]
370    pub fn kind(&self) -> RngEngineKind {
371        self.kind
372    }
373
374    /// Returns `true` when the active path is the GPU.
375    pub fn is_gpu(&self) -> bool {
376        match &self.inner {
377            #[cfg(feature = "cpu")]
378            RngEngineInner::Cpu(_) => false,
379            #[cfg(feature = "gpu")]
380            RngEngineInner::Gpu(_) => true,
381        }
382    }
383
384    // -----------------------------------------------------------------------
385    // Uniform f32
386    // -----------------------------------------------------------------------
387
388    /// Fills `out` with independent uniform samples drawn from `[0.0, 1.0)`.
389    ///
390    /// # Errors
391    ///
392    /// * [`RngError::EmptyBuffer`] — `out` is empty.
393    /// * [`RngError::GpuError`]   — CUDA operation failed (GPU path only).
394    pub fn uniform_f32(&mut self, out: &mut [f32]) -> Result<(), RngError> {
395        if out.is_empty() {
396            return Err(RngError::EmptyBuffer);
397        }
398        match &mut self.inner {
399            #[cfg(feature = "cpu")]
400            RngEngineInner::Cpu(state) => {
401                for slot in out.iter_mut() {
402                    *slot = state.next_f32();
403                }
404                Ok(())
405            }
406            #[cfg(feature = "gpu")]
407            RngEngineInner::Gpu(gs) => {
408                use oxicuda_memory::DeviceBuffer;
409                let n = out.len();
410                let mut dev_buf =
411                    DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
412                gs.generator
413                    .generate_uniform_f32(&mut dev_buf)
414                    .map_err(|e| RngError::GpuError(e.to_string()))?;
415                dev_buf
416                    .copy_to_host(out)
417                    .map_err(|e| RngError::GpuError(e.to_string()))?;
418                Ok(())
419            }
420        }
421    }
422
423    // -----------------------------------------------------------------------
424    // Normal f32
425    // -----------------------------------------------------------------------
426
427    /// Fills `out` with independent normal samples from `N(mean, std_dev²)`.
428    ///
429    /// Uses Box-Muller on the CPU path and the engine's native Gaussian kernel
430    /// on the GPU path.
431    ///
432    /// # Errors
433    ///
434    /// * [`RngError::EmptyBuffer`]                  — `out` is empty.
435    /// * [`RngError::InvalidParam`]                 — `std_dev < 0` or not finite.
436    /// * [`RngError::GpuError`]                     — CUDA failure (GPU path).
437    pub fn normal_f32(&mut self, out: &mut [f32], mean: f32, std_dev: f32) -> Result<(), RngError> {
438        if out.is_empty() {
439            return Err(RngError::EmptyBuffer);
440        }
441        if !std_dev.is_finite() || std_dev < 0.0 {
442            return Err(RngError::InvalidParam(format!(
443                "std_dev must be finite and >= 0, got {std_dev}"
444            )));
445        }
446        if !mean.is_finite() {
447            return Err(RngError::InvalidParam(format!(
448                "mean must be finite, got {mean}"
449            )));
450        }
451
452        match &mut self.inner {
453            #[cfg(feature = "cpu")]
454            RngEngineInner::Cpu(state) => {
455                let n = out.len();
456                let mut i = 0usize;
457                // Consume pairs from Box-Muller; handle the odd element.
458                while i + 1 < n {
459                    let (z0, z1) = state.next_normal_pair();
460                    out[i] = mean + std_dev * z0;
461                    out[i + 1] = mean + std_dev * z1;
462                    i += 2;
463                }
464                if i < n {
465                    let (z0, _) = state.next_normal_pair();
466                    out[i] = mean + std_dev * z0;
467                }
468                Ok(())
469            }
470            #[cfg(feature = "gpu")]
471            RngEngineInner::Gpu(gs) => {
472                use oxicuda_memory::DeviceBuffer;
473                let n = out.len();
474                let mut dev_buf =
475                    DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
476                gs.generator
477                    .generate_normal_f32(&mut dev_buf, mean, std_dev)
478                    .map_err(|e| RngError::GpuError(e.to_string()))?;
479                dev_buf
480                    .copy_to_host(out)
481                    .map_err(|e| RngError::GpuError(e.to_string()))?;
482                Ok(())
483            }
484        }
485    }
486
487    // -----------------------------------------------------------------------
488    // Bernoulli
489    // -----------------------------------------------------------------------
490
491    /// Fills `out` with Bernoulli(p) samples: each element is `1u8` with
492    /// probability `p` and `0u8` otherwise.
493    ///
494    /// # Errors
495    ///
496    /// * [`RngError::EmptyBuffer`]  — `out` is empty.
497    /// * [`RngError::InvalidParam`] — `p` is not in `[0.0, 1.0]`.
498    /// * [`RngError::GpuError`]     — CUDA failure (GPU path).
499    pub fn bernoulli(&mut self, out: &mut [u8], p: f32) -> Result<(), RngError> {
500        if out.is_empty() {
501            return Err(RngError::EmptyBuffer);
502        }
503        if !p.is_finite() || !(0.0..=1.0).contains(&p) {
504            return Err(RngError::InvalidParam(format!(
505                "p must be in [0.0, 1.0], got {p}"
506            )));
507        }
508
509        match &mut self.inner {
510            #[cfg(feature = "cpu")]
511            RngEngineInner::Cpu(state) => {
512                for slot in out.iter_mut() {
513                    *slot = u8::from(state.next_f32() < p);
514                }
515                Ok(())
516            }
517            #[cfg(feature = "gpu")]
518            RngEngineInner::Gpu(gs) => {
519                // GPU path: generate uniform f32 on device, threshold on host.
520                // A future optimisation can do the threshold in a PTX kernel.
521                use oxicuda_memory::DeviceBuffer;
522                let n = out.len();
523                let mut dev_buf =
524                    DeviceBuffer::<f32>::alloc(n).map_err(|e| RngError::GpuError(e.to_string()))?;
525                gs.generator
526                    .generate_uniform_f32(&mut dev_buf)
527                    .map_err(|e| RngError::GpuError(e.to_string()))?;
528
529                let mut host_buf = vec![0f32; n];
530                dev_buf
531                    .copy_to_host(&mut host_buf)
532                    .map_err(|e| RngError::GpuError(e.to_string()))?;
533                for (slot, &u) in out.iter_mut().zip(host_buf.iter()) {
534                    *slot = u8::from(u < p);
535                }
536                Ok(())
537            }
538        }
539    }
540
541    // -----------------------------------------------------------------------
542    // Uniform f64
543    // -----------------------------------------------------------------------
544
545    /// Fills `out` with independent uniform samples drawn from `[0.0, 1.0)`
546    /// with 52-bit mantissa precision.
547    ///
548    /// Each value is constructed from a 64-bit PCG output using the IEEE 754
549    /// exponent-field trick: the top 52 bits are inserted into the mantissa of
550    /// a double with exponent bias 1023 (∈ `[1.0, 2.0)`), then 1.0 is
551    /// subtracted to shift to `[0.0, 1.0)`.
552    ///
553    /// # Errors
554    ///
555    /// * [`RngError::EmptyBuffer`] — `out` is empty.
556    pub fn uniform_f64(&mut self, out: &mut [f64]) -> Result<(), RngError> {
557        if out.is_empty() {
558            return Err(RngError::EmptyBuffer);
559        }
560        match &mut self.inner {
561            #[cfg(feature = "cpu")]
562            RngEngineInner::Cpu(state) => {
563                for slot in out.iter_mut() {
564                    *slot = state.next_f64();
565                }
566                Ok(())
567            }
568            #[cfg(feature = "gpu")]
569            RngEngineInner::Gpu(_gs) => {
570                // GPU path: no native f64 cuRAND kernel wired yet; use CPU
571                // emulation on the host side for correctness.
572                Err(RngError::GpuError(
573                    "uniform_f64 on GPU path not yet implemented".to_string(),
574                ))
575            }
576        }
577    }
578
579    // -----------------------------------------------------------------------
580    // Normal f64
581    // -----------------------------------------------------------------------
582
583    /// Fills `out` with independent normal samples from `N(mean, std_dev²)`
584    /// with double precision.
585    ///
586    /// Uses Box-Muller on the CPU path.  Each pair of output values consumes
587    /// two `uniform_f64` draws; an odd-length buffer consumes one additional
588    /// pair (discarding the second normal from the last Box-Muller step).
589    ///
590    /// # Errors
591    ///
592    /// * [`RngError::EmptyBuffer`]  — `out` is empty.
593    /// * [`RngError::InvalidParam`] — `std_dev < 0` or not finite, or `mean`
594    ///   is not finite.
595    pub fn normal_f64(&mut self, out: &mut [f64], mean: f64, std_dev: f64) -> Result<(), RngError> {
596        if out.is_empty() {
597            return Err(RngError::EmptyBuffer);
598        }
599        if !std_dev.is_finite() || std_dev < 0.0 {
600            return Err(RngError::InvalidParam(format!(
601                "std_dev must be finite and >= 0, got {std_dev}"
602            )));
603        }
604        if !mean.is_finite() {
605            return Err(RngError::InvalidParam(format!(
606                "mean must be finite, got {mean}"
607            )));
608        }
609
610        match &mut self.inner {
611            #[cfg(feature = "cpu")]
612            RngEngineInner::Cpu(state) => {
613                let n = out.len();
614                let mut i = 0usize;
615                // Consume pairs from Box-Muller; handle the odd trailing element.
616                while i + 1 < n {
617                    let (z0, z1) = state.next_normal_pair_f64();
618                    out[i] = mean + std_dev * z0;
619                    out[i + 1] = mean + std_dev * z1;
620                    i += 2;
621                }
622                if i < n {
623                    let (z0, _) = state.next_normal_pair_f64();
624                    out[i] = mean + std_dev * z0;
625                }
626                Ok(())
627            }
628            #[cfg(feature = "gpu")]
629            RngEngineInner::Gpu(_gs) => Err(RngError::GpuError(
630                "normal_f64 on GPU path not yet implemented".to_string(),
631            )),
632        }
633    }
634
635    // -----------------------------------------------------------------------
636    // Streaming API
637    // -----------------------------------------------------------------------
638
639    /// Generates `total` f32 uniform samples and delivers them in chunks of at
640    /// most `chunk_size` elements, calling `consumer` once per chunk.
641    ///
642    /// The final chunk may be smaller than `chunk_size` when `total` is not a
643    /// multiple of `chunk_size`.
644    ///
645    /// # Determinism
646    ///
647    /// Given the same seed and the same `total`, the complete sequence of
648    /// generated values is identical regardless of `chunk_size`.  The chunk
649    /// size only affects how many values are presented per callback.
650    ///
651    /// # Errors
652    ///
653    /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
654    pub fn fill_uniform_chunked<F: FnMut(&[f32])>(
655        &mut self,
656        total: usize,
657        chunk_size: usize,
658        consumer: &mut F,
659    ) -> Result<(), RngError> {
660        if total == 0 || chunk_size == 0 {
661            return Err(RngError::EmptyBuffer);
662        }
663
664        let mut buf = vec![0f32; chunk_size];
665        let mut remaining = total;
666
667        while remaining > 0 {
668            let n = remaining.min(chunk_size);
669            self.uniform_f32(&mut buf[..n])?;
670            consumer(&buf[..n]);
671            remaining -= n;
672        }
673        Ok(())
674    }
675
676    /// Generates `total` f64 uniform samples and delivers them in chunks of at
677    /// most `chunk_size` elements, calling `consumer` once per chunk.
678    ///
679    /// The final chunk may be smaller than `chunk_size` when `total` is not a
680    /// multiple of `chunk_size`.
681    ///
682    /// # Determinism
683    ///
684    /// Given the same seed and the same `total`, the complete sequence of
685    /// generated values is identical regardless of `chunk_size`.
686    ///
687    /// # Errors
688    ///
689    /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
690    pub fn fill_uniform_chunked_f64<F: FnMut(&[f64])>(
691        &mut self,
692        total: usize,
693        chunk_size: usize,
694        consumer: &mut F,
695    ) -> Result<(), RngError> {
696        if total == 0 || chunk_size == 0 {
697            return Err(RngError::EmptyBuffer);
698        }
699
700        let mut buf = vec![0f64; chunk_size];
701        let mut remaining = total;
702
703        while remaining > 0 {
704            let n = remaining.min(chunk_size);
705            self.uniform_f64(&mut buf[..n])?;
706            consumer(&buf[..n]);
707            remaining -= n;
708        }
709        Ok(())
710    }
711
712    /// Generates `total` f32 normal samples from `N(mean, std_dev²)` and
713    /// delivers them in chunks of at most `chunk_size` elements, calling
714    /// `consumer` once per chunk.
715    ///
716    /// The final chunk may be smaller than `chunk_size`.
717    ///
718    /// # Determinism
719    ///
720    /// Given the same seed and the same `total`, the full sequence is identical
721    /// regardless of `chunk_size`.  Note: because Box-Muller consumes values in
722    /// pairs, chunk boundaries that split a pair internally will advance the
723    /// stream by a full pair — the global sequence is determined by `total`, not
724    /// chunk boundaries.
725    ///
726    /// # Errors
727    ///
728    /// * [`RngError::EmptyBuffer`] — `total == 0` or `chunk_size == 0`.
729    /// * [`RngError::InvalidParam`] — `std_dev < 0` or not finite, or `mean`
730    ///   is not finite.
731    pub fn fill_normal_chunked<F: FnMut(&[f32])>(
732        &mut self,
733        total: usize,
734        chunk_size: usize,
735        mean: f32,
736        std_dev: f32,
737        consumer: &mut F,
738    ) -> Result<(), RngError> {
739        if total == 0 || chunk_size == 0 {
740            return Err(RngError::EmptyBuffer);
741        }
742        if !std_dev.is_finite() || std_dev < 0.0 {
743            return Err(RngError::InvalidParam(format!(
744                "std_dev must be finite and >= 0, got {std_dev}"
745            )));
746        }
747        if !mean.is_finite() {
748            return Err(RngError::InvalidParam(format!(
749                "mean must be finite, got {mean}"
750            )));
751        }
752
753        let mut buf = vec![0f32; chunk_size];
754        let mut remaining = total;
755
756        while remaining > 0 {
757            let n = remaining.min(chunk_size);
758            self.normal_f32(&mut buf[..n], mean, std_dev)?;
759            consumer(&buf[..n]);
760            remaining -= n;
761        }
762        Ok(())
763    }
764}
765
766// ---------------------------------------------------------------------------
767// Unit tests (CPU-path only, no CUDA device required)
768// ---------------------------------------------------------------------------
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773
774    // -----------------------------------------------------------------------
775    // PCG internals
776    // -----------------------------------------------------------------------
777
778    #[test]
779    #[cfg(feature = "cpu")]
780    fn pcg_inc_is_odd() {
781        // The inc field must be odd for full-period PCG.
782        for seed in [0u64, 1, 42, u64::MAX, u64::MAX / 2] {
783            let state = CpuRngState::new(seed);
784            assert_eq!(state.inc & 1, 1, "inc must be odd for seed={seed}");
785        }
786    }
787
788    #[test]
789    #[cfg(feature = "cpu")]
790    fn pcg_uniform_in_range() {
791        let mut state = CpuRngState::new(12345);
792        for _ in 0..10_000 {
793            let v = state.next_f32();
794            assert!(
795                (0.0..1.0).contains(&v),
796                "uniform sample {v} not in [0.0, 1.0)"
797            );
798        }
799    }
800
801    #[test]
802    #[cfg(feature = "cpu")]
803    fn pcg_deterministic_replay() {
804        let mut a = CpuRngState::new(777);
805        let mut b = CpuRngState::new(777);
806        for _ in 0..1000 {
807            assert_eq!(a.next_u32(), b.next_u32());
808        }
809    }
810
811    #[test]
812    #[cfg(feature = "cpu")]
813    fn pcg_different_seeds_differ() {
814        let mut a = CpuRngState::new(0);
815        let mut b = CpuRngState::new(1);
816        // Extremely unlikely that 100 consecutive u32 outputs are identical.
817        let outputs_a: Vec<u32> = (0..100).map(|_| a.next_u32()).collect();
818        let outputs_b: Vec<u32> = (0..100).map(|_| b.next_u32()).collect();
819        assert_ne!(
820            outputs_a, outputs_b,
821            "different seeds should produce different sequences"
822        );
823    }
824
825    // -----------------------------------------------------------------------
826    // Box-Muller
827    // -----------------------------------------------------------------------
828
829    #[test]
830    #[cfg(feature = "cpu")]
831    fn box_muller_pair_is_finite() {
832        let mut state = CpuRngState::new(42);
833        for _ in 0..10_000 {
834            let (z0, z1) = state.next_normal_pair();
835            assert!(z0.is_finite(), "z0 is not finite: {z0}");
836            assert!(z1.is_finite(), "z1 is not finite: {z1}");
837        }
838    }
839
840    // -----------------------------------------------------------------------
841    // RngEngineKind
842    // -----------------------------------------------------------------------
843
844    #[test]
845    fn engine_kind_as_str() {
846        assert_eq!(RngEngineKind::Philox.as_str(), "philox");
847        assert_eq!(RngEngineKind::Xorwow.as_str(), "xorwow");
848        assert_eq!(RngEngineKind::Mrg32k3a.as_str(), "mrg32k3a");
849    }
850
851    #[test]
852    fn engine_kind_display() {
853        assert_eq!(format!("{}", RngEngineKind::Philox), "philox");
854        assert_eq!(format!("{}", RngEngineKind::Xorwow), "xorwow");
855        assert_eq!(format!("{}", RngEngineKind::Mrg32k3a), "mrg32k3a");
856    }
857
858    // -----------------------------------------------------------------------
859    // RngEngine construction & properties
860    // -----------------------------------------------------------------------
861
862    #[test]
863    fn engine_new_returns_ok() {
864        for kind in [
865            RngEngineKind::Philox,
866            RngEngineKind::Xorwow,
867            RngEngineKind::Mrg32k3a,
868        ] {
869            assert!(
870                RngEngine::new(kind, 0).is_ok(),
871                "construction failed for {kind}"
872            );
873        }
874    }
875
876    #[test]
877    fn engine_kind_accessor() {
878        let eng = RngEngine::new(RngEngineKind::Mrg32k3a, 1).unwrap();
879        assert_eq!(eng.kind(), RngEngineKind::Mrg32k3a);
880    }
881
882    #[test]
883    fn engine_is_not_gpu_in_ci() {
884        // GPU path always falls back to CPU in CI (no CUDA device).
885        let eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
886        assert!(!eng.is_gpu());
887    }
888
889    // -----------------------------------------------------------------------
890    // Error handling
891    // -----------------------------------------------------------------------
892
893    #[test]
894    fn uniform_empty_buffer_error() {
895        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
896        let mut out: Vec<f32> = vec![];
897        assert!(matches!(
898            eng.uniform_f32(&mut out),
899            Err(RngError::EmptyBuffer)
900        ));
901    }
902
903    #[test]
904    fn normal_empty_buffer_error() {
905        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
906        let mut out: Vec<f32> = vec![];
907        assert!(matches!(
908            eng.normal_f32(&mut out, 0.0, 1.0),
909            Err(RngError::EmptyBuffer)
910        ));
911    }
912
913    #[test]
914    fn bernoulli_empty_buffer_error() {
915        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
916        let mut out: Vec<u8> = vec![];
917        assert!(matches!(
918            eng.bernoulli(&mut out, 0.5),
919            Err(RngError::EmptyBuffer)
920        ));
921    }
922
923    #[test]
924    fn normal_negative_stddev_error() {
925        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
926        let mut out = vec![0f32; 10];
927        assert!(matches!(
928            eng.normal_f32(&mut out, 0.0, -1.0),
929            Err(RngError::InvalidParam(_))
930        ));
931    }
932
933    #[test]
934    fn normal_nan_mean_error() {
935        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
936        let mut out = vec![0f32; 10];
937        assert!(matches!(
938            eng.normal_f32(&mut out, f32::NAN, 1.0),
939            Err(RngError::InvalidParam(_))
940        ));
941    }
942
943    #[test]
944    fn bernoulli_invalid_p_error() {
945        let mut eng = RngEngine::new(RngEngineKind::Philox, 0).unwrap();
946        let mut out = vec![0u8; 10];
947        assert!(matches!(
948            eng.bernoulli(&mut out, -0.1),
949            Err(RngError::InvalidParam(_))
950        ));
951        assert!(matches!(
952            eng.bernoulli(&mut out, 1.1),
953            Err(RngError::InvalidParam(_))
954        ));
955        assert!(matches!(
956            eng.bernoulli(&mut out, f32::NAN),
957            Err(RngError::InvalidParam(_))
958        ));
959    }
960
961    // -----------------------------------------------------------------------
962    // Statistical sanity — quick checks (small N, loose tolerances)
963    // -----------------------------------------------------------------------
964
965    #[test]
966    fn uniform_in_range() {
967        let mut eng = RngEngine::new(RngEngineKind::Philox, 42).unwrap();
968        let mut out = vec![0f32; 1_000];
969        eng.uniform_f32(&mut out).unwrap();
970        for &v in &out {
971            assert!((0.0..1.0).contains(&v), "uniform sample {v} out of [0,1)");
972        }
973    }
974
975    #[test]
976    fn normal_odd_length_fills_all_elements() {
977        // Exercises the trailing odd-element branch in normal_f32.
978        let mut eng = RngEngine::new(RngEngineKind::Xorwow, 99).unwrap();
979        let mut out = vec![f32::NAN; 7]; // odd length
980        eng.normal_f32(&mut out, 0.0, 1.0).unwrap();
981        for (i, &v) in out.iter().enumerate() {
982            assert!(v.is_finite(), "element {i} is not finite: {v}");
983        }
984    }
985
986    #[test]
987    fn bernoulli_outputs_only_zero_or_one() {
988        let mut eng = RngEngine::new(RngEngineKind::Mrg32k3a, 555).unwrap();
989        let mut out = vec![255u8; 1_000];
990        eng.bernoulli(&mut out, 0.5).unwrap();
991        for &b in &out {
992            assert!(b == 0 || b == 1, "bernoulli output {b} is not 0 or 1");
993        }
994    }
995
996    #[test]
997    fn bernoulli_p_zero_produces_all_zeros() {
998        let mut eng = RngEngine::new(RngEngineKind::Philox, 1).unwrap();
999        let mut out = vec![1u8; 500];
1000        eng.bernoulli(&mut out, 0.0).unwrap();
1001        assert!(out.iter().all(|&b| b == 0));
1002    }
1003
1004    #[test]
1005    fn bernoulli_p_one_produces_all_ones() {
1006        let mut eng = RngEngine::new(RngEngineKind::Philox, 2).unwrap();
1007        let mut out = vec![0u8; 500];
1008        eng.bernoulli(&mut out, 1.0).unwrap();
1009        assert!(out.iter().all(|&b| b == 1));
1010    }
1011}
1012
1013// ---------------------------------------------------------------------------
1014// Compile-time Send+Sync assertions for the CPU path
1015// ---------------------------------------------------------------------------
1016
1017/// Verifies at compile time that [`RngEngine`] is both [`Send`] and [`Sync`]
1018/// on the CPU path (no `gpu` feature).
1019///
1020/// If the type bounds fail, this module fails to compile — no runtime test
1021/// needed.
1022#[cfg(not(feature = "gpu"))]
1023mod send_sync_assertions {
1024    use super::RngEngine;
1025
1026    fn _assert_send<T: Send>() {}
1027    fn _assert_sync<T: Sync>() {}
1028
1029    fn _check_rng_engine_send_sync() {
1030        _assert_send::<RngEngine>();
1031        _assert_sync::<RngEngine>();
1032    }
1033}