scale_decode/visitor/
decode.rs

1// Copyright (C) 2022 Parity Technologies (UK) Ltd. (admin@parity.io)
2// This file is a part of the scale-decode crate.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//         http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15use crate::visitor::{
16    Array, BitSequence, Composite, DecodeAsTypeResult, DecodeError, Sequence, Str, Tuple,
17    TypeIdFor, Variant, Visitor,
18};
19use crate::Field;
20use alloc::format;
21use alloc::string::ToString;
22use codec::{self, Decode};
23use scale_type_resolver::{
24    BitsOrderFormat, BitsStoreFormat, FieldIter, PathIter, Primitive, ResolvedTypeVisitor,
25    TypeResolver, UnhandledKind, VariantIter,
26};
27
28/// Decode data according to the type ID and type resolver provided.
29/// The provided pointer to the data slice will be moved forwards as needed
30/// depending on what was decoded, and a method on the provided [`Visitor`]
31/// will be called depending on the type that needs to be decoded.
32pub fn decode_with_visitor<'scale, 'resolver, V: Visitor>(
33    data: &mut &'scale [u8],
34    ty_id: TypeIdFor<V>,
35    types: &'resolver V::TypeResolver,
36    visitor: V,
37) -> Result<V::Value<'scale, 'resolver>, V::Error> {
38    decode_with_visitor_maybe_compact(data, ty_id, types, visitor, false)
39}
40
41pub fn decode_with_visitor_maybe_compact<'scale, 'resolver, V: Visitor>(
42    data: &mut &'scale [u8],
43    ty_id: TypeIdFor<V>,
44    types: &'resolver V::TypeResolver,
45    visitor: V,
46    is_compact: bool,
47) -> Result<V::Value<'scale, 'resolver>, V::Error> {
48    // Provide option to "bail out" and do something custom first.
49    let visitor = match visitor.unchecked_decode_as_type(data, ty_id.clone(), types) {
50        DecodeAsTypeResult::Decoded(r) => return r,
51        DecodeAsTypeResult::Skipped(v) => v,
52    };
53
54    let decoder = Decoder::new(data, types, ty_id.clone(), visitor, is_compact);
55    let res = types.resolve_type(ty_id, decoder);
56
57    match res {
58        // We got a value back; return it
59        Ok(Ok(val)) => Ok(val),
60        // We got a visitor error back; return it
61        Ok(Err(e)) => Err(e),
62        // We got a TypeResolver error back; turn it into a DecodeError and then visitor error to return.
63        Err(resolve_type_error) => {
64            Err(DecodeError::TypeResolvingError(resolve_type_error.to_string()).into())
65        }
66    }
67}
68
69/// This struct implements `ResolvedTypeVisitor`. One of those methods fired depending on the type that
70/// we resolve from the given TypeId, and then based on the information handed to that method we decode
71/// the SCALE encoded bytes as needed and then call the relevant method on the `scale_decode::Visitor` to
72/// hand back the decoded value (or some nice interface to allow the user to decode the value).
73struct Decoder<'a, 'scale, 'resolver, V: Visitor> {
74    data: &'a mut &'scale [u8],
75    type_id: TypeIdFor<V>,
76    types: &'resolver V::TypeResolver,
77    visitor: V,
78    is_compact: bool,
79}
80
81impl<'a, 'scale, 'resolver, V: Visitor> Decoder<'a, 'scale, 'resolver, V> {
82    fn new(
83        data: &'a mut &'scale [u8],
84        types: &'resolver V::TypeResolver,
85        type_id: TypeIdFor<V>,
86        visitor: V,
87        is_compact: bool,
88    ) -> Self {
89        Decoder { data, type_id, types, is_compact, visitor }
90    }
91}
92
93// Our types like Composite/Variant/Sequence/Array/Tuple all use the same
94// approach to skip over any bytes that the visitor didn't consume, so this
95// macro performs that logic.
96macro_rules! skip_decoding_and_return {
97    ($self:ident, $visit_result:ident, $visitor_ty:ident) => {{
98        // Skip over any bytes that the visitor chose not to decode:
99        let skip_res = $visitor_ty.skip_decoding();
100        if skip_res.is_ok() {
101            *$self.data = $visitor_ty.bytes_from_undecoded();
102        }
103
104        // Prioritize returning visitor errors over skip_decoding errors.
105        match ($visit_result, skip_res) {
106            (Err(e), _) => Err(e),
107            (_, Err(e)) => Err(e.into()),
108            (Ok(v), _) => Ok(v),
109        }
110    }};
111}
112
113impl<'temp, 'scale, 'resolver, V: Visitor> ResolvedTypeVisitor<'resolver>
114    for Decoder<'temp, 'scale, 'resolver, V>
115{
116    type TypeId = TypeIdFor<V>;
117    type Value = Result<V::Value<'scale, 'resolver>, V::Error>;
118
119    fn visit_unhandled(self, kind: UnhandledKind) -> Self::Value {
120        let type_id = self.type_id;
121        Err(DecodeError::TypeIdNotFound(format!(
122            "Kind {kind:?} (type ID {type_id:?}) has not been properly handled"
123        ))
124        .into())
125    }
126
127    fn visit_not_found(self) -> Self::Value {
128        let type_id = self.type_id;
129        Err(DecodeError::TypeIdNotFound(format!("{type_id:?}")).into())
130    }
131
132    fn visit_composite<Path, Fields>(self, path: Path, mut fields: Fields) -> Self::Value
133    where
134        Path: PathIter<'resolver>,
135        Fields: FieldIter<'resolver, Self::TypeId>,
136    {
137        // guard against invalid compact types: only composites with 1 field can be compact encoded
138        if self.is_compact && fields.len() != 1 {
139            return Err(DecodeError::CannotDecodeCompactIntoType.into());
140        }
141
142        let mut items = Composite::new(path, self.data, &mut fields, self.types, self.is_compact);
143        let res = self.visitor.visit_composite(&mut items, self.type_id);
144
145        skip_decoding_and_return!(self, res, items)
146    }
147
148    fn visit_variant<Path, Fields, Var>(self, path: Path, variants: Var) -> Self::Value
149    where
150        Path: PathIter<'resolver>,
151        Fields: FieldIter<'resolver, Self::TypeId>,
152        Var: VariantIter<'resolver, Fields>,
153    {
154        if self.is_compact {
155            return Err(DecodeError::CannotDecodeCompactIntoType.into());
156        }
157
158        let mut variant = Variant::new(path, self.data, variants, self.types)?;
159        let res = self.visitor.visit_variant(&mut variant, self.type_id);
160
161        skip_decoding_and_return!(self, res, variant)
162    }
163
164    fn visit_sequence<Path>(self, path: Path, inner_type_id: Self::TypeId) -> Self::Value
165    where
166        Path: PathIter<'resolver>,
167    {
168        if self.is_compact {
169            return Err(DecodeError::CannotDecodeCompactIntoType.into());
170        }
171
172        let mut items = Sequence::new(path, self.data, inner_type_id, self.types)?;
173        let res = self.visitor.visit_sequence(&mut items, self.type_id);
174
175        skip_decoding_and_return!(self, res, items)
176    }
177
178    fn visit_array(self, inner_type_id: Self::TypeId, len: usize) -> Self::Value {
179        if self.is_compact {
180            return Err(DecodeError::CannotDecodeCompactIntoType.into());
181        }
182
183        let mut arr = Array::new(self.data, inner_type_id, len, self.types);
184        let res = self.visitor.visit_array(&mut arr, self.type_id);
185
186        skip_decoding_and_return!(self, res, arr)
187    }
188
189    fn visit_tuple<TypeIds>(self, type_ids: TypeIds) -> Self::Value
190    where
191        TypeIds: ExactSizeIterator<Item = Self::TypeId>,
192    {
193        // guard against invalid compact types: only composites with 1 field can be compact encoded
194        if self.is_compact && type_ids.len() != 1 {
195            return Err(DecodeError::CannotDecodeCompactIntoType.into());
196        }
197
198        let mut fields = type_ids.map(Field::unnamed);
199        let mut items = Tuple::new(self.data, &mut fields, self.types, self.is_compact);
200        let res = self.visitor.visit_tuple(&mut items, self.type_id);
201
202        skip_decoding_and_return!(self, res, items)
203    }
204
205    fn visit_primitive(self, primitive: Primitive) -> Self::Value {
206        macro_rules! err_if_compact {
207            ($is_compact:expr) => {
208                if $is_compact {
209                    return Err(DecodeError::CannotDecodeCompactIntoType.into());
210                }
211            };
212        }
213
214        fn decode_32_bytes<'scale>(
215            data: &mut &'scale [u8],
216        ) -> Result<&'scale [u8; 32], DecodeError> {
217            // Pull an array from the data if we can, preserving the lifetime.
218            let arr: &'scale [u8; 32] = match (*data).try_into() {
219                Ok(arr) => arr,
220                Err(_) => return Err(DecodeError::NotEnoughInput),
221            };
222            // If we successfully read the bytes, then advance the pointer past them.
223            *data = &data[32..];
224            Ok(arr)
225        }
226
227        let data = self.data;
228        let is_compact = self.is_compact;
229        let visitor = self.visitor;
230        let type_id = self.type_id;
231
232        match primitive {
233            Primitive::Bool => {
234                err_if_compact!(is_compact);
235                let b = bool::decode(data).map_err(|e| e.into())?;
236                visitor.visit_bool(b, type_id)
237            }
238            Primitive::Char => {
239                err_if_compact!(is_compact);
240                // Treat chars as u32's
241                let val = u32::decode(data).map_err(|e| e.into())?;
242                let c = char::from_u32(val).ok_or(DecodeError::InvalidChar(val))?;
243                visitor.visit_char(c, type_id)
244            }
245            Primitive::Str => {
246                err_if_compact!(is_compact);
247                // Avoid allocating; don't decode into a String. instead, pull the bytes
248                // and let the visitor decide whether to use them or not.
249                let mut s = Str::new(data)?;
250                // Since we aren't decoding here, shift our bytes along to after the str:
251                *data = s.bytes_after()?;
252                visitor.visit_str(&mut s, type_id)
253            }
254            Primitive::U8 => {
255                let n = if is_compact {
256                    codec::Compact::<u8>::decode(data).map(|c| c.0)
257                } else {
258                    u8::decode(data)
259                }
260                .map_err(Into::into)?;
261                visitor.visit_u8(n, type_id)
262            }
263            Primitive::U16 => {
264                let n = if is_compact {
265                    codec::Compact::<u16>::decode(data).map(|c| c.0)
266                } else {
267                    u16::decode(data)
268                }
269                .map_err(Into::into)?;
270                visitor.visit_u16(n, type_id)
271            }
272            Primitive::U32 => {
273                let n = if is_compact {
274                    codec::Compact::<u32>::decode(data).map(|c| c.0)
275                } else {
276                    u32::decode(data)
277                }
278                .map_err(Into::into)?;
279                visitor.visit_u32(n, type_id)
280            }
281            Primitive::U64 => {
282                let n = if is_compact {
283                    codec::Compact::<u64>::decode(data).map(|c| c.0)
284                } else {
285                    u64::decode(data)
286                }
287                .map_err(Into::into)?;
288                visitor.visit_u64(n, type_id)
289            }
290            Primitive::U128 => {
291                let n = if is_compact {
292                    codec::Compact::<u128>::decode(data).map(|c| c.0)
293                } else {
294                    u128::decode(data)
295                }
296                .map_err(Into::into)?;
297                visitor.visit_u128(n, type_id)
298            }
299            Primitive::U256 => {
300                err_if_compact!(is_compact);
301                let arr = decode_32_bytes(data)?;
302                visitor.visit_u256(arr, type_id)
303            }
304            Primitive::I8 => {
305                err_if_compact!(is_compact);
306                let n = i8::decode(data).map_err(|e| e.into())?;
307                visitor.visit_i8(n, type_id)
308            }
309            Primitive::I16 => {
310                err_if_compact!(is_compact);
311                let n = i16::decode(data).map_err(|e| e.into())?;
312                visitor.visit_i16(n, type_id)
313            }
314            Primitive::I32 => {
315                err_if_compact!(is_compact);
316                let n = i32::decode(data).map_err(|e| e.into())?;
317                visitor.visit_i32(n, type_id)
318            }
319            Primitive::I64 => {
320                err_if_compact!(is_compact);
321                let n = i64::decode(data).map_err(|e| e.into())?;
322                visitor.visit_i64(n, type_id)
323            }
324            Primitive::I128 => {
325                err_if_compact!(is_compact);
326                let n = i128::decode(data).map_err(|e| e.into())?;
327                visitor.visit_i128(n, type_id)
328            }
329            Primitive::I256 => {
330                err_if_compact!(is_compact);
331                let arr = decode_32_bytes(data)?;
332                visitor.visit_i256(arr, type_id)
333            }
334        }
335    }
336
337    fn visit_compact(self, inner_type_id: Self::TypeId) -> Self::Value {
338        decode_with_visitor_maybe_compact(self.data, inner_type_id, self.types, self.visitor, true)
339    }
340
341    fn visit_bit_sequence(
342        self,
343        store_format: BitsStoreFormat,
344        order_format: BitsOrderFormat,
345    ) -> Self::Value {
346        if self.is_compact {
347            return Err(DecodeError::CannotDecodeCompactIntoType.into());
348        }
349
350        let format = scale_bits::Format::new(store_format, order_format);
351        let mut bitseq = BitSequence::new(format, self.data);
352        let res = self.visitor.visit_bitsequence(&mut bitseq, self.type_id);
353
354        // Move to the bytes after the bit sequence.
355        *self.data = bitseq.bytes_after()?;
356
357        res
358    }
359}