ws_tool/codec/frame/
blocking.rs

1use super::{FrameConfig, FrameReadState, FrameWriteState};
2use http;
3use crate::{
4    codec::{apply_mask, Split},
5    errors::WsError,
6    frame::{ctor_header, header_len, OpCode, OwnedFrame, SimplifiedHeader},
7    protocol::standard_handshake_resp_check,
8};
9use bytes::BytesMut;
10use std::{
11    io::{IoSlice, Read, Write},
12    ops::Range,
13};
14
15type IOResult<T> = std::io::Result<T>;
16
17impl FrameReadState {
18    /// **NOTE** masked frame has already been unmasked
19    pub fn receive<S: Read>(
20        &mut self,
21        stream: &mut S,
22    ) -> Result<(SimplifiedHeader, &[u8]), WsError> {
23        if self.config.merge_frame {
24            loop {
25                let (mut header, range) = self.read_one_frame(stream)?;
26                if let Some(merged) = self
27                    .check_frame(header, range.clone())
28                    .and_then(|_| self.merge_frame(header, range.clone()))?
29                {
30                    if merged {
31                        header.code = self.fragmented_type;
32                        break Ok((header, &self.fragmented_data));
33                    } else {
34                        break Ok((header, &self.buf.buf[range]));
35                    }
36                }
37            }
38        } else {
39            let (header, range) = self.read_one_frame(stream)?;
40            self.check_frame(header, range.clone())?;
41            Ok((header, &self.buf.buf[range]))
42        }
43    }
44
45    #[inline]
46    fn read_one_frame<S: Read>(
47        &mut self,
48        stream: &mut S,
49    ) -> Result<(SimplifiedHeader, Range<usize>), WsError> {
50        while !self.is_header_ok() {
51            self.poll(stream)?;
52        }
53        let (header_len, payload_len, total_len) = self.parse_frame_header()?;
54        self.poll_one_frame(stream, total_len)?;
55        Ok(self.consume_frame(header_len, payload_len, total_len))
56    }
57
58    #[inline]
59    fn poll<S: Read>(&mut self, stream: &mut S) -> std::io::Result<usize> {
60        let buf = self.buf.prepare(self.config.resize_size);
61        let count = stream.read(buf)?;
62        self.buf.produce(count);
63        if count == 0 {
64            return Err(std::io::Error::new(
65                std::io::ErrorKind::ConnectionAborted,
66                "read eof",
67            ));
68        }
69        Ok(count)
70    }
71
72    #[inline]
73    fn poll_one_frame<S: Read>(&mut self, stream: &mut S, size: usize) -> std::io::Result<()> {
74        let read_len = self.buf.ava_data().len();
75        if read_len < size {
76            let buf = self.buf.prepare(size - read_len);
77            stream.read_exact(buf)?;
78            self.buf.produce(size - read_len);
79        }
80        Ok(())
81    }
82}
83
84impl FrameWriteState {
85    // DOUBT return error if payload len data >= 126 ?
86    /// send immutable payload
87    ///
88    /// if need to mask, copy data to inner buffer and then apply mask
89    ///
90    /// will auto fragment if auto_fragment_size > 0
91    pub fn send<S: Write>(
92        &mut self,
93        stream: &mut S,
94        opcode: OpCode,
95        payload: &[u8],
96    ) -> IOResult<()> {
97        if payload.is_empty() {
98            let mask = if self.config.mask_send_frame {
99                Some(rand::random())
100            } else {
101                None
102            };
103            let header = ctor_header(
104                &mut self.header_buf,
105                true,
106                false,
107                false,
108                false,
109                mask,
110                opcode,
111                0,
112            );
113            stream.write_all(header)?;
114            return Ok(());
115        }
116        if self.config.auto_fragment_size > 0 && self.config.auto_fragment_size < payload.len() {
117            let chunk_size = self.config.auto_fragment_size;
118            let parts: Vec<&[u8]> = payload.chunks(chunk_size).collect();
119            let total = parts.len();
120            let single_header_len = header_len(self.config.mask_send_frame, chunk_size as u64);
121            let mut header_buf_len = single_header_len * parts.len();
122            let left = payload.len() % chunk_size;
123            if left != 0 {
124                header_buf_len += header_len(self.config.mask_send_frame, left as u64);
125            }
126            let total_bytes = payload.len() + header_buf_len;
127            if self.config.mask_send_frame {
128                if self.buf.len() < total_bytes {
129                    self.buf.resize(total_bytes, 0);
130                }
131                parts.iter().enumerate().for_each(|(idx, chunk)| {
132                    let fin = idx + 1 == total;
133                    let s_idx = idx * single_header_len;
134                    let mask = rand::random();
135                    let header_len = ctor_header(
136                        &mut self.buf[s_idx..],
137                        fin,
138                        false,
139                        false,
140                        false,
141                        mask,
142                        opcode,
143                        chunk.len() as u64,
144                    )
145                    .len();
146                    let slice = &mut self.buf[(s_idx + header_len)..];
147                    slice.copy_from_slice(chunk);
148                    apply_mask(slice, mask);
149                });
150                stream.write_all(&self.buf[..total_bytes])?;
151            } else {
152                if self.buf.len() < header_buf_len {
153                    self.buf.resize(header_buf_len, 0);
154                }
155                let mut slices = Vec::with_capacity(total * 2);
156                parts.iter().enumerate().for_each(|(idx, chunk)| {
157                    let fin = idx + 1 == total;
158                    let s_idx = idx * chunk_size;
159                    ctor_header(
160                        &mut self.buf[s_idx..],
161                        fin,
162                        false,
163                        false,
164                        false,
165                        None,
166                        opcode,
167                        chunk.len() as u64,
168                    );
169                });
170                parts.iter().enumerate().for_each(|(idx, chunk)| {
171                    let fin = idx + 1 == total;
172                    if fin {
173                        slices.push(IoSlice::new(&self.buf[(idx * single_header_len)..]))
174                    } else {
175                        slices.push(IoSlice::new(
176                            &self.buf[(idx * single_header_len)..(idx + 1) * single_header_len],
177                        ))
178                    }
179                    slices.push(IoSlice::new(chunk));
180                });
181                let num = stream.write_vectored(&slices)?;
182                let remain = total_bytes - num;
183                if remain > 0 {
184                    if let Some(buf) = slices.last() {
185                        stream.write_all(&buf[(buf.len() - remain)..])?;
186                    }
187                }
188            }
189        } else if self.config.mask_send_frame {
190            let total_bytes = header_len(true, payload.len() as u64) + payload.len();
191            let mask: [u8; 4] = rand::random();
192            let header = ctor_header(
193                &mut self.header_buf,
194                true,
195                false,
196                false,
197                false,
198                mask,
199                opcode,
200                payload.len() as u64,
201            );
202            if self.buf.len() < payload.len() {
203                self.buf.resize(payload.len(), 0)
204            }
205            self.buf[..(payload.len())].copy_from_slice(payload);
206            apply_mask(&mut self.buf[..(payload.len())], mask);
207            let num = stream.write_vectored(&[
208                IoSlice::new(header),
209                IoSlice::new(&self.buf[..(payload.len())]),
210            ])?;
211            let remain = total_bytes - num;
212            if remain > 0 {
213                stream.write_all(&self.buf[(payload.len() - remain)..(payload.len())])?;
214            }
215        } else {
216            let total_bytes = header_len(false, payload.len() as u64) + payload.len();
217            let header = ctor_header(
218                &mut self.header_buf,
219                true,
220                false,
221                false,
222                false,
223                None,
224                opcode,
225                payload.len() as u64,
226            );
227            // if self.buf.len() < payload.len() {
228            //     self.buf.resize(payload.len(), 0)
229            // }
230            let num = stream.write_vectored(&[IoSlice::new(header), IoSlice::new(payload)])?;
231            let remain = total_bytes - num;
232            if remain > 0 {
233                stream.write_all(&payload[(payload.len() - remain)..])?;
234            }
235        };
236
237        if self.config.renew_buf_on_write {
238            self.buf = BytesMut::new()
239        }
240        Ok(())
241    }
242
243    pub(crate) fn send_owned_frame<S: Write>(
244        &mut self,
245        stream: &mut S,
246        frame: OwnedFrame,
247    ) -> IOResult<()> {
248        let header = IoSlice::new(&frame.header().0);
249        let body = IoSlice::new(frame.payload());
250        let total = header.len() + body.len();
251        let num = stream.write_vectored(&[header, body])?;
252        let remain = total - num;
253        if remain > 0 {
254            stream.write_all(&body[(body.len() - remain)..])?
255        }
256        Ok(())
257    }
258}
259
260/// recv part of websocket stream
261pub struct FrameRecv<S: Read> {
262    stream: S,
263    read_state: FrameReadState,
264}
265
266impl<S: Read> FrameRecv<S> {
267    /// construct method
268    pub fn new(stream: S, read_state: FrameReadState) -> Self {
269        Self { stream, read_state }
270    }
271
272    /// receive a frame
273    pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
274        self.read_state.receive(&mut self.stream)
275    }
276}
277
278/// send part of websocket frame
279pub struct FrameSend<S: Write> {
280    stream: S,
281    write_state: FrameWriteState,
282}
283
284impl<S: Write> FrameSend<S> {
285    /// construct method
286    pub fn new(stream: S, write_state: FrameWriteState) -> Self {
287        Self {
288            stream,
289            write_state,
290        }
291    }
292
293    /// send payload
294    ///
295    /// will auto fragment if auto_fragment_size > 0
296    pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
297        self.write_state
298            .send(&mut self.stream, code, payload)
299            .map_err(WsError::IOError)
300    }
301
302    /// send a read frame, **this method will not check validation of frame and do not fragment**
303    pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
304        self.write_state
305            .send_owned_frame(&mut self.stream, frame)
306            .map_err(WsError::IOError)
307    }
308
309    /// flush stream to ensure all data are send
310    pub fn flush(&mut self) -> Result<(), WsError> {
311        self.stream.flush().map_err(WsError::IOError)
312    }
313}
314
315/// recv/send websocket frame
316pub struct FrameCodec<S: Read + Write> {
317    /// underlying transport stream
318    pub stream: S,
319    /// read state
320    pub read_state: FrameReadState,
321    /// write state
322    pub write_state: FrameWriteState,
323}
324
325impl<S: Read + Write> FrameCodec<S> {
326    /// construct method
327    pub fn new(stream: S) -> Self {
328        Self {
329            stream,
330            read_state: FrameReadState::default(),
331            write_state: FrameWriteState::default(),
332        }
333    }
334
335    /// construct with stream and config
336    pub fn new_with(stream: S, config: FrameConfig) -> Self {
337        Self {
338            stream,
339            read_state: FrameReadState::with_config(config.clone()),
340            write_state: FrameWriteState::with_config(config),
341        }
342    }
343
344    /// get mutable underlying stream
345    pub fn stream_mut(&mut self) -> &mut S {
346        &mut self.stream
347    }
348
349    /// used for server side to construct a new server
350    pub fn factory(_req: http::Request<()>, stream: S) -> Result<Self, WsError> {
351        let config = FrameConfig {
352            mask_send_frame: false,
353            ..Default::default()
354        };
355        Ok(Self::new_with(stream, config))
356    }
357
358    /// used to client side to construct a new client
359    pub fn check_fn(key: String, resp: http::Response<()>, stream: S) -> Result<Self, WsError> {
360        standard_handshake_resp_check(key.as_bytes(), &resp)?;
361        Ok(Self::new_with(stream, Default::default()))
362    }
363
364    /// receive a frame
365    pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
366        self.read_state.receive(&mut self.stream)
367    }
368
369    /// send data, **will copy data if need mask**
370    pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
371        self.write_state
372            .send(&mut self.stream, code, payload)
373            .map_err(WsError::IOError)
374    }
375
376    /// send a read frame, **this method will not check validation of frame and do not fragment**
377    pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
378        self.write_state
379            .send_owned_frame(&mut self.stream, frame)
380            .map_err(WsError::IOError)
381    }
382
383    /// flush stream to ensure all data are send
384    pub fn flush(&mut self) -> Result<(), WsError> {
385        self.stream.flush().map_err(WsError::IOError)
386    }
387}
388
389impl<R, W, S> FrameCodec<S>
390where
391    R: Read,
392    W: Write,
393    S: Read + Write + Split<R = R, W = W>,
394{
395    /// split codec to recv and send parts
396    pub fn split(self) -> (FrameRecv<R>, FrameSend<W>) {
397        let FrameCodec {
398            stream,
399            read_state,
400            write_state,
401        } = self;
402        let (read, write) = stream.split();
403        (
404            FrameRecv::new(read, read_state),
405            FrameSend::new(write, write_state),
406        )
407    }
408}