use std::{convert::TryInto, mem::size_of};
use serde::de::{
self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess,
Visitor,
};
use serde::Deserialize;
use crate::error::{Error, Result};
const BEBOP_STARTING_INDEX: usize = 1;
pub struct Deserializer<'de> {
input: &'de [u8],
skipped_index: bool,
is_message: bool,
}
impl<'de> Deserializer<'de> {
pub fn from_bytes(input: &'de [u8]) -> Self {
Deserializer {
input,
skipped_index: false,
is_message: false,
}
}
}
pub fn from_bytes<'a, T>(s: &'a [u8]) -> Result<T>
where
T: Deserialize<'a>,
{
let mut deserializer = Deserializer::from_bytes(s);
let t = T::deserialize(&mut deserializer)?;
if deserializer.input.is_empty() {
Ok(t)
} else {
Err(Error::TrailingBytes)
}
}
impl<'de> Deserializer<'de> {
fn parse_string(&mut self) -> Result<&'de str> {
let str_len = self.parse_object_size()?;
if str_len > self.input.len() {
return Err(Error::Eof);
}
let (data, remaining) = self.input.split_at(str_len);
self.input = remaining;
Ok(std::str::from_utf8(data).map_err(|_| Error::InvalidUtf8)?)
}
fn parse_object_size(&mut self) -> Result<usize> {
let size = size_of::<u32>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
Ok(u32::from_le_bytes(raw.try_into().map_err(|_| Error::InvalidNumberBytes)?) as usize)
}
}
impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn is_human_readable(&self) -> bool {
false
}
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::Message(
"Bebop does not support deserializer_any".to_string(),
))
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let byte = self.input.first().ok_or(Error::Eof)?;
let val = if *byte == 0u8 {
false
} else if *byte == 1u8 {
true
} else {
return Err(Error::InvalidBool);
};
self.input = &self.input[1..];
visitor.visit_bool(val)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<i8>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_i8(i8::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<i16>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_i16(i16::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<i32>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_i32(i32::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<i64>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_i64(i64::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let byte = self.input.first().ok_or(Error::Eof)?;
self.input = &self.input[1..];
visitor.visit_u8(*byte)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<u16>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_u16(u16::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<u32>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_u32(u32::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<u64>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_u64(u64::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<f32>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_f32(f32::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<f64>();
if size > self.input.len() {
return Err(Error::Eof);
}
let (raw, remaining) = self.input.split_at(size);
self.input = remaining;
visitor.visit_f64(f64::from_le_bytes(
raw.try_into().map_err(|_| Error::InvalidNumberBytes)?,
))
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let data = self.parse_string()?;
if data.chars().count() != 1 {
return Err(Error::InvalidChar);
}
visitor.visit_char(data.chars().next().unwrap())
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_borrowed_str(self.parse_string()?)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_string(self.parse_string()?.to_owned())
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.parse_object_size()?;
if len > self.input.len() {
return Err(Error::Eof);
}
let (data, remaining) = self.input.split_at(len);
self.input = remaining;
visitor.visit_bytes(data)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.parse_object_size()?;
if len > self.input.len() {
return Err(Error::Eof);
}
let (data, remaining) = self.input.split_at(len);
self.input = remaining;
visitor.visit_byte_buf(data.to_owned())
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.is_message && self.skipped_index {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.is_message && self.skipped_index {
Err(Error::UnexpectedData)
} else if self.is_message && !self.skipped_index {
visitor.visit_unit()
} else {
Err(Error::InvalidUnit)
}
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.parse_object_size()?;
visitor.visit_seq(List::new(&mut self, len))
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.parse_object_size()?;
visitor.visit_map(List::new(&mut self, len))
}
fn deserialize_struct<V>(
mut self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
let size = size_of::<u32>();
if size > self.input.len() {
return Err(Error::Eof);
}
let bytes = &self.input[..size];
let len =
u32::from_le_bytes(bytes.try_into().map_err(|_| Error::InvalidNumberBytes)?) as usize;
let is_message = if self.input[size..][len - 1] == 0u8 {
self.input = &self.input[size..];
self.is_message = true;
true
} else {
false
};
visitor.visit_seq(StructAccess::new(&mut self, fields.len(), is_message))
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_enum(self)
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::Message(
"Bebop does not support deserialize identifier".to_string(),
))
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_any(visitor)
}
}
struct List<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
expected_len: usize,
current_len: usize,
}
impl<'a, 'de> List<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, expected_len: usize) -> Self {
List {
de,
expected_len,
current_len: 0,
}
}
}
impl<'de, 'a> SeqAccess<'de> for List<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.current_len == self.expected_len {
Ok(None)
} else {
self.current_len += 1;
seed.deserialize(&mut *self.de).map(Some)
}
}
}
impl<'de, 'a> MapAccess<'de> for List<'a, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
{
if self.current_len == self.expected_len {
Ok(None)
} else {
self.current_len += 1;
seed.deserialize(&mut *self.de).map(Some)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.de)
}
}
struct StructAccess<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
expected_fields: usize,
is_message: bool,
next_index: usize,
}
impl<'a, 'de> StructAccess<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, expected_fields: usize, is_message: bool) -> Self {
StructAccess {
de,
expected_fields,
is_message,
next_index: BEBOP_STARTING_INDEX,
}
}
}
impl<'de, 'a> SeqAccess<'de> for StructAccess<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.next_index > self.expected_fields && !self.is_message {
Ok(None)
} else if self.next_index > self.expected_fields && self.is_message {
self.de.input = &self.de.input[1..];
Ok(None)
} else {
let expected_index = self.next_index;
self.next_index = expected_index + 1;
if self.is_message {
let possible_index = *self.de.input.first().ok_or(Error::Eof)? as usize;
if expected_index == possible_index {
self.de.input = &self.de.input[1..];
self.de.skipped_index = false;
} else {
self.de.skipped_index = true;
}
}
let res = seed.deserialize(&mut *self.de).map(Some)?;
if self.is_message && self.next_index > self.expected_fields {
self.de.input = &self.de.input[1..];
}
Ok(res)
}
}
}
impl<'a, 'de> EnumAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: DeserializeSeed<'de>,
{
let index = self.parse_object_size()?;
let val = seed.deserialize(index.into_deserializer())?;
Ok((val, self))
}
}
impl<'a, 'de> VariantAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value>
where
T: DeserializeSeed<'de>,
{
Err(Error::VariantDataNotAllowed)
}
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::VariantDataNotAllowed)
}
fn struct_variant<V>(self, _fields: &'static [&'static str], _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::VariantDataNotAllowed)
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq)]
struct SimpleStruct {
name: String,
age: u16,
}
#[test]
fn test_valid_struct() {
let data: Vec<u8> = vec![7, 0, 0, 0, 67, 104, 97, 114, 108, 105, 101, 28, 0];
let deserialized: SimpleStruct = from_bytes(&data).expect("Unable to deserialize");
let expected = SimpleStruct {
name: "Charlie".to_string(),
age: 28,
};
assert_eq!(deserialized, expected);
}
#[derive(Debug, Deserialize, PartialEq)]
struct SimpleMessage {
name: String,
age: Option<u16>,
}
#[test]
fn test_valid_message_all_fields() {
let data: Vec<u8> = vec![
16, 0, 0, 0, 1, 7, 0, 0, 0, 67, 104, 97, 114, 108, 105, 101, 2, 28, 0, 0,
];
let deserialized: SimpleMessage = from_bytes(&data).expect("Unable to deserialize");
let expected = SimpleMessage {
name: "Charlie".to_string(),
age: Some(28),
};
assert_eq!(deserialized, expected);
}
#[test]
fn test_valid_message_some_fields() {
let data: Vec<u8> = vec![
13, 0, 0, 0, 1, 7, 0, 0, 0, 67, 104, 97, 114, 108, 105, 101, 0,
];
let deserialized: SimpleMessage = from_bytes(&data).expect("Unable to deserialize");
let expected = SimpleMessage {
name: "Charlie".to_string(),
age: None,
};
assert_eq!(deserialized, expected);
}
#[derive(Debug, Deserialize, PartialEq)]
#[allow(dead_code)]
enum Fun {
Not,
Somewhat,
Really,
}
#[derive(Debug, Deserialize, PartialEq)]
struct Complex {
name: Option<String>,
fun_level: Fun,
map: HashMap<String, SimpleStruct>,
message_map: HashMap<String, SimpleMessage>,
list: Vec<f32>,
boolean: bool,
int16: i16,
int32: i32,
int64: i64,
uint16: u16,
uint32: u32,
uint64: u64,
byte: u8,
float64: f64,
}
#[test]
fn test_complex() {
let data: Vec<u8> = vec![
124, 0, 0, 0, 1, 7, 0, 0, 0, 67, 104, 97, 114, 108, 105, 101, 2, 1, 0, 0, 0, 3, 1, 0,
0, 0, 3, 0, 0, 0, 111, 110, 101, 3, 0, 0, 0, 79, 110, 101, 16, 0, 4, 1, 0, 0, 0, 3, 0,
0, 0, 111, 110, 101, 9, 0, 0, 0, 1, 3, 0, 0, 0, 79, 110, 101, 0, 5, 2, 0, 0, 0, 218,
15, 73, 64, 77, 248, 45, 64, 6, 1, 7, 253, 255, 8, 42, 0, 0, 0, 9, 21, 205, 91, 7, 0,
0, 0, 0, 10, 3, 0, 11, 42, 0, 0, 0, 12, 21, 205, 91, 7, 0, 0, 0, 0, 13, 17, 14, 74,
216, 18, 77, 251, 33, 9, 64, 0,
];
let deserialized: Complex = from_bytes(&data).expect("Unable to deserialize");
let mut map = HashMap::new();
map.insert(
"one".to_string(),
SimpleStruct {
name: "One".to_string(),
age: 16,
},
);
let mut message_map = HashMap::new();
message_map.insert(
"one".to_string(),
SimpleMessage {
name: "One".to_string(),
age: None,
},
);
#[allow(clippy::approx_constant, clippy::clippy::excessive_precision)]
let expected = Complex {
name: Some("Charlie".to_string()),
fun_level: Fun::Somewhat,
map,
message_map,
list: vec![3.1415926, 2.71828],
boolean: true,
int16: -3,
int32: 42,
int64: 123456789,
uint16: 3,
uint32: 42,
uint64: 123456789,
byte: 17,
float64: 3.1415926,
};
assert_eq!(deserialized, expected);
}
}