use std::error::Error as StdError;
use std::future::Future;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::Path;
use std::sync::Arc;
use crate::http::Mime;
use futures::{TryStream, TryStreamExt};
use hyper::server::accept::{self, Accept};
use hyper::server::conn::AddrIncoming;
use hyper::Server as HyperServer;
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::transport::LiftIo;
use crate::{Catcher, Router, Service};
pub fn builder<I>(incoming: I) -> hyper::server::Builder<I> {
HyperServer::builder(incoming)
}
pub struct Server {
service: Service,
}
impl Server {
pub fn new(router: Router) -> Server {
Server {
service: Service::new(router),
}
}
pub fn with_catchers(mut self, catchers: Vec<Box<dyn Catcher>>) -> Self {
self.service.catchers = Arc::new(catchers);
self
}
pub fn with_allowed_media_types(mut self, allowed_media_types: Vec<Mime>) -> Self {
self.service.allowed_media_types = Arc::new(allowed_media_types);
self
}
fn create_bind_hyper_server(
self,
addr: impl Into<SocketAddr>,
) -> Result<(SocketAddr, hyper::Server<AddrIncoming, Service>), hyper::Error> {
let addr = addr.into();
let mut incoming = AddrIncoming::bind(&addr)?;
incoming.set_nodelay(true);
Ok((addr, builder(incoming).serve(self.service)))
}
#[inline]
fn create_bind_incoming_hyper_server<S>(
self,
incoming: S,
) -> hyper::Server<impl Accept<Conn = LiftIo<S::Ok>, Error = S::Error>, Service>
where
S: TryStream + Send,
S::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
{
builder(accept::from_stream(incoming.map_ok(LiftIo).into_stream())).serve(self.service)
}
pub async fn bind(self, addr: impl Into<SocketAddr> + 'static) {
self.try_bind(addr).await.unwrap();
}
pub async fn try_bind(self, addr: impl Into<SocketAddr>) -> Result<SocketAddr, hyper::Error> {
let (addr, srv) = self.create_bind_hyper_server(addr)?;
tracing::info!("listening with socket addr: {}", addr);
if let Err(err) = srv.await {
tracing::error!("server error: {}", err);
Err(err)
} else {
Ok(addr)
}
}
pub async fn bind_with_graceful_shutdown(
self,
addr: impl Into<SocketAddr> + 'static,
signal: impl Future<Output = ()> + Send + 'static,
) {
self.try_bind_with_graceful_shutdown(addr, signal).await.unwrap();
}
pub async fn try_bind_with_graceful_shutdown(
self,
addr: impl Into<SocketAddr> + 'static,
signal: impl Future<Output = ()> + Send + 'static,
) -> Result<SocketAddr, hyper::Error> {
let (addr, srv) = self.create_bind_hyper_server(addr)?;
if let Err(err) = srv.with_graceful_shutdown(signal).await {
tracing::error!("server error: {}", err);
Err(err)
} else {
Ok(addr)
}
}
pub async fn bind_incoming<I>(self, incoming: I)
where
I: TryStream + Send,
I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
{
self.try_bind_incoming(incoming).await.unwrap();
}
pub async fn try_bind_incoming<I>(self, incoming: I) -> Result<(), hyper::Error>
where
I: TryStream + Send,
I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let srv = self.create_bind_incoming_hyper_server(incoming);
tracing::info!("listening with custom incoming");
if let Err(err) = srv.await {
tracing::error!("server error: {}", err);
Err(err)
} else {
Ok(())
}
}
pub async fn bind_incoming_with_graceful_shutdown<I>(
self,
incoming: I,
signal: impl Future<Output = ()> + Send + 'static,
) where
I: TryStream + Send,
I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
{
self.try_bind_incoming_with_graceful_shutdown(incoming, signal)
.await
.unwrap();
}
pub async fn try_bind_incoming_with_graceful_shutdown<I>(
self,
incoming: I,
signal: impl Future<Output = ()> + Send + 'static,
) -> Result<(), hyper::Error>
where
I: TryStream + Send,
I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let srv = self.create_bind_incoming_hyper_server(incoming);
tracing::info!("listening with custom incoming");
if let Err(err) = srv.with_graceful_shutdown(signal).await {
tracing::error!("server error: {}", err);
Err(err)
} else {
Ok(())
}
}
#[cfg(feature = "tls")]
pub fn tls(self) -> TlsServer {
TlsServer {
service: self.service,
config: TlsConfigBuilder::new(),
}
}
}
#[cfg(feature = "tls")]
pub struct TlsServer {
service: Service,
config: TlsConfigBuilder,
}
#[cfg(feature = "tls")]
impl TlsServer {
pub fn builder<I>(incoming: I) -> hyper::server::Builder<I> {
HyperServer::builder(incoming)
}
pub fn key_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.key_path(path))
}
pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.cert_path(path))
}
pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.client_auth_optional_path(path))
}
pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.client_auth_required_path(path))
}
pub fn key(self, key: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.key(key.as_ref()))
}
pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.cert(cert.as_ref()))
}
pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
}
pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
}
pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
}
fn with_tls<Func>(self, func: Func) -> Self
where
Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
{
let TlsServer { service, config } = self;
let config = func(config);
TlsServer { service, config }
}
#[inline]
fn create_bind_hyper_server(
self,
addr: impl Into<SocketAddr>,
) -> Result<(SocketAddr, hyper::Server<TlsAcceptor, Service>), crate::Error> {
let addr = addr.into();
let TlsServer { service, config } = self;
let tls = config.build().map_err(crate::Error::new)?;
let mut incoming = AddrIncoming::bind(&addr).map_err(crate::Error::new)?;
incoming.set_nodelay(true);
let srv = builder(TlsAcceptor::new(tls, incoming)).serve(service);
Ok((addr, srv))
}
pub fn start(self, addr: impl Into<SocketAddr> + 'static) {
self.start_with_threads(addr, num_cpus::get())
}
pub fn start_with_threads(self, addr: impl Into<SocketAddr> + 'static, threads: usize) {
let runtime = crate::new_runtime(threads);
let _ = runtime.block_on(async { self.bind(addr).await });
}
pub async fn bind(self, addr: impl Into<SocketAddr> + 'static) {
self.try_bind(addr).await.unwrap();
}
pub async fn try_bind(self, addr: impl Into<SocketAddr>) -> Result<SocketAddr, crate::Error> {
let (addr, srv) = self.create_bind_hyper_server(addr)?;
tracing::info!("tls listening with socket addr");
if let Err(err) = srv.await {
tracing::error!("server error: {}", err);
Err(crate::Error::new(err))
} else {
Ok(addr)
}
}
pub async fn try_bind_with_graceful_shutdown(
self,
addr: impl Into<SocketAddr> + 'static,
signal: impl Future<Output = ()> + Send + 'static,
) -> Result<SocketAddr, crate::Error> {
let (addr, srv) = self.create_bind_hyper_server(addr)?;
tracing::info!("tls listening with socket addr");
if let Err(err) = srv.with_graceful_shutdown(signal).await {
tracing::error!("server error: {}", err);
Err(crate::Error::new(err))
} else {
Ok(addr)
}
}
}