1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use hyper::Method;
6
7use crate::error::Error;
8use crate::http::{Request, Response};
9use crate::middleware::{MiddlewareFn, Next};
10
11type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12type HandlerFn =
13 Arc<dyn Fn(Request, Params) -> BoxFuture<Result<Response, Error>> + Send + Sync>;
14
15pub struct Params {
16 pairs: Vec<(String, String)>,
17}
18
19impl Params {
20 fn empty() -> Self {
21 Self { pairs: Vec::new() }
22 }
23
24 pub fn get(&self, name: &str) -> Option<&str> {
25 self.pairs
26 .iter()
27 .find_map(|(k, v)| (k == name).then_some(v.as_str()))
28 }
29
30 pub fn len(&self) -> usize {
31 self.pairs.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
35 self.pairs.is_empty()
36 }
37}
38
39enum Segment {
40 Literal(String),
41 Param(String),
42}
43
44struct Route {
45 method: Method,
46 segments: Vec<Segment>,
47 handler: HandlerFn,
48}
49
50pub struct Router {
51 routes: Vec<Route>,
52 middlewares: Vec<MiddlewareFn>,
53}
54
55impl Router {
56 pub fn new() -> Self {
57 Self {
58 routes: Vec::new(),
59 middlewares: Vec::new(),
60 }
61 }
62
63 pub fn wrap<F, Fut>(mut self, middleware: F) -> Self
64 where
65 F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
66 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
67 {
68 self.middlewares
69 .push(Arc::new(move |req, next| Box::pin(middleware(req, next))));
70 self
71 }
72
73 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
74 where
75 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
76 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
77 {
78 self.route(Method::GET, path, handler)
79 }
80
81 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
82 where
83 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
84 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
85 {
86 self.route(Method::POST, path, handler)
87 }
88
89 fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
90 where
91 F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
92 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
93 {
94 let handler: HandlerFn = Arc::new(move |req, params| Box::pin(handler(req, params)));
95 self.routes.push(Route {
96 method,
97 segments: parse_path(path),
98 handler,
99 });
100 self
101 }
102
103 pub async fn dispatch(&self, req: Request) -> Response {
104 let path = req.uri().path().to_owned();
105 let actual: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
106 let method = req.method().clone();
107
108 let mut found: Option<(HandlerFn, Params)> = None;
109 let mut path_matched = false;
110
111 for route in &self.routes {
112 if let Some(params) = match_segments(&route.segments, &actual) {
113 path_matched = true;
114 if route.method == method {
115 found = Some((route.handler.clone(), params));
116 break;
117 }
118 }
119 }
120
121 let (handler, params) = found.unwrap_or_else(|| {
122 let method_not_allowed = path_matched;
123 let fallback: HandlerFn = Arc::new(move |_req, _params| {
124 let err = if method_not_allowed {
125 Error::MethodNotAllowed
126 } else {
127 Error::NotFound
128 };
129 Box::pin(async move { Err(err) })
130 });
131 (fallback, Params::empty())
132 });
133
134 let chain = build_chain(&self.middlewares, handler, params);
135 match chain(req).await {
136 Ok(resp) => resp,
137 Err(err) => err.into_response(),
138 }
139 }
140}
141
142impl Default for Router {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148fn build_chain(
149 middlewares: &[MiddlewareFn],
150 handler: HandlerFn,
151 params: Params,
152) -> Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> {
153 let mut chain: Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> =
154 Box::new(move |req| handler(req, params));
155
156 for mw in middlewares.iter().rev() {
157 let mw = mw.clone();
158 let inner = chain;
159 chain = Box::new(move |req| {
160 let next = Next::new(inner);
161 mw(req, next)
162 });
163 }
164 chain
165}
166
167fn parse_path(path: &str) -> Vec<Segment> {
168 path.split('/')
169 .filter(|s| !s.is_empty())
170 .map(|s| match s.strip_prefix(':') {
171 Some(name) => Segment::Param(name.to_owned()),
172 None => Segment::Literal(s.to_owned()),
173 })
174 .collect()
175}
176
177fn match_segments(patterns: &[Segment], actual: &[&str]) -> Option<Params> {
178 if patterns.len() != actual.len() {
179 return None;
180 }
181 let mut params = Params::empty();
182 for (pat, seg) in patterns.iter().zip(actual.iter()) {
183 match pat {
184 Segment::Literal(lit) => {
185 if lit != seg {
186 return None;
187 }
188 }
189 Segment::Param(name) => {
190 params.pairs.push((name.clone(), (*seg).to_owned()));
191 }
192 }
193 }
194 Some(params)
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn segs(path: &str) -> Vec<Segment> {
202 parse_path(path)
203 }
204
205 fn parts(path: &str) -> Vec<&str> {
206 path.split('/').filter(|s| !s.is_empty()).collect()
207 }
208
209 #[test]
210 fn root_path_is_empty_segment_list() {
211 assert!(parse_path("/").is_empty());
212 }
213
214 #[test]
215 fn literal_match() {
216 assert!(match_segments(&segs("/users"), &parts("/users")).is_some());
217 assert!(match_segments(&segs("/users"), &parts("/posts")).is_none());
218 }
219
220 #[test]
221 fn param_captures_value() {
222 let params = match_segments(&segs("/users/:id"), &parts("/users/42")).unwrap();
223 assert_eq!(params.get("id"), Some("42"));
224 }
225
226 #[test]
227 fn length_mismatch_does_not_match() {
228 assert!(match_segments(&segs("/users/:id"), &parts("/users")).is_none());
229 assert!(match_segments(&segs("/users"), &parts("/users/42")).is_none());
230 }
231
232 #[test]
233 fn multiple_params_captured_by_name() {
234 let params = match_segments(
235 &segs("/a/:x/b/:y"),
236 &parts("/a/first/b/second"),
237 )
238 .unwrap();
239 assert_eq!(params.get("x"), Some("first"));
240 assert_eq!(params.get("y"), Some("second"));
241 }
242}