suika_server/
router.rs

1use crate::error::HttpError;
2use crate::middleware::{Middleware, MiddlewareFuture, Next};
3use crate::request::Request;
4use crate::response::Response;
5use regex::Regex;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9
10/// Represents a route in the router.
11pub struct Route {
12    pub method: Option<String>,
13    pub pattern: Regex,
14    pub handler: Arc<
15        dyn for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a> + Send + Sync,
16    >,
17}
18
19/// A router for handling HTTP requests and routing them to appropriate handlers.
20///
21/// The `Router` can handle routes with or without parameters, and it supports mounting sub-routers.
22///
23/// # Examples
24///
25/// ```
26/// use suika_server::request::Request;
27/// use suika_server::response::{Response, Body};
28/// use suika_server::middleware::{Middleware, Next, MiddlewareFuture};
29/// use suika_server::router::Router;
30/// use regex::Regex;
31/// use std::collections::HashMap;
32/// use std::sync::{Arc, Mutex};
33/// use tokio::sync::Mutex as TokioMutex;
34///
35/// #[derive(Clone)]
36/// struct MockNextMiddleware {
37///     called: Arc<TokioMutex<bool>>,
38/// }
39///
40/// impl MockNextMiddleware {
41///     fn new() -> Self {
42///         Self {
43///             called: Arc::new(TokioMutex::new(false)),
44///         }
45///     }
46/// }
47///
48/// impl Middleware for MockNextMiddleware {
49///     fn handle<'a>(
50///         &'a self,
51///         _req: &'a mut Request,
52///         _res: &'a mut Response,
53///         _next: Next<'a>,
54///     ) -> MiddlewareFuture<'a> {
55///         let called = Arc::clone(&self.called);
56///         Box::pin(async move {
57///             let mut called_lock = called.lock().await;
58///             *called_lock = true;
59///             Ok(())
60///         })
61///     }
62/// }
63///
64/// #[tokio::main]
65/// async fn main() {
66///     let mut router = Router::new("/api");
67///
68///     router.add_route(Some("GET"), "/test", |req, res| {
69///         Box::pin(async move {
70///             res.set_status(200).await;
71///             res.body("Test route".to_string()).await;
72///             Ok(())
73///         })
74///     });
75///
76///     let mut req = Request::new(
77///         "GET /api/test HTTP/1.1\r\n\r\n",
78///         Arc::new(Mutex::new(HashMap::new())),
79///     ).unwrap();
80///
81///     let mut res = Response::new(None);
82///
83///     let next_middleware = MockNextMiddleware::new();
84///     let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> = vec![Arc::new(next_middleware.clone())];
85///     let next = Next::new(middleware_stack.as_slice());
86///
87///     router.handle(&mut req, &mut res, next.clone()).await.unwrap();
88///
89///     let inner = res.get_inner().await;
90///     assert_eq!(inner.status_code(), Some(200));
91///     assert_eq!(inner.body(), &Some(Body::Text("Test route".to_string())));
92///
93///     let next_called = *next_middleware.called.lock().await;
94///     assert!(!next_called);
95/// }
96/// ```
97pub struct Router {
98    pub base_path: String,
99    pub routes: Vec<Route>,
100    pub sub_routers: Vec<Router>,
101}
102
103impl Router {
104    /// Creates a new `Router` with the specified base path.
105    ///
106    /// # Arguments
107    ///
108    /// * `base_path` - The base path for the router.
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use suika_server::router::Router;
114    ///
115    /// let router = Router::new("/api");
116    /// ```
117    pub fn new(base_path: &str) -> Self {
118        Self {
119            base_path: base_path.to_string(),
120            routes: Vec::new(),
121            sub_routers: Vec::new(),
122        }
123    }
124
125    /// Adds a GET route to the router
126    ///
127    /// # Arguments
128    ///
129    /// * `pattern` - The URL pattern for the route, which can include named parameters.
130    /// * `handler` - The handler function for the route.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use suika_server::request::Request;
136    /// use suika_server::response::Response;
137    /// use suika_server::middleware::MiddlewareFuture;
138    /// use suika_server::router::Router;
139    /// use std::sync::Arc;
140    ///
141    /// let mut router = Router::new("/api");
142    ///
143    /// router.get("/test", |req, res| {
144    ///     Box::pin(async move {
145    ///         res.set_status(200).await;
146    ///         res.body("Test route".to_string()).await;
147    ///         Ok(())
148    ///     })
149    /// });
150    /// ```
151    pub fn get<F>(&mut self, pattern: &str, handler: F)
152    where
153        F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
154            + Send
155            + Sync
156            + 'static,
157    {
158        self.add_route(Some("GET"), pattern, handler);
159    }
160
161    /// Adds a PUT route to the router
162    ///
163    /// # Arguments
164    ///
165    /// * `pattern` - The URL pattern for the route, which can include named parameters.
166    /// * `handler` - The handler function for the route.
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// use suika_server::request::Request;
172    /// use suika_server::response::Response;
173    /// use suika_server::middleware::MiddlewareFuture;
174    /// use suika_server::router::Router;
175    /// use std::sync::Arc;
176    ///
177    /// let mut router = Router::new("/api");
178    ///
179    /// router.put("/test", |req, res| {
180    ///     Box::pin(async move {
181    ///         res.set_status(200).await;
182    ///         res.body("Test route".to_string()).await;
183    ///         Ok(())
184    ///     })
185    /// });
186    /// ```
187    pub fn put<F>(&mut self, pattern: &str, handler: F)
188    where
189        F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
190            + Send
191            + Sync
192            + 'static,
193    {
194        self.add_route(Some("PUT"), pattern, handler);
195    }
196
197    /// Adds a POST route to the router
198    ///
199    /// # Arguments
200    ///
201    /// * `pattern` - The URL pattern for the route, which can include named parameters.
202    /// * `handler` - The handler function for the route.
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// use suika_server::request::Request;
208    /// use suika_server::response::Response;
209    /// use suika_server::middleware::MiddlewareFuture;
210    /// use suika_server::router::Router;
211    /// use std::sync::Arc;
212    ///
213    /// let mut router = Router::new("/api");
214    ///
215    /// router.post("/test", |req, res| {
216    ///     Box::pin(async move {
217    ///         res.set_status(200).await;
218    ///         res.body("Test route".to_string()).await;
219    ///         Ok(())
220    ///     })
221    /// });
222    /// ```
223    pub fn post<F>(&mut self, pattern: &str, handler: F)
224    where
225        F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
226            + Send
227            + Sync
228            + 'static,
229    {
230        self.add_route(Some("POST"), pattern, handler);
231    }
232
233    /// Adds a DELETE route to the router
234    ///
235    /// # Arguments
236    ///
237    /// * `pattern` - The URL pattern for the route, which can include named parameters.
238    /// * `handler` - The handler function for the route.
239    ///
240    /// # Examples
241    ///
242    /// ```
243    /// use suika_server::request::Request;
244    /// use suika_server::response::Response;
245    /// use suika_server::middleware::MiddlewareFuture;
246    /// use suika_server::router::Router;
247    /// use std::sync::Arc;
248    ///
249    /// let mut router = Router::new("/api");
250    ///
251    /// router.delete("/test", |req, res| {
252    ///     Box::pin(async move {
253    ///         res.set_status(200).await;
254    ///         res.body("Test route".to_string()).await;
255    ///         Ok(())
256    ///     })
257    /// });
258    /// ```
259    pub fn delete<F>(&mut self, pattern: &str, handler: F)
260    where
261        F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
262            + Send
263            + Sync
264            + 'static,
265    {
266        self.add_route(Some("DELETE"), pattern, handler);
267    }
268
269    /// Adds a route to the router.
270    ///
271    /// # Arguments
272    ///
273    /// * `method` - The HTTP method for the route (e.g., "GET", "POST").
274    /// * `pattern` - The URL pattern for the route, which can include named parameters.
275    /// * `handler` - The handler function for the route.
276    ///
277    /// # Examples
278    ///
279    /// ```
280    /// use suika_server::request::Request;
281    /// use suika_server::response::Response;
282    /// use suika_server::middleware::MiddlewareFuture;
283    /// use suika_server::router::Router;
284    /// use std::sync::Arc;
285    ///
286    /// let mut router = Router::new("/api");
287    ///
288    /// router.add_route(Some("GET"), "/test", |req, res| {
289    ///     Box::pin(async move {
290    ///         res.set_status(200).await;
291    ///         res.body("Test route".to_string()).await;
292    ///         Ok(())
293    ///     })
294    /// });
295    /// ```
296    pub fn add_route<F>(&mut self, method: Option<&str>, pattern: &str, handler: F)
297    where
298        F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
299            + Send
300            + Sync
301            + 'static,
302    {
303        let full_pattern = format!("{}{}", self.base_path.trim_end_matches('/'), pattern);
304        let rgx = Regex::new(&full_pattern).expect("Invalid regex pattern");
305
306        self.routes.push(Route {
307            method: method.map(|m| m.to_string()),
308            pattern: rgx,
309            handler: Arc::new(handler),
310        });
311    }
312
313    /// Mounts a sub-router onto this router.
314    ///
315    /// # Arguments
316    ///
317    /// * `sub_router` - The sub-router to mount.
318    ///
319    /// # Examples
320    ///
321    /// ```
322    /// use suika_server::router::Router;
323    ///
324    /// let mut router = Router::new("/api");
325    /// let sub_router = Router::new("/sub");
326    ///
327    /// router.mount(sub_router);
328    /// ```
329    pub fn mount(&mut self, mut sub_router: Router) {
330        let combined = format!(
331            "{}{}",
332            self.base_path.trim_end_matches('/'),
333            sub_router.base_path
334        );
335        sub_router.base_path = combined;
336        self.sub_routers.push(sub_router);
337    }
338
339    /// Handles an incoming HTTP request by matching it to a route.
340    ///
341    /// This method is called internally by the `Router`'s `Middleware` implementation.
342    ///
343    /// # Arguments
344    ///
345    /// * `req` - A mutable reference to the incoming request.
346    /// * `res` - A mutable reference to the response to be sent.
347    ///
348    /// # Returns
349    ///
350    /// A future that resolves to `Result<bool, HttpError>`, indicating whether a route was matched.
351    fn handle_internal<'a>(
352        &'a self,
353        req: &'a mut Request,
354        res: &'a mut Response,
355    ) -> Pin<Box<dyn futures::Future<Output = Result<bool, HttpError>> + Send + 'a>> {
356        Box::pin(async move {
357            for route in &self.routes {
358                if let Some(ref route_method) = route.method {
359                    if route_method.to_uppercase() != req.method().to_uppercase() {
360                        continue;
361                    }
362                }
363                if let Some(caps) = route.pattern.captures(req.path()) {
364                    let mut params = HashMap::new();
365                    for name in route.pattern.capture_names().flatten() {
366                        if let Some(value) = caps.name(name) {
367                            params.insert(name.to_string(), value.as_str().to_string());
368                        }
369                    }
370
371                    req.set_params(params);
372
373                    if let Err(e) = (route.handler)(req, res).await {
374                        res.error(e).await;
375                        return Ok(true);
376                    }
377                    return Ok(true);
378                }
379            }
380
381            for subr in &self.sub_routers {
382                if subr.handle_internal(req, res).await? {
383                    return Ok(true);
384                }
385            }
386            Ok(false)
387        })
388    }
389}
390
391impl Middleware for Router {
392    /// Handles an incoming HTTP request by routing it to the appropriate handler.
393    ///
394    /// # Arguments
395    ///
396    /// * `req` - A mutable reference to the incoming request.
397    /// * `res` - A mutable reference to the response.
398    /// * `next` - The next middleware in the stack.
399    ///
400    /// # Returns
401    ///
402    /// A future that resolves to a `Result<(), HttpError>`.
403    ///
404    /// # Examples
405    ///
406    /// ```
407    /// use suika_server::request::Request;
408    /// use suika_server::response::{Response, Body};
409    /// use suika_server::middleware::{Middleware, Next, MiddlewareFuture};
410    /// use suika_server::router::Router;
411    /// use regex::Regex;
412    /// use std::sync::{Arc, Mutex};
413    /// use tokio::sync::Mutex as TokioMutex;
414    /// use std::collections::HashMap;
415    ///
416    ///
417    /// #[derive(Clone)]
418    /// struct MockNextMiddleware {
419    ///     called: Arc<TokioMutex<bool>>,
420    /// }
421    ///
422    /// impl MockNextMiddleware {
423    ///     fn new() -> Self {
424    ///         Self {
425    ///             called: Arc::new(TokioMutex::new(false)),
426    ///         }
427    ///     }
428    /// }
429    ///
430    /// impl Middleware for MockNextMiddleware {
431    ///     fn handle<'a>(
432    ///         &'a self,
433    ///         _req: &'a mut Request,
434    ///         _res: &'a mut Response,
435    ///         _next: Next<'a>,
436    ///     ) -> MiddlewareFuture<'a> {
437    ///         let called = Arc::clone(&self.called);
438    ///         Box::pin(async move {
439    ///             let mut called_lock = called.lock().await;
440    ///             *called_lock = true;
441    ///             Ok(())
442    ///         })
443    ///     }
444    /// }
445    ///
446    /// #[tokio::main]
447    /// async fn main() {
448    ///     let mut router = Router::new("/api");
449    ///
450    ///     router.add_route(Some("GET"), "/test", |req, res| {
451    ///         Box::pin(async move {
452    ///             res.set_status(200).await;
453    ///             res.body("Test route".to_string()).await;
454    ///             Ok(())
455    ///         })
456    ///     });
457    ///
458    ///     let mut req = Request::new(
459    ///         "GET /api/test HTTP/1.1\r\n\r\n",
460    ///         Arc::new(Mutex::new(HashMap::new()))
461    ///     ).unwrap();
462    ///
463    ///     let mut res = Response::new(None);
464    ///
465    ///     let next_middleware = MockNextMiddleware::new();
466    ///     let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> = vec![Arc::new(next_middleware.clone())];
467    ///     let next = Next::new(middleware_stack.as_slice());
468    ///
469    ///     router.handle(&mut req, &mut res, next.clone()).await.unwrap();
470    ///
471    ///     let inner = res.get_inner().await;
472    ///     assert_eq!(inner.status_code(), Some(200));
473    ///     assert_eq!(inner.body(), &Some(Body::Text("Test route".to_string())));
474    ///
475    ///     let next_called = *next_middleware.called.lock().await;
476    ///     assert!(!next_called);
477    /// }
478    /// ```
479    fn handle<'a>(
480        &'a self,
481        req: &'a mut Request,
482        res: &'a mut Response,
483        mut next: Next<'a>,
484    ) -> MiddlewareFuture<'a> {
485        Box::pin(async move {
486            let matched_route = self.handle_internal(req, res).await;
487            if let Err(e) = matched_route {
488                res.error(e).await;
489            } else if !matched_route.unwrap_or(false) {
490                next.run(req, res).await?;
491            }
492            Ok(())
493        })
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::middleware::{Middleware, Next};
501    use crate::request::Request;
502    use crate::response::{Body, Response};
503    use std::sync::{Arc, Mutex};
504    use tokio::sync::Mutex as TokioMutex;
505
506    // Mock Next middleware
507    #[derive(Clone)]
508    struct MockNextMiddleware {
509        called: Arc<TokioMutex<bool>>,
510    }
511
512    impl MockNextMiddleware {
513        fn new() -> Self {
514            Self {
515                called: Arc::new(TokioMutex::new(false)),
516            }
517        }
518    }
519
520    impl Middleware for MockNextMiddleware {
521        fn handle<'a>(
522            &'a self,
523            _req: &'a mut Request,
524            _res: &'a mut Response,
525            _next: Next<'a>,
526        ) -> MiddlewareFuture<'a> {
527            let called = Arc::clone(&self.called);
528            Box::pin(async move {
529                let mut called_lock = called.lock().await;
530                *called_lock = true;
531                Ok(())
532            })
533        }
534    }
535
536    #[tokio::test]
537    async fn test_router_handles_get_route() {
538        let mut router = Router::new("/api");
539
540        router.get("/test", |_req, res| {
541            Box::pin(async move {
542                res.set_status(200).await;
543                res.body("GET route".to_string()).await;
544                Ok(())
545            })
546        });
547
548        let mut req = Request::new(
549            "GET /api/test HTTP/1.1\r\n\r\n",
550            Arc::new(Mutex::new(HashMap::new())),
551        )
552        .unwrap();
553        let mut res = Response::new(None);
554
555        let next_middleware = MockNextMiddleware::new();
556        let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
557            vec![Arc::new(next_middleware.clone())];
558        let next = Next::new(middleware_stack.as_slice());
559
560        router
561            .handle(&mut req, &mut res, next.clone())
562            .await
563            .unwrap();
564
565        let inner = res.get_inner().await;
566        assert_eq!(inner.status_code(), Some(200));
567        assert_eq!(inner.body(), &Some(Body::Text("GET route".to_string())));
568
569        let next_called = *next_middleware.called.lock().await;
570        assert!(!next_called);
571    }
572
573    #[tokio::test]
574    async fn test_router_handles_put_route() {
575        let mut router = Router::new("/api");
576
577        router.put("/test", |_req, res| {
578            Box::pin(async move {
579                res.set_status(200).await;
580                res.body("PUT route".to_string()).await;
581                Ok(())
582            })
583        });
584
585        let mut req = Request::new(
586            "PUT /api/test HTTP/1.1\r\n\r\n",
587            Arc::new(Mutex::new(HashMap::new())),
588        )
589        .unwrap();
590        let mut res = Response::new(None);
591
592        let next_middleware = MockNextMiddleware::new();
593        let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
594            vec![Arc::new(next_middleware.clone())];
595        let next = Next::new(middleware_stack.as_slice());
596
597        router
598            .handle(&mut req, &mut res, next.clone())
599            .await
600            .unwrap();
601
602        let inner = res.get_inner().await;
603        assert_eq!(inner.status_code(), Some(200));
604        assert_eq!(inner.body(), &Some(Body::Text("PUT route".to_string())));
605
606        let next_called = *next_middleware.called.lock().await;
607        assert!(!next_called);
608    }
609
610    #[tokio::test]
611    async fn test_router_handles_post_route() {
612        let mut router = Router::new("/api");
613
614        router.post("/test", |_req, res| {
615            Box::pin(async move {
616                res.set_status(200).await;
617                res.body("POST route".to_string()).await;
618                Ok(())
619            })
620        });
621
622        let mut req = Request::new(
623            "POST /api/test HTTP/1.1\r\n\r\n",
624            Arc::new(Mutex::new(HashMap::new())),
625        )
626        .unwrap();
627        let mut res = Response::new(None);
628
629        let next_middleware = MockNextMiddleware::new();
630        let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
631            vec![Arc::new(next_middleware.clone())];
632        let next = Next::new(middleware_stack.as_slice());
633
634        router
635            .handle(&mut req, &mut res, next.clone())
636            .await
637            .unwrap();
638
639        let inner = res.get_inner().await;
640        assert_eq!(inner.status_code(), Some(200));
641        assert_eq!(inner.body(), &Some(Body::Text("POST route".to_string())));
642
643        let next_called = *next_middleware.called.lock().await;
644        assert!(!next_called);
645    }
646
647    #[tokio::test]
648    async fn test_router_handles_delete_route() {
649        let mut router = Router::new("/api");
650
651        router.delete("/test", |_req, res| {
652            Box::pin(async move {
653                res.set_status(200).await;
654                res.body("DELETE route".to_string()).await;
655                Ok(())
656            })
657        });
658
659        let mut req = Request::new(
660            "DELETE /api/test HTTP/1.1\r\n\r\n",
661            Arc::new(Mutex::new(HashMap::new())),
662        )
663        .unwrap();
664        let mut res = Response::new(None);
665
666        let next_middleware = MockNextMiddleware::new();
667        let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
668            vec![Arc::new(next_middleware.clone())];
669        let next = Next::new(middleware_stack.as_slice());
670
671        router
672            .handle(&mut req, &mut res, next.clone())
673            .await
674            .unwrap();
675
676        let inner = res.get_inner().await;
677        assert_eq!(inner.status_code(), Some(200));
678        assert_eq!(inner.body(), &Some(Body::Text("DELETE route".to_string())));
679
680        let next_called = *next_middleware.called.lock().await;
681        assert!(!next_called);
682    }
683}