tfhe/shortint/server_key/
modulus_switch_noise_reduction.rs

1use super::{PBSConformanceParams, PbsTypeConformanceParams};
2use crate::conformance::ParameterSetConformant;
3use crate::core_crypto::algorithms::*;
4use crate::core_crypto::commons::math::random::{CompressionSeed, DynamicDistribution, Uniform};
5use crate::core_crypto::commons::parameters::{
6    LweDimension, NoiseEstimationMeasureBound, PlaintextCount, RSigmaFactor,
7};
8use crate::core_crypto::commons::traits::*;
9use crate::core_crypto::entities::*;
10use crate::core_crypto::prelude::modulus_switch_noise_reduction::improve_lwe_ciphertext_modulus_switch_noise_for_binary_key;
11use crate::core_crypto::prelude::{
12    CiphertextModulus as CoreCiphertextModulus, CiphertextModulusLog, Variance,
13};
14use crate::shortint::backward_compatibility::server_key::modulus_switch_noise_reduction::*;
15use crate::shortint::engine::ShortintEngine;
16use crate::shortint::parameters::ModulusSwitchNoiseReductionParams;
17
18use serde::{Deserialize, Serialize};
19use std::fmt::Debug;
20use tfhe_versionable::Versionize;
21
22#[derive(Copy, Clone)]
23pub struct ModulusSwitchNoiseReductionKeyConformanceParams {
24    pub modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
25    pub lwe_dimension: LweDimension,
26}
27
28impl TryFrom<&PBSConformanceParams> for ModulusSwitchNoiseReductionKeyConformanceParams {
29    type Error = ();
30
31    fn try_from(value: &PBSConformanceParams) -> Result<Self, ()> {
32        match &value.pbs_type {
33            PbsTypeConformanceParams::Classic {
34                modulus_switch_noise_reduction,
35            } => modulus_switch_noise_reduction.map_or(Err(()), |modulus_switch_noise_reduction| {
36                Ok(Self {
37                    modulus_switch_noise_reduction_params: modulus_switch_noise_reduction,
38                    lwe_dimension: value.in_lwe_dimension,
39                })
40            }),
41            PbsTypeConformanceParams::MultiBit { .. } => Err(()),
42        }
43    }
44}
45
46/// Before applying a modulus switch to a ciphertext, it's possible to modify it (but not the value
47/// it encrypts) in a way that decreases the noise added by the subsequent modulus switch.
48///
49/// A [ModulusSwitchNoiseReductionKey] is needed to perform this modification.
50/// [improve_modulus_switch_noise](ModulusSwitchNoiseReductionKey::improve_modulus_switch_noise)
51/// method can then be called on the target ciphertext.
52///
53/// The lower level primitive is
54/// [improve_lwe_ciphertext_modulus_switch_noise_for_binary_key](crate::core_crypto::algorithms::modulus_switch_noise_reduction::improve_lwe_ciphertext_modulus_switch_noise_for_binary_key)
55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
56#[versionize(ModulusSwitchNoiseReductionKeyVersions)]
57pub struct ModulusSwitchNoiseReductionKey<InputScalar>
58where
59    InputScalar: UnsignedInteger,
60{
61    pub modulus_switch_zeros: LweCiphertextListOwned<InputScalar>,
62    pub ms_bound: NoiseEstimationMeasureBound,
63    pub ms_r_sigma_factor: RSigmaFactor,
64    pub ms_input_variance: Variance,
65}
66
67impl<InputScalar> ParameterSetConformant for ModulusSwitchNoiseReductionKey<InputScalar>
68where
69    InputScalar: UnsignedInteger,
70{
71    type ParameterSet = ModulusSwitchNoiseReductionKeyConformanceParams;
72
73    fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
74        let Self {
75            modulus_switch_zeros,
76            ms_bound,
77            ms_r_sigma_factor,
78            ms_input_variance,
79        } = self;
80
81        let ModulusSwitchNoiseReductionKeyConformanceParams {
82            modulus_switch_noise_reduction_params,
83            lwe_dimension,
84        } = parameter_set;
85
86        let ModulusSwitchNoiseReductionParams {
87            modulus_switch_zeros_count: param_modulus_switch_zeros_count,
88            ms_bound: param_ms_bound,
89            ms_r_sigma_factor: param_ms_r_sigma_factor,
90            ms_input_variance: param_ms_input_variance,
91        } = modulus_switch_noise_reduction_params;
92
93        ms_bound == param_ms_bound
94            && ms_r_sigma_factor == param_ms_r_sigma_factor
95            && ms_input_variance == param_ms_input_variance
96            && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0
97            && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension
98    }
99}
100
101impl<InputScalar> ModulusSwitchNoiseReductionKey<InputScalar>
102where
103    InputScalar: UnsignedInteger,
104{
105    pub fn improve_modulus_switch_noise<Cont>(
106        &self,
107        input: &mut LweCiphertext<Cont>,
108        log_modulus: CiphertextModulusLog,
109    ) where
110        Cont: ContainerMut<Element = InputScalar>,
111    {
112        improve_lwe_ciphertext_modulus_switch_noise_for_binary_key(
113            input,
114            &self.modulus_switch_zeros,
115            self.ms_r_sigma_factor,
116            self.ms_bound,
117            self.ms_input_variance,
118            log_modulus,
119        );
120    }
121}
122
123#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
124#[versionize(CompressedModulusSwitchNoiseReductionKeyVersions)]
125pub struct CompressedModulusSwitchNoiseReductionKey<InputScalar>
126where
127    InputScalar: UnsignedInteger,
128{
129    pub modulus_switch_zeros: SeededLweCiphertextListOwned<InputScalar>,
130    pub ms_bound: NoiseEstimationMeasureBound,
131    pub ms_r_sigma_factor: RSigmaFactor,
132    pub ms_input_variance: Variance,
133}
134
135impl<InputScalar> ParameterSetConformant for CompressedModulusSwitchNoiseReductionKey<InputScalar>
136where
137    InputScalar: UnsignedInteger,
138{
139    type ParameterSet = ModulusSwitchNoiseReductionKeyConformanceParams;
140
141    fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
142        let Self {
143            modulus_switch_zeros,
144            ms_bound,
145            ms_r_sigma_factor,
146            ms_input_variance,
147        } = self;
148
149        let ModulusSwitchNoiseReductionKeyConformanceParams {
150            modulus_switch_noise_reduction_params,
151            lwe_dimension,
152        } = parameter_set;
153
154        let ModulusSwitchNoiseReductionParams {
155            modulus_switch_zeros_count: param_modulus_switch_zeros_count,
156            ms_bound: param_ms_bound,
157            ms_r_sigma_factor: param_ms_r_sigma_factor,
158            ms_input_variance: param_ms_input_variance,
159        } = modulus_switch_noise_reduction_params;
160
161        ms_bound == param_ms_bound
162            && ms_r_sigma_factor == param_ms_r_sigma_factor
163            && ms_input_variance == param_ms_input_variance
164            && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0
165            && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension
166    }
167}
168
169impl<InputScalar> ModulusSwitchNoiseReductionKey<InputScalar>
170where
171    InputScalar: Encryptable<Uniform, DynamicDistribution<InputScalar>>,
172{
173    pub fn new<KeyCont: Container<Element = InputScalar> + Sync>(
174        modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
175        secret_key: &LweSecretKey<KeyCont>,
176        engine: &mut ShortintEngine,
177        ciphertext_modulus: CoreCiphertextModulus<InputScalar>,
178        lwe_noise_distribution: DynamicDistribution<InputScalar>,
179    ) -> Self {
180        let ModulusSwitchNoiseReductionParams {
181            modulus_switch_zeros_count: count,
182            ms_bound,
183            ms_r_sigma_factor,
184            ms_input_variance,
185        } = modulus_switch_noise_reduction_params;
186
187        let lwe_size = secret_key.lwe_dimension().to_lwe_size();
188
189        let mut modulus_switch_zeros =
190            LweCiphertextList::new(InputScalar::ZERO, lwe_size, count, ciphertext_modulus);
191
192        let plaintext_list = PlaintextList::new(InputScalar::ZERO, PlaintextCount(count.0));
193
194        // Parallelism allowed
195        #[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
196        par_encrypt_lwe_ciphertext_list(
197            secret_key,
198            &mut modulus_switch_zeros,
199            &plaintext_list,
200            lwe_noise_distribution,
201            &mut engine.encryption_generator,
202        );
203
204        // No parallelism allowed
205        #[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
206        encrypt_lwe_ciphertext_list(
207            secret_key,
208            &mut modulus_switch_zeros,
209            &plaintext_list,
210            lwe_noise_distribution,
211            &mut engine.encryption_generator,
212        );
213
214        Self {
215            modulus_switch_zeros,
216            ms_bound,
217            ms_r_sigma_factor,
218            ms_input_variance,
219        }
220    }
221}
222
223impl<InputScalar> CompressedModulusSwitchNoiseReductionKey<InputScalar>
224where
225    InputScalar: Encryptable<Uniform, DynamicDistribution<InputScalar>>,
226{
227    pub fn new<KeyCont: Container<Element = InputScalar> + Sync>(
228        modulus_switch_noise_reduction_params: ModulusSwitchNoiseReductionParams,
229        secret_key: &LweSecretKey<KeyCont>,
230        engine: &mut ShortintEngine,
231        ciphertext_modulus: CoreCiphertextModulus<InputScalar>,
232        lwe_noise_distribution: DynamicDistribution<InputScalar>,
233        compression_seed: CompressionSeed,
234    ) -> Self {
235        let ModulusSwitchNoiseReductionParams {
236            modulus_switch_zeros_count: count,
237            ms_bound,
238            ms_r_sigma_factor,
239            ms_input_variance,
240        } = modulus_switch_noise_reduction_params;
241
242        let lwe_size = secret_key.lwe_dimension().to_lwe_size();
243
244        let mut modulus_switch_zeros = SeededLweCiphertextList::new(
245            InputScalar::ZERO,
246            lwe_size,
247            count,
248            compression_seed,
249            ciphertext_modulus,
250        );
251
252        let plaintext_list = PlaintextList::new(InputScalar::ZERO, PlaintextCount(count.0));
253
254        // Parallelism allowed
255        #[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
256        par_encrypt_seeded_lwe_ciphertext_list(
257            secret_key,
258            &mut modulus_switch_zeros,
259            &plaintext_list,
260            lwe_noise_distribution,
261            &mut engine.seeder,
262        );
263
264        // No parallelism allowed
265        #[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
266        encrypt_seeded_lwe_ciphertext_list(
267            secret_key,
268            &mut modulus_switch_zeros,
269            &plaintext_list,
270            lwe_noise_distribution,
271            &mut engine.seeder,
272        );
273
274        Self {
275            modulus_switch_zeros,
276            ms_bound,
277            ms_r_sigma_factor,
278            ms_input_variance,
279        }
280    }
281}
282
283impl<InputScalar> CompressedModulusSwitchNoiseReductionKey<InputScalar>
284where
285    InputScalar: UnsignedTorus,
286{
287    pub fn decompress(&self) -> ModulusSwitchNoiseReductionKey<InputScalar> {
288        ModulusSwitchNoiseReductionKey {
289            modulus_switch_zeros: self
290                .modulus_switch_zeros
291                .as_view()
292                .decompress_into_lwe_ciphertext_list(),
293            ms_bound: self.ms_bound,
294            ms_r_sigma_factor: self.ms_r_sigma_factor,
295            ms_input_variance: self.ms_input_variance,
296        }
297    }
298}