tasm_lib/arithmetic/i128/shift_right.rs
1use triton_vm::prelude::*;
2
3use crate::arithmetic::u32::is_u32::IsU32;
4use crate::arithmetic::u32::shift_left::ShiftLeft as ShlU32;
5use crate::arithmetic::u32::shift_right::ShiftRight as ShrU32;
6use crate::prelude::*;
7
8/// Right-shift for 128-bit integers AKA [right-shift for `i128`][shr].
9///
10/// # Behavior
11///
12/// ```text
13/// BEFORE: _ arg3 arg2 arg1 arg0 shamt
14/// AFTER: _ res3 res2 res1 res0
15/// ```
16///
17/// where `res == arg >> shamt` as `i128`s.
18///
19/// # Preconditions
20///
21/// - `arg` consists of 4 `u32`s
22/// - `shamt` is in `[0:128)`
23///
24/// # Postconditions
25///
26/// - `res` consists of 4 `u32`s
27///
28/// # Panics
29///
30/// - If preconditions are not met.
31///
32/// [shr]: core::ops::Shr
33#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
34pub struct ShiftRight;
35
36impl ShiftRight {
37 pub const ARGUMENT_LIMB_3_NOT_U32_ERROR_ID: i128 = 323;
38 pub const ARGUMENT_LIMB_2_NOT_U32_ERROR_ID: i128 = 322;
39 pub const ARGUMENT_LIMB_1_NOT_U32_ERROR_ID: i128 = 321;
40 pub const ARGUMENT_LIMB_0_NOT_U32_ERROR_ID: i128 = 320;
41 pub const SHAMT_NOT_U32_ERROR_ID: i128 = 324;
42}
43
44impl BasicSnippet for ShiftRight {
45 fn inputs(&self) -> Vec<(DataType, String)> {
46 vec![
47 (DataType::I128, "arg".to_string()),
48 (DataType::U32, "shamt".to_string()),
49 ]
50 }
51
52 fn outputs(&self) -> Vec<(DataType, String)> {
53 vec![(DataType::I128, "res".to_string())]
54 }
55
56 fn entrypoint(&self) -> String {
57 "tasmlib_arithmetic_i128_shift_right".to_string()
58 }
59
60 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
61 let entrypoint = self.entrypoint();
62 let shr_i128_by_32n = format!("{entrypoint}_by_32n");
63 let clean_up_for_early_return = format!("{entrypoint}_early_return");
64 let entrypoint = self.entrypoint();
65
66 let is_u32 = library.import(Box::new(IsU32));
67 let shr_u32 = library.import(Box::new(ShrU32));
68 let shl_u32 = library.import(Box::new(ShlU32));
69
70 triton_asm! {
71 // BEFORE: _ arg3 arg2 arg1 arg0 shamt
72 // AFTER: _ res3 res2 res1 res0
73 {entrypoint}:
74
75 /* assert preconditions */
76
77 dup 4 dup 4 dup 4 dup 4
78 // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0
79
80 push 128 dup 5
81 // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0 128 shamt
82
83 lt
84 // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0 (shamt < 128)
85
86 assert error_id {Self::SHAMT_NOT_U32_ERROR_ID}
87 // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0
88
89 call {is_u32} assert error_id {Self::ARGUMENT_LIMB_0_NOT_U32_ERROR_ID}
90 call {is_u32} assert error_id {Self::ARGUMENT_LIMB_1_NOT_U32_ERROR_ID}
91 call {is_u32} assert error_id {Self::ARGUMENT_LIMB_2_NOT_U32_ERROR_ID}
92 call {is_u32} assert error_id {Self::ARGUMENT_LIMB_3_NOT_U32_ERROR_ID}
93 // _ arg3 arg2 arg1 arg0 shamt
94
95
96 /* extract top bit */
97
98 dup 4 push 31 call {shr_u32}
99 hint msb = stack[0]
100 // _ arg3 arg2 arg1 arg0 shamt msb
101
102
103 /* shift right by multiple of 32 */
104
105 call {shr_i128_by_32n}
106 // _ arg3' arg2' arg1' arg0' (shamt % 32) msb
107 // _ arg3' arg2' arg1' arg0' shamt' msb
108
109
110 /* early return if possible */
111 dup 1 push 0 eq dup 0
112 // _ arg3' arg2' arg1' arg0' shamt' msb (shamt' == 0) (shamt' == 0)
113
114 skiz call {clean_up_for_early_return}
115 skiz return
116 // _ arg3' arg2' arg1' arg0' shamt' msb
117
118
119 /* shift right by the remainder modulo 32 */
120
121 push 32 dup 2 push -1 mul add
122 // _ arg3' arg2' arg1' arg0' shamt' msb (32-shamt')
123 // _ arg3' arg2' arg1' arg0' shamt' msb compl'
124
125 push {u32::MAX} dup 2 mul
126 // _ arg3' arg2' arg1' arg0' shamt' msb compl' (u32::MAX * msb)
127
128 dup 1 call {shl_u32}
129 // _ arg3' arg2' arg1' arg0' shamt' msb compl' ((u32::MAX * msb) << compl')
130 // _ arg3' arg2' arg1' arg0' shamt' msb compl' new_ms_limb
131
132 pick 7 dup 0
133 // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' arg3'
134
135 dup 3 call {shl_u32}
136 // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' (arg3' << compl')
137 // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' arg3'_lo
138
139 place 2
140 // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb arg3'
141
142 dup 5 call {shr_u32}
143 // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb (arg3' >> shamt')
144 // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb arg3_hi
145
146 add
147 // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo arg3''
148
149 swap 7 dup 0
150 // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' arg2'
151
152 dup 3 call {shl_u32}
153 // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' (arg2' << compl')
154 // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' arg2'_lo
155
156 place 2
157 // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo arg2'
158
159 dup 5 call {shr_u32}
160 // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo (arg2' >> shamt')
161 // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo arg2'_hi
162
163 add
164 // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo (arg3'_lo + arg2'_hi)
165 // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg2''
166
167 swap 6 dup 0
168 // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' arg1'
169
170 dup 3 call {shl_u32}
171 // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' (arg1' << compl')
172 // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' arg1'_lo
173
174 place 2
175 // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo arg1'
176
177 dup 5 call {shr_u32}
178 // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo (arg1' >> shamt')
179 // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo arg1'_hi
180
181 add
182 // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo (arg2'_lo+ arg1'_hi)
183 // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg1''
184
185 swap 5
186 // _ arg3'' arg2'' arg1'' shamt' msb compl' arg1'_lo arg0'
187
188 pick 4
189 // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo arg0' shamt'
190
191 call {shr_u32}
192 // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo (arg0' >> shamt')
193 // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo arg0'_hi
194
195 add
196 // _ arg3'' arg2'' arg1'' msb compl' (arg1'_lo + arg0'_hi)
197 // _ arg3'' arg2'' arg1'' msb compl' argo0''
198
199 place 2 pop 2
200 // _ arg3'' arg2'' arg1'' argo0''
201
202 return
203
204 // BEFORE: _ arg3 arg2 arg1 arg0 shamt msb
205 // AFTER: _ arg3' arg2' arg1' arg0' shamt' msb
206 // where `arg >> shamt == arg' >> shamt'` and `shamt' < 32`
207 {shr_i128_by_32n}:
208
209 /* evaluate termination condition */
210
211 push 32 dup 2 lt
212 // _ arg3 arg2 arg1 arg0 shamt msb (shamt < 32)
213
214 skiz return
215
216
217 /* apply one limb-shift */
218
219 push {u32::MAX} dup 1 mul
220 // _ arg3 arg2 arg1 arg0 shamt msb (u32::MAX * msb)
221 // _ arg3 arg2 arg1 arg0 shamt msb ms_limb
222
223 place 6
224 // _ ms_limb arg3 arg2 arg1 arg0 shamt msb
225
226 pick 2 pop 1
227 // _ ms_limb arg3 arg2 arg1 shamt msb
228
229 pick 1 addi -32 place 1
230 // _ ms_limb arg3 arg2 arg1 (shamt-32) msb
231
232 recurse
233
234 // BEFORE: _ arg3' arg2' arg1' arg0' shamt' msb b
235 // AFTER: _ arg3' arg2' arg1' arg0' b
236 {clean_up_for_early_return}:
237 place 2
238 pop 2
239 return
240
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::test_helpers::tasm_final_state;
249 use crate::test_prelude::*;
250
251 impl ShiftRight {
252 fn assert_expected_shift_behavior(&self, arg: i128, shamt: u32) {
253 let initial_stack = self.set_up_test_stack((arg, shamt));
254
255 let mut expected_stack = initial_stack.clone();
256 self.rust_shadow(&mut expected_stack);
257
258 test_rust_equivalence_given_complete_state(
259 &ShadowedClosure::new(Self),
260 &initial_stack,
261 &[],
262 &NonDeterminism::default(),
263 &None,
264 Some(&expected_stack),
265 );
266 }
267 }
268
269 impl Closure for ShiftRight {
270 type Args = (i128, u32);
271
272 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
273 let (arg, shift_amount) = pop_encodable::<Self::Args>(stack);
274 push_encodable(stack, &(arg >> shift_amount));
275 }
276
277 fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
278 let mut rng = StdRng::from_seed(seed);
279 (rng.random(), rng.random_range(0..128))
280 }
281 }
282
283 #[test]
284 fn standard_test() {
285 ShadowedClosure::new(ShiftRight).test()
286 }
287
288 #[proptest]
289 fn proptest(#[strategy(arb())] arg: i128, #[strategy(0u32..128)] shamt: u32) {
290 ShiftRight.assert_expected_shift_behavior(arg, shamt);
291 }
292
293 #[test]
294 fn test_edge_cases() {
295 // all i128s from all combinations of {-1, 0, 1} as their limbs
296 let arguments = (0..4)
297 .map(|_| [-1, 0, 1])
298 .multi_cartesian_product()
299 .map(|limbs| <[i128; 4]>::try_from(limbs).unwrap())
300 .map(|[l0, l1, l2, l3]| l0 + (l1 << 32) + (l2 << 64) + (l3 << 96));
301
302 let shift_amounts = [0, 1, 16, 31]
303 .into_iter()
304 .cartesian_product(0..4)
305 .map(|(l, r)| l + 32 * r);
306
307 arguments
308 .cartesian_product(shift_amounts)
309 .for_each(|(arg, shamt)| ShiftRight.assert_expected_shift_behavior(arg, shamt));
310 }
311
312 /// Shifting right by 127 must produce either 0xff..f, or 0x00..0, depending on
313 /// the sign of the i128-argument.
314 #[proptest(cases = 50)]
315 fn shifting_right_by_127_is_zero_or_minus_1(arg: i128) {
316 let mut final_state = tasm_final_state(
317 &ShadowedClosure::new(ShiftRight),
318 &ShiftRight.set_up_test_stack((arg, 127)),
319 &[],
320 NonDeterminism::default(),
321 &None,
322 );
323
324 let final_stack = &mut final_state.op_stack.stack;
325 let num_bits_in_result = pop_encodable::<i128>(final_stack).count_ones();
326
327 if arg.is_positive() {
328 prop_assert_eq!(0, num_bits_in_result);
329 } else {
330 prop_assert_eq!(i128::BITS, num_bits_in_result);
331 }
332 }
333}
334
335#[cfg(test)]
336mod benches {
337 use super::*;
338 use crate::test_prelude::*;
339
340 #[test]
341 fn benchmark() {
342 ShadowedClosure::new(ShiftRight).bench()
343 }
344}