tfhe/core_crypto/commons/
dispersion.rs

1//! Module containing noise distribution primitives.
2//!
3//! When dealing with noise, we tend to use different representation for the same value. In
4//! general, the noise is specified by the standard deviation of a gaussian distribution, which
5//! is of the form $\sigma = 2^p$, with $p$ a negative integer. Depending on the use case though,
6//! we rely on different representations for this quantity:
7//!
8//! + $\sigma$ can be encoded in the [`StandardDev`] type.
9//! + $p$ can be encoded in the [`LogStandardDev`] type.
10//! + $\sigma^2$ can be encoded in the [`Variance`] type.
11//!
12//! In any of those cases, the corresponding type implements the `DispersionParameter` trait,
13//! which makes if possible to use any of those representations generically when noise must be
14//! defined.
15
16use serde::{Deserialize, Serialize};
17use tfhe_versionable::Versionize;
18
19use crate::core_crypto::backward_compatibility::commons::dispersion::{
20    StandardDevVersions, VarianceVersions,
21};
22
23/// A trait for types representing distribution parameters, for a given unsigned integer type.
24//  Warning:
25//  DispersionParameter type should ONLY wrap a single native type.
26//  As long as Variance wraps a native type (f64) it is ok to derive it from Copy instead of
27//  Clone because f64 is itself Copy and stored in register.
28pub trait DispersionParameter: Copy {
29    /// Return the standard deviation of the distribution, i.e. $\sigma = 2^p$.
30    fn get_standard_dev(&self) -> StandardDev;
31    /// Return the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$.
32    fn get_variance(&self) -> Variance;
33    /// Return base 2 logarithm of the standard deviation of the distribution, i.e.
34    /// $\log\_2(\sigma)=p$
35    fn get_log_standard_dev(&self) -> LogStandardDev;
36    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
37    fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev;
38
39    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
40    fn get_modular_variance(&self, modulus: f64) -> ModularVariance;
41
42    /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
43    fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev;
44}
45
46/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
47/// representation.
48///
49/// # Example:
50///
51/// ```rust
52/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, LogStandardDev};
53/// let params = LogStandardDev::from_log_standard_dev(-25.);
54/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
55/// assert_eq!(params.get_log_standard_dev().0, -25.);
56/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
57/// assert_eq!(
58///     params.get_modular_standard_dev(2_f64.powi(32)).value,
59///     2_f64.powf(32. - 25.)
60/// );
61/// assert_eq!(
62///     params.get_modular_log_standard_dev(2_f64.powi(32)).value,
63///     32. - 25.
64/// );
65/// assert_eq!(
66///     params.get_modular_variance(2_f64.powi(32)).value,
67///     2_f64.powf(32. - 25.).powi(2)
68/// );
69///
70/// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32);
71/// assert_eq!(modular_params.get_standard_dev().0, 2_f64.powf(-10.));
72/// ```
73#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
74pub struct LogStandardDev(pub f64);
75
76#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
77pub struct ModularLogStandardDev {
78    pub value: f64,
79    pub modulus: f64,
80}
81
82impl LogStandardDev {
83    pub fn from_log_standard_dev(log_std: f64) -> Self {
84        Self(log_std)
85    }
86
87    pub fn from_modular_log_standard_dev(log_std: f64, log2_modulus: u32) -> Self {
88        Self(log_std - log2_modulus as f64)
89    }
90}
91
92impl DispersionParameter for LogStandardDev {
93    fn get_standard_dev(&self) -> StandardDev {
94        StandardDev(f64::powf(2., self.0))
95    }
96    fn get_variance(&self) -> Variance {
97        Variance(f64::powf(2., self.0 * 2.))
98    }
99    fn get_log_standard_dev(&self) -> Self {
100        Self(self.0)
101    }
102    fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
103        ModularStandardDev {
104            value: 2_f64.powf(self.0) * modulus,
105            modulus,
106        }
107    }
108    fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
109        let std_dev = 2_f64.powf(self.0) * modulus;
110
111        ModularVariance {
112            value: std_dev * std_dev,
113            modulus,
114        }
115    }
116    fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
117        ModularLogStandardDev {
118            value: modulus.log2() + self.0,
119            modulus,
120        }
121    }
122}
123
124/// A distribution parameter that uses the standard deviation as representation.
125///
126/// # Example:
127///
128/// ```rust
129/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, StandardDev};
130/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.));
131/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
132/// assert_eq!(params.get_log_standard_dev().0, -25.);
133/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
134/// assert_eq!(
135///     params.get_modular_standard_dev(2_f64.powi(32)).value,
136///     2_f64.powf(32. - 25.)
137/// );
138/// assert_eq!(
139///     params.get_modular_log_standard_dev(2_f64.powi(32)).value,
140///     32. - 25.
141/// );
142/// assert_eq!(
143///     params.get_modular_variance(2_f64.powi(32)).value,
144///     2_f64.powf(32. - 25.).powi(2)
145/// );
146/// ```
147#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
148#[versionize(StandardDevVersions)]
149pub struct StandardDev(pub f64);
150
151#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
152pub struct ModularStandardDev {
153    pub value: f64,
154    pub modulus: f64,
155}
156
157impl StandardDev {
158    pub fn from_standard_dev(std: f64) -> Self {
159        Self(std)
160    }
161
162    pub fn from_modular_standard_dev(std: f64, log2_modulus: u32) -> Self {
163        Self(std / 2_f64.powf(log2_modulus as f64))
164    }
165}
166
167impl DispersionParameter for StandardDev {
168    fn get_standard_dev(&self) -> Self {
169        Self(self.0)
170    }
171    fn get_variance(&self) -> Variance {
172        Variance(self.0.powi(2))
173    }
174    fn get_log_standard_dev(&self) -> LogStandardDev {
175        LogStandardDev(self.0.log2())
176    }
177    fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
178        ModularStandardDev {
179            value: self.0 * modulus,
180            modulus,
181        }
182    }
183    fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
184        let std_dev = self.0 * modulus;
185
186        ModularVariance {
187            value: std_dev * std_dev,
188            modulus,
189        }
190    }
191    fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
192        ModularLogStandardDev {
193            value: modulus.log2() + self.0.log2(),
194            modulus,
195        }
196    }
197}
198
199/// A distribution parameter that uses the variance as representation
200///
201/// # Example:
202///
203/// ```rust
204/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, Variance};
205/// let params = Variance::from_variance(2_f64.powi(-50));
206/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
207/// assert_eq!(params.get_log_standard_dev().0, -25.);
208/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
209/// assert_eq!(
210///     params.get_modular_standard_dev(2_f64.powi(32)).value,
211///     2_f64.powf(32. - 25.)
212/// );
213/// assert_eq!(
214///     params.get_modular_log_standard_dev(2_f64.powi(32)).value,
215///     32. - 25.
216/// );
217/// assert_eq!(
218///     params.get_modular_variance(2_f64.powi(32)).value,
219///     2_f64.powf(32. - 25.).powi(2)
220/// );
221/// ```
222#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
223#[versionize(VarianceVersions)]
224pub struct Variance(pub f64);
225
226#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
227pub struct ModularVariance {
228    pub value: f64,
229    pub modulus: f64,
230}
231
232impl Variance {
233    pub fn from_variance(var: f64) -> Self {
234        Self(var)
235    }
236
237    pub fn from_modular_variance(var: f64, modulus: f64) -> Self {
238        Self(var / (modulus * modulus))
239    }
240}
241
242impl DispersionParameter for Variance {
243    fn get_standard_dev(&self) -> StandardDev {
244        StandardDev(self.0.sqrt())
245    }
246    fn get_variance(&self) -> Self {
247        Self(self.0)
248    }
249    fn get_log_standard_dev(&self) -> LogStandardDev {
250        LogStandardDev(self.0.sqrt().log2())
251    }
252    fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
253        ModularStandardDev {
254            value: self.0.sqrt() * modulus,
255            modulus,
256        }
257    }
258    fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
259        ModularVariance {
260            value: self.0 * modulus * modulus,
261            modulus,
262        }
263    }
264    fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
265        ModularLogStandardDev {
266            value: modulus.log2() + self.0.sqrt().log2(),
267            modulus,
268        }
269    }
270}