use crate::visitor::{
Array, BitSequence, Composite, DecodeAsTypeResult, DecodeError, Sequence, Str, Tuple,
TypeIdFor, Variant, Visitor,
};
use crate::Field;
use alloc::format;
use alloc::string::ToString;
use codec::{self, Decode};
use scale_type_resolver::{
BitsOrderFormat, BitsStoreFormat, FieldIter, PathIter, Primitive, ResolvedTypeVisitor,
TypeResolver, UnhandledKind, VariantIter,
};
pub fn decode_with_visitor<'scale, 'resolver, V: Visitor>(
data: &mut &'scale [u8],
ty_id: TypeIdFor<V>,
types: &'resolver V::TypeResolver,
visitor: V,
) -> Result<V::Value<'scale, 'resolver>, V::Error> {
decode_with_visitor_maybe_compact(data, ty_id, types, visitor, false)
}
pub fn decode_with_visitor_maybe_compact<'scale, 'resolver, V: Visitor>(
data: &mut &'scale [u8],
ty_id: TypeIdFor<V>,
types: &'resolver V::TypeResolver,
visitor: V,
is_compact: bool,
) -> Result<V::Value<'scale, 'resolver>, V::Error> {
let visitor = match visitor.unchecked_decode_as_type(data, ty_id.clone(), types) {
DecodeAsTypeResult::Decoded(r) => return r,
DecodeAsTypeResult::Skipped(v) => v,
};
let decoder = Decoder::new(data, types, ty_id.clone(), visitor, is_compact);
let res = types.resolve_type(ty_id, decoder);
match res {
Ok(Ok(val)) => Ok(val),
Ok(Err(e)) => Err(e),
Err(resolve_type_error) => {
Err(DecodeError::TypeResolvingError(resolve_type_error.to_string()).into())
}
}
}
struct Decoder<'a, 'scale, 'resolver, V: Visitor> {
data: &'a mut &'scale [u8],
type_id: TypeIdFor<V>,
types: &'resolver V::TypeResolver,
visitor: V,
is_compact: bool,
}
impl<'a, 'scale, 'resolver, V: Visitor> Decoder<'a, 'scale, 'resolver, V> {
fn new(
data: &'a mut &'scale [u8],
types: &'resolver V::TypeResolver,
type_id: TypeIdFor<V>,
visitor: V,
is_compact: bool,
) -> Self {
Decoder { data, type_id, types, is_compact, visitor }
}
}
macro_rules! skip_decoding_and_return {
($self:ident, $visit_result:ident, $visitor_ty:ident) => {{
let skip_res = $visitor_ty.skip_decoding();
if skip_res.is_ok() {
*$self.data = $visitor_ty.bytes_from_undecoded();
}
match ($visit_result, skip_res) {
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e.into()),
(Ok(v), _) => Ok(v),
}
}};
}
impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
for Decoder<'temp, 'scale, 'resolver, V>
{
type TypeId = TypeIdFor<V>;
type Value = Result<V::Value<'scale, 'resolver>, V::Error>;
fn visit_unhandled(self, kind: UnhandledKind) -> Self::Value {
let type_id = self.type_id;
Err(DecodeError::TypeIdNotFound(format!(
"Kind {kind:?} (type ID {type_id:?}) has not been properly handled"
))
.into())
}
fn visit_not_found(self) -> Self::Value {
let type_id = self.type_id;
Err(DecodeError::TypeIdNotFound(format!("{type_id:?}")).into())
}
fn visit_composite<Path, Fields>(self, path: Path, mut fields: Fields) -> Self::Value
where
Path: PathIter<'resolver>,
Fields: FieldIter<'resolver, Self::TypeId>,
{
if self.is_compact && fields.len() != 1 {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let mut items = Composite::new(path, self.data, &mut fields, self.types, self.is_compact);
let res = self.visitor.visit_composite(&mut items, self.type_id);
skip_decoding_and_return!(self, res, items)
}
fn visit_variant<Path, Fields, Var>(self, _path: Path, variants: Var) -> Self::Value
where
Path: PathIter<'resolver>,
Fields: FieldIter<'resolver, Self::TypeId>,
Var: VariantIter<'resolver, Fields>,
{
if self.is_compact {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let mut variant = Variant::new(self.data, variants, self.types)?;
let res = self.visitor.visit_variant(&mut variant, self.type_id);
skip_decoding_and_return!(self, res, variant)
}
fn visit_sequence<Path>(self, _path: Path, inner_type_id: Self::TypeId) -> Self::Value
where
Path: PathIter<'resolver>,
{
if self.is_compact {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let mut items = Sequence::new(self.data, inner_type_id, self.types)?;
let res = self.visitor.visit_sequence(&mut items, self.type_id);
skip_decoding_and_return!(self, res, items)
}
fn visit_array(self, inner_type_id: Self::TypeId, len: usize) -> Self::Value {
if self.is_compact {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let mut arr = Array::new(self.data, inner_type_id, len, self.types);
let res = self.visitor.visit_array(&mut arr, self.type_id);
skip_decoding_and_return!(self, res, arr)
}
fn visit_tuple<TypeIds>(self, type_ids: TypeIds) -> Self::Value
where
TypeIds: ExactSizeIterator<Item = Self::TypeId>,
{
if self.is_compact && type_ids.len() != 1 {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let mut fields = type_ids.map(Field::unnamed);
let mut items = Tuple::new(self.data, &mut fields, self.types, self.is_compact);
let res = self.visitor.visit_tuple(&mut items, self.type_id);
skip_decoding_and_return!(self, res, items)
}
fn visit_primitive(self, primitive: Primitive) -> Self::Value {
macro_rules! err_if_compact {
($is_compact:expr) => {
if $is_compact {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
};
}
fn decode_32_bytes<'scale>(
data: &mut &'scale [u8],
) -> Result<&'scale [u8; 32], DecodeError> {
let arr: &'scale [u8; 32] = match (*data).try_into() {
Ok(arr) => arr,
Err(_) => return Err(DecodeError::NotEnoughInput),
};
*data = &data[32..];
Ok(arr)
}
let data = self.data;
let is_compact = self.is_compact;
let visitor = self.visitor;
let type_id = self.type_id;
match primitive {
Primitive::Bool => {
err_if_compact!(is_compact);
let b = bool::decode(data).map_err(|e| e.into())?;
visitor.visit_bool(b, type_id)
}
Primitive::Char => {
err_if_compact!(is_compact);
let val = u32::decode(data).map_err(|e| e.into())?;
let c = char::from_u32(val).ok_or(DecodeError::InvalidChar(val))?;
visitor.visit_char(c, type_id)
}
Primitive::Str => {
err_if_compact!(is_compact);
let mut s = Str::new(data)?;
*data = s.bytes_after()?;
visitor.visit_str(&mut s, type_id)
}
Primitive::U8 => {
let n = if is_compact {
codec::Compact::<u8>::decode(data).map(|c| c.0)
} else {
u8::decode(data)
}
.map_err(Into::into)?;
visitor.visit_u8(n, type_id)
}
Primitive::U16 => {
let n = if is_compact {
codec::Compact::<u16>::decode(data).map(|c| c.0)
} else {
u16::decode(data)
}
.map_err(Into::into)?;
visitor.visit_u16(n, type_id)
}
Primitive::U32 => {
let n = if is_compact {
codec::Compact::<u32>::decode(data).map(|c| c.0)
} else {
u32::decode(data)
}
.map_err(Into::into)?;
visitor.visit_u32(n, type_id)
}
Primitive::U64 => {
let n = if is_compact {
codec::Compact::<u64>::decode(data).map(|c| c.0)
} else {
u64::decode(data)
}
.map_err(Into::into)?;
visitor.visit_u64(n, type_id)
}
Primitive::U128 => {
let n = if is_compact {
codec::Compact::<u128>::decode(data).map(|c| c.0)
} else {
u128::decode(data)
}
.map_err(Into::into)?;
visitor.visit_u128(n, type_id)
}
Primitive::U256 => {
err_if_compact!(is_compact);
let arr = decode_32_bytes(data)?;
visitor.visit_u256(arr, type_id)
}
Primitive::I8 => {
err_if_compact!(is_compact);
let n = i8::decode(data).map_err(|e| e.into())?;
visitor.visit_i8(n, type_id)
}
Primitive::I16 => {
err_if_compact!(is_compact);
let n = i16::decode(data).map_err(|e| e.into())?;
visitor.visit_i16(n, type_id)
}
Primitive::I32 => {
err_if_compact!(is_compact);
let n = i32::decode(data).map_err(|e| e.into())?;
visitor.visit_i32(n, type_id)
}
Primitive::I64 => {
err_if_compact!(is_compact);
let n = i64::decode(data).map_err(|e| e.into())?;
visitor.visit_i64(n, type_id)
}
Primitive::I128 => {
err_if_compact!(is_compact);
let n = i128::decode(data).map_err(|e| e.into())?;
visitor.visit_i128(n, type_id)
}
Primitive::I256 => {
err_if_compact!(is_compact);
let arr = decode_32_bytes(data)?;
visitor.visit_i256(arr, type_id)
}
}
}
fn visit_compact(self, inner_type_id: Self::TypeId) -> Self::Value {
decode_with_visitor_maybe_compact(self.data, inner_type_id, self.types, self.visitor, true)
}
fn visit_bit_sequence(
self,
store_format: BitsStoreFormat,
order_format: BitsOrderFormat,
) -> Self::Value {
if self.is_compact {
return Err(DecodeError::CannotDecodeCompactIntoType.into());
}
let format = scale_bits::Format::new(store_format, order_format);
let mut bitseq = BitSequence::new(format, self.data);
let res = self.visitor.visit_bitsequence(&mut bitseq, self.type_id);
*self.data = bitseq.bytes_after()?;
res
}
}