Skip to main content

snarkos_node/bootstrap_client/
handshake.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    BootstrapClient,
18    bft::events::{self, DisconnectReason, Event},
19    bootstrap_client::{codec::BootstrapClientCodec, network::MessageOrEvent},
20    network::{ConnectionMode, NodeType, PeerPoolHandling, log_repo_sha_comparison},
21    router::messages::{self, Message},
22    tcp::{ConnectError, Connection, ConnectionSide, protocols::*},
23};
24use snarkos_node_network::harden_socket;
25use snarkvm::{
26    ledger::narwhal::Data,
27    prelude::{Address, Network, io_error},
28};
29
30use futures_util::sink::SinkExt;
31use rand::{Rng, rngs::OsRng};
32use std::{io, net::SocketAddr};
33use tokio::net::TcpStream;
34use tokio_stream::StreamExt;
35use tokio_util::codec::Framed;
36
37#[derive(Debug)]
38enum HandshakeMessageKind {
39    ChallengeRequest,
40    ChallengeResponse,
41}
42
43macro_rules! send_msg {
44    ($msg:expr, $framed:expr, $peer_addr:expr) => {{
45        trace!("Sending '{}' to '{}'", $msg.name(), $peer_addr);
46        $framed.send($msg).await
47    }};
48}
49
50/// A macro handling incoming handshake messages, rejecting unexpected ones.
51macro_rules! expect_handshake_msg {
52    ($msg_ty:expr, $framed:expr, $peer_addr:expr) => {{
53        // Read the message as bytes.
54        let Some(message) = $framed.try_next().await? else {
55            return Err(ConnectError::other(format!(
56                "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
57                stringify!($msg_ty),
58            )));
59        };
60
61        // Match the expected message type with its expected size or peer type indicator.
62        match $msg_ty {
63            HandshakeMessageKind::ChallengeRequest
64                if matches!(
65                    message,
66                    MessageOrEvent::Message(Message::ChallengeRequest(_))
67                        | MessageOrEvent::Event(Event::ChallengeRequest(_))
68                ) =>
69            {
70                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
71                message
72            }
73            HandshakeMessageKind::ChallengeResponse
74                if matches!(
75                    message,
76                    MessageOrEvent::Message(Message::ChallengeResponse(_))
77                        | MessageOrEvent::Event(Event::ChallengeResponse(_))
78                ) =>
79            {
80                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
81                message
82            }
83            _ => {
84                let msg_name = match message {
85                    MessageOrEvent::Message(message) => message.name(),
86                    MessageOrEvent::Event(event) => event.name(),
87                };
88                return Err(ConnectError::other(format!(
89                    "'{}' did not follow the handshake protocol: expected {}, got {msg_name}",
90                    $peer_addr,
91                    stringify!($msg_ty),
92                )));
93            }
94        }
95    }};
96}
97
98#[async_trait]
99impl<N: Network> Handshake for BootstrapClient<N> {
100    async fn perform_handshake(&self, mut connection: Connection) -> Result<Connection, ConnectError> {
101        let peer_addr = connection.addr();
102        let peer_side = connection.side();
103        let stream = self.borrow_stream(&mut connection);
104        // Make the socket more robust.
105        harden_socket(stream)?;
106
107        // We don't know the listening address yet, as we don't initiate connections.
108        let mut listener_addr = if peer_side == ConnectionSide::Initiator {
109            debug!("Received a connection request from '{peer_addr}'");
110            None
111        } else {
112            unreachable!("The boostrapper clients don't initiate connections");
113        };
114
115        // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
116        let handshake_result = if peer_side == ConnectionSide::Responder {
117            unreachable!("The boostrapper clients don't initiate connections");
118        } else {
119            self.handshake_inner_responder(peer_addr, &mut listener_addr, stream).await
120        };
121
122        if let Some(addr) = listener_addr {
123            match handshake_result {
124                Ok((peer_port, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode)) => {
125                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
126                        // Due to only having a single Resolver, the BootstrapClient only adds an Aleo
127                        // address mapping for Gateway-mode connections, as it is only used there, and
128                        // it could otherwise clash with the Router-mode mapping for validators, which
129                        // may connect in both modes at the same time.
130                        let aleo_addr =
131                            if connection_mode == ConnectionMode::Gateway { Some(peer_aleo_addr) } else { None };
132                        self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, aleo_addr);
133                        peer.upgrade_to_connected(
134                            peer_addr,
135                            peer_port,
136                            peer_aleo_addr,
137                            peer_node_type,
138                            peer_version,
139                            peer_snarkos_sha,
140                            connection_mode,
141                        );
142                    }
143                    debug!("Completed the handshake with '{peer_addr}'");
144                }
145                Err(_) => {
146                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
147                        // The peer may only be downgraded if it's a ConnectingPeer.
148                        if peer.is_connecting() {
149                            peer.downgrade_to_candidate(addr);
150                        }
151                    }
152                }
153            }
154        }
155
156        handshake_result.map(|_| connection)
157    }
158}
159
160impl<N: Network> BootstrapClient<N> {
161    /// The connection responder side of the handshake.
162    async fn handshake_inner_responder<'a>(
163        &'a self,
164        peer_addr: SocketAddr,
165        listener_addr: &mut Option<SocketAddr>,
166        stream: &'a mut TcpStream,
167    ) -> Result<(u16, Address<N>, NodeType, u32, Option<[u8; 40]>, ConnectionMode), ConnectError> {
168        // Construct the stream.
169        let mut framed = Framed::new(stream, BootstrapClientCodec::<N>::handshake());
170
171        /* Step 1: Receive the challenge request. */
172
173        // Listen for the challenge request message, which can be either from a regular peer, or a validator.
174        let peer_request = expect_handshake_msg!(HandshakeMessageKind::ChallengeRequest, framed, peer_addr);
175        let (peer_port, peer_nonce, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode) =
176            match peer_request {
177                MessageOrEvent::Message(Message::ChallengeRequest(ref msg)) => (
178                    msg.listener_port,
179                    msg.nonce,
180                    msg.address,
181                    msg.node_type,
182                    msg.version,
183                    msg.snarkos_sha,
184                    ConnectionMode::Router,
185                ),
186                MessageOrEvent::Event(Event::ChallengeRequest(ref msg)) => (
187                    msg.listener_port,
188                    msg.nonce,
189                    msg.address,
190                    NodeType::Validator,
191                    msg.version,
192                    msg.snarkos_sha,
193                    ConnectionMode::Gateway,
194                ),
195                _ => unreachable!(),
196            };
197        debug!("Handshake mode: {connection_mode:?}");
198
199        // Obtain the peer's listening address.
200        *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_port));
201
202        // Introduce the peer into the peer pool.
203        self.add_connecting_peer(listener_addr.unwrap())?;
204
205        // Verify the challenge request.
206        if !self.verify_challenge_request(peer_addr, &mut framed, &peer_request).await? {
207            return Err(ConnectError::application(DisconnectReason::InvalidChallengeRequest));
208        };
209
210        /* Step 2: Send the challenge response followed by own challenge request. */
211
212        // Initialize an RNG.
213        let rng = &mut OsRng;
214
215        // Sign the counterparty nonce.
216        let response_nonce: u64 = rng.r#gen();
217        let data = [peer_nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
218        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
219            return Err(ConnectError::other(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
220        };
221
222        // Send the challenge response.
223        if connection_mode == ConnectionMode::Router {
224            let our_response = messages::ChallengeResponse {
225                genesis_header: self.genesis_header,
226                restrictions_id: self.restrictions_id,
227                signature: Data::Object(our_signature),
228                nonce: response_nonce,
229            };
230            let msg = Message::ChallengeResponse::<N>(our_response);
231            send_msg!(msg, framed, peer_addr)?;
232        } else {
233            let our_response = events::ChallengeResponse {
234                restrictions_id: self.restrictions_id,
235                signature: Data::Object(our_signature),
236                nonce: response_nonce,
237            };
238            let msg = Event::ChallengeResponse::<N>(our_response);
239            send_msg!(msg, framed, peer_addr)?;
240        }
241
242        // Sample a random nonce.
243        let our_nonce: u64 = rng.r#gen();
244        // Do not send a snarkOS SHA as the bootstrap client is not aware of height.
245        let snarkos_sha = None;
246        // Send the challenge request.
247        if connection_mode == ConnectionMode::Router {
248            let our_request = messages::ChallengeRequest::new(
249                self.local_ip().port(),
250                NodeType::BootstrapClient,
251                self.account.address(),
252                our_nonce,
253                snarkos_sha,
254            );
255            let msg = Message::ChallengeRequest(our_request);
256            send_msg!(msg, framed, peer_addr)?;
257        } else {
258            let our_request =
259                events::ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce, snarkos_sha);
260            let msg = Event::ChallengeRequest(our_request);
261            send_msg!(msg, framed, peer_addr)?;
262        }
263
264        /* Step 3: Receive the challenge response. */
265
266        // Listen for the challenge response message.
267        let peer_response = expect_handshake_msg!(HandshakeMessageKind::ChallengeResponse, framed, peer_addr);
268        // Verify the challenge response.
269        if !self.verify_challenge_response(peer_addr, peer_aleo_addr, our_nonce, &peer_response).await {
270            if connection_mode == ConnectionMode::Router {
271                let msg = Message::Disconnect::<N>(messages::DisconnectReason::InvalidChallengeResponse.into());
272                send_msg!(msg, framed, peer_addr)?;
273            } else {
274                let msg = Event::Disconnect::<N>(events::DisconnectReason::InvalidChallengeResponse.into());
275                send_msg!(msg, framed, peer_addr)?;
276            }
277            return Err(ConnectError::application(DisconnectReason::InvalidChallengeResponse));
278        }
279
280        Ok((peer_port, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode))
281    }
282
283    async fn verify_challenge_request(
284        &self,
285        peer_addr: SocketAddr,
286        framed: &mut Framed<&mut TcpStream, BootstrapClientCodec<N>>,
287        request: &MessageOrEvent<N>,
288    ) -> io::Result<bool> {
289        match request {
290            MessageOrEvent::Message(Message::ChallengeRequest(msg)) => {
291                log_repo_sha_comparison(peer_addr, &msg.snarkos_sha, Self::OWNER);
292
293                if msg.version < Message::<N>::latest_message_version() {
294                    let msg = Message::Disconnect::<N>(messages::DisconnectReason::OutdatedClientVersion.into());
295                    send_msg!(msg, framed, peer_addr)?;
296                    return Ok(false);
297                }
298
299                // Reject validators that aren't members of the committee.
300                if msg.node_type == NodeType::Validator {
301                    if let Some(current_committee) =
302                        self.get_or_update_committee().await.map_err(|_| io_error("Couldn't load the committee"))?
303                    {
304                        if !current_committee.contains(&msg.address) {
305                            let msg = Message::Disconnect::<N>(messages::DisconnectReason::ProtocolViolation.into());
306                            send_msg!(msg, framed, peer_addr)?;
307                            return Ok(false);
308                        }
309                    }
310                }
311            }
312            MessageOrEvent::Event(Event::ChallengeRequest(msg)) => {
313                log_repo_sha_comparison(peer_addr, &msg.snarkos_sha, Self::OWNER);
314
315                if msg.version < Event::<N>::VERSION {
316                    let msg = Event::Disconnect::<N>(events::DisconnectReason::OutdatedClientVersion.into());
317                    send_msg!(msg, framed, peer_addr)?;
318                    return Ok(false);
319                }
320
321                // Reject validators that aren't members of the committee.
322                if let Some(current_committee) =
323                    self.get_or_update_committee().await.map_err(|_| io_error("Couldn't load the committee"))?
324                {
325                    if !current_committee.contains(&msg.address) {
326                        let msg = Message::Disconnect::<N>(messages::DisconnectReason::ProtocolViolation.into());
327                        send_msg!(msg, framed, peer_addr)?;
328                        return Ok(false);
329                    }
330                }
331            }
332            _ => unreachable!(),
333        }
334
335        Ok(true)
336    }
337
338    async fn verify_challenge_response(
339        &self,
340        peer_addr: SocketAddr,
341        peer_aleo_addr: Address<N>,
342        our_nonce: u64,
343        response: &MessageOrEvent<N>,
344    ) -> bool {
345        let (peer_restrictions_id, peer_signature, peer_nonce) = match response {
346            MessageOrEvent::Message(Message::ChallengeResponse(msg)) => {
347                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
348            }
349            MessageOrEvent::Event(Event::ChallengeResponse(msg)) => {
350                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
351            }
352            _ => unreachable!(),
353        };
354
355        // Verify the restrictions ID.
356        if peer_restrictions_id != self.restrictions_id {
357            warn!("{} Handshake with '{peer_addr}' failed (incorrect restrictions ID)", Self::OWNER);
358            return false;
359        }
360        // Perform the deferred non-blocking deserialization of the signature.
361        let Ok(signature) = peer_signature.deserialize().await else {
362            warn!("{} Handshake with '{peer_addr}' failed (cannot deserialize the signature)", Self::OWNER);
363            return false;
364        };
365        // Verify the signature.
366        if !signature.verify_bytes(&peer_aleo_addr, &[our_nonce.to_le_bytes(), peer_nonce.to_le_bytes()].concat()) {
367            warn!("{} Handshake with '{peer_addr}' failed (invalid signature)", Self::OWNER);
368            return false;
369        }
370
371        true
372    }
373}