wasmcloud_provider_http_server/
lib.rs

1//! The httpserver capability provider allows wasmcloud components to receive
2//! and process http(s) messages from web browsers, command-line tools
3//! such as curl, and other http clients. The server is fully asynchronous,
4//! and built on Rust's high-performance axum library, which is in turn based
5//! on hyper, and can process a large number of simultaneous connections.
6//!
7//! ## Features:
8//!
9//! - HTTP/1 and HTTP/2
10//! - TLS
11//! - CORS support (select `allowed_origins`, `allowed_methods`,
12//!   `allowed_headers`.) Cors has sensible defaults so it should
13//!   work as-is for development purposes, and may need refinement
14//!   for production if a more secure configuration is required.
15//! - All settings can be specified at runtime, using per-component link settings:
16//!   - bind path/address
17//!   - TLS
18//!   - Cors
19//! - Flexible configuration loading: from host, or from local toml or json file.
20//! - Fully asynchronous, using tokio lightweight "green" threads
21//! - Thread pool (for managing a pool of OS threads). The default
22//!   thread pool has one thread per cpu core.
23//!
24
25use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_core::http::{load_settings, ServiceSettings};
43use wasmcloud_provider_sdk::provider::WrpcClient;
44use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
45use wrpc_interface_http::InvokeIncomingHandler as _;
46
47mod address;
48mod host;
49mod path;
50
51pub async fn run() -> anyhow::Result<()> {
52    initialize_observability!(
53        "http-server-provider",
54        std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55    );
56
57    let host_data = load_host_data().context("failed to load host data")?;
58    match host_data.config.get("routing_mode").map(String::as_str) {
59        // Run provider in address mode by default
60        Some("address") | None => run_provider(
61            address::HttpServerProvider::new(host_data).context(
62                "failed to create address-mode HTTP server provider from hostdata configuration",
63            )?,
64            "http-server-provider",
65        )
66        .await?
67        .await,
68        // Run provider in path mode
69        Some("path") => {
70            run_provider(
71                path::HttpServerProvider::new(host_data).await.context(
72                    "failed to create path-mode HTTP server provider from hostdata configuration",
73                )?,
74                "http-server-provider",
75            )
76            .await?
77            .await;
78        }
79        Some("host") => {
80            run_provider(
81                host::HttpServerProvider::new(host_data).await.context(
82                    "failed to create host-mode HTTP server provider from hostdata configuration",
83                )?,
84                "http-server-provider",
85            )
86            .await?
87            .await;
88        }
89        Some(other) => bail!("unknown routing_mode: {other}"),
90    };
91
92    Ok(())
93}
94
95/// Build a request to send to the component from the incoming request
96pub(crate) fn build_request(
97    request: extract::Request,
98    scheme: http::uri::Scheme,
99    authority: String,
100    settings: &ServiceSettings,
101) -> Result<http::Request<axum::body::Body>, axum::response::ErrorResponse> {
102    let method = request.method();
103    if let Some(readonly_mode) = settings.readonly_mode {
104        if readonly_mode
105            && method != http::method::Method::GET
106            && method != http::method::Method::HEAD
107        {
108            debug!("only GET and HEAD allowed in read-only mode");
109            Err((
110                http::StatusCode::METHOD_NOT_ALLOWED,
111                "only GET and HEAD allowed in read-only mode",
112            ))?;
113        }
114    }
115    let (
116        http::request::Parts {
117            method,
118            uri,
119            headers,
120            ..
121        },
122        body,
123    ) = request.into_parts();
124    let http::uri::Parts { path_and_query, .. } = uri.into_parts();
125
126    let mut uri = http::Uri::builder().scheme(scheme);
127    if !authority.is_empty() {
128        uri = uri.authority(authority);
129    }
130    if let Some(path_and_query) = path_and_query {
131        uri = uri.path_and_query(path_and_query);
132    }
133    let uri = uri
134        .build()
135        .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
136    let mut req = http::Request::builder();
137    *req.headers_mut().ok_or((
138        http::StatusCode::INTERNAL_SERVER_ERROR,
139        "invalid request generated",
140    ))? = headers;
141    let req = req
142        .uri(uri)
143        .method(method)
144        .body(body)
145        .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
146
147    Ok(req)
148}
149
150/// Invoke a component with the given request
151pub(crate) async fn invoke_component(
152    wrpc: &WrpcClient,
153    target: &str,
154    req: http::Request<axum::body::Body>,
155    timeout: Option<Duration>,
156    cache_control: Option<&String>,
157) -> impl axum::response::IntoResponse {
158    // Create a new wRPC client with all headers from the current span injected
159    let mut cx = async_nats::HeaderMap::new();
160    for (k, v) in
161        wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
162            &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
163        )
164        .iter()
165    {
166        cx.insert(k.as_str(), v.as_str());
167    }
168
169    trace!(?req, component_id = target, "httpserver calling component");
170    let fut = wrpc.invoke_handle_http(Some(cx), req);
171    let res = if let Some(timeout) = timeout {
172        let Ok(res) = time::timeout(timeout, fut).await else {
173            Err(http::StatusCode::REQUEST_TIMEOUT)?
174        };
175        res
176    } else {
177        fut.await
178    };
179    let (res, errors, io) =
180        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
181    let io = io.map(spawn);
182    let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
183    // TODO: Convert this to http status code
184    let mut res =
185        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
186    if let Some(cache_control) = cache_control {
187        let cache_control = http::HeaderValue::from_str(cache_control)
188            .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
189        res.headers_mut().append("Cache-Control", cache_control);
190    };
191    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
192        body,
193        errors,
194        io,
195    }))
196}
197
198/// Helper function to construct a [`CorsLayer`] according to the [`ServiceSettings`].
199pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
200    let allow_origin = settings.cors_allowed_origins.as_ref();
201    let allow_origin: Vec<_> = allow_origin
202        .map(|origins| {
203            origins
204                .iter()
205                .map(AsRef::as_ref)
206                .map(http::HeaderValue::from_str)
207                .collect::<Result<_, _>>()
208                .context("failed to parse allowed origins")
209        })
210        .transpose()?
211        .unwrap_or_default();
212    let allow_origin = if allow_origin.is_empty() {
213        cors::AllowOrigin::any()
214    } else {
215        cors::AllowOrigin::list(allow_origin)
216    };
217    let allow_headers = settings.cors_allowed_headers.as_ref();
218    let allow_headers: Vec<_> = allow_headers
219        .map(|headers| {
220            headers
221                .iter()
222                .map(AsRef::as_ref)
223                .map(http::HeaderName::from_str)
224                .collect::<Result<_, _>>()
225                .context("failed to parse allowed header names")
226        })
227        .transpose()?
228        .unwrap_or_default();
229    let allow_headers = if allow_headers.is_empty() {
230        cors::AllowHeaders::any()
231    } else {
232        cors::AllowHeaders::list(allow_headers)
233    };
234    let allow_methods = settings.cors_allowed_methods.as_ref();
235    let allow_methods: Vec<_> = allow_methods
236        .map(|methods| {
237            methods
238                .iter()
239                .map(AsRef::as_ref)
240                .map(http::Method::from_str)
241                .collect::<Result<_, _>>()
242                .context("failed to parse allowed methods")
243        })
244        .transpose()?
245        .unwrap_or_default();
246    let allow_methods = if allow_methods.is_empty() {
247        cors::AllowMethods::any()
248    } else {
249        cors::AllowMethods::list(allow_methods)
250    };
251    let expose_headers = settings.cors_exposed_headers.as_ref();
252    let expose_headers: Vec<_> = expose_headers
253        .map(|headers| {
254            headers
255                .iter()
256                .map(AsRef::as_ref)
257                .map(http::HeaderName::from_str)
258                .collect::<Result<_, _>>()
259                .context("failed to parse exposeed header names")
260        })
261        .transpose()?
262        .unwrap_or_default();
263    let expose_headers = if expose_headers.is_empty() {
264        cors::ExposeHeaders::any()
265    } else {
266        cors::ExposeHeaders::list(expose_headers)
267    };
268    let mut cors = CorsLayer::new()
269        .allow_origin(allow_origin)
270        .allow_headers(allow_headers)
271        .allow_methods(allow_methods)
272        .expose_headers(expose_headers);
273    if let Some(max_age) = settings.cors_max_age_secs {
274        cors = cors.max_age(Duration::from_secs(max_age));
275    }
276
277    Ok(cors)
278}
279
280/// Helper function to create and listen on a [`TcpListener`] from the given [`ServiceSettings`].
281///
282/// Note that this function actually calls the `bind` method on the [`TcpSocket`], it's up to the
283/// caller to ensure that the address is not already in use (or to handle the error if it is).
284pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
285    let socket = match &settings.address {
286        SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
287        SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
288    }
289    .context("Unable to open socket")?;
290    // Copied this option from
291    // https://github.com/bytecodealliance/wasmtime/blob/05095c18680927ce0cf6c7b468f9569ec4d11bd7/src/commands/serve.rs#L319.
292    // This does increase throughput by 10-15% which is why we're creating the socket. We're
293    // using the tokio one because it exposes the `reuseaddr` option.
294    socket
295        .set_reuseaddr(!cfg!(windows))
296        .context("Error when setting socket to reuseaddr")?;
297    socket
298        .set_nodelay(true)
299        .context("failed to set `TCP_NODELAY`")?;
300
301    match settings.disable_keepalive {
302        Some(false) => {
303            info!("disabling TCP keepalive");
304            socket
305                .set_keepalive(false)
306                .context("failed to disable TCP keepalive")?
307        }
308        None | Some(true) => socket
309            .set_keepalive(true)
310            .context("failed to enable TCP keepalive")?,
311    }
312
313    socket
314        .bind(settings.address)
315        .context("Unable to bind to address")?;
316    let listener = socket.listen(1024).context("unable to listen on socket")?;
317    let listener = listener.into_std().context("Unable to get listener")?;
318
319    Ok(listener)
320}
321
322pin_project! {
323    struct ResponseBody {
324        #[pin]
325        body: wrpc_interface_http::HttpBody,
326        #[pin]
327        errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
328        #[pin]
329        io: Option<JoinHandle<anyhow::Result<()>>>,
330    }
331}
332
333impl http_body::Body for ResponseBody {
334    type Data = Bytes;
335    type Error = anyhow::Error;
336
337    fn poll_frame(
338        mut self: Pin<&mut Self>,
339        cx: &mut Context<'_>,
340    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
341        let mut this = self.as_mut().project();
342        if let Some(io) = this.io.as_mut().as_pin_mut() {
343            match io.poll(cx) {
344                Poll::Ready(Ok(Ok(()))) => {
345                    this.io.take();
346                }
347                Poll::Ready(Ok(Err(err))) => {
348                    return Poll::Ready(Some(Err(
349                        anyhow!(err).context("failed to complete async I/O")
350                    )))
351                }
352                Poll::Ready(Err(err)) => {
353                    return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
354                }
355                Poll::Pending => {}
356            }
357        }
358        match this.errors.poll_next(cx) {
359            Poll::Ready(Some(err)) => {
360                if let Some(io) = this.io.as_pin_mut() {
361                    io.abort();
362                }
363                return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
364            }
365            Poll::Ready(None) | Poll::Pending => {}
366        }
367        match ready!(this.body.poll_frame(cx)) {
368            Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
369            Some(Err(err)) => {
370                if let Some(io) = this.io.as_pin_mut() {
371                    io.abort();
372                }
373                Poll::Ready(Some(Err(err)))
374            }
375            None => {
376                if let Some(io) = this.io.as_pin_mut() {
377                    io.abort();
378                }
379                Poll::Ready(None)
380            }
381        }
382    }
383}
384
385#[cfg(test)]
386mod test {
387    use std::collections::HashMap;
388
389    use anyhow::Result;
390    use futures::StreamExt;
391    use wasmcloud_provider_sdk::{
392        provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
393    };
394    use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
395
396    use crate::{address, path};
397
398    // This test is ignored by default as it requires a container runtime to be installed
399    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
400    #[ignore]
401    #[tokio::test]
402    async fn can_listen_and_invoke_with_timeout() -> Result<()> {
403        let nats_container = NatsServer::default()
404            .start()
405            .await
406            .expect("failed to start nats-server container");
407        let nats_port = nats_container
408            .get_host_port_ipv4(4222)
409            .await
410            .expect("should be able to find the NATS port");
411        let nats_address = format!("nats://127.0.0.1:{nats_port}");
412
413        let default_address = "0.0.0.0:8080";
414        let host_data = HostData {
415            lattice_rpc_url: nats_address.clone(),
416            lattice_rpc_prefix: "lattice".to_string(),
417            provider_key: "http-server-provider-test".to_string(),
418            config: std::collections::HashMap::from([
419                ("default_address".to_string(), default_address.to_string()),
420                ("routing_mode".to_string(), "address".to_string()),
421            ]),
422            link_definitions: vec![InterfaceLinkDefinition {
423                source_id: "http-server-provider-test".to_string(),
424                target: "test-component".to_string(),
425                name: "default".to_string(),
426                wit_namespace: "wasi".to_string(),
427                wit_package: "http".to_string(),
428                interfaces: vec!["incoming-handler".to_string()],
429                source_config: std::collections::HashMap::from([(
430                    "timeout_ms".to_string(),
431                    "100".to_string(),
432                )]),
433                target_config: HashMap::new(),
434                source_secrets: None,
435                target_secrets: None,
436            }],
437            ..Default::default()
438        };
439        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
440
441        let provider = run_provider(
442            address::HttpServerProvider::new(&host_data)
443                .expect("should be able to create provider"),
444            "http-server-provider-test",
445        )
446        .await
447        .expect("should be able to run provider");
448
449        // Use a separate task to listen for the component message
450        let conn = async_nats::connect(nats_address)
451            .await
452            .expect("should be able to connect");
453        let mut subscriber = conn
454            .subscribe("lattice.test-component.wrpc.>")
455            .await
456            .expect("should be able to subscribe");
457
458        let provider_handle = tokio::spawn(provider);
459
460        // Let the provider have a second to setup the listener
461        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
462        let resp = reqwest::get("http://127.0.0.1:8080")
463            .await
464            .expect("should be able to make request");
465
466        // Should have timed out
467        assert_eq!(resp.status(), 408);
468        // Ensure component received the message
469        let msg = subscriber
470            .next()
471            .await
472            .expect("should be able to get a message");
473        assert!(msg.subject.contains("test-component"));
474        provider_handle.abort();
475        let _ = nats_container.stop().await;
476
477        Ok(())
478    }
479
480    // This test is ignored by default as it requires a container runtime to be installed
481    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
482    #[ignore]
483    #[tokio::test]
484    async fn can_support_path_based_routing() -> Result<()> {
485        let nats_container = NatsServer::default()
486            .start()
487            .await
488            .expect("failed to start nats-server container");
489        let nats_port = nats_container
490            .get_host_port_ipv4(4222)
491            .await
492            .expect("should be able to find the NATS port");
493        let nats_address = format!("nats://127.0.0.1:{nats_port}");
494
495        let default_address = "0.0.0.0:8081";
496        let host_data = HostData {
497            lattice_rpc_url: nats_address.clone(),
498            lattice_rpc_prefix: "lattice".to_string(),
499            provider_key: "http-server-provider-test".to_string(),
500            config: std::collections::HashMap::from([
501                ("default_address".to_string(), default_address.to_string()),
502                ("routing_mode".to_string(), "path".to_string()),
503                ("timeout_ms".to_string(), "100".to_string()),
504            ]),
505            link_definitions: vec![
506                InterfaceLinkDefinition {
507                    source_id: "http-server-provider-test".to_string(),
508                    target: "test-component-one".to_string(),
509                    name: "default".to_string(),
510                    wit_namespace: "wasi".to_string(),
511                    wit_package: "http".to_string(),
512                    interfaces: vec!["incoming-handler".to_string()],
513                    source_config: std::collections::HashMap::from([(
514                        "path".to_string(),
515                        "/foo".to_string(),
516                    )]),
517                    target_config: HashMap::new(),
518                    source_secrets: None,
519                    target_secrets: None,
520                },
521                InterfaceLinkDefinition {
522                    source_id: "http-server-provider-test".to_string(),
523                    target: "test-component-two".to_string(),
524                    name: "default".to_string(),
525                    wit_namespace: "wasi".to_string(),
526                    wit_package: "http".to_string(),
527                    interfaces: vec!["incoming-handler".to_string()],
528                    source_config: std::collections::HashMap::from([(
529                        "path".to_string(),
530                        "/bar".to_string(),
531                    )]),
532                    target_config: HashMap::new(),
533                    source_secrets: None,
534                    target_secrets: None,
535                },
536            ],
537            ..Default::default()
538        };
539        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
540
541        let provider = run_provider(
542            path::HttpServerProvider::new(&host_data)
543                .await
544                .expect("should be able to create provider"),
545            "http-server-provider-test",
546        )
547        .await
548        .expect("should be able to run provider");
549
550        // Use a separate task to listen for the component message
551        let conn = async_nats::connect(nats_address)
552            .await
553            .expect("should be able to connect");
554        let mut subscriber_one = conn
555            .subscribe("lattice.test-component-one.wrpc.>")
556            .await
557            .expect("should be able to subscribe");
558        let mut subscriber_two = conn
559            .subscribe("lattice.test-component-two.wrpc.>")
560            .await
561            .expect("should be able to subscribe");
562
563        let provider_handle = tokio::spawn(provider);
564        // Let the provider have a second to setup the listeners
565        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
566
567        // Invoke component one
568        let resp = reqwest::get("http://127.0.0.1:8081/foo")
569            .await
570            .expect("should be able to make request");
571        // Should have timed out
572        assert_eq!(resp.status(), 408);
573        let msg = subscriber_one
574            .next()
575            .await
576            .expect("should be able to get a message");
577        assert!(msg.subject.contains("test-component-one"));
578
579        // Invoke component two
580        let resp = reqwest::get("http://127.0.0.1:8081/bar")
581            .await
582            .expect("should be able to make request");
583        // Should have timed out
584        assert_eq!(resp.status(), 408);
585        let msg = subscriber_two
586            .next()
587            .await
588            .expect("should be able to get a message");
589        assert!(msg.subject.contains("test-component-two"));
590
591        // Invoke component two with a query parameter
592        let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
593            .await
594            .expect("should be able to make request");
595        // Should have timed out
596        assert_eq!(resp.status(), 408);
597        let msg = subscriber_two
598            .next()
599            .await
600            .expect("should be able to get a message");
601        assert!(msg.subject.contains("test-component-two"));
602
603        // Unknown path should return 404
604        let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
605            .await
606            .expect("should be able to make request");
607        assert_eq!(resp.status(), 404);
608
609        // No other messages should have been received
610        // (the assertion is that the operation timed out)
611        assert!(
612            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
613                .await
614                .is_err(),
615        );
616        assert!(
617            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
618                .await
619                .is_err(),
620        );
621
622        provider_handle.abort();
623        let _ = nats_container.stop().await;
624
625        Ok(())
626    }
627}