1use std::cmp;
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{BufMut, Bytes};
8use futures_io::AsyncRead;
9use js_sys::Uint8Array;
10use web_sys::WebTransportReceiveStream;
11
12use crate::Error;
13use web_streams::Reader;
14
15type ReadFuture = Pin<Box<dyn Future<Output = (Reader<Uint8Array>, io::Result<Option<Bytes>>)>>>;
16
17enum ReadState {
18 Idle,
19 Reading(ReadFuture),
20}
21
22pub struct RecvStream {
26 reader: Option<Reader<Uint8Array>>,
27 buffer: Bytes,
28 read_state: ReadState,
29}
30
31impl RecvStream {
32 pub(super) fn new(stream: WebTransportReceiveStream) -> Result<Self, Error> {
33 let reader = Reader::new(&stream)?;
34
35 Ok(Self {
36 reader: Some(reader),
37 buffer: Bytes::new(),
38 read_state: ReadState::Idle,
39 })
40 }
41
42 pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
46 if !self.buffer.is_empty() {
47 let size = cmp::min(max, self.buffer.len());
48 let data = self.buffer.split_to(size);
49 return Ok(Some(data));
50 }
51
52 let reader = self
53 .reader
54 .as_mut()
55 .ok_or_else(|| Error::Unknown("reader is unavailable".into()))?;
56
57 let mut data: Bytes = match reader.read().await? {
58 Some(data) => Bytes::from(data.to_vec()),
59 None => return Ok(None),
60 };
61
62 if data.len() > max {
63 self.buffer = data.split_off(max);
65 }
66
67 Ok(Some(data))
68 }
69
70 pub async fn read_buf<B: BufMut>(&mut self, buf: &mut B) -> Result<Option<usize>, Error> {
75 let chunk = match self.read(buf.remaining_mut()).await? {
76 Some(chunk) => chunk,
77 None => return Ok(None),
78 };
79
80 let size = chunk.len();
81 buf.put(chunk);
82
83 Ok(Some(size))
84 }
85
86 pub fn stop(&mut self, reason: &str) {
88 if let Some(reader) = self.reader.as_mut() {
89 reader.abort(reason);
90 }
91 }
92
93 pub async fn closed(&self) -> Result<Option<u8>, Error> {
95 let reader = match self.reader.as_ref() {
96 Some(reader) => reader,
97 None => return Err(Error::Unknown("reader is unavailable".into())),
98 };
99
100 let err = match reader.closed().await {
101 Ok(()) => return Ok(None),
102 Err(err) => Error::from(err),
103 };
104
105 if let Error::Stream(err) = &err {
107 if let Some(code) = err.stream_error_code() {
108 return Ok(Some(code));
109 }
110 }
111
112 Err(err)
113 }
114}
115
116impl Drop for RecvStream {
117 fn drop(&mut self) {
118 if let Some(reader) = self.reader.as_mut() {
119 reader.abort("dropped");
120 }
121 }
122}
123
124impl RecvStream {
125 fn poll_inflight_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<Bytes>>> {
126 match &mut self.read_state {
127 ReadState::Idle => Poll::Ready(Ok(None)),
128 ReadState::Reading(fut) => match fut.as_mut().poll(cx) {
129 Poll::Pending => Poll::Pending,
130 Poll::Ready((reader, result)) => {
131 self.reader = Some(reader);
132 self.read_state = ReadState::Idle;
133 Poll::Ready(result)
134 }
135 },
136 }
137 }
138
139 fn error_unavailable() -> io::Error {
140 io::Error::new(io::ErrorKind::Other, "reader is unavailable")
141 }
142
143 fn to_io_error(error: Error) -> io::Error {
144 io::Error::new(io::ErrorKind::Other, error.to_string())
145 }
146}
147
148impl AsyncRead for RecvStream {
149 fn poll_read(
150 mut self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &mut [u8],
153 ) -> Poll<io::Result<usize>> {
154 if buf.is_empty() {
155 return Poll::Ready(Ok(0));
156 }
157
158 loop {
159 if !self.buffer.is_empty() {
160 let size = cmp::min(buf.len(), self.buffer.len());
161 buf[..size].copy_from_slice(&self.buffer.split_to(size));
162 return Poll::Ready(Ok(size));
163 }
164
165 if matches!(self.read_state, ReadState::Idle) {
166 let mut reader = match self.reader.take() {
167 Some(reader) => reader,
168 None => return Poll::Ready(Err(Self::error_unavailable())),
169 };
170
171 let fut = Box::pin(async move {
172 let result = reader
173 .read()
174 .await
175 .map(|data| data.map(|value| Bytes::from(value.to_vec())))
176 .map_err(|err| Self::to_io_error(err.into()));
177 (reader, result)
178 });
179 self.read_state = ReadState::Reading(fut);
180 }
181
182 match self.poll_inflight_read(cx) {
183 Poll::Pending => return Poll::Pending,
184 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
185 Poll::Ready(Ok(None)) => return Poll::Ready(Ok(0)),
186 Poll::Ready(Ok(Some(chunk))) => {
187 self.buffer = chunk;
188 }
189 }
190 }
191 }
192}
193
194#[cfg(target_family = "wasm")]
195impl webtrans_trait::RecvStream for RecvStream {
196 type Error = Error;
197
198 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
199 let chunk = match Self::read(self, dst.len()).await? {
200 Some(chunk) => chunk,
201 None => return Ok(None),
202 };
203
204 let size = chunk.len();
205 dst[..size].copy_from_slice(&chunk);
206
207 Ok(Some(size))
208 }
209
210 fn stop(&mut self, code: u32) {
211 Self::stop(self, &code.to_string());
212 }
213
214 async fn closed(&mut self) -> Result<(), Self::Error> {
215 match Self::closed(self).await? {
216 Some(code) => Err(Error::Unknown(
217 format!("stream closed with code {code}").into(),
218 )),
219 None => Ok(()),
220 }
221 }
222}