Skip to main content

tfhe/integer/server_key/radix_parallel/
shift.rs

1use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext};
2use crate::integer::server_key::radix_parallel::bit_extractor::BitExtractor;
3use crate::integer::ServerKey;
4use rayon::prelude::*;
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7pub(super) enum BarrelShifterOperation {
8    LeftRotate,
9    LeftShift,
10    RightShift,
11    RightRotate,
12}
13
14impl BarrelShifterOperation {
15    pub(super) fn invert_direction(self) -> Self {
16        match self {
17            Self::LeftRotate => Self::RightRotate,
18            Self::LeftShift => Self::RightShift,
19            Self::RightShift => Self::LeftShift,
20            Self::RightRotate => Self::LeftRotate,
21        }
22    }
23}
24
25impl ServerKey {
26    //======================================================================
27    //                Shift Right
28    //======================================================================
29
30    pub fn unchecked_right_shift_parallelized<T>(&self, ct_left: &T, shift: &RadixCiphertext) -> T
31    where
32        T: IntegerRadixCiphertext,
33    {
34        let mut result = ct_left.clone();
35        self.unchecked_right_shift_assign_parallelized(&mut result, shift);
36        result
37    }
38
39    pub fn unchecked_right_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &RadixCiphertext)
40    where
41        T: IntegerRadixCiphertext,
42    {
43        self.unchecked_shift_rotate_bits_assign(ct, shift, BarrelShifterOperation::RightShift);
44    }
45
46    pub fn smart_right_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &mut RadixCiphertext)
47    where
48        T: IntegerRadixCiphertext,
49    {
50        rayon::join(
51            || {
52                if !ct.block_carries_are_empty() {
53                    self.full_propagate_parallelized(ct);
54                }
55            },
56            || {
57                if !shift.block_carries_are_empty() {
58                    self.full_propagate_parallelized(shift);
59                }
60            },
61        );
62        self.unchecked_right_shift_assign_parallelized(ct, shift);
63    }
64
65    pub fn smart_right_shift_parallelized<T>(&self, ct: &mut T, shift: &mut RadixCiphertext) -> T
66    where
67        T: IntegerRadixCiphertext,
68    {
69        rayon::join(
70            || {
71                if !ct.block_carries_are_empty() {
72                    self.full_propagate_parallelized(ct);
73                }
74            },
75            || {
76                if !shift.block_carries_are_empty() {
77                    self.full_propagate_parallelized(shift);
78                }
79            },
80        );
81        self.unchecked_right_shift_parallelized(ct, shift)
82    }
83
84    pub fn right_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &RadixCiphertext)
85    where
86        T: IntegerRadixCiphertext,
87    {
88        let mut tmp_rhs;
89
90        let (lhs, rhs) = match (
91            ct.block_carries_are_empty(),
92            shift.block_carries_are_empty(),
93        ) {
94            (true, true) => (ct, shift),
95            (true, false) => {
96                tmp_rhs = shift.clone();
97                self.full_propagate_parallelized(&mut tmp_rhs);
98                (ct, &tmp_rhs)
99            }
100            (false, true) => {
101                self.full_propagate_parallelized(ct);
102                (ct, shift)
103            }
104            (false, false) => {
105                tmp_rhs = shift.clone();
106                rayon::join(
107                    || self.full_propagate_parallelized(ct),
108                    || self.full_propagate_parallelized(&mut tmp_rhs),
109                );
110                (ct, &tmp_rhs)
111            }
112        };
113
114        self.unchecked_right_shift_assign_parallelized(lhs, rhs);
115    }
116
117    /// Computes homomorphically a right shift by an encrypted amount
118    ///
119    /// The result is returned as a new ciphertext.
120    ///
121    /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will
122    /// check that the input ciphertexts block carries are empty and clears them if it's not the
123    /// case and the operation requires it. It outputs a ciphertext whose block carries are always
124    /// empty.
125    ///
126    /// This means that when using only "default" operations, a given operation (like add for
127    /// example) has always the same performance characteristics from one call to another and
128    /// guarantees correctness by pre-emptively clearing carries of output ciphertexts.
129    ///
130    /// # Example
131    ///
132    /// ```rust
133    /// use tfhe::integer::gen_keys_radix;
134    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
135    ///
136    /// // We have 4 * 2 = 8 bits of message
137    /// let size = 4;
138    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
139    ///
140    /// let msg = 128;
141    /// let shift = 2;
142    ///
143    /// let ct = cks.encrypt(msg);
144    /// let shift_ct = cks.encrypt(shift as u64);
145    ///
146    /// // Compute homomorphically a right shift:
147    /// let ct_res = sks.right_shift_parallelized(&ct, &shift_ct);
148    ///
149    /// // Decrypt:
150    /// let dec: u64 = cks.decrypt(&ct_res);
151    /// assert_eq!(msg >> shift, dec);
152    /// ```
153    pub fn right_shift_parallelized<T>(&self, ct: &T, shift: &RadixCiphertext) -> T
154    where
155        T: IntegerRadixCiphertext,
156    {
157        let mut ct_res = ct.clone();
158        self.right_shift_assign_parallelized(&mut ct_res, shift);
159        ct_res
160    }
161
162    //======================================================================
163    //                Shift Left
164    //======================================================================
165
166    /// left shift by and encrypted amount
167    ///
168    /// This requires:
169    /// - ct to have clean carries
170    /// - shift to have clean carries
171    /// - the number of bits in the block to be >= 3
172    pub fn unchecked_left_shift_parallelized<T>(&self, ct_left: &T, shift: &RadixCiphertext) -> T
173    where
174        T: IntegerRadixCiphertext,
175    {
176        let mut result = ct_left.clone();
177        self.unchecked_left_shift_assign_parallelized(&mut result, shift);
178        result
179    }
180
181    /// left shift by and encrypted amount
182    ///
183    /// This requires:
184    /// - ct to have clean carries
185    /// - shift to have clean carries
186    /// - the number of bits in the block to be >= 3
187    pub fn unchecked_left_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &RadixCiphertext)
188    where
189        T: IntegerRadixCiphertext,
190    {
191        self.unchecked_shift_rotate_bits_assign(ct, shift, BarrelShifterOperation::LeftShift);
192    }
193
194    pub fn smart_left_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &mut RadixCiphertext)
195    where
196        T: IntegerRadixCiphertext,
197    {
198        rayon::join(
199            || {
200                if !ct.block_carries_are_empty() {
201                    self.full_propagate_parallelized(ct);
202                }
203            },
204            || {
205                if !shift.block_carries_are_empty() {
206                    self.full_propagate_parallelized(shift);
207                }
208            },
209        );
210        self.unchecked_left_shift_assign_parallelized(ct, shift);
211    }
212
213    pub fn smart_left_shift_parallelized<T>(&self, ct: &mut T, shift: &mut RadixCiphertext) -> T
214    where
215        T: IntegerRadixCiphertext,
216    {
217        rayon::join(
218            || {
219                if !ct.block_carries_are_empty() {
220                    self.full_propagate_parallelized(ct);
221                }
222            },
223            || {
224                if !shift.block_carries_are_empty() {
225                    self.full_propagate_parallelized(shift);
226                }
227            },
228        );
229        self.unchecked_left_shift_parallelized(ct, shift)
230    }
231
232    pub fn left_shift_assign_parallelized<T>(&self, ct: &mut T, shift: &RadixCiphertext)
233    where
234        T: IntegerRadixCiphertext,
235    {
236        let mut tmp_rhs;
237
238        let (lhs, rhs) = match (
239            ct.block_carries_are_empty(),
240            shift.block_carries_are_empty(),
241        ) {
242            (true, true) => (ct, shift),
243            (true, false) => {
244                tmp_rhs = shift.clone();
245                self.full_propagate_parallelized(&mut tmp_rhs);
246                (ct, &tmp_rhs)
247            }
248            (false, true) => {
249                self.full_propagate_parallelized(ct);
250                (ct, shift)
251            }
252            (false, false) => {
253                tmp_rhs = shift.clone();
254                rayon::join(
255                    || self.full_propagate_parallelized(ct),
256                    || self.full_propagate_parallelized(&mut tmp_rhs),
257                );
258                (ct, &tmp_rhs)
259            }
260        };
261
262        self.unchecked_left_shift_assign_parallelized(lhs, rhs);
263    }
264
265    /// Computes homomorphically a left shift by an encrypted amount.
266    ///
267    /// The result is returned as a new ciphertext.
268    ///
269    /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will
270    /// check that the input ciphertexts block carries are empty and clears them if it's not the
271    /// case and the operation requires it. It outputs a ciphertext whose block carries are always
272    /// empty.
273    ///
274    /// This means that when using only "default" operations, a given operation (like add for
275    /// example) has always the same performance characteristics from one call to another and
276    /// guarantees correctness by pre-emptively clearing carries of output ciphertexts.
277    ///
278    /// # Example
279    ///
280    /// ```rust
281    /// use tfhe::integer::gen_keys_radix;
282    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
283    ///
284    /// // We have 4 * 2 = 8 bits of message
285    /// let size = 4;
286    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
287    ///
288    /// let msg = 21;
289    /// let shift = 2;
290    ///
291    /// let ct1 = cks.encrypt(msg);
292    /// let ct2 = cks.encrypt(shift as u64);
293    ///
294    /// // Compute homomorphically a left shift:
295    /// let ct_res = sks.left_shift_parallelized(&ct1, &ct2);
296    ///
297    /// // Decrypt:
298    /// let dec: u64 = cks.decrypt(&ct_res);
299    /// assert_eq!(msg << shift, dec);
300    /// ```
301    pub fn left_shift_parallelized<T>(&self, ct: &T, shift: &RadixCiphertext) -> T
302    where
303        T: IntegerRadixCiphertext,
304    {
305        let mut ct_res = ct.clone();
306        self.left_shift_assign_parallelized(&mut ct_res, shift);
307        ct_res
308    }
309
310    /// Does a rotation/shift of bits of the `ct` by the specified `amount`
311    ///
312    /// Input must not have carries
313    pub(super) fn unchecked_shift_rotate_bits_assign<T>(
314        &self,
315        ct: &mut T,
316        amount: &RadixCiphertext,
317        operation: BarrelShifterOperation,
318    ) where
319        T: IntegerRadixCiphertext,
320    {
321        let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
322        let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
323        assert!(carry_bits_per_block >= message_bits_per_block);
324
325        let num_bits = ct.blocks().len() * message_bits_per_block as usize;
326        let mut max_num_bits_that_tell_shift = num_bits.ilog2() as u64;
327        // This effectively means, that if the block parameters
328        // give a total_nb_bits that is not a power of two,
329        // then the behaviour of shifting won't be the same
330        // if shift >= total_nb_bits compared to when total_nb_bits
331        // is a power of two, as will 'capture' more bits in `shift_bits`
332        if !num_bits.is_power_of_two() {
333            max_num_bits_that_tell_shift += 1;
334        }
335
336        if message_bits_per_block == 1 {
337            let mut shift_bit_extractor = BitExtractor::with_final_offset(
338                &amount.blocks,
339                self,
340                message_bits_per_block as usize,
341                message_bits_per_block as usize,
342            );
343
344            // If blocks encrypt one bit, then shifting bits is just shifting blocks
345            let result = self.block_barrel_shifter_impl(
346                ct,
347                &mut shift_bit_extractor,
348                0..max_num_bits_that_tell_shift as usize,
349                // Our blocks are stored in little endian order
350                operation.invert_direction(),
351            );
352
353            *ct = result;
354        } else if message_bits_per_block.is_power_of_two() {
355            let result = self.barrel_shift_bits_pow2_block_modulus(
356                ct,
357                amount,
358                operation,
359                max_num_bits_that_tell_shift as usize,
360            );
361            *ct = result;
362        } else {
363            self.bit_barrel_shifter(ct, amount, operation);
364        }
365    }
366
367    /// Does a rotation/shift of bits of the `ct` by the specified `amount`
368    ///
369    /// Uses a barrel shifter implementation
370    ///
371    /// # Note
372    ///
373    /// This only works for parameters where blocks encrypts a number of bits
374    /// of message that is a power of 2 (e.g. 1 bit, 2 bit, 4 bits, but not 3 bits)
375    pub(super) fn barrel_shift_bits_pow2_block_modulus<T>(
376        &self,
377        ct: &T,
378        amount: &RadixCiphertext,
379        operation: BarrelShifterOperation,
380        max_num_bits_that_tell_shift: usize,
381    ) -> T
382    where
383        T: IntegerRadixCiphertext,
384    {
385        if amount.blocks.is_empty() || ct.blocks().is_empty() {
386            return ct.clone();
387        }
388
389        let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
390        let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
391        assert!(carry_bits_per_block >= message_bits_per_block);
392        assert!(message_bits_per_block.is_power_of_two());
393
394        if ct.blocks().len() == 1 {
395            let lut = self
396                .key
397                .generate_lookup_table_bivariate(|input, first_shift_block| {
398                    let shift_within_block = first_shift_block % message_bits_per_block;
399
400                    match operation {
401                        BarrelShifterOperation::LeftShift => {
402                            (input << shift_within_block) % self.message_modulus().0
403                        }
404                        BarrelShifterOperation::LeftRotate => {
405                            let shifted = (input << shift_within_block) % self.message_modulus().0;
406                            let wrapped = input >> (message_bits_per_block - shift_within_block);
407                            shifted | wrapped
408                        }
409                        BarrelShifterOperation::RightRotate => {
410                            let shifted = input >> shift_within_block;
411                            let wrapped = (input << (message_bits_per_block - shift_within_block))
412                                % self.message_modulus().0;
413                            wrapped | shifted
414                        }
415                        BarrelShifterOperation::RightShift => {
416                            if T::IS_SIGNED {
417                                let sign_bit_pos = message_bits_per_block - 1;
418                                let sign_bit = (input >> sign_bit_pos) & 1;
419                                let padding_block = (self.message_modulus().0 - 1) * sign_bit;
420
421                                // Pad with sign bits to 'simulate' an arithmetic shift
422                                let input = (padding_block << message_bits_per_block) | input;
423                                (input >> shift_within_block) % self.message_modulus().0
424                            } else {
425                                input >> shift_within_block
426                            }
427                        }
428                    }
429                });
430
431            let block = self.key.unchecked_apply_lookup_table_bivariate(
432                &ct.blocks()[0],
433                &amount.blocks[0],
434                &lut,
435            );
436
437            return T::from_blocks(vec![block]);
438        }
439
440        let message_for_block =
441            self.key
442                .generate_lookup_table_bivariate(|input, first_shift_block| {
443                    let shift_within_block = first_shift_block % message_bits_per_block;
444                    let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
445
446                    let b = match operation {
447                        BarrelShifterOperation::LeftShift | BarrelShifterOperation::LeftRotate => {
448                            (input << shift_within_block) % self.message_modulus().0
449                        }
450                        BarrelShifterOperation::RightShift
451                        | BarrelShifterOperation::RightRotate => {
452                            (input >> shift_within_block) % self.message_modulus().0
453                        }
454                    };
455
456                    if shift_to_next_block == 1 {
457                        0
458                    } else {
459                        b
460                    }
461                });
462
463        // When doing right shift of a signed ciphertext, we do an arithmetic shift
464        // Thus, we need some special luts to be used on the last block
465        // (which has the sign bit)
466        let message_for_block_right_shift_signed =
467            if T::IS_SIGNED && operation == BarrelShifterOperation::RightShift {
468                let lut = self
469                    .key
470                    .generate_lookup_table_bivariate(|input, first_shift_block| {
471                        let shift_within_block = first_shift_block % message_bits_per_block;
472                        let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
473
474                        let sign_bit_pos = message_bits_per_block - 1;
475                        let sign_bit = (input >> sign_bit_pos) & 1;
476                        let padding_block = (self.message_modulus().0 - 1) * sign_bit;
477
478                        if shift_to_next_block == 1 {
479                            padding_block
480                        } else {
481                            // Pad with sign bits to 'simulate' an arithmetic shift
482                            let input = (padding_block << message_bits_per_block) | input;
483                            (input >> shift_within_block) % self.message_modulus().0
484                        }
485                    });
486                Some(lut)
487            } else {
488                None
489            };
490
491        // Extracts bits and put them in the bit index 2 (=> bit number 3)
492        // so that it is already aligned to the correct position of the cmux input,
493        // and we reduce noise growth
494        let mut shift_bit_extractor = BitExtractor::with_final_offset(
495            &amount.blocks,
496            self,
497            message_bits_per_block as usize,
498            message_bits_per_block as usize,
499        );
500
501        let message_for_next_block =
502            self.key
503                .generate_lookup_table_bivariate(|previous, first_shift_block| {
504                    let shift_within_block = first_shift_block % message_bits_per_block;
505                    let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
506
507                    if shift_to_next_block == 1 {
508                        // We get the message part of the previous block
509                        match operation {
510                            BarrelShifterOperation::LeftShift
511                            | BarrelShifterOperation::LeftRotate => {
512                                (previous << shift_within_block) % self.message_modulus().0
513                            }
514                            BarrelShifterOperation::RightShift
515                            | BarrelShifterOperation::RightRotate => {
516                                (previous >> shift_within_block) % self.message_modulus().0
517                            }
518                        }
519                    } else {
520                        // We get the carry part of the previous block
521                        match operation {
522                            BarrelShifterOperation::LeftShift
523                            | BarrelShifterOperation::LeftRotate => {
524                                previous >> (message_bits_per_block - shift_within_block)
525                            }
526                            BarrelShifterOperation::RightShift
527                            | BarrelShifterOperation::RightRotate => {
528                                (previous << (message_bits_per_block - shift_within_block))
529                                    % self.message_modulus().0
530                            }
531                        }
532                    }
533                });
534
535        let message_for_next_next_block =
536            self.key
537                .generate_lookup_table_bivariate(|previous_previous, first_shift_block| {
538                    let shift_within_block = first_shift_block % message_bits_per_block;
539                    let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
540
541                    if shift_to_next_block == 1 {
542                        // We get the carry part of the previous block
543                        match operation {
544                            BarrelShifterOperation::LeftShift
545                            | BarrelShifterOperation::LeftRotate => {
546                                previous_previous >> (message_bits_per_block - shift_within_block)
547                            }
548                            BarrelShifterOperation::RightShift
549                            | BarrelShifterOperation::RightRotate => {
550                                (previous_previous << (message_bits_per_block - shift_within_block))
551                                    % self.message_modulus().0
552                            }
553                        }
554                    } else {
555                        // Nothing reaches that block
556                        0
557                    }
558                });
559
560        let message_for_next_block_right_shift_signed = if T::IS_SIGNED
561            && operation == BarrelShifterOperation::RightShift
562        {
563            let lut = self
564                .key
565                .generate_lookup_table_bivariate(|previous, first_shift_block| {
566                    let shift_within_block = first_shift_block % message_bits_per_block;
567                    let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
568
569                    let sign_bit_pos = message_bits_per_block - 1;
570                    let sign_bit = (previous >> sign_bit_pos) & 1;
571                    let padding_block = (self.message_modulus().0 - 1) * sign_bit;
572
573                    if shift_to_next_block == 1 {
574                        // Pad with sign bits to 'simulate' an arithmetic shift
575                        let previous = (padding_block << message_bits_per_block) | previous;
576                        // We get the message part of the previous block
577                        (previous >> shift_within_block) % self.message_modulus().0
578                    } else {
579                        // We get the carry part of the previous block
580                        (previous << (message_bits_per_block - shift_within_block))
581                            % self.message_modulus().0
582                    }
583                });
584            Some(lut)
585        } else {
586            None
587        };
588
589        let mut messages = ct.blocks().to_vec();
590        let mut messages_for_next_blocks = ct.blocks().to_vec();
591        let mut messages_for_next_next_blocks = ct.blocks().to_vec();
592        let first_block = &amount.blocks[0];
593        let num_blocks = ct.blocks().len();
594        rayon::scope(|s| {
595            s.spawn(|_| {
596                messages.par_iter_mut().enumerate().for_each(|(i, block)| {
597                    let lut = if T::IS_SIGNED
598                        && operation == BarrelShifterOperation::RightShift
599                        && i == num_blocks - 1
600                    {
601                        message_for_block_right_shift_signed.as_ref().unwrap()
602                    } else {
603                        &message_for_block
604                    };
605                    self.key
606                        .unchecked_apply_lookup_table_bivariate_assign(block, first_block, lut);
607                });
608            });
609
610            s.spawn(|_| {
611                let range = match operation {
612                    BarrelShifterOperation::RightShift => {
613                        messages_for_next_blocks[0] = self.key.create_trivial(0);
614                        1..num_blocks
615                    }
616                    BarrelShifterOperation::LeftShift => {
617                        messages_for_next_blocks[num_blocks - 1] = self.key.create_trivial(0);
618                        0..num_blocks - 1
619                    }
620                    BarrelShifterOperation::LeftRotate | BarrelShifterOperation::RightRotate => {
621                        0..num_blocks
622                    }
623                };
624
625                let range_len = range.len();
626                messages_for_next_blocks[range]
627                    .par_iter_mut()
628                    .enumerate()
629                    .for_each(|(i, block)| {
630                        let lut = if T::IS_SIGNED
631                            && operation == BarrelShifterOperation::RightShift
632                            && i == range_len - 1
633                        {
634                            message_for_next_block_right_shift_signed.as_ref().unwrap()
635                        } else {
636                            &message_for_next_block
637                        };
638                        self.key.unchecked_apply_lookup_table_bivariate_assign(
639                            block,
640                            first_block,
641                            lut,
642                        );
643                    });
644            });
645
646            s.spawn(|_| {
647                let range = match operation {
648                    BarrelShifterOperation::RightShift => {
649                        messages_for_next_next_blocks[0] = self.key.create_trivial(0);
650                        messages_for_next_next_blocks[1] = self.key.create_trivial(0);
651                        2..num_blocks
652                    }
653                    BarrelShifterOperation::LeftShift => {
654                        messages_for_next_next_blocks[num_blocks - 1] = self.key.create_trivial(0);
655                        messages_for_next_next_blocks[num_blocks - 2] = self.key.create_trivial(0);
656                        0..num_blocks - 2
657                    }
658                    BarrelShifterOperation::LeftRotate | BarrelShifterOperation::RightRotate => {
659                        0..num_blocks
660                    }
661                };
662                messages_for_next_next_blocks[range]
663                    .par_iter_mut()
664                    .for_each(|block| {
665                        self.key.unchecked_apply_lookup_table_bivariate_assign(
666                            block,
667                            first_block,
668                            &message_for_next_next_block,
669                        );
670                    });
671            });
672
673            s.spawn(|_| {
674                let num_bit_that_tells_shift_within_blocks = message_bits_per_block.ilog2();
675                let num_bits_already_done = num_bit_that_tells_shift_within_blocks + 1;
676                if u64::from(num_bits_already_done) == message_bits_per_block {
677                    shift_bit_extractor.set_source_blocks(&amount.blocks[1..]);
678                    shift_bit_extractor.prepare_next_batch();
679                } else {
680                    shift_bit_extractor.prepare_next_batch();
681                    assert!(
682                        shift_bit_extractor.current_buffer_len() > num_bits_already_done as usize
683                    );
684                    // Now remove bits that where used for the 'shift within blocks'
685                    for _ in 0..num_bits_already_done {
686                        let _ = shift_bit_extractor.next().unwrap();
687                    }
688                }
689            });
690        });
691
692        // 0 should never be possible
693        assert!(shift_bit_extractor.current_buffer_len() >= 1);
694
695        match operation {
696            BarrelShifterOperation::LeftShift | BarrelShifterOperation::LeftRotate => {
697                messages_for_next_blocks.rotate_right(1);
698                messages_for_next_next_blocks.rotate_right(2);
699            }
700            BarrelShifterOperation::RightShift | BarrelShifterOperation::RightRotate => {
701                messages_for_next_blocks.rotate_left(1);
702                messages_for_next_next_blocks.rotate_left(2);
703            }
704        }
705
706        for (m0, (m1, m2)) in messages.iter_mut().zip(
707            messages_for_next_blocks
708                .iter()
709                .zip(messages_for_next_next_blocks.iter()),
710        ) {
711            self.key.unchecked_add_assign(m0, m1);
712            self.key.unchecked_add_assign(m0, m2);
713        }
714
715        let radix = T::from_blocks(messages);
716
717        let num_bit_that_tells_shift_within_blocks = message_bits_per_block.ilog2();
718        let num_bits_already_done = num_bit_that_tells_shift_within_blocks + 1;
719        self.block_barrel_shifter_impl(
720            &radix,
721            &mut shift_bit_extractor,
722            // We already did the first block rotation so we start at 1
723            // And do + 1 as the range is exclusive
724            1..max_num_bits_that_tell_shift - num_bits_already_done as usize + 1,
725            // blocks are in little endian order which is the opposite
726            // of how bits are textually represented
727            operation.invert_direction(),
728        )
729    }
730
731    /// This implements a "barrel shifter".
732    ///
733    /// This construct is what is used in hardware to
734    /// implement left/right shift/rotate
735    ///
736    /// This requires:
737    /// - ct to have clean carries
738    /// - shift to have clean carries
739    /// - the number of bits in the block to be >= 3
740    ///
741    /// Similarly to rust `wrapping_shl/shr` functions
742    /// it removes any high-order bits of `shift`
743    /// that would cause the shift to exceed the bitwidth of the type.
744    ///
745    /// **However**, when the total number of bits represented by the
746    /// radix ciphertext is not a power of two (eg a ciphertext with 12 bits)
747    /// then, it removes bit that are higher than the closest higher power of two.
748    /// So for a 12 bits radix ciphertext, its closest higher power of two is 16,
749    /// thus, any bit that are higher than log2(16) will be removed
750    ///
751    /// `ct` will be assigned the result, and it will be in a fresh state
752    pub(super) fn bit_barrel_shifter<T>(
753        &self,
754        ct: &mut T,
755        shift: &RadixCiphertext,
756        operation: BarrelShifterOperation,
757    ) where
758        T: IntegerRadixCiphertext,
759    {
760        // What matters is the len of the ct to shift, not the `shift` len
761        let num_blocks = ct.blocks().len();
762        let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
763        let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
764        let total_nb_bits = message_bits_per_block * num_blocks as u64;
765
766        assert!(
767            (message_bits_per_block + carry_bits_per_block) >= 3,
768            "Blocks must have at least 3 bits"
769        );
770
771        let (bits, shift_bits) = rayon::join(
772            || {
773                let mut bit_extractor =
774                    BitExtractor::new(ct.blocks(), self, message_bits_per_block as usize);
775                bit_extractor.extract_all_bits()
776            },
777            || {
778                let mut max_num_bits_that_tell_shift = total_nb_bits.ilog2() as u64;
779                // This effectively means, that if the block parameters
780                // give a total_nb_bits that is not a power of two,
781                // then the behaviour of shifting won't be the same
782                // if shift >= total_nb_bits compared to when total_nb_bits
783                // is a power of two, as will 'capture' more bits in `shift_bits`
784                if !total_nb_bits.is_power_of_two() {
785                    max_num_bits_that_tell_shift += 1;
786                }
787
788                // Extracts bits and put them in the bit index 2 (=> bit number 3)
789                // so that it is already aligned to the correct position of the cmux input
790                // and we reduce noise growth
791                let mut bit_extractor = BitExtractor::with_final_offset(
792                    &shift.blocks,
793                    self,
794                    message_bits_per_block as usize,
795                    2,
796                );
797                bit_extractor.extract_n_bits(max_num_bits_that_tell_shift as usize)
798            },
799        );
800
801        let mux_lut = self.key.generate_lookup_table(|x| {
802            // x is expected to be x = 0bcba
803            // where
804            // - c is the control bit
805            // - b the bit value returned if c is 1
806            // - a the bit value returned if c is 0
807            // (any bit above c is ignored)
808            let x = x & 7;
809            let control_bit = x >> 2;
810            let previous_bit = (x & 2) >> 1;
811            let current_bit = x & 1;
812
813            if control_bit == 1 {
814                previous_bit
815            } else {
816                current_bit
817            }
818        });
819
820        let offset = match operation {
821            BarrelShifterOperation::LeftShift | BarrelShifterOperation::LeftRotate => 0,
822            BarrelShifterOperation::RightShift | BarrelShifterOperation::RightRotate => {
823                total_nb_bits
824            }
825        };
826
827        let is_right_shift = matches!(operation, BarrelShifterOperation::RightShift);
828        let padding_bit = if T::IS_SIGNED && is_right_shift {
829            // Do an "arithmetic shift" by padding with the sign bit
830            bits.last().unwrap().clone()
831        } else {
832            self.key.create_trivial(0)
833        };
834
835        let mut input_bits_a = bits;
836        let mut input_bits_b = input_bits_a.clone();
837        // Buffer used to hold inputs for a bitwise cmux gate, simulated using a PBS
838        let mut mux_inputs = input_bits_a.clone();
839
840        for (d, shift_bit) in shift_bits.iter().enumerate() {
841            for i in 0..total_nb_bits as usize {
842                input_bits_b[i].clone_from(&input_bits_a[i]);
843                self.key.create_trivial_assign(&mut mux_inputs[i], 0);
844            }
845
846            match operation {
847                BarrelShifterOperation::LeftShift => {
848                    input_bits_b.rotate_right(1 << d);
849                    for bit_that_wrapped in &mut input_bits_b[..1 << d] {
850                        bit_that_wrapped.clone_from(&padding_bit);
851                    }
852                }
853                BarrelShifterOperation::RightShift => {
854                    input_bits_b.rotate_left(1 << d);
855                    let bits_that_wrapped = &mut input_bits_b[total_nb_bits as usize - (1 << d)..];
856                    for bit_that_wrapped in bits_that_wrapped {
857                        bit_that_wrapped.clone_from(&padding_bit);
858                    }
859                }
860                BarrelShifterOperation::LeftRotate => {
861                    input_bits_b.rotate_right(1 << d);
862                }
863                BarrelShifterOperation::RightRotate => {
864                    input_bits_b.rotate_left(1 << d);
865                }
866            }
867
868            input_bits_a
869                .par_iter_mut()
870                .zip_eq(mux_inputs.par_iter_mut())
871                .enumerate()
872                .for_each(|(i, (a, mux_gate_input))| {
873                    let b = &input_bits_b[((i as u64 + offset) % total_nb_bits) as usize];
874
875                    // pack bits into one block so that we have
876                    // control_bit|b|a
877
878                    self.key.unchecked_add_assign(mux_gate_input, b);
879                    self.key.unchecked_scalar_mul_assign(mux_gate_input, 2);
880                    self.key.unchecked_add_assign(mux_gate_input, &*a);
881                    // The shift bit is already properly aligned/positioned
882                    self.key.unchecked_add_assign(mux_gate_input, shift_bit);
883
884                    // we have
885                    //
886                    // control_bit|b|a
887                    self.key.apply_lookup_table_assign(mux_gate_input, &mux_lut);
888                    (*a).clone_from(mux_gate_input);
889                });
890        }
891
892        // rename for clarity
893        let mut output_bits = input_bits_a;
894        assert_eq!(
895            output_bits.len(),
896            message_bits_per_block as usize * num_blocks
897        );
898        // We have to reconstruct blocks from the individual bits
899        output_bits
900            .as_mut_slice()
901            .par_chunks_exact_mut(message_bits_per_block as usize)
902            .zip_eq(ct.blocks_mut().par_iter_mut())
903            .for_each(|(grouped_bits, block)| {
904                let (head, last) = grouped_bits.split_at_mut(message_bits_per_block as usize - 1);
905                for bit in head.iter().rev() {
906                    self.key.unchecked_scalar_mul_assign(&mut last[0], 2);
907                    self.key.unchecked_add_assign(&mut last[0], bit);
908                }
909                // To give back a clean ciphertext
910                self.key.message_extract_assign(&mut last[0]);
911                std::mem::swap(block, &mut last[0]);
912            });
913    }
914}