1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures_util::{Stream, TryStreamExt};
7use js_sys::{BigInt, Uint8Array};
8use pin_project::pin_project;
9use wasm_bindgen::{JsCast, JsValue};
10use wasm_streams::readable::IntoStream;
11use web_sys::ReadableStream;
12use worker_sys::FixedLengthStream as FixedLengthStreamSys;
13
14use crate::{Error, Result};
15
16#[pin_project]
17#[derive(Debug)]
18pub struct ByteStream {
19 #[pin]
20 pub(crate) inner: IntoStream<'static>,
21}
22
23impl From<ReadableStream> for ByteStream {
24 fn from(stream: ReadableStream) -> Self {
25 Self {
26 inner: wasm_streams::ReadableStream::from_raw(stream.unchecked_into()).into_stream(),
27 }
28 }
29}
30
31impl Stream for ByteStream {
32 type Item = Result<Vec<u8>>;
33
34 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
35 let this = self.project();
36 let item = match futures_util::ready!(this.inner.poll_next(cx)) {
37 Some(res) => res.map(Uint8Array::from).map_err(Error::from),
38 None => return Poll::Ready(None),
39 };
40
41 Poll::Ready(match item {
42 Ok(value) => Some(Ok(value.to_vec())),
43 Err(e) if e.to_string() == "Error: aborted" => None,
44 Err(e) => Some(Err(e)),
45 })
46 }
47}
48
49#[pin_project]
50pub struct FixedLengthStream {
51 length: u64,
52 #[pin]
53 bytes_read: u64,
54 #[pin]
55 inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + 'static>>,
56}
57
58impl core::fmt::Debug for FixedLengthStream {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60 f.debug_struct("FixedLengthStream")
61 .field("length", &self.length)
62 .field("bytes_read", &self.bytes_read)
63 .finish()
64 }
65}
66
67impl FixedLengthStream {
68 pub fn wrap(stream: impl Stream<Item = Result<Vec<u8>>> + 'static, length: u64) -> Self {
69 Self {
70 length,
71 bytes_read: 0,
72 inner: Box::pin(stream),
73 }
74 }
75}
76
77impl Stream for FixedLengthStream {
78 type Item = Result<Vec<u8>>;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let mut this = self.project();
82 let item = if let Some(res) = futures_util::ready!(this.inner.poll_next(cx)) {
83 let chunk = match res {
84 Ok(chunk) => chunk,
85 Err(err) => return Poll::Ready(Some(Err(err))),
86 };
87
88 *this.bytes_read += chunk.len() as u64;
89
90 if *this.bytes_read > *this.length {
91 let err = Error::from(format!(
92 "fixed length stream had different length than expected (expected {}, got {})",
93 *this.length, *this.bytes_read,
94 ));
95 Some(Err(err))
96 } else {
97 Some(Ok(chunk))
98 }
99 } else if *this.bytes_read != *this.length {
100 let err = Error::from(format!(
101 "fixed length stream had different length than expected (expected {}, got {})",
102 *this.length, *this.bytes_read,
103 ));
104 Some(Err(err))
105 } else {
106 None
107 };
108
109 Poll::Ready(item)
110 }
111}
112
113impl From<FixedLengthStream> for FixedLengthStreamSys {
114 fn from(stream: FixedLengthStream) -> Self {
115 let raw = if stream.length < u32::MAX as u64 {
116 FixedLengthStreamSys::new(stream.length as u32).unwrap()
117 } else {
118 FixedLengthStreamSys::new_big_int(BigInt::from(stream.length)).unwrap()
119 };
120
121 let js_stream = stream
122 .map_ok(|item| -> Vec<u8> { item })
123 .map_ok(|chunk| {
124 let array = Uint8Array::new_with_length(chunk.len() as _);
125 array.copy_from(&chunk);
126
127 array.into()
128 })
129 .map_err(JsValue::from);
130
131 let stream: ReadableStream = wasm_streams::ReadableStream::from_stream(js_stream)
132 .as_raw()
133 .clone()
134 .unchecked_into();
135 let _ = stream.pipe_to(&raw.writable());
136
137 raw
138 }
139}