1use crate::{PayloadLength, Opcode, FrameInfo, masking};
2
3use nonmax::NonMaxU8;
4
5#[cfg(feature="large_frames")]
7pub type FrameDecoderError = core::convert::Infallible;
8
9#[allow(missing_docs)]
12#[cfg(not(feature="large_frames"))]
13#[derive(Debug,PartialEq, Eq, PartialOrd, Ord,Hash,Clone, Copy)]
14pub enum FrameDecoderError {
15 ExceededFrameSize,
16}
17
18#[derive(Clone, Copy, Debug)]
19struct SmallBufWithLen<const C: usize> {
20 len: u8,
21 data: [u8; C],
22}
23
24impl<const C: usize> SmallBufWithLen<C> {
25 fn slurp<'a, 'c>(&'c mut self, data: &'a mut [u8]) -> &'a mut [u8] {
27 let offset = self.len as usize;
28 let maxlen = (C - offset).min(data.len());
29 self.data[offset..(offset+maxlen)].copy_from_slice(&data[..maxlen]);
30 self.len += maxlen as u8;
31 &mut data[maxlen..]
32 }
33 fn is_full(&self) -> bool {
34 self.len as usize == C
35 }
36 const fn new() -> SmallBufWithLen<C> {
37 SmallBufWithLen {
38 len: 0,
39 data: [0u8; C],
40 }
41 }
42}
43
44#[derive(Clone, Copy, Debug)]
46enum FrameDecodingState {
47 HeaderBeginning(SmallBufWithLen<2>),
48 PayloadLength16(SmallBufWithLen<2>),
49 #[cfg(feature="large_frames")]
50 PayloadLength64(SmallBufWithLen<8>),
51 MaskingKey(SmallBufWithLen<4>),
52 PayloadData {
53 phase: Option<NonMaxU8>,
54 remaining: PayloadLength,
55 },
56}
57
58impl Default for FrameDecodingState {
59 fn default() -> Self {
60 FrameDecodingState::HeaderBeginning(SmallBufWithLen::new())
61 }
62}
63
64#[doc=include_str!("../examples/decode_frame.rs")]
75#[derive(Clone, Copy, Debug, Default)]
82pub struct WebsocketFrameDecoder {
83 state: FrameDecodingState,
84 mask: [u8; 4],
85 basic_header: [u8; 2],
86 payload_length: PayloadLength,
87 original_opcode: Opcode,
88}
89
90#[derive(Debug,Clone)]
92pub struct WebsocketFrameDecoderAddDataResult {
93 pub consumed_bytes: usize,
99 pub event: Option<WebsocketFrameEvent>,
101}
102
103#[allow(missing_docs)]
104#[derive(Debug, PartialEq, Eq, Clone)]
106pub enum WebsocketFrameEvent {
107 Start{frame_info: FrameInfo, original_opcode: Opcode},
113
114 PayloadChunk{ original_opcode: Opcode},
122
123 End{frame_info: FrameInfo, original_opcode: Opcode},
131}
132
133impl WebsocketFrameDecoder {
134 fn get_opcode(&self) -> Opcode {
135 use Opcode::*;
136 match self.basic_header[0] & 0xF {
137 0 => Continuation,
138 1 => Text,
139 2 => Binary,
140 3 => ReservedData3,
141 4 => ReservedData4,
142 5 => ReservedData5,
143 6 => ReservedData6,
144 7 => ReservedData7,
145 8 => ConnectionClose,
146 9 => Ping,
147 0xA => Pong,
148 0xB => ReservedControlB,
149 0xC => ReservedControlC,
150 0xD => ReservedControlD,
151 0xE => ReservedControlE,
152 0xF => ReservedControlF,
153 _ => unreachable!(),
154 }
155 }
156
157 fn get_frame_info(&self, masked: bool) -> (FrameInfo, Opcode) {
159 let fi = FrameInfo {
160 opcode: self.get_opcode(),
161 payload_length: self.payload_length,
162 mask: if masked { Some(self.mask) } else { None },
163 fin: self.basic_header[0] & 0x80 == 0x80,
164 reserved: (self.basic_header[0] & 0x70) >> 4,
165 };
166 let mut original_opcode = fi.opcode;
167 if original_opcode==Opcode::Continuation {
168 original_opcode = self.original_opcode;
169 }
170 (fi, original_opcode)
171 }
172
173 pub fn add_data<'a, 'b>(
186 &'a mut self,
187 mut data: &'b mut [u8],
188 ) -> Result<WebsocketFrameDecoderAddDataResult, FrameDecoderError> {
189 let original_data_len = data.len();
190 loop {
191 macro_rules! return_dummy {
192 () => {
193 return Ok(WebsocketFrameDecoderAddDataResult {
194 consumed_bytes: original_data_len - data.len(),
195 event: None,
196 });
197 };
198 }
199 if data.len() == 0 && ! matches!(self.state, FrameDecodingState::PayloadData{remaining: 0, ..}) {
200 return_dummy!();
201 }
202 macro_rules! try_to_fill_buffer_or_return {
203 ($v:ident) => {
204 data = $v.slurp(data);
205 if !$v.is_full() {
206 assert!(data.is_empty());
207 return_dummy!();
208 }
209 let $v = $v.data;
210 };
211 }
212 let mut length_is_ready = false;
213 match self.state {
214 FrameDecodingState::HeaderBeginning(ref mut v) => {
215 try_to_fill_buffer_or_return!(v);
216 self.basic_header = v;
217 let opcode = self.get_opcode();
218 if opcode.is_data() && opcode != Opcode::Continuation {
219 self.original_opcode = opcode;
220 }
221 match self.basic_header[1] & 0x7F {
222 0x7E => {
223 self.state = FrameDecodingState::PayloadLength16(SmallBufWithLen::new())
224 }
225 #[cfg(feature="large_frames")]
226 0x7F => {
227 self.state = FrameDecodingState::PayloadLength64(SmallBufWithLen::new())
228 }
229 #[cfg(not(feature="large_frames"))] 0x7F => {
230 return Err(FrameDecoderError::ExceededFrameSize);
231 }
232 x => {
233 self.payload_length = x.into();
234 length_is_ready = true;
235 }
236 };
237 }
238 FrameDecodingState::PayloadLength16(ref mut v) => {
239 try_to_fill_buffer_or_return!(v);
240 self.payload_length = u16::from_be_bytes(v).into();
241 length_is_ready = true;
242 }
243 #[cfg(feature="large_frames")]
244 FrameDecodingState::PayloadLength64(ref mut v) => {
245 try_to_fill_buffer_or_return!(v);
246 self.payload_length = u64::from_be_bytes(v);
247 length_is_ready = true;
248 }
249 FrameDecodingState::MaskingKey(ref mut v) => {
250 try_to_fill_buffer_or_return!(v);
251 self.mask = v;
252 self.state = FrameDecodingState::PayloadData {
253 phase: Some(NonMaxU8::default()),
254 remaining: self.payload_length,
255 };
256 let (frame_info, original_opcode) = self.get_frame_info(true);
257 return Ok(WebsocketFrameDecoderAddDataResult {
258 consumed_bytes: original_data_len - data.len(),
259 event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
260 });
261 }
262 FrameDecodingState::PayloadData {
263 phase,
264 remaining: 0,
265 } => {
266 self.state = FrameDecodingState::HeaderBeginning(SmallBufWithLen::new());
267 let (fi, original_opcode) = self.get_frame_info(phase.is_some());
268 if fi.opcode.is_data() && fi.fin {
269 self.original_opcode = Opcode::Continuation;
270 }
271 return Ok(WebsocketFrameDecoderAddDataResult {
272 consumed_bytes: original_data_len - data.len(),
273 event: Some(WebsocketFrameEvent::End{frame_info: fi, original_opcode}
274 ),
275 });
276 }
277 FrameDecodingState::PayloadData {
278 ref mut phase,
279 ref mut remaining,
280 } => {
281 let start_offset = original_data_len - data.len();
282 let mut max_len = data.len();
283 if let Ok(remaining_usize) = usize::try_from(*remaining) {
284 max_len = max_len.min(remaining_usize);
285 }
286 let (payload_chunk, _rest) = data.split_at_mut(max_len);
287
288 if let Some(phase) = phase {
289 let mut ph = phase.get();
290 masking::apply_mask(self.mask, payload_chunk, ph);
291 ph += payload_chunk.len() as u8;
292 *phase = NonMaxU8::new(ph & 0x03).unwrap();
293 }
294
295 *remaining -= max_len as PayloadLength;
296 let mut original_opcode = self.get_opcode();
297 if original_opcode == Opcode::Continuation {
298 original_opcode = self.original_opcode;
299 }
300 assert_eq!(start_offset, 0);
301 return Ok(WebsocketFrameDecoderAddDataResult {
302 consumed_bytes: max_len,
303 event: Some(WebsocketFrameEvent::PayloadChunk{original_opcode}),
304 });
305 }
306 }
307 if length_is_ready {
308 if self.basic_header[1] & 0x80 == 0x80 {
309 self.state = FrameDecodingState::MaskingKey(SmallBufWithLen::new());
310 } else {
311 self.state = FrameDecodingState::PayloadData {
312 phase: None,
313 remaining: self.payload_length,
314 };
315 let (frame_info, original_opcode) = self.get_frame_info(false);
316 return Ok(WebsocketFrameDecoderAddDataResult {
317 consumed_bytes: original_data_len - data.len(),
318 event: Some(WebsocketFrameEvent::Start{frame_info, original_opcode}),
319 });
320 }
321 }
322 }
323 }
324
325 #[inline]
329 pub fn eof_valid(&self) -> bool {
330 matches!(self.state, FrameDecodingState::HeaderBeginning(..))
331 }
332
333 #[inline]
335 pub const fn new() -> Self {
336 WebsocketFrameDecoder {
337 state: FrameDecodingState::HeaderBeginning(SmallBufWithLen::new()),
338 mask: [0; 4],
339 basic_header: [0; 2],
340 payload_length: 0,
341 original_opcode: Opcode::Continuation,
342 }
343 }
344}