tfhe/high_level_api/array/cpu/
integers.rs

1//! This module contains the implementations of the FheUint array and FheInt array backend
2//! where the values and computations are always done on CPU
3use super::super::helpers::{create_sub_mut_slice_with_bound, create_sub_slice_with_bound};
4use super::super::traits::{ArithmeticArrayBackend, BitwiseArrayBackend, ClearBitwiseArrayBackend};
5use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
6use crate::high_level_api::array::{
7    ArrayBackend, FheArrayBase, FheBackendArray, FheBackendArraySlice, FheBackendArraySliceMut,
8};
9
10use crate::array::traits::{
11    BackendDataContainer, BackendDataContainerMut, ClearArithmeticArrayBackend, TensorSlice,
12};
13use crate::high_level_api::global_state;
14use crate::high_level_api::integers::{FheIntId, FheUintId};
15use crate::integer::block_decomposition::{
16    DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
17};
18use crate::integer::server_key::radix_parallel::scalar_div_mod::SignedReciprocable;
19use crate::integer::server_key::{Reciprocable, ScalarMultiplier};
20use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, SignedRadixCiphertext};
21use crate::prelude::{FheDecrypt, FheTryEncrypt};
22use crate::{ClientKey, Error};
23use rayon::prelude::*;
24use std::marker::PhantomData;
25use std::ops::RangeBounds;
26
27pub struct CpuIntegerArrayBackend<T>(PhantomData<T>);
28
29pub type CpuUintArrayBackend = CpuIntegerArrayBackend<RadixCiphertext>;
30pub type CpuIntArrayBackend = CpuIntegerArrayBackend<SignedRadixCiphertext>;
31
32// Base alias for array of unsigned integers on the CPU only backend
33pub type CpuFheUintArray<Id> = FheBackendArray<CpuUintArrayBackend, Id>;
34pub type CpuFheUintSlice<'a, Id> = FheBackendArraySlice<'a, CpuUintArrayBackend, Id>;
35pub type CpuFheUintSliceMut<'a, Id> = FheBackendArraySliceMut<'a, CpuUintArrayBackend, Id>;
36
37// Base alias for array of signed integers on the CPU only backend
38pub type CpuFheIntArray<Id> = FheBackendArray<CpuIntArrayBackend, Id>;
39pub type CpuFheIntSlice<'a, Id> = FheBackendArraySlice<'a, CpuIntArrayBackend, Id>;
40pub type CpuFheIntSliceMut<'a, Id> = FheBackendArraySliceMut<'a, CpuIntArrayBackend, Id>;
41
42impl<T> ArrayBackend for CpuIntegerArrayBackend<T>
43where
44    T: IntegerRadixCiphertext,
45{
46    type Slice<'a>
47        = &'a [T]
48    where
49        Self: 'a;
50    type SliceMut<'a>
51        = &'a mut [T]
52    where
53        Self: 'a;
54    type Owned = Vec<T>;
55}
56
57#[inline]
58#[track_caller]
59fn par_map_sks_op_on_pair_of_elements<'a, T, F>(
60    lhs: TensorSlice<'a, &'a [T]>,
61    rhs: TensorSlice<'a, &'a [T]>,
62    op: F,
63) -> Vec<T>
64where
65    T: IntegerRadixCiphertext,
66    F: Send + Sync + Fn(&crate::integer::ServerKey, &T, &T) -> T,
67{
68    global_state::with_cpu_internal_keys(|cpu_key| {
69        lhs.par_iter()
70            .zip(rhs.par_iter())
71            .map(|(lhs, rhs)| op(cpu_key.pbs_key(), lhs, rhs))
72            .collect::<Vec<_>>()
73    })
74}
75
76impl<T> ArithmeticArrayBackend for CpuIntegerArrayBackend<T>
77where
78    T: IntegerRadixCiphertext,
79{
80    fn add_slices<'a>(
81        lhs: TensorSlice<'_, Self::Slice<'a>>,
82        rhs: TensorSlice<'_, Self::Slice<'a>>,
83    ) -> Self::Owned {
84        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::add_parallelized)
85    }
86
87    fn sub_slices<'a>(
88        lhs: TensorSlice<'_, Self::Slice<'a>>,
89        rhs: TensorSlice<'_, Self::Slice<'a>>,
90    ) -> Self::Owned {
91        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::sub_parallelized)
92    }
93
94    fn mul_slices<'a>(
95        lhs: TensorSlice<'_, Self::Slice<'a>>,
96        rhs: TensorSlice<'_, Self::Slice<'a>>,
97    ) -> Self::Owned {
98        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::mul_parallelized)
99    }
100
101    fn div_slices<'a>(
102        lhs: TensorSlice<'_, Self::Slice<'a>>,
103        rhs: TensorSlice<'_, Self::Slice<'a>>,
104    ) -> Self::Owned {
105        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::div_parallelized)
106    }
107
108    fn rem_slices<'a>(
109        lhs: TensorSlice<'_, Self::Slice<'a>>,
110        rhs: TensorSlice<'_, Self::Slice<'a>>,
111    ) -> Self::Owned {
112        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::rem_parallelized)
113    }
114}
115
116#[inline]
117#[track_caller]
118fn par_map_sks_scalar_op_on_pair_of_elements<'a, T, Clear, F>(
119    lhs: TensorSlice<'a, &'a [T]>,
120    rhs: TensorSlice<'a, &'a [Clear]>,
121    op: F,
122) -> Vec<T>
123where
124    T: IntegerRadixCiphertext,
125    Clear: Copy + Send + Sync,
126    F: Send + Sync + Fn(&crate::integer::ServerKey, &T, Clear) -> T,
127{
128    global_state::with_cpu_internal_keys(|cpu_key| {
129        lhs.par_iter()
130            .zip(rhs.par_iter())
131            .map(|(lhs, rhs)| op(cpu_key.pbs_key(), lhs, *rhs))
132            .collect::<Vec<_>>()
133    })
134}
135
136impl<Clear> ClearArithmeticArrayBackend<Clear> for CpuIntegerArrayBackend<RadixCiphertext>
137where
138    Clear: DecomposableInto<u8>
139        + std::ops::Not<Output = Clear>
140        + std::ops::Add<Clear, Output = Clear>
141        + ScalarMultiplier
142        + Reciprocable,
143{
144    fn add_slices(
145        lhs: TensorSlice<'_, Self::Slice<'_>>,
146        rhs: TensorSlice<'_, &'_ [Clear]>,
147    ) -> Self::Owned {
148        par_map_sks_scalar_op_on_pair_of_elements(
149            lhs,
150            rhs,
151            crate::integer::ServerKey::scalar_add_parallelized,
152        )
153    }
154
155    fn sub_slices(
156        lhs: TensorSlice<'_, Self::Slice<'_>>,
157        rhs: TensorSlice<'_, &'_ [Clear]>,
158    ) -> Self::Owned {
159        par_map_sks_scalar_op_on_pair_of_elements(
160            lhs,
161            rhs,
162            crate::integer::ServerKey::scalar_sub_parallelized,
163        )
164    }
165
166    fn mul_slices(
167        lhs: TensorSlice<'_, Self::Slice<'_>>,
168        rhs: TensorSlice<'_, &'_ [Clear]>,
169    ) -> Self::Owned {
170        par_map_sks_scalar_op_on_pair_of_elements(
171            lhs,
172            rhs,
173            crate::integer::ServerKey::scalar_mul_parallelized,
174        )
175    }
176
177    fn div_slices(
178        lhs: TensorSlice<'_, Self::Slice<'_>>,
179        rhs: TensorSlice<'_, &'_ [Clear]>,
180    ) -> Self::Owned {
181        par_map_sks_scalar_op_on_pair_of_elements(
182            lhs,
183            rhs,
184            crate::integer::ServerKey::scalar_div_parallelized,
185        )
186    }
187
188    fn rem_slices(
189        lhs: TensorSlice<'_, Self::Slice<'_>>,
190        rhs: TensorSlice<'_, &'_ [Clear]>,
191    ) -> Self::Owned {
192        par_map_sks_scalar_op_on_pair_of_elements(
193            lhs,
194            rhs,
195            crate::integer::ServerKey::scalar_rem_parallelized,
196        )
197    }
198}
199
200impl<Clear> ClearArithmeticArrayBackend<Clear> for CpuIntegerArrayBackend<SignedRadixCiphertext>
201where
202    Clear: DecomposableInto<u8>
203        + std::ops::Not<Output = Clear>
204        + std::ops::Add<Clear, Output = Clear>
205        + ScalarMultiplier
206        + SignedReciprocable,
207{
208    fn add_slices(
209        lhs: TensorSlice<'_, Self::Slice<'_>>,
210        rhs: TensorSlice<'_, &'_ [Clear]>,
211    ) -> Self::Owned {
212        par_map_sks_scalar_op_on_pair_of_elements(
213            lhs,
214            rhs,
215            crate::integer::ServerKey::scalar_add_parallelized,
216        )
217    }
218
219    fn sub_slices(
220        lhs: TensorSlice<'_, Self::Slice<'_>>,
221        rhs: TensorSlice<'_, &'_ [Clear]>,
222    ) -> Self::Owned {
223        par_map_sks_scalar_op_on_pair_of_elements(
224            lhs,
225            rhs,
226            crate::integer::ServerKey::scalar_sub_parallelized,
227        )
228    }
229
230    fn mul_slices(
231        lhs: TensorSlice<'_, Self::Slice<'_>>,
232        rhs: TensorSlice<'_, &'_ [Clear]>,
233    ) -> Self::Owned {
234        par_map_sks_scalar_op_on_pair_of_elements(
235            lhs,
236            rhs,
237            crate::integer::ServerKey::scalar_mul_parallelized,
238        )
239    }
240
241    fn div_slices(
242        lhs: TensorSlice<'_, Self::Slice<'_>>,
243        rhs: TensorSlice<'_, &'_ [Clear]>,
244    ) -> Self::Owned {
245        par_map_sks_scalar_op_on_pair_of_elements(
246            lhs,
247            rhs,
248            crate::integer::ServerKey::signed_scalar_div_parallelized,
249        )
250    }
251
252    fn rem_slices(
253        lhs: TensorSlice<'_, Self::Slice<'_>>,
254        rhs: TensorSlice<'_, &'_ [Clear]>,
255    ) -> Self::Owned {
256        par_map_sks_scalar_op_on_pair_of_elements(
257            lhs,
258            rhs,
259            crate::integer::ServerKey::signed_scalar_rem_parallelized,
260        )
261    }
262}
263
264impl<T> BitwiseArrayBackend for CpuIntegerArrayBackend<T>
265where
266    T: IntegerRadixCiphertext,
267{
268    fn bitand<'a>(
269        lhs: TensorSlice<'_, Self::Slice<'a>>,
270        rhs: TensorSlice<'_, Self::Slice<'a>>,
271    ) -> Self::Owned {
272        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitand_parallelized)
273    }
274
275    fn bitor<'a>(
276        lhs: TensorSlice<'_, Self::Slice<'a>>,
277        rhs: TensorSlice<'_, Self::Slice<'a>>,
278    ) -> Self::Owned {
279        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitor_parallelized)
280    }
281
282    fn bitxor<'a>(
283        lhs: TensorSlice<'_, Self::Slice<'a>>,
284        rhs: TensorSlice<'_, Self::Slice<'a>>,
285    ) -> Self::Owned {
286        par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitxor_parallelized)
287    }
288
289    fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
290        global_state::with_cpu_internal_keys(|cpu_key| {
291            lhs.par_iter()
292                .map(|lhs| cpu_key.pbs_key().bitnot(lhs))
293                .collect::<Vec<_>>()
294        })
295    }
296}
297
298impl<Clear, T> ClearBitwiseArrayBackend<Clear> for CpuIntegerArrayBackend<T>
299where
300    T: IntegerRadixCiphertext,
301    Clear: DecomposableInto<u8>,
302{
303    fn bitand_slice(
304        lhs: TensorSlice<'_, Self::Slice<'_>>,
305        rhs: TensorSlice<'_, &'_ [Clear]>,
306    ) -> Self::Owned {
307        par_map_sks_scalar_op_on_pair_of_elements(
308            lhs,
309            rhs,
310            crate::integer::ServerKey::scalar_bitand_parallelized,
311        )
312    }
313
314    fn bitor_slice(
315        lhs: TensorSlice<'_, Self::Slice<'_>>,
316        rhs: TensorSlice<'_, &'_ [Clear]>,
317    ) -> Self::Owned {
318        par_map_sks_scalar_op_on_pair_of_elements(
319            lhs,
320            rhs,
321            crate::integer::ServerKey::scalar_bitor_parallelized,
322        )
323    }
324
325    fn bitxor_slice(
326        lhs: TensorSlice<'_, Self::Slice<'_>>,
327        rhs: TensorSlice<'_, &'_ [Clear]>,
328    ) -> Self::Owned {
329        par_map_sks_scalar_op_on_pair_of_elements(
330            lhs,
331            rhs,
332            crate::integer::ServerKey::scalar_bitxor_parallelized,
333        )
334    }
335}
336
337impl<T> BackendDataContainer for Vec<T>
338where
339    T: IntegerRadixCiphertext,
340{
341    type Backend = CpuIntegerArrayBackend<T>;
342
343    fn len(&self) -> usize {
344        self.len()
345    }
346
347    fn as_sub_slice(
348        &self,
349        range: impl RangeBounds<usize>,
350    ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
351        create_sub_slice_with_bound(Self::as_slice(self), range)
352    }
353
354    fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
355        self
356    }
357}
358
359impl<T> BackendDataContainerMut for Vec<T>
360where
361    T: IntegerRadixCiphertext,
362{
363    fn as_sub_slice_mut(
364        &mut self,
365        range: impl RangeBounds<usize>,
366    ) -> <Self::Backend as ArrayBackend>::SliceMut<'_> {
367        create_sub_mut_slice_with_bound(self.as_mut_slice(), range)
368    }
369}
370
371impl<T> BackendDataContainer for &[T]
372where
373    T: IntegerRadixCiphertext,
374{
375    type Backend = CpuIntegerArrayBackend<T>;
376
377    fn len(&self) -> usize {
378        <[T]>::len(self)
379    }
380
381    fn as_sub_slice(
382        &self,
383        range: impl RangeBounds<usize>,
384    ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
385        create_sub_slice_with_bound(*self, range)
386    }
387
388    fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
389        self.to_vec()
390    }
391}
392
393impl<T> BackendDataContainer for &mut [T]
394where
395    T: IntegerRadixCiphertext,
396{
397    type Backend = CpuIntegerArrayBackend<T>;
398
399    fn len(&self) -> usize {
400        <[T]>::len(self)
401    }
402
403    fn as_sub_slice(
404        &self,
405        range: impl RangeBounds<usize>,
406    ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
407        create_sub_slice_with_bound(*self, range)
408    }
409
410    fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
411        self.to_vec()
412    }
413}
414
415impl<T> BackendDataContainerMut for &mut [T]
416where
417    T: IntegerRadixCiphertext,
418{
419    fn as_sub_slice_mut(
420        &mut self,
421        range: impl RangeBounds<usize>,
422    ) -> <Self::Backend as ArrayBackend>::SliceMut<'_> {
423        create_sub_mut_slice_with_bound(*self, range)
424    }
425}
426
427impl<'a, Clear, Id> FheTryEncrypt<&'a [Clear], ClientKey> for FheArrayBase<Vec<RadixCiphertext>, Id>
428where
429    Id: FheUintId,
430    Clear: DecomposableInto<u64> + UnsignedNumeric,
431{
432    type Error = Error;
433
434    fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
435        let num_blocks = Id::num_blocks(key.message_modulus());
436        Ok(Self::new(
437            clears
438                .iter()
439                .copied()
440                .map(|clear| key.key.key.encrypt_radix(clear, num_blocks))
441                .collect::<Vec<_>>(),
442            vec![clears.len()],
443        ))
444    }
445}
446
447impl<'a, Clear, Id> FheTryEncrypt<(&'a [Clear], Vec<usize>), ClientKey>
448    for FheArrayBase<Vec<RadixCiphertext>, Id>
449where
450    Id: FheUintId,
451    Clear: DecomposableInto<u64> + UnsignedNumeric,
452{
453    type Error = Error;
454
455    fn try_encrypt(
456        (clears, shape): (&'a [Clear], Vec<usize>),
457        key: &ClientKey,
458    ) -> Result<Self, Self::Error> {
459        if clears.len() != shape.iter().copied().product::<usize>() {
460            return Err(crate::Error::new(
461                "Shape does not matches the number of elements given".to_string(),
462            ));
463        }
464        let num_blocks = Id::num_blocks(key.message_modulus());
465        let elems = clears
466            .iter()
467            .copied()
468            .map(|clear| key.key.key.encrypt_radix(clear, num_blocks))
469            .collect::<Vec<_>>();
470        let data = Self::new(elems, shape);
471        Ok(data)
472    }
473}
474
475impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintArray<Id>
476where
477    Id: FheUintId,
478    Clear: RecomposableFrom<u64> + UnsignedNumeric,
479{
480    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
481        self.as_slice().decrypt(key)
482    }
483}
484
485impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintSliceMut<'_, Id>
486where
487    Id: FheUintId,
488    Clear: RecomposableFrom<u64> + UnsignedNumeric,
489{
490    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
491        self.as_slice().decrypt(key)
492    }
493}
494
495impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintSlice<'_, Id>
496where
497    Id: FheUintId,
498    Clear: RecomposableFrom<u64> + UnsignedNumeric,
499{
500    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
501        self.as_tensor_slice()
502            .iter()
503            .map(|ct| key.key.key.decrypt_radix(ct))
504            .collect()
505    }
506}
507
508impl<'a, Clear, Id> FheTryEncrypt<&'a [Clear], ClientKey> for CpuFheIntArray<Id>
509where
510    Id: FheIntId,
511    Clear: DecomposableInto<u64> + SignedNumeric,
512{
513    type Error = Error;
514
515    fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
516        let num_blocks = Id::num_blocks(key.message_modulus());
517        Ok(Self::new(
518            clears
519                .iter()
520                .copied()
521                .map(|clear| key.key.key.encrypt_signed_radix(clear, num_blocks))
522                .collect::<Vec<_>>(),
523            vec![clears.len()],
524        ))
525    }
526}
527
528impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntArray<Id>
529where
530    Id: FheIntId,
531    Clear: RecomposableSignedInteger,
532{
533    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
534        self.as_slice().decrypt(key)
535    }
536}
537
538impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntSliceMut<'_, Id>
539where
540    Id: FheIntId,
541    Clear: RecomposableSignedInteger,
542{
543    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
544        self.as_slice().decrypt(key)
545    }
546}
547
548impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntSlice<'_, Id>
549where
550    Id: FheIntId,
551    Clear: RecomposableSignedInteger,
552{
553    fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
554        self.elems
555            .iter()
556            .map(|ct| key.key.key.decrypt_signed_radix(ct))
557            .collect()
558    }
559}