1use crate::{
5 bindings::http::types::{self, ErrorCode, Method, Scheme},
6 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
7};
8use bytes::Bytes;
9use http::header::{HeaderMap, HeaderName, HeaderValue};
10use http_body_util::BodyExt;
11use hyper::body::Body;
12use std::any::Any;
13use std::fmt;
14use std::time::Duration;
15use wasmtime::component::{Resource, ResourceTable};
16use wasmtime::{Result, bail};
17use wasmtime_wasi::p2::Pollable;
18use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
19
20#[cfg(feature = "default-send-request")]
21use {
22 crate::io::TokioIo,
23 crate::{error::dns_error, hyper_request_error},
24 tokio::net::TcpStream,
25 tokio::time::timeout,
26};
27
28const DEFAULT_FIELD_SIZE_LIMIT: usize = 128 * 1024;
37
38#[derive(Debug)]
40pub struct WasiHttpCtx {
41 pub(crate) field_size_limit: usize,
42}
43
44impl WasiHttpCtx {
45 pub fn new() -> Self {
47 Self {
48 field_size_limit: DEFAULT_FIELD_SIZE_LIMIT,
49 }
50 }
51
52 pub fn set_field_size_limit(&mut self, limit: usize) {
60 self.field_size_limit = limit;
61 }
62}
63
64pub trait WasiHttpView {
106 fn ctx(&mut self) -> &mut WasiHttpCtx;
108
109 fn table(&mut self) -> &mut ResourceTable;
111
112 fn new_incoming_request<B>(
114 &mut self,
115 scheme: Scheme,
116 req: hyper::Request<B>,
117 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
118 where
119 B: Body<Data = Bytes> + Send + 'static,
120 B::Error: Into<ErrorCode>,
121 Self: Sized,
122 {
123 let field_size_limit = self.ctx().field_size_limit;
124 let (parts, body) = req.into_parts();
125 let body = body.map_err(Into::into).boxed_unsync();
126 let body = HostIncomingBody::new(
127 body,
128 std::time::Duration::from_millis(600 * 1000),
130 field_size_limit,
131 );
132 let incoming_req =
133 HostIncomingRequest::new(self, parts, scheme, Some(body), field_size_limit)?;
134 Ok(self.table().push(incoming_req)?)
135 }
136
137 fn new_response_outparam(
139 &mut self,
140 result: tokio::sync::oneshot::Sender<
141 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
142 >,
143 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
144 let id = self.table().push(HostResponseOutparam { result })?;
145 Ok(id)
146 }
147
148 #[cfg(feature = "default-send-request")]
150 fn send_request(
151 &mut self,
152 request: hyper::Request<HyperOutgoingBody>,
153 config: OutgoingRequestConfig,
154 ) -> crate::HttpResult<HostFutureIncomingResponse> {
155 Ok(default_send_request(request, config))
156 }
157
158 #[cfg(not(feature = "default-send-request"))]
160 fn send_request(
161 &mut self,
162 request: hyper::Request<HyperOutgoingBody>,
163 config: OutgoingRequestConfig,
164 ) -> crate::HttpResult<HostFutureIncomingResponse>;
165
166 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
168 DEFAULT_FORBIDDEN_HEADERS.contains(name)
169 }
170
171 fn outgoing_body_buffer_chunks(&mut self) -> usize {
175 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
176 }
177
178 fn outgoing_body_chunk_size(&mut self) -> usize {
181 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
182 }
183}
184
185pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
187pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
189
190impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
191 fn ctx(&mut self) -> &mut WasiHttpCtx {
192 T::ctx(self)
193 }
194
195 fn table(&mut self) -> &mut ResourceTable {
196 T::table(self)
197 }
198
199 fn new_response_outparam(
200 &mut self,
201 result: tokio::sync::oneshot::Sender<
202 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
203 >,
204 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
205 T::new_response_outparam(self, result)
206 }
207
208 fn send_request(
209 &mut self,
210 request: hyper::Request<HyperOutgoingBody>,
211 config: OutgoingRequestConfig,
212 ) -> crate::HttpResult<HostFutureIncomingResponse> {
213 T::send_request(self, request, config)
214 }
215
216 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
217 T::is_forbidden_header(self, name)
218 }
219
220 fn outgoing_body_buffer_chunks(&mut self) -> usize {
221 T::outgoing_body_buffer_chunks(self)
222 }
223
224 fn outgoing_body_chunk_size(&mut self) -> usize {
225 T::outgoing_body_chunk_size(self)
226 }
227}
228
229impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
230 fn ctx(&mut self) -> &mut WasiHttpCtx {
231 T::ctx(self)
232 }
233
234 fn table(&mut self) -> &mut ResourceTable {
235 T::table(self)
236 }
237
238 fn new_response_outparam(
239 &mut self,
240 result: tokio::sync::oneshot::Sender<
241 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
242 >,
243 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
244 T::new_response_outparam(self, result)
245 }
246
247 fn send_request(
248 &mut self,
249 request: hyper::Request<HyperOutgoingBody>,
250 config: OutgoingRequestConfig,
251 ) -> crate::HttpResult<HostFutureIncomingResponse> {
252 T::send_request(self, request, config)
253 }
254
255 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
256 T::is_forbidden_header(self, name)
257 }
258
259 fn outgoing_body_buffer_chunks(&mut self) -> usize {
260 T::outgoing_body_buffer_chunks(self)
261 }
262
263 fn outgoing_body_chunk_size(&mut self) -> usize {
264 T::outgoing_body_chunk_size(self)
265 }
266}
267
268#[repr(transparent)]
281pub struct WasiHttpImpl<T>(pub T);
282
283impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
284 fn ctx(&mut self) -> &mut WasiHttpCtx {
285 self.0.ctx()
286 }
287
288 fn table(&mut self) -> &mut ResourceTable {
289 self.0.table()
290 }
291
292 fn new_response_outparam(
293 &mut self,
294 result: tokio::sync::oneshot::Sender<
295 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
296 >,
297 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
298 self.0.new_response_outparam(result)
299 }
300
301 fn send_request(
302 &mut self,
303 request: hyper::Request<HyperOutgoingBody>,
304 config: OutgoingRequestConfig,
305 ) -> crate::HttpResult<HostFutureIncomingResponse> {
306 self.0.send_request(request, config)
307 }
308
309 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
310 self.0.is_forbidden_header(name)
311 }
312
313 fn outgoing_body_buffer_chunks(&mut self) -> usize {
314 self.0.outgoing_body_buffer_chunks()
315 }
316
317 fn outgoing_body_chunk_size(&mut self) -> usize {
318 self.0.outgoing_body_chunk_size()
319 }
320}
321
322pub const DEFAULT_FORBIDDEN_HEADERS: [http::header::HeaderName; 9] = [
325 hyper::header::CONNECTION,
326 HeaderName::from_static("keep-alive"),
327 hyper::header::PROXY_AUTHENTICATE,
328 hyper::header::PROXY_AUTHORIZATION,
329 HeaderName::from_static("proxy-connection"),
330 hyper::header::TRANSFER_ENCODING,
331 hyper::header::UPGRADE,
332 hyper::header::HOST,
333 HeaderName::from_static("http2-settings"),
334];
335
336pub(crate) fn remove_forbidden_headers(view: &mut dyn WasiHttpView, headers: &mut FieldMap) {
338 let forbidden_keys = Vec::from_iter(headers.as_ref().keys().filter_map(|name| {
339 if view.is_forbidden_header(name) {
340 Some(name.clone())
341 } else {
342 None
343 }
344 }));
345
346 for name in forbidden_keys {
347 headers.remove_all(&name);
348 }
349}
350
351pub struct OutgoingRequestConfig {
353 pub use_tls: bool,
355 pub connect_timeout: Duration,
357 pub first_byte_timeout: Duration,
359 pub between_bytes_timeout: Duration,
361}
362
363#[cfg(feature = "default-send-request")]
368pub fn default_send_request(
369 request: hyper::Request<HyperOutgoingBody>,
370 config: OutgoingRequestConfig,
371) -> HostFutureIncomingResponse {
372 let handle = wasmtime_wasi::runtime::spawn(async move {
373 Ok(default_send_request_handler(request, config).await)
374 });
375 HostFutureIncomingResponse::pending(handle)
376}
377
378#[cfg(feature = "default-send-request")]
383pub async fn default_send_request_handler(
384 mut request: hyper::Request<HyperOutgoingBody>,
385 OutgoingRequestConfig {
386 use_tls,
387 connect_timeout,
388 first_byte_timeout,
389 between_bytes_timeout,
390 }: OutgoingRequestConfig,
391) -> Result<IncomingResponse, types::ErrorCode> {
392 let authority = if let Some(authority) = request.uri().authority() {
393 if authority.port().is_some() {
394 authority.to_string()
395 } else {
396 let port = if use_tls { 443 } else { 80 };
397 format!("{}:{port}", authority.to_string())
398 }
399 } else {
400 return Err(types::ErrorCode::HttpRequestUriInvalid);
401 };
402 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
403 .await
404 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
405 .map_err(|e| match e.kind() {
406 std::io::ErrorKind::AddrNotAvailable => {
407 dns_error("address not available".to_string(), 0)
408 }
409
410 _ => {
411 if e.to_string()
412 .starts_with("failed to lookup address information")
413 {
414 dns_error("address not available".to_string(), 0)
415 } else {
416 types::ErrorCode::ConnectionRefused
417 }
418 }
419 })?;
420
421 let (mut sender, worker) = if use_tls {
422 use rustls::pki_types::ServerName;
423
424 let root_cert_store = rustls::RootCertStore {
426 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
427 };
428 let config = rustls::ClientConfig::builder()
429 .with_root_certificates(root_cert_store)
430 .with_no_client_auth();
431 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
432 let mut parts = authority.split(":");
433 let host = parts.next().unwrap_or(&authority);
434 let domain = ServerName::try_from(host)
435 .map_err(|e| {
436 tracing::warn!("dns lookup error: {e:?}");
437 dns_error("invalid dns name".to_string(), 0)
438 })?
439 .to_owned();
440 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
441 tracing::warn!("tls protocol error: {e:?}");
442 types::ErrorCode::TlsProtocolError
443 })?;
444 let stream = TokioIo::new(stream);
445
446 let (sender, conn) = timeout(
447 connect_timeout,
448 hyper::client::conn::http1::handshake(stream),
449 )
450 .await
451 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
452 .map_err(hyper_request_error)?;
453
454 let worker = wasmtime_wasi::runtime::spawn(async move {
455 match conn.await {
456 Ok(()) => {}
457 Err(e) => tracing::warn!("dropping error {e}"),
460 }
461 });
462
463 (sender, worker)
464 } else {
465 let tcp_stream = TokioIo::new(tcp_stream);
466 let (sender, conn) = timeout(
467 connect_timeout,
468 hyper::client::conn::http1::handshake(tcp_stream),
470 )
471 .await
472 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
473 .map_err(hyper_request_error)?;
474
475 let worker = wasmtime_wasi::runtime::spawn(async move {
476 match conn.await {
477 Ok(()) => {}
478 Err(e) => tracing::warn!("dropping error {e}"),
480 }
481 });
482
483 (sender, worker)
484 };
485
486 *request.uri_mut() = http::Uri::builder()
490 .path_and_query(
491 request
492 .uri()
493 .path_and_query()
494 .map(|p| p.as_str())
495 .unwrap_or("/"),
496 )
497 .build()
498 .expect("comes from valid request");
499
500 let resp = timeout(first_byte_timeout, sender.send_request(request))
501 .await
502 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
503 .map_err(hyper_request_error)?
504 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
505
506 Ok(IncomingResponse {
507 resp,
508 worker: Some(worker),
509 between_bytes_timeout,
510 })
511}
512
513impl From<http::Method> for types::Method {
514 fn from(method: http::Method) -> Self {
515 if method == http::Method::GET {
516 types::Method::Get
517 } else if method == hyper::Method::HEAD {
518 types::Method::Head
519 } else if method == hyper::Method::POST {
520 types::Method::Post
521 } else if method == hyper::Method::PUT {
522 types::Method::Put
523 } else if method == hyper::Method::DELETE {
524 types::Method::Delete
525 } else if method == hyper::Method::CONNECT {
526 types::Method::Connect
527 } else if method == hyper::Method::OPTIONS {
528 types::Method::Options
529 } else if method == hyper::Method::TRACE {
530 types::Method::Trace
531 } else if method == hyper::Method::PATCH {
532 types::Method::Patch
533 } else {
534 types::Method::Other(method.to_string())
535 }
536 }
537}
538
539impl TryInto<http::Method> for types::Method {
540 type Error = http::method::InvalidMethod;
541
542 fn try_into(self) -> Result<http::Method, Self::Error> {
543 match self {
544 Method::Get => Ok(http::Method::GET),
545 Method::Head => Ok(http::Method::HEAD),
546 Method::Post => Ok(http::Method::POST),
547 Method::Put => Ok(http::Method::PUT),
548 Method::Delete => Ok(http::Method::DELETE),
549 Method::Connect => Ok(http::Method::CONNECT),
550 Method::Options => Ok(http::Method::OPTIONS),
551 Method::Trace => Ok(http::Method::TRACE),
552 Method::Patch => Ok(http::Method::PATCH),
553 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
554 }
555 }
556}
557
558#[derive(Debug)]
560pub struct HostIncomingRequest {
561 pub(crate) method: http::method::Method,
562 pub(crate) uri: http::uri::Uri,
563 pub(crate) headers: FieldMap,
564 pub(crate) scheme: Scheme,
565 pub(crate) authority: String,
566 pub body: Option<HostIncomingBody>,
568}
569
570impl HostIncomingRequest {
571 pub fn new(
573 view: &mut dyn WasiHttpView,
574 parts: http::request::Parts,
575 scheme: Scheme,
576 body: Option<HostIncomingBody>,
577 field_size_limit: usize,
578 ) -> wasmtime::Result<Self> {
579 let authority = match parts.uri.authority() {
580 Some(authority) => authority.to_string(),
581 None => match parts.headers.get(http::header::HOST) {
582 Some(host) => host.to_str()?.to_string(),
583 None => bail!("invalid HTTP request missing authority in URI and host header"),
584 },
585 };
586
587 let mut headers = FieldMap::new(parts.headers, field_size_limit);
588 remove_forbidden_headers(view, &mut headers);
589
590 Ok(Self {
591 method: parts.method,
592 uri: parts.uri,
593 headers,
594 authority,
595 scheme,
596 body,
597 })
598 }
599}
600
601pub struct HostResponseOutparam {
603 pub result:
605 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
606}
607
608pub struct HostOutgoingResponse {
610 pub status: http::StatusCode,
612 pub headers: FieldMap,
614 pub body: Option<HyperOutgoingBody>,
616}
617
618impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
619 type Error = http::Error;
620
621 fn try_from(
622 resp: HostOutgoingResponse,
623 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
624 use http_body_util::Empty;
625
626 let mut builder = hyper::Response::builder().status(resp.status);
627
628 *builder.headers_mut().unwrap() = resp.headers.map;
629
630 match resp.body {
631 Some(body) => builder.body(body),
632 None => builder.body(
633 Empty::<bytes::Bytes>::new()
634 .map_err(|_| unreachable!("Infallible error"))
635 .boxed_unsync(),
636 ),
637 }
638 }
639}
640
641#[derive(Debug)]
643pub struct HostOutgoingRequest {
644 pub method: Method,
646 pub scheme: Option<Scheme>,
648 pub authority: Option<String>,
650 pub path_with_query: Option<String>,
652 pub headers: FieldMap,
654 pub body: Option<HyperOutgoingBody>,
656}
657
658#[derive(Debug, Default)]
660pub struct HostRequestOptions {
661 pub connect_timeout: Option<std::time::Duration>,
663 pub first_byte_timeout: Option<std::time::Duration>,
665 pub between_bytes_timeout: Option<std::time::Duration>,
667}
668
669#[derive(Debug)]
671pub struct HostIncomingResponse {
672 pub status: u16,
674 pub headers: FieldMap,
676 pub body: Option<HostIncomingBody>,
678}
679
680#[derive(Debug)]
682pub enum HostFields {
683 Ref {
685 parent: u32,
687
688 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
694 },
695 Owned {
697 fields: FieldMap,
699 },
700}
701
702#[derive(Debug, Clone)]
705pub struct FieldMap {
706 map: HeaderMap,
707 limit: usize,
708 size: usize,
709}
710
711#[derive(Debug)]
713pub struct FieldSizeLimitError {
714 pub(crate) size: usize,
716 pub(crate) limit: usize,
718}
719impl fmt::Display for FieldSizeLimitError {
720 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
721 write!(f, "Field size limit {} exceeded: {}", self.limit, self.size)
722 }
723}
724impl std::error::Error for FieldSizeLimitError {}
725
726impl FieldMap {
727 pub fn new(map: HeaderMap, limit: usize) -> Self {
733 let size = Self::content_size(&map);
734 Self { map, size, limit }
735 }
736 pub fn empty(limit: usize) -> Self {
738 Self {
739 map: HeaderMap::new(),
740 size: 0,
741 limit,
742 }
743 }
744 pub fn into_inner(self) -> HeaderMap {
746 self.map
747 }
748 pub(crate) fn content_size(map: &HeaderMap) -> usize {
751 let mut sum = 0;
752 for key in map.keys() {
753 sum += header_name_size(key);
754 }
755 for value in map.values() {
756 sum += header_value_size(value);
757 }
758 sum
759 }
760 pub fn remove_all(&mut self, key: &HeaderName) -> Vec<HeaderValue> {
764 use http::header::Entry;
765 match self.map.try_entry(key) {
766 Ok(Entry::Vacant { .. }) | Err(_) => Vec::new(),
767 Ok(Entry::Occupied(e)) => {
768 let (name, value_drain) = e.remove_entry_mult();
769 let mut removed = header_name_size(&name);
770 let values = value_drain.collect::<Vec<_>>();
771 for v in values.iter() {
772 removed += header_value_size(v);
773 }
774 self.size -= removed;
775 values
776 }
777 }
778 }
779 pub fn append(&mut self, key: &HeaderName, value: HeaderValue) -> Result<bool> {
784 let key_size = header_name_size(key);
785 let val_size = header_value_size(&value);
786 let new_size = if !self.map.contains_key(key) {
787 self.size + key_size + val_size
788 } else {
789 self.size + val_size
790 };
791 if new_size > self.limit {
792 bail!(FieldSizeLimitError {
793 limit: self.limit,
794 size: new_size
795 })
796 }
797 self.size = new_size;
798 Ok(self.map.try_append(key, value)?)
799 }
800}
801
802fn header_name_size(name: &HeaderName) -> usize {
807 name.as_str().len() + size_of::<HeaderName>()
808}
809
810fn header_value_size(value: &HeaderValue) -> usize {
816 value.len() + size_of::<HeaderValue>()
817}
818
819impl AsRef<HeaderMap> for FieldMap {
822 fn as_ref(&self) -> &HeaderMap {
823 &self.map
824 }
825}
826
827pub type FutureIncomingResponseHandle =
829 AbortOnDropJoinHandle<wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>>;
830
831#[derive(Debug)]
833pub struct IncomingResponse {
834 pub resp: hyper::Response<HyperIncomingBody>,
836 pub worker: Option<AbortOnDropJoinHandle<()>>,
838 pub between_bytes_timeout: std::time::Duration,
840}
841
842#[derive(Debug)]
844pub enum HostFutureIncomingResponse {
845 Pending(FutureIncomingResponseHandle),
847 Ready(wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>),
851 Consumed,
853}
854
855impl HostFutureIncomingResponse {
856 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
858 Self::Pending(handle)
859 }
860
861 pub fn ready(result: wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
863 Self::Ready(result)
864 }
865
866 pub fn is_ready(&self) -> bool {
868 matches!(self, Self::Ready(_))
869 }
870
871 pub fn unwrap_ready(self) -> wasmtime::Result<Result<IncomingResponse, types::ErrorCode>> {
873 match self {
874 Self::Ready(res) => res,
875 Self::Pending(_) | Self::Consumed => {
876 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
877 }
878 }
879 }
880}
881
882#[async_trait::async_trait]
883impl Pollable for HostFutureIncomingResponse {
884 async fn ready(&mut self) {
885 if let Self::Pending(handle) = self {
886 *self = Self::Ready(handle.await);
887 }
888 }
889}