Skip to main content

tasm_lib/arithmetic/u128/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u128::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
7pub struct SafeAdd;
8
9impl SafeAdd {
10    pub(crate) const OVERFLOW_ERROR_ID: i128 = 170;
11}
12
13impl BasicSnippet for SafeAdd {
14    fn parameters(&self) -> Vec<(DataType, String)> {
15        vec![
16            (DataType::U128, "lhs".to_owned()),
17            (DataType::U128, "rhs".to_owned()),
18        ]
19    }
20
21    fn return_values(&self) -> Vec<(DataType, String)> {
22        vec![(DataType::U128, "sum".to_owned())]
23    }
24
25    fn entrypoint(&self) -> String {
26        "tasmlib_arithmetic_u128_safe_add".to_string()
27    }
28
29    /// Four top elements of stack are assumed to be valid u32s. So to have
30    /// a value that's less than 2^32.
31    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
32        let add_code = OverflowingAdd::addition_code();
33
34        triton_asm! {
35            // BEFORE: _ rhs_3 rhs_2 rhs_1 rhs_0 lhs_3 lhs_2 lhs_1 lhs_0
36            // AFTER:  _ sum_3 sum_2 sum_1 sum_0
37            {self.entrypoint()}:
38                {&add_code}
39                // _ sum_3 sum_2 sum_1 sum_0 overflow
40
41                push 0
42                eq
43                assert error_id {Self::OVERFLOW_ERROR_ID}
44                return
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rand::rngs::StdRng;
52
53    use super::*;
54    use crate::test_prelude::*;
55
56    impl SafeAdd {
57        fn assert_expected_add_behavior(&self, lhs: u128, rhs: u128) {
58            let initial_stack = self.set_up_test_stack((lhs, rhs));
59
60            let mut expected_stack = initial_stack.clone();
61            self.rust_shadow(&mut expected_stack).unwrap();
62
63            test_rust_equivalence_given_complete_state(
64                &ShadowedClosure::new(Self),
65                &initial_stack,
66                &[],
67                &NonDeterminism::default(),
68                &None,
69                Some(&expected_stack),
70            );
71        }
72    }
73
74    impl Closure for SafeAdd {
75        type Args = <OverflowingAdd as Closure>::Args;
76
77        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
78            let (left, right) = pop_encodable::<Self::Args>(stack)?;
79            let sum = left
80                .checked_add(right)
81                .ok_or(RustShadowError::ArithmeticOverflow)?;
82            push_encodable(stack, &sum);
83            Ok(())
84        }
85
86        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
87            let mut rng = StdRng::from_seed(seed);
88            let lhs = rng.random();
89            let rhs = rng.random_range(0..=u128::MAX - lhs);
90
91            (lhs, rhs)
92        }
93
94        fn corner_case_args(&self) -> Vec<Self::Args> {
95            let edge_case_points = OverflowingAdd::edge_case_points();
96
97            edge_case_points
98                .iter()
99                .cartesian_product(&edge_case_points)
100                .filter(|&(&l, &r)| l.checked_add(r).is_some())
101                .map(|(&l, &r)| (l, r))
102                .collect()
103        }
104    }
105
106    #[macro_rules_attr::apply(test)]
107    fn rust_shadow() {
108        ShadowedClosure::new(SafeAdd).test()
109    }
110
111    #[macro_rules_attr::apply(test)]
112    fn unit_test() {
113        SafeAdd.assert_expected_add_behavior(1 << 67, 1 << 67)
114    }
115
116    #[macro_rules_attr::apply(test)]
117    fn overflow_test() {
118        for args in [
119            (1 << 127, 1 << 127),
120            (u128::MAX, u128::MAX),
121            (u128::MAX, 1),
122            (u128::MAX, 1 << 31),
123            (u128::MAX, 1 << 32),
124            (u128::MAX, 1 << 33),
125            (u128::MAX, 1 << 63),
126            (u128::MAX, 1 << 64),
127            (u128::MAX, 1 << 65),
128            (u128::MAX, 1 << 95),
129            (u128::MAX, 1 << 96),
130            (u128::MAX, 1 << 97),
131            (u128::MAX - 1, 2),
132        ]
133        .into_iter()
134        .flat_map(|(left, right)| [(left, right), (right, left)])
135        {
136            test_assertion_failure(
137                &ShadowedClosure::new(SafeAdd),
138                InitVmState::with_stack(SafeAdd.set_up_test_stack(args)),
139                &[SafeAdd::OVERFLOW_ERROR_ID],
140            );
141        }
142
143        for i in 0..128 {
144            let a = 1 << i;
145            let b = u128::MAX - a + 1;
146
147            assert_eq!((0, true), a.overflowing_add(b), "i = {i}. a = {a}, b = {b}");
148            test_assertion_failure(
149                &ShadowedClosure::new(SafeAdd),
150                InitVmState::with_stack(SafeAdd.set_up_test_stack((a, b))),
151                &[SafeAdd::OVERFLOW_ERROR_ID],
152            );
153        }
154    }
155}
156
157#[cfg(test)]
158mod benches {
159    use super::*;
160    use crate::test_prelude::*;
161
162    #[macro_rules_attr::apply(test)]
163    fn benchmark() {
164        ShadowedClosure::new(SafeAdd).bench()
165    }
166}