1use std::fmt::Write;
2
3use crate::{
4 Body, Error, Version,
5 client::{IncomingResponse, OutgoingRequest},
6 url::Url,
7 ws::{AgentRole, WebSocket},
8};
9
10pub struct ClientHandshakeBuilder {
12 key: String,
13 protocols: Vec<String>,
14 input_buffer_capacity: usize,
15}
16
17impl ClientHandshakeBuilder {
18 fn new() -> Self {
20 Self {
21 key: super::create_key(),
22 protocols: Vec::new(),
23 input_buffer_capacity: 65_536,
24 }
25 }
26
27 pub fn protocol<T>(mut self, protocol: T) -> Self
29 where
30 T: Into<String>,
31 {
32 self.protocols.push(protocol.into());
33 self
34 }
35
36 #[inline]
38 pub fn input_buffer_capacity(mut self, capacity: usize) -> Self {
39 self.input_buffer_capacity = capacity;
40 self
41 }
42
43 pub fn build(self, url: Url) -> (ClientHandshake, OutgoingRequest) {
45 let handshake = ClientHandshake {
46 accept: super::create_accept_token(self.key.as_bytes()),
47 protocols: self.protocols,
48 input_buffer_capacity: self.input_buffer_capacity,
49 };
50
51 let mut builder = OutgoingRequest::get(url)
52 .unwrap()
53 .set_version(Version::Version11)
54 .add_header_field(("Connection", "upgrade"))
55 .add_header_field(("Upgrade", "websocket"))
56 .add_header_field(("Sec-WebSocket-Version", "13"))
57 .add_header_field(("Sec-WebSocket-Key", self.key));
58
59 if !handshake.protocols.is_empty() {
60 builder = builder
61 .add_header_field(("Sec-WebSocket-Protocol", join_tokens(&handshake.protocols)));
62 }
63
64 (handshake, builder.body(Body::empty()))
65 }
66}
67
68pub struct ClientHandshake {
70 accept: String,
71 protocols: Vec<String>,
72 input_buffer_capacity: usize,
73}
74
75impl ClientHandshake {
76 #[inline]
78 pub fn builder() -> ClientHandshakeBuilder {
79 ClientHandshakeBuilder::new()
80 }
81
82 pub fn complete(self, response: IncomingResponse) -> Result<WebSocket, Error> {
84 assert_eq!(response.status_code(), 101);
85
86 let is_upgrade = response
87 .get_header_field_value("connection")
88 .map(|v| v.as_ref())
89 .unwrap_or(b"")
90 .split(|&b| b == b',')
91 .map(|kw| kw.trim_ascii())
92 .filter(|kw| !kw.is_empty())
93 .any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
94
95 if !is_upgrade {
96 return Err(Error::from_static_msg("not a connection upgrade"));
97 }
98
99 let is_websocket = response
100 .get_header_field_value("upgrade")
101 .map(|v| v.as_ref())
102 .unwrap_or(b"")
103 .trim_ascii()
104 .eq_ignore_ascii_case(b"websocket");
105
106 if !is_websocket {
107 return Err(Error::from_static_msg("not a WebSocket upgrade"));
108 }
109
110 let accept = response
111 .get_header_field_value("sec-websocket-accept")
112 .ok_or_else(|| Error::from_static_msg("missing WS accept header"))?
113 .as_ref();
114
115 if accept != self.accept.as_bytes() {
116 return Err(Error::from_static_msg("invalid WS accept header"));
117 }
118
119 let protocol = response
120 .get_header_field_value("sec-websocket-protocol")
121 .map(|p| p.trim_ascii());
122
123 if let Some(protocol) = protocol {
124 let is_valid_protocol = self.protocols.iter().any(|p| p.as_bytes() == protocol);
125
126 if !is_valid_protocol {
127 return Err(Error::from_static_msg("invalid WS sub-protocol"));
128 }
129 }
130
131 let extensions = response
132 .get_header_fields("sec-websocket-extensions")
133 .flat_map(|field| {
134 field
135 .value()
136 .map(|v| v.as_ref())
137 .unwrap_or(b"")
138 .split(|&b| b == b',')
139 .map(|p| p.trim_ascii())
140 .filter(|p| !p.is_empty())
141 })
142 .count();
143
144 if extensions != 0 {
145 return Err(Error::from_static_msg("unknown WS extensions"));
146 }
147
148 let upgraded = response
149 .upgrade()
150 .ok_or_else(|| Error::from_static_msg("unable to upgrade the HTTP connection"))?;
151
152 let res = WebSocket::new(upgraded, AgentRole::Client, self.input_buffer_capacity);
153
154 Ok(res)
155 }
156}
157
158fn join_tokens(tokens: &[String]) -> String {
160 let mut res = String::new();
161
162 let mut tokens = tokens.iter();
163
164 if let Some(token) = tokens.next() {
165 res += token.trim();
166 }
167
168 for token in tokens {
169 let _ = write!(res, ",{}", token.trim());
170 }
171
172 res
173}