1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::Error;
14use crate::http::{Request, Response};
15use crate::middleware::{MiddlewareFn, Next};
16
17type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
18type HandlerFn = Arc<dyn Fn(Request, Params) -> BoxFuture<Result<Response, Error>> + Send + Sync>;
19
20pub struct Params {
21 pairs: Vec<(String, String)>,
22}
23
24impl Params {
25 fn empty() -> Self {
26 Self { pairs: Vec::new() }
27 }
28
29 pub fn get(&self, name: &str) -> Option<&str> {
30 self.pairs
31 .iter()
32 .find_map(|(k, v)| (k == name).then_some(v.as_str()))
33 }
34
35 pub fn len(&self) -> usize {
36 self.pairs.len()
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.pairs.is_empty()
41 }
42}
43
44enum Segment {
45 Literal(String),
46 Param(String),
47}
48
49struct Route {
50 method: Method,
51 segments: Vec<Segment>,
52 handler: HandlerFn,
53}
54
55pub struct Router {
56 routes: Vec<Route>,
57 middlewares: Vec<MiddlewareFn>,
58}
59
60impl Router {
61 pub fn new() -> Self {
62 Self {
63 routes: Vec::new(),
64 middlewares: Vec::new(),
65 }
66 }
67
68 pub fn wrap<F, Fut>(mut self, middleware: F) -> Self
69 where
70 F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
71 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
72 {
73 self.middlewares
74 .push(Arc::new(move |req, next| Box::pin(middleware(req, next))));
75 self
76 }
77
78 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
79 where
80 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
81 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
82 {
83 self.route(Method::GET, path, handler)
84 }
85
86 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
87 where
88 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
89 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
90 {
91 self.route(Method::POST, path, handler)
92 }
93
94 fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
95 where
96 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
97 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
98 {
99 let handler: HandlerFn = Arc::new(move |req, params| Box::pin(handler(req, params)));
100 self.routes.push(Route {
101 method,
102 segments: parse_path(path),
103 handler,
104 });
105 self
106 }
107
108 pub fn has_route(&self, method: &Method, path: &str) -> bool {
116 let target = parse_path(path);
117 self.routes
118 .iter()
119 .any(|r| r.method == *method && segments_equal(&r.segments, &target))
120 }
121
122 pub async fn dispatch(&self, req: Request) -> Response {
123 let path = req.uri().path().to_owned();
124 let actual: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
125 let method = req.method().clone();
126
127 let mut found: Option<(HandlerFn, Params)> = None;
128 let mut path_matched = false;
129
130 for route in &self.routes {
131 if let Some(params) = match_segments(&route.segments, &actual) {
132 path_matched = true;
133 if route.method == method {
134 found = Some((route.handler.clone(), params));
135 break;
136 }
137 }
138 }
139
140 let (handler, params) = found.unwrap_or_else(|| {
141 let method_not_allowed = path_matched;
142 let fallback: HandlerFn = Arc::new(move |_req, _params| {
143 let err = if method_not_allowed {
144 Error::MethodNotAllowed
145 } else {
146 Error::NotFound
147 };
148 Box::pin(async move { Err(err) })
149 });
150 (fallback, Params::empty())
151 });
152
153 let chain = build_chain(&self.middlewares, handler, params);
154 match chain(req).await {
155 Ok(resp) => resp,
156 Err(err) => err.into_response(),
157 }
158 }
159}
160
161impl Default for Router {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167fn build_chain(
168 middlewares: &[MiddlewareFn],
169 handler: HandlerFn,
170 params: Params,
171) -> Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> {
172 let mut chain: Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> =
173 Box::new(move |req| handler(req, params));
174
175 for mw in middlewares.iter().rev() {
176 let mw = mw.clone();
177 let inner = chain;
178 chain = Box::new(move |req| {
179 let next = Next::new(inner);
180 mw(req, next)
181 });
182 }
183 chain
184}
185
186fn parse_path(path: &str) -> Vec<Segment> {
187 path.split('/')
188 .filter(|s| !s.is_empty())
189 .map(|s| match s.strip_prefix(':') {
190 Some(name) => Segment::Param(name.to_owned()),
191 None => Segment::Literal(s.to_owned()),
192 })
193 .collect()
194}
195
196fn segments_equal(a: &[Segment], b: &[Segment]) -> bool {
197 if a.len() != b.len() {
198 return false;
199 }
200 a.iter().zip(b.iter()).all(|(x, y)| match (x, y) {
201 (Segment::Literal(u), Segment::Literal(v)) => u == v,
202 (Segment::Param(u), Segment::Param(v)) => u == v,
203 _ => false,
204 })
205}
206
207fn match_segments(patterns: &[Segment], actual: &[&str]) -> Option<Params> {
208 if patterns.len() != actual.len() {
209 return None;
210 }
211 let mut params = Params::empty();
212 for (pat, seg) in patterns.iter().zip(actual.iter()) {
213 match pat {
214 Segment::Literal(lit) => {
215 if lit != seg {
216 return None;
217 }
218 }
219 Segment::Param(name) => {
220 params.pairs.push((name.clone(), (*seg).to_owned()));
221 }
222 }
223 }
224 Some(params)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn segs(path: &str) -> Vec<Segment> {
232 parse_path(path)
233 }
234
235 fn parts(path: &str) -> Vec<&str> {
236 path.split('/').filter(|s| !s.is_empty()).collect()
237 }
238
239 #[test]
240 fn root_path_is_empty_segment_list() {
241 assert!(parse_path("/").is_empty());
242 }
243
244 #[test]
245 fn literal_match() {
246 assert!(match_segments(&segs("/users"), &parts("/users")).is_some());
247 assert!(match_segments(&segs("/users"), &parts("/posts")).is_none());
248 }
249
250 #[test]
251 fn param_captures_value() {
252 let params = match_segments(&segs("/users/:id"), &parts("/users/42")).unwrap();
253 assert_eq!(params.get("id"), Some("42"));
254 }
255
256 #[test]
257 fn length_mismatch_does_not_match() {
258 assert!(match_segments(&segs("/users/:id"), &parts("/users")).is_none());
259 assert!(match_segments(&segs("/users"), &parts("/users/42")).is_none());
260 }
261
262 #[test]
263 fn multiple_params_captured_by_name() {
264 let params = match_segments(&segs("/a/:x/b/:y"), &parts("/a/first/b/second")).unwrap();
265 assert_eq!(params.get("x"), Some("first"));
266 assert_eq!(params.get("y"), Some("second"));
267 }
268
269 #[test]
270 fn has_route_detects_registered_path() {
271 let router = Router::new()
272 .get("/", |_req, _p| async {
273 Ok::<Response, Error>(crate::http::text("home"))
274 })
275 .get("/users/:id", |_req, _p| async {
276 Ok::<Response, Error>(crate::http::text("user"))
277 });
278 assert!(router.has_route(&Method::GET, "/"));
279 assert!(router.has_route(&Method::GET, "/users/:id"));
280 assert!(!router.has_route(&Method::GET, "/missing"));
281 assert!(!router.has_route(&Method::POST, "/"));
282 assert!(router.has_route(&Method::GET, "/users/:id"));
284 assert!(!router.has_route(&Method::GET, "/users/:other"));
285 }
286}