Skip to main content

sapling_crypto_ce/circuit/
float_point.rs

1use bellman::pairing::{Engine};
2use bellman::pairing::ff::{Field, PrimeField};
3use bellman::{ConstraintSystem, SynthesisError};
4use super::boolean::{Boolean};
5use super::num::{AllocatedNum, Num};
6use super::Assignment;
7
8/// Takes a bit decomposition, parses and packs into an AllocatedNum
9/// If exponent is equal to zero, then exponent multiplier is equal to 1
10pub fn parse_with_exponent_le<E: Engine, CS: ConstraintSystem<E>>(
11    mut cs: CS,
12    bits: &[Boolean],
13    exponent_length: usize,
14    mantissa_length: usize,
15    exponent_base: u64
16) -> Result<AllocatedNum<E>, SynthesisError>
17{
18    assert!(bits.len() == exponent_length + mantissa_length);
19
20    let one_allocated = AllocatedNum::alloc(
21        cs.namespace(|| "allocate one"),
22        || Ok(E::Fr::one())
23    )?;
24
25    let mut exponent_result = AllocatedNum::alloc(
26        cs.namespace(|| "allocate exponent result"),
27        || Ok(E::Fr::one())
28    )?;
29
30    let exponent_base_string = exponent_base.to_string();
31    let exponent_base_value = E::Fr::from_str(&exponent_base_string.clone()).unwrap();
32
33    let mut exponent_base = AllocatedNum::alloc(
34        cs.namespace(|| "allocate exponent base"), 
35        || Ok(exponent_base_value)
36    )?;
37
38    let exponent_value = exponent_base_value;
39
40    for i in 0..exponent_length {
41        let thisbit = &bits[i];
42
43        let multiplier = AllocatedNum::conditionally_select(
44            cs.namespace(|| format!("select exponent multiplier {}", i)),
45            &exponent_base, 
46            &one_allocated, 
47            &thisbit
48        )?;
49
50        exponent_result = exponent_result.mul(
51            cs.namespace(|| format!("make exponent result {}", i)),
52            &multiplier
53        )?;
54
55        exponent_base = exponent_base.clone().square(
56            cs.namespace(|| format!("make exponent base {}", i))
57        )?;
58
59        // exponent_base = exponent_base.mul(
60        //     cs.namespace(|| format!("make exponent base {}", i)), 
61        //     &exponent_base.clone()
62        // )?;
63    }
64
65    let mut mantissa_result = Num::<E>::zero();
66    let mut mantissa_base = E::Fr::one();
67
68    for i in exponent_length..(exponent_length+mantissa_length)
69    {
70        let thisbit = &bits[i];
71        mantissa_result = mantissa_result.add_bool_with_coeff(CS::one(), &thisbit, mantissa_base);
72        mantissa_base.double();
73    }
74
75    let mantissa = AllocatedNum::alloc(
76        cs.namespace(|| "allocating mantissa"),
77        || Ok(*mantissa_result.get_value().get()?)
78    )?;
79
80    mantissa.mul(
81        cs.namespace(|| "calculate floating point result"),
82        &exponent_result
83    )
84}
85
86pub fn convert_to_float(
87    integer: u128,
88    exponent_length: usize,
89    mantissa_length: usize,
90    exponent_base: u32
91) -> Result<Vec<bool>, SynthesisError>
92{
93    let exponent_base = u128::from(exponent_base);
94    let mut max_exponent = 1u128;
95    let max_power = (1 << exponent_length) - 1;
96
97    for _ in 0..max_power
98    {
99        max_exponent = max_exponent * exponent_base;
100    }
101
102    let max_mantissa = (1u128 << mantissa_length) - 1;
103    
104    if integer > (max_mantissa * max_exponent) {
105        return Err(SynthesisError::Unsatisfiable)
106    }
107
108    let mut exponent: usize = 0;
109    let mut mantissa = integer;
110
111    if integer > max_mantissa {
112    // always try best precision
113        let exponent_guess = integer / max_mantissa;
114        let mut exponent_temp = exponent_guess;
115
116        loop {
117            if exponent_temp < exponent_base {
118                break
119            }
120            exponent_temp = exponent_temp / exponent_base;
121            exponent += 1;
122        }
123
124        exponent_temp = 1u128;
125        for _ in 0..exponent 
126        {
127            exponent_temp = exponent_temp * exponent_base;
128        }    
129
130        if exponent_temp * max_mantissa < integer 
131        {
132            exponent += 1;
133            exponent_temp = exponent_temp * exponent_base;
134        }
135
136        mantissa = integer / exponent_temp;
137    }
138
139    // encode into bits. First bits of mantissa in LE order
140
141    let mut encoding = vec![];
142
143    for i in 0..exponent_length {
144        if exponent & (1 << i) != 0 {
145            encoding.push(true);
146        } else {
147            encoding.push(false);
148        }
149    }
150
151    for i in 0..mantissa_length {
152        if mantissa & (1 << i) != 0 {
153            encoding.push(true);
154        } else {
155            encoding.push(false);
156        }
157    }
158
159    assert!(encoding.len() == exponent_length + mantissa_length);
160
161    Ok(encoding)
162}
163
164
165pub fn parse_float_to_u128(
166    encoding: Vec<bool>,
167    exponent_length: usize,
168    mantissa_length: usize,
169    exponent_base: u32
170) -> Result<u128, SynthesisError>
171{
172    assert!(exponent_length + mantissa_length == encoding.len());
173
174    let exponent_base = u128::from(exponent_base);
175    let mut exponent_multiplier = exponent_base;
176    let mut exponent = 1u128;
177    let bitslice: &[bool] = &encoding;
178    for i in 0..exponent_length
179    {
180        if bitslice[i] {
181            let max_exponent = (u128::max_value() / exponent_multiplier) + 1;
182            if exponent >= max_exponent {
183                return Err(SynthesisError::Unsatisfiable)
184            }
185            exponent = exponent * exponent_multiplier;
186        }
187        exponent_multiplier = exponent_multiplier * exponent_multiplier;
188    }
189
190    let mut max_mantissa = u128::max_value();
191    if exponent != 1u128 {
192        max_mantissa = (u128::max_value() / exponent) + 1;
193    }
194
195    let mut mantissa_power = 1u128;
196    let mut mantissa = 0u128;
197    for i in exponent_length..(exponent_length + mantissa_length)
198    {
199        if bitslice[i] {
200            let max_mant = (max_mantissa / 2u128) + 1;
201            if mantissa >= max_mant {
202                return Err(SynthesisError::Unsatisfiable)
203            }
204            mantissa = mantissa + mantissa_power;
205        }
206        mantissa_power = mantissa_power * 2u128;
207    }
208
209    let result = mantissa * exponent;
210
211    Ok(result)
212}
213
214#[test]
215fn test_parsing() {
216    use rand::{SeedableRng, Rng, XorShiftRng};
217    use bellman::{ConstraintSystem};
218    use bellman::pairing::bn256::{Bn256};
219    use ::circuit::test::*;
220    use super::boolean::{AllocatedBit, Boolean};
221
222    let rng = XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
223
224    let mut cs = TestConstraintSystem::<Bn256>::new();
225
226    // exp = 1  
227    // let bits: Vec<bool> = vec![false, false, false, false, false, true];
228
229    // exp = 10
230    // let bits: Vec<bool> = vec![true, false, false, false, false, true];
231
232    // exp = 1000 = 10^3
233    // let bits: Vec<bool> = vec![true, true, false, false, false, true];
234
235    // exp = 10^7 = 10000000
236    // let bits: Vec<bool> = vec![true, true, true, false, false, true];
237
238    // exp = 10^15 = 1000000000000000
239    let bits: Vec<bool> = vec![true, true, true, true, false, true];
240
241    // let bits: Vec<bool> = vec![true, true, true, true, true, true];
242
243    let circuit_bits = bits.iter().enumerate()
244                            .map(|(i, &b)| {
245                                Boolean::from(
246                                    AllocatedBit::alloc(
247                                    cs.namespace(|| format!("bit {}", i)),
248                                    Some(b)
249                                    ).unwrap()
250                                )
251                            })
252                            .collect::<Vec<_>>();
253
254    let exp_result = parse_with_exponent_le(cs.namespace(|| "parse"), &circuit_bits, 5, 1, 10).unwrap();
255
256    print!("{}\n", exp_result.get_value().unwrap().into_repr());
257    assert!(cs.is_satisfied());
258    print!("constraints for float parsing = {}\n", cs.num_constraints());
259}
260
261#[test]
262fn test_encoding() {
263    use rand::{SeedableRng, Rng, XorShiftRng};
264    let mut rng = XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
265
266    // max encoded value is 10^31 * 2047 ~ 10^34 ~ 112 bits
267
268    for _ in 0..1000 {
269        let top_word = rng.next_u64() & 0x0000ffffffffffff;
270        let bottom_word = rng.next_u64();
271        let integer = (u128::from(top_word) << 64) + u128::from(bottom_word);
272
273        let encoding = convert_to_float(integer, 5, 11, 10);
274
275        assert!(encoding.is_ok());
276
277        let decoded = parse_float_to_u128(encoding.unwrap(), 5, 11, 10);
278
279        assert!(decoded.is_ok());
280
281        let dec = decoded.unwrap();
282
283        assert!(integer/dec == 1u128);
284        assert!(dec/integer <= 1u128);
285    }
286}
287
288#[test]
289fn test_encoding_powers_of_two() {
290    use rand::{SeedableRng, Rng, XorShiftRng};
291    let rng = XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
292
293    let mantissa_length = 11;
294
295    for i in 0..mantissa_length {
296        let mantissa = 1u128 << i;
297        let encoding = convert_to_float(mantissa, 5, mantissa_length, 10).unwrap();
298        for (j, bit) in encoding.into_iter().enumerate() {
299            if j != 5 + i  {
300                assert!(!bit);
301            } else {
302                assert!(bit);
303            }
304        }
305    }
306}
307
308#[test]
309fn test_encoding_small_numbers() {
310    use rand::{SeedableRng, Rng, XorShiftRng};
311    let rng = XorShiftRng::from_seed([0x3dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]);
312
313    let mantissa_length = 11;
314
315    for i in 0..20 {
316        let encoding = convert_to_float(i as u128, 5, 11, 10).unwrap();
317        for bit in encoding.into_iter() {
318            if bit {
319                print!("1");
320            } else {
321                print!("0")
322            }
323        }
324        print!("\n");
325    }
326}
327
328#[test]
329fn test_encoding_specific() {
330    let encoding = convert_to_float(20400, 5, 11, 10).unwrap();
331    for bit in encoding.clone().into_iter(){
332        if bit {
333            print!("1");
334        } else {
335            print!("0")
336        }
337    }
338    print!("\n");
339    let decoded = parse_float_to_u128(encoding, 5, 11, 10).unwrap();
340    println!("Decode = {}", decoded);
341}