use super::service::{layer_fn, BoxedIo, ServiceBuilderExt};
#[cfg(feature = "tls")]
use super::{
service::TlsAcceptor,
tls::{Identity, TlsProvider},
Certificate,
};
use crate::body::BoxBody;
use futures_core::Stream;
use futures_util::{ready, try_future::MapErr, TryFutureExt, TryStreamExt};
use http::{Request, Response};
use hyper::{
server::{accept::Accept, conn},
Body,
};
use std::{
fmt,
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{
layer::{util::Stack, Layer},
limit::concurrency::ConcurrencyLimitLayer,
Service,
ServiceBuilder,
};
use tower_make::MakeService;
#[cfg(feature = "tls")]
use tracing::error;
type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
type Interceptor = Arc<dyn Layer<BoxService, Service = BoxService> + Send + Sync + 'static>;
#[derive(Default, Clone)]
pub struct Server {
interceptor: Option<Interceptor>,
concurrency_limit: Option<usize>,
#[cfg(feature = "tls")]
tls: Option<TlsAcceptor>,
init_stream_window_size: Option<u32>,
init_connection_window_size: Option<u32>,
max_concurrent_streams: Option<u32>,
}
impl Server {
pub fn builder() -> Self {
Default::default()
}
}
impl Server {
#[cfg(feature = "tls")]
pub fn tls_config(&mut self, tls_config: &ServerTlsConfig) -> &mut Self {
self.tls = Some(tls_config.tls_acceptor().unwrap());
self
}
pub fn concurrency_limit_per_connection(&mut self, limit: usize) -> &mut Self {
self.concurrency_limit = Some(limit);
self
}
pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.init_stream_window_size = sz.into();
self
}
pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.init_connection_window_size = sz.into();
self
}
pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
self.max_concurrent_streams = max.into();
self
}
pub fn interceptor_fn<F, Out>(&mut self, f: F) -> &mut Self
where
F: Fn(&mut BoxService, Request<Body>) -> Out + Send + Sync + 'static,
Out: Future<Output = Result<Response<BoxBody>, crate::Error>> + Send + 'static,
{
let f = Arc::new(f);
let interceptor = layer_fn(move |mut s| {
let f = f.clone();
tower::service_fn(move |req| f(&mut s, req))
});
let layer = Stack::new(interceptor, layer_fn(BoxService::new));
self.interceptor = Some(Arc::new(layer));
self
}
pub async fn serve<M, S>(self, addr: SocketAddr, svc: M) -> Result<(), super::Error>
where
M: Service<(), Response = S>,
M::Error: Into<crate::Error> + Send + 'static,
M::Future: Send + 'static,
S: Service<Request<Body>, Response = Response<BoxBody>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
let interceptor = self.interceptor.clone();
let concurrency_limit = self.concurrency_limit;
let init_connection_window_size = self.init_connection_window_size;
let init_stream_window_size = self.init_stream_window_size;
let max_concurrent_streams = self.max_concurrent_streams;
let incoming = hyper::server::accept::from_stream(async_stream::try_stream! {
let mut tcp = TcpIncoming::bind(addr)?;
while let Some(stream) = tcp.try_next().await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &self.tls {
let io = match tls.connect(stream.into_inner()).await {
Ok(io) => io,
Err(error) => {
error!(message = "Unable to accept incoming connection.", %error);
continue
},
};
yield BoxedIo::new(io);
continue;
}
}
yield BoxedIo::new(stream);
}
});
let svc = MakeSvc {
inner: svc,
interceptor,
concurrency_limit,
};
hyper::Server::builder(incoming)
.http2_only(true)
.http2_initial_connection_window_size(init_connection_window_size)
.http2_initial_stream_window_size(init_stream_window_size)
.http2_max_concurrent_streams(max_concurrent_streams)
.serve(svc)
.await
.map_err(map_err)?;
Ok(())
}
}
fn map_err(e: impl Into<crate::Error>) -> super::Error {
super::Error::from_source(super::ErrorKind::Server, e.into())
}
impl fmt::Debug for Server {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Builder").finish()
}
}
#[cfg(feature = "tls")]
#[derive(Clone)]
pub struct ServerTlsConfig {
provider: TlsProvider,
identity: Option<Identity>,
client_ca_root: Option<Certificate>,
#[cfg(feature = "openssl")]
openssl_raw: Option<openssl1::ssl::SslAcceptor>,
#[cfg(feature = "rustls")]
rustls_raw: Option<tokio_rustls::rustls::ServerConfig>,
}
#[cfg(feature = "tls")]
impl fmt::Debug for ServerTlsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerTlsConfig")
.field("provider", &self.provider)
.finish()
}
}
#[cfg(feature = "tls")]
impl ServerTlsConfig {
#[cfg(feature = "openssl")]
pub fn with_openssl() -> Self {
Self::new(TlsProvider::OpenSsl)
}
#[cfg(feature = "rustls")]
pub fn with_rustls() -> Self {
Self::new(TlsProvider::Rustls)
}
fn new(provider: TlsProvider) -> Self {
ServerTlsConfig {
provider,
identity: None,
client_ca_root: None,
#[cfg(feature = "openssl")]
openssl_raw: None,
#[cfg(feature = "rustls")]
rustls_raw: None,
}
}
pub fn identity(&mut self, identity: Identity) -> &mut Self {
self.identity = Some(identity);
self
}
pub fn client_ca_root(&mut self, cert: Certificate) -> &mut Self {
self.client_ca_root = Some(cert);
self
}
#[cfg(feature = "openssl")]
pub fn openssl_connector(&mut self, acceptor: openssl1::ssl::SslAcceptor) -> &mut Self {
self.openssl_raw = Some(acceptor);
self
}
#[cfg(feature = "rustls")]
pub fn rustls_server_config(
&mut self,
config: tokio_rustls::rustls::ServerConfig,
) -> &mut Self {
self.rustls_raw = Some(config);
self
}
fn tls_acceptor(&self) -> Result<TlsAcceptor, crate::Error> {
match self.provider {
#[cfg(feature = "openssl")]
TlsProvider::OpenSsl => match &self.openssl_raw {
None => TlsAcceptor::new_with_openssl_identity(
self.identity.clone().unwrap(),
self.client_ca_root.clone(),
),
Some(acceptor) => TlsAcceptor::new_with_openssl_raw(acceptor.clone()),
},
#[cfg(feature = "rustls")]
TlsProvider::Rustls => match &self.rustls_raw {
None => TlsAcceptor::new_with_rustls_identity(
self.identity.clone().unwrap(),
self.client_ca_root.clone(),
),
Some(config) => TlsAcceptor::new_with_rustls_raw(config.clone()),
},
}
}
}
#[derive(Debug)]
struct TcpIncoming {
inner: conn::AddrIncoming,
}
impl TcpIncoming {
fn bind(addr: SocketAddr) -> Result<Self, crate::Error> {
let inner = conn::AddrIncoming::bind(&addr).map_err(Box::new)?;
Ok(Self { inner })
}
}
impl Stream for TcpIncoming {
type Item = Result<conn::AddrStream, crate::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Accept::poll_accept(Pin::new(&mut self.inner), cx)) {
Some(Ok(s)) => Poll::Ready(Some(Ok(s))),
Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
None => Poll::Ready(None),
}
}
}
#[derive(Debug)]
struct Svc<S>(S);
impl<S> Service<Request<Body>> for Svc<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>>,
S::Error: Into<crate::Error>,
{
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = MapErr<S::Future, fn(S::Error) -> crate::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
self.0.call(req).map_err(|e| e.into())
}
}
struct MakeSvc<M> {
interceptor: Option<Interceptor>,
concurrency_limit: Option<usize>,
inner: M,
}
impl<M, S, T> Service<T> for MakeSvc<M>
where
M: Service<(), Response = S>,
M::Error: Into<crate::Error> + Send,
M::Future: Send + 'static,
S: Service<Request<Body>, Response = Response<BoxBody>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
type Response = BoxService;
type Error = crate::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
MakeService::poll_ready(&mut self.inner, cx).map_err(Into::into)
}
fn call(&mut self, _: T) -> Self::Future {
let interceptor = self.interceptor.clone();
let make = self.inner.make_service(());
let concurrency_limit = self.concurrency_limit;
Box::pin(async move {
let svc = make.await.map_err(Into::into)?;
let svc = ServiceBuilder::new()
.optional_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.service(svc);
let svc = if let Some(interceptor) = interceptor {
let layered = interceptor.layer(BoxService::new(Svc(svc)));
BoxService::new(Svc(layered))
} else {
BoxService::new(Svc(svc))
};
Ok(svc)
})
}
}