tasm_lib/arithmetic/u32/
next_power_of_two.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Returns the smallest power of two greater than or equal to argument `arg`.
10/// Behaves like the [`rustc` method in debug mode][rustc_pow] for all inputs
11/// of type `u32`.
12///
13/// ### Behavior
14///
15/// ```text
16/// BEFORE: _ arg
17/// AFTER:  _ u32::next_power_of_two(arg)
18/// ```
19///
20/// ### Pre-conditions
21///
22/// - `arg` is a valid u32
23/// - `arg` is smaller than or equal to 2^31
24///
25/// ### Post-conditions
26///
27/// - the output is the smallest power of two greater than or equal to `arg`
28///
29/// [rustc_pow]: u32::next_power_of_two
30#[derive(Debug, Clone, Copy)]
31pub struct NextPowerOfTwo;
32
33impl NextPowerOfTwo {
34    pub const INPUT_TOO_LARGE_ERROR_ID: i128 = 130;
35}
36
37impl BasicSnippet for NextPowerOfTwo {
38    fn inputs(&self) -> Vec<(DataType, String)> {
39        vec![(DataType::U32, "self".to_owned())]
40    }
41
42    fn outputs(&self) -> Vec<(DataType, String)> {
43        vec![(DataType::U32, "next_power_of_two".to_owned())]
44    }
45
46    fn entrypoint(&self) -> String {
47        "tasmlib_arithmetic_u32_next_power_of_two".to_string()
48    }
49
50    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
51        let entrypoint = self.entrypoint();
52        let zero_or_one_label = format!("{entrypoint}_zero_or_one_label");
53        let greater_than_one_label = format!("{entrypoint}_greater_than_one");
54        triton_asm!(
55            {entrypoint}:
56                // _ arg
57
58                push 1
59                push 2
60                dup 2
61                lt
62                // _ arg 1 (arg < 2)
63
64                skiz
65                    call {zero_or_one_label}
66                // if arg < 2:  _ 1   0
67                // if arg >= 2: _ arg 1
68
69                skiz
70                    call {greater_than_one_label}
71                // _
72
73                return
74
75            {zero_or_one_label}:
76                // _ arg 1
77
78                pop 2
79                push 1
80                push 0
81                // _ 1 0
82
83                return
84
85            {greater_than_one_label}:
86                // _ arg
87
88                addi -1
89                log_2_floor
90                addi 1
91                // _ log₂(⌊value - 1⌋ + 1)
92
93                push 2
94                pow
95                // _ 2^log₂(⌊value - 1⌋ + 1)
96
97                /* Assert result *not* 2^{32} */
98                dup 0
99                push {1u64 << 32}
100                eq
101                push 0
102                eq
103                assert error_id {Self::INPUT_TOO_LARGE_ERROR_ID}
104
105                return
106        )
107    }
108
109    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
110        let mut sign_offs = HashMap::new();
111        sign_offs.insert(Reviewer("ferdinand"), 0x131c49afe5bf05af.into());
112        sign_offs
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::test_prelude::*;
120
121    impl Closure for NextPowerOfTwo {
122        type Args = u32;
123
124        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
125            let arg = pop_encodable::<Self::Args>(stack);
126            push_encodable(stack, &arg.next_power_of_two());
127        }
128
129        fn pseudorandom_args(
130            &self,
131            seed: [u8; 32],
132            bench_case: Option<BenchmarkCase>,
133        ) -> Self::Args {
134            match bench_case {
135                Some(BenchmarkCase::CommonCase) => (1 << 27) - 1,
136                Some(BenchmarkCase::WorstCase) => (1 << 31) - 1,
137                None => StdRng::from_seed(seed).next_u32() / 2,
138            }
139        }
140
141        fn corner_case_args(&self) -> Vec<Self::Args> {
142            let small_inputs = 0..=66;
143            let big_valid_inputs = (0..=5).map(|i| (1 << 31) - i);
144
145            small_inputs.chain(big_valid_inputs).collect()
146        }
147    }
148
149    impl NextPowerOfTwo {
150        fn prepare_vm_stack(&self, arg: u32) -> Vec<BFieldElement> {
151            [self.init_stack_for_isolated_run(), bfe_vec![arg]].concat()
152        }
153    }
154
155    #[test]
156    fn next_power_of_two_pbt() {
157        ShadowedClosure::new(NextPowerOfTwo).test()
158    }
159
160    #[test]
161    fn npo2_overflow_negative_test() {
162        let greater_two_pow_31 = (1..=5).map(|i| (1 << 31) + i);
163        let smaller_equal_u32_max = (0..=2).map(|i| u32::MAX - i);
164
165        for arg in greater_two_pow_31.chain(smaller_equal_u32_max) {
166            test_assertion_failure(
167                &ShadowedClosure::new(NextPowerOfTwo),
168                InitVmState::with_stack(NextPowerOfTwo.prepare_vm_stack(arg)),
169                &[NextPowerOfTwo::INPUT_TOO_LARGE_ERROR_ID],
170            );
171        }
172    }
173}
174
175#[cfg(test)]
176mod benches {
177    use super::*;
178    use crate::test_prelude::*;
179
180    #[test]
181    fn benchmark() {
182        ShadowedClosure::new(NextPowerOfTwo).bench()
183    }
184}