Skip to main content

wasmtime_wasi_http/
body.rs

1//! Implementation of the `wasi:http/types` interface's various body types.
2
3use crate::bindings::http::types;
4use crate::types::FieldMap;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::BodyExt;
8use http_body_util::combinators::UnsyncBoxBody;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime::format_err;
15use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
16use wasmtime_wasi::runtime::{AbortOnDropJoinHandle, poll_noop};
17
18/// Common type for incoming bodies.
19pub type HyperIncomingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
20
21/// Common type for outgoing bodies.
22pub type HyperOutgoingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
23
24/// The concrete type behind a `was:http/types.incoming-body` resource.
25#[derive(Debug)]
26pub struct HostIncomingBody {
27    body: IncomingBodyState,
28    field_size_limit: usize,
29    /// An optional worker task to keep alive while this body is being read.
30    /// This ensures that if the parent of this body is dropped before the body
31    /// then the backing data behind this worker is kept alive.
32    worker: Option<AbortOnDropJoinHandle<()>>,
33}
34
35impl HostIncomingBody {
36    /// Create a new `HostIncomingBody` with the given `body` and a per-frame timeout
37    pub fn new(
38        body: HyperIncomingBody,
39        between_bytes_timeout: Duration,
40        field_size_limit: usize,
41    ) -> HostIncomingBody {
42        let body = BodyWithTimeout::new(body, between_bytes_timeout);
43        HostIncomingBody {
44            body: IncomingBodyState::Start(body),
45            field_size_limit,
46            worker: None,
47        }
48    }
49
50    /// Retain a worker task that needs to be kept alive while this body is being read.
51    pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
52        assert!(self.worker.is_none());
53        self.worker = Some(worker);
54    }
55
56    /// Try taking the stream of this body, if it's available.
57    pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
58        match &mut self.body {
59            IncomingBodyState::Start(_) => {}
60            IncomingBodyState::InBodyStream(_) => return None,
61        }
62        let (tx, rx) = oneshot::channel();
63        let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
64            IncomingBodyState::Start(b) => b,
65            IncomingBodyState::InBodyStream(_) => unreachable!(),
66        };
67        Some(HostIncomingBodyStream {
68            state: IncomingBodyStreamState::Open { body, tx },
69            buffer: Bytes::new(),
70            error: None,
71        })
72    }
73
74    /// Convert this body into a `HostFutureTrailers` resource.
75    pub fn into_future_trailers(self) -> HostFutureTrailers {
76        HostFutureTrailers::Waiting(self)
77    }
78}
79
80/// Internal state of a [`HostIncomingBody`].
81#[derive(Debug)]
82enum IncomingBodyState {
83    /// The body is stored here meaning that within `HostIncomingBody` the
84    /// `take_stream` method can be called for example.
85    Start(BodyWithTimeout),
86
87    /// The body is within a `HostIncomingBodyStream` meaning that it's not
88    /// currently owned here. The body will be sent back over this channel when
89    /// it's done, however.
90    InBodyStream(oneshot::Receiver<StreamEnd>),
91}
92
93/// Small wrapper around [`HyperIncomingBody`] which adds a timeout to every frame.
94#[derive(Debug)]
95struct BodyWithTimeout {
96    /// Underlying stream that frames are coming from.
97    inner: HyperIncomingBody,
98    /// Currently active timeout that's reset between frames.
99    timeout: Pin<Box<tokio::time::Sleep>>,
100    /// Whether or not `timeout` needs to be reset on the next call to
101    /// `poll_frame`.
102    reset_sleep: bool,
103    /// Maximal duration between when a frame is first requested and when it's
104    /// allowed to arrive.
105    between_bytes_timeout: Duration,
106}
107
108impl BodyWithTimeout {
109    fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
110        BodyWithTimeout {
111            inner,
112            between_bytes_timeout,
113            reset_sleep: true,
114            timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
115                tokio::time::sleep(Duration::new(0, 0))
116            })),
117        }
118    }
119}
120
121impl Body for BodyWithTimeout {
122    type Data = Bytes;
123    type Error = types::ErrorCode;
124
125    fn poll_frame(
126        self: Pin<&mut Self>,
127        cx: &mut Context<'_>,
128    ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
129        let me = Pin::into_inner(self);
130
131        // If the timeout timer needs to be reset, do that now relative to the
132        // current instant. Otherwise test the timeout timer and see if it's
133        // fired yet and if so we've timed out and return an error.
134        if me.reset_sleep {
135            me.timeout
136                .as_mut()
137                .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
138            me.reset_sleep = false;
139        }
140
141        // Register interest in this context on the sleep timer, and if the
142        // sleep elapsed that means that we've timed out.
143        if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
144            return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
145        }
146
147        // Without timeout business now handled check for the frame. If a frame
148        // arrives then the sleep timer will be reset on the next frame.
149        let result = Pin::new(&mut me.inner).poll_frame(cx);
150        me.reset_sleep = result.is_ready();
151        result
152    }
153}
154
155/// Message sent when a `HostIncomingBodyStream` is done to the
156/// `HostFutureTrailers` state.
157#[derive(Debug)]
158enum StreamEnd {
159    /// The body wasn't completely read and was dropped early. May still have
160    /// trailers, but requires reading more frames.
161    Remaining(BodyWithTimeout),
162
163    /// Body was completely read and trailers were read. Here are the trailers.
164    /// Note that `None` means that the body finished without trailers.
165    Trailers(Option<http::HeaderMap>),
166}
167
168/// The concrete type behind the `wasi:io/streams.input-stream` resource returned
169/// by `wasi:http/types.incoming-body`'s `stream` method.
170#[derive(Debug)]
171pub struct HostIncomingBodyStream {
172    state: IncomingBodyStreamState,
173    buffer: Bytes,
174    error: Option<wasmtime::Error>,
175}
176
177impl HostIncomingBodyStream {
178    fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
179        match frame {
180            Some(Ok(frame)) => match frame.into_data() {
181                // A data frame was received, so queue up the buffered data for
182                // the next `read` call.
183                Ok(bytes) => {
184                    assert!(self.buffer.is_empty());
185                    self.buffer = bytes;
186                }
187
188                // Trailers were received meaning that this was the final frame.
189                // Throw away the body and send the trailers along the
190                // `tx` channel to make them available.
191                Err(trailers) => {
192                    let trailers = trailers.into_trailers().unwrap();
193                    let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
194                        IncomingBodyStreamState::Open { body: _, tx } => tx,
195                        IncomingBodyStreamState::Closed => unreachable!(),
196                    };
197
198                    // NB: ignore send failures here because if this fails then
199                    // no one was interested in the trailers.
200                    let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
201                }
202            },
203
204            // An error was received meaning that the stream is now done.
205            // Destroy the body to terminate the stream while enqueueing the
206            // error to get returned from the next call to `read`.
207            Some(Err(e)) => {
208                self.error = Some(e.into());
209                self.state = IncomingBodyStreamState::Closed;
210            }
211
212            // No more frames are going to be received again, so drop the `body`
213            // and the `tx` channel we'd send the body back onto because it's
214            // not needed as frames are done.
215            None => {
216                self.state = IncomingBodyStreamState::Closed;
217            }
218        }
219    }
220}
221
222#[derive(Debug)]
223enum IncomingBodyStreamState {
224    /// The body is currently open for reading and present here.
225    ///
226    /// When trailers are read, or when this is dropped, the body is sent along
227    /// `tx`.
228    ///
229    /// This state is transitioned to `Closed` when an error happens, EOF
230    /// happens, or when trailers are read.
231    Open {
232        body: BodyWithTimeout,
233        tx: oneshot::Sender<StreamEnd>,
234    },
235
236    /// This body is closed and no longer available for reading, no more data
237    /// will come.
238    Closed,
239}
240
241#[async_trait::async_trait]
242impl InputStream for HostIncomingBodyStream {
243    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
244        loop {
245            // Handle buffered data/errors if any
246            if !self.buffer.is_empty() {
247                let len = size.min(self.buffer.len());
248                let chunk = self.buffer.split_to(len);
249                return Ok(chunk);
250            }
251
252            if let Some(e) = self.error.take() {
253                return Err(StreamError::LastOperationFailed(e));
254            }
255
256            // Extract the body that we're reading from. If present perform a
257            // non-blocking poll to see if a frame is already here. If it is
258            // then turn the loop again to operate on the results. If it's not
259            // here then return an empty buffer as no data is available at this
260            // time.
261            let body = match &mut self.state {
262                IncomingBodyStreamState::Open { body, .. } => body,
263                IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
264            };
265
266            let future = body.frame();
267            futures::pin_mut!(future);
268            match poll_noop(future) {
269                Some(result) => {
270                    self.record_frame(result);
271                }
272                None => return Ok(Bytes::new()),
273            }
274        }
275    }
276}
277
278#[async_trait::async_trait]
279impl Pollable for HostIncomingBodyStream {
280    async fn ready(&mut self) {
281        if !self.buffer.is_empty() || self.error.is_some() {
282            return;
283        }
284
285        if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
286            let frame = body.frame().await;
287            self.record_frame(frame);
288        }
289    }
290}
291
292impl Drop for HostIncomingBodyStream {
293    fn drop(&mut self) {
294        // When a body stream is dropped, for whatever reason, attempt to send
295        // the body back to the `tx` which will provide the trailers if desired.
296        // This isn't necessary if the state is already closed. Additionally,
297        // like `record_frame` above, `send` errors are ignored as they indicate
298        // that the body/trailers aren't actually needed.
299        let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
300        if let IncomingBodyStreamState::Open { body, tx } = prev {
301            let _ = tx.send(StreamEnd::Remaining(body));
302        }
303    }
304}
305
306/// The concrete type behind a `wasi:http/types.future-trailers` resource.
307#[derive(Debug)]
308pub enum HostFutureTrailers {
309    /// Trailers aren't here yet.
310    ///
311    /// This state represents two similar states:
312    ///
313    /// * The body is here and ready for reading and we're waiting to read
314    ///   trailers. This can happen for example when the actual body wasn't read
315    ///   or if the body was only partially read.
316    ///
317    /// * The body is being read by something else and we're waiting for that to
318    ///   send us the trailers (or the body itself). This state will get entered
319    ///   when the body stream is dropped for example. If the body stream reads
320    ///   the trailers itself it will also send a message over here with the
321    ///   trailers.
322    Waiting(HostIncomingBody),
323
324    /// Trailers are ready and here they are.
325    ///
326    /// Note that `Ok(None)` means that there were no trailers for this request
327    /// while `Ok(Some(_))` means that trailers were found in the request.
328    Done(Result<Option<FieldMap>, types::ErrorCode>),
329
330    /// Trailers have been consumed by `future-trailers.get`.
331    Consumed,
332}
333
334#[async_trait::async_trait]
335impl Pollable for HostFutureTrailers {
336    async fn ready(&mut self) {
337        let body = match self {
338            HostFutureTrailers::Waiting(body) => body,
339            HostFutureTrailers::Done(_) => return,
340            HostFutureTrailers::Consumed => return,
341        };
342
343        // If the body is itself being read by a body stream then we need to
344        // wait for that to be done.
345        if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
346            match rx.await {
347                // Trailers were read for us and here they are, so store the
348                // result.
349                Ok(StreamEnd::Trailers(Some(t))) => {
350                    *self = Self::Done(Ok(Some(FieldMap::new(t, body.field_size_limit))));
351                }
352                // The body wasn't fully read and was dropped before trailers
353                // were reached. It's up to us now to complete the body.
354                Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
355
356                // This means there were no trailers present.
357                Ok(StreamEnd::Trailers(None)) | Err(_) => {
358                    *self = HostFutureTrailers::Done(Ok(None));
359                }
360            }
361        }
362
363        // Here it should be guaranteed that `InBodyStream` is now gone, so if
364        // we have the body ourselves then read frames until trailers are found.
365        let body = match self {
366            HostFutureTrailers::Waiting(body) => body,
367            HostFutureTrailers::Done(_) => return,
368            HostFutureTrailers::Consumed => return,
369        };
370        let hyper_body = match &mut body.body {
371            IncomingBodyState::Start(body) => body,
372            IncomingBodyState::InBodyStream(_) => unreachable!(),
373        };
374        let result = loop {
375            match hyper_body.frame().await {
376                None => break Ok(None),
377                Some(Err(e)) => break Err(e),
378                Some(Ok(frame)) => {
379                    // If this frame is a data frame ignore it as we're only
380                    // interested in trailers.
381                    if let Ok(header_map) = frame.into_trailers() {
382                        break Ok(Some(FieldMap::new(header_map, body.field_size_limit)));
383                    }
384                }
385            }
386        };
387        *self = HostFutureTrailers::Done(result);
388    }
389}
390
391#[derive(Debug, Clone)]
392struct WrittenState {
393    expected: u64,
394    written: Arc<std::sync::atomic::AtomicU64>,
395}
396
397impl WrittenState {
398    fn new(expected_size: u64) -> Self {
399        Self {
400            expected: expected_size,
401            written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
402        }
403    }
404
405    /// The number of bytes that have been written so far.
406    fn written(&self) -> u64 {
407        self.written.load(std::sync::atomic::Ordering::Relaxed)
408    }
409
410    /// Add `len` to the total number of bytes written. Returns `false` if the new total exceeds
411    /// the number of bytes expected to be written.
412    fn update(&self, len: usize) -> bool {
413        let len = len as u64;
414        let old = self
415            .written
416            .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
417        old + len <= self.expected
418    }
419}
420
421/// The concrete type behind a `wasi:http/types.outgoing-body` resource.
422pub struct HostOutgoingBody {
423    /// The output stream that the body is written to.
424    body_output_stream: Option<Box<dyn OutputStream>>,
425    context: StreamContext,
426    written: Option<WrittenState>,
427    finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
428}
429
430impl HostOutgoingBody {
431    /// Create a new `HostOutgoingBody`
432    pub fn new(
433        context: StreamContext,
434        size: Option<u64>,
435        buffer_chunks: usize,
436        chunk_size: usize,
437    ) -> (Self, HyperOutgoingBody) {
438        assert!(buffer_chunks >= 1);
439
440        let written = size.map(WrittenState::new);
441
442        use tokio::sync::oneshot::error::RecvError;
443        struct BodyImpl {
444            body_receiver: mpsc::Receiver<Bytes>,
445            finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
446        }
447        impl Body for BodyImpl {
448            type Data = Bytes;
449            type Error = types::ErrorCode;
450            fn poll_frame(
451                mut self: Pin<&mut Self>,
452                cx: &mut Context<'_>,
453            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
454                match self.as_mut().body_receiver.poll_recv(cx) {
455                    Poll::Pending => Poll::Pending,
456                    Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
457
458                    // This means that the `body_sender` end of the channel has been dropped.
459                    Poll::Ready(None) => {
460                        if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
461                            match Pin::new(&mut finish_receiver).poll(cx) {
462                                Poll::Pending => {
463                                    self.as_mut().finish_receiver = Some(finish_receiver);
464                                    Poll::Pending
465                                }
466                                Poll::Ready(Ok(message)) => match message {
467                                    FinishMessage::Finished => Poll::Ready(None),
468                                    FinishMessage::Trailers(trailers) => {
469                                        Poll::Ready(Some(Ok(Frame::trailers(trailers))))
470                                    }
471                                    FinishMessage::Abort => {
472                                        Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
473                                    }
474                                },
475                                Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
476                            }
477                        } else {
478                            Poll::Ready(None)
479                        }
480                    }
481                }
482            }
483        }
484
485        // always add 1 buffer here because one empty slot is required
486        let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
487        let (finish_sender, finish_receiver) = oneshot::channel();
488        let body_impl = BodyImpl {
489            body_receiver,
490            finish_receiver: Some(finish_receiver),
491        }
492        .boxed_unsync();
493
494        let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
495
496        (
497            Self {
498                body_output_stream: Some(Box::new(output_stream)),
499                context,
500                written,
501                finish_sender: Some(finish_sender),
502            },
503            body_impl,
504        )
505    }
506
507    /// Take the output stream, if it's available.
508    pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
509        self.body_output_stream.take()
510    }
511
512    /// Finish the body, optionally with trailers.
513    pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
514        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
515        // will immediately pick up the finish sender.
516        drop(self.body_output_stream);
517
518        let sender = self
519            .finish_sender
520            .take()
521            .expect("outgoing-body trailer_sender consumed by a non-owning function");
522
523        if let Some(w) = self.written {
524            let written = w.written();
525            if written != w.expected {
526                let _ = sender.send(FinishMessage::Abort);
527                return Err(self.context.as_body_size_error(written));
528            }
529        }
530
531        let message = if let Some(ts) = trailers {
532            FinishMessage::Trailers(ts.into_inner())
533        } else {
534            FinishMessage::Finished
535        };
536
537        // Ignoring failure: receiver died sending body, but we can't report that here.
538        let _ = sender.send(message);
539
540        Ok(())
541    }
542
543    /// Abort the body.
544    pub fn abort(mut self) {
545        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
546        // will immediately pick up the finish sender.
547        drop(self.body_output_stream);
548
549        let sender = self
550            .finish_sender
551            .take()
552            .expect("outgoing-body trailer_sender consumed by a non-owning function");
553
554        let _ = sender.send(FinishMessage::Abort);
555    }
556}
557
558/// Message sent to end the `[HostOutgoingBody]` stream.
559#[derive(Debug)]
560enum FinishMessage {
561    Finished,
562    Trailers(hyper::HeaderMap),
563    Abort,
564}
565
566/// Whether the body is a request or response body.
567#[derive(Clone, Copy, Debug, Eq, PartialEq)]
568pub enum StreamContext {
569    /// The body is a request body.
570    Request,
571    /// The body is a response body.
572    Response,
573}
574
575impl StreamContext {
576    /// Construct the correct [`types::ErrorCode`] body size error.
577    pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
578        match self {
579            StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
580            StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
581        }
582    }
583}
584
585/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`].
586#[derive(Debug)]
587struct BodyWriteStream {
588    context: StreamContext,
589    writer: mpsc::Sender<Bytes>,
590    write_budget: usize,
591    written: Option<WrittenState>,
592}
593
594impl BodyWriteStream {
595    /// Create a [`BodyWriteStream`].
596    fn new(
597        context: StreamContext,
598        write_budget: usize,
599        writer: mpsc::Sender<Bytes>,
600        written: Option<WrittenState>,
601    ) -> Self {
602        // at least one capacity is required to send a message
603        assert!(writer.max_capacity() >= 1);
604        BodyWriteStream {
605            context,
606            writer,
607            write_budget,
608            written,
609        }
610    }
611}
612
613#[async_trait::async_trait]
614impl OutputStream for BodyWriteStream {
615    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
616        let len = bytes.len();
617        match self.writer.try_send(bytes) {
618            // If the message was sent then it's queued up now in hyper to get
619            // received.
620            Ok(()) => {
621                if let Some(written) = self.written.as_ref() {
622                    if !written.update(len) {
623                        let total = written.written();
624                        return Err(StreamError::LastOperationFailed(format_err!(
625                            self.context.as_body_size_error(total)
626                        )));
627                    }
628                }
629
630                Ok(())
631            }
632
633            // If this channel is full then that means `check_write` wasn't
634            // called. The call to `check_write` always guarantees that there's
635            // at least one capacity if a write is allowed.
636            Err(mpsc::error::TrySendError::Full(_)) => {
637                Err(StreamError::Trap(format_err!("write exceeded budget")))
638            }
639
640            // Hyper is gone so this stream is now closed.
641            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
642        }
643    }
644
645    fn flush(&mut self) -> Result<(), StreamError> {
646        // Flushing doesn't happen in this body stream since we're currently
647        // only tracking sending bytes over to hyper.
648        if self.writer.is_closed() {
649            Err(StreamError::Closed)
650        } else {
651            Ok(())
652        }
653    }
654
655    fn check_write(&mut self) -> Result<usize, StreamError> {
656        if self.writer.is_closed() {
657            Err(StreamError::Closed)
658        } else if self.writer.capacity() == 0 {
659            // If there is no more capacity in this sender channel then don't
660            // allow any more writes because the hyper task needs to catch up
661            // now.
662            //
663            // Note that this relies on this task being the only one sending
664            // data to ensure that no one else can steal a write into this
665            // channel.
666            Ok(0)
667        } else {
668            Ok(self.write_budget)
669        }
670    }
671}
672
673#[async_trait::async_trait]
674impl Pollable for BodyWriteStream {
675    async fn ready(&mut self) {
676        // Attempt to perform a reservation for a send. If there's capacity in
677        // the channel or it's already closed then this will return immediately.
678        // If the channel is full this will block until capacity opens up.
679        let _ = self.writer.reserve().await;
680    }
681}