Skip to main content

tasm_lib/arithmetic/u160/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u160::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct SafeAdd;
8
9/// Sum two u160 values, crashing the VM on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [rhs: u160] [lhs: u160]
15/// AFTER:  _ [sum: u160]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21///
22/// ### Postconditions
23///
24/// - the output is properly [`BFieldCodec`] encoded
25impl SafeAdd {
26    pub(crate) const OVERFLOW_ERROR_ID: i128 = 570;
27}
28
29impl BasicSnippet for SafeAdd {
30    fn parameters(&self) -> Vec<(DataType, String)> {
31        vec![
32            (DataType::U160, "lhs".to_owned()),
33            (DataType::U160, "rhs".to_owned()),
34        ]
35    }
36
37    fn return_values(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U160, "sum".to_owned())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u160_safe_add".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        let add_code = OverflowingAdd::addition_code();
47
48        triton_asm! {
49            // BEFORE: _ rhs_4 rhs_3 rhs_2 rhs_1 rhs_0 lhs_4 lhs_3 lhs_2 lhs_1 lhs_0
50            // AFTER:  _ sum_4 sum_3 sum_2 sum_1 sum_0
51            {self.entrypoint()}:
52                {&add_code}
53                // _ sum_4 sum_3 sum_2 sum_1 sum_0 overflow
54
55                push 0
56                eq
57                assert error_id {Self::OVERFLOW_ERROR_ID}
58                // _ sum_4 sum_3 sum_2 sum_1 sum_0
59                // _ [sum]
60
61                return
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use num::BigUint;
69    use rand::rngs::StdRng;
70
71    use super::*;
72    use crate::arithmetic::u160::u128_to_u160_shl_32;
73    use crate::arithmetic::u160::u128_to_u160_shl_32_lower_limb_filled;
74    use crate::test_prelude::*;
75
76    #[macro_rules_attr::apply(test)]
77    fn rust_shadow() {
78        ShadowedClosure::new(SafeAdd).test()
79    }
80
81    #[macro_rules_attr::apply(test)]
82    fn overflow_test() {
83        for (left, right) in [
84            (1 << 127, 1 << 127),
85            (u128::MAX, u128::MAX),
86            (u128::MAX, 1),
87            (u128::MAX, 1 << 31),
88            (u128::MAX, 1 << 32),
89            (u128::MAX, 1 << 33),
90            (u128::MAX, 1 << 63),
91            (u128::MAX, 1 << 64),
92            (u128::MAX, 1 << 65),
93            (u128::MAX, 1 << 95),
94            (u128::MAX, 1 << 96),
95            (u128::MAX, 1 << 97),
96            (u128::MAX - 1, 2),
97        ]
98        .into_iter()
99        .flat_map(|(left, right)| [(left, right), (right, left)])
100        {
101            let left = u128_to_u160_shl_32_lower_limb_filled(left);
102            let right = u128_to_u160_shl_32(right);
103            test_assertion_failure(
104                &ShadowedClosure::new(SafeAdd),
105                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
106                &[SafeAdd::OVERFLOW_ERROR_ID],
107            );
108        }
109
110        for i in 0..128 {
111            let left = 1 << i;
112            let right = u128::MAX - left + 1;
113
114            assert_eq!(
115                (0, true),
116                left.overflowing_add(right),
117                "i = {i}. a = {left}, b = {right}"
118            );
119
120            let left = u128_to_u160_shl_32_lower_limb_filled(left);
121            let right = u128_to_u160_shl_32(right);
122
123            test_assertion_failure(
124                &ShadowedClosure::new(SafeAdd),
125                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
126                &[SafeAdd::OVERFLOW_ERROR_ID],
127            );
128        }
129    }
130
131    impl Closure for SafeAdd {
132        type Args = <OverflowingAdd as Closure>::Args;
133
134        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
135            let left: [u32; 5] = pop_encodable(stack)?;
136            let left: BigUint = BigUint::new(left.to_vec());
137            let right: [u32; 5] = pop_encodable(stack)?;
138            let right: BigUint = BigUint::new(right.to_vec());
139            let sum = left + right;
140            let mut sum = sum.to_u32_digits();
141            if sum.len() > 5 {
142                return Err(RustShadowError::ArithmeticOverflow);
143            }
144
145            sum.resize(5, 0);
146            let sum: [u32; 5] = sum.try_into().unwrap();
147
148            push_encodable(stack, &sum);
149            Ok(())
150        }
151
152        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
153            let mut rng = StdRng::from_seed(seed);
154            let lhs: [u32; 5] = rng.random();
155            let lhs_as_biguint: BigUint = BigUint::new(lhs.to_vec());
156
157            let u160_max = BigUint::from_bytes_be(&[0xFF; 20]);
158            let max = &u160_max - &lhs_as_biguint;
159
160            // Generate random bytes for rhs, making sure the value < max
161            let mut rhs_bytes = [0u8; 20];
162            let rhs = loop {
163                rng.fill(&mut rhs_bytes);
164                let candidate = BigUint::from_bytes_be(&rhs_bytes);
165                if candidate < max {
166                    break candidate;
167                }
168            };
169
170            let mut rhs = rhs.to_u32_digits();
171            rhs.resize(5, 0);
172
173            (lhs, rhs.try_into().unwrap())
174        }
175
176        fn corner_case_args(&self) -> Vec<Self::Args> {
177            fn u160_checked_add(l: [u32; 5], r: [u32; 5]) -> Option<[u32; 5]> {
178                let l: BigUint = BigUint::new(l.to_vec());
179                let r: BigUint = BigUint::new(r.to_vec());
180
181                let sum = l + r;
182                let mut sum = sum.to_u32_digits();
183
184                if sum.len() > 5 {
185                    None
186                } else {
187                    sum.resize(5, 0);
188                    Some(sum.try_into().unwrap())
189                }
190            }
191
192            let edge_case_points = OverflowingAdd::edge_case_points();
193
194            edge_case_points
195                .iter()
196                .cartesian_product(&edge_case_points)
197                .filter(|&(&l, &r)| u160_checked_add(l, r).is_some())
198                .map(|(&l, &r)| (l, r))
199                .collect()
200        }
201    }
202}
203
204#[cfg(test)]
205mod benches {
206    use super::*;
207    use crate::test_prelude::*;
208
209    #[macro_rules_attr::apply(test)]
210    fn benchmark() {
211        ShadowedClosure::new(SafeAdd).bench()
212    }
213}