1#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![allow(clippy::arithmetic_side_effects)]
5#[cfg(feature = "frozen-abi")]
6use solana_frozen_abi_macro::AbiExample;
7use {
8 serde_core::{
9 de::{self, Deserializer, SeqAccess, Visitor},
10 ser::{self, SerializeTuple, Serializer},
11 Deserialize, Serialize,
12 },
13 std::{convert::TryFrom, fmt, marker::PhantomData},
14};
15
16#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
22pub struct ShortU16(pub u16);
23
24impl Serialize for ShortU16 {
25 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
26 where
27 S: Serializer,
28 {
29 let mut seq = serializer.serialize_tuple(1)?;
32
33 let mut rem_val = self.0;
34 loop {
35 let mut elem = (rem_val & 0x7f) as u8;
36 rem_val >>= 7;
37 if rem_val == 0 {
38 seq.serialize_element(&elem)?;
39 break;
40 } else {
41 elem |= 0x80;
42 seq.serialize_element(&elem)?;
43 }
44 }
45 seq.end()
46 }
47}
48
49enum VisitStatus {
50 Done(u16),
51 More(u16),
52}
53
54#[derive(Debug)]
55enum VisitError {
56 TooLong(usize),
57 TooShort(usize),
58 Overflow(u32),
59 Alias,
60 ByteThreeContinues,
61}
62
63impl VisitError {
64 fn into_de_error<'de, A>(self) -> A::Error
65 where
66 A: SeqAccess<'de>,
67 {
68 match self {
69 VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
70 VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
71 VisitError::Overflow(val) => de::Error::invalid_value(
72 de::Unexpected::Unsigned(val as u64),
73 &"a value in the range [0, 65535]",
74 ),
75 VisitError::Alias => de::Error::invalid_value(
76 de::Unexpected::Other("alias encoding"),
77 &"strict form encoding",
78 ),
79 VisitError::ByteThreeContinues => de::Error::invalid_value(
80 de::Unexpected::Other("continue signal on byte-three"),
81 &"a terminal signal on or before byte-three",
82 ),
83 }
84 }
85}
86
87type VisitResult = Result<VisitStatus, VisitError>;
88
89const MAX_ENCODING_LENGTH: usize = 3;
90fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
91 if elem == 0 && nth_byte != 0 {
92 return Err(VisitError::Alias);
93 }
94
95 let val = u32::from(val);
96 let elem = u32::from(elem);
97 let elem_val = elem & 0x7f;
98 let elem_done = (elem & 0x80) == 0;
99
100 if nth_byte >= MAX_ENCODING_LENGTH {
101 return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
102 } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
103 return Err(VisitError::ByteThreeContinues);
104 }
105
106 let shift = u32::try_from(nth_byte)
107 .unwrap_or(u32::MAX)
108 .saturating_mul(7);
109 let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
110
111 let new_val = val | elem_val;
112 let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
113
114 if elem_done {
115 Ok(VisitStatus::Done(val))
116 } else {
117 Ok(VisitStatus::More(val))
118 }
119}
120
121struct ShortU16Visitor;
122
123impl<'de> Visitor<'de> for ShortU16Visitor {
124 type Value = ShortU16;
125
126 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
127 formatter.write_str("a ShortU16")
128 }
129
130 fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
131 where
132 A: SeqAccess<'de>,
133 {
134 let mut val: u16 = 0;
139 for nth_byte in 0..MAX_ENCODING_LENGTH {
140 let elem: u8 = seq.next_element()?.ok_or_else(|| {
141 VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
142 })?;
143 match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
144 VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
145 VisitStatus::More(new_val) => val = new_val,
146 }
147 }
148
149 Err(VisitError::ByteThreeContinues.into_de_error::<A>())
150 }
151}
152
153impl<'de> Deserialize<'de> for ShortU16 {
154 fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
155 where
156 D: Deserializer<'de>,
157 {
158 deserializer.deserialize_tuple(3, ShortU16Visitor)
159 }
160}
161
162pub fn serialize<S: Serializer, T: Serialize>(
168 elements: &[T],
169 serializer: S,
170) -> Result<S::Ok, S::Error> {
171 let mut seq = serializer.serialize_tuple(1)?;
174
175 let len = elements.len();
176 if len > u16::MAX as usize {
177 return Err(ser::Error::custom("length larger than u16"));
178 }
179 let short_len = ShortU16(len as u16);
180 seq.serialize_element(&short_len)?;
181
182 for element in elements {
183 seq.serialize_element(element)?;
184 }
185 seq.end()
186}
187
188struct ShortVecVisitor<T> {
189 _t: PhantomData<T>,
190}
191
192impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
193where
194 T: Deserialize<'de>,
195{
196 type Value = Vec<T>;
197
198 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
199 formatter.write_str("a Vec with a multi-byte length")
200 }
201
202 fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
203 where
204 A: SeqAccess<'de>,
205 {
206 let short_len: ShortU16 = seq
207 .next_element()?
208 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
209 let len = short_len.0 as usize;
210
211 let mut result = Vec::with_capacity(len);
212 for i in 0..len {
213 let elem = seq
214 .next_element()?
215 .ok_or_else(|| de::Error::invalid_length(i, &self))?;
216 result.push(elem);
217 }
218 Ok(result)
219 }
220}
221
222pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
228where
229 D: Deserializer<'de>,
230 T: Deserialize<'de>,
231{
232 let visitor = ShortVecVisitor { _t: PhantomData };
233 deserializer.deserialize_tuple(usize::MAX, visitor)
234}
235
236pub struct ShortVec<T>(pub Vec<T>);
237
238impl<T: Serialize> Serialize for ShortVec<T> {
239 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240 where
241 S: Serializer,
242 {
243 serialize(&self.0, serializer)
244 }
245}
246
247impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
248 fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
249 where
250 D: Deserializer<'de>,
251 {
252 deserialize(deserializer).map(ShortVec)
253 }
254}
255
256#[allow(clippy::result_unit_err)]
258pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
259 let mut val = 0;
260 for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
261 match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
262 VisitStatus::More(new_val) => val = new_val,
263 VisitStatus::Done(new_val) => {
264 return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
265 }
266 }
267 }
268 Err(())
269}
270
271#[cfg(test)]
272mod tests {
273 use {
274 super::*,
275 assert_matches::assert_matches,
276 bincode::{deserialize, serialize},
277 };
278
279 fn encode_len(len: u16) -> Vec<u8> {
281 bincode::serialize(&ShortU16(len)).unwrap()
282 }
283
284 fn assert_len_encoding(len: u16, bytes: &[u8]) {
285 assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
286 assert_eq!(
287 decode_shortu16_len(bytes).unwrap(),
288 (usize::from(len), bytes.len()),
289 "unexpected usize decoding"
290 );
291 }
292
293 #[test]
294 fn test_short_vec_encode_len() {
295 assert_len_encoding(0x0, &[0x0]);
296 assert_len_encoding(0x7f, &[0x7f]);
297 assert_len_encoding(0x80, &[0x80, 0x01]);
298 assert_len_encoding(0xff, &[0xff, 0x01]);
299 assert_len_encoding(0x100, &[0x80, 0x02]);
300 assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
301 assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
302 }
303
304 fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
305 assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
306 }
307
308 fn assert_bad_deserialized_value(bytes: &[u8]) {
309 assert!(deserialize::<ShortU16>(bytes).is_err());
310 }
311
312 #[test]
313 fn test_deserialize() {
314 assert_good_deserialized_value(0x0000, &[0x00]);
315 assert_good_deserialized_value(0x007f, &[0x7f]);
316 assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
317 assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
318 assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
319 assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
320 assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
321 assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
322 assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
323
324 assert_bad_deserialized_value(&[0x80, 0x00]);
327 assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
328 assert_bad_deserialized_value(&[0xff, 0x00]);
330 assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
331 assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
333 assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
335 assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
337 assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
339 assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
341
342 assert_bad_deserialized_value(&[]);
344 assert_bad_deserialized_value(&[0x80]);
345
346 assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
348
349 assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
352 assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
354 }
355
356 #[test]
357 fn test_short_vec_u8() {
358 let vec = ShortVec(vec![4u8; 32]);
359 let bytes = serialize(&vec).unwrap();
360 assert_eq!(bytes.len(), vec.0.len() + 1);
361
362 let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
363 assert_eq!(vec.0, vec1.0);
364 }
365
366 #[test]
367 fn test_short_vec_u8_too_long() {
368 let vec = ShortVec(vec![4u8; u16::MAX as usize]);
369 assert_matches!(serialize(&vec), Ok(_));
370
371 let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
372 assert_matches!(serialize(&vec), Err(_));
373 }
374
375 #[test]
376 fn test_short_vec_json() {
377 let vec = ShortVec(vec![0, 1, 2]);
378 let s = serde_json::to_string(&vec).unwrap();
379 assert_eq!(s, "[[3],0,1,2]");
380 }
381
382 #[test]
383 fn test_short_vec_aliased_length() {
384 let bytes = [
385 0x81, 0x80, 0x00, 0x00,
387 ];
388 assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
389 }
390}