use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}};
use futures::{prelude::*, task::Context, task::Poll};
use log::debug;
use smallvec::SmallVec;
use std::{collections::VecDeque, fmt, pin::Pin};
pub struct ListenersStream<TTrans>
where
TTrans: Transport,
{
transport: TTrans,
listeners: VecDeque<Pin<Box<Listener<TTrans>>>>,
next_id: ListenerId
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ListenerId(u64);
#[pin_project::pin_project]
#[derive(Debug)]
struct Listener<TTrans>
where
TTrans: Transport,
{
id: ListenerId,
#[pin]
listener: TTrans::Listener,
addresses: SmallVec<[Multiaddr; 4]>
}
pub enum ListenersEvent<TTrans>
where
TTrans: Transport,
{
NewAddress {
listener_id: ListenerId,
listen_addr: Multiaddr
},
AddressExpired {
listener_id: ListenerId,
listen_addr: Multiaddr
},
Incoming {
listener_id: ListenerId,
upgrade: TTrans::ListenerUpgrade,
local_addr: Multiaddr,
send_back_addr: Multiaddr,
},
Closed {
listener_id: ListenerId,
addresses: Vec<Multiaddr>,
reason: Result<(), TTrans::Error>,
},
Error {
listener_id: ListenerId,
error: TTrans::Error,
}
}
impl<TTrans> ListenersStream<TTrans>
where
TTrans: Transport,
{
pub fn new(transport: TTrans) -> Self {
ListenersStream {
transport,
listeners: VecDeque::new(),
next_id: ListenerId(1)
}
}
pub fn with_capacity(transport: TTrans, capacity: usize) -> Self {
ListenersStream {
transport,
listeners: VecDeque::with_capacity(capacity),
next_id: ListenerId(1)
}
}
pub fn listen_on(&mut self, addr: Multiaddr) -> Result<ListenerId, TransportError<TTrans::Error>>
where
TTrans: Clone,
{
let listener = self.transport.clone().listen_on(addr)?;
self.listeners.push_back(Box::pin(Listener {
id: self.next_id,
listener,
addresses: SmallVec::new()
}));
let id = self.next_id;
self.next_id = ListenerId(self.next_id.0 + 1);
Ok(id)
}
pub fn remove_listener(&mut self, id: ListenerId) -> Result<(), ()> {
if let Some(i) = self.listeners.iter().position(|l| l.id == id) {
self.listeners.remove(i);
Ok(())
} else {
Err(())
}
}
pub fn transport(&self) -> &TTrans {
&self.transport
}
pub fn listen_addrs(&self) -> impl Iterator<Item = &Multiaddr> {
self.listeners.iter().flat_map(|l| l.addresses.iter())
}
pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<ListenersEvent<TTrans>> {
let mut remaining = self.listeners.len();
while let Some(mut listener) = self.listeners.pop_back() {
let mut listener_project = listener.as_mut().project();
match TryStream::try_poll_next(listener_project.listener.as_mut(), cx) {
Poll::Pending => {
self.listeners.push_front(listener);
remaining -= 1;
if remaining == 0 { break }
}
Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }))) => {
let id = *listener_project.id;
self.listeners.push_front(listener);
return Poll::Ready(ListenersEvent::Incoming {
listener_id: id,
upgrade,
local_addr,
send_back_addr: remote_addr
})
}
Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => {
if listener_project.addresses.contains(&a) {
debug!("Transport has reported address {} multiple times", a)
}
if !listener_project.addresses.contains(&a) {
listener_project.addresses.push(a.clone());
}
let id = *listener_project.id;
self.listeners.push_front(listener);
return Poll::Ready(ListenersEvent::NewAddress {
listener_id: id,
listen_addr: a
})
}
Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => {
listener_project.addresses.retain(|x| x != &a);
let id = *listener_project.id;
self.listeners.push_front(listener);
return Poll::Ready(ListenersEvent::AddressExpired {
listener_id: id,
listen_addr: a
})
}
Poll::Ready(Some(Ok(ListenerEvent::Error(error)))) => {
let id = *listener_project.id;
self.listeners.push_front(listener);
return Poll::Ready(ListenersEvent::Error {
listener_id: id,
error,
})
}
Poll::Ready(None) => {
return Poll::Ready(ListenersEvent::Closed {
listener_id: *listener_project.id,
addresses: listener_project.addresses.drain(..).collect(),
reason: Ok(()),
})
}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(ListenersEvent::Closed {
listener_id: *listener_project.id,
addresses: listener_project.addresses.drain(..).collect(),
reason: Err(err),
})
}
}
}
Poll::Pending
}
}
impl<TTrans> Stream for ListenersStream<TTrans>
where
TTrans: Transport,
{
type Item = ListenersEvent<TTrans>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
ListenersStream::poll(self, cx).map(Option::Some)
}
}
impl<TTrans> Unpin for ListenersStream<TTrans>
where
TTrans: Transport,
{
}
impl<TTrans> fmt::Debug for ListenersStream<TTrans>
where
TTrans: Transport + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("ListenersStream")
.field("transport", &self.transport)
.field("listen_addrs", &self.listen_addrs().collect::<Vec<_>>())
.finish()
}
}
impl<TTrans> fmt::Debug for ListenersEvent<TTrans>
where
TTrans: Transport,
TTrans::Error: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ListenersEvent::NewAddress { listener_id, listen_addr } => f
.debug_struct("ListenersEvent::NewAddress")
.field("listener_id", listener_id)
.field("listen_addr", listen_addr)
.finish(),
ListenersEvent::AddressExpired { listener_id, listen_addr } => f
.debug_struct("ListenersEvent::AddressExpired")
.field("listener_id", listener_id)
.field("listen_addr", listen_addr)
.finish(),
ListenersEvent::Incoming { listener_id, local_addr, .. } => f
.debug_struct("ListenersEvent::Incoming")
.field("listener_id", listener_id)
.field("local_addr", local_addr)
.finish(),
ListenersEvent::Closed { listener_id, addresses, reason } => f
.debug_struct("ListenersEvent::Closed")
.field("listener_id", listener_id)
.field("addresses", addresses)
.field("reason", reason)
.finish(),
ListenersEvent::Error { listener_id, error } => f
.debug_struct("ListenersEvent::Error")
.field("listener_id", listener_id)
.field("error", error)
.finish()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport;
#[test]
fn incoming_event() {
async_std::task::block_on(async move {
let mem_transport = transport::MemoryTransport::default();
let mut listeners = ListenersStream::new(mem_transport);
listeners.listen_on("/memory/0".parse().unwrap()).unwrap();
let address = {
let event = listeners.next().await.unwrap();
if let ListenersEvent::NewAddress { listen_addr, .. } = event {
listen_addr
} else {
panic!("Was expecting the listen address to be reported")
}
};
let address2 = address.clone();
async_std::task::spawn(async move {
mem_transport.dial(address2).unwrap().await.unwrap();
});
match listeners.next().await.unwrap() {
ListenersEvent::Incoming { local_addr, send_back_addr, .. } => {
assert_eq!(local_addr, address);
assert!(send_back_addr != address);
},
_ => panic!()
}
});
}
#[test]
fn listener_event_error_isnt_fatal() {
#[derive(Clone)]
struct DummyTrans;
impl transport::Transport for DummyTrans {
type Output = ();
type Error = std::io::Error;
type Listener = Pin<Box<dyn Stream<Item = Result<ListenerEvent<Self::ListenerUpgrade, std::io::Error>, std::io::Error>>>>;
type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>>>>;
type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>>>>;
fn listen_on(self, _: Multiaddr) -> Result<Self::Listener, transport::TransportError<Self::Error>> {
Ok(Box::pin(stream::unfold((), |()| async move {
Some((Ok(ListenerEvent::Error(std::io::Error::from(std::io::ErrorKind::Other))), ()))
})))
}
fn dial(self, _: Multiaddr) -> Result<Self::Dial, transport::TransportError<Self::Error>> {
panic!()
}
fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option<Multiaddr> { None }
}
async_std::task::block_on(async move {
let transport = DummyTrans;
let mut listeners = ListenersStream::new(transport);
listeners.listen_on("/memory/0".parse().unwrap()).unwrap();
for _ in 0..10 {
match listeners.next().await.unwrap() {
ListenersEvent::Error { .. } => {},
_ => panic!()
}
}
});
}
#[test]
fn listener_error_is_fatal() {
#[derive(Clone)]
struct DummyTrans;
impl transport::Transport for DummyTrans {
type Output = ();
type Error = std::io::Error;
type Listener = Pin<Box<dyn Stream<Item = Result<ListenerEvent<Self::ListenerUpgrade, std::io::Error>, std::io::Error>>>>;
type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>>>>;
type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>>>>;
fn listen_on(self, _: Multiaddr) -> Result<Self::Listener, transport::TransportError<Self::Error>> {
Ok(Box::pin(stream::unfold((), |()| async move {
Some((Err(std::io::Error::from(std::io::ErrorKind::Other)), ()))
})))
}
fn dial(self, _: Multiaddr) -> Result<Self::Dial, transport::TransportError<Self::Error>> {
panic!()
}
fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option<Multiaddr> { None }
}
async_std::task::block_on(async move {
let transport = DummyTrans;
let mut listeners = ListenersStream::new(transport);
listeners.listen_on("/memory/0".parse().unwrap()).unwrap();
match listeners.next().await.unwrap() {
ListenersEvent::Closed { .. } => {},
_ => panic!()
}
});
}
}