Skip to main content

wincode/
len.rs

1//! Support for heterogenous sequence length encoding.
2use crate::{
3    error::{pointer_sized_decode_error, preallocation_size_limit, ReadResult, WriteResult},
4    io::{Reader, Writer},
5    schema::{SchemaRead, SchemaWrite},
6};
7
8/// Behavior to support heterogenous sequence length encoding.
9///
10/// It is possible for sequences to have different length encoding schemes.
11/// This trait abstracts over that possibility, allowing users to specify
12/// the length encoding scheme for a sequence.
13pub trait SeqLen {
14    /// Read the length of a sequence from the reader, where
15    /// `T` is the type of the sequence elements. This can be used to
16    /// enforce size constraints for preallocations.
17    ///
18    /// May return an error if some length condition is not met
19    /// (e.g., size constraints, overflow, etc.).
20    fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize>;
21    /// Write the length of a sequence to the writer.
22    fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()>;
23    /// Calculate the number of bytes needed to write the given length.
24    ///
25    /// Useful for variable length encoding schemes.
26    fn write_bytes_needed(len: usize) -> WriteResult<usize>;
27}
28
29const DEFAULT_BINCODE_LEN_MAX_SIZE: usize = 4 << 20; // 4 MiB
30/// [`SeqLen`] implementation for bincode's default fixint encoding.
31///
32/// The `MAX_SIZE` constant is a limit on the maximum preallocation size
33/// (in bytes) for heap allocated structures. This is a safety precaution
34/// against malicious input causing OOM. The default is 4 MiB. Users are
35/// free to override this limit by passing a different constant or by
36/// implementing their own `SeqLen` implementation.
37pub struct BincodeLen<const MAX_SIZE: usize = DEFAULT_BINCODE_LEN_MAX_SIZE>;
38
39impl<const MAX_SIZE: usize> SeqLen for BincodeLen<MAX_SIZE> {
40    #[inline(always)]
41    fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize> {
42        // Bincode's default fixint encoding writes lengths as `u64`.
43        let len = u64::get(reader)
44            .and_then(|len| usize::try_from(len).map_err(|_| pointer_sized_decode_error()))?;
45        let needed = len
46            .checked_mul(size_of::<T>())
47            .ok_or_else(|| preallocation_size_limit(usize::MAX, MAX_SIZE))?;
48        if needed > MAX_SIZE {
49            return Err(preallocation_size_limit(needed, MAX_SIZE));
50        }
51        Ok(len)
52    }
53
54    #[inline(always)]
55    fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()> {
56        u64::write(writer, &(len as u64))
57    }
58
59    #[inline(always)]
60    fn write_bytes_needed(_len: usize) -> WriteResult<usize> {
61        Ok(size_of::<u64>())
62    }
63}
64
65#[cfg(feature = "solana-short-vec")]
66pub mod short_vec {
67    use {
68        super::*,
69        crate::error::{read_length_encoding_overflow, write_length_encoding_overflow},
70        core::{
71            mem::{transmute, MaybeUninit},
72            ptr,
73        },
74        solana_short_vec::{decode_shortu16_len, ShortU16},
75    };
76
77    impl<'de> SchemaRead<'de> for ShortU16 {
78        type Dst = Self;
79
80        fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
81            let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
82                return Err(read_length_encoding_overflow("u16::MAX"));
83            };
84
85            // SAFETY: `read` is the number of bytes visited by `decode_shortu16_len` to decode the length,
86            // which implies the reader had at least `read` bytes available.
87            unsafe { reader.consume_unchecked(read) };
88
89            // SAFETY: `dst` is a valid pointer to a `MaybeUninit<ShortU16>`.
90            let slot = unsafe { &mut *(&raw mut (*dst.as_mut_ptr()).0).cast::<MaybeUninit<u16>>() };
91            // SAFETY: `len` is always a valid u16. `decode_shortu16_len` casts it to a usize before returning,
92            // so no risk of overflow.
93            slot.write(len as u16);
94            Ok(())
95        }
96    }
97
98    impl SchemaWrite for ShortU16 {
99        type Src = Self;
100
101        fn size_of(src: &Self::Src) -> WriteResult<usize> {
102            Ok(short_u16_bytes_needed(src.0))
103        }
104
105        fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
106            let val = src.0;
107            let needed = short_u16_bytes_needed(val);
108            let mut buf = [MaybeUninit::<u8>::uninit(); 3];
109            // SAFETY: short_u16 uses a maximum of 3 bytes, so the buffer is always large enough.
110            unsafe { encode_short_u16(buf.as_mut_ptr().cast::<u8>(), needed, val) };
111            // SAFETY: encode_short_u16 writes exactly `needed` bytes.
112            let buf =
113                unsafe { transmute::<&[MaybeUninit<u8>], &[u8]>(buf.get_unchecked(..needed)) };
114            writer.write(buf)?;
115            Ok(())
116        }
117    }
118
119    pub type ShortU16Len = ShortU16;
120
121    /// Branchless computation of the number of bytes needed to encode a short u16.
122    ///
123    /// See [`solana_short_vec::ShortU16`] for more details.
124    #[inline(always)]
125    #[allow(clippy::arithmetic_side_effects)]
126    fn short_u16_bytes_needed(len: u16) -> usize {
127        1 + (len >= 0x80) as usize + (len >= 0x4000) as usize
128    }
129
130    #[inline(always)]
131    fn try_short_u16_bytes_needed<T: TryInto<u16>>(len: T) -> WriteResult<usize> {
132        match len.try_into() {
133            Ok(len) => Ok(short_u16_bytes_needed(len)),
134            Err(_) => Err(write_length_encoding_overflow("u16::MAX")),
135        }
136    }
137
138    /// Encode a short u16 into the given buffer.
139    ///
140    /// See [`solana_short_vec::ShortU16`] for more details.
141    ///
142    /// # Safety
143    ///
144    /// - `dst` must be a valid for writes.
145    /// - `dst` must be valid for `needed` bytes.
146    #[inline(always)]
147    unsafe fn encode_short_u16(dst: *mut u8, needed: usize, len: u16) {
148        // From `solana_short_vec`:
149        //
150        // u16 serialized with 1 to 3 bytes. If the value is above
151        // 0x7f, the top bit is set and the remaining value is stored in the next
152        // bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
153        // byte may only have the 2 least-significant bits set, otherwise the encoded
154        // value will overflow the u16.
155        match needed {
156            1 => ptr::write(dst, len as u8),
157            2 => {
158                ptr::write(dst, ((len & 0x7f) as u8) | 0x80);
159                ptr::write(dst.add(1), (len >> 7) as u8);
160            }
161            3 => {
162                ptr::write(dst, ((len & 0x7f) as u8) | 0x80);
163                ptr::write(dst.add(1), (((len >> 7) & 0x7f) as u8) | 0x80);
164                ptr::write(dst.add(2), (len >> 14) as u8);
165            }
166            _ => unreachable!(),
167        }
168    }
169
170    impl SeqLen for ShortU16Len {
171        #[inline(always)]
172        fn read<'de, T>(reader: &mut impl Reader<'de>) -> ReadResult<usize> {
173            let Ok((len, read)) = decode_shortu16_len(reader.fill_buf(3)?) else {
174                return Err(read_length_encoding_overflow("u16::MAX"));
175            };
176            unsafe { reader.consume_unchecked(read) };
177            Ok(len)
178        }
179
180        #[inline(always)]
181        fn write(writer: &mut impl Writer, len: usize) -> WriteResult<()> {
182            if len > u16::MAX as usize {
183                return Err(write_length_encoding_overflow("u16::MAX"));
184            }
185
186            <ShortU16 as SchemaWrite>::write(writer, &ShortU16(len as u16))
187        }
188
189        #[inline(always)]
190        fn write_bytes_needed(len: usize) -> WriteResult<usize> {
191            try_short_u16_bytes_needed(len)
192        }
193    }
194
195    #[cfg(all(test, feature = "alloc", feature = "derive"))]
196    mod tests {
197        use {
198            super::*,
199            crate::{
200                containers::{self, Pod},
201                proptest_config::proptest_cfg,
202            },
203            alloc::vec::Vec,
204            proptest::prelude::*,
205            solana_short_vec::ShortU16,
206            wincode_derive::{SchemaRead, SchemaWrite},
207        };
208
209        fn our_short_u16_encode(len: u16) -> Vec<u8> {
210            let needed = short_u16_bytes_needed(len);
211            let mut buf = Vec::with_capacity(needed);
212            unsafe {
213                encode_short_u16(buf.as_mut_ptr(), needed, len);
214                buf.set_len(needed);
215            }
216            buf
217        }
218
219        #[derive(
220            serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, SchemaWrite, SchemaRead,
221        )]
222        #[wincode(internal)]
223        struct ShortVecStruct {
224            #[serde(with = "solana_short_vec")]
225            #[wincode(with = "containers::Vec<Pod<u8>, ShortU16Len>")]
226            bytes: Vec<u8>,
227            #[serde(with = "solana_short_vec")]
228            #[wincode(with = "containers::Vec<Pod<[u8; 32]>, ShortU16Len>")]
229            ar: Vec<[u8; 32]>,
230        }
231
232        #[derive(SchemaWrite, SchemaRead, serde::Serialize, serde::Deserialize)]
233        #[wincode(internal)]
234        struct ShortVecAsSchema {
235            short_u16: ShortU16,
236        }
237
238        fn strat_short_vec_struct() -> impl Strategy<Value = ShortVecStruct> {
239            (
240                proptest::collection::vec(any::<u8>(), 0..=100),
241                proptest::collection::vec(any::<[u8; 32]>(), 0..=16),
242            )
243                .prop_map(|(bytes, ar)| ShortVecStruct { bytes, ar })
244        }
245
246        proptest! {
247            #![proptest_config(proptest_cfg())]
248
249            #[test]
250            fn encode_u16_equivalence(len in 0..=u16::MAX) {
251                let our = our_short_u16_encode(len);
252                let bincode = bincode::serialize(&ShortU16(len)).unwrap();
253                prop_assert_eq!(our, bincode);
254            }
255
256            #[test]
257            fn test_short_vec_struct(short_vec_struct in strat_short_vec_struct()) {
258                let bincode_serialized = bincode::serialize(&short_vec_struct).unwrap();
259                let schema_serialized = crate::serialize(&short_vec_struct).unwrap();
260                prop_assert_eq!(&bincode_serialized, &schema_serialized);
261                let bincode_deserialized: ShortVecStruct = bincode::deserialize(&bincode_serialized).unwrap();
262                let schema_deserialized: ShortVecStruct = crate::deserialize(&schema_serialized).unwrap();
263                prop_assert_eq!(&short_vec_struct, &bincode_deserialized);
264                prop_assert_eq!(short_vec_struct, schema_deserialized);
265            }
266
267            #[test]
268            fn test_short_vec_as_schema(sv in any::<u16>()) {
269                let val = ShortVecAsSchema { short_u16: ShortU16(sv) };
270                let bincode_serialized = bincode::serialize(&val).unwrap();
271                let wincode_serialized = crate::serialize(&val).unwrap();
272                prop_assert_eq!(&bincode_serialized, &wincode_serialized);
273                let bincode_deserialized: ShortVecAsSchema = bincode::deserialize(&bincode_serialized).unwrap();
274                let wincode_deserialized: ShortVecAsSchema = crate::deserialize(&wincode_serialized).unwrap();
275                prop_assert_eq!(val.short_u16.0, bincode_deserialized.short_u16.0);
276                prop_assert_eq!(val.short_u16.0, wincode_deserialized.short_u16.0);
277            }
278        }
279    }
280}
281
282#[cfg(feature = "solana-short-vec")]
283pub use short_vec::*;