tasm_lib/arithmetic/u128/shift_right.rs
1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// [Shift right][shr] for unsigned 128-bit integers.
10///
11/// # Behavior
12///
13/// ```text
14/// BEFORE: _ [arg: u128] shift_amount
15/// AFTER: _ [result: u128]
16/// ```
17///
18/// # Preconditions
19///
20/// - input argument `arg` is properly [`BFieldCodec`] encoded
21/// - input argument `shift_amount` is in `0..128`
22///
23/// # Postconditions
24///
25/// - the output is the input argument `arg` bit-shifted to the right by
26/// input argument `shift_amount`
27/// - the output is properly [`BFieldCodec`] encoded
28///
29/// [shr]: core::ops::Shr
30#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
31pub struct ShiftRight;
32
33impl ShiftRight {
34 pub const SHIFT_AMOUNT_TOO_BIG_ERROR_ID: i128 = 540;
35}
36
37impl BasicSnippet for ShiftRight {
38 fn inputs(&self) -> Vec<(DataType, String)> {
39 let arg = (DataType::U128, "arg".to_string());
40 let shift_amount = (DataType::U32, "shift_amount".to_string());
41
42 vec![arg, shift_amount]
43 }
44
45 fn outputs(&self) -> Vec<(DataType, String)> {
46 vec![(DataType::U128, "shifted_arg".to_string())]
47 }
48
49 fn entrypoint(&self) -> String {
50 "tasmlib_arithmetic_u128_shift_right".to_string()
51 }
52
53 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
54 let entrypoint = self.entrypoint();
55 let shift_amount_gt_32 = format!("{entrypoint}_shift_amount_gt_32");
56
57 triton_asm!(
58 // BEFORE: _ v_3 v_2 v_1 v_0 s
59 // AFTER: _ (v >> s)_3 (v >> s)_2 (v >> s)_1 (v >> s)_0
60 {entrypoint}:
61 /* bounds check */
62 push 128
63 dup 1
64 lt
65 assert error_id {Self::SHIFT_AMOUNT_TOO_BIG_ERROR_ID}
66 // _ v_3 v_2 v_1 v_0 s
67
68 /* special case if shift amount is greater than 32 */
69 dup 0
70 push 32
71 lt // _ v_3 v_2 v_1 v_0 s (s > 32)
72 skiz
73 call {shift_amount_gt_32}
74 // _ v_3 v_2 v_1 v_0 s
75
76 /* for an explanation, see snippet “u64::ShiftRight” */
77 push -1
78 mul
79 addi 32 // _ v_3 v_2 v_1 v_0 (32 - s)
80 push 2
81 pow // _ v_3 v_2 v_1 v_0 (2^(32 - s))
82
83 dup 0
84 pick 5
85 mul // _ v_2 v_1 v_0 (2^(32 - s)) v_3s
86 place 4
87 xb_mul // _ v_3s v_2s v_1s v_0s
88
89 pick 3
90 split // _ v_2s v_1s v_0s (v >> s)_3 c_2
91 pick 4
92 split // _ v_1s v_0s (v >> s)_3 c_2 (v_2 >> s) c_1
93 pick 5
94 split // _ v_0s (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0
95 pick 6
96 split // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0 (v_0 >> s) trash
97
98 pop 1 // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) c_0 (v_0 >> s)
99 add // _ (v >> s)_3 c_2 (v_2 >> s) c_1 (v_1 >> s) (v >> s)_0
100 place 4 // _ (v >> s)_3 (v >> s)_0 c_2 (v_2 >> s) c_1 (v_1 >> s)
101 add // _ (v >> s)_3 (v >> s)_0 c_2 (v_2 >> s) (v >> s)_1
102 place 3 // _ (v >> s)_3 (v >> s)_1 (v >> s)_0 c_2 (v_2 >> s)
103 add // _ (v >> s)_3 (v >> s)_1 (v >> s)_0 (v_2 >> s)_2
104 place 2 // _ (v >> s)_3 (v_2 >> s)_2 (v >> s)_1 (v >> s)_0
105
106 return
107
108 // BEFORE: _ [v: u128] s
109 // AFTER: _ [v >> i·32: u128] (s - i·32)
110 // such that i·32 <= s < (i+1)·32
111 {shift_amount_gt_32}:
112 addi -32 // _ v_3 v_2 v_1 v_0 (s - 32)
113 pick 1
114 pop 1 // _ v_3 v_2 v_1 (s - 32)
115 push 0
116 place 4 // _ 0 v_3 v_2 v_1 (s - 32)
117
118 dup 0
119 push 32
120 lt
121 skiz
122 recurse
123 return
124 )
125 }
126
127 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
128 let mut sign_offs = HashMap::new();
129 sign_offs.insert(Reviewer("ferdinand"), 0x9875596d880d6dd0.into());
130 sign_offs
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::test_prelude::*;
138 use rand::rngs::StdRng;
139
140 impl Closure for ShiftRight {
141 type Args = (u128, u32);
142
143 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
144 let (arg, shift_amount) = pop_encodable::<Self::Args>(stack);
145 assert!(shift_amount < 128);
146 push_encodable(stack, &(arg >> shift_amount));
147 }
148
149 fn pseudorandom_args(
150 &self,
151 seed: [u8; 32],
152 bench_case: Option<BenchmarkCase>,
153 ) -> Self::Args {
154 let mut rng = StdRng::from_seed(seed);
155
156 match bench_case {
157 Some(BenchmarkCase::CommonCase) => (0x642, 20),
158 Some(BenchmarkCase::WorstCase) => (0x123, 127),
159 None => (rng.random(), rng.random_range(0..128)),
160 }
161 }
162
163 fn corner_case_args(&self) -> Vec<Self::Args> {
164 [0, 1 << 3, 1 << 64, u64::MAX.into(), 1 << 127, u128::MAX]
165 .into_iter()
166 .cartesian_product(0..128)
167 .collect()
168 }
169 }
170
171 #[test]
172 fn rust_shadow() {
173 ShadowedClosure::new(ShiftRight).test();
174 }
175
176 #[proptest]
177 fn too_large_shift_crashes_vm(arg: u128, #[strategy(128_u32..)] shift_amount: u32) {
178 test_assertion_failure(
179 &ShadowedClosure::new(ShiftRight),
180 InitVmState::with_stack(ShiftRight.set_up_test_stack((arg, shift_amount))),
181 &[ShiftRight::SHIFT_AMOUNT_TOO_BIG_ERROR_ID],
182 )
183 }
184}
185
186#[cfg(test)]
187mod benches {
188 use super::*;
189 use crate::test_prelude::*;
190
191 #[test]
192 fn benchmark() {
193 ShadowedClosure::new(ShiftRight).bench();
194 }
195}