trillium_server_common/
config_ext.rs1use crate::{Acceptor, CloneCounterObserver, Config, Server, Stopper, Transport};
2use futures_lite::prelude::*;
3use std::{
4 io::ErrorKind,
5 net::{SocketAddr, TcpListener, ToSocketAddrs},
6};
7use trillium::Handler;
8use trillium_http::{transport::BoxedTransport, Conn as HttpConn, Error, SERVICE_UNAVAILABLE};
9#[trillium::async_trait]
17pub trait ConfigExt<ServerType, AcceptorType>
18where
19 ServerType: Server,
20{
21 fn port(&self) -> u16;
25
26 fn host(&self) -> String;
29
30 fn socket_addrs(&self) -> Vec<SocketAddr>;
33
34 fn should_register_signals(&self) -> bool;
40
41 fn nodelay(&self) -> bool;
44
45 fn stopper(&self) -> Stopper;
49
50 fn acceptor(&self) -> &AcceptorType;
52
53 fn counter_observer(&self) -> &CloneCounterObserver;
55
56 async fn graceful_shutdown(&self);
60
61 async fn handle_stream(&self, stream: ServerType::Transport, handler: impl Handler);
65
66 fn build_listener<Listener>(&self) -> Listener
75 where
76 Listener: TryFrom<TcpListener>,
77 <Listener as TryFrom<TcpListener>>::Error: std::fmt::Debug;
78
79 fn over_capacity(&self) -> bool;
83}
84
85#[trillium::async_trait]
86impl<ServerType, AcceptorType> ConfigExt<ServerType, AcceptorType>
87 for Config<ServerType, AcceptorType>
88where
89 ServerType: Server + Send + ?Sized,
90 AcceptorType: Acceptor<<ServerType as Server>::Transport>,
91{
92 fn port(&self) -> u16 {
93 self.port
94 .or_else(|| std::env::var("PORT").ok().and_then(|p| p.parse().ok()))
95 .unwrap_or(8080)
96 }
97
98 fn host(&self) -> String {
99 self.host
100 .as_ref()
101 .map(String::from)
102 .or_else(|| std::env::var("HOST").ok())
103 .unwrap_or_else(|| String::from("localhost"))
104 }
105
106 fn socket_addrs(&self) -> Vec<SocketAddr> {
107 (self.host(), self.port())
108 .to_socket_addrs()
109 .unwrap()
110 .collect()
111 }
112
113 fn should_register_signals(&self) -> bool {
114 self.register_signals
115 }
116
117 fn nodelay(&self) -> bool {
118 self.nodelay
119 }
120
121 fn stopper(&self) -> Stopper {
122 self.stopper.clone()
123 }
124
125 fn acceptor(&self) -> &AcceptorType {
126 &self.acceptor
127 }
128
129 fn counter_observer(&self) -> &CloneCounterObserver {
130 &self.observer
131 }
132
133 async fn graceful_shutdown(&self) {
134 let current = self.observer.current();
135 if current > 0 {
136 log::info!(
137 "waiting for {} open connection{} to close",
138 current,
139 if current == 1 { "" } else { "s" }
140 );
141 self.observer.clone().await;
142 log::info!("all done!")
143 }
144 }
145
146 async fn handle_stream(&self, mut stream: ServerType::Transport, handler: impl Handler) {
147 if self.over_capacity() {
148 let mut byte = [0u8]; trillium::log_error!(stream.read(&mut byte).await);
150 trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await);
151 return;
152 }
153
154 let counter = self.observer.counter();
155
156 trillium::log_error!(stream.set_nodelay(self.nodelay));
157
158 let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip());
159
160 let stream = match self.acceptor.accept(stream).await {
161 Ok(stream) => stream,
162 Err(e) => {
163 log::error!("acceptor error: {:?}", e);
164 return;
165 }
166 };
167
168 let handler = &handler;
169 let result = HttpConn::map_with_config(
170 self.http_config,
171 stream,
172 self.stopper.clone(),
173 |mut conn| async {
174 conn.set_peer_ip(peer_ip);
175 let conn = handler.run(conn.into()).await;
176 let conn = handler.before_send(conn).await;
177
178 conn.into_inner()
179 },
180 )
181 .await;
182
183 match result {
184 Ok(Some(upgrade)) => {
185 let upgrade = upgrade.map_transport(BoxedTransport::new);
186 if handler.has_upgrade(&upgrade) {
187 log::debug!("upgrading...");
188 handler.upgrade(upgrade).await;
189 } else {
190 log::error!("upgrade specified but no upgrade handler provided");
191 }
192 }
193
194 Err(Error::Closed) | Ok(None) => {
195 log::debug!("closing connection");
196 }
197
198 Err(Error::Io(e))
199 if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe =>
200 {
201 log::debug!("closing connection");
202 }
203
204 Err(e) => {
205 log::error!("http error: {:?}", e);
206 }
207 };
208
209 drop(counter);
210 }
211
212 fn build_listener<Listener>(&self) -> Listener
213 where
214 Listener: TryFrom<TcpListener>,
215 <Listener as TryFrom<TcpListener>>::Error: std::fmt::Debug,
216 {
217 #[cfg(unix)]
218 let listener = {
219 use std::os::unix::prelude::FromRawFd;
220
221 if let Some(fd) = std::env::var("LISTEN_FD")
222 .ok()
223 .and_then(|fd| fd.parse().ok())
224 {
225 log::debug!("using fd {} from LISTEN_FD", fd);
226 unsafe { TcpListener::from_raw_fd(fd) }
227 } else {
228 TcpListener::bind((self.host(), self.port())).unwrap()
229 }
230 };
231
232 #[cfg(not(unix))]
233 let listener = TcpListener::bind((self.host(), self.port())).unwrap();
234
235 listener.set_nonblocking(true).unwrap();
236 listener.try_into().unwrap()
237 }
238
239 fn over_capacity(&self) -> bool {
240 self.max_connections
241 .map_or(false, |m| self.observer.current() >= m)
242 }
243}