1#![deny(
9 missing_docs,
10 unused_must_use,
11 unused_mut,
12 unused_imports,
13 unused_import_braces
14)]
15
16pub use tungstenite;
17
18mod compat;
19#[cfg(feature = "connect")]
20mod connect;
21mod handshake;
22
23use std::io::{Read, Write};
24
25use compat::{AllowStd, ContextWaker, cvt};
26use futures_util::{
27 sink::{Sink, SinkExt},
28 stream::{FusedStream, Stream},
29};
30use std::{
31 pin::Pin,
32 task::{Context, Poll},
33};
34use wstd::io::{AsyncRead, AsyncWrite};
35
36#[cfg(feature = "handshake")]
37use tungstenite::{
38 client::IntoClientRequest,
39 handshake::{
40 HandshakeError,
41 client::{ClientHandshake, Response},
42 server::{Callback, NoCallback},
43 },
44};
45use tungstenite::{
46 error::Error as WsError,
47 protocol::{Message, Role, WebSocket, WebSocketConfig},
48};
49
50#[cfg(feature = "connect")]
51pub use connect::{connect_async, connect_async_with_config};
52
53use tungstenite::protocol::CloseFrame;
54
55#[cfg(feature = "handshake")]
68pub async fn client_async<R, S>(
69 request: R,
70 stream: S,
71) -> Result<(WebSocketStream<S>, Response), WsError>
72where
73 R: IntoClientRequest + Unpin,
74 S: AsyncRead + AsyncWrite + Unpin,
75{
76 client_async_with_config(request, stream, None).await
77}
78
79#[cfg(feature = "handshake")]
82pub async fn client_async_with_config<R, S>(
83 request: R,
84 stream: S,
85 config: Option<WebSocketConfig>,
86) -> Result<(WebSocketStream<S>, Response), WsError>
87where
88 R: IntoClientRequest + Unpin,
89 S: AsyncRead + AsyncWrite + Unpin,
90{
91 let f = handshake::client_handshake(stream, move |allow_std| {
92 let request = request.into_client_request()?;
93 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
94 cli_handshake.handshake()
95 });
96 f.await.map_err(|e| match e {
97 HandshakeError::Failure(e) => e,
98 e => WsError::Io(std::io::Error::other(e.to_string())),
99 })
100}
101
102#[cfg(feature = "handshake")]
114pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
115where
116 S: AsyncRead + AsyncWrite + Unpin,
117{
118 accept_hdr_async(stream, NoCallback).await
119}
120
121#[cfg(feature = "handshake")]
124pub async fn accept_async_with_config<S>(
125 stream: S,
126 config: Option<WebSocketConfig>,
127) -> Result<WebSocketStream<S>, WsError>
128where
129 S: AsyncRead + AsyncWrite + Unpin,
130{
131 accept_hdr_async_with_config(stream, NoCallback, config).await
132}
133
134#[cfg(feature = "handshake")]
140pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
141where
142 S: AsyncRead + AsyncWrite + Unpin,
143 C: Callback + Unpin,
144{
145 accept_hdr_async_with_config(stream, callback, None).await
146}
147
148#[cfg(feature = "handshake")]
151pub async fn accept_hdr_async_with_config<S, C>(
152 stream: S,
153 callback: C,
154 config: Option<WebSocketConfig>,
155) -> Result<WebSocketStream<S>, WsError>
156where
157 S: AsyncRead + AsyncWrite + Unpin,
158 C: Callback + Unpin,
159{
160 let f = handshake::server_handshake(stream, move |allow_std| {
161 tungstenite::accept_hdr_with_config(allow_std, callback, config)
162 });
163 f.await.map_err(|e| match e {
164 HandshakeError::Failure(e) => e,
165 e => WsError::Io(std::io::Error::other(e.to_string())),
166 })
167}
168
169#[derive(Debug)]
179pub struct WebSocketStream<S> {
180 inner: WebSocket<AllowStd<S>>,
181 closing: bool,
182 ended: bool,
183 ready: bool,
188}
189
190impl<S> WebSocketStream<S> {
191 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
194 where
195 S: AsyncRead + AsyncWrite + Unpin,
196 {
197 handshake::without_handshake(stream, move |allow_std| {
198 WebSocket::from_raw_socket(allow_std, role, config)
199 })
200 .await
201 }
202
203 pub async fn from_partially_read(
206 stream: S,
207 part: Vec<u8>,
208 role: Role,
209 config: Option<WebSocketConfig>,
210 ) -> Self
211 where
212 S: AsyncRead + AsyncWrite + Unpin,
213 {
214 handshake::without_handshake(stream, move |allow_std| {
215 WebSocket::from_partially_read(allow_std, part, role, config)
216 })
217 .await
218 }
219
220 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
221 Self {
222 inner: ws,
223 closing: false,
224 ended: false,
225 ready: true,
226 }
227 }
228
229 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
230 where
231 S: Unpin,
232 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
233 AllowStd<S>: Read + Write,
234 {
235 if let Some((kind, ctx)) = ctx {
236 self.inner.get_mut().set_waker(kind, ctx.waker());
237 }
238 f(&mut self.inner)
239 }
240
241 pub fn into_inner(self) -> S {
243 self.inner.into_inner().into_inner()
244 }
245
246 pub fn get_ref(&self) -> &S
248 where
249 S: AsyncRead + AsyncWrite + Unpin,
250 {
251 self.inner.get_ref().get_ref()
252 }
253
254 pub fn get_mut(&mut self) -> &mut S
256 where
257 S: AsyncRead + AsyncWrite + Unpin,
258 {
259 self.inner.get_mut().get_mut()
260 }
261
262 pub fn get_config(&self) -> &WebSocketConfig {
264 self.inner.get_config()
265 }
266
267 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
269 where
270 S: AsyncRead + AsyncWrite + Unpin,
271 {
272 self.send(Message::Close(msg)).await
273 }
274}
275
276impl<T> Stream for WebSocketStream<T>
277where
278 T: AsyncRead + AsyncWrite + Unpin,
279{
280 type Item = Result<Message, WsError>;
281
282 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283 if self.ended {
287 return Poll::Ready(None);
288 }
289
290 match futures_util::ready!(
291 self.with_context(Some((ContextWaker::Read, cx)), |s| { cvt(s.read()) })
292 ) {
293 Ok(v) => Poll::Ready(Some(Ok(v))),
294 Err(e) => {
295 self.ended = true;
296 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
297 Poll::Ready(None)
298 } else {
299 Poll::Ready(Some(Err(e)))
300 }
301 }
302 }
303 }
304}
305
306impl<T> FusedStream for WebSocketStream<T>
307where
308 T: AsyncRead + AsyncWrite + Unpin,
309{
310 fn is_terminated(&self) -> bool {
311 self.ended
312 }
313}
314
315impl<T> Sink<Message> for WebSocketStream<T>
316where
317 T: AsyncRead + AsyncWrite + Unpin,
318{
319 type Error = WsError;
320
321 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322 if self.ready {
323 Poll::Ready(Ok(()))
324 } else {
325 (*self)
327 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
328 .map(|r| {
329 self.ready = true;
330 r
331 })
332 }
333 }
334
335 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
336 match (*self).with_context(None, |s| s.write(item)) {
337 Ok(()) => {
338 self.ready = true;
339 Ok(())
340 }
341 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
342 self.ready = false;
345 Ok(())
346 }
347 Err(e) => {
348 self.ready = true;
349 Err(e)
350 }
351 }
352 }
353
354 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355 (*self)
356 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
357 .map(|r| {
358 self.ready = true;
359 match r {
360 Err(WsError::ConnectionClosed) => Ok(()),
362 other => other,
363 }
364 })
365 }
366
367 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368 self.ready = true;
369 let res = if self.closing {
370 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
372 } else {
373 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
374 };
375
376 match res {
377 Ok(()) => Poll::Ready(Ok(())),
378 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
379 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
380 self.closing = true;
381 Poll::Pending
382 }
383 Err(err) => Poll::Ready(Err(err)),
384 }
385 }
386}
387
388#[cfg(feature = "connect")]
390#[inline]
391fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
392 match request.uri().host() {
393 Some(d) => Ok(d.to_string()),
394 None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use crate::{WebSocketStream, compat::AllowStd};
401 use std::io::{Read, Write};
402
403 fn is_read<T: Read>() {}
404 fn is_write<T: Write>() {}
405 fn is_unpin<T: Unpin>() {}
406
407 #[test]
408 fn web_socket_stream_has_traits() {
409 is_read::<AllowStd<wstd::net::TcpStream>>();
410 is_write::<AllowStd<wstd::net::TcpStream>>();
411 is_unpin::<WebSocketStream<wstd::net::TcpStream>>();
412 }
413}