1use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::BytesMut;
10use futures_core::Stream;
11use futures_sink::Sink;
12use pin_project_lite::pin_project;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15use crate::Config;
16use crate::cork::CorkBuffer;
17use crate::error::{CloseReason, Error, Result};
18use crate::protocol::{Message, Protocol, Role};
19
20const DEFAULT_HIGH_WATER_MARK: usize = 64 * 1024;
22
23const DEFAULT_LOW_WATER_MARK: usize = 16 * 1024;
25
26pin_project! {
27 pub struct WebSocketStream<S> {
61 #[pin]
62 inner: S,
63 protocol: Protocol,
64 read_buf: BytesMut,
65 write_buf: CorkBuffer,
66 state: StreamState,
67 config: Config,
68 pending_messages: Vec<Message>,
70 pending_index: usize,
71 high_water_mark: usize,
73 low_water_mark: usize,
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum StreamState {
79 Open,
81 Flushing,
83 CloseSent,
85 Closed,
87}
88
89impl<S> WebSocketStream<S>
90where
91 S: AsyncRead + AsyncWrite + Unpin,
92{
93 pub fn from_raw(inner: S, role: Role, config: Config) -> Self {
95 let protocol = Protocol::new(role, config.max_frame_size, config.max_message_size);
96
97 Self {
98 inner,
99 protocol,
100 read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
101 write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
102 state: StreamState::Open,
103 config,
104 pending_messages: Vec::new(),
105 pending_index: 0,
106 high_water_mark: DEFAULT_HIGH_WATER_MARK,
107 low_water_mark: DEFAULT_LOW_WATER_MARK,
108 }
109 }
110
111 pub fn server(inner: S, config: Config) -> Self {
113 Self::from_raw(inner, Role::Server, config)
114 }
115
116 pub fn client(inner: S, config: Config) -> Self {
118 Self::from_raw(inner, Role::Client, config)
119 }
120
121 pub fn get_ref(&self) -> &S {
123 &self.inner
124 }
125
126 pub fn get_mut(&mut self) -> &mut S {
128 &mut self.inner
129 }
130
131 pub fn into_inner(self) -> S {
133 self.inner
134 }
135
136 pub fn is_closed(&self) -> bool {
138 self.state == StreamState::Closed
139 }
140
141 #[inline]
160 pub fn is_backpressured(&self) -> bool {
161 self.write_buf.pending_bytes() > self.high_water_mark
162 }
163
164 #[inline]
169 pub fn is_write_buffer_low(&self) -> bool {
170 self.write_buf.pending_bytes() <= self.low_water_mark
171 }
172
173 #[inline]
177 pub fn write_buffer_len(&self) -> usize {
178 self.write_buf.pending_bytes()
179 }
180
181 #[inline]
185 pub fn read_buffer_len(&self) -> usize {
186 self.read_buf.len()
187 }
188
189 #[inline]
194 pub fn set_high_water_mark(&mut self, size: usize) {
195 self.high_water_mark = size;
196 }
197
198 #[inline]
203 pub fn set_low_water_mark(&mut self, size: usize) {
204 self.low_water_mark = size;
205 }
206
207 #[inline]
209 pub fn high_water_mark(&self) -> usize {
210 self.high_water_mark
211 }
212
213 #[inline]
215 pub fn low_water_mark(&self) -> usize {
216 self.low_water_mark
217 }
218
219 pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
221 if self.state != StreamState::Open {
222 return Ok(());
223 }
224
225 let close = Message::Close(Some(CloseReason::new(code, reason)));
226 self.protocol
227 .encode_message(&close, self.write_buf.buffer_mut())?;
228 self.state = StreamState::CloseSent;
229
230 self.flush_write_buf().await?;
232 Ok(())
233 }
234
235 async fn flush_write_buf(&mut self) -> Result<()> {
237 use tokio::io::AsyncWriteExt;
238
239 while self.write_buf.has_data() {
240 let slices = self.write_buf.get_write_slices();
241 if slices.is_empty() {
242 break;
243 }
244
245 let n = self.inner.write_vectored(&slices).await?;
246 if n == 0 {
247 return Err(Error::ConnectionClosed);
248 }
249 self.write_buf.consume(n);
250 }
251
252 self.inner.flush().await?;
253 Ok(())
254 }
255
256 fn poll_read_more(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
258 let this = self.project();
259
260 if this.read_buf.capacity() - this.read_buf.len() < 4096 {
262 this.read_buf.reserve(8192);
263 }
264
265 let buf_len = this.read_buf.len();
267 let buf_cap = this.read_buf.capacity();
268
269 unsafe {
271 this.read_buf.set_len(buf_cap);
272 }
273
274 let mut read_buf = ReadBuf::new(&mut this.read_buf[buf_len..]);
275
276 match this.inner.poll_read(cx, &mut read_buf) {
277 Poll::Ready(Ok(())) => {
278 let n = read_buf.filled().len();
279 unsafe {
280 this.read_buf.set_len(buf_len + n);
281 }
282 if n == 0 {
283 Poll::Ready(Ok(0))
284 } else {
285 Poll::Ready(Ok(n))
286 }
287 }
288 Poll::Ready(Err(e)) => {
289 unsafe {
290 this.read_buf.set_len(buf_len);
291 }
292 Poll::Ready(Err(e))
293 }
294 Poll::Pending => {
295 unsafe {
296 this.read_buf.set_len(buf_len);
297 }
298 Poll::Pending
299 }
300 }
301 }
302
303 fn process_read_buf(&mut self) -> Result<()> {
305 if self.read_buf.is_empty() {
306 return Ok(());
307 }
308
309 let messages = self.protocol.process(&mut self.read_buf)?;
310
311 if !messages.is_empty() {
312 self.pending_messages = messages;
313 self.pending_index = 0;
314 }
315
316 Ok(())
317 }
318
319 fn next_pending_message(&mut self) -> Option<Message> {
321 if self.pending_index < self.pending_messages.len() {
322 let msg = self.pending_messages[self.pending_index].clone();
323 self.pending_index += 1;
324
325 if self.pending_index >= self.pending_messages.len() {
327 self.pending_messages.clear();
328 self.pending_index = 0;
329 }
330
331 Some(msg)
332 } else {
333 None
334 }
335 }
336}
337
338impl<S> Stream for WebSocketStream<S>
339where
340 S: AsyncRead + AsyncWrite + Unpin,
341{
342 type Item = Result<Message>;
343
344 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345 loop {
346 if self.state == StreamState::Closed {
348 return Poll::Ready(None);
349 }
350
351 if let Some(msg) = self.as_mut().get_mut().next_pending_message() {
353 match &msg {
355 Message::Ping(data) => {
356 let this = self.as_mut().get_mut();
358 this.protocol.encode_pong(data, this.write_buf.buffer_mut());
359 }
360 Message::Close(reason) => {
361 let this = self.as_mut().get_mut();
362 if this.state == StreamState::Open {
363 this.protocol
365 .encode_close_response(this.write_buf.buffer_mut());
366 this.state = StreamState::Closed;
367 }
368 return Poll::Ready(Some(Ok(Message::Close(reason.clone()))));
369 }
370 _ => {}
371 }
372
373 return Poll::Ready(Some(Ok(msg)));
374 }
375
376 match self.as_mut().poll_read_more(cx) {
378 Poll::Ready(Ok(0)) => {
379 self.as_mut().get_mut().state = StreamState::Closed;
381 return Poll::Ready(None);
382 }
383 Poll::Ready(Ok(_n)) => {
384 match self.as_mut().get_mut().process_read_buf() {
386 Ok(()) => continue, Err(e) => return Poll::Ready(Some(Err(e))),
388 }
389 }
390 Poll::Ready(Err(e)) => {
391 return Poll::Ready(Some(Err(e.into())));
392 }
393 Poll::Pending => {
394 return Poll::Pending;
396 }
397 }
398 }
399 }
400}
401
402impl<S> Sink<Message> for WebSocketStream<S>
403where
404 S: AsyncRead + AsyncWrite + Unpin,
405{
406 type Error = Error;
407
408 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
409 if self.state == StreamState::Closed {
410 return Poll::Ready(Err(Error::ConnectionClosed));
411 }
412 Poll::Ready(Ok(()))
413 }
414
415 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
416 let this = self.get_mut();
417
418 if this.state == StreamState::Closed {
419 return Err(Error::ConnectionClosed);
420 }
421
422 if item.is_close() {
424 this.state = StreamState::CloseSent;
425 }
426
427 this.protocol
429 .encode_message(&item, this.write_buf.buffer_mut())?;
430 Ok(())
431 }
432
433 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
434 let this = self.as_mut().get_mut();
435
436 while this.write_buf.has_data() {
438 let slices = this.write_buf.get_write_slices();
439 if slices.is_empty() {
440 break;
441 }
442
443 match Pin::new(&mut this.inner).poll_write_vectored(cx, &slices) {
444 Poll::Ready(Ok(0)) => {
445 return Poll::Ready(Err(Error::ConnectionClosed));
446 }
447 Poll::Ready(Ok(n)) => {
448 this.write_buf.consume(n);
449 }
450 Poll::Ready(Err(e)) => {
451 return Poll::Ready(Err(e.into()));
452 }
453 Poll::Pending => {
454 return Poll::Pending;
455 }
456 }
457 }
458
459 match Pin::new(&mut self.as_mut().get_mut().inner).poll_flush(cx) {
461 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
462 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
463 Poll::Pending => Poll::Pending,
464 }
465 }
466
467 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
468 if self.state == StreamState::Open {
470 let close = Message::Close(Some(CloseReason::new(1000, "")));
471 if let Err(e) = self.as_mut().start_send(close) {
472 return Poll::Ready(Err(e));
473 }
474 }
475
476 match self.as_mut().poll_flush(cx) {
478 Poll::Ready(Ok(())) => {}
479 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
480 Poll::Pending => return Poll::Pending,
481 }
482
483 match Pin::new(&mut self.as_mut().get_mut().inner).poll_shutdown(cx) {
485 Poll::Ready(Ok(())) => {
486 self.as_mut().get_mut().state = StreamState::Closed;
487 Poll::Ready(Ok(()))
488 }
489 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
490 Poll::Pending => Poll::Pending,
491 }
492 }
493}
494
495pub struct WebSocketStreamBuilder {
497 config: Config,
498 role: Role,
499 high_water_mark: usize,
500 low_water_mark: usize,
501}
502
503impl WebSocketStreamBuilder {
504 pub fn new() -> Self {
506 Self {
507 config: Config::default(),
508 role: Role::Server,
509 high_water_mark: DEFAULT_HIGH_WATER_MARK,
510 low_water_mark: DEFAULT_LOW_WATER_MARK,
511 }
512 }
513
514 pub fn role(mut self, role: Role) -> Self {
516 self.role = role;
517 self
518 }
519
520 pub fn max_message_size(mut self, size: usize) -> Self {
522 self.config.max_message_size = size;
523 self
524 }
525
526 pub fn max_frame_size(mut self, size: usize) -> Self {
528 self.config.max_frame_size = size;
529 self
530 }
531
532 pub fn write_buffer_size(mut self, size: usize) -> Self {
534 self.config.write_buffer_size = size;
535 self
536 }
537
538 pub fn high_water_mark(mut self, size: usize) -> Self {
543 self.high_water_mark = size;
544 self
545 }
546
547 pub fn low_water_mark(mut self, size: usize) -> Self {
552 self.low_water_mark = size;
553 self
554 }
555
556 pub fn build<S>(self, stream: S) -> WebSocketStream<S>
558 where
559 S: AsyncRead + AsyncWrite + Unpin,
560 {
561 let mut ws = WebSocketStream::from_raw(stream, self.role, self.config);
562 ws.high_water_mark = self.high_water_mark;
563 ws.low_water_mark = self.low_water_mark;
564 ws
565 }
566}
567
568impl Default for WebSocketStreamBuilder {
569 fn default() -> Self {
570 Self::new()
571 }
572}
573
574use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
589use tokio::sync::mpsc;
590
591#[derive(Debug, Clone)]
593enum ControlRequest {
594 Pong(bytes::Bytes),
596 CloseResponse,
598}
599
600pub struct SplitReader<S> {
606 reader: ReadHalf<S>,
608 protocol: Protocol,
610 read_buf: BytesMut,
612 pending_messages: Vec<Message>,
614 pending_index: usize,
615 control_tx: mpsc::UnboundedSender<ControlRequest>,
617 closed: bool,
619}
620
621pub struct SplitWriter<S> {
627 writer: WriteHalf<S>,
629 protocol: Protocol,
631 write_buf: BytesMut,
633 control_rx: mpsc::UnboundedReceiver<ControlRequest>,
635 closed: bool,
637}
638
639impl<S> WebSocketStream<S>
640where
641 S: AsyncRead + AsyncWrite + Unpin,
642{
643 pub fn split(self) -> (SplitReader<S>, SplitWriter<S>) {
665 let (reader, writer) = tokio::io::split(self.inner);
667
668 let (control_tx, control_rx) = mpsc::unbounded_channel();
670
671 let reader_protocol = Protocol::new(
673 self.protocol.role,
674 self.config.max_frame_size,
675 self.config.max_message_size,
676 );
677 let writer_protocol = self.protocol;
678
679 (
680 SplitReader {
681 reader,
682 protocol: reader_protocol,
683 read_buf: self.read_buf,
684 pending_messages: self.pending_messages,
685 pending_index: self.pending_index,
686 control_tx,
687 closed: self.state == StreamState::Closed,
688 },
689 SplitWriter {
690 writer,
691 protocol: writer_protocol,
692 write_buf: BytesMut::with_capacity(1024),
693 control_rx,
694 closed: self.state == StreamState::Closed,
695 },
696 )
697 }
698}
699
700impl<S> SplitReader<S>
701where
702 S: AsyncRead + AsyncWrite + Unpin,
703{
704 pub async fn next(&mut self) -> Option<Result<Message>> {
709 loop {
710 if self.closed {
712 return None;
713 }
714
715 if self.pending_index < self.pending_messages.len() {
717 let msg = self.pending_messages[self.pending_index].clone();
718 self.pending_index += 1;
719
720 if self.pending_index >= self.pending_messages.len() {
721 self.pending_messages.clear();
722 self.pending_index = 0;
723 }
724
725 match &msg {
727 Message::Ping(data) => {
728 let _ = self.control_tx.send(ControlRequest::Pong(data.clone()));
730 continue;
732 }
733 Message::Close(reason) => {
734 if !self.closed {
735 let _ = self.control_tx.send(ControlRequest::CloseResponse);
737 self.closed = true;
738 }
739 return Some(Ok(Message::Close(reason.clone())));
740 }
741 Message::Pong(_) => {
742 continue;
744 }
745 _ => {}
746 }
747
748 return Some(Ok(msg));
749 }
750
751 if self.read_buf.capacity() - self.read_buf.len() < 4096 {
754 self.read_buf.reserve(8192);
755 }
756
757 match self.reader.read_buf(&mut self.read_buf).await {
758 Ok(0) => {
759 self.closed = true;
761 return None;
762 }
763 Ok(_n) => {
764 match self.protocol.process(&mut self.read_buf) {
766 Ok(messages) => {
767 if !messages.is_empty() {
768 self.pending_messages = messages;
769 self.pending_index = 0;
770 }
771 }
772 Err(e) => return Some(Err(e)),
773 }
774 }
776 Err(e) => {
777 return Some(Err(e.into()));
778 }
779 }
780 }
781 }
782
783 pub fn is_closed(&self) -> bool {
785 self.closed
786 }
787}
788
789impl<S> SplitWriter<S>
790where
791 S: AsyncRead + AsyncWrite + Unpin,
792{
793 pub async fn send(&mut self, msg: Message) -> Result<()> {
797 if self.closed {
798 return Err(Error::ConnectionClosed);
799 }
800
801 self.process_control_requests().await?;
803
804 if msg.is_close() {
805 self.closed = true;
806 }
807
808 self.write_buf.clear();
810 self.protocol.encode_message(&msg, &mut self.write_buf)?;
811
812 self.writer.write_all(&self.write_buf).await?;
814 self.writer.flush().await?;
815 Ok(())
816 }
817
818 async fn process_control_requests(&mut self) -> Result<()> {
820 while let Ok(req) = self.control_rx.try_recv() {
822 self.write_buf.clear();
823
824 match req {
825 ControlRequest::Pong(data) => {
826 self.protocol.encode_pong(&data, &mut self.write_buf);
827 }
828 ControlRequest::CloseResponse => {
829 self.protocol.encode_close_response(&mut self.write_buf);
830 self.closed = true;
831 }
832 }
833
834 if !self.write_buf.is_empty() {
835 self.writer.write_all(&self.write_buf).await?;
836 }
837 }
838
839 Ok(())
840 }
841
842 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
844 self.send(Message::text(text)).await
845 }
846
847 pub async fn send_binary(&mut self, data: bytes::Bytes) -> Result<()> {
849 self.send(Message::Binary(data)).await
850 }
851
852 pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
854 self.send(Message::Close(Some(CloseReason::new(code, reason))))
855 .await
856 }
857
858 pub fn is_closed(&self) -> bool {
860 self.closed
861 }
862
863 pub async fn flush(&mut self) -> Result<()> {
865 self.process_control_requests().await?;
866 self.writer.flush().await.map_err(Into::into)
867 }
868}
869
870#[cfg(feature = "permessage-deflate")]
875pin_project! {
876 pub struct CompressedWebSocketStream<S> {
881 #[pin]
882 inner: S,
883 protocol: crate::protocol::CompressedProtocol,
884 read_buf: BytesMut,
885 write_buf: CorkBuffer,
886 state: StreamState,
887 config: Config,
888 pending_messages: Vec<Message>,
889 pending_index: usize,
890 high_water_mark: usize,
891 low_water_mark: usize,
892 }
893}
894
895#[cfg(feature = "permessage-deflate")]
896impl<S> CompressedWebSocketStream<S>
897where
898 S: AsyncRead + AsyncWrite + Unpin,
899{
900 pub fn server(inner: S, config: Config, deflate_config: crate::deflate::DeflateConfig) -> Self {
902 let protocol = crate::protocol::CompressedProtocol::server(
903 config.max_frame_size,
904 config.max_message_size,
905 deflate_config,
906 );
907
908 Self {
909 inner,
910 protocol,
911 read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
912 write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
913 state: StreamState::Open,
914 config,
915 pending_messages: Vec::new(),
916 pending_index: 0,
917 high_water_mark: DEFAULT_HIGH_WATER_MARK,
918 low_water_mark: DEFAULT_LOW_WATER_MARK,
919 }
920 }
921
922 pub fn client(inner: S, config: Config, deflate_config: crate::deflate::DeflateConfig) -> Self {
924 let protocol = crate::protocol::CompressedProtocol::client(
925 config.max_frame_size,
926 config.max_message_size,
927 deflate_config,
928 );
929
930 Self {
931 inner,
932 protocol,
933 read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
934 write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
935 state: StreamState::Open,
936 config,
937 pending_messages: Vec::new(),
938 pending_index: 0,
939 high_water_mark: DEFAULT_HIGH_WATER_MARK,
940 low_water_mark: DEFAULT_LOW_WATER_MARK,
941 }
942 }
943
944 #[inline]
946 pub fn is_closed(&self) -> bool {
947 self.state == StreamState::Closed || self.protocol.is_closed()
948 }
949
950 #[inline]
952 pub fn is_backpressured(&self) -> bool {
953 self.write_buf.pending_bytes() > self.high_water_mark
954 }
955
956 #[inline]
958 pub fn write_buffer_len(&self) -> usize {
959 self.write_buf.pending_bytes()
960 }
961
962 pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
964 if self.state != StreamState::Open {
965 return Ok(());
966 }
967
968 let close = Message::Close(Some(CloseReason::new(code, reason)));
969 self.protocol
970 .encode_message(&close, self.write_buf.buffer_mut())?;
971 self.state = StreamState::CloseSent;
972
973 self.flush_write_buf().await?;
974 Ok(())
975 }
976
977 async fn flush_write_buf(&mut self) -> Result<()> {
979 use tokio::io::AsyncWriteExt;
980
981 while self.write_buf.has_data() {
982 let slices = self.write_buf.get_write_slices();
983 if slices.is_empty() {
984 break;
985 }
986
987 let n = self.inner.write_vectored(&slices).await?;
988 if n == 0 {
989 return Err(Error::ConnectionClosed);
990 }
991 self.write_buf.consume(n);
992 }
993
994 self.inner.flush().await?;
995 Ok(())
996 }
997
998 fn poll_read_more(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
1000 let this = self.project();
1001
1002 if this.read_buf.capacity() - this.read_buf.len() < 4096 {
1003 this.read_buf.reserve(8192);
1004 }
1005
1006 let buf_len = this.read_buf.len();
1007 let buf_cap = this.read_buf.capacity();
1008
1009 unsafe {
1010 this.read_buf.set_len(buf_cap);
1011 }
1012
1013 let mut read_buf = ReadBuf::new(&mut this.read_buf[buf_len..]);
1014
1015 match this.inner.poll_read(cx, &mut read_buf) {
1016 Poll::Ready(Ok(())) => {
1017 let n = read_buf.filled().len();
1018 unsafe {
1019 this.read_buf.set_len(buf_len + n);
1020 }
1021 if n == 0 {
1022 Poll::Ready(Ok(0))
1023 } else {
1024 Poll::Ready(Ok(n))
1025 }
1026 }
1027 Poll::Ready(Err(e)) => {
1028 unsafe {
1029 this.read_buf.set_len(buf_len);
1030 }
1031 Poll::Ready(Err(e))
1032 }
1033 Poll::Pending => {
1034 unsafe {
1035 this.read_buf.set_len(buf_len);
1036 }
1037 Poll::Pending
1038 }
1039 }
1040 }
1041
1042 fn process_read_buf(&mut self) -> Result<()> {
1044 if self.read_buf.is_empty() {
1045 return Ok(());
1046 }
1047
1048 let messages = self.protocol.process(&mut self.read_buf)?;
1049
1050 if !messages.is_empty() {
1051 self.pending_messages = messages;
1052 self.pending_index = 0;
1053 }
1054
1055 Ok(())
1056 }
1057
1058 fn next_pending_message(&mut self) -> Option<Message> {
1060 if self.pending_index < self.pending_messages.len() {
1061 let msg = self.pending_messages[self.pending_index].clone();
1062 self.pending_index += 1;
1063
1064 if self.pending_index >= self.pending_messages.len() {
1065 self.pending_messages.clear();
1066 self.pending_index = 0;
1067 }
1068
1069 Some(msg)
1070 } else {
1071 None
1072 }
1073 }
1074}
1075
1076#[cfg(feature = "permessage-deflate")]
1077impl<S> Stream for CompressedWebSocketStream<S>
1078where
1079 S: AsyncRead + AsyncWrite + Unpin,
1080{
1081 type Item = Result<Message>;
1082
1083 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1084 loop {
1085 if self.state == StreamState::Closed {
1086 return Poll::Ready(None);
1087 }
1088
1089 if let Some(msg) = self.as_mut().get_mut().next_pending_message() {
1090 match &msg {
1091 Message::Ping(data) => {
1092 let this = self.as_mut().get_mut();
1093 this.protocol.encode_pong(data, this.write_buf.buffer_mut());
1094 }
1095 Message::Close(reason) => {
1096 let this = self.as_mut().get_mut();
1097 if this.state == StreamState::Open {
1098 this.protocol
1099 .encode_close_response(this.write_buf.buffer_mut());
1100 this.state = StreamState::Closed;
1101 }
1102 return Poll::Ready(Some(Ok(Message::Close(reason.clone()))));
1103 }
1104 _ => {}
1105 }
1106
1107 return Poll::Ready(Some(Ok(msg)));
1108 }
1109
1110 match self.as_mut().poll_read_more(cx) {
1111 Poll::Ready(Ok(0)) => {
1112 self.as_mut().get_mut().state = StreamState::Closed;
1113 return Poll::Ready(None);
1114 }
1115 Poll::Ready(Ok(_n)) => match self.as_mut().get_mut().process_read_buf() {
1116 Ok(()) => continue,
1117 Err(e) => return Poll::Ready(Some(Err(e))),
1118 },
1119 Poll::Ready(Err(e)) => {
1120 return Poll::Ready(Some(Err(e.into())));
1121 }
1122 Poll::Pending => {
1123 return Poll::Pending;
1124 }
1125 }
1126 }
1127 }
1128}
1129
1130#[cfg(feature = "permessage-deflate")]
1131impl<S> Sink<Message> for CompressedWebSocketStream<S>
1132where
1133 S: AsyncRead + AsyncWrite + Unpin,
1134{
1135 type Error = Error;
1136
1137 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
1138 if self.state == StreamState::Closed {
1139 return Poll::Ready(Err(Error::ConnectionClosed));
1140 }
1141 Poll::Ready(Ok(()))
1142 }
1143
1144 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
1145 let this = self.get_mut();
1146
1147 if this.state == StreamState::Closed {
1148 return Err(Error::ConnectionClosed);
1149 }
1150
1151 if item.is_close() {
1152 this.state = StreamState::CloseSent;
1153 }
1154
1155 this.protocol
1156 .encode_message(&item, this.write_buf.buffer_mut())?;
1157 Ok(())
1158 }
1159
1160 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1161 let this = self.as_mut().get_mut();
1162
1163 while this.write_buf.has_data() {
1164 let slices = this.write_buf.get_write_slices();
1165 if slices.is_empty() {
1166 break;
1167 }
1168
1169 match Pin::new(&mut this.inner).poll_write_vectored(cx, &slices) {
1170 Poll::Ready(Ok(0)) => {
1171 return Poll::Ready(Err(Error::ConnectionClosed));
1172 }
1173 Poll::Ready(Ok(n)) => {
1174 this.write_buf.consume(n);
1175 }
1176 Poll::Ready(Err(e)) => {
1177 return Poll::Ready(Err(e.into()));
1178 }
1179 Poll::Pending => {
1180 return Poll::Pending;
1181 }
1182 }
1183 }
1184
1185 match Pin::new(&mut self.as_mut().get_mut().inner).poll_flush(cx) {
1186 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
1187 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
1188 Poll::Pending => Poll::Pending,
1189 }
1190 }
1191
1192 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1193 if self.state == StreamState::Open {
1194 let close = Message::Close(Some(CloseReason::new(1000, "")));
1195 if let Err(e) = self.as_mut().start_send(close) {
1196 return Poll::Ready(Err(e));
1197 }
1198 }
1199
1200 match self.as_mut().poll_flush(cx) {
1201 Poll::Ready(Ok(())) => {}
1202 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1203 Poll::Pending => return Poll::Pending,
1204 }
1205
1206 match Pin::new(&mut self.as_mut().get_mut().inner).poll_shutdown(cx) {
1207 Poll::Ready(Ok(())) => {
1208 self.as_mut().get_mut().state = StreamState::Closed;
1209 Poll::Ready(Ok(()))
1210 }
1211 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
1212 Poll::Pending => Poll::Pending,
1213 }
1214 }
1215}
1216
1217#[cfg(feature = "permessage-deflate")]
1227pub struct CompressedSplitReader<S> {
1228 reader: ReadHalf<S>,
1230 protocol: crate::protocol::CompressedReaderProtocol,
1232 read_buf: BytesMut,
1234 pending_messages: Vec<Message>,
1236 pending_index: usize,
1237 control_tx: mpsc::UnboundedSender<ControlRequest>,
1239 closed: bool,
1241}
1242
1243#[cfg(feature = "permessage-deflate")]
1249pub struct CompressedSplitWriter<S> {
1250 writer: WriteHalf<S>,
1252 protocol: crate::protocol::CompressedWriterProtocol,
1254 write_buf: BytesMut,
1256 control_rx: mpsc::UnboundedReceiver<ControlRequest>,
1258 closed: bool,
1260}
1261
1262#[cfg(feature = "permessage-deflate")]
1263impl<S> CompressedWebSocketStream<S>
1264where
1265 S: AsyncRead + AsyncWrite + Unpin,
1266{
1267 pub fn split(self) -> (CompressedSplitReader<S>, CompressedSplitWriter<S>) {
1293 let (reader, writer) = tokio::io::split(self.inner);
1295
1296 let (control_tx, control_rx) = mpsc::unbounded_channel();
1298
1299 let (reader_protocol, writer_protocol) = self
1301 .protocol
1302 .split(self.config.max_frame_size, self.config.max_message_size);
1303
1304 (
1305 CompressedSplitReader {
1306 reader,
1307 protocol: reader_protocol,
1308 read_buf: self.read_buf,
1309 pending_messages: self.pending_messages,
1310 pending_index: self.pending_index,
1311 control_tx,
1312 closed: self.state == StreamState::Closed,
1313 },
1314 CompressedSplitWriter {
1315 writer,
1316 protocol: writer_protocol,
1317 write_buf: BytesMut::with_capacity(1024),
1318 control_rx,
1319 closed: self.state == StreamState::Closed,
1320 },
1321 )
1322 }
1323}
1324
1325#[cfg(feature = "permessage-deflate")]
1326impl<S> CompressedSplitReader<S>
1327where
1328 S: AsyncRead + AsyncWrite + Unpin,
1329{
1330 pub async fn next(&mut self) -> Option<Result<Message>> {
1335 loop {
1336 if self.closed {
1338 return None;
1339 }
1340
1341 if self.pending_index < self.pending_messages.len() {
1343 let msg = self.pending_messages[self.pending_index].clone();
1344 self.pending_index += 1;
1345
1346 if self.pending_index >= self.pending_messages.len() {
1347 self.pending_messages.clear();
1348 self.pending_index = 0;
1349 }
1350
1351 match &msg {
1353 Message::Ping(data) => {
1354 let _ = self.control_tx.send(ControlRequest::Pong(data.clone()));
1356 continue;
1358 }
1359 Message::Close(reason) => {
1360 if !self.closed {
1361 let _ = self.control_tx.send(ControlRequest::CloseResponse);
1363 self.closed = true;
1364 }
1365 return Some(Ok(Message::Close(reason.clone())));
1366 }
1367 Message::Pong(_) => {
1368 continue;
1370 }
1371 _ => {}
1372 }
1373
1374 return Some(Ok(msg));
1375 }
1376
1377 if self.read_buf.capacity() - self.read_buf.len() < 4096 {
1380 self.read_buf.reserve(8192);
1381 }
1382
1383 match self.reader.read_buf(&mut self.read_buf).await {
1384 Ok(0) => {
1385 self.closed = true;
1387 return None;
1388 }
1389 Ok(_n) => {
1390 match self.protocol.process(&mut self.read_buf) {
1392 Ok(messages) => {
1393 if !messages.is_empty() {
1394 self.pending_messages = messages;
1395 self.pending_index = 0;
1396 }
1397 }
1398 Err(e) => return Some(Err(e)),
1399 }
1400 }
1402 Err(e) => {
1403 return Some(Err(e.into()));
1404 }
1405 }
1406 }
1407 }
1408
1409 pub fn is_closed(&self) -> bool {
1411 self.closed
1412 }
1413}
1414
1415#[cfg(feature = "permessage-deflate")]
1416impl<S> CompressedSplitWriter<S>
1417where
1418 S: AsyncRead + AsyncWrite + Unpin,
1419{
1420 pub async fn send(&mut self, msg: Message) -> Result<()> {
1424 if self.closed {
1425 return Err(Error::ConnectionClosed);
1426 }
1427
1428 self.process_control_requests().await?;
1430
1431 if msg.is_close() {
1432 self.closed = true;
1433 }
1434
1435 self.write_buf.clear();
1437 self.protocol.encode_message(&msg, &mut self.write_buf)?;
1438
1439 self.writer.write_all(&self.write_buf).await?;
1441 self.writer.flush().await?;
1442 Ok(())
1443 }
1444
1445 async fn process_control_requests(&mut self) -> Result<()> {
1447 while let Ok(req) = self.control_rx.try_recv() {
1449 self.write_buf.clear();
1450
1451 match req {
1452 ControlRequest::Pong(data) => {
1453 self.protocol.encode_pong(&data, &mut self.write_buf);
1454 }
1455 ControlRequest::CloseResponse => {
1456 self.protocol.encode_close_response(&mut self.write_buf);
1457 self.closed = true;
1458 }
1459 }
1460
1461 if !self.write_buf.is_empty() {
1462 self.writer.write_all(&self.write_buf).await?;
1463 }
1464 }
1465
1466 Ok(())
1467 }
1468
1469 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
1471 self.send(Message::text(text)).await
1472 }
1473
1474 pub async fn send_binary(&mut self, data: bytes::Bytes) -> Result<()> {
1476 self.send(Message::Binary(data)).await
1477 }
1478
1479 pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
1481 self.send(Message::Close(Some(CloseReason::new(code, reason))))
1482 .await
1483 }
1484
1485 pub fn is_closed(&self) -> bool {
1487 self.closed
1488 }
1489
1490 pub async fn flush(&mut self) -> Result<()> {
1492 self.process_control_requests().await?;
1493 self.writer.flush().await.map_err(Into::into)
1494 }
1495}
1496
1497#[cfg(test)]
1498mod tests {
1499 use super::*;
1500
1501 #[test]
1505 fn test_builder() {
1506 let _builder = WebSocketStreamBuilder::new()
1507 .role(Role::Server)
1508 .max_message_size(1024 * 1024)
1509 .max_frame_size(64 * 1024);
1510 }
1511}