use std::{
collections::HashSet,
fmt,
io,
net::{IpAddr, SocketAddr},
ops::Deref,
sync::{
atomic::{AtomicUsize, Ordering::*},
Arc,
},
time::Duration,
};
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use tokio::{
io::split,
net::{TcpListener, TcpStream},
sync::oneshot,
task::JoinHandle,
time::timeout,
};
use tracing::*;
use crate::{
connections::{Connection, ConnectionSide, Connections},
protocols::{Protocol, Protocols},
Config,
KnownPeers,
Stats,
};
static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
pub struct Tcp(Arc<InnerTcp>);
impl Deref for Tcp {
type Target = Arc<InnerTcp>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[doc(hidden)]
pub struct InnerTcp {
span: Span,
config: Config,
listening_addr: OnceCell<SocketAddr>,
pub(crate) protocols: Protocols,
connecting: Mutex<HashSet<SocketAddr>>,
connections: Connections,
known_peers: KnownPeers,
stats: Stats,
pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
}
impl Tcp {
pub fn new(mut config: Config) -> Self {
if config.name.is_none() {
config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string());
}
let span = crate::helpers::create_span(config.name.as_deref().unwrap());
let tcp = Tcp(Arc::new(InnerTcp {
span,
config,
listening_addr: Default::default(),
protocols: Default::default(),
connecting: Default::default(),
connections: Default::default(),
known_peers: Default::default(),
stats: Default::default(),
tasks: Default::default(),
}));
debug!(parent: tcp.span(), "The node is ready");
tcp
}
#[inline]
pub fn name(&self) -> &str {
self.config.name.as_deref().unwrap()
}
#[inline]
pub fn config(&self) -> &Config {
&self.config
}
pub fn listening_addr(&self) -> io::Result<SocketAddr> {
self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
}
pub fn is_connected(&self, addr: SocketAddr) -> bool {
self.connections.is_connected(addr)
}
pub fn is_connecting(&self, addr: SocketAddr) -> bool {
self.connecting.lock().contains(&addr)
}
pub fn num_connected(&self) -> usize {
self.connections.num_connected()
}
pub fn num_connecting(&self) -> usize {
self.connecting.lock().len()
}
pub fn connected_addrs(&self) -> Vec<SocketAddr> {
self.connections.addrs()
}
pub fn connecting_addrs(&self) -> Vec<SocketAddr> {
self.connecting.lock().iter().copied().collect()
}
#[inline]
pub fn known_peers(&self) -> &KnownPeers {
&self.known_peers
}
#[inline]
pub fn stats(&self) -> &Stats {
&self.stats
}
#[inline]
pub fn span(&self) -> &Span {
&self.span
}
pub async fn shut_down(&self) {
debug!(parent: self.span(), "Shutting down the TCP stack");
let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
if let Some(listening_task) = tasks.next() {
listening_task.abort(); }
for addr in self.connected_addrs() {
self.disconnect(addr).await;
}
for handle in tasks {
handle.abort();
}
}
}
impl Tcp {
pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
if let Ok(listening_addr) = self.listening_addr() {
if addr == listening_addr || self.is_self_connect(addr) {
error!(parent: self.span(), "Attempted to self-connect ({addr})");
return Err(io::ErrorKind::AddrInUse.into());
}
}
if !self.can_add_connection() {
error!(parent: self.span(), "Too many connections; refusing to connect to {addr}");
return Err(io::ErrorKind::ConnectionRefused.into());
}
if self.is_connected(addr) {
warn!(parent: self.span(), "Already connected to {addr}");
return Err(io::ErrorKind::AlreadyExists.into());
}
if !self.connecting.lock().insert(addr) {
warn!(parent: self.span(), "Already connecting to {addr}");
return Err(io::ErrorKind::AlreadyExists.into());
}
let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into());
let res = if let Some(listen_ip) = self.config().listener_ip {
let sock =
if listen_ip.is_ipv4() { tokio::net::TcpSocket::new_v4()? } else { tokio::net::TcpSocket::new_v6()? };
sock.bind(SocketAddr::new(listen_ip, 0))?;
timeout(timeout_duration, sock.connect(addr)).await
} else {
timeout(timeout_duration, TcpStream::connect(addr)).await
};
let stream = match res {
Ok(Ok(stream)) => Ok(stream),
Ok(err) => {
self.connecting.lock().remove(&addr);
err
}
Err(err) => {
self.connecting.lock().remove(&addr);
error!("connection timeout error: {}", err);
Err(io::ErrorKind::TimedOut.into())
}
}?;
let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await;
if let Err(ref e) = ret {
self.connecting.lock().remove(&addr);
self.known_peers().register_failure(addr);
error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}");
}
ret
}
pub async fn disconnect(&self, addr: SocketAddr) -> bool {
if let Some(handler) = self.protocols.disconnect.get() {
if self.is_connected(addr) {
let (sender, receiver) = oneshot::channel();
handler.trigger((addr, sender));
let _ = receiver.await; }
}
let conn = self.connections.remove(addr);
if let Some(ref conn) = conn {
debug!(parent: self.span(), "Disconnecting from {}", conn.addr());
for task in conn.tasks.iter().rev() {
task.abort();
}
if conn.side() == ConnectionSide::Initiator {
self.known_peers().remove(conn.addr());
}
debug!(parent: self.span(), "Disconnected from {}", conn.addr());
} else {
warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}");
}
conn.is_some()
}
}
impl Tcp {
pub async fn enable_listener(&self) -> io::Result<SocketAddr> {
let listener_ip =
self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set");
let listener = self.create_listener(listener_ip).await?;
let port = listener.local_addr()?.port();
let listening_addr = (listener_ip, port).into();
self.listening_addr.set(listening_addr).expect("The node's listener was started more than once");
let (tx, rx) = oneshot::channel();
let tcp = self.clone();
let listening_task = tokio::spawn(async move {
trace!(parent: tcp.span(), "Spawned the listening task");
tx.send(()).unwrap(); loop {
match listener.accept().await {
Ok((stream, addr)) => tcp.handle_connection(stream, addr),
Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"),
}
}
});
self.tasks.lock().push(listening_task);
let _ = rx.await;
debug!(parent: self.span(), "Listening on {listening_addr}");
Ok(listening_addr)
}
async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
debug!("Creating a TCP listener on {listener_ip}...");
let listener = if let Some(port) = self.config().desired_listening_port {
let desired_listening_addr = SocketAddr::new(listener_ip, port);
match TcpListener::bind(desired_listening_addr).await {
Ok(listener) => listener,
Err(e) => {
if self.config().allow_random_port {
warn!(
parent: self.span(),
"Trying any listening port, as the desired port is unavailable: {e}"
);
let random_available_addr = SocketAddr::new(listener_ip, 0);
TcpListener::bind(random_available_addr).await?
} else {
error!(parent: self.span(), "The desired listening port is unavailable: {e}");
return Err(e);
}
}
}
} else if self.config().allow_random_port {
let random_available_addr = SocketAddr::new(listener_ip, 0);
TcpListener::bind(random_available_addr).await?
} else {
panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set");
};
Ok(listener)
}
fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) {
debug!(parent: self.span(), "Received a connection from {addr}");
if !self.can_add_connection() || self.is_self_connect(addr) {
debug!(parent: self.span(), "Rejecting the connection from {addr}");
return;
}
self.connecting.lock().insert(addr);
let tcp = self.clone();
tokio::spawn(async move {
if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await {
tcp.connecting.lock().remove(&addr);
tcp.known_peers().register_failure(addr);
error!(parent: tcp.span(), "Failed to connect with {addr}: {e}");
}
});
}
fn is_self_connect(&self, addr: SocketAddr) -> bool {
let listening_addr = self.listening_addr().unwrap();
match listening_addr.ip().is_loopback() {
true => listening_addr.port() == addr.port(),
false => listening_addr.ip() == addr.ip(),
}
}
fn can_add_connection(&self) -> bool {
let num_connected = self.num_connected();
let limit = self.config.max_connections as usize;
if num_connected >= limit {
warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached");
false
} else if num_connected + self.num_connecting() >= limit {
warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached");
false
} else {
true
}
}
async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> {
self.known_peers.add(peer_addr);
if own_side == ConnectionSide::Initiator {
if let Ok(addr) = stream.local_addr() {
debug!(
parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
peer_addr, addr.port()
);
} else {
warn!(parent: self.span(), "couldn't determine the peer's port");
}
}
let connection = Connection::new(peer_addr, stream, !own_side);
let mut connection = self.enable_protocols(connection).await?;
let conn_ready_tx = connection.readiness_notifier.take();
self.connections.add(connection);
self.connecting.lock().remove(&peer_addr);
if let Some(tx) = conn_ready_tx {
let _ = tx.send(());
}
if let Some(handler) = self.protocols.on_connect.get() {
let (sender, receiver) = oneshot::channel();
handler.trigger((peer_addr, sender));
let _ = receiver.await; }
Ok(())
}
async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
macro_rules! enable_protocol {
($handler_type: ident, $node:expr, $conn: expr) => {
if let Some(handler) = $node.protocols.$handler_type.get() {
let (conn_returner, conn_retriever) = oneshot::channel();
handler.trigger(($conn, conn_returner));
match conn_retriever.await {
Ok(Ok(conn)) => conn,
Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
Ok(e) => return e,
}
} else {
$conn
}
};
}
let mut conn = enable_protocol!(handshake, self, conn);
if let Some(stream) = conn.stream.take() {
let (reader, writer) = split(stream);
conn.reader = Some(Box::new(reader));
conn.writer = Some(Box::new(writer));
}
let conn = enable_protocol!(reading, self, conn);
let conn = enable_protocol!(writing, self, conn);
Ok(conn)
}
}
impl fmt::Debug for Tcp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "The TCP stack config: {:?}", self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[tokio::test]
async fn test_new() {
let tcp = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
max_connections: 200,
..Default::default()
});
assert_eq!(tcp.config.max_connections, 200);
assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
assert_eq!(tcp.num_connected(), 0);
assert_eq!(tcp.num_connecting(), 0);
}
#[tokio::test]
async fn test_connect() {
let tcp = Tcp::new(Config::default());
let node_ip = tcp.enable_listener().await.unwrap();
tcp.connect(node_ip).await.unwrap_err();
assert_eq!(tcp.num_connected(), 0);
assert_eq!(tcp.num_connecting(), 0);
assert!(!tcp.is_connected(node_ip));
assert!(!tcp.is_connecting(node_ip));
let peer = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer_ip = peer.enable_listener().await.unwrap();
tcp.connect(peer_ip).await.unwrap();
assert_eq!(tcp.num_connected(), 1);
assert_eq!(tcp.num_connecting(), 0);
assert!(tcp.is_connected(peer_ip));
assert!(!tcp.is_connecting(peer_ip));
}
#[tokio::test]
async fn test_disconnect() {
let tcp = Tcp::new(Config::default());
let _node_ip = tcp.enable_listener().await.unwrap();
let peer = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer_ip = peer.enable_listener().await.unwrap();
tcp.connect(peer_ip).await.unwrap();
assert_eq!(tcp.num_connected(), 1);
assert_eq!(tcp.num_connecting(), 0);
assert!(tcp.is_connected(peer_ip));
assert!(!tcp.is_connecting(peer_ip));
tcp.disconnect(peer_ip).await;
assert_eq!(tcp.num_connected(), 0);
assert_eq!(tcp.num_connecting(), 0);
assert!(!tcp.is_connected(peer_ip));
assert!(!tcp.is_connecting(peer_ip));
tcp.disconnect(peer_ip).await;
assert_eq!(tcp.num_connected(), 0);
assert_eq!(tcp.num_connecting(), 0);
assert!(!tcp.is_connected(peer_ip));
assert!(!tcp.is_connecting(peer_ip));
}
#[tokio::test]
async fn test_can_add_connection() {
let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
let peer = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer_ip = peer.enable_listener().await.unwrap();
assert!(tcp.can_add_connection());
let stream = TcpStream::connect(peer_ip).await.unwrap();
tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator));
assert!(!tcp.can_add_connection());
tcp.connections.remove(peer_ip);
assert!(tcp.can_add_connection());
tcp.connecting.lock().insert(peer_ip);
assert!(!tcp.can_add_connection());
tcp.connecting.lock().remove(&peer_ip);
assert!(tcp.can_add_connection());
let stream = TcpStream::connect(peer_ip).await.unwrap();
tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder));
tcp.connecting.lock().insert(peer_ip);
assert!(!tcp.can_add_connection());
tcp.connections.remove(peer_ip);
tcp.connecting.lock().remove(&peer_ip);
assert!(tcp.can_add_connection());
}
#[tokio::test]
async fn test_handle_connection() {
let tcp = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
max_connections: 1,
..Default::default()
});
let peer1 = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer1_ip = peer1.enable_listener().await.unwrap();
let stream = TcpStream::connect(peer1_ip).await.unwrap();
tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder));
assert!(!tcp.can_add_connection());
assert_eq!(tcp.num_connected(), 1);
assert_eq!(tcp.num_connecting(), 0);
assert!(tcp.is_connected(peer1_ip));
assert!(!tcp.is_connecting(peer1_ip));
let peer2 = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer2_ip = peer2.enable_listener().await.unwrap();
let stream = TcpStream::connect(peer2_ip).await.unwrap();
tcp.handle_connection(stream, peer2_ip);
assert!(!tcp.can_add_connection());
assert_eq!(tcp.num_connected(), 1);
assert_eq!(tcp.num_connecting(), 0);
assert!(tcp.is_connected(peer1_ip));
assert!(!tcp.is_connected(peer2_ip));
assert!(!tcp.is_connecting(peer1_ip));
assert!(!tcp.is_connecting(peer2_ip));
}
#[tokio::test]
async fn test_adapt_stream() {
let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
let peer = Tcp::new(Config {
listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
desired_listening_port: Some(0),
max_connections: 1,
..Default::default()
});
let peer_ip = peer.enable_listener().await.unwrap();
tcp.connecting.lock().insert(peer_ip);
assert_eq!(tcp.num_connected(), 0);
assert_eq!(tcp.num_connecting(), 1);
assert!(!tcp.is_connected(peer_ip));
assert!(tcp.is_connecting(peer_ip));
let stream = TcpStream::connect(peer_ip).await.unwrap();
tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap();
assert_eq!(tcp.num_connected(), 1);
assert_eq!(tcp.num_connecting(), 0);
assert!(tcp.is_connected(peer_ip));
assert!(!tcp.is_connecting(peer_ip));
}
}