tasm_lib/arithmetic/i128/
shift_right.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u32::is_u32::IsU32;
4use crate::arithmetic::u32::shift_left::ShiftLeft as ShlU32;
5use crate::arithmetic::u32::shift_right::ShiftRight as ShrU32;
6use crate::prelude::*;
7
8/// Right-shift for 128-bit integers AKA [right-shift for `i128`][shr].
9///
10/// # Behavior
11///
12/// ```text
13/// BEFORE: _ arg3 arg2 arg1 arg0 shamt
14/// AFTER:  _ res3 res2 res1 res0
15/// ```
16///
17/// where `res == arg >> shamt` as `i128`s.
18///
19/// # Preconditions
20///
21///  - `arg` consists of 4 `u32`s
22///  - `shamt` is in `[0:128)`
23///
24/// # Postconditions
25///
26///  - `res` consists of 4 `u32`s
27///
28/// # Panics
29///
30///  - If preconditions are not met.
31///
32/// [shr]: core::ops::Shr
33#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
34pub struct ShiftRight;
35
36impl ShiftRight {
37    pub const ARGUMENT_LIMB_3_NOT_U32_ERROR_ID: i128 = 323;
38    pub const ARGUMENT_LIMB_2_NOT_U32_ERROR_ID: i128 = 322;
39    pub const ARGUMENT_LIMB_1_NOT_U32_ERROR_ID: i128 = 321;
40    pub const ARGUMENT_LIMB_0_NOT_U32_ERROR_ID: i128 = 320;
41    pub const SHAMT_NOT_U32_ERROR_ID: i128 = 324;
42}
43
44impl BasicSnippet for ShiftRight {
45    fn inputs(&self) -> Vec<(DataType, String)> {
46        vec![
47            (DataType::I128, "arg".to_string()),
48            (DataType::U32, "shamt".to_string()),
49        ]
50    }
51
52    fn outputs(&self) -> Vec<(DataType, String)> {
53        vec![(DataType::I128, "res".to_string())]
54    }
55
56    fn entrypoint(&self) -> String {
57        "tasmlib_arithmetic_i128_shift_right".to_string()
58    }
59
60    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
61        let entrypoint = self.entrypoint();
62        let shr_i128_by_32n = format!("{entrypoint}_by_32n");
63        let clean_up_for_early_return = format!("{entrypoint}_early_return");
64        let entrypoint = self.entrypoint();
65
66        let is_u32 = library.import(Box::new(IsU32));
67        let shr_u32 = library.import(Box::new(ShrU32));
68        let shl_u32 = library.import(Box::new(ShlU32));
69
70        triton_asm! {
71            // BEFORE: _ arg3 arg2 arg1 arg0 shamt
72            // AFTER: _ res3 res2 res1 res0
73            {entrypoint}:
74
75                /* assert preconditions */
76
77                dup 4 dup 4 dup 4 dup 4
78                // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0
79
80                push 128 dup 5
81                // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0 128 shamt
82
83                lt
84                // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0 (shamt < 128)
85
86                assert error_id {Self::SHAMT_NOT_U32_ERROR_ID}
87                // _ arg3 arg2 arg1 arg0 shamt arg3 arg2 arg1 arg0
88
89                call {is_u32} assert error_id {Self::ARGUMENT_LIMB_0_NOT_U32_ERROR_ID}
90                call {is_u32} assert error_id {Self::ARGUMENT_LIMB_1_NOT_U32_ERROR_ID}
91                call {is_u32} assert error_id {Self::ARGUMENT_LIMB_2_NOT_U32_ERROR_ID}
92                call {is_u32} assert error_id {Self::ARGUMENT_LIMB_3_NOT_U32_ERROR_ID}
93                // _ arg3 arg2 arg1 arg0 shamt
94
95
96                /* extract top bit */
97
98                dup 4 push 31 call {shr_u32}
99                hint msb = stack[0]
100                // _ arg3 arg2 arg1 arg0 shamt msb
101
102
103                /* shift right by multiple of 32 */
104
105                call {shr_i128_by_32n}
106                // _ arg3' arg2' arg1' arg0' (shamt % 32) msb
107                // _ arg3' arg2' arg1' arg0' shamt' msb
108
109
110                /* early return if possible */
111                dup 1 push 0 eq dup 0
112                // _ arg3' arg2' arg1' arg0' shamt' msb (shamt' == 0) (shamt' == 0)
113
114                skiz call {clean_up_for_early_return}
115                skiz return
116                // _ arg3' arg2' arg1' arg0' shamt' msb
117
118
119                /* shift right by the remainder modulo 32 */
120
121                push 32 dup 2 push -1 mul add
122                // _ arg3' arg2' arg1' arg0' shamt' msb (32-shamt')
123                // _ arg3' arg2' arg1' arg0' shamt' msb compl'
124
125                push {u32::MAX} dup 2 mul
126                // _ arg3' arg2' arg1' arg0' shamt' msb compl' (u32::MAX * msb)
127
128                dup 1 call {shl_u32}
129                // _ arg3' arg2' arg1' arg0' shamt' msb compl' ((u32::MAX * msb) << compl')
130                // _ arg3' arg2' arg1' arg0' shamt' msb compl' new_ms_limb
131
132                pick 7 dup 0
133                // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' arg3'
134
135                dup 3 call {shl_u32}
136                // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' (arg3' << compl')
137                // _ arg2' arg1' arg0' shamt' msb compl' new_ms_limb arg3' arg3'_lo
138
139                place 2
140                // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb arg3'
141
142                dup 5 call {shr_u32}
143                // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb (arg3' >> shamt')
144                // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo new_ms_limb arg3_hi
145
146                add
147                // _ arg2' arg1' arg0' shamt' msb compl' arg3'_lo arg3''
148
149                swap 7 dup 0
150                // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' arg2'
151
152                dup 3 call {shl_u32}
153                // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' (arg2' << compl')
154                // _ arg3'' arg1' arg0' shamt' msb compl' arg3'_lo arg2' arg2'_lo
155
156                place 2
157                // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo arg2'
158
159                dup 5 call {shr_u32}
160                // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo (arg2' >> shamt')
161                // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg3'_lo arg2'_hi
162
163                add
164                // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo (arg3'_lo + arg2'_hi)
165                // _ arg3'' arg1' arg0' shamt' msb compl' arg2'_lo arg2''
166
167                swap 6 dup 0
168                // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' arg1'
169
170                dup 3 call {shl_u32}
171                // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' (arg1' << compl')
172                // _ arg3'' arg2'' arg0' shamt' msb compl' arg2'_lo arg1' arg1'_lo
173
174                place 2
175                // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo arg1'
176
177                dup 5 call {shr_u32}
178                // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo (arg1' >> shamt')
179                // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg2'_lo arg1'_hi
180
181                add
182                // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo (arg2'_lo+ arg1'_hi)
183                // _ arg3'' arg2'' arg0' shamt' msb compl' arg1'_lo arg1''
184
185                swap 5
186                // _ arg3'' arg2'' arg1'' shamt' msb compl' arg1'_lo arg0'
187
188                pick 4
189                // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo arg0' shamt'
190
191                call {shr_u32}
192                // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo (arg0' >> shamt')
193                // _ arg3'' arg2'' arg1'' msb compl' arg1'_lo arg0'_hi
194
195                add
196                // _ arg3'' arg2'' arg1'' msb compl' (arg1'_lo + arg0'_hi)
197                // _ arg3'' arg2'' arg1'' msb compl' argo0''
198
199                place 2 pop 2
200                // _ arg3'' arg2'' arg1'' argo0''
201
202                return
203
204            // BEFORE: _ arg3  arg2  arg1  arg0  shamt  msb
205            // AFTER:  _ arg3' arg2' arg1' arg0' shamt' msb
206            // where `arg >> shamt == arg' >> shamt'` and `shamt' < 32`
207            {shr_i128_by_32n}:
208
209                /* evaluate termination condition */
210
211                push 32 dup 2 lt
212                // _ arg3 arg2 arg1 arg0 shamt msb (shamt < 32)
213
214                skiz return
215
216
217                /* apply one limb-shift */
218
219                push {u32::MAX} dup 1 mul
220                // _ arg3 arg2 arg1 arg0 shamt msb (u32::MAX * msb)
221                // _ arg3 arg2 arg1 arg0 shamt msb ms_limb
222
223                place 6
224                // _ ms_limb arg3 arg2 arg1 arg0 shamt msb
225
226                pick 2 pop 1
227                // _ ms_limb arg3 arg2 arg1 shamt msb
228
229                pick 1 addi -32 place 1
230                // _ ms_limb arg3 arg2 arg1 (shamt-32) msb
231
232                recurse
233
234            // BEFORE: _ arg3' arg2' arg1' arg0' shamt' msb b
235            // AFTER:  _ arg3' arg2' arg1' arg0' b
236            {clean_up_for_early_return}:
237                place 2
238                pop 2
239                return
240
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::test_helpers::tasm_final_state;
249    use crate::test_prelude::*;
250
251    impl ShiftRight {
252        fn assert_expected_shift_behavior(&self, arg: i128, shamt: u32) {
253            let initial_stack = self.set_up_test_stack((arg, shamt));
254
255            let mut expected_stack = initial_stack.clone();
256            self.rust_shadow(&mut expected_stack);
257
258            test_rust_equivalence_given_complete_state(
259                &ShadowedClosure::new(Self),
260                &initial_stack,
261                &[],
262                &NonDeterminism::default(),
263                &None,
264                Some(&expected_stack),
265            );
266        }
267    }
268
269    impl Closure for ShiftRight {
270        type Args = (i128, u32);
271
272        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
273            let (arg, shift_amount) = pop_encodable::<Self::Args>(stack);
274            push_encodable(stack, &(arg >> shift_amount));
275        }
276
277        fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
278            let mut rng = StdRng::from_seed(seed);
279            (rng.random(), rng.random_range(0..128))
280        }
281    }
282
283    #[test]
284    fn standard_test() {
285        ShadowedClosure::new(ShiftRight).test()
286    }
287
288    #[proptest]
289    fn proptest(#[strategy(arb())] arg: i128, #[strategy(0u32..128)] shamt: u32) {
290        ShiftRight.assert_expected_shift_behavior(arg, shamt);
291    }
292
293    #[test]
294    fn test_edge_cases() {
295        // all i128s from all combinations of {-1, 0, 1} as their limbs
296        let arguments = (0..4)
297            .map(|_| [-1, 0, 1])
298            .multi_cartesian_product()
299            .map(|limbs| <[i128; 4]>::try_from(limbs).unwrap())
300            .map(|[l0, l1, l2, l3]| l0 + (l1 << 32) + (l2 << 64) + (l3 << 96));
301
302        let shift_amounts = [0, 1, 16, 31]
303            .into_iter()
304            .cartesian_product(0..4)
305            .map(|(l, r)| l + 32 * r);
306
307        arguments
308            .cartesian_product(shift_amounts)
309            .for_each(|(arg, shamt)| ShiftRight.assert_expected_shift_behavior(arg, shamt));
310    }
311
312    /// Shifting right by 127 must produce either 0xff..f, or 0x00..0, depending on
313    /// the sign of the i128-argument.
314    #[proptest(cases = 50)]
315    fn shifting_right_by_127_is_zero_or_minus_1(arg: i128) {
316        let mut final_state = tasm_final_state(
317            &ShadowedClosure::new(ShiftRight),
318            &ShiftRight.set_up_test_stack((arg, 127)),
319            &[],
320            NonDeterminism::default(),
321            &None,
322        );
323
324        let final_stack = &mut final_state.op_stack.stack;
325        let num_bits_in_result = pop_encodable::<i128>(final_stack).count_ones();
326
327        if arg.is_positive() {
328            prop_assert_eq!(0, num_bits_in_result);
329        } else {
330            prop_assert_eq!(i128::BITS, num_bits_in_result);
331        }
332    }
333}
334
335#[cfg(test)]
336mod benches {
337    use super::*;
338    use crate::test_prelude::*;
339
340    #[test]
341    fn benchmark() {
342        ShadowedClosure::new(ShiftRight).bench()
343    }
344}