snarkos_node/bootstrap_client/
handshake.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    BootstrapClient,
18    bft::events::{self, Event},
19    bootstrap_client::{codec::BootstrapClientCodec, network::MessageOrEvent},
20    router::{
21        NodeType,
22        PeerPoolHandling,
23        messages::{self, Message},
24    },
25    tcp::{Connection, ConnectionSide, protocols::*},
26};
27use snarkvm::{
28    ledger::narwhal::Data,
29    prelude::{Address, Network, error},
30};
31
32use futures_util::sink::SinkExt;
33use rand::{Rng, rngs::OsRng};
34use std::{io, net::SocketAddr};
35use tokio::net::TcpStream;
36use tokio_stream::StreamExt;
37use tokio_util::codec::Framed;
38
39#[derive(Debug)]
40enum HandshakeMessageKind {
41    ChallengeRequest,
42    ChallengeResponse,
43}
44
45macro_rules! send_msg {
46    ($msg:expr, $framed:expr, $peer_addr:expr) => {{
47        trace!("Sending '{}' to '{}'", $msg.name(), $peer_addr);
48        $framed.send($msg).await
49    }};
50}
51
52/// A macro handling incoming handshake messages, rejecting unexpected ones.
53macro_rules! expect_handshake_msg {
54    ($msg_ty:expr, $framed:expr, $peer_addr:expr) => {{
55        // Read the message as bytes.
56        let Some(message) = $framed.try_next().await? else {
57            return Err(error(format!(
58                "the peer disconnected before sending {:?}, likely due to peer saturation or shutdown",
59                stringify!($msg_ty),
60            )));
61        };
62
63        // Match the expected message type with its expected size or peer type indicator.
64        match $msg_ty {
65            HandshakeMessageKind::ChallengeRequest
66                if matches!(
67                    message,
68                    MessageOrEvent::Message(Message::ChallengeRequest(_))
69                        | MessageOrEvent::Event(Event::ChallengeRequest(_))
70                ) =>
71            {
72                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
73                message
74            }
75            HandshakeMessageKind::ChallengeResponse
76                if matches!(
77                    message,
78                    MessageOrEvent::Message(Message::ChallengeResponse(_))
79                        | MessageOrEvent::Event(Event::ChallengeResponse(_))
80                ) =>
81            {
82                trace!("Received a '{}' from '{}'", stringify!($msg_ty), $peer_addr);
83                message
84            }
85            _ => {
86                let msg_name = match message {
87                    MessageOrEvent::Message(message) => message.name(),
88                    MessageOrEvent::Event(event) => event.name(),
89                };
90                return Err(error(format!(
91                    "'{}' did not follow the handshake protocol: expected {}, got {msg_name}",
92                    $peer_addr,
93                    stringify!($msg_ty),
94                )));
95            }
96        }
97    }};
98}
99
100#[async_trait]
101impl<N: Network> Handshake for BootstrapClient<N> {
102    async fn perform_handshake(&self, mut connection: Connection) -> io::Result<Connection> {
103        let peer_addr = connection.addr();
104        let peer_side = connection.side();
105        let stream = self.borrow_stream(&mut connection);
106
107        // We don't know the listening address yet, as we don't initiate connections.
108        let mut listener_addr = if peer_side == ConnectionSide::Initiator {
109            debug!("Received a connection request from '{peer_addr}'");
110            None
111        } else {
112            unreachable!("The boostrapper clients don't initiate connections");
113        };
114
115        // Perform the handshake; we pass on a mutable reference to listener_addr in case the process is broken at any point in time.
116        let handshake_result = if peer_side == ConnectionSide::Responder {
117            unreachable!("The boostrapper clients don't initiate connections");
118        } else {
119            self.handshake_inner_responder(peer_addr, &mut listener_addr, stream).await
120        };
121
122        if let Some(addr) = listener_addr {
123            match handshake_result {
124                Ok(Some((peer_port, peer_aleo_addr, peer_node_type, peer_version, validator_mode))) => {
125                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
126                        self.resolver.write().insert_peer(
127                            peer.listener_addr(),
128                            peer_addr,
129                            // Only resolve aleo addresses for Gateway connections.
130                            if validator_mode { Some(peer_aleo_addr) } else { None },
131                        );
132                        peer.upgrade_to_connected(peer_addr, peer_port, peer_aleo_addr, peer_node_type, peer_version);
133                    }
134                    debug!("Completed the handshake with '{peer_addr}'");
135                }
136                Ok(None) => {
137                    return Err(error("Duplicate handshake attempt with '{addr}'"));
138                }
139                Err(error) => {
140                    debug!("Handshake with '{peer_addr}' failed: {error}");
141                    if let Some(peer) = self.peer_pool.write().get_mut(&addr) {
142                        peer.downgrade_to_candidate(addr);
143                    }
144                    return Err(error);
145                }
146            }
147        }
148
149        Ok(connection)
150    }
151}
152
153impl<N: Network> BootstrapClient<N> {
154    /// The connection responder side of the handshake.
155    async fn handshake_inner_responder<'a>(
156        &'a self,
157        peer_addr: SocketAddr,
158        listener_addr: &mut Option<SocketAddr>,
159        stream: &'a mut TcpStream,
160    ) -> io::Result<Option<(u16, Address<N>, NodeType, u32, bool)>> {
161        // Construct the stream.
162        let mut framed = Framed::new(stream, BootstrapClientCodec::<N>::handshake());
163
164        /* Step 1: Receive the challenge request. */
165
166        // Listen for the challenge request message, which can be either from a regular peer, or a validator.
167        let peer_request = expect_handshake_msg!(HandshakeMessageKind::ChallengeRequest, framed, peer_addr);
168        let (peer_port, peer_nonce, peer_aleo_addr, peer_node_type, peer_version, validator_mode) = match peer_request {
169            MessageOrEvent::Message(Message::ChallengeRequest(ref msg)) => {
170                (msg.listener_port, msg.nonce, msg.address, msg.node_type, msg.version, false)
171            }
172            MessageOrEvent::Event(Event::ChallengeRequest(ref msg)) => {
173                (msg.listener_port, msg.nonce, msg.address, NodeType::Validator, msg.version, true)
174            }
175            _ => unreachable!(),
176        };
177        debug!("Handshake mode: {}validator", if validator_mode { "" } else { "non-" });
178
179        // Obtain the peer's listening address.
180        *listener_addr = Some(SocketAddr::new(peer_addr.ip(), peer_port));
181        let listener_addr = listener_addr.unwrap();
182
183        // Introduce the peer into the peer pool.
184        if !self.add_connecting_peer(listener_addr) {
185            // Return early if already being connected to.
186            return Ok(None);
187        }
188
189        // Verify the challenge request.
190        if !self.verify_challenge_request(peer_addr, &mut framed, &peer_request).await? {
191            return Err(error(format!("Handshake with '{peer_addr}' failed: invalid challenge request")));
192        };
193
194        /* Step 2: Send the challenge response followed by own challenge request. */
195
196        // Initialize an RNG.
197        let rng = &mut OsRng;
198
199        // Sign the counterparty nonce.
200        let response_nonce: u64 = rng.r#gen();
201        let data = [peer_nonce.to_le_bytes(), response_nonce.to_le_bytes()].concat();
202        let Ok(our_signature) = self.account.sign_bytes(&data, rng) else {
203            return Err(error(format!("Failed to sign the challenge request nonce from '{peer_addr}'")));
204        };
205
206        // Send the challenge response.
207        if !validator_mode {
208            let our_response = messages::ChallengeResponse {
209                genesis_header: self.genesis_header,
210                restrictions_id: self.restrictions_id,
211                signature: Data::Object(our_signature),
212                nonce: response_nonce,
213            };
214            let msg = Message::ChallengeResponse::<N>(our_response);
215            send_msg!(msg, framed, peer_addr)?;
216        } else {
217            let our_response = events::ChallengeResponse {
218                restrictions_id: self.restrictions_id,
219                signature: Data::Object(our_signature),
220                nonce: response_nonce,
221            };
222            let msg = Event::ChallengeResponse::<N>(our_response);
223            send_msg!(msg, framed, peer_addr)?;
224        }
225
226        // Sample a random nonce.
227        let our_nonce: u64 = rng.r#gen();
228        // Send the challenge request.
229        if !validator_mode {
230            let our_request = messages::ChallengeRequest::new(
231                self.local_ip().port(),
232                NodeType::BootstrapClient,
233                self.account.address(),
234                our_nonce,
235            );
236            let msg = Message::ChallengeRequest(our_request);
237            send_msg!(msg, framed, peer_addr)?;
238        } else {
239            let our_request = events::ChallengeRequest::new(self.local_ip().port(), self.account.address(), our_nonce);
240            let msg = Event::ChallengeRequest(our_request);
241            send_msg!(msg, framed, peer_addr)?;
242        }
243
244        /* Step 3: Receive the challenge response. */
245
246        // Listen for the challenge response message.
247        let peer_response = expect_handshake_msg!(HandshakeMessageKind::ChallengeResponse, framed, peer_addr);
248        // Verify the challenge response.
249        if !self.verify_challenge_response(peer_addr, peer_aleo_addr, our_nonce, &peer_response).await {
250            if !validator_mode {
251                let msg = Message::Disconnect::<N>(messages::DisconnectReason::InvalidChallengeResponse.into());
252                send_msg!(msg, framed, peer_addr)?;
253            } else {
254                let msg = Event::Disconnect::<N>(events::DisconnectReason::InvalidChallengeResponse.into());
255                send_msg!(msg, framed, peer_addr)?;
256            }
257            return Err(error(format!("Handshake with '{peer_addr}' failed: invalid challenge response")));
258        }
259
260        Ok(Some((peer_port, peer_aleo_addr, peer_node_type, peer_version, validator_mode)))
261    }
262
263    async fn verify_challenge_request(
264        &self,
265        peer_addr: SocketAddr,
266        framed: &mut Framed<&mut TcpStream, BootstrapClientCodec<N>>,
267        request: &MessageOrEvent<N>,
268    ) -> io::Result<bool> {
269        match request {
270            MessageOrEvent::Message(Message::ChallengeRequest(msg)) => {
271                if msg.version < Message::<N>::latest_message_version() {
272                    let msg = Message::Disconnect::<N>(messages::DisconnectReason::OutdatedClientVersion.into());
273                    send_msg!(msg, framed, peer_addr)?;
274                    return Ok(false);
275                }
276            }
277            MessageOrEvent::Event(Event::ChallengeRequest(msg)) => {
278                if msg.version < Event::<N>::VERSION {
279                    let msg = Event::Disconnect::<N>(events::DisconnectReason::OutdatedClientVersion.into());
280                    send_msg!(msg, framed, peer_addr)?;
281                    return Ok(false);
282                }
283            }
284            _ => unreachable!(),
285        }
286
287        Ok(true)
288    }
289
290    async fn verify_challenge_response(
291        &self,
292        peer_addr: SocketAddr,
293        peer_aleo_addr: Address<N>,
294        our_nonce: u64,
295        response: &MessageOrEvent<N>,
296    ) -> bool {
297        let (peer_restrictions_id, peer_signature, peer_nonce) = match response {
298            MessageOrEvent::Message(Message::ChallengeResponse(msg)) => {
299                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
300            }
301            MessageOrEvent::Event(Event::ChallengeResponse(msg)) => {
302                (msg.restrictions_id, msg.signature.clone(), msg.nonce)
303            }
304            _ => unreachable!(),
305        };
306
307        // Verify the restrictions ID.
308        if peer_restrictions_id != self.restrictions_id {
309            warn!("{} Handshake with '{peer_addr}' failed (incorrect restrictions ID)", Self::OWNER);
310            return false;
311        }
312        // Perform the deferred non-blocking deserialization of the signature.
313        let Ok(signature) = peer_signature.deserialize().await else {
314            warn!("{} Handshake with '{peer_addr}' failed (cannot deserialize the signature)", Self::OWNER);
315            return false;
316        };
317        // Verify the signature.
318        if !signature.verify_bytes(&peer_aleo_addr, &[our_nonce.to_le_bytes(), peer_nonce.to_le_bytes()].concat()) {
319            warn!("{} Handshake with '{peer_addr}' failed (invalid signature)", Self::OWNER);
320            return false;
321        }
322
323        true
324    }
325}