Skip to main content

rust_web_server/router/
mod.rs

1pub(crate) mod matcher;
2#[cfg(test)]
3mod tests;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::request::Request;
9use crate::response::Response;
10use crate::server::ConnectionInfo;
11use matcher::Segment;
12
13/// Named path-segment values extracted from a matched route pattern.
14///
15/// Given the pattern `/users/:id/posts/:post_id` matched against
16/// `/users/42/posts/7`, `params.get("id")` returns `Some("42")` and
17/// `params.get("post_id")` returns `Some("7")`.
18///
19/// Wildcard segments (`*name`) capture everything after the prefix:
20/// `/files/*path` matched against `/files/a/b/c` gives `path = "a/b/c"`.
21pub struct PathParams {
22    params: HashMap<String, String>,
23}
24
25impl PathParams {
26    /// Build a `PathParams` from an existing map — used to adapt
27    /// [`matcher::try_match`]'s output for both `Router` and `AsyncAppWithState`.
28    pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
29        PathParams { params }
30    }
31
32    /// Returns the value for the named parameter, or `None` if absent.
33    pub fn get(&self, name: &str) -> Option<&str> {
34        self.params.get(name).map(String::as_str)
35    }
36}
37
38type HandlerFn =
39    Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
40
41#[derive(Clone)]
42struct Route {
43    method: String,
44    segments: Vec<Segment>,
45    handler: HandlerFn,
46}
47
48/// A path-based HTTP router with named parameter extraction.
49///
50/// Register routes with [`Router::get`], [`Router::post`], etc. Each handler
51/// receives the parsed [`PathParams`] alongside the raw [`Request`] and
52/// [`ConnectionInfo`]. Call [`Router::handle`] from inside a [`Controller`]
53/// or an [`Application::execute`] implementation.
54///
55/// # Example
56///
57/// ```rust,no_run
58/// use rust_web_server::router::{Router, PathParams};
59/// use rust_web_server::request::Request;
60/// use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
61/// use rust_web_server::range::Range;
62/// use rust_web_server::mime_type::MimeType;
63/// use rust_web_server::server::ConnectionInfo;
64/// use rust_web_server::core::New;
65///
66/// let router = Router::new()
67///     .get("/hello", |_req, _params, _conn| {
68///         let mut r = Response::new();
69///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
70///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
71///         r.content_range_list = vec![Range::get_content_range(b"hello".to_vec(), MimeType::TEXT_PLAIN.to_string())];
72///         r
73///     })
74///     .get("/users/:id", |_req, params, _conn| {
75///         let id = params.get("id").unwrap_or("unknown");
76///         let mut r = Response::new();
77///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
78///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
79///         r.content_range_list = vec![Range::get_content_range(
80///             format!("user {}", id).into_bytes(),
81///             MimeType::TEXT_PLAIN.to_string(),
82///         )];
83///         r
84///     });
85/// ```
86/// A registered route entry returned by [`Router::route_entries`].
87#[derive(Clone)]
88pub struct RouteInfo {
89    pub method: String,
90    pub pattern: String,
91}
92
93#[derive(Clone)]
94pub struct Router {
95    routes: Vec<Route>,
96    /// When set, `handle()` only matches if the request's SNI hostname (or
97    /// `Host` header for plain HTTP) equals this value.
98    host: Option<String>,
99}
100
101impl Router {
102    pub fn new() -> Self {
103        Router { routes: Vec::new(), host: None }
104    }
105
106    /// Restrict this router to requests whose SNI hostname (TLS) or `Host`
107    /// header (plain HTTP) matches `host`.  Call before registering routes.
108    pub fn with_host(mut self, host: &str) -> Self {
109        self.host = Some(host.to_string());
110        self
111    }
112
113    /// Register a `GET` handler for `pattern`.
114    pub fn get<F>(self, pattern: &str, handler: F) -> Self
115    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
116        self.add("GET", pattern, handler)
117    }
118
119    /// Register a `POST` handler for `pattern`.
120    pub fn post<F>(self, pattern: &str, handler: F) -> Self
121    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
122        self.add("POST", pattern, handler)
123    }
124
125    /// Register a `PUT` handler for `pattern`.
126    pub fn put<F>(self, pattern: &str, handler: F) -> Self
127    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
128        self.add("PUT", pattern, handler)
129    }
130
131    /// Register a `PATCH` handler for `pattern`.
132    pub fn patch<F>(self, pattern: &str, handler: F) -> Self
133    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
134        self.add("PATCH", pattern, handler)
135    }
136
137    /// Register a `DELETE` handler for `pattern`.
138    pub fn delete<F>(self, pattern: &str, handler: F) -> Self
139    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
140        self.add("DELETE", pattern, handler)
141    }
142
143    fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
144    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
145        self.routes.push(Route {
146            method: method.to_string(),
147            segments: matcher::parse_pattern(pattern),
148            handler: Arc::new(handler),
149        });
150        self
151    }
152
153    /// Return a snapshot of all registered routes as `(method, pattern)` pairs.
154    ///
155    /// Patterns are reconstructed from parsed segments, so the output exactly
156    /// matches what was passed to `.get()`, `.post()`, etc. at registration time.
157    pub fn route_entries(&self) -> Vec<RouteInfo> {
158        self.routes.iter().map(|r| RouteInfo {
159            method: r.method.clone(),
160            pattern: matcher::segments_to_pattern(&r.segments),
161        }).collect()
162    }
163
164    /// Try to match `request` against registered routes in registration order.
165    ///
166    /// Returns `Some(response)` on the first match, `None` if no route matches.
167    /// The query string is stripped before matching; only the path is used.
168    ///
169    /// When `.with_host()` is set, this returns `None` immediately unless the
170    /// request's SNI hostname (TLS) or `Host` header (plain HTTP) matches.
171    pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
172        if let Some(required_host) = &self.host {
173            let actual = connection.sni_hostname.as_deref().or_else(|| {
174                request.headers.iter()
175                    .find(|h| h.name.eq_ignore_ascii_case("host"))
176                    .map(|h| h.value.as_str())
177            });
178            if actual != Some(required_host.as_str()) {
179                return None;
180            }
181        }
182
183        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
184        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
185
186        for route in &self.routes {
187            if route.method != request.method {
188                continue;
189            }
190            if let Some(params) = matcher::try_match(&route.segments, &path_segs) {
191                let params = PathParams::from_map(params);
192                return Some((route.handler)(request, &params, connection));
193            }
194        }
195        None
196    }
197}