Skip to main content

tfhe/core_crypto/commons/math/random/
mod.rs

1//! A module containing random sampling functions.
2//!
3//! This module contains a [`RandomGenerator`] type, which exposes methods to sample numeric values
4//! randomly according to a given distribution, for instance:
5//!
6//! + [`RandomGenerator::random_uniform`] samples a random unsigned integer with uniform probability
7//!   over the set of representable values.
8//! + [`RandomGenerator::random_gaussian`] samples a random float with using a gaussian
9//!   distribution.
10//!
11//! The implementation relies on the [`RandomGenerable`] trait, which gives a type the ability to
12//! be randomly generated according to a given distribution. The module contains multiple
13//! implementations of this trait, for different distributions. Note, though, that instead of
14//! using the [`RandomGenerable`] methods, you should use the various methods exposed by
15//! [`RandomGenerator`] instead.
16use crate::core_crypto::backward_compatibility::commons::math::random::DynamicDistributionVersions;
17use crate::core_crypto::commons::dispersion::{DispersionParameter, StandardDev, Variance};
18use crate::core_crypto::commons::numeric::{FloatingPoint, UnsignedInteger};
19use std::ops::Bound;
20
21use crate::core_crypto::prelude::{CastInto, Numeric};
22pub use gaussian::*;
23pub use generator::*;
24pub use t_uniform::*;
25pub use tfhe_csprng::generators::DefaultRandomGenerator;
26use tfhe_versionable::Versionize;
27pub use uniform::*;
28pub use uniform_binary::*;
29pub use uniform_ternary::*;
30
31#[cfg(test)]
32pub(crate) mod tests;
33
34mod gaussian;
35mod generator;
36mod t_uniform;
37mod uniform;
38mod uniform_binary;
39mod uniform_ternary;
40
41/// A trait giving a type the ability to be randomly generated according to a given distribution.
42pub trait RandomGenerable<D: Distribution>
43where
44    Self: Sized,
45{
46    // This is required as e.g. Gaussian can generate pairs of Torus elements and we can't use a
47    // pair of elements as custom modulus
48    type CustomModulus: Copy;
49
50    /// Generate a value from `distribution`.
51    fn generate_one<G: ByteRandomGenerator>(
52        generator: &mut RandomGenerator<G>,
53        distribution: D,
54    ) -> Self;
55
56    /// Generate a value from `distribution` modulo the given `custom_modulus`.
57    ///
58    /// This `custom_modulus` must be able to represent all possible values from `distribution`.
59    ///
60    /// Implementations are allowed to panic if the given `custom_modulus` cannot represent all
61    /// possible values from `distribution` as the effective distribution would differ from the
62    /// desired distribution.
63    fn generate_one_custom_modulus<G: ByteRandomGenerator>(
64        generator: &mut RandomGenerator<G>,
65        distribution: D,
66        custom_modulus: Self::CustomModulus,
67    ) -> Self {
68        let _ = generator;
69        let _ = distribution;
70        let _ = custom_modulus;
71        todo!("This distribution does not support custom modulus generation at this time.");
72    }
73
74    fn fill_slice<G: ByteRandomGenerator>(
75        generator: &mut RandomGenerator<G>,
76        distribution: D,
77        slice: &mut [Self],
78    ) {
79        for s in slice.iter_mut() {
80            *s = Self::generate_one(generator, distribution);
81        }
82    }
83
84    fn fill_slice_custom_mod<G: ByteRandomGenerator>(
85        generator: &mut RandomGenerator<G>,
86        distribution: D,
87        slice: &mut [Self],
88        custom_modulus: Self::CustomModulus,
89    ) {
90        for s in slice.iter_mut() {
91            *s = Self::generate_one_custom_modulus(generator, distribution, custom_modulus);
92        }
93    }
94
95    /// Return the probability to successfully generate a sample from the given distribution for the
96    /// type the trait is implemented on.
97    ///
98    /// If the generation can never fail it should return 1.0, otherwise it returns a value in
99    /// ]0; 1[.
100    ///
101    /// If None is passed for modulus, then the native modulus of the type (e.g. $2^{64}$ for u64)
102    /// or no modulus for floating points values is used.
103    ///
104    /// Otherwise the given modulus is interpreted as being the one used for a call to
105    /// [`RandomGenerable::generate_one_custom_modulus`].
106    fn single_sample_success_probability(
107        distribution: D,
108        modulus: Option<Self::CustomModulus>,
109    ) -> f64;
110
111    /// Return how many bytes coming from a CSPRNG are required to generate one sample even if that
112    /// generation can fail.
113    ///
114    /// If None is passed for modulus, then the native modulus of the type (e.g. $2^{64}$ for u64)
115    /// or no modulus for floating points values is used.
116    ///
117    /// Otherwise the given modulus is interpreted as being the one used for a call to
118    /// [`RandomGenerable::generate_one_custom_modulus`].
119    fn single_sample_required_random_byte_count(
120        distribution: D,
121        modulus: Option<Self::CustomModulus>,
122    ) -> usize;
123}
124
125/// A marker trait for types representing distributions.
126pub trait Distribution: seal::Sealed + Copy {}
127mod seal {
128    use crate::core_crypto::commons::numeric::{FloatingPoint, UnsignedInteger};
129
130    pub trait Sealed {}
131    impl Sealed for super::Uniform {}
132    impl Sealed for super::UniformBinary {}
133    impl Sealed for super::UniformTernary {}
134    impl<T: FloatingPoint> Sealed for super::Gaussian<T> {}
135    impl<T: UnsignedInteger> Sealed for super::TUniform<T> {}
136    impl<T: UnsignedInteger> Sealed for super::DynamicDistribution<T> {}
137}
138impl Distribution for Uniform {}
139impl Distribution for UniformBinary {}
140impl Distribution for UniformTernary {}
141impl<T: FloatingPoint> Distribution for Gaussian<T> {}
142impl<T: UnsignedInteger> Distribution for TUniform<T> {}
143
144pub trait BoundedDistribution<T>: Distribution {
145    fn low_bound(&self) -> Bound<T>;
146    fn high_bound(&self) -> Bound<T>;
147
148    fn contains(self, value: T) -> bool
149    where
150        T: Numeric,
151    {
152        {
153            match self.low_bound() {
154                Bound::Included(inclusive_low) => {
155                    if value < inclusive_low {
156                        return false;
157                    }
158                }
159                Bound::Excluded(exclusive_low) => {
160                    if value <= exclusive_low {
161                        return false;
162                    }
163                }
164                Bound::Unbounded => {}
165            }
166        }
167
168        {
169            match self.high_bound() {
170                Bound::Included(inclusive_high) => {
171                    if value > inclusive_high {
172                        return false;
173                    }
174                }
175                Bound::Excluded(exclusive_high) => {
176                    if value >= exclusive_high {
177                        return false;
178                    }
179                }
180                Bound::Unbounded => {}
181            }
182        }
183
184        true
185    }
186}
187
188impl<T> BoundedDistribution<T::Signed> for TUniform<T>
189where
190    T: UnsignedInteger,
191{
192    fn low_bound(&self) -> Bound<T::Signed> {
193        Bound::Included(self.min_value_inclusive())
194    }
195
196    fn high_bound(&self) -> Bound<T::Signed> {
197        Bound::Included(self.max_value_inclusive())
198    }
199}
200
201impl<T> BoundedDistribution<T::Signed> for DynamicDistribution<T>
202where
203    T: UnsignedInteger,
204{
205    fn low_bound(&self) -> Bound<T::Signed> {
206        match self {
207            Self::Gaussian(_) => Bound::Unbounded,
208            Self::TUniform(tu) => tu.low_bound(),
209        }
210    }
211
212    fn high_bound(&self) -> Bound<T::Signed> {
213        match self {
214            Self::Gaussian(_) => Bound::Unbounded,
215            Self::TUniform(tu) => tu.high_bound(),
216        }
217    }
218}
219
220#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Versionize)]
221#[versionize(DynamicDistributionVersions)]
222pub enum DynamicDistribution<T: UnsignedInteger> {
223    Gaussian(Gaussian<f64>),
224    TUniform(TUniform<T>),
225}
226
227impl<T: UnsignedInteger> DynamicDistribution<T> {
228    pub const fn new_gaussian_from_std_dev(std: StandardDev) -> Self {
229        Self::Gaussian(Gaussian::from_standard_dev(std, 0.0))
230    }
231
232    pub fn new_gaussian(dispersion: impl DispersionParameter) -> Self {
233        Self::Gaussian(Gaussian::from_dispersion_parameter(dispersion, 0.0))
234    }
235
236    #[track_caller]
237    pub const fn new_t_uniform(bound_log2: u32) -> Self {
238        Self::TUniform(TUniform::new(bound_log2))
239    }
240
241    #[track_caller]
242    pub const fn try_new_t_uniform(bound_log2: u32) -> Result<Self, &'static str> {
243        match TUniform::try_new(bound_log2) {
244            Ok(ok) => Ok(Self::TUniform(ok)),
245            Err(e) => Err(e),
246        }
247    }
248
249    #[track_caller]
250    pub const fn gaussian_std_dev(&self) -> StandardDev {
251        match self {
252            Self::Gaussian(gaussian) => StandardDev(gaussian.std),
253            Self::TUniform(_) => {
254                panic!("Tried to get gaussian variance from a non gaussian distribution")
255            }
256        }
257    }
258
259    #[track_caller]
260    pub fn gaussian_variance(&self) -> Variance {
261        match self {
262            Self::Gaussian(gaussian) => StandardDev::from_standard_dev(gaussian.std).get_variance(),
263            Self::TUniform(_) => {
264                panic!("Tried to get gaussian variance from a non gaussian distribution")
265            }
266        }
267    }
268}
269
270impl DynamicDistribution<u32> {
271    pub const fn to_u64_distribution(self) -> DynamicDistribution<u64> {
272        // Depending on how the Scalar type is used, converting it from u32 to u64
273        // might affect the underlying distribution in subtle ways. When adding support for
274        // new distributions, make sure that the result is still correct.
275        match self {
276            Self::Gaussian(gaussian) => DynamicDistribution::Gaussian(gaussian),
277            Self::TUniform(tuniform) => {
278                // Ok because an u32 bound is always also a valid u64 bound
279                DynamicDistribution::new_t_uniform(tuniform.bound_log2())
280            }
281        }
282    }
283}
284
285impl<T: UnsignedInteger> std::fmt::Display for DynamicDistribution<T> {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        match self {
288            Self::Gaussian(Gaussian { std, mean }) => {
289                write!(f, "Gaussian(ยต={mean},sigma={std})")
290            }
291            Self::TUniform(t_uniform) => {
292                write!(f, "TUniform({})", t_uniform.bound_log2())
293            }
294        }
295    }
296}
297
298impl<T: UnsignedInteger> Distribution for DynamicDistribution<T> {}
299
300impl<
301        T: UnsignedInteger
302            + RandomGenerable<Gaussian<f64>, CustomModulus = T>
303            + RandomGenerable<TUniform<T>, CustomModulus = T>,
304    > RandomGenerable<DynamicDistribution<T>> for T
305{
306    type CustomModulus = Self;
307
308    fn generate_one<G: ByteRandomGenerator>(
309        generator: &mut RandomGenerator<G>,
310        distribution: DynamicDistribution<T>,
311    ) -> Self {
312        match distribution {
313            DynamicDistribution::Gaussian(gaussian) => Self::generate_one(generator, gaussian),
314            DynamicDistribution::TUniform(t_uniform) => Self::generate_one(generator, t_uniform),
315        }
316    }
317
318    fn generate_one_custom_modulus<G: ByteRandomGenerator>(
319        generator: &mut RandomGenerator<G>,
320        distribution: DynamicDistribution<T>,
321        custom_modulus: Self::CustomModulus,
322    ) -> Self {
323        match distribution {
324            DynamicDistribution::Gaussian(gaussian) => {
325                Self::generate_one_custom_modulus(generator, gaussian, custom_modulus)
326            }
327            DynamicDistribution::TUniform(t_uniform) => {
328                Self::generate_one_custom_modulus(generator, t_uniform, custom_modulus)
329            }
330        }
331    }
332
333    fn single_sample_success_probability(
334        distribution: DynamicDistribution<T>,
335        modulus: Option<Self::CustomModulus>,
336    ) -> f64 {
337        match distribution {
338            DynamicDistribution::Gaussian(gaussian) => {
339                <Self as RandomGenerable<Gaussian<f64>>>::single_sample_success_probability(
340                    gaussian, modulus,
341                )
342            }
343            DynamicDistribution::TUniform(t_uniform) => {
344                <Self as RandomGenerable<TUniform<T>>>::single_sample_success_probability(
345                    t_uniform, modulus,
346                )
347            }
348        }
349    }
350
351    fn single_sample_required_random_byte_count(
352        distribution: DynamicDistribution<T>,
353        modulus: Option<Self::CustomModulus>,
354    ) -> usize {
355        match distribution {
356            DynamicDistribution::Gaussian(gaussian) => {
357                <Self as RandomGenerable<Gaussian<f64>>>::single_sample_required_random_byte_count(
358                    gaussian, modulus,
359                )
360            }
361            DynamicDistribution::TUniform(t_uniform) => {
362                <Self as RandomGenerable<TUniform<T>>>::single_sample_required_random_byte_count(
363                    t_uniform, modulus,
364                )
365            }
366        }
367    }
368}