1use axum::{
2 Router,
3 body::Body,
4 http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
5 response::{IntoResponse, Response},
6};
7use chrono::TimeDelta;
8use std::{
9 collections::{BTreeMap, HashMap},
10 convert::Infallible,
11 net::SocketAddr,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15 time::Instant,
16};
17use tower::Service;
18use tracing::trace;
19
20use crate::{
21 ClientInfo, RenderContext, Rendered, RenderedBody,
22 mime::MimeType,
23 servable::{Servable, ServableWithRoute},
24};
25
26struct Default404 {}
27
28impl Servable for Default404 {
29 fn head<'a>(
30 &'a self,
31 _ctx: &'a RenderContext,
32 ) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
33 Box::pin(async {
34 return Rendered {
35 code: StatusCode::NOT_FOUND,
36 body: (),
37 ttl: Some(TimeDelta::days(1)),
38 headers: HeaderMap::new(),
39 mime: Some(MimeType::Html),
40 private: false,
41 };
42 })
43 }
44
45 fn render<'a>(
46 &'a self,
47 ctx: &'a RenderContext,
48 ) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
49 Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
50 }
51}
52
53#[derive(Clone)]
85pub struct ServableRouter {
86 pages: Arc<HashMap<String, Arc<dyn Servable>>>,
87 notfound: Arc<dyn Servable>,
88}
89
90impl ServableRouter {
91 #[inline(always)]
93 pub fn new() -> Self {
94 Self {
95 pages: Arc::new(HashMap::new()),
96 notfound: Arc::new(Default404 {}),
97 }
98 }
99
100 #[inline(always)]
102 pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
103 self.notfound = Arc::new(page);
104 self
105 }
106
107 #[inline(always)]
114 pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
115 let route = route.into();
116
117 if !route.starts_with("/") {
118 panic!("route must start with /")
119 };
120
121 if route.ends_with("/") && route != "/" {
122 panic!("route must not end with /")
123 };
124
125 if route.contains("//") {
126 panic!("route must not contain //")
127 };
128
129 #[expect(clippy::expect_used)]
130 Arc::get_mut(&mut self.pages)
131 .expect("add_pages called after service was started")
132 .insert(route, Arc::new(page));
133
134 self
135 }
136
137 #[inline(always)]
140 pub fn add_page_with_route<S: Servable + 'static>(
141 self,
142 servable_with_route: &'static ServableWithRoute<S>,
143 ) -> Self {
144 self.add_page(servable_with_route.route(), servable_with_route)
145 }
146
147 #[inline(always)]
155 pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
156 Router::new().fallback_service(self)
157 }
158}
159
160impl Service<Request<Body>> for ServableRouter {
165 type Response = Response;
166 type Error = Infallible;
167 type Future =
168 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
169
170 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171 Poll::Ready(Ok(()))
172 }
173
174 fn call(&mut self, req: Request<Body>) -> Self::Future {
175 if req.method() != Method::GET && req.method() != Method::HEAD {
176 let mut headers = HeaderMap::with_capacity(1);
177 headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
178 return Box::pin(async {
179 Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
180 });
181 }
182
183 let pages = self.pages.clone();
184 let notfound = self.notfound.clone();
185 Box::pin(async move {
186 let addr = req.extensions().get::<SocketAddr>().copied();
187 let route = req.uri().path().to_owned();
188 let headers = req.headers().clone();
189 let query: BTreeMap<String, String> =
190 serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
191
192 let start = Instant::now();
193 let client_info = ClientInfo::from_headers(&headers);
194 let ua = headers
195 .get("user-agent")
196 .and_then(|x| x.to_str().ok())
197 .unwrap_or("");
198
199 trace!(
200 message = "Serving route",
201 route,
202 addr = ?addr,
203 user_agent = ua,
204 device_type = ?client_info.device_type
205 );
206
207 if (route.ends_with('/') && route != "/") || route.contains("//") {
209 let mut new_route = route.clone();
210 while new_route.contains("//") {
211 new_route = new_route.replace("//", "/");
212 }
213 let new_route = new_route.trim_matches('/');
214
215 trace!(
216 message = "Redirecting",
217 route,
218 new_route,
219 addr = ?addr,
220 user_agent = ua,
221 device_type = ?client_info.device_type
222 );
223
224 let mut headers = HeaderMap::with_capacity(1);
225 match HeaderValue::from_str(&format!("/{new_route}")) {
226 Ok(x) => headers.append(header::LOCATION, x),
227 Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
228 };
229 return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
230 }
231
232 let ctx = RenderContext {
233 client_info,
234 route,
235 query,
236 };
237
238 let page = pages.get(&ctx.route).unwrap_or(¬found);
239 let mut rend = match req.method() == Method::HEAD {
240 true => page.head(&ctx).await.with_body(RenderedBody::Empty),
241 false => page.render(&ctx).await,
242 };
243
244 {
246 if !rend.headers.contains_key(header::CACHE_CONTROL) {
247 let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(0).max(0);
248
249 let mut value = String::new();
250
251 value.push_str(match rend.private {
252 true => "private, ",
253 false => "public, ",
254 });
255
256 value.push_str(&format!("max-age={}, ", max_age));
257
258 #[expect(clippy::unwrap_used)]
259 rend.headers.insert(
260 header::CACHE_CONTROL,
261 HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
262 );
263 }
264
265 if !rend.headers.contains_key("Accept-CH") {
266 rend.headers
267 .insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
268 }
269
270 if !rend.headers.contains_key(header::CONTENT_TYPE)
271 && let Some(mime) = &rend.mime
272 {
273 #[expect(clippy::unwrap_used)]
274 rend.headers.insert(
275 header::CONTENT_TYPE,
276 HeaderValue::from_str(&mime.to_string()).unwrap(),
277 );
278 }
279 }
280
281 trace!(
282 message = "Served route",
283 route = ctx.route,
284 addr = ?addr,
285 user_agent = ua,
286 device_type = ?client_info.device_type,
287 time_ns = start.elapsed().as_nanos()
288 );
289
290 Ok(match rend.body {
291 RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
292 RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
293 RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
294 RenderedBody::Empty => (rend.code, rend.headers).into_response(),
295 })
296 })
297 }
298}