Skip to main content

zhc_builder/iops/
count.rs

1use zhc_crypto::integer_semantics::CiphertextSpec;
2use zhc_langs::ioplang::{Lut1Def, Lut2Def};
3use zhc_utils::{
4    iter::{ChunkIt, CollectInVec, UnwrapChunks},
5    n_bits_to_encode,
6};
7
8use crate::{
9    BitType, CiphertextBlock, NU,
10    builder::{Builder, Ciphertext},
11};
12
13/// Creates an IR that counts the number of **zero** bits in an encrypted integer.
14///
15/// The returned [`Builder`] declares one ciphertext input of `spec.int_size()`
16/// bits and one ciphertext output whose width is `⌈log₂(int_size + 1)⌉` bits
17/// — just enough to represent every possible count from 0 to `int_size`.
18/// Internally delegates to [`Builder::iop_count`] with [`BitType::Zero`].
19///
20/// The `spec` parameter describes the integer encoding (bit-width, message
21/// bits, carry bits) and determines the number of blocks in the
22/// decomposition.
23///
24/// # Examples
25///
26/// ```rust,no_run
27/// # use zhc_builder::{CiphertextSpec, count_0};
28/// # let spec = CiphertextSpec::new(16, 2, 2);
29/// let builder = count_0(spec);
30/// let ir = builder.into_ir();
31/// ```
32pub fn count_0(spec: CiphertextSpec) -> Builder {
33    let mut builder = Builder::new(spec.block_spec());
34    let src_a = builder.ciphertext_input(spec.int_size());
35    let res = builder.iop_count(&src_a, BitType::Zero);
36    builder.ciphertext_output(res);
37    builder
38}
39
40/// Creates an IR that counts the number of **one** bits in an encrypted integer.
41///
42/// The returned [`Builder`] declares one ciphertext input of `spec.int_size()`
43/// bits and one ciphertext output whose width is `⌈log₂(int_size + 1)⌉` bits
44/// — just enough to represent every possible count from 0 to `int_size`.
45/// Internally delegates to [`Builder::iop_count`] with [`BitType::One`].
46///
47/// The `spec` parameter describes the integer encoding (bit-width, message
48/// bits, carry bits) and determines the number of blocks in the
49/// decomposition.
50///
51/// # Examples
52///
53/// ```rust,no_run
54/// # use zhc_builder::{CiphertextSpec, count_1};
55/// # let spec = CiphertextSpec::new(16, 2, 2);
56/// let builder = count_1(spec);
57/// let ir = builder.into_ir();
58/// ```
59pub fn count_1(spec: CiphertextSpec) -> Builder {
60    let mut builder = Builder::new(spec.block_spec());
61    let src_a = builder.ciphertext_input(spec.int_size());
62    let res = builder.iop_count(&src_a, BitType::One);
63    builder.ciphertext_output(res);
64    builder
65}
66
67type Column = Vec<CiphertextBlock>;
68
69impl Builder {
70    /// Counts the number of zero or one bits in an encrypted integer.
71    ///
72    /// The operation splits `inp` into individual bits, then performs a
73    /// recursive column-based reduction that sums them with carry
74    /// propagation. When `kind` is [`BitType::Zero`], the first
75    /// reduction pass uses inverted look-up tables so that each bit
76    /// contributes 1 when it is zero. The returned [`Ciphertext`] has a
77    /// width of `⌈log₂(int_size + 1)⌉` bits, just enough to represent
78    /// every possible count from 0 to `int_size`.
79    ///
80    /// # Examples
81    ///
82    /// ```rust,no_run
83    /// # use zhc_builder::{CiphertextSpec, Builder, BitType};
84    /// # let spec = CiphertextSpec::new(16, 2, 2);
85    /// # let mut builder = Builder::new(spec.block_spec());
86    /// # let a = builder.ciphertext_input(spec.int_size());
87    /// let pop = builder.iop_count(&a, BitType::One);
88    /// ```
89    pub fn iop_count(&mut self, inp: &Ciphertext, kind: BitType) -> Ciphertext {
90        assert!(
91            inp.spec().int_size().is_multiple_of(2),
92            "Non-multiple-of-two integer size not supported."
93        );
94        self.with_comment("iop_count", || {
95            let blocks = self.ciphertext_split(inp);
96            let bits = self
97                .comment("extract bits")
98                .vector_lookup2(blocks, Lut2Def::ManyMsgSplit)
99                .into_iter()
100                .flat_map(|(l, r)| [l, r].into_iter())
101                .take(inp.spec().int_size() as usize)
102                .collect::<Vec<_>>();
103            let res = self.count_from_bits(bits, kind);
104            let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
105            let n_blocks = output_size.div_ceil(self.spec().message_size() as u16) as usize;
106            self.comment("output")
107                .ciphertext_join(&res[..n_blocks], Some(output_size))
108        })
109    }
110
111    pub(crate) fn count_reduce_recursive(
112        &self,
113        inp: impl AsRef<[Column]>,
114        kind: BitType,
115    ) -> Vec<Column> {
116        // The workhorse recursive reduction implementation.
117        //
118        // Works on columns. The last column contains only bits and as such can be added for longer
119        // before pbs. At the first iteration, there is only a single column, with all bits
120        // inside. Then as the reduction goes, carries are propagated to the column above.
121        // The reduction is finished when there are only one message per column. The columns are
122        // then turned to a ciphertext.
123        let inp = inp.as_ref();
124
125        self.push_comment("count_reduce_recursive");
126
127        if inp.iter().all(|col| col.len() <= 1) {
128            // Reduction is finished, can return.
129            return inp.to_vec();
130        } else {
131            let op_nb = NU;
132            let op_nb_bool = 1 << ((op_nb as f64).log2().ceil() as usize);
133            let op_nb_single = op_nb_bool - 1;
134            let reduction_iteration = inp.len();
135
136            let mut output: Vec<Column> = vec![Column::new(); inp.len() + 1];
137            for (col_idx, col) in inp.iter().enumerate() {
138                if col.len() == 1 {
139                    output[col_idx].push(col[0]);
140                } else if col_idx == inp.len() - 1 {
141                    self.push_comment(format!("last Column {col_idx}"));
142                    col.iter()
143                        .cloned()
144                        .chunk(op_nb_single)
145                        .unwrap_chunks()
146                        .map(|chunk| {
147                            let sum = self.vector_add_reduce(&chunk);
148                            if kind == BitType::Zero && reduction_iteration == 1 {
149                                self.with_comment("zero and first reduction", || {
150                                    match chunk.len() {
151                                        1 => self.block_lookup2(sum, Lut2Def::ManyInv1CarryMsg),
152                                        2 => self.block_lookup2(sum, Lut2Def::ManyInv2CarryMsg),
153                                        3 => self.block_lookup2(sum, Lut2Def::ManyInv3CarryMsg),
154                                        4 => self.block_lookup2(sum, Lut2Def::ManyInv4CarryMsg),
155                                        5 => self.block_lookup2(sum, Lut2Def::ManyInv5CarryMsg),
156                                        6 => self.block_lookup2(sum, Lut2Def::ManyInv6CarryMsg),
157                                        7 => self.block_lookup2(sum, Lut2Def::ManyInv7CarryMsg),
158                                        _ => unreachable!(),
159                                    }
160                                })
161                            } else {
162                                self.with_comment("common branch", || {
163                                    self.block_lookup2(sum, Lut2Def::ManyCarryMsg)
164                                })
165                            }
166                        })
167                        .for_each(|(msg, carry)| {
168                            output[col_idx].push(msg);
169                            output[col_idx + 1].push(carry);
170                        });
171                    self.pop_comment();
172                } else {
173                    self.push_comment(format!("regular Column {col_idx}"));
174                    col.iter()
175                        .cloned()
176                        .chunk(op_nb)
177                        .unwrap_chunks()
178                        .map(|chunk| {
179                            let sum = self.vector_add_reduce(&chunk);
180                            if chunk.len() <= 2 {
181                                // We have enough room to use a 2lookup in this case
182                                self.block_lookup2(sum, Lut2Def::ManyCarryMsg)
183                            } else {
184                                // We don't have enough room. We must do two pbses
185                                (
186                                    self.block_lookup(sum, Lut1Def::MsgOnly),
187                                    self.block_lookup(sum, Lut1Def::CarryInMsg),
188                                )
189                            }
190                        })
191                        .for_each(|(msg, carry)| {
192                            output[col_idx].push(msg);
193                            output[col_idx + 1].push(carry);
194                        });
195                    self.pop_comment();
196                }
197            }
198            let output = self.comment("output").count_reduce_recursive(output, kind);
199            self.pop_comment();
200            output
201        }
202    }
203
204    pub(crate) fn count_from_bits(
205        &self,
206        bits: impl AsRef<[CiphertextBlock]>,
207        kind: BitType,
208    ) -> Vec<CiphertextBlock> {
209        // Count bits of the given type.
210        // The input is a set of blocks each encoding a single bit.
211        self.with_comment("count_from_bits", || {
212            let bits = bits.as_ref().to_vec();
213            let res: Vec<Column> = self.count_reduce_recursive(vec![bits], kind);
214            let output = res
215                .into_iter()
216                .filter(|col| !col.is_empty())
217                .map(|col| col[0])
218                .covec();
219            self.comment("output").vector_inspect(output)
220        })
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use super::*;
227    use zhc_langs::ioplang::IopValue;
228    use zhc_utils::assert_display_is;
229
230    #[test]
231    fn test_count0() {
232        let spec = CiphertextSpec::new(18, 2, 2);
233        let ir = count_0(spec).into_ir();
234        assert_display_is!(
235            ir.format()
236                .with_walker(zhc_ir::PrintWalker::Linear)
237                .show_comments(true)
238                .show_types(false),
239            r#"
240                                                                                                                                                                                                   | %0 = input_ciphertext<0, 18>();
241                // iop_count                                                                                                                                                                       | %1 = extract_ct_block<0>(%0);
242                // iop_count                                                                                                                                                                       | %2 = extract_ct_block<1>(%0);
243                // iop_count                                                                                                                                                                       | %3 = extract_ct_block<2>(%0);
244                // iop_count                                                                                                                                                                       | %4 = extract_ct_block<3>(%0);
245                // iop_count                                                                                                                                                                       | %5 = extract_ct_block<4>(%0);
246                // iop_count                                                                                                                                                                       | %6 = extract_ct_block<5>(%0);
247                // iop_count                                                                                                                                                                       | %7 = extract_ct_block<6>(%0);
248                // iop_count                                                                                                                                                                       | %8 = extract_ct_block<7>(%0);
249                // iop_count                                                                                                                                                                       | %9 = extract_ct_block<8>(%0);
250                // iop_count / extract bits                                                                                                                                                        | %10, %11 = pbs2<ManyMsgSplit>(%1);
251                // iop_count / extract bits                                                                                                                                                        | %12, %13 = pbs2<ManyMsgSplit>(%2);
252                // iop_count / extract bits                                                                                                                                                        | %14, %15 = pbs2<ManyMsgSplit>(%3);
253                // iop_count / extract bits                                                                                                                                                        | %16, %17 = pbs2<ManyMsgSplit>(%4);
254                // iop_count / extract bits                                                                                                                                                        | %18, %19 = pbs2<ManyMsgSplit>(%5);
255                // iop_count / extract bits                                                                                                                                                        | %20, %21 = pbs2<ManyMsgSplit>(%6);
256                // iop_count / extract bits                                                                                                                                                        | %22, %23 = pbs2<ManyMsgSplit>(%7);
257                // iop_count / extract bits                                                                                                                                                        | %24, %25 = pbs2<ManyMsgSplit>(%8);
258                // iop_count / extract bits                                                                                                                                                        | %26, %27 = pbs2<ManyMsgSplit>(%9);
259                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %28 = add_ct(%10, %11);
260                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %29 = add_ct(%28, %12);
261                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %30 = add_ct(%29, %13);
262                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %31 = add_ct(%30, %14);
263                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %32 = add_ct(%31, %15);
264                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %33 = add_ct(%32, %16);
265                // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction                                                                                 | %34, %35 = pbs2<ManyInv7CarryMsg>(%33);
266                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %36 = add_ct(%17, %18);
267                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %37 = add_ct(%36, %19);
268                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %38 = add_ct(%37, %20);
269                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %39 = add_ct(%38, %21);
270                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %40 = add_ct(%39, %22);
271                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %41 = add_ct(%40, %23);
272                // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction                                                                                 | %42, %43 = pbs2<ManyInv7CarryMsg>(%41);
273                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %44 = add_ct(%24, %25);
274                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %45 = add_ct(%44, %26);
275                // iop_count / count_from_bits / count_reduce_recursive / last Column 0                                                                                                            | %46 = add_ct(%45, %27);
276                // iop_count / count_from_bits / count_reduce_recursive / last Column 0 / zero and first reduction                                                                                 | %47, %48 = pbs2<ManyInv4CarryMsg>(%46);
277                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0                                                                       | %49 = add_ct(%34, %42);
278                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0                                                                       | %50 = add_ct(%49, %47);
279                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0                                                                       | %51 = pbs<Protect, MsgOnly>(%50);
280                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / regular Column 0                                                                       | %52 = pbs<Protect, CarryInMsg>(%50);
281                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1                                                                          | %53 = add_ct(%35, %43);
282                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1                                                                          | %54 = add_ct(%53, %48);
283                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / last Column 1 / common branch                                                          | %55, %56 = pbs2<ManyCarryMsg>(%54);
284                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 1                                     | %57 = add_ct(%52, %55);
285                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 1                                     | %58, %59 = pbs2<ManyCarryMsg>(%57);
286                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 2   | %60 = add_ct(%59, %56);
287                // iop_count / count_from_bits / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / output / count_reduce_recursive / regular Column 2   | %61, %62 = pbs2<ManyCarryMsg>(%60);
288                // iop_count / output                                                                                                                                                              | %67 = decl_ct<5>();
289                // iop_count / output                                                                                                                                                              | %68 = store_ct_block<0>(%51, %67);
290                // iop_count / output                                                                                                                                                              | %69 = store_ct_block<1>(%58, %68);
291                // iop_count / output                                                                                                                                                              | %70 = store_ct_block<2>(%61, %69);
292                                                                                                                                                                                                   | output<0>(%70);
293            "#
294        );
295    }
296
297    #[test]
298    fn correctness_count0() {
299        fn semantic(inp: &[IopValue]) -> Vec<IopValue> {
300            let [IopValue::Ciphertext(inp)] = inp else {
301                unreachable!()
302            };
303            let res = inp.as_storage().count_zeros() - (u128::BITS - inp.spec().int_size() as u32);
304            let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
305            vec![IopValue::Ciphertext(
306                inp.spec()
307                    .block_spec()
308                    .ciphertext_spec(output_size)
309                    .from_int(res as u128),
310            )]
311        }
312
313        for size in (2..128).step_by(2) {
314            count_0(CiphertextSpec::new(size, 2, 2)).test_random(100, semantic);
315        }
316    }
317
318    #[test]
319    fn correctness_count1() {
320        fn semantic(inp: &[IopValue]) -> Vec<IopValue> {
321            let [IopValue::Ciphertext(inp)] = inp else {
322                unreachable!()
323            };
324            let res = inp.as_storage().count_ones();
325            let output_size: u16 = n_bits_to_encode(inp.spec().int_size());
326            vec![IopValue::Ciphertext(
327                inp.spec()
328                    .block_spec()
329                    .ciphertext_spec(output_size)
330                    .from_int(res as u128),
331            )]
332        }
333
334        for size in (2..128).step_by(2) {
335            count_1(CiphertextSpec::new(size, 2, 2)).test_random(100, semantic);
336        }
337    }
338}