Skip to main content

snarkos_node_tcp/helpers/
connections.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Objects associated with connection handling.
17
18use std::{collections::HashMap, net::SocketAddr, ops::Not, sync::atomic::AtomicBool};
19
20#[cfg(feature = "locktick")]
21use locktick::parking_lot::RwLock;
22#[cfg(not(feature = "locktick"))]
23use parking_lot::RwLock;
24use tokio::{
25    io::{AsyncRead, AsyncWrite},
26    net::TcpStream,
27    sync::oneshot,
28    task::JoinHandle,
29};
30use tracing::*;
31
32#[cfg(doc)]
33use crate::protocols::{Handshake, Reading, Writing};
34
35/// A map of all currently connected addresses to their associated connection.
36#[derive(Default)]
37pub(crate) struct Connections(pub(crate) RwLock<HashMap<SocketAddr, Connection>>);
38
39impl Connections {
40    /// Adds the given connection to the list of active connections.
41    pub(crate) fn add(&self, conn: Connection) {
42        self.0.write().insert(conn.addr, conn);
43    }
44
45    /// Returns `true` if the given address is connected.
46    pub(crate) fn is_connected(&self, addr: SocketAddr) -> bool {
47        self.0.read().contains_key(&addr)
48    }
49
50    /// Removes the connection associated with the given address.
51    pub(crate) fn remove(&self, addr: SocketAddr) -> Option<Connection> {
52        self.0.write().remove(&addr)
53    }
54
55    /// Returns the number of connected addresses.
56    pub(crate) fn num_connected(&self) -> usize {
57        self.0.read().len()
58    }
59
60    /// Returns the list of connected addresses.
61    pub(crate) fn addrs(&self) -> Vec<SocketAddr> {
62        self.0.read().keys().copied().collect()
63    }
64}
65
66/// A helper trait to facilitate trait-objectification of connection readers.
67pub(crate) trait AR: AsyncRead + Unpin + Send + Sync {}
68impl<T: AsyncRead + Unpin + Send + Sync> AR for T {}
69
70/// A helper trait to facilitate trait-objectification of connection writers.
71pub(crate) trait AW: AsyncWrite + Unpin + Send + Sync {}
72impl<T: AsyncWrite + Unpin + Send + Sync> AW for T {}
73
74/// Created for each active connection; used by the protocols to obtain a handle for
75/// reading and writing, and keeps track of tasks that have been spawned for the connection.
76pub struct Connection {
77    /// The address of the connection.
78    addr: SocketAddr,
79    /// The connection's side in relation to Tcp.
80    side: ConnectionSide,
81    /// Available and used only in the [`Handshake`] protocol.
82    pub(crate) stream: Option<TcpStream>,
83    /// Available and used only in the [`Reading`] protocol.
84    pub(crate) reader: Option<Box<dyn AR>>,
85    /// Available and used only in the [`Writing`] protocol.
86    pub(crate) writer: Option<Box<dyn AW>>,
87    /// Used to notify the [`Reading`] protocol that the connection is fully ready.
88    pub(crate) readiness_notifier: Option<oneshot::Sender<()>>,
89    /// Prevents the OnDisconnect hook from being triggered multiple times.
90    pub(crate) disconnecting: AtomicBool,
91    /// Handles to tasks spawned for the connection.
92    pub(crate) tasks: Vec<JoinHandle<()>>,
93    /// The tracing span.
94    pub(crate) span: Span,
95}
96
97impl Connection {
98    /// Creates a [`Connection`] with placeholders for protocol-related objects.
99    pub(crate) fn new(addr: SocketAddr, stream: TcpStream, side: ConnectionSide, span: Span) -> Self {
100        Self {
101            addr,
102            stream: Some(stream),
103            reader: None,
104            writer: None,
105            readiness_notifier: None,
106            disconnecting: Default::default(),
107            side,
108            tasks: Default::default(),
109            span,
110        }
111    }
112
113    /// Returns the address associated with the connection.
114    pub fn addr(&self) -> SocketAddr {
115        self.addr
116    }
117
118    /// Returns `ConnectionSide::Initiator` if the associated peer initiated the connection
119    /// and `ConnectionSide::Responder` if the connection request was initiated by Tcp.
120    pub fn side(&self) -> ConnectionSide {
121        self.side
122    }
123
124    /// Returns the tracing [`Span`] associated with the connection.
125    #[inline]
126    pub const fn span(&self) -> &Span {
127        &self.span
128    }
129}
130
131/// Indicates who was the initiator and who was the responder when the connection was established.
132#[derive(Clone, Copy, Debug, PartialEq, Eq)]
133pub enum ConnectionSide {
134    /// The side that initiated the connection.
135    Initiator,
136    /// The side that accepted the connection.
137    Responder,
138}
139
140impl Not for ConnectionSide {
141    type Output = Self;
142
143    fn not(self) -> Self::Output {
144        match self {
145            Self::Initiator => Self::Responder,
146            Self::Responder => Self::Initiator,
147        }
148    }
149}
150
151impl Drop for Connection {
152    fn drop(&mut self) {
153        for task in self.tasks.iter().rev() {
154            task.abort();
155        }
156    }
157}
158
159pub(crate) fn create_connection_span(addr: SocketAddr, parent: &Span) -> Span {
160    macro_rules! try_span {
161        ($lvl:expr) => {
162            let s = span!(parent: parent, $lvl, "conn", addr = %addr);
163            if !s.is_disabled() {
164                return s;
165            }
166        };
167    }
168    try_span!(Level::TRACE);
169    try_span!(Level::DEBUG);
170    try_span!(Level::INFO);
171    try_span!(Level::WARN);
172    error_span!(parent: parent, "conn", addr = %addr)
173}