1use std::collections::HashMap;
2use http::Method;
3use crate::{Request, Response, HandlerFn};
4
5pub struct Router {
7 routes: HashMap<Method, Vec<Route>>,
8 not_found_handler: Option<HandlerFn>,
9}
10
11#[derive(Clone)]
13struct Route {
14 pattern: RoutePattern,
15 handler: HandlerFn,
16}
17
18#[derive(Debug, Clone)]
20struct RoutePattern {
21 segments: Vec<Segment>,
22}
23
24#[derive(Debug, Clone, PartialEq)]
26enum Segment {
27 Static(String),
28 Param(String),
29 Wildcard,
30}
31
32impl Router {
33 pub fn new() -> Self {
35 Self {
36 routes: HashMap::new(),
37 not_found_handler: None,
38 }
39 }
40
41 pub fn route(&mut self, method: Method, path: &str, handler: HandlerFn) {
43 let pattern = RoutePattern::parse(path);
44 let route = Route { pattern, handler };
45
46 self.routes
47 .entry(method)
48 .or_insert_with(Vec::new)
49 .push(route);
50 }
51
52 pub fn get(&mut self, path: &str, handler: HandlerFn) {
54 self.route(Method::GET, path, handler);
55 }
56
57 pub fn post(&mut self, path: &str, handler: HandlerFn) {
59 self.route(Method::POST, path, handler);
60 }
61
62 pub fn put(&mut self, path: &str, handler: HandlerFn) {
64 self.route(Method::PUT, path, handler);
65 }
66
67 pub fn delete(&mut self, path: &str, handler: HandlerFn) {
69 self.route(Method::DELETE, path, handler);
70 }
71
72 pub fn patch(&mut self, path: &str, handler: HandlerFn) {
74 self.route(Method::PATCH, path, handler);
75 }
76
77 pub fn not_found(&mut self, handler: HandlerFn) {
79 self.not_found_handler = Some(handler);
80 }
81
82 pub(crate) fn get_all_routes(&self) -> Vec<(Method, String, HandlerFn)> {
84 let mut all_routes = Vec::new();
85
86 for (method, routes) in &self.routes {
87 for route in routes {
88 let path = route.pattern.to_string();
90 all_routes.push((method.clone(), path, route.handler.clone()));
91 }
92 }
93
94 all_routes
95 }
96
97 pub async fn route_request(&self, mut req: Request) -> Response {
99 if let Some(routes) = self.routes.get(req.method()) {
100 for route in routes {
101 if let Some(params) = route.pattern.matches(req.path()) {
102 for (name, value) in params {
104 req.set_param(name, value);
105 }
106 return (route.handler)(req).await;
107 }
108 }
109 }
110
111 if let Some(handler) = &self.not_found_handler {
113 handler(req).await
114 } else {
115 Response::not_found()
116 }
117 }
118}
119
120impl Default for Router {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl Clone for Router {
127 fn clone(&self) -> Self {
128 Self {
129 routes: self.routes.clone(),
130 not_found_handler: self.not_found_handler.clone(),
131 }
132 }
133}
134
135impl RoutePattern {
136 fn to_string(&self) -> String {
138 let mut result = String::from("/");
139 for segment in &self.segments {
140 match segment {
141 Segment::Static(s) => {
142 result.push_str(s);
143 result.push('/');
144 }
145 Segment::Param(name) => {
146 result.push(':');
147 result.push_str(name);
148 result.push('/');
149 }
150 Segment::Wildcard => {
151 result.push('*');
152 result.push('/');
153 }
154 }
155 }
156 if result.len() > 1 && result.ends_with('/') {
158 result.pop();
159 }
160 result
161 }
162
163 fn parse(pattern: &str) -> Self {
165 let mut segments = Vec::new();
166
167 for segment in pattern.split('/').filter(|s| !s.is_empty()) {
168 if segment.starts_with(':') {
169 let param_name = segment[1..].to_string();
170 segments.push(Segment::Param(param_name));
171 } else if segment == "*" {
172 segments.push(Segment::Wildcard);
173 } else {
174 segments.push(Segment::Static(segment.to_string()));
175 }
176 }
177
178 Self { segments }
179 }
180
181 fn matches(&self, path: &str) -> Option<HashMap<String, String>> {
183 let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
184
185 if path == "/" && self.segments.is_empty() {
187 return Some(HashMap::new());
188 }
189
190 let mut params = HashMap::new();
191 let mut path_idx = 0;
192 let mut pattern_idx = 0;
193
194 while pattern_idx < self.segments.len() && path_idx < path_segments.len() {
195 match &self.segments[pattern_idx] {
196 Segment::Static(expected) => {
197 if path_segments[path_idx] != expected {
198 return None;
199 }
200 path_idx += 1;
201 pattern_idx += 1;
202 }
203 Segment::Param(name) => {
204 params.insert(name.clone(), path_segments[path_idx].to_string());
205 path_idx += 1;
206 pattern_idx += 1;
207 }
208 Segment::Wildcard => {
209 return Some(params);
211 }
212 }
213 }
214
215 if pattern_idx == self.segments.len() && path_idx == path_segments.len() {
217 Some(params)
218 } else if pattern_idx < self.segments.len()
219 && matches!(self.segments[pattern_idx], Segment::Wildcard) {
220 Some(params)
221 } else {
222 None
223 }
224 }
225}
226
227#[cfg(disabled_for_now)]
228mod tests {
229 use super::*;
230 use crate::Response;
231
232 #[test]
233 fn test_route_pattern_parsing() {
234 let pattern = RoutePattern::parse("/users/:id/posts/:post_id");
235 assert_eq!(pattern.segments.len(), 4);
236 assert_eq!(pattern.segments[0], Segment::Static("users".to_string()));
237 assert_eq!(pattern.segments[1], Segment::Param("id".to_string()));
238 assert_eq!(pattern.segments[2], Segment::Static("posts".to_string()));
239 assert_eq!(pattern.segments[3], Segment::Param("post_id".to_string()));
240 }
241
242 #[test]
243 fn test_route_pattern_matching() {
244 let pattern = RoutePattern::parse("/users/:id");
245 let params = pattern.matches("/users/123").unwrap();
246 assert_eq!(params.get("id"), Some(&"123".to_string()));
247
248 assert!(pattern.matches("/users").is_none());
249 assert!(pattern.matches("/users/123/extra").is_none());
250 }
251
252 #[test]
253 fn test_wildcard_matching() {
254 let pattern = RoutePattern::parse("/files/*");
255 let params = pattern.matches("/files/path/to/file.txt");
256 assert!(params.is_some());
257 }
258
259 #[tokio::test]
260 async fn test_router_basic_routing() {
261 let mut router = Router::new();
262
263 router.get("/", std::sync::Arc::new(|_| Box::pin(async {
264 Response::ok().body("Home")
265 })));
266
267 router.get("/users/:id", std::sync::Arc::new(|req| Box::pin(async move {
268 let id = req.param("id").unwrap_or("unknown");
269 Response::ok().body(format!("User: {}", id))
270 })));
271
272 let req = Request::from_hyper(
274 http::Request::builder()
275 .method("GET")
276 .uri("/")
277 .body(())
278 .unwrap()
279 .into_parts()
280 .0,
281 Vec::new(),
282 )
283 .await
284 .unwrap();
285
286 let response = router.route_request(req).await;
287 assert_eq!(response.body_bytes(), b"Home");
288 }
289}