Skip to main content

tfhe/high_level_api/integers/signed/
squashed_noise.rs

1use super::base::{FheInt, FheIntId};
2use crate::backward_compatibility::integers::{
3    InnerSquashedNoiseSignedRadixCiphertextVersionedOwned, SquashedNoiseFheIntVersions,
4};
5use crate::high_level_api::details::MaybeCloned;
6use crate::high_level_api::errors::UninitializedNoiseSquashing;
7use crate::high_level_api::global_state::{self, with_internal_keys};
8#[cfg(feature = "gpu")]
9use crate::high_level_api::global_state::{
10    with_cuda_internal_keys, with_thread_local_cuda_streams_for_gpu_indexes,
11};
12use crate::high_level_api::keys::InternalServerKey;
13use crate::high_level_api::traits::{FheDecrypt, SquashNoise};
14use crate::high_level_api::SquashedNoiseCiphertextState;
15use crate::integer::block_decomposition::{RecomposableFrom, SignExtendable};
16#[cfg(feature = "gpu")]
17use crate::integer::gpu::ciphertext::squashed_noise::CudaSquashedNoiseSignedRadixCiphertext;
18use crate::named::Named;
19use crate::prelude::Tagged;
20use crate::{ClientKey, Device, Tag};
21use serde::{Deserializer, Serializer};
22use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
23
24/// Enum that manages the current inner representation of a squashed noise FheInt .
25pub(in crate::high_level_api) enum InnerSquashedNoiseSignedRadixCiphertext {
26    Cpu(crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext),
27    #[cfg(feature = "gpu")]
28    Cuda(CudaSquashedNoiseSignedRadixCiphertext),
29}
30
31impl Clone for InnerSquashedNoiseSignedRadixCiphertext {
32    fn clone(&self) -> Self {
33        match self {
34            Self::Cpu(inner) => Self::Cpu(inner.clone()),
35            #[cfg(feature = "gpu")]
36            Self::Cuda(inner) => {
37                with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
38                    Self::Cuda(inner.duplicate(streams))
39                })
40            }
41        }
42    }
43}
44impl serde::Serialize for InnerSquashedNoiseSignedRadixCiphertext {
45    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
46    where
47        S: Serializer,
48    {
49        self.on_cpu().serialize(serializer)
50    }
51}
52
53impl<'de> serde::Deserialize<'de> for InnerSquashedNoiseSignedRadixCiphertext {
54    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55    where
56        D: Deserializer<'de>,
57    {
58        let mut deserialized = Self::Cpu(
59            crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext::deserialize(
60                deserializer,
61            )?,
62        );
63        deserialized.move_to_device_of_server_key_if_set();
64        Ok(deserialized)
65    }
66}
67
68// Only CPU data are serialized so we only versionize the CPU type.
69#[derive(serde::Serialize, serde::Deserialize)]
70#[cfg_attr(dylint_lib = "tfhe_lints", allow(serialize_without_versionize))]
71pub(crate) struct InnerSquashedNoiseSignedRadixCiphertextVersionOwned(
72    <crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext as VersionizeOwned>::VersionedOwned,
73);
74
75impl Versionize for InnerSquashedNoiseSignedRadixCiphertext {
76    type Versioned<'vers> = InnerSquashedNoiseSignedRadixCiphertextVersionedOwned;
77
78    fn versionize(&self) -> Self::Versioned<'_> {
79        let data = self.on_cpu();
80        let versioned = data.into_owned().versionize_owned();
81        InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(
82            InnerSquashedNoiseSignedRadixCiphertextVersionOwned(versioned),
83        )
84    }
85}
86impl VersionizeOwned for InnerSquashedNoiseSignedRadixCiphertext {
87    type VersionedOwned = InnerSquashedNoiseSignedRadixCiphertextVersionedOwned;
88
89    fn versionize_owned(self) -> Self::VersionedOwned {
90        let cpu_data = self.on_cpu();
91        InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(
92            InnerSquashedNoiseSignedRadixCiphertextVersionOwned(
93                cpu_data.into_owned().versionize_owned(),
94            ),
95        )
96    }
97}
98
99impl Unversionize for InnerSquashedNoiseSignedRadixCiphertext {
100    fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
101        match versioned {
102            InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(v0) => {
103                let mut unversioned = Self::Cpu(
104                    crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext::unversionize(
105                        v0.0,
106                    )?,
107                );
108                unversioned.move_to_device_of_server_key_if_set();
109                Ok(unversioned)
110            }
111        }
112    }
113}
114
115impl InnerSquashedNoiseSignedRadixCiphertext {
116    /// Returns the inner cpu ciphertext if self is on the CPU, otherwise, returns a copy
117    /// that is on the CPU
118    pub(crate) fn on_cpu(
119        &self,
120    ) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext> {
121        match self {
122            Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
123            #[cfg(feature = "gpu")]
124            Self::Cuda(ct) => {
125                with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| {
126                    MaybeCloned::Cloned(ct.to_squashed_noise_signed_radix_ciphertext(streams))
127                })
128            }
129        }
130    }
131    fn current_device(&self) -> crate::Device {
132        match self {
133            Self::Cpu(_) => crate::Device::Cpu,
134            #[cfg(feature = "gpu")]
135            Self::Cuda(_) => crate::Device::CudaGpu,
136        }
137    }
138
139    #[allow(clippy::needless_pass_by_ref_mut)]
140    fn move_to_device(&mut self, target_device: Device) {
141        let current_device = self.current_device();
142
143        if current_device == target_device {
144            #[cfg(feature = "gpu")]
145            // We may not be on the correct Cuda device
146            if let Self::Cuda(cuda_ct) = self {
147                with_cuda_internal_keys(|keys| {
148                    let streams = &keys.streams;
149                    if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
150                        *cuda_ct = cuda_ct.duplicate(streams);
151                    }
152                })
153            }
154            return;
155        }
156
157        // The logic is that the common device is the CPU, all other devices
158        // know how to transfer from and to CPU.
159
160        // So we first transfer to CPU
161        let cpu_ct = self.on_cpu();
162
163        // Then we can transfer the desired device
164        match target_device {
165            Device::Cpu => {
166                let _ = cpu_ct;
167            }
168            #[cfg(feature = "gpu")]
169            Device::CudaGpu => {
170                let new_inner = with_cuda_internal_keys(|keys| {
171                    let streams = &keys.streams;
172                    CudaSquashedNoiseSignedRadixCiphertext::from_squashed_noise_signed_radix_ciphertext(&cpu_ct, streams)
173                });
174                *self = Self::Cuda(new_inner);
175            }
176            #[cfg(feature = "hpu")]
177            Device::Hpu => {
178                panic!("HPU does not support noise squashing compression");
179            }
180        }
181    }
182
183    #[inline]
184    pub(crate) fn move_to_device_of_server_key_if_set(&mut self) {
185        if let Some(device) = global_state::device_of_internal_keys() {
186            self.move_to_device(device);
187        }
188    }
189}
190
191#[derive(Clone, serde::Deserialize, serde::Serialize, Versionize)]
192#[versionize(SquashedNoiseFheIntVersions)]
193pub struct SquashedNoiseFheInt {
194    pub(in crate::high_level_api) inner: InnerSquashedNoiseSignedRadixCiphertext,
195    pub(in crate::high_level_api) state: SquashedNoiseCiphertextState,
196    tag: Tag,
197}
198
199impl Named for SquashedNoiseFheInt {
200    const NAME: &'static str = "high_level_api::SquashedNoiseFheInt";
201}
202
203impl SquashedNoiseFheInt {
204    pub(in crate::high_level_api) fn new(
205        inner: InnerSquashedNoiseSignedRadixCiphertext,
206        state: SquashedNoiseCiphertextState,
207        tag: Tag,
208    ) -> Self {
209        Self { inner, state, tag }
210    }
211
212    pub fn underlying_squashed_noise_ciphertext(
213        &self,
214    ) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext> {
215        self.inner.on_cpu()
216    }
217
218    pub fn num_bits(&self) -> usize {
219        match &self.inner {
220            InnerSquashedNoiseSignedRadixCiphertext::Cpu(on_cpu) => {
221                on_cpu.original_block_count
222                    * on_cpu.packed_blocks[0].message_modulus().0.ilog2() as usize
223            }
224            #[cfg(feature = "gpu")]
225            InnerSquashedNoiseSignedRadixCiphertext::Cuda(gpu_ct) => {
226                gpu_ct.ciphertext.original_block_count
227                    * gpu_ct
228                        .ciphertext
229                        .info
230                        .blocks
231                        .first()
232                        .unwrap()
233                        .message_modulus
234                        .0
235                        .ilog2() as usize
236            }
237        }
238    }
239}
240
241impl<Clear> FheDecrypt<Clear> for SquashedNoiseFheInt
242where
243    Clear: RecomposableFrom<u128> + SignExtendable,
244{
245    fn decrypt(&self, key: &ClientKey) -> Clear {
246        let noise_squashing_private_key = key.private_noise_squashing_decryption_key(self.state);
247
248        noise_squashing_private_key
249            .decrypt_signed_radix(&self.inner.on_cpu())
250            .unwrap()
251    }
252}
253
254impl Tagged for SquashedNoiseFheInt {
255    fn tag(&self) -> &Tag {
256        &self.tag
257    }
258
259    fn tag_mut(&mut self) -> &mut Tag {
260        &mut self.tag
261    }
262}
263
264impl<Id: FheIntId> SquashNoise for FheInt<Id> {
265    type Output = SquashedNoiseFheInt;
266
267    fn squash_noise(&self) -> crate::Result<Self::Output> {
268        with_internal_keys(|keys| match keys {
269            InternalServerKey::Cpu(server_key) => {
270                let noise_squashing_key = server_key
271                    .key
272                    .noise_squashing_key
273                    .as_ref()
274                    .ok_or(UninitializedNoiseSquashing)?;
275
276                Ok(SquashedNoiseFheInt {
277                    inner: InnerSquashedNoiseSignedRadixCiphertext::Cpu(
278                        noise_squashing_key.squash_signed_radix_ciphertext_noise(
279                            server_key.key.pbs_key(),
280                            &self.ciphertext.on_cpu(),
281                        )?,
282                    ),
283                    state: SquashedNoiseCiphertextState::Normal,
284                    tag: server_key.tag.clone(),
285                })
286            }
287            #[cfg(feature = "gpu")]
288            InternalServerKey::Cuda(cuda_key) => {
289                let streams = &cuda_key.streams;
290                let noise_squashing_key = cuda_key
291                    .key
292                    .noise_squashing_key
293                    .as_ref()
294                    .ok_or(UninitializedNoiseSquashing)?;
295
296                let cuda_squashed_ct = noise_squashing_key.squash_signed_radix_ciphertext_noise(
297                    cuda_key.pbs_key(),
298                    &self.ciphertext.on_gpu(streams),
299                    streams,
300                )?;
301
302                let cpu_squashed_ct =
303                    cuda_squashed_ct.to_squashed_noise_signed_radix_ciphertext(streams);
304                Ok(SquashedNoiseFheInt {
305                    inner: InnerSquashedNoiseSignedRadixCiphertext::Cpu(cpu_squashed_ct),
306                    state: SquashedNoiseCiphertextState::Normal,
307                    tag: cuda_key.tag.clone(),
308                })
309            }
310            #[cfg(feature = "hpu")]
311            InternalServerKey::Hpu(_device) => {
312                Err(crate::error!("Hpu devices do not support noise squashing"))
313            }
314        })
315    }
316}