Skip to main content

snarkos_node/bootstrap_client/
handshake.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    BootstrapClient,
18    bft::events::{self, DisconnectReason, Event},
19    bootstrap_client::{codec::BootstrapClientCodec, network::MessageOrEvent},
20    network::{ConnectionMode, NodeType, PeerPoolHandling, log_repo_sha_comparison},
21    router::messages::{self, Message},
22    tcp::{ConnectError, Connection, ConnectionSide, protocols::*},
23};
24use snarkos_node_network::harden_socket;
25use snarkvm::{
26    ledger::narwhal::Data,
27    prelude::{Address, Network, io_error},
28};
29
30use futures_util::sink::SinkExt;
31
32use std::{io, net::SocketAddr};
33use tokio::net::TcpStream;
34use tokio_stream::StreamExt;
35use tokio_util::codec::Framed;
36
37#[derive(Debug)]
38enum HandshakeMessageKind {
39    ChallengeRequest,
40    ChallengeResponse,
41}
42
43macro_rules! send_msg {
44    ($msg:expr, $framed:expr, $peer_addr:expr) => {{
45        trace!("Sending '{}' to '{}'", $msg.name(), $peer_addr);
46        $framed.send($msg).await
47    }};
48}
49
50/// A macro handling incoming handshake messages, rejecting unexpected ones.
51macro_rules! expect_handshake_msg {
52    ($msg_ty:expr, $framed:expr, $peer_addr:expr) => {{
53        // Read the message as bytes.
54        let Some(message) = $framed.try_next().await? else {
55            return Err(ConnectError::other(format!(
56                "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
57                stringify!($msg_ty),
58            )));
59        };
60
61        // Match the expected message type with its expected size or peer type indicator.
62        match $msg_ty {
63            HandshakeMessageKind::ChallengeRequest
64                if matches!(
65                    message,
66                    MessageOrEvent::Message(Message::ChallengeRequest(_))
67                        | MessageOrEvent::Event(Event::ChallengeRequest(_))
68                ) =>
69            {
70                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
71                message
72            }
73            HandshakeMessageKind::ChallengeResponse
74                if matches!(
75                    message,
76                    MessageOrEvent::Message(Message::ChallengeResponse(_))
77                        | MessageOrEvent::Event(Event::ChallengeResponse(_))
78                ) =>
79            {
80                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
81                message
82            }
83            _ => {
84                let msg_name = match message {
85                    MessageOrEvent::Message(message) => message.name(),
86                    MessageOrEvent::Event(event) => event.name(),
87                };
88                return Err(ConnectError::other(format!(
89                    "'{}' did not follow the handshake protocol: expected {}, got {msg_name}",
90                    $peer_addr,
91                    stringify!($msg_ty),
92                )));
93            }
94        }
95    }};
96}
97
98#[async_trait]
99impl<N: Network> Handshake for BootstrapClient<N> {
100    async fn perform_handshake(&self, mut connection: Connection) -> Result<Connection, ConnectError> {
101        let peer_addr = connection.addr();
102        let peer_side = connection.side();
103        let stream = self.borrow_stream(&mut connection);
104        // Make the socket more robust.
105        harden_socket(stream)?;
106
107        // We don't know the listening address yet, as we don't initiate connections.
108        let mut listener_addr = if peer_side == ConnectionSide::Initiator {
109            debug!("Received a connection request from '{peer_addr}'");
110            None
111        } else {
112            unreachable!("The boostrapper clients don't initiate connections");
113        };
114
115        // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
116        let handshake_result = if peer_side == ConnectionSide::Responder {
117            unreachable!("The boostrapper clients don't initiate connections");
118        } else {
119            self.handshake_inner_responder(peer_addr, &mut listener_addr, stream).await
120        };
121
122        if let Some(addr) = listener_addr {
123            match handshake_result {
124                Ok((peer_port, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode)) => {
125                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
126                        // Due to only having a single Resolver, the BootstrapClient only adds an Aleo
127                        // address mapping for Gateway-mode connections, as it is only used there, and
128                        // it could otherwise clash with the Router-mode mapping for validators, which
129                        // may connect in both modes at the same time.
130                        let aleo_addr =
131                            if connection_mode == ConnectionMode::Gateway { Some(peer_aleo_addr) } else { None };
132                        self.resolver.write().insert_peer(peer.listener_addr(), peer_addr, aleo_addr);
133                        peer.upgrade_to_connected(
134                            peer_addr,
135                            peer_port,
136                            peer_aleo_addr,
137                            peer_node_type,
138                            peer_version,
139                            peer_snarkos_sha,
140                            connection_mode,
141                        );
142                    }
143                    debug!("Completed the handshake with '{peer_addr}'");
144                }
145                Err(_) => {
146                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
147                        // The peer may only be downgraded if it's a ConnectingPeer.
148                        if peer.is_connecting() {
149                            peer.downgrade_to_candidate(addr);
150                        }
151                    }
152                }
153            }
154        }
155
156        handshake_result.map(|_| connection)
157    }
158}
159
160impl<N: Network> BootstrapClient<N> {
161    /// The connection responder side of the handshake.
162    async fn handshake_inner_responder<'a>(
163        &'a self,
164        peer_addr: SocketAddr,
165        listener_addr: &mut Option<SocketAddr>,
166        stream: &'a mut TcpStream,
167    ) -> Result<(u16, Address<N>, NodeType, u32, Option<[u8; 40]>, ConnectionMode), ConnectError> {
168        // Construct the stream.
169        let mut framed = Framed::new(stream, BootstrapClientCodec::<N>::handshake());
170
171        /* Step 1: Receive the challenge request. */
172
173        // Listen for the challenge request message, which can be either from a regular peer, or a validator.
174        let peer_request = expect_handshake_msg!(HandshakeMessageKind::ChallengeRequest, framed, peer_addr);
175        let (peer_port, peer_nonce, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode) =
176            match peer_request {
177                MessageOrEvent::Message(Message::ChallengeRequest(ref msg)) => (
178                    msg.listener_port,
179                    msg.nonce,
180                    msg.address,
181                    msg.node_type,
182                    msg.version,
183                    msg.snarkos_sha,
184                    ConnectionMode::Router,
185                ),
186                MessageOrEvent::Event(Event::ChallengeRequest(ref msg)) => (
187                    msg.listener_port,
188                    msg.nonce,
189                    msg.address,
190                    NodeType::Validator,
191                    msg.version,
192                    msg.snarkos_sha,
193                    ConnectionMode::Gateway,
194                ),
195                _ => unreachable!(),
196            };
197        debug!("Handshake mode: {connection_mode:?}");
198
199        // Obtain the peer's listening address.
200        *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_port));
201
202        // Introduce the peer into the peer pool.
203        self.add_connecting_peer(listener_addr.unwrap())?;
204
205        // Verify the challenge request.
206        if !self.verify_challenge_request(peer_addr, &mut framed, &peer_request).await? {
207            return Err(ConnectError::application(DisconnectReason::InvalidChallengeRequest));
208        };
209
210        /* Step 2: Send the challenge response followed by own challenge request. */
211
212        // Sign the counterparty nonce.
213        let response_nonce: u64 = rand::random();
214        let data = [peer_nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
215        let Ok(our_signature) = self.account.sign_bytes(&data, &mut rand::rng()) else {
216            return Err(ConnectError::other(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
217        };
218
219        // Send the challenge response.
220        if connection_mode == ConnectionMode::Router {
221            let our_response = messages::ChallengeResponse {
222                genesis_header: self.genesis_header,
223                restrictions_id: self.restrictions_id,
224                signature: Data::Object(our_signature),
225                nonce: response_nonce,
226            };
227            let msg = Message::ChallengeResponse::<N>(our_response);
228            send_msg!(msg, framed, peer_addr)?;
229        } else {
230            let our_response = events::ChallengeResponse {
231                restrictions_id: self.restrictions_id,
232                signature: Data::Object(our_signature),
233                nonce: response_nonce,
234            };
235            let msg = Event::ChallengeResponse::<N>(our_response);
236            send_msg!(msg, framed, peer_addr)?;
237        }
238
239        // Sample a random nonce.
240        let our_nonce: u64 = rand::random();
241        // Do not send a snarkOS SHA as the bootstrap client is not aware of height.
242        let snarkos_sha = None;
243        // Send the challenge request.
244        if connection_mode == ConnectionMode::Router {
245            let our_request = messages::ChallengeRequest::new(
246                self.local_ip().port(),
247                NodeType::BootstrapClient,
248                self.account.address(),
249                our_nonce,
250                snarkos_sha,
251            );
252            let msg = Message::ChallengeRequest(our_request);
253            send_msg!(msg, framed, peer_addr)?;
254        } else {
255            let our_request =
256                events::ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce, snarkos_sha);
257            let msg = Event::ChallengeRequest(our_request);
258            send_msg!(msg, framed, peer_addr)?;
259        }
260
261        /* Step 3: Receive the challenge response. */
262
263        // Listen for the challenge response message.
264        let peer_response = expect_handshake_msg!(HandshakeMessageKind::ChallengeResponse, framed, peer_addr);
265        // Verify the challenge response.
266        if !self.verify_challenge_response(peer_addr, peer_aleo_addr, our_nonce, &peer_response).await {
267            if connection_mode == ConnectionMode::Router {
268                let msg = Message::Disconnect::<N>(messages::DisconnectReason::InvalidChallengeResponse.into());
269                send_msg!(msg, framed, peer_addr)?;
270            } else {
271                let msg = Event::Disconnect::<N>(events::DisconnectReason::InvalidChallengeResponse.into());
272                send_msg!(msg, framed, peer_addr)?;
273            }
274            return Err(ConnectError::application(DisconnectReason::InvalidChallengeResponse));
275        }
276
277        Ok((peer_port, peer_aleo_addr, peer_node_type, peer_version, peer_snarkos_sha, connection_mode))
278    }
279
280    async fn verify_challenge_request(
281        &self,
282        peer_addr: SocketAddr,
283        framed: &mut Framed<&mut TcpStream, BootstrapClientCodec<N>>,
284        request: &MessageOrEvent<N>,
285    ) -> io::Result<bool> {
286        match request {
287            MessageOrEvent::Message(Message::ChallengeRequest(msg)) => {
288                log_repo_sha_comparison(peer_addr, &msg.snarkos_sha, Self::OWNER);
289
290                if msg.version < Message::<N>::latest_message_version() {
291                    let msg = Message::Disconnect::<N>(messages::DisconnectReason::OutdatedClientVersion.into());
292                    send_msg!(msg, framed, peer_addr)?;
293                    return Ok(false);
294                }
295
296                // Reject validators that aren't members of the committee.
297                if msg.node_type == NodeType::Validator {
298                    if let Some(current_committee) =
299                        self.get_or_update_committee().await.map_err(|_| io_error("Couldn't load the committee"))?
300                    {
301                        if !current_committee.contains(&msg.address) {
302                            let msg = Message::Disconnect::<N>(messages::DisconnectReason::ProtocolViolation.into());
303                            send_msg!(msg, framed, peer_addr)?;
304                            return Ok(false);
305                        }
306                    }
307                }
308            }
309            MessageOrEvent::Event(Event::ChallengeRequest(msg)) => {
310                log_repo_sha_comparison(peer_addr, &msg.snarkos_sha, Self::OWNER);
311
312                if msg.version < Event::<N>::VERSION {
313                    let msg = Event::Disconnect::<N>(events::DisconnectReason::OutdatedClientVersion.into());
314                    send_msg!(msg, framed, peer_addr)?;
315                    return Ok(false);
316                }
317
318                // Reject validators that aren't members of the committee.
319                if let Some(current_committee) =
320                    self.get_or_update_committee().await.map_err(|_| io_error("Couldn't load the committee"))?
321                {
322                    if !current_committee.contains(&msg.address) {
323                        let msg = Message::Disconnect::<N>(messages::DisconnectReason::ProtocolViolation.into());
324                        send_msg!(msg, framed, peer_addr)?;
325                        return Ok(false);
326                    }
327                }
328            }
329            _ => unreachable!(),
330        }
331
332        Ok(true)
333    }
334
335    async fn verify_challenge_response(
336        &self,
337        peer_addr: SocketAddr,
338        peer_aleo_addr: Address<N>,
339        our_nonce: u64,
340        response: &MessageOrEvent<N>,
341    ) -> bool {
342        let (peer_restrictions_id, peer_signature, peer_nonce) = match response {
343            MessageOrEvent::Message(Message::ChallengeResponse(msg)) => {
344                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
345            }
346            MessageOrEvent::Event(Event::ChallengeResponse(msg)) => {
347                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
348            }
349            _ => unreachable!(),
350        };
351
352        // Verify the restrictions ID.
353        if peer_restrictions_id != self.restrictions_id {
354            warn!("{} Handshake with '{peer_addr}' failed (incorrect restrictions ID)", Self::OWNER);
355            return false;
356        }
357        // Perform the deferred non-blocking deserialization of the signature.
358        let Ok(signature) = peer_signature.deserialize().await else {
359            warn!("{} Handshake with '{peer_addr}' failed (cannot deserialize the signature)", Self::OWNER);
360            return false;
361        };
362        // Verify the signature.
363        if !signature.verify_bytes(&peer_aleo_addr, &[our_nonce.to_le_bytes(), peer_nonce.to_le_bytes()].concat()) {
364            warn!("{} Handshake with '{peer_addr}' failed (invalid signature)", Self::OWNER);
365            return false;
366        }
367
368        true
369    }
370}