1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2use std::{
25 convert::Infallible,
26 future::Future,
27 pin::Pin,
28 sync::Arc,
29 task::{Context, Poll},
30};
31
32use bytes::Bytes;
33use http::{Request, Response, StatusCode};
34use http_body_util::Either;
35pub use matchit::InsertError;
36use tower::{Layer, Service};
37
38#[derive(Clone)]
39pub struct HidePathsLayerBuilder<N = DefaultNotFoundService> {
45 hidden: matchit::Router<()>,
46 notfound: N,
47 errors: Vec<(String, InsertError)>,
48}
49
50impl<N> HidePathsLayerBuilder<N> {
51 #[must_use]
52 pub fn new() -> HidePathsLayerBuilder<DefaultNotFoundService> {
54 HidePathsLayerBuilder {
55 hidden: matchit::Router::new(),
56 notfound: DefaultNotFoundService,
57 errors: Vec::new(),
58 }
59 }
60
61 pub fn with_not_found_service<T>(self, notfound: T) -> HidePathsLayerBuilder<T> {
63 HidePathsLayerBuilder {
64 notfound,
65 hidden: self.hidden,
66 errors: self.errors,
67 }
68 }
69
70 #[must_use]
71 pub fn hide(mut self, route: impl Into<String>) -> Self {
73 let route = route.into();
74 if let Err(err) = self.hidden.insert(&route, ()) {
75 self.errors.push((route, err));
76 }
77 self
78 }
79
80 #[must_use]
81 pub fn hide_all<IS: Into<String>>(mut self, routes: impl IntoIterator<Item = IS>) -> Self {
83 for route in routes {
84 self = self.hide(route);
85 }
86 self
87 }
88
89 pub const fn errors(&self) -> &[(String, InsertError)] {
91 self.errors.as_slice()
92 }
93
94 pub fn build(self) -> Result<HidePathsLayer<N>, HidePathsLayerBuilderError> {
99 if !self.errors.is_empty() {
100 return Err(HidePathsLayerBuilderError(self.errors));
101 }
102 Ok(HidePathsLayer {
103 hidden: Arc::new(self.hidden),
104 notfound: self.notfound,
105 })
106 }
107}
108
109#[derive(Debug)]
110pub struct HidePathsLayerBuilderError(pub Vec<(String, InsertError)>);
111
112impl std::fmt::Display for HidePathsLayerBuilderError {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "Could not hide the following paths due to errors: ")?;
115 for (path, err) in &self.0 {
116 write!(f, "`{path}` due to `{err}`, ")?;
117 }
118 Ok(())
119 }
120}
121
122impl std::error::Error for HidePathsLayerBuilderError {}
123
124#[derive(Clone)]
125pub struct HidePathsLayer<N = DefaultNotFoundService> {
128 hidden: Arc<matchit::Router<()>>,
129 notfound: N,
130}
131
132impl HidePathsLayer<DefaultNotFoundService> {
133 #[must_use]
134 pub fn builder() -> HidePathsLayerBuilder<DefaultNotFoundService> {
135 HidePathsLayerBuilder::<DefaultNotFoundService>::new()
136 }
137}
138
139impl<S, N> Layer<S> for HidePathsLayer<N>
140where
141 N: Clone,
142{
143 type Service = HidePath<S, N>;
144
145 fn layer(&self, inner: S) -> HidePath<S, N> {
146 HidePath {
147 hidden: self.hidden.clone(),
148 notfound: self.notfound.clone(),
149 inner,
150 }
151 }
152}
153
154#[derive(Clone)]
155pub struct HidePath<S, N> {
158 hidden: Arc<matchit::Router<()>>,
159 notfound: N,
160 inner: S,
161}
162
163#[pin_project::pin_project(project = PinResponseSource)]
164pub enum ResponseFuture<S, N> {
167 Child(#[pin] S),
168 NotFound(#[pin] N),
169}
170
171impl<S, N, SB, NB, SBE, NBE> std::future::Future for ResponseFuture<S, N>
172where
173 S: Future<Output = Result<Response<SB>, Infallible>>,
174 N: Future<Output = Result<Response<NB>, Infallible>>,
175 SB: http_body::Body<Data = Bytes, Error = SBE> + Send + 'static,
176 NB: http_body::Body<Data = Bytes, Error = NBE> + Send + 'static,
177{
178 type Output = Result<Response<Either<SB, NB>>, Infallible>;
179
180 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
181 match self.project() {
182 PinResponseSource::Child(s) => s.poll(cx).map(|v| {
183 v.map(|resp| {
184 let (parts, body) = resp.into_parts();
185 Response::from_parts(parts, Either::Left(body))
186 })
187 }),
188 PinResponseSource::NotFound(s) => s.poll(cx).map(|v| {
189 v.map(|resp| {
190 let (parts, body) = resp.into_parts();
191 Response::from_parts(parts, Either::Right(body))
192 })
193 }),
194 }
195 }
196}
197
198impl<ReqBody, S, SResBody, SResBodyError, N, NResBody, NResBodyError> Service<Request<ReqBody>>
199 for HidePath<S, N>
200where
201 S: Service<Request<ReqBody>, Response = Response<SResBody>, Error = Infallible> + Clone,
202 S::Future: Send + 'static,
203 SResBody: http_body::Body<Data = Bytes, Error = SResBodyError> + Send + 'static,
204 SResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
205 N: Service<Request<ReqBody>, Response = Response<NResBody>, Error = Infallible> + Clone,
206 N::Future: Send + 'static,
207 NResBody: http_body::Body<Data = Bytes, Error = NResBodyError> + Send + 'static,
208 NResBodyError: Into<Box<dyn std::error::Error + Send + Sync>>,
209{
210 type Error = Infallible;
211 type Future = ResponseFuture<S::Future, N::Future>;
212 type Response = Response<http_body_util::Either<SResBody, NResBody>>;
213
214 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 self.inner.poll_ready(cx)
216 }
217
218 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
219 let path = req.uri().path();
220 if self.hidden.at(path).is_ok() {
221 tracing::info!(?path, "Blocked request");
222 ResponseFuture::NotFound(self.notfound.call(req))
223 } else {
224 ResponseFuture::Child(self.inner.call(req))
225 }
226 }
227}
228
229#[derive(Clone, Copy, Debug, Default)]
230pub struct DefaultNotFoundService;
232
233pub struct DefaultNotFoundFuture;
235
236impl<T> Service<T> for DefaultNotFoundService {
237 type Error = Infallible;
238 type Future = DefaultNotFoundFuture;
239 type Response = Response<http_body_util::Empty<Bytes>>;
240
241 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242 Poll::Ready(Ok(()))
243 }
244
245 fn call(&mut self, _: T) -> Self::Future {
246 DefaultNotFoundFuture
247 }
248}
249
250impl std::future::Future for DefaultNotFoundFuture {
251 type Output = Result<Response<http_body_util::Empty<Bytes>>, Infallible>;
252
253 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
254 let mut resp = Response::new(http_body_util::Empty::new());
255 *resp.status_mut() = StatusCode::NOT_FOUND;
256 Poll::Ready(Ok(resp))
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use http::Request;
263 use http_body_util::{BodyExt, Empty};
264 use tower::ServiceExt;
265
266 fn request(url: &str) -> Request<Empty<Bytes>> {
267 Request::builder().uri(url).body(Empty::new()).unwrap()
268 }
269
270 use super::*;
271 #[tokio::test]
272 async fn path_hidden() {
273 let body = "test string";
274 let layer = HidePathsLayer::builder()
275 .hide("/example.html")
276 .build()
277 .unwrap();
278 let svc = tower::ServiceBuilder::new().layer(layer).service_fn(
279 |_: Request<Empty<Bytes>>| async move {
280 Ok::<_, Infallible>(Response::new(http_body_util::Full::new(Bytes::from(body))))
281 },
282 );
283 let not_found = svc.clone().oneshot(request("/example.html")).await.unwrap();
284 assert_eq!(not_found.status(), StatusCode::NOT_FOUND);
285 assert!(
286 not_found
287 .body()
288 .clone()
289 .collect()
290 .await
291 .unwrap()
292 .to_bytes()
293 .is_empty()
294 );
295 }
296
297 #[tokio::test]
298 async fn path_not_hidden() {
299 let body = "test string";
300 let layer = HidePathsLayer::builder()
301 .hide("/example.html")
302 .build()
303 .unwrap();
304 let svc = tower::ServiceBuilder::new().layer(layer).service_fn(
305 |_: Request<Empty<Bytes>>| async move {
306 Ok::<_, Infallible>(Response::new(http_body_util::Full::new(Bytes::from(body))))
307 },
308 );
309 let not_found = svc
310 .clone()
311 .oneshot(request("/example.htmlb"))
312 .await
313 .unwrap();
314 assert_eq!(not_found.status(), StatusCode::OK);
315 assert_eq!(
316 not_found.body().clone().collect().await.unwrap().to_bytes(),
317 body
318 );
319 }
320}