Skip to main content

solana_short_vec/
lib.rs

1//! Compact serde-encoding of vectors with small length.
2#![no_std]
3#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![allow(clippy::arithmetic_side_effects)]
6
7extern crate alloc;
8#[cfg(feature = "frozen-abi")]
9extern crate std;
10
11#[cfg(feature = "frozen-abi")]
12use solana_frozen_abi_macro::AbiExample;
13use {alloc::vec::Vec, core::convert::TryFrom};
14#[cfg(feature = "serde")]
15use {
16    core::{fmt, marker::PhantomData},
17    serde_core::{
18        de::{self, Deserializer, SeqAccess, Visitor},
19        ser::{self, SerializeTuple, Serializer},
20        Deserialize, Serialize,
21    },
22};
23
24#[cfg(feature = "wincode")]
25mod wincode;
26
27/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
28/// 0x7f, the top bit is set and the remaining value is stored in the next
29/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
30/// byte may only have the 2 least-significant bits set, otherwise the encoded
31/// value will overflow the u16.
32#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
33pub struct ShortU16(pub u16);
34
35#[cfg(feature = "serde")]
36impl Serialize for ShortU16 {
37    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38    where
39        S: Serializer,
40    {
41        // Pass a non-zero value to serialize_tuple() so that serde_json will
42        // generate an open bracket.
43        let mut seq = serializer.serialize_tuple(1)?;
44
45        let mut rem_val = self.0;
46        loop {
47            let mut elem = (rem_val & 0x7f) as u8;
48            rem_val >>= 7;
49            if rem_val == 0 {
50                seq.serialize_element(&elem)?;
51                break;
52            } else {
53                elem |= 0x80;
54                seq.serialize_element(&elem)?;
55            }
56        }
57        seq.end()
58    }
59}
60
61enum VisitStatus {
62    Done(u16),
63    More(u16),
64}
65
66#[derive(Debug)]
67#[cfg_attr(not(feature = "serde"), allow(dead_code))]
68enum VisitError {
69    TooLong(usize),
70    TooShort(usize),
71    Overflow(u32),
72    Alias,
73    ByteThreeContinues,
74}
75
76#[cfg(feature = "serde")]
77impl VisitError {
78    fn into_de_error<'de, A>(self) -> A::Error
79    where
80        A: SeqAccess<'de>,
81    {
82        match self {
83            VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
84            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
85            VisitError::Overflow(val) => de::Error::invalid_value(
86                de::Unexpected::Unsigned(val as u64),
87                &"a value in the range [0, 65535]",
88            ),
89            VisitError::Alias => de::Error::invalid_value(
90                de::Unexpected::Other("alias encoding"),
91                &"strict form encoding",
92            ),
93            VisitError::ByteThreeContinues => de::Error::invalid_value(
94                de::Unexpected::Other("continue signal on byte-three"),
95                &"a terminal signal on or before byte-three",
96            ),
97        }
98    }
99}
100
101type VisitResult = Result<VisitStatus, VisitError>;
102
103const MAX_ENCODING_LENGTH: usize = 3;
104fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
105    if elem == 0 && nth_byte != 0 {
106        return Err(VisitError::Alias);
107    }
108
109    let val = u32::from(val);
110    let elem = u32::from(elem);
111    let elem_val = elem & 0x7f;
112    let elem_done = (elem & 0x80) == 0;
113
114    if nth_byte >= MAX_ENCODING_LENGTH {
115        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
116    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
117        return Err(VisitError::ByteThreeContinues);
118    }
119
120    let shift = u32::try_from(nth_byte)
121        .unwrap_or(u32::MAX)
122        .saturating_mul(7);
123    let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
124
125    let new_val = val | elem_val;
126    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
127
128    if elem_done {
129        Ok(VisitStatus::Done(val))
130    } else {
131        Ok(VisitStatus::More(val))
132    }
133}
134
135#[cfg(feature = "serde")]
136struct ShortU16Visitor;
137
138#[cfg(feature = "serde")]
139impl<'de> Visitor<'de> for ShortU16Visitor {
140    type Value = ShortU16;
141
142    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
143        formatter.write_str("a ShortU16")
144    }
145
146    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
147    where
148        A: SeqAccess<'de>,
149    {
150        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
151        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
152        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
153        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
154        let mut val: u16 = 0;
155        for nth_byte in 0..MAX_ENCODING_LENGTH {
156            let elem: u8 = seq.next_element()?.ok_or_else(|| {
157                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
158            })?;
159            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
160                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
161                VisitStatus::More(new_val) => val = new_val,
162            }
163        }
164
165        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
166    }
167}
168
169#[cfg(feature = "serde")]
170impl<'de> Deserialize<'de> for ShortU16 {
171    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
172    where
173        D: Deserializer<'de>,
174    {
175        deserializer.deserialize_tuple(3, ShortU16Visitor)
176    }
177}
178
179/// If you don't want to use the ShortVec newtype, you can do ShortVec
180/// serialization on an ordinary vector with the following field annotation:
181///
182/// #[serde(with = "short_vec")]
183///
184#[cfg(feature = "serde")]
185pub fn serialize<S: Serializer, T: Serialize>(
186    elements: &[T],
187    serializer: S,
188) -> Result<S::Ok, S::Error> {
189    // Pass a non-zero value to serialize_tuple() so that serde_json will
190    // generate an open bracket.
191    let mut seq = serializer.serialize_tuple(1)?;
192
193    let len = elements.len();
194    if len > u16::MAX as usize {
195        return Err(ser::Error::custom("length larger than u16"));
196    }
197    let short_len = ShortU16(len as u16);
198    seq.serialize_element(&short_len)?;
199
200    for element in elements {
201        seq.serialize_element(element)?;
202    }
203    seq.end()
204}
205
206#[cfg(feature = "serde")]
207struct ShortVecVisitor<T> {
208    _t: PhantomData<T>,
209}
210
211#[cfg(feature = "serde")]
212impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
213where
214    T: Deserialize<'de>,
215{
216    type Value = Vec<T>;
217
218    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219        formatter.write_str("a Vec with a multi-byte length")
220    }
221
222    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
223    where
224        A: SeqAccess<'de>,
225    {
226        let short_len: ShortU16 = seq
227            .next_element()?
228            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
229        let len = short_len.0 as usize;
230
231        let mut result = Vec::with_capacity(len);
232        for i in 0..len {
233            let elem = seq
234                .next_element()?
235                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
236            result.push(elem);
237        }
238        Ok(result)
239    }
240}
241
242/// If you don't want to use the ShortVec newtype, you can do ShortVec
243/// deserialization on an ordinary vector with the following field annotation:
244///
245/// #[serde(with = "short_vec")]
246///
247#[cfg(feature = "serde")]
248pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
249where
250    D: Deserializer<'de>,
251    T: Deserialize<'de>,
252{
253    let visitor = ShortVecVisitor { _t: PhantomData };
254    deserializer.deserialize_tuple(usize::MAX, visitor)
255}
256
257pub struct ShortVec<T>(pub Vec<T>);
258
259#[cfg(feature = "serde")]
260impl<T: Serialize> Serialize for ShortVec<T> {
261    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262    where
263        S: Serializer,
264    {
265        serialize(&self.0, serializer)
266    }
267}
268
269#[cfg(feature = "serde")]
270impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
271    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
272    where
273        D: Deserializer<'de>,
274    {
275        deserialize(deserializer).map(ShortVec)
276    }
277}
278
279/// Return the decoded value and how many bytes it consumed.
280#[allow(clippy::result_unit_err)]
281pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
282    let mut val = 0;
283    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
284        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
285            VisitStatus::More(new_val) => val = new_val,
286            VisitStatus::Done(new_val) => {
287                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
288            }
289        }
290    }
291    Err(())
292}
293
294#[cfg(all(test, feature = "serde"))]
295mod tests {
296    use {
297        super::*,
298        alloc::vec,
299        assert_matches::assert_matches,
300        bincode::{deserialize, serialize},
301    };
302
303    /// Return the serialized length.
304    fn encode_len(len: u16) -> Vec<u8> {
305        bincode::serialize(&ShortU16(len)).unwrap()
306    }
307
308    fn assert_len_encoding(len: u16, bytes: &[u8]) {
309        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
310        assert_eq!(
311            decode_shortu16_len(bytes).unwrap(),
312            (usize::from(len), bytes.len()),
313            "unexpected usize decoding"
314        );
315    }
316
317    #[test]
318    fn test_short_vec_encode_len() {
319        assert_len_encoding(0x0, &[0x0]);
320        assert_len_encoding(0x7f, &[0x7f]);
321        assert_len_encoding(0x80, &[0x80, 0x01]);
322        assert_len_encoding(0xff, &[0xff, 0x01]);
323        assert_len_encoding(0x100, &[0x80, 0x02]);
324        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
325        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
326    }
327
328    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
329        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
330    }
331
332    fn assert_bad_deserialized_value(bytes: &[u8]) {
333        assert!(deserialize::<ShortU16>(bytes).is_err());
334    }
335
336    #[test]
337    fn test_deserialize() {
338        assert_good_deserialized_value(0x0000, &[0x00]);
339        assert_good_deserialized_value(0x007f, &[0x7f]);
340        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
341        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
342        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
343        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
344        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
345        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
346        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
347
348        // aliases
349        // 0x0000
350        assert_bad_deserialized_value(&[0x80, 0x00]);
351        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
352        // 0x007f
353        assert_bad_deserialized_value(&[0xff, 0x00]);
354        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
355        // 0x0080
356        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
357        // 0x00ff
358        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
359        // 0x0100
360        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
361        // 0x07ff
362        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
363        // 0x3fff
364        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
365
366        // too short
367        assert_bad_deserialized_value(&[]);
368        assert_bad_deserialized_value(&[0x80]);
369
370        // too long
371        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
372
373        // too large
374        // 0x0001_0000
375        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
376        // 0x0001_8000
377        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
378    }
379
380    #[test]
381    fn test_short_vec_u8() {
382        let vec = ShortVec(vec![4u8; 32]);
383        let bytes = serialize(&vec).unwrap();
384        assert_eq!(bytes.len(), vec.0.len() + 1);
385
386        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
387        assert_eq!(vec.0, vec1.0);
388    }
389
390    #[test]
391    fn test_short_vec_u8_too_long() {
392        let vec = ShortVec(vec![4u8; u16::MAX as usize]);
393        assert_matches!(serialize(&vec), Ok(_));
394
395        let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
396        assert_matches!(serialize(&vec), Err(_));
397    }
398
399    #[test]
400    fn test_short_vec_json() {
401        let vec = ShortVec(vec![0, 1, 2]);
402        let s = serde_json::to_string(&vec).unwrap();
403        assert_eq!(s, "[[3],0,1,2]");
404    }
405
406    #[test]
407    fn test_short_vec_aliased_length() {
408        let bytes = [
409            0x81, 0x80, 0x00, // 3-byte alias of 1
410            0x00,
411        ];
412        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
413    }
414}