trillium_server_common/
config_ext.rs

1use crate::{Acceptor, CloneCounterObserver, Config, Server, Stopper, Transport};
2use futures_lite::prelude::*;
3use std::{
4    io::ErrorKind,
5    net::{SocketAddr, TcpListener, ToSocketAddrs},
6};
7use trillium::Handler;
8use trillium_http::{transport::BoxedTransport, Conn as HttpConn, Error, SERVICE_UNAVAILABLE};
9/// # Server-implementer interfaces to Config
10///
11/// These functions are intended for use by authors of trillium servers,
12/// and should not be necessary to build an application. Please open
13/// an issue if you find yourself using this trait directly in an
14/// application.
15
16#[trillium::async_trait]
17pub trait ConfigExt<ServerType, AcceptorType>
18where
19    ServerType: Server,
20{
21    /// resolve a port for this application, either directly
22    /// configured, from the environmental variable `PORT`, or a default
23    /// of `8080`
24    fn port(&self) -> u16;
25
26    /// resolve the host for this application, either directly from
27    /// configuration, from the `HOST` env var, or `"localhost"`
28    fn host(&self) -> String;
29
30    /// use the [`ConfigExt::port`] and [`ConfigExt::host`] to resolve
31    /// a vec of potential socket addrs
32    fn socket_addrs(&self) -> Vec<SocketAddr>;
33
34    /// returns whether this server should register itself for
35    /// operating system signals. this flag does nothing aside from
36    /// communicating to the server implementer that this is
37    /// desired. defaults to true on `cfg(unix)` systems, and false
38    /// elsewhere.
39    fn should_register_signals(&self) -> bool;
40
41    /// returns whether the server should set TCP_NODELAY on the
42    /// TcpListener, if that is applicable
43    fn nodelay(&self) -> bool;
44
45    /// returns a clone of the [`Stopper`] associated with
46    /// this server, to be used in conjunction with signals or other
47    /// service interruption methods
48    fn stopper(&self) -> Stopper;
49
50    /// returns the tls acceptor for this server
51    fn acceptor(&self) -> &AcceptorType;
52
53    /// returns the [`CloneCounterObserver`] for this server
54    fn counter_observer(&self) -> &CloneCounterObserver;
55
56    /// waits for the last clone of the [`CloneCounter`][crate::CloneCounter] in this
57    /// config to drop, indicating that all outstanding requests are
58    /// complete
59    async fn graceful_shutdown(&self);
60
61    /// apply the provided handler to the transport, using
62    /// [`trillium_http`]'s http implementation. this is the default inner
63    /// loop for most trillium servers
64    async fn handle_stream(&self, stream: ServerType::Transport, handler: impl Handler);
65
66    /// builds any type that is TryFrom<std::net::TcpListener> and
67    /// configures it for use. most trillium servers should use this if
68    /// possible instead of using [`ConfigExt::port`],
69    /// [`ConfigExt::host`], or [`ConfigExt::socket_addrs`].
70    ///
71    /// this function also contains logic that sets nonblocking to
72    /// true and on unix systems will build a tcp listener from the
73    /// `LISTEN_FD` env var.
74    fn build_listener<Listener>(&self) -> Listener
75    where
76        Listener: TryFrom<TcpListener>,
77        <Listener as TryFrom<TcpListener>>::Error: std::fmt::Debug;
78
79    /// determines if the server is currently responding to more than
80    /// the maximum number of connections set by
81    /// `Config::with_max_connections`.
82    fn over_capacity(&self) -> bool;
83}
84
85#[trillium::async_trait]
86impl<ServerType, AcceptorType> ConfigExt<ServerType, AcceptorType>
87    for Config<ServerType, AcceptorType>
88where
89    ServerType: Server + Send + ?Sized,
90    AcceptorType: Acceptor<<ServerType as Server>::Transport>,
91{
92    fn port(&self) -> u16 {
93        self.port
94            .or_else(|| std::env::var("PORT").ok().and_then(|p| p.parse().ok()))
95            .unwrap_or(8080)
96    }
97
98    fn host(&self) -> String {
99        self.host
100            .as_ref()
101            .map(String::from)
102            .or_else(|| std::env::var("HOST").ok())
103            .unwrap_or_else(|| String::from("localhost"))
104    }
105
106    fn socket_addrs(&self) -> Vec<SocketAddr> {
107        (self.host(), self.port())
108            .to_socket_addrs()
109            .unwrap()
110            .collect()
111    }
112
113    fn should_register_signals(&self) -> bool {
114        self.register_signals
115    }
116
117    fn nodelay(&self) -> bool {
118        self.nodelay
119    }
120
121    fn stopper(&self) -> Stopper {
122        self.stopper.clone()
123    }
124
125    fn acceptor(&self) -> &AcceptorType {
126        &self.acceptor
127    }
128
129    fn counter_observer(&self) -> &CloneCounterObserver {
130        &self.observer
131    }
132
133    async fn graceful_shutdown(&self) {
134        let current = self.observer.current();
135        if current > 0 {
136            log::info!(
137                "waiting for {} open connection{} to close",
138                current,
139                if current == 1 { "" } else { "s" }
140            );
141            self.observer.clone().await;
142            log::info!("all done!")
143        }
144    }
145
146    async fn handle_stream(&self, mut stream: ServerType::Transport, handler: impl Handler) {
147        if self.over_capacity() {
148            let mut byte = [0u8]; // wait for the client to start requesting
149            trillium::log_error!(stream.read(&mut byte).await);
150            trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await);
151            return;
152        }
153
154        let counter = self.observer.counter();
155
156        trillium::log_error!(stream.set_nodelay(self.nodelay));
157
158        let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip());
159
160        let stream = match self.acceptor.accept(stream).await {
161            Ok(stream) => stream,
162            Err(e) => {
163                log::error!("acceptor error: {:?}", e);
164                return;
165            }
166        };
167
168        let handler = &handler;
169        let result = HttpConn::map_with_config(
170            self.http_config,
171            stream,
172            self.stopper.clone(),
173            |mut conn| async {
174                conn.set_peer_ip(peer_ip);
175                let conn = handler.run(conn.into()).await;
176                let conn = handler.before_send(conn).await;
177
178                conn.into_inner()
179            },
180        )
181        .await;
182
183        match result {
184            Ok(Some(upgrade)) => {
185                let upgrade = upgrade.map_transport(BoxedTransport::new);
186                if handler.has_upgrade(&upgrade) {
187                    log::debug!("upgrading...");
188                    handler.upgrade(upgrade).await;
189                } else {
190                    log::error!("upgrade specified but no upgrade handler provided");
191                }
192            }
193
194            Err(Error::Closed) | Ok(None) => {
195                log::debug!("closing connection");
196            }
197
198            Err(Error::Io(e))
199                if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe =>
200            {
201                log::debug!("closing connection");
202            }
203
204            Err(e) => {
205                log::error!("http error: {:?}", e);
206            }
207        };
208
209        drop(counter);
210    }
211
212    fn build_listener<Listener>(&self) -> Listener
213    where
214        Listener: TryFrom<TcpListener>,
215        <Listener as TryFrom<TcpListener>>::Error: std::fmt::Debug,
216    {
217        #[cfg(unix)]
218        let listener = {
219            use std::os::unix::prelude::FromRawFd;
220
221            if let Some(fd) = std::env::var("LISTEN_FD")
222                .ok()
223                .and_then(|fd| fd.parse().ok())
224            {
225                log::debug!("using fd {} from LISTEN_FD", fd);
226                unsafe { TcpListener::from_raw_fd(fd) }
227            } else {
228                TcpListener::bind((self.host(), self.port())).unwrap()
229            }
230        };
231
232        #[cfg(not(unix))]
233        let listener = TcpListener::bind((self.host(), self.port())).unwrap();
234
235        listener.set_nonblocking(true).unwrap();
236        listener.try_into().unwrap()
237    }
238
239    fn over_capacity(&self) -> bool {
240        self.max_connections
241            .map_or(false, |m| self.observer.current() >= m)
242    }
243}