1use super::super::helpers::{create_sub_mut_slice_with_bound, create_sub_slice_with_bound};
4use super::super::traits::{ArithmeticArrayBackend, BitwiseArrayBackend, ClearBitwiseArrayBackend};
5use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
6use crate::high_level_api::array::{
7 ArrayBackend, FheArrayBase, FheBackendArray, FheBackendArraySlice, FheBackendArraySliceMut,
8};
9
10use crate::array::traits::{
11 BackendDataContainer, BackendDataContainerMut, ClearArithmeticArrayBackend, TensorSlice,
12};
13use crate::high_level_api::global_state;
14use crate::high_level_api::integers::{FheIntId, FheUintId};
15use crate::integer::block_decomposition::{
16 DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
17};
18use crate::integer::server_key::radix_parallel::scalar_div_mod::SignedReciprocable;
19use crate::integer::server_key::{Reciprocable, ScalarMultiplier};
20use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, SignedRadixCiphertext};
21use crate::prelude::{FheDecrypt, FheTryEncrypt};
22use crate::{ClientKey, Error};
23use rayon::prelude::*;
24use std::marker::PhantomData;
25use std::ops::RangeBounds;
26
27pub struct CpuIntegerArrayBackend<T>(PhantomData<T>);
28
29pub type CpuUintArrayBackend = CpuIntegerArrayBackend<RadixCiphertext>;
30pub type CpuIntArrayBackend = CpuIntegerArrayBackend<SignedRadixCiphertext>;
31
32pub type CpuFheUintArray<Id> = FheBackendArray<CpuUintArrayBackend, Id>;
34pub type CpuFheUintSlice<'a, Id> = FheBackendArraySlice<'a, CpuUintArrayBackend, Id>;
35pub type CpuFheUintSliceMut<'a, Id> = FheBackendArraySliceMut<'a, CpuUintArrayBackend, Id>;
36
37pub type CpuFheIntArray<Id> = FheBackendArray<CpuIntArrayBackend, Id>;
39pub type CpuFheIntSlice<'a, Id> = FheBackendArraySlice<'a, CpuIntArrayBackend, Id>;
40pub type CpuFheIntSliceMut<'a, Id> = FheBackendArraySliceMut<'a, CpuIntArrayBackend, Id>;
41
42impl<T> ArrayBackend for CpuIntegerArrayBackend<T>
43where
44 T: IntegerRadixCiphertext,
45{
46 type Slice<'a>
47 = &'a [T]
48 where
49 Self: 'a;
50 type SliceMut<'a>
51 = &'a mut [T]
52 where
53 Self: 'a;
54 type Owned = Vec<T>;
55}
56
57#[inline]
58#[track_caller]
59fn par_map_sks_op_on_pair_of_elements<'a, T, F>(
60 lhs: TensorSlice<'a, &'a [T]>,
61 rhs: TensorSlice<'a, &'a [T]>,
62 op: F,
63) -> Vec<T>
64where
65 T: IntegerRadixCiphertext,
66 F: Send + Sync + Fn(&crate::integer::ServerKey, &T, &T) -> T,
67{
68 global_state::with_cpu_internal_keys(|cpu_key| {
69 lhs.par_iter()
70 .zip(rhs.par_iter())
71 .map(|(lhs, rhs)| op(cpu_key.pbs_key(), lhs, rhs))
72 .collect::<Vec<_>>()
73 })
74}
75
76impl<T> ArithmeticArrayBackend for CpuIntegerArrayBackend<T>
77where
78 T: IntegerRadixCiphertext,
79{
80 fn add_slices<'a>(
81 lhs: TensorSlice<'_, Self::Slice<'a>>,
82 rhs: TensorSlice<'_, Self::Slice<'a>>,
83 ) -> Self::Owned {
84 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::add_parallelized)
85 }
86
87 fn sub_slices<'a>(
88 lhs: TensorSlice<'_, Self::Slice<'a>>,
89 rhs: TensorSlice<'_, Self::Slice<'a>>,
90 ) -> Self::Owned {
91 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::sub_parallelized)
92 }
93
94 fn mul_slices<'a>(
95 lhs: TensorSlice<'_, Self::Slice<'a>>,
96 rhs: TensorSlice<'_, Self::Slice<'a>>,
97 ) -> Self::Owned {
98 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::mul_parallelized)
99 }
100
101 fn div_slices<'a>(
102 lhs: TensorSlice<'_, Self::Slice<'a>>,
103 rhs: TensorSlice<'_, Self::Slice<'a>>,
104 ) -> Self::Owned {
105 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::div_parallelized)
106 }
107
108 fn rem_slices<'a>(
109 lhs: TensorSlice<'_, Self::Slice<'a>>,
110 rhs: TensorSlice<'_, Self::Slice<'a>>,
111 ) -> Self::Owned {
112 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::rem_parallelized)
113 }
114}
115
116#[inline]
117#[track_caller]
118fn par_map_sks_scalar_op_on_pair_of_elements<'a, T, Clear, F>(
119 lhs: TensorSlice<'a, &'a [T]>,
120 rhs: TensorSlice<'a, &'a [Clear]>,
121 op: F,
122) -> Vec<T>
123where
124 T: IntegerRadixCiphertext,
125 Clear: Copy + Send + Sync,
126 F: Send + Sync + Fn(&crate::integer::ServerKey, &T, Clear) -> T,
127{
128 global_state::with_cpu_internal_keys(|cpu_key| {
129 lhs.par_iter()
130 .zip(rhs.par_iter())
131 .map(|(lhs, rhs)| op(cpu_key.pbs_key(), lhs, *rhs))
132 .collect::<Vec<_>>()
133 })
134}
135
136impl<Clear> ClearArithmeticArrayBackend<Clear> for CpuIntegerArrayBackend<RadixCiphertext>
137where
138 Clear: DecomposableInto<u8>
139 + std::ops::Not<Output = Clear>
140 + std::ops::Add<Clear, Output = Clear>
141 + ScalarMultiplier
142 + Reciprocable,
143{
144 fn add_slices(
145 lhs: TensorSlice<'_, Self::Slice<'_>>,
146 rhs: TensorSlice<'_, &'_ [Clear]>,
147 ) -> Self::Owned {
148 par_map_sks_scalar_op_on_pair_of_elements(
149 lhs,
150 rhs,
151 crate::integer::ServerKey::scalar_add_parallelized,
152 )
153 }
154
155 fn sub_slices(
156 lhs: TensorSlice<'_, Self::Slice<'_>>,
157 rhs: TensorSlice<'_, &'_ [Clear]>,
158 ) -> Self::Owned {
159 par_map_sks_scalar_op_on_pair_of_elements(
160 lhs,
161 rhs,
162 crate::integer::ServerKey::scalar_sub_parallelized,
163 )
164 }
165
166 fn mul_slices(
167 lhs: TensorSlice<'_, Self::Slice<'_>>,
168 rhs: TensorSlice<'_, &'_ [Clear]>,
169 ) -> Self::Owned {
170 par_map_sks_scalar_op_on_pair_of_elements(
171 lhs,
172 rhs,
173 crate::integer::ServerKey::scalar_mul_parallelized,
174 )
175 }
176
177 fn div_slices(
178 lhs: TensorSlice<'_, Self::Slice<'_>>,
179 rhs: TensorSlice<'_, &'_ [Clear]>,
180 ) -> Self::Owned {
181 par_map_sks_scalar_op_on_pair_of_elements(
182 lhs,
183 rhs,
184 crate::integer::ServerKey::scalar_div_parallelized,
185 )
186 }
187
188 fn rem_slices(
189 lhs: TensorSlice<'_, Self::Slice<'_>>,
190 rhs: TensorSlice<'_, &'_ [Clear]>,
191 ) -> Self::Owned {
192 par_map_sks_scalar_op_on_pair_of_elements(
193 lhs,
194 rhs,
195 crate::integer::ServerKey::scalar_rem_parallelized,
196 )
197 }
198}
199
200impl<Clear> ClearArithmeticArrayBackend<Clear> for CpuIntegerArrayBackend<SignedRadixCiphertext>
201where
202 Clear: DecomposableInto<u8>
203 + std::ops::Not<Output = Clear>
204 + std::ops::Add<Clear, Output = Clear>
205 + ScalarMultiplier
206 + SignedReciprocable,
207{
208 fn add_slices(
209 lhs: TensorSlice<'_, Self::Slice<'_>>,
210 rhs: TensorSlice<'_, &'_ [Clear]>,
211 ) -> Self::Owned {
212 par_map_sks_scalar_op_on_pair_of_elements(
213 lhs,
214 rhs,
215 crate::integer::ServerKey::scalar_add_parallelized,
216 )
217 }
218
219 fn sub_slices(
220 lhs: TensorSlice<'_, Self::Slice<'_>>,
221 rhs: TensorSlice<'_, &'_ [Clear]>,
222 ) -> Self::Owned {
223 par_map_sks_scalar_op_on_pair_of_elements(
224 lhs,
225 rhs,
226 crate::integer::ServerKey::scalar_sub_parallelized,
227 )
228 }
229
230 fn mul_slices(
231 lhs: TensorSlice<'_, Self::Slice<'_>>,
232 rhs: TensorSlice<'_, &'_ [Clear]>,
233 ) -> Self::Owned {
234 par_map_sks_scalar_op_on_pair_of_elements(
235 lhs,
236 rhs,
237 crate::integer::ServerKey::scalar_mul_parallelized,
238 )
239 }
240
241 fn div_slices(
242 lhs: TensorSlice<'_, Self::Slice<'_>>,
243 rhs: TensorSlice<'_, &'_ [Clear]>,
244 ) -> Self::Owned {
245 par_map_sks_scalar_op_on_pair_of_elements(
246 lhs,
247 rhs,
248 crate::integer::ServerKey::signed_scalar_div_parallelized,
249 )
250 }
251
252 fn rem_slices(
253 lhs: TensorSlice<'_, Self::Slice<'_>>,
254 rhs: TensorSlice<'_, &'_ [Clear]>,
255 ) -> Self::Owned {
256 par_map_sks_scalar_op_on_pair_of_elements(
257 lhs,
258 rhs,
259 crate::integer::ServerKey::signed_scalar_rem_parallelized,
260 )
261 }
262}
263
264impl<T> BitwiseArrayBackend for CpuIntegerArrayBackend<T>
265where
266 T: IntegerRadixCiphertext,
267{
268 fn bitand<'a>(
269 lhs: TensorSlice<'_, Self::Slice<'a>>,
270 rhs: TensorSlice<'_, Self::Slice<'a>>,
271 ) -> Self::Owned {
272 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitand_parallelized)
273 }
274
275 fn bitor<'a>(
276 lhs: TensorSlice<'_, Self::Slice<'a>>,
277 rhs: TensorSlice<'_, Self::Slice<'a>>,
278 ) -> Self::Owned {
279 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitor_parallelized)
280 }
281
282 fn bitxor<'a>(
283 lhs: TensorSlice<'_, Self::Slice<'a>>,
284 rhs: TensorSlice<'_, Self::Slice<'a>>,
285 ) -> Self::Owned {
286 par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::ServerKey::bitxor_parallelized)
287 }
288
289 fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
290 global_state::with_cpu_internal_keys(|cpu_key| {
291 lhs.par_iter()
292 .map(|lhs| cpu_key.pbs_key().bitnot(lhs))
293 .collect::<Vec<_>>()
294 })
295 }
296}
297
298impl<Clear, T> ClearBitwiseArrayBackend<Clear> for CpuIntegerArrayBackend<T>
299where
300 T: IntegerRadixCiphertext,
301 Clear: DecomposableInto<u8>,
302{
303 fn bitand_slice(
304 lhs: TensorSlice<'_, Self::Slice<'_>>,
305 rhs: TensorSlice<'_, &'_ [Clear]>,
306 ) -> Self::Owned {
307 par_map_sks_scalar_op_on_pair_of_elements(
308 lhs,
309 rhs,
310 crate::integer::ServerKey::scalar_bitand_parallelized,
311 )
312 }
313
314 fn bitor_slice(
315 lhs: TensorSlice<'_, Self::Slice<'_>>,
316 rhs: TensorSlice<'_, &'_ [Clear]>,
317 ) -> Self::Owned {
318 par_map_sks_scalar_op_on_pair_of_elements(
319 lhs,
320 rhs,
321 crate::integer::ServerKey::scalar_bitor_parallelized,
322 )
323 }
324
325 fn bitxor_slice(
326 lhs: TensorSlice<'_, Self::Slice<'_>>,
327 rhs: TensorSlice<'_, &'_ [Clear]>,
328 ) -> Self::Owned {
329 par_map_sks_scalar_op_on_pair_of_elements(
330 lhs,
331 rhs,
332 crate::integer::ServerKey::scalar_bitxor_parallelized,
333 )
334 }
335}
336
337impl<T> BackendDataContainer for Vec<T>
338where
339 T: IntegerRadixCiphertext,
340{
341 type Backend = CpuIntegerArrayBackend<T>;
342
343 fn len(&self) -> usize {
344 self.len()
345 }
346
347 fn as_sub_slice(
348 &self,
349 range: impl RangeBounds<usize>,
350 ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
351 create_sub_slice_with_bound(Self::as_slice(self), range)
352 }
353
354 fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
355 self
356 }
357}
358
359impl<T> BackendDataContainerMut for Vec<T>
360where
361 T: IntegerRadixCiphertext,
362{
363 fn as_sub_slice_mut(
364 &mut self,
365 range: impl RangeBounds<usize>,
366 ) -> <Self::Backend as ArrayBackend>::SliceMut<'_> {
367 create_sub_mut_slice_with_bound(self.as_mut_slice(), range)
368 }
369}
370
371impl<T> BackendDataContainer for &[T]
372where
373 T: IntegerRadixCiphertext,
374{
375 type Backend = CpuIntegerArrayBackend<T>;
376
377 fn len(&self) -> usize {
378 <[T]>::len(self)
379 }
380
381 fn as_sub_slice(
382 &self,
383 range: impl RangeBounds<usize>,
384 ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
385 create_sub_slice_with_bound(*self, range)
386 }
387
388 fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
389 self.to_vec()
390 }
391}
392
393impl<T> BackendDataContainer for &mut [T]
394where
395 T: IntegerRadixCiphertext,
396{
397 type Backend = CpuIntegerArrayBackend<T>;
398
399 fn len(&self) -> usize {
400 <[T]>::len(self)
401 }
402
403 fn as_sub_slice(
404 &self,
405 range: impl RangeBounds<usize>,
406 ) -> <Self::Backend as ArrayBackend>::Slice<'_> {
407 create_sub_slice_with_bound(*self, range)
408 }
409
410 fn into_owned(self) -> <Self::Backend as ArrayBackend>::Owned {
411 self.to_vec()
412 }
413}
414
415impl<T> BackendDataContainerMut for &mut [T]
416where
417 T: IntegerRadixCiphertext,
418{
419 fn as_sub_slice_mut(
420 &mut self,
421 range: impl RangeBounds<usize>,
422 ) -> <Self::Backend as ArrayBackend>::SliceMut<'_> {
423 create_sub_mut_slice_with_bound(*self, range)
424 }
425}
426
427impl<'a, Clear, Id> FheTryEncrypt<&'a [Clear], ClientKey> for FheArrayBase<Vec<RadixCiphertext>, Id>
428where
429 Id: FheUintId,
430 Clear: DecomposableInto<u64> + UnsignedNumeric,
431{
432 type Error = Error;
433
434 fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
435 let num_blocks = Id::num_blocks(key.message_modulus());
436 Ok(Self::new(
437 clears
438 .iter()
439 .copied()
440 .map(|clear| key.key.key.encrypt_radix(clear, num_blocks))
441 .collect::<Vec<_>>(),
442 vec![clears.len()],
443 ))
444 }
445}
446
447impl<'a, Clear, Id> FheTryEncrypt<(&'a [Clear], Vec<usize>), ClientKey>
448 for FheArrayBase<Vec<RadixCiphertext>, Id>
449where
450 Id: FheUintId,
451 Clear: DecomposableInto<u64> + UnsignedNumeric,
452{
453 type Error = Error;
454
455 fn try_encrypt(
456 (clears, shape): (&'a [Clear], Vec<usize>),
457 key: &ClientKey,
458 ) -> Result<Self, Self::Error> {
459 if clears.len() != shape.iter().copied().product::<usize>() {
460 return Err(crate::Error::new(
461 "Shape does not matches the number of elements given".to_string(),
462 ));
463 }
464 let num_blocks = Id::num_blocks(key.message_modulus());
465 let elems = clears
466 .iter()
467 .copied()
468 .map(|clear| key.key.key.encrypt_radix(clear, num_blocks))
469 .collect::<Vec<_>>();
470 let data = Self::new(elems, shape);
471 Ok(data)
472 }
473}
474
475impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintArray<Id>
476where
477 Id: FheUintId,
478 Clear: RecomposableFrom<u64> + UnsignedNumeric,
479{
480 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
481 self.as_slice().decrypt(key)
482 }
483}
484
485impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintSliceMut<'_, Id>
486where
487 Id: FheUintId,
488 Clear: RecomposableFrom<u64> + UnsignedNumeric,
489{
490 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
491 self.as_slice().decrypt(key)
492 }
493}
494
495impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheUintSlice<'_, Id>
496where
497 Id: FheUintId,
498 Clear: RecomposableFrom<u64> + UnsignedNumeric,
499{
500 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
501 self.as_tensor_slice()
502 .iter()
503 .map(|ct| key.key.key.decrypt_radix(ct))
504 .collect()
505 }
506}
507
508impl<'a, Clear, Id> FheTryEncrypt<&'a [Clear], ClientKey> for CpuFheIntArray<Id>
509where
510 Id: FheIntId,
511 Clear: DecomposableInto<u64> + SignedNumeric,
512{
513 type Error = Error;
514
515 fn try_encrypt(clears: &'a [Clear], key: &ClientKey) -> Result<Self, Self::Error> {
516 let num_blocks = Id::num_blocks(key.message_modulus());
517 Ok(Self::new(
518 clears
519 .iter()
520 .copied()
521 .map(|clear| key.key.key.encrypt_signed_radix(clear, num_blocks))
522 .collect::<Vec<_>>(),
523 vec![clears.len()],
524 ))
525 }
526}
527
528impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntArray<Id>
529where
530 Id: FheIntId,
531 Clear: RecomposableSignedInteger,
532{
533 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
534 self.as_slice().decrypt(key)
535 }
536}
537
538impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntSliceMut<'_, Id>
539where
540 Id: FheIntId,
541 Clear: RecomposableSignedInteger,
542{
543 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
544 self.as_slice().decrypt(key)
545 }
546}
547
548impl<Clear, Id> FheDecrypt<Vec<Clear>> for CpuFheIntSlice<'_, Id>
549where
550 Id: FheIntId,
551 Clear: RecomposableSignedInteger,
552{
553 fn decrypt(&self, key: &ClientKey) -> Vec<Clear> {
554 self.elems
555 .iter()
556 .map(|ct| key.key.key.decrypt_signed_radix(ct))
557 .collect()
558 }
559}