Skip to main content

servable/
router.rs

1use axum::{
2	Router,
3	body::Body,
4	http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
5	response::{IntoResponse, Response},
6};
7use chrono::TimeDelta;
8use std::{
9	collections::{BTreeMap, HashMap},
10	convert::Infallible,
11	net::SocketAddr,
12	pin::Pin,
13	sync::Arc,
14	task::{Context, Poll},
15	time::Instant,
16};
17use tower::Service;
18use tracing::trace;
19
20use crate::{
21	ClientInfo, RenderContext, Rendered, RenderedBody,
22	servable::{Servable, ServableWithRoute},
23};
24
25struct Default404 {}
26
27impl Servable for Default404 {
28	fn head<'a>(
29		&'a self,
30		_ctx: &'a RenderContext,
31	) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
32		Box::pin(async {
33			return Rendered {
34				code: StatusCode::NOT_FOUND,
35				body: (),
36				ttl: Some(TimeDelta::days(1)),
37				headers: HeaderMap::new(),
38				mime: Some(mime::TEXT_HTML),
39				private: false,
40			};
41		})
42	}
43
44	fn render<'a>(
45		&'a self,
46		ctx: &'a RenderContext,
47	) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
48		Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
49	}
50}
51
52/// A set of related [Servable]s under one route.
53///
54/// Use as follows:
55/// ```rust
56/// use servable::{ServableRouter, StaticAsset};
57/// use axum::Router;
58/// use tower_http::compression::{CompressionLayer, predicate::DefaultPredicate};
59///
60/// // Add compression, for example.
61/// // Also consider CORS and timeout.
62/// let compression: CompressionLayer = CompressionLayer::new()
63/// 	.br(true)
64/// 	.deflate(true)
65/// 	.gzip(true)
66/// 	.zstd(true)
67/// 	.compress_when(DefaultPredicate::new());
68///
69/// let route = ServableRouter::new()
70/// 	.add_page(
71/// 		"/page",
72/// 		StaticAsset {
73/// 			bytes: "I am a page".as_bytes(),
74/// 			mime: mime::TEXT_PLAIN,
75/// 			ttl: StaticAsset::DEFAULT_TTL
76/// 		},
77/// 	);
78///
79/// let router: Router<()> = route
80/// 	.into_router()
81/// 	.layer(compression.clone());
82/// ```
83#[derive(Clone)]
84pub struct ServableRouter {
85	pages: Arc<HashMap<String, Arc<dyn Servable>>>,
86	notfound: Arc<dyn Servable>,
87}
88
89impl ServableRouter {
90	/// Create a new, empty [ServableRouter]
91	#[inline(always)]
92	pub fn new() -> Self {
93		Self {
94			pages: Arc::new(HashMap::new()),
95			notfound: Arc::new(Default404 {}),
96		}
97	}
98
99	/// Set this server's "not found" page
100	#[inline(always)]
101	pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
102		self.notfound = Arc::new(page);
103		self
104	}
105
106	/// Add a [Servable] to this server at the given route.
107	/// - panics if route does not start with a `/`, ends with a `/`, or contains `//`.
108	///   - urls are normalized, routes that violate this condition will never be served.
109	///   - `/` is an exception, it is valid.
110	/// - panics if called after this service is started
111	/// - overwrites existing pages
112	#[inline(always)]
113	pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
114		let route = route.into();
115
116		if !route.starts_with("/") {
117			panic!("route must start with /")
118		};
119
120		if route.ends_with("/") && route != "/" {
121			panic!("route must not end with /")
122		};
123
124		if route.contains("//") {
125			panic!("route must not contain //")
126		};
127
128		#[expect(clippy::expect_used)]
129		Arc::get_mut(&mut self.pages)
130			.expect("add_pages called after service was started")
131			.insert(route, Arc::new(page));
132
133		self
134	}
135
136	/// Add a [ServableWithRoute] to this server.
137	/// Behaves exactly like [Self::add_page].
138	#[inline(always)]
139	pub fn add_page_with_route<S: Servable + 'static>(
140		self,
141		servable_with_route: &'static ServableWithRoute<S>,
142	) -> Self {
143		self.add_page(servable_with_route.route(), servable_with_route)
144	}
145
146	/// Convenience method.
147	/// Turns this service into a router.
148	///
149	/// Equivalent to:
150	/// ```ignore
151	/// Router::new().fallback_service(self)
152	/// ```
153	#[inline(always)]
154	pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
155		Router::new().fallback_service(self)
156	}
157}
158
159//
160// MARK: impl Service
161//
162
163impl Service<Request<Body>> for ServableRouter {
164	type Response = Response;
165	type Error = Infallible;
166	type Future =
167		Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
168
169	fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170		Poll::Ready(Ok(()))
171	}
172
173	fn call(&mut self, req: Request<Body>) -> Self::Future {
174		if req.method() != Method::GET && req.method() != Method::HEAD {
175			let mut headers = HeaderMap::with_capacity(1);
176			headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
177			return Box::pin(async {
178				Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
179			});
180		}
181
182		let pages = self.pages.clone();
183		let notfound = self.notfound.clone();
184		Box::pin(async move {
185			let addr = req.extensions().get::<SocketAddr>().copied();
186			let route = req.uri().path().to_owned();
187			let headers = req.headers().clone();
188			let query: BTreeMap<String, String> =
189				serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
190
191			let start = Instant::now();
192			let client_info = ClientInfo::from_headers(&headers);
193			let ua = headers
194				.get("user-agent")
195				.and_then(|x| x.to_str().ok())
196				.unwrap_or("");
197
198			trace!(
199				message = "Serving route",
200				route,
201				addr = ?addr,
202				user_agent = ua,
203				device_type = ?client_info.device_type
204			);
205
206			// Normalize url with redirect
207			if (route.ends_with('/') && route != "/") || route.contains("//") {
208				let mut new_route = route.clone();
209				while new_route.contains("//") {
210					new_route = new_route.replace("//", "/");
211				}
212				let new_route = new_route.trim_matches('/');
213
214				trace!(
215					message = "Redirecting",
216					route,
217					new_route,
218					addr = ?addr,
219					user_agent = ua,
220					device_type = ?client_info.device_type
221				);
222
223				let mut headers = HeaderMap::with_capacity(1);
224				match HeaderValue::from_str(&format!("/{new_route}")) {
225					Ok(x) => headers.append(header::LOCATION, x),
226					Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
227				};
228				return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
229			}
230
231			let ctx = RenderContext {
232				client_info,
233				route,
234				query,
235			};
236
237			let page = pages.get(&ctx.route).unwrap_or(&notfound);
238			let mut rend = match req.method() == Method::HEAD {
239				true => page.head(&ctx).await.with_body(RenderedBody::Empty),
240				false => page.render(&ctx).await,
241			};
242
243			// Tweak headers
244			{
245				if !rend.headers.contains_key(header::CACHE_CONTROL) {
246					let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(0).max(0);
247
248					let mut value = String::new();
249
250					value.push_str(match rend.private {
251						true => "private, ",
252						false => "public, ",
253					});
254
255					value.push_str(&format!("max-age={}, ", max_age));
256
257					#[expect(clippy::unwrap_used)]
258					rend.headers.insert(
259						header::CACHE_CONTROL,
260						HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
261					);
262				}
263
264				if !rend.headers.contains_key("Accept-CH") {
265					rend.headers
266						.insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
267				}
268
269				if !rend.headers.contains_key(header::CONTENT_TYPE)
270					&& let Some(mime) = &rend.mime
271				{
272					#[expect(clippy::unwrap_used)]
273					rend.headers.insert(
274						header::CONTENT_TYPE,
275						HeaderValue::from_str(mime.as_ref()).unwrap(),
276					);
277				}
278			}
279
280			trace!(
281				message = "Served route",
282				route = ctx.route,
283				addr = ?addr,
284				user_agent = ua,
285				device_type = ?client_info.device_type,
286				time_ns = start.elapsed().as_nanos()
287			);
288
289			Ok(match rend.body {
290				RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
291				RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
292				RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
293				RenderedBody::Empty => (rend.code, rend.headers).into_response(),
294			})
295		})
296	}
297}