socket_flow/
split.rs

1use crate::config::WebSocketConfig;
2use crate::encoder::Encoder;
3use crate::error::Error;
4use crate::frame::{Frame, OpCode};
5use crate::message::Message;
6use crate::write::Writer;
7use bytes::BytesMut;
8use futures::Stream;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13use tokio::sync::Mutex;
14use tokio::time::sleep;
15use tokio_stream::wrappers::ReceiverStream;
16
17const PAYLOAD_SIZE_COMPRESSION_ENABLE: usize = 1;
18
19pub struct WSReader {
20    read_rx: ReceiverStream<Result<Message, Error>>,
21}
22
23impl WSReader {
24    pub fn new(read_rx: ReceiverStream<Result<Message, Error>>) -> Self {
25        Self { read_rx }
26    }
27}
28
29impl Stream for WSReader {
30    type Item = Result<Message, Error>;
31    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
32        let this = self.get_mut();
33        Pin::new(&mut this.read_rx).poll_next(cx)
34    }
35}
36
37pub struct WSWriter {
38    pub writer: Arc<Mutex<Writer>>,
39    pub web_socket_config: WebSocketConfig,
40    encoder: Encoder,
41}
42
43impl WSWriter {
44    pub fn new(
45        writer: Arc<Mutex<Writer>>,
46        web_socket_config: WebSocketConfig,
47        encoder: Encoder,
48    ) -> Self {
49        Self {
50            writer,
51            web_socket_config,
52            encoder,
53        }
54    }
55
56    /// This function will be used for closing the connection between two instances, mainly it will
57    /// be used by a client,
58    /// to request disconnection with a server.It first sends a close frame
59    /// through the socket, and waits until it receives the confirmation in a channel
60    /// executing it inside a timeout, to avoid a long waiting time
61    pub async fn close_connection(&mut self) -> Result<(), Error> {
62        self.write_frames(vec![Frame::new(true, OpCode::Close, Vec::new(), false)])
63            .await?;
64
65        sleep(Duration::from_millis(500)).await;
66
67        Ok(())
68
69        // match timeout(Duration::from_secs(CLOSE_TIMEOUT), self.write.lock().await.closed()).await {
70        //     Err(err) => Err(err)?,
71        //     _ => Ok(()),
72        // }
73    }
74
75    pub async fn send_message(&mut self, message: Message) -> Result<(), Error> {
76        self.write_message(message).await
77    }
78
79    // This function will be used to send general data as a Vector of bytes, and by default will
80    // be sent as a text opcode
81    pub async fn send(&mut self, data: Vec<u8>) -> Result<(), Error> {
82        self.write_message(Message::Text(String::from_utf8(data)?))
83            .await
84    }
85
86    pub async fn send_as_binary(&mut self, data: Vec<u8>) -> Result<(), Error> {
87        self.write_message(Message::Binary(data)).await
88    }
89
90    pub async fn send_as_text(&mut self, data: String) -> Result<(), Error> {
91        self.write_message(Message::Text(data)).await
92    }
93
94    // It will send a ping frame through the socket
95    pub async fn send_ping(&mut self) -> Result<(), Error> {
96        self.write_frames(vec![Frame::new(true, OpCode::Ping, Vec::new(), false)])
97            .await
98    }
99
100    // This function can be used to send large payloads, that will be divided in chunks using fragmented
101    // messages, and Continue opcode
102    pub async fn send_large_data_fragmented(
103        &mut self,
104        mut data: Vec<u8>,
105        fragment_size: usize,
106    ) -> Result<(), Error> {
107        // Each fragment size will be limited by max_frame_size config,
108        // that had been given by the user,
109        // or it will use the default max frame size which is 16 MiB.
110        if fragment_size > self.web_socket_config.max_frame_size.unwrap_or_default() {
111            return Err(Error::CustomFragmentSizeExceeded(
112                fragment_size,
113                self.web_socket_config.max_frame_size.unwrap_or_default(),
114            ));
115        }
116
117        if data.len() > self.web_socket_config.max_message_size.unwrap_or_default() {
118            return Err(Error::MaxMessageSize);
119        }
120
121        // This function will check if compression is enabled, and apply if needed
122        let compressed = self.check_compression(&mut data)?;
123
124        let chunks = data.chunks(fragment_size);
125        let total_chunks = chunks.len();
126
127        for (i, chunk) in chunks.enumerate() {
128            let is_final = i == total_chunks - 1;
129            let opcode = if i == 0 {
130                OpCode::Text
131            } else {
132                OpCode::Continue
133            };
134
135            self.write_frames(vec![Frame::new(
136                is_final,
137                opcode,
138                Vec::from(chunk),
139                compressed,
140            )])
141            .await?
142        }
143
144        Ok(())
145    }
146
147    pub(crate) fn check_compression(&mut self, data: &mut Vec<u8>) -> Result<bool, Error> {
148        let mut compressed = false;
149        // If compression is enabled, and the payload is greater than 8KB, compress the payload
150        if self
151            .web_socket_config
152            .extensions
153            .clone()
154            .unwrap_or_default()
155            .permessage_deflate
156            && data.len() > PAYLOAD_SIZE_COMPRESSION_ENABLE
157        {
158            *data = self.encoder.compress(&mut BytesMut::from(&data[..]))?;
159            compressed = true;
160        }
161
162        Ok(compressed)
163    }
164
165    pub(crate) fn convert_to_frames(&mut self, message: Message) -> Result<Vec<Frame>, Error> {
166        let opcode = match message {
167            Message::Text(_) => OpCode::Text,
168            Message::Binary(_) => OpCode::Binary,
169        };
170
171        let mut payload = match message {
172            Message::Text(text) => text.into_bytes(),
173            Message::Binary(data) => data,
174        };
175
176        // Empty payloads aren't compressed
177        if payload.is_empty() {
178            return Ok(vec![Frame {
179                final_fragment: true,
180                opcode,
181                payload,
182                compressed: false,
183            }]);
184        }
185
186        let max_frame_size = self.web_socket_config.max_frame_size.unwrap_or_default();
187        let mut frames = Vec::new();
188        // This function will check if compression is enabled, and apply if needed
189        let compressed = self.check_compression(&mut payload)?;
190
191        for chunk in payload.chunks(max_frame_size) {
192            frames.push(Frame {
193                final_fragment: false,
194                opcode: if frames.is_empty() {
195                    opcode.clone()
196                } else {
197                    OpCode::Continue
198                },
199                payload: chunk.to_vec(),
200                compressed,
201            });
202        }
203
204        if let Some(last_frame) = frames.last_mut() {
205            last_frame.final_fragment = true;
206        }
207
208        Ok(frames)
209    }
210
211    pub(crate) async fn write_message(&mut self, message: Message) -> Result<(), Error> {
212        if message.as_binary().len() > self.web_socket_config.max_message_size.unwrap_or_default() {
213            return Err(Error::MaxMessageSize);
214        }
215
216        let frames = self.convert_to_frames(message)?;
217        self.write_frames(frames).await
218    }
219
220    pub(crate) async fn write_frames(&mut self, frames: Vec<Frame>) -> Result<(), Error> {
221        // For compressed messages, regardless if it's fragmented or not, we always set the RSV1 bit
222        // for the first frame.
223        let mut set_rsv1_first_frame = !frames.is_empty() && frames[0].compressed;
224
225        for frame in frames {
226            self.writer
227                .lock()
228                .await
229                .write_frame(frame, set_rsv1_first_frame)
230                .await?;
231            // Setting it to false,
232            // since we only need
233            // to set RSV1 bit for the first frame if compression is enabled
234            set_rsv1_first_frame = false;
235        }
236        Ok(())
237    }
238}