tfhe/shortint/atomic_pattern/
standard.rs

1use serde::{Deserialize, Serialize};
2use tfhe_csprng::seeders::Seed;
3use tfhe_versionable::Versionize;
4
5use super::{
6    apply_ms_blind_rotate, apply_programmable_bootstrap, AtomicPattern, AtomicPatternKind,
7    AtomicPatternMut,
8};
9use crate::conformance::ParameterSetConformant;
10use crate::core_crypto::prelude::{
11    allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext,
12    keyswitch_lwe_ciphertext, LweCiphertext, LweCiphertextOwned, LweDimension,
13    LweKeyswitchKeyOwned, MonomialDegree, MsDecompressionType,
14};
15use crate::shortint::backward_compatibility::atomic_pattern::StandardAtomicPatternServerKeyVersions;
16use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
17use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
18use crate::shortint::engine::ShortintEngine;
19use crate::shortint::oprf::generate_pseudo_random_from_pbs;
20use crate::shortint::server_key::{
21    decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
22    LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
23};
24use crate::shortint::{
25    Ciphertext, CiphertextModulus, EncryptionKeyChoice, PBSOrder, PBSParameters,
26};
27
28/// The definition of the server key elements used in the [`Standard`](AtomicPatternKind::Standard)
29/// atomic pattern
30#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
31#[versionize(StandardAtomicPatternServerKeyVersions)]
32pub struct StandardAtomicPatternServerKey {
33    pub key_switching_key: LweKeyswitchKeyOwned<u64>,
34    pub bootstrapping_key: ShortintBootstrappingKey<u64>,
35    pub pbs_order: PBSOrder,
36}
37
38impl ParameterSetConformant for StandardAtomicPatternServerKey {
39    type ParameterSet = PBSParameters;
40
41    fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
42        let Self {
43            key_switching_key,
44            bootstrapping_key,
45            pbs_order,
46        } = self;
47
48        let pbs_conformance_params = parameter_set.into();
49
50        let pbs_key_ok = bootstrapping_key.is_conformant(&pbs_conformance_params);
51
52        let ks_conformance_params = parameter_set.into();
53
54        let ks_key_ok = key_switching_key.is_conformant(&ks_conformance_params);
55
56        let params_pbs_order: PBSOrder = parameter_set.encryption_key_choice().into();
57        let pbs_order_ok = *pbs_order == params_pbs_order;
58
59        pbs_key_ok && ks_key_ok && pbs_order_ok
60    }
61}
62
63impl StandardAtomicPatternServerKey {
64    pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
65        let params = &cks.parameters;
66
67        let in_key = &cks.small_lwe_secret_key();
68
69        let out_key = &cks.glwe_secret_key;
70
71        let bootstrapping_key_base = engine.new_bootstrapping_key(*params, in_key, out_key);
72
73        // Creation of the key switching key
74        let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
75            &cks.large_lwe_secret_key(),
76            &cks.small_lwe_secret_key(),
77            params.ks_base_log(),
78            params.ks_level(),
79            params.lwe_noise_distribution(),
80            params.ciphertext_modulus(),
81            &mut engine.encryption_generator,
82        );
83
84        Self::from_raw_parts(
85            key_switching_key,
86            bootstrapping_key_base,
87            params.encryption_key_choice().into(),
88        )
89    }
90
91    pub fn from_raw_parts(
92        key_switching_key: LweKeyswitchKeyOwned<u64>,
93        bootstrapping_key: ShortintBootstrappingKey<u64>,
94        pbs_order: PBSOrder,
95    ) -> Self {
96        assert_eq!(
97            key_switching_key.input_key_lwe_dimension(),
98            bootstrapping_key.output_lwe_dimension(),
99            "Mismatch between the input LweKeyswitchKey LweDimension ({:?}) \
100            and the ShortintBootstrappingKey output LweDimension ({:?})",
101            key_switching_key.input_key_lwe_dimension(),
102            bootstrapping_key.output_lwe_dimension()
103        );
104
105        assert_eq!(
106            key_switching_key.output_key_lwe_dimension(),
107            bootstrapping_key.input_lwe_dimension(),
108            "Mismatch between the output LweKeyswitchKey LweDimension ({:?}) \
109            and the ShortintBootstrappingKey input LweDimension ({:?})",
110            key_switching_key.output_key_lwe_dimension(),
111            bootstrapping_key.input_lwe_dimension()
112        );
113
114        Self {
115            key_switching_key,
116            bootstrapping_key,
117            pbs_order,
118        }
119    }
120
121    pub fn intermediate_lwe_dimension(&self) -> LweDimension {
122        match self.pbs_order {
123            PBSOrder::KeyswitchBootstrap => {
124                self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small)
125            }
126            PBSOrder::BootstrapKeyswitch => {
127                self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Big)
128            }
129        }
130    }
131}
132
133impl AtomicPattern for StandardAtomicPatternServerKey {
134    fn ciphertext_lwe_dimension_for_key(&self, key_choice: EncryptionKeyChoice) -> LweDimension {
135        match key_choice {
136            EncryptionKeyChoice::Big => self.bootstrapping_key.output_lwe_dimension(),
137            EncryptionKeyChoice::Small => self.bootstrapping_key.input_lwe_dimension(),
138        }
139    }
140
141    fn ciphertext_modulus_for_key(&self, _key_choice: EncryptionKeyChoice) -> CiphertextModulus {
142        // Both keys use the same modulus
143        self.key_switching_key.ciphertext_modulus()
144    }
145
146    fn ciphertext_decompression_method(&self) -> MsDecompressionType {
147        match &self.bootstrapping_key {
148            ShortintBootstrappingKey::Classic { .. } => MsDecompressionType::ClassicPbs,
149            ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => {
150                MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor())
151            }
152        }
153    }
154
155    fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
156        ShortintEngine::with_thread_local_mut(|engine| {
157            let (mut ciphertext_buffer, buffers) =
158                engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
159
160            match self.pbs_order {
161                PBSOrder::KeyswitchBootstrap => {
162                    keyswitch_lwe_ciphertext(
163                        &self.key_switching_key,
164                        &ct.ct,
165                        &mut ciphertext_buffer,
166                    );
167
168                    apply_programmable_bootstrap(
169                        &self.bootstrapping_key,
170                        &ciphertext_buffer,
171                        &mut ct.ct,
172                        &acc.acc,
173                        buffers,
174                    );
175                }
176                PBSOrder::BootstrapKeyswitch => {
177                    apply_programmable_bootstrap(
178                        &self.bootstrapping_key,
179                        &ct.ct,
180                        &mut ciphertext_buffer,
181                        &acc.acc,
182                        buffers,
183                    );
184
185                    keyswitch_lwe_ciphertext(
186                        &self.key_switching_key,
187                        &ciphertext_buffer,
188                        &mut ct.ct,
189                    );
190                }
191            }
192        });
193    }
194
195    fn apply_many_lookup_table(
196        &self,
197        ct: &Ciphertext,
198        acc: &ManyLookupTableOwned,
199    ) -> Vec<Ciphertext> {
200        match self.pbs_order {
201            PBSOrder::KeyswitchBootstrap => self.keyswitch_programmable_bootstrap_many_lut(ct, acc),
202            PBSOrder::BootstrapKeyswitch => self.programmable_bootstrap_keyswitch_many_lut(ct, acc),
203        }
204    }
205
206    fn lookup_table_size(&self) -> LookupTableSize {
207        LookupTableSize::new(
208            self.bootstrapping_key.glwe_size(),
209            self.bootstrapping_key.polynomial_size(),
210        )
211    }
212
213    fn kind(&self) -> AtomicPatternKind {
214        AtomicPatternKind::Standard(self.pbs_order)
215    }
216
217    fn deterministic_execution(&self) -> bool {
218        self.bootstrapping_key.deterministic_pbs_execution()
219    }
220
221    fn generate_oblivious_pseudo_random(
222        &self,
223        seed: Seed,
224        random_bits_count: u64,
225        full_bits_count: u64,
226    ) -> (LweCiphertextOwned<u64>, Degree) {
227        let (ct, degree) = generate_pseudo_random_from_pbs(
228            &self.bootstrapping_key,
229            seed,
230            random_bits_count,
231            full_bits_count,
232            self.ciphertext_modulus(),
233        );
234
235        match self.pbs_order {
236            PBSOrder::KeyswitchBootstrap => (ct, degree),
237            PBSOrder::BootstrapKeyswitch => {
238                let mut ct_ksed = LweCiphertext::new(
239                    0,
240                    self.bootstrapping_key.input_lwe_dimension().to_lwe_size(),
241                    self.ciphertext_modulus(),
242                );
243
244                keyswitch_lwe_ciphertext(&self.key_switching_key, &ct, &mut ct_ksed);
245
246                (ct_ksed, degree)
247            }
248        }
249    }
250
251    fn switch_modulus_and_compress(&self, ct: &Ciphertext) -> CompressedModulusSwitchedCiphertext {
252        let compressed_modulus_switched_lwe_ciphertext =
253            ShortintEngine::with_thread_local_mut(|engine| {
254                let (mut ciphertext_buffer, _) = engine
255                    .get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
256
257                let input_ct = match self.pbs_order {
258                    PBSOrder::KeyswitchBootstrap => {
259                        keyswitch_lwe_ciphertext(
260                            &self.key_switching_key,
261                            &ct.ct,
262                            &mut ciphertext_buffer,
263                        );
264                        ciphertext_buffer.as_view()
265                    }
266                    PBSOrder::BootstrapKeyswitch => ct.ct.as_view(),
267                };
268
269                switch_modulus_and_compress(input_ct, &self.bootstrapping_key)
270            });
271
272        CompressedModulusSwitchedCiphertext {
273            compressed_modulus_switched_lwe_ciphertext,
274            degree: ct.degree,
275            message_modulus: ct.message_modulus,
276            carry_modulus: ct.carry_modulus,
277            atomic_pattern: ct.atomic_pattern,
278        }
279    }
280
281    fn decompress_and_apply_lookup_table(
282        &self,
283        compressed_ct: &CompressedModulusSwitchedCiphertext,
284        lut: &LookupTableOwned,
285    ) -> Ciphertext {
286        let mut output = LweCiphertext::new(
287            0,
288            self.ciphertext_lwe_dimension().to_lwe_size(),
289            self.ciphertext_modulus(),
290        );
291
292        ShortintEngine::with_thread_local_mut(|engine| {
293            let (mut ciphertext_buffer, buffers) =
294                engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
295
296            match self.pbs_order {
297                PBSOrder::KeyswitchBootstrap => {
298                    decompress_and_apply_lookup_table(
299                        compressed_ct,
300                        &lut.acc,
301                        &self.bootstrapping_key,
302                        &mut output.as_mut_view(),
303                        buffers,
304                    );
305                }
306                PBSOrder::BootstrapKeyswitch => {
307                    decompress_and_apply_lookup_table(
308                        compressed_ct,
309                        &lut.acc,
310                        &self.bootstrapping_key,
311                        &mut ciphertext_buffer,
312                        buffers,
313                    );
314
315                    keyswitch_lwe_ciphertext(
316                        &self.key_switching_key,
317                        &ciphertext_buffer,
318                        &mut output,
319                    );
320                }
321            }
322        });
323
324        Ciphertext::new(
325            output,
326            lut.degree,
327            NoiseLevel::NOMINAL,
328            compressed_ct.message_modulus,
329            compressed_ct.carry_modulus,
330            compressed_ct.atomic_pattern,
331        )
332    }
333}
334
335impl AtomicPatternMut for StandardAtomicPatternServerKey {
336    fn set_deterministic_execution(&mut self, new_deterministic_execution: bool) {
337        self.bootstrapping_key
338            .set_deterministic_pbs_execution(new_deterministic_execution)
339    }
340}
341
342impl StandardAtomicPatternServerKey {
343    pub(crate) fn keyswitch_programmable_bootstrap_many_lut(
344        &self,
345        ct: &Ciphertext,
346        lut: &ManyLookupTableOwned,
347    ) -> Vec<Ciphertext> {
348        let mut acc = lut.acc.clone();
349
350        ShortintEngine::with_thread_local_mut(|engine| {
351            let (mut ciphertext_buffer, buffers) =
352                engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
353
354            // Compute a key switch
355            keyswitch_lwe_ciphertext(&self.key_switching_key, &ct.ct, &mut ciphertext_buffer);
356
357            apply_ms_blind_rotate(
358                &self.bootstrapping_key,
359                &ciphertext_buffer.as_view(),
360                &mut acc,
361                buffers,
362            );
363        });
364
365        // The accumulator has been rotated, we can now proceed with the various sample extractions
366        let function_count = lut.function_count();
367        let mut outputs = Vec::with_capacity(function_count);
368
369        for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() {
370            let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride);
371            let mut output_shortint_ct = ct.clone();
372
373            extract_lwe_sample_from_glwe_ciphertext(
374                &acc,
375                &mut output_shortint_ct.ct,
376                monomial_degree,
377            );
378
379            output_shortint_ct.degree = *output_degree;
380            output_shortint_ct.set_noise_level_to_nominal();
381            outputs.push(output_shortint_ct);
382        }
383
384        outputs
385    }
386
387    pub(crate) fn programmable_bootstrap_keyswitch_many_lut(
388        &self,
389        ct: &Ciphertext,
390        lut: &ManyLookupTableOwned,
391    ) -> Vec<Ciphertext> {
392        let mut acc = lut.acc.clone();
393
394        ShortintEngine::with_thread_local_mut(|engine| {
395            // Compute the programmable bootstrapping with fixed test polynomial
396            let buffers = engine.get_computation_buffers();
397
398            apply_ms_blind_rotate(&self.bootstrapping_key, &ct.ct, &mut acc, buffers);
399        });
400
401        // The accumulator has been rotated, we can now proceed with the various sample extractions
402        let function_count = lut.function_count();
403        let mut outputs = Vec::with_capacity(function_count);
404
405        let mut tmp_lwe_ciphertext = LweCiphertext::new(
406            0u64,
407            self.key_switching_key
408                .input_key_lwe_dimension()
409                .to_lwe_size(),
410            self.key_switching_key.ciphertext_modulus(),
411        );
412
413        for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() {
414            let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride);
415            extract_lwe_sample_from_glwe_ciphertext(&acc, &mut tmp_lwe_ciphertext, monomial_degree);
416
417            let mut output_shortint_ct = ct.clone();
418
419            // Compute a key switch
420            keyswitch_lwe_ciphertext(
421                &self.key_switching_key,
422                &tmp_lwe_ciphertext,
423                &mut output_shortint_ct.ct,
424            );
425
426            output_shortint_ct.degree = *output_degree;
427            output_shortint_ct.set_noise_level_to_nominal();
428            outputs.push(output_shortint_ct);
429        }
430
431        outputs
432    }
433}