1use axum::{
2 Router,
3 body::Body,
4 http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
5 response::{IntoResponse, Response},
6};
7use chrono::TimeDelta;
8use std::{
9 collections::{BTreeMap, HashMap},
10 convert::Infallible,
11 net::SocketAddr,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15 time::Instant,
16};
17use tower::Service;
18use tracing::trace;
19
20use crate::{
21 ClientInfo, RenderContext, Rendered, RenderedBody,
22 servable::{Servable, ServableWithRoute},
23};
24
25struct Default404 {}
26
27impl Servable for Default404 {
28 fn head<'a>(
29 &'a self,
30 _ctx: &'a RenderContext,
31 ) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
32 Box::pin(async {
33 return Rendered {
34 code: StatusCode::NOT_FOUND,
35 body: (),
36 ttl: Some(TimeDelta::days(1)),
37 headers: HeaderMap::new(),
38 mime: Some(mime::TEXT_HTML),
39 private: false,
40 };
41 })
42 }
43
44 fn render<'a>(
45 &'a self,
46 ctx: &'a RenderContext,
47 ) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
48 Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
49 }
50}
51
52#[derive(Clone)]
84pub struct ServableRouter {
85 pages: Arc<HashMap<String, Arc<dyn Servable>>>,
86 notfound: Arc<dyn Servable>,
87}
88
89impl ServableRouter {
90 #[inline(always)]
92 pub fn new() -> Self {
93 Self {
94 pages: Arc::new(HashMap::new()),
95 notfound: Arc::new(Default404 {}),
96 }
97 }
98
99 #[inline(always)]
101 pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
102 self.notfound = Arc::new(page);
103 self
104 }
105
106 #[inline(always)]
113 pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
114 let route = route.into();
115
116 if !route.starts_with("/") {
117 panic!("route must start with /")
118 };
119
120 if route.ends_with("/") && route != "/" {
121 panic!("route must not end with /")
122 };
123
124 if route.contains("//") {
125 panic!("route must not contain //")
126 };
127
128 #[expect(clippy::expect_used)]
129 Arc::get_mut(&mut self.pages)
130 .expect("add_pages called after service was started")
131 .insert(route, Arc::new(page));
132
133 self
134 }
135
136 #[inline(always)]
139 pub fn add_page_with_route<S: Servable + 'static>(
140 self,
141 servable_with_route: &'static ServableWithRoute<S>,
142 ) -> Self {
143 self.add_page(servable_with_route.route(), servable_with_route)
144 }
145
146 #[inline(always)]
154 pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
155 Router::new().fallback_service(self)
156 }
157}
158
159impl Service<Request<Body>> for ServableRouter {
164 type Response = Response;
165 type Error = Infallible;
166 type Future =
167 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
168
169 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170 Poll::Ready(Ok(()))
171 }
172
173 fn call(&mut self, req: Request<Body>) -> Self::Future {
174 if req.method() != Method::GET && req.method() != Method::HEAD {
175 let mut headers = HeaderMap::with_capacity(1);
176 headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
177 return Box::pin(async {
178 Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
179 });
180 }
181
182 let pages = self.pages.clone();
183 let notfound = self.notfound.clone();
184 Box::pin(async move {
185 let addr = req.extensions().get::<SocketAddr>().copied();
186 let route = req.uri().path().to_owned();
187 let headers = req.headers().clone();
188 let query: BTreeMap<String, String> =
189 serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
190
191 let start = Instant::now();
192 let client_info = ClientInfo::from_headers(&headers);
193 let ua = headers
194 .get("user-agent")
195 .and_then(|x| x.to_str().ok())
196 .unwrap_or("");
197
198 trace!(
199 message = "Serving route",
200 route,
201 addr = ?addr,
202 user_agent = ua,
203 device_type = ?client_info.device_type
204 );
205
206 if (route.ends_with('/') && route != "/") || route.contains("//") {
208 let mut new_route = route.clone();
209 while new_route.contains("//") {
210 new_route = new_route.replace("//", "/");
211 }
212 let new_route = new_route.trim_matches('/');
213
214 trace!(
215 message = "Redirecting",
216 route,
217 new_route,
218 addr = ?addr,
219 user_agent = ua,
220 device_type = ?client_info.device_type
221 );
222
223 let mut headers = HeaderMap::with_capacity(1);
224 match HeaderValue::from_str(&format!("/{new_route}")) {
225 Ok(x) => headers.append(header::LOCATION, x),
226 Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
227 };
228 return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
229 }
230
231 let ctx = RenderContext {
232 client_info,
233 route,
234 query,
235 };
236
237 let page = pages.get(&ctx.route).unwrap_or(¬found);
238 let mut rend = match req.method() == Method::HEAD {
239 true => page.head(&ctx).await.with_body(RenderedBody::Empty),
240 false => page.render(&ctx).await,
241 };
242
243 {
245 if !rend.headers.contains_key(header::CACHE_CONTROL) {
246 let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(0).max(0);
247
248 let mut value = String::new();
249
250 value.push_str(match rend.private {
251 true => "private, ",
252 false => "public, ",
253 });
254
255 value.push_str(&format!("max-age={}, ", max_age));
256
257 #[expect(clippy::unwrap_used)]
258 rend.headers.insert(
259 header::CACHE_CONTROL,
260 HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
261 );
262 }
263
264 if !rend.headers.contains_key("Accept-CH") {
265 rend.headers
266 .insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
267 }
268
269 if !rend.headers.contains_key(header::CONTENT_TYPE)
270 && let Some(mime) = &rend.mime
271 {
272 #[expect(clippy::unwrap_used)]
273 rend.headers.insert(
274 header::CONTENT_TYPE,
275 HeaderValue::from_str(mime.as_ref()).unwrap(),
276 );
277 }
278 }
279
280 trace!(
281 message = "Served route",
282 route = ctx.route,
283 addr = ?addr,
284 user_agent = ua,
285 device_type = ?client_info.device_type,
286 time_ns = start.elapsed().as_nanos()
287 );
288
289 Ok(match rend.body {
290 RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
291 RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
292 RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
293 RenderedBody::Empty => (rend.code, rend.headers).into_response(),
294 })
295 })
296 }
297}