Skip to main content

tasm_lib/arithmetic/u32/
safe_pow.rs

1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5/// A u32 `pow` that behaves like Rustc's `pow` method on `u32`, crashing in case of overflow.
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
7pub struct SafePow;
8
9impl BasicSnippet for SafePow {
10    fn parameters(&self) -> Vec<(DataType, String)> {
11        vec![
12            (DataType::U32, "base".to_owned()),
13            (DataType::U32, "exponent".to_owned()),
14        ]
15    }
16
17    fn return_values(&self) -> Vec<(DataType, String)> {
18        vec![(DataType::U32, "result".to_owned())]
19    }
20
21    fn entrypoint(&self) -> String {
22        "tasmlib_arithmetic_u32_safe_pow".to_string()
23    }
24
25    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
26        // This algorithm is implemented below. `bpow2` has type
27        // `u64` because it would otherwise erroneously overflow
28        // in the last iteration of the loop when e.g. calculating
29        // 2.pow(31).
30
31        // fn safe_pow(base: u32, exponent: u64) -> Self {
32        //     let mut bpow2: u64 = base as u64;
33        //     let mut acc = 1u32;
34        //     let mut i = exponent;
35
36        //     while i != 0 {
37        //         if i & 1 == 1 {
38        //             acc *= bpow2;
39        //         }
40
41        //         bpow2 *= bpow2;
42        //         i >>= 1;
43        //     }
44
45        //     acc
46        // }
47
48        let entrypoint = self.entrypoint();
49        let while_acc_label = format!("{entrypoint}_while_acc");
50        let mul_acc_with_bpow2_label = format!("{entrypoint}_mul_acc_with_bpow2");
51        triton_asm!(
52            {entrypoint}:
53                // _ base exponent
54
55                push 0
56                swap 2
57                swap 1
58                // _ 0 base exponent
59
60                push 1
61                // _ [base_u64] exponent acc
62
63                // rename: `exponent` -> `i`, `base_u64` -> `bpow2_u64`
64
65                // _ [bpow2_u64] i acc
66                call {while_acc_label}
67                // _ [bpow2_u64] 0 acc
68
69                swap 3
70                pop 3
71                return
72
73            // INVARIANT: _ [bpow2_u64] i acc
74            {while_acc_label}:
75                // check condition
76                dup 1 push 0 eq
77                skiz
78                    return
79                // _ [bpow2_u64] i acc
80
81                // Verify that `bpow2_u64` does not exceed `u32::MAX`
82                dup 3 push 0 eq assert error_id 120
83
84                // _ 0 bpow2 i acc
85                dup 1
86                push 1
87                and
88                // _ 0 bpow2 i acc (i & 1)
89                skiz
90                    call {mul_acc_with_bpow2_label}
91
92                // _ 0 bpow2 i acc
93
94                swap 2
95                // _ 0 acc i bpow2
96
97                dup 0 mul split
98                // _ 0 acc i bpow2_next_hi bpow2_next_lo
99
100                // GOAL: // _ [bpow_u64] i acc
101
102                swap 3
103                // _ 0 bpow2_next_lo i bpow2_next_hi acc
104
105                swap 1
106                // _ 0 bpow2_next_lo i acc bpow2_next_hi
107
108                swap 4 pop 1
109                // _ bpow2_next_hi bpow2_next_lo i acc
110
111                // _ [bpow2_next_u64] i acc
112
113                push 2
114                // _ [bpow2_next_u64] i acc 2
115
116                dup 2
117                // _ [bpow2_next_u64] i acc 2 i
118
119                div_mod
120                // _ [bpow2_u64] i acc (i >> 2) (i % 2)
121
122                pop 1 swap 2 pop 1
123                // _ [bpow2_u64] (i >> 2) acc
124
125
126                recurse
127
128            {mul_acc_with_bpow2_label}:
129                // _ 0 bpow2 i acc
130
131                dup 2
132                mul
133                // _ 0 bpow2 i (acc * bpow2)
134
135                split swap 1 push 0 eq assert error_id 121
136                // _ 0 bpow2 i new_acc
137
138                return
139        )
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::test_prelude::*;
147
148    impl Closure for SafePow {
149        type Args = (u32, u32);
150
151        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
152            let (base, exponent) = pop_encodable::<Self::Args>(stack)?;
153            let pow = base
154                .checked_pow(exponent)
155                .ok_or(RustShadowError::ArithmeticOverflow)?;
156            push_encodable(stack, &pow);
157            Ok(())
158        }
159
160        fn pseudorandom_args(
161            &self,
162            seed: [u8; 32],
163            bench_case: Option<BenchmarkCase>,
164        ) -> Self::Args {
165            let Some(bench_case) = bench_case else {
166                let mut seeded_rng = StdRng::from_seed(seed);
167                let base = seeded_rng.random_range(0..0x10);
168                let exponent = seeded_rng.random_range(0..0x8);
169                return (base, exponent);
170            };
171
172            match bench_case {
173                BenchmarkCase::CommonCase => (10, 5),
174                BenchmarkCase::WorstCase => (2, 31),
175            }
176        }
177
178        fn corner_case_args(&self) -> Vec<Self::Args> {
179            vec![(0, 0)]
180        }
181    }
182
183    #[macro_rules_attr::apply(test)]
184    fn ruts_shadow() {
185        ShadowedClosure::new(SafePow).test()
186    }
187
188    #[macro_rules_attr::apply(test)]
189    fn u32_pow_unit_test() {
190        for (base, exp) in [
191            (0, 0),
192            (0, 1),
193            (1, 0),
194            (1, 1),
195            (2, 30),
196            (2, 31),
197            (3, 20),
198            (4, 15),
199            (5, 13),
200            (6, 12),
201            (7, 11),
202            (8, 10),
203            (9, 10),
204            (10, 9),
205            (11, 9),
206            (12, 8),
207            (u32::MAX, 0),
208            (u32::MAX, 1),
209            (1, u32::MAX),
210            (0, u32::MAX),
211            (1, u32::MAX - 1),
212            (0, u32::MAX - 1),
213            (1, u32::MAX - 2),
214            (0, u32::MAX - 2),
215            (1, u32::MAX - 3),
216            (0, u32::MAX - 3),
217        ] {
218            let initial_stack = SafePow.set_up_test_stack((base, exp));
219            let mut expected_final_stack = initial_stack.clone();
220            SafePow.rust_shadow(&mut expected_final_stack).unwrap();
221
222            let _vm_output_state = test_rust_equivalence_given_complete_state(
223                &ShadowedClosure::new(SafePow),
224                &initial_stack,
225                &[],
226                &NonDeterminism::default(),
227                &None,
228                Some(&expected_final_stack),
229            );
230        }
231    }
232
233    #[macro_rules_attr::apply(test)]
234    fn u32_pow_negative_test() {
235        for (base, exp) in [
236            (2, 32),
237            (3, 21),
238            (4, 16),
239            (5, 14),
240            (6, 13),
241            (7, 12),
242            (8, 11),
243            (9, 11),
244            (10, 10),
245            (11, 10),
246            (12, 10),
247            (u32::MAX, 2),
248            (u32::MAX, 3),
249            (u32::MAX, 4),
250            (u32::MAX, 5),
251            (u32::MAX, 6),
252            (u32::MAX, 7),
253            (u32::MAX, 8),
254            (u32::MAX, 9),
255            (1 << 16, 2),
256            (1 << 16, 3),
257            (1 << 16, 4),
258            (1 << 16, 5),
259            (1 << 16, 6),
260            (1 << 16, 7),
261            (1 << 16, 8),
262            (1 << 8, 4),
263            (1 << 8, 8),
264            (1 << 8, 16),
265            (1 << 8, 32),
266        ] {
267            test_assertion_failure(
268                &ShadowedClosure::new(SafePow),
269                InitVmState::with_stack(SafePow.set_up_test_stack((base, exp))),
270                &[120, 121],
271            );
272        }
273    }
274}
275
276#[cfg(test)]
277mod benches {
278    use super::*;
279    use crate::test_prelude::*;
280
281    #[macro_rules_attr::apply(test)]
282    fn benchmark() {
283        ShadowedClosure::new(SafePow).bench()
284    }
285}