typeway_server/
secure_headers.rs1use std::convert::Infallible;
28use std::future::Future;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31
32use crate::body::BoxBody;
33
34#[derive(Clone, Debug)]
40pub struct SecureHeadersLayer {
41 headers: Vec<(String, String)>,
42}
43
44impl SecureHeadersLayer {
45 pub fn new() -> Self {
55 SecureHeadersLayer {
56 headers: vec![
57 ("x-content-type-options".to_string(), "nosniff".to_string()),
58 ("x-frame-options".to_string(), "DENY".to_string()),
59 ("x-xss-protection".to_string(), "0".to_string()),
60 (
61 "referrer-policy".to_string(),
62 "strict-origin-when-cross-origin".to_string(),
63 ),
64 (
65 "content-security-policy".to_string(),
66 "default-src 'self'".to_string(),
67 ),
68 (
69 "permissions-policy".to_string(),
70 "camera=(), microphone=(), geolocation=()".to_string(),
71 ),
72 ],
73 }
74 }
75
76 pub fn hsts(mut self, max_age_secs: u64) -> Self {
87 let value = format!("max-age={max_age_secs}; includeSubDomains; preload");
88 self.set_header("strict-transport-security", value);
89 self
90 }
91
92 pub fn frame_options(mut self, value: impl Into<String>) -> Self {
96 self.set_header("x-frame-options", value.into());
97 self
98 }
99
100 pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
102 self.set_header("content-security-policy", value.into());
103 self
104 }
105
106 pub fn disable_csp(mut self) -> Self {
108 self.headers
109 .retain(|(name, _)| name != "content-security-policy");
110 self
111 }
112
113 pub fn custom(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
117 let name = name.into().to_ascii_lowercase();
118 let value = value.into();
119 self.set_header(&name, value);
120 self
121 }
122
123 fn set_header(&mut self, name: &str, value: String) {
125 if let Some(entry) = self.headers.iter_mut().find(|(n, _)| n == name) {
126 entry.1 = value;
127 } else {
128 self.headers.push((name.to_string(), value));
129 }
130 }
131}
132
133impl Default for SecureHeadersLayer {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl<S> tower_layer::Layer<S> for SecureHeadersLayer {
140 type Service = SecureHeadersService<S>;
141
142 fn layer(&self, inner: S) -> Self::Service {
143 let parsed: Vec<(http::HeaderName, http::HeaderValue)> = self
145 .headers
146 .iter()
147 .filter_map(|(name, value)| {
148 let header_name = http::HeaderName::from_bytes(name.as_bytes()).ok()?;
149 let header_value = http::HeaderValue::from_str(value).ok()?;
150 Some((header_name, header_value))
151 })
152 .collect();
153
154 SecureHeadersService {
155 inner,
156 headers: std::sync::Arc::new(parsed),
157 }
158 }
159}
160
161#[derive(Clone)]
165pub struct SecureHeadersService<S> {
166 inner: S,
167 headers: std::sync::Arc<Vec<(http::HeaderName, http::HeaderValue)>>,
168}
169
170impl<S, B> tower_service::Service<http::Request<B>> for SecureHeadersService<S>
171where
172 S: tower_service::Service<
173 http::Request<B>,
174 Response = http::Response<BoxBody>,
175 Error = Infallible,
176 > + Clone
177 + Send
178 + 'static,
179 S::Future: Send + 'static,
180 B: Send + 'static,
181{
182 type Response = http::Response<BoxBody>;
183 type Error = Infallible;
184 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
185
186 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187 self.inner.poll_ready(cx)
188 }
189
190 fn call(&mut self, req: http::Request<B>) -> Self::Future {
191 let mut inner = self.inner.clone();
192 let headers = self.headers.clone();
193 Box::pin(async move {
194 let mut resp = inner.call(req).await?;
195 for (name, value) in headers.iter() {
196 resp.headers_mut().insert(name.clone(), value.clone());
197 }
198 Ok(resp)
199 })
200 }
201}