wisp_mux/
stream.rs

1use crate::{
2	sink_unfold,
3	ws::{Frame, LockedWebSocketWrite, Payload},
4	AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
5};
6
7use bytes::{BufMut, Bytes, BytesMut};
8use event_listener::Event;
9use flume as mpsc;
10use futures::{
11	channel::oneshot,
12	ready, select,
13	stream::{self, IntoAsyncRead},
14	task::{noop_waker_ref, Context, Poll},
15	AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt,
16};
17use pin_project_lite::pin_project;
18use std::{
19	pin::Pin,
20	sync::{
21		atomic::{AtomicBool, AtomicU32, Ordering},
22		Arc,
23	},
24};
25
26pub(crate) enum WsEvent {
27	Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
28	CreateStream(
29		StreamType,
30		String,
31		u16,
32		oneshot::Sender<Result<MuxStream, WispError>>,
33	),
34	EndFut(Option<CloseReason>),
35}
36
37/// Read side of a multiplexor stream.
38pub struct MuxStreamRead {
39	/// ID of the stream.
40	pub stream_id: u32,
41	/// Type of the stream.
42	pub stream_type: StreamType,
43
44	role: Role,
45
46	tx: LockedWebSocketWrite,
47	rx: mpsc::Receiver<Bytes>,
48
49	is_closed: Arc<AtomicBool>,
50	is_closed_event: Arc<Event>,
51	close_reason: Arc<AtomicCloseReason>,
52
53	flow_control: Arc<AtomicU32>,
54	flow_control_read: AtomicU32,
55	target_flow_control: u32,
56}
57
58impl MuxStreamRead {
59	/// Read an event from the stream.
60	pub async fn read(&self) -> Option<Bytes> {
61		if self.is_closed.load(Ordering::Acquire) {
62			return None;
63		}
64		let bytes = select! {
65			x = self.rx.recv_async() => x.ok()?,
66			_ = self.is_closed_event.listen().fuse() => return None
67		};
68		if self.role == Role::Server && self.stream_type == StreamType::Tcp {
69			let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
70			if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) {
71				self.tx
72					.write_frame(
73						Packet::new_continue(
74							self.stream_id,
75							self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
76						)
77						.into(),
78					)
79					.await
80					.ok()?;
81				self.flow_control_read.store(0, Ordering::Release);
82			}
83		}
84		Some(bytes)
85	}
86
87	pub(crate) fn into_inner_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
88		Box::pin(stream::unfold(self, |rx| async move {
89			Some((rx.read().await?, rx))
90		}))
91	}
92
93	/// Turn the read half into one that implements futures `Stream`, consuming it.
94	pub fn into_stream(self) -> MuxStreamIoStream {
95		MuxStreamIoStream {
96			rx: self.into_inner_stream(),
97		}
98	}
99
100	/// Get the stream's close reason, if it was closed.
101	pub fn get_close_reason(&self) -> Option<CloseReason> {
102		if self.is_closed.load(Ordering::Acquire) {
103			Some(self.close_reason.load(Ordering::Acquire))
104		} else {
105			None
106		}
107	}
108}
109
110/// Write side of a multiplexor stream.
111pub struct MuxStreamWrite {
112	/// ID of the stream.
113	pub stream_id: u32,
114	/// Type of the stream.
115	pub stream_type: StreamType,
116
117	role: Role,
118	mux_tx: mpsc::Sender<WsEvent>,
119	tx: LockedWebSocketWrite,
120
121	is_closed: Arc<AtomicBool>,
122	close_reason: Arc<AtomicCloseReason>,
123
124	continue_recieved: Arc<Event>,
125	flow_control: Arc<AtomicU32>,
126}
127
128impl MuxStreamWrite {
129	pub(crate) async fn write_payload_internal<'a>(
130		&self,
131		header: Frame<'static>,
132		body: Frame<'a>,
133	) -> Result<(), WispError> {
134		if self.role == Role::Client
135			&& self.stream_type == StreamType::Tcp
136			&& self.flow_control.load(Ordering::Acquire) == 0
137		{
138			self.continue_recieved.listen().await;
139		}
140		if self.is_closed.load(Ordering::Acquire) {
141			return Err(WispError::StreamAlreadyClosed);
142		}
143
144		self.tx.write_split(header, body).await?;
145
146		if self.role == Role::Client && self.stream_type == StreamType::Tcp {
147			self.flow_control.store(
148				self.flow_control.load(Ordering::Acquire).saturating_sub(1),
149				Ordering::Release,
150			);
151		}
152		Ok(())
153	}
154
155	/// Write a payload to the stream.
156	pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
157		let frame: Frame<'static> = Frame::from(Packet::new_data(
158			self.stream_id,
159			Payload::Bytes(BytesMut::new()),
160		));
161		self.write_payload_internal(frame, Frame::binary(data))
162			.await
163	}
164
165	/// Write data to the stream.
166	pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
167		self.write_payload(Payload::Borrowed(data.as_ref())).await
168	}
169
170	/// Get a handle to close the connection.
171	///
172	/// Useful to close the connection without having access to the stream.
173	///
174	/// # Example
175	/// ```
176	/// let handle = stream.get_close_handle();
177	/// if let Err(error) = handle_stream(stream) {
178	///     handle.close(0x01);
179	/// }
180	/// ```
181	pub fn get_close_handle(&self) -> MuxStreamCloser {
182		MuxStreamCloser {
183			stream_id: self.stream_id,
184			close_channel: self.mux_tx.clone(),
185			is_closed: self.is_closed.clone(),
186			close_reason: self.close_reason.clone(),
187		}
188	}
189
190	/// Get a protocol extension stream to send protocol extension packets.
191	pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
192		MuxProtocolExtensionStream {
193			stream_id: self.stream_id,
194			tx: self.tx.clone(),
195			is_closed: self.is_closed.clone(),
196		}
197	}
198
199	/// Close the stream. You will no longer be able to write or read after this has been called.
200	pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
201		if self.is_closed.load(Ordering::Acquire) {
202			return Err(WispError::StreamAlreadyClosed);
203		}
204		self.is_closed.store(true, Ordering::Release);
205
206		let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
207		self.mux_tx
208			.send_async(WsEvent::Close(
209				Packet::new_close(self.stream_id, reason),
210				tx,
211			))
212			.await
213			.map_err(|_| WispError::MuxMessageFailedToSend)?;
214		rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
215
216		Ok(())
217	}
218
219	/// Get the stream's close reason, if it was closed.
220	pub fn get_close_reason(&self) -> Option<CloseReason> {
221		if self.is_closed.load(Ordering::Acquire) {
222			Some(self.close_reason.load(Ordering::Acquire))
223		} else {
224			None
225		}
226	}
227
228	pub(crate) fn into_inner_sink(
229		self,
230	) -> Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>> {
231		let handle = self.get_close_handle();
232		Box::pin(sink_unfold::unfold(
233			self,
234			|tx, data| async move {
235				tx.write_payload(data).await?;
236				Ok(tx)
237			},
238			handle,
239			|handle| async move {
240				handle.close(CloseReason::Unknown).await?;
241				Ok(handle)
242			},
243		))
244	}
245
246	/// Turn the write half into one that implements futures `Sink`, consuming it.
247	pub fn into_sink(self) -> MuxStreamIoSink {
248		MuxStreamIoSink {
249			tx: self.into_inner_sink(),
250		}
251	}
252}
253
254impl Drop for MuxStreamWrite {
255	fn drop(&mut self) {
256		if !self.is_closed.load(Ordering::Acquire) {
257			self.is_closed.store(true, Ordering::Release);
258			let (tx, _) = oneshot::channel();
259			let _ = self.mux_tx.send(WsEvent::Close(
260				Packet::new_close(self.stream_id, CloseReason::Unknown),
261				tx,
262			));
263		}
264	}
265}
266
267/// Multiplexor stream.
268pub struct MuxStream {
269	/// ID of the stream.
270	pub stream_id: u32,
271	rx: MuxStreamRead,
272	tx: MuxStreamWrite,
273}
274
275impl MuxStream {
276	#[allow(clippy::too_many_arguments)]
277	pub(crate) fn new(
278		stream_id: u32,
279		role: Role,
280		stream_type: StreamType,
281		rx: mpsc::Receiver<Bytes>,
282		mux_tx: mpsc::Sender<WsEvent>,
283		tx: LockedWebSocketWrite,
284		is_closed: Arc<AtomicBool>,
285		is_closed_event: Arc<Event>,
286		close_reason: Arc<AtomicCloseReason>,
287		flow_control: Arc<AtomicU32>,
288		continue_recieved: Arc<Event>,
289		target_flow_control: u32,
290	) -> Self {
291		Self {
292			stream_id,
293			rx: MuxStreamRead {
294				stream_id,
295				stream_type,
296				role,
297				tx: tx.clone(),
298				rx,
299				is_closed: is_closed.clone(),
300				is_closed_event: is_closed_event.clone(),
301				close_reason: close_reason.clone(),
302				flow_control: flow_control.clone(),
303				flow_control_read: AtomicU32::new(0),
304				target_flow_control,
305			},
306			tx: MuxStreamWrite {
307				stream_id,
308				stream_type,
309				role,
310				mux_tx,
311				tx,
312				is_closed: is_closed.clone(),
313				close_reason: close_reason.clone(),
314				flow_control: flow_control.clone(),
315				continue_recieved: continue_recieved.clone(),
316			},
317		}
318	}
319
320	/// Read an event from the stream.
321	pub async fn read(&self) -> Option<Bytes> {
322		self.rx.read().await
323	}
324
325	/// Write a payload to the stream.
326	pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
327		self.tx.write_payload(data).await
328	}
329
330	/// Write data to the stream.
331	pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
332		self.tx.write(data).await
333	}
334
335	/// Get a handle to close the connection.
336	///
337	/// Useful to close the connection without having access to the stream.
338	///
339	/// # Example
340	/// ```
341	/// let handle = stream.get_close_handle();
342	/// if let Err(error) = handle_stream(stream) {
343	///     handle.close(0x01);
344	/// }
345	/// ```
346	pub fn get_close_handle(&self) -> MuxStreamCloser {
347		self.tx.get_close_handle()
348	}
349
350	/// Get a protocol extension stream to send protocol extension packets.
351	pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
352		self.tx.get_protocol_extension_stream()
353	}
354
355	/// Close the stream. You will no longer be able to write or read after this has been called.
356	pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
357		self.tx.close(reason).await
358	}
359
360	/// Split the stream into read and write parts, consuming it.
361	pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
362		(self.rx, self.tx)
363	}
364
365	/// Turn the stream into one that implements futures `Stream + Sink`, consuming it.
366	pub fn into_io(self) -> MuxStreamIo {
367		MuxStreamIo {
368			rx: self.rx.into_stream(),
369			tx: self.tx.into_sink(),
370		}
371	}
372}
373
374/// Close handle for a multiplexor stream.
375#[derive(Clone)]
376pub struct MuxStreamCloser {
377	/// ID of the stream.
378	pub stream_id: u32,
379	close_channel: mpsc::Sender<WsEvent>,
380	is_closed: Arc<AtomicBool>,
381	close_reason: Arc<AtomicCloseReason>,
382}
383
384impl MuxStreamCloser {
385	/// Close the stream. You will no longer be able to write or read after this has been called.
386	pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
387		if self.is_closed.load(Ordering::Acquire) {
388			return Err(WispError::StreamAlreadyClosed);
389		}
390		self.is_closed.store(true, Ordering::Release);
391
392		let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
393		self.close_channel
394			.send_async(WsEvent::Close(
395				Packet::new_close(self.stream_id, reason),
396				tx,
397			))
398			.await
399			.map_err(|_| WispError::MuxMessageFailedToSend)?;
400		rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
401
402		Ok(())
403	}
404
405	/// Get the stream's close reason, if it was closed.
406	pub fn get_close_reason(&self) -> Option<CloseReason> {
407		if self.is_closed.load(Ordering::Acquire) {
408			Some(self.close_reason.load(Ordering::Acquire))
409		} else {
410			None
411		}
412	}
413}
414
415/// Stream for sending arbitrary protocol extension packets.
416pub struct MuxProtocolExtensionStream {
417	/// ID of the stream.
418	pub stream_id: u32,
419	pub(crate) tx: LockedWebSocketWrite,
420	pub(crate) is_closed: Arc<AtomicBool>,
421}
422
423impl MuxProtocolExtensionStream {
424	/// Send a protocol extension packet with this stream's ID.
425	pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
426		if self.is_closed.load(Ordering::Acquire) {
427			return Err(WispError::StreamAlreadyClosed);
428		}
429		let mut encoded = BytesMut::with_capacity(1 + 4 + data.len());
430		encoded.put_u8(packet_type);
431		encoded.put_u32_le(self.stream_id);
432		encoded.extend(data);
433		self.tx
434			.write_frame(Frame::binary(Payload::Bytes(encoded)))
435			.await
436	}
437}
438
439pin_project! {
440	/// Multiplexor stream that implements futures `Stream + Sink`.
441	pub struct MuxStreamIo {
442		#[pin]
443		rx: MuxStreamIoStream,
444		#[pin]
445		tx: MuxStreamIoSink,
446	}
447}
448
449impl MuxStreamIo {
450	/// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`.
451	pub fn into_asyncrw(self) -> MuxStreamAsyncRW {
452		MuxStreamAsyncRW {
453			rx: self.rx.into_asyncread(),
454			tx: self.tx.into_asyncwrite(),
455		}
456	}
457
458	/// Split the stream into read and write parts, consuming it.
459	pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) {
460		(self.rx, self.tx)
461	}
462}
463
464impl Stream for MuxStreamIo {
465	type Item = Result<Bytes, std::io::Error>;
466	fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467		self.project().rx.poll_next(cx)
468	}
469}
470
471impl Sink<&[u8]> for MuxStreamIo {
472	type Error = std::io::Error;
473	fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
474		self.project().tx.poll_ready(cx)
475	}
476	fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
477		self.project().tx.start_send(item)
478	}
479	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
480		self.project().tx.poll_flush(cx)
481	}
482	fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
483		self.project().tx.poll_close(cx)
484	}
485}
486
487pin_project! {
488	/// Read side of a multiplexor stream that implements futures `Stream`.
489	pub struct MuxStreamIoStream {
490		#[pin]
491		rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
492	}
493}
494
495impl MuxStreamIoStream {
496	/// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead`.
497	pub fn into_asyncread(self) -> MuxStreamAsyncRead {
498		MuxStreamAsyncRead::new(self)
499	}
500}
501
502impl Stream for MuxStreamIoStream {
503	type Item = Result<Bytes, std::io::Error>;
504	fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
505		self.project().rx.poll_next(cx).map(|x| x.map(Ok))
506	}
507}
508
509pin_project! {
510	/// Write side of a multiplexor stream that implements futures `Sink`.
511	pub struct MuxStreamIoSink {
512		#[pin]
513		tx: Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>>,
514	}
515}
516
517impl MuxStreamIoSink {
518	/// Turn the sink into one that implements futures `AsyncWrite`.
519	pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite {
520		MuxStreamAsyncWrite::new(self)
521	}
522}
523
524impl Sink<&[u8]> for MuxStreamIoSink {
525	type Error = std::io::Error;
526	fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
527		self.project()
528			.tx
529			.poll_ready(cx)
530			.map_err(std::io::Error::other)
531	}
532	fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
533		self.project()
534			.tx
535			.start_send(Payload::Bytes(BytesMut::from(item)))
536			.map_err(std::io::Error::other)
537	}
538	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
539		self.project()
540			.tx
541			.poll_flush(cx)
542			.map_err(std::io::Error::other)
543	}
544	fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
545		self.project()
546			.tx
547			.poll_close(cx)
548			.map_err(std::io::Error::other)
549	}
550}
551
552pin_project! {
553	/// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`.
554	pub struct MuxStreamAsyncRW {
555		#[pin]
556		rx: MuxStreamAsyncRead,
557		#[pin]
558		tx: MuxStreamAsyncWrite,
559	}
560}
561
562impl MuxStreamAsyncRW {
563	/// Split the stream into read and write parts, consuming it.
564	pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) {
565		(self.rx, self.tx)
566	}
567}
568
569impl AsyncRead for MuxStreamAsyncRW {
570	fn poll_read(
571		self: Pin<&mut Self>,
572		cx: &mut Context<'_>,
573		buf: &mut [u8],
574	) -> Poll<std::io::Result<usize>> {
575		self.project().rx.poll_read(cx, buf)
576	}
577
578	fn poll_read_vectored(
579		self: Pin<&mut Self>,
580		cx: &mut Context<'_>,
581		bufs: &mut [std::io::IoSliceMut<'_>],
582	) -> Poll<std::io::Result<usize>> {
583		self.project().rx.poll_read_vectored(cx, bufs)
584	}
585}
586
587impl AsyncBufRead for MuxStreamAsyncRW {
588	fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
589		self.project().rx.poll_fill_buf(cx)
590	}
591
592	fn consume(self: Pin<&mut Self>, amt: usize) {
593		self.project().rx.consume(amt)
594	}
595}
596
597impl AsyncWrite for MuxStreamAsyncRW {
598	fn poll_write(
599		self: Pin<&mut Self>,
600		cx: &mut Context<'_>,
601		buf: &[u8],
602	) -> Poll<std::io::Result<usize>> {
603		self.project().tx.poll_write(cx, buf)
604	}
605
606	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
607		self.project().tx.poll_flush(cx)
608	}
609
610	fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
611		self.project().tx.poll_close(cx)
612	}
613}
614
615pin_project! {
616	/// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`.
617	pub struct MuxStreamAsyncRead {
618		#[pin]
619		rx: IntoAsyncRead<MuxStreamIoStream>,
620		// state: Option<MuxStreamAsyncReadState>
621	}
622}
623
624impl MuxStreamAsyncRead {
625	pub(crate) fn new(stream: MuxStreamIoStream) -> Self {
626		Self {
627			rx: stream.into_async_read(),
628			// state: None,
629		}
630	}
631}
632
633impl AsyncRead for MuxStreamAsyncRead {
634	fn poll_read(
635		self: Pin<&mut Self>,
636		cx: &mut Context<'_>,
637		buf: &mut [u8],
638	) -> Poll<std::io::Result<usize>> {
639		self.project().rx.poll_read(cx, buf)
640	}
641}
642impl AsyncBufRead for MuxStreamAsyncRead {
643	fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
644		self.project().rx.poll_fill_buf(cx)
645	}
646	fn consume(self: Pin<&mut Self>, amt: usize) {
647		self.project().rx.consume(amt)
648	}
649}
650
651pin_project! {
652	/// Write side of a multiplexor stream that implements futures `AsyncWrite`.
653	pub struct MuxStreamAsyncWrite {
654		#[pin]
655		tx: MuxStreamIoSink,
656		error: Option<std::io::Error>
657	}
658}
659
660impl MuxStreamAsyncWrite {
661	pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
662		Self {
663			tx: sink,
664			error: None,
665		}
666	}
667}
668
669impl AsyncWrite for MuxStreamAsyncWrite {
670	fn poll_write(
671		mut self: Pin<&mut Self>,
672		cx: &mut Context<'_>,
673		buf: &[u8],
674	) -> Poll<std::io::Result<usize>> {
675		if let Some(err) = self.error.take() {
676			return Poll::Ready(Err(err));
677		}
678
679		let mut this = self.as_mut().project();
680
681		ready!(this.tx.as_mut().poll_ready(cx))?;
682		match this.tx.as_mut().start_send(buf) {
683			Ok(()) => {
684				let mut cx = Context::from_waker(noop_waker_ref());
685				let cx = &mut cx;
686
687				match this.tx.poll_flush(cx) {
688					Poll::Ready(Err(err)) => {
689						self.error = Some(err);
690					}
691					Poll::Ready(Ok(_)) | Poll::Pending => {}
692				}
693
694				Poll::Ready(Ok(buf.len()))
695			}
696			Err(e) => Poll::Ready(Err(e)),
697		}
698	}
699
700	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
701		self.project().tx.poll_flush(cx)
702	}
703
704	fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
705		self.project().tx.poll_close(cx)
706	}
707}