1use group::ff::PrimeField;
5use sigma_proofs::errors::Error;
6use subtle::Choice;
7
8pub fn bit_decomp_vartime<S: PrimeField>(mut s: S) -> Option<(u128, u32)> {
15 let mut val = 0u128;
16 let mut bitnum = 0u32;
17 let mut bitval = 1u128; while bitnum < 127 && !s.is_zero_vartime() {
19 if s.is_odd().into() {
20 val += bitval;
21 s -= S::ONE;
22 }
23 bitnum += 1;
24 bitval <<= 1;
25 s *= S::TWO_INV;
26 }
27 if s.is_zero_vartime() {
28 Some((val, bitnum))
29 } else {
30 None
31 }
32}
33
34pub fn bit_decomp<S: PrimeField>(mut s: S, nbits: u32) -> Vec<Choice> {
40 let mut bits = Vec::with_capacity(nbits as usize);
41 let mut bitnum = 0u32;
42 while bitnum < nbits && bitnum < 127 {
43 let lowbit = s.is_odd();
44 s -= S::conditional_select(&S::ZERO, &S::ONE, lowbit);
45 s *= S::TWO_INV;
46 bits.push(lowbit);
47 bitnum += 1;
48 }
49 bits
50}
51
52pub fn bitrep_scalars_vartime<S: PrimeField>(upper: S) -> Result<Vec<S>, Error> {
71 let (upper_val, mut nbits) = bit_decomp_vartime(upper).ok_or(Error::VerificationFailure)?;
73
74 if nbits < 2 {
76 return Err(Error::VerificationFailure);
77 }
78
79 if upper_val == 1u128 << (nbits - 1) {
81 nbits -= 1;
82 }
83
84 Ok((0..nbits)
87 .map(|i| {
88 if i < nbits - 1 {
89 S::from_u128(1u128 << i)
90 } else {
91 S::from_u128(upper_val - (1u128 << (nbits - 1)))
93 }
94 })
95 .collect())
96}
97
98pub fn compute_bitrep<S: PrimeField>(mut x: S, bitrep_scalars: &[S]) -> Vec<Choice> {
110 let nbits: u32 = bitrep_scalars.len().try_into().unwrap();
112
113 let x_raw_bits = bit_decomp(x, nbits);
117 let high_bit = x_raw_bits[(nbits as usize) - 1];
118
119 x -= S::conditional_select(&S::ZERO, &bitrep_scalars[(nbits as usize) - 1], high_bit);
126
127 let mut x_bits = bit_decomp(x, nbits - 1);
129
130 x_bits.push(high_bit);
132
133 x_bits
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use curve25519_dalek::scalar::Scalar;
140 use std::ops::Neg;
141 use subtle::ConditionallySelectable;
142
143 fn bit_decomp_tester(s: Scalar, nbits: u32, expect_bitstr: &str) {
144 assert_eq!(
147 bit_decomp(s, nbits)
148 .into_iter()
149 .map(|c| char::from(u8::conditional_select(&b'0', &b'1', c)))
150 .collect::<String>(),
151 expect_bitstr
152 );
153 }
154
155 #[test]
156 fn bit_decomp_test() {
157 assert_eq!(bit_decomp_vartime(Scalar::from(0u32)), Some((0, 0)));
158 assert_eq!(bit_decomp_vartime(Scalar::from(1u32)), Some((1, 1)));
159 assert_eq!(bit_decomp_vartime(Scalar::from(2u32)), Some((2, 2)));
160 assert_eq!(bit_decomp_vartime(Scalar::from(3u32)), Some((3, 2)));
161 assert_eq!(bit_decomp_vartime(Scalar::from(4u32)), Some((4, 3)));
162 assert_eq!(bit_decomp_vartime(Scalar::from(5u32)), Some((5, 3)));
163 assert_eq!(bit_decomp_vartime(Scalar::from(6u32)), Some((6, 3)));
164 assert_eq!(bit_decomp_vartime(Scalar::from(7u32)), Some((7, 3)));
165 assert_eq!(bit_decomp_vartime(Scalar::from(8u32)), Some((8, 4)));
166 assert_eq!(bit_decomp_vartime(Scalar::from(1u32).neg()), None);
167 assert_eq!(
168 bit_decomp_vartime(Scalar::from((1u128 << 127) - 2)),
169 Some(((i128::MAX - 1) as u128, 127))
170 );
171 assert_eq!(
172 bit_decomp_vartime(Scalar::from((1u128 << 127) - 1)),
173 Some((i128::MAX as u128, 127))
174 );
175 assert_eq!(bit_decomp_vartime(Scalar::from(1u128 << 127)), None);
176
177 bit_decomp_tester(Scalar::from(0u32), 0, "");
178 bit_decomp_tester(Scalar::from(0u32), 5, "00000");
179 bit_decomp_tester(Scalar::from(1u32), 0, "");
180 bit_decomp_tester(Scalar::from(1u32), 1, "1");
181 bit_decomp_tester(Scalar::from(2u32), 1, "0");
182 bit_decomp_tester(Scalar::from(2u32), 2, "01");
183 bit_decomp_tester(Scalar::from(3u32), 1, "1");
184 bit_decomp_tester(Scalar::from(3u32), 2, "11");
185 bit_decomp_tester(Scalar::from(5u32), 8, "10100000");
186 bit_decomp_tester(
189 Scalar::from(1u32).neg(),
190 32,
191 "00110111110010111010111100111010",
192 );
193 bit_decomp_tester(Scalar::from((1u128 << 127) - 2), 127,
194 "0111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"
195 );
196 bit_decomp_tester(Scalar::from((1u128 << 127) - 1), 127,
197 "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"
198 );
199 bit_decomp_tester(Scalar::from(1u128 << 127), 127,
200 "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
201 );
202 bit_decomp_tester(Scalar::from(1u128 << 127), 128,
203 "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
204 );
205 }
206
207 fn bitrep_tester(upper: Scalar, x: Scalar, expected: bool) -> Result<(), Error> {
211 let rep_scalars = bitrep_scalars_vartime(upper)?;
212 let bitrep = compute_bitrep(x, &rep_scalars);
213
214 let nbits = bitrep.len();
215 assert!(nbits == rep_scalars.len());
216 let mut x_out = Scalar::ZERO;
217 for i in 0..nbits {
218 x_out += Scalar::conditional_select(&Scalar::ZERO, &rep_scalars[i], bitrep[i]);
219 }
220
221 if (x == x_out) != expected {
222 return Err(Error::VerificationFailure);
223 }
224
225 Ok(())
226 }
227
228 #[test]
229 fn bitrep_test() {
230 bitrep_tester(Scalar::from(0u32), Scalar::from(0u32), false).unwrap_err();
231 bitrep_tester(Scalar::from(1u32), Scalar::from(0u32), true).unwrap_err();
232 bitrep_tester(Scalar::from(2u32), Scalar::from(1u32), true).unwrap();
233 bitrep_tester(Scalar::from(3u32), Scalar::from(1u32), true).unwrap();
234 bitrep_tester(Scalar::from(100u32), Scalar::from(99u32), true).unwrap();
235 bitrep_tester(Scalar::from(127u32), Scalar::from(126u32), true).unwrap();
236 bitrep_tester(Scalar::from(128u32), Scalar::from(127u32), true).unwrap();
237 bitrep_tester(Scalar::from(128u32), Scalar::from(128u32), false).unwrap();
238 bitrep_tester(Scalar::from(129u32), Scalar::from(128u32), true).unwrap();
239 bitrep_tester(Scalar::from(129u32), Scalar::from(0u32), true).unwrap();
240 bitrep_tester(Scalar::from(129u32), Scalar::from(129u32), false).unwrap();
241 }
242}