Skip to main content

tower_http/cors/
allow_private_network.rs

1use std::{fmt, sync::Arc};
2
3use http::{
4    header::{HeaderName, HeaderValue},
5    request::Parts as RequestParts,
6};
7
8/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header.
9///
10/// See [`CorsLayer::allow_private_network`] for more details.
11///
12/// [wicg]: https://wicg.github.io/private-network-access/
13/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
14#[derive(Clone, Default)]
15#[must_use]
16pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);
17
18impl AllowPrivateNetwork {
19    /// Allow requests via a more private network than the one used to access the origin
20    ///
21    /// See [`CorsLayer::allow_private_network`] for more details.
22    ///
23    /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
24    pub fn yes() -> Self {
25        Self(AllowPrivateNetworkInner::Yes)
26    }
27
28    /// Allow requests via private network for some requests, based on a given predicate
29    ///
30    /// The first argument to the predicate is the request origin.
31    ///
32    /// See [`CorsLayer::allow_private_network`] for more details.
33    ///
34    /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
35    pub fn predicate<F>(f: F) -> Self
36    where
37        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
38    {
39        Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
40    }
41
42    #[allow(
43        clippy::declare_interior_mutable_const,
44        clippy::borrow_interior_mutable_const
45    )]
46    pub(super) fn to_header(
47        &self,
48        origin: Option<&HeaderValue>,
49        parts: &RequestParts,
50    ) -> Option<(HeaderName, HeaderValue)> {
51        #[allow(clippy::declare_interior_mutable_const)]
52        const REQUEST_PRIVATE_NETWORK: HeaderName =
53            HeaderName::from_static("access-control-request-private-network");
54
55        #[allow(clippy::declare_interior_mutable_const)]
56        const ALLOW_PRIVATE_NETWORK: HeaderName =
57            HeaderName::from_static("access-control-allow-private-network");
58
59        const TRUE: HeaderValue = HeaderValue::from_static("true");
60
61        // Cheapest fallback: allow_private_network hasn't been set
62        if let AllowPrivateNetworkInner::No = &self.0 {
63            return None;
64        }
65
66        // Access-Control-Allow-Private-Network is only relevant if the request
67        // has the Access-Control-Request-Private-Network header set, else skip
68        if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
69            return None;
70        }
71
72        let allow_private_network = match &self.0 {
73            AllowPrivateNetworkInner::Yes => true,
74            AllowPrivateNetworkInner::No => false, // unreachable, but not harmful
75            AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
76        };
77
78        allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE))
79    }
80}
81
82impl From<bool> for AllowPrivateNetwork {
83    fn from(v: bool) -> Self {
84        match v {
85            true => Self(AllowPrivateNetworkInner::Yes),
86            false => Self(AllowPrivateNetworkInner::No),
87        }
88    }
89}
90
91impl fmt::Debug for AllowPrivateNetwork {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self.0 {
94            AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
95            AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
96            AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
97        }
98    }
99}
100
101#[derive(Clone, Default)]
102enum AllowPrivateNetworkInner {
103    Yes,
104    #[default]
105    No,
106    Predicate(
107        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
108    ),
109}
110
111#[cfg(test)]
112mod tests {
113    #![allow(
114        clippy::declare_interior_mutable_const,
115        clippy::borrow_interior_mutable_const
116    )]
117
118    use super::AllowPrivateNetwork;
119    use crate::cors::CorsLayer;
120
121    use crate::test_helpers::Body;
122    use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
123    use tower::{BoxError, ServiceBuilder, ServiceExt};
124    use tower_service::Service;
125
126    const REQUEST_PRIVATE_NETWORK: HeaderName =
127        HeaderName::from_static("access-control-request-private-network");
128
129    const ALLOW_PRIVATE_NETWORK: HeaderName =
130        HeaderName::from_static("access-control-allow-private-network");
131
132    const TRUE: HeaderValue = HeaderValue::from_static("true");
133
134    #[tokio::test]
135    async fn cors_private_network_header_is_added_correctly() {
136        let mut service = ServiceBuilder::new()
137            .layer(CorsLayer::new().allow_private_network(true))
138            .service_fn(echo);
139
140        let req = Request::builder()
141            .header(REQUEST_PRIVATE_NETWORK, TRUE)
142            .body(Body::empty())
143            .unwrap();
144        let res = service.ready().await.unwrap().call(req).await.unwrap();
145
146        assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
147
148        let req = Request::builder().body(Body::empty()).unwrap();
149        let res = service.ready().await.unwrap().call(req).await.unwrap();
150
151        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
152    }
153
154    #[tokio::test]
155    async fn cors_private_network_header_is_added_correctly_with_predicate() {
156        let allow_private_network =
157            AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
158                parts.uri.path() == "/allow-private" && origin == "localhost"
159            });
160        let mut service = ServiceBuilder::new()
161            .layer(CorsLayer::new().allow_private_network(allow_private_network))
162            .service_fn(echo);
163
164        let req = Request::builder()
165            .header(ORIGIN, "localhost")
166            .header(REQUEST_PRIVATE_NETWORK, TRUE)
167            .uri("/allow-private")
168            .body(Body::empty())
169            .unwrap();
170
171        let res = service.ready().await.unwrap().call(req).await.unwrap();
172        assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
173
174        let req = Request::builder()
175            .header(ORIGIN, "localhost")
176            .header(REQUEST_PRIVATE_NETWORK, TRUE)
177            .uri("/other")
178            .body(Body::empty())
179            .unwrap();
180
181        let res = service.ready().await.unwrap().call(req).await.unwrap();
182
183        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
184
185        let req = Request::builder()
186            .header(ORIGIN, "not-localhost")
187            .header(REQUEST_PRIVATE_NETWORK, TRUE)
188            .uri("/allow-private")
189            .body(Body::empty())
190            .unwrap();
191
192        let res = service.ready().await.unwrap().call(req).await.unwrap();
193
194        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
195    }
196
197    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
198        Ok(Response::new(req.into_body()))
199    }
200}