1use crate::high_level_api::backward_compatibility::tag::TagVersions;
2use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned};
3
4const STACK_ARRAY_SIZE: usize = std::mem::size_of::<Vec<u8>>() - 1;
5
6#[derive(Clone, Debug)]
12pub(in crate::high_level_api) enum SmallVec {
13 Stack {
14 bytes: [u8; STACK_ARRAY_SIZE],
15 len: u8,
18 },
19 Heap(Vec<u8>),
20}
21
22impl Default for SmallVec {
23 fn default() -> Self {
24 Self::Stack {
25 bytes: Default::default(),
26 len: 0,
27 }
28 }
29}
30impl PartialEq for SmallVec {
31 fn eq(&self, other: &Self) -> bool {
32 match (self, other) {
33 (
34 Self::Stack {
35 bytes: l_bytes,
36 len: l_len,
37 },
38 Self::Stack {
39 bytes: r_bytes,
40 len: r_len,
41 },
42 ) => l_len == r_len && l_bytes[..usize::from(*l_len)] == r_bytes[..usize::from(*l_len)],
43 (Self::Heap(l_vec), Self::Heap(r_vec)) => l_vec == r_vec,
44 (
45 Self::Heap(l_vec),
46 Self::Stack {
47 bytes: r_bytes,
48 len: r_len,
49 },
50 ) => l_vec.len() == usize::from(*r_len) && l_vec == &r_bytes[..usize::from(*r_len)],
51 (
52 Self::Stack {
53 bytes: l_bytes,
54 len: l_len,
55 },
56 Self::Heap(r_vec),
57 ) => usize::from(*l_len) == r_vec.len() && &l_bytes[..usize::from(*l_len)] == r_vec,
58 }
59 }
60}
61
62impl Eq for SmallVec {}
63
64impl SmallVec {
65 pub fn data(&self) -> &[u8] {
67 match self {
68 Self::Stack { bytes, len } => &bytes[..usize::from(*len)],
69 Self::Heap(vec) => vec.as_slice(),
70 }
71 }
72
73 pub fn as_slice(&self) -> &[u8] {
75 self.data()
76 }
77
78 pub fn as_mut_slice(&mut self) -> &mut [u8] {
80 match self {
81 Self::Stack { bytes, len } => &mut bytes[..usize::from(*len)],
82 Self::Heap(vec) => vec.as_mut_slice(),
83 }
84 }
85
86 pub fn len(&self) -> usize {
88 match self {
89 Self::Stack { len, .. } => usize::from(*len),
90 Self::Heap(vec) => vec.len(),
91 }
92 }
93
94 pub fn is_empty(&self) -> bool {
96 match self {
97 Self::Stack { len, .. } => *len == 0,
98 Self::Heap(vec) => vec.is_empty(),
99 }
100 }
101
102 pub fn as_u64(&self) -> u64 {
107 let mut le_bytes = [0u8; u64::BITS as usize / 8];
108 let data = self.data();
109 let smallest = le_bytes.len().min(data.len());
110 le_bytes[..smallest].copy_from_slice(&data[..smallest]);
111
112 u64::from_le_bytes(le_bytes)
113 }
114
115 pub fn as_u128(&self) -> u128 {
120 let mut le_bytes = [0u8; u128::BITS as usize / 8];
121 let data = self.data();
122 let smallest = le_bytes.len().min(data.len());
123 le_bytes[..smallest].copy_from_slice(&data[..smallest]);
124
125 u128::from_le_bytes(le_bytes)
126 }
127
128 pub fn set_data(&mut self, data: &[u8]) {
132 match self {
133 Self::Stack { bytes, len } => {
134 if data.len() > bytes.len() {
135 *self = Self::Heap(data.to_vec());
138 } else {
139 bytes[..data.len()].copy_from_slice(data);
140 *len = data.len() as u8;
141 }
142 }
143 Self::Heap(vec) => {
144 vec.clear();
154 vec.extend_from_slice(data);
155 }
156 }
157 }
158
159 pub fn set_u64(&mut self, value: u64) {
164 let le_bytes = value.to_le_bytes();
165 self.set_data(le_bytes.as_slice());
166 }
167
168 pub fn set_u128(&mut self, value: u128) {
173 let le_bytes = value.to_le_bytes();
174 self.set_data(le_bytes.as_slice());
175 }
176
177 pub fn clear(&mut self) {
181 match self {
182 Self::Stack { bytes: _, len } => *len = 0,
183 Self::Heap(items) => items.clear(),
184 }
185 }
186
187 fn from_vec_conservative(vec: Vec<u8>) -> Self {
192 if vec.len() > STACK_ARRAY_SIZE {
196 Self::Heap(vec)
197 } else {
198 let mut data = Self::default();
199 data.set_data(vec.as_slice());
200 data
201 }
202 }
203}
204
205impl serde::Serialize for SmallVec {
206 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207 where
208 S: serde::Serializer,
209 {
210 serializer.serialize_bytes(self.data())
211 }
212}
213
214struct SmallVecVisitor;
215
216impl serde::de::Visitor<'_> for SmallVecVisitor {
217 type Value = SmallVec;
218
219 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
220 formatter.write_str("a slice of bytes (&[u8]) or Vec<u8>")
221 }
222
223 fn visit_bytes<E>(self, bytes: &[u8]) -> Result<Self::Value, E>
224 where
225 E: serde::de::Error,
226 {
227 let mut vec = SmallVec::default();
228 vec.set_data(bytes);
229 Ok(vec)
230 }
231
232 fn visit_byte_buf<E>(self, bytes: Vec<u8>) -> Result<Self::Value, E>
233 where
234 E: serde::de::Error,
235 {
236 Ok(SmallVec::from_vec_conservative(bytes))
237 }
238}
239
240impl Versionize for SmallVec {
241 type Versioned<'vers>
242 = &'vers [u8]
243 where
244 Self: 'vers;
245
246 fn versionize(&self) -> Self::Versioned<'_> {
247 self.data()
248 }
249}
250
251impl VersionizeOwned for SmallVec {
252 type VersionedOwned = Vec<u8>;
253
254 fn versionize_owned(self) -> Self::VersionedOwned {
255 match self {
256 Self::Stack { bytes, len } => bytes[..usize::from(len)].to_vec(),
257 Self::Heap(vec) => vec,
258 }
259 }
260}
261
262impl Unversionize for SmallVec {
263 fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
264 Ok(Self::from_vec_conservative(versioned))
265 }
266}
267
268impl<'de> serde::Deserialize<'de> for SmallVec {
269 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270 where
271 D: serde::Deserializer<'de>,
272 {
273 deserializer.deserialize_bytes(SmallVecVisitor)
274 }
275}
276
277#[derive(
316 Default, Clone, Debug, serde::Serialize, serde::Deserialize, Versionize, PartialEq, Eq,
317)]
318#[versionize(TagVersions)]
319pub struct Tag {
320 inner: SmallVec,
322}
323
324impl Tag {
325 pub fn data(&self) -> &[u8] {
327 self.inner.data()
328 }
329
330 pub fn as_slice(&self) -> &[u8] {
332 self.inner.as_slice()
333 }
334
335 pub fn as_mut_slice(&mut self) -> &mut [u8] {
337 self.inner.as_mut_slice()
338 }
339
340 pub fn len(&self) -> usize {
342 self.inner.len()
343 }
344
345 pub fn is_empty(&self) -> bool {
347 self.inner.is_empty()
348 }
349
350 pub fn as_u64(&self) -> u64 {
355 self.inner.as_u64()
356 }
357
358 pub fn as_u128(&self) -> u128 {
363 self.inner.as_u128()
364 }
365
366 pub fn set_data(&mut self, data: &[u8]) {
370 self.inner.set_data(data);
371 }
372
373 pub fn set_u64(&mut self, value: u64) {
378 self.inner.set_u64(value);
379 }
380
381 pub fn set_u128(&mut self, value: u128) {
386 self.inner.set_u128(value);
387 }
388}
389
390impl From<u64> for Tag {
391 fn from(value: u64) -> Self {
392 let mut s = Self::default();
393 s.set_u64(value);
394 s
395 }
396}
397
398impl From<&str> for Tag {
399 fn from(value: &str) -> Self {
400 let mut tag = Self::default();
401 tag.set_data(value.as_bytes());
402 tag
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use rand::prelude::*;
410
411 #[test]
412 fn test_small_vec() {
413 let mut vec_1 = SmallVec::default();
414 vec_1.set_data(&[1, 2, 3, 4, 5]);
415
416 let mut vec_2 = SmallVec::default();
417 vec_2.set_data(vec_1.data());
418
419 assert!(matches!(vec_1, SmallVec::Stack { .. }));
420 assert!(matches!(vec_2, SmallVec::Stack { .. }));
421 assert_eq!(vec_2.len(), vec_1.len());
422 assert_eq!(vec_1.len(), 5);
423 assert_eq!(vec_1, vec_2); assert_eq!(vec_2, vec_1);
425
426 let big_data = (0..500u64).map(|x| (x % 256) as u8).collect::<Vec<_>>();
428 vec_1.set_data(&big_data);
429 assert!(matches!(vec_1, SmallVec::Heap(_)));
430 assert!(matches!(vec_2, SmallVec::Stack { .. }));
431 assert_ne!(vec_2.len(), vec_1.len());
432 assert_eq!(vec_1.len(), big_data.len());
433 assert_ne!(vec_1, vec_2);
434 assert_ne!(vec_2, vec_1);
435
436 vec_2.set_data(&big_data);
439 assert!(matches!(vec_1, SmallVec::Heap(_)));
440 assert!(matches!(vec_2, SmallVec::Heap(_)));
441 assert_eq!(vec_2.len(), vec_1.len());
442 assert_eq!(vec_1.len(), big_data.len());
443 assert_eq!(vec_1, vec_2); assert_eq!(vec_2, vec_1);
445
446 vec_1.set_data(&[1, 2, 3, 4, 5]);
450 assert!(matches!(vec_1, SmallVec::Heap(_)));
451 assert_eq!(vec_1.len(), 5);
452 assert_eq!(vec_1.data(), &[1, 2, 3, 4, 5]);
453 assert_ne!(vec_1, vec_2);
454 assert_ne!(vec_2, vec_1);
455 }
456
457 #[test]
458 fn test_small_vec_u64_u128() {
459 let mut rng = rand::thread_rng();
460
461 let mut vec = SmallVec::default();
462 {
463 let value = rng.gen();
464 vec.set_u64(value);
465 assert_eq!(vec.as_u64(), value);
466
467 assert_eq!(vec.as_u128(), u128::from(value));
468 }
469
470 {
471 let value = rng.gen();
472 vec.set_u128(value);
473 assert_eq!(vec.as_u128(), value);
474
475 assert_eq!(vec.as_u64(), value as u64);
476 }
477 }
478}