winter_crypto/random/
default.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::vec::Vec;
7
8use math::{FieldElement, StarkField};
9
10use crate::{errors::RandomCoinError, Digest, ElementHasher, RandomCoin};
11
12// DEFAULT RANDOM COIN IMPLEMENTATION
13// ================================================================================================
14
15/// Pseudo-random element generator for finite fields, which is a default implementation of the
16/// RandomCoin trait.
17///
18/// A random coin can be used to draw elements uniformly at random from the specified base field
19/// or from any extension of the base field.
20///
21/// Internally we use a cryptographic hash function (which is specified via the `H` type parameter),
22/// to draw elements from the field. The coin works roughly as follows:
23/// - The internal state of the coin consists of a `seed` and a `counter`. At instantiation time,
24///   the `seed` is set to a hash of the provided bytes, and the `counter` is set to 0.
25/// - To draw the next element, we increment the `counter` and compute hash(`seed` || `counter`). If
26///   the resulting value is a valid field element, we return the result; otherwise we try again
27///   until a valid element is found or the number of allowed tries is exceeded.
28/// - We can also re-seed the coin with a new value. During the reseeding procedure, the seed is set
29///   to hash(`old_seed` || `new_seed`), and the counter is reset to 0.
30///
31/// # Examples
32/// ```
33/// # use winter_crypto::{RandomCoin, DefaultRandomCoin, Hasher, hashers::Blake3_256};
34/// # use math::fields::f128::BaseElement;
35/// // initial elements for seeding the random coin
36/// let seed = &[
37///     BaseElement::new(1),
38///     BaseElement::new(2),
39///     BaseElement::new(3),
40///     BaseElement::new(4),
41/// ];
42///
43/// // instantiate a random coin using BLAKE3 as the hash function
44/// let mut coin = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
45///
46/// // should draw different elements each time
47/// let e1 = coin.draw::<BaseElement>().unwrap();
48/// let e2 = coin.draw::<BaseElement>().unwrap();
49/// assert_ne!(e1, e2);
50///
51/// let e3 = coin.draw::<BaseElement>().unwrap();
52/// assert_ne!(e1, e3);
53/// assert_ne!(e2, e3);
54///
55/// // should draw same elements for the same seed
56/// let mut coin2 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
57/// let mut coin1 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
58/// let e1 = coin1.draw::<BaseElement>().unwrap();
59/// let e2 = coin2.draw::<BaseElement>().unwrap();
60/// assert_eq!(e1, e2);
61///
62/// // should draw different elements based on seed
63/// let mut coin1 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
64/// let seed = &[
65///     BaseElement::new(2),
66///     BaseElement::new(3),
67///     BaseElement::new(4),
68///     BaseElement::new(5),
69/// ];
70/// let mut coin2 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
71/// let e1 = coin1.draw::<BaseElement>().unwrap();
72/// let e2 = coin2.draw::<BaseElement>().unwrap();
73/// assert_ne!(e1, e2);
74/// ```
75pub struct DefaultRandomCoin<H: ElementHasher> {
76    seed: H::Digest,
77    counter: u64,
78}
79
80impl<H: ElementHasher> DefaultRandomCoin<H> {
81    /// Updates the state by incrementing the counter and returns hash(seed || counter)
82    fn next(&mut self) -> H::Digest {
83        self.counter += 1;
84        H::merge_with_int(self.seed, self.counter)
85    }
86}
87
88impl<B: StarkField, H: ElementHasher<BaseField = B>> RandomCoin for DefaultRandomCoin<H> {
89    type BaseField = B;
90    type Hasher = H;
91
92    // CONSTRUCTOR
93    // --------------------------------------------------------------------------------------------
94    /// Returns a new random coin instantiated with the provided `seed`.
95    fn new(seed: &[Self::BaseField]) -> Self {
96        let seed = H::hash_elements(seed);
97        Self { seed, counter: 0 }
98    }
99
100    // RESEEDING
101    // --------------------------------------------------------------------------------------------
102
103    /// Reseeds the coin with the specified data by setting the new seed to hash(`seed` || `data`).
104    ///
105    /// # Examples
106    /// ```
107    /// # use winter_crypto::{RandomCoin, DefaultRandomCoin, Hasher, hashers::Blake3_256};
108    /// # use math::fields::f128::BaseElement;
109    /// // initial elements for seeding the random coin
110    /// let seed = &[
111    ///     BaseElement::new(1),
112    ///     BaseElement::new(2),
113    ///     BaseElement::new(3),
114    ///     BaseElement::new(4),
115    /// ];
116    ///
117    /// let mut coin1 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
118    /// let mut coin2 = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
119    ///
120    /// // should draw the same element form both coins
121    /// let e1 = coin1.draw::<BaseElement>().unwrap();
122    /// let e2 = coin2.draw::<BaseElement>().unwrap();
123    /// assert_eq!(e1, e2);
124    ///
125    /// // after reseeding should draw different elements
126    /// coin2.reseed(Blake3_256::<BaseElement>::hash(&[2, 3, 4, 5]));
127    /// let e1 = coin1.draw::<BaseElement>().unwrap();
128    /// let e2 = coin2.draw::<BaseElement>().unwrap();
129    /// assert_ne!(e1, e2);
130    /// ```
131    fn reseed(&mut self, data: H::Digest) {
132        self.seed = H::merge(&[self.seed, data]);
133        self.counter = 0;
134    }
135
136    // PUBLIC ACCESSORS
137    // --------------------------------------------------------------------------------------------
138
139    /// Computes hash(`seed` || `value`) and returns the number of leading zeros in the resulting
140    /// value if it is interpreted as an integer in big-endian byte order.
141    fn check_leading_zeros(&self, value: u64) -> u32 {
142        let new_seed = H::merge_with_int(self.seed, value);
143        let bytes = new_seed.as_bytes();
144        let seed_head = u64::from_le_bytes(bytes[..8].try_into().unwrap());
145        seed_head.trailing_zeros()
146    }
147
148    // DRAW METHODS
149    // --------------------------------------------------------------------------------------------
150
151    /// Returns the next pseudo-random field element.
152    ///
153    /// # Errors
154    /// Returns an error if a valid field element could not be generated after 1000 calls to the
155    /// PRNG.
156    fn draw<E: FieldElement>(&mut self) -> Result<E, RandomCoinError> {
157        for _ in 0..1000 {
158            // get the next pseudo-random value and take the first ELEMENT_BYTES from it
159            let value = self.next();
160            let bytes = &value.as_bytes()[..E::ELEMENT_BYTES];
161
162            // check if the bytes can be converted into a valid field element; if they can,
163            // return; otherwise try again
164            if let Some(element) = E::from_random_bytes(bytes) {
165                return Ok(element);
166            }
167        }
168
169        Err(RandomCoinError::FailedToDrawFieldElement(1000))
170    }
171
172    /// Returns a vector of integers selected from the range [0, domain_size) after reseeding
173    /// the PRNG with the specified `nonce` by setting the new seed to hash(`seed` || `nonce`).
174    ///
175    /// # Errors
176    /// Returns an error if the specified number of integers could not be generated after 1000
177    /// calls to the PRNG.
178    ///
179    /// # Panics
180    /// Panics if:
181    /// - `domain_size` is not a power of two.
182    /// - `num_values` is greater than or equal to `domain_size`.
183    ///
184    /// # Examples
185    /// ```
186    /// # use std::collections::HashSet;
187    /// # use winter_crypto::{RandomCoin, DefaultRandomCoin, Hasher, hashers::Blake3_256};
188    /// # use math::fields::f128::BaseElement;
189    /// // initial elements for seeding the random coin
190    /// let seed = &[
191    ///     BaseElement::new(1),
192    ///     BaseElement::new(2),
193    ///     BaseElement::new(3),
194    ///     BaseElement::new(4),
195    /// ];
196    ///
197    /// let mut coin = DefaultRandomCoin::<Blake3_256<BaseElement>>::new(seed);
198    ///
199    /// let num_values = 20;
200    /// let domain_size = 64;
201    /// let nonce = 0;
202    /// let values = coin.draw_integers(num_values, domain_size, nonce).unwrap();
203    ///
204    /// assert_eq!(num_values, values.len());
205    ///
206    /// for value in values {
207    ///     assert!(value < domain_size);
208    /// }
209    /// ```
210    fn draw_integers(
211        &mut self,
212        num_values: usize,
213        domain_size: usize,
214        nonce: u64,
215    ) -> Result<Vec<usize>, RandomCoinError> {
216        assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
217        assert!(num_values < domain_size, "number of values must be smaller than domain size");
218
219        // reseed with nonce
220        self.seed = H::merge_with_int(self.seed, nonce);
221        self.counter = 0;
222
223        // determine how many bits are needed to represent valid values in the domain
224        let v_mask = (domain_size - 1) as u64;
225
226        // draw values from PRNG until we get as many unique values as specified by num_queries
227        let mut values = Vec::new();
228        for _ in 0..1000 {
229            // get the next pseudo-random value and read the first 8 bytes from it
230            let bytes: [u8; 8] = self.next().as_bytes()[..8].try_into().unwrap();
231
232            // convert to integer and limit the integer to the number of bits which can fit
233            // into the specified domain
234            let value = (u64::from_le_bytes(bytes) & v_mask) as usize;
235
236            values.push(value);
237            if values.len() == num_values {
238                break;
239            }
240        }
241
242        if values.len() < num_values {
243            return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
244        }
245
246        Ok(values)
247    }
248}