servable/
router.rs

1use axum::{
2	Router,
3	body::Body,
4	http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
5	response::{IntoResponse, Response},
6};
7use chrono::TimeDelta;
8use std::{
9	collections::{BTreeMap, HashMap},
10	convert::Infallible,
11	net::SocketAddr,
12	pin::Pin,
13	sync::Arc,
14	task::{Context, Poll},
15	time::Instant,
16};
17use tower::Service;
18use tracing::trace;
19
20use crate::{
21	ClientInfo, RenderContext, Rendered, RenderedBody,
22	mime::MimeType,
23	servable::{Servable, ServableWithRoute},
24};
25
26struct Default404 {}
27
28impl Servable for Default404 {
29	fn head<'a>(
30		&'a self,
31		_ctx: &'a RenderContext,
32	) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
33		Box::pin(async {
34			return Rendered {
35				code: StatusCode::NOT_FOUND,
36				body: (),
37				ttl: Some(TimeDelta::days(1)),
38				headers: HeaderMap::new(),
39				mime: Some(MimeType::Html),
40				private: false,
41			};
42		})
43	}
44
45	fn render<'a>(
46		&'a self,
47		ctx: &'a RenderContext,
48	) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
49		Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
50	}
51}
52
53/// A set of related [Servable]s under one route.
54///
55/// Use as follows:
56/// ```rust
57/// use servable::{ServableRouter, StaticAsset, mime::MimeType};
58/// use axum::Router;
59/// use tower_http::compression::{CompressionLayer, predicate::DefaultPredicate};
60///
61/// // Add compression, for example.
62/// // Also consider CORS and timeout.
63/// let compression: CompressionLayer = CompressionLayer::new()
64/// 	.br(true)
65/// 	.deflate(true)
66/// 	.gzip(true)
67/// 	.zstd(true)
68/// 	.compress_when(DefaultPredicate::new());
69///
70/// let route = ServableRouter::new()
71/// 	.add_page(
72/// 		"/page",
73/// 		StaticAsset {
74/// 			bytes: "I am a page".as_bytes(),
75/// 			mime: MimeType::Text,
76/// 			ttl: StaticAsset::DEFAULT_TTL
77/// 		},
78/// 	);
79///
80/// let router: Router<()> = route
81/// 	.into_router()
82/// 	.layer(compression.clone());
83/// ```
84#[derive(Clone)]
85pub struct ServableRouter {
86	pages: Arc<HashMap<String, Arc<dyn Servable>>>,
87	notfound: Arc<dyn Servable>,
88}
89
90impl ServableRouter {
91	/// Create a new, empty [ServableRouter]
92	#[inline(always)]
93	pub fn new() -> Self {
94		Self {
95			pages: Arc::new(HashMap::new()),
96			notfound: Arc::new(Default404 {}),
97		}
98	}
99
100	/// Set this server's "not found" page
101	#[inline(always)]
102	pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
103		self.notfound = Arc::new(page);
104		self
105	}
106
107	/// Add a [Servable] to this server at the given route.
108	/// - panics if route does not start with a `/`, ends with a `/`, or contains `//`.
109	///   - urls are normalized, routes that violate this condition will never be served.
110	///   - `/` is an exception, it is valid.
111	/// - panics if called after this service is started
112	/// - overwrites existing pages
113	#[inline(always)]
114	pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
115		let route = route.into();
116
117		if !route.starts_with("/") {
118			panic!("route must start with /")
119		};
120
121		if route.ends_with("/") && route != "/" {
122			panic!("route must not end with /")
123		};
124
125		if route.contains("//") {
126			panic!("route must not contain //")
127		};
128
129		#[expect(clippy::expect_used)]
130		Arc::get_mut(&mut self.pages)
131			.expect("add_pages called after service was started")
132			.insert(route, Arc::new(page));
133
134		self
135	}
136
137	/// Add a [ServableWithRoute] to this server.
138	/// Behaves exactly like [Self::add_page].
139	#[inline(always)]
140	pub fn add_page_with_route<S: Servable + 'static>(
141		self,
142		servable_with_route: &'static ServableWithRoute<S>,
143	) -> Self {
144		self.add_page(servable_with_route.route(), servable_with_route)
145	}
146
147	/// Convenience method.
148	/// Turns this service into a router.
149	///
150	/// Equivalent to:
151	/// ```ignore
152	/// Router::new().fallback_service(self)
153	/// ```
154	#[inline(always)]
155	pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
156		Router::new().fallback_service(self)
157	}
158}
159
160//
161// MARK: impl Service
162//
163
164impl Service<Request<Body>> for ServableRouter {
165	type Response = Response;
166	type Error = Infallible;
167	type Future =
168		Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
169
170	fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171		Poll::Ready(Ok(()))
172	}
173
174	fn call(&mut self, req: Request<Body>) -> Self::Future {
175		if req.method() != Method::GET && req.method() != Method::HEAD {
176			let mut headers = HeaderMap::with_capacity(1);
177			headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
178			return Box::pin(async {
179				Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
180			});
181		}
182
183		let pages = self.pages.clone();
184		let notfound = self.notfound.clone();
185		Box::pin(async move {
186			let addr = req.extensions().get::<SocketAddr>().copied();
187			let route = req.uri().path().to_owned();
188			let headers = req.headers().clone();
189			let query: BTreeMap<String, String> =
190				serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
191
192			let start = Instant::now();
193			let client_info = ClientInfo::from_headers(&headers);
194			let ua = headers
195				.get("user-agent")
196				.and_then(|x| x.to_str().ok())
197				.unwrap_or("");
198
199			trace!(
200				message = "Serving route",
201				route,
202				addr = ?addr,
203				user_agent = ua,
204				device_type = ?client_info.device_type
205			);
206
207			// Normalize url with redirect
208			if (route.ends_with('/') && route != "/") || route.contains("//") {
209				let mut new_route = route.clone();
210				while new_route.contains("//") {
211					new_route = new_route.replace("//", "/");
212				}
213				let new_route = new_route.trim_matches('/');
214
215				trace!(
216					message = "Redirecting",
217					route,
218					new_route,
219					addr = ?addr,
220					user_agent = ua,
221					device_type = ?client_info.device_type
222				);
223
224				let mut headers = HeaderMap::with_capacity(1);
225				match HeaderValue::from_str(&format!("/{new_route}")) {
226					Ok(x) => headers.append(header::LOCATION, x),
227					Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
228				};
229				return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
230			}
231
232			let ctx = RenderContext {
233				client_info,
234				route,
235				query,
236			};
237
238			let page = pages.get(&ctx.route).unwrap_or(&notfound);
239			let mut rend = match req.method() == Method::HEAD {
240				true => page.head(&ctx).await.with_body(RenderedBody::Empty),
241				false => page.render(&ctx).await,
242			};
243
244			// Tweak headers
245			{
246				if !rend.headers.contains_key(header::CACHE_CONTROL) {
247					let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(0).max(0);
248
249					let mut value = String::new();
250
251					value.push_str(match rend.private {
252						true => "private, ",
253						false => "public, ",
254					});
255
256					value.push_str(&format!("max-age={}, ", max_age));
257
258					#[expect(clippy::unwrap_used)]
259					rend.headers.insert(
260						header::CACHE_CONTROL,
261						HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
262					);
263				}
264
265				if !rend.headers.contains_key("Accept-CH") {
266					rend.headers
267						.insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
268				}
269
270				if !rend.headers.contains_key(header::CONTENT_TYPE)
271					&& let Some(mime) = &rend.mime
272				{
273					#[expect(clippy::unwrap_used)]
274					rend.headers.insert(
275						header::CONTENT_TYPE,
276						HeaderValue::from_str(&mime.to_string()).unwrap(),
277					);
278				}
279			}
280
281			trace!(
282				message = "Served route",
283				route = ctx.route,
284				addr = ?addr,
285				user_agent = ua,
286				device_type = ?client_info.device_type,
287				time_ns = start.elapsed().as_nanos()
288			);
289
290			Ok(match rend.body {
291				RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
292				RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
293				RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
294				RenderedBody::Empty => (rend.code, rend.headers).into_response(),
295			})
296		})
297	}
298}