Skip to main content

webtrans_wasm/
recv.rs

1use std::cmp;
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{BufMut, Bytes};
8use futures_io::AsyncRead;
9use js_sys::Uint8Array;
10use web_sys::WebTransportReceiveStream;
11
12use crate::Error;
13use web_streams::Reader;
14
15type ReadFuture = Pin<Box<dyn Future<Output = (Reader<Uint8Array>, io::Result<Option<Bytes>>)>>>;
16
17enum ReadState {
18    Idle,
19    Reading(ReadFuture),
20}
21
22/// A byte stream received from the remote peer.
23///
24/// Either side may close with an error code, or the peer may close with a FIN.
25pub struct RecvStream {
26    reader: Option<Reader<Uint8Array>>,
27    buffer: Bytes,
28    read_state: ReadState,
29}
30
31impl RecvStream {
32    pub(super) fn new(stream: WebTransportReceiveStream) -> Result<Self, Error> {
33        let reader = Reader::new(&stream)?;
34
35        Ok(Self {
36            reader: Some(reader),
37            buffer: Bytes::new(),
38            read_state: ReadState::Idle,
39        })
40    }
41
42    /// Read the next chunk of data with the provided maximum size.
43    ///
44    /// This returns a chunk of data instead of copying, which can be more efficient.
45    pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
46        if !self.buffer.is_empty() {
47            let size = cmp::min(max, self.buffer.len());
48            let data = self.buffer.split_to(size);
49            return Ok(Some(data));
50        }
51
52        let reader = self
53            .reader
54            .as_mut()
55            .ok_or_else(|| Error::Unknown("reader is unavailable".into()))?;
56
57        let mut data: Bytes = match reader.read().await? {
58            Some(data) => Bytes::from(data.to_vec()),
59            None => return Ok(None),
60        };
61
62        if data.len() > max {
63            // The chunk is too large; buffer the remainder for the next read.
64            self.buffer = data.split_off(max);
65        }
66
67        Ok(Some(data))
68    }
69
70    /// Read some data into the provided buffer.
71    ///
72    /// Returns the (non-zero) number of bytes read, or `None` if the stream is closed.
73    /// Advances the buffer by the number of bytes read.
74    pub async fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Result<Option<usize>, Error> {
75        let chunk = match self.read(buf.remaining_mut()).await? {
76            Some(chunk) => chunk,
77            None => return Ok(None),
78        };
79
80        let size = chunk.len();
81        buf.put(chunk);
82
83        Ok(Some(size))
84    }
85
86    /// Abort reading from the stream with the given reason.
87    pub fn stop(&mut self, reason: &str) {
88        if let Some(reader) = self.reader.as_mut() {
89            reader.abort(reason);
90        }
91    }
92
93    /// Block until the stream has closed and return the error code, if any.
94    pub async fn closed(&self) -> Result<Option<u8>, Error> {
95        let reader = match self.reader.as_ref() {
96            Some(reader) => reader,
97            None => return Err(Error::Unknown("reader is unavailable".into())),
98        };
99
100        let err = match reader.closed().await {
101            Ok(()) => return Ok(None),
102            Err(err) => Error::from(err),
103        };
104
105        // If this is a WebTransportError, extract the error code when available.
106        if let Error::Stream(err) = &err {
107            if let Some(code) = err.stream_error_code() {
108                return Ok(Some(code));
109            }
110        }
111
112        Err(err)
113    }
114}
115
116impl Drop for RecvStream {
117    fn drop(&mut self) {
118        if let Some(reader) = self.reader.as_mut() {
119            reader.abort("dropped");
120        }
121    }
122}
123
124impl RecvStream {
125    fn poll_inflight_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<Bytes>>> {
126        match &mut self.read_state {
127            ReadState::Idle => Poll::Ready(Ok(None)),
128            ReadState::Reading(fut) => match fut.as_mut().poll(cx) {
129                Poll::Pending => Poll::Pending,
130                Poll::Ready((reader, result)) => {
131                    self.reader = Some(reader);
132                    self.read_state = ReadState::Idle;
133                    Poll::Ready(result)
134                }
135            },
136        }
137    }
138
139    fn error_unavailable() -> io::Error {
140        io::Error::new(io::ErrorKind::Other, "reader is unavailable")
141    }
142
143    fn to_io_error(error: Error) -> io::Error {
144        io::Error::new(io::ErrorKind::Other, error.to_string())
145    }
146}
147
148impl AsyncRead for RecvStream {
149    fn poll_read(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &mut [u8],
153    ) -> Poll<io::Result<usize>> {
154        if buf.is_empty() {
155            return Poll::Ready(Ok(0));
156        }
157
158        loop {
159            if !self.buffer.is_empty() {
160                let size = cmp::min(buf.len(), self.buffer.len());
161                buf[..size].copy_from_slice(&self.buffer.split_to(size));
162                return Poll::Ready(Ok(size));
163            }
164
165            if matches!(self.read_state, ReadState::Idle) {
166                let mut reader = match self.reader.take() {
167                    Some(reader) => reader,
168                    None => return Poll::Ready(Err(Self::error_unavailable())),
169                };
170
171                let fut = Box::pin(async move {
172                    let result = reader
173                        .read()
174                        .await
175                        .map(|data| data.map(|value| Bytes::from(value.to_vec())))
176                        .map_err(|err| Self::to_io_error(err.into()));
177                    (reader, result)
178                });
179                self.read_state = ReadState::Reading(fut);
180            }
181
182            match self.poll_inflight_read(cx) {
183                Poll::Pending => return Poll::Pending,
184                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
185                Poll::Ready(Ok(None)) => return Poll::Ready(Ok(0)),
186                Poll::Ready(Ok(Some(chunk))) => {
187                    self.buffer = chunk;
188                }
189            }
190        }
191    }
192}
193
194#[cfg(target_family = "wasm")]
195impl webtrans_trait::RecvStream for RecvStream {
196    type Error = Error;
197
198    async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
199        let chunk = match Self::read(self, dst.len()).await? {
200            Some(chunk) => chunk,
201            None => return Ok(None),
202        };
203
204        let size = chunk.len();
205        dst[..size].copy_from_slice(&chunk);
206
207        Ok(Some(size))
208    }
209
210    fn stop(&mut self, code: u32) {
211        Self::stop(self, &code.to_string());
212    }
213
214    async fn closed(&mut self) -> Result<(), Self::Error> {
215        match Self::closed(self).await? {
216            Some(code) => Err(Error::Unknown(
217                format!("stream closed with code {code}").into(),
218            )),
219            None => Ok(()),
220        }
221    }
222}