safecoin_net_utils/
ip_echo_server.rs1use {
2 crate::{HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH},
3 log::*,
4 serde_derive::{Deserialize, Serialize},
5 solana_sdk::deserialize_utils::default_on_eof,
6 std::{
7 io,
8 net::{IpAddr, SocketAddr},
9 time::Duration,
10 },
11 tokio::{
12 io::{AsyncReadExt, AsyncWriteExt},
13 net::{TcpListener, TcpStream},
14 runtime::{self, Runtime},
15 time::timeout,
16 },
17};
18
19pub type IpEchoServer = Runtime;
20
21pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4;
22
23const IO_TIMEOUT: Duration = Duration::from_secs(5);
24
25#[derive(Serialize, Deserialize, Default, Debug)]
26pub(crate) struct IpEchoServerMessage {
27 tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], }
30
31#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
32pub struct IpEchoServerResponse {
33 pub(crate) address: IpAddr,
35 #[serde(deserialize_with = "default_on_eof")]
37 pub(crate) shred_version: Option<u16>,
38}
39
40impl IpEchoServerMessage {
41 pub fn new(tcp_ports: &[u16], udp_ports: &[u16]) -> Self {
42 let mut msg = Self::default();
43 assert!(tcp_ports.len() <= msg.tcp_ports.len());
44 assert!(udp_ports.len() <= msg.udp_ports.len());
45
46 msg.tcp_ports[..tcp_ports.len()].copy_from_slice(tcp_ports);
47 msg.udp_ports[..udp_ports.len()].copy_from_slice(udp_ports);
48 msg
49 }
50}
51
52pub(crate) fn ip_echo_server_request_length() -> usize {
53 const REQUEST_TERMINUS_LENGTH: usize = 1;
54 HEADER_LENGTH
55 + bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize
56 + REQUEST_TERMINUS_LENGTH
57}
58
59async fn process_connection(
60 mut socket: TcpStream,
61 peer_addr: SocketAddr,
62 shred_version: Option<u16>,
63) -> io::Result<()> {
64 info!("connection from {:?}", peer_addr);
65
66 let mut data = vec![0u8; ip_echo_server_request_length()];
67
68 let mut writer = {
69 let (mut reader, writer) = socket.split();
70 let _ = timeout(IO_TIMEOUT, reader.read_exact(&mut data)).await??;
71 writer
72 };
73
74 let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect();
75 if request_header != "\0\0\0\0" {
76 if request_header == "GET " || request_header == "POST" {
80 timeout(
82 IO_TIMEOUT,
83 writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"),
84 )
85 .await??;
86 return Ok(());
87 }
88 return Err(io::Error::new(
89 io::ErrorKind::Other,
90 format!("Bad request header: {}", request_header),
91 ));
92 }
93
94 let msg =
95 bincode::deserialize::<IpEchoServerMessage>(&data[HEADER_LENGTH..]).map_err(|err| {
96 io::Error::new(
97 io::ErrorKind::Other,
98 format!("Failed to deserialize IpEchoServerMessage: {:?}", err),
99 )
100 })?;
101
102 trace!("request: {:?}", msg);
103
104 match std::net::UdpSocket::bind("0.0.0.0:0") {
106 Ok(udp_socket) => {
107 for udp_port in &msg.udp_ports {
108 if *udp_port != 0 {
109 match udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port))) {
110 Ok(_) => debug!("Successful send_to udp/{}", udp_port),
111 Err(err) => info!("Failed to send_to udp/{}: {}", udp_port, err),
112 }
113 }
114 }
115 }
116 Err(err) => {
117 warn!("Failed to bind local udp socket: {}", err);
118 }
119 }
120
121 for tcp_port in &msg.tcp_ports {
123 if *tcp_port != 0 {
124 debug!("Connecting to tcp/{}", tcp_port);
125
126 let mut tcp_stream = timeout(
127 IO_TIMEOUT,
128 TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)),
129 )
130 .await??;
131
132 debug!("Connection established to tcp/{}", *tcp_port);
133 let _ = tcp_stream.shutdown();
134 }
135 }
136 let response = IpEchoServerResponse {
137 address: peer_addr.ip(),
138 shred_version,
139 };
140 let mut bytes = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH];
143 bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &response).unwrap();
144 trace!("response: {:?}", bytes);
145 writer.write_all(&bytes).await
146}
147
148async fn run_echo_server(tcp_listener: std::net::TcpListener, shred_version: Option<u16>) {
149 info!("bound to {:?}", tcp_listener.local_addr().unwrap());
150 let tcp_listener =
151 TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener");
152
153 loop {
154 match tcp_listener.accept().await {
155 Ok((socket, peer_addr)) => {
156 runtime::Handle::current().spawn(async move {
157 if let Err(err) = process_connection(socket, peer_addr, shred_version).await {
158 info!("session failed: {:?}", err);
159 }
160 });
161 }
162 Err(err) => warn!("listener accept failed: {:?}", err),
163 }
164 }
165}
166
167pub fn ip_echo_server(
170 tcp_listener: std::net::TcpListener,
171 shred_version: Option<u16>,
173) -> IpEchoServer {
174 tcp_listener.set_nonblocking(true).unwrap();
175
176 let runtime = Runtime::new().expect("Failed to create Runtime");
177 runtime.spawn(run_echo_server(tcp_listener, shred_version));
178 runtime
179}