websocket_lite/
client.rs

1use std::fmt;
2use std::io::{Read, Write};
3use std::net::{SocketAddr, TcpStream as StdTcpStream};
4use std::result;
5use std::str;
6
7use futures::StreamExt;
8use tokio::{
9    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
10    net::TcpStream as TokioTcpStream,
11};
12use tokio_util::codec::{Decoder, Framed};
13use url::{self, Url};
14use websocket_codec::UpgradeCodec;
15
16use crate::ssl;
17use crate::sync;
18use crate::{AsyncClient, AsyncNetworkStream, Client, MessageCodec, NetworkStream, Result};
19
20fn replace_codec<T, C1, C2>(framed: Framed<T, C1>, codec: C2) -> Framed<T, C2>
21where
22    T: AsyncRead + AsyncWrite,
23{
24    // TODO improve this? https://github.com/tokio-rs/tokio/issues/717
25    let parts1 = framed.into_parts();
26    let mut parts2 = Framed::new(parts1.io, codec).into_parts();
27    parts2.read_buf = parts1.read_buf;
28    parts2.write_buf = parts1.write_buf;
29    Framed::from_parts(parts2)
30}
31
32macro_rules! writeok {
33    ($dst:expr, $($arg:tt)*) => {
34        let _ = fmt::Write::write_fmt(&mut $dst, format_args!($($arg)*));
35    }
36}
37
38fn resolve(url: &Url) -> Result<SocketAddr> {
39    url.socket_addrs(|| None)?
40        .into_iter()
41        .next()
42        .ok_or_else(|| "can't resolve host".to_owned().into())
43}
44
45fn make_key(key: Option<[u8; 16]>, key_base64: &mut [u8; 24]) -> &str {
46    let key_bytes = key.unwrap_or_else(rand::random);
47    assert_eq!(
48        24,
49        base64::encode_config_slice(&key_bytes, base64::STANDARD, key_base64)
50    );
51
52    str::from_utf8(key_base64).unwrap()
53}
54
55fn build_request(url: &Url, key: &str, headers: &[(String, String)]) -> String {
56    let mut s = String::new();
57    writeok!(s, "GET {path}", path = url.path());
58    if let Some(query) = url.query() {
59        writeok!(s, "?{query}", query = query);
60    }
61
62    s += " HTTP/1.1\r\n";
63
64    if let Some(host) = url.host() {
65        writeok!(s, "Host: {host}", host = host);
66        if let Some(port) = url.port_or_known_default() {
67            writeok!(s, ":{port}", port = port);
68        }
69
70        s += "\r\n";
71    }
72
73    writeok!(
74        s,
75        "Upgrade: websocket\r\n\
76         Connection: Upgrade\r\n\
77         Sec-WebSocket-Key: {key}\r\n\
78         Sec-WebSocket-Version: 13\r\n",
79        key = key
80    );
81
82    for (name, value) in headers {
83        writeok!(s, "{name}: {value}\r\n", name = name, value = value);
84    }
85
86    writeok!(s, "\r\n");
87    s
88}
89
90/// Establishes a WebSocket connection.
91///
92/// `ws://...` and `wss://...` URLs are supported.
93pub struct ClientBuilder {
94    url: Url,
95    key: Option<[u8; 16]>,
96    headers: Vec<(String, String)>,
97}
98
99impl ClientBuilder {
100    /// Creates a `ClientBuilder` that connects to a given WebSocket URL.
101    ///
102    /// This method returns an `Err` result if URL parsing fails.
103    pub fn new(url: &str) -> result::Result<Self, url::ParseError> {
104        Ok(Self::from_url(Url::parse(url)?))
105    }
106
107    /// Creates a `ClientBuilder` that connects to a given WebSocket URL.
108    ///
109    /// This method never fails as the URL has already been parsed.
110    pub fn from_url(url: Url) -> Self {
111        ClientBuilder {
112            url,
113            key: None,
114            headers: Vec::new(),
115        }
116    }
117
118    /// Adds an extra HTTP header for client
119    ///
120    pub fn add_header(&mut self, name: String, value: String) {
121        self.headers.push((name, value));
122    }
123
124    /// Establishes a connection to the WebSocket server.
125    ///
126    /// `wss://...` URLs are not supported by this method. Use `async_connect` if you need to be able to handle
127    /// both `ws://...` and `wss://...` URLs.
128    pub async fn async_connect_insecure(self) -> Result<AsyncClient<TokioTcpStream>> {
129        let addr = resolve(&self.url)?;
130        let stream = TokioTcpStream::connect(&addr).await?;
131        self.async_connect_on(stream).await
132    }
133
134    /// Establishes a connection to the WebSocket server.
135    ///
136    /// `wss://...` URLs are not supported by this method. Use `connect` if you need to be able to handle
137    /// both `ws://...` and `wss://...` URLs.
138    pub fn connect_insecure(self) -> Result<Client<StdTcpStream>> {
139        let addr = resolve(&self.url)?;
140        let stream = StdTcpStream::connect(&addr)?;
141        self.connect_on(stream)
142    }
143
144    /// Establishes a connection to the WebSocket server.
145    pub async fn async_connect(
146        self,
147    ) -> Result<AsyncClient<Box<dyn AsyncNetworkStream + Sync + Send + Unpin + 'static>>> {
148        let addr = resolve(&self.url)?;
149        let stream = TokioTcpStream::connect(&addr).await?;
150
151        let stream: Box<dyn AsyncNetworkStream + Sync + Send + Unpin + 'static> = if self.url.scheme() == "wss" {
152            let domain = self.url.domain().unwrap_or("").to_owned();
153            let stream = ssl::async_wrap(domain, stream).await?;
154            Box::new(stream)
155        } else {
156            Box::new(stream)
157        };
158
159        self.async_connect_on(stream).await
160    }
161
162    /// Establishes a connection to the WebSocket server.
163    pub fn connect(self) -> Result<Client<Box<dyn NetworkStream + Sync + Send + 'static>>> {
164        let addr = resolve(&self.url)?;
165        let stream = StdTcpStream::connect(&addr)?;
166
167        let stream: Box<dyn NetworkStream + Sync + Send + 'static> = if self.url.scheme() == "wss" {
168            let domain = self.url.domain().unwrap_or("");
169            let stream = ssl::wrap(domain, stream)?;
170            Box::new(stream)
171        } else {
172            Box::new(stream)
173        };
174
175        self.connect_on(stream)
176    }
177
178    /// Takes over an already established stream and uses it to send and receive WebSocket messages.
179    ///
180    /// This method assumes that the TLS connection has already been established, if needed. It sends an HTTP
181    /// `Connection: Upgrade` request and waits for an HTTP OK response before proceeding.
182    pub async fn async_connect_on<S: AsyncRead + AsyncWrite + Unpin>(self, mut stream: S) -> Result<AsyncClient<S>> {
183        let mut key_base64 = [0; 24];
184        let key = make_key(self.key, &mut key_base64);
185        let upgrade_codec = UpgradeCodec::new(key);
186        let request = build_request(&self.url, key, &self.headers);
187        AsyncWriteExt::write_all(&mut stream, request.as_bytes()).await?;
188
189        let (opt, framed) = upgrade_codec.framed(stream).into_future().await;
190        opt.ok_or_else(|| "no HTTP Upgrade response".to_owned())??;
191        Ok(replace_codec(framed, MessageCodec::client()))
192    }
193
194    /// Takes over an already established stream and uses it to send and receive WebSocket messages.
195    ///
196    /// This method assumes that the TLS connection has already been established, if needed. It sends an HTTP
197    /// `Connection: Upgrade` request and waits for an HTTP OK response before proceeding.
198    pub fn connect_on<S: Read + Write>(self, mut stream: S) -> Result<Client<S>> {
199        let mut key_base64 = [0; 24];
200        let key = make_key(self.key, &mut key_base64);
201        let upgrade_codec = UpgradeCodec::new(key);
202        let request = build_request(&self.url, key, &self.headers);
203        Write::write_all(&mut stream, request.as_bytes())?;
204
205        let mut framed = sync::Framed::new(stream, upgrade_codec);
206        framed.receive()?.ok_or_else(|| "no HTTP Upgrade response".to_owned())?;
207        Ok(framed.replace_codec(MessageCodec::client()))
208    }
209
210    // Not pub - used by the tests
211    #[cfg(test)]
212    fn key(mut self, key: &[u8]) -> Self {
213        let mut a = [0; 16];
214        a.copy_from_slice(key);
215        self.key = Some(a);
216        self
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use std::fmt;
223    use std::io::{self, Cursor, Read, Write};
224    use std::pin::Pin;
225    use std::result;
226    use std::str;
227    use std::task::{Context, Poll};
228
229    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
230
231    use crate::ClientBuilder;
232
233    type Result<T> = result::Result<T, crate::Error>;
234
235    pub struct ReadWritePair<R, W>(pub R, pub W);
236
237    impl<R: Read, W> Read for ReadWritePair<R, W> {
238        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
239            self.0.read(buf)
240        }
241
242        fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
243            self.0.read_to_end(buf)
244        }
245
246        fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
247            self.0.read_to_string(buf)
248        }
249
250        fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
251            self.0.read_exact(buf)
252        }
253    }
254
255    impl<R, W: Write> Write for ReadWritePair<R, W> {
256        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
257            self.1.write(buf)
258        }
259
260        fn flush(&mut self) -> io::Result<()> {
261            self.1.flush()
262        }
263
264        fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
265            self.1.write_all(buf)
266        }
267
268        fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
269            self.1.write_fmt(fmt)
270        }
271    }
272
273    impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for ReadWritePair<R, W> {
274        fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
275            Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
276        }
277    }
278
279    impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for ReadWritePair<R, W> {
280        fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
281            Pin::new(&mut self.get_mut().1).poll_write(cx, buf)
282        }
283
284        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285            Pin::new(&mut self.get_mut().1).poll_flush(cx)
286        }
287
288        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
289            Pin::new(&mut self.get_mut().1).poll_shutdown(cx)
290        }
291    }
292
293    static REQUEST: &str = "GET /stream?query HTTP/1.1\r\n\
294                            Host: localhost:8000\r\n\
295                            Upgrade: websocket\r\n\
296                            Connection: Upgrade\r\n\
297                            Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
298                            Sec-WebSocket-Version: 13\r\n\
299                            \r\n";
300
301    static RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
302                             Upgrade: websocket\r\n\
303                             Connection: Upgrade\r\n\
304                             sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
305                             \r\n";
306
307    #[tokio::test]
308    async fn can_async_connect_on() -> Result<()> {
309        let mut input = Cursor::new(RESPONSE);
310        let mut output = Vec::new();
311
312        ClientBuilder::new("ws://localhost:8000/stream?query")?
313            .key(&base64::decode(b"dGhlIHNhbXBsZSBub25jZQ==")?)
314            .async_connect_on(ReadWritePair(&mut input, &mut output))
315            .await
316            .unwrap();
317
318        assert_eq!(REQUEST, str::from_utf8(&output)?);
319        Ok(())
320    }
321
322    #[test]
323    fn can_connect_on() -> Result<()> {
324        let mut input = Cursor::new(RESPONSE);
325        let mut output = Vec::new();
326
327        ClientBuilder::new("ws://localhost:8000/stream?query")?
328            .key(&base64::decode(b"dGhlIHNhbXBsZSBub25jZQ==")?)
329            .connect_on(ReadWritePair(&mut input, &mut output))?;
330
331        assert_eq!(REQUEST, str::from_utf8(&output)?);
332        Ok(())
333    }
334}