Skip to main content

snarkos_node_tcp/
tcp.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{
17    collections::HashSet,
18    fmt,
19    io,
20    net::{IpAddr, SocketAddr},
21    ops::Deref,
22    sync::{
23        Arc,
24        atomic::{AtomicUsize, Ordering::*},
25    },
26    time::{Duration, Instant},
27};
28
29#[cfg(feature = "locktick")]
30use locktick::parking_lot::Mutex;
31use once_cell::sync::OnceCell;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::Mutex;
34use tokio::{
35    io::split,
36    net::{TcpListener, TcpSocket, TcpStream},
37    sync::oneshot,
38    task::JoinHandle,
39    time::timeout,
40};
41use tracing::*;
42
43use crate::{
44    BannedPeers,
45    Config,
46    KnownPeers,
47    Stats,
48    connections::{Connection, ConnectionSide, Connections},
49    protocols::{Protocol, Protocols},
50};
51
52// A sequential numeric identifier assigned to `Tcp`s that were not provided with a name.
53static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
54
55/// The central object responsible for handling connections.
56#[derive(Clone)]
57pub struct Tcp(Arc<InnerTcp>);
58
59impl Deref for Tcp {
60    type Target = Arc<InnerTcp>;
61
62    fn deref(&self) -> &Self::Target {
63        &self.0
64    }
65}
66
67/// Error types for the `Tcp::connect` function.
68#[allow(missing_docs)]
69#[derive(thiserror::Error, Debug)]
70pub enum ConnectError {
71    #[error("already reached the maximum number of {limit} connections")]
72    MaximumConnectionsReached { limit: u16 },
73    #[error("already connecting to node at {address:?}")]
74    AlreadyConnecting { address: SocketAddr },
75    #[error("already connected to node at {address:?}")]
76    AlreadyConnected { address: SocketAddr },
77    #[error("attempt to self-connect (at address {address:?}")]
78    SelfConnect { address: SocketAddr },
79    #[error("I/O error: {0}")]
80    IoError(std::io::Error),
81}
82
83impl From<std::io::Error> for ConnectError {
84    fn from(inner: std::io::Error) -> Self {
85        Self::IoError(inner)
86    }
87}
88
89#[doc(hidden)]
90pub struct InnerTcp {
91    /// The tracing span.
92    span: Span,
93    /// The node's configuration.
94    config: Config,
95    /// The node's listening address.
96    listening_addr: OnceCell<SocketAddr>,
97    /// Contains objects used by the protocols implemented by the node.
98    pub(crate) protocols: Protocols,
99    /// A set of connections that have not been finalized yet.
100    connecting: Mutex<HashSet<SocketAddr>>,
101    /// Contains objects related to the node's active connections.
102    connections: Connections,
103    /// Collects statistics related to the node's peers.
104    known_peers: KnownPeers,
105    /// Contains the set of currently banned peers.
106    banned_peers: BannedPeers,
107    /// Collects statistics related to the node itself.
108    stats: Stats,
109    /// The node's tasks.
110    pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
111}
112
113impl Tcp {
114    /// Creates a new [`Tcp`] using the given [`Config`].
115    pub fn new(mut config: Config) -> Self {
116        // If there is no pre-configured name, assign a sequential numeric identifier.
117        if config.name.is_none() {
118            config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string());
119        }
120
121        // Create a tracing span containing the node's name.
122        let span = crate::helpers::create_span(config.name.as_deref().unwrap());
123
124        // Initialize the Tcp stack.
125        let tcp = Tcp(Arc::new(InnerTcp {
126            span,
127            config,
128            listening_addr: Default::default(),
129            protocols: Default::default(),
130            connecting: Default::default(),
131            connections: Default::default(),
132            known_peers: Default::default(),
133            banned_peers: Default::default(),
134            stats: Stats::new(Instant::now()),
135            tasks: Default::default(),
136        }));
137
138        debug!(parent: tcp.span(), "The node is ready");
139
140        tcp
141    }
142
143    /// Returns the name assigned.
144    #[inline]
145    pub fn name(&self) -> &str {
146        // safe; can be set as None in Config, but receives a default value on Tcp creation
147        self.config.name.as_deref().unwrap()
148    }
149
150    /// Returns a reference to the configuration.
151    #[inline]
152    pub fn config(&self) -> &Config {
153        &self.config
154    }
155
156    /// Returns the listening address; returns an error if Tcp was not configured
157    /// to listen for inbound connections.
158    pub fn listening_addr(&self) -> io::Result<SocketAddr> {
159        self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
160    }
161
162    /// Checks whether the provided address is connected.
163    pub fn is_connected(&self, addr: SocketAddr) -> bool {
164        self.connections.is_connected(addr)
165    }
166
167    /// Checks if Tcp is currently setting up a connection with the provided address.
168    pub fn is_connecting(&self, addr: SocketAddr) -> bool {
169        self.connecting.lock().contains(&addr)
170    }
171
172    /// Returns the number of active connections.
173    pub fn num_connected(&self) -> usize {
174        self.connections.num_connected()
175    }
176
177    /// Returns the number of connections that are currently being set up.
178    pub fn num_connecting(&self) -> usize {
179        self.connecting.lock().len()
180    }
181
182    /// Returns a list containing addresses of active connections.
183    pub fn connected_addrs(&self) -> Vec<SocketAddr> {
184        self.connections.addrs()
185    }
186
187    /// Returns a list containing addresses of pending connections.
188    pub fn connecting_addrs(&self) -> Vec<SocketAddr> {
189        self.connecting.lock().iter().copied().collect()
190    }
191
192    /// Returns a reference to the collection of statistics of known peers.
193    #[inline]
194    pub fn known_peers(&self) -> &KnownPeers {
195        &self.known_peers
196    }
197
198    /// Returns a reference to the set of currently banned peers.
199    #[inline]
200    pub fn banned_peers(&self) -> &BannedPeers {
201        &self.banned_peers
202    }
203
204    /// Returns a reference to the statistics.
205    #[inline]
206    pub fn stats(&self) -> &Stats {
207        &self.stats
208    }
209
210    /// Returns the tracing [`Span`] associated with Tcp.
211    #[inline]
212    pub fn span(&self) -> &Span {
213        &self.span
214    }
215
216    /// Gracefully shuts down the stack.
217    pub async fn shut_down(&self) {
218        debug!(parent: self.span(), "Shutting down the TCP stack");
219
220        // Retrieve all tasks.
221        let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
222
223        // Abort the listening task first.
224        if let Some(listening_task) = tasks.next() {
225            listening_task.abort(); // abort the listening task first
226        }
227        // Disconnect from all connected peers.
228        for addr in self.connected_addrs() {
229            self.disconnect(addr).await;
230        }
231        // Abort all remaining tasks.
232        for handle in tasks {
233            handle.abort();
234        }
235    }
236}
237
238impl Tcp {
239    /// Connects to the provided `SocketAddr`.
240    pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> {
241        if let Ok(listening_addr) = self.listening_addr() {
242            // TODO(nkls): maybe this first check can be dropped; though it might be best to keep just in case.
243            if addr == listening_addr || self.is_self_connect(addr) {
244                error!(parent: self.span(), "Attempted to self-connect ({addr})");
245                return Err(ConnectError::SelfConnect { address: addr });
246            }
247        }
248
249        if !self.can_add_connection() {
250            error!(parent: self.span(), "Too many connections; refusing to connect to {addr}");
251            return Err(ConnectError::MaximumConnectionsReached { limit: self.config.max_connections });
252        }
253
254        if self.is_connected(addr) {
255            warn!(parent: self.span(), "Already connected to {addr}");
256            return Err(ConnectError::AlreadyConnected { address: addr });
257        }
258
259        if !self.connecting.lock().insert(addr) {
260            warn!(parent: self.span(), "Already connecting to {addr}");
261            return Err(ConnectError::AlreadyConnecting { address: addr });
262        }
263
264        let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into());
265
266        // Bind the tcp socket to the configured listener ip if it's set.
267        // Otherwise default to the system's default interface.
268        let res = if let Some(listen_ip) = self.config().listener_ip {
269            timeout(timeout_duration, self.connect_with_specific_interface(listen_ip, addr)).await
270        } else {
271            timeout(timeout_duration, TcpStream::connect(addr)).await
272        };
273
274        let stream = match res {
275            Ok(Ok(stream)) => Ok(stream),
276            Ok(err) => {
277                self.connecting.lock().remove(&addr);
278                err
279            }
280            Err(err) => {
281                self.connecting.lock().remove(&addr);
282                error!("connection timeout error: {}", err);
283                Err(io::ErrorKind::TimedOut.into())
284            }
285        }?;
286
287        let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await;
288
289        if let Err(ref e) = ret {
290            self.connecting.lock().remove(&addr);
291            self.known_peers().register_failure(addr.ip());
292            error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}");
293        }
294
295        ret.map_err(|err| err.into())
296    }
297
298    async fn connect_with_specific_interface(&self, listen_ip: IpAddr, addr: SocketAddr) -> io::Result<TcpStream> {
299        let sock = if listen_ip.is_ipv4() { TcpSocket::new_v4()? } else { TcpSocket::new_v6()? };
300        // Lock the socket to a specific interface.
301        sock.bind(SocketAddr::new(listen_ip, 0))?;
302        sock.connect(addr).await
303    }
304
305    /// Disconnects from the provided `SocketAddr`.
306    ///
307    /// Returns true if the we were connected to the given address.
308    pub async fn disconnect(&self, addr: SocketAddr) -> bool {
309        // claim the disconnect to avoid duplicate executions, or return early if already claimed
310        if let Some(conn) = self.connections.0.read().get(&addr) {
311            if conn.disconnecting.swap(true, Relaxed) {
312                // valid connection, but someone else is already disconnecting it
313                return false;
314            }
315        } else {
316            // not connected
317            return false;
318        };
319
320        if let Some(handler) = self.protocols.disconnect.get() {
321            let (sender, receiver) = oneshot::channel();
322            handler.trigger((addr, sender));
323            let _ = receiver.await; // can't really fail
324        }
325
326        let conn = self.connections.remove(addr);
327
328        if let Some(ref conn) = conn {
329            debug!(parent: self.span(), "Disconnecting from {}", conn.addr());
330
331            // Shut down the associated tasks of the peer.
332            for task in conn.tasks.iter().rev() {
333                task.abort();
334            }
335
336            debug!(parent: self.span(), "Disconnected from {}", conn.addr());
337        } else {
338            warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}");
339        }
340
341        conn.is_some()
342    }
343}
344
345impl Tcp {
346    /// Spawns a task that listens for incoming connections.
347    pub async fn enable_listener(&self) -> io::Result<SocketAddr> {
348        // Retrieve the listening IP address, which must be set.
349        let listener_ip =
350            self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set");
351
352        // Initialize the TCP listener.
353        let listener = self.create_listener(listener_ip).await?;
354
355        // Discover the port, if it was unspecified.
356        let port = listener.local_addr()?.port();
357
358        // Set the listening IP address.
359        let listening_addr = (listener_ip, port).into();
360        self.listening_addr.set(listening_addr).expect("The node's listener was started more than once");
361
362        // Use a channel to know when the listening task is ready.
363        let (tx, rx) = oneshot::channel();
364
365        let tcp = self.clone();
366        let listening_task = tokio::spawn(async move {
367            trace!(parent: tcp.span(), "Spawned the listening task");
368            tx.send(()).unwrap(); // safe; the channel was just opened
369
370            loop {
371                // Await for a new connection.
372                match listener.accept().await {
373                    Ok((stream, addr)) => tcp.handle_connection(stream, addr),
374                    Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"),
375                }
376            }
377        });
378        self.tasks.lock().push(listening_task);
379        let _ = rx.await;
380        debug!(parent: self.span(), "Listening on {listening_addr}");
381
382        Ok(listening_addr)
383    }
384
385    /// Creates an instance of `TcpListener` based on the node's configuration.
386    async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
387        debug!("Creating a TCP listener on {listener_ip}...");
388        let listener = if let Some(port) = self.config().desired_listening_port {
389            // Construct the desired listening IP address.
390            let desired_listening_addr = SocketAddr::new(listener_ip, port);
391            // If a desired listening port is set, try to bind to it.
392            match TcpListener::bind(desired_listening_addr).await {
393                Ok(listener) => listener,
394                Err(e) => {
395                    if self.config().allow_random_port {
396                        warn!(
397                            parent: self.span(),
398                            "Trying any listening port, as the desired port is unavailable: {e}"
399                        );
400                        let random_available_addr = SocketAddr::new(listener_ip, 0);
401                        TcpListener::bind(random_available_addr).await?
402                    } else {
403                        error!(parent: self.span(), "The desired listening port is unavailable: {e}");
404                        return Err(e);
405                    }
406                }
407            }
408        } else if self.config().allow_random_port {
409            let random_available_addr = SocketAddr::new(listener_ip, 0);
410            TcpListener::bind(random_available_addr).await?
411        } else {
412            panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set");
413        };
414
415        Ok(listener)
416    }
417
418    /// Handles a new inbound connection.
419    fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) {
420        debug!(parent: self.span(), "Received a connection from {addr}");
421
422        if !self.can_add_connection() || self.is_self_connect(addr) {
423            debug!(parent: self.span(), "Rejecting the connection from {addr}");
424            return;
425        }
426
427        self.connecting.lock().insert(addr);
428
429        let tcp = self.clone();
430        tokio::spawn(async move {
431            if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await {
432                tcp.connecting.lock().remove(&addr);
433                tcp.known_peers().register_failure(addr.ip());
434                error!(parent: tcp.span(), "Failed to connect with {addr}: {e}");
435            }
436        });
437    }
438
439    /// Checks if the given IP address is the same as the listening address of this `Tcp`.
440    fn is_self_connect(&self, addr: SocketAddr) -> bool {
441        // SAFETY: if we're opening connections, this should never fail.
442        let listening_addr = self.listening_addr().unwrap();
443
444        match listening_addr.ip().is_loopback() {
445            // If localhost, check the ports, this only works on outbound connections, since we
446            // don't know the ephemeral port a peer might be using if they initiate the connection.
447            true => listening_addr.port() == addr.port(),
448            // If it's not localhost, matching IPs indicate a self-connect in both directions.
449            false => listening_addr.ip() == addr.ip(),
450        }
451    }
452
453    /// Checks whether the `Tcp` can handle an additional connection.
454    fn can_add_connection(&self) -> bool {
455        // Retrieve the number of connected peers.
456        let num_connected = self.num_connected();
457        // Retrieve the maximum number of connected peers.
458        let limit = self.config.max_connections as usize;
459
460        if num_connected >= limit {
461            warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached");
462            false
463        } else if num_connected + self.num_connecting() >= limit {
464            warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached");
465            false
466        } else {
467            true
468        }
469    }
470
471    /// Prepares the freshly acquired connection to handle the protocols the Tcp implements.
472    async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> {
473        self.known_peers.add(peer_addr.ip());
474
475        // Register the port seen by the peer.
476        if own_side == ConnectionSide::Initiator {
477            if let Ok(addr) = stream.local_addr() {
478                debug!(
479                    parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
480                    peer_addr, addr.port()
481                );
482            } else {
483                warn!(parent: self.span(), "couldn't determine the peer's port");
484            }
485        }
486
487        let connection = Connection::new(peer_addr, stream, !own_side);
488
489        // Enact the enabled protocols.
490        let mut connection = self.enable_protocols(connection).await?;
491
492        // if Reading is enabled, we'll notify the related task when the connection is fully ready.
493        let conn_ready_tx = connection.readiness_notifier.take();
494
495        self.connections.add(connection);
496        self.connecting.lock().remove(&peer_addr);
497
498        // Send the aforementioned notification so that reading from the socket can commence.
499        if let Some(tx) = conn_ready_tx {
500            let _ = tx.send(());
501        }
502
503        // If enabled, enact OnConnect.
504        if let Some(handler) = self.protocols.on_connect.get() {
505            let (sender, receiver) = oneshot::channel();
506            handler.trigger((peer_addr, sender));
507            let _ = receiver.await; // can't really fail
508        }
509
510        Ok(())
511    }
512
513    /// Enacts the enabled protocols on the provided connection.
514    async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
515        /// A helper macro to enable a protocol on a connection.
516        macro_rules! enable_protocol {
517            ($handler_type: ident, $node:expr, $conn: expr) => {
518                if let Some(handler) = $node.protocols.$handler_type.get() {
519                    let (conn_returner, conn_retriever) = oneshot::channel();
520
521                    handler.trigger(($conn, conn_returner));
522
523                    match conn_retriever.await {
524                        Ok(Ok(conn)) => conn,
525                        Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
526                        Ok(e) => return e,
527                    }
528                } else {
529                    $conn
530                }
531            };
532        }
533
534        let mut conn = enable_protocol!(handshake, self, conn);
535
536        // Split the stream after the handshake (if not done before).
537        if let Some(stream) = conn.stream.take() {
538            let (reader, writer) = split(stream);
539            conn.reader = Some(Box::new(reader));
540            conn.writer = Some(Box::new(writer));
541        }
542
543        let conn = enable_protocol!(reading, self, conn);
544        let conn = enable_protocol!(writing, self, conn);
545
546        Ok(conn)
547    }
548}
549
550impl fmt::Debug for Tcp {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        write!(f, "The TCP stack config: {:?}", self.config)
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    use std::{
561        net::{IpAddr, Ipv4Addr},
562        str::FromStr,
563    };
564
565    #[tokio::test]
566    async fn test_new() {
567        let tcp = Tcp::new(Config {
568            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
569            max_connections: 200,
570            ..Default::default()
571        });
572
573        assert_eq!(tcp.config.max_connections, 200);
574        assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
575        assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
576
577        assert_eq!(tcp.num_connected(), 0);
578        assert_eq!(tcp.num_connecting(), 0);
579    }
580
581    #[tokio::test]
582    async fn test_connect() {
583        let tcp = Tcp::new(Config::default());
584        let node_ip = tcp.enable_listener().await.unwrap();
585
586        // Ensure self-connecting is not possible.
587        let result = tcp.connect(node_ip).await;
588        assert!(matches!(result, Err(ConnectError::SelfConnect { .. })));
589
590        assert_eq!(tcp.num_connected(), 0);
591        assert_eq!(tcp.num_connecting(), 0);
592        assert!(!tcp.is_connected(node_ip));
593        assert!(!tcp.is_connecting(node_ip));
594
595        // Initialize the peer.
596        let peer = Tcp::new(Config {
597            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
598            desired_listening_port: Some(0),
599            max_connections: 1,
600            ..Default::default()
601        });
602        let peer_ip = peer.enable_listener().await.unwrap();
603
604        // Connect to the peer.
605        tcp.connect(peer_ip).await.unwrap();
606        assert_eq!(tcp.num_connected(), 1);
607        assert_eq!(tcp.num_connecting(), 0);
608        assert!(tcp.is_connected(peer_ip));
609        assert!(!tcp.is_connecting(peer_ip));
610    }
611
612    #[tokio::test]
613    async fn test_disconnect() {
614        let tcp = Tcp::new(Config::default());
615        let _node_ip = tcp.enable_listener().await.unwrap();
616
617        // Initialize the peer.
618        let peer = Tcp::new(Config {
619            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
620            desired_listening_port: Some(0),
621            max_connections: 1,
622            ..Default::default()
623        });
624        let peer_ip = peer.enable_listener().await.unwrap();
625
626        // Connect to the peer.
627        tcp.connect(peer_ip).await.unwrap();
628        assert_eq!(tcp.num_connected(), 1);
629        assert_eq!(tcp.num_connecting(), 0);
630        assert!(tcp.is_connected(peer_ip));
631        assert!(!tcp.is_connecting(peer_ip));
632
633        // Disconnect from the peer.
634        let has_disconnected = tcp.disconnect(peer_ip).await;
635        assert!(has_disconnected);
636        assert_eq!(tcp.num_connected(), 0);
637        assert_eq!(tcp.num_connecting(), 0);
638        assert!(!tcp.is_connected(peer_ip));
639        assert!(!tcp.is_connecting(peer_ip));
640
641        // Ensure disconnecting from the peer a second time is okay.
642        let has_disconnected = tcp.disconnect(peer_ip).await;
643        assert!(!has_disconnected);
644        assert_eq!(tcp.num_connected(), 0);
645        assert_eq!(tcp.num_connecting(), 0);
646        assert!(!tcp.is_connected(peer_ip));
647        assert!(!tcp.is_connecting(peer_ip));
648    }
649
650    #[tokio::test]
651    async fn test_can_add_connection() {
652        let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
653
654        // Initialize the peer.
655        let peer = Tcp::new(Config {
656            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
657            desired_listening_port: Some(0),
658            max_connections: 1,
659            ..Default::default()
660        });
661        let peer_ip = peer.enable_listener().await.unwrap();
662
663        assert!(tcp.can_add_connection());
664
665        // Simulate an active connection.
666        let stream = TcpStream::connect(peer_ip).await.unwrap();
667        tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator));
668        assert!(!tcp.can_add_connection());
669
670        // Ensure that we cannot invoke connect() successfully in this case.
671        // Use a non-local IP, to ensure it is never qual to peer IP.
672        let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
673        let result = tcp.connect(another_ip).await;
674        assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
675
676        // Remove the active connection.
677        tcp.connections.remove(peer_ip);
678        assert!(tcp.can_add_connection());
679
680        // Simulate a pending connection.
681        tcp.connecting.lock().insert(peer_ip);
682        assert!(!tcp.can_add_connection());
683
684        // Ensure that we cannot invoke connect() successfully in this case either.
685        let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
686        let result = tcp.connect(another_ip).await;
687        assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
688
689        // Remove the pending connection.
690        tcp.connecting.lock().remove(&peer_ip);
691        assert!(tcp.can_add_connection());
692
693        // Simulate an active and a pending connection (this case should never occur).
694        let stream = TcpStream::connect(peer_ip).await.unwrap();
695        tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder));
696        tcp.connecting.lock().insert(peer_ip);
697        assert!(!tcp.can_add_connection());
698
699        // Remove the active and pending connection.
700        tcp.connections.remove(peer_ip);
701        tcp.connecting.lock().remove(&peer_ip);
702        assert!(tcp.can_add_connection());
703    }
704
705    #[tokio::test]
706    async fn test_handle_connection() {
707        let tcp = Tcp::new(Config {
708            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
709            max_connections: 1,
710            ..Default::default()
711        });
712
713        // Initialize peer 1.
714        let peer1 = Tcp::new(Config {
715            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
716            desired_listening_port: Some(0),
717            max_connections: 1,
718            ..Default::default()
719        });
720        let peer1_ip = peer1.enable_listener().await.unwrap();
721
722        // Simulate an active connection.
723        let stream = TcpStream::connect(peer1_ip).await.unwrap();
724        tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder));
725        assert!(!tcp.can_add_connection());
726        assert_eq!(tcp.num_connected(), 1);
727        assert_eq!(tcp.num_connecting(), 0);
728        assert!(tcp.is_connected(peer1_ip));
729        assert!(!tcp.is_connecting(peer1_ip));
730
731        // Initialize peer 2.
732        let peer2 = Tcp::new(Config {
733            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
734            desired_listening_port: Some(0),
735            max_connections: 1,
736            ..Default::default()
737        });
738        let peer2_ip = peer2.enable_listener().await.unwrap();
739
740        // Handle the connection.
741        let stream = TcpStream::connect(peer2_ip).await.unwrap();
742        tcp.handle_connection(stream, peer2_ip);
743        assert!(!tcp.can_add_connection());
744        assert_eq!(tcp.num_connected(), 1);
745        assert_eq!(tcp.num_connecting(), 0);
746        assert!(tcp.is_connected(peer1_ip));
747        assert!(!tcp.is_connected(peer2_ip));
748        assert!(!tcp.is_connecting(peer1_ip));
749        assert!(!tcp.is_connecting(peer2_ip));
750    }
751
752    #[tokio::test]
753    async fn test_adapt_stream() {
754        let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
755
756        // Initialize the peer.
757        let peer = Tcp::new(Config {
758            listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
759            desired_listening_port: Some(0),
760            max_connections: 1,
761            ..Default::default()
762        });
763        let peer_ip = peer.enable_listener().await.unwrap();
764
765        // Simulate a pending connection.
766        tcp.connecting.lock().insert(peer_ip);
767        assert_eq!(tcp.num_connected(), 0);
768        assert_eq!(tcp.num_connecting(), 1);
769        assert!(!tcp.is_connected(peer_ip));
770        assert!(tcp.is_connecting(peer_ip));
771
772        // Simulate a new connection.
773        let stream = TcpStream::connect(peer_ip).await.unwrap();
774        tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap();
775        assert_eq!(tcp.num_connected(), 1);
776        assert_eq!(tcp.num_connecting(), 0);
777        assert!(tcp.is_connected(peer_ip));
778        assert!(!tcp.is_connecting(peer_ip));
779    }
780}