tfhe_csprng/generators/
mod.rs

1//! A module containing random generators objects.
2//!
3//! See [crate-level](`crate`) explanations.
4use crate::seeders::SeedKind;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7
8/// The number of children created when a generator is forked.
9#[derive(Debug, Copy, Clone)]
10pub struct ChildrenCount(pub u64);
11
12/// The number of bytes each child can generate, when a generator is forked.
13#[derive(Debug, Copy, Clone)]
14pub struct BytesPerChild(pub u64);
15
16/// A structure representing the number of bytes between two table indices.
17#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)]
18pub struct ByteCount(pub u128);
19
20/// Multiplies two u64 values without overflow, returning the full 128-bit product.
21pub(crate) fn widening_mul(a: u64, b: u64) -> u128 {
22    (a as u128) * (b as u128)
23}
24
25/// An error occurring during a generator fork.
26#[derive(Debug)]
27pub enum ForkError {
28    ForkTooLarge,
29    ZeroChildrenCount,
30    ZeroBytesPerChild,
31}
32
33impl Display for ForkError {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        match self {
36            ForkError::ForkTooLarge => {
37                write!(
38                    f,
39                    "The children generators would output bytes after the parent bound. "
40                )
41            }
42            ForkError::ZeroChildrenCount => {
43                write!(
44                    f,
45                    "The number of children in the fork must be greater than zero."
46                )
47            }
48            ForkError::ZeroBytesPerChild => {
49                write!(
50                    f,
51                    "The number of bytes per child must be greater than zero."
52                )
53            }
54        }
55    }
56}
57impl Error for ForkError {}
58
59/// A trait for cryptographically secure pseudo-random generators.
60///
61/// See the [crate-level](#crate) documentation for details.
62pub trait RandomGenerator: Iterator<Item = u8> {
63    /// The iterator over children generators, returned by `try_fork` in case of success.
64    type ChildrenIter: Iterator<Item = Self>;
65
66    /// Creates a new generator from a seed.
67    ///
68    /// This operation is usually costly to perform, as the aes round keys need to be generated from
69    /// the seed.
70    fn new(seed: impl Into<SeedKind>) -> Self;
71
72    /// Returns the number of bytes that can still be outputted by the generator before reaching its
73    /// bound.
74    ///
75    /// Note:
76    /// -----
77    ///
78    /// A fresh generator can generate 2¹³² bytes. Unfortunately, no rust integer type in is able
79    /// to encode such a large number. Consequently [`ByteCount`] uses the largest integer type
80    /// available to encode this value: the `u128` type. For this reason, this method does not
81    /// effectively return the number of remaining bytes, but instead
82    /// `min(2¹²⁸-1, remaining_bytes)`.
83    fn remaining_bytes(&self) -> ByteCount;
84
85    /// Returns the next byte of the stream, if the generator did not yet reach its bound.
86    fn next_byte(&mut self) -> Option<u8> {
87        self.next()
88    }
89
90    /// Tries to fork the generator into an iterator of `n_children` new generators, each able to
91    /// output `n_bytes` bytes.
92    ///
93    /// Note:
94    /// -----
95    ///
96    /// To be successful, the number of remaining bytes for the parent generator must be larger than
97    /// `n_children*n_bytes`.
98    fn try_fork(
99        &mut self,
100        n_children: ChildrenCount,
101        n_bytes: BytesPerChild,
102    ) -> Result<Self::ChildrenIter, ForkError>;
103}
104
105/// A trait extending [`RandomGenerator`] to the parallel iterators of `rayon`.
106#[cfg(feature = "parallel")]
107pub trait ParallelRandomGenerator: RandomGenerator + Send {
108    /// The iterator over children generators, returned by `par_try_fork` in case of success.
109    type ParChildrenIter: rayon::prelude::IndexedParallelIterator<Item = Self>;
110
111    /// Tries to fork the generator into a parallel iterator of `n_children` new generators, each
112    /// able to output `n_bytes` bytes.
113    ///
114    /// Note:
115    /// -----
116    ///
117    /// To be successful, the number of remaining bytes for the parent generator must be larger than
118    /// `n_children*n_bytes`.
119    fn par_try_fork(
120        &mut self,
121        n_children: ChildrenCount,
122        n_bytes: BytesPerChild,
123    ) -> Result<Self::ParChildrenIter, ForkError>;
124}
125
126mod aes_ctr;
127
128mod implem;
129pub use implem::*;
130
131pub mod default;
132/// Convenience alias for the most efficient CSPRNG implementation available.
133pub use default::DefaultRandomGenerator;
134
135#[cfg(test)]
136#[allow(unused)] // to please clippy when tests are not activated
137pub mod generator_generic_test {
138    use super::*;
139    use crate::seeders::{Seed, XofSeed};
140    use rand::Rng;
141
142    const REPEATS: usize = 1_000;
143
144    fn any_seed() -> impl Iterator<Item = Seed> {
145        std::iter::repeat_with(|| Seed(rand::thread_rng().gen()))
146    }
147
148    fn some_children_count() -> impl Iterator<Item = ChildrenCount> {
149        std::iter::repeat_with(|| ChildrenCount(rand::thread_rng().gen::<u64>() % 16 + 1))
150    }
151
152    fn some_bytes_per_child() -> impl Iterator<Item = BytesPerChild> {
153        std::iter::repeat_with(|| BytesPerChild(rand::thread_rng().gen::<u64>() % 128 + 1))
154    }
155
156    /// Checks that the PRNG roughly generates uniform numbers.
157    ///
158    /// To do that, we perform an histogram of the occurrences of each byte value, over a fixed
159    /// number of samples and check that the empirical probabilities of the bins are close to
160    /// the theoretical probabilities.
161    pub fn test_roughly_uniform<G: RandomGenerator>() {
162        // Number of bins to use for the histogram.
163        const N_BINS: usize = u8::MAX as usize + 1;
164        // Number of samples to use for the histogram.
165        let n_samples = 10_000_000_usize;
166        // Theoretical probability of a each bins.
167        let expected_prob: f64 = 1. / N_BINS as f64;
168        // Absolute error allowed on the empirical probabilities.
169        // This value was tuned to make the test pass on an arguably correct state of
170        // implementation. 10^-4 precision is arguably pretty fine for this rough test, but it would
171        // be interesting to improve this test.
172        let precision = 10f64.powi(-3);
173
174        for _ in 0..REPEATS {
175            // We instantiate a new generator.
176            let seed = any_seed().next().unwrap();
177            let mut generator = G::new(seed);
178            // We create a new histogram
179            let mut counts = [0usize; N_BINS];
180            // We fill the histogram.
181            for _ in 0..n_samples {
182                counts[generator.next_byte().unwrap() as usize] += 1;
183            }
184            // We check that the empirical probabilities are close enough to the theoretical one.
185            counts
186                .iter()
187                .map(|a| (*a as f64) / (n_samples as f64))
188                .for_each(|a| assert!((a - expected_prob).abs() < precision))
189        }
190    }
191
192    /// Checks that given a state and a key, the PRNG is determinist.
193    pub fn test_generator_determinism<G: RandomGenerator>() {
194        for _ in 0..REPEATS {
195            let seed = any_seed().next().unwrap();
196            let mut first_generator = G::new(seed);
197            let mut second_generator = G::new(seed);
198            for _ in 0..1024 {
199                assert_eq!(first_generator.next(), second_generator.next());
200            }
201        }
202    }
203
204    /// Checks that forks returns a bounded child, and that the proper number of bytes can be
205    /// generated.
206    pub fn test_fork_children<G: RandomGenerator>() {
207        for _ in 0..REPEATS {
208            let ((seed, n_children), n_bytes) = any_seed()
209                .zip(some_children_count())
210                .zip(some_bytes_per_child())
211                .next()
212                .unwrap();
213            let mut gen = G::new(seed);
214            let mut bounded = gen.try_fork(n_children, n_bytes).unwrap().next().unwrap();
215            assert_eq!(bounded.remaining_bytes(), ByteCount(n_bytes.0 as u128));
216            for _ in 0..n_bytes.0 {
217                bounded.next().unwrap();
218            }
219
220            // Assert we are at the bound
221            assert!(bounded.next().is_none());
222        }
223    }
224
225    /// Checks that a bounded prng returns none when exceeding the allowed number of bytes.
226    ///
227    /// To properly check for panic use `#[should_panic(expected = "expected test panic")]` as an
228    /// attribute on the test function.
229    pub fn test_bounded_none_should_panic<G: RandomGenerator>() {
230        let ((seed, n_children), n_bytes) = any_seed()
231            .zip(some_children_count())
232            .zip(some_bytes_per_child())
233            .next()
234            .unwrap();
235        let mut gen = G::new(seed);
236        let mut bounded = gen.try_fork(n_children, n_bytes).unwrap().next().unwrap();
237        assert_eq!(bounded.remaining_bytes(), ByteCount(n_bytes.0 as u128));
238        for _ in 0..n_bytes.0 {
239            assert!(bounded.next().is_some());
240        }
241
242        // One call too many, should panic
243        bounded.next().ok_or("expected test panic").unwrap();
244    }
245
246    pub fn test_vectors<G: RandomGenerator>() {
247        // Number of random bytes to generate,
248        // this should be 2 batch worth of aes calls (where a batch is 8 aes)
249        const N_BYTES: usize = 16 * 2 * 8;
250
251        const EXPECTED_BYTE: [u8; N_BYTES] = [
252            14, 216, 93, 249, 97, 26, 187, 114, 73, 205, 209, 104, 197, 70, 126, 250, 235, 1, 136,
253            141, 46, 146, 174, 231, 14, 204, 28, 99, 139, 246, 214, 112, 253, 151, 34, 114, 235, 7,
254            76, 37, 36, 154, 226, 148, 68, 238, 117, 87, 212, 183, 174, 200, 222, 153, 62, 48, 166,
255            134, 27, 97, 230, 206, 78, 128, 151, 166, 15, 156, 120, 158, 35, 41, 121, 55, 180, 184,
256            108, 160, 33, 208, 255, 147, 246, 159, 10, 239, 6, 103, 124, 123, 83, 72, 189, 237,
257            225, 36, 30, 151, 134, 94, 211, 181, 108, 239, 137, 18, 246, 237, 233, 59, 61, 24, 111,
258            198, 76, 92, 86, 129, 171, 50, 124, 2, 72, 143, 160, 223, 32, 187, 175, 239, 111, 51,
259            85, 110, 134, 45, 193, 113, 247, 249, 78, 230, 103, 123, 66, 48, 31, 169, 228, 140,
260            202, 168, 202, 199, 147, 89, 135, 104, 254, 198, 72, 31, 103, 236, 207, 138, 24, 100,
261            230, 168, 233, 214, 130, 195, 0, 25, 220, 136, 128, 173, 40, 154, 116, 87, 114, 187,
262            170, 150, 131, 163, 155, 98, 217, 198, 238, 178, 165, 214, 168, 252, 107, 123, 214, 33,
263            17, 114, 35, 23, 172, 145, 5, 39, 16, 33, 92, 163, 132, 240, 167, 128, 226, 165, 80, 9,
264            153, 252, 139, 0, 139, 0, 54, 188, 253, 141, 2, 78, 97, 53, 214, 173, 155, 84, 98, 51,
265            70, 110, 91, 181, 229, 231, 27, 225, 185, 143, 63, 238,
266        ];
267
268        let seed = Seed(1u128);
269
270        let mut rng = G::new(seed);
271        let bytes = rng.take(N_BYTES).collect::<Vec<_>>();
272        assert_eq!(bytes, EXPECTED_BYTE);
273    }
274
275    pub fn test_vectors_xof_seed<G: RandomGenerator>() {
276        // Number of random bytes to generate,
277        // this should be 2 batch worth of aes calls (where a batch is 8 aes)
278        const N_BYTES: usize = 16 * 2 * 8;
279
280        const EXPECTED_BYTE: [u8; N_BYTES] = [
281            134, 231, 117, 200, 60, 174, 158, 95, 80, 64, 236, 147, 204, 196, 251, 198, 110, 155,
282            74, 69, 162, 251, 224, 46, 46, 83, 209, 224, 89, 108, 68, 240, 37, 16, 109, 194, 92, 3,
283            164, 21, 167, 224, 205, 31, 90, 178, 59, 150, 142, 238, 113, 144, 181, 118, 160, 72,
284            187, 38, 29, 61, 189, 229, 66, 22, 4, 38, 210, 63, 232, 182, 115, 49, 96, 6, 120, 226,
285            40, 51, 144, 59, 136, 224, 252, 195, 50, 250, 134, 45, 149, 220, 32, 27, 35, 225, 190,
286            73, 161, 182, 250, 149, 153, 131, 220, 143, 181, 152, 187, 25, 62, 197, 24, 10, 142,
287            57, 172, 15, 17, 244, 242, 232, 51, 50, 244, 85, 58, 69, 28, 113, 151, 143, 138, 166,
288            198, 16, 210, 46, 234, 138, 32, 124, 98, 167, 141, 251, 60, 13, 158, 106, 29, 86, 63,
289            73, 42, 138, 174, 195, 192, 72, 122, 74, 54, 134, 107, 144, 241, 12, 33, 70, 27, 116,
290            154, 123, 1, 252, 141, 73, 79, 30, 162, 43, 57, 8, 99, 62, 222, 117, 232, 147, 81, 189,
291            54, 17, 233, 33, 41, 132, 155, 246, 185, 189, 17, 77, 32, 107, 134, 61, 174, 64, 174,
292            80, 229, 239, 243, 143, 152, 249, 254, 125, 42, 0, 170, 253, 34, 57, 100, 82, 244, 9,
293            101, 126, 138, 218, 215, 55, 58, 177, 154, 5, 28, 113, 89, 123, 129, 254, 212, 191,
294            162, 44, 120, 67, 241, 157, 31, 162, 113, 91,
295        ];
296
297        let seed = 1u128;
298        let xof_seed = XofSeed::new_u128(seed, *b"abcdefgh");
299
300        let mut rng = G::new(xof_seed);
301        let bytes = rng.take(N_BYTES).collect::<Vec<_>>();
302        assert_eq!(bytes, EXPECTED_BYTE);
303    }
304
305    pub fn test_vectors_xof_seed_bytes<G: RandomGenerator>() {
306        // Number of random bytes to generate,
307        // this should be 2 batch worth of aes calls (where a batch is 8 aes)
308        const N_BYTES: usize = 16 * 2 * 8;
309
310        const EXPECTED_BYTE: [u8; N_BYTES] = [
311            21, 82, 236, 82, 18, 196, 63, 129, 54, 134, 70, 114, 199, 200, 11, 5, 52, 170, 218, 49,
312            127, 45, 5, 252, 214, 82, 127, 196, 241, 83, 161, 79, 139, 183, 33, 122, 126, 177, 23,
313            36, 161, 122, 7, 112, 237, 154, 195, 90, 202, 218, 64, 90, 86, 190, 139, 169, 192, 105,
314            248, 220, 126, 133, 60, 124, 81, 72, 183, 238, 253, 138, 141, 144, 167, 168, 94, 19,
315            172, 92, 235, 113, 185, 31, 150, 143, 165, 220, 115, 83, 180, 1, 10, 130, 140, 32, 74,
316            132, 76, 22, 120, 126, 68, 154, 95, 61, 202, 79, 126, 38, 217, 181, 243, 6, 218, 75,
317            232, 235, 194, 255, 254, 184, 18, 122, 51, 222, 61, 167, 175, 97, 188, 186, 217, 105,
318            72, 205, 130, 3, 204, 157, 252, 27, 20, 212, 136, 70, 65, 215, 164, 130, 242, 107, 214,
319            150, 211, 59, 92, 13, 148, 219, 96, 181, 5, 38, 170, 48, 218, 111, 131, 246, 102, 169,
320            17, 182, 253, 41, 209, 185, 79, 245, 30, 142, 192, 127, 78, 178, 68, 223, 89, 210, 27,
321            84, 164, 163, 216, 188, 190, 128, 154, 224, 160, 53, 249, 10, 250, 95, 160, 94, 28, 41,
322            34, 254, 232, 137, 185, 82, 82, 192, 74, 197, 19, 46, 180, 169, 182, 216, 221, 127,
323            196, 185, 156, 82, 32, 133, 97, 140, 183, 67, 37, 110, 31, 210, 197, 27, 81, 197, 132,
324            136, 98, 78, 218, 252, 247, 239, 205, 21, 166, 218,
325        ];
326
327        let seed = vec![
328            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
329            24, 25, 26, 27, 28, 29, 30, 31,
330        ];
331        let xof_seed = XofSeed::new(seed, *b"abcdefgh");
332
333        let mut rng = G::new(xof_seed);
334        let bytes = rng.take(N_BYTES).collect::<Vec<_>>();
335        assert_eq!(bytes, EXPECTED_BYTE);
336    }
337}