tasm_lib/arithmetic/u160/
div_mod.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic;
4use crate::prelude::*;
5
6/// Return the quotient and remainder for division of u160 values. This
7/// algorithm divines in the result from the secret-input stream. If the
8/// divined-in inputs are not correct, this snippet crashes the VM. If the
9/// divined input is not a `BFieldCodec` encoded u160, the VM also crashes.
10#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct DivMod;
12
13impl DivMod {
14    pub const DIVISION_BY_ZERO_ERROR_ID: i128 = 590;
15    pub const REMAINDER_TOO_BIG: i128 = 591;
16    pub const EULER_EQUATION_ERROR: i128 = 592;
17}
18
19impl BasicSnippet for DivMod {
20    fn parameters(&self) -> Vec<(DataType, String)> {
21        vec![
22            (DataType::U160, "numerator".to_owned()),
23            (DataType::U160, "denominator".to_owned()),
24        ]
25    }
26
27    fn return_values(&self) -> Vec<(DataType, String)> {
28        vec![
29            (DataType::U160, "quotient".to_owned()),
30            (DataType::U160, "remainder".to_owned()),
31        ]
32    }
33
34    fn entrypoint(&self) -> String {
35        "tasmlib_arithmetic_u160_div_mod".to_string()
36    }
37
38    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39        let lt = library.import(Box::new(arithmetic::u160::lt::Lt));
40        let safe_add = library.import(Box::new(arithmetic::u160::safe_add::SafeAdd));
41        let safe_mul = library.import(Box::new(arithmetic::u160::safe_mul::SafeMul));
42        let compare = DataType::U160.compare();
43
44        let remainder_pointer = library.kmalloc(5);
45        let quotient_pointer = library.kmalloc(5);
46
47        let divine_u160 = triton_asm!(
48            divine 5
49        );
50        let input_sanitation = triton_asm!(
51            dup 4
52            pop_count
53            // _ [x] pop_count(x_4)
54
55            dup 4
56            pop_count
57            // _ [x] p_cnt(x_4) p_cnt(x_3)
58
59            dup 4
60            pop_count
61            // _ [x] p_cnt(x_4) p_cnt(x_3) p_cnt(x_2)
62
63            dup 4
64            pop_count
65            // _ [x] p_cnt(x_4) p_cnt(x_3) p_cnt(x_2) p_cnt(x_1)
66
67            dup 4
68            pop_count
69            // _ [x] p_cnt(x_4) p_cnt(x_3) p_cnt(x_2) p_cnt(x_1) p_cnt(x_0)
70
71            pop 5
72            // _ [x]
73        );
74
75        triton_asm!(
76            // BEFORE: _ n_4 n_3 n_2 n_1 n_0 d_4 d_3 d_2 d_1 d_0
77            // AFTER:  _ q_4 q_3 q_2 q_1 q_0 r_4 r_3 r_2 r_1 r_0
78            {self.entrypoint()}:
79                // _ [n] [d]
80
81
82                /* Assert denominator ≠ 0 */
83                {&triton_asm![dup 4; 5]}
84                {&triton_asm![push 0; 5]}
85                // _ [n] [d] [d] [0]
86
87                {&compare}
88                // _ [n] [d] (d == 0)
89
90                push 0
91                eq
92                assert error_id {Self::DIVISION_BY_ZERO_ERROR_ID}
93                // _ [n] [d]
94
95
96                /* Read and input and assert well-formed u160s */
97                {&divine_u160}
98                {&input_sanitation}
99                {&divine_u160}
100                {&input_sanitation}
101                // _ [n] [d] [q] [r]
102
103
104                /* Verify r < d */
105                {&triton_asm![dup 14; 5]}
106                {&triton_asm![dup 9; 5]}
107                // _ [n] [d] [q] [r] [d] [r]
108
109                call {lt}
110                // _ [n] [d] [q] [r] (r < d)
111
112                assert error_id {Self::REMAINDER_TOO_BIG}
113                // _ [n] [d] [q] [r]
114
115
116                /* Verify q * d + r == n */
117                push {remainder_pointer.write_address()}
118                write_mem 5
119                pop 1
120                // _ [n] [d] [q]
121
122                {&triton_asm![dup 4; 5]}
123                // _ [n] [d] [q] [q]
124
125                push {quotient_pointer.write_address()}
126                write_mem 5
127                pop 1
128                // _ [n] [d] [q]
129
130                call {safe_mul}
131                // _ [n] [d * q]
132
133                push {remainder_pointer.read_address()}
134                read_mem 5
135                pop 1
136                // _ [n] [d * q] [r]
137
138                call {safe_add}
139                // _ [n] [d * q + r]
140
141                {&compare}
142                // _ (n == d * q + r)
143
144                assert error_id {Self::EULER_EQUATION_ERROR}
145                // _
146
147
148                /* Put return value on stack */
149                push {quotient_pointer.read_address()}
150                read_mem 5
151                pop 1
152                // _ [q]
153
154                push {remainder_pointer.read_address()}
155                read_mem 5
156                pop 1
157                // _ [q] [r]
158
159                return
160        )
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use num::BigUint;
167    use num::Integer;
168    use num_traits::Zero;
169    use rand::rngs::StdRng;
170
171    use super::*;
172    use crate::arithmetic::u160::u128_to_u160;
173    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
174    use crate::test_helpers::negative_test;
175    use crate::test_prelude::Algorithm;
176    use crate::test_prelude::*;
177
178    #[test]
179    fn std_test() {
180        ShadowedAlgorithm::new(DivMod).test()
181    }
182
183    #[test]
184    fn too_big_remainder() {
185        let numerator = 10;
186        let denominator = 3;
187        let stack = DivMod.stack(u128_to_u160(numerator), u128_to_u160(denominator));
188
189        let too_big_remainder = 4;
190        let too_small_quotient = 2;
191        let nondeterminism = DivMod::nondeterminism(
192            u128_to_u160(too_small_quotient).map(|x| bfe!(x)),
193            u128_to_u160(too_big_remainder).map(|x| bfe!(x)),
194        );
195
196        test_assertion_failure(
197            &ShadowedAlgorithm::new(DivMod),
198            InitVmState {
199                nondeterminism,
200                stack,
201                ..Default::default()
202            },
203            &[DivMod::REMAINDER_TOO_BIG],
204        )
205    }
206
207    #[test]
208    fn euler_equation_error() {
209        let numerator = 1u128 << 99;
210        let denominator = 1u128 << 45;
211        let stack = DivMod.stack(u128_to_u160(numerator), u128_to_u160(denominator));
212
213        let correct_remainder = 0;
214        let bad_quotient = 1u128 << 43;
215        let nondeterminism = DivMod::nondeterminism(
216            u128_to_u160(bad_quotient).map(|x| bfe!(x)),
217            u128_to_u160(correct_remainder).map(|x| bfe!(x)),
218        );
219
220        test_assertion_failure(
221            &ShadowedAlgorithm::new(DivMod),
222            InitVmState {
223                nondeterminism,
224                stack,
225                ..Default::default()
226            },
227            &[DivMod::EULER_EQUATION_ERROR],
228        )
229    }
230
231    #[test]
232    fn division_by_zero() {
233        let numerator = u128_to_u160(52);
234        let denominator = u128_to_u160(0);
235        let state = DivMod.prepare_state(numerator, denominator);
236        test_assertion_failure(
237            &ShadowedAlgorithm::new(DivMod),
238            state.into(),
239            &[DivMod::DIVISION_BY_ZERO_ERROR_ID],
240        );
241    }
242
243    #[test]
244    fn bad_encoding() {
245        let not_u32 = bfe!(1u64 << 32);
246        let valid_u32 = bfe!(14);
247        let mut nondeterminism = DivMod::nondeterminism([valid_u32; 5], [valid_u32; 5]);
248        let stack = DivMod.stack(u128_to_u160(4), u128_to_u160(4));
249
250        for i in 0..10 {
251            nondeterminism.individual_tokens[i] = not_u32;
252            negative_test(
253                &ShadowedAlgorithm::new(DivMod),
254                InitVmState {
255                    nondeterminism: nondeterminism.clone(),
256                    stack: stack.clone(),
257                    ..Default::default()
258                },
259                &[InstructionError::OpStackError(
260                    OpStackError::FailedU32Conversion(not_u32),
261                )],
262            );
263
264            nondeterminism.individual_tokens[i] = valid_u32;
265        }
266    }
267
268    impl DivMod {
269        fn nondeterminism(
270            quotient: [BFieldElement; 5],
271            remainder: [BFieldElement; 5],
272        ) -> NonDeterminism {
273            let individual_tokens = [quotient, remainder].concat();
274
275            NonDeterminism {
276                individual_tokens,
277                digests: Default::default(),
278                ram: Default::default(),
279            }
280        }
281
282        fn stack(&self, numerator: [u32; 5], denominator: [u32; 5]) -> Vec<BFieldElement> {
283            let mut stack = self.init_stack_for_isolated_run();
284            push_encodable(&mut stack, &numerator);
285            push_encodable(&mut stack, &denominator);
286
287            stack
288        }
289
290        fn prepare_state(
291            &self,
292            numerator: [u32; 5],
293            denominator: [u32; 5],
294        ) -> AlgorithmInitialState {
295            let stack = self.stack(numerator, denominator);
296
297            let numerator: BigUint = BigUint::new(numerator.to_vec());
298            let denominator: BigUint = BigUint::new(denominator.to_vec());
299
300            let (quotient, remainder) = if denominator.is_zero() {
301                // Special-casing to prevent crash in state-preparation.
302                // Dividing by zero always crashes execution, so the values used
303                // here are irrelevant.
304                (0u32.into(), 0u32.into())
305            } else {
306                numerator.div_rem(&denominator)
307            };
308
309            let mut quotient = quotient.to_u32_digits();
310            quotient.resize(5, 0);
311            quotient.reverse();
312            let quotient: [u32; 5] = quotient.try_into().unwrap();
313
314            let mut remainder = remainder.to_u32_digits();
315            remainder.resize(5, 0);
316            remainder.reverse();
317            let remainder: [u32; 5] = remainder.try_into().unwrap();
318
319            let nondeterminism = Self::nondeterminism(
320                quotient.encode().try_into().unwrap(),
321                remainder.encode().try_into().unwrap(),
322            );
323
324            AlgorithmInitialState {
325                stack,
326                nondeterminism,
327            }
328        }
329
330        fn edge_case_points() -> Vec<[u32; 5]> {
331            [2, 1 << 32, 1 << 96, u128::MAX]
332                .into_iter()
333                .flat_map(|p| [p.checked_sub(1), Some(p), p.checked_add(1)])
334                .flatten()
335                .map(u128_to_u160)
336                .chain([[u32::MAX; 5]])
337                .collect()
338        }
339    }
340
341    impl Algorithm for DivMod {
342        fn rust_shadow(
343            &self,
344            stack: &mut Vec<BFieldElement>,
345            memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>,
346            nondeterminism: &NonDeterminism,
347        ) {
348            let denominator: [u32; 5] = pop_encodable(stack);
349            let denominator: BigUint = BigUint::new(denominator.to_vec());
350            assert!(!denominator.is_zero());
351
352            let numerator: [u32; 5] = pop_encodable(stack);
353            let numerator: BigUint = BigUint::new(numerator.to_vec());
354
355            let quotient = &nondeterminism.individual_tokens[0..5];
356            let mut quotient: [u32; 5] =
357                *TasmObject::decode_iter(&mut quotient.iter().cloned()).unwrap();
358            quotient.reverse();
359            let quotient: BigUint = BigUint::new(quotient.to_vec());
360
361            let remainder = &nondeterminism.individual_tokens[5..10];
362            let mut remainder: [u32; 5] =
363                *TasmObject::decode_iter(&mut remainder.iter().cloned()).unwrap();
364            remainder.reverse();
365            let remainder: BigUint = BigUint::new(remainder.to_vec());
366
367            assert!(remainder < denominator);
368            assert!(numerator == quotient.clone() * denominator + remainder.clone());
369
370            let mut quotient = quotient.to_u32_digits();
371            quotient.resize(5, 0);
372            let quotient: [u32; 5] = quotient.try_into().unwrap();
373
374            let mut remainder = remainder.to_u32_digits();
375            remainder.resize(5, 0);
376            let remainder: [u32; 5] = remainder.try_into().unwrap();
377
378            // Imitate use of static memory
379            encode_to_memory(memory, STATIC_MEMORY_FIRST_ADDRESS - bfe!(4), &remainder);
380            encode_to_memory(memory, STATIC_MEMORY_FIRST_ADDRESS - bfe!(9), &quotient);
381
382            push_encodable(stack, &quotient);
383            push_encodable(stack, &remainder);
384        }
385
386        fn pseudorandom_initial_state(
387            &self,
388            seed: [u8; 32],
389            _bench_case: Option<crate::test_prelude::BenchmarkCase>,
390        ) -> AlgorithmInitialState {
391            let mut rng = StdRng::from_seed(seed);
392            let numerator: [u32; 5] = rng.random();
393            let mut denominator: [u32; 5] = rng.random();
394
395            // Limit size of denominator
396            let denominator_div = rng.next_u32();
397            denominator[4] /= denominator_div;
398
399            self.prepare_state(numerator, denominator)
400        }
401
402        fn corner_case_initial_states(&self) -> Vec<AlgorithmInitialState> {
403            let edge_case_points = Self::edge_case_points();
404
405            edge_case_points
406                .iter()
407                .cartesian_product(&edge_case_points)
408                .map(|(l, r)| self.prepare_state(*l, *r))
409                .collect()
410        }
411    }
412}
413
414#[cfg(test)]
415mod benches {
416    use super::*;
417    use crate::test_prelude::*;
418
419    #[test]
420    fn benchmark() {
421        ShadowedAlgorithm::new(DivMod).bench()
422    }
423}