tork_core/router/
matcher.rs1use std::collections::HashMap;
9
10use http::Method;
11
12use crate::error::{Error, Result};
13use crate::extract::PathParams;
14use crate::router::Route;
15
16pub enum Match<'a> {
18 Found {
20 route: &'a Route,
22 params: PathParams,
24 },
25 MethodNotAllowed,
27 NotFound,
29}
30
31pub struct Matcher {
33 by_method: HashMap<Method, matchit::Router<usize>>,
34 all_paths: matchit::Router<()>,
35 routes: Vec<Route>,
36}
37
38impl Matcher {
39 pub fn build(routes: Vec<Route>) -> Result<Self> {
46 let mut by_method: HashMap<Method, matchit::Router<usize>> = HashMap::new();
47 let mut all_paths: matchit::Router<()> = matchit::Router::new();
48
49 for (index, route) in routes.iter().enumerate() {
50 let method_router = by_method.entry(route.method().clone()).or_default();
51
52 method_router.insert(route.path(), index).map_err(|error| {
53 Error::internal(format!(
54 "failed to register route {} {}: {error}",
55 route.method(),
56 route.path()
57 ))
58 })?;
59
60 let _ = all_paths.insert(route.path(), ());
63 }
64
65 Ok(Self {
66 by_method,
67 all_paths,
68 routes,
69 })
70 }
71
72 pub fn find(&self, method: &Method, path: &str) -> Match<'_> {
74 if path.contains('\0') {
78 return Match::NotFound;
79 }
80 if let Some(method_router) = self.by_method.get(method) {
81 if let Ok(matched) = method_router.at(path) {
82 let mut params = PathParams::new();
83 for (name, value) in matched.params.iter() {
84 params.push(name.to_owned(), value.to_owned());
85 }
86 return Match::Found {
87 route: &self.routes[*matched.value],
88 params,
89 };
90 }
91
92 if let Some(normalized) = normalized_request_path(path) {
93 if let Ok(matched) = method_router.at(normalized) {
94 let mut params = PathParams::new();
95 for (name, value) in matched.params.iter() {
96 params.push(name.to_owned(), value.to_owned());
97 }
98 return Match::Found {
99 route: &self.routes[*matched.value],
100 params,
101 };
102 }
103 }
104
105 if let Some(collapsed) = collapse_double_slashes(path) {
106 if let Ok(matched) = method_router.at(&collapsed) {
107 let mut params = PathParams::new();
108 for (name, value) in matched.params.iter() {
109 params.push(name.to_owned(), value.to_owned());
110 }
111 return Match::Found {
112 route: &self.routes[*matched.value],
113 params,
114 };
115 }
116 }
117 }
118
119 if self.all_paths.at(path).is_ok() {
120 Match::MethodNotAllowed
121 } else if let Some(normalized) = normalized_request_path(path) {
122 if self.all_paths.at(normalized).is_ok() {
123 Match::MethodNotAllowed
124 } else {
125 Match::NotFound
126 }
127 } else if let Some(collapsed) = collapse_double_slashes(path) {
128 if self.all_paths.at(&collapsed).is_ok() {
129 Match::MethodNotAllowed
130 } else {
131 Match::NotFound
132 }
133 } else {
134 Match::NotFound
135 }
136 }
137
138 pub fn routes(&self) -> &[Route] {
140 &self.routes
141 }
142}
143
144fn normalized_request_path(path: &str) -> Option<&str> {
150 if path == "/" || !path.ends_with('/') {
151 return None;
152 }
153
154 let trimmed = path.trim_end_matches('/');
155 Some(if trimmed.is_empty() { "/" } else { trimmed })
156}
157
158fn collapse_double_slashes(path: &str) -> Option<String> {
163 if !path.contains("//") {
164 return None;
165 }
166 let collapsed: String = path
167 .split('/')
168 .filter(|s| !s.is_empty())
169 .collect::<Vec<_>>()
170 .join("/");
171 Some(if collapsed.is_empty() {
172 "/".to_owned()
173 } else {
174 collapsed
175 })
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::error::Result;
182 use crate::extract::RequestContext;
183 use crate::response::{empty, Response};
184 use crate::router::{BoxFuture, HandlerFn};
185 use http::StatusCode;
186 use std::sync::Arc;
187
188 fn dummy_handler() -> HandlerFn {
189 Arc::new(
190 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
191 Box::pin(async { Ok(empty(StatusCode::OK)) })
192 },
193 )
194 }
195
196 fn matcher() -> Matcher {
197 Matcher::build(vec![Route::new(
198 Method::GET,
199 "/users/{user_id}",
200 dummy_handler(),
201 )])
202 .unwrap()
203 }
204
205 #[test]
206 fn matches_and_captures_params() {
207 match matcher().find(&Method::GET, "/users/42") {
208 Match::Found { params, .. } => assert_eq!(params.get("user_id"), Some("42")),
209 _ => panic!("expected a match"),
210 }
211 }
212
213 #[test]
214 fn trailing_slash_is_ignored() {
215 assert!(matches!(
216 matcher().find(&Method::GET, "/users/42/"),
217 Match::Found { .. }
218 ));
219 }
220
221 #[test]
222 fn wrong_method_is_method_not_allowed() {
223 assert!(matches!(
224 matcher().find(&Method::POST, "/users/42"),
225 Match::MethodNotAllowed
226 ));
227 }
228
229 #[test]
230 fn unknown_path_is_not_found() {
231 assert!(matches!(
232 matcher().find(&Method::GET, "/unknown"),
233 Match::NotFound
234 ));
235 }
236
237 #[test]
238 fn build_rejects_duplicate_same_method_and_path() {
239 let routes = vec![
240 Route::new(Method::GET, "/users/{user_id}", dummy_handler()),
241 Route::new(Method::GET, "/users/{user_id}", dummy_handler()),
242 ];
243 let err = match Matcher::build(routes) {
244 Ok(_) => panic!("expected duplicate route registration to fail"),
245 Err(err) => err,
246 };
247 assert!(err
248 .to_string()
249 .contains("failed to register route GET /users/{user_id}"));
250 }
251
252 #[test]
253 fn normalized_request_path_covers_root_and_trailing_slashes() {
254 assert_eq!(normalized_request_path("/"), None);
255 assert_eq!(normalized_request_path("/users"), None);
256 assert_eq!(normalized_request_path("/users/"), Some("/users"));
257 assert_eq!(normalized_request_path("/users///"), Some("/users"));
258 }
259
260 #[test]
261 fn root_path_matches_and_method_not_allowed_uses_all_paths() {
262 let routes = vec![
263 Route::new(Method::GET, "/", dummy_handler()),
264 Route::new(Method::POST, "/users", dummy_handler()),
265 ];
266 let matcher = Matcher::build(routes).unwrap();
267 assert!(matches!(
268 matcher.find(&Method::GET, "/"),
269 Match::Found { .. }
270 ));
271 assert!(matches!(
272 matcher.find(&Method::POST, "/"),
273 Match::MethodNotAllowed
274 ));
275 assert!(matches!(
276 matcher.find(&Method::GET, "/users/"),
277 Match::MethodNotAllowed
278 ));
279 }
280}