tasm_lib/arithmetic/u32/
trailing_zeros.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Returns the number of trailing zeros in the binary representation of the
10/// input argument. Behaves like [`u32::trailing_zeros`].
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ arg
16/// AFTER:  _ u32::trailing_zeros(arg)
17/// ```
18///
19/// ### Preconditions
20///
21/// - `arg` is a valid `u32`
22///
23/// ### Postconditions
24///
25/// - the output is the number of trailing zeros in the binary representation
26///   of `arg`
27#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
28pub struct TrailingZeros;
29
30impl BasicSnippet for TrailingZeros {
31    fn inputs(&self) -> Vec<(DataType, String)> {
32        vec![(DataType::U32, "arg".to_string())]
33    }
34
35    fn outputs(&self) -> Vec<(DataType, String)> {
36        vec![(DataType::U32, "trailing_zeros(arg)".to_string())]
37    }
38
39    fn entrypoint(&self) -> String {
40        "tasmlib_arithmetic_u32_trailing_zeros".to_string()
41    }
42
43    // The basic idea for the algorithm below is taken from “Count the consecutive
44    // zero bits (trailing) on the right in parallel” [0]. For example, consider
45    // input 1010100₂:
46    //
47    // input:                         1010100₂
48    // bitwise negation:         11…110101011₂
49    // (wrapping) add one:       11…110101100₂
50    // bitwise `and` with input:          100₂
51    // base-2 integer logarithm:            2
52    //
53    // By handling the edge case “arg == 0” early, the bitwise negation of the input
54    // can never be 11…11₂, meaning the subsequent addition of 1 can never overflow.
55    // This, in turn, implies that the instruction `log_2_floor` will never cause a
56    // crash.
57    //
58    // [0] https://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightParallel
59    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
60        let entrypoint = self.entrypoint();
61        let arg_eq_0 = format!("{entrypoint}_arg_eq_0");
62        let arg_neq_0 = format!("{entrypoint}_arg_neq_0");
63
64        triton_asm! {
65            // BEFORE: _ arg
66            // AFTER:  _ trailing_zeros(arg)
67            {entrypoint}:
68                push 1
69                dup 1
70                push 0
71                eq
72                // _ arg 1 (arg == 0)
73
74                skiz call {arg_eq_0}
75                skiz call {arg_neq_0}
76                // _ trailing_zeros(arg)
77
78                return
79
80            // BEFORE: _ 0 1
81            // AFTER:  _ 32 0
82            {arg_eq_0}:
83                pop 2
84                push 32
85                push 0
86                return
87
88            // BEFORE: _ arg
89            // AFTER:  _ trailing_zeros(arg)
90            // where arg != 0
91            {arg_neq_0}:
92                dup 0
93                push {u32::MAX}
94                    hint u32_max: u32 = stack[0]
95                xor
96                    hint bitwise_negated_arg: u32 = stack[0]
97                // _ arg bitwise_negated_arg
98
99                addi 1
100                and
101                log_2_floor
102
103                return
104        }
105    }
106
107    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
108        let mut sign_offs = HashMap::new();
109        sign_offs.insert(Reviewer("ferdinand"), 0xc7e78a3074304156.into());
110        sign_offs
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::test_prelude::*;
118
119    impl Closure for TrailingZeros {
120        type Args = u32;
121
122        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
123            let arg = pop_encodable::<Self::Args>(stack);
124            push_encodable(stack, &arg.trailing_zeros());
125        }
126
127        fn pseudorandom_args(
128            &self,
129            seed: [u8; 32],
130            bench_case: Option<BenchmarkCase>,
131        ) -> Self::Args {
132            match bench_case {
133                Some(BenchmarkCase::CommonCase) => 0b1111_1111 << 3,
134                Some(BenchmarkCase::WorstCase) => 1 << 31,
135                None => StdRng::from_seed(seed).random(),
136            }
137        }
138
139        fn corner_case_args(&self) -> Vec<Self::Args> {
140            [1, 1 << 31, u32::MAX - 1]
141                .into_iter()
142                .flat_map(|i| [i - 1, i, i + 1])
143                .collect()
144        }
145    }
146
147    #[test]
148    fn unit() {
149        ShadowedClosure::new(TrailingZeros).test();
150    }
151}
152
153#[cfg(test)]
154mod benches {
155    use super::*;
156    use crate::test_prelude::*;
157
158    #[test]
159    fn benchmark() {
160        ShadowedClosure::new(TrailingZeros).bench()
161    }
162}