tetsy_scale_codec/
depth_limit.rs1use crate::{Error, Decode, Input};
16
17const DECODE_MAX_DEPTH_MSG: &str = "Maximum recursion depth reached when decoding";
19
20pub trait DecodeLimit: Sized {
22 fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
26
27 fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
31}
32
33
34struct DepthTrackingInput<'a, I> {
35 input: &'a mut I,
36 depth: u32,
37 max_depth: u32,
38}
39
40impl<'a, I:Input> Input for DepthTrackingInput<'a, I> {
41 fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
42 self.input.remaining_len()
43 }
44
45 fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
46 self.input.read(into)
47 }
48
49 fn read_byte(&mut self) -> Result<u8, Error> {
50 self.input.read_byte()
51 }
52
53 fn descend_ref(&mut self) -> Result<(), Error> {
54 self.input.descend_ref()?;
55 self.depth += 1;
56 if self.depth > self.max_depth {
57 Err(DECODE_MAX_DEPTH_MSG.into())
58 } else {
59 Ok(())
60 }
61 }
62
63 fn ascend_ref(&mut self) {
64 self.input.ascend_ref();
65 self.depth -= 1;
66 }
67}
68
69impl<T: Decode> DecodeLimit for T {
70 fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
71 let mut input = DepthTrackingInput {
72 input: &mut &input[..],
73 depth: 0,
74 max_depth: limit,
75 };
76 let res = T::decode(&mut input)?;
77
78 if input.input.is_empty() {
79 Ok(res)
80 } else {
81 Err(crate::decode_all::DECODE_ALL_ERR_MSG.into())
82 }
83 }
84
85 fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
86 let mut input = DepthTrackingInput {
87 input: &mut &input[..],
88 depth: 0,
89 max_depth: limit,
90 };
91 T::decode(&mut input)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::Encode;
99
100 #[test]
101 fn decode_limit_works() {
102 type NestedVec = Vec<Vec<Vec<Vec<u8>>>>;
103 let nested: NestedVec = vec![vec![vec![vec![1]]]];
104 let encoded = nested.encode();
105
106 let decoded = NestedVec::decode_with_depth_limit(3, &encoded).unwrap();
107 assert_eq!(decoded, nested);
108 assert!(NestedVec::decode_with_depth_limit(2, &encoded).is_err());
109 }
110}