spiral_rs/
client.rs

1use crate::{
2    arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
3};
4use rand::{Rng, SeedableRng};
5use rand_chacha::ChaCha20Rng;
6use std::{iter::once, mem::size_of};
7use subtle::ConditionallySelectable;
8use subtle::ConstantTimeEq;
9
10pub type Seed = <ChaCha20Rng as SeedableRng>::Seed;
11pub const SEED_LENGTH: usize = 32;
12
13pub const DEFAULT_PARAMS: &'static str = r#"
14    {"n": 2,
15    "nu_1": 10,
16    "nu_2": 6,
17    "p": 512,
18    "q2_bits": 21,
19    "s_e": 85.83255142749422,
20    "t_gsw": 10,
21    "t_conv": 4,
22    "t_exp_left": 16,
23    "t_exp_right": 56,
24    "instances": 11,
25    "db_item_size": 100000 }
26"#;
27
28const UUID_V4_LEN: usize = 36;
29
30fn new_vec_raw<'a>(
31    params: &'a Params,
32    num: usize,
33    rows: usize,
34    cols: usize,
35) -> Vec<PolyMatrixRaw<'a>> {
36    let mut v = Vec::with_capacity(num);
37    for _ in 0..num {
38        v.push(PolyMatrixRaw::zero(params, rows, cols));
39    }
40    v
41}
42
43fn get_inv_from_rng(params: &Params, rng: &mut ChaCha20Rng) -> u64 {
44    params.modulus - (rng.gen::<u64>() % params.modulus)
45}
46
47fn mat_sz_bytes_excl_first_row(a: &PolyMatrixRaw) -> usize {
48    (a.rows - 1) * a.cols * a.params.poly_len * size_of::<u64>()
49}
50
51fn serialize_polymatrix_for_rng(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
52    let offs = a.cols * a.params.poly_len; // skip the first row
53    for i in 0..(a.rows - 1) * a.cols * a.params.poly_len {
54        vec.extend_from_slice(&u64::to_ne_bytes(a.data[offs + i]));
55    }
56}
57
58fn serialize_vec_polymatrix_for_rng(vec: &mut Vec<u8>, a: &Vec<PolyMatrixRaw>) {
59    for i in 0..a.len() {
60        serialize_polymatrix_for_rng(vec, &a[i]);
61    }
62}
63
64fn deserialize_polymatrix_rng(a: &mut PolyMatrixRaw, data: &[u8], rng: &mut ChaCha20Rng) -> usize {
65    let (first_row, rest) = a
66        .data
67        .as_mut_slice()
68        .split_at_mut(a.cols * a.params.poly_len);
69    for i in 0..first_row.len() {
70        first_row[i] = get_inv_from_rng(a.params, rng);
71    }
72    for (i, chunk) in data.chunks(size_of::<u64>()).enumerate() {
73        rest[i] = u64::from_ne_bytes(chunk.try_into().unwrap());
74    }
75    mat_sz_bytes_excl_first_row(a)
76}
77
78fn deserialize_vec_polymatrix_rng(
79    a: &mut Vec<PolyMatrixRaw>,
80    data: &[u8],
81    rng: &mut ChaCha20Rng,
82) -> usize {
83    let mut chunks = data.chunks(mat_sz_bytes_excl_first_row(&a[0]));
84    let mut bytes_read = 0;
85    for i in 0..a.len() {
86        bytes_read += deserialize_polymatrix_rng(&mut a[i], chunks.next().unwrap(), rng);
87    }
88    bytes_read
89}
90
91fn extract_excl_rng_data(v_buf: &[u64]) -> Vec<u64> {
92    let mut out = Vec::new();
93    for i in 0..v_buf.len() {
94        if i % 2 == 1 {
95            out.push(v_buf[i]);
96        }
97    }
98    out
99}
100
101fn interleave_rng_data(params: &Params, v_buf: &[u64], rng: &mut ChaCha20Rng) -> Vec<u64> {
102    let mut out = Vec::new();
103
104    let mut reg_cts = Vec::new();
105    for _ in 0..params.num_expanded() {
106        let mut sigma = PolyMatrixRaw::zero(&params, 2, 1);
107        for z in 0..params.poly_len {
108            sigma.data[z] = get_inv_from_rng(params, rng);
109        }
110        reg_cts.push(sigma.ntt());
111    }
112    // reorient into server's preferred indexing
113    let reg_cts_buf_words = params.num_expanded() * 2 * params.poly_len;
114    let mut reg_cts_buf = vec![0u64; reg_cts_buf_words];
115    reorient_reg_ciphertexts(params, reg_cts_buf.as_mut_slice(), &reg_cts);
116
117    assert_eq!(reg_cts_buf_words, 2 * v_buf.len());
118
119    for i in 0..v_buf.len() {
120        out.push(reg_cts_buf[2 * i]);
121        out.push(v_buf[i]);
122    }
123    out
124}
125
126pub struct PublicParameters<'a> {
127    pub v_packing: Vec<PolyMatrixNTT<'a>>, // Ws
128    pub v_expansion_left: Option<Vec<PolyMatrixNTT<'a>>>,
129    pub v_expansion_right: Option<Vec<PolyMatrixNTT<'a>>>,
130    pub v_conversion: Option<Vec<PolyMatrixNTT<'a>>>, // V
131    pub seed: Option<Seed>,
132}
133
134impl<'a> PublicParameters<'a> {
135    pub fn init(params: &'a Params) -> Self {
136        if params.expand_queries {
137            PublicParameters {
138                v_packing: Vec::new(),
139                v_expansion_left: Some(Vec::new()),
140                v_expansion_right: Some(Vec::new()),
141                v_conversion: Some(Vec::new()),
142                seed: None,
143            }
144        } else {
145            PublicParameters {
146                v_packing: Vec::new(),
147                v_expansion_left: None,
148                v_expansion_right: None,
149                v_conversion: None,
150                seed: None,
151            }
152        }
153    }
154
155    fn from_ntt_alloc_vec(v: &Vec<PolyMatrixNTT<'a>>) -> Option<Vec<PolyMatrixRaw<'a>>> {
156        Some(v.iter().map(from_ntt_alloc).collect())
157    }
158
159    fn from_ntt_alloc_opt_vec(
160        v: &Option<Vec<PolyMatrixNTT<'a>>>,
161    ) -> Option<Vec<PolyMatrixRaw<'a>>> {
162        Some(v.as_ref()?.iter().map(from_ntt_alloc).collect())
163    }
164
165    fn to_ntt_alloc_vec(v: &Vec<PolyMatrixRaw<'a>>) -> Option<Vec<PolyMatrixNTT<'a>>> {
166        Some(v.iter().map(to_ntt_alloc).collect())
167    }
168
169    pub fn to_raw(&self) -> Vec<Option<Vec<PolyMatrixRaw>>> {
170        vec![
171            Self::from_ntt_alloc_vec(&self.v_packing),
172            Self::from_ntt_alloc_opt_vec(&self.v_expansion_left),
173            Self::from_ntt_alloc_opt_vec(&self.v_expansion_right),
174            Self::from_ntt_alloc_opt_vec(&self.v_conversion),
175        ]
176    }
177
178    pub fn serialize(&self) -> Vec<u8> {
179        let mut data = Vec::new();
180        if self.seed.is_some() {
181            let seed = self.seed.as_ref().unwrap();
182            data.extend(seed);
183        }
184        for v in self.to_raw().iter() {
185            if v.is_some() {
186                serialize_vec_polymatrix_for_rng(&mut data, v.as_ref().unwrap());
187            }
188        }
189        data
190    }
191
192    pub fn deserialize(params: &'a Params, data: &[u8]) -> Self {
193        assert_eq!(params.setup_bytes(), data.len());
194
195        let mut idx = 0;
196
197        let seed = data[0..SEED_LENGTH].try_into().unwrap();
198        let mut rng = ChaCha20Rng::from_seed(seed);
199        idx += SEED_LENGTH;
200
201        let mut v_packing = new_vec_raw(params, params.n, params.n + 1, params.t_conv);
202        idx += deserialize_vec_polymatrix_rng(&mut v_packing, &data[idx..], &mut rng);
203
204        if params.expand_queries {
205            let mut v_expansion_left = new_vec_raw(params, params.g(), 2, params.t_exp_left);
206            idx += deserialize_vec_polymatrix_rng(&mut v_expansion_left, &data[idx..], &mut rng);
207
208            let mut v_expansion_right =
209                new_vec_raw(params, params.stop_round() + 1, 2, params.t_exp_right);
210            idx += deserialize_vec_polymatrix_rng(&mut v_expansion_right, &data[idx..], &mut rng);
211
212            let mut v_conversion = new_vec_raw(params, 1, 2, 2 * params.t_conv);
213            _ = deserialize_vec_polymatrix_rng(&mut v_conversion, &data[idx..], &mut rng);
214
215            Self {
216                v_packing: Self::to_ntt_alloc_vec(&v_packing).unwrap(),
217                v_expansion_left: Self::to_ntt_alloc_vec(&v_expansion_left),
218                v_expansion_right: Self::to_ntt_alloc_vec(&v_expansion_right),
219                v_conversion: Self::to_ntt_alloc_vec(&v_conversion),
220                seed: Some(seed),
221            }
222        } else {
223            Self {
224                v_packing: Self::to_ntt_alloc_vec(&v_packing).unwrap(),
225                v_expansion_left: None,
226                v_expansion_right: None,
227                v_conversion: None,
228                seed: Some(seed),
229            }
230        }
231    }
232}
233
234pub struct Query<'a> {
235    pub ct: Option<PolyMatrixRaw<'a>>,
236    pub v_buf: Option<Vec<u64>>,
237    pub v_ct: Option<Vec<PolyMatrixRaw<'a>>>,
238    pub seed: Option<Seed>,
239}
240
241impl<'a> Query<'a> {
242    pub fn empty() -> Self {
243        Query {
244            ct: None,
245            v_ct: None,
246            v_buf: None,
247            seed: None,
248        }
249    }
250
251    pub fn serialize(&self) -> Vec<u8> {
252        let mut data = Vec::new();
253        if self.seed.is_some() {
254            let seed = self.seed.as_ref().unwrap();
255            data.extend(seed);
256        }
257        if self.ct.is_some() {
258            let ct = self.ct.as_ref().unwrap();
259            serialize_polymatrix_for_rng(&mut data, &ct);
260        }
261        if self.v_buf.is_some() {
262            let v_buf = self.v_buf.as_ref().unwrap();
263            let v_buf_extracted = extract_excl_rng_data(&v_buf);
264            data.extend(v_buf_extracted.iter().map(|x| x.to_ne_bytes()).flatten());
265        }
266        if self.v_ct.is_some() {
267            let v_ct = self.v_ct.as_ref().unwrap();
268            for x in v_ct {
269                serialize_polymatrix_for_rng(&mut data, x);
270            }
271        }
272        data
273    }
274
275    pub fn deserialize(params: &'a Params, mut data: &[u8]) -> Self {
276        assert_eq!(params.query_bytes(), data.len());
277
278        let mut out = Query::empty();
279        let seed = data[0..SEED_LENGTH].try_into().unwrap();
280        out.seed = Some(seed);
281        let mut rng = ChaCha20Rng::from_seed(seed);
282        data = &data[SEED_LENGTH..];
283        if params.expand_queries {
284            let mut ct = PolyMatrixRaw::zero(params, 2, 1);
285            deserialize_polymatrix_rng(&mut ct, data, &mut rng);
286            out.ct = Some(ct);
287        } else {
288            let v_buf_bytes = params.query_v_buf_bytes();
289            let v_buf: Vec<u64> = (&data[..v_buf_bytes])
290                .chunks(size_of::<u64>())
291                .map(|x| u64::from_ne_bytes(x.try_into().unwrap()))
292                .collect();
293            let v_buf_interleaved = interleave_rng_data(params, &v_buf, &mut rng);
294            out.v_buf = Some(v_buf_interleaved);
295
296            let mut v_ct = new_vec_raw(params, params.db_dim_2, 2, 2 * params.t_gsw);
297            deserialize_vec_polymatrix_rng(&mut v_ct, &data[v_buf_bytes..], &mut rng);
298            out.v_ct = Some(v_ct);
299        }
300        out
301    }
302}
303
304fn matrix_with_identity<'a>(p: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
305    assert_eq!(p.cols, 1);
306    let mut r = PolyMatrixRaw::zero(p.params, p.rows, p.rows + 1);
307    r.copy_into(p, 0, 0);
308    r.copy_into(&PolyMatrixRaw::identity(p.params, p.rows, p.rows), 0, 1);
309    r
310}
311
312fn params_with_moduli(params: &Params, moduli: &Vec<u64>) -> Params {
313    Params::init(
314        params.poly_len,
315        moduli,
316        params.noise_width,
317        params.n,
318        params.pt_modulus,
319        params.q2_bits,
320        params.t_conv,
321        params.t_exp_left,
322        params.t_exp_right,
323        params.t_gsw,
324        params.expand_queries,
325        params.db_dim_1,
326        params.db_dim_2,
327        params.instances,
328        params.db_item_size,
329    )
330}
331
332pub struct Client<'a> {
333    params: &'a Params,
334    sk_gsw: PolyMatrixRaw<'a>,
335    sk_reg: PolyMatrixRaw<'a>,
336    sk_gsw_full: PolyMatrixRaw<'a>,
337    sk_reg_full: PolyMatrixRaw<'a>,
338    dg: DiscreteGaussian,
339}
340
341impl<'a> Client<'a> {
342    pub fn init(params: &'a Params) -> Self {
343        let sk_gsw_dims = params.get_sk_gsw();
344        let sk_reg_dims = params.get_sk_reg();
345        let sk_gsw = PolyMatrixRaw::zero(params, sk_gsw_dims.0, sk_gsw_dims.1);
346        let sk_reg = PolyMatrixRaw::zero(params, sk_reg_dims.0, sk_reg_dims.1);
347        let sk_gsw_full = matrix_with_identity(&sk_gsw);
348        let sk_reg_full = matrix_with_identity(&sk_reg);
349
350        let dg = DiscreteGaussian::init(params.noise_width);
351
352        Self {
353            params,
354            sk_gsw,
355            sk_reg,
356            sk_gsw_full,
357            sk_reg_full,
358            dg,
359        }
360    }
361
362    #[allow(dead_code)]
363    pub(crate) fn get_sk_reg(&self) -> &PolyMatrixRaw<'a> {
364        &self.sk_reg
365    }
366
367    fn get_fresh_gsw_public_key(
368        &self,
369        m: usize,
370        rng: &mut ChaCha20Rng,
371        rng_pub: &mut ChaCha20Rng,
372    ) -> PolyMatrixRaw<'a> {
373        let params = self.params;
374        let n = params.n;
375
376        let a = PolyMatrixRaw::random_rng(params, 1, m, rng_pub);
377        let e = PolyMatrixRaw::noise(params, n, m, &self.dg, rng);
378        let a_inv = -&a;
379        let b_p = &self.sk_gsw.ntt() * &a.ntt();
380        let b = &e.ntt() + &b_p;
381        let p = stack(&a_inv, &b.raw());
382        p
383    }
384
385    fn get_regev_sample(
386        &self,
387        rng: &mut ChaCha20Rng,
388        rng_pub: &mut ChaCha20Rng,
389    ) -> PolyMatrixNTT<'a> {
390        let params = self.params;
391        let a = PolyMatrixRaw::random_rng(params, 1, 1, rng_pub);
392        let e = PolyMatrixRaw::noise(params, 1, 1, &self.dg, rng);
393        let b_p = &self.sk_reg.ntt() * &a.ntt();
394        let b = &e.ntt() + &b_p;
395        let mut p = PolyMatrixNTT::zero(params, 2, 1);
396        p.copy_into(&(-&a).ntt(), 0, 0);
397        p.copy_into(&b, 1, 0);
398        p
399    }
400
401    fn get_fresh_reg_public_key(
402        &self,
403        m: usize,
404        rng: &mut ChaCha20Rng,
405        rng_pub: &mut ChaCha20Rng,
406    ) -> PolyMatrixNTT<'a> {
407        let params = self.params;
408
409        let mut p = PolyMatrixNTT::zero(params, 2, m);
410
411        for i in 0..m {
412            p.copy_into(&self.get_regev_sample(rng, rng_pub), 0, i);
413        }
414        p
415    }
416
417    fn encrypt_matrix_gsw(
418        &self,
419        ag: &PolyMatrixNTT<'a>,
420        rng: &mut ChaCha20Rng,
421        rng_pub: &mut ChaCha20Rng,
422    ) -> PolyMatrixNTT<'a> {
423        let mx = ag.cols;
424        let p = self.get_fresh_gsw_public_key(mx, rng, rng_pub);
425        let res = &(p.ntt()) + &(ag.pad_top(1));
426        res
427    }
428
429    pub fn encrypt_matrix_reg(
430        &self,
431        a: &PolyMatrixNTT<'a>,
432        rng: &mut ChaCha20Rng,
433        rng_pub: &mut ChaCha20Rng,
434    ) -> PolyMatrixNTT<'a> {
435        let m = a.cols;
436        let p = self.get_fresh_reg_public_key(m, rng, rng_pub);
437        &p + &a.pad_top(1)
438    }
439
440    pub fn decrypt_matrix_reg(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
441        &self.sk_reg_full.ntt() * a
442    }
443
444    pub fn decrypt_matrix_gsw(&self, a: &PolyMatrixNTT<'a>) -> PolyMatrixNTT<'a> {
445        &self.sk_gsw_full.ntt() * a
446    }
447
448    fn generate_expansion_params(
449        &self,
450        num_exp: usize,
451        m_exp: usize,
452        rng: &mut ChaCha20Rng,
453        rng_pub: &mut ChaCha20Rng,
454    ) -> Vec<PolyMatrixNTT<'a>> {
455        let params = self.params;
456        let g_exp = build_gadget(params, 1, m_exp);
457        let g_exp_ntt = g_exp.ntt();
458        let mut res = Vec::new();
459
460        for i in 0..num_exp {
461            let t = (params.poly_len / (1 << i)) + 1;
462            let tau_sk_reg = automorph_alloc(&self.sk_reg, t);
463            let prod = &tau_sk_reg.ntt() * &g_exp_ntt;
464            let w_exp_i = self.encrypt_matrix_reg(&prod, rng, rng_pub);
465            res.push(w_exp_i);
466        }
467        res
468    }
469
470    pub fn generate_keys_from_seed(&mut self, seed: Seed) -> PublicParameters<'a> {
471        self.generate_keys_impl(&mut ChaCha20Rng::from_seed(seed))
472    }
473
474    pub fn generate_keys(&mut self) -> PublicParameters<'a> {
475        self.generate_keys_impl(&mut ChaCha20Rng::from_entropy())
476    }
477
478    pub fn generate_secret_keys_from_seed(&mut self, seed: Seed) {
479        self.generate_secret_keys_impl(&mut ChaCha20Rng::from_seed(seed))
480    }
481
482    pub fn generate_secret_keys(&mut self) {
483        self.generate_secret_keys_impl(&mut ChaCha20Rng::from_entropy())
484    }
485
486    pub fn generate_keys_optional(
487        &mut self,
488        seed: Seed,
489        generate_pub_params: bool,
490    ) -> Option<Vec<u8>> {
491        if generate_pub_params {
492            Some(self.generate_keys_from_seed(seed).serialize())
493        } else {
494            self.generate_secret_keys_from_seed(seed);
495            None
496        }
497    }
498
499    fn generate_secret_keys_impl(&mut self, rng: &mut ChaCha20Rng) {
500        self.dg.sample_matrix(&mut self.sk_gsw, rng);
501        self.dg.sample_matrix(&mut self.sk_reg, rng);
502        self.sk_gsw_full = matrix_with_identity(&self.sk_gsw);
503        self.sk_reg_full = matrix_with_identity(&self.sk_reg);
504    }
505
506    fn generate_keys_impl(&mut self, rng: &mut ChaCha20Rng) -> PublicParameters<'a> {
507        let params = self.params;
508
509        self.generate_secret_keys_impl(rng);
510        let sk_reg_ntt = to_ntt_alloc(&self.sk_reg);
511
512        let mut rng = ChaCha20Rng::from_entropy();
513        let mut pp = PublicParameters::init(params);
514        let pp_seed = rng.gen();
515        pp.seed = Some(pp_seed);
516        let mut rng_pub = ChaCha20Rng::from_seed(pp_seed);
517
518        // Params for packing
519        let gadget_conv = build_gadget(params, 1, params.t_conv);
520        let gadget_conv_ntt = to_ntt_alloc(&gadget_conv);
521        for i in 0..params.n {
522            let scaled = scalar_multiply_alloc(&sk_reg_ntt, &gadget_conv_ntt);
523            let mut ag = PolyMatrixNTT::zero(params, params.n, params.t_conv);
524            ag.copy_into(&scaled, i, 0);
525            let w = self.encrypt_matrix_gsw(&ag, &mut rng, &mut rng_pub);
526            pp.v_packing.push(w);
527        }
528
529        if params.expand_queries {
530            // Params for expansion
531            pp.v_expansion_left = Some(self.generate_expansion_params(
532                params.g(),
533                params.t_exp_left,
534                &mut rng,
535                &mut rng_pub,
536            ));
537            pp.v_expansion_right = Some(self.generate_expansion_params(
538                params.stop_round() + 1,
539                params.t_exp_right,
540                &mut rng,
541                &mut rng_pub,
542            ));
543
544            // Params for converison
545            let g_conv = build_gadget(params, 2, 2 * params.t_conv);
546            let sk_reg_ntt = self.sk_reg.ntt();
547            let sk_reg_squared_ntt = &sk_reg_ntt * &sk_reg_ntt;
548            pp.v_conversion = Some(Vec::from_iter(once(PolyMatrixNTT::zero(
549                params,
550                2,
551                2 * params.t_conv,
552            ))));
553            for i in 0..2 * params.t_conv {
554                let sigma;
555                if i % 2 == 0 {
556                    let val = g_conv.get_poly(0, i)[0];
557                    sigma = &sk_reg_squared_ntt * &single_poly(params, val).ntt();
558                } else {
559                    let val = g_conv.get_poly(1, i)[0];
560                    sigma = &sk_reg_ntt * &single_poly(params, val).ntt();
561                }
562                let ct = self.encrypt_matrix_reg(&sigma, &mut rng, &mut rng_pub);
563                pp.v_conversion.as_mut().unwrap()[0].copy_into(&ct, 0, i);
564            }
565        }
566
567        pp
568    }
569
570    pub fn generate_query(&self, idx_target: usize) -> Query<'a> {
571        let params = self.params;
572        let further_dims = params.db_dim_2;
573        let idx_dim0 = idx_target / (1 << further_dims);
574        let idx_further = idx_target % (1 << further_dims);
575        let scale_k = params.modulus / params.pt_modulus;
576        let bits_per = get_bits_per(params, params.t_gsw);
577
578        let mut rng = ChaCha20Rng::from_entropy();
579
580        let mut query = Query::empty();
581        let query_seed = ChaCha20Rng::from_entropy().gen();
582        query.seed = Some(query_seed);
583        let mut rng_pub = ChaCha20Rng::from_seed(query_seed);
584        if params.expand_queries {
585            // pack query into single ciphertext
586            let mut sigma = PolyMatrixRaw::zero(params, 1, 1);
587            let inv_2_g_first = invert_uint_mod(1 << params.g(), params.modulus).unwrap();
588            let inv_2_g_rest =
589                invert_uint_mod(1 << (params.stop_round() + 1), params.modulus).unwrap();
590
591            if params.db_dim_2 == 0 {
592                for i in 0..(1 << params.db_dim_1) {
593                    sigma.data[i].conditional_assign(&scale_k, (i as u64).ct_eq(&(idx_dim0 as u64)))
594                }
595
596                for i in 0..params.poly_len {
597                    sigma.data[i] = multiply_uint_mod(sigma.data[i], inv_2_g_first, params.modulus);
598                }
599            } else {
600                for i in 0..(1 << params.db_dim_1) {
601                    sigma.data[2 * i]
602                        .conditional_assign(&scale_k, (i as u64).ct_eq(&(idx_dim0 as u64)))
603                }
604
605                for i in 0..further_dims as u64 {
606                    let mask = 1 << i;
607                    let bit = ((idx_further as u64) & mask).ct_eq(&mask);
608                    for j in 0..params.t_gsw {
609                        let val = u64::conditional_select(&0, &(1u64 << (bits_per * j)), bit);
610                        let idx = (i as usize) * params.t_gsw + (j as usize);
611                        sigma.data[2 * idx + 1] = val;
612                    }
613                }
614
615                for i in 0..params.poly_len / 2 {
616                    sigma.data[2 * i] =
617                        multiply_uint_mod(sigma.data[2 * i], inv_2_g_first, params.modulus);
618                    sigma.data[2 * i + 1] =
619                        multiply_uint_mod(sigma.data[2 * i + 1], inv_2_g_rest, params.modulus);
620                }
621            }
622
623            query.ct = Some(from_ntt_alloc(&self.encrypt_matrix_reg(
624                &to_ntt_alloc(&sigma),
625                &mut rng,
626                &mut rng_pub,
627            )));
628        } else {
629            let num_expanded = 1 << params.db_dim_1;
630            let mut sigma_v = Vec::<PolyMatrixNTT>::new();
631
632            // generate regev ciphertexts
633            let reg_cts_buf_words = num_expanded * 2 * params.poly_len;
634            let mut reg_cts_buf = vec![0u64; reg_cts_buf_words];
635            let mut reg_cts = Vec::<PolyMatrixNTT>::new();
636            for i in 0..num_expanded {
637                let value = ((i == idx_dim0) as u64) * scale_k;
638                let sigma = PolyMatrixRaw::single_value(&params, value);
639                reg_cts.push(self.encrypt_matrix_reg(
640                    &to_ntt_alloc(&sigma),
641                    &mut rng,
642                    &mut rng_pub,
643                ));
644            }
645            // reorient into server's preferred indexing
646            reorient_reg_ciphertexts(self.params, reg_cts_buf.as_mut_slice(), &reg_cts);
647
648            // generate GSW ciphertexts
649            for i in 0..further_dims {
650                let bit = ((idx_further as u64) & (1 << (i as u64))) >> (i as u64);
651                let mut ct_gsw = PolyMatrixNTT::zero(&params, 2, 2 * params.t_gsw);
652
653                for j in 0..params.t_gsw {
654                    let value = (1u64 << (bits_per * j)) * bit;
655                    let sigma = PolyMatrixRaw::single_value(&params, value);
656                    let sigma_ntt = to_ntt_alloc(&sigma);
657
658                    // important to rng in the right order here
659                    let prod = &to_ntt_alloc(&self.sk_reg) * &sigma_ntt;
660                    let ct = &self.encrypt_matrix_reg(&prod, &mut rng, &mut rng_pub);
661                    ct_gsw.copy_into(ct, 0, 2 * j);
662
663                    let ct = &self.encrypt_matrix_reg(&sigma_ntt, &mut rng, &mut rng_pub);
664                    ct_gsw.copy_into(ct, 0, 2 * j + 1);
665                }
666                sigma_v.push(ct_gsw);
667            }
668
669            query.v_buf = Some(reg_cts_buf);
670            query.v_ct = Some(sigma_v.iter().map(|x| from_ntt_alloc(x)).collect());
671        }
672        query
673    }
674
675    pub fn generate_full_query(&self, id: &str, idx_target: usize) -> Vec<u8> {
676        assert_eq!(id.len(), UUID_V4_LEN);
677        let query = self.generate_query(idx_target);
678        let mut query_buf = query.serialize();
679        let mut full_query_buf = id.as_bytes().to_vec();
680        full_query_buf.append(&mut query_buf);
681        full_query_buf
682    }
683
684    pub fn decode_response(&self, data: &[u8]) -> Vec<u8> {
685        /*
686            0. NTT over q2 the secret key
687
688            1. read first row in q2_bit chunks
689            2. read rest in q1_bit chunks
690            3. NTT over q2 the first row
691            4. Multiply the results of (0) and (3)
692            5. Divide and round correctly
693        */
694        let params = self.params;
695        let p = params.pt_modulus;
696        let p_bits = log2_ceil(params.pt_modulus);
697        let q1 = 4 * params.pt_modulus;
698        let q1_bits = log2_ceil(q1) as usize;
699        let q2 = Q2_VALUES[params.q2_bits as usize];
700        let q2_bits = params.q2_bits as usize;
701
702        let q2_params = params_with_moduli(params, &vec![q2]);
703
704        // this only needs to be done during keygen
705        let mut sk_gsw_q2 = PolyMatrixRaw::zero(&q2_params, params.n, 1);
706        for i in 0..params.poly_len * params.n {
707            sk_gsw_q2.data[i] = recenter(self.sk_gsw.data[i], params.modulus, q2);
708        }
709        let mut sk_gsw_q2_ntt = PolyMatrixNTT::zero(&q2_params, params.n, 1);
710        to_ntt(&mut sk_gsw_q2_ntt, &sk_gsw_q2);
711
712        let mut result = PolyMatrixRaw::zero(&params, params.instances * params.n, params.n);
713
714        let mut bit_offs = 0;
715        for instance in 0..params.instances {
716            // this must be done during decoding
717            let mut first_row = PolyMatrixRaw::zero(&q2_params, 1, params.n);
718            let mut rest_rows = PolyMatrixRaw::zero(&params, params.n, params.n);
719            for i in 0..params.n * params.poly_len {
720                first_row.data[i] = read_arbitrary_bits(data, bit_offs, q2_bits);
721                bit_offs += q2_bits;
722            }
723            for i in 0..params.n * params.n * params.poly_len {
724                rest_rows.data[i] = read_arbitrary_bits(data, bit_offs, q1_bits);
725                bit_offs += q1_bits;
726            }
727
728            let mut first_row_q2 = PolyMatrixNTT::zero(&q2_params, 1, params.n);
729            to_ntt(&mut first_row_q2, &first_row);
730
731            let sk_prod = (&sk_gsw_q2_ntt * &first_row_q2).raw();
732
733            let q1_i64 = q1 as i64;
734            let q2_i64 = q2 as i64;
735            let p_i128 = p as i128;
736            for i in 0..params.n * params.n * params.poly_len {
737                let mut val_first = sk_prod.data[i] as i64;
738                if val_first >= q2_i64 / 2 {
739                    val_first -= q2_i64;
740                }
741                let mut val_rest = rest_rows.data[i] as i64;
742                if val_rest >= q1_i64 / 2 {
743                    val_rest -= q1_i64;
744                }
745
746                let denom = (q2 * (q1 / p)) as i64;
747
748                let mut r = val_first * q1_i64;
749                r += val_rest * q2_i64;
750
751                // divide r by q2, rounding
752                let sign: i64 = if r >= 0 { 1 } else { -1 };
753                let mut res = ((r + sign * (denom / 2)) as i128) / (denom as i128);
754                res = (res + (denom as i128 / p_i128) * (p_i128) + 2 * (p_i128)) % (p_i128);
755                let idx = instance * params.n * params.n * params.poly_len + i;
756                result.data[idx] = res as u64;
757            }
758        }
759
760        // println!("{:?}", result.data.as_slice().to_vec());
761        result.to_vec(p_bits as usize, params.modp_words_per_chunk())
762    }
763}
764
765#[cfg(test)]
766mod test {
767    use super::*;
768
769    fn get_params() -> Params {
770        get_short_keygen_params()
771    }
772
773    #[test]
774    fn init_is_correct() {
775        let params = get_params();
776        let client = Client::init(&params);
777
778        assert_eq!(*client.params, params);
779    }
780
781    #[test]
782    fn keygen_is_correct() {
783        let params = get_params();
784        let mut client = Client::init(&params);
785
786        _ = client.generate_keys();
787
788        let threshold = (10.0 * params.noise_width) as u64;
789
790        for i in 0..client.sk_gsw.data.len() {
791            let val = client.sk_gsw.data[i];
792            assert!((val < threshold) || ((params.modulus - val) < threshold));
793        }
794    }
795
796    fn get_vec(v: &Vec<PolyMatrixNTT>) -> Vec<u64> {
797        v.iter().map(|d| d.as_slice().to_vec()).flatten().collect()
798    }
799
800    fn public_parameters_serialization_is_correct_for_params(params: Params) {
801        let mut client = Client::init(&params);
802        let pub_params = client.generate_keys();
803
804        let serialized1 = pub_params.serialize();
805        let deserialized1 = PublicParameters::deserialize(&params, &serialized1);
806        let serialized2 = deserialized1.serialize();
807
808        assert_eq!(serialized1, serialized2);
809        assert_eq!(
810            get_vec(&pub_params.v_packing),
811            get_vec(&deserialized1.v_packing)
812        );
813
814        println!(
815            "packing mats (bytes) {}",
816            get_vec(&pub_params.v_packing).len() * 8
817        );
818        println!("total size   (bytes) {}", serialized1.len());
819        if pub_params.v_conversion.is_some() {
820            let l1 = get_vec(&pub_params.v_conversion.unwrap());
821            assert_eq!(l1, get_vec(&deserialized1.v_conversion.unwrap()));
822            println!("conv mats (bytes) {}", l1.len() * 8);
823        }
824        if pub_params.v_expansion_left.is_some() {
825            let l1 = get_vec(&pub_params.v_expansion_left.unwrap());
826            assert_eq!(l1, get_vec(&deserialized1.v_expansion_left.unwrap()));
827            println!("exp left (bytes) {}", l1.len() * 8);
828        }
829        if pub_params.v_expansion_right.is_some() {
830            let l1 = get_vec(&pub_params.v_expansion_right.unwrap());
831            assert_eq!(l1, get_vec(&deserialized1.v_expansion_right.unwrap()));
832            println!("exp right (bytes) {}", l1.len() * 8);
833        }
834    }
835
836    #[test]
837    fn public_parameters_serialization_is_correct() {
838        public_parameters_serialization_is_correct_for_params(get_params())
839    }
840
841    #[test]
842    fn real_public_parameters_serialization_is_correct() {
843        let cfg_expand = r#"
844            {'n': 2,
845            'nu_1': 10,
846            'nu_2': 6,
847            'p': 512,
848            'q2_bits': 21,
849            's_e': 85.83255142749422,
850            't_gsw': 10,
851            't_conv': 4,
852            't_exp_left': 16,
853            't_exp_right': 56,
854            'instances': 11,
855            'db_item_size': 100000 }
856        "#;
857        let cfg = cfg_expand.replace("'", "\"");
858        let params = params_from_json(&cfg);
859        public_parameters_serialization_is_correct_for_params(params)
860    }
861
862    #[test]
863    fn real_public_parameters_2_serialization_is_correct() {
864        let cfg = r#"
865            { "n": 4,
866            "nu_1": 9,
867            "nu_2": 5,
868            "p": 256,
869            "q2_bits": 20,
870            "t_gsw": 8,
871            "t_conv": 4,
872            "t_exp_left": 8,
873            "t_exp_right": 56,
874            "instances": 2,
875            "db_item_size": 65536 }
876        "#;
877        let params = params_from_json(&cfg);
878        public_parameters_serialization_is_correct_for_params(params)
879    }
880
881    #[test]
882    fn no_expansion_public_parameters_serialization_is_correct() {
883        public_parameters_serialization_is_correct_for_params(get_no_expansion_testing_params())
884    }
885
886    fn query_serialization_is_correct_for_params(params: Params) {
887        let mut client = Client::init(&params);
888        _ = client.generate_keys();
889        let query = client.generate_query(1);
890
891        let serialized1 = query.serialize();
892        let deserialized1 = Query::deserialize(&params, &serialized1);
893        let serialized2 = deserialized1.serialize();
894
895        assert_eq!(serialized1.len(), serialized2.len());
896        for i in 0..serialized1.len() {
897            assert_eq!(serialized1[i], serialized2[i], "at {}", i);
898        }
899    }
900
901    #[test]
902    fn query_serialization_is_correct() {
903        query_serialization_is_correct_for_params(get_params())
904    }
905
906    #[test]
907    fn no_expansion_query_serialization_is_correct() {
908        query_serialization_is_correct_for_params(get_no_expansion_testing_params())
909    }
910}