Skip to main content

snarkos_node_tcp/protocols/
handshake.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{io, time::Duration};
17
18use tokio::{
19    io::{AsyncRead, AsyncWrite, split},
20    net::TcpStream,
21    sync::{mpsc, oneshot},
22    time::timeout,
23};
24use tracing::*;
25
26use crate::{
27    ConnectError,
28    Connection,
29    P2P,
30    protocols::{ProtocolHandler, ReturnableConnection},
31};
32
33/// Can be used to specify and enable network handshakes. Upon establishing a connection, both sides will
34/// need to adhere to the specified handshake rules in order to finalize the connection and be able to send
35/// or receive any messages.
36#[async_trait::async_trait]
37pub trait Handshake: P2P
38where
39    Self: Clone + Send + Sync + 'static,
40{
41    /// The maximum time allowed for a connection to perform a handshake before it is rejected.
42    ///
43    /// The default value is 5s.
44    const TIMEOUT_MS: u64 = 5_000;
45
46    /// Prepares the node to perform specified network handshakes.
47    async fn enable_handshake(&self) {
48        let (from_node_sender, mut from_node_receiver) = mpsc::unbounded_channel::<ReturnableConnection>();
49
50        // use a channel to know when the handshake task is ready
51        let (tx, rx) = oneshot::channel();
52
53        // spawn a background task dedicated to handling the handshakes
54        let self_clone = self.clone();
55        let handshake_task = tokio::spawn(async move {
56            trace!(parent: self_clone.tcp().span(), "spawned the Handshake handler task");
57            tx.send(()).unwrap(); // safe; the channel was just opened
58
59            while let Some((conn, result_sender)) = from_node_receiver.recv().await {
60                let addr = conn.addr();
61
62                let node = self_clone.clone();
63                tokio::spawn(async move {
64                    debug!(parent: node.tcp().span(), "shaking hands with {} as the {:?}", addr, !conn.side());
65                    let result = timeout(Duration::from_millis(Self::TIMEOUT_MS), node.perform_handshake(conn)).await;
66
67                    let ret: io::Result<_> = match result {
68                        Ok(Ok(conn)) => {
69                            debug!(parent: node.tcp().span(), "successfully handshaken with {addr}");
70                            Ok(conn)
71                        }
72                        Ok(Err(err)) => {
73                            debug!(parent: node.tcp().span(), "handshake with {addr} failed: {err}");
74                            Err(err.into())
75                        }
76                        Err(_) => {
77                            debug!(parent: node.tcp().span(), "handshake with {addr} timed out");
78                            Err(io::ErrorKind::TimedOut.into())
79                        }
80                    };
81
82                    // return the Connection to the Tcp, resuming Tcp::adapt_stream
83                    if result_sender.send(ret).is_err() {
84                        unreachable!("couldn't return a Connection to the Tcp");
85                    }
86                });
87            }
88        });
89        let _ = rx.await;
90        self.tcp().tasks.lock().push(handshake_task);
91
92        // register the Handshake handler with the Tcp
93        let hdl = Box::new(ProtocolHandler(from_node_sender));
94        assert!(self.tcp().protocols.handshake.set(hdl).is_ok(), "the Handshake protocol was enabled more than once!");
95    }
96
97    /// Performs the handshake; temporarily assumes control of the [`Connection`] and returns it if the handshake is
98    /// successful.
99    async fn perform_handshake(&self, conn: Connection) -> Result<Connection, ConnectError>;
100
101    /// Borrows the full connection stream to be used in the implementation of [`Handshake::perform_handshake`].
102    fn borrow_stream<'a>(&self, conn: &'a mut Connection) -> &'a mut TcpStream {
103        conn.stream.as_mut().unwrap()
104    }
105
106    /// Assumes full control of a connection's stream in the implementation of [`Handshake::perform_handshake`], by
107    /// the end of which it *must* be followed by [`Handshake::return_stream`].
108    fn take_stream(&self, conn: &mut Connection) -> TcpStream {
109        conn.stream.take().unwrap()
110    }
111
112    /// This method only needs to be called if [`Handshake::take_stream`] had been called before; it is used to
113    /// return a (potentially modified) stream back to the applicable connection.
114    fn return_stream<T: AsyncRead + AsyncWrite + Send + Sync + 'static>(&self, conn: &mut Connection, stream: T) {
115        let (reader, writer) = split(stream);
116        conn.reader = Some(Box::new(reader));
117        conn.writer = Some(Box::new(writer));
118    }
119}