snarkos_node_router/
handshake.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    NodeType,
18    Peer,
19    Router,
20    messages::{ChallengeRequest, ChallengeResponse, DisconnectReason, Message, MessageCodec, MessageTrait},
21};
22use snarkos_node_tcp::{ConnectionSide, P2P, Tcp};
23use snarkvm::{
24    ledger::narwhal::Data,
25    prelude::{Address, Field, Network, block::Header, error},
26};
27
28use anyhow::{Result, bail};
29use futures::SinkExt;
30use rand::{Rng, rngs::OsRng};
31use std::{collections::hash_map::Entry, io, net::SocketAddr};
32use tokio::net::TcpStream;
33use tokio_stream::StreamExt;
34use tokio_util::codec::Framed;
35
36impl<N: Network> P2P for Router<N> {
37    /// Returns a reference to the TCP instance.
38    fn tcp(&self) -> &Tcp {
39        &self.tcp
40    }
41}
42
43/// A macro unwrapping the expected handshake message or returning an error for unexpected messages.
44#[macro_export]
45macro_rules! expect_message {
46    ($msg_ty:path, $framed:expr, $peer_addr:expr) => {
47        match $framed.try_next().await? {
48            // Received the expected message, proceed.
49            Some($msg_ty(data)) => {
50                trace!("Received '{}' from '{}'", data.name(), $peer_addr);
51                data
52            }
53            // Received a disconnect message, abort.
54            Some(Message::Disconnect(reason)) => {
55                return Err(error(format!("'{}' disconnected: {reason:?}", $peer_addr)))
56            }
57            // Received an unexpected message, abort.
58            Some(ty) => {
59                return Err(error(format!(
60                    "'{}' did not follow the handshake protocol: received {:?} instead of {}",
61                    $peer_addr,
62                    ty.name(),
63                    stringify!($msg_ty),
64                )))
65            }
66            // Received nothing.
67            None => {
68                return Err(error(format!("'{}' disconnected before sending {:?}", $peer_addr, stringify!($msg_ty),)))
69            }
70        }
71    };
72}
73
74/// Send the given message to the peer.
75async fn send<N: Network>(
76    framed: &mut Framed<&mut TcpStream, MessageCodec<N>>,
77    peer_addr: SocketAddr,
78    message: Message<N>,
79) -> io::Result<()> {
80    trace!("Sending '{}' to '{peer_addr}'", message.name());
81    framed.send(message).await
82}
83
84impl<N: Network> Router<N> {
85    /// Executes the handshake protocol.
86    pub async fn handshake<'a>(
87        &'a self,
88        peer_addr: SocketAddr,
89        stream: &'a mut TcpStream,
90        peer_side: ConnectionSide,
91        genesis_header: Header<N>,
92        restrictions_id: Field<N>,
93    ) -> io::Result<(SocketAddr, Framed<&'a mut TcpStream, MessageCodec<N>>)> {
94        // If this is an inbound connection, we log it, but don't know the listening address yet.
95        // Otherwise, we can immediately register the listening address.
96        let mut peer_ip = if peer_side == ConnectionSide::Initiator {
97            debug!("Received a connection request from '{peer_addr}'");
98            None
99        } else {
100            debug!("Connecting to {peer_addr}...");
101            Some(peer_addr)
102        };
103
104        // Check (or impose) IP-level bans.
105        #[cfg(not(any(test)))]
106        if !self.is_dev() && peer_side == ConnectionSide::Initiator {
107            // If the IP is already banned reject the connection.
108            if self.is_ip_banned(peer_addr.ip()) {
109                trace!("Rejected a connection request from banned IP '{}'", peer_addr.ip());
110                return Err(error(format!("'{}' is a banned IP address", peer_addr.ip())));
111            }
112
113            let num_attempts =
114                self.cache.insert_inbound_connection(peer_addr.ip(), Router::<N>::CONNECTION_ATTEMPTS_SINCE_SECS);
115
116            debug!("Number of connection attempts from '{}': {}", peer_addr.ip(), num_attempts);
117            if num_attempts > Router::<N>::MAX_CONNECTION_ATTEMPTS {
118                self.update_ip_ban(peer_addr.ip());
119                trace!("Rejected a consecutive connection request from IP '{}'", peer_addr.ip());
120                return Err(error(format!("'{}' appears to be spamming connections", peer_addr.ip())));
121            }
122        }
123
124        // Perform the handshake; we pass on a mutable reference to peer_ip in case the process is broken at any point in time.
125        let handshake_result = if peer_side == ConnectionSide::Responder {
126            self.handshake_inner_initiator(peer_addr, &mut peer_ip, stream, genesis_header, restrictions_id).await
127        } else {
128            self.handshake_inner_responder(peer_addr, &mut peer_ip, stream, genesis_header, restrictions_id).await
129        };
130
131        if let Some(ip) = peer_ip {
132            if handshake_result.is_err() {
133                // Remove the address from the collection of connecting peers (if the handshake got to the point where it's known).
134                self.connecting_peers.lock().remove(&ip);
135            }
136        }
137
138        handshake_result
139    }
140
141    /// The connection initiator side of the handshake.
142    async fn handshake_inner_initiator<'a>(
143        &'a self,
144        peer_addr: SocketAddr,
145        peer_ip: &mut Option<SocketAddr>,
146        stream: &'a mut TcpStream,
147        genesis_header: Header<N>,
148        restrictions_id: Field<N>,
149    ) -> io::Result<(SocketAddr, Framed<&'a mut TcpStream, MessageCodec<N>>)> {
150        // This value is immediately guaranteed to be present, so it can be unwrapped.
151        let peer_ip = peer_ip.unwrap();
152        // Construct the stream.
153        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
154
155        // Initialize an RNG.
156        let rng = &mut OsRng;
157
158        /* Step 1: Send the challenge request. */
159
160        // Sample a random nonce.
161        let our_nonce = rng.gen();
162        // Send a challenge request to the peer.
163        let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
164        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
165
166        /* Step 2: Receive the peer's challenge response followed by the challenge request. */
167
168        // Listen for the challenge response message.
169        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
170        // Listen for the challenge request message.
171        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
172
173        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
174        if let Some(reason) = self
175            .verify_challenge_response(
176                peer_addr,
177                peer_request.address,
178                peer_request.node_type,
179                peer_response,
180                genesis_header,
181                restrictions_id,
182                our_nonce,
183            )
184            .await
185        {
186            send(&mut framed, peer_addr, reason.into()).await?;
187            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
188        }
189        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
190        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
191            send(&mut framed, peer_addr, reason.into()).await?;
192            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
193        }
194        /* Step 3: Send the challenge response. */
195
196        let response_nonce: u64 = rng.gen();
197        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
198        // Sign the counterparty nonce.
199        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
200            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
201        };
202        // Send the challenge response.
203        let our_response = ChallengeResponse {
204            genesis_header,
205            restrictions_id,
206            signature: Data::Object(our_signature),
207            nonce: response_nonce,
208        };
209        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
210
211        // Finalize the connecting peer information.
212        self.connecting_peers.lock().insert(peer_ip, Some(Peer::new(peer_ip, peer_addr, &peer_request)));
213        // Adds a bidirectional map between the listener address and (ambiguous) peer address.
214        self.resolver.insert_peer(peer_ip, peer_addr);
215
216        Ok((peer_ip, framed))
217    }
218
219    /// The connection responder side of the handshake.
220    async fn handshake_inner_responder<'a>(
221        &'a self,
222        peer_addr: SocketAddr,
223        peer_ip: &mut Option<SocketAddr>,
224        stream: &'a mut TcpStream,
225        genesis_header: Header<N>,
226        restrictions_id: Field<N>,
227    ) -> io::Result<(SocketAddr, Framed<&'a mut TcpStream, MessageCodec<N>>)> {
228        // Construct the stream.
229        let mut framed = Framed::new(stream, MessageCodec::<N>::handshake());
230
231        /* Step 1: Receive the challenge request. */
232
233        // Listen for the challenge request message.
234        let peer_request = expect_message!(Message::ChallengeRequest, framed, peer_addr);
235
236        // Obtain the peer's listening address.
237        *peer_ip = Some(SocketAddr::new(peer_addr.ip(), peer_request.listener_port));
238        let peer_ip = peer_ip.unwrap();
239
240        // Knowing the peer's listening address, ensure it is allowed to connect.
241        if let Err(forbidden_message) = self.ensure_peer_is_allowed(peer_ip) {
242            return Err(error(format!("{forbidden_message}")));
243        }
244        // Verify the challenge request. If a disconnect reason was returned, send the disconnect message and abort.
245        if let Some(reason) = self.verify_challenge_request(peer_addr, &peer_request) {
246            send(&mut framed, peer_addr, reason.into()).await?;
247            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
248        }
249        /* Step 2: Send the challenge response followed by own challenge request. */
250
251        // Initialize an RNG.
252        let rng = &mut OsRng;
253
254        // Sign the counterparty nonce.
255        let response_nonce: u64 = rng.gen();
256        let data = [peer_request.nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
257        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
258            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
259        };
260        // Send the challenge response.
261        let our_response = ChallengeResponse {
262            genesis_header,
263            restrictions_id,
264            signature: Data::Object(our_signature),
265            nonce: response_nonce,
266        };
267        send(&mut framed, peer_addr, Message::ChallengeResponse(our_response)).await?;
268
269        // Sample a random nonce.
270        let our_nonce = rng.gen();
271        // Send the challenge request.
272        let our_request = ChallengeRequest::new(self.local_ip().port(), self.node_type, self.address(), our_nonce);
273        send(&mut framed, peer_addr, Message::ChallengeRequest(our_request)).await?;
274
275        /* Step 3: Receive the challenge response. */
276
277        // Listen for the challenge response message.
278        let peer_response = expect_message!(Message::ChallengeResponse, framed, peer_addr);
279        // Verify the challenge response. If a disconnect reason was returned, send the disconnect message and abort.
280        if let Some(reason) = self
281            .verify_challenge_response(
282                peer_addr,
283                peer_request.address,
284                peer_request.node_type,
285                peer_response,
286                genesis_header,
287                restrictions_id,
288                our_nonce,
289            )
290            .await
291        {
292            send(&mut framed, peer_addr, reason.into()).await?;
293            return Err(error(format!("Dropped '{peer_addr}' for reason: {reason:?}")));
294        }
295
296        // Finalize the connecting peer information.
297        self.connecting_peers.lock().insert(peer_ip, Some(Peer::new(peer_ip, peer_addr, &peer_request)));
298        // Adds a bidirectional map between the listener address and (ambiguous) peer address.
299        self.resolver.insert_peer(peer_ip, peer_addr);
300
301        Ok((peer_ip, framed))
302    }
303
304    /// Ensure the peer is allowed to connect.
305    fn ensure_peer_is_allowed(&self, peer_ip: SocketAddr) -> Result<()> {
306        // Ensure the peer IP is not this node.
307        if self.is_local_ip(&peer_ip) {
308            bail!("Dropping connection request from '{peer_ip}' (attempted to self-connect)")
309        }
310        // Ensure the node is not already connecting to this peer.
311        match self.connecting_peers.lock().entry(peer_ip) {
312            Entry::Vacant(entry) => entry.insert(None),
313            Entry::Occupied(_) => {
314                bail!("Dropping connection request from '{peer_ip}' (already shaking hands as the initiator)")
315            }
316        };
317        // Ensure the node is not already connected to this peer.
318        if self.is_connected(&peer_ip) {
319            bail!("Dropping connection request from '{peer_ip}' (already connected)")
320        }
321        // Ensure either the peer is trusted or `allow_external_peers` is true.
322        if !self.allow_external_peers() && !self.is_trusted(&peer_ip) {
323            bail!("Dropping connection request from '{peer_ip}' (untrusted)")
324        }
325        // Ensure the peer is not restricted.
326        if self.is_restricted(&peer_ip) {
327            bail!("Dropping connection request from '{peer_ip}' (restricted)")
328        }
329        // Ensure the peer is not spamming connection attempts.
330        if !peer_ip.ip().is_loopback() {
331            // Add this connection attempt and retrieve the number of attempts.
332            let num_attempts = self.cache.insert_inbound_connection(peer_ip.ip(), Self::RADIO_SILENCE_IN_SECS as i64);
333            // Ensure the connecting peer has not surpassed the connection attempt limit.
334            if num_attempts > Self::MAXIMUM_CONNECTION_FAILURES {
335                // Restrict the peer.
336                self.insert_restricted_peer(peer_ip);
337                bail!("Dropping connection request from '{peer_ip}' (tried {num_attempts} times)")
338            }
339        }
340        Ok(())
341    }
342
343    /// Verifies the given challenge request. Returns a disconnect reason if the request is invalid.
344    fn verify_challenge_request(
345        &self,
346        peer_addr: SocketAddr,
347        message: &ChallengeRequest<N>,
348    ) -> Option<DisconnectReason> {
349        // Retrieve the components of the challenge request.
350        let &ChallengeRequest { version, listener_port: _, node_type: _, address: _, nonce: _ } = message;
351
352        // Ensure the message protocol version is not outdated.
353        if !self.is_valid_message_version(version) {
354            warn!("Dropping '{peer_addr}' on version {version} (outdated)");
355            return Some(DisconnectReason::OutdatedClientVersion);
356        }
357        None
358    }
359
360    /// Verifies the given challenge response. Returns a disconnect reason if the response is invalid.
361    #[allow(clippy::too_many_arguments)]
362    async fn verify_challenge_response(
363        &self,
364        peer_addr: SocketAddr,
365        peer_address: Address<N>,
366        peer_node_type: NodeType,
367        response: ChallengeResponse<N>,
368        expected_genesis_header: Header<N>,
369        expected_restrictions_id: Field<N>,
370        expected_nonce: u64,
371    ) -> Option<DisconnectReason> {
372        // Retrieve the components of the challenge response.
373        let ChallengeResponse { genesis_header, restrictions_id, signature, nonce } = response;
374
375        // Verify the challenge response, by checking that the block header matches.
376        if genesis_header != expected_genesis_header {
377            warn!("Handshake with '{peer_addr}' failed (incorrect block header)");
378            return Some(DisconnectReason::InvalidChallengeResponse);
379        }
380        // Verify the restrictions ID.
381        if !peer_node_type.is_prover() && !self.node_type.is_prover() && restrictions_id != expected_restrictions_id {
382            warn!("Handshake with '{peer_addr}' failed (incorrect restrictions ID)");
383            return Some(DisconnectReason::InvalidChallengeResponse);
384        }
385        // Perform the deferred non-blocking deserialization of the signature.
386        let Ok(signature) = signature.deserialize().await else {
387            warn!("Handshake with '{peer_addr}' failed (cannot deserialize the signature)");
388            return Some(DisconnectReason::InvalidChallengeResponse);
389        };
390        // Verify the signature.
391        if !signature.verify_bytes(&peer_address, &[expected_nonce.to_le_bytes(), nonce.to_le_bytes()].concat()) {
392            warn!("Handshake with '{peer_addr}' failed (invalid signature)");
393            return Some(DisconnectReason::InvalidChallengeResponse);
394        }
395        None
396    }
397}