Skip to main content

typeway_server/
router.rs

1//! The runtime [`Router`] that matches incoming requests to handlers.
2//!
3//! Routes are dispatched via a per-method radix trie (`matchit`). Patterns
4//! that conflict with already-registered routes fall back to a linear scan
5//! within their method bucket, so registration never fails silently. The
6//! `match_fn` produced by `typeway_path!` runs on the candidate route as a
7//! type-validation step (e.g. confirming `{}` parses as `u32`).
8
9use std::collections::HashSet;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use http::StatusCode;
15
16use crate::body::{body_from_bytes, body_from_string, BoxBody};
17use crate::extract::PathPrefixOffset;
18use crate::handler::BoxedHandler;
19
20pub(crate) type MatchFn = Box<dyn Fn(&[&str]) -> bool + Send + Sync>;
21type StateInjector = Arc<dyn Fn(&mut http::Extensions) + Send + Sync>;
22pub(crate) type FallbackService = Arc<
23    dyn Fn(
24            http::Request<hyper::body::Incoming>,
25        ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>>
26        + Send
27        + Sync,
28>;
29
30/// Default maximum request body size: 2 MiB.
31pub const DEFAULT_MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
32
33/// A runtime HTTP router.
34///
35/// Routes are dispatched through a per-method radix trie (`matchit`). The
36/// trie matches on the request path; a per-route `match_fn` then validates
37/// typed captures (e.g. that `{}` parses as `u32`). Patterns that conflict
38/// in the trie fall back to a linear scan within the same method bucket,
39/// so registration never silently drops a route.
40pub struct Router {
41    /// All mutable state behind RwLock so the router can be configured
42    /// after Arc is shared (e.g., when LayeredServer wraps it).
43    inner: parking_lot::RwLock<RouterInner>,
44}
45
46struct RouterInner {
47    routes: Vec<RouteEntry>,
48    method_index: MethodIndex,
49    /// Unified path-shape trie across all methods. Used to short-circuit
50    /// 404 detection without walking each method bucket.
51    any_method_trie: matchit::Router<()>,
52    /// Patterns already mirrored into `any_method_trie`; needed because
53    /// matchit rejects duplicate inserts (same pattern from another method)
54    /// the same way it rejects real structural conflicts.
55    any_method_seen: HashSet<String>,
56    /// Set when a pattern conflicted with `any_method_trie` (different shape
57    /// collapsing onto an existing entry). When true, miss detection falls
58    /// back to the per-bucket walk so we don't false-negative.
59    any_method_has_fallback: bool,
60    /// Bitmap of which method buckets have routes registered. When the
61    /// request's method bit is the *only* one set, a miss in that bucket
62    /// is a definite 404 with no need to consult `any_method_trie` (saves
63    /// a trie walk on the miss path for single-method APIs).
64    methods_present: u8,
65    state_injector: Option<StateInjector>,
66    fallback: Option<FallbackService>,
67    max_body_size: usize,
68    prefix: Option<Vec<String>>,
69    /// Cached `"/seg1/seg2"` form of `prefix`, for byte-level path stripping.
70    prefix_str: Option<String>,
71}
72
73struct RouteEntry {
74    // method/pattern are only read by `find_handler_by_pattern` (gRPC feature).
75    #[cfg_attr(not(feature = "grpc"), allow(dead_code))]
76    method: http::Method,
77    #[cfg_attr(not(feature = "grpc"), allow(dead_code))]
78    pattern: String,
79    match_fn: MatchFn,
80    handler: BoxedHandler,
81}
82
83/// Per-method radix tries plus a linear fallback for patterns matchit rejects.
84#[derive(Default)]
85struct MethodIndex {
86    get: MethodBucket,
87    post: MethodBucket,
88    put: MethodBucket,
89    delete: MethodBucket,
90    patch: MethodBucket,
91    head: MethodBucket,
92    options: MethodBucket,
93    other: MethodBucket,
94}
95
96#[derive(Default)]
97struct MethodBucket {
98    /// Radix trie of patterns -> route index. Most routes live here.
99    trie: matchit::Router<usize>,
100    /// Routes whose patterns conflicted with the trie (linear fallback).
101    fallback: Vec<usize>,
102}
103
104impl MethodIndex {
105    fn bucket(&self, method: &http::Method) -> &MethodBucket {
106        match *method {
107            http::Method::GET => &self.get,
108            http::Method::POST => &self.post,
109            http::Method::PUT => &self.put,
110            http::Method::DELETE => &self.delete,
111            http::Method::PATCH => &self.patch,
112            http::Method::HEAD => &self.head,
113            http::Method::OPTIONS => &self.options,
114            _ => &self.other,
115        }
116    }
117
118    fn bucket_mut(&mut self, method: &http::Method) -> &mut MethodBucket {
119        match *method {
120            http::Method::GET => &mut self.get,
121            http::Method::POST => &mut self.post,
122            http::Method::PUT => &mut self.put,
123            http::Method::DELETE => &mut self.delete,
124            http::Method::PATCH => &mut self.patch,
125            http::Method::HEAD => &mut self.head,
126            http::Method::OPTIONS => &mut self.options,
127            _ => &mut self.other,
128        }
129    }
130
131    fn all_buckets(&self) -> [&MethodBucket; 8] {
132        [
133            &self.get,
134            &self.post,
135            &self.put,
136            &self.delete,
137            &self.patch,
138            &self.head,
139            &self.options,
140            &self.other,
141        ]
142    }
143}
144
145/// Bit position for `RouterInner::methods_present` for a given HTTP method.
146fn method_bit(method: &http::Method) -> u8 {
147    match *method {
148        http::Method::GET => 1 << 0,
149        http::Method::POST => 1 << 1,
150        http::Method::PUT => 1 << 2,
151        http::Method::DELETE => 1 << 3,
152        http::Method::PATCH => 1 << 4,
153        http::Method::HEAD => 1 << 5,
154        http::Method::OPTIONS => 1 << 6,
155        _ => 1 << 7,
156    }
157}
158
159/// Convert a typeway pattern (`/users/{}/posts/{*rest}`) into the matchit
160/// pattern syntax (`/users/{p0}/posts/{*rest}`). matchit requires every
161/// capture to have a unique name; typeway emits empty `{}` for typed
162/// captures and `{*name}` for catch-alls.
163fn to_matchit_pattern(pat: &str) -> String {
164    let bytes = pat.as_bytes();
165    let mut out = String::with_capacity(pat.len() + 8);
166    let mut counter: u32 = 0;
167    let mut i = 0;
168    while i < bytes.len() {
169        if bytes[i] == b'{' {
170            let close = bytes[i + 1..].iter().position(|&b| b == b'}');
171            match close {
172                Some(rel_end) => {
173                    let inner = &pat[i + 1..i + 1 + rel_end];
174                    if inner.is_empty() {
175                        out.push_str(&format!("{{p{counter}}}"));
176                        counter += 1;
177                    } else {
178                        out.push('{');
179                        out.push_str(inner);
180                        out.push('}');
181                    }
182                    i += 2 + rel_end;
183                }
184                None => {
185                    out.push_str(&pat[i..]);
186                    break;
187                }
188            }
189        } else {
190            out.push(bytes[i] as char);
191            i += 1;
192        }
193    }
194    out
195}
196
197/// Strip the configured prefix off the request path, returning the
198/// post-prefix path (always starts with `/`) or `None` if the path
199/// doesn't fall under the prefix.
200fn strip_prefix<'a>(prefix_str: Option<&str>, path: &'a str) -> Option<&'a str> {
201    match prefix_str {
202        Some(prefix) => match path.strip_prefix(prefix) {
203            Some("") => Some("/"),
204            Some(rest) if rest.starts_with('/') => Some(rest),
205            _ => None,
206        },
207        None => Some(if path.is_empty() { "/" } else { path }),
208    }
209}
210
211/// Lazy holder for the path segments needed by `match_fn`.
212///
213/// Splitting `/a/b/c` into a `SmallVec<[&str; 8]>` is cheap but not free, and
214/// the most common 404 (path doesn't exist in any method) never needs the
215/// segments at all — the trie miss alone is conclusive. Building them on first
216/// `get()` saves the allocation in that case.
217struct LazySegments<'a> {
218    path: &'a str,
219    cache: Option<smallvec::SmallVec<[&'a str; 8]>>,
220}
221
222impl<'a> LazySegments<'a> {
223    fn new(path: &'a str) -> Self {
224        Self { path, cache: None }
225    }
226
227    fn get(&mut self) -> &[&'a str] {
228        self.cache
229            .get_or_insert_with(|| self.path.split('/').filter(|s| !s.is_empty()).collect())
230            .as_slice()
231    }
232}
233
234/// Find the matching route index for a given path within a method bucket.
235/// Tries the trie first, then falls back to a linear scan over conflicts.
236///
237/// Segments are produced lazily: on a clean trie miss with an empty fallback
238/// bucket, we never allocate them at all (the common 404 path).
239fn lookup_in_bucket(
240    bucket: &MethodBucket,
241    routes: &[RouteEntry],
242    lookup_path: &str,
243    segments: &mut LazySegments,
244) -> Option<usize> {
245    if let Ok(m) = bucket.trie.at(lookup_path) {
246        let idx = *m.value;
247        if (routes[idx].match_fn)(segments.get()) {
248            return Some(idx);
249        }
250    }
251    if bucket.fallback.is_empty() {
252        return None;
253    }
254    let segs = segments.get();
255    bucket
256        .fallback
257        .iter()
258        .copied()
259        .find(|&i| (routes[i].match_fn)(segs))
260}
261
262fn make_status_response(status: StatusCode, body: &'static [u8]) -> http::Response<BoxBody> {
263    let mut res = http::Response::new(body_from_bytes(bytes::Bytes::from_static(body)));
264    *res.status_mut() = status;
265    res
266}
267
268fn not_found_response() -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
269    Box::pin(std::future::ready(make_status_response(
270        StatusCode::NOT_FOUND,
271        b"Not Found",
272    )))
273}
274
275fn method_not_allowed_response() -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
276    Box::pin(std::future::ready(make_status_response(
277        StatusCode::METHOD_NOT_ALLOWED,
278        b"Method Not Allowed",
279    )))
280}
281
282/// Outcome of resolving a request against the routing table.
283enum LookupOutcome {
284    Hit(usize),
285    MethodNotAllowed,
286    NotFound,
287}
288
289/// Run the full lookup (per-method trie, fallback, then 404-vs-405 disambiguation)
290/// for `lookup_path` under `method`. Builds the segments slice lazily, so a clean
291/// 404 with no fallback never allocates one.
292fn resolve(inner: &RouterInner, method: &http::Method, lookup_path: &str) -> LookupOutcome {
293    let mut segments = LazySegments::new(lookup_path);
294
295    let bucket = inner.method_index.bucket(method);
296    if let Some(i) = lookup_in_bucket(bucket, &inner.routes, lookup_path, &mut segments) {
297        return LookupOutcome::Hit(i);
298    }
299
300    // Fast path: if our method bucket is the only populated one, no other
301    // method can possibly accept this path, so a bucket miss is a guaranteed
302    // 404. Skips the unified-trie lookup AND the per-bucket walk.
303    let bit = method_bit(method);
304    if inner.methods_present == bit {
305        return LookupOutcome::NotFound;
306    }
307
308    // General case: if no pattern in any method matches the path *shape*, it's
309    // definitely a 404. The unified `any_method_trie` lets us check this in one
310    // trie lookup. If a structural conflict was recorded at registration, the
311    // unified trie may have a false negative, so we skip the optimization.
312    let any_path_shape =
313        inner.any_method_has_fallback || inner.any_method_trie.at(lookup_path).is_ok();
314    let path_matched = any_path_shape
315        && inner
316            .method_index
317            .all_buckets()
318            .iter()
319            .any(|b| lookup_in_bucket(b, &inner.routes, lookup_path, &mut segments).is_some());
320
321    if path_matched {
322        LookupOutcome::MethodNotAllowed
323    } else {
324        LookupOutcome::NotFound
325    }
326}
327
328impl Router {
329    /// Create an empty router.
330    pub fn new() -> Self {
331        Router {
332            inner: parking_lot::RwLock::new(RouterInner {
333                routes: Vec::new(),
334                method_index: MethodIndex::default(),
335                any_method_trie: matchit::Router::new(),
336                any_method_seen: HashSet::new(),
337                any_method_has_fallback: false,
338                methods_present: 0,
339                state_injector: None,
340                fallback: None,
341                max_body_size: DEFAULT_MAX_BODY_SIZE,
342                prefix: None,
343                prefix_str: None,
344            }),
345        }
346    }
347
348    /// Set a path prefix for all routes.
349    pub(crate) fn set_prefix(&self, prefix: &str) {
350        let segments: Vec<String> = prefix
351            .split('/')
352            .filter(|s| !s.is_empty())
353            .map(|s| s.to_string())
354            .collect();
355        if segments.is_empty() {
356            return;
357        }
358        let mut joined = String::with_capacity(segments.iter().map(|s| s.len() + 1).sum());
359        for seg in &segments {
360            joined.push('/');
361            joined.push_str(seg);
362        }
363        let mut inner = self.inner.write();
364        inner.prefix = Some(segments);
365        inner.prefix_str = Some(joined);
366    }
367
368    /// Set the maximum request body size in bytes.
369    pub(crate) fn set_max_body_size(&self, max: usize) {
370        self.inner.write().max_body_size = max;
371    }
372
373    /// Register a route with a method, pattern, match function, and handler.
374    pub(crate) fn add_route(
375        &self,
376        method: http::Method,
377        pattern: String,
378        match_fn: MatchFn,
379        handler: BoxedHandler,
380    ) {
381        let mut inner = self.inner.write();
382        let idx = inner.routes.len();
383
384        inner.methods_present |= method_bit(&method);
385
386        let matchit_pattern = to_matchit_pattern(&pattern);
387        let bucket = inner.method_index.bucket_mut(&method);
388        if bucket.trie.insert(matchit_pattern.clone(), idx).is_err() {
389            // Pattern conflicts with an already-registered one (e.g. two routes
390            // collapse to the same matchit shape). Keep it in the linear
391            // fallback so registration never silently drops a route.
392            bucket.fallback.push(idx);
393        }
394
395        // Mirror the pattern into the unified path-existence trie. We only
396        // attempt the insert if we haven't seen this exact pattern before,
397        // since matchit can't distinguish "duplicate" from "conflict".
398        if inner.any_method_seen.insert(matchit_pattern.clone())
399            && inner.any_method_trie.insert(matchit_pattern, ()).is_err()
400        {
401            // Genuine structural conflict — fall back to the per-bucket
402            // walk for miss detection so we don't false-negative.
403            inner.any_method_has_fallback = true;
404        }
405
406        inner.routes.push(RouteEntry {
407            method,
408            pattern,
409            match_fn,
410            handler,
411        });
412    }
413
414    /// Set the state injector function.
415    pub(crate) fn set_state_injector(
416        &self,
417        injector: Arc<dyn Fn(&mut http::Extensions) + Send + Sync>,
418    ) {
419        self.inner.write().state_injector = Some(injector);
420    }
421
422    /// Set a fallback service for requests that don't match any typeway route.
423    pub(crate) fn set_fallback(&self, fallback: FallbackService) {
424        self.inner.write().fallback = Some(fallback);
425    }
426
427    /// Look up a handler by HTTP method and route pattern string.
428    ///
429    /// Returns a clone of the handler if found. Since `BoxedHandler` is
430    /// `Arc`-wrapped, this clone is cheap (reference count increment).
431    ///
432    /// Used by the native gRPC server to build its own dispatch table
433    /// from the already-registered REST handlers.
434    #[cfg(feature = "grpc")]
435    pub(crate) fn find_handler_by_pattern(
436        &self,
437        method: &http::Method,
438        pattern: &str,
439    ) -> Option<BoxedHandler> {
440        let inner = self.inner.read();
441        inner
442            .routes
443            .iter()
444            .find(|e| e.method == *method && e.pattern == pattern)
445            .map(|e| e.handler.clone())
446    }
447
448    /// Get a clone of the state injector, if one is set.
449    #[cfg(feature = "grpc")]
450    pub(crate) fn state_injector(&self) -> Option<StateInjector> {
451        self.inner.read().state_injector.clone()
452    }
453
454    /// Route a request to the appropriate handler.
455    ///
456    /// Must be called on `Arc<Router>` so the router outlives the returned future.
457    /// The body is collected into bytes before handler dispatch, enabling
458    /// both Hyper and Axum body types to be handled uniformly.
459    pub fn route(
460        self: &Arc<Self>,
461        req: http::Request<hyper::body::Incoming>,
462    ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
463        let inner = self.inner.read();
464
465        // Consume req into parts up front so we can borrow path from parts.uri
466        // without conflicting with the later move into the handler. If we hit
467        // a fallback path we reassemble the request below.
468        let (mut parts, body) = req.into_parts();
469
470        let path: &str = parts.uri.path();
471        let lookup_path = match strip_prefix(inner.prefix_str.as_deref(), path) {
472            Some(p) => p,
473            None => {
474                // Path doesn't fall under the configured prefix.
475                return if let Some(ref fallback) = inner.fallback {
476                    fallback(http::Request::from_parts(parts, body))
477                } else {
478                    not_found_response()
479                };
480            }
481        };
482
483        match resolve(&inner, &parts.method, lookup_path) {
484            LookupOutcome::Hit(i) => {
485                // Tell `Path<T>` how many bytes of the URI path are prefix.
486                // Skipped when there's no prefix to keep the common path allocation-free.
487                if let Some(ref prefix_str) = inner.prefix_str {
488                    parts.extensions.insert(PathPrefixOffset(prefix_str.len()));
489                }
490                if let Some(ref injector) = inner.state_injector {
491                    injector(&mut parts.extensions);
492                }
493                let router = self.clone();
494                let max_body = inner.max_body_size;
495                drop(inner);
496                Box::pin(async move {
497                    let body_bytes = match collect_body_limited(body, max_body).await {
498                        Ok(bytes) => bytes,
499                        Err(resp) => return resp,
500                    };
501                    let fut = {
502                        let inner = router.inner.read();
503                        (inner.routes[i].handler)(parts, body_bytes)
504                    };
505                    fut.await
506                })
507            }
508            LookupOutcome::MethodNotAllowed => method_not_allowed_response(),
509            LookupOutcome::NotFound => {
510                if let Some(ref fallback) = inner.fallback {
511                    fallback(http::Request::from_parts(parts, body))
512                } else {
513                    not_found_response()
514                }
515            }
516        }
517    }
518
519    /// Route a request whose body has already been collected into [`bytes::Bytes`].
520    ///
521    /// Used by adapters that have a different body type from `hyper::body::Incoming`
522    /// (e.g. the Axum interop layer), and by anything that has already buffered
523    /// a body for an unrelated reason. Bypasses the `max_body_size` check;
524    /// the caller is responsible for any size limiting.
525    pub fn route_with_bytes(
526        self: &Arc<Self>,
527        mut parts: http::request::Parts,
528        body_bytes: bytes::Bytes,
529    ) -> Pin<Box<dyn Future<Output = http::Response<BoxBody>> + Send>> {
530        let inner = self.inner.read();
531
532        let path: &str = parts.uri.path();
533        let lookup_path = match strip_prefix(inner.prefix_str.as_deref(), path) {
534            Some(p) => p,
535            None => return not_found_response(),
536        };
537
538        match resolve(&inner, &parts.method, lookup_path) {
539            LookupOutcome::Hit(i) => {
540                if let Some(ref prefix_str) = inner.prefix_str {
541                    parts.extensions.insert(PathPrefixOffset(prefix_str.len()));
542                }
543                if let Some(ref injector) = inner.state_injector {
544                    injector(&mut parts.extensions);
545                }
546                let fut = (inner.routes[i].handler)(parts, body_bytes);
547                drop(inner);
548                fut
549            }
550            LookupOutcome::MethodNotAllowed => method_not_allowed_response(),
551            LookupOutcome::NotFound => not_found_response(),
552        }
553    }
554}
555
556impl Default for Router {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562// ---------------------------------------------------------------------------
563// Body collection with size limit
564// ---------------------------------------------------------------------------
565
566/// Collect a hyper body into bytes, enforcing a size limit.
567///
568/// Returns 413 Payload Too Large if the body exceeds `max_bytes`.
569async fn collect_body_limited(
570    body: hyper::body::Incoming,
571    max_bytes: usize,
572) -> Result<bytes::Bytes, http::Response<BoxBody>> {
573    use http_body_util::BodyExt;
574
575    let limited = http_body_util::Limited::new(body, max_bytes);
576    match limited.collect().await {
577        Ok(collected) => Ok(collected.to_bytes()),
578        Err(_) => {
579            let mut res = http::Response::new(body_from_string(format!(
580                "request body too large (max {max_bytes} bytes)"
581            )));
582            *res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
583            Err(res)
584        }
585    }
586}
587
588// ---------------------------------------------------------------------------
589// Tower Service implementation
590// ---------------------------------------------------------------------------
591
592/// A [`tower::Service`] wrapper around a shared [`Router`].
593///
594/// This enables applying Tower middleware layers (tracing, CORS, compression,
595/// timeouts, etc.) to the typeway router.
596#[derive(Clone)]
597pub struct RouterService {
598    router: Arc<Router>,
599}
600
601impl RouterService {
602    /// Wrap a router in a Tower service.
603    pub fn new(router: Arc<Router>) -> Self {
604        RouterService { router }
605    }
606}
607
608impl tower_service::Service<http::Request<hyper::body::Incoming>> for RouterService {
609    type Response = http::Response<BoxBody>;
610    type Error = std::convert::Infallible;
611    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
612
613    fn poll_ready(
614        &mut self,
615        _cx: &mut std::task::Context<'_>,
616    ) -> std::task::Poll<Result<(), Self::Error>> {
617        std::task::Poll::Ready(Ok(()))
618    }
619
620    fn call(&mut self, req: http::Request<hyper::body::Incoming>) -> Self::Future {
621        let router = self.router.clone();
622        Box::pin(async move {
623            use futures::FutureExt;
624
625            let result = std::panic::AssertUnwindSafe(router.route(req))
626                .catch_unwind()
627                .await;
628
629            match result {
630                Ok(response) => Ok(response),
631                Err(panic_info) => {
632                    let message = if let Some(s) = panic_info.downcast_ref::<&str>() {
633                        (*s).to_string()
634                    } else if let Some(s) = panic_info.downcast_ref::<String>() {
635                        s.clone()
636                    } else {
637                        "unknown panic".to_string()
638                    };
639
640                    tracing::error!("handler panicked: {message}");
641
642                    let mut res =
643                        http::Response::new(body_from_string("Internal Server Error".to_string()));
644                    *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
645                    Ok(res)
646                }
647            }
648        })
649    }
650}