1#![deny(
9 missing_docs,
10 unused_must_use,
11 unused_mut,
12 unused_imports,
13 unused_import_braces
14)]
15
16pub use tungstenite;
17
18mod compat;
19mod connect;
20mod handshake;
21
22use std::io::{Read, Write};
23
24use compat::{AllowStd, ContextWaker, cvt};
25use futures_util::{
26 sink::{Sink, SinkExt},
27 stream::{FusedStream, Stream},
28};
29use std::{
30 pin::Pin,
31 task::{Context, Poll},
32};
33use wstd::io::{AsyncRead, AsyncWrite};
34
35#[cfg(feature = "handshake")]
36use tungstenite::{
37 client::IntoClientRequest,
38 handshake::{
39 HandshakeError,
40 client::{ClientHandshake, Response},
41 server::{Callback, NoCallback},
42 },
43};
44use tungstenite::{
45 error::Error as WsError,
46 protocol::{Message, Role, WebSocket, WebSocketConfig},
47};
48
49#[cfg(feature = "connect")]
50pub use connect::{connect_async, connect_async_with_config};
51
52use tungstenite::protocol::CloseFrame;
53
54#[cfg(feature = "handshake")]
67pub async fn client_async<R, S>(
68 request: R,
69 stream: S,
70) -> Result<(WebSocketStream<S>, Response), WsError>
71where
72 R: IntoClientRequest + Unpin,
73 S: AsyncRead + AsyncWrite + Unpin,
74{
75 client_async_with_config(request, stream, None).await
76}
77
78#[cfg(feature = "handshake")]
81pub async fn client_async_with_config<R, S>(
82 request: R,
83 stream: S,
84 config: Option<WebSocketConfig>,
85) -> Result<(WebSocketStream<S>, Response), WsError>
86where
87 R: IntoClientRequest + Unpin,
88 S: AsyncRead + AsyncWrite + Unpin,
89{
90 let f = handshake::client_handshake(stream, move |allow_std| {
91 let request = request.into_client_request()?;
92 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
93 cli_handshake.handshake()
94 });
95 f.await.map_err(|e| match e {
96 HandshakeError::Failure(e) => e,
97 e => WsError::Io(std::io::Error::other(e.to_string())),
98 })
99}
100
101#[cfg(feature = "handshake")]
113pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
114where
115 S: AsyncRead + AsyncWrite + Unpin,
116{
117 accept_hdr_async(stream, NoCallback).await
118}
119
120#[cfg(feature = "handshake")]
123pub async fn accept_async_with_config<S>(
124 stream: S,
125 config: Option<WebSocketConfig>,
126) -> Result<WebSocketStream<S>, WsError>
127where
128 S: AsyncRead + AsyncWrite + Unpin,
129{
130 accept_hdr_async_with_config(stream, NoCallback, config).await
131}
132
133#[cfg(feature = "handshake")]
139pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
140where
141 S: AsyncRead + AsyncWrite + Unpin,
142 C: Callback + Unpin,
143{
144 accept_hdr_async_with_config(stream, callback, None).await
145}
146
147#[cfg(feature = "handshake")]
150pub async fn accept_hdr_async_with_config<S, C>(
151 stream: S,
152 callback: C,
153 config: Option<WebSocketConfig>,
154) -> Result<WebSocketStream<S>, WsError>
155where
156 S: AsyncRead + AsyncWrite + Unpin,
157 C: Callback + Unpin,
158{
159 let f = handshake::server_handshake(stream, move |allow_std| {
160 tungstenite::accept_hdr_with_config(allow_std, callback, config)
161 });
162 f.await.map_err(|e| match e {
163 HandshakeError::Failure(e) => e,
164 e => WsError::Io(std::io::Error::other(e.to_string())),
165 })
166}
167
168#[derive(Debug)]
178pub struct WebSocketStream<S> {
179 inner: WebSocket<AllowStd<S>>,
180 closing: bool,
181 ended: bool,
182 ready: bool,
187}
188
189impl<S> WebSocketStream<S> {
190 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
193 where
194 S: AsyncRead + AsyncWrite + Unpin,
195 {
196 handshake::without_handshake(stream, move |allow_std| {
197 WebSocket::from_raw_socket(allow_std, role, config)
198 })
199 .await
200 }
201
202 pub async fn from_partially_read(
205 stream: S,
206 part: Vec<u8>,
207 role: Role,
208 config: Option<WebSocketConfig>,
209 ) -> Self
210 where
211 S: AsyncRead + AsyncWrite + Unpin,
212 {
213 handshake::without_handshake(stream, move |allow_std| {
214 WebSocket::from_partially_read(allow_std, part, role, config)
215 })
216 .await
217 }
218
219 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
220 Self {
221 inner: ws,
222 closing: false,
223 ended: false,
224 ready: true,
225 }
226 }
227
228 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
229 where
230 S: Unpin,
231 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
232 AllowStd<S>: Read + Write,
233 {
234 if let Some((kind, ctx)) = ctx {
235 self.inner.get_mut().set_waker(kind, ctx.waker());
236 }
237 f(&mut self.inner)
238 }
239
240 pub fn into_inner(self) -> S {
242 self.inner.into_inner().into_inner()
243 }
244
245 pub fn get_ref(&self) -> &S
247 where
248 S: AsyncRead + AsyncWrite + Unpin,
249 {
250 self.inner.get_ref().get_ref()
251 }
252
253 pub fn get_mut(&mut self) -> &mut S
255 where
256 S: AsyncRead + AsyncWrite + Unpin,
257 {
258 self.inner.get_mut().get_mut()
259 }
260
261 pub fn get_config(&self) -> &WebSocketConfig {
263 self.inner.get_config()
264 }
265
266 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
268 where
269 S: AsyncRead + AsyncWrite + Unpin,
270 {
271 self.send(Message::Close(msg)).await
272 }
273}
274
275impl<T> Stream for WebSocketStream<T>
276where
277 T: AsyncRead + AsyncWrite + Unpin,
278{
279 type Item = Result<Message, WsError>;
280
281 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
282 if self.ended {
286 return Poll::Ready(None);
287 }
288
289 match futures_util::ready!(
290 self.with_context(Some((ContextWaker::Read, cx)), |s| { cvt(s.read()) })
291 ) {
292 Ok(v) => Poll::Ready(Some(Ok(v))),
293 Err(e) => {
294 self.ended = true;
295 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
296 Poll::Ready(None)
297 } else {
298 Poll::Ready(Some(Err(e)))
299 }
300 }
301 }
302 }
303}
304
305impl<T> FusedStream for WebSocketStream<T>
306where
307 T: AsyncRead + AsyncWrite + Unpin,
308{
309 fn is_terminated(&self) -> bool {
310 self.ended
311 }
312}
313
314impl<T> Sink<Message> for WebSocketStream<T>
315where
316 T: AsyncRead + AsyncWrite + Unpin,
317{
318 type Error = WsError;
319
320 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321 if self.ready {
322 Poll::Ready(Ok(()))
323 } else {
324 (*self)
326 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
327 .map(|r| {
328 self.ready = true;
329 r
330 })
331 }
332 }
333
334 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
335 match (*self).with_context(None, |s| s.write(item)) {
336 Ok(()) => {
337 self.ready = true;
338 Ok(())
339 }
340 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
341 self.ready = false;
344 Ok(())
345 }
346 Err(e) => {
347 self.ready = true;
348 Err(e)
349 }
350 }
351 }
352
353 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
354 (*self)
355 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
356 .map(|r| {
357 self.ready = true;
358 match r {
359 Err(WsError::ConnectionClosed) => Ok(()),
361 other => other,
362 }
363 })
364 }
365
366 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367 self.ready = true;
368 let res = if self.closing {
369 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
371 } else {
372 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
373 };
374
375 match res {
376 Ok(()) => Poll::Ready(Ok(())),
377 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
378 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
379 self.closing = true;
380 Poll::Pending
381 }
382 Err(err) => Poll::Ready(Err(err)),
383 }
384 }
385}
386
387#[cfg(feature = "connect")]
389#[inline]
390fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
391 match request.uri().host() {
392 Some(d) => Ok(d.to_string()),
393 None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use crate::{WebSocketStream, compat::AllowStd};
400 use std::io::{Read, Write};
401
402 fn is_read<T: Read>() {}
403 fn is_write<T: Write>() {}
404 fn is_unpin<T: Unpin>() {}
405
406 #[test]
407 fn web_socket_stream_has_traits() {
408 is_read::<AllowStd<wstd::net::TcpStream>>();
409 is_write::<AllowStd<wstd::net::TcpStream>>();
410 is_unpin::<WebSocketStream<wstd::net::TcpStream>>();
411 }
412}