tasm_lib/arithmetic/u192/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u192::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct SafeAdd;
8
9/// Sum two u192 values, crashing the VM on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [rhs: u192] [lhs: u192]
15/// AFTER:  _ [sum: u192]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21///
22/// ### Postconditions
23///
24/// - the output is properly [`BFieldCodec`] encoded
25impl SafeAdd {
26    pub(crate) const OVERFLOW_ERROR_ID: i128 = 600;
27}
28
29impl BasicSnippet for SafeAdd {
30    fn parameters(&self) -> Vec<(DataType, String)> {
31        vec![
32            (DataType::U192, "l".to_owned()),
33            (DataType::U192, "r".to_owned()),
34        ]
35    }
36
37    fn return_values(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U192, "sum".to_owned())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u192_safe_add".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        let add_code = OverflowingAdd::addition_code();
47
48        triton_asm! {
49            // BEFORE: _ r5 r4 r3 r2 r1 r0 l5 l4 l3 l2 l1 l0
50            // AFTER:  _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0
51            {self.entrypoint()}:
52                {&add_code}
53                // _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0 overflow
54
55                push 0
56                eq
57                assert error_id {Self::OVERFLOW_ERROR_ID}
58                // _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0
59                // _ [sum]
60
61                return
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use num::BigUint;
69    use rand::rngs::StdRng;
70
71    use super::*;
72    use crate::arithmetic::u192::U192;
73    use crate::arithmetic::u192::to_u192;
74    use crate::arithmetic::u192::u128_to_u192_shl64;
75    use crate::test_prelude::*;
76
77    #[test]
78    fn rust_shadow() {
79        ShadowedClosure::new(SafeAdd).test()
80    }
81
82    #[test]
83    fn overflow_test() {
84        for (left, right) in [
85            (1 << 127, 1 << 127),
86            (u128::MAX, u128::MAX),
87            (u128::MAX, 1),
88            (u128::MAX, 1 << 31),
89            (u128::MAX, 1 << 32),
90            (u128::MAX, 1 << 33),
91            (u128::MAX, 1 << 63),
92            (u128::MAX, 1 << 64),
93            (u128::MAX, 1 << 65),
94            (u128::MAX, 1 << 95),
95            (u128::MAX, 1 << 96),
96            (u128::MAX, 1 << 97),
97            (u128::MAX - 1, 2),
98        ]
99        .into_iter()
100        .flat_map(|(left, right)| [(left, right), (right, left)])
101        {
102            let left = to_u192(left, u64::MAX);
103            let right = u128_to_u192_shl64(right);
104            test_assertion_failure(
105                &ShadowedClosure::new(SafeAdd),
106                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
107                &[SafeAdd::OVERFLOW_ERROR_ID],
108            );
109        }
110
111        for i in 0..128 {
112            let left = 1 << i;
113            let right = u128::MAX - left + 1;
114
115            assert_eq!(
116                (0, true),
117                left.overflowing_add(right),
118                "i = {i}. a = {left}, b = {right}"
119            );
120
121            let left = to_u192(left, u64::MAX);
122            let right = u128_to_u192_shl64(right);
123
124            test_assertion_failure(
125                &ShadowedClosure::new(SafeAdd),
126                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
127                &[SafeAdd::OVERFLOW_ERROR_ID],
128            );
129        }
130    }
131
132    impl Closure for SafeAdd {
133        type Args = <OverflowingAdd as Closure>::Args;
134
135        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
136            let left: U192 = pop_encodable(stack);
137            let left: BigUint = BigUint::new(left.to_vec());
138            let right: U192 = pop_encodable(stack);
139            let right: BigUint = BigUint::new(right.to_vec());
140            let sum = left + right;
141            let mut sum = sum.to_u32_digits();
142            assert!(sum.len() <= 6, "Overflow");
143
144            sum.resize(6, 0);
145            let sum: U192 = sum.try_into().unwrap();
146
147            push_encodable(stack, &sum);
148        }
149
150        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
151            let mut rng = StdRng::from_seed(seed);
152            let lhs: U192 = rng.random();
153            let lhs_as_biguint: BigUint = BigUint::new(lhs.to_vec());
154
155            let u192_max = BigUint::from_bytes_be(&[0xFF; 24]);
156            let max = &u192_max - &lhs_as_biguint;
157
158            // Generate random bytes for rhs, making sure the value < max
159            let mut rhs_bytes = [0u8; 24];
160            let rhs = loop {
161                rng.fill(&mut rhs_bytes);
162                let candidate = BigUint::from_bytes_be(&rhs_bytes);
163                if candidate < max {
164                    break candidate;
165                }
166            };
167
168            let mut rhs = rhs.to_u32_digits();
169            rhs.resize(6, 0);
170
171            (lhs, rhs.try_into().unwrap())
172        }
173
174        fn corner_case_args(&self) -> Vec<Self::Args> {
175            fn u192_checked_add(l: U192, r: U192) -> Option<U192> {
176                let l: BigUint = BigUint::new(l.to_vec());
177                let r: BigUint = BigUint::new(r.to_vec());
178
179                let sum = l + r;
180                let mut sum = sum.to_u32_digits();
181
182                if sum.len() > 6 {
183                    None
184                } else {
185                    sum.resize(6, 0);
186                    Some(sum.try_into().unwrap())
187                }
188            }
189
190            let edge_case_points = OverflowingAdd::edge_case_points();
191
192            edge_case_points
193                .iter()
194                .cartesian_product(&edge_case_points)
195                .filter(|&(&l, &r)| u192_checked_add(l, r).is_some())
196                .map(|(&l, &r)| (l, r))
197                .collect()
198        }
199    }
200}
201
202#[cfg(test)]
203mod benches {
204    use super::*;
205    use crate::test_prelude::*;
206
207    #[test]
208    fn benchmark() {
209        ShadowedClosure::new(SafeAdd).bench()
210    }
211}