tasm_lib/arithmetic/u128/
shift_right.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// [Shift right][shr] for unsigned 128-bit integers.
10///
11/// # Behavior
12///
13/// ```text
14/// BEFORE: _ [arg: u128] shift_amount
15/// AFTER:  _ [result: u128]
16/// ```
17///
18/// # Preconditions
19///
20/// - input argument `arg` is properly [`BFieldCodec`] encoded
21/// - input argument `shift_amount` is in `0..128`
22///
23/// # Postconditions
24///
25/// - the output is the input argument `arg` bit-shifted to the right by
26///   input argument `shift_amount`
27/// - the output is properly [`BFieldCodec`] encoded
28///
29/// [shr]: core::ops::Shr
30#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
31pub struct ShiftRight;
32
33impl ShiftRight {
34    pub const SHIFT_AMOUNT_TOO_BIG_ERROR_ID: i128 = 540;
35}
36
37impl BasicSnippet for ShiftRight {
38    fn inputs(&self) -> Vec<(DataType, String)> {
39        let arg = (DataType::U128, "arg".to_string());
40        let shift_amount = (DataType::U32, "shift_amount".to_string());
41
42        vec![arg, shift_amount]
43    }
44
45    fn outputs(&self) -> Vec<(DataType, String)> {
46        vec![(DataType::U128, "shifted_arg".to_string())]
47    }
48
49    fn entrypoint(&self) -> String {
50        "tasmlib_arithmetic_u128_shift_right".to_string()
51    }
52
53    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
54        let entrypoint = self.entrypoint();
55        let shift_amount_gt_32 = format!("{entrypoint}_shift_amount_gt_32");
56
57        triton_asm!(
58            // BEFORE: _ v_3 v_2 v_1 v_0 s
59            // AFTER:  _ (v >> s)_3 (v >> s)_2 (v >> s)_1 (v >> s)_0
60            {entrypoint}:
61                /* bounds check */
62                push 128
63                dup 1
64                lt
65                assert error_id {Self::SHIFT_AMOUNT_TOO_BIG_ERROR_ID}
66                            // _ v_3 v_2 v_1 v_0 s
67
68                /* special case if shift amount is greater than 32 */
69                dup 0
70                push 32
71                lt          // _ v_3 v_2 v_1 v_0 s (s > 32)
72                skiz
73                    call {shift_amount_gt_32}
74                            // _ v_3 v_2 v_1 v_0 s
75
76                /* for an explanation, see snippet “u64::ShiftRight” */
77                push -1
78                mul
79                addi 32     // _ v_3 v_2 v_1 v_0 (32 - s)
80                push 2
81                pow         // _ v_3 v_2 v_1 v_0 (2^(32 - s))
82
83                dup 0
84                pick 5
85                mul         // _ v_2 v_1 v_0 (2^(32 - s)) v_3s
86                place 4
87                xb_mul      // _ v_3s v_2s v_1s v_0s
88
89                pick 3
90                split       // _ v_2s v_1s v_0s (v >> s)_3 c_2
91                pick 4
92                split       // _ v_1s v_0s (v >> s)_3 c_2 (v_2 >> s) c_1
93                pick 5
94                split       // _ v_0s (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0
95                pick 6
96                split       // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0 (v_0 >> s) trash
97
98                pop 1       // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0 (v_0 >> s)
99                add         // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) (v >> s)_0
100                place 4     // _ (v >> s)_3 (v >> s)_0 c_2 (v_2 >> s) c_1 (v_1 >> s)
101                add         // _ (v >> s)_3 (v >> s)_0 c_2 (v_2 >> s) (v >> s)_1
102                place 3     // _ (v >> s)_3 (v >> s)_1 (v >> s)_0 c_2 (v_2 >> s)
103                add         // _ (v >> s)_3 (v >> s)_1 (v >> s)_0 (v_2 >> s)_2
104                place 2     // _ (v >> s)_3 (v_2 >> s)_2 (v >> s)_1 (v >> s)_0
105
106                return
107
108            // BEFORE: _ [v: u128] s
109            // AFTER:  _ [v >> i·32: u128] (s - i·32)
110            // such that i·32 <= s < (i+1)·32
111            {shift_amount_gt_32}:
112                addi -32    // _ v_3 v_2 v_1 v_0 (s - 32)
113                pick 1
114                pop 1       // _ v_3 v_2 v_1 (s - 32)
115                push 0
116                place 4     // _ 0 v_3 v_2 v_1 (s - 32)
117
118                dup 0
119                push 32
120                lt
121                skiz
122                    recurse
123                return
124        )
125    }
126
127    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
128        let mut sign_offs = HashMap::new();
129        sign_offs.insert(Reviewer("ferdinand"), 0x9875596d880d6dd0.into());
130        sign_offs
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::test_prelude::*;
138    use rand::rngs::StdRng;
139
140    impl Closure for ShiftRight {
141        type Args = (u128, u32);
142
143        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
144            let (arg, shift_amount) = pop_encodable::<Self::Args>(stack);
145            assert!(shift_amount < 128);
146            push_encodable(stack, &(arg >> shift_amount));
147        }
148
149        fn pseudorandom_args(
150            &self,
151            seed: [u8; 32],
152            bench_case: Option<BenchmarkCase>,
153        ) -> Self::Args {
154            let mut rng = StdRng::from_seed(seed);
155
156            match bench_case {
157                Some(BenchmarkCase::CommonCase) => (0x642, 20),
158                Some(BenchmarkCase::WorstCase) => (0x123, 127),
159                None => (rng.random(), rng.random_range(0..128)),
160            }
161        }
162
163        fn corner_case_args(&self) -> Vec<Self::Args> {
164            [0, 1 << 3, 1 << 64, u64::MAX.into(), 1 << 127, u128::MAX]
165                .into_iter()
166                .cartesian_product(0..128)
167                .collect()
168        }
169    }
170
171    #[test]
172    fn rust_shadow() {
173        ShadowedClosure::new(ShiftRight).test();
174    }
175
176    #[proptest]
177    fn too_large_shift_crashes_vm(arg: u128, #[strategy(128_u32..)] shift_amount: u32) {
178        test_assertion_failure(
179            &ShadowedClosure::new(ShiftRight),
180            InitVmState::with_stack(ShiftRight.set_up_test_stack((arg, shift_amount))),
181            &[ShiftRight::SHIFT_AMOUNT_TOO_BIG_ERROR_ID],
182        )
183    }
184}
185
186#[cfg(test)]
187mod benches {
188    use super::*;
189    use crate::test_prelude::*;
190
191    #[test]
192    fn benchmark() {
193        ShadowedClosure::new(ShiftRight).bench();
194    }
195}