tunnelbana_headers/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2//! # tunnelbana-headers
3//! A tower middleware to add headers to specific routes, or route groups.
4//!
5//! Part of the [tunnelbana](https://github.com/randomairborne/tunnelbana) project.
6//!
7//! # Example
8//! ```rust
9//! use tower_http::services::ServeDir;
10//! use tower::{ServiceBuilder, ServiceExt};
11//! use http::Response;
12//! use tunnelbana_headers::HeadersLayer;
13//!
14//! let config = r#"
15//!/example
16//!  X-Example-Header: example.org
17//!/subpath/{other}
18//!  X-Header-One: h1
19//!  X-Header-Two: h2
20//!/wildcard/{*wildcard}
21//!  X-Header-A: ha
22//!  X-Header-B: hb
23//!"#;
24//! let headers = tunnelbana_headers::parse(config).expect("Failed to parse headers");
25//! let headers_mw = HeadersLayer::new(headers).expect("Failed to route headers");
26//! let serve_dir = ServeDir::new("/var/www/html").append_index_html_on_directories(true);
27//! let service = ServiceBuilder::new()
28//!    .layer(headers_mw)
29//!    .service(serve_dir);
30//! ```
31use std::{
32    convert::Infallible,
33    future::Future,
34    pin::Pin,
35    sync::Arc,
36    task::{Context, Poll},
37};
38
39use bytes::Bytes;
40use http::{
41    HeaderName, HeaderValue, Request, Response,
42    header::{InvalidHeaderName, InvalidHeaderValue},
43};
44pub use matchit::InsertError;
45use matchit::Router;
46use tower::{Layer, Service};
47
48type BonusHeaders = Arc<[(HeaderName, HeaderValue)]>;
49
50#[macro_use]
51extern crate tracing;
52
53#[derive(Clone, Debug)]
54pub struct HeaderGroup {
55    pub path: String,
56    pub targets: Vec<(HeaderName, HeaderValue)>,
57}
58
59/// Parse a list of [`HeaderGroup`]s from a cloudflare-style _headers string.
60/// # Errors
61/// This function errors if you have an orphaned header definition, if you have an invalid header name or value,
62/// or if your name cannot be a matchit path.
63pub fn parse(header_file: &str) -> Result<Vec<HeaderGroup>, HeaderParseError> {
64    if header_file.is_empty() {
65        return Ok(Vec::new());
66    }
67    let mut headers = Vec::new();
68    let mut current_ctx: Option<HeaderGroup> = None;
69    for (idx, line) in header_file.lines().enumerate() {
70        if line.is_empty() || line.trim().starts_with('#') {
71            // handle comments
72            continue;
73        }
74        if line.starts_with(['\t', ' ']) {
75            let Some(ctx) = current_ctx.as_mut() else {
76                return Err(HeaderParseError::new(HeaderParseErrorKind::NoParseCtx, idx));
77            };
78            let (name, value) = line
79                .trim()
80                .split_once(':')
81                .ok_or_else(|| HeaderParseError::new(HeaderParseErrorKind::NoHeaderColon, idx))?;
82            let name = match HeaderName::from_bytes(name.trim().as_bytes()) {
83                Ok(v) => v,
84                Err(e) => {
85                    return Err(HeaderParseError::new(
86                        HeaderParseErrorKind::HeaderNameParse(e),
87                        idx,
88                    ));
89                }
90            };
91            let value = match HeaderValue::from_bytes(value.trim().as_bytes()) {
92                Ok(v) => v,
93                Err(e) => {
94                    return Err(HeaderParseError::new(
95                        HeaderParseErrorKind::HeaderValueParse(e),
96                        idx,
97                    ));
98                }
99            };
100
101            ctx.targets.push((name, value));
102        } else {
103            let mut group = Some(HeaderGroup {
104                path: line.trim().to_string(),
105                targets: Vec::new(),
106            });
107            std::mem::swap(&mut current_ctx, &mut group);
108            if let Some(group) = group {
109                group_add(&mut headers, group);
110            }
111        }
112    }
113    if let Some(group) = current_ctx {
114        group_add(&mut headers, group);
115    }
116    info!(?headers, "Got headers");
117    Ok(headers)
118}
119
120fn group_add(headers: &mut Vec<HeaderGroup>, group: HeaderGroup) {
121    // A * character will register for all subpaths, and also the `/` path above it
122    if group.path.ends_with('*') {
123        let end_idx = group.path.len() - 1;
124        let base_path = &group.path[0..end_idx];
125        trace!("Generating wildcard for {base_path}");
126        headers.push(HeaderGroup {
127            path: base_path.to_owned(),
128            targets: group.targets.clone(),
129        });
130        headers.push(HeaderGroup {
131            path: format!("{base_path}{{*all}}"),
132            targets: group.targets.clone(),
133        });
134    } else {
135        headers.push(group);
136    }
137}
138
139#[derive(Debug, thiserror::Error)]
140#[error("at line {row}: {kind}")]
141/// Describes the location and type of a header parsing problem.
142pub struct HeaderParseError {
143    pub row: usize,
144    #[source]
145    pub kind: HeaderParseErrorKind,
146}
147
148impl HeaderParseError {
149    const fn new(kind: HeaderParseErrorKind, idx: usize) -> Self {
150        Self { row: idx + 1, kind }
151    }
152}
153
154#[derive(Debug, thiserror::Error)]
155/// Types of header parsing errors. These can come from the [`http`]
156/// crate, or internally from `tunnelbana-headers`.
157pub enum HeaderParseErrorKind {
158    #[error("Header name invalid: {0}")]
159    HeaderNameParse(#[from] InvalidHeaderName),
160    #[error("Header name value: {0}")]
161    HeaderValueParse(#[from] InvalidHeaderValue),
162    #[error("You must specify an unindented path before specifying headers")]
163    NoParseCtx,
164    #[error("You must put a colon at the end of the header name")]
165    NoHeaderColon,
166}
167
168#[derive(Clone)]
169/// a [`tower::Layer`] to add to a [`tower::ServiceBuilder`] to add headers.
170pub struct HeadersLayer {
171    headers: Arc<matchit::Router<BonusHeaders>>,
172}
173
174impl HeadersLayer {
175    /// Create a new [`HeadersLayer`]. The header groups are naively added
176    /// to a matchit router internally.
177    /// # Errors
178    /// If two [`HeaderGroup`]s are the same, or would illgally overlap
179    /// an error can be returned
180    pub fn new(header_list: Vec<HeaderGroup>) -> Result<Self, InsertError> {
181        let mut headers = Router::new();
182        for header in header_list {
183            headers.insert(header.path, header.targets.into())?;
184        }
185
186        info!(?headers, "Built auto header map");
187
188        Ok(Self {
189            headers: Arc::new(headers),
190        })
191    }
192}
193
194impl<S> Layer<S> for HeadersLayer {
195    type Service = Headers<S>;
196
197    fn layer(&self, inner: S) -> Headers<S> {
198        Headers {
199            headers: self.headers.clone(),
200            inner,
201        }
202    }
203}
204
205#[derive(Clone)]
206/// a [`tower::Service`] which adds headers to a wrapped S.
207pub struct Headers<S> {
208    headers: Arc<matchit::Router<BonusHeaders>>,
209    inner: S,
210}
211
212#[pin_project::pin_project]
213/// Custom future which adds headers unconditionally to a response.
214pub struct ResponseFuture<F> {
215    #[pin]
216    src: F,
217    additional_headers: Option<BonusHeaders>,
218}
219
220impl<F, B, BE> std::future::Future for ResponseFuture<F>
221where
222    F: Future<Output = Result<Response<B>, Infallible>>,
223    B: http_body::Body<Data = Bytes, Error = BE> + Send + 'static,
224{
225    type Output = F::Output;
226
227    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228        let bonus_headers = self.additional_headers.clone();
229        self.project()
230            .src
231            .poll(cx)
232            .map(|v| add_headers(v, bonus_headers))
233    }
234}
235
236#[allow(clippy::unnecessary_wraps)]
237fn add_headers<B>(
238    res: Result<Response<B>, Infallible>,
239    bonus_headers: Option<BonusHeaders>,
240) -> Result<Response<B>, Infallible> {
241    let Ok(mut inner) = res;
242    let resp_headers = inner.headers_mut();
243    if let Some(bonus_headers) = bonus_headers {
244        for (name, value) in bonus_headers.iter() {
245            resp_headers.insert(name.clone(), value.clone());
246        }
247    }
248    Ok(inner)
249}
250
251impl<ReqBody, F, FResBody, FResBodyError> Service<Request<ReqBody>> for Headers<F>
252where
253    F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
254    F::Future: Send + 'static,
255    FResBody: http_body::Body<Data = Bytes, Error = FResBodyError> + Send + 'static,
256    FResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
257{
258    type Error = Infallible;
259    type Future = ResponseFuture<F::Future>;
260    type Response = Response<FResBody>;
261
262    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263        self.inner.poll_ready(cx)
264    }
265
266    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
267        let path = req.uri().path();
268        let additional_headers = self.headers.at(path).ok().map(|v| v.value.clone());
269        ResponseFuture {
270            src: self.inner.call(req),
271            additional_headers,
272        }
273    }
274}