zhc_builder/iops/count.rs
1use zhc_crypto::integer_semantics::CiphertextSpec;
2use zhc_langs::ioplang::{Lut1Def, Lut2Def};
3use zhc_utils::{
4 iter::{ChunkIt, CollectInVec, UnwrapChunks},
5 n_bits_to_encode,
6};
7
8use crate::{
9 BitType, CiphertextBlock, NU,
10 builder::{Builder, Ciphertext},
11};
12
13/// Creates an IR that counts the number of **zero** bits in an encrypted integer.
14///
15/// The returned [`Builder`] declares one ciphertext input of `spec.int_size()`
16/// bits and one ciphertext output whose width is `⌈log₂(int_size + 1)⌉` bits
17/// — just enough to represent every possible count from 0 to `int_size`.
18/// Internally delegates to [`Builder::iop_count`] with [`BitType::Zero`].
19///
20/// The `spec` parameter describes the integer encoding (bit-width, message
21/// bits, carry bits) and determines the number of blocks in the
22/// decomposition.
23///
24/// # Examples
25///
26/// ```rust,no_run
27/// # use zhc_builder::{CiphertextSpec, count_0};
28/// # let spec = CiphertextSpec::new(16, 2, 2);
29/// let builder = count_0(spec);
30/// let ir = builder.into_ir();
31/// ```
32pub fn count_0(spec: CiphertextSpec) -> Builder {
33 let mut builder = Builder::new(spec.block_spec());
34 let src_a = builder.ciphertext_input(spec.int_size());
35 let res = builder.iop_count(&src_a, BitType::Zero);
36 builder.ciphertext_output(res);
37 builder
38}
39
40/// Creates an IR that counts the number of **one** bits in an encrypted integer.
41///
42/// The returned [`Builder`] declares one ciphertext input of `spec.int_size()`
43/// bits and one ciphertext output whose width is `⌈log₂(int_size + 1)⌉` bits
44/// — just enough to represent every possible count from 0 to `int_size`.
45/// Internally delegates to [`Builder::iop_count`] with [`BitType::One`].
46///
47/// The `spec` parameter describes the integer encoding (bit-width, message
48/// bits, carry bits) and determines the number of blocks in the
49/// decomposition.
50///
51/// # Examples
52///
53/// ```rust,no_run
54/// # use zhc_builder::{CiphertextSpec, count_1};
55/// # let spec = CiphertextSpec::new(16, 2, 2);
56/// let builder = count_1(spec);
57/// let ir = builder.into_ir();
58/// ```
59pub fn count_1(spec: CiphertextSpec) -> Builder {
60 let mut builder = Builder::new(spec.block_spec());
61 let src_a = builder.ciphertext_input(spec.int_size());
62 let res = builder.iop_count(&src_a, BitType::One);
63 builder.ciphertext_output(res);
64 builder
65}
66
67type Column = Vec<CiphertextBlock>;
68
69impl Builder {
70 /// Counts the number of zero or one bits in an encrypted integer.
71 ///
72 /// The operation splits `inp` into individual bits, then performs a
73 /// recursive column-based reduction that sums them with carry
74 /// propagation. When `kind` is [`BitType::Zero`], the first
75 /// reduction pass uses inverted look-up tables so that each bit
76 /// contributes 1 when it is zero. The returned [`Ciphertext`] has a
77 /// width of `⌈log₂(int_size + 1)⌉` bits, just enough to represent
78 /// every possible count from 0 to `int_size`.
79 ///
80 /// # Examples
81 ///
82 /// ```rust,no_run
83 /// # use zhc_builder::{CiphertextSpec, Builder, BitType};
84 /// # let spec = CiphertextSpec::new(16, 2, 2);
85 /// # let mut builder = Builder::new(spec.block_spec());
86 /// # let a = builder.ciphertext_input(spec.int_size());
87 /// let pop = builder.iop_count(&a, BitType::One);
88 /// ```
89 pub fn iop_count(&mut self, inp: &Ciphertext, kind: BitType) -> Ciphertext {
90 assert!(
91 inp.spec().int_size().is_multiple_of(2),
92 "Non-multiple-of-two integer size not supported."
93 );
94 self.with_comment("iop_count", || {
95 let blocks = self.ciphertext_split(inp);
96 let bits = self
97 .comment("extract bits")
98 .vector_lookup2(blocks, Lut2Def::ManyMsgSplit)
99 .into_iter()
100 .flat_map(|(l, r)| [l, r].into_iter())
101 .take(inp.spec().int_size() as usize)
102 .collect::<Vec<_>>();
103 let res = self.count_from_bits(bits, kind);
104 let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
105 let n_blocks = output_size.div_ceil(self.spec().message_size() as u16) as usize;
106 self.comment("output")
107 .ciphertext_join(&res[..n_blocks], Some(output_size))
108 })
109 }
110
111 pub(crate) fn count_reduce_recursive(
112 &self,
113 inp: impl AsRef<[Column]>,
114 kind: BitType,
115 ) -> Vec<Column> {
116 // The workhorse recursive reduction implementation.
117 //
118 // Works on columns. The last column contains only bits and as such can be added for longer
119 // before pbs. At the first iteration, there is only a single column, with all bits
120 // inside. Then as the reduction goes, carries are propagated to the column above.
121 // The reduction is finished when there are only one message per column. The columns are
122 // then turned to a ciphertext.
123 let inp = inp.as_ref();
124
125 self.push_comment("count_reduce_recursive");
126
127 if inp.iter().all(|col| col.len() <= 1) {
128 // Reduction is finished, can return.
129 return inp.to_vec();
130 } else {
131 let op_nb = NU;
132 let op_nb_bool = 1 << ((op_nb as f64).log2().ceil() as usize);
133 let op_nb_single = op_nb_bool - 1;
134 let reduction_iteration = inp.len();
135
136 let mut output: Vec<Column> = vec![Column::new(); inp.len() + 1];
137 for (col_idx, col) in inp.iter().enumerate() {
138 if col.len() == 1 {
139 output[col_idx].push(col[0]);
140 } else if col_idx == inp.len() - 1 {
141 self.push_comment(format!("last Column {col_idx}"));
142 col.iter()
143 .cloned()
144 .chunk(op_nb_single)
145 .unwrap_chunks()
146 .map(|chunk| {
147 let sum = self.vector_add_reduce(&chunk);
148 if kind == BitType::Zero && reduction_iteration == 1 {
149 self.with_comment("zero and first reduction", || {
150 match chunk.len() {
151 1 => self.block_lookup2(sum, Lut2Def::ManyInv1CarryMsg),
152 2 => self.block_lookup2(sum, Lut2Def::ManyInv2CarryMsg),
153 3 => self.block_lookup2(sum, Lut2Def::ManyInv3CarryMsg),
154 4 => self.block_lookup2(sum, Lut2Def::ManyInv4CarryMsg),
155 5 => self.block_lookup2(sum, Lut2Def::ManyInv5CarryMsg),
156 6 => self.block_lookup2(sum, Lut2Def::ManyInv6CarryMsg),
157 7 => self.block_lookup2(sum, Lut2Def::ManyInv7CarryMsg),
158 _ => unreachable!(),
159 }
160 })
161 } else {
162 self.with_comment("common branch", || {
163 self.block_lookup2(sum, Lut2Def::ManyCarryMsg)
164 })
165 }
166 })
167 .for_each(|(msg, carry)| {
168 output[col_idx].push(msg);
169 output[col_idx + 1].push(carry);
170 });
171 self.pop_comment();
172 } else {
173 self.push_comment(format!("regular Column {col_idx}"));
174 col.iter()
175 .cloned()
176 .chunk(op_nb)
177 .unwrap_chunks()
178 .map(|chunk| {
179 let sum = self.vector_add_reduce(&chunk);
180 if chunk.len() <= 2 {
181 // We have enough room to use a 2lookup in this case
182 self.block_lookup2(sum, Lut2Def::ManyCarryMsg)
183 } else {
184 // We don't have enough room. We must do two pbses
185 (
186 self.block_lookup(sum, Lut1Def::MsgOnly),
187 self.block_lookup(sum, Lut1Def::CarryInMsg),
188 )
189 }
190 })
191 .for_each(|(msg, carry)| {
192 output[col_idx].push(msg);
193 output[col_idx + 1].push(carry);
194 });
195 self.pop_comment();
196 }
197 }
198 let output = self.comment("output").count_reduce_recursive(output, kind);
199 self.pop_comment();
200 output
201 }
202 }
203
204 pub(crate) fn count_from_bits(
205 &self,
206 bits: impl AsRef<[CiphertextBlock]>,
207 kind: BitType,
208 ) -> Vec<CiphertextBlock> {
209 // Count bits of the given type.
210 // The input is a set of blocks each encoding a single bit.
211 self.with_comment("count_from_bits", || {
212 let bits = bits.as_ref().to_vec();
213 let res: Vec<Column> = self.count_reduce_recursive(vec![bits], kind);
214 let output = res
215 .into_iter()
216 .filter(|col| !col.is_empty())
217 .map(|col| col[0])
218 .covec();
219 self.comment("output").vector_inspect(output)
220 })
221 }
222}
223
224#[cfg(test)]
225mod test {
226 use super::*;
227 use zhc_langs::ioplang::IopValue;
228 use zhc_utils::assert_display_is;
229
230 #[test]
231 fn test_count0() {
232 let spec = CiphertextSpec::new(18, 2, 2);
233 let ir = count_0(spec).into_ir();
234 assert_display_is!(
235 ir.format()
236 .with_walker(zhc_ir::PrintWalker::Linear)
237 .show_comments(true)
238 .show_types(false),
239 r#"
240 | %0 = input_ciphertext<0, 18>();
241 // iop_count | %1 = extract_ct_block<0>(%0);
242 // iop_count | %2 = extract_ct_block<1>(%0);
243 // iop_count | %3 = extract_ct_block<2>(%0);
244 // iop_count | %4 = extract_ct_block<3>(%0);
245 // iop_count | %5 = extract_ct_block<4>(%0);
246 // iop_count | %6 = extract_ct_block<5>(%0);
247 // iop_count | %7 = extract_ct_block<6>(%0);
248 // iop_count | %8 = extract_ct_block<7>(%0);
249 // iop_count | %9 = extract_ct_block<8>(%0);
250 // iop_count / extract bits | %10, %11 = pbs2<ManyMsgSplit>(%1);
251 // iop_count / extract bits | %12, %13 = pbs2<ManyMsgSplit>(%2);
252 // iop_count / extract bits | %14, %15 = pbs2<ManyMsgSplit>(%3);
253 // iop_count / extract bits | %16, %17 = pbs2<ManyMsgSplit>(%4);
254 // iop_count / extract bits | %18, %19 = pbs2<ManyMsgSplit>(%5);
255 // iop_count / extract bits | %20, %21 = pbs2<ManyMsgSplit>(%6);
256 // iop_count / extract bits | %22, %23 = pbs2<ManyMsgSplit>(%7);
257 // iop_count / extract bits | %24, %25 = pbs2<ManyMsgSplit>(%8);
258 // iop_count / extract bits | %26, %27 = pbs2<ManyMsgSplit>(%9);
259 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %28 = add_ct(%10, %11);
260 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %29 = add_ct(%28, %12);
261 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %30 = add_ct(%29, %13);
262 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %31 = add_ct(%30, %14);
263 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %32 = add_ct(%31, %15);
264 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %33 = add_ct(%32, %16);
265 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction | %34, %35 = pbs2<ManyInv7CarryMsg>(%33);
266 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %36 = add_ct(%17, %18);
267 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %37 = add_ct(%36, %19);
268 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %38 = add_ct(%37, %20);
269 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %39 = add_ct(%38, %21);
270 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %40 = add_ct(%39, %22);
271 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %41 = add_ct(%40, %23);
272 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction | %42, %43 = pbs2<ManyInv7CarryMsg>(%41);
273 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %44 = add_ct(%24, %25);
274 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %45 = add_ct(%44, %26);
275 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 | %46 = add_ct(%45, %27);
276 // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction | %47, %48 = pbs2<ManyInv4CarryMsg>(%46);
277 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0 | %49 = add_ct(%34, %42);
278 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0 | %50 = add_ct(%49, %47);
279 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0 | %51 = pbs<Protect, MsgOnly>(%50);
280 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0 | %52 = pbs<Protect, CarryInMsg>(%50);
281 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1 | %53 = add_ct(%35, %43);
282 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1 | %54 = add_ct(%53, %48);
283 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1 / common branch | %55, %56 = pbs2<ManyCarryMsg>(%54);
284 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 1 | %57 = add_ct(%52, %55);
285 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 1 | %58, %59 = pbs2<ManyCarryMsg>(%57);
286 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 2 | %60 = add_ct(%59, %56);
287 // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 2 | %61, %62 = pbs2<ManyCarryMsg>(%60);
288 // iop_count / output | %67 = decl_ct<5>();
289 // iop_count / output | %68 = store_ct_block<0>(%51, %67);
290 // iop_count / output | %69 = store_ct_block<1>(%58, %68);
291 // iop_count / output | %70 = store_ct_block<2>(%61, %69);
292 | output<0>(%70);
293 "#
294 );
295 }
296
297 #[test]
298 fn correctness_count0() {
299 fn semantic(inp: &[IopValue]) -> Vec<IopValue> {
300 let [IopValue::Ciphertext(inp)] = inp else {
301 unreachable!()
302 };
303 let res = inp.as_storage().count_zeros() - (u128::BITS - inp.spec().int_size() as u32);
304 let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
305 vec![IopValue::Ciphertext(
306 inp.spec()
307 .block_spec()
308 .ciphertext_spec(output_size)
309 .from_int(res as u128),
310 )]
311 }
312
313 for size in (2..128).step_by(2) {
314 count_0(CiphertextSpec::new(size, 2, 2)).test_random(100, semantic);
315 }
316 }
317
318 #[test]
319 fn correctness_count1() {
320 fn semantic(inp: &[IopValue]) -> Vec<IopValue> {
321 let [IopValue::Ciphertext(inp)] = inp else {
322 unreachable!()
323 };
324 let res = inp.as_storage().count_ones();
325 let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
326 vec![IopValue::Ciphertext(
327 inp.spec()
328 .block_spec()
329 .ciphertext_spec(output_size)
330 .from_int(res as u128),
331 )]
332 }
333
334 for size in (2..128).step_by(2) {
335 count_1(CiphertextSpec::new(size, 2, 2)).test_random(100, semantic);
336 }
337 }
338}