1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2use std::{
32 convert::Infallible,
33 future::Future,
34 pin::Pin,
35 sync::Arc,
36 task::{Context, Poll},
37};
38
39use bytes::Bytes;
40use http::{
41 HeaderName, HeaderValue, Request, Response,
42 header::{InvalidHeaderName, InvalidHeaderValue},
43};
44pub use matchit::InsertError;
45use matchit::Router;
46use tower::{Layer, Service};
47
48type BonusHeaders = Arc<[(HeaderName, HeaderValue)]>;
49
50#[macro_use]
51extern crate tracing;
52
53#[derive(Clone, Debug)]
54pub struct HeaderGroup {
55 pub path: String,
56 pub targets: Vec<(HeaderName, HeaderValue)>,
57}
58
59pub fn parse(header_file: &str) -> Result<Vec<HeaderGroup>, HeaderParseError> {
64 if header_file.is_empty() {
65 return Ok(Vec::new());
66 }
67 let mut headers = Vec::new();
68 let mut current_ctx: Option<HeaderGroup> = None;
69 for (idx, line) in header_file.lines().enumerate() {
70 if line.is_empty() || line.trim().starts_with('#') {
71 continue;
73 }
74 if line.starts_with(['\t', ' ']) {
75 let Some(ctx) = current_ctx.as_mut() else {
76 return Err(HeaderParseError::new(HeaderParseErrorKind::NoParseCtx, idx));
77 };
78 let (name, value) = line
79 .trim()
80 .split_once(':')
81 .ok_or_else(|| HeaderParseError::new(HeaderParseErrorKind::NoHeaderColon, idx))?;
82 let name = match HeaderName::from_bytes(name.trim().as_bytes()) {
83 Ok(v) => v,
84 Err(e) => {
85 return Err(HeaderParseError::new(
86 HeaderParseErrorKind::HeaderNameParse(e),
87 idx,
88 ));
89 }
90 };
91 let value = match HeaderValue::from_bytes(value.trim().as_bytes()) {
92 Ok(v) => v,
93 Err(e) => {
94 return Err(HeaderParseError::new(
95 HeaderParseErrorKind::HeaderValueParse(e),
96 idx,
97 ));
98 }
99 };
100
101 ctx.targets.push((name, value));
102 } else {
103 let mut group = Some(HeaderGroup {
104 path: line.trim().to_string(),
105 targets: Vec::new(),
106 });
107 std::mem::swap(&mut current_ctx, &mut group);
108 if let Some(group) = group {
109 group_add(&mut headers, group);
110 }
111 }
112 }
113 if let Some(group) = current_ctx {
114 group_add(&mut headers, group);
115 }
116 info!(?headers, "Got headers");
117 Ok(headers)
118}
119
120fn group_add(headers: &mut Vec<HeaderGroup>, group: HeaderGroup) {
121 if group.path.ends_with('*') {
123 let end_idx = group.path.len() - 1;
124 let base_path = &group.path[0..end_idx];
125 trace!("Generating wildcard for {base_path}");
126 headers.push(HeaderGroup {
127 path: base_path.to_owned(),
128 targets: group.targets.clone(),
129 });
130 headers.push(HeaderGroup {
131 path: format!("{base_path}{{*all}}"),
132 targets: group.targets.clone(),
133 });
134 } else {
135 headers.push(group);
136 }
137}
138
139#[derive(Debug, thiserror::Error)]
140#[error("at line {row}: {kind}")]
141pub struct HeaderParseError {
143 pub row: usize,
144 #[source]
145 pub kind: HeaderParseErrorKind,
146}
147
148impl HeaderParseError {
149 const fn new(kind: HeaderParseErrorKind, idx: usize) -> Self {
150 Self { row: idx + 1, kind }
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
155pub enum HeaderParseErrorKind {
158 #[error("Header name invalid: {0}")]
159 HeaderNameParse(#[from] InvalidHeaderName),
160 #[error("Header name value: {0}")]
161 HeaderValueParse(#[from] InvalidHeaderValue),
162 #[error("You must specify an unindented path before specifying headers")]
163 NoParseCtx,
164 #[error("You must put a colon at the end of the header name")]
165 NoHeaderColon,
166}
167
168#[derive(Clone)]
169pub struct HeadersLayer {
171 headers: Arc<matchit::Router<BonusHeaders>>,
172}
173
174impl HeadersLayer {
175 pub fn new(header_list: Vec<HeaderGroup>) -> Result<Self, InsertError> {
181 let mut headers = Router::new();
182 for header in header_list {
183 headers.insert(header.path, header.targets.into())?;
184 }
185
186 info!(?headers, "Built auto header map");
187
188 Ok(Self {
189 headers: Arc::new(headers),
190 })
191 }
192}
193
194impl<S> Layer<S> for HeadersLayer {
195 type Service = Headers<S>;
196
197 fn layer(&self, inner: S) -> Headers<S> {
198 Headers {
199 headers: self.headers.clone(),
200 inner,
201 }
202 }
203}
204
205#[derive(Clone)]
206pub struct Headers<S> {
208 headers: Arc<matchit::Router<BonusHeaders>>,
209 inner: S,
210}
211
212#[pin_project::pin_project]
213pub struct ResponseFuture<F> {
215 #[pin]
216 src: F,
217 additional_headers: Option<BonusHeaders>,
218}
219
220impl<F, B, BE> std::future::Future for ResponseFuture<F>
221where
222 F: Future<Output = Result<Response<B>, Infallible>>,
223 B: http_body::Body<Data = Bytes, Error = BE> + Send + 'static,
224{
225 type Output = F::Output;
226
227 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228 let bonus_headers = self.additional_headers.clone();
229 self.project()
230 .src
231 .poll(cx)
232 .map(|v| add_headers(v, bonus_headers))
233 }
234}
235
236#[allow(clippy::unnecessary_wraps)]
237fn add_headers<B>(
238 res: Result<Response<B>, Infallible>,
239 bonus_headers: Option<BonusHeaders>,
240) -> Result<Response<B>, Infallible> {
241 let Ok(mut inner) = res;
242 let resp_headers = inner.headers_mut();
243 if let Some(bonus_headers) = bonus_headers {
244 for (name, value) in bonus_headers.iter() {
245 resp_headers.insert(name.clone(), value.clone());
246 }
247 }
248 Ok(inner)
249}
250
251impl<ReqBody, F, FResBody, FResBodyError> Service<Request<ReqBody>> for Headers<F>
252where
253 F: Service<Request<ReqBody>, Response = Response<FResBody>, Error = Infallible> + Clone,
254 F::Future: Send + 'static,
255 FResBody: http_body::Body<Data = Bytes, Error = FResBodyError> + Send + 'static,
256 FResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
257{
258 type Error = Infallible;
259 type Future = ResponseFuture<F::Future>;
260 type Response = Response<FResBody>;
261
262 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263 self.inner.poll_ready(cx)
264 }
265
266 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
267 let path = req.uri().path();
268 let additional_headers = self.headers.at(path).ok().map(|v| v.value.clone());
269 ResponseFuture {
270 src: self.inner.call(req),
271 additional_headers,
272 }
273 }
274}