1use std::collections::HashSet;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use http::StatusCode;
15
16use crate::body::{body_from_bytes, body_from_string, BoxBody};
17use crate::extract::PathPrefixOffset;
18use crate::handler::BoxedHandler;
19
20pub(crate) type MatchFn = Box<dyn Fn(&[&str]) -> bool + Send + Sync>;
21type StateInjector = Arc<dyn Fn(&mut http::Extensions) + Send + Sync>;
22pub(crate) type FallbackService = Arc<
23 dyn Fn(
24 http::Request<hyper::body::Incoming>,
25 ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>>
26 + Send
27 + Sync,
28>;
29
30pub const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
32
33pub struct Router {
41 inner: parking_lot::RwLock<RouterInner>,
44}
45
46struct RouterInner {
47 routes: Vec<RouteEntry>,
48 method_index: MethodIndex,
49 any_method_trie: matchit::Router<()>,
52 any_method_seen: HashSet<String>,
56 any_method_has_fallback: bool,
60 methods_present: u8,
65 state_injector: Option<StateInjector>,
66 fallback: Option<FallbackService>,
67 max_body_size: usize,
68 prefix: Option<Vec<String>>,
69 prefix_str: Option<String>,
71}
72
73struct RouteEntry {
74 #[cfg_attr(not(feature = "grpc"), allow(dead_code))]
76 method: http::Method,
77 #[cfg_attr(not(feature = "grpc"), allow(dead_code))]
78 pattern: String,
79 match_fn: MatchFn,
80 handler: BoxedHandler,
81}
82
83#[derive(Default)]
85struct MethodIndex {
86 get: MethodBucket,
87 post: MethodBucket,
88 put: MethodBucket,
89 delete: MethodBucket,
90 patch: MethodBucket,
91 head: MethodBucket,
92 options: MethodBucket,
93 other: MethodBucket,
94}
95
96#[derive(Default)]
97struct MethodBucket {
98 trie: matchit::Router<usize>,
100 fallback: Vec<usize>,
102}
103
104impl MethodIndex {
105 fn bucket(&self, method: &http::Method) -> &MethodBucket {
106 match *method {
107 http::Method::GET => &self.get,
108 http::Method::POST => &self.post,
109 http::Method::PUT => &self.put,
110 http::Method::DELETE => &self.delete,
111 http::Method::PATCH => &self.patch,
112 http::Method::HEAD => &self.head,
113 http::Method::OPTIONS => &self.options,
114 _ => &self.other,
115 }
116 }
117
118 fn bucket_mut(&mut self, method: &http::Method) -> &mut MethodBucket {
119 match *method {
120 http::Method::GET => &mut self.get,
121 http::Method::POST => &mut self.post,
122 http::Method::PUT => &mut self.put,
123 http::Method::DELETE => &mut self.delete,
124 http::Method::PATCH => &mut self.patch,
125 http::Method::HEAD => &mut self.head,
126 http::Method::OPTIONS => &mut self.options,
127 _ => &mut self.other,
128 }
129 }
130
131 fn all_buckets(&self) -> [&MethodBucket; 8] {
132 [
133 &self.get,
134 &self.post,
135 &self.put,
136 &self.delete,
137 &self.patch,
138 &self.head,
139 &self.options,
140 &self.other,
141 ]
142 }
143}
144
145fn method_bit(method: &http::Method) -> u8 {
147 match *method {
148 http::Method::GET => 1 << 0,
149 http::Method::POST => 1 << 1,
150 http::Method::PUT => 1 << 2,
151 http::Method::DELETE => 1 << 3,
152 http::Method::PATCH => 1 << 4,
153 http::Method::HEAD => 1 << 5,
154 http::Method::OPTIONS => 1 << 6,
155 _ => 1 << 7,
156 }
157}
158
159fn to_matchit_pattern(pat: &str) -> String {
164 let bytes = pat.as_bytes();
165 let mut out = String::with_capacity(pat.len() + 8);
166 let mut counter: u32 = 0;
167 let mut i = 0;
168 while i < bytes.len() {
169 if bytes[i] == b'{' {
170 let close = bytes[i + 1..].iter().position(|&b| b == b'}');
171 match close {
172 Some(rel_end) => {
173 let inner = &pat[i + 1..i + 1 + rel_end];
174 if inner.is_empty() {
175 out.push_str(&format!("{{p{counter}}}"));
176 counter += 1;
177 } else {
178 out.push('{');
179 out.push_str(inner);
180 out.push('}');
181 }
182 i += 2 + rel_end;
183 }
184 None => {
185 out.push_str(&pat[i..]);
186 break;
187 }
188 }
189 } else {
190 out.push(bytes[i] as char);
191 i += 1;
192 }
193 }
194 out
195}
196
197fn strip_prefix<'a>(prefix_str: Option<&str>, path: &'a str) -> Option<&'a str> {
201 match prefix_str {
202 Some(prefix) => match path.strip_prefix(prefix) {
203 Some("") => Some("/"),
204 Some(rest) if rest.starts_with('/') => Some(rest),
205 _ => None,
206 },
207 None => Some(if path.is_empty() { "/" } else { path }),
208 }
209}
210
211struct LazySegments<'a> {
218 path: &'a str,
219 cache: Option<smallvec::SmallVec<[&'a str; 8]>>,
220}
221
222impl<'a> LazySegments<'a> {
223 fn new(path: &'a str) -> Self {
224 Self { path, cache: None }
225 }
226
227 fn get(&mut self) -> &[&'a str] {
228 self.cache
229 .get_or_insert_with(|| self.path.split('/').filter(|s| !s.is_empty()).collect())
230 .as_slice()
231 }
232}
233
234fn lookup_in_bucket(
240 bucket: &MethodBucket,
241 routes: &[RouteEntry],
242 lookup_path: &str,
243 segments: &mut LazySegments,
244) -> Option<usize> {
245 if let Ok(m) = bucket.trie.at(lookup_path) {
246 let idx = *m.value;
247 if (routes[idx].match_fn)(segments.get()) {
248 return Some(idx);
249 }
250 }
251 if bucket.fallback.is_empty() {
252 return None;
253 }
254 let segs = segments.get();
255 bucket
256 .fallback
257 .iter()
258 .copied()
259 .find(|&i| (routes[i].match_fn)(segs))
260}
261
262fn make_status_response(status: StatusCode, body: &'static [u8]) -> http::Response<BoxBody> {
263 let mut res = http::Response::new(body_from_bytes(bytes::Bytes::from_static(body)));
264 *res.status_mut() = status;
265 res
266}
267
268fn not_found_response() -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
269 Box::pin(std::future::ready(make_status_response(
270 StatusCode::NOT_FOUND,
271 b"Not Found",
272 )))
273}
274
275fn method_not_allowed_response() -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
276 Box::pin(std::future::ready(make_status_response(
277 StatusCode::METHOD_NOT_ALLOWED,
278 b"Method Not Allowed",
279 )))
280}
281
282enum LookupOutcome {
284 Hit(usize),
285 MethodNotAllowed,
286 NotFound,
287}
288
289fn resolve(inner: &RouterInner, method: &http::Method, lookup_path: &str) -> LookupOutcome {
293 let mut segments = LazySegments::new(lookup_path);
294
295 let bucket = inner.method_index.bucket(method);
296 if let Some(i) = lookup_in_bucket(bucket, &inner.routes, lookup_path, &mut segments) {
297 return LookupOutcome::Hit(i);
298 }
299
300 let bit = method_bit(method);
304 if inner.methods_present == bit {
305 return LookupOutcome::NotFound;
306 }
307
308 let any_path_shape =
313 inner.any_method_has_fallback || inner.any_method_trie.at(lookup_path).is_ok();
314 let path_matched = any_path_shape
315 && inner
316 .method_index
317 .all_buckets()
318 .iter()
319 .any(|b| lookup_in_bucket(b, &inner.routes, lookup_path, &mut segments).is_some());
320
321 if path_matched {
322 LookupOutcome::MethodNotAllowed
323 } else {
324 LookupOutcome::NotFound
325 }
326}
327
328impl Router {
329 pub fn new() -> Self {
331 Router {
332 inner: parking_lot::RwLock::new(RouterInner {
333 routes: Vec::new(),
334 method_index: MethodIndex::default(),
335 any_method_trie: matchit::Router::new(),
336 any_method_seen: HashSet::new(),
337 any_method_has_fallback: false,
338 methods_present: 0,
339 state_injector: None,
340 fallback: None,
341 max_body_size: DEFAULT_MAX_BODY_SIZE,
342 prefix: None,
343 prefix_str: None,
344 }),
345 }
346 }
347
348 pub(crate) fn set_prefix(&self, prefix: &str) {
350 let segments: Vec<String> = prefix
351 .split('/')
352 .filter(|s| !s.is_empty())
353 .map(|s| s.to_string())
354 .collect();
355 if segments.is_empty() {
356 return;
357 }
358 let mut joined = String::with_capacity(segments.iter().map(|s| s.len() + 1).sum());
359 for seg in &segments {
360 joined.push('/');
361 joined.push_str(seg);
362 }
363 let mut inner = self.inner.write();
364 inner.prefix = Some(segments);
365 inner.prefix_str = Some(joined);
366 }
367
368 pub(crate) fn set_max_body_size(&self, max: usize) {
370 self.inner.write().max_body_size = max;
371 }
372
373 pub(crate) fn add_route(
375 &self,
376 method: http::Method,
377 pattern: String,
378 match_fn: MatchFn,
379 handler: BoxedHandler,
380 ) {
381 let mut inner = self.inner.write();
382 let idx = inner.routes.len();
383
384 inner.methods_present |= method_bit(&method);
385
386 let matchit_pattern = to_matchit_pattern(&pattern);
387 let bucket = inner.method_index.bucket_mut(&method);
388 if bucket.trie.insert(matchit_pattern.clone(), idx).is_err() {
389 bucket.fallback.push(idx);
393 }
394
395 if inner.any_method_seen.insert(matchit_pattern.clone())
399 && inner.any_method_trie.insert(matchit_pattern, ()).is_err()
400 {
401 inner.any_method_has_fallback = true;
404 }
405
406 inner.routes.push(RouteEntry {
407 method,
408 pattern,
409 match_fn,
410 handler,
411 });
412 }
413
414 pub(crate) fn set_state_injector(
416 &self,
417 injector: Arc<dyn Fn(&mut http::Extensions) + Send + Sync>,
418 ) {
419 self.inner.write().state_injector = Some(injector);
420 }
421
422 pub(crate) fn set_fallback(&self, fallback: FallbackService) {
424 self.inner.write().fallback = Some(fallback);
425 }
426
427 #[cfg(feature = "grpc")]
435 pub(crate) fn find_handler_by_pattern(
436 &self,
437 method: &http::Method,
438 pattern: &str,
439 ) -> Option<BoxedHandler> {
440 let inner = self.inner.read();
441 inner
442 .routes
443 .iter()
444 .find(|e| e.method == *method && e.pattern == pattern)
445 .map(|e| e.handler.clone())
446 }
447
448 #[cfg(feature = "grpc")]
450 pub(crate) fn state_injector(&self) -> Option<StateInjector> {
451 self.inner.read().state_injector.clone()
452 }
453
454 pub fn route(
460 self: &Arc<Self>,
461 req: http::Request<hyper::body::Incoming>,
462 ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
463 let inner = self.inner.read();
464
465 let (mut parts, body) = req.into_parts();
469
470 let path: &str = parts.uri.path();
471 let lookup_path = match strip_prefix(inner.prefix_str.as_deref(), path) {
472 Some(p) => p,
473 None => {
474 return if let Some(ref fallback) = inner.fallback {
476 fallback(http::Request::from_parts(parts, body))
477 } else {
478 not_found_response()
479 };
480 }
481 };
482
483 match resolve(&inner, &parts.method, lookup_path) {
484 LookupOutcome::Hit(i) => {
485 if let Some(ref prefix_str) = inner.prefix_str {
488 parts.extensions.insert(PathPrefixOffset(prefix_str.len()));
489 }
490 if let Some(ref injector) = inner.state_injector {
491 injector(&mut parts.extensions);
492 }
493 let router = self.clone();
494 let max_body = inner.max_body_size;
495 drop(inner);
496 Box::pin(async move {
497 let body_bytes = match collect_body_limited(body, max_body).await {
498 Ok(bytes) => bytes,
499 Err(resp) => return resp,
500 };
501 let fut = {
502 let inner = router.inner.read();
503 (inner.routes[i].handler)(parts, body_bytes)
504 };
505 fut.await
506 })
507 }
508 LookupOutcome::MethodNotAllowed => method_not_allowed_response(),
509 LookupOutcome::NotFound => {
510 if let Some(ref fallback) = inner.fallback {
511 fallback(http::Request::from_parts(parts, body))
512 } else {
513 not_found_response()
514 }
515 }
516 }
517 }
518
519 pub fn route_with_bytes(
526 self: &Arc<Self>,
527 mut parts: http::request::Parts,
528 body_bytes: bytes::Bytes,
529 ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
530 let inner = self.inner.read();
531
532 let path: &str = parts.uri.path();
533 let lookup_path = match strip_prefix(inner.prefix_str.as_deref(), path) {
534 Some(p) => p,
535 None => return not_found_response(),
536 };
537
538 match resolve(&inner, &parts.method, lookup_path) {
539 LookupOutcome::Hit(i) => {
540 if let Some(ref prefix_str) = inner.prefix_str {
541 parts.extensions.insert(PathPrefixOffset(prefix_str.len()));
542 }
543 if let Some(ref injector) = inner.state_injector {
544 injector(&mut parts.extensions);
545 }
546 let fut = (inner.routes[i].handler)(parts, body_bytes);
547 drop(inner);
548 fut
549 }
550 LookupOutcome::MethodNotAllowed => method_not_allowed_response(),
551 LookupOutcome::NotFound => not_found_response(),
552 }
553 }
554}
555
556impl Default for Router {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562async fn collect_body_limited(
570 body: hyper::body::Incoming,
571 max_bytes: usize,
572) -> Result<bytes::Bytes, http::Response<BoxBody>> {
573 use http_body_util::BodyExt;
574
575 let limited = http_body_util::Limited::new(body, max_bytes);
576 match limited.collect().await {
577 Ok(collected) => Ok(collected.to_bytes()),
578 Err(_) => {
579 let mut res = http::Response::new(body_from_string(format!(
580 "request body too large (max {max_bytes} bytes)"
581 )));
582 *res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
583 Err(res)
584 }
585 }
586}
587
588#[derive(Clone)]
597pub struct RouterService {
598 router: Arc<Router>,
599}
600
601impl RouterService {
602 pub fn new(router: Arc<Router>) -> Self {
604 RouterService { router }
605 }
606}
607
608impl tower_service::Service<http::Request<hyper::body::Incoming>> for RouterService {
609 type Response = http::Response<BoxBody>;
610 type Error = std::convert::Infallible;
611 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
612
613 fn poll_ready(
614 &mut self,
615 _cx: &mut std::task::Context<'_>,
616 ) -> std::task::Poll<Result<(), Self::Error>> {
617 std::task::Poll::Ready(Ok(()))
618 }
619
620 fn call(&mut self, req: http::Request<hyper::body::Incoming>) -> Self::Future {
621 let router = self.router.clone();
622 Box::pin(async move {
623 use futures::FutureExt;
624
625 let result = std::panic::AssertUnwindSafe(router.route(req))
626 .catch_unwind()
627 .await;
628
629 match result {
630 Ok(response) => Ok(response),
631 Err(panic_info) => {
632 let message = if let Some(s) = panic_info.downcast_ref::<&str>() {
633 (*s).to_string()
634 } else if let Some(s) = panic_info.downcast_ref::<String>() {
635 s.clone()
636 } else {
637 "unknown panic".to_string()
638 };
639
640 tracing::error!("handler panicked: {message}");
641
642 let mut res =
643 http::Response::new(body_from_string("Internal Server Error".to_string()));
644 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
645 Ok(res)
646 }
647 }
648 })
649 }
650}