use crate::{
messages::{ChallengeRequest, ChallengeResponse, DisconnectReason, Message, MessageCodec, MessageTrait},
Peer,
Router,
};
use snarkos_node_tcp::{ConnectionSide, Tcp, P2P};
use snarkvm::{
ledger::narwhal::Data,
prelude::{block::Header, error, Address, Network},
};
use anyhow::{bail, Result};
use futures::SinkExt;
use rand::{rngs::OsRng, Rng};
use std::{io, net::SocketAddr};
use tokio::net::TcpStream;
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
impl<N: Network> P2P for Router<N> {
fn tcp(&self) -> &Tcp {
&self.tcp
}
}
#[macro_export]
macro_rules! expect_message {
($msg_ty:path, $framed:expr, $peer_addr:expr) => {
match $framed.try_next().await? {
Some($msg_ty(data)) => {
trace!("Received '{}' from '{}'", data.name(), $peer_addr);
data
}
Some(Message::Disconnect(reason)) => {
return Err(error(format!("'{}' disconnected: {reason:?}", $peer_addr)))
}
Some(ty) => {
return Err(error(format!(
"'{}' did not follow the handshake protocol: received {:?} instead of {}",
$peer_addr,
ty.name(),
stringify!($msg_ty),
)))
}
None => {
return Err(error(format!("'{}' disconnected before sending {:?}", $peer_addr, stringify!($msg_ty),)))
}
}
};
}
async fn send<N: Network>(
framed: &mut Framed<&mut TcpStream, MessageCodec<N>>,
peer_addr: SocketAddr,
message: Message<N>,
) -> io::Result<()> {
trace!("Sending '{}' to '{peer_addr}'", message.name());
framed.send(message).await
}
impl<N: Network> Router<N> {
pub async fn handshake<'a>(
&'a self,
peer_addr: SocketAddr,
stream: &'a mut TcpStream,
peer_side: ConnectionSide,
genesis_header: Header<N>,
) -> io::Result<(SocketAddr, Framed<&mut TcpStream, MessageCodec<N>>)> {
let mut peer_ip = if peer_side == ConnectionSide::Initiator {
debug!("Received a connection request from '{peer_addr}'");
None
} else {
debug!("Connecting to {peer_addr}...");
Some(peer_addr)
};
let handshake_result = if peer_side == ConnectionSide::Responder {
self.handshake_inner_initiator(peer_addr, &mut peer_ip, stream, genesis_header).await
} else {
self.handshake_inner_responder(peer_addr, &mut peer_ip, stream, genesis_header).await
};
if let Some(ip) = peer_ip {
self.connecting_peers.lock().remove(&ip);
}
if let Ok((ref peer_ip, _)) = handshake_result {
info!("Connected to '{peer_ip}'");
}
handshake_result
}
async fn handshake_inner_initiator<'a>(
&'a self,
peer_addr: SocketAddr,
peer_ip: &mut Option<SocketAddr>,
stream: &'a mut TcpStream,
genesis_header: Header<N>,
) -> io::Result<(SocketAddr, Framed<&mut TcpStream, MessageCodec<N>>)> {
let peer_ip = peer_ip.unwrap();
let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
let rng = &mut OsRng;
let our_nonce = rng.gen();
let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, genesis_header, our_nonce)
.await
{
send(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
let Ok(our_signature) = self.account.sign_bytes(&peer_request.nonce.to_le_bytes(), rng) else {
return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
};
let our_response = ChallengeResponse { genesis_header, signature: Data::Object(our_signature) };
send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
self.insert_connected_peer(Peer::new(peer_ip, &peer_request), peer_addr);
Ok((peer_ip, framed))
}
async fn handshake_inner_responder<'a>(
&'a self,
peer_addr: SocketAddr,
peer_ip: &mut Option<SocketAddr>,
stream: &'a mut TcpStream,
genesis_header: Header<N>,
) -> io::Result<(SocketAddr, Framed<&mut TcpStream, MessageCodec<N>>)> {
let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
*peer_ip = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
let peer_ip = peer_ip.unwrap();
if let Err(forbidden_message) = self.ensure_peer_is_allowed(peer_ip) {
return Err(error(format!("{forbidden_message}")));
}
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
let rng = &mut OsRng;
let Ok(our_signature) = self.account.sign_bytes(&peer_request.nonce.to_le_bytes(), rng) else {
return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
};
let our_response = ChallengeResponse { genesis_header, signature: Data::Object(our_signature) };
send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
let our_nonce = rng.gen();
let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, genesis_header, our_nonce)
.await
{
send(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
self.insert_connected_peer(Peer::new(peer_ip, &peer_request), peer_addr);
Ok((peer_ip, framed))
}
fn ensure_peer_is_allowed(&self, peer_ip: SocketAddr) -> Result<()> {
if self.is_local_ip(&peer_ip) {
bail!("Dropping connection request from '{peer_ip}' (attempted to self-connect)")
}
if !self.connecting_peers.lock().insert(peer_ip) {
bail!("Dropping connection request from '{peer_ip}' (already shaking hands as the initiator)")
}
if self.is_connected(&peer_ip) {
bail!("Dropping connection request from '{peer_ip}' (already connected)")
}
if self.is_restricted(&peer_ip) {
bail!("Dropping connection request from '{peer_ip}' (restricted)")
}
if !peer_ip.ip().is_loopback() {
let num_attempts = self.cache.insert_inbound_connection(peer_ip.ip(), Self::RADIO_SILENCE_IN_SECS as i64);
if num_attempts > Self::MAXIMUM_CONNECTION_FAILURES {
self.insert_restricted_peer(peer_ip);
bail!("Dropping connection request from '{peer_ip}' (tried {num_attempts} times)")
}
}
Ok(())
}
fn verify_challenge_request(
&self,
peer_addr: SocketAddr,
message: &ChallengeRequest<N>,
) -> Option<DisconnectReason> {
let &ChallengeRequest { version, listener_port: _, node_type: _, address: _, nonce: _ } = message;
if version < Message::<N>::VERSION {
warn!("Dropping '{peer_addr}' on version {version} (outdated)");
return Some(DisconnectReason::OutdatedClientVersion);
}
None
}
async fn verify_challenge_response(
&self,
peer_addr: SocketAddr,
peer_address: Address<N>,
response: ChallengeResponse<N>,
expected_genesis_header: Header<N>,
expected_nonce: u64,
) -> Option<DisconnectReason> {
let ChallengeResponse { genesis_header, signature } = response;
if genesis_header != expected_genesis_header {
warn!("Handshake with '{peer_addr}' failed (incorrect block header)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
let Ok(signature) = signature.deserialize().await else {
warn!("Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
};
if !signature.verify_bytes(&peer_address, &expected_nonce.to_le_bytes()) {
warn!("Handshake with '{peer_addr}' failed (invalid signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
None
}
}