1use crate::bindings::http::types;
4use crate::types::FieldMap;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::BodyExt;
8use http_body_util::combinators::UnsyncBoxBody;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime::format_err;
15use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
16use wasmtime_wasi::runtime::{AbortOnDropJoinHandle, poll_noop};
17
18pub type HyperIncomingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
20
21pub type HyperOutgoingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
23
24#[derive(Debug)]
26pub struct HostIncomingBody {
27 body: IncomingBodyState,
28 field_size_limit: usize,
29 worker: Option<AbortOnDropJoinHandle<()>>,
33}
34
35impl HostIncomingBody {
36 pub fn new(
38 body: HyperIncomingBody,
39 between_bytes_timeout: Duration,
40 field_size_limit: usize,
41 ) -> HostIncomingBody {
42 let body = BodyWithTimeout::new(body, between_bytes_timeout);
43 HostIncomingBody {
44 body: IncomingBodyState::Start(body),
45 field_size_limit,
46 worker: None,
47 }
48 }
49
50 pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
52 assert!(self.worker.is_none());
53 self.worker = Some(worker);
54 }
55
56 pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
58 match &mut self.body {
59 IncomingBodyState::Start(_) => {}
60 IncomingBodyState::InBodyStream(_) => return None,
61 }
62 let (tx, rx) = oneshot::channel();
63 let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
64 IncomingBodyState::Start(b) => b,
65 IncomingBodyState::InBodyStream(_) => unreachable!(),
66 };
67 Some(HostIncomingBodyStream {
68 state: IncomingBodyStreamState::Open { body, tx },
69 buffer: Bytes::new(),
70 error: None,
71 })
72 }
73
74 pub fn into_future_trailers(self) -> HostFutureTrailers {
76 HostFutureTrailers::Waiting(self)
77 }
78}
79
80#[derive(Debug)]
82enum IncomingBodyState {
83 Start(BodyWithTimeout),
86
87 InBodyStream(oneshot::Receiver<StreamEnd>),
91}
92
93#[derive(Debug)]
95struct BodyWithTimeout {
96 inner: HyperIncomingBody,
98 timeout: Pin<Box<tokio::time::Sleep>>,
100 reset_sleep: bool,
103 between_bytes_timeout: Duration,
106}
107
108impl BodyWithTimeout {
109 fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
110 BodyWithTimeout {
111 inner,
112 between_bytes_timeout,
113 reset_sleep: true,
114 timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
115 tokio::time::sleep(Duration::new(0, 0))
116 })),
117 }
118 }
119}
120
121impl Body for BodyWithTimeout {
122 type Data = Bytes;
123 type Error = types::ErrorCode;
124
125 fn poll_frame(
126 self: Pin<&mut Self>,
127 cx: &mut Context<'_>,
128 ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
129 let me = Pin::into_inner(self);
130
131 if me.reset_sleep {
135 me.timeout
136 .as_mut()
137 .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
138 me.reset_sleep = false;
139 }
140
141 if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
144 return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
145 }
146
147 let result = Pin::new(&mut me.inner).poll_frame(cx);
150 me.reset_sleep = result.is_ready();
151 result
152 }
153}
154
155#[derive(Debug)]
158enum StreamEnd {
159 Remaining(BodyWithTimeout),
162
163 Trailers(Option<http::HeaderMap>),
166}
167
168#[derive(Debug)]
171pub struct HostIncomingBodyStream {
172 state: IncomingBodyStreamState,
173 buffer: Bytes,
174 error: Option<wasmtime::Error>,
175}
176
177impl HostIncomingBodyStream {
178 fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
179 match frame {
180 Some(Ok(frame)) => match frame.into_data() {
181 Ok(bytes) => {
184 assert!(self.buffer.is_empty());
185 self.buffer = bytes;
186 }
187
188 Err(trailers) => {
192 let trailers = trailers.into_trailers().unwrap();
193 let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
194 IncomingBodyStreamState::Open { body: _, tx } => tx,
195 IncomingBodyStreamState::Closed => unreachable!(),
196 };
197
198 let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
201 }
202 },
203
204 Some(Err(e)) => {
208 self.error = Some(e.into());
209 self.state = IncomingBodyStreamState::Closed;
210 }
211
212 None => {
216 self.state = IncomingBodyStreamState::Closed;
217 }
218 }
219 }
220}
221
222#[derive(Debug)]
223enum IncomingBodyStreamState {
224 Open {
232 body: BodyWithTimeout,
233 tx: oneshot::Sender<StreamEnd>,
234 },
235
236 Closed,
239}
240
241#[async_trait::async_trait]
242impl InputStream for HostIncomingBodyStream {
243 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
244 loop {
245 if !self.buffer.is_empty() {
247 let len = size.min(self.buffer.len());
248 let chunk = self.buffer.split_to(len);
249 return Ok(chunk);
250 }
251
252 if let Some(e) = self.error.take() {
253 return Err(StreamError::LastOperationFailed(e));
254 }
255
256 let body = match &mut self.state {
262 IncomingBodyStreamState::Open { body, .. } => body,
263 IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
264 };
265
266 let future = body.frame();
267 futures::pin_mut!(future);
268 match poll_noop(future) {
269 Some(result) => {
270 self.record_frame(result);
271 }
272 None => return Ok(Bytes::new()),
273 }
274 }
275 }
276}
277
278#[async_trait::async_trait]
279impl Pollable for HostIncomingBodyStream {
280 async fn ready(&mut self) {
281 if !self.buffer.is_empty() || self.error.is_some() {
282 return;
283 }
284
285 if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
286 let frame = body.frame().await;
287 self.record_frame(frame);
288 }
289 }
290}
291
292impl Drop for HostIncomingBodyStream {
293 fn drop(&mut self) {
294 let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
300 if let IncomingBodyStreamState::Open { body, tx } = prev {
301 let _ = tx.send(StreamEnd::Remaining(body));
302 }
303 }
304}
305
306#[derive(Debug)]
308pub enum HostFutureTrailers {
309 Waiting(HostIncomingBody),
323
324 Done(Result<Option<FieldMap>, types::ErrorCode>),
329
330 Consumed,
332}
333
334#[async_trait::async_trait]
335impl Pollable for HostFutureTrailers {
336 async fn ready(&mut self) {
337 let body = match self {
338 HostFutureTrailers::Waiting(body) => body,
339 HostFutureTrailers::Done(_) => return,
340 HostFutureTrailers::Consumed => return,
341 };
342
343 if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
346 match rx.await {
347 Ok(StreamEnd::Trailers(Some(t))) => {
350 *self = Self::Done(Ok(Some(FieldMap::new(t, body.field_size_limit))));
351 }
352 Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
355
356 Ok(StreamEnd::Trailers(None)) | Err(_) => {
358 *self = HostFutureTrailers::Done(Ok(None));
359 }
360 }
361 }
362
363 let body = match self {
366 HostFutureTrailers::Waiting(body) => body,
367 HostFutureTrailers::Done(_) => return,
368 HostFutureTrailers::Consumed => return,
369 };
370 let hyper_body = match &mut body.body {
371 IncomingBodyState::Start(body) => body,
372 IncomingBodyState::InBodyStream(_) => unreachable!(),
373 };
374 let result = loop {
375 match hyper_body.frame().await {
376 None => break Ok(None),
377 Some(Err(e)) => break Err(e),
378 Some(Ok(frame)) => {
379 if let Ok(header_map) = frame.into_trailers() {
382 break Ok(Some(FieldMap::new(header_map, body.field_size_limit)));
383 }
384 }
385 }
386 };
387 *self = HostFutureTrailers::Done(result);
388 }
389}
390
391#[derive(Debug, Clone)]
392struct WrittenState {
393 expected: u64,
394 written: Arc<std::sync::atomic::AtomicU64>,
395}
396
397impl WrittenState {
398 fn new(expected_size: u64) -> Self {
399 Self {
400 expected: expected_size,
401 written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
402 }
403 }
404
405 fn written(&self) -> u64 {
407 self.written.load(std::sync::atomic::Ordering::Relaxed)
408 }
409
410 fn update(&self, len: usize) -> bool {
413 let len = len as u64;
414 let old = self
415 .written
416 .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
417 old + len <= self.expected
418 }
419}
420
421pub struct HostOutgoingBody {
423 body_output_stream: Option<Box<dyn OutputStream>>,
425 context: StreamContext,
426 written: Option<WrittenState>,
427 finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
428}
429
430impl HostOutgoingBody {
431 pub fn new(
433 context: StreamContext,
434 size: Option<u64>,
435 buffer_chunks: usize,
436 chunk_size: usize,
437 ) -> (Self, HyperOutgoingBody) {
438 assert!(buffer_chunks >= 1);
439
440 let written = size.map(WrittenState::new);
441
442 use tokio::sync::oneshot::error::RecvError;
443 struct BodyImpl {
444 body_receiver: mpsc::Receiver<Bytes>,
445 finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
446 }
447 impl Body for BodyImpl {
448 type Data = Bytes;
449 type Error = types::ErrorCode;
450 fn poll_frame(
451 mut self: Pin<&mut Self>,
452 cx: &mut Context<'_>,
453 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
454 match self.as_mut().body_receiver.poll_recv(cx) {
455 Poll::Pending => Poll::Pending,
456 Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
457
458 Poll::Ready(None) => {
460 if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
461 match Pin::new(&mut finish_receiver).poll(cx) {
462 Poll::Pending => {
463 self.as_mut().finish_receiver = Some(finish_receiver);
464 Poll::Pending
465 }
466 Poll::Ready(Ok(message)) => match message {
467 FinishMessage::Finished => Poll::Ready(None),
468 FinishMessage::Trailers(trailers) => {
469 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
470 }
471 FinishMessage::Abort => {
472 Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
473 }
474 },
475 Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
476 }
477 } else {
478 Poll::Ready(None)
479 }
480 }
481 }
482 }
483 }
484
485 let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
487 let (finish_sender, finish_receiver) = oneshot::channel();
488 let body_impl = BodyImpl {
489 body_receiver,
490 finish_receiver: Some(finish_receiver),
491 }
492 .boxed_unsync();
493
494 let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
495
496 (
497 Self {
498 body_output_stream: Some(Box::new(output_stream)),
499 context,
500 written,
501 finish_sender: Some(finish_sender),
502 },
503 body_impl,
504 )
505 }
506
507 pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
509 self.body_output_stream.take()
510 }
511
512 pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
514 drop(self.body_output_stream);
517
518 let sender = self
519 .finish_sender
520 .take()
521 .expect("outgoing-body trailer_sender consumed by a non-owning function");
522
523 if let Some(w) = self.written {
524 let written = w.written();
525 if written != w.expected {
526 let _ = sender.send(FinishMessage::Abort);
527 return Err(self.context.as_body_size_error(written));
528 }
529 }
530
531 let message = if let Some(ts) = trailers {
532 FinishMessage::Trailers(ts.into_inner())
533 } else {
534 FinishMessage::Finished
535 };
536
537 let _ = sender.send(message);
539
540 Ok(())
541 }
542
543 pub fn abort(mut self) {
545 drop(self.body_output_stream);
548
549 let sender = self
550 .finish_sender
551 .take()
552 .expect("outgoing-body trailer_sender consumed by a non-owning function");
553
554 let _ = sender.send(FinishMessage::Abort);
555 }
556}
557
558#[derive(Debug)]
560enum FinishMessage {
561 Finished,
562 Trailers(hyper::HeaderMap),
563 Abort,
564}
565
566#[derive(Clone, Copy, Debug, Eq, PartialEq)]
568pub enum StreamContext {
569 Request,
571 Response,
573}
574
575impl StreamContext {
576 pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
578 match self {
579 StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
580 StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
581 }
582 }
583}
584
585#[derive(Debug)]
587struct BodyWriteStream {
588 context: StreamContext,
589 writer: mpsc::Sender<Bytes>,
590 write_budget: usize,
591 written: Option<WrittenState>,
592}
593
594impl BodyWriteStream {
595 fn new(
597 context: StreamContext,
598 write_budget: usize,
599 writer: mpsc::Sender<Bytes>,
600 written: Option<WrittenState>,
601 ) -> Self {
602 assert!(writer.max_capacity() >= 1);
604 BodyWriteStream {
605 context,
606 writer,
607 write_budget,
608 written,
609 }
610 }
611}
612
613#[async_trait::async_trait]
614impl OutputStream for BodyWriteStream {
615 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
616 let len = bytes.len();
617 match self.writer.try_send(bytes) {
618 Ok(()) => {
621 if let Some(written) = self.written.as_ref() {
622 if !written.update(len) {
623 let total = written.written();
624 return Err(StreamError::LastOperationFailed(format_err!(
625 self.context.as_body_size_error(total)
626 )));
627 }
628 }
629
630 Ok(())
631 }
632
633 Err(mpsc::error::TrySendError::Full(_)) => {
637 Err(StreamError::Trap(format_err!("write exceeded budget")))
638 }
639
640 Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
642 }
643 }
644
645 fn flush(&mut self) -> Result<(), StreamError> {
646 if self.writer.is_closed() {
649 Err(StreamError::Closed)
650 } else {
651 Ok(())
652 }
653 }
654
655 fn check_write(&mut self) -> Result<usize, StreamError> {
656 if self.writer.is_closed() {
657 Err(StreamError::Closed)
658 } else if self.writer.capacity() == 0 {
659 Ok(0)
667 } else {
668 Ok(self.write_budget)
669 }
670 }
671}
672
673#[async_trait::async_trait]
674impl Pollable for BodyWriteStream {
675 async fn ready(&mut self) {
676 let _ = self.writer.reserve().await;
680 }
681}