1use super::{FrameConfig, FrameReadState, FrameWriteState};
2use http;
3use crate::{
4 codec::{apply_mask, Split},
5 errors::WsError,
6 frame::{ctor_header, header_len, OpCode, OwnedFrame, SimplifiedHeader},
7 protocol::standard_handshake_resp_check,
8};
9use bytes::BytesMut;
10use std::{
11 io::{IoSlice, Read, Write},
12 ops::Range,
13};
14
15type IOResult<T> = std::io::Result<T>;
16
17impl FrameReadState {
18 pub fn receive<S: Read>(
20 &mut self,
21 stream: &mut S,
22 ) -> Result<(SimplifiedHeader, &[u8]), WsError> {
23 if self.config.merge_frame {
24 loop {
25 let (mut header, range) = self.read_one_frame(stream)?;
26 if let Some(merged) = self
27 .check_frame(header, range.clone())
28 .and_then(|_| self.merge_frame(header, range.clone()))?
29 {
30 if merged {
31 header.code = self.fragmented_type;
32 break Ok((header, &self.fragmented_data));
33 } else {
34 break Ok((header, &self.buf.buf[range]));
35 }
36 }
37 }
38 } else {
39 let (header, range) = self.read_one_frame(stream)?;
40 self.check_frame(header, range.clone())?;
41 Ok((header, &self.buf.buf[range]))
42 }
43 }
44
45 #[inline]
46 fn read_one_frame<S: Read>(
47 &mut self,
48 stream: &mut S,
49 ) -> Result<(SimplifiedHeader, Range<usize>), WsError> {
50 while !self.is_header_ok() {
51 self.poll(stream)?;
52 }
53 let (header_len, payload_len, total_len) = self.parse_frame_header()?;
54 self.poll_one_frame(stream, total_len)?;
55 Ok(self.consume_frame(header_len, payload_len, total_len))
56 }
57
58 #[inline]
59 fn poll<S: Read>(&mut self, stream: &mut S) -> std::io::Result<usize> {
60 let buf = self.buf.prepare(self.config.resize_size);
61 let count = stream.read(buf)?;
62 self.buf.produce(count);
63 if count == 0 {
64 return Err(std::io::Error::new(
65 std::io::ErrorKind::ConnectionAborted,
66 "read eof",
67 ));
68 }
69 Ok(count)
70 }
71
72 #[inline]
73 fn poll_one_frame<S: Read>(&mut self, stream: &mut S, size: usize) -> std::io::Result<()> {
74 let read_len = self.buf.ava_data().len();
75 if read_len < size {
76 let buf = self.buf.prepare(size - read_len);
77 stream.read_exact(buf)?;
78 self.buf.produce(size - read_len);
79 }
80 Ok(())
81 }
82}
83
84impl FrameWriteState {
85 pub fn send<S: Write>(
92 &mut self,
93 stream: &mut S,
94 opcode: OpCode,
95 payload: &[u8],
96 ) -> IOResult<()> {
97 if payload.is_empty() {
98 let mask = if self.config.mask_send_frame {
99 Some(rand::random())
100 } else {
101 None
102 };
103 let header = ctor_header(
104 &mut self.header_buf,
105 true,
106 false,
107 false,
108 false,
109 mask,
110 opcode,
111 0,
112 );
113 stream.write_all(header)?;
114 return Ok(());
115 }
116 if self.config.auto_fragment_size > 0 && self.config.auto_fragment_size < payload.len() {
117 let chunk_size = self.config.auto_fragment_size;
118 let parts: Vec<&[u8]> = payload.chunks(chunk_size).collect();
119 let total = parts.len();
120 let single_header_len = header_len(self.config.mask_send_frame, chunk_size as u64);
121 let mut header_buf_len = single_header_len * parts.len();
122 let left = payload.len() % chunk_size;
123 if left != 0 {
124 header_buf_len += header_len(self.config.mask_send_frame, left as u64);
125 }
126 let total_bytes = payload.len() + header_buf_len;
127 if self.config.mask_send_frame {
128 if self.buf.len() < total_bytes {
129 self.buf.resize(total_bytes, 0);
130 }
131 parts.iter().enumerate().for_each(|(idx, chunk)| {
132 let fin = idx + 1 == total;
133 let s_idx = idx * single_header_len;
134 let mask = rand::random();
135 let header_len = ctor_header(
136 &mut self.buf[s_idx..],
137 fin,
138 false,
139 false,
140 false,
141 mask,
142 opcode,
143 chunk.len() as u64,
144 )
145 .len();
146 let slice = &mut self.buf[(s_idx + header_len)..];
147 slice.copy_from_slice(chunk);
148 apply_mask(slice, mask);
149 });
150 stream.write_all(&self.buf[..total_bytes])?;
151 } else {
152 if self.buf.len() < header_buf_len {
153 self.buf.resize(header_buf_len, 0);
154 }
155 let mut slices = Vec::with_capacity(total * 2);
156 parts.iter().enumerate().for_each(|(idx, chunk)| {
157 let fin = idx + 1 == total;
158 let s_idx = idx * chunk_size;
159 ctor_header(
160 &mut self.buf[s_idx..],
161 fin,
162 false,
163 false,
164 false,
165 None,
166 opcode,
167 chunk.len() as u64,
168 );
169 });
170 parts.iter().enumerate().for_each(|(idx, chunk)| {
171 let fin = idx + 1 == total;
172 if fin {
173 slices.push(IoSlice::new(&self.buf[(idx * single_header_len)..]))
174 } else {
175 slices.push(IoSlice::new(
176 &self.buf[(idx * single_header_len)..(idx + 1) * single_header_len],
177 ))
178 }
179 slices.push(IoSlice::new(chunk));
180 });
181 let num = stream.write_vectored(&slices)?;
182 let remain = total_bytes - num;
183 if remain > 0 {
184 if let Some(buf) = slices.last() {
185 stream.write_all(&buf[(buf.len() - remain)..])?;
186 }
187 }
188 }
189 } else if self.config.mask_send_frame {
190 let total_bytes = header_len(true, payload.len() as u64) + payload.len();
191 let mask: [u8; 4] = rand::random();
192 let header = ctor_header(
193 &mut self.header_buf,
194 true,
195 false,
196 false,
197 false,
198 mask,
199 opcode,
200 payload.len() as u64,
201 );
202 if self.buf.len() < payload.len() {
203 self.buf.resize(payload.len(), 0)
204 }
205 self.buf[..(payload.len())].copy_from_slice(payload);
206 apply_mask(&mut self.buf[..(payload.len())], mask);
207 let num = stream.write_vectored(&[
208 IoSlice::new(header),
209 IoSlice::new(&self.buf[..(payload.len())]),
210 ])?;
211 let remain = total_bytes - num;
212 if remain > 0 {
213 stream.write_all(&self.buf[(payload.len() - remain)..(payload.len())])?;
214 }
215 } else {
216 let total_bytes = header_len(false, payload.len() as u64) + payload.len();
217 let header = ctor_header(
218 &mut self.header_buf,
219 true,
220 false,
221 false,
222 false,
223 None,
224 opcode,
225 payload.len() as u64,
226 );
227 let num = stream.write_vectored(&[IoSlice::new(header), IoSlice::new(payload)])?;
231 let remain = total_bytes - num;
232 if remain > 0 {
233 stream.write_all(&payload[(payload.len() - remain)..])?;
234 }
235 };
236
237 if self.config.renew_buf_on_write {
238 self.buf = BytesMut::new()
239 }
240 Ok(())
241 }
242
243 pub(crate) fn send_owned_frame<S: Write>(
244 &mut self,
245 stream: &mut S,
246 frame: OwnedFrame,
247 ) -> IOResult<()> {
248 let header = IoSlice::new(&frame.header().0);
249 let body = IoSlice::new(frame.payload());
250 let total = header.len() + body.len();
251 let num = stream.write_vectored(&[header, body])?;
252 let remain = total - num;
253 if remain > 0 {
254 stream.write_all(&body[(body.len() - remain)..])?
255 }
256 Ok(())
257 }
258}
259
260pub struct FrameRecv<S: Read> {
262 stream: S,
263 read_state: FrameReadState,
264}
265
266impl<S: Read> FrameRecv<S> {
267 pub fn new(stream: S, read_state: FrameReadState) -> Self {
269 Self { stream, read_state }
270 }
271
272 pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
274 self.read_state.receive(&mut self.stream)
275 }
276}
277
278pub struct FrameSend<S: Write> {
280 stream: S,
281 write_state: FrameWriteState,
282}
283
284impl<S: Write> FrameSend<S> {
285 pub fn new(stream: S, write_state: FrameWriteState) -> Self {
287 Self {
288 stream,
289 write_state,
290 }
291 }
292
293 pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
297 self.write_state
298 .send(&mut self.stream, code, payload)
299 .map_err(WsError::IOError)
300 }
301
302 pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
304 self.write_state
305 .send_owned_frame(&mut self.stream, frame)
306 .map_err(WsError::IOError)
307 }
308
309 pub fn flush(&mut self) -> Result<(), WsError> {
311 self.stream.flush().map_err(WsError::IOError)
312 }
313}
314
315pub struct FrameCodec<S: Read + Write> {
317 pub stream: S,
319 pub read_state: FrameReadState,
321 pub write_state: FrameWriteState,
323}
324
325impl<S: Read + Write> FrameCodec<S> {
326 pub fn new(stream: S) -> Self {
328 Self {
329 stream,
330 read_state: FrameReadState::default(),
331 write_state: FrameWriteState::default(),
332 }
333 }
334
335 pub fn new_with(stream: S, config: FrameConfig) -> Self {
337 Self {
338 stream,
339 read_state: FrameReadState::with_config(config.clone()),
340 write_state: FrameWriteState::with_config(config),
341 }
342 }
343
344 pub fn stream_mut(&mut self) -> &mut S {
346 &mut self.stream
347 }
348
349 pub fn factory(_req: http::Request<()>, stream: S) -> Result<Self, WsError> {
351 let config = FrameConfig {
352 mask_send_frame: false,
353 ..Default::default()
354 };
355 Ok(Self::new_with(stream, config))
356 }
357
358 pub fn check_fn(key: String, resp: http::Response<()>, stream: S) -> Result<Self, WsError> {
360 standard_handshake_resp_check(key.as_bytes(), &resp)?;
361 Ok(Self::new_with(stream, Default::default()))
362 }
363
364 pub fn receive(&mut self) -> Result<(SimplifiedHeader, &[u8]), WsError> {
366 self.read_state.receive(&mut self.stream)
367 }
368
369 pub fn send(&mut self, code: OpCode, payload: &[u8]) -> Result<(), WsError> {
371 self.write_state
372 .send(&mut self.stream, code, payload)
373 .map_err(WsError::IOError)
374 }
375
376 pub fn send_owned_frame(&mut self, frame: OwnedFrame) -> Result<(), WsError> {
378 self.write_state
379 .send_owned_frame(&mut self.stream, frame)
380 .map_err(WsError::IOError)
381 }
382
383 pub fn flush(&mut self) -> Result<(), WsError> {
385 self.stream.flush().map_err(WsError::IOError)
386 }
387}
388
389impl<R, W, S> FrameCodec<S>
390where
391 R: Read,
392 W: Write,
393 S: Read + Write + Split<R = R, W = W>,
394{
395 pub fn split(self) -> (FrameRecv<R>, FrameSend<W>) {
397 let FrameCodec {
398 stream,
399 read_state,
400 write_state,
401 } = self;
402 let (read, write) = stream.split();
403 (
404 FrameRecv::new(read, read_state),
405 FrameSend::new(write, write_state),
406 )
407 }
408}