1use crate::errors::{ProtocolError, WsError};
2use crate::frame::{get_bit, HeaderView, OpCode, SimplifiedHeader};
3use http;
4use crate::protocol::{cal_accept_key, standard_handshake_req_check};
5use bytes::BytesMut;
6use std::fmt::Debug;
7use std::ops::Range;
8
9#[cfg(feature = "sync")]
10mod blocking;
11
12#[cfg(feature = "sync")]
13pub use blocking::*;
14
15#[cfg(feature = "async")]
16mod non_blocking;
17
18#[cfg(feature = "async")]
19pub use non_blocking::*;
20
21#[derive(Debug, Clone)]
23pub enum ValidateUtf8Policy {
24 Off,
26 FastFail,
28 On,
30}
31
32#[allow(missing_docs)]
33impl ValidateUtf8Policy {
34 pub fn should_check(&self) -> bool {
35 !matches!(self, Self::Off)
36 }
37
38 pub fn is_fast_fail(&self) -> bool {
39 matches!(self, Self::FastFail)
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct FrameConfig {
46 pub check_rsv: bool,
48 pub mask_send_frame: bool,
50 pub renew_buf_on_write: bool,
52 pub auto_unmask: bool,
54 pub max_frame_payload_size: usize,
56 pub auto_fragment_size: usize,
58 pub merge_frame: bool,
60 pub validate_utf8: ValidateUtf8Policy,
62 pub resize_size: usize,
64 pub resize_thresh: usize,
66}
67
68impl Default for FrameConfig {
69 fn default() -> Self {
70 Self {
71 check_rsv: true,
72 mask_send_frame: true,
73 renew_buf_on_write: false,
74 auto_unmask: true,
75 max_frame_payload_size: 0,
76 auto_fragment_size: 0,
77 merge_frame: true,
78 validate_utf8: ValidateUtf8Policy::FastFail,
79 resize_size: 4096,
80 resize_thresh: 1024,
81 }
82 }
83}
84
85#[inline]
87pub fn apply_mask(buf: &mut [u8], mask: [u8; 4]) {
88 apply_mask_array_chunk(buf, mask)
89}
90
91#[inline]
92fn apply_mask_array_chunk(buf: &mut [u8], mask: [u8; 4]) {
93 let mask32 = u32::from_ne_bytes(mask);
94 let mut iter = buf.chunks_exact_mut(4);
95 while let Some(chunk) = iter.next() {
96 let val: &mut u32 = unsafe { std::mem::transmute(chunk.as_mut_ptr().cast::<u32>()) };
97 *val ^= mask32;
98 }
99 for (i, byte) in iter.into_remainder().iter_mut().enumerate() {
100 *byte ^= mask[i & 3];
101 }
102}
103
104pub struct FrameReadState {
106 fragmented: bool,
107 config: FrameConfig,
108 fragmented_data: Vec<u8>,
109 fragmented_type: OpCode,
110 buf: FrameBuffer,
111}
112
113impl Default for FrameReadState {
114 fn default() -> Self {
115 Self {
116 fragmented: false,
117 config: Default::default(),
118 fragmented_data: vec![],
119 fragmented_type: OpCode::default(),
120 buf: FrameBuffer::new(),
121 }
122 }
123}
124
125impl FrameReadState {
126 pub fn with_config(config: FrameConfig) -> Self {
128 Self {
129 config,
130 ..Self::default()
131 }
132 }
133
134 pub fn is_header_ok(&self) -> bool {
136 let ava_data = self.buf.ava_data();
137 if ava_data.len() < 2 {
138 false
139 } else {
140 let len = ava_data[1] & 0b01111111;
141 let mask = get_bit(&ava_data, 1, 0);
142 let mut min_len = match len {
143 0..=125 => 2,
144 126 => 4,
145 127 => 10,
146 _ => unreachable!(),
147 };
148 if mask {
149 min_len += 4;
150 }
151 ava_data.len() >= min_len
152 }
153 }
154
155 #[inline]
157 pub fn get_leading_bits(&self) -> u8 {
158 self.buf.ava_data()[0] >> 4
159 }
160
161 #[inline]
163 pub fn parse_frame_header(&mut self) -> Result<(usize, usize, usize), WsError> {
164 let ava_data = self.buf.ava_data();
165 let leading_bits = self.get_leading_bits();
166 let max_payload_size = self.config.max_frame_payload_size;
167 let check_rsv = self.config.check_rsv;
168
169 fn parse_payload_len(source: &[u8]) -> Result<(usize, usize), ProtocolError> {
170 match source[1] {
171 len @ (0..=125 | 128..=253) => Ok((1, (len & 127) as usize)),
172 126 | 254 => {
173 if source.len() < 4 {
174 return Err(ProtocolError::InsufficientLen(source.len()));
175 }
176 Ok((
177 1 + 2,
178 u16::from_be_bytes((&source[2..4]).try_into().unwrap()) as usize,
179 ))
180 }
181 127 | 255 => {
182 if source.len() < 10 {
183 return Err(ProtocolError::InsufficientLen(source.len()));
184 }
185 Ok((
186 1 + 8,
187 usize::from_be_bytes((&source[2..(8 + 2)]).try_into().unwrap()),
188 ))
189 }
190 }
191 }
192
193 if check_rsv && !(leading_bits == 0b00001000 || leading_bits == 0b00000000) {
194 return Err(WsError::ProtocolError {
195 close_code: 1008,
196 error: ProtocolError::InvalidLeadingBits(leading_bits),
197 });
198 }
199 let (len_occ_bytes, payload_len) =
200 parse_payload_len(ava_data).map_err(|e| WsError::ProtocolError {
201 close_code: 1008,
202 error: e,
203 })?;
204
205 if max_payload_size > 0 && payload_len > max_payload_size {
206 return Err(WsError::ProtocolError {
207 close_code: 1008,
208 error: ProtocolError::PayloadTooLarge(max_payload_size),
209 });
210 }
211 let mask = get_bit(ava_data, 1, 0);
212 let header_len = 1 + len_occ_bytes + if mask { 4 } else { 0 };
213 Ok((header_len, payload_len, header_len + payload_len))
214 }
215
216 #[inline]
218 pub fn consume_frame(
219 &mut self,
220 header_len: usize,
221 payload_len: usize,
222 total_len: usize,
223 ) -> (SimplifiedHeader, Range<usize>) {
224 let buf = &mut self.buf;
225 let auto_unmask = self.config.auto_unmask;
226
227 let ava_data = buf.ava_mut_data();
228 let (header_data, remain) = ava_data.split_at_mut(header_len);
229 let header = HeaderView(header_data);
230 let payload = remain.split_at_mut(payload_len).0;
231 if auto_unmask {
232 if let Some(mask) = header.masking_key() {
233 apply_mask(payload, mask)
234 }
235 }
236 let header: SimplifiedHeader = header.into();
237 let s_idx = buf.consume_idx + header_len;
238 let e_idx = s_idx + payload_len;
239 buf.consume(total_len);
240 (header, s_idx..e_idx)
241 }
242
243 fn check_frame(
244 &mut self,
245 header: SimplifiedHeader,
246 range: Range<usize>,
247 ) -> Result<(), WsError> {
248 let fragmented = &mut self.fragmented;
249 let utf8_policy = &self.config.validate_utf8;
250 let payload = &self.buf.buf[range];
251 match header.code {
252 OpCode::Continue => {
253 if !*fragmented {
254 return Err(WsError::ProtocolError {
255 close_code: 1002,
256 error: ProtocolError::MissInitialFragmentedFrame,
257 });
258 }
259 if header.fin {
260 *fragmented = false;
261 }
262 Ok(())
263 }
264 OpCode::Binary => {
265 if *fragmented {
266 return Err(WsError::ProtocolError {
267 close_code: 1002,
268 error: ProtocolError::NotContinueFrameAfterFragmented,
269 });
270 }
271 *fragmented = !header.fin;
272 Ok(())
273 }
274 OpCode::Text => {
275 if *fragmented {
276 return Err(WsError::ProtocolError {
277 close_code: 1002,
278 error: ProtocolError::NotContinueFrameAfterFragmented,
279 });
280 }
281 if !header.fin {
282 *fragmented = true;
283 if header.code == OpCode::Text
284 && utf8_policy.is_fast_fail()
285 && simdutf8::basic::from_utf8(payload).is_err()
286 {
287 return Err(WsError::ProtocolError {
288 close_code: 1007,
289 error: ProtocolError::InvalidUtf8,
290 });
291 }
292
293 Ok(())
294 } else {
295 if header.code == OpCode::Text
296 && utf8_policy.should_check()
297 && simdutf8::basic::from_utf8(payload).is_err()
298 {
299 return Err(WsError::ProtocolError {
300 close_code: 1007,
301 error: ProtocolError::InvalidUtf8,
302 });
303 }
304 Ok(())
305 }
306 }
307 OpCode::Close | OpCode::Ping | OpCode::Pong => {
308 if !header.fin {
309 return Err(WsError::ProtocolError {
310 close_code: 1002,
311 error: ProtocolError::FragmentedControlFrame,
312 });
313 }
314 let payload_len = payload.len();
315 if payload.len() > 125 {
316 let error = ProtocolError::ControlFrameTooBig(payload_len);
317 return Err(WsError::ProtocolError {
318 close_code: 1002,
319 error,
320 });
321 }
322 if header.code == OpCode::Close {
323 if payload_len == 1 {
324 let error = ProtocolError::InvalidCloseFramePayload;
325 return Err(WsError::ProtocolError {
326 close_code: 1002,
327 error,
328 });
329 }
330 if payload_len >= 2 {
331 let mut code_byte = [0u8; 2];
333 code_byte.copy_from_slice(&payload[..2]);
334 let code = u16::from_be_bytes(code_byte);
335 if code < 1000
336 || (1004..=1006).contains(&code)
337 || (1015..=2999).contains(&code)
338 || code >= 5000
339 {
340 let error = ProtocolError::InvalidCloseCode(code);
341 return Err(WsError::ProtocolError {
342 close_code: 1002,
343 error,
344 });
345 }
346
347 if String::from_utf8(payload[2..].to_vec()).is_err() {
349 let error = ProtocolError::InvalidUtf8;
350 return Err(WsError::ProtocolError {
351 close_code: 1007,
352 error,
353 });
354 }
355 }
356 }
357 Ok(())
358 }
359 _ => Err(WsError::UnsupportedFrame(header.code)),
360 }
361 }
362
363 #[doc(hidden)]
365 #[inline]
366 pub fn merge_frame(
367 &mut self,
368 header: SimplifiedHeader,
369 range: Range<usize>,
370 ) -> Result<Option<bool>, WsError> {
371 let fragmented = &mut self.fragmented;
372 let fragmented_data = &mut self.fragmented_data;
373 let fragmented_type = &mut self.fragmented_type;
374 let payload = &self.buf.buf[range];
375 match header.code {
376 OpCode::Continue => {
377 fragmented_data.extend_from_slice(payload);
378 if header.fin {
379 *fragmented = false;
380 Ok(Some(true))
381 } else {
382 Ok(None)
383 }
384 }
385 OpCode::Text | OpCode::Binary => {
386 *fragmented_type = header.code;
387 if !header.fin {
388 *fragmented = true;
389 *fragmented_type = header.code;
390 fragmented_data.clear();
391 fragmented_data.extend_from_slice(payload);
392 Ok(None)
393 } else {
394 Ok(Some(false))
395 }
396 }
397 OpCode::Close | OpCode::Ping | OpCode::Pong => Ok(Some(false)),
398 _ => unreachable!(),
399 }
400 }
401}
402
403pub(crate) struct FrameBuffer {
404 pub(crate) buf: Vec<u8>,
405 tmp: Vec<u8>,
406 produce_idx: usize,
407 consume_idx: usize,
408}
409
410impl FrameBuffer {
411 pub(crate) fn new() -> Self {
412 Self {
413 buf: vec![0; 8192],
414 tmp: vec![0; 8192],
415 produce_idx: 0,
416 consume_idx: 0,
417 }
418 }
419
420 pub(crate) fn prepare(&mut self, payload_size: usize) -> &mut [u8] {
421 let remain = self.buf.len() - self.produce_idx;
422 if remain >= payload_size {
423 &mut self.buf[self.produce_idx..(self.produce_idx + payload_size)]
424 } else {
425 if self.produce_idx == self.consume_idx {
426 if payload_size > self.buf.len() {
427 self.buf.resize(payload_size, 0);
428 }
429 self.consume_idx = 0;
430 self.produce_idx = 0;
431 &mut self.buf[0..payload_size]
432 } else {
433 self.tmp.resize(self.produce_idx - self.consume_idx, 0);
434 self.tmp
435 .copy_from_slice(&self.buf[self.consume_idx..self.produce_idx]);
436 if payload_size + self.tmp.len() > self.buf.len() {
437 self.buf.resize(payload_size + self.tmp.len(), 0);
438 }
439 self.buf[..(self.tmp.len())].copy_from_slice(&self.tmp);
440 self.consume_idx = 0;
441 self.produce_idx = self.tmp.len();
442 &mut self.buf[self.produce_idx..(self.produce_idx + payload_size)]
443 }
444 }
445 }
446
447 pub(crate) fn ava_data(&self) -> &[u8] {
448 &self.buf[self.consume_idx..self.produce_idx]
449 }
450
451 pub(crate) fn ava_mut_data(&mut self) -> &mut [u8] {
452 &mut self.buf[self.consume_idx..self.produce_idx]
453 }
454
455 pub(crate) fn produce(&mut self, num: usize) {
456 self.produce_idx += num;
457 }
458
459 pub(crate) fn consume(&mut self, num: usize) {
460 self.consume_idx += num;
461 }
462}
463
464#[allow(dead_code)]
466#[derive(Debug, Clone, Default)]
467pub struct FrameWriteState {
468 config: FrameConfig,
469 header_buf: [u8; 14],
470 buf: BytesMut,
471}
472
473impl FrameWriteState {
474 pub fn with_config(config: FrameConfig) -> Self {
476 Self {
477 config,
478 header_buf: [0; 14],
479 buf: BytesMut::new(),
480 }
481 }
482}
483
484pub fn default_handshake_handler(
486 req: http::Request<()>,
487) -> Result<(http::Request<()>, http::Response<String>), (http::Response<String>, WsError)> {
488 match standard_handshake_req_check(&req) {
489 Ok(_) => {
490 let key = req.headers().get("sec-websocket-key").unwrap();
491 let resp = http::Response::builder()
492 .version(http::Version::HTTP_11)
493 .status(http::StatusCode::SWITCHING_PROTOCOLS)
494 .header("Upgrade", "WebSocket")
495 .header("Connection", "Upgrade")
496 .header("Sec-WebSocket-Accept", cal_accept_key(key.as_bytes()))
497 .body(String::new())
498 .unwrap();
499 Ok((req, resp))
500 }
501 Err(e) => {
502 let resp = http::Response::builder()
503 .version(http::Version::HTTP_11)
504 .status(http::StatusCode::BAD_REQUEST)
505 .header("Content-Type", "text/html")
506 .body(e.to_string())
507 .unwrap();
508 Err((resp, e))
509 }
510 }
511}