Skip to main content

tasm_lib/arithmetic/u192/
safe_add.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u192::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct SafeAdd;
8
9/// Sum two u192 values, crashing the VM on overflow.
10///
11/// ### Behavior
12///
13/// ```text
14/// BEFORE: _ [rhs: u192] [lhs: u192]
15/// AFTER:  _ [sum: u192]
16/// ```
17///
18/// ### Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21///
22/// ### Postconditions
23///
24/// - the output is properly [`BFieldCodec`] encoded
25impl SafeAdd {
26    pub(crate) const OVERFLOW_ERROR_ID: i128 = 600;
27}
28
29impl BasicSnippet for SafeAdd {
30    fn parameters(&self) -> Vec<(DataType, String)> {
31        vec![
32            (DataType::U192, "l".to_owned()),
33            (DataType::U192, "r".to_owned()),
34        ]
35    }
36
37    fn return_values(&self) -> Vec<(DataType, String)> {
38        vec![(DataType::U192, "sum".to_owned())]
39    }
40
41    fn entrypoint(&self) -> String {
42        "tasmlib_arithmetic_u192_safe_add".to_string()
43    }
44
45    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46        let add_code = OverflowingAdd::addition_code();
47
48        triton_asm! {
49            // BEFORE: _ r5 r4 r3 r2 r1 r0 l5 l4 l3 l2 l1 l0
50            // AFTER:  _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0
51            {self.entrypoint()}:
52                {&add_code}
53                // _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0 overflow
54
55                push 0
56                eq
57                assert error_id {Self::OVERFLOW_ERROR_ID}
58                // _ sum_5 sum_4 sum_3 sum_2 sum_1 sum_0
59                // _ [sum]
60
61                return
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use num::BigUint;
69    use rand::rngs::StdRng;
70
71    use super::*;
72    use crate::arithmetic::u192::U192;
73    use crate::arithmetic::u192::to_u192;
74    use crate::arithmetic::u192::u128_to_u192_shl64;
75    use crate::test_prelude::*;
76
77    #[macro_rules_attr::apply(test)]
78    fn rust_shadow() {
79        ShadowedClosure::new(SafeAdd).test();
80    }
81
82    #[macro_rules_attr::apply(test)]
83    fn overflow_test() {
84        for (left, right) in [
85            (1 << 127, 1 << 127),
86            (u128::MAX, u128::MAX),
87            (u128::MAX, 1),
88            (u128::MAX, 1 << 31),
89            (u128::MAX, 1 << 32),
90            (u128::MAX, 1 << 33),
91            (u128::MAX, 1 << 63),
92            (u128::MAX, 1 << 64),
93            (u128::MAX, 1 << 65),
94            (u128::MAX, 1 << 95),
95            (u128::MAX, 1 << 96),
96            (u128::MAX, 1 << 97),
97            (u128::MAX - 1, 2),
98        ]
99        .into_iter()
100        .flat_map(|(left, right)| [(left, right), (right, left)])
101        {
102            let left = to_u192(left, u64::MAX);
103            let right = u128_to_u192_shl64(right);
104            test_assertion_failure(
105                &ShadowedClosure::new(SafeAdd),
106                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
107                &[SafeAdd::OVERFLOW_ERROR_ID],
108            );
109        }
110
111        for i in 0..128 {
112            let left = 1 << i;
113            let right = u128::MAX - left + 1;
114
115            assert_eq!(
116                (0, true),
117                left.overflowing_add(right),
118                "i = {i}. a = {left}, b = {right}"
119            );
120
121            let left = to_u192(left, u64::MAX);
122            let right = u128_to_u192_shl64(right);
123
124            test_assertion_failure(
125                &ShadowedClosure::new(SafeAdd),
126                InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
127                &[SafeAdd::OVERFLOW_ERROR_ID],
128            );
129        }
130    }
131
132    impl Closure for SafeAdd {
133        type Args = <OverflowingAdd as Closure>::Args;
134
135        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
136            let left: U192 = pop_encodable(stack)?;
137            let left: BigUint = BigUint::new(left.to_vec());
138            let right: U192 = pop_encodable(stack)?;
139            let right: BigUint = BigUint::new(right.to_vec());
140            let sum = left + right;
141            let mut sum = sum.to_u32_digits();
142            if sum.len() > 6 {
143                return Err(RustShadowError::ArithmeticOverflow);
144            }
145
146            sum.resize(6, 0);
147            let sum: U192 = sum.try_into().unwrap();
148
149            push_encodable(stack, &sum);
150            Ok(())
151        }
152
153        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
154            let mut rng = StdRng::from_seed(seed);
155            let lhs: U192 = rng.random();
156            let lhs_as_biguint: BigUint = BigUint::new(lhs.to_vec());
157
158            let u192_max = BigUint::from_bytes_be(&[0xFF; 24]);
159            let max = &u192_max - &lhs_as_biguint;
160
161            // Generate random bytes for rhs, making sure the value < max
162            let mut rhs_bytes = [0u8; 24];
163            let rhs = loop {
164                rng.fill(&mut rhs_bytes);
165                let candidate = BigUint::from_bytes_be(&rhs_bytes);
166                if candidate < max {
167                    break candidate;
168                }
169            };
170
171            let mut rhs = rhs.to_u32_digits();
172            rhs.resize(6, 0);
173
174            (lhs, rhs.try_into().unwrap())
175        }
176
177        fn corner_case_args(&self) -> Vec<Self::Args> {
178            fn u192_checked_add(l: U192, r: U192) -> Option<U192> {
179                let l: BigUint = BigUint::new(l.to_vec());
180                let r: BigUint = BigUint::new(r.to_vec());
181
182                let sum = l + r;
183                let mut sum = sum.to_u32_digits();
184
185                if sum.len() > 6 {
186                    None
187                } else {
188                    sum.resize(6, 0);
189                    Some(sum.try_into().unwrap())
190                }
191            }
192
193            let edge_case_points = OverflowingAdd::edge_case_points();
194
195            edge_case_points
196                .iter()
197                .cartesian_product(&edge_case_points)
198                .filter(|&(&l, &r)| u192_checked_add(l, r).is_some())
199                .map(|(&l, &r)| (l, r))
200                .collect()
201        }
202    }
203}
204
205#[cfg(test)]
206mod benches {
207    use super::*;
208    use crate::test_prelude::*;
209
210    #[macro_rules_attr::apply(test)]
211    fn benchmark() {
212        ShadowedClosure::new(SafeAdd).bench()
213    }
214}