tfhe/core_crypto/entities/packed_integers.rs
1use tfhe_versionable::Versionize;
2
3use crate::conformance::ParameterSetConformant;
4use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
5use crate::core_crypto::prelude::*;
6
7#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
8#[versionize(PackedIntegersVersions)]
9pub struct PackedIntegers<Scalar: UnsignedInteger> {
10 packed_coeffs: Vec<Scalar>,
11 log_modulus: CiphertextModulusLog,
12 initial_len: usize,
13}
14
15impl<Scalar: UnsignedInteger> PackedIntegers<Scalar> {
16 pub(crate) fn from_raw_parts(
17 packed_coeffs: Vec<Scalar>,
18 log_modulus: CiphertextModulusLog,
19 initial_len: usize,
20 ) -> Self {
21 let required_bits_packed = initial_len * log_modulus.0;
22 let expected_len = required_bits_packed.div_ceil(Scalar::BITS);
23
24 assert_eq!(
25 packed_coeffs.len(),
26 expected_len,
27 "Invalid size for the packed coeffs, got {}, expected {}",
28 packed_coeffs.len(),
29 expected_len
30 );
31
32 Self {
33 packed_coeffs,
34 log_modulus,
35 initial_len,
36 }
37 }
38
39 pub fn pack(slice: &[Scalar], log_modulus: CiphertextModulusLog) -> Self {
40 let log_modulus = log_modulus.0;
41
42 let in_len = slice.len();
43
44 let number_bits_to_pack = in_len * log_modulus;
45
46 let len = number_bits_to_pack.div_ceil(Scalar::BITS);
47
48 // Lowest bits are on the right
49 //
50 // Target mapping:
51 // log_modulus
52 // |-------|
53 //
54 // slice : | k+2 | k+1 | k |
55 // packed_coeffs: i+1 | i | i-1
56 //
57 // |---------------|
58 // Scalar::BITS
59 //
60 // |---|
61 // start_shift
62 //
63 // |---|
64 // shift1
65 // (1st loop iteration)
66 //
67 // |-----------|
68 // shift2
69 // (2nd loop iteration)
70 //
71 // packed_coeffs[i] =
72 // slice[k] >> start_shift
73 // | slice[k+1] << shift1
74 // | slice[k+2] << shift2
75 //
76 // In the lowest bits of packed_coeffs[i], we want the highest bits of slice[k],
77 // hence the right shift
78 // The next bits should be the bits of slice[k+1] which we must left shifted to avoid
79 // overlapping
80 // This goes on
81 let packed_coeffs = (0..len)
82 .map(|i| {
83 let k = Scalar::BITS * i / log_modulus;
84 let mut j = k;
85
86 let start_shift = i * Scalar::BITS - j * log_modulus;
87
88 debug_assert_eq!(slice[j] >> log_modulus, Scalar::ZERO);
89
90 let mut value = slice[j] >> start_shift;
91 j += 1;
92
93 while j * log_modulus < ((i + 1) * Scalar::BITS) && j < slice.len() {
94 let shift = j * log_modulus - i * Scalar::BITS;
95
96 debug_assert_eq!(slice[j] >> log_modulus, Scalar::ZERO);
97
98 value |= slice[j] << shift;
99
100 j += 1;
101 }
102 value
103 })
104 .collect();
105
106 let log_modulus = CiphertextModulusLog(log_modulus);
107
108 Self {
109 packed_coeffs,
110 log_modulus,
111 initial_len: slice.len(),
112 }
113 }
114
115 pub fn unpack(&self) -> impl Iterator<Item = Scalar> + '_ {
116 let log_modulus = self.log_modulus.0;
117
118 // log_modulus lowest bits set to 1
119 let mask = (Scalar::ONE << log_modulus) - Scalar::ONE;
120
121 (0..self.initial_len).map(move |i| {
122 let start = i * log_modulus;
123 let end = (i + 1) * log_modulus;
124
125 let start_block = start / Scalar::BITS;
126 let start_remainder = start % Scalar::BITS;
127
128 let end_block_inclusive = (end - 1) / Scalar::BITS;
129
130 if start_block == end_block_inclusive {
131 // Lowest bits are on the right
132 //
133 // Target mapping:
134 // Scalar::BITS
135 // |---------------|
136 //
137 // packed_coeffs: | start_block+1 | start_block |
138 // container : | i+1 | i | i-1 |
139 //
140 // |-------|
141 // log_modulus
142 //
143 // |---|
144 // start_remainder
145 //
146 // In container[i] we want the bits of packed_coeffs[start_block] starting from
147 // index start_remainder
148 //
149 // container[i] = lowest_bits of single_part
150 //
151 let single_part = self.packed_coeffs[start_block] >> start_remainder;
152
153 single_part & mask
154 } else {
155 // Lowest bits are on the right
156 //
157 // Target mapping:
158 // Scalar::BITS
159 // |---------------|
160 //
161 // packed_coeffs: | start_block+1 | start_block |
162 // container : | i+1 | i | i-1 |
163 //
164 // |-------|
165 // log_modulus
166 //
167 // |-----------|
168 // start_remainder
169 //
170 // |---|
171 // Scalar::BITS - start_remainder
172 //
173 // In the lowest bits of container[i] we want the highest bits of
174 // packed_coeffs[start_block] starting from index start_remainder
175 //
176 // In the next bits, we want the lowest bits of packed_coeffs[start_block + 1]
177 // left shifted to avoid overlapping
178 //
179 // container[i] = lowest_bits of (first_part|second_part)
180 //
181 assert_eq!(end_block_inclusive, start_block + 1);
182
183 let first_part = self.packed_coeffs[start_block] >> start_remainder;
184
185 let second_part =
186 self.packed_coeffs[start_block + 1] << (Scalar::BITS - start_remainder);
187
188 (first_part | second_part) & mask
189 }
190 })
191 }
192
193 pub fn log_modulus(&self) -> CiphertextModulusLog {
194 self.log_modulus
195 }
196
197 pub fn packed_coeffs(&self) -> &[Scalar] {
198 &self.packed_coeffs
199 }
200
201 pub fn initial_len(&self) -> usize {
202 self.initial_len
203 }
204}
205
206impl<Scalar: UnsignedInteger> ParameterSetConformant for PackedIntegers<Scalar> {
207 type ParameterSet = usize;
208
209 fn is_conformant(&self, len: &usize) -> bool {
210 let Self {
211 packed_coeffs,
212 log_modulus,
213 initial_len,
214 } = self;
215
216 let number_packed_bits = *len * log_modulus.0;
217
218 let packed_len = number_packed_bits.div_ceil(Scalar::BITS);
219
220 *len == *initial_len && packed_coeffs.len() == packed_len
221 }
222}