ttpkit_http/ws/
client.rs

1use std::fmt::Write;
2
3use crate::{
4    Body, Error, Version,
5    client::{IncomingResponse, OutgoingRequest},
6    url::Url,
7    ws::{AgentRole, WebSocket},
8};
9
10/// Builder for a client handshake.
11pub struct ClientHandshakeBuilder {
12    key: String,
13    protocols: Vec<String>,
14    input_buffer_capacity: usize,
15}
16
17impl ClientHandshakeBuilder {
18    /// Create a new builder.
19    fn new() -> Self {
20        Self {
21            key: super::create_key(),
22            protocols: Vec::new(),
23            input_buffer_capacity: 65_536,
24        }
25    }
26
27    /// Indicate that a given WS sub-protocol is being supported by the client.
28    pub fn protocol<T>(mut self, protocol: T) -> Self
29    where
30        T: Into<String>,
31    {
32        self.protocols.push(protocol.into());
33        self
34    }
35
36    /// Set the maximum input buffer capacity (default is 65_536).
37    #[inline]
38    pub fn input_buffer_capacity(mut self, capacity: usize) -> Self {
39        self.input_buffer_capacity = capacity;
40        self
41    }
42
43    /// Build the handshake and prepare an outgoing client request.
44    pub fn build(self, url: Url) -> (ClientHandshake, OutgoingRequest) {
45        let handshake = ClientHandshake {
46            accept: super::create_accept_token(self.key.as_bytes()),
47            protocols: self.protocols,
48            input_buffer_capacity: self.input_buffer_capacity,
49        };
50
51        let mut builder = OutgoingRequest::get(url)
52            .unwrap()
53            .set_version(Version::Version11)
54            .add_header_field(("Connection", "upgrade"))
55            .add_header_field(("Upgrade", "websocket"))
56            .add_header_field(("Sec-WebSocket-Version", "13"))
57            .add_header_field(("Sec-WebSocket-Key", self.key));
58
59        if !handshake.protocols.is_empty() {
60            builder = builder
61                .add_header_field(("Sec-WebSocket-Protocol", join_tokens(&handshake.protocols)));
62        }
63
64        (handshake, builder.body(Body::empty()))
65    }
66}
67
68/// Client WS handshake.
69pub struct ClientHandshake {
70    accept: String,
71    protocols: Vec<String>,
72    input_buffer_capacity: usize,
73}
74
75impl ClientHandshake {
76    /// Get a builder for the handshake.
77    #[inline]
78    pub fn builder() -> ClientHandshakeBuilder {
79        ClientHandshakeBuilder::new()
80    }
81
82    /// Complete the handshake using a given incoming HTTP response.
83    pub fn complete(self, response: IncomingResponse) -> Result<WebSocket, Error> {
84        assert_eq!(response.status_code(), 101);
85
86        let is_upgrade = response
87            .get_header_field_value("connection")
88            .map(|v| v.as_ref())
89            .unwrap_or(b"")
90            .split(|&b| b == b',')
91            .map(|kw| kw.trim_ascii())
92            .filter(|kw| !kw.is_empty())
93            .any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
94
95        if !is_upgrade {
96            return Err(Error::from_static_msg("not a connection upgrade"));
97        }
98
99        let is_websocket = response
100            .get_header_field_value("upgrade")
101            .map(|v| v.as_ref())
102            .unwrap_or(b"")
103            .trim_ascii()
104            .eq_ignore_ascii_case(b"websocket");
105
106        if !is_websocket {
107            return Err(Error::from_static_msg("not a WebSocket upgrade"));
108        }
109
110        let accept = response
111            .get_header_field_value("sec-websocket-accept")
112            .ok_or_else(|| Error::from_static_msg("missing WS accept header"))?
113            .as_ref();
114
115        if accept != self.accept.as_bytes() {
116            return Err(Error::from_static_msg("invalid WS accept header"));
117        }
118
119        let protocol = response
120            .get_header_field_value("sec-websocket-protocol")
121            .map(|p| p.trim_ascii());
122
123        if let Some(protocol) = protocol {
124            let is_valid_protocol = self.protocols.iter().any(|p| p.as_bytes() == protocol);
125
126            if !is_valid_protocol {
127                return Err(Error::from_static_msg("invalid WS sub-protocol"));
128            }
129        }
130
131        let extensions = response
132            .get_header_fields("sec-websocket-extensions")
133            .flat_map(|field| {
134                field
135                    .value()
136                    .map(|v| v.as_ref())
137                    .unwrap_or(b"")
138                    .split(|&b| b == b',')
139                    .map(|p| p.trim_ascii())
140                    .filter(|p| !p.is_empty())
141            })
142            .count();
143
144        if extensions != 0 {
145            return Err(Error::from_static_msg("unknown WS extensions"));
146        }
147
148        let upgraded = response
149            .upgrade()
150            .ok_or_else(|| Error::from_static_msg("unable to upgrade the HTTP connection"))?;
151
152        let res = WebSocket::new(upgraded, AgentRole::Client, self.input_buffer_capacity);
153
154        Ok(res)
155    }
156}
157
158/// Join a given slice of tokens into a comma separated list.
159fn join_tokens(tokens: &[String]) -> String {
160    let mut res = String::new();
161
162    let mut tokens = tokens.iter();
163
164    if let Some(token) = tokens.next() {
165        res += token.trim();
166    }
167
168    for token in tokens {
169        let _ = write!(res, ",{}", token.trim());
170    }
171
172    res
173}