1#![deny(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3pub mod extensions;
8#[cfg(feature = "fastwebsockets")]
9#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
10mod fastwebsockets;
11mod packet;
12mod sink_unfold;
13mod stream;
14pub mod ws;
15
16pub use crate::{packet::*, stream::*};
17
18use bytes::{Bytes, BytesMut};
19use dashmap::DashMap;
20use event_listener::Event;
21use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
22use flume as mpsc;
23use futures::{channel::oneshot, select, Future, FutureExt};
24use futures_timer::Delay;
25use std::{
26 sync::{
27 atomic::{AtomicBool, AtomicU32, Ordering},
28 Arc,
29 },
30 time::Duration,
31};
32use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
33
34pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
36
37#[derive(Debug, PartialEq, Copy, Clone)]
39pub enum Role {
40 Client,
42 Server,
44}
45
46#[derive(Debug)]
48pub enum WispError {
49 PacketTooSmall,
51 InvalidPacketType,
53 InvalidStreamId,
55 InvalidCloseReason,
57 InvalidUri,
59 UriHasNoHost,
61 UriHasNoPort,
63 MaxStreamCountReached,
65 IncompatibleProtocolVersion,
67 StreamAlreadyClosed,
69 WsFrameInvalidType,
71 WsFrameNotFinished,
73 WsImplError(Box<dyn std::error::Error + Sync + Send>),
75 WsImplSocketClosed,
77 WsImplNotSupported,
79 ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
81 ExtensionImplNotSupported,
83 ExtensionsNotSupported(Vec<u8>),
85 Utf8Error(std::str::Utf8Error),
87 TryFromIntError(std::num::TryFromIntError),
89 Other(Box<dyn std::error::Error + Sync + Send>),
91 MuxMessageFailedToSend,
93 MuxMessageFailedToRecv,
95 MuxTaskEnded,
97}
98
99impl From<std::str::Utf8Error> for WispError {
100 fn from(err: std::str::Utf8Error) -> Self {
101 Self::Utf8Error(err)
102 }
103}
104
105impl From<std::num::TryFromIntError> for WispError {
106 fn from(value: std::num::TryFromIntError) -> Self {
107 Self::TryFromIntError(value)
108 }
109}
110
111impl std::fmt::Display for WispError {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
113 match self {
114 Self::PacketTooSmall => write!(f, "Packet too small"),
115 Self::InvalidPacketType => write!(f, "Invalid packet type"),
116 Self::InvalidStreamId => write!(f, "Invalid stream id"),
117 Self::InvalidCloseReason => write!(f, "Invalid close reason"),
118 Self::InvalidUri => write!(f, "Invalid URI"),
119 Self::UriHasNoHost => write!(f, "URI has no host"),
120 Self::UriHasNoPort => write!(f, "URI has no port"),
121 Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
122 Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
123 Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
124 Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
125 Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
126 Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
127 Self::WsImplSocketClosed => {
128 write!(f, "Websocket implementation error: websocket closed")
129 }
130 Self::WsImplNotSupported => {
131 write!(f, "Websocket implementation error: unsupported feature")
132 }
133 Self::ExtensionImplError(err) => {
134 write!(f, "Protocol extension implementation error: {}", err)
135 }
136 Self::ExtensionImplNotSupported => {
137 write!(
138 f,
139 "Protocol extension implementation error: unsupported feature"
140 )
141 }
142 Self::ExtensionsNotSupported(list) => {
143 write!(f, "Protocol extensions {:?} not supported", list)
144 }
145 Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
146 Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
147 Self::Other(err) => write!(f, "Other error: {}", err),
148 Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
149 Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
150 Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
151 }
152 }
153}
154
155impl std::error::Error for WispError {}
156
157struct MuxMapValue {
158 stream: mpsc::Sender<Bytes>,
159 stream_type: StreamType,
160
161 flow_control: Arc<AtomicU32>,
162 flow_control_event: Arc<Event>,
163
164 is_closed: Arc<AtomicBool>,
165 close_reason: Arc<AtomicCloseReason>,
166 is_closed_event: Arc<Event>,
167}
168
169struct MuxInner {
170 tx: ws::LockedWebSocketWrite,
171 stream_map: DashMap<u32, MuxMapValue>,
172 buffer_size: u32,
173 fut_exited: Arc<AtomicBool>,
174}
175
176impl MuxInner {
177 pub async fn server_into_future<R>(
178 self,
179 rx: R,
180 extensions: Vec<AnyProtocolExtension>,
181 close_rx: mpsc::Receiver<WsEvent>,
182 muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
183 close_tx: mpsc::Sender<WsEvent>,
184 ) -> Result<(), WispError>
185 where
186 R: ws::WebSocketRead + Send,
187 {
188 self.as_future(
189 close_rx,
190 close_tx.clone(),
191 self.server_loop(rx, extensions, muxstream_sender, close_tx),
192 )
193 .await
194 }
195
196 pub async fn client_into_future<R>(
197 self,
198 rx: R,
199 extensions: Vec<AnyProtocolExtension>,
200 close_rx: mpsc::Receiver<WsEvent>,
201 close_tx: mpsc::Sender<WsEvent>,
202 ) -> Result<(), WispError>
203 where
204 R: ws::WebSocketRead + Send,
205 {
206 self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
207 .await
208 }
209
210 async fn as_future(
211 &self,
212 close_rx: mpsc::Receiver<WsEvent>,
213 close_tx: mpsc::Sender<WsEvent>,
214 wisp_fut: impl Future<Output = Result<(), WispError>>,
215 ) -> Result<(), WispError> {
216 let ret = futures::select! {
217 _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
218 x = wisp_fut.fuse() => x,
219 };
220 self.fut_exited.store(true, Ordering::Release);
221 for x in self.stream_map.iter_mut() {
222 x.is_closed.store(true, Ordering::Release);
223 x.is_closed_event.notify(usize::MAX);
224 }
225 self.stream_map.clear();
226 let _ = self.tx.close().await;
227 ret
228 }
229
230 async fn create_new_stream(
231 &self,
232 stream_id: u32,
233 stream_type: StreamType,
234 role: Role,
235 stream_tx: mpsc::Sender<WsEvent>,
236 tx: LockedWebSocketWrite,
237 target_buffer_size: u32,
238 ) -> Result<(MuxMapValue, MuxStream), WispError> {
239 let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
240
241 let flow_control_event: Arc<Event> = Event::new().into();
242 let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
243
244 let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
245 let close_reason: Arc<AtomicCloseReason> =
246 AtomicCloseReason::new(CloseReason::Unknown).into();
247 let is_closed_event: Arc<Event> = Event::new().into();
248
249 Ok((
250 MuxMapValue {
251 stream: ch_tx,
252 stream_type,
253
254 flow_control: flow_control.clone(),
255 flow_control_event: flow_control_event.clone(),
256
257 is_closed: is_closed.clone(),
258 close_reason: close_reason.clone(),
259 is_closed_event: is_closed_event.clone(),
260 },
261 MuxStream::new(
262 stream_id,
263 role,
264 stream_type,
265 ch_rx,
266 stream_tx,
267 tx,
268 is_closed,
269 is_closed_event,
270 close_reason,
271 flow_control,
272 flow_control_event,
273 target_buffer_size,
274 ),
275 ))
276 }
277
278 async fn stream_loop(
279 &self,
280 stream_rx: mpsc::Receiver<WsEvent>,
281 stream_tx: mpsc::Sender<WsEvent>,
282 ) {
283 let mut next_free_stream_id: u32 = 1;
284 while let Ok(msg) = stream_rx.recv_async().await {
285 match msg {
286 WsEvent::CreateStream(stream_type, host, port, channel) => {
287 let ret: Result<MuxStream, WispError> = async {
288 let stream_id = next_free_stream_id;
289 let next_stream_id = next_free_stream_id
290 .checked_add(1)
291 .ok_or(WispError::MaxStreamCountReached)?;
292
293 let (map_value, stream) = self
294 .create_new_stream(
295 stream_id,
296 stream_type,
297 Role::Client,
298 stream_tx.clone(),
299 self.tx.clone(),
300 0,
301 )
302 .await?;
303
304 self.tx
305 .write_frame(
306 Packet::new_connect(stream_id, stream_type, port, host).into(),
307 )
308 .await?;
309
310 self.stream_map.insert(stream_id, map_value);
311
312 next_free_stream_id = next_stream_id;
313
314 Ok(stream)
315 }
316 .await;
317 let _ = channel.send(ret);
318 }
319 WsEvent::Close(packet, channel) => {
320 if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
321 if let PacketType::Close(close) = packet.packet_type {
322 self.close_stream(packet.stream_id, close);
323 }
324 let _ = channel.send(self.tx.write_frame(packet.into()).await);
325 drop(stream.stream)
326 } else {
327 let _ = channel.send(Err(WispError::InvalidStreamId));
328 }
329 }
330 WsEvent::EndFut(x) => {
331 if let Some(reason) = x {
332 let _ = self
333 .tx
334 .write_frame(Packet::new_close(0, reason).into())
335 .await;
336 }
337 break;
338 }
339 }
340 }
341 }
342
343 fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) {
344 if let Some((_, stream)) = self.stream_map.remove(&stream_id) {
345 stream
346 .close_reason
347 .store(close_packet.reason, Ordering::Release);
348 stream.is_closed.store(true, Ordering::Release);
349 stream.is_closed_event.notify(usize::MAX);
350 stream.flow_control.store(u32::MAX, Ordering::Release);
351 stream.flow_control_event.notify(usize::MAX);
352 drop(stream.stream)
353 }
354 }
355
356 async fn server_loop<R>(
357 &self,
358 mut rx: R,
359 mut extensions: Vec<AnyProtocolExtension>,
360 muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
361 stream_tx: mpsc::Sender<WsEvent>,
362 ) -> Result<(), WispError>
363 where
364 R: ws::WebSocketRead + Send,
365 {
366 let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
368
369 loop {
370 let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
371 if frame.opcode == ws::OpCode::Close {
372 break Ok(());
373 }
374
375 if let Some(ref extra_frame) = optional_frame {
376 if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
377 let mut payload = BytesMut::from(frame.payload);
378 payload.extend_from_slice(&extra_frame.payload);
379 frame.payload = Payload::Bytes(payload);
380 }
381 }
382
383 if let Some(packet) =
384 Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
385 {
386 use PacketType::*;
387 match packet.packet_type {
388 Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
389 Connect(inner_packet) => {
390 let (map_value, stream) = self
391 .create_new_stream(
392 packet.stream_id,
393 inner_packet.stream_type,
394 Role::Server,
395 stream_tx.clone(),
396 self.tx.clone(),
397 target_buffer_size,
398 )
399 .await?;
400 muxstream_sender
401 .send_async((inner_packet, stream))
402 .await
403 .map_err(|_| WispError::MuxMessageFailedToSend)?;
404 self.stream_map.insert(packet.stream_id, map_value);
405 }
406 Data(data) => {
407 let mut data = BytesMut::from(data);
408 if let Some(stream) = self.stream_map.get(&packet.stream_id) {
409 if let Some(extra_frame) = optional_frame {
410 if data.is_empty() {
411 data = extra_frame.payload.into();
412 } else {
413 data.extend_from_slice(&extra_frame.payload);
414 }
415 }
416 let _ = stream.stream.try_send(data.freeze());
417 if stream.stream_type == StreamType::Tcp {
418 stream.flow_control.store(
419 stream
420 .flow_control
421 .load(Ordering::Acquire)
422 .saturating_sub(1),
423 Ordering::Release,
424 );
425 }
426 }
427 }
428 Close(inner_packet) => {
429 if packet.stream_id == 0 {
430 break Ok(());
431 }
432 self.close_stream(packet.stream_id, inner_packet)
433 }
434 }
435 }
436 }
437 }
438
439 async fn client_loop<R>(
440 &self,
441 mut rx: R,
442 mut extensions: Vec<AnyProtocolExtension>,
443 ) -> Result<(), WispError>
444 where
445 R: ws::WebSocketRead + Send,
446 {
447 loop {
448 let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
449 if frame.opcode == ws::OpCode::Close {
450 break Ok(());
451 }
452
453 if let Some(ref extra_frame) = optional_frame {
454 if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
455 let mut payload = BytesMut::from(frame.payload);
456 payload.extend_from_slice(&extra_frame.payload);
457 frame.payload = Payload::Bytes(payload);
458 }
459 }
460
461 if let Some(packet) =
462 Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
463 {
464 use PacketType::*;
465 match packet.packet_type {
466 Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
467 Data(data) => {
468 let mut data = BytesMut::from(data);
469 if let Some(stream) = self.stream_map.get(&packet.stream_id) {
470 if let Some(extra_frame) = optional_frame {
471 if data.is_empty() {
472 data = extra_frame.payload.into();
473 } else {
474 data.extend_from_slice(&extra_frame.payload);
475 }
476 }
477 let _ = stream.stream.send_async(data.freeze()).await;
478 }
479 }
480 Continue(inner_packet) => {
481 if let Some(stream) = self.stream_map.get(&packet.stream_id) {
482 if stream.stream_type == StreamType::Tcp {
483 stream
484 .flow_control
485 .store(inner_packet.buffer_remaining, Ordering::Release);
486 let _ = stream.flow_control_event.notify(u32::MAX);
487 }
488 }
489 }
490 Close(inner_packet) => {
491 if packet.stream_id == 0 {
492 break Ok(());
493 }
494 self.close_stream(packet.stream_id, inner_packet);
495 }
496 }
497 }
498 }
499 }
500}
501
502async fn maybe_wisp_v2<R>(
503 read: &mut R,
504 write: &LockedWebSocketWrite,
505 builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
506) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
507where
508 R: ws::WebSocketRead + Send,
509{
510 let mut supported_extensions = Vec::new();
511 let mut extra_packet: Option<ws::Frame<'static>> = None;
512 let mut downgraded = true;
513
514 let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
515 if let Some(frame) = select! {
516 x = read.wisp_read_frame(write).fuse() => Some(x?),
517 _ = Delay::new(Duration::from_secs(5)).fuse() => None
518 } {
519 let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
520 if let PacketType::Info(info) = packet.packet_type {
521 supported_extensions = info
522 .extensions
523 .into_iter()
524 .filter(|x| extension_ids.contains(&x.get_id()))
525 .collect();
526 downgraded = false;
527 } else {
528 extra_packet.replace(ws::Frame::from(packet).clone());
529 }
530 }
531
532 for extension in supported_extensions.iter_mut() {
533 extension.handle_handshake(read, write).await?;
534 }
535 Ok((supported_extensions, extra_packet, downgraded))
536}
537
538pub struct ServerMux {
558 pub downgraded: bool,
562 pub supported_extension_ids: Vec<u8>,
564 close_tx: mpsc::Sender<WsEvent>,
565 muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
566 tx: ws::LockedWebSocketWrite,
567 fut_exited: Arc<AtomicBool>,
568}
569
570impl ServerMux {
571 pub async fn create<R, W>(
577 mut read: R,
578 write: W,
579 buffer_size: u32,
580 extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
581 ) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
582 where
583 R: ws::WebSocketRead + Send,
584 W: ws::WebSocketWrite + Send + 'static,
585 {
586 let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
587 let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
588 let write = ws::LockedWebSocketWrite::new(Box::new(write));
589 let fut_exited = Arc::new(AtomicBool::new(false));
590
591 write
592 .write_frame(Packet::new_continue(0, buffer_size).into())
593 .await?;
594
595 let (supported_extensions, extra_packet, downgraded) =
596 if let Some(builders) = extension_builders {
597 write
598 .write_frame(
599 Packet::new_info(
600 builders
601 .iter()
602 .map(|x| x.build_to_extension(Role::Client))
603 .collect(),
604 )
605 .into(),
606 )
607 .await?;
608 maybe_wisp_v2(&mut read, &write, builders).await?
609 } else {
610 (Vec::new(), None, true)
611 };
612
613 Ok(ServerMuxResult(
614 Self {
615 muxstream_recv: rx,
616 close_tx: close_tx.clone(),
617 downgraded,
618 supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
619 tx: write.clone(),
620 fut_exited: fut_exited.clone(),
621 },
622 MuxInner {
623 tx: write,
624 stream_map: DashMap::new(),
625 buffer_size,
626 fut_exited,
627 }
628 .server_into_future(
629 AppendingWebSocketRead(extra_packet, read),
630 supported_extensions,
631 close_rx,
632 tx,
633 close_tx,
634 ),
635 ))
636 }
637
638 pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
640 if self.fut_exited.load(Ordering::Acquire) {
641 return None;
642 }
643 self.muxstream_recv.recv_async().await.ok()
644 }
645
646 async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
647 if self.fut_exited.load(Ordering::Acquire) {
648 return Err(WispError::MuxTaskEnded);
649 }
650 self.close_tx
651 .send_async(WsEvent::EndFut(reason))
652 .await
653 .map_err(|_| WispError::MuxMessageFailedToSend)
654 }
655
656 pub async fn close(&self) -> Result<(), WispError> {
660 self.close_internal(None).await
661 }
662
663 pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
667 self.close_internal(Some(CloseReason::IncompatibleExtensions))
668 .await
669 }
670
671 pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
673 MuxProtocolExtensionStream {
674 stream_id: 0,
675 tx: self.tx.clone(),
676 is_closed: self.fut_exited.clone(),
677 }
678 }
679}
680
681impl Drop for ServerMux {
682 fn drop(&mut self) {
683 let _ = self.close_tx.send(WsEvent::EndFut(None));
684 }
685}
686
687pub struct ServerMuxResult<F>(ServerMux, F)
689where
690 F: Future<Output = Result<(), WispError>> + Send;
691
692impl<F> ServerMuxResult<F>
693where
694 F: Future<Output = Result<(), WispError>> + Send,
695{
696 pub fn with_no_required_extensions(self) -> (ServerMux, F) {
698 (self.0, self.1)
699 }
700
701 pub async fn with_required_extensions(
704 self,
705 extensions: &[u8],
706 ) -> Result<(ServerMux, F), WispError> {
707 let mut unsupported_extensions = Vec::new();
708 for extension in extensions {
709 if !self.0.supported_extension_ids.contains(extension) {
710 unsupported_extensions.push(*extension);
711 }
712 }
713 if unsupported_extensions.is_empty() {
714 Ok((self.0, self.1))
715 } else {
716 self.0.close_extension_incompat().await?;
717 self.1.await?;
718 Err(WispError::ExtensionsNotSupported(unsupported_extensions))
719 }
720 }
721
722 pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
724 self.with_required_extensions(&[UdpProtocolExtension::ID])
725 .await
726 }
727}
728
729pub struct ClientMux {
744 pub downgraded: bool,
748 pub supported_extension_ids: Vec<u8>,
750 stream_tx: mpsc::Sender<WsEvent>,
751 tx: ws::LockedWebSocketWrite,
752 fut_exited: Arc<AtomicBool>,
753}
754
755impl ClientMux {
756 pub async fn create<R, W>(
762 mut read: R,
763 write: W,
764 extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
765 ) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
766 where
767 R: ws::WebSocketRead + Send,
768 W: ws::WebSocketWrite + Send + 'static,
769 {
770 let write = ws::LockedWebSocketWrite::new(Box::new(write));
771 let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
772 let fut_exited = Arc::new(AtomicBool::new(false));
773
774 if first_packet.stream_id != 0 {
775 return Err(WispError::InvalidStreamId);
776 }
777 if let PacketType::Continue(packet) = first_packet.packet_type {
778 let (supported_extensions, extra_packet, downgraded) =
779 if let Some(builders) = extension_builders {
780 let x = maybe_wisp_v2(&mut read, &write, builders).await?;
781 if !x.2 {
783 write
784 .write_frame(
785 Packet::new_info(
786 builders
787 .iter()
788 .map(|x| x.build_to_extension(Role::Client))
789 .collect(),
790 )
791 .into(),
792 )
793 .await?;
794 }
795 x
796 } else {
797 (Vec::new(), None, true)
798 };
799
800 let (tx, rx) = mpsc::bounded::<WsEvent>(256);
801 Ok(ClientMuxResult(
802 Self {
803 stream_tx: tx.clone(),
804 downgraded,
805 supported_extension_ids: supported_extensions
806 .iter()
807 .map(|x| x.get_id())
808 .collect(),
809 tx: write.clone(),
810 fut_exited: fut_exited.clone(),
811 },
812 MuxInner {
813 tx: write,
814 stream_map: DashMap::new(),
815 buffer_size: packet.buffer_remaining,
816 fut_exited,
817 }
818 .client_into_future(
819 AppendingWebSocketRead(extra_packet, read),
820 supported_extensions,
821 rx,
822 tx,
823 ),
824 ))
825 } else {
826 Err(WispError::InvalidPacketType)
827 }
828 }
829
830 pub async fn client_new_stream(
832 &self,
833 stream_type: StreamType,
834 host: String,
835 port: u16,
836 ) -> Result<MuxStream, WispError> {
837 if self.fut_exited.load(Ordering::Acquire) {
838 return Err(WispError::MuxTaskEnded);
839 }
840 if stream_type == StreamType::Udp
841 && !self
842 .supported_extension_ids
843 .iter()
844 .any(|x| *x == UdpProtocolExtension::ID)
845 {
846 return Err(WispError::ExtensionsNotSupported(vec![
847 UdpProtocolExtension::ID,
848 ]));
849 }
850 let (tx, rx) = oneshot::channel();
851 self.stream_tx
852 .send_async(WsEvent::CreateStream(stream_type, host, port, tx))
853 .await
854 .map_err(|_| WispError::MuxMessageFailedToSend)?;
855 rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
856 }
857
858 async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
859 if self.fut_exited.load(Ordering::Acquire) {
860 return Err(WispError::MuxTaskEnded);
861 }
862 self.stream_tx
863 .send_async(WsEvent::EndFut(reason))
864 .await
865 .map_err(|_| WispError::MuxMessageFailedToSend)
866 }
867
868 pub async fn close(&self) -> Result<(), WispError> {
872 self.close_internal(None).await
873 }
874
875 pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
879 self.close_internal(Some(CloseReason::IncompatibleExtensions))
880 .await
881 }
882
883 pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
885 MuxProtocolExtensionStream {
886 stream_id: 0,
887 tx: self.tx.clone(),
888 is_closed: self.fut_exited.clone(),
889 }
890 }
891}
892
893impl Drop for ClientMux {
894 fn drop(&mut self) {
895 let _ = self.stream_tx.send(WsEvent::EndFut(None));
896 }
897}
898
899pub struct ClientMuxResult<F>(ClientMux, F)
901where
902 F: Future<Output = Result<(), WispError>> + Send;
903
904impl<F> ClientMuxResult<F>
905where
906 F: Future<Output = Result<(), WispError>> + Send,
907{
908 pub fn with_no_required_extensions(self) -> (ClientMux, F) {
910 (self.0, self.1)
911 }
912
913 pub async fn with_required_extensions(
915 self,
916 extensions: &[u8],
917 ) -> Result<(ClientMux, F), WispError> {
918 let mut unsupported_extensions = Vec::new();
919 for extension in extensions {
920 if !self.0.supported_extension_ids.contains(extension) {
921 unsupported_extensions.push(*extension);
922 }
923 }
924 if unsupported_extensions.is_empty() {
925 Ok((self.0, self.1))
926 } else {
927 self.0.close_extension_incompat().await?;
928 self.1.await?;
929 Err(WispError::ExtensionsNotSupported(unsupported_extensions))
930 }
931 }
932
933 pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
935 self.with_required_extensions(&[UdpProtocolExtension::ID])
936 .await
937 }
938}