tfhe/high_level_api/
tag.rs

1use crate::high_level_api::backward_compatibility::tag::TagVersions;
2use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
3
4const STACK_ARRAY_SIZE: usize = std::mem::size_of::<Vec<u8>>() - 1;
5
6/// Simple short optimized vec, where if the data is small enough
7/// (<= std::mem::size_of::<Vec<u8>>() - 1) the data will be stored on the stack
8///
9/// Once a true heap allocated Vec was needed, it won't be deallocated in favor
10/// of stack data.
11#[derive(Clone, Debug)]
12pub(in crate::high_level_api) enum SmallVec {
13    Stack {
14        bytes: [u8; STACK_ARRAY_SIZE],
15        // The array has a fixed size, but the user may not use all of it
16        // so we keep track of the actual len
17        len: u8,
18    },
19    Heap(Vec<u8>),
20}
21
22impl Default for SmallVec {
23    fn default() -> Self {
24        Self::Stack {
25            bytes: Default::default(),
26            len: 0,
27        }
28    }
29}
30impl PartialEq for SmallVec {
31    fn eq(&self, other: &Self) -> bool {
32        match (self, other) {
33            (
34                Self::Stack {
35                    bytes: l_bytes,
36                    len: l_len,
37                },
38                Self::Stack {
39                    bytes: r_bytes,
40                    len: r_len,
41                },
42            ) => l_len == r_len && l_bytes[..usize::from(*l_len)] == r_bytes[..usize::from(*l_len)],
43            (Self::Heap(l_vec), Self::Heap(r_vec)) => l_vec == r_vec,
44            (
45                Self::Heap(l_vec),
46                Self::Stack {
47                    bytes: r_bytes,
48                    len: r_len,
49                },
50            ) => l_vec.len() == usize::from(*r_len) && l_vec == &r_bytes[..usize::from(*r_len)],
51            (
52                Self::Stack {
53                    bytes: l_bytes,
54                    len: l_len,
55                },
56                Self::Heap(r_vec),
57            ) => usize::from(*l_len) == r_vec.len() && &l_bytes[..usize::from(*l_len)] == r_vec,
58        }
59    }
60}
61
62impl Eq for SmallVec {}
63
64impl SmallVec {
65    /// Returns a slice to the bytes stored
66    pub fn data(&self) -> &[u8] {
67        match self {
68            Self::Stack { bytes, len } => &bytes[..usize::from(*len)],
69            Self::Heap(vec) => vec.as_slice(),
70        }
71    }
72
73    /// Returns a slice to the bytes stored (same a [Self::data])
74    pub fn as_slice(&self) -> &[u8] {
75        self.data()
76    }
77
78    /// Returns a mutable slice to the bytes stored
79    pub fn as_mut_slice(&mut self) -> &mut [u8] {
80        match self {
81            Self::Stack { bytes, len } => &mut bytes[..usize::from(*len)],
82            Self::Heap(vec) => vec.as_mut_slice(),
83        }
84    }
85
86    /// Returns the len, i.e. the number of bytes stored
87    pub fn len(&self) -> usize {
88        match self {
89            Self::Stack { len, .. } => usize::from(*len),
90            Self::Heap(vec) => vec.len(),
91        }
92    }
93
94    /// Returns whether self is empty
95    pub fn is_empty(&self) -> bool {
96        match self {
97            Self::Stack { len, .. } => *len == 0,
98            Self::Heap(vec) => vec.is_empty(),
99        }
100    }
101
102    /// Return the u64 value when interpreting the bytes as a `u64`
103    ///
104    /// * Bytes are interpreted in little endian
105    /// * Bytes above the 8th are ignored
106    pub fn as_u64(&self) -> u64 {
107        let mut le_bytes = [0u8; u64::BITS as usize / 8];
108        let data = self.data();
109        let smallest = le_bytes.len().min(data.len());
110        le_bytes[..smallest].copy_from_slice(&data[..smallest]);
111
112        u64::from_le_bytes(le_bytes)
113    }
114
115    /// Return the u128 value when interpreting the bytes as a `u128`
116    ///
117    /// * Bytes are interpreted in little endian
118    /// * Bytes above the 16th are ignored
119    pub fn as_u128(&self) -> u128 {
120        let mut le_bytes = [0u8; u128::BITS as usize / 8];
121        let data = self.data();
122        let smallest = le_bytes.len().min(data.len());
123        le_bytes[..smallest].copy_from_slice(&data[..smallest]);
124
125        u128::from_le_bytes(le_bytes)
126    }
127
128    /// Sets the data stored in the tag
129    ///
130    /// This overwrites existing data stored
131    pub fn set_data(&mut self, data: &[u8]) {
132        match self {
133            Self::Stack { bytes, len } => {
134                if data.len() > bytes.len() {
135                    // There is not enough space, so we have to allocate
136                    // a Vec
137                    *self = Self::Heap(data.to_vec());
138                } else {
139                    bytes[..data.len()].copy_from_slice(data);
140                    *len = data.len() as u8;
141                }
142            }
143            Self::Heap(vec) => {
144                // Even if the data could fit in the Stack array,
145                // Since, we already have a vec allocated we use it instead.
146                //
147                // And in that case, there won't be any allocations since,
148                // to have a vec in the first place, the allocated size is >
149                // size_of::<Vec<T>>
150                //
151                // But of course, if the new data is larger than the vec, a new
152                // allocation will be made
153                vec.clear();
154                vec.extend_from_slice(data);
155            }
156        }
157    }
158
159    /// Sets the tag with the given u64 value
160    ///
161    /// * Bytes are stored in little endian
162    /// * This overwrites existing data stored
163    pub fn set_u64(&mut self, value: u64) {
164        let le_bytes = value.to_le_bytes();
165        self.set_data(le_bytes.as_slice());
166    }
167
168    /// Sets the tag with the given u128 value
169    ///
170    /// * Bytes are stored in little endian
171    /// * This overwrites existing data stored
172    pub fn set_u128(&mut self, value: u128) {
173        let le_bytes = value.to_le_bytes();
174        self.set_data(le_bytes.as_slice());
175    }
176
177    /// Clears the vector, removing all values.
178    ///
179    /// Note that this method has no effect on the allocated capacity of the vector.
180    pub fn clear(&mut self) {
181        match self {
182            Self::Stack { bytes: _, len } => *len = 0,
183            Self::Heap(items) => items.clear(),
184        }
185    }
186
187    // Creates a SmallVec from the vec, but, only re-uses the vec
188    // if its len would not fit on the stack part.
189    //
190    // Meant for versioning and deserializing
191    fn from_vec_conservative(vec: Vec<u8>) -> Self {
192        // We only re-use the versioned vec, if the SmallVec would actually
193        // have had its data on the heap, otherwise we prefer to keep data on stack
194        // as its cheaper in memory and copies
195        if vec.len() > STACK_ARRAY_SIZE {
196            Self::Heap(vec)
197        } else {
198            let mut data = Self::default();
199            data.set_data(vec.as_slice());
200            data
201        }
202    }
203}
204
205impl serde::Serialize for SmallVec {
206    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207    where
208        S: serde::Serializer,
209    {
210        serializer.serialize_bytes(self.data())
211    }
212}
213
214struct SmallVecVisitor;
215
216impl serde::de::Visitor<'_> for SmallVecVisitor {
217    type Value = SmallVec;
218
219    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
220        formatter.write_str("a slice of bytes (&[u8]) or Vec<u8>")
221    }
222
223    fn visit_bytes<E>(self, bytes: &[u8]) -> Result<Self::Value, E>
224    where
225        E: serde::de::Error,
226    {
227        let mut vec = SmallVec::default();
228        vec.set_data(bytes);
229        Ok(vec)
230    }
231
232    fn visit_byte_buf<E>(self, bytes: Vec<u8>) -> Result<Self::Value, E>
233    where
234        E: serde::de::Error,
235    {
236        Ok(SmallVec::from_vec_conservative(bytes))
237    }
238}
239
240impl Versionize for SmallVec {
241    type Versioned<'vers>
242        = &'vers [u8]
243    where
244        Self: 'vers;
245
246    fn versionize(&self) -> Self::Versioned<'_> {
247        self.data()
248    }
249}
250
251impl VersionizeOwned for SmallVec {
252    type VersionedOwned = Vec<u8>;
253
254    fn versionize_owned(self) -> Self::VersionedOwned {
255        match self {
256            Self::Stack { bytes, len } => bytes[..usize::from(len)].to_vec(),
257            Self::Heap(vec) => vec,
258        }
259    }
260}
261
262impl Unversionize for SmallVec {
263    fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
264        Ok(Self::from_vec_conservative(versioned))
265    }
266}
267
268impl<'de> serde::Deserialize<'de> for SmallVec {
269    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270    where
271        D: serde::Deserializer<'de>,
272    {
273        deserializer.deserialize_bytes(SmallVecVisitor)
274    }
275}
276
277/// Tag
278///
279/// The `Tag` allows to store bytes alongside entities (keys, and ciphertexts)
280/// the main purpose of this system is to `tag` / identify ciphertext with their keys.
281///
282/// TFHE-rs generally does not interpret or check this data, it only stores it and passes it around.
283///
284/// The [crate::upgrade::UpgradeKeyChain] uses the tag to differentiate keys
285///
286/// The rules for how the Tag is passed around are:
287/// * When encrypted, a ciphertext gets the tag of the key used to encrypt it.
288/// * Ciphertexts resulting from operations (add, sub, etc.) get the tag from the ServerKey used
289/// * PublicKey gets its tag from the ClientKey that was used to create it
290/// * ServerKey gets its tag from the ClientKey that was used to create it
291///
292/// User can change the tag of any entities at any point.
293///
294/// # Example
295///
296/// ```
297/// use rand::random;
298/// use tfhe::prelude::*;
299/// use tfhe::{ClientKey, ConfigBuilder, FheUint32, ServerKey};
300///
301/// // Generate the client key then set its tag
302/// let mut cks = ClientKey::generate(ConfigBuilder::default());
303/// let tag_value = random();
304/// cks.tag_mut().set_u64(tag_value);
305/// assert_eq!(cks.tag().as_u64(), tag_value);
306///
307/// // The server key inherits the client key tag
308/// let sks = ServerKey::new(&cks);
309/// assert_eq!(sks.tag(), cks.tag());
310///
311/// // Encrypted data inherits the tag of the encryption key
312/// let a = FheUint32::encrypt(32832u32, &cks);
313/// assert_eq!(a.tag(), cks.tag());
314/// ```
315#[derive(
316    Default, Clone, Debug, serde::Serialize, serde::Deserialize, Versionize, PartialEq, Eq,
317)]
318#[versionize(TagVersions)]
319pub struct Tag {
320    // We don't want the enum to be public
321    inner: SmallVec,
322}
323
324impl Tag {
325    /// Returns a slice to the bytes stored
326    pub fn data(&self) -> &[u8] {
327        self.inner.data()
328    }
329
330    /// Returns a slice to the bytes stored (same a [Self::data])
331    pub fn as_slice(&self) -> &[u8] {
332        self.inner.as_slice()
333    }
334
335    /// Returns a mutable slice to the bytes stored
336    pub fn as_mut_slice(&mut self) -> &mut [u8] {
337        self.inner.as_mut_slice()
338    }
339
340    /// Returns the len, i.e. the number of bytes stored in the tag
341    pub fn len(&self) -> usize {
342        self.inner.len()
343    }
344
345    /// Returns whether the tag is empty
346    pub fn is_empty(&self) -> bool {
347        self.inner.is_empty()
348    }
349
350    /// Return the u64 value when interpreting the bytes as a `u64`
351    ///
352    /// * Bytes are interpreted in little endian
353    /// * Bytes above the 8th are ignored
354    pub fn as_u64(&self) -> u64 {
355        self.inner.as_u64()
356    }
357
358    /// Return the u128 value when interpreting the bytes as a `u128`
359    ///
360    /// * Bytes are interpreted in little endian
361    /// * Bytes above the 16th are ignored
362    pub fn as_u128(&self) -> u128 {
363        self.inner.as_u128()
364    }
365
366    /// Sets the data stored in the tag
367    ///
368    /// This overwrites existing data stored
369    pub fn set_data(&mut self, data: &[u8]) {
370        self.inner.set_data(data);
371    }
372
373    /// Sets the tag with the given u64 value
374    ///
375    /// * Bytes are stored in little endian
376    /// * This overwrites existing data stored
377    pub fn set_u64(&mut self, value: u64) {
378        self.inner.set_u64(value);
379    }
380
381    /// Sets the tag with the given u128 value
382    ///
383    /// * Bytes are stored in little endian
384    /// * This overwrites existing data stored
385    pub fn set_u128(&mut self, value: u128) {
386        self.inner.set_u128(value);
387    }
388}
389
390impl From<u64> for Tag {
391    fn from(value: u64) -> Self {
392        let mut s = Self::default();
393        s.set_u64(value);
394        s
395    }
396}
397
398impl From<&str> for Tag {
399    fn from(value: &str) -> Self {
400        let mut tag = Self::default();
401        tag.set_data(value.as_bytes());
402        tag
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use rand::prelude::*;
410
411    #[test]
412    fn test_small_vec() {
413        let mut vec_1 = SmallVec::default();
414        vec_1.set_data(&[1, 2, 3, 4, 5]);
415
416        let mut vec_2 = SmallVec::default();
417        vec_2.set_data(vec_1.data());
418
419        assert!(matches!(vec_1, SmallVec::Stack { .. }));
420        assert!(matches!(vec_2, SmallVec::Stack { .. }));
421        assert_eq!(vec_2.len(), vec_1.len());
422        assert_eq!(vec_1.len(), 5);
423        assert_eq!(vec_1, vec_2); // Test both ways
424        assert_eq!(vec_2, vec_1);
425
426        // Put something big in vec_1, we expect the data to be on the heap now
427        let big_data = (0..500u64).map(|x| (x % 256) as u8).collect::<Vec<_>>();
428        vec_1.set_data(&big_data);
429        assert!(matches!(vec_1, SmallVec::Heap(_)));
430        assert!(matches!(vec_2, SmallVec::Stack { .. }));
431        assert_ne!(vec_2.len(), vec_1.len());
432        assert_eq!(vec_1.len(), big_data.len());
433        assert_ne!(vec_1, vec_2);
434        assert_ne!(vec_2, vec_1);
435
436        // Put something the same big data in vec_2,
437        // we also expect the data to be on the heap now
438        vec_2.set_data(&big_data);
439        assert!(matches!(vec_1, SmallVec::Heap(_)));
440        assert!(matches!(vec_2, SmallVec::Heap(_)));
441        assert_eq!(vec_2.len(), vec_1.len());
442        assert_eq!(vec_1.len(), big_data.len());
443        assert_eq!(vec_1, vec_2); // Test both ways
444        assert_eq!(vec_2, vec_1);
445
446        // Now put back something small in vec 1
447        // We expect the data to still be on the heap, since
448        // the heap was allocated to store the previous big data
449        vec_1.set_data(&[1, 2, 3, 4, 5]);
450        assert!(matches!(vec_1, SmallVec::Heap(_)));
451        assert_eq!(vec_1.len(), 5);
452        assert_eq!(vec_1.data(), &[1, 2, 3, 4, 5]);
453        assert_ne!(vec_1, vec_2);
454        assert_ne!(vec_2, vec_1);
455    }
456
457    #[test]
458    fn test_small_vec_u64_u128() {
459        let mut rng = rand::thread_rng();
460
461        let mut vec = SmallVec::default();
462        {
463            let value = rng.gen();
464            vec.set_u64(value);
465            assert_eq!(vec.as_u64(), value);
466
467            assert_eq!(vec.as_u128(), u128::from(value));
468        }
469
470        {
471            let value = rng.gen();
472            vec.set_u128(value);
473            assert_eq!(vec.as_u128(), value);
474
475            assert_eq!(vec.as_u64(), value as u64);
476        }
477    }
478}