springtime_web_axum/
server.rs

1//! Core server-related functionality.
2
3use crate::config::{ServerConfig, WebConfig, WebConfigProvider};
4use crate::router::RouterBootstrap;
5use futures::future::try_join_all;
6use springtime::future::{BoxFuture, FutureExt};
7use springtime::runner::ApplicationRunner;
8use springtime_di::component_registry::conditional::unregistered_component;
9use springtime_di::instance_provider::{ComponentInstancePtr, ErrorPtr};
10use springtime_di::{component_alias, injectable, Component};
11use std::future::{Future, IntoFuture};
12use std::sync::Arc;
13use thiserror::Error;
14use tokio::net::TcpListener;
15use tokio::select;
16use tokio::sync::watch::{channel, Receiver, Sender};
17use tracing::{debug, info};
18
19pub type ShutdownSignalSender = Sender<()>;
20
21/// Errors related to bootstrapping servers.
22#[derive(Error, Debug)]
23pub enum ServerBootstrapError {
24    #[error("Error binding server: {0}")]
25    BindError(#[source] tokio::io::Error),
26    #[error("Error configuring router: {0}")]
27    RouterError(#[source] ErrorPtr),
28}
29
30/// Trait for components responsible for creating web servers from
31/// [ServerConfig](crate::config::ServerConfig). Create a component implementing this trait to
32/// override the default bootstrap.
33#[injectable]
34pub trait ServerBootstrap {
35    /// Create a [Builder] which will them be used to create a web server.
36    fn bootstrap_server(
37        &self,
38        config: &ServerConfig,
39    ) -> BoxFuture<'_, Result<TcpListener, ServerBootstrapError>>;
40}
41
42#[derive(Component)]
43#[component(priority = -128, condition = "unregistered_component::<dyn ServerBootstrap + Send + Sync>")]
44struct DefaultServerBootstrap;
45
46#[component_alias]
47impl ServerBootstrap for DefaultServerBootstrap {
48    fn bootstrap_server(
49        &self,
50        config: &ServerConfig,
51    ) -> BoxFuture<'_, Result<TcpListener, ServerBootstrapError>> {
52        let listen_address = config.listen_address.clone();
53
54        async move {
55            TcpListener::bind(&listen_address)
56                .await
57                .map_err(ServerBootstrapError::BindError)
58        }
59        .boxed()
60    }
61}
62
63#[derive(Component)]
64struct ServerRunner {
65    server_bootstrap: ComponentInstancePtr<dyn ServerBootstrap + Send + Sync>,
66    router_bootstrap: ComponentInstancePtr<dyn RouterBootstrap + Send + Sync>,
67    config_provider: ComponentInstancePtr<dyn WebConfigProvider + Send + Sync>,
68    shutdown_signal_source: Option<ComponentInstancePtr<dyn ShutdownSignalSource + Send + Sync>>,
69}
70
71#[component_alias]
72impl ApplicationRunner for ServerRunner {
73    fn run(&self) -> BoxFuture<'_, Result<(), ErrorPtr>> {
74        async {
75            info!("Starting servers...");
76
77            let (tx, rx) = channel(());
78            if let Some(shutdown_signal_source) = &self.shutdown_signal_source {
79                shutdown_signal_source.register_shutdown(tx)?;
80            }
81
82            let config = self.config_provider.config().await?;
83            let servers = self
84                .create_servers(config, rx)
85                .await
86                .map_err(|error| Arc::new(error) as ErrorPtr)?;
87
88            info!("Running {} servers...", servers.len());
89
90            try_join_all(servers.into_iter()).await?;
91
92            info!("All servers stopped.");
93
94            Ok(())
95        }
96        .boxed()
97    }
98}
99
100impl ServerRunner {
101    async fn create_server(
102        &self,
103        config: &ServerConfig,
104        server_name: &str,
105        mut shutdown_receiver: Receiver<()>,
106    ) -> Result<impl Future<Output = Result<(), ErrorPtr>>, ServerBootstrapError> {
107        debug!(server_name, "Creating new server.");
108
109        let router = self
110            .router_bootstrap
111            .bootstrap_router(server_name)
112            .map_err(ServerBootstrapError::RouterError)?;
113
114        self.server_bootstrap
115            .bootstrap_server(config)
116            .await
117            .map(move |listener| async move {
118                let serve = axum::serve(listener, router.into_make_service()).into_future();
119
120                select! {
121                    result = serve => {
122                        result.map_err(|error| Arc::new(error) as ErrorPtr)
123                    }
124                    _ = shutdown_receiver.changed() => {
125                        Ok(())
126                    }
127                }
128            })
129    }
130
131    async fn create_servers(
132        &self,
133        config: &WebConfig,
134        shutdown_receiver: Receiver<()>,
135    ) -> Result<Vec<impl Future<Output = Result<(), ErrorPtr>>>, ServerBootstrapError> {
136        let mut result = Vec::with_capacity(config.servers.len());
137        for (server_name, config) in config.servers.iter() {
138            result.push(
139                self.create_server(config, server_name, shutdown_receiver.clone())
140                    .await?,
141            );
142        }
143
144        Ok(result)
145    }
146}
147
148/// Source for gracefully shutting down the server runner with all running servers. Only the primary
149/// instance is taken into account.
150#[injectable]
151pub trait ShutdownSignalSource {
152    /// Takes the given signal sender to add custom shutdown signaling logic.
153    fn register_shutdown(&self, shutdown_sender: ShutdownSignalSender) -> Result<(), ErrorPtr>;
154}