use crate::{
CONTEXT,
MAX_BATCH_DELAY_IN_MS,
MEMORY_POOL_PORT,
Worker,
events::{EventCodec, PrimaryPing},
helpers::{Cache, PrimarySender, Resolver, Storage, SyncSender, WorkerSender, assign_to_worker},
spawn_blocking,
};
use snarkos_account::Account;
use snarkos_node_bft_events::{
BlockRequest,
BlockResponse,
CertificateRequest,
CertificateResponse,
ChallengeRequest,
ChallengeResponse,
DataBlocks,
DisconnectReason,
Event,
EventTrait,
TransmissionRequest,
TransmissionResponse,
ValidatorsRequest,
ValidatorsResponse,
};
use snarkos_node_bft_ledger_service::LedgerService;
use snarkos_node_sync::{MAX_BLOCKS_BEHIND, communication_service::CommunicationService};
use snarkos_node_tcp::{
Config,
Connection,
ConnectionSide,
P2P,
Tcp,
is_bogon_ip,
is_unspecified_or_broadcast_ip,
protocols::{Disconnect, Handshake, OnConnect, Reading, Writing},
};
use snarkvm::{
console::prelude::*,
ledger::{
committee::Committee,
narwhal::{BatchHeader, Data},
},
prelude::{Address, Field},
};
use colored::Colorize;
use futures::SinkExt;
use indexmap::{IndexMap, IndexSet};
use parking_lot::{Mutex, RwLock};
use rand::seq::{IteratorRandom, SliceRandom};
#[cfg(not(any(test)))]
use std::net::IpAddr;
use std::{collections::HashSet, future::Future, io, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{
net::TcpStream,
sync::{OnceCell, oneshot},
task::{self, JoinHandle},
};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
const CACHE_EVENTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64; const CACHE_REQUESTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64; const MAX_CONNECTION_ATTEMPTS: usize = 10;
const RESTRICTED_INTERVAL: i64 = (MAX_CONNECTION_ATTEMPTS as u64 * MAX_BATCH_DELAY_IN_MS / 1000) as i64; const MIN_CONNECTED_VALIDATORS: usize = 175;
const MAX_VALIDATORS_TO_SEND: usize = 200;
#[cfg(not(any(test)))]
const CONNECTION_ATTEMPTS_SINCE_SECS: i64 = 10;
const IP_BAN_TIME_IN_SECS: u64 = 300;
#[async_trait]
pub trait Transport<N: Network>: Send + Sync {
async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>>;
fn broadcast(&self, event: Event<N>);
}
#[derive(Clone)]
pub struct Gateway<N: Network> {
account: Account<N>,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
tcp: Tcp,
cache: Arc<Cache<N>>,
resolver: Arc<Resolver<N>>,
trusted_validators: IndexSet<SocketAddr>,
connected_peers: Arc<RwLock<IndexSet<SocketAddr>>>,
connecting_peers: Arc<Mutex<IndexSet<SocketAddr>>>,
primary_sender: Arc<OnceCell<PrimarySender<N>>>,
worker_senders: Arc<OnceCell<IndexMap<u8, WorkerSender<N>>>>,
sync_sender: Arc<OnceCell<SyncSender<N>>>,
handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
dev: Option<u16>,
}
impl<N: Network> Gateway<N> {
pub fn new(
account: Account<N>,
storage: Storage<N>,
ledger: Arc<dyn LedgerService<N>>,
ip: Option<SocketAddr>,
trusted_validators: &[SocketAddr],
dev: Option<u16>,
) -> Result<Self> {
let ip = match (ip, dev) {
(None, Some(dev)) => SocketAddr::from_str(&format!("127.0.0.1:{}", MEMORY_POOL_PORT + dev))?,
(None, None) => SocketAddr::from_str(&format!("0.0.0.0:{}", MEMORY_POOL_PORT))?,
(Some(ip), _) => ip,
};
let tcp = Tcp::new(Config::new(ip, Committee::<N>::MAX_COMMITTEE_SIZE));
Ok(Self {
account,
storage,
ledger,
tcp,
cache: Default::default(),
resolver: Default::default(),
trusted_validators: trusted_validators.iter().copied().collect(),
connected_peers: Default::default(),
connecting_peers: Default::default(),
primary_sender: Default::default(),
worker_senders: Default::default(),
sync_sender: Default::default(),
handles: Default::default(),
dev,
})
}
pub async fn run(
&self,
primary_sender: PrimarySender<N>,
worker_senders: IndexMap<u8, WorkerSender<N>>,
sync_sender: Option<SyncSender<N>>,
) {
debug!("Starting the gateway for the memory pool...");
self.primary_sender.set(primary_sender).expect("Primary sender already set in gateway");
self.worker_senders.set(worker_senders).expect("The worker senders are already set");
if let Some(sync_sender) = sync_sender {
self.sync_sender.set(sync_sender).expect("Sync sender already set in gateway");
}
self.enable_handshake().await;
self.enable_reading().await;
self.enable_writing().await;
self.enable_disconnect().await;
self.enable_on_connect().await;
let _listening_addr = self.tcp.enable_listener().await.expect("Failed to enable the TCP listener");
self.initialize_heartbeat();
info!("Started the gateway for the memory pool at '{}'", self.local_ip());
}
}
impl<N: Network> Gateway<N> {
fn max_committee_size(&self) -> usize {
self.ledger
.current_committee()
.map_or_else(|_e| Committee::<N>::MAX_COMMITTEE_SIZE as usize, |committee| committee.num_members())
}
fn max_cache_events(&self) -> usize {
self.max_cache_transmissions()
}
fn max_cache_certificates(&self) -> usize {
2 * BatchHeader::<N>::MAX_GC_ROUNDS * self.max_committee_size()
}
fn max_cache_transmissions(&self) -> usize {
self.max_cache_certificates() * BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH
}
fn max_cache_duplicates(&self) -> usize {
self.max_committee_size().pow(2)
}
}
#[async_trait]
impl<N: Network> CommunicationService for Gateway<N> {
type Message = Event<N>;
fn prepare_block_request(start_height: u32, end_height: u32) -> Self::Message {
debug_assert!(start_height < end_height, "Invalid block request format");
Event::BlockRequest(BlockRequest { start_height, end_height })
}
async fn send(&self, peer_ip: SocketAddr, message: Self::Message) -> Option<oneshot::Receiver<io::Result<()>>> {
Transport::send(self, peer_ip, message).await
}
}
impl<N: Network> Gateway<N> {
pub const fn account(&self) -> &Account<N> {
&self.account
}
pub const fn dev(&self) -> Option<u16> {
self.dev
}
pub fn local_ip(&self) -> SocketAddr {
self.tcp.listening_addr().expect("The TCP listener is not enabled")
}
pub fn is_local_ip(&self, ip: SocketAddr) -> bool {
ip == self.local_ip()
|| (ip.ip().is_unspecified() || ip.ip().is_loopback()) && ip.port() == self.local_ip().port()
}
pub fn is_valid_peer_ip(&self, ip: SocketAddr) -> bool {
!self.is_local_ip(ip) && !is_bogon_ip(ip.ip()) && !is_unspecified_or_broadcast_ip(ip.ip())
}
pub fn resolver(&self) -> &Resolver<N> {
&self.resolver
}
pub fn primary_sender(&self) -> &PrimarySender<N> {
self.primary_sender.get().expect("Primary sender not set in gateway")
}
pub fn num_workers(&self) -> u8 {
u8::try_from(self.worker_senders.get().expect("Missing worker senders in gateway").len())
.expect("Too many workers")
}
pub fn get_worker_sender(&self, worker_id: u8) -> Option<&WorkerSender<N>> {
self.worker_senders.get().and_then(|senders| senders.get(&worker_id))
}
pub fn is_connected_address(&self, address: Address<N>) -> bool {
match self.resolver.get_peer_ip_for_address(address) {
Some(peer_ip) => self.is_connected_ip(peer_ip),
None => false,
}
}
pub fn is_connected_ip(&self, ip: SocketAddr) -> bool {
self.connected_peers.read().contains(&ip)
}
pub fn is_connecting_ip(&self, ip: SocketAddr) -> bool {
self.connecting_peers.lock().contains(&ip)
}
pub fn is_authorized_validator_ip(&self, ip: SocketAddr) -> bool {
if self.trusted_validators.contains(&ip) {
return true;
}
match self.resolver.get_address(ip) {
Some(address) => self.is_authorized_validator_address(address),
None => false,
}
}
pub fn is_authorized_validator_address(&self, validator_address: Address<N>) -> bool {
if self
.ledger
.get_committee_lookback_for_round(self.storage.current_round())
.map_or(false, |committee| committee.is_committee_member(validator_address))
{
return true;
}
if self.ledger.current_committee().map_or(false, |committee| committee.is_committee_member(validator_address)) {
return true;
}
let previous_block_height = self.ledger.latest_block_height().saturating_sub(MAX_BLOCKS_BEHIND);
match self.ledger.get_block_round(previous_block_height) {
Ok(block_round) => (block_round..self.storage.current_round()).step_by(2).any(|round| {
self.ledger
.get_committee_lookback_for_round(round)
.map_or(false, |committee| committee.is_committee_member(validator_address))
}),
Err(_) => false,
}
}
pub fn max_connected_peers(&self) -> usize {
self.tcp.config().max_connections as usize
}
pub fn number_of_connected_peers(&self) -> usize {
self.connected_peers.read().len()
}
pub fn connected_addresses(&self) -> HashSet<Address<N>> {
self.connected_peers.read().iter().filter_map(|peer_ip| self.resolver.get_address(*peer_ip)).collect()
}
pub fn connected_peers(&self) -> &RwLock<IndexSet<SocketAddr>> {
&self.connected_peers
}
pub fn connect(&self, peer_ip: SocketAddr) -> Option<JoinHandle<()>> {
if let Err(forbidden_error) = self.check_connection_attempt(peer_ip) {
warn!("{forbidden_error}");
return None;
}
let self_ = self.clone();
Some(tokio::spawn(async move {
debug!("Connecting to validator {peer_ip}...");
if let Err(error) = self_.tcp.connect(peer_ip).await {
self_.connecting_peers.lock().shift_remove(&peer_ip);
warn!("Unable to connect to '{peer_ip}' - {error}");
}
}))
}
fn check_connection_attempt(&self, peer_ip: SocketAddr) -> Result<()> {
if self.is_local_ip(peer_ip) {
bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (attempted to self-connect)")
}
if self.number_of_connected_peers() >= self.max_connected_peers() {
bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (maximum peers reached)")
}
if self.is_connected_ip(peer_ip) {
bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (already connected)")
}
if self.is_connecting_ip(peer_ip) {
bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (already connecting)")
}
Ok(())
}
fn ensure_peer_is_allowed(&self, peer_ip: SocketAddr) -> Result<()> {
if self.is_local_ip(peer_ip) {
bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (attempted to self-connect)")
}
if !self.connecting_peers.lock().insert(peer_ip) {
bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (already shaking hands as the initiator)")
}
if self.is_connected_ip(peer_ip) {
bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (already connected)")
}
if !peer_ip.ip().is_loopback() {
let num_attempts = self.cache.insert_inbound_connection(peer_ip.ip(), RESTRICTED_INTERVAL);
if num_attempts > MAX_CONNECTION_ATTEMPTS {
bail!("Dropping connection request from '{peer_ip}' (tried {num_attempts} times)")
}
}
Ok(())
}
#[cfg(not(any(test)))]
fn is_ip_banned(&self, ip: IpAddr) -> bool {
self.tcp.banned_peers().is_ip_banned(&ip)
}
#[cfg(not(any(test)))]
fn update_ip_ban(&self, ip: IpAddr) {
self.tcp.banned_peers().update_ip_ban(ip);
}
#[cfg(feature = "metrics")]
fn update_metrics(&self) {
metrics::gauge(metrics::bft::CONNECTED, self.connected_peers.read().len() as f64);
metrics::gauge(metrics::bft::CONNECTING, self.connecting_peers.lock().len() as f64);
}
#[cfg(not(test))]
fn insert_connected_peer(&self, peer_ip: SocketAddr, peer_addr: SocketAddr, address: Address<N>) {
self.resolver.insert_peer(peer_ip, peer_addr, address);
self.connected_peers.write().insert(peer_ip);
#[cfg(feature = "metrics")]
self.update_metrics();
}
#[cfg(test)]
pub fn insert_connected_peer(&self, peer_ip: SocketAddr, peer_addr: SocketAddr, address: Address<N>) {
self.resolver.insert_peer(peer_ip, peer_addr, address);
self.connected_peers.write().insert(peer_ip);
}
fn remove_connected_peer(&self, peer_ip: SocketAddr) {
if let Some(sync_sender) = self.sync_sender.get() {
let tx_block_sync_remove_peer_ = sync_sender.tx_block_sync_remove_peer.clone();
tokio::spawn(async move {
if let Err(e) = tx_block_sync_remove_peer_.send(peer_ip).await {
warn!("Unable to remove '{peer_ip}' from the sync module - {e}");
}
});
}
self.resolver.remove_peer(peer_ip);
self.connected_peers.write().shift_remove(&peer_ip);
#[cfg(feature = "metrics")]
self.update_metrics();
}
fn send_inner(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
let Some(peer_addr) = self.resolver.get_ambiguous(peer_ip) else {
warn!("Unable to resolve the listener IP address '{peer_ip}'");
return None;
};
let name = event.name();
trace!("{CONTEXT} Sending '{name}' to '{peer_ip}'");
let result = self.unicast(peer_addr, event);
if let Err(e) = &result {
warn!("{CONTEXT} Failed to send '{name}' to '{peer_ip}': {e}");
debug!("{CONTEXT} Disconnecting from '{peer_ip}' (unable to send)");
self.disconnect(peer_ip);
}
result.ok()
}
async fn inbound(&self, peer_addr: SocketAddr, event: Event<N>) -> Result<()> {
let Some(peer_ip) = self.resolver.get_listener(peer_addr) else {
bail!("{CONTEXT} Unable to resolve the (ambiguous) peer address '{peer_addr}'")
};
if !self.is_authorized_validator_ip(peer_ip) {
bail!("{CONTEXT} Dropping '{}' from '{peer_ip}' (not authorized)", event.name())
}
let num_events = self.cache.insert_inbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
if num_events >= self.max_cache_events() {
bail!("Dropping '{peer_ip}' for spamming events (num_events = {num_events})")
}
match event {
Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
let certificate_id = match &event {
Event::CertificateRequest(CertificateRequest { certificate_id }) => *certificate_id,
Event::CertificateResponse(CertificateResponse { certificate }) => certificate.id(),
_ => unreachable!(),
};
let num_events = self.cache.insert_inbound_certificate(certificate_id, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(());
}
}
Event::TransmissionRequest(TransmissionRequest { transmission_id })
| Event::TransmissionResponse(TransmissionResponse { transmission_id, .. }) => {
let num_events = self.cache.insert_inbound_transmission(transmission_id, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(());
}
}
Event::BlockRequest(_) => {
let num_events = self.cache.insert_inbound_block_request(peer_ip, CACHE_REQUESTS_INTERVAL);
if num_events >= self.max_cache_duplicates() {
return Ok(());
}
}
_ => {}
}
trace!("{CONTEXT} Received '{}' from '{peer_ip}'", event.name());
match event {
Event::BatchPropose(batch_propose) => {
let _ = self.primary_sender().tx_batch_propose.send((peer_ip, batch_propose)).await;
Ok(())
}
Event::BatchSignature(batch_signature) => {
let _ = self.primary_sender().tx_batch_signature.send((peer_ip, batch_signature)).await;
Ok(())
}
Event::BatchCertified(batch_certified) => {
let _ = self.primary_sender().tx_batch_certified.send((peer_ip, batch_certified.certificate)).await;
Ok(())
}
Event::BlockRequest(block_request) => {
let BlockRequest { start_height, end_height } = block_request;
if start_height >= end_height {
bail!("Block request from '{peer_ip}' has an invalid range ({start_height}..{end_height})")
}
if end_height - start_height > DataBlocks::<N>::MAXIMUM_NUMBER_OF_BLOCKS as u32 {
bail!("Block request from '{peer_ip}' has an excessive range ({start_height}..{end_height})")
}
let self_ = self.clone();
let blocks = match task::spawn_blocking(move || {
match self_.ledger.get_blocks(start_height..end_height) {
Ok(blocks) => Ok(Data::Object(DataBlocks(blocks))),
Err(error) => bail!("Missing blocks {start_height} to {end_height} from ledger - {error}"),
}
})
.await
{
Ok(Ok(blocks)) => blocks,
Ok(Err(error)) => return Err(error),
Err(error) => return Err(anyhow!("[BlockRequest] {error}")),
};
let self_ = self.clone();
tokio::spawn(async move {
let event = Event::BlockResponse(BlockResponse { request: block_request, blocks });
Transport::send(&self_, peer_ip, event).await;
});
Ok(())
}
Event::BlockResponse(block_response) => {
if let Some(sync_sender) = self.sync_sender.get() {
let BlockResponse { request, blocks } = block_response;
if !self.cache.remove_outbound_block_request(peer_ip, &request) {
bail!("Unsolicited block response from '{peer_ip}'")
}
let blocks = blocks.deserialize().await.map_err(|error| anyhow!("[BlockResponse] {error}"))?;
blocks.ensure_response_is_well_formed(peer_ip, request.start_height, request.end_height)?;
if let Err(e) = sync_sender.advance_with_sync_blocks(peer_ip, blocks.0).await {
warn!("Unable to process block response from '{peer_ip}' - {e}");
}
}
Ok(())
}
Event::CertificateRequest(certificate_request) => {
if let Some(sync_sender) = self.sync_sender.get() {
let _ = sync_sender.tx_certificate_request.send((peer_ip, certificate_request)).await;
}
Ok(())
}
Event::CertificateResponse(certificate_response) => {
if let Some(sync_sender) = self.sync_sender.get() {
let _ = sync_sender.tx_certificate_response.send((peer_ip, certificate_response)).await;
}
Ok(())
}
Event::ChallengeRequest(..) | Event::ChallengeResponse(..) => {
bail!("{CONTEXT} Peer '{peer_ip}' is not following the protocol")
}
Event::Disconnect(disconnect) => {
bail!("{CONTEXT} {:?}", disconnect.reason)
}
Event::PrimaryPing(ping) => {
let PrimaryPing { version, block_locators, primary_certificate } = ping;
if version < Event::<N>::VERSION {
bail!("Dropping '{peer_ip}' on event version {version} (outdated)");
}
if let Some(sync_sender) = self.sync_sender.get() {
if let Err(error) = sync_sender.update_peer_locators(peer_ip, block_locators).await {
bail!("Validator '{peer_ip}' sent invalid block locators - {error}");
}
}
let _ = self.primary_sender().tx_primary_ping.send((peer_ip, primary_certificate)).await;
Ok(())
}
Event::TransmissionRequest(request) => {
let Ok(worker_id) = assign_to_worker(request.transmission_id, self.num_workers()) else {
warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", request.transmission_id);
return Ok(());
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_transmission_request.send((peer_ip, request)).await;
}
Ok(())
}
Event::TransmissionResponse(response) => {
let Ok(worker_id) = assign_to_worker(response.transmission_id, self.num_workers()) else {
warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", response.transmission_id);
return Ok(());
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_transmission_response.send((peer_ip, response)).await;
}
Ok(())
}
Event::ValidatorsRequest(_) => {
let mut connected_peers: Vec<_> = match self.dev.is_some() {
true => self.connected_peers.read().iter().copied().collect(),
false => {
self.connected_peers.read().iter().copied().filter(|ip| self.is_valid_peer_ip(*ip)).collect()
}
};
connected_peers.shuffle(&mut rand::thread_rng());
let self_ = self.clone();
tokio::spawn(async move {
let mut validators = IndexMap::with_capacity(MAX_VALIDATORS_TO_SEND);
for validator_ip in connected_peers.into_iter().take(MAX_VALIDATORS_TO_SEND) {
if let Some(validator_address) = self_.resolver.get_address(validator_ip) {
validators.insert(validator_ip, validator_address);
}
}
let event = Event::ValidatorsResponse(ValidatorsResponse { validators });
Transport::send(&self_, peer_ip, event).await;
});
Ok(())
}
Event::ValidatorsResponse(response) => {
let ValidatorsResponse { validators } = response;
ensure!(validators.len() <= MAX_VALIDATORS_TO_SEND, "{CONTEXT} Received too many validators");
if !self.cache.contains_outbound_validators_request(peer_ip) {
bail!("{CONTEXT} Received validators response from '{peer_ip}' without a validators request")
}
self.cache.decrement_outbound_validators_requests(peer_ip);
if self.number_of_connected_peers() < MIN_CONNECTED_VALIDATORS {
let self_ = self.clone();
tokio::spawn(async move {
for (validator_ip, validator_address) in validators {
if self_.dev.is_some() {
if self_.is_local_ip(validator_ip) {
continue;
}
} else {
if !self_.is_valid_peer_ip(validator_ip) {
continue;
}
}
if self_.account.address() == validator_address {
continue;
}
if self_.is_connected_ip(validator_ip) || self_.is_connecting_ip(validator_ip) {
continue;
}
if self_.is_connected_address(validator_address) {
continue;
}
if !self_.is_authorized_validator_address(validator_address) {
continue;
}
self_.connect(validator_ip);
}
});
}
Ok(())
}
Event::WorkerPing(ping) => {
ensure!(
ping.transmission_ids.len() <= Worker::<N>::MAX_TRANSMISSIONS_PER_WORKER_PING,
"{CONTEXT} Received too many transmissions"
);
let num_workers = self.num_workers();
for transmission_id in ping.transmission_ids.into_iter() {
let Ok(worker_id) = assign_to_worker(transmission_id, num_workers) else {
warn!("{CONTEXT} Unable to assign transmission ID '{transmission_id}' to a worker");
continue;
};
if let Some(sender) = self.get_worker_sender(worker_id) {
let _ = sender.tx_worker_ping.send((peer_ip, transmission_id)).await;
}
}
Ok(())
}
}
}
pub fn disconnect(&self, peer_ip: SocketAddr) -> JoinHandle<()> {
let gateway = self.clone();
tokio::spawn(async move {
if let Some(peer_addr) = gateway.resolver.get_ambiguous(peer_ip) {
let _disconnected = gateway.tcp.disconnect(peer_addr).await;
debug_assert!(_disconnected);
}
})
}
fn initialize_heartbeat(&self) {
let self_clone = self.clone();
self.spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
info!("Starting the heartbeat of the gateway...");
loop {
self_clone.heartbeat();
tokio::time::sleep(Duration::from_secs(15)).await;
}
});
}
#[allow(dead_code)]
fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
self.handles.lock().push(tokio::spawn(future));
}
pub async fn shut_down(&self) {
info!("Shutting down the gateway...");
self.handles.lock().iter().for_each(|handle| handle.abort());
self.tcp.shut_down().await;
}
}
impl<N: Network> Gateway<N> {
fn heartbeat(&self) {
self.log_connected_validators();
self.handle_trusted_validators();
self.handle_unauthorized_validators();
self.handle_min_connected_validators();
self.handle_banned_ips();
}
fn log_connected_validators(&self) {
let validators = self.connected_peers().read().clone();
let validators_total = self.ledger.current_committee().map_or(0, |c| c.num_members().saturating_sub(1));
let total_validators = format!("(of {validators_total} bonded validators)").dimmed();
let connections_msg = match validators.len() {
0 => "No connected validators".to_string(),
num_connected => format!("Connected to {num_connected} validators {total_validators}"),
};
info!("{connections_msg}");
for peer_ip in validators {
let address = self.resolver.get_address(peer_ip).map_or("Unknown".to_string(), |a| a.to_string());
debug!("{}", format!(" {peer_ip} - {address}").dimmed());
}
}
fn handle_trusted_validators(&self) {
for validator_ip in &self.trusted_validators {
if !self.is_local_ip(*validator_ip)
&& !self.is_connecting_ip(*validator_ip)
&& !self.is_connected_ip(*validator_ip)
{
self.connect(*validator_ip);
}
}
}
fn handle_unauthorized_validators(&self) {
let self_ = self.clone();
tokio::spawn(async move {
let validators = self_.connected_peers().read().clone();
for peer_ip in validators {
if !self_.is_authorized_validator_ip(peer_ip) {
warn!("{CONTEXT} Disconnecting from '{peer_ip}' - Validator is not in the current committee");
Transport::send(&self_, peer_ip, DisconnectReason::ProtocolViolation.into()).await;
self_.disconnect(peer_ip);
}
}
});
}
fn handle_min_connected_validators(&self) {
if self.number_of_connected_peers() < MIN_CONNECTED_VALIDATORS {
let validators = self.connected_peers().read().clone();
if validators.is_empty() {
return;
}
if let Some(validator_ip) = validators.into_iter().choose(&mut rand::thread_rng()) {
let self_ = self.clone();
tokio::spawn(async move {
self_.cache.increment_outbound_validators_requests(validator_ip);
let _ = Transport::send(&self_, validator_ip, Event::ValidatorsRequest(ValidatorsRequest)).await;
});
}
}
}
fn handle_banned_ips(&self) {
self.tcp.banned_peers().remove_old_bans(IP_BAN_TIME_IN_SECS);
}
}
#[async_trait]
impl<N: Network> Transport<N> for Gateway<N> {
async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
macro_rules! send {
($self:ident, $cache_map:ident, $interval:expr, $freq:ident) => {{
while $self.cache.$cache_map(peer_ip, $interval) > $self.$freq() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
$self.send_inner(peer_ip, event)
}};
}
match event {
Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
send!(self, insert_outbound_certificate, CACHE_REQUESTS_INTERVAL, max_cache_certificates)
}
Event::TransmissionRequest(_) | Event::TransmissionResponse(_) => {
self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
send!(self, insert_outbound_transmission, CACHE_REQUESTS_INTERVAL, max_cache_transmissions)
}
Event::BlockRequest(request) => {
self.cache.insert_outbound_block_request(peer_ip, request);
send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
}
_ => {
send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
}
}
}
fn broadcast(&self, event: Event<N>) {
if self.number_of_connected_peers() > 0 {
let self_ = self.clone();
let connected_peers = self.connected_peers.read().clone();
tokio::spawn(async move {
for peer_ip in connected_peers {
let _ = Transport::send(&self_, peer_ip, event.clone()).await;
}
});
}
}
}
impl<N: Network> P2P for Gateway<N> {
fn tcp(&self) -> &Tcp {
&self.tcp
}
}
#[async_trait]
impl<N: Network> Reading for Gateway<N> {
type Codec = EventCodec<N>;
type Message = Event<N>;
const MESSAGE_QUEUE_DEPTH: usize = 2
* BatchHeader::<N>::MAX_GC_ROUNDS
* Committee::<N>::MAX_COMMITTEE_SIZE as usize
* BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH;
fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
Default::default()
}
async fn process_message(&self, peer_addr: SocketAddr, message: Self::Message) -> io::Result<()> {
if let Err(error) = self.inbound(peer_addr, message).await {
if let Some(peer_ip) = self.resolver.get_listener(peer_addr) {
warn!("{CONTEXT} Disconnecting from '{peer_ip}' - {error}");
let self_ = self.clone();
tokio::spawn(async move {
Transport::send(&self_, peer_ip, DisconnectReason::ProtocolViolation.into()).await;
self_.disconnect(peer_ip);
});
}
}
Ok(())
}
}
#[async_trait]
impl<N: Network> Writing for Gateway<N> {
type Codec = EventCodec<N>;
type Message = Event<N>;
const MESSAGE_QUEUE_DEPTH: usize = 2
* BatchHeader::<N>::MAX_GC_ROUNDS
* Committee::<N>::MAX_COMMITTEE_SIZE as usize
* BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH;
fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
Default::default()
}
}
#[async_trait]
impl<N: Network> Disconnect for Gateway<N> {
async fn handle_disconnect(&self, peer_addr: SocketAddr) {
if let Some(peer_ip) = self.resolver.get_listener(peer_addr) {
self.remove_connected_peer(peer_ip);
self.cache.clear_outbound_validators_requests(peer_ip);
self.cache.clear_outbound_block_requests(peer_ip);
}
}
}
#[async_trait]
impl<N: Network> OnConnect for Gateway<N> {
async fn on_connect(&self, _peer_addr: SocketAddr) {
return;
}
}
#[async_trait]
impl<N: Network> Handshake for Gateway<N> {
async fn perform_handshake(&self, mut connection: Connection) -> io::Result<Connection> {
let peer_addr = connection.addr();
let peer_side = connection.side();
#[cfg(not(any(test)))]
if self.dev().is_none() && peer_side == ConnectionSide::Initiator {
if self.is_ip_banned(peer_addr.ip()) {
trace!("{CONTEXT} Gateway rejected a connection request from banned IP '{}'", peer_addr.ip());
return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
}
let num_attempts = self.cache.insert_inbound_connection(peer_addr.ip(), CONNECTION_ATTEMPTS_SINCE_SECS);
debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
if num_attempts > MAX_CONNECTION_ATTEMPTS {
self.update_ip_ban(peer_addr.ip());
trace!("{CONTEXT} Gateway rejected a consecutive connection request from IP '{}'", peer_addr.ip());
return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
}
}
let stream = self.borrow_stream(&mut connection);
let mut peer_ip = if peer_side == ConnectionSide::Initiator {
debug!("{CONTEXT} Gateway received a connection request from '{peer_addr}'");
None
} else {
debug!("{CONTEXT} Gateway is connecting to {peer_addr}...");
Some(peer_addr)
};
let restrictions_id = self.ledger.latest_restrictions_id();
let handshake_result = if peer_side == ConnectionSide::Responder {
self.handshake_inner_initiator(peer_addr, peer_ip, restrictions_id, stream).await
} else {
self.handshake_inner_responder(peer_addr, &mut peer_ip, restrictions_id, stream).await
};
if let Some(ip) = peer_ip {
self.connecting_peers.lock().shift_remove(&ip);
}
let (ref peer_ip, _) = handshake_result?;
info!("{CONTEXT} Gateway is connected to '{peer_ip}'");
Ok(connection)
}
}
macro_rules! expect_event {
($event_ty:path, $framed:expr, $peer_addr:expr) => {
match $framed.try_next().await? {
Some($event_ty(data)) => {
trace!("{CONTEXT} Gateway received '{}' from '{}'", data.name(), $peer_addr);
data
}
Some(Event::Disconnect(reason)) => {
return Err(error(format!("{CONTEXT} '{}' disconnected: {reason:?}", $peer_addr)));
}
Some(ty) => {
return Err(error(format!(
"{CONTEXT} '{}' did not follow the handshake protocol: received {:?} instead of {}",
$peer_addr,
ty.name(),
stringify!($event_ty),
)))
}
None => {
return Err(error(format!(
"{CONTEXT} '{}' disconnected before sending {:?}",
$peer_addr,
stringify!($event_ty)
)))
}
}
};
}
async fn send_event<N: Network>(
framed: &mut Framed<&mut TcpStream, EventCodec<N>>,
peer_addr: SocketAddr,
event: Event<N>,
) -> io::Result<()> {
trace!("{CONTEXT} Gateway is sending '{}' to '{peer_addr}'", event.name());
framed.send(event).await
}
impl<N: Network> Gateway<N> {
async fn handshake_inner_initiator<'a>(
&'a self,
peer_addr: SocketAddr,
peer_ip: Option<SocketAddr>,
restrictions_id: Field<N>,
stream: &'a mut TcpStream,
) -> io::Result<(SocketAddr, Framed<&mut TcpStream, EventCodec<N>>)> {
let peer_ip = peer_ip.unwrap();
let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
let rng = &mut rand::rngs::OsRng;
let our_nonce = rng.gen();
let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce);
send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
.await
{
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
let response_nonce: u64 = rng.gen();
let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
};
let our_response =
ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
self.insert_connected_peer(peer_ip, peer_addr, peer_request.address);
Ok((peer_ip, framed))
}
async fn handshake_inner_responder<'a>(
&'a self,
peer_addr: SocketAddr,
peer_ip: &mut Option<SocketAddr>,
restrictions_id: Field<N>,
stream: &'a mut TcpStream,
) -> io::Result<(SocketAddr, Framed<&mut TcpStream, EventCodec<N>>)> {
let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
if self.account.address() == peer_request.address {
return Err(error("Skipping request to connect to self".to_string()));
}
*peer_ip = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
let peer_ip = peer_ip.unwrap();
if let Err(forbidden_message) = self.ensure_peer_is_allowed(peer_ip) {
return Err(error(format!("{forbidden_message}")));
}
if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
let rng = &mut rand::rngs::OsRng;
let response_nonce: u64 = rng.gen();
let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
};
let our_response =
ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
let our_nonce = rng.gen();
let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce);
send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
if let Some(reason) = self
.verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
.await
{
send_event(&mut framed, peer_addr, reason.into()).await?;
return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
}
self.insert_connected_peer(peer_ip, peer_addr, peer_request.address);
Ok((peer_ip, framed))
}
fn verify_challenge_request(&self, peer_addr: SocketAddr, event: &ChallengeRequest<N>) -> Option<DisconnectReason> {
let &ChallengeRequest { version, listener_port: _, address, nonce: _ } = event;
if version < Event::<N>::VERSION {
warn!("{CONTEXT} Gateway is dropping '{peer_addr}' on version {version} (outdated)");
return Some(DisconnectReason::OutdatedClientVersion);
}
if !self.is_authorized_validator_address(address) {
warn!("{CONTEXT} Gateway is dropping '{peer_addr}' for being an unauthorized validator ({address})");
return Some(DisconnectReason::ProtocolViolation);
}
if self.is_connected_address(address) {
warn!("{CONTEXT} Gateway is dropping '{peer_addr}' for being already connected ({address})");
return Some(DisconnectReason::ProtocolViolation);
}
None
}
async fn verify_challenge_response(
&self,
peer_addr: SocketAddr,
peer_address: Address<N>,
response: ChallengeResponse<N>,
expected_restrictions_id: Field<N>,
expected_nonce: u64,
) -> Option<DisconnectReason> {
let ChallengeResponse { restrictions_id, signature, nonce } = response;
if restrictions_id != expected_restrictions_id {
warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (incorrect restrictions ID)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
let Ok(signature) = spawn_blocking!(signature.deserialize_blocking()) else {
warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (cannot deserialize the signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
};
if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (invalid signature)");
return Some(DisconnectReason::InvalidChallengeResponse);
}
None
}
}
#[cfg(test)]
mod prop_tests {
use crate::{
Gateway,
MAX_WORKERS,
MEMORY_POOL_PORT,
Worker,
gateway::prop_tests::GatewayAddress::{Dev, Prod},
helpers::{Storage, init_primary_channels, init_worker_channels},
};
use snarkos_account::Account;
use snarkos_node_bft_ledger_service::MockLedgerService;
use snarkos_node_bft_storage_service::BFTMemoryService;
use snarkos_node_tcp::P2P;
use snarkvm::{
ledger::{
committee::{
Committee,
prop_tests::{CommitteeContext, ValidatorSet},
test_helpers::sample_committee_for_round_and_members,
},
narwhal::{BatchHeader, batch_certificate::test_helpers::sample_batch_certificate_for_round},
},
prelude::{MainnetV0, PrivateKey},
utilities::TestRng,
};
use indexmap::{IndexMap, IndexSet};
use proptest::{
prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any, any_with},
sample::Selector,
};
use std::{
fmt::{Debug, Formatter},
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
};
use test_strategy::proptest;
type CurrentNetwork = MainnetV0;
impl Debug for Gateway<CurrentNetwork> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Gateway").field(&self.account.address()).field(&self.tcp.config()).finish()
}
}
#[derive(Debug, test_strategy::Arbitrary)]
enum GatewayAddress {
Dev(u8),
Prod(Option<SocketAddr>),
}
impl GatewayAddress {
fn ip(&self) -> Option<SocketAddr> {
if let GatewayAddress::Prod(ip) = self {
return *ip;
}
None
}
fn port(&self) -> Option<u16> {
if let GatewayAddress::Dev(port) = self {
return Some(*port as u16);
}
None
}
}
impl Arbitrary for Gateway<CurrentNetwork> {
type Parameters = ();
type Strategy = BoxedStrategy<Gateway<CurrentNetwork>>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
any_valid_dev_gateway()
.prop_map(|(storage, _, private_key, address)| {
Gateway::new(
Account::try_from(private_key).unwrap(),
storage.clone(),
storage.ledger().clone(),
address.ip(),
&[],
address.port(),
)
.unwrap()
})
.boxed()
}
}
type GatewayInput = (Storage<CurrentNetwork>, CommitteeContext, PrivateKey<CurrentNetwork>, GatewayAddress);
fn any_valid_dev_gateway() -> BoxedStrategy<GatewayInput> {
(any::<CommitteeContext>(), any::<Selector>())
.prop_flat_map(|(context, account_selector)| {
let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
(
any_with::<Storage<CurrentNetwork>>(context.clone()),
Just(context),
Just(account_selector.select(validators)),
0u8..,
)
.prop_map(|(a, b, c, d)| (a, b, c.private_key, Dev(d)))
})
.boxed()
}
fn any_valid_prod_gateway() -> BoxedStrategy<GatewayInput> {
(any::<CommitteeContext>(), any::<Selector>())
.prop_flat_map(|(context, account_selector)| {
let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
(
any_with::<Storage<CurrentNetwork>>(context.clone()),
Just(context),
Just(account_selector.select(validators)),
any::<Option<SocketAddr>>(),
)
.prop_map(|(a, b, c, d)| (a, b, c.private_key, Prod(d)))
})
.boxed()
}
#[proptest]
fn gateway_dev_initialization(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
let (storage, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let gateway =
Gateway::new(account.clone(), storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port())
.unwrap();
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT + dev.port().unwrap()));
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::MAX_COMMITTEE_SIZE);
assert_eq!(gateway.account().address(), account.address());
}
#[proptest]
fn gateway_prod_initialization(#[strategy(any_valid_prod_gateway())] input: GatewayInput) {
let (storage, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let gateway =
Gateway::new(account.clone(), storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port())
.unwrap();
let tcp_config = gateway.tcp().config();
if let Some(socket_addr) = dev.ip() {
assert_eq!(tcp_config.listener_ip, Some(socket_addr.ip()));
assert_eq!(tcp_config.desired_listening_port, Some(socket_addr.port()));
} else {
assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT));
}
let tcp_config = gateway.tcp().config();
assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::MAX_COMMITTEE_SIZE);
assert_eq!(gateway.account().address(), account.address());
}
#[proptest(async = "tokio")]
async fn gateway_start(
#[strategy(any_valid_dev_gateway())] input: GatewayInput,
#[strategy(0..MAX_WORKERS)] workers_count: u8,
) {
let (storage, committee, private_key, dev) = input;
let committee = committee.0;
let worker_storage = storage.clone();
let account = Account::try_from(private_key).unwrap();
let gateway =
Gateway::new(account, storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port()).unwrap();
let (primary_sender, _) = init_primary_channels();
let (workers, worker_senders) = {
let mut tx_workers = IndexMap::new();
let mut workers = IndexMap::new();
for id in 0..workers_count {
let (tx_worker, rx_worker) = init_worker_channels();
let ledger = Arc::new(MockLedgerService::new(committee.clone()));
let worker =
Worker::new(id, Arc::new(gateway.clone()), worker_storage.clone(), ledger, Default::default())
.unwrap();
worker.run(rx_worker);
workers.insert(id, worker);
tx_workers.insert(id, tx_worker);
}
(workers, tx_workers)
};
gateway.run(primary_sender, worker_senders, None).await;
assert_eq!(
gateway.local_ip(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), MEMORY_POOL_PORT + dev.port().unwrap())
);
assert_eq!(gateway.num_workers(), workers.len() as u8);
}
#[proptest]
fn test_is_authorized_validator(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
let rng = &mut TestRng::default();
let current_round = 2;
let committee_size = 4;
let max_gc_rounds = BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64;
let (_, _, private_key, dev) = input;
let account = Account::try_from(private_key).unwrap();
let mut certificates = IndexSet::new();
for _ in 0..committee_size {
certificates.insert(sample_batch_certificate_for_round(current_round, rng));
}
let addresses: Vec<_> = certificates.iter().map(|certificate| certificate.author()).collect();
let committee = sample_committee_for_round_and_members(current_round, addresses, rng);
for _ in 0..committee_size {
certificates.insert(sample_batch_certificate_for_round(current_round, rng));
}
let ledger = Arc::new(MockLedgerService::new(committee.clone()));
let storage = Storage::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
let gateway =
Gateway::new(account.clone(), storage.clone(), ledger.clone(), dev.ip(), &[], dev.port()).unwrap();
for certificate in certificates.iter() {
storage.testing_only_insert_certificate_testing_only(certificate.clone());
}
for i in 0..certificates.clone().len() {
let is_authorized = gateway.is_authorized_validator_address(certificates[i].author());
if i < committee_size {
assert!(is_authorized);
} else {
assert!(!is_authorized);
}
}
}
}