snarkos_node_tcp/protocols/handshake.rs
1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{io, time::Duration};
17
18use tokio::{
19 io::{AsyncRead, AsyncWrite, split},
20 net::TcpStream,
21 sync::{mpsc, oneshot},
22 time::timeout,
23};
24use tracing::*;
25
26use crate::{
27 Connection,
28 P2P,
29 protocols::{ProtocolHandler, ReturnableConnection},
30};
31
32/// Can be used to specify and enable network handshakes. Upon establishing a connection, both sides will
33/// need to adhere to the specified handshake rules in order to finalize the connection and be able to send
34/// or receive any messages.
35#[async_trait::async_trait]
36pub trait Handshake: P2P
37where
38 Self: Clone + Send + Sync + 'static,
39{
40 /// The maximum time allowed for a connection to perform a handshake before it is rejected.
41 ///
42 /// The default value is 3000ms.
43 const TIMEOUT_MS: u64 = 3_000;
44
45 /// Prepares the node to perform specified network handshakes.
46 async fn enable_handshake(&self) {
47 let (from_node_sender, mut from_node_receiver) = mpsc::unbounded_channel::<ReturnableConnection>();
48
49 // use a channel to know when the handshake task is ready
50 let (tx, rx) = oneshot::channel();
51
52 // spawn a background task dedicated to handling the handshakes
53 let self_clone = self.clone();
54 let handshake_task = tokio::spawn(async move {
55 trace!(parent: self_clone.tcp().span(), "spawned the Handshake handler task");
56 tx.send(()).unwrap(); // safe; the channel was just opened
57
58 while let Some((conn, result_sender)) = from_node_receiver.recv().await {
59 let addr = conn.addr();
60
61 let node = self_clone.clone();
62 tokio::spawn(async move {
63 debug!(parent: node.tcp().span(), "shaking hands with {} as the {:?}", addr, !conn.side());
64 let result = timeout(Duration::from_millis(Self::TIMEOUT_MS), node.perform_handshake(conn)).await;
65
66 let ret = match result {
67 Ok(Ok(conn)) => {
68 debug!(parent: node.tcp().span(), "successfully handshaken with {}", addr);
69 Ok(conn)
70 }
71 Ok(Err(e)) => {
72 error!(parent: node.tcp().span(), "handshake with {} failed: {}", addr, e);
73 Err(e)
74 }
75 Err(_) => {
76 error!(parent: node.tcp().span(), "handshake with {} timed out", addr);
77 Err(io::ErrorKind::TimedOut.into())
78 }
79 };
80
81 // return the Connection to the Tcp, resuming Tcp::adapt_stream
82 if result_sender.send(ret).is_err() {
83 unreachable!("couldn't return a Connection to the Tcp");
84 }
85 });
86 }
87 });
88 let _ = rx.await;
89 self.tcp().tasks.lock().push(handshake_task);
90
91 // register the Handshake handler with the Tcp
92 let hdl = Box::new(ProtocolHandler(from_node_sender));
93 assert!(self.tcp().protocols.handshake.set(hdl).is_ok(), "the Handshake protocol was enabled more than once!");
94 }
95
96 /// Performs the handshake; temporarily assumes control of the [`Connection`] and returns it if the handshake is
97 /// successful.
98 async fn perform_handshake(&self, conn: Connection) -> io::Result<Connection>;
99
100 /// Borrows the full connection stream to be used in the implementation of [`Handshake::perform_handshake`].
101 fn borrow_stream<'a>(&self, conn: &'a mut Connection) -> &'a mut TcpStream {
102 conn.stream.as_mut().unwrap()
103 }
104
105 /// Assumes full control of a connection's stream in the implementation of [`Handshake::perform_handshake`], by
106 /// the end of which it *must* be followed by [`Handshake::return_stream`].
107 fn take_stream(&self, conn: &mut Connection) -> TcpStream {
108 conn.stream.take().unwrap()
109 }
110
111 /// This method only needs to be called if [`Handshake::take_stream`] had been called before; it is used to
112 /// return a (potentially modified) stream back to the applicable connection.
113 fn return_stream<T: AsyncRead + AsyncWrite + Send + Sync + 'static>(&self, conn: &mut Connection, stream: T) {
114 let (reader, writer) = split(stream);
115 conn.reader = Some(Box::new(reader));
116 conn.writer = Some(Box::new(writer));
117 }
118}