stedy/
rng.rs

1use crate::chacha::ChaCha20;
2
3#[cfg(feature = "getrandom")]
4use crate::Error;
5
6pub struct Rng {
7    cipher: ChaCha20,
8    buffer: [u8; 64],
9    index: usize,
10}
11
12impl Rng {
13    #[cfg(feature = "getrandom")]
14    pub fn new() -> Result<Self, Error> {
15        let mut seed = [0u8; 32];
16        getrandom::getrandom(&mut seed).or(Err(Error::Entropy))?;
17        Ok(Self::from(seed))
18    }
19
20    #[cfg(not(feature = "getrandom"))]
21    pub fn new(seed: [u8; 32]) -> Self {
22        Self::from(seed)
23    }
24
25    fn refill_buffer(&mut self) {
26        self.cipher.apply_keystream(&mut self.buffer);
27        self.index = 0;
28    }
29
30    pub fn next_u32(&mut self) -> u32 {
31        if self.index + 4 > 64 {
32            self.refill_buffer();
33        }
34        let result =
35            u32::from_le_bytes(self.buffer[self.index..self.index + 4].try_into().unwrap());
36        self.index += 4;
37        result
38    }
39
40    pub fn next_u64(&mut self) -> u64 {
41        if self.index + 8 > 64 {
42            self.refill_buffer();
43        }
44        let result =
45            u64::from_le_bytes(self.buffer[self.index..self.index + 8].try_into().unwrap());
46        self.index += 8;
47        result
48    }
49
50    pub fn fill(&mut self, bytes: &mut [u8]) {
51        for chunk in bytes.chunks_mut(64) {
52            if self.index == 64 {
53                self.refill_buffer();
54            }
55            let remaining = 64 - self.index;
56            let size = chunk.len().min(remaining);
57            chunk.copy_from_slice(&self.buffer[self.index..self.index + size]);
58            self.index += size;
59        }
60    }
61}
62
63impl From<&[u8; 32]> for Rng {
64    fn from(value: &[u8; 32]) -> Self {
65        Self {
66            cipher: ChaCha20::from(value, &[0u8; 12]),
67            buffer: [0u8; 64],
68            index: 64,
69        }
70    }
71}
72
73impl From<[u8; 32]> for Rng {
74    fn from(value: [u8; 32]) -> Self {
75        Self::from(&value)
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn test_next_u32() {
85        let seed = [0u8; 32];
86        let mut rng = Rng::from(seed);
87        let result = rng.next_u32();
88        assert_eq!(result, 2917185654);
89    }
90
91    #[test]
92    fn test_next_u64() {
93        let seed = [0u8; 32];
94        let mut rng = Rng::from(seed);
95        let result = rng.next_u64();
96        assert_eq!(result, 10393729187455219830);
97    }
98
99    #[test]
100    fn test_fill() {
101        let seed = [0u8; 32];
102        let mut rng = Rng::from(seed);
103        let mut bytes = [0u8; 32];
104        rng.fill(&mut bytes);
105        assert_eq!(
106            bytes,
107            [
108                118, 184, 224, 173, 160, 241, 61, 144, 64, 93, 106, 229, 83, 134, 189, 40, 189,
109                210, 25, 184, 160, 141, 237, 26, 168, 54, 239, 204, 139, 119, 13, 199,
110            ]
111        );
112    }
113
114    #[test]
115    fn test_fill_multiple_blocks() {
116        let seed = [0u8; 32];
117        let mut rng = Rng::from(seed);
118        let mut bytes = [0u8; 96];
119        rng.fill(&mut bytes);
120        assert_eq!(
121            bytes,
122            [
123                118, 184, 224, 173, 160, 241, 61, 144, 64, 93, 106, 229, 83, 134, 189, 40, 189,
124                210, 25, 184, 160, 141, 237, 26, 168, 54, 239, 204, 139, 119, 13, 199, 218, 65, 89,
125                124, 81, 87, 72, 141, 119, 36, 224, 63, 184, 216, 74, 55, 106, 67, 184, 244, 21,
126                24, 161, 28, 195, 135, 182, 105, 178, 238, 101, 134, 233, 191, 7, 19, 245, 160, 5,
127                234, 216, 231, 253, 153, 32, 171, 181, 37, 118, 221, 48, 24, 232, 110, 136, 115,
128                186, 240, 188, 242, 185, 153, 119, 42
129            ]
130        );
131    }
132
133    #[test]
134    fn test_getrandom_fill() {
135        let mut rng = Rng::new().unwrap();
136        let mut bytes = [0u8; 32];
137        rng.fill(&mut bytes);
138        assert_ne!(bytes, [0; 32]);
139        assert_ne!(
140            bytes,
141            [
142                118, 184, 224, 173, 160, 241, 61, 144, 64, 93, 106, 229, 83, 134, 189, 40, 189,
143                210, 25, 184, 160, 141, 237, 26, 168, 54, 239, 204, 139, 119, 13, 199,
144            ]
145        );
146    }
147}