wasmcloud_provider_http_server/
lib.rs1use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_provider_sdk::provider::WrpcClient;
43use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
44use wrpc_interface_http::InvokeIncomingHandler as _;
45
46mod address;
47mod path;
48mod settings;
49pub use settings::{default_listen_address, load_settings, ServiceSettings};
50
51pub async fn run() -> anyhow::Result<()> {
52 initialize_observability!(
53 "http-server-provider",
54 std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55 );
56
57 let host_data = load_host_data().context("failed to load host data")?;
58 match host_data.config.get("routing_mode").map(String::as_str) {
59 Some("address") | None => run_provider(
61 address::HttpServerProvider::new(host_data).context(
62 "failed to create address-mode HTTP server provider from hostdata configuration",
63 )?,
64 "http-server-provider",
65 )
66 .await?
67 .await,
68 Some("path") => {
70 run_provider(
71 path::HttpServerProvider::new(host_data).await.context(
72 "failed to create path-mode HTTP server provider from hostdata configuration",
73 )?,
74 "http-server-provider",
75 )
76 .await?
77 .await;
78 }
79 Some(other) => bail!("unknown routing_mode: {other}"),
80 };
81
82 Ok(())
83}
84
85pub(crate) fn build_request(
87 request: extract::Request,
88 scheme: http::uri::Scheme,
89 authority: String,
90 settings: &ServiceSettings,
91) -> Result<http::Request<axum::body::Body>, axum::response::ErrorResponse> {
92 let method = request.method();
93 if let Some(readonly_mode) = settings.readonly_mode {
94 if readonly_mode
95 && method != http::method::Method::GET
96 && method != http::method::Method::HEAD
97 {
98 debug!("only GET and HEAD allowed in read-only mode");
99 Err((
100 http::StatusCode::METHOD_NOT_ALLOWED,
101 "only GET and HEAD allowed in read-only mode",
102 ))?;
103 }
104 }
105 let (
106 http::request::Parts {
107 method,
108 uri,
109 headers,
110 ..
111 },
112 body,
113 ) = request.into_parts();
114 let http::uri::Parts { path_and_query, .. } = uri.into_parts();
115
116 let mut uri = http::Uri::builder().scheme(scheme);
117 if !authority.is_empty() {
118 uri = uri.authority(authority);
119 }
120 if let Some(path_and_query) = path_and_query {
121 uri = uri.path_and_query(path_and_query);
122 }
123 let uri = uri
124 .build()
125 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
126 let mut req = http::Request::builder();
127 *req.headers_mut().ok_or((
128 http::StatusCode::INTERNAL_SERVER_ERROR,
129 "invalid request generated",
130 ))? = headers;
131 let req = req
132 .uri(uri)
133 .method(method)
134 .body(body)
135 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
136
137 Ok(req)
138}
139
140pub(crate) async fn invoke_component(
142 wrpc: &WrpcClient,
143 target: &str,
144 req: http::Request<axum::body::Body>,
145 timeout: Option<Duration>,
146 cache_control: Option<&String>,
147) -> impl axum::response::IntoResponse {
148 let mut cx = async_nats::HeaderMap::new();
150 for (k, v) in
151 wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
152 &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
153 )
154 .iter()
155 {
156 cx.insert(k.as_str(), v.as_str());
157 }
158
159 trace!(?req, component_id = target, "httpserver calling component");
160 let fut = wrpc.invoke_handle_http(Some(cx), req);
161 let res = if let Some(timeout) = timeout {
162 let Ok(res) = time::timeout(timeout, fut).await else {
163 Err(http::StatusCode::REQUEST_TIMEOUT)?
164 };
165 res
166 } else {
167 fut.await
168 };
169 let (res, errors, io) =
170 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
171 let io = io.map(spawn);
172 let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
173 let mut res =
175 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
176 if let Some(cache_control) = cache_control {
177 let cache_control = http::HeaderValue::from_str(cache_control)
178 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
179 res.headers_mut().append("Cache-Control", cache_control);
180 };
181 axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
182 body,
183 errors,
184 io,
185 }))
186}
187
188pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
190 let allow_origin = settings.cors_allowed_origins.as_ref();
191 let allow_origin: Vec<_> = allow_origin
192 .map(|origins| {
193 origins
194 .iter()
195 .map(AsRef::as_ref)
196 .map(http::HeaderValue::from_str)
197 .collect::<Result<_, _>>()
198 .context("failed to parse allowed origins")
199 })
200 .transpose()?
201 .unwrap_or_default();
202 let allow_origin = if allow_origin.is_empty() {
203 cors::AllowOrigin::any()
204 } else {
205 cors::AllowOrigin::list(allow_origin)
206 };
207 let allow_headers = settings.cors_allowed_headers.as_ref();
208 let allow_headers: Vec<_> = allow_headers
209 .map(|headers| {
210 headers
211 .iter()
212 .map(AsRef::as_ref)
213 .map(http::HeaderName::from_str)
214 .collect::<Result<_, _>>()
215 .context("failed to parse allowed header names")
216 })
217 .transpose()?
218 .unwrap_or_default();
219 let allow_headers = if allow_headers.is_empty() {
220 cors::AllowHeaders::any()
221 } else {
222 cors::AllowHeaders::list(allow_headers)
223 };
224 let allow_methods = settings.cors_allowed_methods.as_ref();
225 let allow_methods: Vec<_> = allow_methods
226 .map(|methods| {
227 methods
228 .iter()
229 .map(AsRef::as_ref)
230 .map(http::Method::from_str)
231 .collect::<Result<_, _>>()
232 .context("failed to parse allowed methods")
233 })
234 .transpose()?
235 .unwrap_or_default();
236 let allow_methods = if allow_methods.is_empty() {
237 cors::AllowMethods::any()
238 } else {
239 cors::AllowMethods::list(allow_methods)
240 };
241 let expose_headers = settings.cors_exposed_headers.as_ref();
242 let expose_headers: Vec<_> = expose_headers
243 .map(|headers| {
244 headers
245 .iter()
246 .map(AsRef::as_ref)
247 .map(http::HeaderName::from_str)
248 .collect::<Result<_, _>>()
249 .context("failed to parse exposeed header names")
250 })
251 .transpose()?
252 .unwrap_or_default();
253 let expose_headers = if expose_headers.is_empty() {
254 cors::ExposeHeaders::any()
255 } else {
256 cors::ExposeHeaders::list(expose_headers)
257 };
258 let mut cors = CorsLayer::new()
259 .allow_origin(allow_origin)
260 .allow_headers(allow_headers)
261 .allow_methods(allow_methods)
262 .expose_headers(expose_headers);
263 if let Some(max_age) = settings.cors_max_age_secs {
264 cors = cors.max_age(Duration::from_secs(max_age));
265 }
266
267 Ok(cors)
268}
269
270pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
275 let socket = match &settings.address {
276 SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
277 SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
278 }
279 .context("Unable to open socket")?;
280 socket
285 .set_reuseaddr(!cfg!(windows))
286 .context("Error when setting socket to reuseaddr")?;
287 socket
288 .set_nodelay(true)
289 .context("failed to set `TCP_NODELAY`")?;
290
291 match settings.disable_keepalive {
292 Some(false) => {
293 info!("disabling TCP keepalive");
294 socket
295 .set_keepalive(false)
296 .context("failed to disable TCP keepalive")?
297 }
298 None | Some(true) => socket
299 .set_keepalive(true)
300 .context("failed to enable TCP keepalive")?,
301 }
302
303 socket
304 .bind(settings.address)
305 .context("Unable to bind to address")?;
306 let listener = socket.listen(1024).context("unable to listen on socket")?;
307 let listener = listener.into_std().context("Unable to get listener")?;
308
309 Ok(listener)
310}
311
312pin_project! {
313 struct ResponseBody {
314 #[pin]
315 body: wrpc_interface_http::HttpBody,
316 #[pin]
317 errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
318 #[pin]
319 io: Option<JoinHandle<anyhow::Result<()>>>,
320 }
321}
322
323impl http_body::Body for ResponseBody {
324 type Data = Bytes;
325 type Error = anyhow::Error;
326
327 fn poll_frame(
328 mut self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
331 let mut this = self.as_mut().project();
332 if let Some(io) = this.io.as_mut().as_pin_mut() {
333 match io.poll(cx) {
334 Poll::Ready(Ok(Ok(()))) => {
335 this.io.take();
336 }
337 Poll::Ready(Ok(Err(err))) => {
338 return Poll::Ready(Some(Err(
339 anyhow!(err).context("failed to complete async I/O")
340 )))
341 }
342 Poll::Ready(Err(err)) => {
343 return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
344 }
345 Poll::Pending => {}
346 }
347 }
348 match this.errors.poll_next(cx) {
349 Poll::Ready(Some(err)) => {
350 if let Some(io) = this.io.as_pin_mut() {
351 io.abort();
352 }
353 return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
354 }
355 Poll::Ready(None) | Poll::Pending => {}
356 }
357 match ready!(this.body.poll_frame(cx)) {
358 Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
359 Some(Err(err)) => {
360 if let Some(io) = this.io.as_pin_mut() {
361 io.abort();
362 }
363 Poll::Ready(Some(Err(err)))
364 }
365 None => {
366 if let Some(io) = this.io.as_pin_mut() {
367 io.abort();
368 }
369 Poll::Ready(None)
370 }
371 }
372 }
373}
374
375#[cfg(test)]
376mod test {
377 use std::collections::HashMap;
378
379 use anyhow::Result;
380 use futures::StreamExt;
381 use wasmcloud_provider_sdk::{
382 provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
383 };
384 use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
385
386 use crate::{address, path};
387
388 #[ignore]
391 #[tokio::test]
392 async fn can_listen_and_invoke_with_timeout() -> Result<()> {
393 let nats_container = NatsServer::default()
394 .start()
395 .await
396 .expect("failed to start nats-server container");
397 let nats_port = nats_container
398 .get_host_port_ipv4(4222)
399 .await
400 .expect("should be able to find the NATS port");
401 let nats_address = format!("nats://127.0.0.1:{nats_port}");
402
403 let default_address = "0.0.0.0:8080";
404 let host_data = HostData {
405 lattice_rpc_url: nats_address.clone(),
406 lattice_rpc_prefix: "lattice".to_string(),
407 provider_key: "http-server-provider-test".to_string(),
408 config: std::collections::HashMap::from([
409 ("default_address".to_string(), default_address.to_string()),
410 ("routing_mode".to_string(), "address".to_string()),
411 ]),
412 link_definitions: vec![InterfaceLinkDefinition {
413 source_id: "http-server-provider-test".to_string(),
414 target: "test-component".to_string(),
415 name: "default".to_string(),
416 wit_namespace: "wasi".to_string(),
417 wit_package: "http".to_string(),
418 interfaces: vec!["incoming-handler".to_string()],
419 source_config: std::collections::HashMap::from([(
420 "timeout_ms".to_string(),
421 "100".to_string(),
422 )]),
423 target_config: HashMap::new(),
424 source_secrets: None,
425 target_secrets: None,
426 }],
427 ..Default::default()
428 };
429 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
430
431 let provider = run_provider(
432 address::HttpServerProvider::new(&host_data)
433 .expect("should be able to create provider"),
434 "http-server-provider-test",
435 )
436 .await
437 .expect("should be able to run provider");
438
439 let conn = async_nats::connect(nats_address)
441 .await
442 .expect("should be able to connect");
443 let mut subscriber = conn
444 .subscribe("lattice.test-component.wrpc.>")
445 .await
446 .expect("should be able to subscribe");
447
448 let provider_handle = tokio::spawn(provider);
449
450 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
452 let resp = reqwest::get("http://127.0.0.1:8080")
453 .await
454 .expect("should be able to make request");
455
456 assert_eq!(resp.status(), 408);
458 let msg = subscriber
460 .next()
461 .await
462 .expect("should be able to get a message");
463 assert!(msg.subject.contains("test-component"));
464 provider_handle.abort();
465 let _ = nats_container.stop().await;
466
467 Ok(())
468 }
469
470 #[ignore]
473 #[tokio::test]
474 async fn can_support_path_based_routing() -> Result<()> {
475 let nats_container = NatsServer::default()
476 .start()
477 .await
478 .expect("failed to start nats-server container");
479 let nats_port = nats_container
480 .get_host_port_ipv4(4222)
481 .await
482 .expect("should be able to find the NATS port");
483 let nats_address = format!("nats://127.0.0.1:{nats_port}");
484
485 let default_address = "0.0.0.0:8081";
486 let host_data = HostData {
487 lattice_rpc_url: nats_address.clone(),
488 lattice_rpc_prefix: "lattice".to_string(),
489 provider_key: "http-server-provider-test".to_string(),
490 config: std::collections::HashMap::from([
491 ("default_address".to_string(), default_address.to_string()),
492 ("routing_mode".to_string(), "path".to_string()),
493 ("timeout_ms".to_string(), "100".to_string()),
494 ]),
495 link_definitions: vec![
496 InterfaceLinkDefinition {
497 source_id: "http-server-provider-test".to_string(),
498 target: "test-component-one".to_string(),
499 name: "default".to_string(),
500 wit_namespace: "wasi".to_string(),
501 wit_package: "http".to_string(),
502 interfaces: vec!["incoming-handler".to_string()],
503 source_config: std::collections::HashMap::from([(
504 "path".to_string(),
505 "/foo".to_string(),
506 )]),
507 target_config: HashMap::new(),
508 source_secrets: None,
509 target_secrets: None,
510 },
511 InterfaceLinkDefinition {
512 source_id: "http-server-provider-test".to_string(),
513 target: "test-component-two".to_string(),
514 name: "default".to_string(),
515 wit_namespace: "wasi".to_string(),
516 wit_package: "http".to_string(),
517 interfaces: vec!["incoming-handler".to_string()],
518 source_config: std::collections::HashMap::from([(
519 "path".to_string(),
520 "/bar".to_string(),
521 )]),
522 target_config: HashMap::new(),
523 source_secrets: None,
524 target_secrets: None,
525 },
526 ],
527 ..Default::default()
528 };
529 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
530
531 let provider = run_provider(
532 path::HttpServerProvider::new(&host_data)
533 .await
534 .expect("should be able to create provider"),
535 "http-server-provider-test",
536 )
537 .await
538 .expect("should be able to run provider");
539
540 let conn = async_nats::connect(nats_address)
542 .await
543 .expect("should be able to connect");
544 let mut subscriber_one = conn
545 .subscribe("lattice.test-component-one.wrpc.>")
546 .await
547 .expect("should be able to subscribe");
548 let mut subscriber_two = conn
549 .subscribe("lattice.test-component-two.wrpc.>")
550 .await
551 .expect("should be able to subscribe");
552
553 let provider_handle = tokio::spawn(provider);
554 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
556
557 let resp = reqwest::get("http://127.0.0.1:8081/foo")
559 .await
560 .expect("should be able to make request");
561 assert_eq!(resp.status(), 408);
563 let msg = subscriber_one
564 .next()
565 .await
566 .expect("should be able to get a message");
567 assert!(msg.subject.contains("test-component-one"));
568
569 let resp = reqwest::get("http://127.0.0.1:8081/bar")
571 .await
572 .expect("should be able to make request");
573 assert_eq!(resp.status(), 408);
575 let msg = subscriber_two
576 .next()
577 .await
578 .expect("should be able to get a message");
579 assert!(msg.subject.contains("test-component-two"));
580
581 let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
583 .await
584 .expect("should be able to make request");
585 assert_eq!(resp.status(), 408);
587 let msg = subscriber_two
588 .next()
589 .await
590 .expect("should be able to get a message");
591 assert!(msg.subject.contains("test-component-two"));
592
593 let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
595 .await
596 .expect("should be able to make request");
597 assert_eq!(resp.status(), 404);
598
599 assert!(
602 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
603 .await
604 .is_err(),
605 );
606 assert!(
607 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
608 .await
609 .is_err(),
610 );
611
612 provider_handle.abort();
613 let _ = nats_container.stop().await;
614
615 Ok(())
616 }
617}