1#![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-manual-roots",
55 feature = "tokio-rustls-native-certs",
56 feature = "tokio-rustls-webpki-roots",
57 feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::io::{Read, Write};
62
63use compat::{cvt, AllowStd, ContextWaker};
64use futures_io::{AsyncRead, AsyncWrite};
65use futures_util::{
66 sink::{Sink, SinkExt},
67 stream::{FusedStream, Stream},
68};
69use log::*;
70use std::pin::Pin;
71use std::task::{Context, Poll};
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75 client::IntoClientRequest,
76 handshake::{
77 client::{ClientHandshake, Response},
78 server::{Callback, NoCallback},
79 HandshakeError,
80 },
81};
82use tungstenite::{
83 error::Error as WsError,
84 protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96use tungstenite::protocol::CloseFrame;
97
98#[cfg(feature = "handshake")]
111pub async fn client_async<'a, R, S>(
112 request: R,
113 stream: S,
114) -> Result<(WebSocketStream<S>, Response), WsError>
115where
116 R: IntoClientRequest + Unpin,
117 S: AsyncRead + AsyncWrite + Unpin,
118{
119 client_async_with_config(request, stream, None).await
120}
121
122#[cfg(feature = "handshake")]
125pub async fn client_async_with_config<'a, R, S>(
126 request: R,
127 stream: S,
128 config: Option<WebSocketConfig>,
129) -> Result<(WebSocketStream<S>, Response), WsError>
130where
131 R: IntoClientRequest + Unpin,
132 S: AsyncRead + AsyncWrite + Unpin,
133{
134 let f = handshake::client_handshake(stream, move |allow_std| {
135 let request = request.into_client_request()?;
136 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
137 cli_handshake.handshake()
138 });
139 f.await.map_err(|e| match e {
140 HandshakeError::Failure(e) => e,
141 e => WsError::Io(std::io::Error::new(
142 std::io::ErrorKind::Other,
143 e.to_string(),
144 )),
145 })
146}
147
148#[cfg(feature = "handshake")]
160pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
161where
162 S: AsyncRead + AsyncWrite + Unpin,
163{
164 accept_hdr_async(stream, NoCallback).await
165}
166
167#[cfg(feature = "handshake")]
170pub async fn accept_async_with_config<S>(
171 stream: S,
172 config: Option<WebSocketConfig>,
173) -> Result<WebSocketStream<S>, WsError>
174where
175 S: AsyncRead + AsyncWrite + Unpin,
176{
177 accept_hdr_async_with_config(stream, NoCallback, config).await
178}
179
180#[cfg(feature = "handshake")]
186pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
187where
188 S: AsyncRead + AsyncWrite + Unpin,
189 C: Callback + Unpin,
190{
191 accept_hdr_async_with_config(stream, callback, None).await
192}
193
194#[cfg(feature = "handshake")]
197pub async fn accept_hdr_async_with_config<S, C>(
198 stream: S,
199 callback: C,
200 config: Option<WebSocketConfig>,
201) -> Result<WebSocketStream<S>, WsError>
202where
203 S: AsyncRead + AsyncWrite + Unpin,
204 C: Callback + Unpin,
205{
206 let f = handshake::server_handshake(stream, move |allow_std| {
207 tungstenite::accept_hdr_with_config(allow_std, callback, config)
208 });
209 f.await.map_err(|e| match e {
210 HandshakeError::Failure(e) => e,
211 e => WsError::Io(std::io::Error::new(
212 std::io::ErrorKind::Other,
213 e.to_string(),
214 )),
215 })
216}
217
218#[derive(Debug)]
228pub struct WebSocketStream<S> {
229 inner: WebSocket<AllowStd<S>>,
230 closing: bool,
231 ended: bool,
232 ready: bool,
237}
238
239impl<S> WebSocketStream<S> {
240 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
243 where
244 S: AsyncRead + AsyncWrite + Unpin,
245 {
246 handshake::without_handshake(stream, move |allow_std| {
247 WebSocket::from_raw_socket(allow_std, role, config)
248 })
249 .await
250 }
251
252 pub async fn from_partially_read(
255 stream: S,
256 part: Vec<u8>,
257 role: Role,
258 config: Option<WebSocketConfig>,
259 ) -> Self
260 where
261 S: AsyncRead + AsyncWrite + Unpin,
262 {
263 handshake::without_handshake(stream, move |allow_std| {
264 WebSocket::from_partially_read(allow_std, part, role, config)
265 })
266 .await
267 }
268
269 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
270 Self {
271 inner: ws,
272 closing: false,
273 ended: false,
274 ready: true,
275 }
276 }
277
278 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
279 where
280 S: Unpin,
281 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
282 AllowStd<S>: Read + Write,
283 {
284 #[cfg(feature = "verbose-logging")]
285 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
286 if let Some((kind, ctx)) = ctx {
287 self.inner.get_mut().set_waker(kind, ctx.waker());
288 }
289 f(&mut self.inner)
290 }
291
292 pub fn get_ref(&self) -> &S
294 where
295 S: AsyncRead + AsyncWrite + Unpin,
296 {
297 self.inner.get_ref().get_ref()
298 }
299
300 pub fn get_mut(&mut self) -> &mut S
302 where
303 S: AsyncRead + AsyncWrite + Unpin,
304 {
305 self.inner.get_mut().get_mut()
306 }
307
308 pub fn get_config(&self) -> &WebSocketConfig {
310 self.inner.get_config()
311 }
312
313 pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
315 where
316 S: AsyncRead + AsyncWrite + Unpin,
317 {
318 let msg = msg.map(|msg| msg.into_owned());
319 self.send(Message::Close(msg)).await
320 }
321}
322
323impl<T> Stream for WebSocketStream<T>
324where
325 T: AsyncRead + AsyncWrite + Unpin,
326{
327 type Item = Result<Message, WsError>;
328
329 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330 #[cfg(feature = "verbose-logging")]
331 trace!("{}:{} Stream.poll_next", file!(), line!());
332
333 if self.ended {
337 return Poll::Ready(None);
338 }
339
340 match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
341 #[cfg(feature = "verbose-logging")]
342 trace!(
343 "{}:{} Stream.with_context poll_next -> read()",
344 file!(),
345 line!()
346 );
347 cvt(s.read())
348 })) {
349 Ok(v) => Poll::Ready(Some(Ok(v))),
350 Err(e) => {
351 self.ended = true;
352 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
353 Poll::Ready(None)
354 } else {
355 Poll::Ready(Some(Err(e)))
356 }
357 }
358 }
359 }
360}
361
362impl<T> FusedStream for WebSocketStream<T>
363where
364 T: AsyncRead + AsyncWrite + Unpin,
365{
366 fn is_terminated(&self) -> bool {
367 self.ended
368 }
369}
370
371impl<T> Sink<Message> for WebSocketStream<T>
372where
373 T: AsyncRead + AsyncWrite + Unpin,
374{
375 type Error = WsError;
376
377 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
378 if self.ready {
379 Poll::Ready(Ok(()))
380 } else {
381 (*self)
383 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
384 .map(|r| {
385 self.ready = true;
386 r
387 })
388 }
389 }
390
391 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
392 match (*self).with_context(None, |s| s.write(item)) {
393 Ok(()) => {
394 self.ready = true;
395 Ok(())
396 }
397 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
398 self.ready = false;
401 Ok(())
402 }
403 Err(e) => {
404 self.ready = true;
405 debug!("websocket start_send error: {}", e);
406 Err(e)
407 }
408 }
409 }
410
411 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
412 (*self)
413 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
414 .map(|r| {
415 self.ready = true;
416 match r {
417 Err(WsError::ConnectionClosed) => Ok(()),
419 other => other,
420 }
421 })
422 }
423
424 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425 self.ready = true;
426 let res = if self.closing {
427 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
429 } else {
430 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
431 };
432
433 match res {
434 Ok(()) => Poll::Ready(Ok(())),
435 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
436 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
437 trace!("WouldBlock");
438 self.closing = true;
439 Poll::Pending
440 }
441 Err(err) => {
442 debug!("websocket close error: {}", err);
443 Poll::Ready(Err(err))
444 }
445 }
446 }
447}
448
449#[cfg(any(
450 feature = "async-tls",
451 feature = "async-std-runtime",
452 feature = "tokio-runtime",
453 feature = "gio-runtime"
454))]
455#[inline]
457pub(crate) fn domain(
458 request: &tungstenite::handshake::client::Request,
459) -> Result<String, tungstenite::Error> {
460 request
461 .uri()
462 .host()
463 .map(|host| {
464 let host = if host.starts_with('[') {
470 &host[1..host.len() - 1]
471 } else {
472 host
473 };
474
475 host.to_owned()
476 })
477 .ok_or(tungstenite::Error::Url(
478 tungstenite::error::UrlError::NoHostName,
479 ))
480}
481
482#[cfg(any(
483 feature = "async-std-runtime",
484 feature = "tokio-runtime",
485 feature = "gio-runtime"
486))]
487#[inline]
489pub(crate) fn port(
490 request: &tungstenite::handshake::client::Request,
491) -> Result<u16, tungstenite::Error> {
492 request
493 .uri()
494 .port_u16()
495 .or_else(|| match request.uri().scheme_str() {
496 Some("wss") => Some(443),
497 Some("ws") => Some(80),
498 _ => None,
499 })
500 .ok_or(tungstenite::Error::Url(
501 tungstenite::error::UrlError::UnsupportedUrlScheme,
502 ))
503}
504
505#[cfg(test)]
506mod tests {
507 #[cfg(any(
508 feature = "async-tls",
509 feature = "async-std-runtime",
510 feature = "tokio-runtime",
511 feature = "gio-runtime"
512 ))]
513 #[test]
514 fn domain_strips_ipv6_brackets() {
515 use tungstenite::client::IntoClientRequest;
516
517 let request = "ws://[::1]:80".into_client_request().unwrap();
518 assert_eq!(crate::domain(&request).unwrap(), "::1");
519 }
520
521 #[cfg(feature = "handshake")]
522 #[test]
523 fn requests_cannot_contain_invalid_uris() {
524 use tungstenite::client::IntoClientRequest;
525
526 assert!("ws://[".into_client_request().is_err());
527 assert!("ws://[blabla/bla".into_client_request().is_err());
528 assert!("ws://[::1/bla".into_client_request().is_err());
529 }
530}