use crate::mediator;
use crate::wrappers::*;
use crate::Error;
use crate::MakeTransport;
use futures_core::{
future::Future,
ready,
stream::TryStream,
task::{Context, Poll},
};
use futures_sink::Sink;
use pin_project::pin_project;
use std::collections::VecDeque;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{atomic, Arc};
use std::{error, fmt};
use tower_service::Service;
#[cfg(feature = "tracing")]
use tracing::Level;
pub struct Maker<NT, Request> {
t_maker: NT,
_req: PhantomData<fn(Request)>,
}
impl<NT, Request> Maker<NT, Request> {
pub fn new(t: NT) -> Self {
Maker {
t_maker: t,
_req: PhantomData,
}
}
}
#[derive(Debug)]
pub enum SpawnError<E> {
SpawnFailed,
Inner(E),
}
impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
where
NT: MakeTransport<Target, Request>,
NT::Transport: 'static + Send,
Request: 'static + Send,
NT::Item: 'static + Send,
NT::SinkError: 'static + Send + Sync,
NT::Error: 'static + Send + Sync,
NT::Future: 'static + Send,
{
type Error = SpawnError<NT::MakeError>;
type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&mut self, target: Target) -> Self::Future {
let maker = self.t_maker.make_transport(target);
Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
}
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
}
}
impl<NT, Request> tower_load::Load for Maker<NT, Request> {
type Metric = u8;
fn load(&self) -> Self::Metric {
0
}
}
pub struct Client<T, E, Request>
where
T: Sink<Request> + TryStream,
{
mediator: mediator::Sender<ClientRequest<T, Request>>,
in_flight: Arc<atomic::AtomicUsize>,
_error: PhantomData<fn(E)>,
}
struct Pending<Item> {
tx: tokio_sync::oneshot::Sender<ClientResponse<Item>>,
#[cfg(feature = "tracing")]
span: tracing::Span,
}
#[pin_project]
struct ClientInner<T, E, Request>
where
T: Sink<Request> + TryStream,
{
mediator: mediator::Receiver<ClientRequest<T, Request>>,
responses: VecDeque<Pending<T::Ok>>,
#[pin]
transport: T,
in_flight: Arc<atomic::AtomicUsize>,
finish: bool,
rx_only: bool,
#[allow(unused)]
error: PhantomData<fn(E)>,
}
impl<T, E, Request> Client<T, E, Request>
where
T: Sink<Request> + TryStream + Send + 'static,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T::Ok: 'static + Send,
{
pub fn new(transport: T) -> Self where {
Self::with_error_handler(transport, |_| {})
}
pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
where
F: FnOnce(E) + Send + 'static,
{
let (tx, rx) = mediator::new();
let in_flight = Arc::new(atomic::AtomicUsize::new(0));
tokio_executor::spawn({
let c = ClientInner {
mediator: rx,
responses: Default::default(),
transport,
in_flight: in_flight.clone(),
error: PhantomData::<fn(E)>,
finish: false,
rx_only: false,
};
async move {
if let Err(e) = c.await {
on_service_error(e);
}
}
});
Client {
mediator: tx,
in_flight,
_error: PhantomData,
}
}
}
impl<T, E, Request> Future for ClientInner<T, E, Request>
where
T: Sink<Request> + TryStream,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T::Ok: 'static + Send,
{
type Output = Result<(), E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
let mut transport: Pin<_> = this.transport;
while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
if let Err(e) = r {
return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
}
match this.mediator.try_recv(cx) {
Poll::Ready(Some(ClientRequest {
req,
span: _span,
res,
})) => {
#[cfg(feature = "tracing")]
let guard = _span.enter();
#[cfg(feature = "tracing")]
tracing::event!(Level::TRACE, "request received by worker; sending to Sink");
transport
.as_mut()
.start_send(req)
.map_err(Error::from_sink_error)?;
#[cfg(feature = "tracing")]
tracing::event!(Level::TRACE, "request sent");
#[cfg(feature = "tracing")]
drop(guard);
this.responses.push_back(Pending {
tx: res,
#[cfg(feature = "tracing")]
span: _span,
});
this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
}
Poll::Ready(None) => {
*this.finish = true;
break;
}
Poll::Pending => {
break;
}
}
}
if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
if *this.finish {
let _ = transport
.as_mut()
.poll_close(cx)
.map_err(Error::from_sink_error)?;
*this.rx_only = true;
} else {
let _ = transport
.as_mut()
.poll_flush(cx)
.map_err(Error::from_sink_error)?;
}
}
while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
match ready!(transport.as_mut().try_poll_next(cx))
.transpose()
.map_err(Error::from_stream_error)?
{
Some(r) => {
let pending = this
.responses
.pop_front()
.expect("got a request with no sender?");
event!(pending.span, Level::TRACE, "response arrived; forwarding");
let sender = pending.tx;
let _ = sender.send(ClientResponse {
response: r,
#[cfg(feature = "tracing")]
span: pending.span,
});
this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
}
None => {
return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
}
}
}
if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
if *this.rx_only {
} else {
ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
}
return Poll::Ready(Ok(()));
}
Poll::Pending
}
}
impl<T, E, Request> Service<Request> for Client<T, E, Request>
where
T: Sink<Request> + TryStream,
E: From<Error<T, Request>>,
E: 'static + Send,
Request: 'static + Send,
T: 'static,
T::Ok: 'static + Send,
{
type Response = T::Ok;
type Error = E;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), E>> {
Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
}
fn call(&mut self, req: Request) -> Self::Future {
let (tx, rx) = tokio_sync::oneshot::channel();
#[cfg(feature = "tracing")]
let span = tracing::Span::current();
#[cfg(not(feature = "tracing"))]
let span = ();
event!(span, Level::TRACE, "issuing request");
let req = ClientRequest { req, span, res: tx };
let r = self.mediator.try_send(req);
Box::pin(async move {
match r {
Ok(()) => match rx.await {
Ok(r) => {
event!(r.span, tracing::Level::TRACE, "response returned");
Ok(r.response)
}
Err(_) => Err(E::from(Error::ClientDropped)),
},
Err(_) => Err(E::from(Error::TransportFull)),
}
})
}
}
impl<T, E, Request> tower_load::Load for Client<T, E, Request>
where
T: Sink<Request> + TryStream,
{
type Metric = usize;
fn load(&self) -> Self::Metric {
self.in_flight.load(atomic::Ordering::Acquire)
}
}
impl<T> fmt::Display for SpawnError<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SpawnError::SpawnFailed => write!(f, "error spawning multiplex client"),
SpawnError::Inner(ref te) => {
write!(f, "error making new multiplex transport: {:?}", te)
}
}
}
}
impl<T> error::Error for SpawnError<T>
where
T: error::Error,
{
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
SpawnError::SpawnFailed => None,
SpawnError::Inner(ref te) => Some(te),
}
}
fn description(&self) -> &str {
"error creating new multiplex client"
}
}