sp1_recursion_compiler/ir/
bits.rs

1use p3_field::{AbstractField, Field};
2use sp1_recursion_core::runtime::NUM_BITS;
3
4use super::{Array, Builder, Config, DslIr, Felt, Usize, Var};
5
6impl<C: Config> Builder<C> {
7    /// Converts a variable to LE bits.
8    pub fn num2bits_v(&mut self, num: Var<C::N>) -> Array<C, Var<C::N>> {
9        // This function is only used when the native field is Babybear.
10        assert!(C::N::bits() == NUM_BITS);
11
12        let output = self.dyn_array::<Var<_>>(NUM_BITS);
13        self.push_op(DslIr::HintBitsV(output.clone(), num));
14
15        let sum: Var<_> = self.eval(C::N::zero());
16        for i in 0..NUM_BITS {
17            let bit = self.get(&output, i);
18            self.assert_var_eq(bit * (bit - C::N::one()), C::N::zero());
19            self.assign(sum, sum + bit * C::N::from_canonical_u32(1 << i));
20        }
21
22        self.assert_var_eq(sum, num);
23
24        self.less_than_bb_modulus(output.clone());
25
26        output
27    }
28
29    /// Range checks a variable to a certain number of bits.
30    pub fn range_check_v(&mut self, num: Var<C::N>, num_bits: usize) {
31        let bits = self.num2bits_v(num);
32        self.range(num_bits, bits.len()).for_each(|i, builder| {
33            let bit = builder.get(&bits, i);
34            builder.assert_var_eq(bit, C::N::zero());
35        });
36    }
37
38    /// Converts a variable to bits inside a circuit.
39    pub fn num2bits_v_circuit(&mut self, num: Var<C::N>, bits: usize) -> Vec<Var<C::N>> {
40        let mut output = Vec::new();
41        for _ in 0..bits {
42            output.push(self.uninit());
43        }
44
45        self.push_op(DslIr::CircuitNum2BitsV(num, bits, output.clone()));
46
47        output
48    }
49
50    /// Range checks a felt to a certain number of bits.
51    pub fn range_check_f(&mut self, num: Felt<C::F>, num_bits: usize) {
52        let bits = self.num2bits_f(num);
53        self.range(num_bits, bits.len()).for_each(|i, builder| {
54            let bit = builder.get(&bits, i);
55            builder.assert_var_eq(bit, C::N::zero());
56        });
57    }
58
59    /// Converts a felt to bits.
60    pub fn num2bits_f(&mut self, num: Felt<C::F>) -> Array<C, Var<C::N>> {
61        let output = self.dyn_array::<Var<_>>(NUM_BITS);
62        self.push_op(DslIr::HintBitsF(output.clone(), num));
63
64        let sum: Felt<_> = self.eval(C::F::zero());
65        for i in 0..NUM_BITS {
66            let bit = self.get(&output, i);
67            self.assert_var_eq(bit * (bit - C::N::one()), C::N::zero());
68            self.if_eq(bit, C::N::one()).then(|builder| {
69                builder.assign(sum, sum + C::F::from_canonical_u32(1 << i));
70            });
71        }
72
73        self.assert_felt_eq(sum, num);
74
75        self.less_than_bb_modulus(output.clone());
76
77        output
78    }
79
80    /// Converts a felt to bits inside a circuit.
81    pub fn num2bits_f_circuit(&mut self, num: Felt<C::F>) -> Vec<Var<C::N>> {
82        let mut output = Vec::new();
83        for _ in 0..NUM_BITS {
84            output.push(self.uninit());
85        }
86
87        self.push_op(DslIr::CircuitNum2BitsF(num, output.clone()));
88
89        output
90    }
91
92    /// Convert bits to a variable.
93    pub fn bits2num_v(&mut self, bits: &Array<C, Var<C::N>>) -> Var<C::N> {
94        let num: Var<_> = self.eval(C::N::zero());
95        let power: Var<_> = self.eval(C::N::one());
96        self.range(0, bits.len()).for_each(|i, builder| {
97            let bit = builder.get(bits, i);
98            builder.assign(num, num + bit * power);
99            builder.assign(power, power * C::N::from_canonical_u32(2));
100        });
101        num
102    }
103
104    /// Convert bits to a variable inside a circuit.
105    pub fn bits2num_v_circuit(&mut self, bits: &[Var<C::N>]) -> Var<C::N> {
106        let result: Var<_> = self.eval(C::N::zero());
107        for i in 0..bits.len() {
108            self.assign(result, result + bits[i] * C::N::from_canonical_u32(1 << i));
109        }
110        result
111    }
112
113    /// Convert bits to a felt.
114    pub fn bits2num_f(&mut self, bits: &Array<C, Var<C::N>>) -> Felt<C::F> {
115        let num: Felt<_> = self.eval(C::F::zero());
116        for i in 0..NUM_BITS {
117            let bit = self.get(bits, i);
118            // Add `bit * 2^i` to the sum.
119            self.if_eq(bit, C::N::one()).then(|builder| {
120                builder.assign(num, num + C::F::from_canonical_u32(1 << i));
121            });
122        }
123        num
124    }
125
126    /// Reverse a list of bits.
127    ///
128    /// SAFETY: calling this function with `bit_len` greater [`NUM_BITS`] will result in undefined
129    /// behavior.
130    ///
131    /// Reference: [p3_util::reverse_bits_len]
132    pub fn reverse_bits_len(
133        &mut self,
134        index_bits: &Array<C, Var<C::N>>,
135        bit_len: impl Into<Usize<C::N>>,
136    ) -> Array<C, Var<C::N>> {
137        let bit_len = bit_len.into();
138
139        let mut result_bits = self.dyn_array::<Var<_>>(NUM_BITS);
140        self.range(0, bit_len).for_each(|i, builder| {
141            let index: Var<C::N> = builder.eval(bit_len - i - C::N::one());
142            let entry = builder.get(index_bits, index);
143            builder.set_value(&mut result_bits, i, entry);
144        });
145
146        let zero = self.eval(C::N::zero());
147        self.range(bit_len, NUM_BITS).for_each(|i, builder| {
148            builder.set_value(&mut result_bits, i, zero);
149        });
150
151        result_bits
152    }
153
154    /// Reverse a list of bits inside a circuit.
155    ///
156    /// SAFETY: calling this function with `bit_len` greater [`NUM_BITS`] will result in undefined
157    /// behavior.
158    ///
159    /// Reference: [p3_util::reverse_bits_len]
160    pub fn reverse_bits_len_circuit(
161        &mut self,
162        index_bits: Vec<Var<C::N>>,
163        bit_len: usize,
164    ) -> Vec<Var<C::N>> {
165        assert!(bit_len <= NUM_BITS);
166        let mut result_bits = Vec::new();
167        for i in 0..bit_len {
168            let idx = bit_len - i - 1;
169            result_bits.push(index_bits[idx]);
170        }
171        result_bits
172    }
173
174    /// Checks that the LE bit decomposition of a number is less than the babybear modulus.
175    ///
176    /// SAFETY: This function assumes that the num_bits values are already verified to be boolean.
177    ///
178    /// The babybear modulus in LE bits is: 100_000_000_000_000_000_000_000_000_111_1.
179    /// To check that the num_bits array is less than that value, we first check if the most
180    /// significant bits are all 1.  If it is, then we assert that the other bits are all 0.
181    fn less_than_bb_modulus(&mut self, num_bits: Array<C, Var<C::N>>) {
182        let one: Var<_> = self.eval(C::N::one());
183        let zero: Var<_> = self.eval(C::N::zero());
184
185        let mut most_sig_4_bits = one;
186        for i in (NUM_BITS - 4)..NUM_BITS {
187            let bit = self.get(&num_bits, i);
188            most_sig_4_bits = self.eval(bit * most_sig_4_bits);
189        }
190
191        let mut sum_least_sig_bits = zero;
192        for i in 0..(NUM_BITS - 4) {
193            let bit = self.get(&num_bits, i);
194            sum_least_sig_bits = self.eval(bit + sum_least_sig_bits);
195        }
196
197        // If the most significant 4 bits are all 1, then check the sum of the least significant
198        // bits, else return zero.
199        let check: Var<_> =
200            self.eval(most_sig_4_bits * sum_least_sig_bits + (one - most_sig_4_bits) * zero);
201        self.assert_var_eq(check, zero);
202    }
203}