tasm_lib/arithmetic/u128/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u128::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
7pub struct SafeAdd;
8
9impl SafeAdd {
10    pub(crate) const OVERFLOW_ERROR_ID: i128 = 170;
11}
12
13impl BasicSnippet for SafeAdd {
14    fn inputs(&self) -> Vec<(DataType, String)> {
15        vec![
16            (DataType::U128, "lhs".to_owned()),
17            (DataType::U128, "rhs".to_owned()),
18        ]
19    }
20
21    fn outputs(&self) -> Vec<(DataType, String)> {
22        vec![(DataType::U128, "sum".to_owned())]
23    }
24
25    fn entrypoint(&self) -> String {
26        "tasmlib_arithmetic_u128_safe_add".to_string()
27    }
28
29    /// Four top elements of stack are assumed to be valid u32s. So to have
30    /// a value that's less than 2^32.
31    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
32        let add_code = OverflowingAdd::addition_code();
33
34        triton_asm! {
35            // BEFORE: _ rhs_3 rhs_2 rhs_1 rhs_0 lhs_3 lhs_2 lhs_1 lhs_0
36            // AFTER:  _ sum_3 sum_2 sum_1 sum_0
37            {self.entrypoint()}:
38                {&add_code}
39                // _ sum_3 sum_2 sum_1 sum_0 overflow
40
41                push 0
42                eq
43                assert error_id {Self::OVERFLOW_ERROR_ID}
44                return
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::test_prelude::*;
53    use rand::rngs::StdRng;
54
55    impl SafeAdd {
56        fn assert_expected_add_behavior(&self, lhs: u128, rhs: u128) {
57            let initial_stack = self.set_up_test_stack((lhs, rhs));
58
59            let mut expected_stack = initial_stack.clone();
60            self.rust_shadow(&mut expected_stack);
61
62            test_rust_equivalence_given_complete_state(
63                &ShadowedClosure::new(Self),
64                &initial_stack,
65                &[],
66                &NonDeterminism::default(),
67                &None,
68                Some(&expected_stack),
69            );
70        }
71    }
72
73    impl Closure for SafeAdd {
74        type Args = <OverflowingAdd as Closure>::Args;
75
76        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
77            let (left, right) = pop_encodable::<Self::Args>(stack);
78            let sum = left.checked_add(right).unwrap();
79            push_encodable(stack, &sum);
80        }
81
82        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
83            let mut rng = StdRng::from_seed(seed);
84            let lhs = rng.random();
85            let rhs = rng.random_range(0..=u128::MAX - lhs);
86
87            (lhs, rhs)
88        }
89
90        fn corner_case_args(&self) -> Vec<Self::Args> {
91            let edge_case_points = OverflowingAdd::edge_case_points();
92
93            edge_case_points
94                .iter()
95                .cartesian_product(&edge_case_points)
96                .filter(|&(&l, &r)| l.checked_add(r).is_some())
97                .map(|(&l, &r)| (l, r))
98                .collect()
99        }
100    }
101
102    #[test]
103    fn rust_shadow() {
104        ShadowedClosure::new(SafeAdd).test()
105    }
106
107    #[test]
108    fn unit_test() {
109        SafeAdd.assert_expected_add_behavior(1 << 67, 1 << 67)
110    }
111
112    #[test]
113    fn overflow_test() {
114        for args in [
115            (1 << 127, 1 << 127),
116            (u128::MAX, u128::MAX),
117            (u128::MAX, 1),
118            (u128::MAX, 1 << 31),
119            (u128::MAX, 1 << 32),
120            (u128::MAX, 1 << 33),
121            (u128::MAX, 1 << 63),
122            (u128::MAX, 1 << 64),
123            (u128::MAX, 1 << 65),
124            (u128::MAX, 1 << 95),
125            (u128::MAX, 1 << 96),
126            (u128::MAX, 1 << 97),
127            (u128::MAX - 1, 2),
128        ]
129        .into_iter()
130        .flat_map(|(left, right)| [(left, right), (right, left)])
131        {
132            test_assertion_failure(
133                &ShadowedClosure::new(SafeAdd),
134                InitVmState::with_stack(SafeAdd.set_up_test_stack(args)),
135                &[SafeAdd::OVERFLOW_ERROR_ID],
136            );
137        }
138
139        for i in 0..128 {
140            let a = 1 << i;
141            let b = u128::MAX - a + 1;
142
143            assert_eq!((0, true), a.overflowing_add(b), "i = {i}. a = {a}, b = {b}");
144            test_assertion_failure(
145                &ShadowedClosure::new(SafeAdd),
146                InitVmState::with_stack(SafeAdd.set_up_test_stack((a, b))),
147                &[SafeAdd::OVERFLOW_ERROR_ID],
148            );
149        }
150    }
151}
152
153#[cfg(test)]
154mod benches {
155    use super::*;
156    use crate::test_prelude::*;
157
158    #[test]
159    fn benchmark() {
160        ShadowedClosure::new(SafeAdd).bench()
161    }
162}