1use std::fmt;
2use std::io::{Read, Write};
3use std::net::{SocketAddr, TcpStream as StdTcpStream};
4use std::result;
5use std::str;
6
7use futures::StreamExt;
8use tokio::{
9 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
10 net::TcpStream as TokioTcpStream,
11};
12use tokio_util::codec::{Decoder, Framed};
13use url::{self, Url};
14use websocket_codec::UpgradeCodec;
15
16use crate::ssl;
17use crate::sync;
18use crate::{AsyncClient, AsyncNetworkStream, Client, MessageCodec, NetworkStream, Result};
19
20fn replace_codec<T, C1, C2>(framed: Framed<T, C1>, codec: C2) -> Framed<T, C2>
21where
22 T: AsyncRead + AsyncWrite,
23{
24 let parts1 = framed.into_parts();
26 let mut parts2 = Framed::new(parts1.io, codec).into_parts();
27 parts2.read_buf = parts1.read_buf;
28 parts2.write_buf = parts1.write_buf;
29 Framed::from_parts(parts2)
30}
31
32macro_rules! writeok {
33 ($dst:expr, $($arg:tt)*) => {
34 let _ = fmt::Write::write_fmt(&mut $dst, format_args!($($arg)*));
35 }
36}
37
38fn resolve(url: &Url) -> Result<SocketAddr> {
39 url.socket_addrs(|| None)?
40 .into_iter()
41 .next()
42 .ok_or_else(|| "can't resolve host".to_owned().into())
43}
44
45fn make_key(key: Option<[u8; 16]>, key_base64: &mut [u8; 24]) -> &str {
46 let key_bytes = key.unwrap_or_else(rand::random);
47 assert_eq!(
48 24,
49 base64::encode_config_slice(&key_bytes, base64::STANDARD, key_base64)
50 );
51
52 str::from_utf8(key_base64).unwrap()
53}
54
55fn build_request(url: &Url, key: &str, headers: &[(String, String)]) -> String {
56 let mut s = String::new();
57 writeok!(s, "GET {path}", path = url.path());
58 if let Some(query) = url.query() {
59 writeok!(s, "?{query}", query = query);
60 }
61
62 s += " HTTP/1.1\r\n";
63
64 if let Some(host) = url.host() {
65 writeok!(s, "Host: {host}", host = host);
66 if let Some(port) = url.port_or_known_default() {
67 writeok!(s, ":{port}", port = port);
68 }
69
70 s += "\r\n";
71 }
72
73 writeok!(
74 s,
75 "Upgrade: websocket\r\n\
76 Connection: Upgrade\r\n\
77 Sec-WebSocket-Key: {key}\r\n\
78 Sec-WebSocket-Version: 13\r\n",
79 key = key
80 );
81
82 for (name, value) in headers {
83 writeok!(s, "{name}: {value}\r\n", name = name, value = value);
84 }
85
86 writeok!(s, "\r\n");
87 s
88}
89
90pub struct ClientBuilder {
94 url: Url,
95 key: Option<[u8; 16]>,
96 headers: Vec<(String, String)>,
97}
98
99impl ClientBuilder {
100 pub fn new(url: &str) -> result::Result<Self, url::ParseError> {
104 Ok(Self::from_url(Url::parse(url)?))
105 }
106
107 pub fn from_url(url: Url) -> Self {
111 ClientBuilder {
112 url,
113 key: None,
114 headers: Vec::new(),
115 }
116 }
117
118 pub fn add_header(&mut self, name: String, value: String) {
121 self.headers.push((name, value));
122 }
123
124 pub async fn async_connect_insecure(self) -> Result<AsyncClient<TokioTcpStream>> {
129 let addr = resolve(&self.url)?;
130 let stream = TokioTcpStream::connect(&addr).await?;
131 self.async_connect_on(stream).await
132 }
133
134 pub fn connect_insecure(self) -> Result<Client<StdTcpStream>> {
139 let addr = resolve(&self.url)?;
140 let stream = StdTcpStream::connect(&addr)?;
141 self.connect_on(stream)
142 }
143
144 pub async fn async_connect(
146 self,
147 ) -> Result<AsyncClient<Box<dyn AsyncNetworkStream + Sync + Send + Unpin + 'static>>> {
148 let addr = resolve(&self.url)?;
149 let stream = TokioTcpStream::connect(&addr).await?;
150
151 let stream: Box<dyn AsyncNetworkStream + Sync + Send + Unpin + 'static> = if self.url.scheme() == "wss" {
152 let domain = self.url.domain().unwrap_or("").to_owned();
153 let stream = ssl::async_wrap(domain, stream).await?;
154 Box::new(stream)
155 } else {
156 Box::new(stream)
157 };
158
159 self.async_connect_on(stream).await
160 }
161
162 pub fn connect(self) -> Result<Client<Box<dyn NetworkStream + Sync + Send + 'static>>> {
164 let addr = resolve(&self.url)?;
165 let stream = StdTcpStream::connect(&addr)?;
166
167 let stream: Box<dyn NetworkStream + Sync + Send + 'static> = if self.url.scheme() == "wss" {
168 let domain = self.url.domain().unwrap_or("");
169 let stream = ssl::wrap(domain, stream)?;
170 Box::new(stream)
171 } else {
172 Box::new(stream)
173 };
174
175 self.connect_on(stream)
176 }
177
178 pub async fn async_connect_on<S: AsyncRead + AsyncWrite + Unpin>(self, mut stream: S) -> Result<AsyncClient<S>> {
183 let mut key_base64 = [0; 24];
184 let key = make_key(self.key, &mut key_base64);
185 let upgrade_codec = UpgradeCodec::new(key);
186 let request = build_request(&self.url, key, &self.headers);
187 AsyncWriteExt::write_all(&mut stream, request.as_bytes()).await?;
188
189 let (opt, framed) = upgrade_codec.framed(stream).into_future().await;
190 opt.ok_or_else(|| "no HTTP Upgrade response".to_owned())??;
191 Ok(replace_codec(framed, MessageCodec::client()))
192 }
193
194 pub fn connect_on<S: Read + Write>(self, mut stream: S) -> Result<Client<S>> {
199 let mut key_base64 = [0; 24];
200 let key = make_key(self.key, &mut key_base64);
201 let upgrade_codec = UpgradeCodec::new(key);
202 let request = build_request(&self.url, key, &self.headers);
203 Write::write_all(&mut stream, request.as_bytes())?;
204
205 let mut framed = sync::Framed::new(stream, upgrade_codec);
206 framed.receive()?.ok_or_else(|| "no HTTP Upgrade response".to_owned())?;
207 Ok(framed.replace_codec(MessageCodec::client()))
208 }
209
210 #[cfg(test)]
212 fn key(mut self, key: &[u8]) -> Self {
213 let mut a = [0; 16];
214 a.copy_from_slice(key);
215 self.key = Some(a);
216 self
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use std::fmt;
223 use std::io::{self, Cursor, Read, Write};
224 use std::pin::Pin;
225 use std::result;
226 use std::str;
227 use std::task::{Context, Poll};
228
229 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
230
231 use crate::ClientBuilder;
232
233 type Result<T> = result::Result<T, crate::Error>;
234
235 pub struct ReadWritePair<R, W>(pub R, pub W);
236
237 impl<R: Read, W> Read for ReadWritePair<R, W> {
238 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
239 self.0.read(buf)
240 }
241
242 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
243 self.0.read_to_end(buf)
244 }
245
246 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
247 self.0.read_to_string(buf)
248 }
249
250 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
251 self.0.read_exact(buf)
252 }
253 }
254
255 impl<R, W: Write> Write for ReadWritePair<R, W> {
256 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
257 self.1.write(buf)
258 }
259
260 fn flush(&mut self) -> io::Result<()> {
261 self.1.flush()
262 }
263
264 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
265 self.1.write_all(buf)
266 }
267
268 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
269 self.1.write_fmt(fmt)
270 }
271 }
272
273 impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for ReadWritePair<R, W> {
274 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
275 Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
276 }
277 }
278
279 impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for ReadWritePair<R, W> {
280 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
281 Pin::new(&mut self.get_mut().1).poll_write(cx, buf)
282 }
283
284 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285 Pin::new(&mut self.get_mut().1).poll_flush(cx)
286 }
287
288 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
289 Pin::new(&mut self.get_mut().1).poll_shutdown(cx)
290 }
291 }
292
293 static REQUEST: &str = "GET /stream?query HTTP/1.1\r\n\
294 Host: localhost:8000\r\n\
295 Upgrade: websocket\r\n\
296 Connection: Upgrade\r\n\
297 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
298 Sec-WebSocket-Version: 13\r\n\
299 \r\n";
300
301 static RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
302 Upgrade: websocket\r\n\
303 Connection: Upgrade\r\n\
304 sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
305 \r\n";
306
307 #[tokio::test]
308 async fn can_async_connect_on() -> Result<()> {
309 let mut input = Cursor::new(RESPONSE);
310 let mut output = Vec::new();
311
312 ClientBuilder::new("ws://localhost:8000/stream?query")?
313 .key(&base64::decode(b"dGhlIHNhbXBsZSBub25jZQ==")?)
314 .async_connect_on(ReadWritePair(&mut input, &mut output))
315 .await
316 .unwrap();
317
318 assert_eq!(REQUEST, str::from_utf8(&output)?);
319 Ok(())
320 }
321
322 #[test]
323 fn can_connect_on() -> Result<()> {
324 let mut input = Cursor::new(RESPONSE);
325 let mut output = Vec::new();
326
327 ClientBuilder::new("ws://localhost:8000/stream?query")?
328 .key(&base64::decode(b"dGhlIHNhbXBsZSBub25jZQ==")?)
329 .connect_on(ReadWritePair(&mut input, &mut output))?;
330
331 assert_eq!(REQUEST, str::from_utf8(&output)?);
332 Ok(())
333 }
334}