snarkos_node_tcp/protocols/handshake.rs
1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{io, time::Duration};
17
18use tokio::{
19 io::{AsyncRead, AsyncWrite, split},
20 net::TcpStream,
21 sync::{mpsc, oneshot},
22 time::timeout,
23};
24use tracing::*;
25
26use crate::{
27 ConnectError,
28 Connection,
29 P2P,
30 protocols::{ProtocolHandler, ReturnableConnection},
31};
32
33/// Can be used to specify and enable network handshakes. Upon establishing a connection, both sides will
34/// need to adhere to the specified handshake rules in order to finalize the connection and be able to send
35/// or receive any messages.
36#[async_trait::async_trait]
37pub trait Handshake: P2P
38where
39 Self: Clone + Send + Sync + 'static,
40{
41 /// The maximum time allowed for a connection to perform a handshake before it is rejected.
42 ///
43 /// The default value is 5s.
44 const TIMEOUT_MS: u64 = 5_000;
45
46 /// Prepares the node to perform specified network handshakes.
47 async fn enable_handshake(&self) {
48 let (from_node_sender, mut from_node_receiver) = mpsc::unbounded_channel::<ReturnableConnection>();
49
50 // use a channel to know when the handshake task is ready
51 let (tx, rx) = oneshot::channel();
52
53 // spawn a background task dedicated to handling the handshakes
54 let self_clone = self.clone();
55 let handshake_task = tokio::spawn(async move {
56 trace!(parent: self_clone.tcp().span(), "spawned the Handshake handler task");
57 tx.send(()).unwrap(); // safe; the channel was just opened
58
59 while let Some((conn, result_sender)) = from_node_receiver.recv().await {
60 let addr = conn.addr();
61
62 let node = self_clone.clone();
63 tokio::spawn(async move {
64 debug!(parent: node.tcp().span(), "shaking hands with {} as the {:?}", addr, !conn.side());
65 let result = timeout(Duration::from_millis(Self::TIMEOUT_MS), node.perform_handshake(conn)).await;
66
67 let ret: io::Result<_> = match result {
68 Ok(Ok(conn)) => {
69 debug!(parent: node.tcp().span(), "successfully handshaken with {addr}");
70 Ok(conn)
71 }
72 Ok(Err(err)) => {
73 debug!(parent: node.tcp().span(), "handshake with {addr} failed: {err}");
74 Err(err.into())
75 }
76 Err(_) => {
77 debug!(parent: node.tcp().span(), "handshake with {addr} timed out");
78 Err(io::ErrorKind::TimedOut.into())
79 }
80 };
81
82 // return the Connection to the Tcp, resuming Tcp::adapt_stream
83 if result_sender.send(ret).is_err() {
84 unreachable!("couldn't return a Connection to the Tcp");
85 }
86 });
87 }
88 });
89 let _ = rx.await;
90 self.tcp().tasks.lock().push(handshake_task);
91
92 // register the Handshake handler with the Tcp
93 let hdl = Box::new(ProtocolHandler(from_node_sender));
94 assert!(self.tcp().protocols.handshake.set(hdl).is_ok(), "the Handshake protocol was enabled more than once!");
95 }
96
97 /// Performs the handshake; temporarily assumes control of the [`Connection`] and returns it if the handshake is
98 /// successful.
99 async fn perform_handshake(&self, conn: Connection) -> Result<Connection, ConnectError>;
100
101 /// Borrows the full connection stream to be used in the implementation of [`Handshake::perform_handshake`].
102 fn borrow_stream<'a>(&self, conn: &'a mut Connection) -> &'a mut TcpStream {
103 conn.stream.as_mut().unwrap()
104 }
105
106 /// Assumes full control of a connection's stream in the implementation of [`Handshake::perform_handshake`], by
107 /// the end of which it *must* be followed by [`Handshake::return_stream`].
108 fn take_stream(&self, conn: &mut Connection) -> TcpStream {
109 conn.stream.take().unwrap()
110 }
111
112 /// This method only needs to be called if [`Handshake::take_stream`] had been called before; it is used to
113 /// return a (potentially modified) stream back to the applicable connection.
114 fn return_stream<T: AsyncRead + AsyncWrite + Send + Sync + 'static>(&self, conn: &mut Connection, stream: T) {
115 let (reader, writer) = split(stream);
116 conn.reader = Some(Box::new(reader));
117 conn.writer = Some(Box::new(writer));
118 }
119}