1use std::{
17 collections::HashSet,
18 fmt,
19 io,
20 net::{IpAddr, SocketAddr},
21 ops::Deref,
22 sync::{
23 Arc,
24 atomic::{AtomicUsize, Ordering::*},
25 },
26 time::{Duration, Instant},
27};
28
29#[cfg(feature = "locktick")]
30use locktick::parking_lot::Mutex;
31use once_cell::sync::OnceCell;
32#[cfg(not(feature = "locktick"))]
33use parking_lot::Mutex;
34use tokio::{
35 io::split,
36 net::{TcpListener, TcpStream},
37 sync::oneshot,
38 task::JoinHandle,
39 time::timeout,
40};
41use tracing::*;
42
43use crate::{
44 BannedPeers,
45 Config,
46 KnownPeers,
47 Stats,
48 connections::{Connection, ConnectionSide, Connections},
49 protocols::{Protocol, Protocols},
50};
51
52static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
54
55#[derive(Clone)]
57pub struct Tcp(Arc<InnerTcp>);
58
59impl Deref for Tcp {
60 type Target = Arc<InnerTcp>;
61
62 fn deref(&self) -> &Self::Target {
63 &self.0
64 }
65}
66
67#[allow(missing_docs)]
69#[derive(thiserror::Error, Debug)]
70pub enum ConnectError {
71 #[error("already reached the maximum number of {limit} connections")]
72 MaximumConnectionsReached { limit: u16 },
73 #[error("already connecting to node at {address:?}")]
74 AlreadyConnecting { address: SocketAddr },
75 #[error("already connected to node at {address:?}")]
76 AlreadyConnected { address: SocketAddr },
77 #[error("attempt to self-connect (at address {address:?}")]
78 SelfConnect { address: SocketAddr },
79 #[error("I/O error: {0}")]
80 IoError(std::io::Error),
81}
82
83impl From<std::io::Error> for ConnectError {
84 fn from(inner: std::io::Error) -> Self {
85 Self::IoError(inner)
86 }
87}
88
89#[doc(hidden)]
90pub struct InnerTcp {
91 span: Span,
93 config: Config,
95 listening_addr: OnceCell<SocketAddr>,
97 pub(crate) protocols: Protocols,
99 connecting: Mutex<HashSet<SocketAddr>>,
101 connections: Connections,
103 known_peers: KnownPeers,
105 banned_peers: BannedPeers,
107 stats: Stats,
109 pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
111}
112
113impl Tcp {
114 pub fn new(mut config: Config) -> Self {
116 if config.name.is_none() {
118 config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string());
119 }
120
121 let span = crate::helpers::create_span(config.name.as_deref().unwrap());
123
124 let tcp = Tcp(Arc::new(InnerTcp {
126 span,
127 config,
128 listening_addr: Default::default(),
129 protocols: Default::default(),
130 connecting: Default::default(),
131 connections: Default::default(),
132 known_peers: Default::default(),
133 banned_peers: Default::default(),
134 stats: Stats::new(Instant::now()),
135 tasks: Default::default(),
136 }));
137
138 debug!(parent: tcp.span(), "The node is ready");
139
140 tcp
141 }
142
143 #[inline]
145 pub fn name(&self) -> &str {
146 self.config.name.as_deref().unwrap()
148 }
149
150 #[inline]
152 pub fn config(&self) -> &Config {
153 &self.config
154 }
155
156 pub fn listening_addr(&self) -> io::Result<SocketAddr> {
159 self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
160 }
161
162 pub fn is_connected(&self, addr: SocketAddr) -> bool {
164 self.connections.is_connected(addr)
165 }
166
167 pub fn is_connecting(&self, addr: SocketAddr) -> bool {
169 self.connecting.lock().contains(&addr)
170 }
171
172 pub fn num_connected(&self) -> usize {
174 self.connections.num_connected()
175 }
176
177 pub fn num_connecting(&self) -> usize {
179 self.connecting.lock().len()
180 }
181
182 pub fn connected_addrs(&self) -> Vec<SocketAddr> {
184 self.connections.addrs()
185 }
186
187 pub fn connecting_addrs(&self) -> Vec<SocketAddr> {
189 self.connecting.lock().iter().copied().collect()
190 }
191
192 #[inline]
194 pub fn known_peers(&self) -> &KnownPeers {
195 &self.known_peers
196 }
197
198 #[inline]
200 pub fn banned_peers(&self) -> &BannedPeers {
201 &self.banned_peers
202 }
203
204 #[inline]
206 pub fn stats(&self) -> &Stats {
207 &self.stats
208 }
209
210 #[inline]
212 pub fn span(&self) -> &Span {
213 &self.span
214 }
215
216 pub async fn shut_down(&self) {
218 debug!(parent: self.span(), "Shutting down the TCP stack");
219
220 let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
222
223 if let Some(listening_task) = tasks.next() {
225 listening_task.abort(); }
227 for addr in self.connected_addrs() {
229 self.disconnect(addr).await;
230 }
231 for handle in tasks {
233 handle.abort();
234 }
235 }
236}
237
238impl Tcp {
239 pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> {
241 if let Ok(listening_addr) = self.listening_addr() {
242 if addr == listening_addr || self.is_self_connect(addr) {
244 error!(parent: self.span(), "Attempted to self-connect ({addr})");
245 return Err(ConnectError::SelfConnect { address: addr });
246 }
247 }
248
249 if !self.can_add_connection() {
250 error!(parent: self.span(), "Too many connections; refusing to connect to {addr}");
251 return Err(ConnectError::MaximumConnectionsReached { limit: self.config.max_connections });
252 }
253
254 if self.is_connected(addr) {
255 warn!(parent: self.span(), "Already connected to {addr}");
256 return Err(ConnectError::AlreadyConnected { address: addr });
257 }
258
259 if !self.connecting.lock().insert(addr) {
260 warn!(parent: self.span(), "Already connecting to {addr}");
261 return Err(ConnectError::AlreadyConnecting { address: addr });
262 }
263
264 let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into());
265
266 let res = if let Some(listen_ip) = self.config().listener_ip {
269 let sock =
270 if listen_ip.is_ipv4() { tokio::net::TcpSocket::new_v4()? } else { tokio::net::TcpSocket::new_v6()? };
271 sock.bind(SocketAddr::new(listen_ip, 0))?;
272 timeout(timeout_duration, sock.connect(addr)).await
273 } else {
274 timeout(timeout_duration, TcpStream::connect(addr)).await
275 };
276
277 let stream = match res {
278 Ok(Ok(stream)) => Ok(stream),
279 Ok(err) => {
280 self.connecting.lock().remove(&addr);
281 err
282 }
283 Err(err) => {
284 self.connecting.lock().remove(&addr);
285 error!("connection timeout error: {}", err);
286 Err(io::ErrorKind::TimedOut.into())
287 }
288 }?;
289
290 let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await;
291
292 if let Err(ref e) = ret {
293 self.connecting.lock().remove(&addr);
294 self.known_peers().register_failure(addr.ip());
295 error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}");
296 }
297
298 ret.map_err(|err| err.into())
299 }
300
301 pub async fn disconnect(&self, addr: SocketAddr) -> bool {
305 if let Some(conn) = self.connections.0.read().get(&addr) {
307 if conn.disconnecting.swap(true, Relaxed) {
308 return false;
310 }
311 } else {
312 return false;
314 };
315
316 if let Some(handler) = self.protocols.disconnect.get() {
317 let (sender, receiver) = oneshot::channel();
318 handler.trigger((addr, sender));
319 let _ = receiver.await; }
321
322 let conn = self.connections.remove(addr);
323
324 if let Some(ref conn) = conn {
325 debug!(parent: self.span(), "Disconnecting from {}", conn.addr());
326
327 for task in conn.tasks.iter().rev() {
329 task.abort();
330 }
331
332 debug!(parent: self.span(), "Disconnected from {}", conn.addr());
333 } else {
334 warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}");
335 }
336
337 conn.is_some()
338 }
339}
340
341impl Tcp {
342 pub async fn enable_listener(&self) -> io::Result<SocketAddr> {
344 let listener_ip =
346 self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set");
347
348 let listener = self.create_listener(listener_ip).await?;
350
351 let port = listener.local_addr()?.port();
353
354 let listening_addr = (listener_ip, port).into();
356 self.listening_addr.set(listening_addr).expect("The node's listener was started more than once");
357
358 let (tx, rx) = oneshot::channel();
360
361 let tcp = self.clone();
362 let listening_task = tokio::spawn(async move {
363 trace!(parent: tcp.span(), "Spawned the listening task");
364 tx.send(()).unwrap(); loop {
367 match listener.accept().await {
369 Ok((stream, addr)) => tcp.handle_connection(stream, addr),
370 Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"),
371 }
372 }
373 });
374 self.tasks.lock().push(listening_task);
375 let _ = rx.await;
376 debug!(parent: self.span(), "Listening on {listening_addr}");
377
378 Ok(listening_addr)
379 }
380
381 async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
383 debug!("Creating a TCP listener on {listener_ip}...");
384 let listener = if let Some(port) = self.config().desired_listening_port {
385 let desired_listening_addr = SocketAddr::new(listener_ip, port);
387 match TcpListener::bind(desired_listening_addr).await {
389 Ok(listener) => listener,
390 Err(e) => {
391 if self.config().allow_random_port {
392 warn!(
393 parent: self.span(),
394 "Trying any listening port, as the desired port is unavailable: {e}"
395 );
396 let random_available_addr = SocketAddr::new(listener_ip, 0);
397 TcpListener::bind(random_available_addr).await?
398 } else {
399 error!(parent: self.span(), "The desired listening port is unavailable: {e}");
400 return Err(e);
401 }
402 }
403 }
404 } else if self.config().allow_random_port {
405 let random_available_addr = SocketAddr::new(listener_ip, 0);
406 TcpListener::bind(random_available_addr).await?
407 } else {
408 panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set");
409 };
410
411 Ok(listener)
412 }
413
414 fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) {
416 debug!(parent: self.span(), "Received a connection from {addr}");
417
418 if !self.can_add_connection() || self.is_self_connect(addr) {
419 debug!(parent: self.span(), "Rejecting the connection from {addr}");
420 return;
421 }
422
423 self.connecting.lock().insert(addr);
424
425 let tcp = self.clone();
426 tokio::spawn(async move {
427 if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await {
428 tcp.connecting.lock().remove(&addr);
429 tcp.known_peers().register_failure(addr.ip());
430 error!(parent: tcp.span(), "Failed to connect with {addr}: {e}");
431 }
432 });
433 }
434
435 fn is_self_connect(&self, addr: SocketAddr) -> bool {
437 let listening_addr = self.listening_addr().unwrap();
439
440 match listening_addr.ip().is_loopback() {
441 true => listening_addr.port() == addr.port(),
444 false => listening_addr.ip() == addr.ip(),
446 }
447 }
448
449 fn can_add_connection(&self) -> bool {
451 let num_connected = self.num_connected();
453 let limit = self.config.max_connections as usize;
455
456 if num_connected >= limit {
457 warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached");
458 false
459 } else if num_connected + self.num_connecting() >= limit {
460 warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached");
461 false
462 } else {
463 true
464 }
465 }
466
467 async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> {
469 self.known_peers.add(peer_addr.ip());
470
471 if own_side == ConnectionSide::Initiator {
473 if let Ok(addr) = stream.local_addr() {
474 debug!(
475 parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
476 peer_addr, addr.port()
477 );
478 } else {
479 warn!(parent: self.span(), "couldn't determine the peer's port");
480 }
481 }
482
483 let connection = Connection::new(peer_addr, stream, !own_side);
484
485 let mut connection = self.enable_protocols(connection).await?;
487
488 let conn_ready_tx = connection.readiness_notifier.take();
490
491 self.connections.add(connection);
492 self.connecting.lock().remove(&peer_addr);
493
494 if let Some(tx) = conn_ready_tx {
496 let _ = tx.send(());
497 }
498
499 if let Some(handler) = self.protocols.on_connect.get() {
501 let (sender, receiver) = oneshot::channel();
502 handler.trigger((peer_addr, sender));
503 let _ = receiver.await; }
505
506 Ok(())
507 }
508
509 async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
511 macro_rules! enable_protocol {
513 ($handler_type: ident, $node:expr, $conn: expr) => {
514 if let Some(handler) = $node.protocols.$handler_type.get() {
515 let (conn_returner, conn_retriever) = oneshot::channel();
516
517 handler.trigger(($conn, conn_returner));
518
519 match conn_retriever.await {
520 Ok(Ok(conn)) => conn,
521 Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
522 Ok(e) => return e,
523 }
524 } else {
525 $conn
526 }
527 };
528 }
529
530 let mut conn = enable_protocol!(handshake, self, conn);
531
532 if let Some(stream) = conn.stream.take() {
534 let (reader, writer) = split(stream);
535 conn.reader = Some(Box::new(reader));
536 conn.writer = Some(Box::new(writer));
537 }
538
539 let conn = enable_protocol!(reading, self, conn);
540 let conn = enable_protocol!(writing, self, conn);
541
542 Ok(conn)
543 }
544}
545
546impl fmt::Debug for Tcp {
547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548 write!(f, "The TCP stack config: {:?}", self.config)
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 use std::{
557 net::{IpAddr, Ipv4Addr},
558 str::FromStr,
559 };
560
561 #[tokio::test]
562 async fn test_new() {
563 let tcp = Tcp::new(Config {
564 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
565 max_connections: 200,
566 ..Default::default()
567 });
568
569 assert_eq!(tcp.config.max_connections, 200);
570 assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST)));
571 assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
572
573 assert_eq!(tcp.num_connected(), 0);
574 assert_eq!(tcp.num_connecting(), 0);
575 }
576
577 #[tokio::test]
578 async fn test_connect() {
579 let tcp = Tcp::new(Config::default());
580 let node_ip = tcp.enable_listener().await.unwrap();
581
582 let result = tcp.connect(node_ip).await;
584 assert!(matches!(result, Err(ConnectError::SelfConnect { .. })));
585
586 assert_eq!(tcp.num_connected(), 0);
587 assert_eq!(tcp.num_connecting(), 0);
588 assert!(!tcp.is_connected(node_ip));
589 assert!(!tcp.is_connecting(node_ip));
590
591 let peer = Tcp::new(Config {
593 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
594 desired_listening_port: Some(0),
595 max_connections: 1,
596 ..Default::default()
597 });
598 let peer_ip = peer.enable_listener().await.unwrap();
599
600 tcp.connect(peer_ip).await.unwrap();
602 assert_eq!(tcp.num_connected(), 1);
603 assert_eq!(tcp.num_connecting(), 0);
604 assert!(tcp.is_connected(peer_ip));
605 assert!(!tcp.is_connecting(peer_ip));
606 }
607
608 #[tokio::test]
609 async fn test_disconnect() {
610 let tcp = Tcp::new(Config::default());
611 let _node_ip = tcp.enable_listener().await.unwrap();
612
613 let peer = Tcp::new(Config {
615 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
616 desired_listening_port: Some(0),
617 max_connections: 1,
618 ..Default::default()
619 });
620 let peer_ip = peer.enable_listener().await.unwrap();
621
622 tcp.connect(peer_ip).await.unwrap();
624 assert_eq!(tcp.num_connected(), 1);
625 assert_eq!(tcp.num_connecting(), 0);
626 assert!(tcp.is_connected(peer_ip));
627 assert!(!tcp.is_connecting(peer_ip));
628
629 let has_disconnected = tcp.disconnect(peer_ip).await;
631 assert!(has_disconnected);
632 assert_eq!(tcp.num_connected(), 0);
633 assert_eq!(tcp.num_connecting(), 0);
634 assert!(!tcp.is_connected(peer_ip));
635 assert!(!tcp.is_connecting(peer_ip));
636
637 let has_disconnected = tcp.disconnect(peer_ip).await;
639 assert!(!has_disconnected);
640 assert_eq!(tcp.num_connected(), 0);
641 assert_eq!(tcp.num_connecting(), 0);
642 assert!(!tcp.is_connected(peer_ip));
643 assert!(!tcp.is_connecting(peer_ip));
644 }
645
646 #[tokio::test]
647 async fn test_can_add_connection() {
648 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
649
650 let peer = Tcp::new(Config {
652 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
653 desired_listening_port: Some(0),
654 max_connections: 1,
655 ..Default::default()
656 });
657 let peer_ip = peer.enable_listener().await.unwrap();
658
659 assert!(tcp.can_add_connection());
660
661 let stream = TcpStream::connect(peer_ip).await.unwrap();
663 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator));
664 assert!(!tcp.can_add_connection());
665
666 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
669 let result = tcp.connect(another_ip).await;
670 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
671
672 tcp.connections.remove(peer_ip);
674 assert!(tcp.can_add_connection());
675
676 tcp.connecting.lock().insert(peer_ip);
678 assert!(!tcp.can_add_connection());
679
680 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap();
682 let result = tcp.connect(another_ip).await;
683 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. })));
684
685 tcp.connecting.lock().remove(&peer_ip);
687 assert!(tcp.can_add_connection());
688
689 let stream = TcpStream::connect(peer_ip).await.unwrap();
691 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder));
692 tcp.connecting.lock().insert(peer_ip);
693 assert!(!tcp.can_add_connection());
694
695 tcp.connections.remove(peer_ip);
697 tcp.connecting.lock().remove(&peer_ip);
698 assert!(tcp.can_add_connection());
699 }
700
701 #[tokio::test]
702 async fn test_handle_connection() {
703 let tcp = Tcp::new(Config {
704 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
705 max_connections: 1,
706 ..Default::default()
707 });
708
709 let peer1 = Tcp::new(Config {
711 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
712 desired_listening_port: Some(0),
713 max_connections: 1,
714 ..Default::default()
715 });
716 let peer1_ip = peer1.enable_listener().await.unwrap();
717
718 let stream = TcpStream::connect(peer1_ip).await.unwrap();
720 tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder));
721 assert!(!tcp.can_add_connection());
722 assert_eq!(tcp.num_connected(), 1);
723 assert_eq!(tcp.num_connecting(), 0);
724 assert!(tcp.is_connected(peer1_ip));
725 assert!(!tcp.is_connecting(peer1_ip));
726
727 let peer2 = Tcp::new(Config {
729 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
730 desired_listening_port: Some(0),
731 max_connections: 1,
732 ..Default::default()
733 });
734 let peer2_ip = peer2.enable_listener().await.unwrap();
735
736 let stream = TcpStream::connect(peer2_ip).await.unwrap();
738 tcp.handle_connection(stream, peer2_ip);
739 assert!(!tcp.can_add_connection());
740 assert_eq!(tcp.num_connected(), 1);
741 assert_eq!(tcp.num_connecting(), 0);
742 assert!(tcp.is_connected(peer1_ip));
743 assert!(!tcp.is_connected(peer2_ip));
744 assert!(!tcp.is_connecting(peer1_ip));
745 assert!(!tcp.is_connecting(peer2_ip));
746 }
747
748 #[tokio::test]
749 async fn test_adapt_stream() {
750 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() });
751
752 let peer = Tcp::new(Config {
754 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
755 desired_listening_port: Some(0),
756 max_connections: 1,
757 ..Default::default()
758 });
759 let peer_ip = peer.enable_listener().await.unwrap();
760
761 tcp.connecting.lock().insert(peer_ip);
763 assert_eq!(tcp.num_connected(), 0);
764 assert_eq!(tcp.num_connecting(), 1);
765 assert!(!tcp.is_connected(peer_ip));
766 assert!(tcp.is_connecting(peer_ip));
767
768 let stream = TcpStream::connect(peer_ip).await.unwrap();
770 tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap();
771 assert_eq!(tcp.num_connected(), 1);
772 assert_eq!(tcp.num_connecting(), 0);
773 assert!(tcp.is_connected(peer_ip));
774 assert!(!tcp.is_connecting(peer_ip));
775 }
776}