Skip to main content

specter/websocket/
connection.rs

1use std::future::Future;
2use std::time::Duration;
3
4use crate::url::Url;
5use bytes::{Bytes, BytesMut};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
7use tokio::time::timeout as tokio_timeout;
8
9use crate::transport::connector::MaybeHttpsStream;
10use crate::websocket::error::{WebSocketError, WebSocketResult};
11use crate::websocket::frame::{
12    decode_frame, encode_frame_append, encode_frame_into, Frame, FrameConfig, FrameDecoder,
13    MaskRng, OpCode,
14};
15use crate::websocket::message::{CloseFrame, Message, PreparedMessage};
16use crate::websocket::WebSocketConfig;
17
18const READ_CHUNK_SIZE: usize = 16 * 1024;
19const INITIAL_READ_CAPACITY: usize = 16 * 1024;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WebSocketFrameOpcode {
23    Continuation,
24    Text,
25    Binary,
26    Close,
27    Ping,
28    Pong,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct WebSocketFrame {
33    pub fin: bool,
34    pub opcode: WebSocketFrameOpcode,
35    pub payload: Bytes,
36}
37
38impl From<OpCode> for WebSocketFrameOpcode {
39    fn from(value: OpCode) -> Self {
40        match value {
41            OpCode::Continuation => Self::Continuation,
42            OpCode::Text => Self::Text,
43            OpCode::Binary => Self::Binary,
44            OpCode::Close => Self::Close,
45            OpCode::Ping => Self::Ping,
46            OpCode::Pong => Self::Pong,
47        }
48    }
49}
50
51impl From<Frame> for WebSocketFrame {
52    fn from(frame: Frame) -> Self {
53        Self {
54            fin: frame.fin,
55            opcode: WebSocketFrameOpcode::from(frame.opcode),
56            payload: frame.payload,
57        }
58    }
59}
60
61#[derive(Debug)]
62pub struct WebSocket {
63    stream: MaybeHttpsStream,
64    url: Url,
65    protocol: Option<String>,
66    read_buffer: BytesMut,
67    write_buffer: BytesMut,
68    frame_config: FrameConfig,
69    read_timeout: Option<Duration>,
70    write_timeout: Option<Duration>,
71    decoder: FrameDecoder,
72    mask_rng: MaskRng,
73    close_sent: bool,
74    close_received: bool,
75}
76
77#[derive(Debug)]
78pub struct WebSocketReader {
79    stream: ReadHalf<MaybeHttpsStream>,
80    url: Url,
81    read_buffer: BytesMut,
82    frame_config: FrameConfig,
83    read_timeout: Option<Duration>,
84    decoder: FrameDecoder,
85    close_received: bool,
86}
87
88#[derive(Debug)]
89pub struct WebSocketWriter {
90    stream: WriteHalf<MaybeHttpsStream>,
91    url: Url,
92    write_buffer: BytesMut,
93    frame_config: FrameConfig,
94    write_timeout: Option<Duration>,
95    mask_rng: MaskRng,
96    close_sent: bool,
97}
98
99impl WebSocket {
100    pub(crate) fn new(
101        stream: MaybeHttpsStream,
102        url: Url,
103        protocol: Option<String>,
104        config: WebSocketConfig,
105        initial_read_buffer: Bytes,
106    ) -> Self {
107        // Pre-allocate the read buffer so the first frame doesn't pay the
108        // grow-from-zero cost. Carries over any bytes left in the handshake
109        // buffer (typically empty).
110        let mut read_buffer =
111            BytesMut::with_capacity(INITIAL_READ_CAPACITY.max(initial_read_buffer.len()));
112        read_buffer.extend_from_slice(&initial_read_buffer);
113        Self {
114            stream,
115            url,
116            protocol,
117            read_buffer,
118            write_buffer: BytesMut::with_capacity(READ_CHUNK_SIZE),
119            frame_config: FrameConfig::new(config.max_frame_size, config.max_message_size),
120            read_timeout: config.read_timeout,
121            write_timeout: config.write_timeout,
122            decoder: FrameDecoder::new(),
123            mask_rng: MaskRng::new(),
124            close_sent: false,
125            close_received: false,
126        }
127    }
128
129    pub fn url(&self) -> &Url {
130        &self.url
131    }
132
133    pub fn protocol(&self) -> Option<&str> {
134        self.protocol.as_deref()
135    }
136
137    pub fn split(self) -> (WebSocketReader, WebSocketWriter) {
138        let (read_stream, write_stream) = tokio::io::split(self.stream);
139        let reader = WebSocketReader {
140            stream: read_stream,
141            url: self.url.clone(),
142            read_buffer: self.read_buffer,
143            frame_config: self.frame_config,
144            read_timeout: self.read_timeout,
145            decoder: self.decoder,
146            close_received: self.close_received,
147        };
148        let writer = WebSocketWriter {
149            stream: write_stream,
150            url: self.url,
151            write_buffer: self.write_buffer,
152            frame_config: self.frame_config,
153            write_timeout: self.write_timeout,
154            mask_rng: self.mask_rng,
155            close_sent: self.close_sent,
156        };
157        (reader, writer)
158    }
159
160    pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
161        if self.close_sent && !matches!(msg, Message::Close(_)) {
162            return Err(WebSocketError::protocol(
163                &self.url,
164                "cannot send data after close frame",
165            ));
166        }
167
168        match msg {
169            Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
170            Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
171            Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
172            Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
173            Message::Close(frame) => self.close(frame).await,
174        }
175    }
176
177    pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
178        self.send(Message::Text(text.into())).await
179    }
180
181    pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
182        self.send(Message::Binary(bytes.into())).await
183    }
184
185    pub async fn send_prepared(&mut self, message: &PreparedMessage) -> WebSocketResult<()> {
186        match message {
187            PreparedMessage::Text(bytes) => self.write_frame(OpCode::Text, bytes).await,
188            PreparedMessage::Binary(bytes) => self.write_frame(OpCode::Binary, bytes).await,
189        }
190    }
191
192    pub async fn send_prepared_batch<'a>(
193        &mut self,
194        messages: impl IntoIterator<Item = &'a PreparedMessage>,
195    ) -> WebSocketResult<()> {
196        self.write_prepared_batch(messages).await
197    }
198
199    pub async fn next_frame(&mut self) -> WebSocketResult<Option<WebSocketFrame>> {
200        Self::read_next_frame(
201            &self.url,
202            self.read_timeout,
203            &mut self.stream,
204            &mut self.read_buffer,
205            self.frame_config,
206        )
207        .await
208    }
209
210    pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
211        loop {
212            let frame = match decode_frame(&self.url, &mut self.read_buffer, self.frame_config) {
213                Ok(frame) => frame,
214                Err(error) => return Err(self.best_effort_close_for_error(error).await),
215            };
216
217            if let Some(frame) = frame {
218                let message = match self
219                    .decoder
220                    .decode_message(&self.url, frame, self.frame_config)
221                {
222                    Ok(message) => message,
223                    Err(error) => return Err(self.best_effort_close_for_error(error).await),
224                };
225
226                match message {
227                    Some(Message::Ping(payload)) => {
228                        if !self.close_received {
229                            self.write_control(OpCode::Pong, &payload).await?;
230                        }
231                        return Ok(Some(Message::Ping(payload)));
232                    }
233                    Some(Message::Close(frame)) => {
234                        self.close_received = true;
235                        if !self.close_sent {
236                            self.send_close_raw(frame.clone()).await?;
237                        }
238                        return Ok(None);
239                    }
240                    Some(other) => return Ok(Some(other)),
241                    None => {}
242                }
243            } else {
244                self.read_buffer.reserve(READ_CHUNK_SIZE);
245                let n = Self::io_with_timeout(
246                    &self.url,
247                    self.read_timeout,
248                    "read",
249                    self.stream.read_buf(&mut self.read_buffer),
250                )
251                .await?;
252                if n == 0 {
253                    return if self.close_sent || self.close_received {
254                        Ok(None)
255                    } else {
256                        Err(WebSocketError::connection_closed(&self.url))
257                    };
258                }
259            }
260        }
261    }
262
263    pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
264        if !self.close_sent {
265            self.send_close_raw(frame).await?;
266        }
267        Ok(())
268    }
269
270    async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
271        validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
272        encode_frame_into(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
273        Self::io_with_timeout(
274            &self.url,
275            self.write_timeout,
276            "write",
277            self.stream.write_all(&self.write_buffer),
278        )
279        .await
280    }
281
282    async fn write_prepared_batch<'a>(
283        &mut self,
284        messages: impl IntoIterator<Item = &'a PreparedMessage>,
285    ) -> WebSocketResult<()> {
286        // Encode all frames into the per-connection write_buffer in one pass.
287        // The earlier implementation built a fresh BytesMut for the batch and
288        // copied each encoded frame into it; with encode_frame_append we
289        // write into the live buffer directly, saving an allocation per call
290        // and a full memcpy per batched frame.
291        self.write_buffer.clear();
292        for message in messages {
293            let (opcode, payload) = match message {
294                PreparedMessage::Text(bytes) => (OpCode::Text, bytes.as_ref()),
295                PreparedMessage::Binary(bytes) => (OpCode::Binary, bytes.as_ref()),
296            };
297            validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
298            encode_frame_append(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
299        }
300        if self.write_buffer.is_empty() {
301            return Ok(());
302        }
303        Self::io_with_timeout(
304            &self.url,
305            self.write_timeout,
306            "write",
307            self.stream.write_all(&self.write_buffer),
308        )
309        .await
310    }
311
312    async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
313        if payload.len() > 125 {
314            return Err(WebSocketError::protocol(
315                &self.url,
316                "control frame payload exceeds 125 bytes",
317            ));
318        }
319        self.write_frame(opcode, payload).await?;
320        Self::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush()).await
321    }
322
323    async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
324        let payload = match frame {
325            Some(frame) => frame.encode(&self.url)?,
326            None => Vec::new(),
327        };
328        self.write_control(OpCode::Close, &payload).await?;
329        self.close_sent = true;
330        Ok(())
331    }
332
333    async fn best_effort_close_for_error(&mut self, error: WebSocketError) -> WebSocketError {
334        if let Some(code) = error.close_code() {
335            if !self.close_sent {
336                let frame = CloseFrame {
337                    code,
338                    reason: String::new(),
339                };
340                let _ = self.send_close_raw(Some(frame)).await;
341            }
342        }
343        error
344    }
345
346    async fn io_with_timeout<T, F>(
347        url: &Url,
348        timeout: Option<Duration>,
349        operation: &'static str,
350        future: F,
351    ) -> WebSocketResult<T>
352    where
353        F: Future<Output = std::io::Result<T>>,
354    {
355        let result = match timeout {
356            Some(duration) => {
357                tokio_timeout(duration, future)
358                    .await
359                    .map_err(|_| WebSocketError::Timeout {
360                        url: url.to_string(),
361                        operation: format!("{operation} after {:?}", duration),
362                    })?
363            }
364            None => future.await,
365        };
366
367        result.map_err(|error| WebSocketError::io(url, error))
368    }
369
370    async fn read_next_frame<S>(
371        url: &Url,
372        read_timeout: Option<Duration>,
373        stream: &mut S,
374        read_buffer: &mut BytesMut,
375        frame_config: FrameConfig,
376    ) -> WebSocketResult<Option<WebSocketFrame>>
377    where
378        S: tokio::io::AsyncRead + Unpin,
379    {
380        loop {
381            if let Some(frame) = decode_frame(url, read_buffer, frame_config)? {
382                return Ok(Some(WebSocketFrame {
383                    fin: frame.fin,
384                    opcode: frame.opcode.into(),
385                    payload: frame.payload,
386                }));
387            }
388            read_buffer.reserve(READ_CHUNK_SIZE);
389            let n = Self::io_with_timeout(url, read_timeout, "read", stream.read_buf(read_buffer))
390                .await?;
391            if n == 0 {
392                return Ok(None);
393            }
394        }
395    }
396}
397
398impl WebSocketReader {
399    pub async fn next_frame(&mut self) -> WebSocketResult<Option<WebSocketFrame>> {
400        WebSocket::read_next_frame(
401            &self.url,
402            self.read_timeout,
403            &mut self.stream,
404            &mut self.read_buffer,
405            self.frame_config,
406        )
407        .await
408    }
409
410    pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
411        loop {
412            let frame = decode_frame(&self.url, &mut self.read_buffer, self.frame_config)?;
413            if let Some(frame) = frame {
414                let message = self
415                    .decoder
416                    .decode_message(&self.url, frame, self.frame_config)?;
417                match message {
418                    Some(Message::Close(_)) => {
419                        self.close_received = true;
420                        return Ok(None);
421                    }
422                    Some(other) => return Ok(Some(other)),
423                    None => {}
424                }
425            } else {
426                self.read_buffer.reserve(READ_CHUNK_SIZE);
427                let n = WebSocket::io_with_timeout(
428                    &self.url,
429                    self.read_timeout,
430                    "read",
431                    self.stream.read_buf(&mut self.read_buffer),
432                )
433                .await?;
434                if n == 0 {
435                    return if self.close_received {
436                        Ok(None)
437                    } else {
438                        Err(WebSocketError::connection_closed(&self.url))
439                    };
440                }
441            }
442        }
443    }
444}
445
446impl WebSocketWriter {
447    pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
448        if self.close_sent && !matches!(msg, Message::Close(_)) {
449            return Err(WebSocketError::protocol(
450                &self.url,
451                "cannot send data after close frame",
452            ));
453        }
454
455        match msg {
456            Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
457            Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
458            Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
459            Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
460            Message::Close(frame) => self.close(frame).await,
461        }
462    }
463
464    pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
465        self.send(Message::Text(text.into())).await
466    }
467
468    pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
469        self.send(Message::Binary(bytes.into())).await
470    }
471
472    pub async fn send_prepared(&mut self, message: &PreparedMessage) -> WebSocketResult<()> {
473        match message {
474            PreparedMessage::Text(bytes) => self.write_frame(OpCode::Text, bytes).await,
475            PreparedMessage::Binary(bytes) => self.write_frame(OpCode::Binary, bytes).await,
476        }
477    }
478
479    pub async fn send_prepared_batch<'a>(
480        &mut self,
481        messages: impl IntoIterator<Item = &'a PreparedMessage>,
482    ) -> WebSocketResult<()> {
483        self.write_prepared_batch(messages).await
484    }
485
486    pub async fn send_ping(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
487        self.send(Message::Ping(bytes.into())).await
488    }
489
490    pub async fn send_pong(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
491        self.send(Message::Pong(bytes.into())).await
492    }
493
494    pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
495        if !self.close_sent {
496            self.send_close_raw(frame).await?;
497        }
498        Ok(())
499    }
500
501    async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
502        validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
503        encode_frame_into(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
504        WebSocket::io_with_timeout(
505            &self.url,
506            self.write_timeout,
507            "write",
508            self.stream.write_all(&self.write_buffer),
509        )
510        .await
511    }
512
513    async fn write_prepared_batch<'a>(
514        &mut self,
515        messages: impl IntoIterator<Item = &'a PreparedMessage>,
516    ) -> WebSocketResult<()> {
517        // See WebSocket::write_prepared_batch for the rationale: encode all
518        // frames into the per-writer write_buffer in one pass via
519        // encode_frame_append, saving the batch allocation and per-frame
520        // memcpy that the prior implementation paid.
521        self.write_buffer.clear();
522        for message in messages {
523            let (opcode, payload) = match message {
524                PreparedMessage::Text(bytes) => (OpCode::Text, bytes.as_ref()),
525                PreparedMessage::Binary(bytes) => (OpCode::Binary, bytes.as_ref()),
526            };
527            validate_outbound_payload(&self.url, self.frame_config, opcode, payload)?;
528            encode_frame_append(opcode, payload, &mut self.mask_rng, &mut self.write_buffer);
529        }
530        if self.write_buffer.is_empty() {
531            return Ok(());
532        }
533        WebSocket::io_with_timeout(
534            &self.url,
535            self.write_timeout,
536            "write",
537            self.stream.write_all(&self.write_buffer),
538        )
539        .await
540    }
541
542    async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
543        if payload.len() > 125 {
544            return Err(WebSocketError::protocol(
545                &self.url,
546                "control frame payload exceeds 125 bytes",
547            ));
548        }
549        self.write_frame(opcode, payload).await?;
550        WebSocket::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush())
551            .await
552    }
553
554    async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
555        let payload = match frame {
556            Some(frame) => frame.encode(&self.url)?,
557            None => Vec::new(),
558        };
559        self.write_control(OpCode::Close, &payload).await?;
560        self.close_sent = true;
561        Ok(())
562    }
563}
564
565fn validate_outbound_payload(
566    url: &Url,
567    frame_config: FrameConfig,
568    opcode: OpCode,
569    payload: &[u8],
570) -> WebSocketResult<()> {
571    if payload.len() > frame_config.max_frame_size {
572        return Err(WebSocketError::limit_exceeded(
573            url,
574            format!("frame exceeds {} bytes", frame_config.max_frame_size),
575        ));
576    }
577    if matches!(opcode, OpCode::Text | OpCode::Binary)
578        && payload.len() > frame_config.max_message_size
579    {
580        return Err(WebSocketError::limit_exceeded(
581            url,
582            format!("message exceeds {} bytes", frame_config.max_message_size),
583        ));
584    }
585    Ok(())
586}