1use super::server_key::{
2 apply_programmable_bootstrap_no_ms_noise_reduction, GenericServerKey, LookupTableSize,
3 ShortintBootstrappingKey,
4};
5use super::Ciphertext;
6use crate::core_crypto::fft_impl::common::modulus_switch;
7use crate::core_crypto::prelude::{
8 lwe_ciphertext_plaintext_add_assign, CastFrom, CastInto, CiphertextModulus,
9 CiphertextModulusLog, LweCiphertext, LweCiphertextOwned, LweSize, Plaintext, UnsignedInteger,
10 UnsignedTorus,
11};
12use crate::shortint::atomic_pattern::AtomicPattern;
13use crate::shortint::ciphertext::Degree;
14use crate::shortint::engine::ShortintEngine;
15use crate::shortint::parameters::NoiseLevel;
16use crate::shortint::server_key::generate_lookup_table_no_encode;
17use tfhe_csprng::seeders::Seed;
18
19pub fn sha3_hash<Scalar>(values: &mut [Scalar], seed: Seed)
20where
21 Scalar: UnsignedInteger,
22{
23 use sha3::digest::{ExtendableOutput, Update, XofReader};
24
25 let mut hasher = sha3::Shake256::default();
26
27 let bytes = seed.0.to_le_bytes();
28
29 hasher.update(bytes.as_slice());
30
31 let mut reader = hasher.finalize_xof();
32
33 for value in values {
34 let bytes = bytemuck::bytes_of_mut(value);
35 reader.read(bytes);
36 *value = value.to_le();
38 }
39}
40pub fn create_random_from_seed<Scalar>(
41 seed: Seed,
42 lwe_size: LweSize,
43 ciphertext_modulus: CiphertextModulus<Scalar>,
44) -> LweCiphertext<Vec<Scalar>>
45where
46 Scalar: UnsignedInteger,
47{
48 let mut ct = LweCiphertext::new(Scalar::ZERO, lwe_size, ciphertext_modulus);
49
50 sha3_hash(ct.get_mut_mask().as_mut(), seed);
51
52 ct
53}
54
55pub fn create_random_from_seed_modulus_switched<Scalar>(
56 seed: Seed,
57 lwe_size: LweSize,
58 log_modulus: CiphertextModulusLog,
59 ciphertext_modulus: CiphertextModulus<Scalar>,
60) -> LweCiphertext<Vec<Scalar>>
61where
62 Scalar: UnsignedInteger,
63{
64 let mut ct = create_random_from_seed(seed, lwe_size, ciphertext_modulus);
65
66 for i in ct.as_mut() {
67 *i = modulus_switch(*i, log_modulus) << (Scalar::BITS - log_modulus.0);
68 }
69
70 ct
71}
72
73pub(crate) fn generate_pseudo_random_from_pbs<InputScalar>(
83 bootstrapping_key: &ShortintBootstrappingKey<InputScalar>,
84 seed: Seed,
85 random_bits_count: u64,
86 full_bits_count: u64,
87 ciphertext_modulus: CiphertextModulus<u64>,
88) -> (LweCiphertextOwned<u64>, Degree)
89where
90 InputScalar: UnsignedTorus + CastFrom<usize> + CastInto<usize>,
91{
92 assert!(
93 random_bits_count <= full_bits_count,
94 "The number of random bits asked for (={random_bits_count}) is bigger than full_bits_count (={full_bits_count})"
95 );
96
97 let in_lwe_size = bootstrapping_key.input_lwe_dimension().to_lwe_size();
98
99 let seeded = create_random_from_seed_modulus_switched(
100 seed,
101 in_lwe_size,
102 bootstrapping_key
103 .polynomial_size()
104 .to_blind_rotation_input_modulus_log(),
105 CiphertextModulus::new_native(),
106 );
107
108 let p = 1 << random_bits_count;
109 let degree = p - 1;
110
111 let delta = 1_u64 << (64 - full_bits_count);
112
113 let poly_delta = 2 * bootstrapping_key.polynomial_size().0 as u64 / p;
114
115 let lut_size = LookupTableSize::new(
116 bootstrapping_key.glwe_size(),
117 bootstrapping_key.polynomial_size(),
118 );
119 let acc = generate_lookup_table_no_encode(lut_size, ciphertext_modulus, |x| {
120 (2 * (x / poly_delta) + 1) * delta / 2
121 });
122
123 let out_lwe_size = bootstrapping_key.output_lwe_dimension().to_lwe_size();
124
125 let mut ct = LweCiphertext::new(0, out_lwe_size, ciphertext_modulus);
126
127 ShortintEngine::with_thread_local_mut(|engine| {
128 let buffers = engine.get_computation_buffers();
129
130 apply_programmable_bootstrap_no_ms_noise_reduction(
131 bootstrapping_key,
132 &seeded,
133 &mut ct,
134 &acc,
135 buffers,
136 );
137 });
138
139 lwe_ciphertext_plaintext_add_assign(&mut ct, Plaintext(degree * delta / 2));
140 (ct, Degree(degree))
141}
142
143impl<AP: AtomicPattern> GenericServerKey<AP> {
144 pub fn generate_oblivious_pseudo_random(
148 &self,
149 seed: Seed,
150 random_bits_count: u64,
151 ) -> Ciphertext {
152 assert!(
153 random_bits_count < 64,
154 "random_bits_count >= 64 is not supported",
155 );
156 assert!(
157 1 << random_bits_count <= self.message_modulus.0,
158 "The range asked for a random value (=[0, 2^{}[) does not fit in the available range [0, {}[",
159 random_bits_count, self.message_modulus.0
160 );
161
162 self.generate_oblivious_pseudo_random_message_and_carry(seed, random_bits_count)
163 }
164
165 pub(crate) fn generate_oblivious_pseudo_random_message_and_carry(
168 &self,
169 seed: Seed,
170 random_bits_count: u64,
171 ) -> Ciphertext {
172 assert!(
173 self.message_modulus.0.is_power_of_two(),
174 "The message modulus(={}), must be a power of 2 to use the OPRF",
175 self.message_modulus.0
176 );
177 let message_bits_count = self.message_modulus.0.ilog2() as u64;
178
179 assert!(
180 self.carry_modulus.0.is_power_of_two(),
181 "The carry modulus(={}), must be a power of 2 to use the OPRF",
182 self.carry_modulus.0
183 );
184 let carry_bits_count = self.carry_modulus.0.ilog2() as u64;
185
186 assert!(
187 random_bits_count <= carry_bits_count + message_bits_count,
188 "The number of random bits asked for (={random_bits_count}) is bigger than carry_bits_count (={carry_bits_count}) + message_bits_count(={message_bits_count})",
189 );
190
191 let (ct, degree) = self.atomic_pattern.generate_oblivious_pseudo_random(
192 seed,
193 random_bits_count,
194 1 + carry_bits_count + message_bits_count,
195 );
196
197 Ciphertext::new(
198 ct,
199 degree,
200 NoiseLevel::NOMINAL,
201 self.message_modulus,
202 self.carry_modulus,
203 self.atomic_pattern.kind(),
204 )
205 }
206}
207
208#[cfg(test)]
209pub(crate) mod test {
210 use crate::core_crypto::prelude::{decrypt_lwe_ciphertext, LweSecretKey};
211 use crate::shortint::{ClientKey, ServerKey, ShortintParameterSet};
212
213 use super::*;
214
215 use rayon::prelude::*;
216 use statrs::distribution::ContinuousCDF;
217 use std::collections::HashMap;
218 use tfhe_csprng::seeders::Seed;
219
220 fn square(a: f64) -> f64 {
221 a * a
222 }
223
224 #[test]
225 fn oprf_compare_plain_ci_run_filter() {
226 use crate::shortint::gen_keys;
227 use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
228 use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
229
230 let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
231
232 for seed in 0..1000 {
233 oprf_compare_plain_from_seed::<u64>(Seed(seed), &ck, &sk);
234 }
235
236 let (ck, sk) = gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
237
238 for seed in 0..1000 {
239 oprf_compare_plain_from_seed::<u32>(Seed(seed), &ck, &sk);
240 }
241 }
242
243 fn oprf_compare_plain_from_seed<Scalar: UnsignedInteger + CastFrom<u64> + CastInto<u64>>(
244 seed: Seed,
245 ck: &ClientKey,
246 sk: &ServerKey,
247 ) {
248 let params = ck.parameters;
249
250 let random_bits_count = 2;
251
252 let input_p = 2 * params.polynomial_size().0 as u64;
253
254 let log_input_p = input_p.ilog2() as usize;
255
256 let p_prime = 1 << random_bits_count;
257
258 let output_p = 2 * params.carry_modulus().0 * params.message_modulus().0;
259
260 let poly_delta = 2 * params.polynomial_size().0 as u64 / p_prime;
261
262 let img = sk.generate_oblivious_pseudo_random(seed, random_bits_count);
263
264 let lwe_size = params.lwe_dimension().to_lwe_size();
265
266 let ct = create_random_from_seed_modulus_switched(
267 seed,
268 lwe_size,
269 params
270 .polynomial_size()
271 .to_blind_rotation_input_modulus_log(),
272 CiphertextModulus::new_native(),
273 );
274
275 let sk = LweSecretKey::from_container(
276 ck.small_lwe_secret_key()
277 .as_ref()
278 .iter()
279 .copied()
280 .map(|x| Scalar::cast_from(x))
281 .collect::<Vec<_>>(),
282 );
283
284 let plain_prf_input = CastInto::<u64>::cast_into(
285 decrypt_lwe_ciphertext(&sk, &ct)
286 .0
287 .wrapping_add(Scalar::ONE << (Scalar::BITS - log_input_p - 1))
288 >> (Scalar::BITS - log_input_p),
289 );
290
291 let half_negacyclic_part = |x| 2 * (x / poly_delta) + 1;
292
293 let negacyclic_part = |x| {
294 assert!(x < input_p);
295 if x < input_p / 2 {
296 half_negacyclic_part(x)
297 } else {
298 2 * output_p - half_negacyclic_part(x - (input_p / 2))
299 }
300 };
301
302 let prf = |x| {
303 let a = (negacyclic_part(x) + p_prime - 1) % (2 * output_p);
304 assert!(a % 2 == 0);
305 a / 2
306 };
307
308 let expected_output = prf(plain_prf_input);
309 let output = ck.decrypt_message_and_carry(&img);
310
311 assert!(output < p_prime);
312 assert_eq!(output, expected_output);
313 }
314
315 #[test]
316 fn oprf_test_uniformity_ci_run_filter() {
317 let sample_count: usize = 100_000;
318
319 let p_value_limit: f64 = 0.000_01;
320
321 use crate::shortint::gen_keys;
322 use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
323 use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
324
325 for params in [
326 ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
327 ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
328 ] {
329 let (ck, sk) = gen_keys(params);
330
331 let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
332 test_uniformity(sample_count, p_value_limit, distinct_values, f)
333 };
334
335 let random_bits_count = 2;
336
337 test_uniformity(1 << random_bits_count, &|seed| {
338 let img =
339 sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
340
341 ck.decrypt_message_and_carry(&img)
342 });
343 }
344 }
345
346 pub fn test_uniformity<F>(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F)
347 where
348 F: Sync + Fn(usize) -> u64,
349 {
350 let p_value = uniformity_p_value(f, sample_count, distinct_values);
351
352 assert!(
353 p_value_limit < p_value,
354 "p_value (={p_value}) expected to be bigger than {p_value_limit}"
355 );
356 }
357
358 fn uniformity_p_value<F>(f: F, sample_count: usize, distinct_values: u64) -> f64
359 where
360 F: Sync + Fn(usize) -> u64,
361 {
362 let values: Vec<_> = (0..sample_count).into_par_iter().map(&f).collect();
363
364 let mut values_count = HashMap::new();
365
366 for i in &values {
367 assert!(*i < distinct_values, "i {} dv{}", *i, distinct_values);
368
369 *values_count.entry(i).or_insert(0) += 1;
370 }
371
372 let single_expected_count = sample_count as f64 / distinct_values as f64;
373
374 let distance: f64 = (0..distinct_values)
376 .map(|value| *values_count.get(&value).unwrap_or(&0))
377 .map(|count| square(count as f64 - single_expected_count) / single_expected_count)
378 .sum();
379
380 statrs::distribution::ChiSquared::new((distinct_values - 1) as f64)
381 .unwrap()
382 .sf(distance)
383 }
384}