snarkos_node_router/
handshake.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    ConnectionMode,
18    NodeType,
19    PeerPoolHandling,
20    Router,
21    messages::{ChallengeRequest, ChallengeResponse, DisconnectReason, Message, MessageCodec, MessageTrait},
22};
23use snarkos_node_network::{built_info, log_repo_sha_comparison};
24use snarkos_node_tcp::{ConnectionSide, P2P, Tcp};
25use snarkvm::{
26    ledger::narwhal::Data,
27    prelude::{Address, ConsensusVersion, Field, Network, block::Header, error},
28};
29
30use anyhow::{Result, bail};
31use futures::SinkExt;
32use rand::{Rng, rngs::OsRng};
33use std::{io, net::SocketAddr};
34use tokio::net::TcpStream;
35use tokio_stream::StreamExt;
36use tokio_util::codec::Framed;
37
38impl<N: Network> P2P for Router<N> {
39    /// Returns a reference to the TCP instance.
40    fn tcp(&self) -> &Tcp {
41        &self.tcp
42    }
43}
44
45/// A macro unwrapping the expected handshake message or returning an error for unexpected messages.
46#[macro_export]
47macro_rules! expect_message {
48    ($msg_ty:path, $framed:expr, $peer_addr:expr) => {
49        match $framed.try_next().await? {
50            // Received the expected message, proceed.
51            Some($msg_ty(data)) => {
52                trace!("Received '{}' from '{}'", data.name(), $peer_addr);
53                data
54            }
55            // Received a disconnect message, abort.
56            Some(Message::Disconnect(reason)) => {
57                return Err(error(format!("'{}' disconnected: {reason:?}", $peer_addr)))
58            }
59            // Received an unexpected message, abort.
60            Some(ty) => {
61                return Err(error(format!(
62                    "'{}' did not follow the handshake protocol: received {:?} instead of {}",
63                    $peer_addr,
64                    ty.name(),
65                    stringify!($msg_ty),
66                )))
67            }
68            // Received nothing.
69            None => {
70                return Err(error(format!(
71                    "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
72                    stringify!($msg_ty),
73                )))
74            }
75        }
76    };
77}
78
79/// Send the given message to the peer.
80async fn send<N: Network>(
81    framed: &mut Framed<&mut TcpStream, MessageCodec<N>>,
82    peer_addr: SocketAddr,
83    message: Message<N>,
84) -> io::Result<()> {
85    trace!("Sending '{}' to '{peer_addr}'", message.name());
86    framed.send(message).await
87}
88
89impl<N: Network> Router<N> {
90    /// Executes the handshake protocol.
91    pub async fn handshake<'a>(
92        &'a self,
93        peer_addr: SocketAddr,
94        stream: &'a mut TcpStream,
95        peer_side: ConnectionSide,
96        genesis_header: Header<N>,
97        restrictions_id: Field<N>,
98    ) -> io::Result<Option<ChallengeRequest<N>>> {
99        // If this is an inbound connection, we log it, but don't know the listening address yet.
100        // Otherwise, we can immediately register the listening address.
101        let mut listener_addr = if peer_side == ConnectionSide::Initiator {
102            debug!("Received a connection request from '{peer_addr}'");
103            None
104        } else {
105            debug!("Shaking hands with '{peer_addr}'...");
106            Some(peer_addr)
107        };
108
109        // Check (or impose) IP-level bans.
110        #[cfg(not(feature = "test"))]
111        if !self.is_dev() && peer_side == ConnectionSide::Initiator {
112            // If the IP is already banned reject the connection.
113            if self.is_ip_banned(peer_addr.ip()) {
114                trace!("Rejected a connection request from banned IP '{}'", peer_addr.ip());
115                return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
116            }
117
118            let num_attempts =
119                self.cache.insert_inbound_connection(peer_addr.ip(), Router::<N>::CONNECTION_ATTEMPTS_SINCE_SECS);
120
121            debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
122            if num_attempts > Router::<N>::MAX_CONNECTION_ATTEMPTS {
123                self.update_ip_ban(peer_addr.ip());
124                trace!("Rejected a consecutive connection request from IP '{}'", peer_addr.ip());
125                return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
126            }
127        }
128
129        // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
130        let handshake_result = if peer_side == ConnectionSide::Responder {
131            self.handshake_inner_initiator(peer_addr, stream, genesis_header, restrictions_id).await
132        } else {
133            self.handshake_inner_responder(peer_addr, &mut listener_addr, stream, genesis_header, restrictions_id).await
134        };
135
136        if let Some(addr) = listener_addr {
137            match handshake_result {
138                Ok(Some(ref cr)) => {
139                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
140                        self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, Some(cr.address));
141                        peer.upgrade_to_connected(
142                            peer_addr,
143                            cr.listener_port,
144                            cr.address,
145                            cr.node_type,
146                            cr.version,
147                            ConnectionMode::Router,
148                        );
149                    }
150                    #[cfg(feature = "metrics")]
151                    self.update_metrics();
152                    debug!("Completed the handshake with '{peer_addr}'");
153                }
154                Ok(None) => {
155                    return Err(error(format!("Duplicate handshake attempt with '{addr}'")));
156                }
157                Err(_) => {
158                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
159                        // The peer may only be downgraded if it's a ConnectingPeer.
160                        if peer.is_connecting() {
161                            peer.downgrade_to_candidate(addr);
162                        }
163                    }
164                }
165            }
166        }
167
168        handshake_result
169    }
170
171    /// The connection initiator side of the handshake.
172    async fn handshake_inner_initiator<'a>(
173        &'a self,
174        peer_addr: SocketAddr,
175        stream: &'a mut TcpStream,
176        genesis_header: Header<N>,
177        restrictions_id: Field<N>,
178    ) -> io::Result<Option<ChallengeRequest<N>>> {
179        // Introduce the peer into the peer pool.
180        if !self.add_connecting_peer(peer_addr) {
181            // Return early if already being connected to.
182            return Ok(None);
183        }
184
185        // Construct the stream.
186        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
187
188        // Initialize an RNG.
189        let rng = &mut OsRng;
190
191        // Determine the snarkOS SHA to send to the peer.
192        let current_block_height = self.ledger.latest_block_height();
193        let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
194        let snarkos_sha = (consensus_version >= ConsensusVersion::V12)
195            .then(|| built_info::GIT_COMMIT_HASH.unwrap_or_default().into());
196
197        /* Step 1: Send the challenge request. */
198
199        // Sample a random nonce.
200        let our_nonce = rng.r#gen();
201        // Send a challenge request to the peer.
202        let our_request =
203            ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce, snarkos_sha);
204        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
205
206        /* Step 2: Receive the peer's challenge response followed by the challenge request. */
207
208        // Listen for the challenge response message.
209        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
210        // Listen for the challenge request message.
211        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
212
213        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
214        if let Some(reason) = self
215            .verify_challenge_response(
216                peer_addr,
217                peer_request.address,
218                peer_request.node_type,
219                peer_response,
220                genesis_header,
221                restrictions_id,
222                our_nonce,
223            )
224            .await
225        {
226            send(&mut framed, peer_addr, reason.into()).await?;
227            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
228        }
229        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
230        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
231            send(&mut framed, peer_addr, reason.into()).await?;
232            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
233        }
234
235        /* Step 3: Send the challenge response. */
236
237        let response_nonce: u64 = rng.r#gen();
238        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
239        // Sign the counterparty nonce.
240        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
241            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
242        };
243        // Send the challenge response.
244        let our_response = ChallengeResponse {
245            genesis_header,
246            restrictions_id,
247            signature: Data::Object(our_signature),
248            nonce: response_nonce,
249        };
250        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
251
252        Ok(Some(peer_request))
253    }
254
255    /// The connection responder side of the handshake.
256    async fn handshake_inner_responder<'a>(
257        &'a self,
258        peer_addr: SocketAddr,
259        listener_addr: &mut Option<SocketAddr>,
260        stream: &'a mut TcpStream,
261        genesis_header: Header<N>,
262        restrictions_id: Field<N>,
263    ) -> io::Result<Option<ChallengeRequest<N>>> {
264        // Construct the stream.
265        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
266
267        /* Step 1: Receive the challenge request. */
268
269        // Listen for the challenge request message.
270        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
271
272        // Determine the snarkOS SHA to send to the peer.
273        let current_block_height = self.ledger.latest_block_height();
274        let consensus_version = N::CONSENSUS_VERSION(current_block_height).unwrap();
275        let snarkos_sha = (consensus_version >= ConsensusVersion::V12)
276            .then(|| built_info::GIT_COMMIT_HASH.unwrap_or_default().into());
277
278        // Obtain the peer's listening address.
279        *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
280        let listener_addr = listener_addr.unwrap();
281
282        // Knowing the peer's listening address, ensure it is allowed to connect.
283        if let Err(forbidden_message) = self.ensure_peer_is_allowed(listener_addr) {
284            return Err(error(format!("{forbidden_message}")));
285        }
286
287        // Introduce the peer into the peer pool.
288        if !self.add_connecting_peer(listener_addr) {
289            // Return early if already being connected to.
290            return Ok(None);
291        }
292
293        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
294        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
295            send(&mut framed, peer_addr, reason.into()).await?;
296            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
297        }
298
299        /* Step 2: Send the challenge response followed by own challenge request. */
300
301        // Initialize an RNG.
302        let rng = &mut OsRng;
303
304        // Sign the counterparty nonce.
305        let response_nonce: u64 = rng.r#gen();
306        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
307        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
308            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
309        };
310        // Send the challenge response.
311        let our_response = ChallengeResponse {
312            genesis_header,
313            restrictions_id,
314            signature: Data::Object(our_signature),
315            nonce: response_nonce,
316        };
317        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
318
319        // Sample a random nonce.
320        let our_nonce = rng.r#gen();
321        // Send the challenge request.
322        let our_request =
323            ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce, snarkos_sha);
324        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
325
326        /* Step 3: Receive the challenge response. */
327
328        // Listen for the challenge response message.
329        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
330        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
331        if let Some(reason) = self
332            .verify_challenge_response(
333                peer_addr,
334                peer_request.address,
335                peer_request.node_type,
336                peer_response,
337                genesis_header,
338                restrictions_id,
339                our_nonce,
340            )
341            .await
342        {
343            send(&mut framed, peer_addr, reason.into()).await?;
344            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
345        }
346
347        Ok(Some(peer_request))
348    }
349
350    /// Ensure the peer is allowed to connect.
351    fn ensure_peer_is_allowed(&self, listener_addr: SocketAddr) -> Result<()> {
352        // Ensure that it's not a self-connect attempt.
353        if self.is_local_ip(listener_addr) {
354            bail!("Dropping connection request from '{listener_addr}' (attempted to self-connect)");
355        }
356        // Unknown peers are untrusted, so check if `trusted_peers_only` is true.
357        if self.trusted_peers_only() && !self.is_trusted(listener_addr) {
358            bail!("Dropping connection request from '{listener_addr}' (untrusted)");
359        }
360        Ok(())
361    }
362
363    /// Verifies the given challenge request. Returns a disconnect reason if the request is invalid.
364    fn verify_challenge_request(
365        &self,
366        peer_addr: SocketAddr,
367        message: &ChallengeRequest<N>,
368    ) -> Option<DisconnectReason> {
369        // Retrieve the components of the challenge request.
370        let &ChallengeRequest { version, listener_port: _, node_type, address, nonce: _, ref snarkos_sha } = message;
371        log_repo_sha_comparison(peer_addr, snarkos_sha.as_ref(), Self::OWNER);
372
373        // Ensure the message protocol version is not outdated.
374        if !self.is_valid_message_version(version) {
375            warn!("Dropping '{peer_addr}' on version {version} (outdated)");
376            return Some(DisconnectReason::OutdatedClientVersion);
377        }
378
379        // Ensure there are no validators connected with the given Aleo address.
380        if self.node_type() == NodeType::Validator
381            && node_type == NodeType::Validator
382            && self.is_connected_address(address)
383        {
384            warn!("Dropping '{peer_addr}' for being already connected ({address})");
385            return Some(DisconnectReason::NoReasonGiven);
386        }
387
388        None
389    }
390
391    /// Verifies the given challenge response. Returns a disconnect reason if the response is invalid.
392    #[allow(clippy::too_many_arguments)]
393    async fn verify_challenge_response(
394        &self,
395        peer_addr: SocketAddr,
396        peer_address: Address<N>,
397        peer_node_type: NodeType,
398        response: ChallengeResponse<N>,
399        expected_genesis_header: Header<N>,
400        expected_restrictions_id: Field<N>,
401        expected_nonce: u64,
402    ) -> Option<DisconnectReason> {
403        // Retrieve the components of the challenge response.
404        let ChallengeResponse { genesis_header, restrictions_id, signature, nonce } = response;
405
406        // Verify the challenge response, by checking that the block header matches.
407        if genesis_header != expected_genesis_header {
408            warn!("Handshake with '{peer_addr}' failed (incorrect block header)");
409            return Some(DisconnectReason::InvalidChallengeResponse);
410        }
411        // Verify the restrictions ID.
412        if !peer_node_type.is_prover() && !self.node_type.is_prover() && restrictions_id != expected_restrictions_id {
413            warn!("Handshake with '{peer_addr}' failed (incorrect restrictions ID)");
414            return Some(DisconnectReason::InvalidChallengeResponse);
415        }
416        // Perform the deferred non-blocking deserialization of the signature.
417        let Ok(signature) = signature.deserialize().await else {
418            warn!("Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
419            return Some(DisconnectReason::InvalidChallengeResponse);
420        };
421        // Verify the signature.
422        if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
423            warn!("Handshake with '{peer_addr}' failed (invalid signature)");
424            return Some(DisconnectReason::InvalidChallengeResponse);
425        }
426        None
427    }
428}