use std::pin::Pin;
use std::task::{Context, Poll};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::stream::{Peekable, Stream, StreamExt};
use futures::{Future, FutureExt};
use log::{debug, warn};
use crate::error::*;
use crate::xfer::{
DnsRequest, DnsRequestSender, DnsRequestStreamHandle, DnsResponse, OneshotDnsRequest,
};
#[must_use = "futures do nothing unless polled"]
pub struct DnsExchange<S, R>
where
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
io_stream: S,
outbound_messages: Peekable<UnboundedReceiver<OneshotDnsRequest<R>>>,
}
impl<S, R> DnsExchange<S, R>
where
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
pub fn from_stream(stream: S) -> (Self, DnsRequestStreamHandle<R>) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = DnsRequestStreamHandle::<R>::new(message_sender);
let stream = Self::from_stream_with_receiver(stream, outbound_messages);
(stream, message_sender)
}
pub fn from_stream_with_receiver(
stream: S,
receiver: UnboundedReceiver<OneshotDnsRequest<R>>,
) -> Self {
DnsExchange {
io_stream: stream,
outbound_messages: receiver.peekable(),
}
}
pub fn connect<F>(connect_future: F) -> (DnsExchangeConnect<F, S, R>, DnsRequestStreamHandle<R>)
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
{
let (message_sender, outbound_messages) = unbounded();
(
DnsExchangeConnect::connect(connect_future, outbound_messages),
DnsRequestStreamHandle::<R>::new(message_sender),
)
}
fn pollable_split(
&mut self,
) -> (
&mut S,
&mut Peekable<UnboundedReceiver<OneshotDnsRequest<R>>>,
) {
(&mut self.io_stream, &mut self.outbound_messages)
}
}
impl<S, R> Future for DnsExchange<S, R>
where
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
type Output = Result<(), ProtoError>;
#[allow(clippy::unused_unit)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (io_stream, outbound_messages) = self.pollable_split();
let mut io_stream = Pin::new(io_stream);
let mut outbound_messages = Pin::new(outbound_messages);
loop {
match io_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(()))) => (),
Poll::Pending => {
if io_stream.is_shutdown() {
return Poll::Pending;
}
()
}
Poll::Ready(None) => {
debug!("io_stream is done, shutting down");
return Poll::Ready(Ok(()));
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)),
}
match outbound_messages.as_mut().poll_next(cx) {
Poll::Ready(Some(dns_request)) => {
let (dns_request, serial_response): (DnsRequest, _) = dns_request.unwrap();
debug!("sending message via: {}", io_stream);
match serial_response.send_response(io_stream.send_message(dns_request, cx)) {
Ok(()) => (),
Err(_) => {
warn!("failed to associate send_message response to the sender");
return Poll::Ready(Err(
"failed to associate send_message response to the sender".into(),
));
}
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
debug!("all handles closed, shutting down: {}", io_stream);
io_stream.shutdown();
}
}
}
}
}
pub struct DnsExchangeConnect<F, S, R>(DnsExchangeConnectInner<F, S, R>)
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin;
impl<F, S, R> DnsExchangeConnect<F, S, R>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
fn connect(
connect_future: F,
outbound_messages: UnboundedReceiver<OneshotDnsRequest<R>>,
) -> Self {
DnsExchangeConnect(DnsExchangeConnectInner::Connecting {
connect_future,
outbound_messages: Some(outbound_messages),
})
}
}
impl<F, S, R> Future for DnsExchangeConnect<F, S, R>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
type Output = Result<DnsExchange<S, R>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
enum DnsExchangeConnectInner<F, S, R>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send,
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
Connecting {
connect_future: F,
outbound_messages: Option<UnboundedReceiver<OneshotDnsRequest<R>>>,
},
FailAll {
error: ProtoError,
outbound_messages: UnboundedReceiver<OneshotDnsRequest<R>>,
},
}
impl<F, S, R> Future for DnsExchangeConnectInner<F, S, R>
where
F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
S: DnsRequestSender<DnsResponseFuture = R>,
R: Future<Output = Result<DnsResponse, ProtoError>> + 'static + Send + Unpin,
{
type Output = Result<DnsExchange<S, R>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
let next;
match *self {
DnsExchangeConnectInner::Connecting {
ref mut connect_future,
ref mut outbound_messages,
} => {
let connect_future = Pin::new(connect_future);
match connect_future.poll(cx) {
Poll::Ready(Ok(stream)) => {
debug!("connection established: {}", stream);
return Poll::Ready(Ok(DnsExchange::from_stream_with_receiver(
stream,
outbound_messages
.take()
.expect("cannot poll after complete"),
)));
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(error)) => {
debug!("stream errored while connecting: {:?}", error);
next = DnsExchangeConnectInner::FailAll {
error,
outbound_messages: outbound_messages
.take()
.expect("cannot poll after complete"),
}
}
};
}
DnsExchangeConnectInner::FailAll {
ref error,
ref mut outbound_messages,
} => {
while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
Poll::Ready(opt) => opt,
Poll::Pending => return Poll::Pending,
} {
let response = S::error_response(error.clone());
outbound_message.unwrap().1.send_response(response).ok();
}
return Poll::Ready(Err(error.clone()));
}
}
*self = next;
}
}
}