tfhe/core_crypto/commons/dispersion.rs
1//! Module containing noise distribution primitives.
2//!
3//! When dealing with noise, we tend to use different representation for the same value. In
4//! general, the noise is specified by the standard deviation of a gaussian distribution, which
5//! is of the form $\sigma = 2^p$, with $p$ a negative integer. Depending on the use case though,
6//! we rely on different representations for this quantity:
7//!
8//! + $\sigma$ can be encoded in the [`StandardDev`] type.
9//! + $p$ can be encoded in the [`LogStandardDev`] type.
10//! + $\sigma^2$ can be encoded in the [`Variance`] type.
11//!
12//! In any of those cases, the corresponding type implements the `DispersionParameter` trait,
13//! which makes if possible to use any of those representations generically when noise must be
14//! defined.
15
16use serde::{Deserialize, Serialize};
17use tfhe_versionable::Versionize;
18
19use crate::core_crypto::backward_compatibility::commons::dispersion::{
20 StandardDevVersions, VarianceVersions,
21};
22
23/// A trait for types representing distribution parameters, for a given unsigned integer type.
24// Warning:
25// DispersionParameter type should ONLY wrap a single native type.
26// As long as Variance wraps a native type (f64) it is ok to derive it from Copy instead of
27// Clone because f64 is itself Copy and stored in register.
28pub trait DispersionParameter: Copy {
29 /// Return the standard deviation of the distribution, i.e. $\sigma = 2^p$.
30 fn get_standard_dev(&self) -> StandardDev;
31 /// Return the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$.
32 fn get_variance(&self) -> Variance;
33 /// Return base 2 logarithm of the standard deviation of the distribution, i.e.
34 /// $\log\_2(\sigma)=p$
35 fn get_log_standard_dev(&self) -> LogStandardDev;
36 /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
37 fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev;
38
39 /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
40 fn get_modular_variance(&self, modulus: f64) -> ModularVariance;
41
42 /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
43 fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev;
44}
45
46/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
47/// representation.
48///
49/// # Example:
50///
51/// ```rust
52/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, LogStandardDev};
53/// let params = LogStandardDev::from_log_standard_dev(-25.);
54/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
55/// assert_eq!(params.get_log_standard_dev().0, -25.);
56/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
57/// assert_eq!(
58/// params.get_modular_standard_dev(2_f64.powi(32)).value,
59/// 2_f64.powf(32. - 25.)
60/// );
61/// assert_eq!(
62/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
63/// 32. - 25.
64/// );
65/// assert_eq!(
66/// params.get_modular_variance(2_f64.powi(32)).value,
67/// 2_f64.powf(32. - 25.).powi(2)
68/// );
69///
70/// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32);
71/// assert_eq!(modular_params.get_standard_dev().0, 2_f64.powf(-10.));
72/// ```
73#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
74pub struct LogStandardDev(pub f64);
75
76#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
77pub struct ModularLogStandardDev {
78 pub value: f64,
79 pub modulus: f64,
80}
81
82impl LogStandardDev {
83 pub fn from_log_standard_dev(log_std: f64) -> Self {
84 Self(log_std)
85 }
86
87 pub fn from_modular_log_standard_dev(log_std: f64, log2_modulus: u32) -> Self {
88 Self(log_std - log2_modulus as f64)
89 }
90}
91
92impl DispersionParameter for LogStandardDev {
93 fn get_standard_dev(&self) -> StandardDev {
94 StandardDev(f64::powf(2., self.0))
95 }
96 fn get_variance(&self) -> Variance {
97 Variance(f64::powf(2., self.0 * 2.))
98 }
99 fn get_log_standard_dev(&self) -> Self {
100 Self(self.0)
101 }
102 fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
103 ModularStandardDev {
104 value: 2_f64.powf(self.0) * modulus,
105 modulus,
106 }
107 }
108 fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
109 let std_dev = 2_f64.powf(self.0) * modulus;
110
111 ModularVariance {
112 value: std_dev * std_dev,
113 modulus,
114 }
115 }
116 fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
117 ModularLogStandardDev {
118 value: modulus.log2() + self.0,
119 modulus,
120 }
121 }
122}
123
124/// A distribution parameter that uses the standard deviation as representation.
125///
126/// # Example:
127///
128/// ```rust
129/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, StandardDev};
130/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.));
131/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
132/// assert_eq!(params.get_log_standard_dev().0, -25.);
133/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
134/// assert_eq!(
135/// params.get_modular_standard_dev(2_f64.powi(32)).value,
136/// 2_f64.powf(32. - 25.)
137/// );
138/// assert_eq!(
139/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
140/// 32. - 25.
141/// );
142/// assert_eq!(
143/// params.get_modular_variance(2_f64.powi(32)).value,
144/// 2_f64.powf(32. - 25.).powi(2)
145/// );
146/// ```
147#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
148#[versionize(StandardDevVersions)]
149pub struct StandardDev(pub f64);
150
151#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
152pub struct ModularStandardDev {
153 pub value: f64,
154 pub modulus: f64,
155}
156
157impl StandardDev {
158 pub fn from_standard_dev(std: f64) -> Self {
159 Self(std)
160 }
161
162 pub fn from_modular_standard_dev(std: f64, log2_modulus: u32) -> Self {
163 Self(std / 2_f64.powf(log2_modulus as f64))
164 }
165}
166
167impl DispersionParameter for StandardDev {
168 fn get_standard_dev(&self) -> Self {
169 Self(self.0)
170 }
171 fn get_variance(&self) -> Variance {
172 Variance(self.0.powi(2))
173 }
174 fn get_log_standard_dev(&self) -> LogStandardDev {
175 LogStandardDev(self.0.log2())
176 }
177 fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
178 ModularStandardDev {
179 value: self.0 * modulus,
180 modulus,
181 }
182 }
183 fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
184 let std_dev = self.0 * modulus;
185
186 ModularVariance {
187 value: std_dev * std_dev,
188 modulus,
189 }
190 }
191 fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
192 ModularLogStandardDev {
193 value: modulus.log2() + self.0.log2(),
194 modulus,
195 }
196 }
197}
198
199/// A distribution parameter that uses the variance as representation
200///
201/// # Example:
202///
203/// ```rust
204/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, Variance};
205/// let params = Variance::from_variance(2_f64.powi(-50));
206/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
207/// assert_eq!(params.get_log_standard_dev().0, -25.);
208/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
209/// assert_eq!(
210/// params.get_modular_standard_dev(2_f64.powi(32)).value,
211/// 2_f64.powf(32. - 25.)
212/// );
213/// assert_eq!(
214/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
215/// 32. - 25.
216/// );
217/// assert_eq!(
218/// params.get_modular_variance(2_f64.powi(32)).value,
219/// 2_f64.powf(32. - 25.).powi(2)
220/// );
221/// ```
222#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
223#[versionize(VarianceVersions)]
224pub struct Variance(pub f64);
225
226#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
227pub struct ModularVariance {
228 pub value: f64,
229 pub modulus: f64,
230}
231
232impl Variance {
233 pub fn from_variance(var: f64) -> Self {
234 Self(var)
235 }
236
237 pub fn from_modular_variance(var: f64, modulus: f64) -> Self {
238 Self(var / (modulus * modulus))
239 }
240}
241
242impl DispersionParameter for Variance {
243 fn get_standard_dev(&self) -> StandardDev {
244 StandardDev(self.0.sqrt())
245 }
246 fn get_variance(&self) -> Self {
247 Self(self.0)
248 }
249 fn get_log_standard_dev(&self) -> LogStandardDev {
250 LogStandardDev(self.0.sqrt().log2())
251 }
252 fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
253 ModularStandardDev {
254 value: self.0.sqrt() * modulus,
255 modulus,
256 }
257 }
258 fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
259 ModularVariance {
260 value: self.0 * modulus * modulus,
261 modulus,
262 }
263 }
264 fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
265 ModularLogStandardDev {
266 value: modulus.log2() + self.0.sqrt().log2(),
267 modulus,
268 }
269 }
270}