1use std::{
6 collections::VecDeque,
7 io::{self, IoSlice},
8 mem::{replace, take},
9 pin::Pin,
10 task::{ready, Context, Poll, Waker},
11};
12
13use bytes::{Buf, BytesMut};
14use futures_core::Stream;
15use futures_sink::Sink;
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio_util::{codec::FramedRead, io::poll_write_buf};
18
19#[cfg(any(feature = "client", feature = "server"))]
20use super::types::Role;
21use super::{
22 codec::WebSocketProtocol,
23 types::{Frame, Message, OpCode, Payload, StreamState},
24 Config, Limits,
25};
26use crate::{CloseCode, Error};
27
28#[derive(Debug)]
30struct EncodedFrame {
31 header: [u8; 14],
33 payload: Payload,
35}
36
37impl EncodedFrame {
38 #[inline]
40 fn is_masked(&self) -> bool {
41 self.header[1] >> 7 != 0
42 }
43
44 #[inline]
46 fn header_len(&self) -> usize {
47 let mask_bytes = if self.is_masked() { 4 } else { 0 };
48 match self.header[1] & 127 {
49 127 => 10 + mask_bytes,
50 126 => 4 + mask_bytes,
51 _ => 2 + mask_bytes,
52 }
53 }
54
55 fn len(&self) -> usize {
57 self.header_len() + self.payload.len()
58 }
59}
60
61#[derive(Debug)]
63struct FrameQueue {
64 queue: VecDeque<EncodedFrame>,
67 bytes_written: usize,
69 pending_bytes: usize,
71}
72
73impl FrameQueue {
74 #[cfg(any(feature = "client", feature = "server"))]
76 fn new() -> Self {
77 Self {
78 queue: VecDeque::with_capacity(1),
79 bytes_written: 0,
80 pending_bytes: 0,
81 }
82 }
83
84 fn push(&mut self, item: EncodedFrame) {
86 self.pending_bytes += item.len();
87 self.queue.push_back(item);
88 }
89}
90
91impl Buf for FrameQueue {
92 fn remaining(&self) -> usize {
93 self.pending_bytes
94 }
95
96 fn chunk(&self) -> &[u8] {
97 if let Some(frame) = self.queue.front() {
98 if self.bytes_written >= frame.header_len() {
99 unsafe {
100 frame
101 .payload
102 .get_unchecked(self.bytes_written - frame.header_len()..)
103 }
104 } else {
105 &frame.header[self.bytes_written..frame.header_len()]
106 }
107 } else {
108 &[]
109 }
110 }
111
112 fn advance(&mut self, mut cnt: usize) {
113 self.pending_bytes -= cnt;
114 cnt += self.bytes_written;
115
116 while cnt > 0 {
117 let item = self
118 .queue
119 .front()
120 .expect("advance called with too long count");
121 let item_len = item.len();
122
123 if cnt >= item_len {
124 self.queue.pop_front();
125 self.bytes_written = 0;
126 cnt -= item_len;
127 } else {
128 self.bytes_written = cnt;
129 return;
130 }
131 }
132 }
133
134 fn chunks_vectored<'a>(&'a self, dst: &mut [io::IoSlice<'a>]) -> usize {
135 let mut n = 0;
136 for (idx, frame) in self.queue.iter().enumerate() {
137 if n >= dst.len() {
138 break;
139 }
140
141 if idx == 0 {
142 if frame.header_len() > self.bytes_written {
143 dst[n] = IoSlice::new(&frame.header[self.bytes_written..frame.header_len()]);
144 n += 1;
145 }
146
147 if !frame.payload.is_empty() && n < dst.len() {
148 dst[n] = IoSlice::new(unsafe {
149 frame
150 .payload
151 .get_unchecked(self.bytes_written.saturating_sub(frame.header_len())..)
152 });
153 n += 1;
154 }
155 } else {
156 dst[n] = IoSlice::new(&frame.header[..frame.header_len()]);
157 n += 1;
158 if !frame.payload.is_empty() && n < dst.len() {
159 dst[n] = IoSlice::new(&frame.payload);
160 n += 1;
161 }
162 }
163 }
164
165 n
166 }
167}
168
169#[allow(clippy::module_name_repetitions)]
182#[derive(Debug)]
183pub struct WebSocketStream<T> {
184 inner: FramedRead<T, WebSocketProtocol>,
187
188 config: Config,
190
191 state: StreamState,
193
194 partial_payload: BytesMut,
196 partial_opcode: OpCode,
198
199 header_buf: [u8; 14],
201
202 frame_queue: FrameQueue,
204
205 flushing_waker: Option<Waker>,
208}
209
210impl<T> WebSocketStream<T>
211where
212 T: AsyncRead + AsyncWrite + Unpin,
213{
214 #[cfg(any(feature = "client", feature = "server"))]
216 pub(crate) fn from_raw_stream(stream: T, role: Role, config: Config, limits: Limits) -> Self {
217 Self {
218 inner: FramedRead::new(stream, WebSocketProtocol::new(role, limits)),
219 config,
220 state: StreamState::Active,
221 partial_payload: BytesMut::new(),
222 partial_opcode: OpCode::Continuation,
223 header_buf: [0; 14],
224 frame_queue: FrameQueue::new(),
225 flushing_waker: None,
226 }
227 }
228
229 #[cfg(any(feature = "client", feature = "server"))]
232 pub(crate) fn from_framed<U>(
233 framed: FramedRead<T, U>,
234 role: Role,
235 config: Config,
236 limits: Limits,
237 ) -> Self {
238 Self {
239 inner: framed.map_decoder(|_| WebSocketProtocol::new(role, limits)),
240 config,
241 state: StreamState::Active,
242 partial_payload: BytesMut::new(),
243 partial_opcode: OpCode::Continuation,
244 header_buf: [0; 14],
245 frame_queue: FrameQueue::new(),
246 flushing_waker: None,
247 }
248 }
249
250 pub fn get_ref(&self) -> &T {
255 self.inner.get_ref()
256 }
257
258 pub fn get_mut(&mut self) -> &mut T {
264 self.inner.get_mut()
265 }
266
267 pub fn limits(&self) -> &Limits {
269 &self.inner.decoder().limits
270 }
271
272 pub fn limits_mut(&mut self) -> &mut Limits {
274 &mut self.inner.decoder_mut().limits
275 }
276
277 fn poll_next_frame(
285 mut self: Pin<&mut Self>,
286 cx: &mut Context<'_>,
287 ) -> Poll<Option<Result<Frame, Error>>> {
288 if self.state == StreamState::CloseAcknowledged {
292 return Poll::Ready(None);
293 } else if self.state == StreamState::ClosedByPeer {
294 ready!(self.as_mut().poll_flush(cx))?;
295 self.state = StreamState::CloseAcknowledged;
296 return Poll::Ready(None);
297 }
298
299 if self.frame_queue.has_remaining() {
306 let waker = self.flushing_waker.clone();
307 _ = self.as_mut().poll_flush(&mut Context::from_waker(
308 waker.as_ref().unwrap_or(cx.waker()),
309 ))?;
310 }
311
312 let frame = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
313 Some(Ok(frame)) => frame,
314 Some(Err(e)) => {
315 if matches!(e, Error::Io(_)) || self.state == StreamState::ClosedByUs {
316 self.state = StreamState::CloseAcknowledged;
317 } else {
318 self.state = StreamState::ClosedByPeer;
319
320 match &e {
321 Error::Protocol(e) => self.queue_frame(Frame::from(e)),
322 Error::PayloadTooLong { max_len, .. } => self.queue_frame(
323 Message::close(
324 Some(CloseCode::MESSAGE_TOO_BIG),
325 &format!("max length: {max_len}"),
326 )
327 .into(),
328 ),
329 _ => {}
330 }
331 }
332 return Poll::Ready(Some(Err(e)));
333 }
334 None => return Poll::Ready(None),
335 };
336
337 match frame.opcode {
338 OpCode::Close => match self.state {
339 StreamState::Active => {
340 self.state = StreamState::ClosedByPeer;
341
342 let mut frame = frame.clone();
343 frame.payload.truncate(2);
344
345 self.queue_frame(frame);
346 }
347 StreamState::ClosedByPeer | StreamState::CloseAcknowledged => {
348 debug_assert!(false, "unexpected StreamState");
349 }
350 StreamState::ClosedByUs => {
351 self.state = StreamState::CloseAcknowledged;
352 }
353 },
354 OpCode::Ping if self.state == StreamState::Active => {
355 let mut frame = frame.clone();
356 frame.opcode = OpCode::Pong;
357
358 self.queue_frame(frame);
359 }
360 _ => {}
361 }
362
363 Poll::Ready(Some(Ok(frame)))
364 }
365
366 fn queue_frame(
368 &mut self,
369 #[cfg_attr(not(feature = "client"), allow(unused_mut))] mut frame: Frame,
370 ) {
371 if frame.opcode == OpCode::Close && self.state != StreamState::ClosedByPeer {
372 self.state = StreamState::ClosedByUs;
373 }
374
375 #[cfg_attr(not(feature = "client"), allow(unused_variables))]
376 let mask = frame.encode(&mut self.header_buf);
377
378 #[cfg(feature = "client")]
379 {
380 if self.inner.decoder().role == Role::Client {
381 let mut payload = BytesMut::from(frame.payload);
382 crate::rand::get_mask(mask);
383 let mut mask_copy = *mask;
388 crate::mask::frame(&mut mask_copy, &mut payload);
389 frame.payload = Payload::from(payload);
390 self.header_buf[1] |= 1 << 7;
391 }
392 }
393
394 let item = EncodedFrame {
395 header: self.header_buf,
396 payload: frame.payload,
397 };
398 self.frame_queue.push(item);
399 }
400
401 fn set_flushing_waker(&mut self, waker: &Waker) {
404 if !self
405 .flushing_waker
406 .as_ref()
407 .is_some_and(|w| w.will_wake(waker))
408 {
409 self.flushing_waker = Some(waker.clone());
410 }
411 }
412}
413
414impl<T> Stream for WebSocketStream<T>
415where
416 T: AsyncRead + AsyncWrite + Unpin,
417{
418 type Item = Result<Message, Error>;
419
420 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
421 let max_len = self.inner.decoder().limits.max_payload_len;
422
423 loop {
424 let (opcode, payload, fin) = match ready!(self.as_mut().poll_next_frame(cx)?) {
425 Some(frame) => (frame.opcode, frame.payload, frame.is_final),
426 None => return Poll::Ready(None),
427 };
428 let len = self.partial_payload.len() + payload.len();
429
430 if opcode != OpCode::Continuation {
431 if fin {
432 return Poll::Ready(Some(Ok(Message { opcode, payload })));
433 }
434 self.partial_opcode = opcode;
435 self.partial_payload = BytesMut::from(payload);
436 } else if len > max_len {
437 return Poll::Ready(Some(Err(Error::PayloadTooLong { len, max_len })));
438 } else {
439 self.partial_payload.extend_from_slice(&payload);
440 }
441
442 if fin {
443 break;
444 }
445 }
446
447 let opcode = replace(&mut self.partial_opcode, OpCode::Continuation);
448 let mut payload = Payload::from(take(&mut self.partial_payload));
449 payload.set_utf8_validated(opcode == OpCode::Text);
450
451 Poll::Ready(Some(Ok(Message { opcode, payload })))
452 }
453}
454
455impl<T> Sink<Message> for WebSocketStream<T>
461where
462 T: AsyncRead + AsyncWrite + Unpin,
463{
464 type Error = Error;
465
466 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
467 if self.frame_queue.remaining() >= self.config.flush_threshold {
470 self.as_mut().poll_flush(cx)
471 } else {
472 Poll::Ready(Ok(()))
473 }
474 }
475
476 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
477 if self.state != StreamState::Active {
478 return Err(Error::AlreadyClosed);
479 }
480
481 if item.opcode.is_control() || item.payload.len() <= self.config.frame_size {
482 let frame: Frame = item.into();
483 self.queue_frame(frame);
484 } else {
485 for frame in item.into_frames(self.config.frame_size) {
487 self.queue_frame(frame);
488 }
489 }
490
491 Ok(())
492 }
493
494 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
495 let this = self.get_mut();
498 let frame_queue = &mut this.frame_queue;
499 let io = this.inner.get_mut();
500 let flushing_waker = &mut this.flushing_waker;
501
502 while frame_queue.has_remaining() {
503 let n = match poll_write_buf(Pin::new(io), cx, frame_queue) {
504 Poll::Ready(Ok(n)) => n,
505 Poll::Ready(Err(e)) => {
506 *flushing_waker = None;
507 this.state = StreamState::CloseAcknowledged;
508 return Poll::Ready(Err(Error::Io(e)));
509 }
510 Poll::Pending => {
511 this.set_flushing_waker(cx.waker());
512 return Poll::Pending;
513 }
514 };
515
516 if n == 0 {
517 *flushing_waker = None;
518 this.state = StreamState::CloseAcknowledged;
519 return Poll::Ready(Err(Error::Io(io::ErrorKind::WriteZero.into())));
520 }
521 }
522
523 match Pin::new(io).poll_flush(cx) {
524 Poll::Ready(Ok(())) => {
525 *flushing_waker = None;
526 Poll::Ready(Ok(()))
527 }
528 Poll::Ready(Err(e)) => {
529 *flushing_waker = None;
530 this.state = StreamState::CloseAcknowledged;
531 Poll::Ready(Err(Error::Io(e)))
532 }
533 Poll::Pending => {
534 this.set_flushing_waker(cx.waker());
535 Poll::Pending
536 }
537 }
538 }
539
540 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
541 if self.state == StreamState::Active {
542 self.queue_frame(Frame::DEFAULT_CLOSE);
543 }
544 while ready!(self.as_mut().poll_next(cx)).is_some() {}
545
546 ready!(self.as_mut().poll_flush(cx))?;
547 Pin::new(self.inner.get_mut())
548 .poll_shutdown(cx)
549 .map_err(Error::Io)
550 }
551}