tasm_lib/arithmetic/u128/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u128::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
7pub struct SafeAdd;
8
9impl SafeAdd {
10    pub(crate) const OVERFLOW_ERROR_ID: i128 = 170;
11}
12
13impl BasicSnippet for SafeAdd {
14    fn parameters(&self) -> Vec<(DataType, String)> {
15        vec![
16            (DataType::U128, "lhs".to_owned()),
17            (DataType::U128, "rhs".to_owned()),
18        ]
19    }
20
21    fn return_values(&self) -> Vec<(DataType, String)> {
22        vec![(DataType::U128, "sum".to_owned())]
23    }
24
25    fn entrypoint(&self) -> String {
26        "tasmlib_arithmetic_u128_safe_add".to_string()
27    }
28
29    /// Four top elements of stack are assumed to be valid u32s. So to have
30    /// a value that's less than 2^32.
31    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
32        let add_code = OverflowingAdd::addition_code();
33
34        triton_asm! {
35            // BEFORE: _ rhs_3 rhs_2 rhs_1 rhs_0 lhs_3 lhs_2 lhs_1 lhs_0
36            // AFTER:  _ sum_3 sum_2 sum_1 sum_0
37            {self.entrypoint()}:
38                {&add_code}
39                // _ sum_3 sum_2 sum_1 sum_0 overflow
40
41                push 0
42                eq
43                assert error_id {Self::OVERFLOW_ERROR_ID}
44                return
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rand::rngs::StdRng;
52
53    use super::*;
54    use crate::test_prelude::*;
55
56    impl SafeAdd {
57        fn assert_expected_add_behavior(&self, lhs: u128, rhs: u128) {
58            let initial_stack = self.set_up_test_stack((lhs, rhs));
59
60            let mut expected_stack = initial_stack.clone();
61            self.rust_shadow(&mut expected_stack);
62
63            test_rust_equivalence_given_complete_state(
64                &ShadowedClosure::new(Self),
65                &initial_stack,
66                &[],
67                &NonDeterminism::default(),
68                &None,
69                Some(&expected_stack),
70            );
71        }
72    }
73
74    impl Closure for SafeAdd {
75        type Args = <OverflowingAdd as Closure>::Args;
76
77        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
78            let (left, right) = pop_encodable::<Self::Args>(stack);
79            let sum = left.checked_add(right).unwrap();
80            push_encodable(stack, &sum);
81        }
82
83        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
84            let mut rng = StdRng::from_seed(seed);
85            let lhs = rng.random();
86            let rhs = rng.random_range(0..=u128::MAX - lhs);
87
88            (lhs, rhs)
89        }
90
91        fn corner_case_args(&self) -> Vec<Self::Args> {
92            let edge_case_points = OverflowingAdd::edge_case_points();
93
94            edge_case_points
95                .iter()
96                .cartesian_product(&edge_case_points)
97                .filter(|&(&l, &r)| l.checked_add(r).is_some())
98                .map(|(&l, &r)| (l, r))
99                .collect()
100        }
101    }
102
103    #[test]
104    fn rust_shadow() {
105        ShadowedClosure::new(SafeAdd).test()
106    }
107
108    #[test]
109    fn unit_test() {
110        SafeAdd.assert_expected_add_behavior(1 << 67, 1 << 67)
111    }
112
113    #[test]
114    fn overflow_test() {
115        for args in [
116            (1 << 127, 1 << 127),
117            (u128::MAX, u128::MAX),
118            (u128::MAX, 1),
119            (u128::MAX, 1 << 31),
120            (u128::MAX, 1 << 32),
121            (u128::MAX, 1 << 33),
122            (u128::MAX, 1 << 63),
123            (u128::MAX, 1 << 64),
124            (u128::MAX, 1 << 65),
125            (u128::MAX, 1 << 95),
126            (u128::MAX, 1 << 96),
127            (u128::MAX, 1 << 97),
128            (u128::MAX - 1, 2),
129        ]
130        .into_iter()
131        .flat_map(|(left, right)| [(left, right), (right, left)])
132        {
133            test_assertion_failure(
134                &ShadowedClosure::new(SafeAdd),
135                InitVmState::with_stack(SafeAdd.set_up_test_stack(args)),
136                &[SafeAdd::OVERFLOW_ERROR_ID],
137            );
138        }
139
140        for i in 0..128 {
141            let a = 1 << i;
142            let b = u128::MAX - a + 1;
143
144            assert_eq!((0, true), a.overflowing_add(b), "i = {i}. a = {a}, b = {b}");
145            test_assertion_failure(
146                &ShadowedClosure::new(SafeAdd),
147                InitVmState::with_stack(SafeAdd.set_up_test_stack((a, b))),
148                &[SafeAdd::OVERFLOW_ERROR_ID],
149            );
150        }
151    }
152}
153
154#[cfg(test)]
155mod benches {
156    use super::*;
157    use crate::test_prelude::*;
158
159    #[test]
160    fn benchmark() {
161        ShadowedClosure::new(SafeAdd).bench()
162    }
163}