tfhe/shortint/
oprf.rs

1use super::server_key::{
2    apply_programmable_bootstrap_no_ms_noise_reduction, GenericServerKey, LookupTableSize,
3    ShortintBootstrappingKey,
4};
5use super::Ciphertext;
6use crate::core_crypto::fft_impl::common::modulus_switch;
7use crate::core_crypto::prelude::{
8    lwe_ciphertext_plaintext_add_assign, CastFrom, CastInto, CiphertextModulus,
9    CiphertextModulusLog, LweCiphertext, LweCiphertextOwned, LweSize, Plaintext, UnsignedInteger,
10    UnsignedTorus,
11};
12use crate::shortint::atomic_pattern::AtomicPattern;
13use crate::shortint::ciphertext::Degree;
14use crate::shortint::engine::ShortintEngine;
15use crate::shortint::parameters::NoiseLevel;
16use crate::shortint::server_key::generate_lookup_table_no_encode;
17use tfhe_csprng::seeders::Seed;
18
19pub fn sha3_hash<Scalar>(values: &mut [Scalar], seed: Seed)
20where
21    Scalar: UnsignedInteger,
22{
23    use sha3::digest::{ExtendableOutput, Update, XofReader};
24
25    let mut hasher = sha3::Shake256::default();
26
27    let bytes = seed.0.to_le_bytes();
28
29    hasher.update(bytes.as_slice());
30
31    let mut reader = hasher.finalize_xof();
32
33    for value in values {
34        let bytes = bytemuck::bytes_of_mut(value);
35        reader.read(bytes);
36        // On little endian machine this is a no op, on big endian it will swap the bytes
37        *value = value.to_le();
38    }
39}
40pub fn create_random_from_seed<Scalar>(
41    seed: Seed,
42    lwe_size: LweSize,
43    ciphertext_modulus: CiphertextModulus<Scalar>,
44) -> LweCiphertext<Vec<Scalar>>
45where
46    Scalar: UnsignedInteger,
47{
48    let mut ct = LweCiphertext::new(Scalar::ZERO, lwe_size, ciphertext_modulus);
49
50    sha3_hash(ct.get_mut_mask().as_mut(), seed);
51
52    ct
53}
54
55pub fn create_random_from_seed_modulus_switched<Scalar>(
56    seed: Seed,
57    lwe_size: LweSize,
58    log_modulus: CiphertextModulusLog,
59    ciphertext_modulus: CiphertextModulus<Scalar>,
60) -> LweCiphertext<Vec<Scalar>>
61where
62    Scalar: UnsignedInteger,
63{
64    let mut ct = create_random_from_seed(seed, lwe_size, ciphertext_modulus);
65
66    for i in ct.as_mut() {
67        *i = modulus_switch(*i, log_modulus) << (Scalar::BITS - log_modulus.0);
68    }
69
70    ct
71}
72
73/// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[`, using a PBS.
74///
75/// `full_bits_count` is the size of the lwe message, ie the shortint message + carry + padding
76/// bit.
77/// The output in in the form 0000rrr000noise (rbc=3, fbc=7)
78/// The encryted value is oblivious to the server.
79///
80/// It is the reponsiblity of the calling AP to transform this into a shortint ciphertext. The
81/// returned LWE is in the post PBS state, so a Keyswitch might be needed if the order is PBS-KS.
82pub(crate) fn generate_pseudo_random_from_pbs<InputScalar>(
83    bootstrapping_key: &ShortintBootstrappingKey<InputScalar>,
84    seed: Seed,
85    random_bits_count: u64,
86    full_bits_count: u64,
87    ciphertext_modulus: CiphertextModulus<u64>,
88) -> (LweCiphertextOwned<u64>, Degree)
89where
90    InputScalar: UnsignedTorus + CastFrom<usize> + CastInto<usize>,
91{
92    assert!(
93        random_bits_count <= full_bits_count,
94        "The number of random bits asked for (={random_bits_count}) is bigger than full_bits_count (={full_bits_count})"
95    );
96
97    let in_lwe_size = bootstrapping_key.input_lwe_dimension().to_lwe_size();
98
99    let seeded = create_random_from_seed_modulus_switched(
100        seed,
101        in_lwe_size,
102        bootstrapping_key
103            .polynomial_size()
104            .to_blind_rotation_input_modulus_log(),
105        CiphertextModulus::new_native(),
106    );
107
108    let p = 1 << random_bits_count;
109    let degree = p - 1;
110
111    let delta = 1_u64 << (64 - full_bits_count);
112
113    let poly_delta = 2 * bootstrapping_key.polynomial_size().0 as u64 / p;
114
115    let lut_size = LookupTableSize::new(
116        bootstrapping_key.glwe_size(),
117        bootstrapping_key.polynomial_size(),
118    );
119    let acc = generate_lookup_table_no_encode(lut_size, ciphertext_modulus, |x| {
120        (2 * (x / poly_delta) + 1) * delta / 2
121    });
122
123    let out_lwe_size = bootstrapping_key.output_lwe_dimension().to_lwe_size();
124
125    let mut ct = LweCiphertext::new(0, out_lwe_size, ciphertext_modulus);
126
127    ShortintEngine::with_thread_local_mut(|engine| {
128        let buffers = engine.get_computation_buffers();
129
130        apply_programmable_bootstrap_no_ms_noise_reduction(
131            bootstrapping_key,
132            &seeded,
133            &mut ct,
134            &acc,
135            buffers,
136        );
137    });
138
139    lwe_ciphertext_plaintext_add_assign(&mut ct, Plaintext(degree * delta / 2));
140    (ct, Degree(degree))
141}
142
143impl<AP: AtomicPattern> GenericServerKey<AP> {
144    /// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[`
145    /// `2^random_bits_count` must be smaller than the message modulus
146    /// The encryted value is oblivious to the server
147    pub fn generate_oblivious_pseudo_random(
148        &self,
149        seed: Seed,
150        random_bits_count: u64,
151    ) -> Ciphertext {
152        assert!(
153            random_bits_count < 64,
154            "random_bits_count >= 64 is not supported",
155        );
156        assert!(
157            1 << random_bits_count <= self.message_modulus.0,
158            "The range asked for a random value (=[0, 2^{}[) does not fit in the available range [0, {}[",
159            random_bits_count, self.message_modulus.0
160        );
161
162        self.generate_oblivious_pseudo_random_message_and_carry(seed, random_bits_count)
163    }
164
165    /// Uniformly generates a random value in `[0, 2^random_bits_count[`
166    /// The encryted value is oblivious to the server
167    pub(crate) fn generate_oblivious_pseudo_random_message_and_carry(
168        &self,
169        seed: Seed,
170        random_bits_count: u64,
171    ) -> Ciphertext {
172        assert!(
173            self.message_modulus.0.is_power_of_two(),
174            "The message modulus(={}), must be a power of 2 to use the OPRF",
175            self.message_modulus.0
176        );
177        let message_bits_count = self.message_modulus.0.ilog2() as u64;
178
179        assert!(
180            self.carry_modulus.0.is_power_of_two(),
181            "The carry modulus(={}), must be a power of 2 to use the OPRF",
182            self.carry_modulus.0
183        );
184        let carry_bits_count = self.carry_modulus.0.ilog2() as u64;
185
186        assert!(
187            random_bits_count <= carry_bits_count + message_bits_count,
188            "The number of random bits asked for (={random_bits_count}) is bigger than carry_bits_count (={carry_bits_count}) + message_bits_count(={message_bits_count})",
189        );
190
191        let (ct, degree) = self.atomic_pattern.generate_oblivious_pseudo_random(
192            seed,
193            random_bits_count,
194            1 + carry_bits_count + message_bits_count,
195        );
196
197        Ciphertext::new(
198            ct,
199            degree,
200            NoiseLevel::NOMINAL,
201            self.message_modulus,
202            self.carry_modulus,
203            self.atomic_pattern.kind(),
204        )
205    }
206}
207
208#[cfg(test)]
209pub(crate) mod test {
210    use crate::core_crypto::prelude::{decrypt_lwe_ciphertext, LweSecretKey};
211    use crate::shortint::{ClientKey, ServerKey, ShortintParameterSet};
212
213    use super::*;
214
215    use rayon::prelude::*;
216    use statrs::distribution::ContinuousCDF;
217    use std::collections::HashMap;
218    use tfhe_csprng::seeders::Seed;
219
220    fn square(a: f64) -> f64 {
221        a * a
222    }
223
224    #[test]
225    fn oprf_compare_plain_ci_run_filter() {
226        use crate::shortint::gen_keys;
227        use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
228        use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
229
230        let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
231
232        for seed in 0..1000 {
233            oprf_compare_plain_from_seed::<u64>(Seed(seed), &ck, &sk);
234        }
235
236        let (ck, sk) = gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
237
238        for seed in 0..1000 {
239            oprf_compare_plain_from_seed::<u32>(Seed(seed), &ck, &sk);
240        }
241    }
242
243    fn oprf_compare_plain_from_seed<Scalar: UnsignedInteger + CastFrom<u64> + CastInto<u64>>(
244        seed: Seed,
245        ck: &ClientKey,
246        sk: &ServerKey,
247    ) {
248        let params = ck.parameters;
249
250        let random_bits_count = 2;
251
252        let input_p = 2 * params.polynomial_size().0 as u64;
253
254        let log_input_p = input_p.ilog2() as usize;
255
256        let p_prime = 1 << random_bits_count;
257
258        let output_p = 2 * params.carry_modulus().0 * params.message_modulus().0;
259
260        let poly_delta = 2 * params.polynomial_size().0 as u64 / p_prime;
261
262        let img = sk.generate_oblivious_pseudo_random(seed, random_bits_count);
263
264        let lwe_size = params.lwe_dimension().to_lwe_size();
265
266        let ct = create_random_from_seed_modulus_switched(
267            seed,
268            lwe_size,
269            params
270                .polynomial_size()
271                .to_blind_rotation_input_modulus_log(),
272            CiphertextModulus::new_native(),
273        );
274
275        let sk = LweSecretKey::from_container(
276            ck.small_lwe_secret_key()
277                .as_ref()
278                .iter()
279                .copied()
280                .map(|x| Scalar::cast_from(x))
281                .collect::<Vec<_>>(),
282        );
283
284        let plain_prf_input = CastInto::<u64>::cast_into(
285            decrypt_lwe_ciphertext(&sk, &ct)
286                .0
287                .wrapping_add(Scalar::ONE << (Scalar::BITS - log_input_p - 1))
288                >> (Scalar::BITS - log_input_p),
289        );
290
291        let half_negacyclic_part = |x| 2 * (x / poly_delta) + 1;
292
293        let negacyclic_part = |x| {
294            assert!(x < input_p);
295            if x < input_p / 2 {
296                half_negacyclic_part(x)
297            } else {
298                2 * output_p - half_negacyclic_part(x - (input_p / 2))
299            }
300        };
301
302        let prf = |x| {
303            let a = (negacyclic_part(x) + p_prime - 1) % (2 * output_p);
304            assert!(a % 2 == 0);
305            a / 2
306        };
307
308        let expected_output = prf(plain_prf_input);
309        let output = ck.decrypt_message_and_carry(&img);
310
311        assert!(output < p_prime);
312        assert_eq!(output, expected_output);
313    }
314
315    #[test]
316    fn oprf_test_uniformity_ci_run_filter() {
317        let sample_count: usize = 100_000;
318
319        let p_value_limit: f64 = 0.000_01;
320
321        use crate::shortint::gen_keys;
322        use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
323        use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
324
325        for params in [
326            ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
327            ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
328        ] {
329            let (ck, sk) = gen_keys(params);
330
331            let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
332                test_uniformity(sample_count, p_value_limit, distinct_values, f)
333            };
334
335            let random_bits_count = 2;
336
337            test_uniformity(1 << random_bits_count, &|seed| {
338                let img =
339                    sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
340
341                ck.decrypt_message_and_carry(&img)
342            });
343        }
344    }
345
346    pub fn test_uniformity<F>(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F)
347    where
348        F: Sync + Fn(usize) -> u64,
349    {
350        let p_value = uniformity_p_value(f, sample_count, distinct_values);
351
352        assert!(
353            p_value_limit < p_value,
354            "p_value (={p_value}) expected to be bigger than {p_value_limit}"
355        );
356    }
357
358    fn uniformity_p_value<F>(f: F, sample_count: usize, distinct_values: u64) -> f64
359    where
360        F: Sync + Fn(usize) -> u64,
361    {
362        let values: Vec<_> = (0..sample_count).into_par_iter().map(&f).collect();
363
364        let mut values_count = HashMap::new();
365
366        for i in &values {
367            assert!(*i < distinct_values, "i {} dv{}", *i, distinct_values);
368
369            *values_count.entry(i).or_insert(0) += 1;
370        }
371
372        let single_expected_count = sample_count as f64 / distinct_values as f64;
373
374        // https://en.wikipedia.org/wiki/Pearson's_chi-squared_test
375        let distance: f64 = (0..distinct_values)
376            .map(|value| *values_count.get(&value).unwrap_or(&0))
377            .map(|count| square(count as f64 - single_expected_count) / single_expected_count)
378            .sum();
379
380        statrs::distribution::ChiSquared::new((distinct_values - 1) as f64)
381            .unwrap()
382            .sf(distance)
383    }
384}