Skip to main content

rust_web_server/router/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request::Request;
8use crate::response::Response;
9use crate::server::ConnectionInfo;
10
11/// Named path-segment values extracted from a matched route pattern.
12///
13/// Given the pattern `/users/:id/posts/:post_id` matched against
14/// `/users/42/posts/7`, `params.get("id")` returns `Some("42")` and
15/// `params.get("post_id")` returns `Some("7")`.
16///
17/// Wildcard segments (`*name`) capture everything after the prefix:
18/// `/files/*path` matched against `/files/a/b/c` gives `path = "a/b/c"`.
19pub struct PathParams {
20    params: HashMap<String, String>,
21}
22
23impl PathParams {
24    fn new() -> Self {
25        PathParams { params: HashMap::new() }
26    }
27
28    /// Build a `PathParams` from an existing map. Used by `AsyncAppWithState`.
29    pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
30        PathParams { params }
31    }
32
33    /// Returns the value for the named parameter, or `None` if absent.
34    pub fn get(&self, name: &str) -> Option<&str> {
35        self.params.get(name).map(String::as_str)
36    }
37
38    fn insert(&mut self, key: String, value: String) {
39        self.params.insert(key, value);
40    }
41}
42
43#[derive(Clone)]
44enum Segment {
45    Literal(String),
46    Param(String),
47    Wildcard(String),
48}
49
50type HandlerFn =
51    Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
52
53#[derive(Clone)]
54struct Route {
55    method: String,
56    segments: Vec<Segment>,
57    handler: HandlerFn,
58}
59
60/// A path-based HTTP router with named parameter extraction.
61///
62/// Register routes with [`Router::get`], [`Router::post`], etc. Each handler
63/// receives the parsed [`PathParams`] alongside the raw [`Request`] and
64/// [`ConnectionInfo`]. Call [`Router::handle`] from inside a [`Controller`]
65/// or an [`Application::execute`] implementation.
66///
67/// # Example
68///
69/// ```rust,no_run
70/// use rust_web_server::router::{Router, PathParams};
71/// use rust_web_server::request::Request;
72/// use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
73/// use rust_web_server::range::Range;
74/// use rust_web_server::mime_type::MimeType;
75/// use rust_web_server::server::ConnectionInfo;
76/// use rust_web_server::core::New;
77///
78/// let router = Router::new()
79///     .get("/hello", |_req, _params, _conn| {
80///         let mut r = Response::new();
81///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
82///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
83///         r.content_range_list = vec![Range::get_content_range(b"hello".to_vec(), MimeType::TEXT_PLAIN.to_string())];
84///         r
85///     })
86///     .get("/users/:id", |_req, params, _conn| {
87///         let id = params.get("id").unwrap_or("unknown");
88///         let mut r = Response::new();
89///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
90///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
91///         r.content_range_list = vec![Range::get_content_range(
92///             format!("user {}", id).into_bytes(),
93///             MimeType::TEXT_PLAIN.to_string(),
94///         )];
95///         r
96///     });
97/// ```
98/// A registered route entry returned by [`Router::route_entries`].
99#[derive(Clone)]
100pub struct RouteInfo {
101    pub method: String,
102    pub pattern: String,
103}
104
105#[derive(Clone)]
106pub struct Router {
107    routes: Vec<Route>,
108    /// When set, `handle()` only matches if the request's SNI hostname (or
109    /// `Host` header for plain HTTP) equals this value.
110    host: Option<String>,
111}
112
113impl Router {
114    pub fn new() -> Self {
115        Router { routes: Vec::new(), host: None }
116    }
117
118    /// Restrict this router to requests whose SNI hostname (TLS) or `Host`
119    /// header (plain HTTP) matches `host`.  Call before registering routes.
120    pub fn with_host(mut self, host: &str) -> Self {
121        self.host = Some(host.to_string());
122        self
123    }
124
125    /// Register a `GET` handler for `pattern`.
126    pub fn get<F>(self, pattern: &str, handler: F) -> Self
127    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
128        self.add("GET", pattern, handler)
129    }
130
131    /// Register a `POST` handler for `pattern`.
132    pub fn post<F>(self, pattern: &str, handler: F) -> Self
133    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
134        self.add("POST", pattern, handler)
135    }
136
137    /// Register a `PUT` handler for `pattern`.
138    pub fn put<F>(self, pattern: &str, handler: F) -> Self
139    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
140        self.add("PUT", pattern, handler)
141    }
142
143    /// Register a `PATCH` handler for `pattern`.
144    pub fn patch<F>(self, pattern: &str, handler: F) -> Self
145    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
146        self.add("PATCH", pattern, handler)
147    }
148
149    /// Register a `DELETE` handler for `pattern`.
150    pub fn delete<F>(self, pattern: &str, handler: F) -> Self
151    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
152        self.add("DELETE", pattern, handler)
153    }
154
155    fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
156    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
157        self.routes.push(Route {
158            method: method.to_string(),
159            segments: Self::parse_pattern(pattern),
160            handler: Arc::new(handler),
161        });
162        self
163    }
164
165    fn parse_pattern(pattern: &str) -> Vec<Segment> {
166        if pattern == "/" {
167            return vec![];
168        }
169        pattern
170            .split('/')
171            .filter(|s| !s.is_empty())
172            .map(|seg| {
173                if let Some(name) = seg.strip_prefix(':') {
174                    Segment::Param(name.to_string())
175                } else if let Some(name) = seg.strip_prefix('*') {
176                    Segment::Wildcard(name.to_string())
177                } else {
178                    Segment::Literal(seg.to_string())
179                }
180            })
181            .collect()
182    }
183
184    /// Return a snapshot of all registered routes as `(method, pattern)` pairs.
185    ///
186    /// Patterns are reconstructed from parsed segments, so the output exactly
187    /// matches what was passed to `.get()`, `.post()`, etc. at registration time.
188    pub fn route_entries(&self) -> Vec<RouteInfo> {
189        self.routes.iter().map(|r| RouteInfo {
190            method: r.method.clone(),
191            pattern: Self::segments_to_pattern(&r.segments),
192        }).collect()
193    }
194
195    fn segments_to_pattern(segs: &[Segment]) -> String {
196        if segs.is_empty() {
197            return "/".to_string();
198        }
199        let parts: Vec<String> = segs.iter().map(|s| match s {
200            Segment::Literal(l) => l.clone(),
201            Segment::Param(n) => format!(":{}", n),
202            Segment::Wildcard(n) => format!("*{}", n),
203        }).collect();
204        format!("/{}", parts.join("/"))
205    }
206
207    /// Try to match `request` against registered routes in registration order.
208    ///
209    /// Returns `Some(response)` on the first match, `None` if no route matches.
210    /// The query string is stripped before matching; only the path is used.
211    ///
212    /// When `.with_host()` is set, this returns `None` immediately unless the
213    /// request's SNI hostname (TLS) or `Host` header (plain HTTP) matches.
214    pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
215        if let Some(required_host) = &self.host {
216            let actual = connection.sni_hostname.as_deref().or_else(|| {
217                request.headers.iter()
218                    .find(|h| h.name.eq_ignore_ascii_case("host"))
219                    .map(|h| h.value.as_str())
220            });
221            if actual != Some(required_host.as_str()) {
222                return None;
223            }
224        }
225
226        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
227        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
228
229        for route in &self.routes {
230            if route.method != request.method {
231                continue;
232            }
233            if let Some(params) = Self::try_match(&route.segments, &path_segs) {
234                return Some((route.handler)(request, &params, connection));
235            }
236        }
237        None
238    }
239
240    fn try_match(pattern: &[Segment], path: &[&str]) -> Option<PathParams> {
241        let mut params = PathParams::new();
242        let mut pi = 0;
243
244        for (si, seg) in pattern.iter().enumerate() {
245            match seg {
246                Segment::Literal(lit) => {
247                    if pi >= path.len() || path[pi] != lit.as_str() {
248                        return None;
249                    }
250                    pi += 1;
251                }
252                Segment::Param(name) => {
253                    if pi >= path.len() {
254                        return None;
255                    }
256                    params.insert(name.clone(), path[pi].to_string());
257                    pi += 1;
258                }
259                Segment::Wildcard(name) => {
260                    if si != pattern.len() - 1 {
261                        return None; // wildcard must be the last segment
262                    }
263                    params.insert(name.clone(), path[pi..].join("/"));
264                    pi = path.len();
265                }
266            }
267        }
268
269        if pi == path.len() { Some(params) } else { None }
270    }
271}