tfhe/core_crypto/entities/
packed_integers.rs

1use tfhe_versionable::Versionize;
2
3use crate::conformance::ParameterSetConformant;
4use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
5use crate::core_crypto::prelude::*;
6
7#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
8#[versionize(PackedIntegersVersions)]
9pub struct PackedIntegers<Scalar: UnsignedInteger> {
10    packed_coeffs: Vec<Scalar>,
11    log_modulus: CiphertextModulusLog,
12    initial_len: usize,
13}
14
15impl<Scalar: UnsignedInteger> PackedIntegers<Scalar> {
16    pub(crate) fn from_raw_parts(
17        packed_coeffs: Vec<Scalar>,
18        log_modulus: CiphertextModulusLog,
19        initial_len: usize,
20    ) -> Self {
21        let required_bits_packed = initial_len * log_modulus.0;
22        let expected_len = required_bits_packed.div_ceil(Scalar::BITS);
23
24        assert_eq!(
25            packed_coeffs.len(),
26            expected_len,
27            "Invalid size for the packed coeffs, got {}, expected {}",
28            packed_coeffs.len(),
29            expected_len
30        );
31
32        Self {
33            packed_coeffs,
34            log_modulus,
35            initial_len,
36        }
37    }
38
39    pub fn pack(slice: &[Scalar], log_modulus: CiphertextModulusLog) -> Self {
40        let log_modulus = log_modulus.0;
41
42        let in_len = slice.len();
43
44        let number_bits_to_pack = in_len * log_modulus;
45
46        let len = number_bits_to_pack.div_ceil(Scalar::BITS);
47
48        // Lowest bits are on the right
49        //
50        // Target mapping:
51        //                          log_modulus
52        //                           |-------|
53        //
54        // slice        :    |  k+2  |  k+1  |   k   |
55        // packed_coeffs:  i+1   |       i       |     i-1
56        //
57        //                       |---------------|
58        //                         Scalar::BITS
59        //
60        //                                       |---|
61        //                                    start_shift
62        //
63        //                                   |---|
64        //                                   shift1
65        //                             (1st loop iteration)
66        //
67        //                           |-----------|
68        //                               shift2
69        //                        (2nd loop iteration)
70        //
71        // packed_coeffs[i] =
72        //                    slice[k] >> start_shift
73        //                  | slice[k+1] << shift1
74        //                  | slice[k+2] << shift2
75        //
76        // In the lowest bits of packed_coeffs[i], we want the highest bits of slice[k],
77        // hence the right shift
78        // The next bits should be the bits of slice[k+1] which we must left shifted to avoid
79        // overlapping
80        // This goes on
81        let packed_coeffs = (0..len)
82            .map(|i| {
83                let k = Scalar::BITS * i / log_modulus;
84                let mut j = k;
85
86                let start_shift = i * Scalar::BITS - j * log_modulus;
87
88                debug_assert_eq!(slice[j] >> log_modulus, Scalar::ZERO);
89
90                let mut value = slice[j] >> start_shift;
91                j += 1;
92
93                while j * log_modulus < ((i + 1) * Scalar::BITS) && j < slice.len() {
94                    let shift = j * log_modulus - i * Scalar::BITS;
95
96                    debug_assert_eq!(slice[j] >> log_modulus, Scalar::ZERO);
97
98                    value |= slice[j] << shift;
99
100                    j += 1;
101                }
102                value
103            })
104            .collect();
105
106        let log_modulus = CiphertextModulusLog(log_modulus);
107
108        Self {
109            packed_coeffs,
110            log_modulus,
111            initial_len: slice.len(),
112        }
113    }
114
115    pub fn unpack(&self) -> impl Iterator<Item = Scalar> + '_ {
116        let log_modulus = self.log_modulus.0;
117
118        // log_modulus lowest bits set to 1
119        let mask = (Scalar::ONE << log_modulus) - Scalar::ONE;
120
121        (0..self.initial_len).map(move |i| {
122            let start = i * log_modulus;
123            let end = (i + 1) * log_modulus;
124
125            let start_block = start / Scalar::BITS;
126            let start_remainder = start % Scalar::BITS;
127
128            let end_block_inclusive = (end - 1) / Scalar::BITS;
129
130            if start_block == end_block_inclusive {
131                // Lowest bits are on the right
132                //
133                // Target mapping:
134                //                                   Scalar::BITS
135                //                                |---------------|
136                //
137                // packed_coeffs: | start_block+1 |  start_block  |
138                // container    :             |  i+1  |   i   |  i-1  |
139                //
140                //                                    |-------|
141                //                                   log_modulus
142                //
143                //                                            |---|
144                //                                       start_remainder
145                //
146                // In container[i] we want the bits of packed_coeffs[start_block] starting from
147                // index start_remainder
148                //
149                // container[i] = lowest_bits of single_part
150                //
151                let single_part = self.packed_coeffs[start_block] >> start_remainder;
152
153                single_part & mask
154            } else {
155                // Lowest bits are on the right
156                //
157                // Target mapping:
158                //                                   Scalar::BITS
159                //                                 |---------------|
160                //
161                // packed_coeffs:  | start_block+1 |  start_block  |
162                // container    :      |  i+1  |   i   |  i-1  |
163                //
164                //                             |-------|
165                //                            log_modulus
166                //
167                //                                     |-----------|
168                //                                    start_remainder
169                //
170                //                                 |---|
171                //                     Scalar::BITS - start_remainder
172                //
173                // In the lowest bits of container[i] we want the highest bits of
174                // packed_coeffs[start_block] starting from index start_remainder
175                //
176                // In the next bits, we want the lowest bits of packed_coeffs[start_block + 1]
177                // left shifted to avoid overlapping
178                //
179                // container[i] = lowest_bits of (first_part|second_part)
180                //
181                assert_eq!(end_block_inclusive, start_block + 1);
182
183                let first_part = self.packed_coeffs[start_block] >> start_remainder;
184
185                let second_part =
186                    self.packed_coeffs[start_block + 1] << (Scalar::BITS - start_remainder);
187
188                (first_part | second_part) & mask
189            }
190        })
191    }
192
193    pub fn log_modulus(&self) -> CiphertextModulusLog {
194        self.log_modulus
195    }
196
197    pub fn packed_coeffs(&self) -> &[Scalar] {
198        &self.packed_coeffs
199    }
200
201    pub fn initial_len(&self) -> usize {
202        self.initial_len
203    }
204}
205
206impl<Scalar: UnsignedInteger> ParameterSetConformant for PackedIntegers<Scalar> {
207    type ParameterSet = usize;
208
209    fn is_conformant(&self, len: &usize) -> bool {
210        let Self {
211            packed_coeffs,
212            log_modulus,
213            initial_len,
214        } = self;
215
216        let number_packed_bits = *len * log_modulus.0;
217
218        let packed_len = number_packed_bits.div_ceil(Scalar::BITS);
219
220        *len == *initial_len && packed_coeffs.len() == packed_len
221    }
222}