1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
7pub struct SafePow;
8
9impl BasicSnippet for SafePow {
10 fn parameters(&self) -> Vec<(DataType, String)> {
11 vec![
12 (DataType::U32, "base".to_owned()),
13 (DataType::U32, "exponent".to_owned()),
14 ]
15 }
16
17 fn return_values(&self) -> Vec<(DataType, String)> {
18 vec![(DataType::U32, "result".to_owned())]
19 }
20
21 fn entrypoint(&self) -> String {
22 "tasmlib_arithmetic_u32_safe_pow".to_string()
23 }
24
25 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
26 let entrypoint = self.entrypoint();
49 let while_acc_label = format!("{entrypoint}_while_acc");
50 let mul_acc_with_bpow2_label = format!("{entrypoint}_mul_acc_with_bpow2");
51 triton_asm!(
52 {entrypoint}:
53 push 0
56 swap 2
57 swap 1
58 push 1
61 call {while_acc_label}
67 swap 3
70 pop 3
71 return
72
73 {while_acc_label}:
75 dup 1 push 0 eq
77 skiz
78 return
79 dup 3 push 0 eq assert error_id 120
83
84 dup 1
86 push 1
87 and
88 skiz
90 call {mul_acc_with_bpow2_label}
91
92 swap 2
95 dup 0 mul split
98 swap 3
103 swap 1
106 swap 4 pop 1
109 push 2
114 dup 2
117 div_mod
120 pop 1 swap 2 pop 1
123 recurse
127
128 {mul_acc_with_bpow2_label}:
129 dup 2
132 mul
133 split swap 1 push 0 eq assert error_id 121
136 return
139 )
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::test_prelude::*;
147
148 impl Closure for SafePow {
149 type Args = (u32, u32);
150
151 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
152 let (base, exponent) = pop_encodable::<Self::Args>(stack)?;
153 let pow = base
154 .checked_pow(exponent)
155 .ok_or(RustShadowError::ArithmeticOverflow)?;
156 push_encodable(stack, &pow);
157 Ok(())
158 }
159
160 fn pseudorandom_args(
161 &self,
162 seed: [u8; 32],
163 bench_case: Option<BenchmarkCase>,
164 ) -> Self::Args {
165 let Some(bench_case) = bench_case else {
166 let mut seeded_rng = StdRng::from_seed(seed);
167 let base = seeded_rng.random_range(0..0x10);
168 let exponent = seeded_rng.random_range(0..0x8);
169 return (base, exponent);
170 };
171
172 match bench_case {
173 BenchmarkCase::CommonCase => (10, 5),
174 BenchmarkCase::WorstCase => (2, 31),
175 }
176 }
177
178 fn corner_case_args(&self) -> Vec<Self::Args> {
179 vec![(0, 0)]
180 }
181 }
182
183 #[macro_rules_attr::apply(test)]
184 fn ruts_shadow() {
185 ShadowedClosure::new(SafePow).test()
186 }
187
188 #[macro_rules_attr::apply(test)]
189 fn u32_pow_unit_test() {
190 for (base, exp) in [
191 (0, 0),
192 (0, 1),
193 (1, 0),
194 (1, 1),
195 (2, 30),
196 (2, 31),
197 (3, 20),
198 (4, 15),
199 (5, 13),
200 (6, 12),
201 (7, 11),
202 (8, 10),
203 (9, 10),
204 (10, 9),
205 (11, 9),
206 (12, 8),
207 (u32::MAX, 0),
208 (u32::MAX, 1),
209 (1, u32::MAX),
210 (0, u32::MAX),
211 (1, u32::MAX - 1),
212 (0, u32::MAX - 1),
213 (1, u32::MAX - 2),
214 (0, u32::MAX - 2),
215 (1, u32::MAX - 3),
216 (0, u32::MAX - 3),
217 ] {
218 let initial_stack = SafePow.set_up_test_stack((base, exp));
219 let mut expected_final_stack = initial_stack.clone();
220 SafePow.rust_shadow(&mut expected_final_stack).unwrap();
221
222 let _vm_output_state = test_rust_equivalence_given_complete_state(
223 &ShadowedClosure::new(SafePow),
224 &initial_stack,
225 &[],
226 &NonDeterminism::default(),
227 &None,
228 Some(&expected_final_stack),
229 );
230 }
231 }
232
233 #[macro_rules_attr::apply(test)]
234 fn u32_pow_negative_test() {
235 for (base, exp) in [
236 (2, 32),
237 (3, 21),
238 (4, 16),
239 (5, 14),
240 (6, 13),
241 (7, 12),
242 (8, 11),
243 (9, 11),
244 (10, 10),
245 (11, 10),
246 (12, 10),
247 (u32::MAX, 2),
248 (u32::MAX, 3),
249 (u32::MAX, 4),
250 (u32::MAX, 5),
251 (u32::MAX, 6),
252 (u32::MAX, 7),
253 (u32::MAX, 8),
254 (u32::MAX, 9),
255 (1 << 16, 2),
256 (1 << 16, 3),
257 (1 << 16, 4),
258 (1 << 16, 5),
259 (1 << 16, 6),
260 (1 << 16, 7),
261 (1 << 16, 8),
262 (1 << 8, 4),
263 (1 << 8, 8),
264 (1 << 8, 16),
265 (1 << 8, 32),
266 ] {
267 test_assertion_failure(
268 &ShadowedClosure::new(SafePow),
269 InitVmState::with_stack(SafePow.set_up_test_stack((base, exp))),
270 &[120, 121],
271 );
272 }
273 }
274}
275
276#[cfg(test)]
277mod benches {
278 use super::*;
279 use crate::test_prelude::*;
280
281 #[macro_rules_attr::apply(test)]
282 fn benchmark() {
283 ShadowedClosure::new(SafePow).bench()
284 }
285}