tetsy_scale_codec/
depth_limit.rs

1// Copyright 2017, 2018 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{Error, Decode, Input};
16
17/// The error message returned when depth limit is reached.
18const DECODE_MAX_DEPTH_MSG: &str = "Maximum recursion depth reached when decoding";
19
20/// Extension trait to [`Decode`] for decoding with a maximum recursion depth.
21pub trait DecodeLimit: Sized {
22	/// Decode `Self` with the given maximum recursion depth.
23	///
24	/// If `limit` is hit, an error is returned.
25	fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
26
27	/// Decode `Self` and consume all of the given input data.
28	///
29	/// If not all data is consumed or `limit` is hit, an error is returned.
30	fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
31}
32
33
34struct DepthTrackingInput<'a, I> {
35	input: &'a mut I,
36	depth: u32,
37	max_depth: u32,
38}
39
40impl<'a, I:Input> Input for DepthTrackingInput<'a, I> {
41	fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
42		self.input.remaining_len()
43	}
44
45	fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
46		self.input.read(into)
47	}
48
49	fn read_byte(&mut self) -> Result<u8, Error> {
50		self.input.read_byte()
51	}
52
53	fn descend_ref(&mut self) -> Result<(), Error> {
54		self.input.descend_ref()?;
55		self.depth += 1;
56		if self.depth > self.max_depth {
57			Err(DECODE_MAX_DEPTH_MSG.into())
58		} else {
59			Ok(())
60		}
61	}
62
63	fn ascend_ref(&mut self) {
64		self.input.ascend_ref();
65		self.depth -= 1;
66	}
67}
68
69impl<T: Decode> DecodeLimit for T {
70	fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
71		let mut input = DepthTrackingInput {
72			input: &mut &input[..],
73			depth: 0,
74			max_depth: limit,
75		};
76		let res = T::decode(&mut input)?;
77
78		if input.input.is_empty() {
79			Ok(res)
80		} else {
81			Err(crate::decode_all::DECODE_ALL_ERR_MSG.into())
82		}
83	}
84
85	fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
86		let mut input = DepthTrackingInput {
87			input: &mut &input[..],
88			depth: 0,
89			max_depth: limit,
90		};
91		T::decode(&mut input)
92	}
93}
94
95#[cfg(test)]
96mod tests {
97	use super::*;
98	use crate::Encode;
99
100	#[test]
101	fn decode_limit_works() {
102		type NestedVec = Vec<Vec<Vec<Vec<u8>>>>;
103		let nested: NestedVec = vec![vec![vec![vec![1]]]];
104		let encoded = nested.encode();
105
106		let decoded = NestedVec::decode_with_depth_limit(3, &encoded).unwrap();
107		assert_eq!(decoded, nested);
108		assert!(NestedVec::decode_with_depth_limit(2, &encoded).is_err());
109	}
110}