tunnelbana_hidepaths/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2//! # tunnelbana-hidepaths
3//! Hide specific paths in tower services by sending them to a 404 service.
4//!
5//! Part of the [tunnelbana](https://github.com/randomairborne/tunnelbana) project.
6//!
7//! # Example
8//! ```rust
9//! use tower_http::services::ServeDir;
10//! use tower::{ServiceBuilder, ServiceExt};
11//! use http::Response;
12//! use tunnelbana_hidepaths::HidePathsLayer;
13//!
14//! let hidepaths_middleware = HidePathsLayer::builder()
15//!     .hide("/_redirects")
16//!     .hide_all(["/.htaccess", "/.well-known/{*hide}"])
17//!     .build()
18//!     .expect("Failed to build path hide router");
19//! let serve_dir = ServeDir::new("/var/www/html").append_index_html_on_directories(true);
20//! let service = ServiceBuilder::new()
21//!    .layer(hidepaths_middleware)
22//!    .service(serve_dir);
23//! ```
24use std::{
25    convert::Infallible,
26    future::Future,
27    pin::Pin,
28    sync::Arc,
29    task::{Context, Poll},
30};
31
32use bytes::Bytes;
33use http::{Request, Response, StatusCode};
34use http_body_util::Either;
35pub use matchit::InsertError;
36use tower::{Layer, Service};
37
38#[derive(Clone)]
39/// Build a [`matchit::Router`] of paths which should be routed to
40/// the not found service.
41///
42/// The not found service defaults to [`DefaultNotFoundService`],
43/// however it is very barebones, so it is recommended to supply your own with [`Self::with_not_found_service`].
44pub struct HidePathsLayerBuilder<N = DefaultNotFoundService> {
45    hidden: matchit::Router<()>,
46    notfound: N,
47    errors: Vec<(String, InsertError)>,
48}
49
50impl<N> HidePathsLayerBuilder<N> {
51    #[must_use]
52    /// Create a new builder with the [`DefaultNotFoundService`].
53    pub fn new() -> HidePathsLayerBuilder<DefaultNotFoundService> {
54        HidePathsLayerBuilder {
55            hidden: matchit::Router::new(),
56            notfound: DefaultNotFoundService,
57            errors: Vec::new(),
58        }
59    }
60
61    /// Use a different service for 404'd files than the [`DefaultNotFoundService`].
62    pub fn with_not_found_service<T>(self, notfound: T) -> HidePathsLayerBuilder<T> {
63        HidePathsLayerBuilder {
64            notfound,
65            hidden: self.hidden,
66            errors: self.errors,
67        }
68    }
69
70    #[must_use]
71    /// All [`matchit`] routes passed to this method will be routed to the not found service.
72    pub fn hide(mut self, route: impl Into<String>) -> Self {
73        let route = route.into();
74        if let Err(err) = self.hidden.insert(&route, ()) {
75            self.errors.push((route, err));
76        }
77        self
78    }
79
80    #[must_use]
81    /// Convenience method for calling [`Self::hide`] in a loop.
82    pub fn hide_all<IS: Into<String>>(mut self, routes: impl IntoIterator<Item = IS>) -> Self {
83        for route in routes {
84            self = self.hide(route);
85        }
86        self
87    }
88
89    /// Get a list of errors which have occured inside the builder.
90    pub const fn errors(&self) -> &[(String, InsertError)] {
91        self.errors.as_slice()
92    }
93
94    /// Build this [`HidePathsLayer`].
95    /// # Errors
96    /// This function errors if matchit has had any errors while inserting-
97    /// you get the path that was inserted, and the error.
98    pub fn build(self) -> Result<HidePathsLayer<N>, HidePathsLayerBuilderError> {
99        if !self.errors.is_empty() {
100            return Err(HidePathsLayerBuilderError(self.errors));
101        }
102        Ok(HidePathsLayer {
103            hidden: Arc::new(self.hidden),
104            notfound: self.notfound,
105        })
106    }
107}
108
109#[derive(Debug)]
110pub struct HidePathsLayerBuilderError(pub Vec<(String, InsertError)>);
111
112impl std::fmt::Display for HidePathsLayerBuilderError {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(f, "Could not hide the following paths due to errors: ")?;
115        for (path, err) in &self.0 {
116            write!(f, "`{path}` due to `{err}`, ")?;
117        }
118        Ok(())
119    }
120}
121
122impl std::error::Error for HidePathsLayerBuilderError {}
123
124#[derive(Clone)]
125/// A [`tower::Layer`] for use with a [`tower::ServiceBuilder`] to reply with a fallback
126/// service to any routes found internally.
127pub struct HidePathsLayer<N = DefaultNotFoundService> {
128    hidden: Arc<matchit::Router<()>>,
129    notfound: N,
130}
131
132impl HidePathsLayer<DefaultNotFoundService> {
133    #[must_use]
134    pub fn builder() -> HidePathsLayerBuilder<DefaultNotFoundService> {
135        HidePathsLayerBuilder::<DefaultNotFoundService>::new()
136    }
137}
138
139impl<S, N> Layer<S> for HidePathsLayer<N>
140where
141    N: Clone,
142{
143    type Service = HidePath<S, N>;
144
145    fn layer(&self, inner: S) -> HidePath<S, N> {
146        HidePath {
147            hidden: self.hidden.clone(),
148            notfound: self.notfound.clone(),
149            inner,
150        }
151    }
152}
153
154#[derive(Clone)]
155/// A wrapper service which forwards to one of two inner services based on if the requested
156/// path is contained within its internal router.
157pub struct HidePath<S, N> {
158    hidden: Arc<matchit::Router<()>>,
159    notfound: N,
160    inner: S,
161}
162
163#[pin_project::pin_project(project = PinResponseSource)]
164/// Future which always delegates the whole response to either the default service, or
165/// a not-found fallback, and returns the service response unmodified.
166pub enum ResponseFuture<S, N> {
167    Child(#[pin] S),
168    NotFound(#[pin] N),
169}
170
171impl<S, N, SB, NB, SBE, NBE> std::future::Future for ResponseFuture<S, N>
172where
173    S: Future<Output = Result<Response<SB>, Infallible>>,
174    N: Future<Output = Result<Response<NB>, Infallible>>,
175    SB: http_body::Body<Data = Bytes, Error = SBE> + Send + 'static,
176    NB: http_body::Body<Data = Bytes, Error = NBE> + Send + 'static,
177{
178    type Output = Result<Response<Either<SB, NB>>, Infallible>;
179
180    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181        match self.project() {
182            PinResponseSource::Child(s) => s.poll(cx).map(|v| {
183                v.map(|resp| {
184                    let (parts, body) = resp.into_parts();
185                    Response::from_parts(parts, Either::Left(body))
186                })
187            }),
188            PinResponseSource::NotFound(s) => s.poll(cx).map(|v| {
189                v.map(|resp| {
190                    let (parts, body) = resp.into_parts();
191                    Response::from_parts(parts, Either::Right(body))
192                })
193            }),
194        }
195    }
196}
197
198impl<ReqBody, S, SResBody, SResBodyError, N, NResBody, NResBodyError> Service<Request<ReqBody>>
199    for HidePath<S, N>
200where
201    S: Service<Request<ReqBody>, Response = Response<SResBody>, Error = Infallible> + Clone,
202    S::Future: Send + 'static,
203    SResBody: http_body::Body<Data = Bytes, Error = SResBodyError> + Send + 'static,
204    SResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
205    N: Service<Request<ReqBody>, Response = Response<NResBody>, Error = Infallible> + Clone,
206    N::Future: Send + 'static,
207    NResBody: http_body::Body<Data = Bytes, Error = NResBodyError> + Send + 'static,
208    NResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
209{
210    type Error = Infallible;
211    type Future = ResponseFuture<S::Future, N::Future>;
212    type Response = Response<http_body_util::Either<SResBody, NResBody>>;
213
214    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        self.inner.poll_ready(cx)
216    }
217
218    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
219        let path = req.uri().path();
220        if self.hidden.at(path).is_ok() {
221            tracing::info!(?path, "Blocked request");
222            ResponseFuture::NotFound(self.notfound.call(req))
223        } else {
224            ResponseFuture::Child(self.inner.call(req))
225        }
226    }
227}
228
229#[derive(Clone, Copy, Debug, Default)]
230/// Unconfigurable service which returns HTTP 404s with no body.
231pub struct DefaultNotFoundService;
232
233/// Future type which returns an empty HTTP 404.
234pub struct DefaultNotFoundFuture;
235
236impl<T> Service<T> for DefaultNotFoundService {
237    type Error = Infallible;
238    type Future = DefaultNotFoundFuture;
239    type Response = Response<http_body_util::Empty<Bytes>>;
240
241    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242        Poll::Ready(Ok(()))
243    }
244
245    fn call(&mut self, _: T) -> Self::Future {
246        DefaultNotFoundFuture
247    }
248}
249
250impl std::future::Future for DefaultNotFoundFuture {
251    type Output = Result<Response<http_body_util::Empty<Bytes>>, Infallible>;
252
253    fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
254        let mut resp = Response::new(http_body_util::Empty::new());
255        *resp.status_mut() = StatusCode::NOT_FOUND;
256        Poll::Ready(Ok(resp))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use http::Request;
263    use http_body_util::{BodyExt, Empty};
264    use tower::ServiceExt;
265
266    fn request(url: &str) -> Request<Empty<Bytes>> {
267        Request::builder().uri(url).body(Empty::new()).unwrap()
268    }
269
270    use super::*;
271    #[tokio::test]
272    async fn path_hidden() {
273        let body = "test string";
274        let layer = HidePathsLayer::builder()
275            .hide("/example.html")
276            .build()
277            .unwrap();
278        let svc = tower::ServiceBuilder::new().layer(layer).service_fn(
279            |_: Request<Empty<Bytes>>| async move {
280                Ok::<_, Infallible>(Response::new(http_body_util::Full::new(Bytes::from(body))))
281            },
282        );
283        let not_found = svc.clone().oneshot(request("/example.html")).await.unwrap();
284        assert_eq!(not_found.status(), StatusCode::NOT_FOUND);
285        assert!(
286            not_found
287                .body()
288                .clone()
289                .collect()
290                .await
291                .unwrap()
292                .to_bytes()
293                .is_empty()
294        );
295    }
296
297    #[tokio::test]
298    async fn path_not_hidden() {
299        let body = "test string";
300        let layer = HidePathsLayer::builder()
301            .hide("/example.html")
302            .build()
303            .unwrap();
304        let svc = tower::ServiceBuilder::new().layer(layer).service_fn(
305            |_: Request<Empty<Bytes>>| async move {
306                Ok::<_, Infallible>(Response::new(http_body_util::Full::new(Bytes::from(body))))
307            },
308        );
309        let not_found = svc
310            .clone()
311            .oneshot(request("/example.htmlb"))
312            .await
313            .unwrap();
314        assert_eq!(not_found.status(), StatusCode::OK);
315        assert_eq!(
316            not_found.body().clone().collect().await.unwrap().to_bytes(),
317            body
318        );
319    }
320}