tfhe/high_level_api/integers/signed/
squashed_noise.rs1use super::base::{FheInt, FheIntId};
2use crate::backward_compatibility::integers::{
3 InnerSquashedNoiseSignedRadixCiphertextVersionedOwned, SquashedNoiseFheIntVersions,
4};
5use crate::high_level_api::details::MaybeCloned;
6use crate::high_level_api::errors::UninitializedNoiseSquashing;
7use crate::high_level_api::global_state::{self, with_internal_keys};
8#[cfg(feature = "gpu")]
9use crate::high_level_api::global_state::{
10 with_cuda_internal_keys, with_thread_local_cuda_streams_for_gpu_indexes,
11};
12use crate::high_level_api::keys::InternalServerKey;
13use crate::high_level_api::traits::{FheDecrypt, SquashNoise};
14use crate::high_level_api::SquashedNoiseCiphertextState;
15use crate::integer::block_decomposition::{RecomposableFrom, SignExtendable};
16#[cfg(feature = "gpu")]
17use crate::integer::gpu::ciphertext::squashed_noise::CudaSquashedNoiseSignedRadixCiphertext;
18use crate::named::Named;
19use crate::prelude::Tagged;
20use crate::{ClientKey, Device, Tag};
21use serde::{Deserializer, Serializer};
22use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
23
24pub(in crate::high_level_api) enum InnerSquashedNoiseSignedRadixCiphertext {
26 Cpu(crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext),
27 #[cfg(feature = "gpu")]
28 Cuda(CudaSquashedNoiseSignedRadixCiphertext),
29}
30
31impl Clone for InnerSquashedNoiseSignedRadixCiphertext {
32 fn clone(&self) -> Self {
33 match self {
34 Self::Cpu(inner) => Self::Cpu(inner.clone()),
35 #[cfg(feature = "gpu")]
36 Self::Cuda(inner) => {
37 with_thread_local_cuda_streams_for_gpu_indexes(inner.gpu_indexes(), |streams| {
38 Self::Cuda(inner.duplicate(streams))
39 })
40 }
41 }
42 }
43}
44impl serde::Serialize for InnerSquashedNoiseSignedRadixCiphertext {
45 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
46 where
47 S: Serializer,
48 {
49 self.on_cpu().serialize(serializer)
50 }
51}
52
53impl<'de> serde::Deserialize<'de> for InnerSquashedNoiseSignedRadixCiphertext {
54 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55 where
56 D: Deserializer<'de>,
57 {
58 let mut deserialized = Self::Cpu(
59 crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext::deserialize(
60 deserializer,
61 )?,
62 );
63 deserialized.move_to_device_of_server_key_if_set();
64 Ok(deserialized)
65 }
66}
67
68#[derive(serde::Serialize, serde::Deserialize)]
70#[cfg_attr(dylint_lib = "tfhe_lints", allow(serialize_without_versionize))]
71pub(crate) struct InnerSquashedNoiseSignedRadixCiphertextVersionOwned(
72 <crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext as VersionizeOwned>::VersionedOwned,
73);
74
75impl Versionize for InnerSquashedNoiseSignedRadixCiphertext {
76 type Versioned<'vers> = InnerSquashedNoiseSignedRadixCiphertextVersionedOwned;
77
78 fn versionize(&self) -> Self::Versioned<'_> {
79 let data = self.on_cpu();
80 let versioned = data.into_owned().versionize_owned();
81 InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(
82 InnerSquashedNoiseSignedRadixCiphertextVersionOwned(versioned),
83 )
84 }
85}
86impl VersionizeOwned for InnerSquashedNoiseSignedRadixCiphertext {
87 type VersionedOwned = InnerSquashedNoiseSignedRadixCiphertextVersionedOwned;
88
89 fn versionize_owned(self) -> Self::VersionedOwned {
90 let cpu_data = self.on_cpu();
91 InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(
92 InnerSquashedNoiseSignedRadixCiphertextVersionOwned(
93 cpu_data.into_owned().versionize_owned(),
94 ),
95 )
96 }
97}
98
99impl Unversionize for InnerSquashedNoiseSignedRadixCiphertext {
100 fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
101 match versioned {
102 InnerSquashedNoiseSignedRadixCiphertextVersionedOwned::V0(v0) => {
103 let mut unversioned = Self::Cpu(
104 crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext::unversionize(
105 v0.0,
106 )?,
107 );
108 unversioned.move_to_device_of_server_key_if_set();
109 Ok(unversioned)
110 }
111 }
112 }
113}
114
115impl InnerSquashedNoiseSignedRadixCiphertext {
116 pub(crate) fn on_cpu(
119 &self,
120 ) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext> {
121 match self {
122 Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
123 #[cfg(feature = "gpu")]
124 Self::Cuda(ct) => {
125 with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| {
126 MaybeCloned::Cloned(ct.to_squashed_noise_signed_radix_ciphertext(streams))
127 })
128 }
129 }
130 }
131 fn current_device(&self) -> crate::Device {
132 match self {
133 Self::Cpu(_) => crate::Device::Cpu,
134 #[cfg(feature = "gpu")]
135 Self::Cuda(_) => crate::Device::CudaGpu,
136 }
137 }
138
139 #[allow(clippy::needless_pass_by_ref_mut)]
140 fn move_to_device(&mut self, target_device: Device) {
141 let current_device = self.current_device();
142
143 if current_device == target_device {
144 #[cfg(feature = "gpu")]
145 if let Self::Cuda(cuda_ct) = self {
147 with_cuda_internal_keys(|keys| {
148 let streams = &keys.streams;
149 if cuda_ct.gpu_indexes() != streams.gpu_indexes() {
150 *cuda_ct = cuda_ct.duplicate(streams);
151 }
152 })
153 }
154 return;
155 }
156
157 let cpu_ct = self.on_cpu();
162
163 match target_device {
165 Device::Cpu => {
166 let _ = cpu_ct;
167 }
168 #[cfg(feature = "gpu")]
169 Device::CudaGpu => {
170 let new_inner = with_cuda_internal_keys(|keys| {
171 let streams = &keys.streams;
172 CudaSquashedNoiseSignedRadixCiphertext::from_squashed_noise_signed_radix_ciphertext(&cpu_ct, streams)
173 });
174 *self = Self::Cuda(new_inner);
175 }
176 #[cfg(feature = "hpu")]
177 Device::Hpu => {
178 panic!("HPU does not support noise squashing compression");
179 }
180 }
181 }
182
183 #[inline]
184 pub(crate) fn move_to_device_of_server_key_if_set(&mut self) {
185 if let Some(device) = global_state::device_of_internal_keys() {
186 self.move_to_device(device);
187 }
188 }
189}
190
191#[derive(Clone, serde::Deserialize, serde::Serialize, Versionize)]
192#[versionize(SquashedNoiseFheIntVersions)]
193pub struct SquashedNoiseFheInt {
194 pub(in crate::high_level_api) inner: InnerSquashedNoiseSignedRadixCiphertext,
195 pub(in crate::high_level_api) state: SquashedNoiseCiphertextState,
196 tag: Tag,
197}
198
199impl Named for SquashedNoiseFheInt {
200 const NAME: &'static str = "high_level_api::SquashedNoiseFheInt";
201}
202
203impl SquashedNoiseFheInt {
204 pub(in crate::high_level_api) fn new(
205 inner: InnerSquashedNoiseSignedRadixCiphertext,
206 state: SquashedNoiseCiphertextState,
207 tag: Tag,
208 ) -> Self {
209 Self { inner, state, tag }
210 }
211
212 pub fn underlying_squashed_noise_ciphertext(
213 &self,
214 ) -> MaybeCloned<'_, crate::integer::ciphertext::SquashedNoiseSignedRadixCiphertext> {
215 self.inner.on_cpu()
216 }
217
218 pub fn num_bits(&self) -> usize {
219 match &self.inner {
220 InnerSquashedNoiseSignedRadixCiphertext::Cpu(on_cpu) => {
221 on_cpu.original_block_count
222 * on_cpu.packed_blocks[0].message_modulus().0.ilog2() as usize
223 }
224 #[cfg(feature = "gpu")]
225 InnerSquashedNoiseSignedRadixCiphertext::Cuda(gpu_ct) => {
226 gpu_ct.ciphertext.original_block_count
227 * gpu_ct
228 .ciphertext
229 .info
230 .blocks
231 .first()
232 .unwrap()
233 .message_modulus
234 .0
235 .ilog2() as usize
236 }
237 }
238 }
239}
240
241impl<Clear> FheDecrypt<Clear> for SquashedNoiseFheInt
242where
243 Clear: RecomposableFrom<u128> + SignExtendable,
244{
245 fn decrypt(&self, key: &ClientKey) -> Clear {
246 let noise_squashing_private_key = key.private_noise_squashing_decryption_key(self.state);
247
248 noise_squashing_private_key
249 .decrypt_signed_radix(&self.inner.on_cpu())
250 .unwrap()
251 }
252}
253
254impl Tagged for SquashedNoiseFheInt {
255 fn tag(&self) -> &Tag {
256 &self.tag
257 }
258
259 fn tag_mut(&mut self) -> &mut Tag {
260 &mut self.tag
261 }
262}
263
264impl<Id: FheIntId> SquashNoise for FheInt<Id> {
265 type Output = SquashedNoiseFheInt;
266
267 fn squash_noise(&self) -> crate::Result<Self::Output> {
268 with_internal_keys(|keys| match keys {
269 InternalServerKey::Cpu(server_key) => {
270 let noise_squashing_key = server_key
271 .key
272 .noise_squashing_key
273 .as_ref()
274 .ok_or(UninitializedNoiseSquashing)?;
275
276 Ok(SquashedNoiseFheInt {
277 inner: InnerSquashedNoiseSignedRadixCiphertext::Cpu(
278 noise_squashing_key.squash_signed_radix_ciphertext_noise(
279 server_key.key.pbs_key(),
280 &self.ciphertext.on_cpu(),
281 )?,
282 ),
283 state: SquashedNoiseCiphertextState::Normal,
284 tag: server_key.tag.clone(),
285 })
286 }
287 #[cfg(feature = "gpu")]
288 InternalServerKey::Cuda(cuda_key) => {
289 let streams = &cuda_key.streams;
290 let noise_squashing_key = cuda_key
291 .key
292 .noise_squashing_key
293 .as_ref()
294 .ok_or(UninitializedNoiseSquashing)?;
295
296 let cuda_squashed_ct = noise_squashing_key.squash_signed_radix_ciphertext_noise(
297 cuda_key.pbs_key(),
298 &self.ciphertext.on_gpu(streams),
299 streams,
300 )?;
301
302 let cpu_squashed_ct =
303 cuda_squashed_ct.to_squashed_noise_signed_radix_ciphertext(streams);
304 Ok(SquashedNoiseFheInt {
305 inner: InnerSquashedNoiseSignedRadixCiphertext::Cpu(cpu_squashed_ct),
306 state: SquashedNoiseCiphertextState::Normal,
307 tag: cuda_key.tag.clone(),
308 })
309 }
310 #[cfg(feature = "hpu")]
311 InternalServerKey::Hpu(_device) => {
312 Err(crate::error!("Hpu devices do not support noise squashing"))
313 }
314 })
315 }
316}