wisp_mux/
lib.rs

1#![deny(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3//! A library for easily creating [Wisp] clients and servers.
4//!
5//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
6
7pub mod extensions;
8#[cfg(feature = "fastwebsockets")]
9#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
10mod fastwebsockets;
11mod packet;
12mod sink_unfold;
13mod stream;
14pub mod ws;
15
16pub use crate::{packet::*, stream::*};
17
18use bytes::{Bytes, BytesMut};
19use dashmap::DashMap;
20use event_listener::Event;
21use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
22use flume as mpsc;
23use futures::{channel::oneshot, select, Future, FutureExt};
24use futures_timer::Delay;
25use std::{
26	sync::{
27		atomic::{AtomicBool, AtomicU32, Ordering},
28		Arc,
29	},
30	time::Duration,
31};
32use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
33
34/// Wisp version supported by this crate.
35pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
36
37/// The role of the multiplexor.
38#[derive(Debug, PartialEq, Copy, Clone)]
39pub enum Role {
40	/// Client side, can create new channels to proxy.
41	Client,
42	/// Server side, can listen for channels to proxy.
43	Server,
44}
45
46/// Errors the Wisp implementation can return.
47#[derive(Debug)]
48pub enum WispError {
49	/// The packet received did not have enough data.
50	PacketTooSmall,
51	/// The packet received had an invalid type.
52	InvalidPacketType,
53	/// The stream had an invalid ID.
54	InvalidStreamId,
55	/// The close packet had an invalid reason.
56	InvalidCloseReason,
57	/// The URI received was invalid.
58	InvalidUri,
59	/// The URI received had no host.
60	UriHasNoHost,
61	/// The URI received had no port.
62	UriHasNoPort,
63	/// The max stream count was reached.
64	MaxStreamCountReached,
65	/// The Wisp protocol version was incompatible.
66	IncompatibleProtocolVersion,
67	/// The stream had already been closed.
68	StreamAlreadyClosed,
69	/// The websocket frame received had an invalid type.
70	WsFrameInvalidType,
71	/// The websocket frame received was not finished.
72	WsFrameNotFinished,
73	/// Error specific to the websocket implementation.
74	WsImplError(Box<dyn std::error::Error + Sync + Send>),
75	/// The websocket implementation socket closed.
76	WsImplSocketClosed,
77	/// The websocket implementation did not support the action.
78	WsImplNotSupported,
79	/// Error specific to the protocol extension implementation.
80	ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
81	/// The protocol extension implementation did not support the action.
82	ExtensionImplNotSupported,
83	/// The specified protocol extensions are not supported by the server.
84	ExtensionsNotSupported(Vec<u8>),
85	/// The string was invalid UTF-8.
86	Utf8Error(std::str::Utf8Error),
87	/// The integer failed to convert.
88	TryFromIntError(std::num::TryFromIntError),
89	/// Other error.
90	Other(Box<dyn std::error::Error + Sync + Send>),
91	/// Failed to send message to multiplexor task.
92	MuxMessageFailedToSend,
93	/// Failed to receive message from multiplexor task.
94	MuxMessageFailedToRecv,
95	/// Multiplexor task ended.
96	MuxTaskEnded,
97}
98
99impl From<std::str::Utf8Error> for WispError {
100	fn from(err: std::str::Utf8Error) -> Self {
101		Self::Utf8Error(err)
102	}
103}
104
105impl From<std::num::TryFromIntError> for WispError {
106	fn from(value: std::num::TryFromIntError) -> Self {
107		Self::TryFromIntError(value)
108	}
109}
110
111impl std::fmt::Display for WispError {
112	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
113		match self {
114			Self::PacketTooSmall => write!(f, "Packet too small"),
115			Self::InvalidPacketType => write!(f, "Invalid packet type"),
116			Self::InvalidStreamId => write!(f, "Invalid stream id"),
117			Self::InvalidCloseReason => write!(f, "Invalid close reason"),
118			Self::InvalidUri => write!(f, "Invalid URI"),
119			Self::UriHasNoHost => write!(f, "URI has no host"),
120			Self::UriHasNoPort => write!(f, "URI has no port"),
121			Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
122			Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
123			Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
124			Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
125			Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
126			Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
127			Self::WsImplSocketClosed => {
128				write!(f, "Websocket implementation error: websocket closed")
129			}
130			Self::WsImplNotSupported => {
131				write!(f, "Websocket implementation error: unsupported feature")
132			}
133			Self::ExtensionImplError(err) => {
134				write!(f, "Protocol extension implementation error: {}", err)
135			}
136			Self::ExtensionImplNotSupported => {
137				write!(
138					f,
139					"Protocol extension implementation error: unsupported feature"
140				)
141			}
142			Self::ExtensionsNotSupported(list) => {
143				write!(f, "Protocol extensions {:?} not supported", list)
144			}
145			Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
146			Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
147			Self::Other(err) => write!(f, "Other error: {}", err),
148			Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
149			Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
150			Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
151		}
152	}
153}
154
155impl std::error::Error for WispError {}
156
157struct MuxMapValue {
158	stream: mpsc::Sender<Bytes>,
159	stream_type: StreamType,
160
161	flow_control: Arc<AtomicU32>,
162	flow_control_event: Arc<Event>,
163
164	is_closed: Arc<AtomicBool>,
165	close_reason: Arc<AtomicCloseReason>,
166	is_closed_event: Arc<Event>,
167}
168
169struct MuxInner {
170	tx: ws::LockedWebSocketWrite,
171	stream_map: DashMap<u32, MuxMapValue>,
172	buffer_size: u32,
173	fut_exited: Arc<AtomicBool>,
174}
175
176impl MuxInner {
177	pub async fn server_into_future<R>(
178		self,
179		rx: R,
180		extensions: Vec<AnyProtocolExtension>,
181		close_rx: mpsc::Receiver<WsEvent>,
182		muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
183		close_tx: mpsc::Sender<WsEvent>,
184	) -> Result<(), WispError>
185	where
186		R: ws::WebSocketRead + Send,
187	{
188		self.as_future(
189			close_rx,
190			close_tx.clone(),
191			self.server_loop(rx, extensions, muxstream_sender, close_tx),
192		)
193		.await
194	}
195
196	pub async fn client_into_future<R>(
197		self,
198		rx: R,
199		extensions: Vec<AnyProtocolExtension>,
200		close_rx: mpsc::Receiver<WsEvent>,
201		close_tx: mpsc::Sender<WsEvent>,
202	) -> Result<(), WispError>
203	where
204		R: ws::WebSocketRead + Send,
205	{
206		self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
207			.await
208	}
209
210	async fn as_future(
211		&self,
212		close_rx: mpsc::Receiver<WsEvent>,
213		close_tx: mpsc::Sender<WsEvent>,
214		wisp_fut: impl Future<Output = Result<(), WispError>>,
215	) -> Result<(), WispError> {
216		let ret = futures::select! {
217			_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
218			x = wisp_fut.fuse() => x,
219		};
220		self.fut_exited.store(true, Ordering::Release);
221		for x in self.stream_map.iter_mut() {
222			x.is_closed.store(true, Ordering::Release);
223			x.is_closed_event.notify(usize::MAX);
224		}
225		self.stream_map.clear();
226		let _ = self.tx.close().await;
227		ret
228	}
229
230	async fn create_new_stream(
231		&self,
232		stream_id: u32,
233		stream_type: StreamType,
234		role: Role,
235		stream_tx: mpsc::Sender<WsEvent>,
236		tx: LockedWebSocketWrite,
237		target_buffer_size: u32,
238	) -> Result<(MuxMapValue, MuxStream), WispError> {
239		let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
240
241		let flow_control_event: Arc<Event> = Event::new().into();
242		let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
243
244		let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
245		let close_reason: Arc<AtomicCloseReason> =
246			AtomicCloseReason::new(CloseReason::Unknown).into();
247		let is_closed_event: Arc<Event> = Event::new().into();
248
249		Ok((
250			MuxMapValue {
251				stream: ch_tx,
252				stream_type,
253
254				flow_control: flow_control.clone(),
255				flow_control_event: flow_control_event.clone(),
256
257				is_closed: is_closed.clone(),
258				close_reason: close_reason.clone(),
259				is_closed_event: is_closed_event.clone(),
260			},
261			MuxStream::new(
262				stream_id,
263				role,
264				stream_type,
265				ch_rx,
266				stream_tx,
267				tx,
268				is_closed,
269				is_closed_event,
270				close_reason,
271				flow_control,
272				flow_control_event,
273				target_buffer_size,
274			),
275		))
276	}
277
278	async fn stream_loop(
279		&self,
280		stream_rx: mpsc::Receiver<WsEvent>,
281		stream_tx: mpsc::Sender<WsEvent>,
282	) {
283		let mut next_free_stream_id: u32 = 1;
284		while let Ok(msg) = stream_rx.recv_async().await {
285			match msg {
286				WsEvent::CreateStream(stream_type, host, port, channel) => {
287					let ret: Result<MuxStream, WispError> = async {
288						let stream_id = next_free_stream_id;
289						let next_stream_id = next_free_stream_id
290							.checked_add(1)
291							.ok_or(WispError::MaxStreamCountReached)?;
292
293						let (map_value, stream) = self
294							.create_new_stream(
295								stream_id,
296								stream_type,
297								Role::Client,
298								stream_tx.clone(),
299								self.tx.clone(),
300								0,
301							)
302							.await?;
303
304						self.tx
305							.write_frame(
306								Packet::new_connect(stream_id, stream_type, port, host).into(),
307							)
308							.await?;
309
310						self.stream_map.insert(stream_id, map_value);
311
312						next_free_stream_id = next_stream_id;
313
314						Ok(stream)
315					}
316					.await;
317					let _ = channel.send(ret);
318				}
319				WsEvent::Close(packet, channel) => {
320					if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
321						if let PacketType::Close(close) = packet.packet_type {
322							self.close_stream(packet.stream_id, close);
323						}
324						let _ = channel.send(self.tx.write_frame(packet.into()).await);
325						drop(stream.stream)
326					} else {
327						let _ = channel.send(Err(WispError::InvalidStreamId));
328					}
329				}
330				WsEvent::EndFut(x) => {
331					if let Some(reason) = x {
332						let _ = self
333							.tx
334							.write_frame(Packet::new_close(0, reason).into())
335							.await;
336					}
337					break;
338				}
339			}
340		}
341	}
342
343	fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) {
344		if let Some((_, stream)) = self.stream_map.remove(&stream_id) {
345			stream
346				.close_reason
347				.store(close_packet.reason, Ordering::Release);
348			stream.is_closed.store(true, Ordering::Release);
349			stream.is_closed_event.notify(usize::MAX);
350			stream.flow_control.store(u32::MAX, Ordering::Release);
351			stream.flow_control_event.notify(usize::MAX);
352			drop(stream.stream)
353		}
354	}
355
356	async fn server_loop<R>(
357		&self,
358		mut rx: R,
359		mut extensions: Vec<AnyProtocolExtension>,
360		muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
361		stream_tx: mpsc::Sender<WsEvent>,
362	) -> Result<(), WispError>
363	where
364		R: ws::WebSocketRead + Send,
365	{
366		// will send continues once flow_control is at 10% of max
367		let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
368
369		loop {
370			let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
371			if frame.opcode == ws::OpCode::Close {
372				break Ok(());
373			}
374
375			if let Some(ref extra_frame) = optional_frame {
376				if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
377					let mut payload = BytesMut::from(frame.payload);
378					payload.extend_from_slice(&extra_frame.payload);
379					frame.payload = Payload::Bytes(payload);
380				}
381			}
382
383			if let Some(packet) =
384				Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
385			{
386				use PacketType::*;
387				match packet.packet_type {
388					Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
389					Connect(inner_packet) => {
390						let (map_value, stream) = self
391							.create_new_stream(
392								packet.stream_id,
393								inner_packet.stream_type,
394								Role::Server,
395								stream_tx.clone(),
396								self.tx.clone(),
397								target_buffer_size,
398							)
399							.await?;
400						muxstream_sender
401							.send_async((inner_packet, stream))
402							.await
403							.map_err(|_| WispError::MuxMessageFailedToSend)?;
404						self.stream_map.insert(packet.stream_id, map_value);
405					}
406					Data(data) => {
407						let mut data = BytesMut::from(data);
408						if let Some(stream) = self.stream_map.get(&packet.stream_id) {
409							if let Some(extra_frame) = optional_frame {
410								if data.is_empty() {
411									data = extra_frame.payload.into();
412								} else {
413									data.extend_from_slice(&extra_frame.payload);
414								}
415							}
416							let _ = stream.stream.try_send(data.freeze());
417							if stream.stream_type == StreamType::Tcp {
418								stream.flow_control.store(
419									stream
420										.flow_control
421										.load(Ordering::Acquire)
422										.saturating_sub(1),
423									Ordering::Release,
424								);
425							}
426						}
427					}
428					Close(inner_packet) => {
429						if packet.stream_id == 0 {
430							break Ok(());
431						}
432						self.close_stream(packet.stream_id, inner_packet)
433					}
434				}
435			}
436		}
437	}
438
439	async fn client_loop<R>(
440		&self,
441		mut rx: R,
442		mut extensions: Vec<AnyProtocolExtension>,
443	) -> Result<(), WispError>
444	where
445		R: ws::WebSocketRead + Send,
446	{
447		loop {
448			let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
449			if frame.opcode == ws::OpCode::Close {
450				break Ok(());
451			}
452
453			if let Some(ref extra_frame) = optional_frame {
454				if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
455					let mut payload = BytesMut::from(frame.payload);
456					payload.extend_from_slice(&extra_frame.payload);
457					frame.payload = Payload::Bytes(payload);
458				}
459			}
460
461			if let Some(packet) =
462				Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
463			{
464				use PacketType::*;
465				match packet.packet_type {
466					Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
467					Data(data) => {
468						let mut data = BytesMut::from(data);
469						if let Some(stream) = self.stream_map.get(&packet.stream_id) {
470							if let Some(extra_frame) = optional_frame {
471								if data.is_empty() {
472									data = extra_frame.payload.into();
473								} else {
474									data.extend_from_slice(&extra_frame.payload);
475								}
476							}
477							let _ = stream.stream.send_async(data.freeze()).await;
478						}
479					}
480					Continue(inner_packet) => {
481						if let Some(stream) = self.stream_map.get(&packet.stream_id) {
482							if stream.stream_type == StreamType::Tcp {
483								stream
484									.flow_control
485									.store(inner_packet.buffer_remaining, Ordering::Release);
486								let _ = stream.flow_control_event.notify(u32::MAX);
487							}
488						}
489					}
490					Close(inner_packet) => {
491						if packet.stream_id == 0 {
492							break Ok(());
493						}
494						self.close_stream(packet.stream_id, inner_packet);
495					}
496				}
497			}
498		}
499	}
500}
501
502async fn maybe_wisp_v2<R>(
503	read: &mut R,
504	write: &LockedWebSocketWrite,
505	builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
506) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
507where
508	R: ws::WebSocketRead + Send,
509{
510	let mut supported_extensions = Vec::new();
511	let mut extra_packet: Option<ws::Frame<'static>> = None;
512	let mut downgraded = true;
513
514	let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
515	if let Some(frame) = select! {
516		x = read.wisp_read_frame(write).fuse() => Some(x?),
517		_ = Delay::new(Duration::from_secs(5)).fuse() => None
518	} {
519		let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
520		if let PacketType::Info(info) = packet.packet_type {
521			supported_extensions = info
522				.extensions
523				.into_iter()
524				.filter(|x| extension_ids.contains(&x.get_id()))
525				.collect();
526			downgraded = false;
527		} else {
528			extra_packet.replace(ws::Frame::from(packet).clone());
529		}
530	}
531
532	for extension in supported_extensions.iter_mut() {
533		extension.handle_handshake(read, write).await?;
534	}
535	Ok((supported_extensions, extra_packet, downgraded))
536}
537
538/// Server-side multiplexor.
539///
540/// # Example
541/// ```
542/// use wisp_mux::ServerMux;
543///
544/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([]));
545/// tokio::spawn(async move {
546///     if let Err(e) = fut.await {
547///         println!("error in multiplexor: {:?}", e);
548///     }
549/// });
550/// while let Some((packet, stream)) = mux.server_new_stream().await {
551///     tokio::spawn(async move {
552///         let url = format!("{}:{}", packet.destination_hostname, packet.destination_port);
553///         // do something with `url` and `packet.stream_type`
554///     });
555/// }
556/// ```
557pub struct ServerMux {
558	/// Whether the connection was downgraded to Wisp v1.
559	///
560	/// If this variable is true you must assume no extensions are supported.
561	pub downgraded: bool,
562	/// Extensions that are supported by both sides.
563	pub supported_extension_ids: Vec<u8>,
564	close_tx: mpsc::Sender<WsEvent>,
565	muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
566	tx: ws::LockedWebSocketWrite,
567	fut_exited: Arc<AtomicBool>,
568}
569
570impl ServerMux {
571	/// Create a new server-side multiplexor.
572	///
573	/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
574	/// **It is not guaranteed that all extensions you specify are available.** You must manually check
575	/// if the extensions you need are available after the multiplexor has been created.
576	pub async fn create<R, W>(
577		mut read: R,
578		write: W,
579		buffer_size: u32,
580		extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
581	) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
582	where
583		R: ws::WebSocketRead + Send,
584		W: ws::WebSocketWrite + Send + 'static,
585	{
586		let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
587		let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
588		let write = ws::LockedWebSocketWrite::new(Box::new(write));
589		let fut_exited = Arc::new(AtomicBool::new(false));
590
591		write
592			.write_frame(Packet::new_continue(0, buffer_size).into())
593			.await?;
594
595		let (supported_extensions, extra_packet, downgraded) =
596			if let Some(builders) = extension_builders {
597				write
598					.write_frame(
599						Packet::new_info(
600							builders
601								.iter()
602								.map(|x| x.build_to_extension(Role::Client))
603								.collect(),
604						)
605						.into(),
606					)
607					.await?;
608				maybe_wisp_v2(&mut read, &write, builders).await?
609			} else {
610				(Vec::new(), None, true)
611			};
612
613		Ok(ServerMuxResult(
614			Self {
615				muxstream_recv: rx,
616				close_tx: close_tx.clone(),
617				downgraded,
618				supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
619				tx: write.clone(),
620				fut_exited: fut_exited.clone(),
621			},
622			MuxInner {
623				tx: write,
624				stream_map: DashMap::new(),
625				buffer_size,
626				fut_exited,
627			}
628			.server_into_future(
629				AppendingWebSocketRead(extra_packet, read),
630				supported_extensions,
631				close_rx,
632				tx,
633				close_tx,
634			),
635		))
636	}
637
638	/// Wait for a stream to be created.
639	pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
640		if self.fut_exited.load(Ordering::Acquire) {
641			return None;
642		}
643		self.muxstream_recv.recv_async().await.ok()
644	}
645
646	async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
647		if self.fut_exited.load(Ordering::Acquire) {
648			return Err(WispError::MuxTaskEnded);
649		}
650		self.close_tx
651			.send_async(WsEvent::EndFut(reason))
652			.await
653			.map_err(|_| WispError::MuxMessageFailedToSend)
654	}
655
656	/// Close all streams.
657	///
658	/// Also terminates the multiplexor future.
659	pub async fn close(&self) -> Result<(), WispError> {
660		self.close_internal(None).await
661	}
662
663	/// Close all streams and send an extension incompatibility error to the client.
664	///
665	/// Also terminates the multiplexor future.
666	pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
667		self.close_internal(Some(CloseReason::IncompatibleExtensions))
668			.await
669	}
670
671	/// Get a protocol extension stream for sending packets with stream id 0.
672	pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
673		MuxProtocolExtensionStream {
674			stream_id: 0,
675			tx: self.tx.clone(),
676			is_closed: self.fut_exited.clone(),
677		}
678	}
679}
680
681impl Drop for ServerMux {
682	fn drop(&mut self) {
683		let _ = self.close_tx.send(WsEvent::EndFut(None));
684	}
685}
686
687/// Result of `ServerMux::new`.
688pub struct ServerMuxResult<F>(ServerMux, F)
689where
690	F: Future<Output = Result<(), WispError>> + Send;
691
692impl<F> ServerMuxResult<F>
693where
694	F: Future<Output = Result<(), WispError>> + Send,
695{
696	/// Require no protocol extensions.
697	pub fn with_no_required_extensions(self) -> (ServerMux, F) {
698		(self.0, self.1)
699	}
700
701	/// Require protocol extensions by their ID. Will close the multiplexor connection if
702	/// extensions are not supported.
703	pub async fn with_required_extensions(
704		self,
705		extensions: &[u8],
706	) -> Result<(ServerMux, F), WispError> {
707		let mut unsupported_extensions = Vec::new();
708		for extension in extensions {
709			if !self.0.supported_extension_ids.contains(extension) {
710				unsupported_extensions.push(*extension);
711			}
712		}
713		if unsupported_extensions.is_empty() {
714			Ok((self.0, self.1))
715		} else {
716			self.0.close_extension_incompat().await?;
717			self.1.await?;
718			Err(WispError::ExtensionsNotSupported(unsupported_extensions))
719		}
720	}
721
722	/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
723	pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
724		self.with_required_extensions(&[UdpProtocolExtension::ID])
725			.await
726	}
727}
728
729/// Client side multiplexor.
730///
731/// # Example
732/// ```
733/// use wisp_mux::{ClientMux, StreamType};
734///
735/// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?;
736/// tokio::spawn(async move {
737///     if let Err(e) = fut.await {
738///         println!("error in multiplexor: {:?}", e);
739///     }
740/// });
741/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
742/// ```
743pub struct ClientMux {
744	/// Whether the connection was downgraded to Wisp v1.
745	///
746	/// If this variable is true you must assume no extensions are supported.
747	pub downgraded: bool,
748	/// Extensions that are supported by both sides.
749	pub supported_extension_ids: Vec<u8>,
750	stream_tx: mpsc::Sender<WsEvent>,
751	tx: ws::LockedWebSocketWrite,
752	fut_exited: Arc<AtomicBool>,
753}
754
755impl ClientMux {
756	/// Create a new client side multiplexor.
757	///
758	/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
759	/// **It is not guaranteed that all extensions you specify are available.** You must manually check
760	/// if the extensions you need are available after the multiplexor has been created.
761	pub async fn create<R, W>(
762		mut read: R,
763		write: W,
764		extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
765	) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
766	where
767		R: ws::WebSocketRead + Send,
768		W: ws::WebSocketWrite + Send + 'static,
769	{
770		let write = ws::LockedWebSocketWrite::new(Box::new(write));
771		let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
772		let fut_exited = Arc::new(AtomicBool::new(false));
773
774		if first_packet.stream_id != 0 {
775			return Err(WispError::InvalidStreamId);
776		}
777		if let PacketType::Continue(packet) = first_packet.packet_type {
778			let (supported_extensions, extra_packet, downgraded) =
779				if let Some(builders) = extension_builders {
780					let x = maybe_wisp_v2(&mut read, &write, builders).await?;
781					// if not downgraded
782					if !x.2 {
783						write
784							.write_frame(
785								Packet::new_info(
786									builders
787										.iter()
788										.map(|x| x.build_to_extension(Role::Client))
789										.collect(),
790								)
791								.into(),
792							)
793							.await?;
794					}
795					x
796				} else {
797					(Vec::new(), None, true)
798				};
799
800			let (tx, rx) = mpsc::bounded::<WsEvent>(256);
801			Ok(ClientMuxResult(
802				Self {
803					stream_tx: tx.clone(),
804					downgraded,
805					supported_extension_ids: supported_extensions
806						.iter()
807						.map(|x| x.get_id())
808						.collect(),
809					tx: write.clone(),
810					fut_exited: fut_exited.clone(),
811				},
812				MuxInner {
813					tx: write,
814					stream_map: DashMap::new(),
815					buffer_size: packet.buffer_remaining,
816					fut_exited,
817				}
818				.client_into_future(
819					AppendingWebSocketRead(extra_packet, read),
820					supported_extensions,
821					rx,
822					tx,
823				),
824			))
825		} else {
826			Err(WispError::InvalidPacketType)
827		}
828	}
829
830	/// Create a new stream, multiplexed through Wisp.
831	pub async fn client_new_stream(
832		&self,
833		stream_type: StreamType,
834		host: String,
835		port: u16,
836	) -> Result<MuxStream, WispError> {
837		if self.fut_exited.load(Ordering::Acquire) {
838			return Err(WispError::MuxTaskEnded);
839		}
840		if stream_type == StreamType::Udp
841			&& !self
842				.supported_extension_ids
843				.iter()
844				.any(|x| *x == UdpProtocolExtension::ID)
845		{
846			return Err(WispError::ExtensionsNotSupported(vec![
847				UdpProtocolExtension::ID,
848			]));
849		}
850		let (tx, rx) = oneshot::channel();
851		self.stream_tx
852			.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
853			.await
854			.map_err(|_| WispError::MuxMessageFailedToSend)?;
855		rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
856	}
857
858	async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
859		if self.fut_exited.load(Ordering::Acquire) {
860			return Err(WispError::MuxTaskEnded);
861		}
862		self.stream_tx
863			.send_async(WsEvent::EndFut(reason))
864			.await
865			.map_err(|_| WispError::MuxMessageFailedToSend)
866	}
867
868	/// Close all streams.
869	///
870	/// Also terminates the multiplexor future.
871	pub async fn close(&self) -> Result<(), WispError> {
872		self.close_internal(None).await
873	}
874
875	/// Close all streams and send an extension incompatibility error to the client.
876	///
877	/// Also terminates the multiplexor future.
878	pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
879		self.close_internal(Some(CloseReason::IncompatibleExtensions))
880			.await
881	}
882
883	/// Get a protocol extension stream for sending packets with stream id 0.
884	pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
885		MuxProtocolExtensionStream {
886			stream_id: 0,
887			tx: self.tx.clone(),
888			is_closed: self.fut_exited.clone(),
889		}
890	}
891}
892
893impl Drop for ClientMux {
894	fn drop(&mut self) {
895		let _ = self.stream_tx.send(WsEvent::EndFut(None));
896	}
897}
898
899/// Result of `ClientMux::new`.
900pub struct ClientMuxResult<F>(ClientMux, F)
901where
902	F: Future<Output = Result<(), WispError>> + Send;
903
904impl<F> ClientMuxResult<F>
905where
906	F: Future<Output = Result<(), WispError>> + Send,
907{
908	/// Require no protocol extensions.
909	pub fn with_no_required_extensions(self) -> (ClientMux, F) {
910		(self.0, self.1)
911	}
912
913	/// Require protocol extensions by their ID.
914	pub async fn with_required_extensions(
915		self,
916		extensions: &[u8],
917	) -> Result<(ClientMux, F), WispError> {
918		let mut unsupported_extensions = Vec::new();
919		for extension in extensions {
920			if !self.0.supported_extension_ids.contains(extension) {
921				unsupported_extensions.push(*extension);
922			}
923		}
924		if unsupported_extensions.is_empty() {
925			Ok((self.0, self.1))
926		} else {
927			self.0.close_extension_incompat().await?;
928			self.1.await?;
929			Err(WispError::ExtensionsNotSupported(unsupported_extensions))
930		}
931	}
932
933	/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
934	pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
935		self.with_required_extensions(&[UdpProtocolExtension::ID])
936			.await
937	}
938}