Skip to main content

rust_web_server/router/
mod.rs

1pub(crate) mod matcher;
2#[cfg(test)]
3mod tests;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::request::Request;
9use crate::response::Response;
10use crate::server::ConnectionInfo;
11use matcher::Segment;
12
13/// Named path-segment values extracted from a matched route pattern.
14///
15/// Given the pattern `/users/:id/posts/:post_id` matched against
16/// `/users/42/posts/7`, `params.get("id")` returns `Some("42")` and
17/// `params.get("post_id")` returns `Some("7")`.
18///
19/// Wildcard segments (`*name`) capture everything after the prefix:
20/// `/files/*path` matched against `/files/a/b/c` gives `path = "a/b/c"`.
21#[derive(Clone)]
22pub struct PathParams {
23    params: HashMap<String, String>,
24}
25
26impl PathParams {
27    /// Build a `PathParams` from an existing map — used to adapt
28    /// [`matcher::try_match`]'s output for both `Router` and `AsyncAppWithState`.
29    pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
30        PathParams { params }
31    }
32
33    /// Returns the value for the named parameter, or `None` if absent.
34    pub fn get(&self, name: &str) -> Option<&str> {
35        self.params.get(name).map(String::as_str)
36    }
37}
38
39type HandlerFn =
40    Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
41
42#[derive(Clone)]
43struct Route {
44    method: String,
45    segments: Vec<Segment>,
46    handler: HandlerFn,
47}
48
49/// A path-based HTTP router with named parameter extraction.
50///
51/// Register routes with [`Router::get`], [`Router::post`], etc. Each handler
52/// receives the parsed [`PathParams`] alongside the raw [`Request`] and
53/// [`ConnectionInfo`]. Call [`Router::handle`] from inside a [`Controller`]
54/// or an [`Application::execute`] implementation.
55///
56/// # Example
57///
58/// ```rust,no_run
59/// use rust_web_server::router::{Router, PathParams};
60/// use rust_web_server::request::Request;
61/// use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
62/// use rust_web_server::range::Range;
63/// use rust_web_server::mime_type::MimeType;
64/// use rust_web_server::server::ConnectionInfo;
65/// use rust_web_server::core::New;
66///
67/// let router = Router::new()
68///     .get("/hello", |_req, _params, _conn| {
69///         let mut r = Response::new();
70///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
71///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
72///         r.content_range_list = vec![Range::get_content_range(b"hello".to_vec(), MimeType::TEXT_PLAIN.to_string())];
73///         r
74///     })
75///     .get("/users/:id", |_req, params, _conn| {
76///         let id = params.get("id").unwrap_or("unknown");
77///         let mut r = Response::new();
78///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
79///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
80///         r.content_range_list = vec![Range::get_content_range(
81///             format!("user {}", id).into_bytes(),
82///             MimeType::TEXT_PLAIN.to_string(),
83///         )];
84///         r
85///     });
86/// ```
87/// A registered route entry returned by [`Router::route_entries`].
88#[derive(Clone)]
89pub struct RouteInfo {
90    pub method: String,
91    pub pattern: String,
92}
93
94#[derive(Clone)]
95pub struct Router {
96    routes: Vec<Route>,
97    /// When set, `handle()` only matches if the request's SNI hostname (or
98    /// `Host` header for plain HTTP) equals this value.
99    host: Option<String>,
100}
101
102impl Router {
103    pub fn new() -> Self {
104        Router { routes: Vec::new(), host: None }
105    }
106
107    /// Restrict this router to requests whose SNI hostname (TLS) or `Host`
108    /// header (plain HTTP) matches `host`.  Call before registering routes.
109    pub fn with_host(mut self, host: &str) -> Self {
110        self.host = Some(host.to_string());
111        self
112    }
113
114    /// Register a `GET` handler for `pattern`.
115    pub fn get<F>(self, pattern: &str, handler: F) -> Self
116    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
117        self.add("GET", pattern, handler)
118    }
119
120    /// Register a `POST` handler for `pattern`.
121    pub fn post<F>(self, pattern: &str, handler: F) -> Self
122    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
123        self.add("POST", pattern, handler)
124    }
125
126    /// Register a `PUT` handler for `pattern`.
127    pub fn put<F>(self, pattern: &str, handler: F) -> Self
128    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
129        self.add("PUT", pattern, handler)
130    }
131
132    /// Register a `PATCH` handler for `pattern`.
133    pub fn patch<F>(self, pattern: &str, handler: F) -> Self
134    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
135        self.add("PATCH", pattern, handler)
136    }
137
138    /// Register a `DELETE` handler for `pattern`.
139    pub fn delete<F>(self, pattern: &str, handler: F) -> Self
140    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
141        self.add("DELETE", pattern, handler)
142    }
143
144    fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
145    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
146        self.routes.push(Route {
147            method: method.to_string(),
148            segments: matcher::parse_pattern(pattern),
149            handler: Arc::new(handler),
150        });
151        self
152    }
153
154    /// Return a snapshot of all registered routes as `(method, pattern)` pairs.
155    ///
156    /// Patterns are reconstructed from parsed segments, so the output exactly
157    /// matches what was passed to `.get()`, `.post()`, etc. at registration time.
158    pub fn route_entries(&self) -> Vec<RouteInfo> {
159        self.routes.iter().map(|r| RouteInfo {
160            method: r.method.clone(),
161            pattern: matcher::segments_to_pattern(&r.segments),
162        }).collect()
163    }
164
165    /// Try to match `request` against registered routes in registration order.
166    ///
167    /// Returns `Some(response)` on the first match, `None` if no route matches.
168    /// The query string is stripped before matching; only the path is used.
169    ///
170    /// When `.with_host()` is set, this returns `None` immediately unless the
171    /// request's SNI hostname (TLS) or `Host` header (plain HTTP) matches.
172    pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
173        if let Some(required_host) = &self.host {
174            let actual = connection.sni_hostname.as_deref().or_else(|| {
175                request.headers.iter()
176                    .find(|h| h.name.eq_ignore_ascii_case("host"))
177                    .map(|h| h.value.as_str())
178            });
179            if actual != Some(required_host.as_str()) {
180                return None;
181            }
182        }
183
184        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
185        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
186
187        for route in &self.routes {
188            if route.method != request.method {
189                continue;
190            }
191            if let Some(params) = matcher::try_match(&route.segments, &path_segs) {
192                let params = PathParams::from_map(params);
193                return Some((route.handler)(request, &params, connection));
194            }
195        }
196        None
197    }
198}