wasmcloud_provider_http_server/
lib.rs

1//! The httpserver capability provider allows wasmcloud components to receive
2//! and process http(s) messages from web browsers, command-line tools
3//! such as curl, and other http clients. The server is fully asynchronous,
4//! and built on Rust's high-performance axum library, which is in turn based
5//! on hyper, and can process a large number of simultaneous connections.
6//!
7//! ## Features:
8//!
9//! - HTTP/1 and HTTP/2
10//! - TLS
11//! - CORS support (select `allowed_origins`, `allowed_methods`,
12//!   `allowed_headers`.) Cors has sensible defaults so it should
13//!   work as-is for development purposes, and may need refinement
14//!   for production if a more secure configuration is required.
15//! - All settings can be specified at runtime, using per-component link settings:
16//!   - bind path/address
17//!   - TLS
18//!   - Cors
19//! - Flexible configuration loading: from host, or from local toml or json file.
20//! - Fully asynchronous, using tokio lightweight "green" threads
21//! - Thread pool (for managing a pool of OS threads). The default
22//!   thread pool has one thread per cpu core.
23//!
24
25use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_provider_sdk::provider::WrpcClient;
43use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
44use wrpc_interface_http::InvokeIncomingHandler as _;
45
46mod address;
47mod path;
48mod settings;
49pub use settings::{default_listen_address, load_settings, ServiceSettings};
50
51pub async fn run() -> anyhow::Result<()> {
52    initialize_observability!(
53        "http-server-provider",
54        std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55    );
56
57    let host_data = load_host_data().context("failed to load host data")?;
58    match host_data.config.get("routing_mode").map(String::as_str) {
59        // Run provider in address mode by default
60        Some("address") | None => run_provider(
61            address::HttpServerProvider::new(host_data).context(
62                "failed to create address-mode HTTP server provider from hostdata configuration",
63            )?,
64            "http-server-provider",
65        )
66        .await?
67        .await,
68        // Run provider in path mode
69        Some("path") => {
70            run_provider(
71                path::HttpServerProvider::new(host_data).await.context(
72                    "failed to create path-mode HTTP server provider from hostdata configuration",
73                )?,
74                "http-server-provider",
75            )
76            .await?
77            .await;
78        }
79        Some(other) => bail!("unknown routing_mode: {other}"),
80    };
81
82    Ok(())
83}
84
85/// Build a request to send to the component from the incoming request
86pub(crate) fn build_request(
87    request: extract::Request,
88    scheme: http::uri::Scheme,
89    authority: String,
90    settings: &ServiceSettings,
91) -> Result<http::Request<axum::body::Body>, axum::response::ErrorResponse> {
92    let method = request.method();
93    if let Some(readonly_mode) = settings.readonly_mode {
94        if readonly_mode
95            && method != http::method::Method::GET
96            && method != http::method::Method::HEAD
97        {
98            debug!("only GET and HEAD allowed in read-only mode");
99            Err((
100                http::StatusCode::METHOD_NOT_ALLOWED,
101                "only GET and HEAD allowed in read-only mode",
102            ))?;
103        }
104    }
105    let (
106        http::request::Parts {
107            method,
108            uri,
109            headers,
110            ..
111        },
112        body,
113    ) = request.into_parts();
114    let http::uri::Parts { path_and_query, .. } = uri.into_parts();
115
116    let mut uri = http::Uri::builder().scheme(scheme);
117    if !authority.is_empty() {
118        uri = uri.authority(authority);
119    }
120    if let Some(path_and_query) = path_and_query {
121        uri = uri.path_and_query(path_and_query);
122    }
123    let uri = uri
124        .build()
125        .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
126    let mut req = http::Request::builder();
127    *req.headers_mut().ok_or((
128        http::StatusCode::INTERNAL_SERVER_ERROR,
129        "invalid request generated",
130    ))? = headers;
131    let req = req
132        .uri(uri)
133        .method(method)
134        .body(body)
135        .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
136
137    Ok(req)
138}
139
140/// Invoke a component with the given request
141pub(crate) async fn invoke_component(
142    wrpc: &WrpcClient,
143    target: &str,
144    req: http::Request<axum::body::Body>,
145    timeout: Option<Duration>,
146    cache_control: Option<&String>,
147) -> impl axum::response::IntoResponse {
148    // Create a new wRPC client with all headers from the current span injected
149    let mut cx = async_nats::HeaderMap::new();
150    for (k, v) in
151        wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
152            &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
153        )
154        .iter()
155    {
156        cx.insert(k.as_str(), v.as_str());
157    }
158
159    trace!(?req, component_id = target, "httpserver calling component");
160    let fut = wrpc.invoke_handle_http(Some(cx), req);
161    let res = if let Some(timeout) = timeout {
162        let Ok(res) = time::timeout(timeout, fut).await else {
163            Err(http::StatusCode::REQUEST_TIMEOUT)?
164        };
165        res
166    } else {
167        fut.await
168    };
169    let (res, errors, io) =
170        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
171    let io = io.map(spawn);
172    let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
173    // TODO: Convert this to http status code
174    let mut res =
175        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
176    if let Some(cache_control) = cache_control {
177        let cache_control = http::HeaderValue::from_str(cache_control)
178            .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
179        res.headers_mut().append("Cache-Control", cache_control);
180    };
181    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
182        body,
183        errors,
184        io,
185    }))
186}
187
188/// Helper function to construct a [`CorsLayer`] according to the [`ServiceSettings`].
189pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
190    let allow_origin = settings.cors_allowed_origins.as_ref();
191    let allow_origin: Vec<_> = allow_origin
192        .map(|origins| {
193            origins
194                .iter()
195                .map(AsRef::as_ref)
196                .map(http::HeaderValue::from_str)
197                .collect::<Result<_, _>>()
198                .context("failed to parse allowed origins")
199        })
200        .transpose()?
201        .unwrap_or_default();
202    let allow_origin = if allow_origin.is_empty() {
203        cors::AllowOrigin::any()
204    } else {
205        cors::AllowOrigin::list(allow_origin)
206    };
207    let allow_headers = settings.cors_allowed_headers.as_ref();
208    let allow_headers: Vec<_> = allow_headers
209        .map(|headers| {
210            headers
211                .iter()
212                .map(AsRef::as_ref)
213                .map(http::HeaderName::from_str)
214                .collect::<Result<_, _>>()
215                .context("failed to parse allowed header names")
216        })
217        .transpose()?
218        .unwrap_or_default();
219    let allow_headers = if allow_headers.is_empty() {
220        cors::AllowHeaders::any()
221    } else {
222        cors::AllowHeaders::list(allow_headers)
223    };
224    let allow_methods = settings.cors_allowed_methods.as_ref();
225    let allow_methods: Vec<_> = allow_methods
226        .map(|methods| {
227            methods
228                .iter()
229                .map(AsRef::as_ref)
230                .map(http::Method::from_str)
231                .collect::<Result<_, _>>()
232                .context("failed to parse allowed methods")
233        })
234        .transpose()?
235        .unwrap_or_default();
236    let allow_methods = if allow_methods.is_empty() {
237        cors::AllowMethods::any()
238    } else {
239        cors::AllowMethods::list(allow_methods)
240    };
241    let expose_headers = settings.cors_exposed_headers.as_ref();
242    let expose_headers: Vec<_> = expose_headers
243        .map(|headers| {
244            headers
245                .iter()
246                .map(AsRef::as_ref)
247                .map(http::HeaderName::from_str)
248                .collect::<Result<_, _>>()
249                .context("failed to parse exposeed header names")
250        })
251        .transpose()?
252        .unwrap_or_default();
253    let expose_headers = if expose_headers.is_empty() {
254        cors::ExposeHeaders::any()
255    } else {
256        cors::ExposeHeaders::list(expose_headers)
257    };
258    let mut cors = CorsLayer::new()
259        .allow_origin(allow_origin)
260        .allow_headers(allow_headers)
261        .allow_methods(allow_methods)
262        .expose_headers(expose_headers);
263    if let Some(max_age) = settings.cors_max_age_secs {
264        cors = cors.max_age(Duration::from_secs(max_age));
265    }
266
267    Ok(cors)
268}
269
270/// Helper function to create and listen on a [`TcpListener`] from the given [`ServiceSettings`].
271///
272/// Note that this function actually calls the `bind` method on the [`TcpSocket`], it's up to the
273/// caller to ensure that the address is not already in use (or to handle the error if it is).
274pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
275    let socket = match &settings.address {
276        SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
277        SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
278    }
279    .context("Unable to open socket")?;
280    // Copied this option from
281    // https://github.com/bytecodealliance/wasmtime/blob/05095c18680927ce0cf6c7b468f9569ec4d11bd7/src/commands/serve.rs#L319.
282    // This does increase throughput by 10-15% which is why we're creating the socket. We're
283    // using the tokio one because it exposes the `reuseaddr` option.
284    socket
285        .set_reuseaddr(!cfg!(windows))
286        .context("Error when setting socket to reuseaddr")?;
287    socket
288        .set_nodelay(true)
289        .context("failed to set `TCP_NODELAY`")?;
290
291    match settings.disable_keepalive {
292        Some(false) => {
293            info!("disabling TCP keepalive");
294            socket
295                .set_keepalive(false)
296                .context("failed to disable TCP keepalive")?
297        }
298        None | Some(true) => socket
299            .set_keepalive(true)
300            .context("failed to enable TCP keepalive")?,
301    }
302
303    socket
304        .bind(settings.address)
305        .context("Unable to bind to address")?;
306    let listener = socket.listen(1024).context("unable to listen on socket")?;
307    let listener = listener.into_std().context("Unable to get listener")?;
308
309    Ok(listener)
310}
311
312pin_project! {
313    struct ResponseBody {
314        #[pin]
315        body: wrpc_interface_http::HttpBody,
316        #[pin]
317        errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
318        #[pin]
319        io: Option<JoinHandle<anyhow::Result<()>>>,
320    }
321}
322
323impl http_body::Body for ResponseBody {
324    type Data = Bytes;
325    type Error = anyhow::Error;
326
327    fn poll_frame(
328        mut self: Pin<&mut Self>,
329        cx: &mut Context<'_>,
330    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
331        let mut this = self.as_mut().project();
332        if let Some(io) = this.io.as_mut().as_pin_mut() {
333            match io.poll(cx) {
334                Poll::Ready(Ok(Ok(()))) => {
335                    this.io.take();
336                }
337                Poll::Ready(Ok(Err(err))) => {
338                    return Poll::Ready(Some(Err(
339                        anyhow!(err).context("failed to complete async I/O")
340                    )))
341                }
342                Poll::Ready(Err(err)) => {
343                    return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
344                }
345                Poll::Pending => {}
346            }
347        }
348        match this.errors.poll_next(cx) {
349            Poll::Ready(Some(err)) => {
350                if let Some(io) = this.io.as_pin_mut() {
351                    io.abort();
352                }
353                return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
354            }
355            Poll::Ready(None) | Poll::Pending => {}
356        }
357        match ready!(this.body.poll_frame(cx)) {
358            Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
359            Some(Err(err)) => {
360                if let Some(io) = this.io.as_pin_mut() {
361                    io.abort();
362                }
363                Poll::Ready(Some(Err(err)))
364            }
365            None => {
366                if let Some(io) = this.io.as_pin_mut() {
367                    io.abort();
368                }
369                Poll::Ready(None)
370            }
371        }
372    }
373}
374
375#[cfg(test)]
376mod test {
377    use std::collections::HashMap;
378
379    use anyhow::Result;
380    use futures::StreamExt;
381    use wasmcloud_provider_sdk::{
382        provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
383    };
384    use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
385
386    use crate::{address, path};
387
388    // This test is ignored by default as it requires a container runtime to be installed
389    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
390    #[ignore]
391    #[tokio::test]
392    async fn can_listen_and_invoke_with_timeout() -> Result<()> {
393        let nats_container = NatsServer::default()
394            .start()
395            .await
396            .expect("failed to start nats-server container");
397        let nats_port = nats_container
398            .get_host_port_ipv4(4222)
399            .await
400            .expect("should be able to find the NATS port");
401        let nats_address = format!("nats://127.0.0.1:{nats_port}");
402
403        let default_address = "0.0.0.0:8080";
404        let host_data = HostData {
405            lattice_rpc_url: nats_address.clone(),
406            lattice_rpc_prefix: "lattice".to_string(),
407            provider_key: "http-server-provider-test".to_string(),
408            config: std::collections::HashMap::from([
409                ("default_address".to_string(), default_address.to_string()),
410                ("routing_mode".to_string(), "address".to_string()),
411            ]),
412            link_definitions: vec![InterfaceLinkDefinition {
413                source_id: "http-server-provider-test".to_string(),
414                target: "test-component".to_string(),
415                name: "default".to_string(),
416                wit_namespace: "wasi".to_string(),
417                wit_package: "http".to_string(),
418                interfaces: vec!["incoming-handler".to_string()],
419                source_config: std::collections::HashMap::from([(
420                    "timeout_ms".to_string(),
421                    "100".to_string(),
422                )]),
423                target_config: HashMap::new(),
424                source_secrets: None,
425                target_secrets: None,
426            }],
427            ..Default::default()
428        };
429        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
430
431        let provider = run_provider(
432            address::HttpServerProvider::new(&host_data)
433                .expect("should be able to create provider"),
434            "http-server-provider-test",
435        )
436        .await
437        .expect("should be able to run provider");
438
439        // Use a separate task to listen for the component message
440        let conn = async_nats::connect(nats_address)
441            .await
442            .expect("should be able to connect");
443        let mut subscriber = conn
444            .subscribe("lattice.test-component.wrpc.>")
445            .await
446            .expect("should be able to subscribe");
447
448        let provider_handle = tokio::spawn(provider);
449
450        // Let the provider have a second to setup the listener
451        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
452        let resp = reqwest::get("http://127.0.0.1:8080")
453            .await
454            .expect("should be able to make request");
455
456        // Should have timed out
457        assert_eq!(resp.status(), 408);
458        // Ensure component received the message
459        let msg = subscriber
460            .next()
461            .await
462            .expect("should be able to get a message");
463        assert!(msg.subject.contains("test-component"));
464        provider_handle.abort();
465        let _ = nats_container.stop().await;
466
467        Ok(())
468    }
469
470    // This test is ignored by default as it requires a container runtime to be installed
471    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
472    #[ignore]
473    #[tokio::test]
474    async fn can_support_path_based_routing() -> Result<()> {
475        let nats_container = NatsServer::default()
476            .start()
477            .await
478            .expect("failed to start nats-server container");
479        let nats_port = nats_container
480            .get_host_port_ipv4(4222)
481            .await
482            .expect("should be able to find the NATS port");
483        let nats_address = format!("nats://127.0.0.1:{nats_port}");
484
485        let default_address = "0.0.0.0:8081";
486        let host_data = HostData {
487            lattice_rpc_url: nats_address.clone(),
488            lattice_rpc_prefix: "lattice".to_string(),
489            provider_key: "http-server-provider-test".to_string(),
490            config: std::collections::HashMap::from([
491                ("default_address".to_string(), default_address.to_string()),
492                ("routing_mode".to_string(), "path".to_string()),
493                ("timeout_ms".to_string(), "100".to_string()),
494            ]),
495            link_definitions: vec![
496                InterfaceLinkDefinition {
497                    source_id: "http-server-provider-test".to_string(),
498                    target: "test-component-one".to_string(),
499                    name: "default".to_string(),
500                    wit_namespace: "wasi".to_string(),
501                    wit_package: "http".to_string(),
502                    interfaces: vec!["incoming-handler".to_string()],
503                    source_config: std::collections::HashMap::from([(
504                        "path".to_string(),
505                        "/foo".to_string(),
506                    )]),
507                    target_config: HashMap::new(),
508                    source_secrets: None,
509                    target_secrets: None,
510                },
511                InterfaceLinkDefinition {
512                    source_id: "http-server-provider-test".to_string(),
513                    target: "test-component-two".to_string(),
514                    name: "default".to_string(),
515                    wit_namespace: "wasi".to_string(),
516                    wit_package: "http".to_string(),
517                    interfaces: vec!["incoming-handler".to_string()],
518                    source_config: std::collections::HashMap::from([(
519                        "path".to_string(),
520                        "/bar".to_string(),
521                    )]),
522                    target_config: HashMap::new(),
523                    source_secrets: None,
524                    target_secrets: None,
525                },
526            ],
527            ..Default::default()
528        };
529        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
530
531        let provider = run_provider(
532            path::HttpServerProvider::new(&host_data)
533                .await
534                .expect("should be able to create provider"),
535            "http-server-provider-test",
536        )
537        .await
538        .expect("should be able to run provider");
539
540        // Use a separate task to listen for the component message
541        let conn = async_nats::connect(nats_address)
542            .await
543            .expect("should be able to connect");
544        let mut subscriber_one = conn
545            .subscribe("lattice.test-component-one.wrpc.>")
546            .await
547            .expect("should be able to subscribe");
548        let mut subscriber_two = conn
549            .subscribe("lattice.test-component-two.wrpc.>")
550            .await
551            .expect("should be able to subscribe");
552
553        let provider_handle = tokio::spawn(provider);
554        // Let the provider have a second to setup the listeners
555        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
556
557        // Invoke component one
558        let resp = reqwest::get("http://127.0.0.1:8081/foo")
559            .await
560            .expect("should be able to make request");
561        // Should have timed out
562        assert_eq!(resp.status(), 408);
563        let msg = subscriber_one
564            .next()
565            .await
566            .expect("should be able to get a message");
567        assert!(msg.subject.contains("test-component-one"));
568
569        // Invoke component two
570        let resp = reqwest::get("http://127.0.0.1:8081/bar")
571            .await
572            .expect("should be able to make request");
573        // Should have timed out
574        assert_eq!(resp.status(), 408);
575        let msg = subscriber_two
576            .next()
577            .await
578            .expect("should be able to get a message");
579        assert!(msg.subject.contains("test-component-two"));
580
581        // Invoke component two with a query parameter
582        let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
583            .await
584            .expect("should be able to make request");
585        // Should have timed out
586        assert_eq!(resp.status(), 408);
587        let msg = subscriber_two
588            .next()
589            .await
590            .expect("should be able to get a message");
591        assert!(msg.subject.contains("test-component-two"));
592
593        // Unknown path should return 404
594        let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
595            .await
596            .expect("should be able to make request");
597        assert_eq!(resp.status(), 404);
598
599        // No other messages should have been received
600        // (the assertion is that the operation timed out)
601        assert!(
602            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
603                .await
604                .is_err(),
605        );
606        assert!(
607            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
608                .await
609                .is_err(),
610        );
611
612        provider_handle.abort();
613        let _ = nats_container.stop().await;
614
615        Ok(())
616    }
617}