Skip to main content

snarkos_node_tcp/
tcp.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{
17    collections::HashSet,
18    fmt,
19    io,
20    net::{IpAddr, SocketAddr},
21    ops::Deref,
22    sync::{
23        Arc,
24        atomic::{AtomicUsize, Ordering::*},
25    },
26    time::{Duration, Instant},
27};
28
29use anyhow::anyhow;
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::Mutex;
32use once_cell::sync::OnceCell;
33#[cfg(not(feature = "locktick"))]
34use parking_lot::Mutex;
35use tokio::{
36    io::split,
37    net::{TcpListener, TcpSocket, TcpStream},
38    sync::oneshot,
39    task::{JoinHandle, JoinSet},
40    time::timeout,
41};
42use tracing::*;
43
44use crate::{
45    BannedPeers,
46    Config,
47    KnownPeers,
48    Stats,
49    connections::{Connection, ConnectionSide, Connections, create_connection_span},
50    protocols::{Protocol, Protocols},
51};
52
53// A sequential numeric identifier assigned to `Tcp`s that were not provided with a name.
54static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
55
56/// The central object responsible for handling connections.
57#[derive(Clone)]
58pub struct Tcp(Arc<InnerTcp>);
59
60impl Deref for Tcp {
61    type Target = Arc<InnerTcp>;
62
63    fn deref(&self) -> &Self::Target {
64        &self.0
65    }
66}
67
68/// A custom application error that can be returned by the `Tcp` stack.
69pub trait ApplicationError: Send + Sync + std::fmt::Debug + std::fmt::Display + 'static {}
70
71/// Error types for the `Tcp::connect` function.
72#[allow(missing_docs)]
73#[derive(thiserror::Error, Debug)]
74pub enum ConnectError {
75    #[error("already reached the maximum number of {limit} connections")]
76    MaximumConnectionsReached { limit: u16 },
77    #[error("already connecting to node at {address:?}")]
78    AlreadyConnecting { address: SocketAddr },
79    #[error("already connected to node at {address:?}")]
80    AlreadyConnected { address: SocketAddr },
81    #[error("attempt to self-connect (at address {address:?}")]
82    SelfConnect { address: SocketAddr },
83    #[error("rejected a connection attempt from a banned IP '{ip}'")]
84    BannedIp { ip: IpAddr },
85    // Socket errors, such as "connection refused".
86    #[error(transparent)]
87    IoError(std::io::Error),
88    // An application-specific reason to reject the connection or abort the handshake.
89    // For snarkOS, this is either a `DisconnectReason` or a `PeeringError`, which do not fully implement `std::error::Error`.
90    #[error("{0}")]
91    ApplicationError(Box<dyn ApplicationError>),
92    /// An unexpected error at the application layer and certain deserialization errors.
93    /// TODO(kaimast): (some of) these should be treated with higher severity, as they indicate a bug or corrupted state,
94    ///                and deserialization errors should not be included in this "other" category.
95    #[error(transparent)]
96    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
97}
98
99impl ConnectError {
100    /// Pass an application-level error to the `Tcp` stack.
101    pub fn application<E: ApplicationError>(err: E) -> Self {
102        Self::ApplicationError(Box::new(err))
103    }
104
105    /// A generic error that can be returned by the `Tcp` stack.
106    pub fn other<E: Into<Box<dyn std::error::Error + Send + Sync>>>(err: E) -> Self {
107        Self::Other(err.into())
108    }
109}
110
111impl From<ConnectError> for std::io::Error {
112    fn from(err: ConnectError) -> Self {
113        match err {
114            ConnectError::IoError(err) => err,
115            ConnectError::Other(err) => std::io::Error::other(err),
116            err => std::io::Error::other(err.to_string()),
117        }
118    }
119}
120
121impl From<std::io::Error> for ConnectError {
122    fn from(err: std::io::Error) -> Self {
123        // Other error are usually checks that fail when snarkVM deserializes a message.
124        if err.kind() == std::io::ErrorKind::Other {
125            // This unwrap should always succeed.
126            let inner = err.into_inner().unwrap_or_else(|| anyhow!("Unknown error").into());
127            ConnectError::other(inner)
128        } else {
129            ConnectError::IoError(err)
130        }
131    }
132}
133
134#[doc(hidden)]
135pub struct InnerTcp {
136    /// The tracing span.
137    span: Span,
138    /// The node's configuration.
139    config: Config,
140    /// The node's listening address.
141    listening_addr: OnceCell<SocketAddr>,
142    /// Contains objects used by the protocols implemented by the node.
143    pub(crate) protocols: Protocols,
144    /// A set of connections that have not been finalized yet.
145    connecting: Mutex<HashSet<SocketAddr>>,
146    /// Contains objects related to the node's active connections.
147    connections: Connections,
148    /// Collects statistics related to the node's peers.
149    known_peers: KnownPeers,
150    /// Contains the set of currently banned peers.
151    banned_peers: BannedPeers,
152    /// Collects statistics related to the node itself.
153    stats: Stats,
154    /// The node's tasks.
155    pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
156}
157
158impl Tcp {
159    /// Creates a new [`Tcp`] using the given [`Config`].
160    pub fn new(mut config: Config) -> Self {
161        // If there is no pre-configured name, assign a sequential numeric identifier.
162        if config.name.is_none() {
163            config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string());
164        }
165
166        // Create a tracing span containing the node's name.
167        let span = crate::helpers::create_span(config.name.as_deref().unwrap());
168
169        // Initialize the Tcp stack.
170        let tcp = Tcp(Arc::new(InnerTcp {
171            span,
172            config,
173            listening_addr: Default::default(),
174            protocols: Default::default(),
175            connecting: Default::default(),
176            connections: Default::default(),
177            known_peers: Default::default(),
178            banned_peers: Default::default(),
179            stats: Stats::new(Instant::now()),
180            tasks: Default::default(),
181        }));
182
183        debug!(parent: tcp.span(), "The node is ready");
184
185        tcp
186    }
187
188    /// How long has this node accepting connections?
189    pub fn uptime(&self) -> Duration {
190        self.stats.timestamp().elapsed()
191    }
192
193    /// Returns the name assigned.
194    #[inline]
195    pub fn name(&self) -> &str {
196        // safe; can be set as None in Config, but receives a default value on Tcp creation
197        self.config.name.as_deref().unwrap()
198    }
199
200    /// Returns a reference to the configuration.
201    #[inline]
202    pub fn config(&self) -> &Config {
203        &self.config
204    }
205
206    /// Returns the listening address; returns an error if Tcp was not configured
207    /// to listen for inbound connections.
208    pub fn listening_addr(&self) -> io::Result<SocketAddr> {
209        self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
210    }
211
212    /// Checks whether the provided address is connected.
213    pub fn is_connected(&self, addr: SocketAddr) -> bool {
214        self.connections.is_connected(addr)
215    }
216
217    /// Checks if Tcp is currently setting up a connection with the provided address.
218    pub fn is_connecting(&self, addr: SocketAddr) -> bool {
219        self.connecting.lock().contains(&addr)
220    }
221
222    /// Returns the number of active connections.
223    pub fn num_connected(&self) -> usize {
224        self.connections.num_connected()
225    }
226
227    /// Returns the number of connections that are currently being set up.
228    pub fn num_connecting(&self) -> usize {
229        self.connecting.lock().len()
230    }
231
232    /// Returns a list containing addresses of active connections.
233    pub fn connected_addrs(&self) -> Vec<SocketAddr> {
234        self.connections.addrs()
235    }
236
237    /// Returns a list containing addresses of pending connections.
238    pub fn connecting_addrs(&self) -> Vec<SocketAddr> {
239        self.connecting.lock().iter().copied().collect()
240    }
241
242    /// Returns a reference to the collection of statistics of known peers.
243    #[inline]
244    pub fn known_peers(&self) -> &KnownPeers {
245        &self.known_peers
246    }
247
248    /// Returns a reference to the set of currently banned peers.
249    #[inline]
250    pub fn banned_peers(&self) -> &BannedPeers {
251        &self.banned_peers
252    }
253
254    /// Returns a reference to the statistics.
255    #[inline]
256    pub fn stats(&self) -> &Stats {
257        &self.stats
258    }
259
260    /// Returns the tracing [`Span`] associated with Tcp.
261    #[inline]
262    pub fn span(&self) -> &Span {
263        &self.span
264    }
265
266    /// Gracefully shuts down the stack.
267    pub async fn shut_down(&self) {
268        debug!(parent: self.span(), "Shutting down the TCP stack");
269
270        // Retrieve all tasks.
271        let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
272
273        // Abort the listening task first.
274        if let Some(listening_task) = tasks.next() {
275            listening_task.abort(); // abort the listening task first
276        }
277
278        // Disconnect from all connected peers.
279        let mut disconnect_tasks = JoinSet::new();
280        for addr in self.connected_addrs() {
281            let node = self.clone();
282            disconnect_tasks.spawn(async move {
283                node.disconnect(addr).await;
284            });
285        }
286        while disconnect_tasks.join_next().await.is_some() {}
287
288        // Abort all remaining tasks.
289        for handle in tasks {
290            handle.abort();
291        }
292    }
293}
294
295impl Tcp {
296    /// Connects to the provided `SocketAddr`.
297    pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> {
298        if let Ok(listening_addr) = self.listening_addr() {
299            // TODO(nkls): maybe this first check can be dropped; though it might be best to keep just in case.
300            if addr == listening_addr || self.is_self_connect(addr) {
301                error!(parent: self.span(), "Attempted to self-connect ({addr})");
302                return Err(ConnectError::SelfConnect { address: addr });
303            }
304        }
305
306        if !self.can_add_connection() {
307            error!(parent: self.span(), "Too many connections; refusing to connect to {addr}");
308            return Err(ConnectError::MaximumConnectionsReached { limit: self.config.max_connections });
309        }
310
311        if self.is_connected(addr) {
312            trace!(parent: self.span(), "Already connected to {addr}");
313            return Err(ConnectError::AlreadyConnected { address: addr });
314        }
315
316        if !self.connecting.lock().insert(addr) {
317            debug!(parent: self.span(), "Already connecting to {addr}");
318            return Err(ConnectError::AlreadyConnecting { address: addr });
319        }
320
321        let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into());
322
323        // Bind the tcp socket to the configured listener ip if it's set.
324        // Otherwise default to the system's default interface.
325        let res = if let Some(listen_ip) = self.config().listener_ip {
326            timeout(timeout_duration, self.connect_with_specific_interface(listen_ip, addr)).await
327        } else {
328            timeout(timeout_duration, TcpStream::connect(addr)).await
329        };
330
331        let stream = match res {
332            Ok(Ok(stream)) => Ok(stream),
333            Ok(err) => {
334                self.connecting.lock().remove(&addr);
335                err
336            }
337            Err(err) => {
338                self.connecting.lock().remove(&addr);
339                error!("connection timeout error: {}", err);
340                Err(io::ErrorKind::TimedOut.into())
341            }
342        }?;
343
344        let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await;
345
346        if let Err(ref e) = ret {
347            self.connecting.lock().remove(&addr);
348            self.known_peers().register_failure(addr.ip());
349            error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}");
350        }
351
352        ret.map_err(|err| err.into())
353    }
354
355    async fn connect_with_specific_interface(&self, listen_ip: IpAddr, addr: SocketAddr) -> io::Result<TcpStream> {
356        let sock = if listen_ip.is_ipv4() { TcpSocket::new_v4()? } else { TcpSocket::new_v6()? };
357        // Lock the socket to a specific interface.
358        sock.bind(SocketAddr::new(listen_ip, 0))?;
359        sock.connect(addr).await
360    }
361
362    /// Disconnects from the provided `SocketAddr`.
363    ///
364    /// Returns true if the we were connected to the given address.
365    pub async fn disconnect(&self, addr: SocketAddr) -> bool {
366        // claim the disconnect to avoid duplicate executions, or return early if already claimed
367        if let Some(conn) = self.connections.0.read().get(&addr) {
368            if conn.disconnecting.swap(true, Relaxed) {
369                // valid connection, but someone else is already disconnecting it
370                return false;
371            }
372        } else {
373            // not connected
374            return false;
375        };
376
377        if let Some(handler) = self.protocols.disconnect.get() {
378            let (sender, receiver) = oneshot::channel();
379            handler.trigger((addr, sender)).await;
380            if let Ok((handle, waiter)) = receiver.await {
381                // register the associated task with the connection, in case
382                // it gets terminated before its completion
383                if let Some(conn) = self.connections.0.write().get_mut(&addr) {
384                    conn.tasks.push(handle);
385                }
386                // wait for the OnDisconnect protocol to perform its specified actions
387                let _ = waiter.await;
388            }
389        }
390
391        let conn = self.connections.remove(addr);
392        let disconnected = conn.is_some();
393
394        if let Some(conn) = conn {
395            debug!(parent: self.span(), "Disconnecting from {addr}");
396
397            // Shut down the associated tasks of the peer.
398            drop(conn);
399
400            debug!(parent: self.span(), "Disconnected from {addr}");
401        } else {
402            warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}");
403        }
404
405        disconnected
406    }
407}
408
409impl Tcp {
410    /// Spawns a task that listens for incoming connections.
411    pub async fn enable_listener(&self) -> io::Result<SocketAddr> {
412        // Retrieve the listening IP address, which must be set.
413        let listener_ip =
414            self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set");
415
416        // Initialize the TCP listener.
417        let listener = self.create_listener(listener_ip).await?;
418
419        // Discover the port, if it was unspecified.
420        let port = listener.local_addr()?.port();
421
422        // Set the listening IP address.
423        let listening_addr = (listener_ip, port).into();
424        self.listening_addr.set(listening_addr).expect("The node's listener was started more than once");
425
426        // Use a channel to know when the listening task is ready.
427        let (tx, rx) = oneshot::channel();
428
429        let tcp = self.clone();
430        let listening_task = tokio::spawn(async move {
431            trace!(parent: tcp.span(), "Spawned the listening task");
432            tx.send(()).unwrap(); // safe; the channel was just opened
433
434            loop {
435                // Await for a new connection.
436                match listener.accept().await {
437                    Ok((stream, addr)) => tcp.handle_connection(stream, addr),
438                    Err(e) => {
439                        error!(parent: tcp.span(), "Failed to accept a connection: {e}");
440                        // if we ran out of FDs, sleep to avoid spinning 100% CPU
441                        // while waiting for a slot to free up
442                        tokio::time::sleep(Duration::from_millis(500)).await;
443                    }
444                }
445            }
446        });
447        self.tasks.lock().push(listening_task);
448        let _ = rx.await;
449        debug!(parent: self.span(), "Listening on {listening_addr}");
450
451        Ok(listening_addr)
452    }
453
454    /// Creates an instance of `TcpListener` based on the node's configuration.
455    async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
456        debug!("Creating a TCP listener on {listener_ip}...");
457        let listener = if let Some(port) = self.config().desired_listening_port {
458            // Construct the desired listening IP address.
459            let desired_listening_addr = SocketAddr::new(listener_ip, port);
460            // If a desired listening port is set, try to bind to it.
461            match TcpListener::bind(desired_listening_addr).await {
462                Ok(listener) => listener,
463                Err(e) => {
464                    if self.config().allow_random_port {
465                        warn!(
466                            parent: self.span(),
467                            "Trying any listening port, as the desired port is unavailable: {e}"
468                        );
469                        let random_available_addr = SocketAddr::new(listener_ip, 0);
470                        TcpListener::bind(random_available_addr).await?
471                    } else {
472                        error!(parent: self.span(), "The desired listening port is unavailable: {e}");
473                        return Err(e);
474                    }
475                }
476            }
477        } else if self.config().allow_random_port {
478            let random_available_addr = SocketAddr::new(listener_ip, 0);
479            TcpListener::bind(random_available_addr).await?
480        } else {
481            panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set");
482        };
483
484        Ok(listener)
485    }
486
487    /// Handles a new inbound connection.
488    fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) {
489        debug!(parent: self.span(), "Received a connection from {addr}");
490
491        if !self.can_add_connection() || self.is_self_connect(addr) {
492            debug!(parent: self.span(), "Rejecting the connection from {addr}");
493            return;
494        }
495
496        self.connecting.lock().insert(addr);
497
498        let tcp = self.clone();
499        tokio::spawn(async move {
500            if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await {
501                tcp.connecting.lock().remove(&addr);
502                tcp.known_peers().register_failure(addr.ip());
503                error!(parent: tcp.span(), "Failed to connect with {addr}: {e}");
504            }
505        });
506    }
507
508    /// Checks if the given IP address is the same as the listening address of this `Tcp`.
509    fn is_self_connect(&self, addr: SocketAddr) -> bool {
510        // SAFETY: if we're opening connections, this should never fail.
511        let listening_addr = self.listening_addr().unwrap();
512
513        match listening_addr.ip().is_loopback() {
514            // If localhost, check the ports, this only works on outbound connections, since we
515            // don't know the ephemeral port a peer might be using if they initiate the connection.
516            true => listening_addr.port() == addr.port(),
517            // If it's not localhost, matching IPs indicate a self-connect in both directions.
518            false => listening_addr.ip() == addr.ip(),
519        }
520    }
521
522    /// Checks whether the `Tcp` can handle an additional connection.
523    fn can_add_connection(&self) -> bool {
524        // Retrieve the number of connected peers.
525        let num_connected = self.num_connected();
526        // Retrieve the maximum number of connected peers.
527        let limit = self.config.max_connections as usize;
528
529        if num_connected >= limit {
530            warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached");
531            false
532        } else if num_connected + self.num_connecting() >= limit {
533            warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached");
534            false
535        } else {
536            true
537        }
538    }
539
540    /// Prepares the freshly acquired connection to handle the protocols the Tcp implements.
541    async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> {
542        self.known_peers.add(peer_addr.ip());
543
544        // Register the port seen by the peer.
545        if own_side == ConnectionSide::Initiator {
546            if let Ok(addr) = stream.local_addr() {
547                debug!(
548                    parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
549                    peer_addr, addr.port()
550                );
551            } else {
552                warn!(parent: self.span(), "couldn't determine the peer's port");
553            }
554        }
555
556        let conn_span = create_connection_span(peer_addr, self.span());
557        let connection = Connection::new(peer_addr, stream, !own_side, conn_span);
558
559        // Enact the enabled protocols.
560        let mut connection = self.enable_protocols(connection).await?;
561
562        // if Reading is enabled, we'll notify the related task when the connection is fully ready.
563        let conn_ready_tx = connection.readiness_notifier.take();
564
565        self.connections.add(connection);
566        self.connecting.lock().remove(&peer_addr);
567
568        // Send the aforementioned notification so that reading from the socket can commence.
569        if let Some(tx) = conn_ready_tx {
570            let _ = tx.send(());
571        }
572
573        // If enabled, enact OnConnect.
574        if let Some(handler) = self.protocols.on_connect.get() {
575            let (sender, receiver) = oneshot::channel();
576            handler.trigger((peer_addr, sender)).await;
577            // Receive the handle for the running task.
578            if let Ok(handle) = receiver.await {
579                // Add the task to the connection so it gets aborted on disconnect.
580                if let Some(conn) = self.connections.0.write().get_mut(&peer_addr) {
581                    conn.tasks.push(handle);
582                } else {
583                    // The connection has just been terminated; abort the OnConnect work.
584                    handle.abort();
585                }
586            }
587        }
588
589        Ok(())
590    }
591
592    /// Enacts the enabled protocols on the provided connection.
593    async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
594        /// A helper macro to enable a protocol on a connection.
595        macro_rules! enable_protocol {
596            ($handler_type: ident, $node:expr, $conn: expr) => {
597                if let Some(handler) = $node.protocols.$handler_type.get() {
598                    let (conn_returner, conn_retriever) = oneshot::channel();
599
600                    handler.trigger(($conn, conn_returner)).await;
601
602                    match conn_retriever.await {
603                        Ok(Ok(conn)) => conn,
604                        Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
605                        Ok(e) => return e,
606                    }
607                } else {
608                    $conn
609                }
610            };
611        }
612
613        let mut conn = enable_protocol!(handshake, self, conn);
614
615        // Split the stream after the handshake (if not done before).
616        if let Some(stream) = conn.stream.take() {
617            let (reader, writer) = split(stream);
618            conn.reader = Some(Box::new(reader));
619            conn.writer = Some(Box::new(writer));
620        }
621
622        let conn = enable_protocol!(reading, self, conn);
623        let conn = enable_protocol!(writing, self, conn);
624
625        Ok(conn)
626    }
627}
628
629impl fmt::Debug for Tcp {
630    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
631        write!(f, "The TCP stack config: {:?}", self.config)
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    use std::{
640        net::{IpAddr, Ipv4Addr},
641        str::FromStr,
642    };
643
644    #[tokio::test]
645    async fn test_new() {
646        let tcp = Tcp::new(Config {
647            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
648            max_connections: 200,
649            ..Default::default()
650        });
651
652        assert_eq!(tcp.config.max_connections, 200);
653        assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
654        assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
655
656        assert_eq!(tcp.num_connected(), 0);
657        assert_eq!(tcp.num_connecting(), 0);
658    }
659
660    #[tokio::test]
661    async fn test_connect() {
662        let tcp = Tcp::new(Config::default());
663        let node_ip = tcp.enable_listener().await.unwrap();
664
665        // Ensure self-connecting is not possible.
666        let result = tcp.connect(node_ip).await;
667        assert!(matches!(result, Err(ConnectError::SelfConnect { .. })));
668
669        assert_eq!(tcp.num_connected(), 0);
670        assert_eq!(tcp.num_connecting(), 0);
671        assert!(!tcp.is_connected(node_ip));
672        assert!(!tcp.is_connecting(node_ip));
673
674        // Initialize the peer.
675        let peer = Tcp::new(Config {
676            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
677            desired_listening_port: Some(0),
678            max_connections: 1,
679            ..Default::default()
680        });
681        let peer_ip = peer.enable_listener().await.unwrap();
682
683        // Connect to the peer.
684        tcp.connect(peer_ip).await.unwrap();
685        assert_eq!(tcp.num_connected(), 1);
686        assert_eq!(tcp.num_connecting(), 0);
687        assert!(tcp.is_connected(peer_ip));
688        assert!(!tcp.is_connecting(peer_ip));
689    }
690
691    #[tokio::test]
692    async fn test_disconnect() {
693        let tcp = Tcp::new(Config::default());
694        let _node_ip = tcp.enable_listener().await.unwrap();
695
696        // Initialize the peer.
697        let peer = Tcp::new(Config {
698            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
699            desired_listening_port: Some(0),
700            max_connections: 1,
701            ..Default::default()
702        });
703        let peer_ip = peer.enable_listener().await.unwrap();
704
705        // Connect to the peer.
706        tcp.connect(peer_ip).await.unwrap();
707        assert_eq!(tcp.num_connected(), 1);
708        assert_eq!(tcp.num_connecting(), 0);
709        assert!(tcp.is_connected(peer_ip));
710        assert!(!tcp.is_connecting(peer_ip));
711
712        // Disconnect from the peer.
713        let has_disconnected = tcp.disconnect(peer_ip).await;
714        assert!(has_disconnected);
715        assert_eq!(tcp.num_connected(), 0);
716        assert_eq!(tcp.num_connecting(), 0);
717        assert!(!tcp.is_connected(peer_ip));
718        assert!(!tcp.is_connecting(peer_ip));
719
720        // Ensure disconnecting from the peer a second time is okay.
721        let has_disconnected = tcp.disconnect(peer_ip).await;
722        assert!(!has_disconnected);
723        assert_eq!(tcp.num_connected(), 0);
724        assert_eq!(tcp.num_connecting(), 0);
725        assert!(!tcp.is_connected(peer_ip));
726        assert!(!tcp.is_connecting(peer_ip));
727    }
728
729    #[tokio::test]
730    async fn test_can_add_connection() {
731        let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
732
733        // Initialize the peer.
734        let peer = Tcp::new(Config {
735            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
736            desired_listening_port: Some(0),
737            max_connections: 1,
738            ..Default::default()
739        });
740        let peer_ip = peer.enable_listener().await.unwrap();
741
742        assert!(tcp.can_add_connection());
743
744        // Simulate an active connection.
745        let stream = TcpStream::connect(peer_ip).await.unwrap();
746        tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator, Span::none()));
747        assert!(!tcp.can_add_connection());
748
749        // Ensure that we cannot invoke connect() successfully in this case.
750        // Use a non-local IP, to ensure it is never qual to peer IP.
751        let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
752        let result = tcp.connect(another_ip).await;
753        assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
754
755        // Remove the active connection.
756        tcp.connections.remove(peer_ip);
757        assert!(tcp.can_add_connection());
758
759        // Simulate a pending connection.
760        tcp.connecting.lock().insert(peer_ip);
761        assert!(!tcp.can_add_connection());
762
763        // Ensure that we cannot invoke connect() successfully in this case either.
764        let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
765        let result = tcp.connect(another_ip).await;
766        assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
767
768        // Remove the pending connection.
769        tcp.connecting.lock().remove(&peer_ip);
770        assert!(tcp.can_add_connection());
771
772        // Simulate an active and a pending connection (this case should never occur).
773        let stream = TcpStream::connect(peer_ip).await.unwrap();
774        tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder, Span::none()));
775        tcp.connecting.lock().insert(peer_ip);
776        assert!(!tcp.can_add_connection());
777
778        // Remove the active and pending connection.
779        tcp.connections.remove(peer_ip);
780        tcp.connecting.lock().remove(&peer_ip);
781        assert!(tcp.can_add_connection());
782    }
783
784    #[tokio::test]
785    async fn test_handle_connection() {
786        let tcp = Tcp::new(Config {
787            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
788            max_connections: 1,
789            ..Default::default()
790        });
791
792        // Initialize peer 1.
793        let peer1 = Tcp::new(Config {
794            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
795            desired_listening_port: Some(0),
796            max_connections: 1,
797            ..Default::default()
798        });
799        let peer1_ip = peer1.enable_listener().await.unwrap();
800
801        // Simulate an active connection.
802        let stream = TcpStream::connect(peer1_ip).await.unwrap();
803        tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder, Span::none()));
804        assert!(!tcp.can_add_connection());
805        assert_eq!(tcp.num_connected(), 1);
806        assert_eq!(tcp.num_connecting(), 0);
807        assert!(tcp.is_connected(peer1_ip));
808        assert!(!tcp.is_connecting(peer1_ip));
809
810        // Initialize peer 2.
811        let peer2 = Tcp::new(Config {
812            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
813            desired_listening_port: Some(0),
814            max_connections: 1,
815            ..Default::default()
816        });
817        let peer2_ip = peer2.enable_listener().await.unwrap();
818
819        // Handle the connection.
820        let stream = TcpStream::connect(peer2_ip).await.unwrap();
821        tcp.handle_connection(stream, peer2_ip);
822        assert!(!tcp.can_add_connection());
823        assert_eq!(tcp.num_connected(), 1);
824        assert_eq!(tcp.num_connecting(), 0);
825        assert!(tcp.is_connected(peer1_ip));
826        assert!(!tcp.is_connected(peer2_ip));
827        assert!(!tcp.is_connecting(peer1_ip));
828        assert!(!tcp.is_connecting(peer2_ip));
829    }
830
831    #[tokio::test]
832    async fn test_adapt_stream() {
833        let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
834
835        // Initialize the peer.
836        let peer = Tcp::new(Config {
837            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
838            desired_listening_port: Some(0),
839            max_connections: 1,
840            ..Default::default()
841        });
842        let peer_ip = peer.enable_listener().await.unwrap();
843
844        // Simulate a pending connection.
845        tcp.connecting.lock().insert(peer_ip);
846        assert_eq!(tcp.num_connected(), 0);
847        assert_eq!(tcp.num_connecting(), 1);
848        assert!(!tcp.is_connected(peer_ip));
849        assert!(tcp.is_connecting(peer_ip));
850
851        // Simulate a new connection.
852        let stream = TcpStream::connect(peer_ip).await.unwrap();
853        tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap();
854        assert_eq!(tcp.num_connected(), 1);
855        assert_eq!(tcp.num_connecting(), 0);
856        assert!(tcp.is_connected(peer_ip));
857        assert!(!tcp.is_connecting(peer_ip));
858    }
859}