viz_router/
router.rs

1use viz_core::{
2    BoxHandler, Handler, HandlerExt, IntoResponse, Next, Request, Response, Result, Transform,
3};
4
5use crate::{Resources, Route};
6
7macro_rules! export_verb {
8    ($name:ident $verb:ty) => {
9        #[doc = concat!(" Adds a handler with a path and HTTP `", stringify!($verb), "` verb pair.")]
10        #[must_use]
11        pub fn $name<S, H, O>(self, path: S, handler: H) -> Self
12        where
13            S: AsRef<str>,
14            H: Handler<Request, Output = Result<O>> + Clone,
15            O: IntoResponse + Send + 'static,
16        {
17            self.route(path, Route::new().$name(handler))
18        }
19    };
20}
21
22/// A routes collection.
23#[derive(Clone, Debug, Default)]
24pub struct Router {
25    pub(crate) routes: Option<Vec<(String, Route)>>,
26}
27
28impl Router {
29    /// Creates an empty `Router`.
30    #[must_use]
31    pub const fn new() -> Self {
32        Self { routes: None }
33    }
34
35    fn push<S>(routes: &mut Vec<(String, Route)>, path: S, route: Route)
36    where
37        S: AsRef<str>,
38    {
39        let path = path.as_ref();
40        match routes
41            .iter_mut()
42            .find_map(|(p, r)| if p == path { Some(r) } else { None })
43        {
44            Some(r) => {
45                *r = route.into_iter().fold(
46                    // original route
47                    r.clone().into_iter().collect(),
48                    |or: Route, (method, handler)| or.on(method, handler),
49                );
50            }
51            None => routes.push((path.to_string(), route)),
52        }
53    }
54
55    /// Inserts a path-route pair into the router.
56    #[must_use]
57    pub fn route<S>(mut self, path: S, route: Route) -> Self
58    where
59        S: AsRef<str>,
60    {
61        Self::push(
62            self.routes.get_or_insert_with(Vec::new),
63            path.as_ref().trim_start_matches('/'),
64            route,
65        );
66        self
67    }
68
69    /// Nested resources with a path.
70    #[must_use]
71    pub fn resources<S>(self, path: S, resource: Resources) -> Self
72    where
73        S: AsRef<str>,
74    {
75        let mut path = path.as_ref().to_string();
76        if !path.ends_with('/') {
77            path.push('/');
78        }
79
80        resource.into_iter().fold(self, |router, (mut sp, route)| {
81            let is_empty = sp.is_empty();
82            sp = path.clone() + &sp;
83            if is_empty {
84                sp = sp.trim_end_matches('/').to_string();
85            }
86            router.route(sp, route)
87        })
88    }
89
90    /// Nested sub-router with a path.
91    #[allow(clippy::similar_names)]
92    #[must_use]
93    pub fn nest<S>(self, path: S, router: Self) -> Self
94    where
95        S: AsRef<str>,
96    {
97        let mut path = path.as_ref().to_string();
98        if !path.ends_with('/') {
99            path.push('/');
100        }
101
102        match router.routes {
103            Some(routes) => routes.into_iter().fold(self, |router, (mut sp, route)| {
104                let is_empty = sp.is_empty();
105                sp = path.clone() + &sp;
106                if is_empty {
107                    sp = sp.trim_end_matches('/').to_string();
108                }
109                router.route(sp, route)
110            }),
111            None => self,
112        }
113    }
114
115    repeat!(
116        export_verb
117        get GET
118        post POST
119        put PUT
120        delete DELETE
121        head HEAD
122        options OPTIONS
123        connect CONNECT
124        patch PATCH
125        trace TRACE
126    );
127
128    /// Adds a handler with a path and any HTTP verbs."
129    #[must_use]
130    pub fn any<S, H, O>(self, path: S, handler: H) -> Self
131    where
132        S: AsRef<str>,
133        H: Handler<Request, Output = Result<O>> + Clone,
134        O: IntoResponse + Send + 'static,
135    {
136        self.route(path, Route::new().any(handler))
137    }
138
139    /// Takes a closure and creates an iterator which calls that closure on each handler.
140    #[must_use]
141    pub fn map_handler<F>(self, f: F) -> Self
142    where
143        F: Fn(BoxHandler<Request, Result<Response>>) -> BoxHandler<Request, Result<Response>>,
144    {
145        Self {
146            routes: self.routes.map(|routes| {
147                routes
148                    .into_iter()
149                    .map(|(path, route)| {
150                        (
151                            path,
152                            route
153                                .into_iter()
154                                .map(|(method, handler)| (method, f(handler)))
155                                .collect(),
156                        )
157                    })
158                    .collect()
159            }),
160        }
161    }
162
163    /// Transforms the types to a middleware and adds it.
164    #[must_use]
165    pub fn with<T>(self, t: T) -> Self
166    where
167        T: Transform<BoxHandler>,
168        T::Output: Handler<Request, Output = Result<Response>> + Clone,
169    {
170        self.map_handler(|handler| t.transform(handler).boxed())
171    }
172
173    /// Adds a middleware for the routes.
174    #[must_use]
175    pub fn with_handler<H>(self, f: H) -> Self
176    where
177        H: Handler<Next<Request, BoxHandler>, Output = Result<Response>> + Clone,
178    {
179        self.map_handler(|handler| handler.around(f.clone()).boxed())
180    }
181}
182
183#[cfg(test)]
184#[allow(clippy::unused_async)]
185mod tests {
186    use http_body_util::{BodyExt, Full};
187    use std::sync::Arc;
188    use viz_core::{
189        Body, Error, Handler, HandlerExt, IntoResponse, Method, Next, Request, RequestExt,
190        Response, ResponseExt, Result, StatusCode, Transform, async_trait,
191        types::{Params, RouteInfo},
192    };
193
194    use crate::{Resources, Route, Router, Tree, any, get};
195
196    #[derive(Clone)]
197    struct Logger;
198
199    impl Logger {
200        const fn new() -> Self {
201            Self
202        }
203    }
204
205    impl<H: Clone> Transform<H> for Logger {
206        type Output = LoggerHandler<H>;
207
208        fn transform(&self, h: H) -> Self::Output {
209            LoggerHandler(h)
210        }
211    }
212
213    #[derive(Clone)]
214    struct LoggerHandler<H>(H);
215
216    #[async_trait]
217    impl<H> Handler<Request> for LoggerHandler<H>
218    where
219        H: Handler<Request>,
220    {
221        type Output = H::Output;
222
223        async fn call(&self, req: Request) -> Self::Output {
224            self.0.call(req).await
225        }
226    }
227
228    #[tokio::test]
229    async fn router() -> anyhow::Result<()> {
230        async fn index(_: Request) -> Result<Response> {
231            Ok(Response::text("index"))
232        }
233
234        async fn all(_: Request) -> Result<Response> {
235            Ok(Response::text("any"))
236        }
237
238        async fn not_found(_: Request) -> Result<impl IntoResponse> {
239            Ok(StatusCode::NOT_FOUND)
240        }
241
242        async fn search(_: Request) -> Result<Response> {
243            Ok(Response::text("search"))
244        }
245
246        async fn show(req: Request) -> Result<Response> {
247            let ids: Vec<String> = req.params()?;
248            let items = ids.into_iter().fold(String::new(), |mut s, id| {
249                s.push(' ');
250                s.push_str(&id);
251                s
252            });
253            Ok(Response::text("show".to_string() + &items))
254        }
255
256        async fn create(_: Request) -> Result<Response> {
257            Ok(Response::text("create"))
258        }
259
260        async fn update(req: Request) -> Result<Response> {
261            let ids: Vec<String> = req.params()?;
262            let items = ids.into_iter().fold(String::new(), |mut s, id| {
263                s.push(' ');
264                s.push_str(&id);
265                s
266            });
267            Ok(Response::text("update".to_string() + &items))
268        }
269
270        async fn delete(req: Request) -> Result<Response> {
271            let ids: Vec<String> = req.params()?;
272            let items = ids.into_iter().fold(String::new(), |mut s, id| {
273                s.push(' ');
274                s.push_str(&id);
275                s
276            });
277            Ok(Response::text("delete".to_string() + &items))
278        }
279
280        async fn middle<H>((req, h): Next<Request, H>) -> Result<Response>
281        where
282            H: Handler<Request, Output = Result<Response>>,
283        {
284            h.call(req).await
285        }
286
287        let users = Resources::default()
288            .named("user")
289            .index(index)
290            .create(create.before(|r: Request| async { Ok(r) }).around(middle))
291            .show(show)
292            .update(update)
293            .destroy(delete)
294            .map_handler(|h| {
295                h.and_then(|res: Response| async {
296                    let (parts, body) = res.into_parts();
297
298                    let mut buf = bytes::BytesMut::new();
299                    buf.extend(b"users: ");
300                    buf.extend(body.collect().await.map_err(Error::boxed)?.to_bytes());
301
302                    Ok(Response::from_parts(parts, Full::from(buf.freeze()).into()))
303                })
304                .boxed()
305            });
306
307        let posts = Router::new().route("search", get(search)).resources(
308            "",
309            Resources::default()
310                .named("post")
311                .create(create)
312                .show(show)
313                .update(update)
314                .destroy(delete)
315                .map_handler(|h| {
316                    h.and_then(|res: Response| async {
317                        let (parts, body) = res.into_parts();
318
319                        let mut buf = bytes::BytesMut::new();
320                        buf.extend(b"posts: ");
321                        buf.extend(body.collect().await.map_err(Error::boxed)?.to_bytes());
322
323                        Ok(Response::from_parts(parts, Full::from(buf.freeze()).into()))
324                    })
325                    .boxed()
326                }),
327        );
328
329        let router = Router::new()
330            // .route("", get(index))
331            .get("", index)
332            .resources("users", users.clone())
333            .nest("posts", posts.resources(":post_id/users", users))
334            .route("search", any(all))
335            .route("*", Route::new().any(not_found))
336            .with(Logger::new());
337
338        let tree: Tree = router.into();
339
340        // GET /posts
341        let (req, method, path) = client(Method::GET, "/posts");
342        let node = tree.find(&method, &path);
343        assert!(node.is_some());
344        let (h, _) = node.unwrap();
345        assert_eq!(
346            h.call(req).await?.into_body().collect().await?.to_bytes(),
347            ""
348        );
349
350        // POST /posts
351        let (req, method, path) = client(Method::POST, "/posts");
352        let node = tree.find(&method, &path);
353        assert!(node.is_some());
354        let (h, _) = node.unwrap();
355        assert_eq!(
356            h.call(req).await?.into_body().collect().await?.to_bytes(),
357            "posts: create"
358        );
359
360        // GET /posts/foo
361        let (mut req, method, path) = client(Method::GET, "/posts/foo");
362        let node = tree.find(&method, &path);
363        assert!(node.is_some());
364        let (h, route) = node.unwrap();
365        req.extensions_mut().insert(Arc::from(RouteInfo {
366            id: *route.id,
367            pattern: route.pattern(),
368            params: route.params().into(),
369        }));
370        assert_eq!(
371            h.call(req).await?.into_body().collect().await?.to_bytes(),
372            "posts: show foo"
373        );
374
375        // PUT /posts/foo
376        let (mut req, method, path) = client(Method::PUT, "/posts/foo");
377        let node = tree.find(&method, &path);
378        assert!(node.is_some());
379        let (h, route) = node.unwrap();
380        req.extensions_mut().insert(Arc::from(RouteInfo {
381            id: *route.id,
382            pattern: route.pattern(),
383            params: Into::<Params>::into(route.params()),
384        }));
385        assert_eq!(
386            h.call(req).await?.into_body().collect().await?.to_bytes(),
387            "posts: update foo"
388        );
389
390        // DELETE /posts/foo
391        let (mut req, method, path) = client(Method::DELETE, "/posts/foo");
392        let node = tree.find(&method, &path);
393        assert!(node.is_some());
394        let (h, route) = node.unwrap();
395        req.extensions_mut().insert(Arc::from(RouteInfo {
396            id: *route.id,
397            pattern: route.pattern(),
398            params: route.params().into(),
399        }));
400        assert_eq!(
401            h.call(req).await?.into_body().collect().await?.to_bytes(),
402            "posts: delete foo"
403        );
404
405        // GET /posts/foo/users
406        let (req, method, path) = client(Method::GET, "/posts/foo/users");
407        let node = tree.find(&method, &path);
408        assert!(node.is_some());
409        let (h, _) = node.unwrap();
410        assert_eq!(
411            h.call(req).await?.into_body().collect().await?.to_bytes(),
412            "users: index"
413        );
414
415        // POST /posts/users
416        let (req, method, path) = client(Method::POST, "/posts/foo/users");
417        let node = tree.find(&method, &path);
418        assert!(node.is_some());
419        let (h, _) = node.unwrap();
420        assert_eq!(
421            h.call(req).await?.into_body().collect().await?.to_bytes(),
422            "users: create"
423        );
424
425        // GET /posts/foo/users/bar
426        let (mut req, method, path) = client(Method::GET, "/posts/foo/users/bar");
427        let node = tree.find(&method, &path);
428        assert!(node.is_some());
429        let (h, route) = node.unwrap();
430        req.extensions_mut().insert(Arc::from(RouteInfo {
431            id: *route.id,
432            pattern: route.pattern(),
433            params: route.params().into(),
434        }));
435        assert_eq!(
436            h.call(req).await?.into_body().collect().await?.to_bytes(),
437            "users: show foo bar"
438        );
439
440        // PUT /posts/foo/users/bar
441        let (mut req, method, path) = client(Method::PUT, "/posts/foo/users/bar");
442        let node = tree.find(&method, &path);
443        assert!(node.is_some());
444        let (h, route) = node.unwrap();
445        let route_info = Arc::from(RouteInfo {
446            id: *route.id,
447            pattern: route.pattern(),
448            params: route.params().into(),
449        });
450        assert_eq!(route.pattern(), "/posts/:post_id/users/:user_id");
451        assert_eq!(route_info.pattern, "/posts/:post_id/users/:user_id");
452        req.extensions_mut().insert(route_info);
453        assert_eq!(
454            h.call(req).await?.into_body().collect().await?.to_bytes(),
455            "users: update foo bar"
456        );
457
458        // DELETE /posts/foo/users/bar
459        let (mut req, method, path) = client(Method::DELETE, "/posts/foo/users/bar");
460        let node = tree.find(&method, &path);
461        assert!(node.is_some());
462        let (h, route) = node.unwrap();
463        req.extensions_mut().insert(Arc::from(RouteInfo {
464            id: *route.id,
465            pattern: route.pattern(),
466            params: route.params().into(),
467        }));
468        assert_eq!(
469            h.call(req).await?.into_body().collect().await?.to_bytes(),
470            "users: delete foo bar"
471        );
472
473        Ok(())
474    }
475
476    #[test]
477    fn debug() {
478        let search = Route::new().get(|_: Request| async { Ok(Response::text("search")) });
479
480        let orgs = Resources::default()
481            .index(|_: Request| async { Ok(Response::text("list posts")) })
482            .create(|_: Request| async { Ok(Response::text("create post")) })
483            .show(|_: Request| async { Ok(Response::text("show post")) });
484
485        let settings = Router::new()
486            .get("/", |_: Request| async { Ok(Response::text("settings")) })
487            .get("/:page", |_: Request| async {
488                Ok(Response::text("setting page"))
489            });
490
491        let app = Router::new()
492            .get("/", |_: Request| async { Ok(Response::text("index")) })
493            .route("search", search.clone())
494            .resources(":org", orgs)
495            .nest("settings", settings)
496            .nest("api", Router::new().route("/search", search));
497
498        let tree: Tree = app.into();
499
500        assert_eq!(
501            format!("{tree:#?}"),
502            "Tree {
503    method: GET,
504    paths: 
505    / •0
506    ├── api/search •6
507    ├── se
508    │   ├── arch •1
509    │   └── ttings •4
510    │       └── /
511    │           └── : •5
512    └── : •2
513        └── /
514            └── : •3
515    ,
516    method: POST,
517    paths: 
518    /
519    └── : •0
520    ,
521}"
522        );
523    }
524
525    fn client(method: Method, path: &str) -> (Request, Method, String) {
526        (
527            Request::builder()
528                .method(method.clone())
529                .uri(path.to_owned())
530                .body(Body::Empty)
531                .unwrap(),
532            method,
533            path.to_string(),
534        )
535    }
536}