wasmcloud_provider_http_server/
lib.rs1use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_core::http::{load_settings, ServiceSettings};
43use wasmcloud_provider_sdk::provider::WrpcClient;
44use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
45use wrpc_interface_http::InvokeIncomingHandler as _;
46
47mod address;
48mod host;
49mod path;
50
51pub async fn run() -> anyhow::Result<()> {
52 initialize_observability!(
53 "http-server-provider",
54 std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55 );
56
57 let host_data = load_host_data().context("failed to load host data")?;
58 match host_data.config.get("routing_mode").map(String::as_str) {
59 Some("address") | None => run_provider(
61 address::HttpServerProvider::new(host_data).context(
62 "failed to create address-mode HTTP server provider from hostdata configuration",
63 )?,
64 "http-server-provider",
65 )
66 .await?
67 .await,
68 Some("path") => {
70 run_provider(
71 path::HttpServerProvider::new(host_data).await.context(
72 "failed to create path-mode HTTP server provider from hostdata configuration",
73 )?,
74 "http-server-provider",
75 )
76 .await?
77 .await;
78 }
79 Some("host") => {
80 run_provider(
81 host::HttpServerProvider::new(host_data).await.context(
82 "failed to create host-mode HTTP server provider from hostdata configuration",
83 )?,
84 "http-server-provider",
85 )
86 .await?
87 .await;
88 }
89 Some(other) => bail!("unknown routing_mode: {other}"),
90 };
91
92 Ok(())
93}
94
95pub(crate) fn build_request(
97 request: extract::Request,
98 scheme: http::uri::Scheme,
99 authority: String,
100 settings: &ServiceSettings,
101) -> Result<http::Request<axum::body::Body>, axum::response::ErrorResponse> {
102 let method = request.method();
103 if let Some(readonly_mode) = settings.readonly_mode {
104 if readonly_mode
105 && method != http::method::Method::GET
106 && method != http::method::Method::HEAD
107 {
108 debug!("only GET and HEAD allowed in read-only mode");
109 Err((
110 http::StatusCode::METHOD_NOT_ALLOWED,
111 "only GET and HEAD allowed in read-only mode",
112 ))?;
113 }
114 }
115 let (
116 http::request::Parts {
117 method,
118 uri,
119 headers,
120 ..
121 },
122 body,
123 ) = request.into_parts();
124 let http::uri::Parts { path_and_query, .. } = uri.into_parts();
125
126 let mut uri = http::Uri::builder().scheme(scheme);
127 if !authority.is_empty() {
128 uri = uri.authority(authority);
129 }
130 if let Some(path_and_query) = path_and_query {
131 uri = uri.path_and_query(path_and_query);
132 }
133 let uri = uri
134 .build()
135 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
136 let mut req = http::Request::builder();
137 *req.headers_mut().ok_or((
138 http::StatusCode::INTERNAL_SERVER_ERROR,
139 "invalid request generated",
140 ))? = headers;
141 let req = req
142 .uri(uri)
143 .method(method)
144 .body(body)
145 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
146
147 Ok(req)
148}
149
150pub(crate) async fn invoke_component(
152 wrpc: &WrpcClient,
153 target: &str,
154 req: http::Request<axum::body::Body>,
155 timeout: Option<Duration>,
156 cache_control: Option<&String>,
157) -> impl axum::response::IntoResponse {
158 let mut cx = async_nats::HeaderMap::new();
160 for (k, v) in
161 wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
162 &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
163 )
164 .iter()
165 {
166 cx.insert(k.as_str(), v.as_str());
167 }
168
169 trace!(?req, component_id = target, "httpserver calling component");
170 let fut = wrpc.invoke_handle_http(Some(cx), req);
171 let res = if let Some(timeout) = timeout {
172 let Ok(res) = time::timeout(timeout, fut).await else {
173 Err(http::StatusCode::REQUEST_TIMEOUT)?
174 };
175 res
176 } else {
177 fut.await
178 };
179 let (res, errors, io) =
180 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
181 let io = io.map(spawn);
182 let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
183 let mut res =
185 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
186 if let Some(cache_control) = cache_control {
187 let cache_control = http::HeaderValue::from_str(cache_control)
188 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
189 res.headers_mut().append("Cache-Control", cache_control);
190 };
191 axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
192 body,
193 errors,
194 io,
195 }))
196}
197
198pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
200 let allow_origin = settings.cors_allowed_origins.as_ref();
201 let allow_origin: Vec<_> = allow_origin
202 .map(|origins| {
203 origins
204 .iter()
205 .map(AsRef::as_ref)
206 .map(http::HeaderValue::from_str)
207 .collect::<Result<_, _>>()
208 .context("failed to parse allowed origins")
209 })
210 .transpose()?
211 .unwrap_or_default();
212 let allow_origin = if allow_origin.is_empty() {
213 cors::AllowOrigin::any()
214 } else {
215 cors::AllowOrigin::list(allow_origin)
216 };
217 let allow_headers = settings.cors_allowed_headers.as_ref();
218 let allow_headers: Vec<_> = allow_headers
219 .map(|headers| {
220 headers
221 .iter()
222 .map(AsRef::as_ref)
223 .map(http::HeaderName::from_str)
224 .collect::<Result<_, _>>()
225 .context("failed to parse allowed header names")
226 })
227 .transpose()?
228 .unwrap_or_default();
229 let allow_headers = if allow_headers.is_empty() {
230 cors::AllowHeaders::any()
231 } else {
232 cors::AllowHeaders::list(allow_headers)
233 };
234 let allow_methods = settings.cors_allowed_methods.as_ref();
235 let allow_methods: Vec<_> = allow_methods
236 .map(|methods| {
237 methods
238 .iter()
239 .map(AsRef::as_ref)
240 .map(http::Method::from_str)
241 .collect::<Result<_, _>>()
242 .context("failed to parse allowed methods")
243 })
244 .transpose()?
245 .unwrap_or_default();
246 let allow_methods = if allow_methods.is_empty() {
247 cors::AllowMethods::any()
248 } else {
249 cors::AllowMethods::list(allow_methods)
250 };
251 let expose_headers = settings.cors_exposed_headers.as_ref();
252 let expose_headers: Vec<_> = expose_headers
253 .map(|headers| {
254 headers
255 .iter()
256 .map(AsRef::as_ref)
257 .map(http::HeaderName::from_str)
258 .collect::<Result<_, _>>()
259 .context("failed to parse exposeed header names")
260 })
261 .transpose()?
262 .unwrap_or_default();
263 let expose_headers = if expose_headers.is_empty() {
264 cors::ExposeHeaders::any()
265 } else {
266 cors::ExposeHeaders::list(expose_headers)
267 };
268 let mut cors = CorsLayer::new()
269 .allow_origin(allow_origin)
270 .allow_headers(allow_headers)
271 .allow_methods(allow_methods)
272 .expose_headers(expose_headers);
273 if let Some(max_age) = settings.cors_max_age_secs {
274 cors = cors.max_age(Duration::from_secs(max_age));
275 }
276
277 Ok(cors)
278}
279
280pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
285 let socket = match &settings.address {
286 SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
287 SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
288 }
289 .context("Unable to open socket")?;
290 socket
295 .set_reuseaddr(!cfg!(windows))
296 .context("Error when setting socket to reuseaddr")?;
297 socket
298 .set_nodelay(true)
299 .context("failed to set `TCP_NODELAY`")?;
300
301 match settings.disable_keepalive {
302 Some(false) => {
303 info!("disabling TCP keepalive");
304 socket
305 .set_keepalive(false)
306 .context("failed to disable TCP keepalive")?
307 }
308 None | Some(true) => socket
309 .set_keepalive(true)
310 .context("failed to enable TCP keepalive")?,
311 }
312
313 socket
314 .bind(settings.address)
315 .context("Unable to bind to address")?;
316 let listener = socket.listen(1024).context("unable to listen on socket")?;
317 let listener = listener.into_std().context("Unable to get listener")?;
318
319 Ok(listener)
320}
321
322pin_project! {
323 struct ResponseBody {
324 #[pin]
325 body: wrpc_interface_http::HttpBody,
326 #[pin]
327 errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
328 #[pin]
329 io: Option<JoinHandle<anyhow::Result<()>>>,
330 }
331}
332
333impl http_body::Body for ResponseBody {
334 type Data = Bytes;
335 type Error = anyhow::Error;
336
337 fn poll_frame(
338 mut self: Pin<&mut Self>,
339 cx: &mut Context<'_>,
340 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
341 let mut this = self.as_mut().project();
342 if let Some(io) = this.io.as_mut().as_pin_mut() {
343 match io.poll(cx) {
344 Poll::Ready(Ok(Ok(()))) => {
345 this.io.take();
346 }
347 Poll::Ready(Ok(Err(err))) => {
348 return Poll::Ready(Some(Err(
349 anyhow!(err).context("failed to complete async I/O")
350 )))
351 }
352 Poll::Ready(Err(err)) => {
353 return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
354 }
355 Poll::Pending => {}
356 }
357 }
358 match this.errors.poll_next(cx) {
359 Poll::Ready(Some(err)) => {
360 if let Some(io) = this.io.as_pin_mut() {
361 io.abort();
362 }
363 return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
364 }
365 Poll::Ready(None) | Poll::Pending => {}
366 }
367 match ready!(this.body.poll_frame(cx)) {
368 Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
369 Some(Err(err)) => {
370 if let Some(io) = this.io.as_pin_mut() {
371 io.abort();
372 }
373 Poll::Ready(Some(Err(err)))
374 }
375 None => {
376 if let Some(io) = this.io.as_pin_mut() {
377 io.abort();
378 }
379 Poll::Ready(None)
380 }
381 }
382 }
383}
384
385#[cfg(test)]
386mod test {
387 use std::collections::HashMap;
388
389 use anyhow::Result;
390 use futures::StreamExt;
391 use wasmcloud_provider_sdk::{
392 provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
393 };
394 use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
395
396 use crate::{address, path};
397
398 #[ignore]
401 #[tokio::test]
402 async fn can_listen_and_invoke_with_timeout() -> Result<()> {
403 let nats_container = NatsServer::default()
404 .start()
405 .await
406 .expect("failed to start nats-server container");
407 let nats_port = nats_container
408 .get_host_port_ipv4(4222)
409 .await
410 .expect("should be able to find the NATS port");
411 let nats_address = format!("nats://127.0.0.1:{nats_port}");
412
413 let default_address = "0.0.0.0:8080";
414 let host_data = HostData {
415 lattice_rpc_url: nats_address.clone(),
416 lattice_rpc_prefix: "lattice".to_string(),
417 provider_key: "http-server-provider-test".to_string(),
418 config: std::collections::HashMap::from([
419 ("default_address".to_string(), default_address.to_string()),
420 ("routing_mode".to_string(), "address".to_string()),
421 ]),
422 link_definitions: vec![InterfaceLinkDefinition {
423 source_id: "http-server-provider-test".to_string(),
424 target: "test-component".to_string(),
425 name: "default".to_string(),
426 wit_namespace: "wasi".to_string(),
427 wit_package: "http".to_string(),
428 interfaces: vec!["incoming-handler".to_string()],
429 source_config: std::collections::HashMap::from([(
430 "timeout_ms".to_string(),
431 "100".to_string(),
432 )]),
433 target_config: HashMap::new(),
434 source_secrets: None,
435 target_secrets: None,
436 }],
437 ..Default::default()
438 };
439 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
440
441 let provider = run_provider(
442 address::HttpServerProvider::new(&host_data)
443 .expect("should be able to create provider"),
444 "http-server-provider-test",
445 )
446 .await
447 .expect("should be able to run provider");
448
449 let conn = async_nats::connect(nats_address)
451 .await
452 .expect("should be able to connect");
453 let mut subscriber = conn
454 .subscribe("lattice.test-component.wrpc.>")
455 .await
456 .expect("should be able to subscribe");
457
458 let provider_handle = tokio::spawn(provider);
459
460 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
462 let resp = reqwest::get("http://127.0.0.1:8080")
463 .await
464 .expect("should be able to make request");
465
466 assert_eq!(resp.status(), 408);
468 let msg = subscriber
470 .next()
471 .await
472 .expect("should be able to get a message");
473 assert!(msg.subject.contains("test-component"));
474 provider_handle.abort();
475 let _ = nats_container.stop().await;
476
477 Ok(())
478 }
479
480 #[ignore]
483 #[tokio::test]
484 async fn can_support_path_based_routing() -> Result<()> {
485 let nats_container = NatsServer::default()
486 .start()
487 .await
488 .expect("failed to start nats-server container");
489 let nats_port = nats_container
490 .get_host_port_ipv4(4222)
491 .await
492 .expect("should be able to find the NATS port");
493 let nats_address = format!("nats://127.0.0.1:{nats_port}");
494
495 let default_address = "0.0.0.0:8081";
496 let host_data = HostData {
497 lattice_rpc_url: nats_address.clone(),
498 lattice_rpc_prefix: "lattice".to_string(),
499 provider_key: "http-server-provider-test".to_string(),
500 config: std::collections::HashMap::from([
501 ("default_address".to_string(), default_address.to_string()),
502 ("routing_mode".to_string(), "path".to_string()),
503 ("timeout_ms".to_string(), "100".to_string()),
504 ]),
505 link_definitions: vec![
506 InterfaceLinkDefinition {
507 source_id: "http-server-provider-test".to_string(),
508 target: "test-component-one".to_string(),
509 name: "default".to_string(),
510 wit_namespace: "wasi".to_string(),
511 wit_package: "http".to_string(),
512 interfaces: vec!["incoming-handler".to_string()],
513 source_config: std::collections::HashMap::from([(
514 "path".to_string(),
515 "/foo".to_string(),
516 )]),
517 target_config: HashMap::new(),
518 source_secrets: None,
519 target_secrets: None,
520 },
521 InterfaceLinkDefinition {
522 source_id: "http-server-provider-test".to_string(),
523 target: "test-component-two".to_string(),
524 name: "default".to_string(),
525 wit_namespace: "wasi".to_string(),
526 wit_package: "http".to_string(),
527 interfaces: vec!["incoming-handler".to_string()],
528 source_config: std::collections::HashMap::from([(
529 "path".to_string(),
530 "/bar".to_string(),
531 )]),
532 target_config: HashMap::new(),
533 source_secrets: None,
534 target_secrets: None,
535 },
536 ],
537 ..Default::default()
538 };
539 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
540
541 let provider = run_provider(
542 path::HttpServerProvider::new(&host_data)
543 .await
544 .expect("should be able to create provider"),
545 "http-server-provider-test",
546 )
547 .await
548 .expect("should be able to run provider");
549
550 let conn = async_nats::connect(nats_address)
552 .await
553 .expect("should be able to connect");
554 let mut subscriber_one = conn
555 .subscribe("lattice.test-component-one.wrpc.>")
556 .await
557 .expect("should be able to subscribe");
558 let mut subscriber_two = conn
559 .subscribe("lattice.test-component-two.wrpc.>")
560 .await
561 .expect("should be able to subscribe");
562
563 let provider_handle = tokio::spawn(provider);
564 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
566
567 let resp = reqwest::get("http://127.0.0.1:8081/foo")
569 .await
570 .expect("should be able to make request");
571 assert_eq!(resp.status(), 408);
573 let msg = subscriber_one
574 .next()
575 .await
576 .expect("should be able to get a message");
577 assert!(msg.subject.contains("test-component-one"));
578
579 let resp = reqwest::get("http://127.0.0.1:8081/bar")
581 .await
582 .expect("should be able to make request");
583 assert_eq!(resp.status(), 408);
585 let msg = subscriber_two
586 .next()
587 .await
588 .expect("should be able to get a message");
589 assert!(msg.subject.contains("test-component-two"));
590
591 let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
593 .await
594 .expect("should be able to make request");
595 assert_eq!(resp.status(), 408);
597 let msg = subscriber_two
598 .next()
599 .await
600 .expect("should be able to get a message");
601 assert!(msg.subject.contains("test-component-two"));
602
603 let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
605 .await
606 .expect("should be able to make request");
607 assert_eq!(resp.status(), 404);
608
609 assert!(
612 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
613 .await
614 .is_err(),
615 );
616 assert!(
617 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
618 .await
619 .is_err(),
620 );
621
622 provider_handle.abort();
623 let _ = nats_container.stop().await;
624
625 Ok(())
626 }
627}