tfhe/shortint/server_key/
modulus_switch_noise_reduction.rs1use super::{PBSConformanceParams, PbsTypeConformanceParams};
2use crate::conformance::ParameterSetConformant;
3use crate::core_crypto::algorithms::*;
4use crate::core_crypto::commons::math::random::{CompressionSeed, DynamicDistribution, Uniform};
5use crate::core_crypto::commons::parameters::{
6 LweDimension, NoiseEstimationMeasureBound, PlaintextCount, RSigmaFactor,
7};
8use crate::core_crypto::commons::traits::*;
9use crate::core_crypto::entities::*;
10use crate::core_crypto::prelude::modulus_switch_noise_reduction::improve_lwe_ciphertext_modulus_switch_noise_for_binary_key;
11use crate::core_crypto::prelude::{
12 CiphertextModulus as CoreCiphertextModulus, CiphertextModulusLog, Variance,
13};
14use crate::shortint::backward_compatibility::server_key::modulus_switch_noise_reduction::*;
15use crate::shortint::engine::ShortintEngine;
16use crate::shortint::parameters::ModulusSwitchNoiseReductionParams;
17
18use serde::{Deserialize, Serialize};
19use std::fmt::Debug;
20use tfhe_versionable::Versionize;
21
22#[derive(Copy, Clone)]
23pub struct ModulusSwitchNoiseReductionKeyConformanceParams {
24 pub modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
25 pub lwe_dimension: LweDimension,
26}
27
28impl TryFrom<&PBSConformanceParams> for ModulusSwitchNoiseReductionKeyConformanceParams {
29 type Error = ();
30
31 fn try_from(value: &PBSConformanceParams) -> Result<Self, ()> {
32 match &value.pbs_type {
33 PbsTypeConformanceParams::Classic {
34 modulus_switch_noise_reduction,
35 } => modulus_switch_noise_reduction.map_or(Err(()), |modulus_switch_noise_reduction| {
36 Ok(Self {
37 modulus_switch_noise_reduction_params: modulus_switch_noise_reduction,
38 lwe_dimension: value.in_lwe_dimension,
39 })
40 }),
41 PbsTypeConformanceParams::MultiBit { .. } => Err(()),
42 }
43 }
44}
45
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
56#[versionize(ModulusSwitchNoiseReductionKeyVersions)]
57pub struct ModulusSwitchNoiseReductionKey<InputScalar>
58where
59 InputScalar: UnsignedInteger,
60{
61 pub modulus_switch_zeros: LweCiphertextListOwned<InputScalar>,
62 pub ms_bound: NoiseEstimationMeasureBound,
63 pub ms_r_sigma_factor: RSigmaFactor,
64 pub ms_input_variance: Variance,
65}
66
67impl<InputScalar> ParameterSetConformant for ModulusSwitchNoiseReductionKey<InputScalar>
68where
69 InputScalar: UnsignedInteger,
70{
71 type ParameterSet = ModulusSwitchNoiseReductionKeyConformanceParams;
72
73 fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
74 let Self {
75 modulus_switch_zeros,
76 ms_bound,
77 ms_r_sigma_factor,
78 ms_input_variance,
79 } = self;
80
81 let ModulusSwitchNoiseReductionKeyConformanceParams {
82 modulus_switch_noise_reduction_params,
83 lwe_dimension,
84 } = parameter_set;
85
86 let ModulusSwitchNoiseReductionParams {
87 modulus_switch_zeros_count: param_modulus_switch_zeros_count,
88 ms_bound: param_ms_bound,
89 ms_r_sigma_factor: param_ms_r_sigma_factor,
90 ms_input_variance: param_ms_input_variance,
91 } = modulus_switch_noise_reduction_params;
92
93 ms_bound == param_ms_bound
94 && ms_r_sigma_factor == param_ms_r_sigma_factor
95 && ms_input_variance == param_ms_input_variance
96 && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0
97 && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension
98 }
99}
100
101impl<InputScalar> ModulusSwitchNoiseReductionKey<InputScalar>
102where
103 InputScalar: UnsignedInteger,
104{
105 pub fn improve_modulus_switch_noise<Cont>(
106 &self,
107 input: &mut LweCiphertext<Cont>,
108 log_modulus: CiphertextModulusLog,
109 ) where
110 Cont: ContainerMut<Element = InputScalar>,
111 {
112 improve_lwe_ciphertext_modulus_switch_noise_for_binary_key(
113 input,
114 &self.modulus_switch_zeros,
115 self.ms_r_sigma_factor,
116 self.ms_bound,
117 self.ms_input_variance,
118 log_modulus,
119 );
120 }
121}
122
123#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
124#[versionize(CompressedModulusSwitchNoiseReductionKeyVersions)]
125pub struct CompressedModulusSwitchNoiseReductionKey<InputScalar>
126where
127 InputScalar: UnsignedInteger,
128{
129 pub modulus_switch_zeros: SeededLweCiphertextListOwned<InputScalar>,
130 pub ms_bound: NoiseEstimationMeasureBound,
131 pub ms_r_sigma_factor: RSigmaFactor,
132 pub ms_input_variance: Variance,
133}
134
135impl<InputScalar> ParameterSetConformant for CompressedModulusSwitchNoiseReductionKey<InputScalar>
136where
137 InputScalar: UnsignedInteger,
138{
139 type ParameterSet = ModulusSwitchNoiseReductionKeyConformanceParams;
140
141 fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
142 let Self {
143 modulus_switch_zeros,
144 ms_bound,
145 ms_r_sigma_factor,
146 ms_input_variance,
147 } = self;
148
149 let ModulusSwitchNoiseReductionKeyConformanceParams {
150 modulus_switch_noise_reduction_params,
151 lwe_dimension,
152 } = parameter_set;
153
154 let ModulusSwitchNoiseReductionParams {
155 modulus_switch_zeros_count: param_modulus_switch_zeros_count,
156 ms_bound: param_ms_bound,
157 ms_r_sigma_factor: param_ms_r_sigma_factor,
158 ms_input_variance: param_ms_input_variance,
159 } = modulus_switch_noise_reduction_params;
160
161 ms_bound == param_ms_bound
162 && ms_r_sigma_factor == param_ms_r_sigma_factor
163 && ms_input_variance == param_ms_input_variance
164 && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0
165 && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension
166 }
167}
168
169impl<InputScalar> ModulusSwitchNoiseReductionKey<InputScalar>
170where
171 InputScalar: Encryptable<Uniform, DynamicDistribution<InputScalar>>,
172{
173 pub fn new<KeyCont: Container<Element = InputScalar> + Sync>(
174 modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
175 secret_key: &LweSecretKey<KeyCont>,
176 engine: &mut ShortintEngine,
177 ciphertext_modulus: CoreCiphertextModulus<InputScalar>,
178 lwe_noise_distribution: DynamicDistribution<InputScalar>,
179 ) -> Self {
180 let ModulusSwitchNoiseReductionParams {
181 modulus_switch_zeros_count: count,
182 ms_bound,
183 ms_r_sigma_factor,
184 ms_input_variance,
185 } = modulus_switch_noise_reduction_params;
186
187 let lwe_size = secret_key.lwe_dimension().to_lwe_size();
188
189 let mut modulus_switch_zeros =
190 LweCiphertextList::new(InputScalar::ZERO, lwe_size, count, ciphertext_modulus);
191
192 let plaintext_list = PlaintextList::new(InputScalar::ZERO, PlaintextCount(count.0));
193
194 #[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
196 par_encrypt_lwe_ciphertext_list(
197 secret_key,
198 &mut modulus_switch_zeros,
199 &plaintext_list,
200 lwe_noise_distribution,
201 &mut engine.encryption_generator,
202 );
203
204 #[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
206 encrypt_lwe_ciphertext_list(
207 secret_key,
208 &mut modulus_switch_zeros,
209 &plaintext_list,
210 lwe_noise_distribution,
211 &mut engine.encryption_generator,
212 );
213
214 Self {
215 modulus_switch_zeros,
216 ms_bound,
217 ms_r_sigma_factor,
218 ms_input_variance,
219 }
220 }
221}
222
223impl<InputScalar> CompressedModulusSwitchNoiseReductionKey<InputScalar>
224where
225 InputScalar: Encryptable<Uniform, DynamicDistribution<InputScalar>>,
226{
227 pub fn new<KeyCont: Container<Element = InputScalar> + Sync>(
228 modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
229 secret_key: &LweSecretKey<KeyCont>,
230 engine: &mut ShortintEngine,
231 ciphertext_modulus: CoreCiphertextModulus<InputScalar>,
232 lwe_noise_distribution: DynamicDistribution<InputScalar>,
233 compression_seed: CompressionSeed,
234 ) -> Self {
235 let ModulusSwitchNoiseReductionParams {
236 modulus_switch_zeros_count: count,
237 ms_bound,
238 ms_r_sigma_factor,
239 ms_input_variance,
240 } = modulus_switch_noise_reduction_params;
241
242 let lwe_size = secret_key.lwe_dimension().to_lwe_size();
243
244 let mut modulus_switch_zeros = SeededLweCiphertextList::new(
245 InputScalar::ZERO,
246 lwe_size,
247 count,
248 compression_seed,
249 ciphertext_modulus,
250 );
251
252 let plaintext_list = PlaintextList::new(InputScalar::ZERO, PlaintextCount(count.0));
253
254 #[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
256 par_encrypt_seeded_lwe_ciphertext_list(
257 secret_key,
258 &mut modulus_switch_zeros,
259 &plaintext_list,
260 lwe_noise_distribution,
261 &mut engine.seeder,
262 );
263
264 #[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
266 encrypt_seeded_lwe_ciphertext_list(
267 secret_key,
268 &mut modulus_switch_zeros,
269 &plaintext_list,
270 lwe_noise_distribution,
271 &mut engine.seeder,
272 );
273
274 Self {
275 modulus_switch_zeros,
276 ms_bound,
277 ms_r_sigma_factor,
278 ms_input_variance,
279 }
280 }
281}
282
283impl<InputScalar> CompressedModulusSwitchNoiseReductionKey<InputScalar>
284where
285 InputScalar: UnsignedTorus,
286{
287 pub fn decompress(&self) -> ModulusSwitchNoiseReductionKey<InputScalar> {
288 ModulusSwitchNoiseReductionKey {
289 modulus_switch_zeros: self
290 .modulus_switch_zeros
291 .as_view()
292 .decompress_into_lwe_ciphertext_list(),
293 ms_bound: self.ms_bound,
294 ms_r_sigma_factor: self.ms_r_sigma_factor,
295 ms_input_variance: self.ms_input_variance,
296 }
297 }
298}