springtime_web_axum/
router.rs

1//! Controller routing handling. By default, routing is based on gathering existing controllers and
2//! their request handlers.
3
4use crate::controller::Controller;
5use axum::Router;
6#[cfg(test)]
7use mockall::automock;
8use springtime_di::component_registry::conditional::unregistered_component;
9use springtime_di::instance_provider::{ComponentInstancePtr, ErrorPtr};
10use springtime_di::{component_alias, injectable, Component};
11use tracing::debug;
12
13/// Trait for configuring [Router] created by [RouterBootstrap]. Multiple such components can be
14/// present and each one will be called with the current router instance.
15#[injectable]
16#[cfg_attr(test, automock)]
17pub trait RouterConfigure {
18    /// Configure and return existing [Router].
19    fn configure(&self, router: Router) -> Result<Router, ErrorPtr>;
20}
21
22/// Trait for creating a [Router], usually based on injected
23/// [Controller](crate::controller::Controller)s.
24#[injectable]
25pub trait RouterBootstrap {
26    /// Creates a new [Router].
27    fn bootstrap_router(&self, server_name: &str) -> Result<Router, ErrorPtr>;
28}
29
30#[derive(Component)]
31#[component(priority = -128, condition = "unregistered_component::<dyn RouterBootstrap + Send + Sync>")]
32struct ControllerRouterBootstrap {
33    controllers: Vec<ComponentInstancePtr<dyn Controller + Send + Sync>>,
34    configure_components: Vec<ComponentInstancePtr<dyn RouterConfigure + Send + Sync>>,
35}
36
37#[component_alias]
38impl RouterBootstrap for ControllerRouterBootstrap {
39    fn bootstrap_router(&self, server_name: &str) -> Result<Router, ErrorPtr> {
40        self.controllers
41            .iter()
42            .filter(|controller| {
43                controller
44                    .server_names()
45                    .map(|server_names| server_names.contains(server_name))
46                    .unwrap_or(true)
47            })
48            .try_fold(Router::new(), |router, controller| {
49                let path = controller.path().unwrap_or_else(|| "/".to_string());
50                let inner_router = controller.create_router()?;
51
52                debug!(path, "Registering new controller routes.");
53
54                controller
55                    .configure_router(inner_router, controller.clone())
56                    .and_then(|inner_router| controller.post_configure_router(inner_router))
57                    .map(|inner_router| {
58                        if path.is_empty() || path == "/" {
59                            // cannot nest root-level routers
60                            router.merge(inner_router)
61                        } else {
62                            router.nest(&path, inner_router)
63                        }
64                    })
65            })
66            .and_then(|router| {
67                self.configure_components
68                    .iter()
69                    .try_fold(router, |router, configure| configure.configure(router))
70            })
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use crate::controller::MockController;
77    use crate::router::{ControllerRouterBootstrap, MockRouterConfigure, RouterBootstrap};
78    use axum::Router;
79    use rustc_hash::FxHashSet;
80    use springtime_di::instance_provider::ComponentInstancePtr;
81
82    #[test]
83    fn should_configure_router_with_filtering() {
84        let mut controller = MockController::new();
85        controller
86            .expect_configure_router()
87            .times(1)
88            .return_const(Ok(Router::new()));
89        controller.expect_server_names().times(1).return_const(
90            ["1".to_string(), "2".to_string()]
91                .into_iter()
92                .collect::<FxHashSet<_>>(),
93        );
94        controller.expect_path().return_const(None);
95        controller
96            .expect_create_router()
97            .return_const(Ok(Router::new()));
98        controller
99            .expect_post_configure_router()
100            .returning(|router| Ok(router));
101
102        let bootstrap = ControllerRouterBootstrap {
103            controllers: vec![ComponentInstancePtr::new(controller)],
104            configure_components: vec![],
105        };
106        assert!(bootstrap.bootstrap_router("1").is_ok());
107    }
108
109    #[test]
110    fn should_not_configure_router_with_filtering() {
111        let mut controller = MockController::new();
112        controller
113            .expect_configure_router()
114            .times(0)
115            .return_const(Ok(Router::new()));
116        controller.expect_server_names().times(1).return_const(
117            ["1".to_string(), "2".to_string()]
118                .into_iter()
119                .collect::<FxHashSet<_>>(),
120        );
121
122        let bootstrap = ControllerRouterBootstrap {
123            controllers: vec![ComponentInstancePtr::new(controller)],
124            configure_components: vec![],
125        };
126        assert!(bootstrap.bootstrap_router("3").is_ok());
127    }
128
129    #[test]
130    fn should_pass_existing_router_for_configuration() {
131        let mut configure = MockRouterConfigure::new();
132        configure
133            .expect_configure()
134            .times(1)
135            .returning(|router| Ok(router));
136
137        let bootstrap = ControllerRouterBootstrap {
138            controllers: vec![],
139            configure_components: vec![ComponentInstancePtr::new(configure)],
140        };
141        assert!(bootstrap.bootstrap_router("1").is_ok());
142    }
143}