tasm_lib/arithmetic/u192/
safe_add.rs1use triton_vm::prelude::*;
2
3use crate::arithmetic::u192::overflowing_add::OverflowingAdd;
4use crate::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct SafeAdd;
8
9impl SafeAdd {
26 pub(crate) const OVERFLOW_ERROR_ID: i128 = 600;
27}
28
29impl BasicSnippet for SafeAdd {
30 fn parameters(&self) -> Vec<(DataType, String)> {
31 vec![
32 (DataType::U192, "l".to_owned()),
33 (DataType::U192, "r".to_owned()),
34 ]
35 }
36
37 fn return_values(&self) -> Vec<(DataType, String)> {
38 vec![(DataType::U192, "sum".to_owned())]
39 }
40
41 fn entrypoint(&self) -> String {
42 "tasmlib_arithmetic_u192_safe_add".to_string()
43 }
44
45 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
46 let add_code = OverflowingAdd::addition_code();
47
48 triton_asm! {
49 {self.entrypoint()}:
52 {&add_code}
53 push 0
56 eq
57 assert error_id {Self::OVERFLOW_ERROR_ID}
58 return
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use num::BigUint;
69 use rand::rngs::StdRng;
70
71 use super::*;
72 use crate::arithmetic::u192::U192;
73 use crate::arithmetic::u192::to_u192;
74 use crate::arithmetic::u192::u128_to_u192_shl64;
75 use crate::test_prelude::*;
76
77 #[macro_rules_attr::apply(test)]
78 fn rust_shadow() {
79 ShadowedClosure::new(SafeAdd).test();
80 }
81
82 #[macro_rules_attr::apply(test)]
83 fn overflow_test() {
84 for (left, right) in [
85 (1 << 127, 1 << 127),
86 (u128::MAX, u128::MAX),
87 (u128::MAX, 1),
88 (u128::MAX, 1 << 31),
89 (u128::MAX, 1 << 32),
90 (u128::MAX, 1 << 33),
91 (u128::MAX, 1 << 63),
92 (u128::MAX, 1 << 64),
93 (u128::MAX, 1 << 65),
94 (u128::MAX, 1 << 95),
95 (u128::MAX, 1 << 96),
96 (u128::MAX, 1 << 97),
97 (u128::MAX - 1, 2),
98 ]
99 .into_iter()
100 .flat_map(|(left, right)| [(left, right), (right, left)])
101 {
102 let left = to_u192(left, u64::MAX);
103 let right = u128_to_u192_shl64(right);
104 test_assertion_failure(
105 &ShadowedClosure::new(SafeAdd),
106 InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
107 &[SafeAdd::OVERFLOW_ERROR_ID],
108 );
109 }
110
111 for i in 0..128 {
112 let left = 1 << i;
113 let right = u128::MAX - left + 1;
114
115 assert_eq!(
116 (0, true),
117 left.overflowing_add(right),
118 "i = {i}. a = {left}, b = {right}"
119 );
120
121 let left = to_u192(left, u64::MAX);
122 let right = u128_to_u192_shl64(right);
123
124 test_assertion_failure(
125 &ShadowedClosure::new(SafeAdd),
126 InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
127 &[SafeAdd::OVERFLOW_ERROR_ID],
128 );
129 }
130 }
131
132 impl Closure for SafeAdd {
133 type Args = <OverflowingAdd as Closure>::Args;
134
135 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
136 let left: U192 = pop_encodable(stack)?;
137 let left: BigUint = BigUint::new(left.to_vec());
138 let right: U192 = pop_encodable(stack)?;
139 let right: BigUint = BigUint::new(right.to_vec());
140 let sum = left + right;
141 let mut sum = sum.to_u32_digits();
142 if sum.len() > 6 {
143 return Err(RustShadowError::ArithmeticOverflow);
144 }
145
146 sum.resize(6, 0);
147 let sum: U192 = sum.try_into().unwrap();
148
149 push_encodable(stack, &sum);
150 Ok(())
151 }
152
153 fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
154 let mut rng = StdRng::from_seed(seed);
155 let lhs: U192 = rng.random();
156 let lhs_as_biguint: BigUint = BigUint::new(lhs.to_vec());
157
158 let u192_max = BigUint::from_bytes_be(&[0xFF; 24]);
159 let max = &u192_max - &lhs_as_biguint;
160
161 let mut rhs_bytes = [0u8; 24];
163 let rhs = loop {
164 rng.fill(&mut rhs_bytes);
165 let candidate = BigUint::from_bytes_be(&rhs_bytes);
166 if candidate < max {
167 break candidate;
168 }
169 };
170
171 let mut rhs = rhs.to_u32_digits();
172 rhs.resize(6, 0);
173
174 (lhs, rhs.try_into().unwrap())
175 }
176
177 fn corner_case_args(&self) -> Vec<Self::Args> {
178 fn u192_checked_add(l: U192, r: U192) -> Option<U192> {
179 let l: BigUint = BigUint::new(l.to_vec());
180 let r: BigUint = BigUint::new(r.to_vec());
181
182 let sum = l + r;
183 let mut sum = sum.to_u32_digits();
184
185 if sum.len() > 6 {
186 None
187 } else {
188 sum.resize(6, 0);
189 Some(sum.try_into().unwrap())
190 }
191 }
192
193 let edge_case_points = OverflowingAdd::edge_case_points();
194
195 edge_case_points
196 .iter()
197 .cartesian_product(&edge_case_points)
198 .filter(|&(&l, &r)| u192_checked_add(l, r).is_some())
199 .map(|(&l, &r)| (l, r))
200 .collect()
201 }
202 }
203}
204
205#[cfg(test)]
206mod benches {
207 use super::*;
208 use crate::test_prelude::*;
209
210 #[macro_rules_attr::apply(test)]
211 fn benchmark() {
212 ShadowedClosure::new(SafeAdd).bench()
213 }
214}