tasm_lib/arithmetic/u32/
safe_pow.rs

1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5/// A u32 `pow` that behaves like Rustc's `pow` method on `u32`, crashing in case of overflow.
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
7pub struct SafePow;
8
9impl BasicSnippet for SafePow {
10    fn inputs(&self) -> Vec<(DataType, String)> {
11        vec![
12            (DataType::U32, "base".to_owned()),
13            (DataType::U32, "exponent".to_owned()),
14        ]
15    }
16
17    fn outputs(&self) -> Vec<(DataType, String)> {
18        vec![(DataType::U32, "result".to_owned())]
19    }
20
21    fn entrypoint(&self) -> String {
22        "tasmlib_arithmetic_u32_safe_pow".to_string()
23    }
24
25    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
26        // This algorithm is implemented below. `bpow2` has type
27        // `u64` because it would otherwise erroneously overflow
28        // in the last iteration of the loop when e.g. calculating
29        // 2.pow(31).
30
31        // fn safe_pow(base: u32, exponent: u64) -> Self {
32        //     let mut bpow2: u64 = base as u64;
33        //     let mut acc = 1u32;
34        //     let mut i = exponent;
35
36        //     while i != 0 {
37        //         if i & 1 == 1 {
38        //             acc *= bpow2;
39        //         }
40
41        //         bpow2 *= bpow2;
42        //         i >>= 1;
43        //     }
44
45        //     acc
46        // }
47
48        let entrypoint = self.entrypoint();
49        let while_acc_label = format!("{entrypoint}_while_acc");
50        let mul_acc_with_bpow2_label = format!("{entrypoint}_mul_acc_with_bpow2");
51        triton_asm!(
52            {entrypoint}:
53                // _ base exponent
54
55                push 0
56                swap 2
57                swap 1
58                // _ 0 base exponent
59
60                push 1
61                // _ [base_u64] exponent acc
62
63                // rename: `exponent` -> `i`, `base_u64` -> `bpow2_u64`
64
65                // _ [bpow2_u64] i acc
66                call {while_acc_label}
67                // _ [bpow2_u64] 0 acc
68
69                swap 3
70                pop 3
71                return
72
73            // INVARIANT: _ [bpow2_u64] i acc
74            {while_acc_label}:
75                // check condition
76                dup 1 push 0 eq
77                skiz
78                    return
79                // _ [bpow2_u64] i acc
80
81                // Verify that `bpow2_u64` does not exceed `u32::MAX`
82                dup 3 push 0 eq assert error_id 120
83
84                // _ 0 bpow2 i acc
85                dup 1
86                push 1
87                and
88                // _ 0 bpow2 i acc (i & 1)
89                skiz
90                    call {mul_acc_with_bpow2_label}
91
92                // _ 0 bpow2 i acc
93
94                swap 2
95                // _ 0 acc i bpow2
96
97                dup 0 mul split
98                // _ 0 acc i bpow2_next_hi bpow2_next_lo
99
100                // GOAL: // _ [bpow_u64] i acc
101
102                swap 3
103                // _ 0 bpow2_next_lo i bpow2_next_hi acc
104
105                swap 1
106                // _ 0 bpow2_next_lo i acc bpow2_next_hi
107
108                swap 4 pop 1
109                // _ bpow2_next_hi bpow2_next_lo i acc
110
111                // _ [bpow2_next_u64] i acc
112
113                push 2
114                // _ [bpow2_next_u64] i acc 2
115
116                dup 2
117                // _ [bpow2_next_u64] i acc 2 i
118
119                div_mod
120                // _ [bpow2_u64] i acc (i >> 2) (i % 2)
121
122                pop 1 swap 2 pop 1
123                // _ [bpow2_u64] (i >> 2) acc
124
125
126                recurse
127
128            {mul_acc_with_bpow2_label}:
129                // _ 0 bpow2 i acc
130
131                dup 2
132                mul
133                // _ 0 bpow2 i (acc * bpow2)
134
135                split swap 1 push 0 eq assert error_id 121
136                // _ 0 bpow2 i new_acc
137
138                return
139        )
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::test_prelude::*;
147
148    impl Closure for SafePow {
149        type Args = (u32, u32);
150
151        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
152            let (base, exponent) = pop_encodable::<Self::Args>(stack);
153            push_encodable(stack, &(base.pow(exponent)));
154        }
155
156        fn pseudorandom_args(
157            &self,
158            seed: [u8; 32],
159            bench_case: Option<BenchmarkCase>,
160        ) -> Self::Args {
161            let Some(bench_case) = bench_case else {
162                let mut seeded_rng = StdRng::from_seed(seed);
163                let base = seeded_rng.random_range(0..0x10);
164                let exponent = seeded_rng.random_range(0..0x8);
165                return (base, exponent);
166            };
167
168            match bench_case {
169                BenchmarkCase::CommonCase => (10, 5),
170                BenchmarkCase::WorstCase => (2, 31),
171            }
172        }
173
174        fn corner_case_args(&self) -> Vec<Self::Args> {
175            vec![(0, 0)]
176        }
177    }
178
179    #[test]
180    fn ruts_shadow() {
181        ShadowedClosure::new(SafePow).test()
182    }
183
184    #[test]
185    fn u32_pow_unit_test() {
186        for (base, exp) in [
187            (0, 0),
188            (0, 1),
189            (1, 0),
190            (1, 1),
191            (2, 30),
192            (2, 31),
193            (3, 20),
194            (4, 15),
195            (5, 13),
196            (6, 12),
197            (7, 11),
198            (8, 10),
199            (9, 10),
200            (10, 9),
201            (11, 9),
202            (12, 8),
203            (u32::MAX, 0),
204            (u32::MAX, 1),
205            (1, u32::MAX),
206            (0, u32::MAX),
207            (1, u32::MAX - 1),
208            (0, u32::MAX - 1),
209            (1, u32::MAX - 2),
210            (0, u32::MAX - 2),
211            (1, u32::MAX - 3),
212            (0, u32::MAX - 3),
213        ] {
214            let initial_stack = SafePow.set_up_test_stack((base, exp));
215            let mut expected_final_stack = initial_stack.clone();
216            SafePow.rust_shadow(&mut expected_final_stack);
217
218            let _vm_output_state = test_rust_equivalence_given_complete_state(
219                &ShadowedClosure::new(SafePow),
220                &initial_stack,
221                &[],
222                &NonDeterminism::default(),
223                &None,
224                Some(&expected_final_stack),
225            );
226        }
227    }
228
229    #[test]
230    fn u32_pow_negative_test() {
231        for (base, exp) in [
232            (2, 32),
233            (3, 21),
234            (4, 16),
235            (5, 14),
236            (6, 13),
237            (7, 12),
238            (8, 11),
239            (9, 11),
240            (10, 10),
241            (11, 10),
242            (12, 10),
243            (u32::MAX, 2),
244            (u32::MAX, 3),
245            (u32::MAX, 4),
246            (u32::MAX, 5),
247            (u32::MAX, 6),
248            (u32::MAX, 7),
249            (u32::MAX, 8),
250            (u32::MAX, 9),
251            (1 << 16, 2),
252            (1 << 16, 3),
253            (1 << 16, 4),
254            (1 << 16, 5),
255            (1 << 16, 6),
256            (1 << 16, 7),
257            (1 << 16, 8),
258            (1 << 8, 4),
259            (1 << 8, 8),
260            (1 << 8, 16),
261            (1 << 8, 32),
262        ] {
263            test_assertion_failure(
264                &ShadowedClosure::new(SafePow),
265                InitVmState::with_stack(SafePow.set_up_test_stack((base, exp))),
266                &[120, 121],
267            );
268        }
269    }
270}
271
272#[cfg(test)]
273mod benches {
274    use super::*;
275    use crate::test_prelude::*;
276
277    #[test]
278    fn benchmark() {
279        ShadowedClosure::new(SafePow).bench()
280    }
281}