1use crate::{
17 CONTEXT,
18 MAX_BATCH_DELAY_IN_MS,
19 MEMORY_POOL_PORT,
20 Worker,
21 events::{EventCodec, PrimaryPing},
22 helpers::{Cache, PrimarySender, Resolver, Storage, SyncSender, WorkerSender, assign_to_worker},
23 spawn_blocking,
24};
25use snarkos_account::Account;
26use snarkos_node_bft_events::{
27 BlockRequest,
28 BlockResponse,
29 CertificateRequest,
30 CertificateResponse,
31 ChallengeRequest,
32 ChallengeResponse,
33 DataBlocks,
34 DisconnectReason,
35 Event,
36 EventTrait,
37 TransmissionRequest,
38 TransmissionResponse,
39 ValidatorsRequest,
40 ValidatorsResponse,
41};
42use snarkos_node_bft_ledger_service::LedgerService;
43use snarkos_node_sync::{MAX_BLOCKS_BEHIND, communication_service::CommunicationService};
44use snarkos_node_tcp::{
45 Config,
46 Connection,
47 ConnectionSide,
48 P2P,
49 Tcp,
50 is_bogon_ip,
51 is_unspecified_or_broadcast_ip,
52 protocols::{Disconnect, Handshake, OnConnect, Reading, Writing},
53};
54use snarkvm::{
55 console::prelude::*,
56 ledger::{
57 committee::Committee,
58 narwhal::{BatchHeader, Data},
59 },
60 prelude::{Address, Field},
61};
62
63use colored::Colorize;
64use futures::SinkExt;
65use indexmap::{IndexMap, IndexSet};
66use parking_lot::{Mutex, RwLock};
67use rand::seq::{IteratorRandom, SliceRandom};
68#[cfg(not(any(test)))]
69use std::net::IpAddr;
70use std::{collections::HashSet, future::Future, io, net::SocketAddr, sync::Arc, time::Duration};
71use tokio::{
72 net::TcpStream,
73 sync::{OnceCell, oneshot},
74 task::{self, JoinHandle},
75};
76use tokio_stream::StreamExt;
77use tokio_util::codec::Framed;
78
79const CACHE_EVENTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64; const CACHE_REQUESTS_INTERVAL: i64 = (MAX_BATCH_DELAY_IN_MS / 1000) as i64; const MAX_CONNECTION_ATTEMPTS: usize = 10;
86const RESTRICTED_INTERVAL: i64 = (MAX_CONNECTION_ATTEMPTS as u64 * MAX_BATCH_DELAY_IN_MS / 1000) as i64; const MIN_CONNECTED_VALIDATORS: usize = 175;
91const MAX_VALIDATORS_TO_SEND: usize = 200;
93
94#[cfg(not(any(test)))]
96const CONNECTION_ATTEMPTS_SINCE_SECS: i64 = 10;
97const IP_BAN_TIME_IN_SECS: u64 = 300;
99
100#[async_trait]
103pub trait Transport<N: Network>: Send + Sync {
104 async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>>;
105 fn broadcast(&self, event: Event<N>);
106}
107
108#[derive(Clone)]
109pub struct Gateway<N: Network> {
110 account: Account<N>,
112 storage: Storage<N>,
114 ledger: Arc<dyn LedgerService<N>>,
116 tcp: Tcp,
118 cache: Arc<Cache<N>>,
120 resolver: Arc<Resolver<N>>,
122 trusted_validators: IndexSet<SocketAddr>,
124 connected_peers: Arc<RwLock<IndexSet<SocketAddr>>>,
126 connecting_peers: Arc<Mutex<IndexSet<SocketAddr>>>,
131 primary_sender: Arc<OnceCell<PrimarySender<N>>>,
133 worker_senders: Arc<OnceCell<IndexMap<u8, WorkerSender<N>>>>,
135 sync_sender: Arc<OnceCell<SyncSender<N>>>,
137 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
139 dev: Option<u16>,
141}
142
143impl<N: Network> Gateway<N> {
144 pub fn new(
146 account: Account<N>,
147 storage: Storage<N>,
148 ledger: Arc<dyn LedgerService<N>>,
149 ip: Option<SocketAddr>,
150 trusted_validators: &[SocketAddr],
151 dev: Option<u16>,
152 ) -> Result<Self> {
153 let ip = match (ip, dev) {
155 (None, Some(dev)) => SocketAddr::from_str(&format!("127.0.0.1:{}", MEMORY_POOL_PORT + dev))?,
156 (None, None) => SocketAddr::from_str(&format!("0.0.0.0:{}", MEMORY_POOL_PORT))?,
157 (Some(ip), _) => ip,
158 };
159 let tcp = Tcp::new(Config::new(ip, Committee::<N>::MAX_COMMITTEE_SIZE));
161 Ok(Self {
163 account,
164 storage,
165 ledger,
166 tcp,
167 cache: Default::default(),
168 resolver: Default::default(),
169 trusted_validators: trusted_validators.iter().copied().collect(),
170 connected_peers: Default::default(),
171 connecting_peers: Default::default(),
172 primary_sender: Default::default(),
173 worker_senders: Default::default(),
174 sync_sender: Default::default(),
175 handles: Default::default(),
176 dev,
177 })
178 }
179
180 pub async fn run(
182 &self,
183 primary_sender: PrimarySender<N>,
184 worker_senders: IndexMap<u8, WorkerSender<N>>,
185 sync_sender: Option<SyncSender<N>>,
186 ) {
187 debug!("Starting the gateway for the memory pool...");
188
189 self.primary_sender.set(primary_sender).expect("Primary sender already set in gateway");
191
192 self.worker_senders.set(worker_senders).expect("The worker senders are already set");
194
195 if let Some(sync_sender) = sync_sender {
197 self.sync_sender.set(sync_sender).expect("Sync sender already set in gateway");
198 }
199
200 self.enable_handshake().await;
202 self.enable_reading().await;
203 self.enable_writing().await;
204 self.enable_disconnect().await;
205 self.enable_on_connect().await;
206 let _listening_addr = self.tcp.enable_listener().await.expect("Failed to enable the TCP listener");
208
209 self.initialize_heartbeat();
211
212 info!("Started the gateway for the memory pool at '{}'", self.local_ip());
213 }
214}
215
216impl<N: Network> Gateway<N> {
218 fn max_committee_size(&self) -> usize {
220 self.ledger
221 .current_committee()
222 .map_or_else(|_e| Committee::<N>::MAX_COMMITTEE_SIZE as usize, |committee| committee.num_members())
223 }
224
225 fn max_cache_events(&self) -> usize {
227 self.max_cache_transmissions()
228 }
229
230 fn max_cache_certificates(&self) -> usize {
232 2 * BatchHeader::<N>::MAX_GC_ROUNDS * self.max_committee_size()
233 }
234
235 fn max_cache_transmissions(&self) -> usize {
237 self.max_cache_certificates() * BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH
238 }
239
240 fn max_cache_duplicates(&self) -> usize {
242 self.max_committee_size().pow(2)
243 }
244}
245
246#[async_trait]
247impl<N: Network> CommunicationService for Gateway<N> {
248 type Message = Event<N>;
250
251 fn prepare_block_request(start_height: u32, end_height: u32) -> Self::Message {
253 debug_assert!(start_height < end_height, "Invalid block request format");
254 Event::BlockRequest(BlockRequest { start_height, end_height })
255 }
256
257 async fn send(&self, peer_ip: SocketAddr, message: Self::Message) -> Option<oneshot::Receiver<io::Result<()>>> {
263 Transport::send(self, peer_ip, message).await
264 }
265}
266
267impl<N: Network> Gateway<N> {
268 pub const fn account(&self) -> &Account<N> {
270 &self.account
271 }
272
273 pub const fn dev(&self) -> Option<u16> {
275 self.dev
276 }
277
278 pub fn local_ip(&self) -> SocketAddr {
280 self.tcp.listening_addr().expect("The TCP listener is not enabled")
281 }
282
283 pub fn is_local_ip(&self, ip: SocketAddr) -> bool {
285 ip == self.local_ip()
286 || (ip.ip().is_unspecified() || ip.ip().is_loopback()) && ip.port() == self.local_ip().port()
287 }
288
289 pub fn is_valid_peer_ip(&self, ip: SocketAddr) -> bool {
291 !self.is_local_ip(ip) && !is_bogon_ip(ip.ip()) && !is_unspecified_or_broadcast_ip(ip.ip())
292 }
293
294 pub fn resolver(&self) -> &Resolver<N> {
296 &self.resolver
297 }
298
299 pub fn primary_sender(&self) -> &PrimarySender<N> {
301 self.primary_sender.get().expect("Primary sender not set in gateway")
302 }
303
304 pub fn num_workers(&self) -> u8 {
306 u8::try_from(self.worker_senders.get().expect("Missing worker senders in gateway").len())
307 .expect("Too many workers")
308 }
309
310 pub fn get_worker_sender(&self, worker_id: u8) -> Option<&WorkerSender<N>> {
312 self.worker_senders.get().and_then(|senders| senders.get(&worker_id))
313 }
314
315 pub fn is_connected_address(&self, address: Address<N>) -> bool {
317 match self.resolver.get_peer_ip_for_address(address) {
319 Some(peer_ip) => self.is_connected_ip(peer_ip),
321 None => false,
322 }
323 }
324
325 pub fn is_connected_ip(&self, ip: SocketAddr) -> bool {
327 self.connected_peers.read().contains(&ip)
328 }
329
330 pub fn is_connecting_ip(&self, ip: SocketAddr) -> bool {
332 self.connecting_peers.lock().contains(&ip)
333 }
334
335 pub fn is_authorized_validator_ip(&self, ip: SocketAddr) -> bool {
337 if self.trusted_validators.contains(&ip) {
339 return true;
340 }
341 match self.resolver.get_address(ip) {
343 Some(address) => self.is_authorized_validator_address(address),
345 None => false,
346 }
347 }
348
349 pub fn is_authorized_validator_address(&self, validator_address: Address<N>) -> bool {
351 if self
360 .ledger
361 .get_committee_lookback_for_round(self.storage.current_round())
362 .map_or(false, |committee| committee.is_committee_member(validator_address))
363 {
364 return true;
365 }
366
367 if self.ledger.current_committee().map_or(false, |committee| committee.is_committee_member(validator_address)) {
369 return true;
370 }
371
372 let previous_block_height = self.ledger.latest_block_height().saturating_sub(MAX_BLOCKS_BEHIND);
374 match self.ledger.get_block_round(previous_block_height) {
376 Ok(block_round) => (block_round..self.storage.current_round()).step_by(2).any(|round| {
377 self.ledger
378 .get_committee_lookback_for_round(round)
379 .map_or(false, |committee| committee.is_committee_member(validator_address))
380 }),
381 Err(_) => false,
382 }
383 }
384
385 pub fn max_connected_peers(&self) -> usize {
387 self.tcp.config().max_connections as usize
388 }
389
390 pub fn number_of_connected_peers(&self) -> usize {
392 self.connected_peers.read().len()
393 }
394
395 pub fn connected_addresses(&self) -> HashSet<Address<N>> {
397 self.connected_peers.read().iter().filter_map(|peer_ip| self.resolver.get_address(*peer_ip)).collect()
398 }
399
400 pub fn connected_peers(&self) -> &RwLock<IndexSet<SocketAddr>> {
402 &self.connected_peers
403 }
404
405 pub fn connect(&self, peer_ip: SocketAddr) -> Option<JoinHandle<()>> {
407 if let Err(forbidden_error) = self.check_connection_attempt(peer_ip) {
409 warn!("{forbidden_error}");
410 return None;
411 }
412
413 let self_ = self.clone();
414 Some(tokio::spawn(async move {
415 debug!("Connecting to validator {peer_ip}...");
416 if let Err(error) = self_.tcp.connect(peer_ip).await {
418 self_.connecting_peers.lock().shift_remove(&peer_ip);
419 warn!("Unable to connect to '{peer_ip}' - {error}");
420 }
421 }))
422 }
423
424 fn check_connection_attempt(&self, peer_ip: SocketAddr) -> Result<()> {
426 if self.is_local_ip(peer_ip) {
428 bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (attempted to self-connect)")
429 }
430 if self.number_of_connected_peers() >= self.max_connected_peers() {
432 bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (maximum peers reached)")
433 }
434 if self.is_connected_ip(peer_ip) {
436 bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (already connected)")
437 }
438 if self.is_connecting_ip(peer_ip) {
440 bail!("{CONTEXT} Dropping connection attempt to '{peer_ip}' (already connecting)")
441 }
442 Ok(())
443 }
444
445 fn ensure_peer_is_allowed(&self, peer_ip: SocketAddr) -> Result<()> {
447 if self.is_local_ip(peer_ip) {
449 bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (attempted to self-connect)")
450 }
451 if !self.connecting_peers.lock().insert(peer_ip) {
453 bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (already shaking hands as the initiator)")
454 }
455 if self.is_connected_ip(peer_ip) {
457 bail!("{CONTEXT} Dropping connection request from '{peer_ip}' (already connected)")
458 }
459 if !peer_ip.ip().is_loopback() {
461 let num_attempts = self.cache.insert_inbound_connection(peer_ip.ip(), RESTRICTED_INTERVAL);
463 if num_attempts > MAX_CONNECTION_ATTEMPTS {
465 bail!("Dropping connection request from '{peer_ip}' (tried {num_attempts} times)")
466 }
467 }
468 Ok(())
469 }
470
471 #[cfg(not(any(test)))]
473 fn is_ip_banned(&self, ip: IpAddr) -> bool {
474 self.tcp.banned_peers().is_ip_banned(&ip)
475 }
476
477 #[cfg(not(any(test)))]
479 fn update_ip_ban(&self, ip: IpAddr) {
480 self.tcp.banned_peers().update_ip_ban(ip);
481 }
482
483 #[cfg(feature = "metrics")]
484 fn update_metrics(&self) {
485 metrics::gauge(metrics::bft::CONNECTED, self.connected_peers.read().len() as f64);
486 metrics::gauge(metrics::bft::CONNECTING, self.connecting_peers.lock().len() as f64);
487 }
488
489 #[cfg(not(test))]
491 fn insert_connected_peer(&self, peer_ip: SocketAddr, peer_addr: SocketAddr, address: Address<N>) {
492 self.resolver.insert_peer(peer_ip, peer_addr, address);
494 self.connected_peers.write().insert(peer_ip);
496 #[cfg(feature = "metrics")]
497 self.update_metrics();
498 }
499
500 #[cfg(test)]
502 pub fn insert_connected_peer(&self, peer_ip: SocketAddr, peer_addr: SocketAddr, address: Address<N>) {
504 self.resolver.insert_peer(peer_ip, peer_addr, address);
506 self.connected_peers.write().insert(peer_ip);
508 }
509
510 fn remove_connected_peer(&self, peer_ip: SocketAddr) {
512 if let Some(sync_sender) = self.sync_sender.get() {
514 let tx_block_sync_remove_peer_ = sync_sender.tx_block_sync_remove_peer.clone();
515 tokio::spawn(async move {
516 if let Err(e) = tx_block_sync_remove_peer_.send(peer_ip).await {
517 warn!("Unable to remove '{peer_ip}' from the sync module - {e}");
518 }
519 });
520 }
521 self.resolver.remove_peer(peer_ip);
523 self.connected_peers.write().shift_remove(&peer_ip);
525 #[cfg(feature = "metrics")]
526 self.update_metrics();
527 }
528
529 fn send_inner(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
535 let Some(peer_addr) = self.resolver.get_ambiguous(peer_ip) else {
537 warn!("Unable to resolve the listener IP address '{peer_ip}'");
538 return None;
539 };
540 let name = event.name();
542 trace!("{CONTEXT} Sending '{name}' to '{peer_ip}'");
544 let result = self.unicast(peer_addr, event);
545 if let Err(e) = &result {
547 warn!("{CONTEXT} Failed to send '{name}' to '{peer_ip}': {e}");
548 debug!("{CONTEXT} Disconnecting from '{peer_ip}' (unable to send)");
549 self.disconnect(peer_ip);
550 }
551 result.ok()
552 }
553
554 async fn inbound(&self, peer_addr: SocketAddr, event: Event<N>) -> Result<()> {
556 let Some(peer_ip) = self.resolver.get_listener(peer_addr) else {
558 bail!("{CONTEXT} Unable to resolve the (ambiguous) peer address '{peer_addr}'")
559 };
560 if !self.is_authorized_validator_ip(peer_ip) {
562 bail!("{CONTEXT} Dropping '{}' from '{peer_ip}' (not authorized)", event.name())
563 }
564 let num_events = self.cache.insert_inbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
566 if num_events >= self.max_cache_events() {
567 bail!("Dropping '{peer_ip}' for spamming events (num_events = {num_events})")
568 }
569 match event {
571 Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
572 let certificate_id = match &event {
574 Event::CertificateRequest(CertificateRequest { certificate_id }) => *certificate_id,
575 Event::CertificateResponse(CertificateResponse { certificate }) => certificate.id(),
576 _ => unreachable!(),
577 };
578 let num_events = self.cache.insert_inbound_certificate(certificate_id, CACHE_REQUESTS_INTERVAL);
580 if num_events >= self.max_cache_duplicates() {
581 return Ok(());
582 }
583 }
584 Event::TransmissionRequest(TransmissionRequest { transmission_id })
585 | Event::TransmissionResponse(TransmissionResponse { transmission_id, .. }) => {
586 let num_events = self.cache.insert_inbound_transmission(transmission_id, CACHE_REQUESTS_INTERVAL);
588 if num_events >= self.max_cache_duplicates() {
589 return Ok(());
590 }
591 }
592 Event::BlockRequest(_) => {
593 let num_events = self.cache.insert_inbound_block_request(peer_ip, CACHE_REQUESTS_INTERVAL);
594 if num_events >= self.max_cache_duplicates() {
595 return Ok(());
596 }
597 }
598 _ => {}
599 }
600 trace!("{CONTEXT} Received '{}' from '{peer_ip}'", event.name());
601
602 match event {
605 Event::BatchPropose(batch_propose) => {
606 let _ = self.primary_sender().tx_batch_propose.send((peer_ip, batch_propose)).await;
608 Ok(())
609 }
610 Event::BatchSignature(batch_signature) => {
611 let _ = self.primary_sender().tx_batch_signature.send((peer_ip, batch_signature)).await;
613 Ok(())
614 }
615 Event::BatchCertified(batch_certified) => {
616 let _ = self.primary_sender().tx_batch_certified.send((peer_ip, batch_certified.certificate)).await;
618 Ok(())
619 }
620 Event::BlockRequest(block_request) => {
621 let BlockRequest { start_height, end_height } = block_request;
622
623 if start_height >= end_height {
625 bail!("Block request from '{peer_ip}' has an invalid range ({start_height}..{end_height})")
626 }
627 if end_height - start_height > DataBlocks::<N>::MAXIMUM_NUMBER_OF_BLOCKS as u32 {
629 bail!("Block request from '{peer_ip}' has an excessive range ({start_height}..{end_height})")
630 }
631
632 let self_ = self.clone();
633 let blocks = match task::spawn_blocking(move || {
634 match self_.ledger.get_blocks(start_height..end_height) {
636 Ok(blocks) => Ok(Data::Object(DataBlocks(blocks))),
637 Err(error) => bail!("Missing blocks {start_height} to {end_height} from ledger - {error}"),
638 }
639 })
640 .await
641 {
642 Ok(Ok(blocks)) => blocks,
643 Ok(Err(error)) => return Err(error),
644 Err(error) => return Err(anyhow!("[BlockRequest] {error}")),
645 };
646
647 let self_ = self.clone();
648 tokio::spawn(async move {
649 let event = Event::BlockResponse(BlockResponse { request: block_request, blocks });
651 Transport::send(&self_, peer_ip, event).await;
652 });
653 Ok(())
654 }
655 Event::BlockResponse(block_response) => {
656 if let Some(sync_sender) = self.sync_sender.get() {
658 let BlockResponse { request, blocks } = block_response;
660 if !self.cache.remove_outbound_block_request(peer_ip, &request) {
662 bail!("Unsolicited block response from '{peer_ip}'")
663 }
664 let blocks = blocks.deserialize().await.map_err(|error| anyhow!("[BlockResponse] {error}"))?;
666 blocks.ensure_response_is_well_formed(peer_ip, request.start_height, request.end_height)?;
668 if let Err(e) = sync_sender.advance_with_sync_blocks(peer_ip, blocks.0).await {
670 warn!("Unable to process block response from '{peer_ip}' - {e}");
671 }
672 }
673 Ok(())
674 }
675 Event::CertificateRequest(certificate_request) => {
676 if let Some(sync_sender) = self.sync_sender.get() {
678 let _ = sync_sender.tx_certificate_request.send((peer_ip, certificate_request)).await;
680 }
681 Ok(())
682 }
683 Event::CertificateResponse(certificate_response) => {
684 if let Some(sync_sender) = self.sync_sender.get() {
686 let _ = sync_sender.tx_certificate_response.send((peer_ip, certificate_response)).await;
688 }
689 Ok(())
690 }
691 Event::ChallengeRequest(..) | Event::ChallengeResponse(..) => {
692 bail!("{CONTEXT} Peer '{peer_ip}' is not following the protocol")
694 }
695 Event::Disconnect(disconnect) => {
696 bail!("{CONTEXT} {:?}", disconnect.reason)
697 }
698 Event::PrimaryPing(ping) => {
699 let PrimaryPing { version, block_locators, primary_certificate } = ping;
700
701 if version < Event::<N>::VERSION {
703 bail!("Dropping '{peer_ip}' on event version {version} (outdated)");
704 }
705
706 if let Some(sync_sender) = self.sync_sender.get() {
708 if let Err(error) = sync_sender.update_peer_locators(peer_ip, block_locators).await {
710 bail!("Validator '{peer_ip}' sent invalid block locators - {error}");
711 }
712 }
713
714 let _ = self.primary_sender().tx_primary_ping.send((peer_ip, primary_certificate)).await;
716 Ok(())
717 }
718 Event::TransmissionRequest(request) => {
719 let Ok(worker_id) = assign_to_worker(request.transmission_id, self.num_workers()) else {
722 warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", request.transmission_id);
723 return Ok(());
724 };
725 if let Some(sender) = self.get_worker_sender(worker_id) {
727 let _ = sender.tx_transmission_request.send((peer_ip, request)).await;
729 }
730 Ok(())
731 }
732 Event::TransmissionResponse(response) => {
733 let Ok(worker_id) = assign_to_worker(response.transmission_id, self.num_workers()) else {
735 warn!("{CONTEXT} Unable to assign transmission ID '{}' to a worker", response.transmission_id);
736 return Ok(());
737 };
738 if let Some(sender) = self.get_worker_sender(worker_id) {
740 let _ = sender.tx_transmission_response.send((peer_ip, response)).await;
742 }
743 Ok(())
744 }
745 Event::ValidatorsRequest(_) => {
746 let mut connected_peers: Vec<_> = match self.dev.is_some() {
748 true => self.connected_peers.read().iter().copied().collect(),
750 false => {
752 self.connected_peers.read().iter().copied().filter(|ip| self.is_valid_peer_ip(*ip)).collect()
753 }
754 };
755 connected_peers.shuffle(&mut rand::thread_rng());
757
758 let self_ = self.clone();
759 tokio::spawn(async move {
760 let mut validators = IndexMap::with_capacity(MAX_VALIDATORS_TO_SEND);
762 for validator_ip in connected_peers.into_iter().take(MAX_VALIDATORS_TO_SEND) {
764 if let Some(validator_address) = self_.resolver.get_address(validator_ip) {
766 validators.insert(validator_ip, validator_address);
768 }
769 }
770 let event = Event::ValidatorsResponse(ValidatorsResponse { validators });
772 Transport::send(&self_, peer_ip, event).await;
773 });
774 Ok(())
775 }
776 Event::ValidatorsResponse(response) => {
777 let ValidatorsResponse { validators } = response;
778 ensure!(validators.len() <= MAX_VALIDATORS_TO_SEND, "{CONTEXT} Received too many validators");
780 if !self.cache.contains_outbound_validators_request(peer_ip) {
782 bail!("{CONTEXT} Received validators response from '{peer_ip}' without a validators request")
783 }
784 self.cache.decrement_outbound_validators_requests(peer_ip);
786
787 if self.number_of_connected_peers() < MIN_CONNECTED_VALIDATORS {
789 let self_ = self.clone();
791 tokio::spawn(async move {
792 for (validator_ip, validator_address) in validators {
793 if self_.dev.is_some() {
794 if self_.is_local_ip(validator_ip) {
796 continue;
797 }
798 } else {
799 if !self_.is_valid_peer_ip(validator_ip) {
801 continue;
802 }
803 }
804
805 if self_.account.address() == validator_address {
807 continue;
808 }
809 if self_.is_connected_ip(validator_ip) || self_.is_connecting_ip(validator_ip) {
811 continue;
812 }
813 if self_.is_connected_address(validator_address) {
815 continue;
816 }
817 if !self_.is_authorized_validator_address(validator_address) {
819 continue;
820 }
821 self_.connect(validator_ip);
823 }
824 });
825 }
826 Ok(())
827 }
828 Event::WorkerPing(ping) => {
829 ensure!(
831 ping.transmission_ids.len() <= Worker::<N>::MAX_TRANSMISSIONS_PER_WORKER_PING,
832 "{CONTEXT} Received too many transmissions"
833 );
834 let num_workers = self.num_workers();
836 for transmission_id in ping.transmission_ids.into_iter() {
838 let Ok(worker_id) = assign_to_worker(transmission_id, num_workers) else {
840 warn!("{CONTEXT} Unable to assign transmission ID '{transmission_id}' to a worker");
841 continue;
842 };
843 if let Some(sender) = self.get_worker_sender(worker_id) {
845 let _ = sender.tx_worker_ping.send((peer_ip, transmission_id)).await;
847 }
848 }
849 Ok(())
850 }
851 }
852 }
853
854 pub fn disconnect(&self, peer_ip: SocketAddr) -> JoinHandle<()> {
856 let gateway = self.clone();
857 tokio::spawn(async move {
858 if let Some(peer_addr) = gateway.resolver.get_ambiguous(peer_ip) {
859 let _disconnected = gateway.tcp.disconnect(peer_addr).await;
861 debug_assert!(_disconnected);
862 }
863 })
864 }
865
866 fn initialize_heartbeat(&self) {
868 let self_clone = self.clone();
869 self.spawn(async move {
870 tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
872 info!("Starting the heartbeat of the gateway...");
873 loop {
874 self_clone.heartbeat();
876 tokio::time::sleep(Duration::from_secs(15)).await;
878 }
879 });
880 }
881
882 #[allow(dead_code)]
884 fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
885 self.handles.lock().push(tokio::spawn(future));
886 }
887
888 pub async fn shut_down(&self) {
890 info!("Shutting down the gateway...");
891 self.handles.lock().iter().for_each(|handle| handle.abort());
893 self.tcp.shut_down().await;
895 }
896}
897
898impl<N: Network> Gateway<N> {
899 fn heartbeat(&self) {
901 self.log_connected_validators();
902 self.handle_trusted_validators();
904 self.handle_unauthorized_validators();
906 self.handle_min_connected_validators();
908 self.handle_banned_ips();
910 }
911
912 fn log_connected_validators(&self) {
914 let validators = self.connected_peers().read().clone();
916 let validators_total = self.ledger.current_committee().map_or(0, |c| c.num_members().saturating_sub(1));
918 let total_validators = format!("(of {validators_total} bonded validators)").dimmed();
920 let connections_msg = match validators.len() {
922 0 => "No connected validators".to_string(),
923 num_connected => format!("Connected to {num_connected} validators {total_validators}"),
924 };
925 info!("{connections_msg}");
927 for peer_ip in validators {
928 let address = self.resolver.get_address(peer_ip).map_or("Unknown".to_string(), |a| a.to_string());
929 debug!("{}", format!(" {peer_ip} - {address}").dimmed());
930 }
931 }
932
933 fn handle_trusted_validators(&self) {
935 for validator_ip in &self.trusted_validators {
937 if !self.is_local_ip(*validator_ip)
939 && !self.is_connecting_ip(*validator_ip)
940 && !self.is_connected_ip(*validator_ip)
941 {
942 self.connect(*validator_ip);
944 }
945 }
946 }
947
948 fn handle_unauthorized_validators(&self) {
950 let self_ = self.clone();
951 tokio::spawn(async move {
952 let validators = self_.connected_peers().read().clone();
954 for peer_ip in validators {
956 if !self_.is_authorized_validator_ip(peer_ip) {
958 warn!("{CONTEXT} Disconnecting from '{peer_ip}' - Validator is not in the current committee");
959 Transport::send(&self_, peer_ip, DisconnectReason::ProtocolViolation.into()).await;
960 self_.disconnect(peer_ip);
962 }
963 }
964 });
965 }
966
967 fn handle_min_connected_validators(&self) {
970 if self.number_of_connected_peers() < MIN_CONNECTED_VALIDATORS {
972 let validators = self.connected_peers().read().clone();
974 if validators.is_empty() {
976 return;
977 }
978 if let Some(validator_ip) = validators.into_iter().choose(&mut rand::thread_rng()) {
980 let self_ = self.clone();
981 tokio::spawn(async move {
982 self_.cache.increment_outbound_validators_requests(validator_ip);
984 let _ = Transport::send(&self_, validator_ip, Event::ValidatorsRequest(ValidatorsRequest)).await;
986 });
987 }
988 }
989 }
990
991 fn handle_banned_ips(&self) {
993 self.tcp.banned_peers().remove_old_bans(IP_BAN_TIME_IN_SECS);
994 }
995}
996
997#[async_trait]
998impl<N: Network> Transport<N> for Gateway<N> {
999 async fn send(&self, peer_ip: SocketAddr, event: Event<N>) -> Option<oneshot::Receiver<io::Result<()>>> {
1007 macro_rules! send {
1008 ($self:ident, $cache_map:ident, $interval:expr, $freq:ident) => {{
1009 while $self.cache.$cache_map(peer_ip, $interval) > $self.$freq() {
1011 tokio::time::sleep(Duration::from_millis(10)).await;
1013 }
1014 $self.send_inner(peer_ip, event)
1016 }};
1017 }
1018
1019 match event {
1021 Event::CertificateRequest(_) | Event::CertificateResponse(_) => {
1022 self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
1024 send!(self, insert_outbound_certificate, CACHE_REQUESTS_INTERVAL, max_cache_certificates)
1026 }
1027 Event::TransmissionRequest(_) | Event::TransmissionResponse(_) => {
1028 self.cache.insert_outbound_event(peer_ip, CACHE_EVENTS_INTERVAL);
1030 send!(self, insert_outbound_transmission, CACHE_REQUESTS_INTERVAL, max_cache_transmissions)
1032 }
1033 Event::BlockRequest(request) => {
1034 self.cache.insert_outbound_block_request(peer_ip, request);
1036 send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
1038 }
1039 _ => {
1040 send!(self, insert_outbound_event, CACHE_EVENTS_INTERVAL, max_cache_events)
1042 }
1043 }
1044 }
1045
1046 fn broadcast(&self, event: Event<N>) {
1050 if self.number_of_connected_peers() > 0 {
1052 let self_ = self.clone();
1053 let connected_peers = self.connected_peers.read().clone();
1054 tokio::spawn(async move {
1055 for peer_ip in connected_peers {
1057 let _ = Transport::send(&self_, peer_ip, event.clone()).await;
1059 }
1060 });
1061 }
1062 }
1063}
1064
1065impl<N: Network> P2P for Gateway<N> {
1066 fn tcp(&self) -> &Tcp {
1068 &self.tcp
1069 }
1070}
1071
1072#[async_trait]
1073impl<N: Network> Reading for Gateway<N> {
1074 type Codec = EventCodec<N>;
1075 type Message = Event<N>;
1076
1077 const MESSAGE_QUEUE_DEPTH: usize = 2
1079 * BatchHeader::<N>::MAX_GC_ROUNDS
1080 * Committee::<N>::MAX_COMMITTEE_SIZE as usize
1081 * BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH;
1082
1083 fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
1086 Default::default()
1087 }
1088
1089 async fn process_message(&self, peer_addr: SocketAddr, message: Self::Message) -> io::Result<()> {
1091 if let Err(error) = self.inbound(peer_addr, message).await {
1093 if let Some(peer_ip) = self.resolver.get_listener(peer_addr) {
1094 warn!("{CONTEXT} Disconnecting from '{peer_ip}' - {error}");
1095 let self_ = self.clone();
1096 tokio::spawn(async move {
1097 Transport::send(&self_, peer_ip, DisconnectReason::ProtocolViolation.into()).await;
1098 self_.disconnect(peer_ip);
1100 });
1101 }
1102 }
1103 Ok(())
1104 }
1105}
1106
1107#[async_trait]
1108impl<N: Network> Writing for Gateway<N> {
1109 type Codec = EventCodec<N>;
1110 type Message = Event<N>;
1111
1112 const MESSAGE_QUEUE_DEPTH: usize = 2
1114 * BatchHeader::<N>::MAX_GC_ROUNDS
1115 * Committee::<N>::MAX_COMMITTEE_SIZE as usize
1116 * BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH;
1117
1118 fn codec(&self, _peer_addr: SocketAddr, _side: ConnectionSide) -> Self::Codec {
1121 Default::default()
1122 }
1123}
1124
1125#[async_trait]
1126impl<N: Network> Disconnect for Gateway<N> {
1127 async fn handle_disconnect(&self, peer_addr: SocketAddr) {
1129 if let Some(peer_ip) = self.resolver.get_listener(peer_addr) {
1130 self.remove_connected_peer(peer_ip);
1131
1132 self.cache.clear_outbound_validators_requests(peer_ip);
1136 self.cache.clear_outbound_block_requests(peer_ip);
1137 }
1138 }
1139}
1140
1141#[async_trait]
1142impl<N: Network> OnConnect for Gateway<N> {
1143 async fn on_connect(&self, _peer_addr: SocketAddr) {
1144 return;
1145 }
1146}
1147
1148#[async_trait]
1149impl<N: Network> Handshake for Gateway<N> {
1150 async fn perform_handshake(&self, mut connection: Connection) -> io::Result<Connection> {
1152 let peer_addr = connection.addr();
1154 let peer_side = connection.side();
1155
1156 #[cfg(not(any(test)))]
1158 if self.dev().is_none() && peer_side == ConnectionSide::Initiator {
1159 if self.is_ip_banned(peer_addr.ip()) {
1161 trace!("{CONTEXT} Gateway rejected a connection request from banned IP '{}'", peer_addr.ip());
1162 return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
1163 }
1164
1165 let num_attempts = self.cache.insert_inbound_connection(peer_addr.ip(), CONNECTION_ATTEMPTS_SINCE_SECS);
1166
1167 debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
1168 if num_attempts > MAX_CONNECTION_ATTEMPTS {
1169 self.update_ip_ban(peer_addr.ip());
1170 trace!("{CONTEXT} Gateway rejected a consecutive connection request from IP '{}'", peer_addr.ip());
1171 return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
1172 }
1173 }
1174
1175 let stream = self.borrow_stream(&mut connection);
1176
1177 let mut peer_ip = if peer_side == ConnectionSide::Initiator {
1180 debug!("{CONTEXT} Gateway received a connection request from '{peer_addr}'");
1181 None
1182 } else {
1183 debug!("{CONTEXT} Gateway is connecting to {peer_addr}...");
1184 Some(peer_addr)
1185 };
1186
1187 let restrictions_id = self.ledger.latest_restrictions_id();
1189
1190 let handshake_result = if peer_side == ConnectionSide::Responder {
1192 self.handshake_inner_initiator(peer_addr, peer_ip, restrictions_id, stream).await
1193 } else {
1194 self.handshake_inner_responder(peer_addr, &mut peer_ip, restrictions_id, stream).await
1195 };
1196
1197 if let Some(ip) = peer_ip {
1199 self.connecting_peers.lock().shift_remove(&ip);
1200 }
1201 let (ref peer_ip, _) = handshake_result?;
1202 info!("{CONTEXT} Gateway is connected to '{peer_ip}'");
1203
1204 Ok(connection)
1205 }
1206}
1207
1208macro_rules! expect_event {
1210 ($event_ty:path, $framed:expr, $peer_addr:expr) => {
1211 match $framed.try_next().await? {
1212 Some($event_ty(data)) => {
1214 trace!("{CONTEXT} Gateway received '{}' from '{}'", data.name(), $peer_addr);
1215 data
1216 }
1217 Some(Event::Disconnect(reason)) => {
1219 return Err(error(format!("{CONTEXT} '{}' disconnected: {reason:?}", $peer_addr)));
1220 }
1221 Some(ty) => {
1223 return Err(error(format!(
1224 "{CONTEXT} '{}' did not follow the handshake protocol: received {:?} instead of {}",
1225 $peer_addr,
1226 ty.name(),
1227 stringify!($event_ty),
1228 )))
1229 }
1230 None => {
1232 return Err(error(format!(
1233 "{CONTEXT} '{}' disconnected before sending {:?}",
1234 $peer_addr,
1235 stringify!($event_ty)
1236 )))
1237 }
1238 }
1239 };
1240}
1241
1242async fn send_event<N: Network>(
1244 framed: &mut Framed<&mut TcpStream, EventCodec<N>>,
1245 peer_addr: SocketAddr,
1246 event: Event<N>,
1247) -> io::Result<()> {
1248 trace!("{CONTEXT} Gateway is sending '{}' to '{peer_addr}'", event.name());
1249 framed.send(event).await
1250}
1251
1252impl<N: Network> Gateway<N> {
1253 async fn handshake_inner_initiator<'a>(
1255 &'a self,
1256 peer_addr: SocketAddr,
1257 peer_ip: Option<SocketAddr>,
1258 restrictions_id: Field<N>,
1259 stream: &'a mut TcpStream,
1260 ) -> io::Result<(SocketAddr, Framed<&mut TcpStream, EventCodec<N>>)> {
1261 let peer_ip = peer_ip.unwrap();
1263
1264 let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
1266
1267 let rng = &mut rand::rngs::OsRng;
1269
1270 let our_nonce = rng.gen();
1274 let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce);
1276 send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
1277
1278 let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
1282 let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
1284
1285 if let Some(reason) = self
1287 .verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
1288 .await
1289 {
1290 send_event(&mut framed, peer_addr, reason.into()).await?;
1291 return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
1292 }
1293 if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
1295 send_event(&mut framed, peer_addr, reason.into()).await?;
1296 return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
1297 }
1298
1299 let response_nonce: u64 = rng.gen();
1303 let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
1304 let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
1305 return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
1306 };
1307 let our_response =
1309 ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
1310 send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
1311
1312 self.insert_connected_peer(peer_ip, peer_addr, peer_request.address);
1314
1315 Ok((peer_ip, framed))
1316 }
1317
1318 async fn handshake_inner_responder<'a>(
1320 &'a self,
1321 peer_addr: SocketAddr,
1322 peer_ip: &mut Option<SocketAddr>,
1323 restrictions_id: Field<N>,
1324 stream: &'a mut TcpStream,
1325 ) -> io::Result<(SocketAddr, Framed<&mut TcpStream, EventCodec<N>>)> {
1326 let mut framed = Framed::new(stream, EventCodec::<N>::handshake());
1328
1329 let peer_request = expect_event!(Event::ChallengeRequest, framed, peer_addr);
1333
1334 if self.account.address() == peer_request.address {
1336 return Err(error("Skipping request to connect to self".to_string()));
1337 }
1338
1339 *peer_ip = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
1341 let peer_ip = peer_ip.unwrap();
1342
1343 if let Err(forbidden_message) = self.ensure_peer_is_allowed(peer_ip) {
1345 return Err(error(format!("{forbidden_message}")));
1346 }
1347 if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
1349 send_event(&mut framed, peer_addr, reason.into()).await?;
1350 return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
1351 }
1352
1353 let rng = &mut rand::rngs::OsRng;
1357
1358 let response_nonce: u64 = rng.gen();
1360 let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
1361 let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
1362 return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
1363 };
1364 let our_response =
1366 ChallengeResponse { restrictions_id, signature: Data::Object(our_signature), nonce: response_nonce };
1367 send_event(&mut framed, peer_addr, Event::ChallengeResponse(our_response)).await?;
1368
1369 let our_nonce = rng.gen();
1371 let our_request = ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce);
1373 send_event(&mut framed, peer_addr, Event::ChallengeRequest(our_request)).await?;
1374
1375 let peer_response = expect_event!(Event::ChallengeResponse, framed, peer_addr);
1379 if let Some(reason) = self
1381 .verify_challenge_response(peer_addr, peer_request.address, peer_response, restrictions_id, our_nonce)
1382 .await
1383 {
1384 send_event(&mut framed, peer_addr, reason.into()).await?;
1385 return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
1386 }
1387 self.insert_connected_peer(peer_ip, peer_addr, peer_request.address);
1389
1390 Ok((peer_ip, framed))
1391 }
1392
1393 fn verify_challenge_request(&self, peer_addr: SocketAddr, event: &ChallengeRequest<N>) -> Option<DisconnectReason> {
1395 let &ChallengeRequest { version, listener_port: _, address, nonce: _ } = event;
1397 if version < Event::<N>::VERSION {
1399 warn!("{CONTEXT} Gateway is dropping '{peer_addr}' on version {version} (outdated)");
1400 return Some(DisconnectReason::OutdatedClientVersion);
1401 }
1402 if !self.is_authorized_validator_address(address) {
1404 warn!("{CONTEXT} Gateway is dropping '{peer_addr}' for being an unauthorized validator ({address})");
1405 return Some(DisconnectReason::ProtocolViolation);
1406 }
1407 if self.is_connected_address(address) {
1409 warn!("{CONTEXT} Gateway is dropping '{peer_addr}' for being already connected ({address})");
1410 return Some(DisconnectReason::ProtocolViolation);
1411 }
1412 None
1413 }
1414
1415 async fn verify_challenge_response(
1417 &self,
1418 peer_addr: SocketAddr,
1419 peer_address: Address<N>,
1420 response: ChallengeResponse<N>,
1421 expected_restrictions_id: Field<N>,
1422 expected_nonce: u64,
1423 ) -> Option<DisconnectReason> {
1424 let ChallengeResponse { restrictions_id, signature, nonce } = response;
1426
1427 if restrictions_id != expected_restrictions_id {
1429 warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (incorrect restrictions ID)");
1430 return Some(DisconnectReason::InvalidChallengeResponse);
1431 }
1432 let Ok(signature) = spawn_blocking!(signature.deserialize_blocking()) else {
1434 warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (cannot deserialize the signature)");
1435 return Some(DisconnectReason::InvalidChallengeResponse);
1436 };
1437 if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
1439 warn!("{CONTEXT} Gateway handshake with '{peer_addr}' failed (invalid signature)");
1440 return Some(DisconnectReason::InvalidChallengeResponse);
1441 }
1442 None
1443 }
1444}
1445
1446#[cfg(test)]
1447mod prop_tests {
1448 use crate::{
1449 Gateway,
1450 MAX_WORKERS,
1451 MEMORY_POOL_PORT,
1452 Worker,
1453 gateway::prop_tests::GatewayAddress::{Dev, Prod},
1454 helpers::{Storage, init_primary_channels, init_worker_channels},
1455 };
1456 use snarkos_account::Account;
1457 use snarkos_node_bft_ledger_service::MockLedgerService;
1458 use snarkos_node_bft_storage_service::BFTMemoryService;
1459 use snarkos_node_tcp::P2P;
1460 use snarkvm::{
1461 ledger::{
1462 committee::{
1463 Committee,
1464 prop_tests::{CommitteeContext, ValidatorSet},
1465 test_helpers::sample_committee_for_round_and_members,
1466 },
1467 narwhal::{BatchHeader, batch_certificate::test_helpers::sample_batch_certificate_for_round},
1468 },
1469 prelude::{MainnetV0, PrivateKey},
1470 utilities::TestRng,
1471 };
1472
1473 use indexmap::{IndexMap, IndexSet};
1474 use proptest::{
1475 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any, any_with},
1476 sample::Selector,
1477 };
1478 use std::{
1479 fmt::{Debug, Formatter},
1480 net::{IpAddr, Ipv4Addr, SocketAddr},
1481 sync::Arc,
1482 };
1483 use test_strategy::proptest;
1484
1485 type CurrentNetwork = MainnetV0;
1486
1487 impl Debug for Gateway<CurrentNetwork> {
1488 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1489 f.debug_tuple("Gateway").field(&self.account.address()).field(&self.tcp.config()).finish()
1491 }
1492 }
1493
1494 #[derive(Debug, test_strategy::Arbitrary)]
1495 enum GatewayAddress {
1496 Dev(u8),
1497 Prod(Option<SocketAddr>),
1498 }
1499
1500 impl GatewayAddress {
1501 fn ip(&self) -> Option<SocketAddr> {
1502 if let GatewayAddress::Prod(ip) = self {
1503 return *ip;
1504 }
1505 None
1506 }
1507
1508 fn port(&self) -> Option<u16> {
1509 if let GatewayAddress::Dev(port) = self {
1510 return Some(*port as u16);
1511 }
1512 None
1513 }
1514 }
1515
1516 impl Arbitrary for Gateway<CurrentNetwork> {
1517 type Parameters = ();
1518 type Strategy = BoxedStrategy<Gateway<CurrentNetwork>>;
1519
1520 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1521 any_valid_dev_gateway()
1522 .prop_map(|(storage, _, private_key, address)| {
1523 Gateway::new(
1524 Account::try_from(private_key).unwrap(),
1525 storage.clone(),
1526 storage.ledger().clone(),
1527 address.ip(),
1528 &[],
1529 address.port(),
1530 )
1531 .unwrap()
1532 })
1533 .boxed()
1534 }
1535 }
1536
1537 type GatewayInput = (Storage<CurrentNetwork>, CommitteeContext, PrivateKey<CurrentNetwork>, GatewayAddress);
1538
1539 fn any_valid_dev_gateway() -> BoxedStrategy<GatewayInput> {
1540 (any::<CommitteeContext>(), any::<Selector>())
1541 .prop_flat_map(|(context, account_selector)| {
1542 let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
1543 (
1544 any_with::<Storage<CurrentNetwork>>(context.clone()),
1545 Just(context),
1546 Just(account_selector.select(validators)),
1547 0u8..,
1548 )
1549 .prop_map(|(a, b, c, d)| (a, b, c.private_key, Dev(d)))
1550 })
1551 .boxed()
1552 }
1553
1554 fn any_valid_prod_gateway() -> BoxedStrategy<GatewayInput> {
1555 (any::<CommitteeContext>(), any::<Selector>())
1556 .prop_flat_map(|(context, account_selector)| {
1557 let CommitteeContext(_, ValidatorSet(validators)) = context.clone();
1558 (
1559 any_with::<Storage<CurrentNetwork>>(context.clone()),
1560 Just(context),
1561 Just(account_selector.select(validators)),
1562 any::<Option<SocketAddr>>(),
1563 )
1564 .prop_map(|(a, b, c, d)| (a, b, c.private_key, Prod(d)))
1565 })
1566 .boxed()
1567 }
1568
1569 #[proptest]
1570 fn gateway_dev_initialization(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
1571 let (storage, _, private_key, dev) = input;
1572 let account = Account::try_from(private_key).unwrap();
1573
1574 let gateway =
1575 Gateway::new(account.clone(), storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port())
1576 .unwrap();
1577 let tcp_config = gateway.tcp().config();
1578 assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
1579 assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT + dev.port().unwrap()));
1580
1581 let tcp_config = gateway.tcp().config();
1582 assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::MAX_COMMITTEE_SIZE);
1583 assert_eq!(gateway.account().address(), account.address());
1584 }
1585
1586 #[proptest]
1587 fn gateway_prod_initialization(#[strategy(any_valid_prod_gateway())] input: GatewayInput) {
1588 let (storage, _, private_key, dev) = input;
1589 let account = Account::try_from(private_key).unwrap();
1590
1591 let gateway =
1592 Gateway::new(account.clone(), storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port())
1593 .unwrap();
1594 let tcp_config = gateway.tcp().config();
1595 if let Some(socket_addr) = dev.ip() {
1596 assert_eq!(tcp_config.listener_ip, Some(socket_addr.ip()));
1597 assert_eq!(tcp_config.desired_listening_port, Some(socket_addr.port()));
1598 } else {
1599 assert_eq!(tcp_config.listener_ip, Some(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
1600 assert_eq!(tcp_config.desired_listening_port, Some(MEMORY_POOL_PORT));
1601 }
1602
1603 let tcp_config = gateway.tcp().config();
1604 assert_eq!(tcp_config.max_connections, Committee::<CurrentNetwork>::MAX_COMMITTEE_SIZE);
1605 assert_eq!(gateway.account().address(), account.address());
1606 }
1607
1608 #[proptest(async = "tokio")]
1609 async fn gateway_start(
1610 #[strategy(any_valid_dev_gateway())] input: GatewayInput,
1611 #[strategy(0..MAX_WORKERS)] workers_count: u8,
1612 ) {
1613 let (storage, committee, private_key, dev) = input;
1614 let committee = committee.0;
1615 let worker_storage = storage.clone();
1616 let account = Account::try_from(private_key).unwrap();
1617
1618 let gateway =
1619 Gateway::new(account, storage.clone(), storage.ledger().clone(), dev.ip(), &[], dev.port()).unwrap();
1620
1621 let (primary_sender, _) = init_primary_channels();
1622
1623 let (workers, worker_senders) = {
1624 let mut tx_workers = IndexMap::new();
1626 let mut workers = IndexMap::new();
1627
1628 for id in 0..workers_count {
1630 let (tx_worker, rx_worker) = init_worker_channels();
1632 let ledger = Arc::new(MockLedgerService::new(committee.clone()));
1634 let worker =
1635 Worker::new(id, Arc::new(gateway.clone()), worker_storage.clone(), ledger, Default::default())
1636 .unwrap();
1637 worker.run(rx_worker);
1639
1640 workers.insert(id, worker);
1642 tx_workers.insert(id, tx_worker);
1643 }
1644 (workers, tx_workers)
1645 };
1646
1647 gateway.run(primary_sender, worker_senders, None).await;
1648 assert_eq!(
1649 gateway.local_ip(),
1650 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), MEMORY_POOL_PORT + dev.port().unwrap())
1651 );
1652 assert_eq!(gateway.num_workers(), workers.len() as u8);
1653 }
1654
1655 #[proptest]
1656 fn test_is_authorized_validator(#[strategy(any_valid_dev_gateway())] input: GatewayInput) {
1657 let rng = &mut TestRng::default();
1658
1659 let current_round = 2;
1661 let committee_size = 4;
1662 let max_gc_rounds = BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64;
1663 let (_, _, private_key, dev) = input;
1664 let account = Account::try_from(private_key).unwrap();
1665
1666 let mut certificates = IndexSet::new();
1668 for _ in 0..committee_size {
1669 certificates.insert(sample_batch_certificate_for_round(current_round, rng));
1670 }
1671 let addresses: Vec<_> = certificates.iter().map(|certificate| certificate.author()).collect();
1672 let committee = sample_committee_for_round_and_members(current_round, addresses, rng);
1674 for _ in 0..committee_size {
1676 certificates.insert(sample_batch_certificate_for_round(current_round, rng));
1677 }
1678 let ledger = Arc::new(MockLedgerService::new(committee.clone()));
1680 let storage = Storage::new(ledger.clone(), Arc::new(BFTMemoryService::new()), max_gc_rounds);
1682 let gateway =
1684 Gateway::new(account.clone(), storage.clone(), ledger.clone(), dev.ip(), &[], dev.port()).unwrap();
1685 for certificate in certificates.iter() {
1687 storage.testing_only_insert_certificate_testing_only(certificate.clone());
1688 }
1689 for i in 0..certificates.clone().len() {
1691 let is_authorized = gateway.is_authorized_validator_address(certificates[i].author());
1692 if i < committee_size {
1693 assert!(is_authorized);
1694 } else {
1695 assert!(!is_authorized);
1696 }
1697 }
1698 }
1699}