Skip to main content

specter/websocket/
connection.rs

1use std::future::Future;
2use std::time::Duration;
3
4use bytes::{Bytes, BytesMut};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::time::timeout as tokio_timeout;
7use url::Url;
8
9use crate::transport::connector::MaybeHttpsStream;
10use crate::websocket::error::{WebSocketError, WebSocketResult};
11use crate::websocket::frame::{decode_frame, encode_frame, FrameConfig, FrameDecoder, OpCode};
12use crate::websocket::message::{CloseFrame, Message};
13use crate::websocket::WebSocketConfig;
14
15#[derive(Debug)]
16pub struct WebSocket {
17    stream: MaybeHttpsStream,
18    url: Url,
19    protocol: Option<String>,
20    read_buffer: BytesMut,
21    frame_config: FrameConfig,
22    read_timeout: Option<Duration>,
23    write_timeout: Option<Duration>,
24    decoder: FrameDecoder,
25    close_sent: bool,
26    close_received: bool,
27}
28
29impl WebSocket {
30    pub(crate) fn new(
31        stream: MaybeHttpsStream,
32        url: Url,
33        protocol: Option<String>,
34        config: WebSocketConfig,
35        initial_read_buffer: Bytes,
36    ) -> Self {
37        Self {
38            stream,
39            url,
40            protocol,
41            read_buffer: BytesMut::from(&initial_read_buffer[..]),
42            frame_config: FrameConfig::new(config.max_frame_size, config.max_message_size),
43            read_timeout: config.read_timeout,
44            write_timeout: config.write_timeout,
45            decoder: FrameDecoder::new(),
46            close_sent: false,
47            close_received: false,
48        }
49    }
50
51    pub fn url(&self) -> &Url {
52        &self.url
53    }
54
55    pub fn protocol(&self) -> Option<&str> {
56        self.protocol.as_deref()
57    }
58
59    pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
60        if self.close_sent && !matches!(msg, Message::Close(_)) {
61            return Err(WebSocketError::protocol(
62                &self.url,
63                "cannot send data after close frame",
64            ));
65        }
66
67        match msg {
68            Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
69            Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
70            Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
71            Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
72            Message::Close(frame) => self.close(frame).await,
73        }
74    }
75
76    pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
77        self.send(Message::Text(text.into())).await
78    }
79
80    pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
81        self.send(Message::Binary(bytes.into())).await
82    }
83
84    pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
85        loop {
86            let frame = match decode_frame(&self.url, &mut self.read_buffer, self.frame_config) {
87                Ok(frame) => frame,
88                Err(error) => return Err(self.best_effort_close_for_error(error).await),
89            };
90
91            if let Some(frame) = frame {
92                let message = match self
93                    .decoder
94                    .decode_message(&self.url, frame, self.frame_config)
95                {
96                    Ok(message) => message,
97                    Err(error) => return Err(self.best_effort_close_for_error(error).await),
98                };
99
100                match message {
101                    Some(Message::Ping(payload)) => {
102                        if !self.close_received {
103                            self.write_control(OpCode::Pong, &payload).await?;
104                        }
105                        return Ok(Some(Message::Ping(payload)));
106                    }
107                    Some(Message::Close(frame)) => {
108                        self.close_received = true;
109                        if !self.close_sent {
110                            self.send_close_raw(frame.clone()).await?;
111                        }
112                        return Ok(None);
113                    }
114                    Some(other) => return Ok(Some(other)),
115                    None => {}
116                }
117            } else {
118                let mut scratch = [0_u8; 8192];
119                let n = Self::io_with_timeout(
120                    &self.url,
121                    self.read_timeout,
122                    "read",
123                    self.stream.read(&mut scratch),
124                )
125                .await?;
126                if n == 0 {
127                    return if self.close_sent || self.close_received {
128                        Ok(None)
129                    } else {
130                        Err(WebSocketError::connection_closed(&self.url))
131                    };
132                }
133                self.read_buffer.extend_from_slice(&scratch[..n]);
134            }
135        }
136    }
137
138    pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
139        if !self.close_sent {
140            self.send_close_raw(frame).await?;
141        }
142        Ok(())
143    }
144
145    async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
146        if payload.len() > self.frame_config.max_frame_size {
147            return Err(WebSocketError::limit_exceeded(
148                &self.url,
149                format!("frame exceeds {} bytes", self.frame_config.max_frame_size),
150            ));
151        }
152        if matches!(opcode, OpCode::Text | OpCode::Binary)
153            && payload.len() > self.frame_config.max_message_size
154        {
155            return Err(WebSocketError::limit_exceeded(
156                &self.url,
157                format!(
158                    "message exceeds {} bytes",
159                    self.frame_config.max_message_size
160                ),
161            ));
162        }
163        let bytes = encode_frame(opcode, payload, true)?;
164        Self::io_with_timeout(
165            &self.url,
166            self.write_timeout,
167            "write",
168            self.stream.write_all(&bytes),
169        )
170        .await?;
171        Self::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush()).await
172    }
173
174    async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
175        if payload.len() > 125 {
176            return Err(WebSocketError::protocol(
177                &self.url,
178                "control frame payload exceeds 125 bytes",
179            ));
180        }
181        self.write_frame(opcode, payload).await
182    }
183
184    async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
185        let payload = match frame {
186            Some(frame) => frame.encode(&self.url)?,
187            None => Vec::new(),
188        };
189        self.write_control(OpCode::Close, &payload).await?;
190        self.close_sent = true;
191        Ok(())
192    }
193
194    async fn best_effort_close_for_error(&mut self, error: WebSocketError) -> WebSocketError {
195        if let Some(code) = error.close_code() {
196            if !self.close_sent {
197                let frame = CloseFrame {
198                    code,
199                    reason: String::new(),
200                };
201                let _ = self.send_close_raw(Some(frame)).await;
202            }
203        }
204        error
205    }
206
207    async fn io_with_timeout<T, F>(
208        url: &Url,
209        timeout: Option<Duration>,
210        operation: &'static str,
211        future: F,
212    ) -> WebSocketResult<T>
213    where
214        F: Future<Output = std::io::Result<T>>,
215    {
216        let result = match timeout {
217            Some(duration) => {
218                tokio_timeout(duration, future)
219                    .await
220                    .map_err(|_| WebSocketError::Timeout {
221                        url: url.to_string(),
222                        operation: format!("{operation} after {:?}", duration),
223                    })?
224            }
225            None => future.await,
226        };
227
228        result.map_err(|error| WebSocketError::io(url, error))
229    }
230}