1use serde::{Deserialize, Serialize};
2use tfhe_csprng::seeders::Seed;
3use tfhe_versionable::Versionize;
4
5use super::{
6 apply_ms_blind_rotate, apply_programmable_bootstrap, AtomicPattern, AtomicPatternKind,
7 AtomicPatternMut,
8};
9use crate::conformance::ParameterSetConformant;
10use crate::core_crypto::prelude::{
11 allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext,
12 keyswitch_lwe_ciphertext, LweCiphertext, LweCiphertextOwned, LweDimension,
13 LweKeyswitchKeyOwned, MonomialDegree, MsDecompressionType,
14};
15use crate::shortint::backward_compatibility::atomic_pattern::StandardAtomicPatternServerKeyVersions;
16use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, Degree, NoiseLevel};
17use crate::shortint::client_key::atomic_pattern::StandardAtomicPatternClientKey;
18use crate::shortint::engine::ShortintEngine;
19use crate::shortint::oprf::generate_pseudo_random_from_pbs;
20use crate::shortint::server_key::{
21 decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
22 LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
23};
24use crate::shortint::{
25 Ciphertext, CiphertextModulus, EncryptionKeyChoice, PBSOrder, PBSParameters,
26};
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
31#[versionize(StandardAtomicPatternServerKeyVersions)]
32pub struct StandardAtomicPatternServerKey {
33 pub key_switching_key: LweKeyswitchKeyOwned<u64>,
34 pub bootstrapping_key: ShortintBootstrappingKey<u64>,
35 pub pbs_order: PBSOrder,
36}
37
38impl ParameterSetConformant for StandardAtomicPatternServerKey {
39 type ParameterSet = PBSParameters;
40
41 fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
42 let Self {
43 key_switching_key,
44 bootstrapping_key,
45 pbs_order,
46 } = self;
47
48 let pbs_conformance_params = parameter_set.into();
49
50 let pbs_key_ok = bootstrapping_key.is_conformant(&pbs_conformance_params);
51
52 let ks_conformance_params = parameter_set.into();
53
54 let ks_key_ok = key_switching_key.is_conformant(&ks_conformance_params);
55
56 let params_pbs_order: PBSOrder = parameter_set.encryption_key_choice().into();
57 let pbs_order_ok = *pbs_order == params_pbs_order;
58
59 pbs_key_ok && ks_key_ok && pbs_order_ok
60 }
61}
62
63impl StandardAtomicPatternServerKey {
64 pub fn new(cks: &StandardAtomicPatternClientKey, engine: &mut ShortintEngine) -> Self {
65 let params = &cks.parameters;
66
67 let in_key = &cks.small_lwe_secret_key();
68
69 let out_key = &cks.glwe_secret_key;
70
71 let bootstrapping_key_base = engine.new_bootstrapping_key(*params, in_key, out_key);
72
73 let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
75 &cks.large_lwe_secret_key(),
76 &cks.small_lwe_secret_key(),
77 params.ks_base_log(),
78 params.ks_level(),
79 params.lwe_noise_distribution(),
80 params.ciphertext_modulus(),
81 &mut engine.encryption_generator,
82 );
83
84 Self::from_raw_parts(
85 key_switching_key,
86 bootstrapping_key_base,
87 params.encryption_key_choice().into(),
88 )
89 }
90
91 pub fn from_raw_parts(
92 key_switching_key: LweKeyswitchKeyOwned<u64>,
93 bootstrapping_key: ShortintBootstrappingKey<u64>,
94 pbs_order: PBSOrder,
95 ) -> Self {
96 assert_eq!(
97 key_switching_key.input_key_lwe_dimension(),
98 bootstrapping_key.output_lwe_dimension(),
99 "Mismatch between the input LweKeyswitchKey LweDimension ({:?}) \
100 and the ShortintBootstrappingKey output LweDimension ({:?})",
101 key_switching_key.input_key_lwe_dimension(),
102 bootstrapping_key.output_lwe_dimension()
103 );
104
105 assert_eq!(
106 key_switching_key.output_key_lwe_dimension(),
107 bootstrapping_key.input_lwe_dimension(),
108 "Mismatch between the output LweKeyswitchKey LweDimension ({:?}) \
109 and the ShortintBootstrappingKey input LweDimension ({:?})",
110 key_switching_key.output_key_lwe_dimension(),
111 bootstrapping_key.input_lwe_dimension()
112 );
113
114 Self {
115 key_switching_key,
116 bootstrapping_key,
117 pbs_order,
118 }
119 }
120
121 pub fn intermediate_lwe_dimension(&self) -> LweDimension {
122 match self.pbs_order {
123 PBSOrder::KeyswitchBootstrap => {
124 self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small)
125 }
126 PBSOrder::BootstrapKeyswitch => {
127 self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Big)
128 }
129 }
130 }
131}
132
133impl AtomicPattern for StandardAtomicPatternServerKey {
134 fn ciphertext_lwe_dimension_for_key(&self, key_choice: EncryptionKeyChoice) -> LweDimension {
135 match key_choice {
136 EncryptionKeyChoice::Big => self.bootstrapping_key.output_lwe_dimension(),
137 EncryptionKeyChoice::Small => self.bootstrapping_key.input_lwe_dimension(),
138 }
139 }
140
141 fn ciphertext_modulus_for_key(&self, _key_choice: EncryptionKeyChoice) -> CiphertextModulus {
142 self.key_switching_key.ciphertext_modulus()
144 }
145
146 fn ciphertext_decompression_method(&self) -> MsDecompressionType {
147 match &self.bootstrapping_key {
148 ShortintBootstrappingKey::Classic { .. } => MsDecompressionType::ClassicPbs,
149 ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => {
150 MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor())
151 }
152 }
153 }
154
155 fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
156 ShortintEngine::with_thread_local_mut(|engine| {
157 let (mut ciphertext_buffer, buffers) =
158 engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
159
160 match self.pbs_order {
161 PBSOrder::KeyswitchBootstrap => {
162 keyswitch_lwe_ciphertext(
163 &self.key_switching_key,
164 &ct.ct,
165 &mut ciphertext_buffer,
166 );
167
168 apply_programmable_bootstrap(
169 &self.bootstrapping_key,
170 &ciphertext_buffer,
171 &mut ct.ct,
172 &acc.acc,
173 buffers,
174 );
175 }
176 PBSOrder::BootstrapKeyswitch => {
177 apply_programmable_bootstrap(
178 &self.bootstrapping_key,
179 &ct.ct,
180 &mut ciphertext_buffer,
181 &acc.acc,
182 buffers,
183 );
184
185 keyswitch_lwe_ciphertext(
186 &self.key_switching_key,
187 &ciphertext_buffer,
188 &mut ct.ct,
189 );
190 }
191 }
192 });
193 }
194
195 fn apply_many_lookup_table(
196 &self,
197 ct: &Ciphertext,
198 acc: &ManyLookupTableOwned,
199 ) -> Vec<Ciphertext> {
200 match self.pbs_order {
201 PBSOrder::KeyswitchBootstrap => self.keyswitch_programmable_bootstrap_many_lut(ct, acc),
202 PBSOrder::BootstrapKeyswitch => self.programmable_bootstrap_keyswitch_many_lut(ct, acc),
203 }
204 }
205
206 fn lookup_table_size(&self) -> LookupTableSize {
207 LookupTableSize::new(
208 self.bootstrapping_key.glwe_size(),
209 self.bootstrapping_key.polynomial_size(),
210 )
211 }
212
213 fn kind(&self) -> AtomicPatternKind {
214 AtomicPatternKind::Standard(self.pbs_order)
215 }
216
217 fn deterministic_execution(&self) -> bool {
218 self.bootstrapping_key.deterministic_pbs_execution()
219 }
220
221 fn generate_oblivious_pseudo_random(
222 &self,
223 seed: Seed,
224 random_bits_count: u64,
225 full_bits_count: u64,
226 ) -> (LweCiphertextOwned<u64>, Degree) {
227 let (ct, degree) = generate_pseudo_random_from_pbs(
228 &self.bootstrapping_key,
229 seed,
230 random_bits_count,
231 full_bits_count,
232 self.ciphertext_modulus(),
233 );
234
235 match self.pbs_order {
236 PBSOrder::KeyswitchBootstrap => (ct, degree),
237 PBSOrder::BootstrapKeyswitch => {
238 let mut ct_ksed = LweCiphertext::new(
239 0,
240 self.bootstrapping_key.input_lwe_dimension().to_lwe_size(),
241 self.ciphertext_modulus(),
242 );
243
244 keyswitch_lwe_ciphertext(&self.key_switching_key, &ct, &mut ct_ksed);
245
246 (ct_ksed, degree)
247 }
248 }
249 }
250
251 fn switch_modulus_and_compress(&self, ct: &Ciphertext) -> CompressedModulusSwitchedCiphertext {
252 let compressed_modulus_switched_lwe_ciphertext =
253 ShortintEngine::with_thread_local_mut(|engine| {
254 let (mut ciphertext_buffer, _) = engine
255 .get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
256
257 let input_ct = match self.pbs_order {
258 PBSOrder::KeyswitchBootstrap => {
259 keyswitch_lwe_ciphertext(
260 &self.key_switching_key,
261 &ct.ct,
262 &mut ciphertext_buffer,
263 );
264 ciphertext_buffer.as_view()
265 }
266 PBSOrder::BootstrapKeyswitch => ct.ct.as_view(),
267 };
268
269 switch_modulus_and_compress(input_ct, &self.bootstrapping_key)
270 });
271
272 CompressedModulusSwitchedCiphertext {
273 compressed_modulus_switched_lwe_ciphertext,
274 degree: ct.degree,
275 message_modulus: ct.message_modulus,
276 carry_modulus: ct.carry_modulus,
277 atomic_pattern: ct.atomic_pattern,
278 }
279 }
280
281 fn decompress_and_apply_lookup_table(
282 &self,
283 compressed_ct: &CompressedModulusSwitchedCiphertext,
284 lut: &LookupTableOwned,
285 ) -> Ciphertext {
286 let mut output = LweCiphertext::new(
287 0,
288 self.ciphertext_lwe_dimension().to_lwe_size(),
289 self.ciphertext_modulus(),
290 );
291
292 ShortintEngine::with_thread_local_mut(|engine| {
293 let (mut ciphertext_buffer, buffers) =
294 engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
295
296 match self.pbs_order {
297 PBSOrder::KeyswitchBootstrap => {
298 decompress_and_apply_lookup_table(
299 compressed_ct,
300 &lut.acc,
301 &self.bootstrapping_key,
302 &mut output.as_mut_view(),
303 buffers,
304 );
305 }
306 PBSOrder::BootstrapKeyswitch => {
307 decompress_and_apply_lookup_table(
308 compressed_ct,
309 &lut.acc,
310 &self.bootstrapping_key,
311 &mut ciphertext_buffer,
312 buffers,
313 );
314
315 keyswitch_lwe_ciphertext(
316 &self.key_switching_key,
317 &ciphertext_buffer,
318 &mut output,
319 );
320 }
321 }
322 });
323
324 Ciphertext::new(
325 output,
326 lut.degree,
327 NoiseLevel::NOMINAL,
328 compressed_ct.message_modulus,
329 compressed_ct.carry_modulus,
330 compressed_ct.atomic_pattern,
331 )
332 }
333}
334
335impl AtomicPatternMut for StandardAtomicPatternServerKey {
336 fn set_deterministic_execution(&mut self, new_deterministic_execution: bool) {
337 self.bootstrapping_key
338 .set_deterministic_pbs_execution(new_deterministic_execution)
339 }
340}
341
342impl StandardAtomicPatternServerKey {
343 pub(crate) fn keyswitch_programmable_bootstrap_many_lut(
344 &self,
345 ct: &Ciphertext,
346 lut: &ManyLookupTableOwned,
347 ) -> Vec<Ciphertext> {
348 let mut acc = lut.acc.clone();
349
350 ShortintEngine::with_thread_local_mut(|engine| {
351 let (mut ciphertext_buffer, buffers) =
352 engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
353
354 keyswitch_lwe_ciphertext(&self.key_switching_key, &ct.ct, &mut ciphertext_buffer);
356
357 apply_ms_blind_rotate(
358 &self.bootstrapping_key,
359 &ciphertext_buffer.as_view(),
360 &mut acc,
361 buffers,
362 );
363 });
364
365 let function_count = lut.function_count();
367 let mut outputs = Vec::with_capacity(function_count);
368
369 for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() {
370 let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride);
371 let mut output_shortint_ct = ct.clone();
372
373 extract_lwe_sample_from_glwe_ciphertext(
374 &acc,
375 &mut output_shortint_ct.ct,
376 monomial_degree,
377 );
378
379 output_shortint_ct.degree = *output_degree;
380 output_shortint_ct.set_noise_level_to_nominal();
381 outputs.push(output_shortint_ct);
382 }
383
384 outputs
385 }
386
387 pub(crate) fn programmable_bootstrap_keyswitch_many_lut(
388 &self,
389 ct: &Ciphertext,
390 lut: &ManyLookupTableOwned,
391 ) -> Vec<Ciphertext> {
392 let mut acc = lut.acc.clone();
393
394 ShortintEngine::with_thread_local_mut(|engine| {
395 let buffers = engine.get_computation_buffers();
397
398 apply_ms_blind_rotate(&self.bootstrapping_key, &ct.ct, &mut acc, buffers);
399 });
400
401 let function_count = lut.function_count();
403 let mut outputs = Vec::with_capacity(function_count);
404
405 let mut tmp_lwe_ciphertext = LweCiphertext::new(
406 0u64,
407 self.key_switching_key
408 .input_key_lwe_dimension()
409 .to_lwe_size(),
410 self.key_switching_key.ciphertext_modulus(),
411 );
412
413 for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() {
414 let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride);
415 extract_lwe_sample_from_glwe_ciphertext(&acc, &mut tmp_lwe_ciphertext, monomial_degree);
416
417 let mut output_shortint_ct = ct.clone();
418
419 keyswitch_lwe_ciphertext(
421 &self.key_switching_key,
422 &tmp_lwe_ciphertext,
423 &mut output_shortint_ct.ct,
424 );
425
426 output_shortint_ct.degree = *output_degree;
427 output_shortint_ct.set_noise_level_to_nominal();
428 outputs.push(output_shortint_ct);
429 }
430
431 outputs
432 }
433}