suika_server/router.rs
1use crate::error::HttpError;
2use crate::middleware::{Middleware, MiddlewareFuture, Next};
3use crate::request::Request;
4use crate::response::Response;
5use regex::Regex;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9
10/// Represents a route in the router.
11pub struct Route {
12 pub method: Option<String>,
13 pub pattern: Regex,
14 pub handler: Arc<
15 dyn for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a> + Send + Sync,
16 >,
17}
18
19/// A router for handling HTTP requests and routing them to appropriate handlers.
20///
21/// The `Router` can handle routes with or without parameters, and it supports mounting sub-routers.
22///
23/// # Examples
24///
25/// ```
26/// use suika_server::request::Request;
27/// use suika_server::response::{Response, Body};
28/// use suika_server::middleware::{Middleware, Next, MiddlewareFuture};
29/// use suika_server::router::Router;
30/// use regex::Regex;
31/// use std::collections::HashMap;
32/// use std::sync::{Arc, Mutex};
33/// use tokio::sync::Mutex as TokioMutex;
34///
35/// #[derive(Clone)]
36/// struct MockNextMiddleware {
37/// called: Arc<TokioMutex<bool>>,
38/// }
39///
40/// impl MockNextMiddleware {
41/// fn new() -> Self {
42/// Self {
43/// called: Arc::new(TokioMutex::new(false)),
44/// }
45/// }
46/// }
47///
48/// impl Middleware for MockNextMiddleware {
49/// fn handle<'a>(
50/// &'a self,
51/// _req: &'a mut Request,
52/// _res: &'a mut Response,
53/// _next: Next<'a>,
54/// ) -> MiddlewareFuture<'a> {
55/// let called = Arc::clone(&self.called);
56/// Box::pin(async move {
57/// let mut called_lock = called.lock().await;
58/// *called_lock = true;
59/// Ok(())
60/// })
61/// }
62/// }
63///
64/// #[tokio::main]
65/// async fn main() {
66/// let mut router = Router::new("/api");
67///
68/// router.add_route(Some("GET"), "/test", |req, res| {
69/// Box::pin(async move {
70/// res.set_status(200).await;
71/// res.body("Test route".to_string()).await;
72/// Ok(())
73/// })
74/// });
75///
76/// let mut req = Request::new(
77/// "GET /api/test HTTP/1.1\r\n\r\n",
78/// Arc::new(Mutex::new(HashMap::new())),
79/// ).unwrap();
80///
81/// let mut res = Response::new(None);
82///
83/// let next_middleware = MockNextMiddleware::new();
84/// let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> = vec![Arc::new(next_middleware.clone())];
85/// let next = Next::new(middleware_stack.as_slice());
86///
87/// router.handle(&mut req, &mut res, next.clone()).await.unwrap();
88///
89/// let inner = res.get_inner().await;
90/// assert_eq!(inner.status_code(), Some(200));
91/// assert_eq!(inner.body(), &Some(Body::Text("Test route".to_string())));
92///
93/// let next_called = *next_middleware.called.lock().await;
94/// assert!(!next_called);
95/// }
96/// ```
97pub struct Router {
98 pub base_path: String,
99 pub routes: Vec<Route>,
100 pub sub_routers: Vec<Router>,
101}
102
103impl Router {
104 /// Creates a new `Router` with the specified base path.
105 ///
106 /// # Arguments
107 ///
108 /// * `base_path` - The base path for the router.
109 ///
110 /// # Examples
111 ///
112 /// ```
113 /// use suika_server::router::Router;
114 ///
115 /// let router = Router::new("/api");
116 /// ```
117 pub fn new(base_path: &str) -> Self {
118 Self {
119 base_path: base_path.to_string(),
120 routes: Vec::new(),
121 sub_routers: Vec::new(),
122 }
123 }
124
125 /// Adds a GET route to the router
126 ///
127 /// # Arguments
128 ///
129 /// * `pattern` - The URL pattern for the route, which can include named parameters.
130 /// * `handler` - The handler function for the route.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use suika_server::request::Request;
136 /// use suika_server::response::Response;
137 /// use suika_server::middleware::MiddlewareFuture;
138 /// use suika_server::router::Router;
139 /// use std::sync::Arc;
140 ///
141 /// let mut router = Router::new("/api");
142 ///
143 /// router.get("/test", |req, res| {
144 /// Box::pin(async move {
145 /// res.set_status(200).await;
146 /// res.body("Test route".to_string()).await;
147 /// Ok(())
148 /// })
149 /// });
150 /// ```
151 pub fn get<F>(&mut self, pattern: &str, handler: F)
152 where
153 F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
154 + Send
155 + Sync
156 + 'static,
157 {
158 self.add_route(Some("GET"), pattern, handler);
159 }
160
161 /// Adds a PUT route to the router
162 ///
163 /// # Arguments
164 ///
165 /// * `pattern` - The URL pattern for the route, which can include named parameters.
166 /// * `handler` - The handler function for the route.
167 ///
168 /// # Examples
169 ///
170 /// ```
171 /// use suika_server::request::Request;
172 /// use suika_server::response::Response;
173 /// use suika_server::middleware::MiddlewareFuture;
174 /// use suika_server::router::Router;
175 /// use std::sync::Arc;
176 ///
177 /// let mut router = Router::new("/api");
178 ///
179 /// router.put("/test", |req, res| {
180 /// Box::pin(async move {
181 /// res.set_status(200).await;
182 /// res.body("Test route".to_string()).await;
183 /// Ok(())
184 /// })
185 /// });
186 /// ```
187 pub fn put<F>(&mut self, pattern: &str, handler: F)
188 where
189 F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
190 + Send
191 + Sync
192 + 'static,
193 {
194 self.add_route(Some("PUT"), pattern, handler);
195 }
196
197 /// Adds a POST route to the router
198 ///
199 /// # Arguments
200 ///
201 /// * `pattern` - The URL pattern for the route, which can include named parameters.
202 /// * `handler` - The handler function for the route.
203 ///
204 /// # Examples
205 ///
206 /// ```
207 /// use suika_server::request::Request;
208 /// use suika_server::response::Response;
209 /// use suika_server::middleware::MiddlewareFuture;
210 /// use suika_server::router::Router;
211 /// use std::sync::Arc;
212 ///
213 /// let mut router = Router::new("/api");
214 ///
215 /// router.post("/test", |req, res| {
216 /// Box::pin(async move {
217 /// res.set_status(200).await;
218 /// res.body("Test route".to_string()).await;
219 /// Ok(())
220 /// })
221 /// });
222 /// ```
223 pub fn post<F>(&mut self, pattern: &str, handler: F)
224 where
225 F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
226 + Send
227 + Sync
228 + 'static,
229 {
230 self.add_route(Some("POST"), pattern, handler);
231 }
232
233 /// Adds a DELETE route to the router
234 ///
235 /// # Arguments
236 ///
237 /// * `pattern` - The URL pattern for the route, which can include named parameters.
238 /// * `handler` - The handler function for the route.
239 ///
240 /// # Examples
241 ///
242 /// ```
243 /// use suika_server::request::Request;
244 /// use suika_server::response::Response;
245 /// use suika_server::middleware::MiddlewareFuture;
246 /// use suika_server::router::Router;
247 /// use std::sync::Arc;
248 ///
249 /// let mut router = Router::new("/api");
250 ///
251 /// router.delete("/test", |req, res| {
252 /// Box::pin(async move {
253 /// res.set_status(200).await;
254 /// res.body("Test route".to_string()).await;
255 /// Ok(())
256 /// })
257 /// });
258 /// ```
259 pub fn delete<F>(&mut self, pattern: &str, handler: F)
260 where
261 F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
262 + Send
263 + Sync
264 + 'static,
265 {
266 self.add_route(Some("DELETE"), pattern, handler);
267 }
268
269 /// Adds a route to the router.
270 ///
271 /// # Arguments
272 ///
273 /// * `method` - The HTTP method for the route (e.g., "GET", "POST").
274 /// * `pattern` - The URL pattern for the route, which can include named parameters.
275 /// * `handler` - The handler function for the route.
276 ///
277 /// # Examples
278 ///
279 /// ```
280 /// use suika_server::request::Request;
281 /// use suika_server::response::Response;
282 /// use suika_server::middleware::MiddlewareFuture;
283 /// use suika_server::router::Router;
284 /// use std::sync::Arc;
285 ///
286 /// let mut router = Router::new("/api");
287 ///
288 /// router.add_route(Some("GET"), "/test", |req, res| {
289 /// Box::pin(async move {
290 /// res.set_status(200).await;
291 /// res.body("Test route".to_string()).await;
292 /// Ok(())
293 /// })
294 /// });
295 /// ```
296 pub fn add_route<F>(&mut self, method: Option<&str>, pattern: &str, handler: F)
297 where
298 F: for<'a> Fn(&'a mut Request, &'a mut Response) -> MiddlewareFuture<'a>
299 + Send
300 + Sync
301 + 'static,
302 {
303 let full_pattern = format!("{}{}", self.base_path.trim_end_matches('/'), pattern);
304 let rgx = Regex::new(&full_pattern).expect("Invalid regex pattern");
305
306 self.routes.push(Route {
307 method: method.map(|m| m.to_string()),
308 pattern: rgx,
309 handler: Arc::new(handler),
310 });
311 }
312
313 /// Mounts a sub-router onto this router.
314 ///
315 /// # Arguments
316 ///
317 /// * `sub_router` - The sub-router to mount.
318 ///
319 /// # Examples
320 ///
321 /// ```
322 /// use suika_server::router::Router;
323 ///
324 /// let mut router = Router::new("/api");
325 /// let sub_router = Router::new("/sub");
326 ///
327 /// router.mount(sub_router);
328 /// ```
329 pub fn mount(&mut self, mut sub_router: Router) {
330 let combined = format!(
331 "{}{}",
332 self.base_path.trim_end_matches('/'),
333 sub_router.base_path
334 );
335 sub_router.base_path = combined;
336 self.sub_routers.push(sub_router);
337 }
338
339 /// Handles an incoming HTTP request by matching it to a route.
340 ///
341 /// This method is called internally by the `Router`'s `Middleware` implementation.
342 ///
343 /// # Arguments
344 ///
345 /// * `req` - A mutable reference to the incoming request.
346 /// * `res` - A mutable reference to the response to be sent.
347 ///
348 /// # Returns
349 ///
350 /// A future that resolves to `Result<bool, HttpError>`, indicating whether a route was matched.
351 fn handle_internal<'a>(
352 &'a self,
353 req: &'a mut Request,
354 res: &'a mut Response,
355 ) -> Pin<Box<dyn futures::Future<Output = Result<bool, HttpError>> + Send + 'a>> {
356 Box::pin(async move {
357 for route in &self.routes {
358 if let Some(ref route_method) = route.method {
359 if route_method.to_uppercase() != req.method().to_uppercase() {
360 continue;
361 }
362 }
363 if let Some(caps) = route.pattern.captures(req.path()) {
364 let mut params = HashMap::new();
365 for name in route.pattern.capture_names().flatten() {
366 if let Some(value) = caps.name(name) {
367 params.insert(name.to_string(), value.as_str().to_string());
368 }
369 }
370
371 req.set_params(params);
372
373 if let Err(e) = (route.handler)(req, res).await {
374 res.error(e).await;
375 return Ok(true);
376 }
377 return Ok(true);
378 }
379 }
380
381 for subr in &self.sub_routers {
382 if subr.handle_internal(req, res).await? {
383 return Ok(true);
384 }
385 }
386 Ok(false)
387 })
388 }
389}
390
391impl Middleware for Router {
392 /// Handles an incoming HTTP request by routing it to the appropriate handler.
393 ///
394 /// # Arguments
395 ///
396 /// * `req` - A mutable reference to the incoming request.
397 /// * `res` - A mutable reference to the response.
398 /// * `next` - The next middleware in the stack.
399 ///
400 /// # Returns
401 ///
402 /// A future that resolves to a `Result<(), HttpError>`.
403 ///
404 /// # Examples
405 ///
406 /// ```
407 /// use suika_server::request::Request;
408 /// use suika_server::response::{Response, Body};
409 /// use suika_server::middleware::{Middleware, Next, MiddlewareFuture};
410 /// use suika_server::router::Router;
411 /// use regex::Regex;
412 /// use std::sync::{Arc, Mutex};
413 /// use tokio::sync::Mutex as TokioMutex;
414 /// use std::collections::HashMap;
415 ///
416 ///
417 /// #[derive(Clone)]
418 /// struct MockNextMiddleware {
419 /// called: Arc<TokioMutex<bool>>,
420 /// }
421 ///
422 /// impl MockNextMiddleware {
423 /// fn new() -> Self {
424 /// Self {
425 /// called: Arc::new(TokioMutex::new(false)),
426 /// }
427 /// }
428 /// }
429 ///
430 /// impl Middleware for MockNextMiddleware {
431 /// fn handle<'a>(
432 /// &'a self,
433 /// _req: &'a mut Request,
434 /// _res: &'a mut Response,
435 /// _next: Next<'a>,
436 /// ) -> MiddlewareFuture<'a> {
437 /// let called = Arc::clone(&self.called);
438 /// Box::pin(async move {
439 /// let mut called_lock = called.lock().await;
440 /// *called_lock = true;
441 /// Ok(())
442 /// })
443 /// }
444 /// }
445 ///
446 /// #[tokio::main]
447 /// async fn main() {
448 /// let mut router = Router::new("/api");
449 ///
450 /// router.add_route(Some("GET"), "/test", |req, res| {
451 /// Box::pin(async move {
452 /// res.set_status(200).await;
453 /// res.body("Test route".to_string()).await;
454 /// Ok(())
455 /// })
456 /// });
457 ///
458 /// let mut req = Request::new(
459 /// "GET /api/test HTTP/1.1\r\n\r\n",
460 /// Arc::new(Mutex::new(HashMap::new()))
461 /// ).unwrap();
462 ///
463 /// let mut res = Response::new(None);
464 ///
465 /// let next_middleware = MockNextMiddleware::new();
466 /// let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> = vec![Arc::new(next_middleware.clone())];
467 /// let next = Next::new(middleware_stack.as_slice());
468 ///
469 /// router.handle(&mut req, &mut res, next.clone()).await.unwrap();
470 ///
471 /// let inner = res.get_inner().await;
472 /// assert_eq!(inner.status_code(), Some(200));
473 /// assert_eq!(inner.body(), &Some(Body::Text("Test route".to_string())));
474 ///
475 /// let next_called = *next_middleware.called.lock().await;
476 /// assert!(!next_called);
477 /// }
478 /// ```
479 fn handle<'a>(
480 &'a self,
481 req: &'a mut Request,
482 res: &'a mut Response,
483 mut next: Next<'a>,
484 ) -> MiddlewareFuture<'a> {
485 Box::pin(async move {
486 let matched_route = self.handle_internal(req, res).await;
487 if let Err(e) = matched_route {
488 res.error(e).await;
489 } else if !matched_route.unwrap_or(false) {
490 next.run(req, res).await?;
491 }
492 Ok(())
493 })
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::middleware::{Middleware, Next};
501 use crate::request::Request;
502 use crate::response::{Body, Response};
503 use std::sync::{Arc, Mutex};
504 use tokio::sync::Mutex as TokioMutex;
505
506 // Mock Next middleware
507 #[derive(Clone)]
508 struct MockNextMiddleware {
509 called: Arc<TokioMutex<bool>>,
510 }
511
512 impl MockNextMiddleware {
513 fn new() -> Self {
514 Self {
515 called: Arc::new(TokioMutex::new(false)),
516 }
517 }
518 }
519
520 impl Middleware for MockNextMiddleware {
521 fn handle<'a>(
522 &'a self,
523 _req: &'a mut Request,
524 _res: &'a mut Response,
525 _next: Next<'a>,
526 ) -> MiddlewareFuture<'a> {
527 let called = Arc::clone(&self.called);
528 Box::pin(async move {
529 let mut called_lock = called.lock().await;
530 *called_lock = true;
531 Ok(())
532 })
533 }
534 }
535
536 #[tokio::test]
537 async fn test_router_handles_get_route() {
538 let mut router = Router::new("/api");
539
540 router.get("/test", |_req, res| {
541 Box::pin(async move {
542 res.set_status(200).await;
543 res.body("GET route".to_string()).await;
544 Ok(())
545 })
546 });
547
548 let mut req = Request::new(
549 "GET /api/test HTTP/1.1\r\n\r\n",
550 Arc::new(Mutex::new(HashMap::new())),
551 )
552 .unwrap();
553 let mut res = Response::new(None);
554
555 let next_middleware = MockNextMiddleware::new();
556 let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
557 vec![Arc::new(next_middleware.clone())];
558 let next = Next::new(middleware_stack.as_slice());
559
560 router
561 .handle(&mut req, &mut res, next.clone())
562 .await
563 .unwrap();
564
565 let inner = res.get_inner().await;
566 assert_eq!(inner.status_code(), Some(200));
567 assert_eq!(inner.body(), &Some(Body::Text("GET route".to_string())));
568
569 let next_called = *next_middleware.called.lock().await;
570 assert!(!next_called);
571 }
572
573 #[tokio::test]
574 async fn test_router_handles_put_route() {
575 let mut router = Router::new("/api");
576
577 router.put("/test", |_req, res| {
578 Box::pin(async move {
579 res.set_status(200).await;
580 res.body("PUT route".to_string()).await;
581 Ok(())
582 })
583 });
584
585 let mut req = Request::new(
586 "PUT /api/test HTTP/1.1\r\n\r\n",
587 Arc::new(Mutex::new(HashMap::new())),
588 )
589 .unwrap();
590 let mut res = Response::new(None);
591
592 let next_middleware = MockNextMiddleware::new();
593 let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
594 vec![Arc::new(next_middleware.clone())];
595 let next = Next::new(middleware_stack.as_slice());
596
597 router
598 .handle(&mut req, &mut res, next.clone())
599 .await
600 .unwrap();
601
602 let inner = res.get_inner().await;
603 assert_eq!(inner.status_code(), Some(200));
604 assert_eq!(inner.body(), &Some(Body::Text("PUT route".to_string())));
605
606 let next_called = *next_middleware.called.lock().await;
607 assert!(!next_called);
608 }
609
610 #[tokio::test]
611 async fn test_router_handles_post_route() {
612 let mut router = Router::new("/api");
613
614 router.post("/test", |_req, res| {
615 Box::pin(async move {
616 res.set_status(200).await;
617 res.body("POST route".to_string()).await;
618 Ok(())
619 })
620 });
621
622 let mut req = Request::new(
623 "POST /api/test HTTP/1.1\r\n\r\n",
624 Arc::new(Mutex::new(HashMap::new())),
625 )
626 .unwrap();
627 let mut res = Response::new(None);
628
629 let next_middleware = MockNextMiddleware::new();
630 let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
631 vec![Arc::new(next_middleware.clone())];
632 let next = Next::new(middleware_stack.as_slice());
633
634 router
635 .handle(&mut req, &mut res, next.clone())
636 .await
637 .unwrap();
638
639 let inner = res.get_inner().await;
640 assert_eq!(inner.status_code(), Some(200));
641 assert_eq!(inner.body(), &Some(Body::Text("POST route".to_string())));
642
643 let next_called = *next_middleware.called.lock().await;
644 assert!(!next_called);
645 }
646
647 #[tokio::test]
648 async fn test_router_handles_delete_route() {
649 let mut router = Router::new("/api");
650
651 router.delete("/test", |_req, res| {
652 Box::pin(async move {
653 res.set_status(200).await;
654 res.body("DELETE route".to_string()).await;
655 Ok(())
656 })
657 });
658
659 let mut req = Request::new(
660 "DELETE /api/test HTTP/1.1\r\n\r\n",
661 Arc::new(Mutex::new(HashMap::new())),
662 )
663 .unwrap();
664 let mut res = Response::new(None);
665
666 let next_middleware = MockNextMiddleware::new();
667 let middleware_stack: Vec<Arc<dyn Middleware + Send + Sync>> =
668 vec![Arc::new(next_middleware.clone())];
669 let next = Next::new(middleware_stack.as_slice());
670
671 router
672 .handle(&mut req, &mut res, next.clone())
673 .await
674 .unwrap();
675
676 let inner = res.get_inner().await;
677 assert_eq!(inner.status_code(), Some(200));
678 assert_eq!(inner.body(), &Some(Body::Text("DELETE route".to_string())));
679
680 let next_called = *next_middleware.called.lock().await;
681 assert!(!next_called);
682 }
683}