tower_http/cors/
allow_private_network.rs1use std::{fmt, sync::Arc};
2
3use http::{
4 header::{HeaderName, HeaderValue},
5 request::Parts as RequestParts,
6};
7
8#[derive(Clone, Default)]
15#[must_use]
16pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);
17
18impl AllowPrivateNetwork {
19 pub fn yes() -> Self {
25 Self(AllowPrivateNetworkInner::Yes)
26 }
27
28 pub fn predicate<F>(f: F) -> Self
36 where
37 F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
38 {
39 Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
40 }
41
42 #[allow(
43 clippy::declare_interior_mutable_const,
44 clippy::borrow_interior_mutable_const
45 )]
46 pub(super) fn to_header(
47 &self,
48 origin: Option<&HeaderValue>,
49 parts: &RequestParts,
50 ) -> Option<(HeaderName, HeaderValue)> {
51 #[allow(clippy::declare_interior_mutable_const)]
52 const REQUEST_PRIVATE_NETWORK: HeaderName =
53 HeaderName::from_static("access-control-request-private-network");
54
55 #[allow(clippy::declare_interior_mutable_const)]
56 const ALLOW_PRIVATE_NETWORK: HeaderName =
57 HeaderName::from_static("access-control-allow-private-network");
58
59 const TRUE: HeaderValue = HeaderValue::from_static("true");
60
61 if let AllowPrivateNetworkInner::No = &self.0 {
63 return None;
64 }
65
66 if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
69 return None;
70 }
71
72 let allow_private_network = match &self.0 {
73 AllowPrivateNetworkInner::Yes => true,
74 AllowPrivateNetworkInner::No => false, AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
76 };
77
78 allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE))
79 }
80}
81
82impl From<bool> for AllowPrivateNetwork {
83 fn from(v: bool) -> Self {
84 match v {
85 true => Self(AllowPrivateNetworkInner::Yes),
86 false => Self(AllowPrivateNetworkInner::No),
87 }
88 }
89}
90
91impl fmt::Debug for AllowPrivateNetwork {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self.0 {
94 AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
95 AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
96 AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
97 }
98 }
99}
100
101#[derive(Clone, Default)]
102enum AllowPrivateNetworkInner {
103 Yes,
104 #[default]
105 No,
106 Predicate(
107 Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
108 ),
109}
110
111#[cfg(test)]
112mod tests {
113 #![allow(
114 clippy::declare_interior_mutable_const,
115 clippy::borrow_interior_mutable_const
116 )]
117
118 use super::AllowPrivateNetwork;
119 use crate::cors::CorsLayer;
120
121 use crate::test_helpers::Body;
122 use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
123 use tower::{BoxError, ServiceBuilder, ServiceExt};
124 use tower_service::Service;
125
126 const REQUEST_PRIVATE_NETWORK: HeaderName =
127 HeaderName::from_static("access-control-request-private-network");
128
129 const ALLOW_PRIVATE_NETWORK: HeaderName =
130 HeaderName::from_static("access-control-allow-private-network");
131
132 const TRUE: HeaderValue = HeaderValue::from_static("true");
133
134 #[tokio::test]
135 async fn cors_private_network_header_is_added_correctly() {
136 let mut service = ServiceBuilder::new()
137 .layer(CorsLayer::new().allow_private_network(true))
138 .service_fn(echo);
139
140 let req = Request::builder()
141 .header(REQUEST_PRIVATE_NETWORK, TRUE)
142 .body(Body::empty())
143 .unwrap();
144 let res = service.ready().await.unwrap().call(req).await.unwrap();
145
146 assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
147
148 let req = Request::builder().body(Body::empty()).unwrap();
149 let res = service.ready().await.unwrap().call(req).await.unwrap();
150
151 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
152 }
153
154 #[tokio::test]
155 async fn cors_private_network_header_is_added_correctly_with_predicate() {
156 let allow_private_network =
157 AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
158 parts.uri.path() == "/allow-private" && origin == "localhost"
159 });
160 let mut service = ServiceBuilder::new()
161 .layer(CorsLayer::new().allow_private_network(allow_private_network))
162 .service_fn(echo);
163
164 let req = Request::builder()
165 .header(ORIGIN, "localhost")
166 .header(REQUEST_PRIVATE_NETWORK, TRUE)
167 .uri("/allow-private")
168 .body(Body::empty())
169 .unwrap();
170
171 let res = service.ready().await.unwrap().call(req).await.unwrap();
172 assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
173
174 let req = Request::builder()
175 .header(ORIGIN, "localhost")
176 .header(REQUEST_PRIVATE_NETWORK, TRUE)
177 .uri("/other")
178 .body(Body::empty())
179 .unwrap();
180
181 let res = service.ready().await.unwrap().call(req).await.unwrap();
182
183 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
184
185 let req = Request::builder()
186 .header(ORIGIN, "not-localhost")
187 .header(REQUEST_PRIVATE_NETWORK, TRUE)
188 .uri("/allow-private")
189 .body(Body::empty())
190 .unwrap();
191
192 let res = service.ready().await.unwrap().call(req).await.unwrap();
193
194 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
195 }
196
197 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
198 Ok(Response::new(req.into_body()))
199 }
200}