Skip to main content

worker/
streams.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures_util::{Stream, TryStreamExt};
7use js_sys::{BigInt, Uint8Array};
8use pin_project::pin_project;
9use wasm_bindgen::{JsCast, JsValue};
10use wasm_streams::readable::IntoStream;
11use web_sys::ReadableStream;
12use worker_sys::FixedLengthStream as FixedLengthStreamSys;
13
14use crate::{Error, Result};
15
16#[pin_project]
17#[derive(Debug)]
18pub struct ByteStream {
19    #[pin]
20    pub(crate) inner: IntoStream<'static>,
21}
22
23impl From<ReadableStream> for ByteStream {
24    fn from(stream: ReadableStream) -> Self {
25        Self {
26            inner: wasm_streams::ReadableStream::from_raw(stream.unchecked_into()).into_stream(),
27        }
28    }
29}
30
31impl Stream for ByteStream {
32    type Item = Result<Vec<u8>>;
33
34    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
35        let this = self.project();
36        let item = match futures_util::ready!(this.inner.poll_next(cx)) {
37            Some(res) => res.map(Uint8Array::from).map_err(Error::from),
38            None => return Poll::Ready(None),
39        };
40
41        Poll::Ready(match item {
42            Ok(value) => Some(Ok(value.to_vec())),
43            Err(e) if e.to_string() == "Error: aborted" => None,
44            Err(e) => Some(Err(e)),
45        })
46    }
47}
48
49#[pin_project]
50pub struct FixedLengthStream {
51    length: u64,
52    #[pin]
53    bytes_read: u64,
54    #[pin]
55    inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + 'static>>,
56}
57
58impl core::fmt::Debug for FixedLengthStream {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
60        f.debug_struct("FixedLengthStream")
61            .field("length", &self.length)
62            .field("bytes_read", &self.bytes_read)
63            .finish()
64    }
65}
66
67impl FixedLengthStream {
68    pub fn wrap(stream: impl Stream<Item = Result<Vec<u8>>> + 'static, length: u64) -> Self {
69        Self {
70            length,
71            bytes_read: 0,
72            inner: Box::pin(stream),
73        }
74    }
75}
76
77impl Stream for FixedLengthStream {
78    type Item = Result<Vec<u8>>;
79
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        let mut this = self.project();
82        let item = if let Some(res) = futures_util::ready!(this.inner.poll_next(cx)) {
83            let chunk = match res {
84                Ok(chunk) => chunk,
85                Err(err) => return Poll::Ready(Some(Err(err))),
86            };
87
88            *this.bytes_read += chunk.len() as u64;
89
90            if *this.bytes_read > *this.length {
91                let err = Error::from(format!(
92                    "fixed length stream had different length than expected (expected {}, got {})",
93                    *this.length, *this.bytes_read,
94                ));
95                Some(Err(err))
96            } else {
97                Some(Ok(chunk))
98            }
99        } else if *this.bytes_read != *this.length {
100            let err = Error::from(format!(
101                "fixed length stream had different length than expected (expected {}, got {})",
102                *this.length, *this.bytes_read,
103            ));
104            Some(Err(err))
105        } else {
106            None
107        };
108
109        Poll::Ready(item)
110    }
111}
112
113impl From<FixedLengthStream> for FixedLengthStreamSys {
114    fn from(stream: FixedLengthStream) -> Self {
115        let raw = if stream.length < u32::MAX as u64 {
116            FixedLengthStreamSys::new(stream.length as u32).unwrap()
117        } else {
118            FixedLengthStreamSys::new_big_int(BigInt::from(stream.length)).unwrap()
119        };
120
121        let js_stream = stream
122            .map_ok(|item| -> Vec<u8> { item })
123            .map_ok(|chunk| {
124                let array = Uint8Array::new_with_length(chunk.len() as _);
125                array.copy_from(&chunk);
126
127                array.into()
128            })
129            .map_err(JsValue::from);
130
131        let stream: ReadableStream = wasm_streams::ReadableStream::from_stream(js_stream)
132            .as_raw()
133            .clone()
134            .unchecked_into();
135        let _ = stream.pipe_to(&raw.writable());
136
137        raw
138    }
139}