torch_web/
router.rs

1use std::collections::HashMap;
2use http::Method;
3use crate::{Request, Response, HandlerFn};
4
5/// A fast, lightweight router for HTTP requests
6pub struct Router {
7    routes: HashMap<Method, Vec<Route>>,
8    not_found_handler: Option<HandlerFn>,
9}
10
11/// Represents a single route with its pattern and handler
12#[derive(Clone)]
13struct Route {
14    pattern: RoutePattern,
15    handler: HandlerFn,
16}
17
18/// Pattern matching for routes
19#[derive(Debug, Clone)]
20struct RoutePattern {
21    segments: Vec<Segment>,
22}
23
24/// A segment of a route pattern
25#[derive(Debug, Clone, PartialEq)]
26enum Segment {
27    Static(String),
28    Param(String),
29    Wildcard,
30}
31
32impl Router {
33    /// Create a new router
34    pub fn new() -> Self {
35        Self {
36            routes: HashMap::new(),
37            not_found_handler: None,
38        }
39    }
40
41    /// Add a route for any HTTP method
42    pub fn route(&mut self, method: Method, path: &str, handler: HandlerFn) {
43        let pattern = RoutePattern::parse(path);
44        let route = Route { pattern, handler };
45        
46        self.routes
47            .entry(method)
48            .or_insert_with(Vec::new)
49            .push(route);
50    }
51
52    /// Add a GET route
53    pub fn get(&mut self, path: &str, handler: HandlerFn) {
54        self.route(Method::GET, path, handler);
55    }
56
57    /// Add a POST route
58    pub fn post(&mut self, path: &str, handler: HandlerFn) {
59        self.route(Method::POST, path, handler);
60    }
61
62    /// Add a PUT route
63    pub fn put(&mut self, path: &str, handler: HandlerFn) {
64        self.route(Method::PUT, path, handler);
65    }
66
67    /// Add a DELETE route
68    pub fn delete(&mut self, path: &str, handler: HandlerFn) {
69        self.route(Method::DELETE, path, handler);
70    }
71
72    /// Add a PATCH route
73    pub fn patch(&mut self, path: &str, handler: HandlerFn) {
74        self.route(Method::PATCH, path, handler);
75    }
76
77    /// Set a custom 404 handler
78    pub fn not_found(&mut self, handler: HandlerFn) {
79        self.not_found_handler = Some(handler);
80    }
81
82    /// Get all routes for mounting (internal use)
83    pub(crate) fn get_all_routes(&self) -> Vec<(Method, String, HandlerFn)> {
84        let mut all_routes = Vec::new();
85
86        for (method, routes) in &self.routes {
87            for route in routes {
88                // Convert the pattern back to a string representation
89                let path = route.pattern.to_string();
90                all_routes.push((method.clone(), path, route.handler.clone()));
91            }
92        }
93
94        all_routes
95    }
96
97    /// Route a request to the appropriate handler
98    pub async fn route_request(&self, mut req: Request) -> Response {
99        if let Some(routes) = self.routes.get(req.method()) {
100            for route in routes {
101                if let Some(params) = route.pattern.matches(req.path()) {
102                    // Set path parameters in the request
103                    for (name, value) in params {
104                        req.set_param(name, value);
105                    }
106                    return (route.handler)(req).await;
107                }
108            }
109        }
110
111        // No route found, use 404 handler or default
112        if let Some(handler) = &self.not_found_handler {
113            handler(req).await
114        } else {
115            Response::not_found()
116        }
117    }
118}
119
120impl Default for Router {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl Clone for Router {
127    fn clone(&self) -> Self {
128        Self {
129            routes: self.routes.clone(),
130            not_found_handler: self.not_found_handler.clone(),
131        }
132    }
133}
134
135impl RoutePattern {
136    /// Convert pattern back to string representation
137    fn to_string(&self) -> String {
138        let mut result = String::from("/");
139        for segment in &self.segments {
140            match segment {
141                Segment::Static(s) => {
142                    result.push_str(s);
143                    result.push('/');
144                }
145                Segment::Param(name) => {
146                    result.push(':');
147                    result.push_str(name);
148                    result.push('/');
149                }
150                Segment::Wildcard => {
151                    result.push('*');
152                    result.push('/');
153                }
154            }
155        }
156        // Remove trailing slash unless it's the root path
157        if result.len() > 1 && result.ends_with('/') {
158            result.pop();
159        }
160        result
161    }
162
163    /// Parse a route pattern string into segments
164    fn parse(pattern: &str) -> Self {
165        let mut segments = Vec::new();
166
167        for segment in pattern.split('/').filter(|s| !s.is_empty()) {
168            if segment.starts_with(':') {
169                let param_name = segment[1..].to_string();
170                segments.push(Segment::Param(param_name));
171            } else if segment == "*" {
172                segments.push(Segment::Wildcard);
173            } else {
174                segments.push(Segment::Static(segment.to_string()));
175            }
176        }
177
178        Self { segments }
179    }
180
181    /// Check if this pattern matches the given path and extract parameters
182    fn matches(&self, path: &str) -> Option<HashMap<String, String>> {
183        let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
184        
185        // Handle root path
186        if path == "/" && self.segments.is_empty() {
187            return Some(HashMap::new());
188        }
189
190        let mut params = HashMap::new();
191        let mut path_idx = 0;
192        let mut pattern_idx = 0;
193
194        while pattern_idx < self.segments.len() && path_idx < path_segments.len() {
195            match &self.segments[pattern_idx] {
196                Segment::Static(expected) => {
197                    if path_segments[path_idx] != expected {
198                        return None;
199                    }
200                    path_idx += 1;
201                    pattern_idx += 1;
202                }
203                Segment::Param(name) => {
204                    params.insert(name.clone(), path_segments[path_idx].to_string());
205                    path_idx += 1;
206                    pattern_idx += 1;
207                }
208                Segment::Wildcard => {
209                    // Wildcard matches everything remaining
210                    return Some(params);
211                }
212            }
213        }
214
215        // Check if we consumed all segments
216        if pattern_idx == self.segments.len() && path_idx == path_segments.len() {
217            Some(params)
218        } else if pattern_idx < self.segments.len() 
219            && matches!(self.segments[pattern_idx], Segment::Wildcard) {
220            Some(params)
221        } else {
222            None
223        }
224    }
225}
226
227#[cfg(disabled_for_now)]
228mod tests {
229    use super::*;
230    use crate::Response;
231
232    #[test]
233    fn test_route_pattern_parsing() {
234        let pattern = RoutePattern::parse("/users/:id/posts/:post_id");
235        assert_eq!(pattern.segments.len(), 4);
236        assert_eq!(pattern.segments[0], Segment::Static("users".to_string()));
237        assert_eq!(pattern.segments[1], Segment::Param("id".to_string()));
238        assert_eq!(pattern.segments[2], Segment::Static("posts".to_string()));
239        assert_eq!(pattern.segments[3], Segment::Param("post_id".to_string()));
240    }
241
242    #[test]
243    fn test_route_pattern_matching() {
244        let pattern = RoutePattern::parse("/users/:id");
245        let params = pattern.matches("/users/123").unwrap();
246        assert_eq!(params.get("id"), Some(&"123".to_string()));
247
248        assert!(pattern.matches("/users").is_none());
249        assert!(pattern.matches("/users/123/extra").is_none());
250    }
251
252    #[test]
253    fn test_wildcard_matching() {
254        let pattern = RoutePattern::parse("/files/*");
255        let params = pattern.matches("/files/path/to/file.txt");
256        assert!(params.is_some());
257    }
258
259    #[tokio::test]
260    async fn test_router_basic_routing() {
261        let mut router = Router::new();
262        
263        router.get("/", std::sync::Arc::new(|_| Box::pin(async {
264            Response::ok().body("Home")
265        })));
266
267        router.get("/users/:id", std::sync::Arc::new(|req| Box::pin(async move {
268            let id = req.param("id").unwrap_or("unknown");
269            Response::ok().body(format!("User: {}", id))
270        })));
271
272        // Test root route
273        let req = Request::from_hyper(
274            http::Request::builder()
275                .method("GET")
276                .uri("/")
277                .body(())
278                .unwrap()
279                .into_parts()
280                .0,
281            Vec::new(),
282        )
283        .await
284        .unwrap();
285
286        let response = router.route_request(req).await;
287        assert_eq!(response.body_bytes(), b"Home");
288    }
289}