sftp_protocol/
parser.rs

1use std::io;
2use std::io::Write;
3
4use circular::Buffer;
5use nom_derive::Parse;
6
7use crate::packet::Packet;
8use crate::packet::PacketHeader;
9use crate::Error;
10
11#[derive(Clone, Debug)]
12pub struct Parser {
13	buffer: Buffer,
14	#[allow(dead_code)]
15	current_header: Option<PacketHeader>
16}
17
18impl Default for Parser {
19	fn default() -> Self {
20		#[allow(clippy::identity_op)] // Including the 1 makes it more clear that this is 1MiB
21		Self::with_capacity(1 * 1024 * 1024)
22	}
23}
24
25impl Parser {
26	pub fn with_capacity(capacity: usize) -> Self {
27		Self{
28			buffer: Buffer::with_capacity(capacity),
29			current_header: None
30		}
31	}
32
33	#[inline]
34	pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
35		self.buffer.write(buf)
36	}
37
38	fn check_packet(&mut self) -> bool {
39		match &self.current_header {
40			None => {
41				if(self.buffer.available_data() >= 5) {
42					let header_bytes = &self.buffer.data()[..5];
43					self.current_header = Some(PacketHeader::parse(header_bytes).unwrap().1);
44					self.check_packet()
45				} else {
46					false
47				}
48			},
49			Some(header) => {
50				self.buffer.available_data() >= header.length as usize + 4
51			}
52		}
53	}
54
55	pub fn get_packet(&mut self) -> Result<Option<Packet>, Error> {
56		match (self.check_packet(), &self.current_header) {
57			(false, _) => Ok(None),
58			(true, None) => Err(Error::NoHeader),
59			(true, Some(header)) => {
60				let len = header.length as usize + 4; // Exclude type byte, we already have it
61				let data = &self.buffer.data()[..len];
62				let packet = Packet::parse(data)?.1;
63				self.buffer.consume(len);
64				self.current_header = None;
65				Ok(Some(packet))
66			}
67		}
68	}
69}
70
71#[cfg(test)]
72pub(crate) fn encode(packet: &Packet) -> Vec<u8> {
73	use bincode::Options;
74	let en = bincode::DefaultOptions::new().with_big_endian().with_fixint_encoding();
75	en.serialize(packet).unwrap()
76}
77
78#[cfg(test)]
79mod test {
80	use rand::Rng;
81	use test_strategy::proptest;
82	use crate::packet::*;
83	use super::*;
84
85	fn random_slices<T>(slice: &[T]) -> Vec<&[T]> /* {{{ */ {
86		let mut rng = rand::thread_rng();
87		let magic = rng.gen_range(2..std::cmp::min(16, slice.len() / 2));
88		let count = rng.gen_range(1..(slice.len() / magic));
89		let approx_size = slice.len() / count;
90		let mut lengths = (0..count).map(|_| rng.gen_range((approx_size / 2)..((approx_size * 3) / 2))).collect::<Vec<_>>();
91		while(lengths.iter().sum::<usize>() != slice.len()) {
92			if(lengths.iter().sum::<usize>() < slice.len()) {
93				for len in lengths.iter_mut() {
94					if(rng.gen::<f32>() < 0.01) {
95						(*len) += 1;
96					}
97				}
98			} else {
99				for len in lengths.iter_mut() {
100					if(*len > 1 && rng.gen::<f32>() < 0.01) {
101						(*len) -= 1;
102					}
103				}
104			}
105		}
106		let mut subslices = Vec::with_capacity(lengths.len());
107		let mut current_pos = 0;
108		for len in lengths {
109			subslices.push(&slice[current_pos..(len + current_pos)]);
110			current_pos += len;
111		}
112		subslices
113	} // }}}
114
115	/*
116	#[test]
117	fn test_random_slices() {
118		let pile_of_bytes = vec![0u8; 1024];
119		let slices = random_slices(&pile_of_bytes);
120		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
121		let pile_of_bytes = vec![0u8; 32];
122		let slices = random_slices(&pile_of_bytes);
123		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
124		let pile_of_bytes = vec![0u8; 1024 * 1024];
125		let slices = random_slices(&pile_of_bytes);
126		assert_eq!(slices.iter().map(|s| s.len()).sum::<usize>(), pile_of_bytes.len());
127	}
128	*/
129
130	#[ignore] // It's broken and doesn't have any actual functional impact
131	#[test]
132	fn validate_packet_result_size() {
133		assert_eq!(std::mem::size_of::<Result<Packet, Error>>(), std::mem::size_of::<Packet>() + std::mem::size_of::<*const ()>());
134		assert_eq!(std::mem::size_of::<Result<Option<Packet>, Error>>(), std::mem::size_of::<Result<Packet, Error>>());
135	}
136
137	#[test]
138	fn single_init() {
139		let mut stream = Parser::default();
140		assert_eq!(stream.get_packet(), Ok(None));
141		let packet = Payload::init(1, vec![]).into_packet();
142		stream.write(&encode(&packet)).unwrap();
143		assert_eq!(stream.get_packet(), Ok(Some(packet)));
144		assert_eq!(stream.get_packet(), Ok(None));
145	}
146
147	#[test]
148	fn multipart_init() {
149		let mut stream = Parser::default();
150		let packet = Payload::init(2, vec![]).into_packet();
151		let bytes = encode(&packet);
152		stream.write(&bytes[0..3]).unwrap();
153		assert_eq!(stream.get_packet(), Ok(None));
154		stream.write(&bytes[3..bytes.len()]).unwrap();
155		assert_eq!(stream.get_packet(), Ok(Some(packet)));
156	}
157
158	#[test]
159	fn handshake() {
160		let mut stream = Parser::default();
161		let init = Payload::init(32768, (0..100).collect()).into_packet();
162		stream.write(&encode(&init)).unwrap();
163		assert_eq!(stream.get_packet(), Ok(Some(init)));
164		let version = Payload::version(3, (100..150).collect()).into_packet();
165		stream.write(&encode(&version)).unwrap();
166		assert_eq!(stream.get_packet(), Ok(Some(version)));
167		assert_eq!(stream.get_packet(), Ok(None));
168	}
169
170	#[test]
171	fn handshake_queued() {
172		let mut stream = Parser::default();
173		let init = Payload::init(32768, (0..100).collect()).into_packet();
174		stream.write(&encode(&init)).unwrap();
175		let version = Payload::version(3, (100..150).collect()).into_packet();
176		stream.write(&encode(&version)).unwrap();
177		assert_eq!(stream.get_packet(), Ok(Some(init)));
178		assert_eq!(stream.get_packet(), Ok(Some(version)));
179		assert_eq!(stream.get_packet(), Ok(None));
180	}
181
182	#[proptest]
183	fn arbitrary_sequence(input: Vec<Packet>) {
184		let mut stream = Parser::default();
185		assert_eq!(stream.get_packet(), Ok(None));
186		for packet in input {
187			stream.write(&encode(&packet)).unwrap();
188			assert_eq!(stream.get_packet(), Ok(Some(packet)));
189			assert_eq!(stream.get_packet(), Ok(None));
190		}
191		assert_eq!(stream.get_packet(), Ok(None));
192	}
193
194	#[proptest]
195	fn arbitrary_sequence_queued(input: Vec<Packet>) {
196		let mut stream = Parser::default();
197		assert_eq!(stream.get_packet(), Ok(None));
198		for packet in &input {
199			stream.write(&encode(packet)).unwrap();
200		}
201		for packet in input {
202			assert_eq!(stream.get_packet(), Ok(Some(packet)));
203		}
204		assert_eq!(stream.get_packet(), Ok(None));
205	}
206
207	#[proptest]
208	fn arbitrary_sequence_multipart(input: Vec<Packet>) {
209		let mut stream = Parser::default();
210		assert_eq!(stream.get_packet(), Ok(None));
211		for packet in input {
212			let bytes = encode(&packet);
213			let slices = random_slices(&bytes);
214			for slice in slices {
215				stream.write(slice).unwrap();
216			}
217			assert_eq!(stream.get_packet(), Ok(Some(packet)));
218			assert_eq!(stream.get_packet(), Ok(None));
219		}
220		assert_eq!(stream.get_packet(), Ok(None));
221	}
222
223	#[proptest]
224	fn arbitrary_sequence_multipart_queued(input: Vec<Packet>) {
225		let mut stream = Parser::default();
226		assert_eq!(stream.get_packet(), Ok(None));
227		for packet in &input {
228			let bytes = encode(packet);
229			let slices = random_slices(&bytes);
230			for slice in slices {
231				stream.write(slice).unwrap();
232			}
233		}
234		for packet in input {
235			assert_eq!(stream.get_packet(), Ok(Some(packet)));
236		}
237		assert_eq!(stream.get_packet(), Ok(None));
238	}
239}
240