snarkos_node_router/
handshake.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    NodeType,
18    PeerPoolHandling,
19    Router,
20    messages::{ChallengeRequest, ChallengeResponse, DisconnectReason, Message, MessageCodec, MessageTrait},
21};
22use snarkos_node_tcp::{ConnectionSide, P2P, Tcp};
23use snarkvm::{
24    ledger::narwhal::Data,
25    prelude::{Address, Field, Network, block::Header, error},
26};
27
28use anyhow::{Result, bail};
29use futures::SinkExt;
30use rand::{Rng, rngs::OsRng};
31use std::{io, net::SocketAddr};
32use tokio::net::TcpStream;
33use tokio_stream::StreamExt;
34use tokio_util::codec::Framed;
35
36impl<N: Network> P2P for Router<N> {
37    /// Returns a reference to the TCP instance.
38    fn tcp(&self) -> &Tcp {
39        &self.tcp
40    }
41}
42
43/// A macro unwrapping the expected handshake message or returning an error for unexpected messages.
44#[macro_export]
45macro_rules! expect_message {
46    ($msg_ty:path, $framed:expr, $peer_addr:expr) => {
47        match $framed.try_next().await? {
48            // Received the expected message, proceed.
49            Some($msg_ty(data)) => {
50                trace!("Received '{}' from '{}'", data.name(), $peer_addr);
51                data
52            }
53            // Received a disconnect message, abort.
54            Some(Message::Disconnect(reason)) => {
55                return Err(error(format!("'{}' disconnected: {reason:?}", $peer_addr)))
56            }
57            // Received an unexpected message, abort.
58            Some(ty) => {
59                return Err(error(format!(
60                    "'{}' did not follow the handshake protocol: received {:?} instead of {}",
61                    $peer_addr,
62                    ty.name(),
63                    stringify!($msg_ty),
64                )))
65            }
66            // Received nothing.
67            None => {
68                return Err(error(format!(
69                    "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
70                    stringify!($msg_ty),
71                )))
72            }
73        }
74    };
75}
76
77/// Send the given message to the peer.
78async fn send<N: Network>(
79    framed: &mut Framed<&mut TcpStream, MessageCodec<N>>,
80    peer_addr: SocketAddr,
81    message: Message<N>,
82) -> io::Result<()> {
83    trace!("Sending '{}' to '{peer_addr}'", message.name());
84    framed.send(message).await
85}
86
87impl<N: Network> Router<N> {
88    /// Executes the handshake protocol.
89    pub async fn handshake<'a>(
90        &'a self,
91        peer_addr: SocketAddr,
92        stream: &'a mut TcpStream,
93        peer_side: ConnectionSide,
94        genesis_header: Header<N>,
95        restrictions_id: Field<N>,
96    ) -> io::Result<Option<ChallengeRequest<N>>> {
97        // If this is an inbound connection, we log it, but don't know the listening address yet.
98        // Otherwise, we can immediately register the listening address.
99        let mut listener_addr = if peer_side == ConnectionSide::Initiator {
100            debug!("Received a connection request from '{peer_addr}'");
101            None
102        } else {
103            debug!("Shaking hands with '{peer_addr}'...");
104            Some(peer_addr)
105        };
106
107        // Check (or impose) IP-level bans.
108        #[cfg(not(feature = "test"))]
109        if !self.is_dev() && peer_side == ConnectionSide::Initiator {
110            // If the IP is already banned reject the connection.
111            if self.is_ip_banned(peer_addr.ip()) {
112                trace!("Rejected a connection request from banned IP '{}'", peer_addr.ip());
113                return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
114            }
115
116            let num_attempts =
117                self.cache.insert_inbound_connection(peer_addr.ip(), Router::<N>::CONNECTION_ATTEMPTS_SINCE_SECS);
118
119            debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
120            if num_attempts > Router::<N>::MAX_CONNECTION_ATTEMPTS {
121                self.update_ip_ban(peer_addr.ip());
122                trace!("Rejected a consecutive connection request from IP '{}'", peer_addr.ip());
123                return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
124            }
125        }
126
127        // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
128        let handshake_result = if peer_side == ConnectionSide::Responder {
129            self.handshake_inner_initiator(peer_addr, stream, genesis_header, restrictions_id).await
130        } else {
131            self.handshake_inner_responder(peer_addr, &mut listener_addr, stream, genesis_header, restrictions_id).await
132        };
133
134        if let Some(addr) = listener_addr {
135            match handshake_result {
136                Ok(Some(ref cr)) => {
137                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
138                        self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, None);
139                        peer.upgrade_to_connected(peer_addr, cr.listener_port, cr.address, cr.node_type, cr.version);
140                    }
141                    #[cfg(feature = "metrics")]
142                    self.update_metrics();
143                    debug!("Completed the handshake with '{peer_addr}'");
144                }
145                Ok(None) => {
146                    return Err(error("Duplicate handshake attempt with '{addr}'"));
147                }
148                Err(_) => {
149                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
150                        peer.downgrade_to_candidate(addr);
151                    }
152                }
153            }
154        }
155
156        handshake_result
157    }
158
159    /// The connection initiator side of the handshake.
160    async fn handshake_inner_initiator<'a>(
161        &'a self,
162        peer_addr: SocketAddr,
163        stream: &'a mut TcpStream,
164        genesis_header: Header<N>,
165        restrictions_id: Field<N>,
166    ) -> io::Result<Option<ChallengeRequest<N>>> {
167        // Introduce the peer into the peer pool.
168        if !self.add_connecting_peer(peer_addr) {
169            // Return early if already being connected to.
170            return Ok(None);
171        }
172
173        // Construct the stream.
174        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
175
176        // Initialize an RNG.
177        let rng = &mut OsRng;
178
179        /* Step 1: Send the challenge request. */
180
181        // Sample a random nonce.
182        let our_nonce = rng.r#gen();
183        // Send a challenge request to the peer.
184        let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
185        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
186
187        /* Step 2: Receive the peer's challenge response followed by the challenge request. */
188
189        // Listen for the challenge response message.
190        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
191        // Listen for the challenge request message.
192        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
193
194        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
195        if let Some(reason) = self
196            .verify_challenge_response(
197                peer_addr,
198                peer_request.address,
199                peer_request.node_type,
200                peer_response,
201                genesis_header,
202                restrictions_id,
203                our_nonce,
204            )
205            .await
206        {
207            send(&mut framed, peer_addr, reason.into()).await?;
208            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
209        }
210        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
211        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
212            send(&mut framed, peer_addr, reason.into()).await?;
213            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
214        }
215
216        /* Step 3: Send the challenge response. */
217
218        let response_nonce: u64 = rng.r#gen();
219        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
220        // Sign the counterparty nonce.
221        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
222            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
223        };
224        // Send the challenge response.
225        let our_response = ChallengeResponse {
226            genesis_header,
227            restrictions_id,
228            signature: Data::Object(our_signature),
229            nonce: response_nonce,
230        };
231        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
232
233        Ok(Some(peer_request))
234    }
235
236    /// The connection responder side of the handshake.
237    async fn handshake_inner_responder<'a>(
238        &'a self,
239        peer_addr: SocketAddr,
240        listener_addr: &mut Option<SocketAddr>,
241        stream: &'a mut TcpStream,
242        genesis_header: Header<N>,
243        restrictions_id: Field<N>,
244    ) -> io::Result<Option<ChallengeRequest<N>>> {
245        // Construct the stream.
246        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
247
248        /* Step 1: Receive the challenge request. */
249
250        // Listen for the challenge request message.
251        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
252
253        // Obtain the peer's listening address.
254        *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
255        let listener_addr = listener_addr.unwrap();
256
257        // Knowing the peer's listening address, ensure it is allowed to connect.
258        if let Err(forbidden_message) = self.ensure_peer_is_allowed(listener_addr) {
259            return Err(error(format!("{forbidden_message}")));
260        }
261
262        // Introduce the peer into the peer pool.
263        if !self.add_connecting_peer(listener_addr) {
264            // Return early if already being connected to.
265            return Ok(None);
266        }
267
268        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
269        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
270            send(&mut framed, peer_addr, reason.into()).await?;
271            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
272        }
273
274        /* Step 2: Send the challenge response followed by own challenge request. */
275
276        // Initialize an RNG.
277        let rng = &mut OsRng;
278
279        // Sign the counterparty nonce.
280        let response_nonce: u64 = rng.r#gen();
281        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
282        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
283            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
284        };
285        // Send the challenge response.
286        let our_response = ChallengeResponse {
287            genesis_header,
288            restrictions_id,
289            signature: Data::Object(our_signature),
290            nonce: response_nonce,
291        };
292        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
293
294        // Sample a random nonce.
295        let our_nonce = rng.r#gen();
296        // Send the challenge request.
297        let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
298        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
299
300        /* Step 3: Receive the challenge response. */
301
302        // Listen for the challenge response message.
303        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
304        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
305        if let Some(reason) = self
306            .verify_challenge_response(
307                peer_addr,
308                peer_request.address,
309                peer_request.node_type,
310                peer_response,
311                genesis_header,
312                restrictions_id,
313                our_nonce,
314            )
315            .await
316        {
317            send(&mut framed, peer_addr, reason.into()).await?;
318            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
319        }
320
321        Ok(Some(peer_request))
322    }
323
324    /// Ensure the peer is allowed to connect.
325    fn ensure_peer_is_allowed(&self, listener_addr: SocketAddr) -> Result<()> {
326        // Ensure that it's not a self-connect attempt.
327        if self.is_local_ip(listener_addr) {
328            bail!("Dropping connection request from '{listener_addr}' (attempted to self-connect)");
329        }
330        // Unknown peers are untrusted, so check if `allow_external_peers` is true.
331        if !self.allow_external_peers() && !self.is_trusted(listener_addr) {
332            bail!("Dropping connection request from '{listener_addr}' (untrusted)");
333        }
334        Ok(())
335    }
336
337    /// Verifies the given challenge request. Returns a disconnect reason if the request is invalid.
338    fn verify_challenge_request(
339        &self,
340        peer_addr: SocketAddr,
341        message: &ChallengeRequest<N>,
342    ) -> Option<DisconnectReason> {
343        // Retrieve the components of the challenge request.
344        let &ChallengeRequest { version, listener_port: _, node_type: _, address: _, nonce: _ } = message;
345
346        // Ensure the message protocol version is not outdated.
347        if !self.is_valid_message_version(version) {
348            warn!("Dropping '{peer_addr}' on version {version} (outdated)");
349            return Some(DisconnectReason::OutdatedClientVersion);
350        }
351        None
352    }
353
354    /// Verifies the given challenge response. Returns a disconnect reason if the response is invalid.
355    #[allow(clippy::too_many_arguments)]
356    async fn verify_challenge_response(
357        &self,
358        peer_addr: SocketAddr,
359        peer_address: Address<N>,
360        peer_node_type: NodeType,
361        response: ChallengeResponse<N>,
362        expected_genesis_header: Header<N>,
363        expected_restrictions_id: Field<N>,
364        expected_nonce: u64,
365    ) -> Option<DisconnectReason> {
366        // Retrieve the components of the challenge response.
367        let ChallengeResponse { genesis_header, restrictions_id, signature, nonce } = response;
368
369        // Verify the challenge response, by checking that the block header matches.
370        if genesis_header != expected_genesis_header {
371            warn!("Handshake with '{peer_addr}' failed (incorrect block header)");
372            return Some(DisconnectReason::InvalidChallengeResponse);
373        }
374        // Verify the restrictions ID.
375        if !peer_node_type.is_prover() && !self.node_type.is_prover() && restrictions_id != expected_restrictions_id {
376            warn!("Handshake with '{peer_addr}' failed (incorrect restrictions ID)");
377            return Some(DisconnectReason::InvalidChallengeResponse);
378        }
379        // Perform the deferred non-blocking deserialization of the signature.
380        let Ok(signature) = signature.deserialize().await else {
381            warn!("Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
382            return Some(DisconnectReason::InvalidChallengeResponse);
383        };
384        // Verify the signature.
385        if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
386            warn!("Handshake with '{peer_addr}' failed (invalid signature)");
387            return Some(DisconnectReason::InvalidChallengeResponse);
388        }
389        None
390    }
391}