Skip to main content

rust_web_server/router/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request::Request;
8use crate::response::Response;
9use crate::server::ConnectionInfo;
10
11/// Named path-segment values extracted from a matched route pattern.
12///
13/// Given the pattern `/users/:id/posts/:post_id` matched against
14/// `/users/42/posts/7`, `params.get("id")` returns `Some("42")` and
15/// `params.get("post_id")` returns `Some("7")`.
16///
17/// Wildcard segments (`*name`) capture everything after the prefix:
18/// `/files/*path` matched against `/files/a/b/c` gives `path = "a/b/c"`.
19pub struct PathParams {
20    params: HashMap<String, String>,
21}
22
23impl PathParams {
24    fn new() -> Self {
25        PathParams { params: HashMap::new() }
26    }
27
28    /// Build a `PathParams` from an existing map. Used by `AsyncAppWithState`.
29    pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
30        PathParams { params }
31    }
32
33    /// Returns the value for the named parameter, or `None` if absent.
34    pub fn get(&self, name: &str) -> Option<&str> {
35        self.params.get(name).map(String::as_str)
36    }
37
38    fn insert(&mut self, key: String, value: String) {
39        self.params.insert(key, value);
40    }
41}
42
43#[derive(Clone)]
44enum Segment {
45    Literal(String),
46    Param(String),
47    Wildcard(String),
48}
49
50type HandlerFn =
51    Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
52
53#[derive(Clone)]
54struct Route {
55    method: String,
56    segments: Vec<Segment>,
57    handler: HandlerFn,
58}
59
60/// A path-based HTTP router with named parameter extraction.
61///
62/// Register routes with [`Router::get`], [`Router::post`], etc. Each handler
63/// receives the parsed [`PathParams`] alongside the raw [`Request`] and
64/// [`ConnectionInfo`]. Call [`Router::handle`] from inside a [`Controller`]
65/// or an [`Application::execute`] implementation.
66///
67/// # Example
68///
69/// ```rust,no_run
70/// use rust_web_server::router::{Router, PathParams};
71/// use rust_web_server::request::Request;
72/// use rust_web_server::response::{Response, STATUS_CODE_REASON_PHRASE};
73/// use rust_web_server::range::Range;
74/// use rust_web_server::mime_type::MimeType;
75/// use rust_web_server::server::ConnectionInfo;
76/// use rust_web_server::core::New;
77///
78/// let router = Router::new()
79///     .get("/hello", |_req, _params, _conn| {
80///         let mut r = Response::new();
81///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
82///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
83///         r.content_range_list = vec![Range::get_content_range(b"hello".to_vec(), MimeType::TEXT_PLAIN.to_string())];
84///         r
85///     })
86///     .get("/users/:id", |_req, params, _conn| {
87///         let id = params.get("id").unwrap_or("unknown");
88///         let mut r = Response::new();
89///         r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
90///         r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
91///         r.content_range_list = vec![Range::get_content_range(
92///             format!("user {}", id).into_bytes(),
93///             MimeType::TEXT_PLAIN.to_string(),
94///         )];
95///         r
96///     });
97/// ```
98#[derive(Clone)]
99pub struct Router {
100    routes: Vec<Route>,
101}
102
103impl Router {
104    pub fn new() -> Self {
105        Router { routes: Vec::new() }
106    }
107
108    /// Register a `GET` handler for `pattern`.
109    pub fn get<F>(self, pattern: &str, handler: F) -> Self
110    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
111        self.add("GET", pattern, handler)
112    }
113
114    /// Register a `POST` handler for `pattern`.
115    pub fn post<F>(self, pattern: &str, handler: F) -> Self
116    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
117        self.add("POST", pattern, handler)
118    }
119
120    /// Register a `PUT` handler for `pattern`.
121    pub fn put<F>(self, pattern: &str, handler: F) -> Self
122    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
123        self.add("PUT", pattern, handler)
124    }
125
126    /// Register a `PATCH` handler for `pattern`.
127    pub fn patch<F>(self, pattern: &str, handler: F) -> Self
128    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
129        self.add("PATCH", pattern, handler)
130    }
131
132    /// Register a `DELETE` handler for `pattern`.
133    pub fn delete<F>(self, pattern: &str, handler: F) -> Self
134    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
135        self.add("DELETE", pattern, handler)
136    }
137
138    fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
139    where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
140        self.routes.push(Route {
141            method: method.to_string(),
142            segments: Self::parse_pattern(pattern),
143            handler: Arc::new(handler),
144        });
145        self
146    }
147
148    fn parse_pattern(pattern: &str) -> Vec<Segment> {
149        if pattern == "/" {
150            return vec![];
151        }
152        pattern
153            .split('/')
154            .filter(|s| !s.is_empty())
155            .map(|seg| {
156                if let Some(name) = seg.strip_prefix(':') {
157                    Segment::Param(name.to_string())
158                } else if let Some(name) = seg.strip_prefix('*') {
159                    Segment::Wildcard(name.to_string())
160                } else {
161                    Segment::Literal(seg.to_string())
162                }
163            })
164            .collect()
165    }
166
167    /// Try to match `request` against registered routes in registration order.
168    ///
169    /// Returns `Some(response)` on the first match, `None` if no route matches.
170    /// The query string is stripped before matching; only the path is used.
171    pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
172        let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
173        let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
174
175        for route in &self.routes {
176            if route.method != request.method {
177                continue;
178            }
179            if let Some(params) = Self::try_match(&route.segments, &path_segs) {
180                return Some((route.handler)(request, &params, connection));
181            }
182        }
183        None
184    }
185
186    fn try_match(pattern: &[Segment], path: &[&str]) -> Option<PathParams> {
187        let mut params = PathParams::new();
188        let mut pi = 0;
189
190        for (si, seg) in pattern.iter().enumerate() {
191            match seg {
192                Segment::Literal(lit) => {
193                    if pi >= path.len() || path[pi] != lit.as_str() {
194                        return None;
195                    }
196                    pi += 1;
197                }
198                Segment::Param(name) => {
199                    if pi >= path.len() {
200                        return None;
201                    }
202                    params.insert(name.clone(), path[pi].to_string());
203                    pi += 1;
204                }
205                Segment::Wildcard(name) => {
206                    if si != pattern.len() - 1 {
207                        return None; // wildcard must be the last segment
208                    }
209                    params.insert(name.clone(), path[pi..].join("/"));
210                    pi = path.len();
211                }
212            }
213        }
214
215        if pi == path.len() { Some(params) } else { None }
216    }
217}