ws_tool/codec/deflate/
blocking.rs

1use std::io::{Read, Write};
2
3use http;
4use crate::{
5    codec::{apply_mask, FrameConfig, Split},
6    errors::{ProtocolError, WsError},
7    frame::{ctor_header, OpCode, OwnedFrame, SimplifiedHeader},
8    protocol::standard_handshake_resp_check,
9};
10use bytes::BytesMut;
11use rand::random;
12
13use super::{DeflateReadState, DeflateWriteState, PMDConfig};
14
15impl DeflateWriteState {
16    /// send a read frame, **this method will not check validation of frame and do not fragment**
17    pub fn send_owned_frame<S: Write>(
18        &mut self,
19        stream: &mut S,
20        mut frame: OwnedFrame,
21    ) -> Result<(), WsError> {
22        if !frame.header().opcode().is_data() {
23            return self
24                .write_state
25                .send_owned_frame(stream, frame)
26                .map_err(WsError::IOError);
27        }
28        let prev_mask = frame.unmask();
29        let header = frame.header();
30        let frame: Result<OwnedFrame, WsError> = header
31            .opcode()
32            .is_data()
33            .then(|| self.com.as_mut())
34            .flatten()
35            .map(|handler| {
36                let mut compressed = Vec::with_capacity(frame.payload().len());
37                handler
38                    .com
39                    .compress(&[frame.payload()], &mut compressed)
40                    .map_err(|code| WsError::CompressFailed(code.to_string()))?;
41                compressed.truncate(compressed.len() - 4);
42                let mut new = OwnedFrame::new(header.opcode(), prev_mask, &compressed);
43                let header = new.header_mut();
44                header.set_rsv1(true);
45                header.set_fin(header.fin());
46
47                if (self.is_server && handler.config.server_no_context_takeover)
48                    || (!self.is_server && handler.config.client_no_context_takeover)
49                {
50                    handler
51                        .com
52                        .reset()
53                        .map_err(|code| WsError::CompressFailed(code.to_string()))?;
54                    tracing::trace!("reset compressor");
55                }
56                Ok(new)
57            })
58            .unwrap_or_else(|| {
59                if let Some(mask) = prev_mask {
60                    frame.mask(mask);
61                }
62                Ok(frame)
63            });
64        self.write_state
65            .send_owned_frame(stream, frame?)
66            .map_err(WsError::IOError)
67    }
68
69    /// send payload
70    ///
71    /// will auto fragment **before compression** if auto_fragment_size > 0
72    pub fn send<S: Write>(
73        &mut self,
74        stream: &mut S,
75        code: OpCode,
76        payload: &[u8],
77    ) -> Result<(), WsError> {
78        let mask_send = self.config.mask_send_frame;
79        let mask_fn = || {
80            if mask_send {
81                Some(random())
82            } else {
83                None
84            }
85        };
86        if payload.is_empty() {
87            let mask = mask_fn();
88            let frame = OwnedFrame::new(code, mask, &[]);
89            return self.send_owned_frame(stream, frame);
90        }
91
92        let chunk_size = if self.config.auto_fragment_size > 0 {
93            self.config.auto_fragment_size
94        } else {
95            payload.len()
96        };
97        let parts: Vec<&[u8]> = payload.chunks(chunk_size).collect();
98        let total = parts.len();
99        for (idx, chunk) in parts.into_iter().enumerate() {
100            let fin = idx + 1 == total;
101            let mask = mask_fn();
102            match (self.com.as_mut(), code.is_data()) {
103                (Some(handler), true) => {
104                    let mut output = vec![];
105                    handler
106                        .com
107                        .compress(&[chunk], &mut output)
108                        .map_err(|code| WsError::CompressFailed(code.to_string()))?;
109                    output.truncate(output.len() - 4);
110                    let header = ctor_header(
111                        &mut self.header_buf,
112                        fin,
113                        true,
114                        false,
115                        false,
116                        mask,
117                        code,
118                        output.len() as u64,
119                    );
120                    stream.write_all(header)?;
121                    if let Some(mask) = mask {
122                        apply_mask(&mut output, mask)
123                    };
124                    stream.write_all(&output)?;
125                    if (self.is_server && handler.config.server_no_context_takeover)
126                        || (!self.is_server && handler.config.client_no_context_takeover)
127                    {
128                        handler
129                            .com
130                            .reset()
131                            .map_err(|code| WsError::CompressFailed(code.to_string()))?;
132                        tracing::trace!("reset compressor");
133                    }
134                }
135                _ => {
136                    let header = ctor_header(
137                        &mut self.header_buf,
138                        fin,
139                        false,
140                        false,
141                        false,
142                        mask,
143                        code,
144                        chunk.len() as u64,
145                    );
146                    stream.write_all(header)?;
147                    if let Some(mask) = mask {
148                        let mut data = BytesMut::from_iter(chunk);
149                        apply_mask(&mut data, mask);
150                        stream.write_all(&data)?;
151                    } else {
152                        stream.write_all(chunk)?;
153                    }
154                }
155            }
156        }
157        Ok(())
158    }
159}
160
161impl DeflateReadState {
162    fn receive_one<S: Read>(
163        &mut self,
164        stream: &mut S,
165    ) -> Result<(SimplifiedHeader, Vec<u8>), WsError> {
166        let (mut header, data) = self.read_state.receive(stream)?;
167        let data = data.to_vec();
168        let compressed = header.rsv1;
169        let is_data_frame = header.code.is_data();
170        if compressed && !is_data_frame {
171            return Err(WsError::ProtocolError {
172                close_code: 1002,
173                error: ProtocolError::CompressedControlFrame,
174            });
175        }
176        if !is_data_frame || !compressed {
177            return Ok((header, data));
178        }
179        let frame = match self.de.as_mut() {
180            Some(handler) => {
181                let mut de_data = vec![];
182                handler
183                    .de
184                    .de_compress(&[&data, &[0, 0, 255, 255]], &mut de_data)
185                    .map_err(|code| WsError::DeCompressFailed(code.to_string()))?;
186                if (self.is_server && handler.config.server_no_context_takeover)
187                    || (!self.is_server && handler.config.client_no_context_takeover)
188                {
189                    handler
190                        .de
191                        .reset()
192                        .map_err(|code| WsError::DeCompressFailed(code.to_string()))?;
193                    tracing::trace!("reset decompressor state");
194                }
195                de_data
196            }
197            None => {
198                if header.rsv1 {
199                    return Err(WsError::DeCompressFailed(
200                        "extension not enabled but got compressed frame".into(),
201                    ));
202                } else {
203                    data
204                }
205            }
206        };
207        header.rsv1 = false;
208        Ok((header, frame))
209    }
210
211    /// receive a message
212    pub fn receive<S: Read>(
213        &mut self,
214        stream: &mut S,
215    ) -> Result<(SimplifiedHeader, &[u8]), WsError> {
216        loop {
217            let (mut header, mut data) = self.receive_one(stream)?;
218            if !self.config.merge_frame {
219                self.fragmented_data.clear();
220                self.fragmented_data.append(&mut data);
221                break Ok((header, &self.fragmented_data));
222            }
223            match header.code {
224                OpCode::Continue => {
225                    if !self.fragmented {
226                        return Err(WsError::ProtocolError {
227                            close_code: 1002,
228                            error: ProtocolError::MissInitialFragmentedFrame,
229                        });
230                    }
231                    let fin = header.fin;
232                    self.fragmented_data.extend_from_slice(&data);
233                    if fin {
234                        self.fragmented = false;
235                        header.code = self.fragmented_type;
236                        break Ok((header, &self.fragmented_data));
237                    } else {
238                        continue;
239                    }
240                }
241                OpCode::Text | OpCode::Binary => {
242                    if self.fragmented {
243                        return Err(WsError::ProtocolError {
244                            close_code: 1002,
245                            error: ProtocolError::NotContinueFrameAfterFragmented,
246                        });
247                    }
248                    if !header.fin {
249                        self.fragmented = true;
250                        self.fragmented_type = header.code;
251                        if header.code == OpCode::Text
252                            && self.config.validate_utf8.is_fast_fail()
253                            && simdutf8::basic::from_utf8(&data).is_err()
254                        {
255                            return Err(WsError::ProtocolError {
256                                close_code: 1007,
257                                error: ProtocolError::InvalidUtf8,
258                            });
259                        }
260                        self.fragmented_data.clear();
261                        self.fragmented_data.extend_from_slice(&data);
262                        continue;
263                    } else {
264                        if header.code == OpCode::Text
265                            && self.config.validate_utf8.should_check()
266                            && simdutf8::basic::from_utf8(&data).is_err()
267                        {
268                            return Err(WsError::ProtocolError {
269                                close_code: 1007,
270                                error: ProtocolError::InvalidUtf8,
271                            });
272                        }
273                        self.fragmented_data.clear();
274                        self.fragmented_data.extend_from_slice(&data);
275                        break Ok((header, &self.fragmented_data));
276                    }
277                }
278                OpCode::Close | OpCode::Ping | OpCode::Pong => {
279                    self.control_buf = data;
280                    break Ok((header, &self.control_buf));
281                }
282                _ => break Err(WsError::UnsupportedFrame(header.code)),
283            }
284        }
285    }
286}
287
288/// recv/send deflate message
289pub struct DeflateCodec<S: Read + Write> {
290    read_state: DeflateReadState,
291    write_state: DeflateWriteState,
292    stream: S,
293}
294
295impl<S: Read + Write> DeflateCodec<S> {
296    /// construct method
297    pub fn new(
298        stream: S,
299        frame_config: FrameConfig,
300        pmd_config: Option<PMDConfig>,
301        is_server: bool,
302    ) -> Self {
303        let read_state =
304            DeflateReadState::with_config(frame_config.clone(), pmd_config.clone(), is_server);
305        let write_state = DeflateWriteState::with_config(frame_config, pmd_config, is_server);
306        Self {
307            read_state,
308            write_state,
309            stream,
310        }
311    }
312
313    /// used for server side to construct a new server
314    pub fn factory(req: http::Request<()>, stream: S) -> Result<Self, WsError> {
315        let mut pmd_confs: Vec<PMDConfig> = vec![];
316        for (k, v) in req.headers() {
317            if k.as_str().to_lowercase() == "sec-websocket-extensions" {
318                if let Ok(s) = v.to_str() {
319                    match PMDConfig::parse_str(s) {
320                        Ok(mut conf) => {
321                            pmd_confs.append(&mut conf);
322                        }
323                        Err(e) => return Err(WsError::HandShakeFailed(e)),
324                    }
325                }
326            }
327        }
328        let mut pmd_conf = pmd_confs.pop();
329        if let Some(conf) = pmd_conf.as_mut() {
330            let min = conf.client_max_window_bits.min(conf.server_max_window_bits);
331            conf.client_max_window_bits = min;
332            conf.server_max_window_bits = min;
333        }
334        tracing::debug!("use deflate config {:?}", pmd_conf);
335
336        let frame_conf = FrameConfig {
337            mask_send_frame: false,
338            ..Default::default()
339        };
340        let codec = DeflateCodec::new(stream, frame_conf, pmd_conf, true);
341        Ok(codec)
342    }
343
344    /// used for client side to construct a new client
345    pub fn check_fn(key: String, resp: http::Response<()>, stream: S) -> Result<Self, WsError> {
346        standard_handshake_resp_check(key.as_bytes(), &resp)?;
347        let mut pmd_confs: Vec<PMDConfig> = vec![];
348        for (k, v) in resp.headers() {
349            if k.as_str().to_lowercase() == "sec-websocket-extensions" {
350                if let Ok(s) = v.to_str() {
351                    match PMDConfig::parse_str(s) {
352                        Ok(mut conf) => {
353                            pmd_confs.append(&mut conf);
354                        }
355                        Err(e) => return Err(WsError::HandShakeFailed(e)),
356                    }
357                }
358            }
359        }
360        let mut pmd_conf = pmd_confs.pop();
361        if let Some(conf) = pmd_conf.as_mut() {
362            let min = conf.client_max_window_bits.min(conf.server_max_window_bits);
363            conf.client_max_window_bits = min;
364            conf.server_max_window_bits = min;
365        }
366        tracing::debug!("use deflate config: {:?}", pmd_conf);
367        let codec = DeflateCodec::new(stream, Default::default(), pmd_conf, false);
368        Ok(codec)
369    }
370
371    /// get mutable underlying stream
372    pub fn stream_mut(&mut self) -> &mut S {
373        &mut self.stream
374    }
375
376    /// receive a message
377    pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
378        self.read_state.receive(&mut self.stream)
379    }
380
381    /// send a read frame, **this method will not check validation of frame and do not fragment**
382    pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
383        self.write_state.send_owned_frame(&mut self.stream, frame)
384    }
385
386    /// send payload
387    ///
388    /// will auto fragment **before compression** if auto_fragment_size > 0
389    pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
390        self.write_state.send(&mut self.stream, code, payload)
391    }
392
393    /// helper function to send text message
394    pub fn text(&mut self, text: &str) -> Result<(), WsError> {
395        self.write_state
396            .send(&mut self.stream, OpCode::Text, text.as_bytes())
397    }
398
399    /// helper function to send binary message
400    pub fn binary(&mut self, data: &[u8]) -> Result<(), WsError> {
401        self.send(OpCode::Binary, data)
402    }
403
404    /// helper function to send ping message
405    pub fn ping(&mut self, data: &[u8]) -> Result<(), WsError> {
406        self.send(OpCode::Ping, data)
407    }
408
409    /// helper function to send ping message
410    pub fn pong(&mut self, data: &[u8]) -> Result<(), WsError> {
411        self.send(OpCode::Pong, data)
412    }
413
414    /// helper method to send close message
415    pub fn close(&mut self, code: u16, msg: &[u8]) -> Result<(), WsError> {
416        let mut data = code.to_be_bytes().to_vec();
417        data.extend_from_slice(msg);
418        self.send(OpCode::Close, &data)
419    }
420
421    /// flush stream to ensure all data are send
422    pub fn flush(&mut self) -> Result<(), WsError> {
423        self.stream.flush().map_err(WsError::IOError)
424    }
425}
426
427/// recv part of deflate message
428pub struct DeflateRecv<S: Read> {
429    stream: S,
430    read_state: DeflateReadState,
431}
432
433impl<S: Read> DeflateRecv<S> {
434    /// construct method
435    pub fn new(stream: S, read_state: DeflateReadState) -> Self {
436        Self { stream, read_state }
437    }
438
439    /// get mutable underlying stream
440    pub fn stream_mut(&mut self) -> &mut S {
441        &mut self.stream
442    }
443
444    /// receive a frame
445    pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
446        self.read_state.receive(&mut self.stream)
447    }
448}
449
450/// send part of deflate message
451pub struct DeflateSend<S: Write> {
452    stream: S,
453    write_state: DeflateWriteState,
454}
455
456impl<S: Write> DeflateSend<S> {
457    /// construct method
458    pub fn new(stream: S, write_state: DeflateWriteState) -> Self {
459        Self {
460            stream,
461            write_state,
462        }
463    }
464
465    /// get mutable underlying stream
466    pub fn stream_mut(&mut self) -> &mut S {
467        &mut self.stream
468    }
469
470    /// send a read frame, **this method will not check validation of frame and do not fragment**
471    pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
472        self.write_state.send_owned_frame(&mut self.stream, frame)
473    }
474
475    /// send payload
476    ///
477    /// will auto fragment **before compression** if auto_fragment_size > 0
478    pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
479        self.write_state.send(&mut self.stream, code, payload)
480    }
481
482    /// helper function to send text message
483    pub fn text(&mut self, text: &str) -> Result<(), WsError> {
484        self.write_state
485            .send(&mut self.stream, OpCode::Text, text.as_bytes())
486    }
487
488    /// helper function to send binary message
489    pub fn binary(&mut self, data: &[u8]) -> Result<(), WsError> {
490        self.send(OpCode::Binary, data)
491    }
492
493    /// helper function to send ping message
494    pub fn ping(&mut self, data: &[u8]) -> Result<(), WsError> {
495        self.send(OpCode::Ping, data)
496    }
497
498    /// helper function to send ping message
499    pub fn pong(&mut self, data: &[u8]) -> Result<(), WsError> {
500        self.send(OpCode::Pong, data)
501    }
502
503    /// helper method to send close message
504    pub fn close(&mut self, code: u16, msg: &[u8]) -> Result<(), WsError> {
505        let mut data = code.to_be_bytes().to_vec();
506        data.extend_from_slice(msg);
507        self.send(OpCode::Close, &data)
508    }
509
510    /// flush stream to ensure all data are send
511    pub fn flush(&mut self) -> Result<(), WsError> {
512        self.stream.flush().map_err(WsError::IOError)
513    }
514}
515
516impl<R, W, S> DeflateCodec<S>
517where
518    R: Read,
519    W: Write,
520    S: Read + Write + Split<R = R, W = W>,
521{
522    /// split codec to recv and send parts
523    pub fn split(self) -> (DeflateRecv<R>, DeflateSend<W>) {
524        let DeflateCodec {
525            stream,
526            read_state,
527            write_state,
528        } = self;
529        let (read, write) = stream.split();
530        (
531            DeflateRecv::new(read, read_state),
532            DeflateSend::new(write, write_state),
533        )
534    }
535}