1use std::io::{Read, Write};
2
3use http;
4use crate::{
5 codec::{apply_mask, FrameConfig, Split},
6 errors::{ProtocolError, WsError},
7 frame::{ctor_header, OpCode, OwnedFrame, SimplifiedHeader},
8 protocol::standard_handshake_resp_check,
9};
10use bytes::BytesMut;
11use rand::random;
12
13use super::{DeflateReadState, DeflateWriteState, PMDConfig};
14
15impl DeflateWriteState {
16 pub fn send_owned_frame<S: Write>(
18 &mut self,
19 stream: &mut S,
20 mut frame: OwnedFrame,
21 ) -> Result<(), WsError> {
22 if !frame.header().opcode().is_data() {
23 return self
24 .write_state
25 .send_owned_frame(stream, frame)
26 .map_err(WsError::IOError);
27 }
28 let prev_mask = frame.unmask();
29 let header = frame.header();
30 let frame: Result<OwnedFrame, WsError> = header
31 .opcode()
32 .is_data()
33 .then(|| self.com.as_mut())
34 .flatten()
35 .map(|handler| {
36 let mut compressed = Vec::with_capacity(frame.payload().len());
37 handler
38 .com
39 .compress(&[frame.payload()], &mut compressed)
40 .map_err(|code| WsError::CompressFailed(code.to_string()))?;
41 compressed.truncate(compressed.len() - 4);
42 let mut new = OwnedFrame::new(header.opcode(), prev_mask, &compressed);
43 let header = new.header_mut();
44 header.set_rsv1(true);
45 header.set_fin(header.fin());
46
47 if (self.is_server && handler.config.server_no_context_takeover)
48 || (!self.is_server && handler.config.client_no_context_takeover)
49 {
50 handler
51 .com
52 .reset()
53 .map_err(|code| WsError::CompressFailed(code.to_string()))?;
54 tracing::trace!("reset compressor");
55 }
56 Ok(new)
57 })
58 .unwrap_or_else(|| {
59 if let Some(mask) = prev_mask {
60 frame.mask(mask);
61 }
62 Ok(frame)
63 });
64 self.write_state
65 .send_owned_frame(stream, frame?)
66 .map_err(WsError::IOError)
67 }
68
69 pub fn send<S: Write>(
73 &mut self,
74 stream: &mut S,
75 code: OpCode,
76 payload: &[u8],
77 ) -> Result<(), WsError> {
78 let mask_send = self.config.mask_send_frame;
79 let mask_fn = || {
80 if mask_send {
81 Some(random())
82 } else {
83 None
84 }
85 };
86 if payload.is_empty() {
87 let mask = mask_fn();
88 let frame = OwnedFrame::new(code, mask, &[]);
89 return self.send_owned_frame(stream, frame);
90 }
91
92 let chunk_size = if self.config.auto_fragment_size > 0 {
93 self.config.auto_fragment_size
94 } else {
95 payload.len()
96 };
97 let parts: Vec<&[u8]> = payload.chunks(chunk_size).collect();
98 let total = parts.len();
99 for (idx, chunk) in parts.into_iter().enumerate() {
100 let fin = idx + 1 == total;
101 let mask = mask_fn();
102 match (self.com.as_mut(), code.is_data()) {
103 (Some(handler), true) => {
104 let mut output = vec![];
105 handler
106 .com
107 .compress(&[chunk], &mut output)
108 .map_err(|code| WsError::CompressFailed(code.to_string()))?;
109 output.truncate(output.len() - 4);
110 let header = ctor_header(
111 &mut self.header_buf,
112 fin,
113 true,
114 false,
115 false,
116 mask,
117 code,
118 output.len() as u64,
119 );
120 stream.write_all(header)?;
121 if let Some(mask) = mask {
122 apply_mask(&mut output, mask)
123 };
124 stream.write_all(&output)?;
125 if (self.is_server && handler.config.server_no_context_takeover)
126 || (!self.is_server && handler.config.client_no_context_takeover)
127 {
128 handler
129 .com
130 .reset()
131 .map_err(|code| WsError::CompressFailed(code.to_string()))?;
132 tracing::trace!("reset compressor");
133 }
134 }
135 _ => {
136 let header = ctor_header(
137 &mut self.header_buf,
138 fin,
139 false,
140 false,
141 false,
142 mask,
143 code,
144 chunk.len() as u64,
145 );
146 stream.write_all(header)?;
147 if let Some(mask) = mask {
148 let mut data = BytesMut::from_iter(chunk);
149 apply_mask(&mut data, mask);
150 stream.write_all(&data)?;
151 } else {
152 stream.write_all(chunk)?;
153 }
154 }
155 }
156 }
157 Ok(())
158 }
159}
160
161impl DeflateReadState {
162 fn receive_one<S: Read>(
163 &mut self,
164 stream: &mut S,
165 ) -> Result<(SimplifiedHeader, Vec<u8>), WsError> {
166 let (mut header, data) = self.read_state.receive(stream)?;
167 let data = data.to_vec();
168 let compressed = header.rsv1;
169 let is_data_frame = header.code.is_data();
170 if compressed && !is_data_frame {
171 return Err(WsError::ProtocolError {
172 close_code: 1002,
173 error: ProtocolError::CompressedControlFrame,
174 });
175 }
176 if !is_data_frame || !compressed {
177 return Ok((header, data));
178 }
179 let frame = match self.de.as_mut() {
180 Some(handler) => {
181 let mut de_data = vec![];
182 handler
183 .de
184 .de_compress(&[&data, &[0, 0, 255, 255]], &mut de_data)
185 .map_err(|code| WsError::DeCompressFailed(code.to_string()))?;
186 if (self.is_server && handler.config.server_no_context_takeover)
187 || (!self.is_server && handler.config.client_no_context_takeover)
188 {
189 handler
190 .de
191 .reset()
192 .map_err(|code| WsError::DeCompressFailed(code.to_string()))?;
193 tracing::trace!("reset decompressor state");
194 }
195 de_data
196 }
197 None => {
198 if header.rsv1 {
199 return Err(WsError::DeCompressFailed(
200 "extension not enabled but got compressed frame".into(),
201 ));
202 } else {
203 data
204 }
205 }
206 };
207 header.rsv1 = false;
208 Ok((header, frame))
209 }
210
211 pub fn receive<S: Read>(
213 &mut self,
214 stream: &mut S,
215 ) -> Result<(SimplifiedHeader, &[u8]), WsError> {
216 loop {
217 let (mut header, mut data) = self.receive_one(stream)?;
218 if !self.config.merge_frame {
219 self.fragmented_data.clear();
220 self.fragmented_data.append(&mut data);
221 break Ok((header, &self.fragmented_data));
222 }
223 match header.code {
224 OpCode::Continue => {
225 if !self.fragmented {
226 return Err(WsError::ProtocolError {
227 close_code: 1002,
228 error: ProtocolError::MissInitialFragmentedFrame,
229 });
230 }
231 let fin = header.fin;
232 self.fragmented_data.extend_from_slice(&data);
233 if fin {
234 self.fragmented = false;
235 header.code = self.fragmented_type;
236 break Ok((header, &self.fragmented_data));
237 } else {
238 continue;
239 }
240 }
241 OpCode::Text | OpCode::Binary => {
242 if self.fragmented {
243 return Err(WsError::ProtocolError {
244 close_code: 1002,
245 error: ProtocolError::NotContinueFrameAfterFragmented,
246 });
247 }
248 if !header.fin {
249 self.fragmented = true;
250 self.fragmented_type = header.code;
251 if header.code == OpCode::Text
252 && self.config.validate_utf8.is_fast_fail()
253 && simdutf8::basic::from_utf8(&data).is_err()
254 {
255 return Err(WsError::ProtocolError {
256 close_code: 1007,
257 error: ProtocolError::InvalidUtf8,
258 });
259 }
260 self.fragmented_data.clear();
261 self.fragmented_data.extend_from_slice(&data);
262 continue;
263 } else {
264 if header.code == OpCode::Text
265 && self.config.validate_utf8.should_check()
266 && simdutf8::basic::from_utf8(&data).is_err()
267 {
268 return Err(WsError::ProtocolError {
269 close_code: 1007,
270 error: ProtocolError::InvalidUtf8,
271 });
272 }
273 self.fragmented_data.clear();
274 self.fragmented_data.extend_from_slice(&data);
275 break Ok((header, &self.fragmented_data));
276 }
277 }
278 OpCode::Close | OpCode::Ping | OpCode::Pong => {
279 self.control_buf = data;
280 break Ok((header, &self.control_buf));
281 }
282 _ => break Err(WsError::UnsupportedFrame(header.code)),
283 }
284 }
285 }
286}
287
288pub struct DeflateCodec<S: Read + Write> {
290 read_state: DeflateReadState,
291 write_state: DeflateWriteState,
292 stream: S,
293}
294
295impl<S: Read + Write> DeflateCodec<S> {
296 pub fn new(
298 stream: S,
299 frame_config: FrameConfig,
300 pmd_config: Option<PMDConfig>,
301 is_server: bool,
302 ) -> Self {
303 let read_state =
304 DeflateReadState::with_config(frame_config.clone(), pmd_config.clone(), is_server);
305 let write_state = DeflateWriteState::with_config(frame_config, pmd_config, is_server);
306 Self {
307 read_state,
308 write_state,
309 stream,
310 }
311 }
312
313 pub fn factory(req: http::Request<()>, stream: S) -> Result<Self, WsError> {
315 let mut pmd_confs: Vec<PMDConfig> = vec![];
316 for (k, v) in req.headers() {
317 if k.as_str().to_lowercase() == "sec-websocket-extensions" {
318 if let Ok(s) = v.to_str() {
319 match PMDConfig::parse_str(s) {
320 Ok(mut conf) => {
321 pmd_confs.append(&mut conf);
322 }
323 Err(e) => return Err(WsError::HandShakeFailed(e)),
324 }
325 }
326 }
327 }
328 let mut pmd_conf = pmd_confs.pop();
329 if let Some(conf) = pmd_conf.as_mut() {
330 let min = conf.client_max_window_bits.min(conf.server_max_window_bits);
331 conf.client_max_window_bits = min;
332 conf.server_max_window_bits = min;
333 }
334 tracing::debug!("use deflate config {:?}", pmd_conf);
335
336 let frame_conf = FrameConfig {
337 mask_send_frame: false,
338 ..Default::default()
339 };
340 let codec = DeflateCodec::new(stream, frame_conf, pmd_conf, true);
341 Ok(codec)
342 }
343
344 pub fn check_fn(key: String, resp: http::Response<()>, stream: S) -> Result<Self, WsError> {
346 standard_handshake_resp_check(key.as_bytes(), &resp)?;
347 let mut pmd_confs: Vec<PMDConfig> = vec![];
348 for (k, v) in resp.headers() {
349 if k.as_str().to_lowercase() == "sec-websocket-extensions" {
350 if let Ok(s) = v.to_str() {
351 match PMDConfig::parse_str(s) {
352 Ok(mut conf) => {
353 pmd_confs.append(&mut conf);
354 }
355 Err(e) => return Err(WsError::HandShakeFailed(e)),
356 }
357 }
358 }
359 }
360 let mut pmd_conf = pmd_confs.pop();
361 if let Some(conf) = pmd_conf.as_mut() {
362 let min = conf.client_max_window_bits.min(conf.server_max_window_bits);
363 conf.client_max_window_bits = min;
364 conf.server_max_window_bits = min;
365 }
366 tracing::debug!("use deflate config: {:?}", pmd_conf);
367 let codec = DeflateCodec::new(stream, Default::default(), pmd_conf, false);
368 Ok(codec)
369 }
370
371 pub fn stream_mut(&mut self) -> &mut S {
373 &mut self.stream
374 }
375
376 pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
378 self.read_state.receive(&mut self.stream)
379 }
380
381 pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
383 self.write_state.send_owned_frame(&mut self.stream, frame)
384 }
385
386 pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
390 self.write_state.send(&mut self.stream, code, payload)
391 }
392
393 pub fn text(&mut self, text: &str) -> Result<(), WsError> {
395 self.write_state
396 .send(&mut self.stream, OpCode::Text, text.as_bytes())
397 }
398
399 pub fn binary(&mut self, data: &[u8]) -> Result<(), WsError> {
401 self.send(OpCode::Binary, data)
402 }
403
404 pub fn ping(&mut self, data: &[u8]) -> Result<(), WsError> {
406 self.send(OpCode::Ping, data)
407 }
408
409 pub fn pong(&mut self, data: &[u8]) -> Result<(), WsError> {
411 self.send(OpCode::Pong, data)
412 }
413
414 pub fn close(&mut self, code: u16, msg: &[u8]) -> Result<(), WsError> {
416 let mut data = code.to_be_bytes().to_vec();
417 data.extend_from_slice(msg);
418 self.send(OpCode::Close, &data)
419 }
420
421 pub fn flush(&mut self) -> Result<(), WsError> {
423 self.stream.flush().map_err(WsError::IOError)
424 }
425}
426
427pub struct DeflateRecv<S: Read> {
429 stream: S,
430 read_state: DeflateReadState,
431}
432
433impl<S: Read> DeflateRecv<S> {
434 pub fn new(stream: S, read_state: DeflateReadState) -> Self {
436 Self { stream, read_state }
437 }
438
439 pub fn stream_mut(&mut self) -> &mut S {
441 &mut self.stream
442 }
443
444 pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
446 self.read_state.receive(&mut self.stream)
447 }
448}
449
450pub struct DeflateSend<S: Write> {
452 stream: S,
453 write_state: DeflateWriteState,
454}
455
456impl<S: Write> DeflateSend<S> {
457 pub fn new(stream: S, write_state: DeflateWriteState) -> Self {
459 Self {
460 stream,
461 write_state,
462 }
463 }
464
465 pub fn stream_mut(&mut self) -> &mut S {
467 &mut self.stream
468 }
469
470 pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
472 self.write_state.send_owned_frame(&mut self.stream, frame)
473 }
474
475 pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
479 self.write_state.send(&mut self.stream, code, payload)
480 }
481
482 pub fn text(&mut self, text: &str) -> Result<(), WsError> {
484 self.write_state
485 .send(&mut self.stream, OpCode::Text, text.as_bytes())
486 }
487
488 pub fn binary(&mut self, data: &[u8]) -> Result<(), WsError> {
490 self.send(OpCode::Binary, data)
491 }
492
493 pub fn ping(&mut self, data: &[u8]) -> Result<(), WsError> {
495 self.send(OpCode::Ping, data)
496 }
497
498 pub fn pong(&mut self, data: &[u8]) -> Result<(), WsError> {
500 self.send(OpCode::Pong, data)
501 }
502
503 pub fn close(&mut self, code: u16, msg: &[u8]) -> Result<(), WsError> {
505 let mut data = code.to_be_bytes().to_vec();
506 data.extend_from_slice(msg);
507 self.send(OpCode::Close, &data)
508 }
509
510 pub fn flush(&mut self) -> Result<(), WsError> {
512 self.stream.flush().map_err(WsError::IOError)
513 }
514}
515
516impl<R, W, S> DeflateCodec<S>
517where
518 R: Read,
519 W: Write,
520 S: Read + Write + Split<R = R, W = W>,
521{
522 pub fn split(self) -> (DeflateRecv<R>, DeflateSend<W>) {
524 let DeflateCodec {
525 stream,
526 read_state,
527 write_state,
528 } = self;
529 let (read, write) = stream.split();
530 (
531 DeflateRecv::new(read, read_state),
532 DeflateSend::new(write, write_state),
533 )
534 }
535}