1#![warn(clippy::all, clippy::nursery)]
2#[cfg(feature = "axum")]
10mod axum;
11pub mod csp;
12pub mod headers;
13
14#[cfg(test)]
15mod tests;
16
17use std::{
18 future::Future,
19 sync::Arc,
20 task::{Context, Poll},
21};
22
23use futures_util::future::BoxFuture;
24use http::{
25 header::{CONTENT_SECURITY_POLICY, CONTENT_SECURITY_POLICY_REPORT_ONLY},
26 HeaderMap, HeaderName, HeaderValue, Request, Response,
27};
28use rand::{distributions::Alphanumeric, Rng};
29use tower_layer::Layer;
30use tower_service::Service;
31
32use crate::{
33 csp::{CspNonce, BAD_CSP_MESSAGE},
34 headers::{
35 ContentSecurityPolicy, CrossOriginEmbedderPolicy, CrossOriginOpenerPolicy,
36 CrossOriginResourcePolicy, Header, OriginAgentCluster, ReferrerPolicy,
37 StrictTransportSecurity, XContentTypeOptions, XDnsPrefetchControl, XDownloadOptions,
38 XFrameOptions, XPermittedCrossDomainPolicies, XXssProtection,
39 },
40};
41
42#[derive(Debug, Clone)]
43pub struct Sombrero {
45 content_security_policy: Option<Arc<ContentSecurityPolicy>>,
46 content_security_policy_report_only: Option<Arc<ContentSecurityPolicy>>,
47 cross_origin_embedder_policy: Option<CrossOriginEmbedderPolicy>,
48 cross_origin_opener_policy: Option<CrossOriginOpenerPolicy>,
49 cross_origin_resource_policy: Option<CrossOriginResourcePolicy>,
50 origin_agent_cluster: Option<OriginAgentCluster>,
51 referrer_policy: Option<ReferrerPolicy>,
52 strict_transport_security: Option<StrictTransportSecurity>,
53 x_content_type_options: Option<XContentTypeOptions>,
54 x_dns_prefetch_control: Option<XDnsPrefetchControl>,
56 x_download_options: Option<XDownloadOptions>,
57 x_frame_options: Option<XFrameOptions>,
58 x_permitted_cross_domain_policies: Option<XPermittedCrossDomainPolicies>,
59 x_xss_protection: Option<XXssProtection>,
61}
62
63macro_rules! builder_add {
64 ($field:ident, $kind:ty) => {
65 #[must_use]
66 pub fn $field(self, k: $kind) -> Self {
67 Self {
68 $field: ::std::option::Option::Some(k),
69 ..self
70 }
71 }
72 };
73}
74
75macro_rules! builder_add_arc {
76 ($field:ident, $kind:ty) => {
77 #[must_use]
78 pub fn $field(self, k: $kind) -> Self {
79 Self {
80 $field: ::std::option::Option::Some(::std::sync::Arc::new(k)),
81 ..self
82 }
83 }
84 };
85}
86
87macro_rules! builder_remove {
88 ($field:ident, $name:ident) => {
89 #[must_use]
90 pub fn $name(self) -> Self {
91 Self {
92 $field: ::std::option::Option::None,
93 ..self
94 }
95 }
96 };
97}
98
99impl Sombrero {
100 pub const fn new_empty() -> Self {
101 Self {
102 content_security_policy: None,
103 content_security_policy_report_only: None,
104 cross_origin_embedder_policy: None,
105 cross_origin_opener_policy: None,
106 cross_origin_resource_policy: None,
107 origin_agent_cluster: None,
108 referrer_policy: None,
109 strict_transport_security: None,
110 x_content_type_options: None,
111 x_dns_prefetch_control: None,
112 x_download_options: None,
113 x_frame_options: None,
114 x_permitted_cross_domain_policies: None,
115 x_xss_protection: None,
116 }
117 }
118}
119
120#[rustfmt::skip]
121impl Sombrero {
122 builder_remove!(content_security_policy, remove_content_security_policy);
123 builder_remove!(content_security_policy_report_only, remove_content_security_policy_report_only);
124 builder_remove!(cross_origin_embedder_policy, remove_cross_origin_embedder_policy);
125 builder_remove!(cross_origin_opener_policy, remove_cross_origin_opener_policy);
126 builder_remove!(cross_origin_resource_policy, remove_cross_origin_resource_policy);
127 builder_remove!(origin_agent_cluster, remove_origin_agent_cluster);
128 builder_remove!(referrer_policy, remove_referrer_policy);
129 builder_remove!(strict_transport_security, remove_strict_transport_security);
130 builder_remove!(x_content_type_options, remove_x_content_type_options);
131 builder_remove!(x_dns_prefetch_control, remove_x_dns_prefetch_control);
132 builder_remove!(x_download_options, remove_x_download_options);
133 builder_remove!(x_frame_options, remove_x_frame_options);
134 builder_remove!(x_permitted_cross_domain_policies, remove_x_permitted_cross_domain_policies);
135 builder_remove!(x_xss_protection, remove_x_xss_protection);
136 builder_add_arc!(content_security_policy, ContentSecurityPolicy);
137 builder_add_arc!(content_security_policy_report_only, ContentSecurityPolicy);
138 builder_add!(cross_origin_embedder_policy, CrossOriginEmbedderPolicy);
139 builder_add!(cross_origin_opener_policy, CrossOriginOpenerPolicy);
140 builder_add!(cross_origin_resource_policy, CrossOriginResourcePolicy);
141 builder_add!(origin_agent_cluster, OriginAgentCluster);
142 builder_add!(referrer_policy, ReferrerPolicy);
143 builder_add!(strict_transport_security, StrictTransportSecurity);
144 builder_add!(x_content_type_options, XContentTypeOptions);
145 builder_add!(x_dns_prefetch_control, XDnsPrefetchControl);
146 builder_add!(x_download_options, XDownloadOptions);
147 builder_add!(x_frame_options, XFrameOptions);
148 builder_add!(x_permitted_cross_domain_policies, XPermittedCrossDomainPolicies);
149 builder_add!(x_xss_protection, XXssProtection);
150}
151
152impl Default for Sombrero {
153 fn default() -> Self {
154 Self {
155 content_security_policy: Some(Arc::new(ContentSecurityPolicy::strict_default())),
156 content_security_policy_report_only: None,
157 cross_origin_embedder_policy: None,
158 cross_origin_opener_policy: Some(CrossOriginOpenerPolicy::SameOrigin),
159 cross_origin_resource_policy: Some(CrossOriginResourcePolicy::SameOrigin),
160 origin_agent_cluster: Some(OriginAgentCluster),
161 referrer_policy: Some(ReferrerPolicy::NoReferrer),
162 strict_transport_security: Some(StrictTransportSecurity::DEFAULT),
163 x_content_type_options: Some(XContentTypeOptions),
164 x_dns_prefetch_control: None,
165 x_download_options: Some(XDownloadOptions),
166 x_frame_options: Some(XFrameOptions::Sameorigin),
167 x_permitted_cross_domain_policies: Some(XPermittedCrossDomainPolicies::None),
168 x_xss_protection: Some(XXssProtection::False),
169 }
170 }
171}
172
173impl<S> Layer<S> for Sombrero {
174 type Service = SombreroService<S>;
175
176 fn layer(&self, inner: S) -> Self::Service {
177 SombreroService {
178 sombrero: self.clone(),
179 inner,
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
185pub struct SombreroService<S> {
186 sombrero: Sombrero,
187 inner: S,
188}
189
190impl<S, Body> Service<Request<Body>> for SombreroService<S>
191where
192 S: Service<Request<Body>, Response = Response<Body>>,
193 S::Future: Send + 'static,
194 S::Error: 'static,
195 Body: Send + 'static,
196{
197 type Error = S::Error;
198 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
199 type Response = Response<Body>;
200
201 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 self.inner.poll_ready(cx)
203 }
204
205 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
206 let nonce = random_string(32);
207 let csp = self
208 .sombrero
209 .content_security_policy
210 .as_ref()
211 .map(|csp| csp.value(&nonce).expect(BAD_CSP_MESSAGE));
212 let csp_ro = self
213 .sombrero
214 .content_security_policy_report_only
215 .as_ref()
216 .map(|csp| csp.value(&nonce).expect(BAD_CSP_MESSAGE));
217 request.extensions_mut().insert(CspNonce(nonce));
218
219 let future = self.inner.call(request);
220 Box::pin(sombrero_svc_middleware(
221 self.sombrero.clone(),
222 csp,
223 csp_ro,
224 future,
225 ))
226 }
227}
228
229fn add_opt_header(map: &mut HeaderMap, header: Option<impl Header>) {
230 if let Some(header) = header {
231 map.insert(header.name(), header.value());
232 }
233}
234
235fn add_opt_header_raw(
236 map: &mut HeaderMap,
237 header_name: HeaderName,
238 header_value: Option<HeaderValue>,
239) {
240 if let Some(header_value) = header_value {
241 map.insert(header_name, header_value);
242 }
243}
244
245async fn sombrero_svc_middleware<F, B, E>(
246 h: Sombrero,
247 content_security_policy: Option<HeaderValue>,
248 content_security_policy_report_only: Option<HeaderValue>,
249 response_fut: F,
250) -> Result<Response<B>, E>
251where
252 F: Future<Output = Result<Response<B>, E>> + Send,
253{
254 let mut response = response_fut.await?;
255 let m = response.headers_mut();
256 add_opt_header_raw(m, CONTENT_SECURITY_POLICY, content_security_policy);
257 add_opt_header_raw(
258 m,
259 CONTENT_SECURITY_POLICY_REPORT_ONLY,
260 content_security_policy_report_only,
261 );
262 add_opt_header(m, h.cross_origin_embedder_policy);
263 add_opt_header(m, h.cross_origin_opener_policy);
264 add_opt_header(m, h.cross_origin_resource_policy);
265 add_opt_header(m, h.origin_agent_cluster);
266 add_opt_header(m, h.referrer_policy);
267 add_opt_header(m, h.strict_transport_security);
268 add_opt_header(m, h.x_content_type_options);
269 add_opt_header(m, h.x_dns_prefetch_control);
270 add_opt_header(m, h.x_download_options);
271 add_opt_header(m, h.x_frame_options);
272 add_opt_header(m, h.x_permitted_cross_domain_policies);
273 add_opt_header(m, h.x_xss_protection);
274 Ok(response)
275}
276
277pub async fn middleware_add_raw_header<F, B, E>(
278 header_name: HeaderName,
279 header_value: HeaderValue,
280 response_fut: F,
281) -> Result<Response<B>, E>
282where
283 F: Future<Output = Result<Response<B>, E>> + Send,
284{
285 let mut response = response_fut.await?;
286 response.headers_mut().insert(header_name, header_value);
287 Ok(response)
288}
289
290pub fn random_string(length: usize) -> String {
291 let rng = rand::thread_rng();
292 rng.sample_iter(Alphanumeric)
293 .take(length)
294 .map(char::from)
295 .collect()
296}
297
298#[derive(Debug, thiserror::Error)]
299pub enum Error {
300 #[cfg(feature = "axum")]
301 #[error("`Sombrero` middleware (required for `CspNonce` extractor) not enabled!")]
302 NonceMiddlewareNotEnabled(#[from] axum::NonceNotFoundError),
303}