1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
//! Module with the definition of the ServerKey.
//!
//! This module implements the generation of the server public key, together with all the
//! available homomorphic integer operations.
pub mod comparator;
mod crt;
mod crt_parallel;
pub(crate) mod radix;
pub(crate) mod radix_parallel;

use crate::integer::client_key::ClientKey;
use crate::shortint::ciphertext::MaxDegree;
use serde::{Deserialize, Serialize};

/// Error returned when the carry buffer is full.
pub use crate::shortint::CheckError;
use crate::shortint::{CarryModulus, MessageModulus};
pub use radix::scalar_mul::ScalarMultiplier;
pub use radix::scalar_sub::TwosComplementNegation;
pub use radix_parallel::{MiniUnsignedInteger, Reciprocable};

/// A structure containing the server public key.
///
/// The server key is generated by the client and is meant to be published: the client
/// sends it to the server so it can compute homomorphic integer circuits.
#[derive(Serialize, Deserialize, Clone)]
pub struct ServerKey {
    pub(crate) key: crate::shortint::ServerKey,
}

impl From<ServerKey> for crate::shortint::ServerKey {
    fn from(key: ServerKey) -> Self {
        key.key
    }
}

impl MaxDegree {
    /// Compute the [`MaxDegree`] for an integer server key (compressed or uncompressed).
    /// To allow carry propagation between shortint blocks in a
    /// [`RadixCiphertext`](`crate::integer::RadixCiphertext`) (which includes adding the extracted
    /// carry from one shortint block to the next block), this formula provisions space to add a
    /// carry.
    pub(crate) fn integer_radix_server_key(
        message_modulus: MessageModulus,
        carry_modulus: CarryModulus,
    ) -> Self {
        let full_max_degree = message_modulus.0 * carry_modulus.0 - 1;

        let carry_max_degree = carry_modulus.0 - 1;

        // We want to be have a margin to add a carry from another block
        Self::new(full_max_degree - carry_max_degree)
    }
}

impl MaxDegree {
    /// Compute the [`MaxDegree`] for an integer server key (compressed or uncompressed).
    /// This is tailored for [`CrtCiphertext`](`crate::integer::CrtCiphertext`) and not compatible
    /// for use with [`RadixCiphertext`](`crate::integer::RadixCiphertext`).
    fn integer_crt_server_key(
        message_modulus: MessageModulus,
        carry_modulus: CarryModulus,
    ) -> Self {
        let full_max_degree = message_modulus.0 * carry_modulus.0 - 1;

        Self::new(full_max_degree)
    }
}

impl ServerKey {
    /// Generates a server key.
    ///
    /// # Example
    ///
    /// ```rust
    /// use tfhe::integer::{ClientKey, ServerKey};
    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
    ///
    /// // Generate the client key:
    /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
    ///
    /// // Generate the server key:
    /// let sks = ServerKey::new_radix_server_key(&cks);
    /// ```
    pub fn new_radix_server_key<C>(cks: C) -> Self
    where
        C: AsRef<ClientKey>,
    {
        // It should remain just enough space to add a carry
        let client_key = cks.as_ref();
        let max_degree = MaxDegree::integer_radix_server_key(
            client_key.key.parameters.message_modulus(),
            client_key.key.parameters.carry_modulus(),
        );

        let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
            &client_key.key,
            max_degree,
        );

        Self { key: sks }
    }

    pub fn new_crt_server_key<C>(cks: C) -> Self
    where
        C: AsRef<ClientKey>,
    {
        let client_key = cks.as_ref();
        let max_degree = MaxDegree::integer_crt_server_key(
            client_key.key.parameters.message_modulus(),
            client_key.key.parameters.carry_modulus(),
        );

        let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
            &client_key.key,
            max_degree,
        );

        Self { key: sks }
    }

    /// Creates a ServerKey destined to be used with
    /// [`RadixCiphertext`](`crate::integer::RadixCiphertext`) from an already generated
    /// shortint::ServerKey.
    ///
    /// # Example
    ///
    /// ```rust
    /// use tfhe::integer::{ClientKey, ServerKey};
    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
    /// use tfhe::shortint::ServerKey as ShortintServerKey;
    ///
    /// let size = 4;
    ///
    /// // Generate the client key:
    /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
    ///
    /// // Generate the shortint server key:
    /// let shortint_sks = ShortintServerKey::new(cks.as_ref());
    ///
    /// // Generate the server key:
    /// let sks = ServerKey::new_radix_server_key_from_shortint(shortint_sks);
    /// ```
    pub fn new_radix_server_key_from_shortint(
        mut key: crate::shortint::server_key::ServerKey,
    ) -> Self {
        // It should remain just enough space add a carry
        let max_degree =
            MaxDegree::integer_radix_server_key(key.message_modulus, key.carry_modulus);

        key.max_degree = max_degree;
        Self { key }
    }

    /// Creates a ServerKey destined to be used with
    /// [`CrtCiphertext`](`crate::integer::CrtCiphertext`) from an already generated
    /// shortint::ServerKey.
    ///
    /// # Example
    ///
    /// ```rust
    /// use tfhe::integer::{ClientKey, ServerKey};
    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
    /// use tfhe::shortint::ServerKey as ShortintServerKey;
    ///
    /// let size = 4;
    ///
    /// // Generate the client key:
    /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
    ///
    /// // Generate the shortint server key:
    /// let shortint_sks = ShortintServerKey::new(cks.as_ref());
    ///
    /// // Generate the server key:
    /// let sks = ServerKey::new_crt_server_key_from_shortint(shortint_sks);
    /// ```
    pub fn new_crt_server_key_from_shortint(
        mut key: crate::shortint::server_key::ServerKey,
    ) -> Self {
        key.max_degree = MaxDegree::integer_crt_server_key(key.message_modulus, key.carry_modulus);
        Self { key }
    }

    /// Deconstruct a [`ServerKey`] into its constituents.
    pub fn into_raw_parts(self) -> crate::shortint::ServerKey {
        self.key
    }

    /// Construct a [`ServerKey`] from its constituents.
    pub fn from_raw_parts(key: crate::shortint::ServerKey) -> Self {
        Self { key }
    }

    pub fn deterministic_pbs_execution(&self) -> bool {
        self.key.deterministic_pbs_execution()
    }

    pub fn set_deterministic_pbs_execution(&mut self, new_deterministic_execution: bool) {
        self.key
            .set_deterministic_pbs_execution(new_deterministic_execution);
    }

    pub fn message_modulus(&self) -> MessageModulus {
        self.key.message_modulus
    }

    pub fn carry_modulus(&self) -> CarryModulus {
        self.key.carry_modulus
    }
}

impl AsRef<crate::shortint::ServerKey> for ServerKey {
    fn as_ref(&self) -> &crate::shortint::ServerKey {
        &self.key
    }
}

#[derive(Clone, Serialize, Deserialize)]
pub struct CompressedServerKey {
    pub(crate) key: crate::shortint::CompressedServerKey,
}

impl CompressedServerKey {
    pub fn new_radix_compressed_server_key(client_key: &ClientKey) -> Self {
        let max_degree = MaxDegree::integer_radix_server_key(
            client_key.key.parameters.message_modulus(),
            client_key.key.parameters.carry_modulus(),
        );

        let key =
            crate::shortint::CompressedServerKey::new_with_max_degree(&client_key.key, max_degree);
        Self { key }
    }

    pub fn new_crt_compressed_server_key(client_key: &ClientKey) -> Self {
        let key = crate::shortint::CompressedServerKey::new(&client_key.key);
        Self { key }
    }

    /// Decompress a [`CompressedServerKey`] into a [`ServerKey`].
    pub fn decompress(&self) -> ServerKey {
        ServerKey {
            key: self.key.decompress(),
        }
    }

    /// Deconstruct a [`CompressedServerKey`] into its constituents.
    pub fn into_raw_parts(self) -> crate::shortint::CompressedServerKey {
        self.key
    }

    /// Construct a [`CompressedServerKey`] from its constituents.
    pub fn from_raw_parts(key: crate::shortint::CompressedServerKey) -> Self {
        Self { key }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::integer::RadixClientKey;
    use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;

    /// https://github.com/zama-ai/tfhe-rs/issues/460
    /// Problem with CompressedServerKey degree being set to shortint MaxDegree not accounting for
    /// the necessary carry bits for e.g. Radix carry propagation.
    #[test]
    fn test_compressed_server_key_max_degree() {
        {
            let cks = ClientKey::new(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS);
            // msg_mod = 4, carry_mod = 4, (msg_mod * carry_mod - 1) - (carry_mod - 1) = 12
            let expected_radix_max_degree = MaxDegree::new(12);

            let sks = ServerKey::new_radix_server_key(&cks);
            assert_eq!(sks.key.max_degree, expected_radix_max_degree);

            let csks = CompressedServerKey::new_radix_compressed_server_key(&cks);
            assert_eq!(csks.key.max_degree, expected_radix_max_degree);

            let decompressed_sks: ServerKey = csks.decompress();
            assert_eq!(decompressed_sks.key.max_degree, expected_radix_max_degree);
        }

        {
            let cks = ClientKey::new(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS);
            // msg_mod = 4, carry_mod = 4, msg_mod * carrymod - 1 = 15
            let expected_crt_max_degree = MaxDegree::new(15);

            let sks = ServerKey::new_crt_server_key(&cks);
            assert_eq!(sks.key.max_degree, expected_crt_max_degree);

            let csks = CompressedServerKey::new_crt_compressed_server_key(&cks);
            assert_eq!(csks.key.max_degree, expected_crt_max_degree);

            let decompressed_sks: ServerKey = csks.decompress();
            assert_eq!(decompressed_sks.key.max_degree, expected_crt_max_degree);
        }

        // Repro case from the user
        {
            let client_key = RadixClientKey::new(PARAM_MESSAGE_2_CARRY_2, 14);
            let compressed_eval_key =
                CompressedServerKey::new_radix_compressed_server_key(client_key.as_ref());
            let evaluation_key = compressed_eval_key.decompress();
            let modulus = (client_key.parameters().message_modulus().0 as u128)
                .pow(client_key.num_blocks() as u32);

            let mut ct = client_key.encrypt(modulus - 1);
            let mut res_ct = ct.clone();
            for _ in 0..5 {
                res_ct = evaluation_key.smart_add_parallelized(&mut res_ct, &mut ct);
            }
            let res: u128 = client_key.decrypt(&res_ct);
            assert_eq!(modulus - 6, res);
        }
    }
}