Skip to main content

typeway_server/
secure_headers.rs

1//! Security headers middleware — adds standard security headers to every response.
2//!
3//! Applies a configurable set of HTTP security headers. The defaults follow
4//! OWASP recommendations and can be overridden or disabled individually.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use typeway_server::secure_headers::SecureHeadersLayer;
10//!
11//! Server::<API>::new(handlers)
12//!     .layer(SecureHeadersLayer::new())
13//!     .serve(addr)
14//!     .await?;
15//! ```
16//!
17//! # Customization
18//!
19//! ```ignore
20//! SecureHeadersLayer::new()
21//!     .hsts(63_072_000)                          // enable HSTS (TLS only)
22//!     .frame_options("SAMEORIGIN")               // allow same-origin framing
23//!     .content_security_policy("default-src 'self'; script-src 'self' cdn.example.com")
24//!     .custom("X-Custom-Header", "value")
25//! ```
26
27use std::convert::Infallible;
28use std::future::Future;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31
32use crate::body::BoxBody;
33
34/// A Tower layer that adds security headers to every response.
35///
36/// Created via [`SecureHeadersLayer::new()`], which sets sensible defaults.
37/// Individual headers can be overridden, disabled, or extended with the
38/// builder methods.
39#[derive(Clone, Debug)]
40pub struct SecureHeadersLayer {
41    headers: Vec<(String, String)>,
42}
43
44impl SecureHeadersLayer {
45    /// Create a new layer with all default security headers enabled.
46    ///
47    /// Default headers:
48    /// - `X-Content-Type-Options: nosniff`
49    /// - `X-Frame-Options: DENY`
50    /// - `X-XSS-Protection: 0`
51    /// - `Referrer-Policy: strict-origin-when-cross-origin`
52    /// - `Content-Security-Policy: default-src 'self'`
53    /// - `Permissions-Policy: camera=(), microphone=(), geolocation=()`
54    pub fn new() -> Self {
55        SecureHeadersLayer {
56            headers: vec![
57                ("x-content-type-options".to_string(), "nosniff".to_string()),
58                ("x-frame-options".to_string(), "DENY".to_string()),
59                ("x-xss-protection".to_string(), "0".to_string()),
60                (
61                    "referrer-policy".to_string(),
62                    "strict-origin-when-cross-origin".to_string(),
63                ),
64                (
65                    "content-security-policy".to_string(),
66                    "default-src 'self'".to_string(),
67                ),
68                (
69                    "permissions-policy".to_string(),
70                    "camera=(), microphone=(), geolocation=()".to_string(),
71                ),
72            ],
73        }
74    }
75
76    /// Enable HTTP Strict Transport Security (HSTS) with the given max-age in seconds.
77    ///
78    /// This header should only be enabled when the server is behind TLS.
79    /// The `includeSubDomains` and `preload` directives are included automatically.
80    ///
81    /// # Example
82    ///
83    /// ```ignore
84    /// SecureHeadersLayer::new().hsts(63_072_000) // 2 years
85    /// ```
86    pub fn hsts(mut self, max_age_secs: u64) -> Self {
87        let value = format!("max-age={max_age_secs}; includeSubDomains; preload");
88        self.set_header("strict-transport-security", value);
89        self
90    }
91
92    /// Override the `X-Frame-Options` header value.
93    ///
94    /// Common values: `"DENY"`, `"SAMEORIGIN"`.
95    pub fn frame_options(mut self, value: impl Into<String>) -> Self {
96        self.set_header("x-frame-options", value.into());
97        self
98    }
99
100    /// Override the `Content-Security-Policy` header value.
101    pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
102        self.set_header("content-security-policy", value.into());
103        self
104    }
105
106    /// Remove the `Content-Security-Policy` header entirely.
107    pub fn disable_csp(mut self) -> Self {
108        self.headers
109            .retain(|(name, _)| name != "content-security-policy");
110        self
111    }
112
113    /// Add an arbitrary header name/value pair.
114    ///
115    /// If the header already exists in the set, its value is replaced.
116    pub fn custom(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
117        let name = name.into().to_ascii_lowercase();
118        let value = value.into();
119        self.set_header(&name, value);
120        self
121    }
122
123    /// Internal helper: set or replace a header by lowercase name.
124    fn set_header(&mut self, name: &str, value: String) {
125        if let Some(entry) = self.headers.iter_mut().find(|(n, _)| n == name) {
126            entry.1 = value;
127        } else {
128            self.headers.push((name.to_string(), value));
129        }
130    }
131}
132
133impl Default for SecureHeadersLayer {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl<S> tower_layer::Layer<S> for SecureHeadersLayer {
140    type Service = SecureHeadersService<S>;
141
142    fn layer(&self, inner: S) -> Self::Service {
143        // Pre-parse header name/value pairs into http types for fast insertion.
144        let parsed: Vec<(http::HeaderName, http::HeaderValue)> = self
145            .headers
146            .iter()
147            .filter_map(|(name, value)| {
148                let header_name = http::HeaderName::from_bytes(name.as_bytes()).ok()?;
149                let header_value = http::HeaderValue::from_str(value).ok()?;
150                Some((header_name, header_value))
151            })
152            .collect();
153
154        SecureHeadersService {
155            inner,
156            headers: std::sync::Arc::new(parsed),
157        }
158    }
159}
160
161/// The Tower service produced by [`SecureHeadersLayer`].
162///
163/// Wraps an inner service and appends security headers to every response.
164#[derive(Clone)]
165pub struct SecureHeadersService<S> {
166    inner: S,
167    headers: std::sync::Arc<Vec<(http::HeaderName, http::HeaderValue)>>,
168}
169
170impl<S, B> tower_service::Service<http::Request<B>> for SecureHeadersService<S>
171where
172    S: tower_service::Service<
173            http::Request<B>,
174            Response = http::Response<BoxBody>,
175            Error = Infallible,
176        > + Clone
177        + Send
178        + 'static,
179    S::Future: Send + 'static,
180    B: Send + 'static,
181{
182    type Response = http::Response<BoxBody>;
183    type Error = Infallible;
184    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
185
186    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187        self.inner.poll_ready(cx)
188    }
189
190    fn call(&mut self, req: http::Request<B>) -> Self::Future {
191        let mut inner = self.inner.clone();
192        let headers = self.headers.clone();
193        Box::pin(async move {
194            let mut resp = inner.call(req).await?;
195            for (name, value) in headers.iter() {
196                resp.headers_mut().insert(name.clone(), value.clone());
197            }
198            Ok(resp)
199        })
200    }
201}