solana_short_vec/
lib.rs

1//! Compact serde-encoding of vectors with small length.
2#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![allow(clippy::arithmetic_side_effects)]
5#[cfg(feature = "frozen-abi")]
6use solana_frozen_abi_macro::AbiExample;
7use {
8    serde_core::{
9        de::{self, Deserializer, SeqAccess, Visitor},
10        ser::{self, SerializeTuple, Serializer},
11        Deserialize, Serialize,
12    },
13    std::{convert::TryFrom, fmt, marker::PhantomData},
14};
15
16/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
17/// 0x7f, the top bit is set and the remaining value is stored in the next
18/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
19/// byte may only have the 2 least-significant bits set, otherwise the encoded
20/// value will overflow the u16.
21#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
22pub struct ShortU16(pub u16);
23
24impl Serialize for ShortU16 {
25    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
26    where
27        S: Serializer,
28    {
29        // Pass a non-zero value to serialize_tuple() so that serde_json will
30        // generate an open bracket.
31        let mut seq = serializer.serialize_tuple(1)?;
32
33        let mut rem_val = self.0;
34        loop {
35            let mut elem = (rem_val & 0x7f) as u8;
36            rem_val >>= 7;
37            if rem_val == 0 {
38                seq.serialize_element(&elem)?;
39                break;
40            } else {
41                elem |= 0x80;
42                seq.serialize_element(&elem)?;
43            }
44        }
45        seq.end()
46    }
47}
48
49enum VisitStatus {
50    Done(u16),
51    More(u16),
52}
53
54#[derive(Debug)]
55enum VisitError {
56    TooLong(usize),
57    TooShort(usize),
58    Overflow(u32),
59    Alias,
60    ByteThreeContinues,
61}
62
63impl VisitError {
64    fn into_de_error<'de, A>(self) -> A::Error
65    where
66        A: SeqAccess<'de>,
67    {
68        match self {
69            VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
70            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
71            VisitError::Overflow(val) => de::Error::invalid_value(
72                de::Unexpected::Unsigned(val as u64),
73                &"a value in the range [0, 65535]",
74            ),
75            VisitError::Alias => de::Error::invalid_value(
76                de::Unexpected::Other("alias encoding"),
77                &"strict form encoding",
78            ),
79            VisitError::ByteThreeContinues => de::Error::invalid_value(
80                de::Unexpected::Other("continue signal on byte-three"),
81                &"a terminal signal on or before byte-three",
82            ),
83        }
84    }
85}
86
87type VisitResult = Result<VisitStatus, VisitError>;
88
89const MAX_ENCODING_LENGTH: usize = 3;
90fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
91    if elem == 0 && nth_byte != 0 {
92        return Err(VisitError::Alias);
93    }
94
95    let val = u32::from(val);
96    let elem = u32::from(elem);
97    let elem_val = elem & 0x7f;
98    let elem_done = (elem & 0x80) == 0;
99
100    if nth_byte >= MAX_ENCODING_LENGTH {
101        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
102    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
103        return Err(VisitError::ByteThreeContinues);
104    }
105
106    let shift = u32::try_from(nth_byte)
107        .unwrap_or(u32::MAX)
108        .saturating_mul(7);
109    let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
110
111    let new_val = val | elem_val;
112    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
113
114    if elem_done {
115        Ok(VisitStatus::Done(val))
116    } else {
117        Ok(VisitStatus::More(val))
118    }
119}
120
121struct ShortU16Visitor;
122
123impl<'de> Visitor<'de> for ShortU16Visitor {
124    type Value = ShortU16;
125
126    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
127        formatter.write_str("a ShortU16")
128    }
129
130    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
131    where
132        A: SeqAccess<'de>,
133    {
134        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
135        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
136        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
137        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
138        let mut val: u16 = 0;
139        for nth_byte in 0..MAX_ENCODING_LENGTH {
140            let elem: u8 = seq.next_element()?.ok_or_else(|| {
141                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
142            })?;
143            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
144                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
145                VisitStatus::More(new_val) => val = new_val,
146            }
147        }
148
149        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
150    }
151}
152
153impl<'de> Deserialize<'de> for ShortU16 {
154    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
155    where
156        D: Deserializer<'de>,
157    {
158        deserializer.deserialize_tuple(3, ShortU16Visitor)
159    }
160}
161
162/// If you don't want to use the ShortVec newtype, you can do ShortVec
163/// serialization on an ordinary vector with the following field annotation:
164///
165/// #[serde(with = "short_vec")]
166///
167pub fn serialize<S: Serializer, T: Serialize>(
168    elements: &[T],
169    serializer: S,
170) -> Result<S::Ok, S::Error> {
171    // Pass a non-zero value to serialize_tuple() so that serde_json will
172    // generate an open bracket.
173    let mut seq = serializer.serialize_tuple(1)?;
174
175    let len = elements.len();
176    if len > u16::MAX as usize {
177        return Err(ser::Error::custom("length larger than u16"));
178    }
179    let short_len = ShortU16(len as u16);
180    seq.serialize_element(&short_len)?;
181
182    for element in elements {
183        seq.serialize_element(element)?;
184    }
185    seq.end()
186}
187
188struct ShortVecVisitor<T> {
189    _t: PhantomData<T>,
190}
191
192impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
193where
194    T: Deserialize<'de>,
195{
196    type Value = Vec<T>;
197
198    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
199        formatter.write_str("a Vec with a multi-byte length")
200    }
201
202    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
203    where
204        A: SeqAccess<'de>,
205    {
206        let short_len: ShortU16 = seq
207            .next_element()?
208            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
209        let len = short_len.0 as usize;
210
211        let mut result = Vec::with_capacity(len);
212        for i in 0..len {
213            let elem = seq
214                .next_element()?
215                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
216            result.push(elem);
217        }
218        Ok(result)
219    }
220}
221
222/// If you don't want to use the ShortVec newtype, you can do ShortVec
223/// deserialization on an ordinary vector with the following field annotation:
224///
225/// #[serde(with = "short_vec")]
226///
227pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
228where
229    D: Deserializer<'de>,
230    T: Deserialize<'de>,
231{
232    let visitor = ShortVecVisitor { _t: PhantomData };
233    deserializer.deserialize_tuple(usize::MAX, visitor)
234}
235
236pub struct ShortVec<T>(pub Vec<T>);
237
238impl<T: Serialize> Serialize for ShortVec<T> {
239    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240    where
241        S: Serializer,
242    {
243        serialize(&self.0, serializer)
244    }
245}
246
247impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
248    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
249    where
250        D: Deserializer<'de>,
251    {
252        deserialize(deserializer).map(ShortVec)
253    }
254}
255
256/// Return the decoded value and how many bytes it consumed.
257#[allow(clippy::result_unit_err)]
258pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
259    let mut val = 0;
260    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
261        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
262            VisitStatus::More(new_val) => val = new_val,
263            VisitStatus::Done(new_val) => {
264                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
265            }
266        }
267    }
268    Err(())
269}
270
271#[cfg(test)]
272mod tests {
273    use {
274        super::*,
275        assert_matches::assert_matches,
276        bincode::{deserialize, serialize},
277    };
278
279    /// Return the serialized length.
280    fn encode_len(len: u16) -> Vec<u8> {
281        bincode::serialize(&ShortU16(len)).unwrap()
282    }
283
284    fn assert_len_encoding(len: u16, bytes: &[u8]) {
285        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
286        assert_eq!(
287            decode_shortu16_len(bytes).unwrap(),
288            (usize::from(len), bytes.len()),
289            "unexpected usize decoding"
290        );
291    }
292
293    #[test]
294    fn test_short_vec_encode_len() {
295        assert_len_encoding(0x0, &[0x0]);
296        assert_len_encoding(0x7f, &[0x7f]);
297        assert_len_encoding(0x80, &[0x80, 0x01]);
298        assert_len_encoding(0xff, &[0xff, 0x01]);
299        assert_len_encoding(0x100, &[0x80, 0x02]);
300        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
301        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
302    }
303
304    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
305        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
306    }
307
308    fn assert_bad_deserialized_value(bytes: &[u8]) {
309        assert!(deserialize::<ShortU16>(bytes).is_err());
310    }
311
312    #[test]
313    fn test_deserialize() {
314        assert_good_deserialized_value(0x0000, &[0x00]);
315        assert_good_deserialized_value(0x007f, &[0x7f]);
316        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
317        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
318        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
319        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
320        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
321        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
322        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
323
324        // aliases
325        // 0x0000
326        assert_bad_deserialized_value(&[0x80, 0x00]);
327        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
328        // 0x007f
329        assert_bad_deserialized_value(&[0xff, 0x00]);
330        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
331        // 0x0080
332        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
333        // 0x00ff
334        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
335        // 0x0100
336        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
337        // 0x07ff
338        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
339        // 0x3fff
340        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
341
342        // too short
343        assert_bad_deserialized_value(&[]);
344        assert_bad_deserialized_value(&[0x80]);
345
346        // too long
347        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
348
349        // too large
350        // 0x0001_0000
351        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
352        // 0x0001_8000
353        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
354    }
355
356    #[test]
357    fn test_short_vec_u8() {
358        let vec = ShortVec(vec![4u8; 32]);
359        let bytes = serialize(&vec).unwrap();
360        assert_eq!(bytes.len(), vec.0.len() + 1);
361
362        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
363        assert_eq!(vec.0, vec1.0);
364    }
365
366    #[test]
367    fn test_short_vec_u8_too_long() {
368        let vec = ShortVec(vec![4u8; u16::MAX as usize]);
369        assert_matches!(serialize(&vec), Ok(_));
370
371        let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
372        assert_matches!(serialize(&vec), Err(_));
373    }
374
375    #[test]
376    fn test_short_vec_json() {
377        let vec = ShortVec(vec![0, 1, 2]);
378        let s = serde_json::to_string(&vec).unwrap();
379        assert_eq!(s, "[[3],0,1,2]");
380    }
381
382    #[test]
383    fn test_short_vec_aliased_length() {
384        let bytes = [
385            0x81, 0x80, 0x00, // 3-byte alias of 1
386            0x00,
387        ];
388        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
389    }
390}