1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
7pub struct SafePow;
8
9impl BasicSnippet for SafePow {
10 fn inputs(&self) -> Vec<(DataType, String)> {
11 vec![
12 (DataType::U32, "base".to_owned()),
13 (DataType::U32, "exponent".to_owned()),
14 ]
15 }
16
17 fn outputs(&self) -> Vec<(DataType, String)> {
18 vec![(DataType::U32, "result".to_owned())]
19 }
20
21 fn entrypoint(&self) -> String {
22 "tasmlib_arithmetic_u32_safe_pow".to_string()
23 }
24
25 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
26 let entrypoint = self.entrypoint();
49 let while_acc_label = format!("{entrypoint}_while_acc");
50 let mul_acc_with_bpow2_label = format!("{entrypoint}_mul_acc_with_bpow2");
51 triton_asm!(
52 {entrypoint}:
53 push 0
56 swap 2
57 swap 1
58 push 1
61 call {while_acc_label}
67 swap 3
70 pop 3
71 return
72
73 {while_acc_label}:
75 dup 1 push 0 eq
77 skiz
78 return
79 dup 3 push 0 eq assert error_id 120
83
84 dup 1
86 push 1
87 and
88 skiz
90 call {mul_acc_with_bpow2_label}
91
92 swap 2
95 dup 0 mul split
98 swap 3
103 swap 1
106 swap 4 pop 1
109 push 2
114 dup 2
117 div_mod
120 pop 1 swap 2 pop 1
123 recurse
127
128 {mul_acc_with_bpow2_label}:
129 dup 2
132 mul
133 split swap 1 push 0 eq assert error_id 121
136 return
139 )
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::test_prelude::*;
147
148 impl Closure for SafePow {
149 type Args = (u32, u32);
150
151 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
152 let (base, exponent) = pop_encodable::<Self::Args>(stack);
153 push_encodable(stack, &(base.pow(exponent)));
154 }
155
156 fn pseudorandom_args(
157 &self,
158 seed: [u8; 32],
159 bench_case: Option<BenchmarkCase>,
160 ) -> Self::Args {
161 let Some(bench_case) = bench_case else {
162 let mut seeded_rng = StdRng::from_seed(seed);
163 let base = seeded_rng.random_range(0..0x10);
164 let exponent = seeded_rng.random_range(0..0x8);
165 return (base, exponent);
166 };
167
168 match bench_case {
169 BenchmarkCase::CommonCase => (10, 5),
170 BenchmarkCase::WorstCase => (2, 31),
171 }
172 }
173
174 fn corner_case_args(&self) -> Vec<Self::Args> {
175 vec![(0, 0)]
176 }
177 }
178
179 #[test]
180 fn ruts_shadow() {
181 ShadowedClosure::new(SafePow).test()
182 }
183
184 #[test]
185 fn u32_pow_unit_test() {
186 for (base, exp) in [
187 (0, 0),
188 (0, 1),
189 (1, 0),
190 (1, 1),
191 (2, 30),
192 (2, 31),
193 (3, 20),
194 (4, 15),
195 (5, 13),
196 (6, 12),
197 (7, 11),
198 (8, 10),
199 (9, 10),
200 (10, 9),
201 (11, 9),
202 (12, 8),
203 (u32::MAX, 0),
204 (u32::MAX, 1),
205 (1, u32::MAX),
206 (0, u32::MAX),
207 (1, u32::MAX - 1),
208 (0, u32::MAX - 1),
209 (1, u32::MAX - 2),
210 (0, u32::MAX - 2),
211 (1, u32::MAX - 3),
212 (0, u32::MAX - 3),
213 ] {
214 let initial_stack = SafePow.set_up_test_stack((base, exp));
215 let mut expected_final_stack = initial_stack.clone();
216 SafePow.rust_shadow(&mut expected_final_stack);
217
218 let _vm_output_state = test_rust_equivalence_given_complete_state(
219 &ShadowedClosure::new(SafePow),
220 &initial_stack,
221 &[],
222 &NonDeterminism::default(),
223 &None,
224 Some(&expected_final_stack),
225 );
226 }
227 }
228
229 #[test]
230 fn u32_pow_negative_test() {
231 for (base, exp) in [
232 (2, 32),
233 (3, 21),
234 (4, 16),
235 (5, 14),
236 (6, 13),
237 (7, 12),
238 (8, 11),
239 (9, 11),
240 (10, 10),
241 (11, 10),
242 (12, 10),
243 (u32::MAX, 2),
244 (u32::MAX, 3),
245 (u32::MAX, 4),
246 (u32::MAX, 5),
247 (u32::MAX, 6),
248 (u32::MAX, 7),
249 (u32::MAX, 8),
250 (u32::MAX, 9),
251 (1 << 16, 2),
252 (1 << 16, 3),
253 (1 << 16, 4),
254 (1 << 16, 5),
255 (1 << 16, 6),
256 (1 << 16, 7),
257 (1 << 16, 8),
258 (1 << 8, 4),
259 (1 << 8, 8),
260 (1 << 8, 16),
261 (1 << 8, 32),
262 ] {
263 test_assertion_failure(
264 &ShadowedClosure::new(SafePow),
265 InitVmState::with_stack(SafePow.set_up_test_stack((base, exp))),
266 &[120, 121],
267 );
268 }
269 }
270}
271
272#[cfg(test)]
273mod benches {
274 use super::*;
275 use crate::test_prelude::*;
276
277 #[test]
278 fn benchmark() {
279 ShadowedClosure::new(SafePow).bench()
280 }
281}