1use crate::config::WebSocketConfig;
2use crate::encoder::Encoder;
3use crate::error::Error;
4use crate::frame::{Frame, OpCode};
5use crate::message::Message;
6use crate::write::Writer;
7use bytes::BytesMut;
8use futures::Stream;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13use tokio::sync::Mutex;
14use tokio::time::sleep;
15use tokio_stream::wrappers::ReceiverStream;
16
17const PAYLOAD_SIZE_COMPRESSION_ENABLE: usize = 1;
18
19pub struct WSReader {
20 read_rx: ReceiverStream<Result<Message, Error>>,
21}
22
23impl WSReader {
24 pub fn new(read_rx: ReceiverStream<Result<Message, Error>>) -> Self {
25 Self { read_rx }
26 }
27}
28
29impl Stream for WSReader {
30 type Item = Result<Message, Error>;
31 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
32 let this = self.get_mut();
33 Pin::new(&mut this.read_rx).poll_next(cx)
34 }
35}
36
37pub struct WSWriter {
38 pub writer: Arc<Mutex<Writer>>,
39 pub web_socket_config: WebSocketConfig,
40 encoder: Encoder,
41}
42
43impl WSWriter {
44 pub fn new(
45 writer: Arc<Mutex<Writer>>,
46 web_socket_config: WebSocketConfig,
47 encoder: Encoder,
48 ) -> Self {
49 Self {
50 writer,
51 web_socket_config,
52 encoder,
53 }
54 }
55
56 pub async fn close_connection(&mut self) -> Result<(), Error> {
62 self.write_frames(vec![Frame::new(true, OpCode::Close, Vec::new(), false)])
63 .await?;
64
65 sleep(Duration::from_millis(500)).await;
66
67 Ok(())
68
69 }
74
75 pub async fn send_message(&mut self, message: Message) -> Result<(), Error> {
76 self.write_message(message).await
77 }
78
79 pub async fn send(&mut self, data: Vec<u8>) -> Result<(), Error> {
82 self.write_message(Message::Text(String::from_utf8(data)?))
83 .await
84 }
85
86 pub async fn send_as_binary(&mut self, data: Vec<u8>) -> Result<(), Error> {
87 self.write_message(Message::Binary(data)).await
88 }
89
90 pub async fn send_as_text(&mut self, data: String) -> Result<(), Error> {
91 self.write_message(Message::Text(data)).await
92 }
93
94 pub async fn send_ping(&mut self) -> Result<(), Error> {
96 self.write_frames(vec![Frame::new(true, OpCode::Ping, Vec::new(), false)])
97 .await
98 }
99
100 pub async fn send_large_data_fragmented(
103 &mut self,
104 mut data: Vec<u8>,
105 fragment_size: usize,
106 ) -> Result<(), Error> {
107 if fragment_size > self.web_socket_config.max_frame_size.unwrap_or_default() {
111 return Err(Error::CustomFragmentSizeExceeded(
112 fragment_size,
113 self.web_socket_config.max_frame_size.unwrap_or_default(),
114 ));
115 }
116
117 if data.len() > self.web_socket_config.max_message_size.unwrap_or_default() {
118 return Err(Error::MaxMessageSize);
119 }
120
121 let compressed = self.check_compression(&mut data)?;
123
124 let chunks = data.chunks(fragment_size);
125 let total_chunks = chunks.len();
126
127 for (i, chunk) in chunks.enumerate() {
128 let is_final = i == total_chunks - 1;
129 let opcode = if i == 0 {
130 OpCode::Text
131 } else {
132 OpCode::Continue
133 };
134
135 self.write_frames(vec![Frame::new(
136 is_final,
137 opcode,
138 Vec::from(chunk),
139 compressed,
140 )])
141 .await?
142 }
143
144 Ok(())
145 }
146
147 pub(crate) fn check_compression(&mut self, data: &mut Vec<u8>) -> Result<bool, Error> {
148 let mut compressed = false;
149 if self
151 .web_socket_config
152 .extensions
153 .clone()
154 .unwrap_or_default()
155 .permessage_deflate
156 && data.len() > PAYLOAD_SIZE_COMPRESSION_ENABLE
157 {
158 *data = self.encoder.compress(&mut BytesMut::from(&data[..]))?;
159 compressed = true;
160 }
161
162 Ok(compressed)
163 }
164
165 pub(crate) fn convert_to_frames(&mut self, message: Message) -> Result<Vec<Frame>, Error> {
166 let opcode = match message {
167 Message::Text(_) => OpCode::Text,
168 Message::Binary(_) => OpCode::Binary,
169 };
170
171 let mut payload = match message {
172 Message::Text(text) => text.into_bytes(),
173 Message::Binary(data) => data,
174 };
175
176 if payload.is_empty() {
178 return Ok(vec![Frame {
179 final_fragment: true,
180 opcode,
181 payload,
182 compressed: false,
183 }]);
184 }
185
186 let max_frame_size = self.web_socket_config.max_frame_size.unwrap_or_default();
187 let mut frames = Vec::new();
188 let compressed = self.check_compression(&mut payload)?;
190
191 for chunk in payload.chunks(max_frame_size) {
192 frames.push(Frame {
193 final_fragment: false,
194 opcode: if frames.is_empty() {
195 opcode.clone()
196 } else {
197 OpCode::Continue
198 },
199 payload: chunk.to_vec(),
200 compressed,
201 });
202 }
203
204 if let Some(last_frame) = frames.last_mut() {
205 last_frame.final_fragment = true;
206 }
207
208 Ok(frames)
209 }
210
211 pub(crate) async fn write_message(&mut self, message: Message) -> Result<(), Error> {
212 if message.as_binary().len() > self.web_socket_config.max_message_size.unwrap_or_default() {
213 return Err(Error::MaxMessageSize);
214 }
215
216 let frames = self.convert_to_frames(message)?;
217 self.write_frames(frames).await
218 }
219
220 pub(crate) async fn write_frames(&mut self, frames: Vec<Frame>) -> Result<(), Error> {
221 let mut set_rsv1_first_frame = !frames.is_empty() && frames[0].compressed;
224
225 for frame in frames {
226 self.writer
227 .lock()
228 .await
229 .write_frame(frame, set_rsv1_first_frame)
230 .await?;
231 set_rsv1_first_frame = false;
235 }
236 Ok(())
237 }
238}