1use std::io::Error as IoError;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Duration;
8
9use failure::Fail;
10use futures::{future, FutureExt, TryFutureExt, SinkExt, StreamExt, TryStreamExt};
11use futures::channel::mpsc;
12use tokio::net::{TcpStream, TcpListener};
13use tokio_util::codec::Framed;
14use tokio::time::Error as TimerError;
15
16use tox_crypto::*;
17use crate::relay::codec::{DecodeError, EncodeError, Codec};
18use crate::relay::handshake::make_server_handshake;
19use crate::relay::server::{Client, Server};
20use crate::stats::*;
21
22const TCP_PING_INTERVAL: Duration = Duration::from_secs(1);
24
25const TCP_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
27
28const SERVER_CHANNEL_SIZE: usize = 64;
29
30#[derive(Debug, Fail)]
32pub enum ServerRunError {
33 #[fail(display = "Incoming IO error: {:?}", error)]
35 IncomingError {
36 #[fail(cause)]
38 error: IoError
39 },
40 #[fail(display = "Ping wakeups timer error: {:?}", error)]
42 PingWakeupsError {
43 error: TimerError
45 },
46 #[fail(display = "Send pings error: {:?}", error)]
48 SendPingsError {
49 #[fail(cause)]
51 error: IoError
52 },
53}
54
55#[derive(Debug, Fail)]
57pub enum ConnectionError {
58 #[fail(display = "Failed to get peer address: {}", error)]
60 PeerAddrError {
61 #[fail(cause)]
63 error: IoError,
64 },
65 #[fail(display = "Failed to send TCP packet: {}", error)]
67 SendPacketError {
68 error: EncodeError
69 },
70 #[fail(display = "Failed to decode incoming packet: {}", error)]
72 DecodePacketError {
73 error: DecodeError
74 },
75 #[fail(display = "Incoming IO error: {:?}", error)]
77 IncomingError {
78 #[fail(cause)]
80 error: IoError
81 },
82 #[fail(display = "Server handshake error: {:?}", error)]
84 ServerHandshakeTimeoutError {
85 #[fail(cause)]
87 error: tokio::time::Elapsed
88 },
89 #[fail(display = "Server handshake error: {:?}", error)]
90 ServerHandshakeIoError {
91 #[fail(cause)]
93 error: IoError,
94 },
95 #[fail(display = "Packet handling error: {:?}", error)]
97 PacketHandlingError {
98 #[fail(cause)]
100 error: IoError
101 },
102 #[fail(display = "Packet handling error: {:?}", error)]
104 InsertClientError {
105 #[fail(cause)]
107 error: IoError
108 },
109
110 #[fail(display = "Packet handling error: {:?}", error)]
111 ShutdownError {
112 #[fail(cause)]
114 error: IoError
115 },
116}
117
118pub async fn tcp_run(server: &Server, mut listener: TcpListener, dht_sk: SecretKey, stats: Stats, connections_limit: usize) -> Result<(), ServerRunError> {
122 let connections_count = Arc::new(AtomicUsize::new(0));
123
124 let connections_future = async {
125 listener.incoming()
126 .map_err(|error| ServerRunError::IncomingError { error })
127 .try_for_each(|stream| {
128 if connections_count.load(Ordering::SeqCst) < connections_limit {
129 connections_count.fetch_add(1, Ordering::SeqCst);
130 let connections_count_c = connections_count.clone();
131 let dht_sk = dht_sk.clone();
132 let stats = stats.clone();
133 let server = server.clone();
134
135 tokio::spawn(
136 async move {
137 let res = tcp_run_connection(&server, stream, dht_sk, stats).await;
138
139 if let Err(ref e) = res {
140 error!("Error while running tcp connection: {:?}", e)
141 }
142
143 connections_count_c.fetch_sub(1, Ordering::SeqCst);
144 res
145 }
146 );
147 } else {
148 trace!("Tcp server has reached the limit of {} connections", connections_limit);
149 }
150
151 future::ok(())
152 }).await
153 };
154
155 let mut wakeups = tokio::time::interval(TCP_PING_INTERVAL);
156 let ping_future = async {
157 while wakeups.next().await.is_some() {
158 trace!("Tcp server ping sender wake up");
159 server.send_pings().await
160 .map_err(|error| ServerRunError::SendPingsError { error })?;
161 }
162
163 Ok(())
164 };
165
166 futures::select! {
167 res = connections_future.fuse() => res,
168 res = ping_future.fuse() => res,
169 }
170}
171
172pub async fn tcp_run_connection(server: &Server, stream: TcpStream, dht_sk: SecretKey, stats: Stats) -> Result<(), ConnectionError> {
174 let addr = match stream.peer_addr() {
175 Ok(addr) => addr,
176 Err(error) => return Err(ConnectionError::PeerAddrError {
177 error
178 }),
179 };
180
181 debug!("A new TCP client connected from {}", addr);
182
183 let fut = tokio::time::timeout(
184 TCP_HANDSHAKE_TIMEOUT,
185 make_server_handshake(stream, dht_sk.clone())
186 );
187 let (stream, channel, client_pk) = match fut.await {
188 Err(error) => Err(
189 ConnectionError::ServerHandshakeTimeoutError { error }
190 ),
191 Ok(Err(error)) => Err(
192 ConnectionError::ServerHandshakeIoError { error }
193 ),
194 Ok(Ok(res)) => Ok(res)
195 }?;
196
197 debug!("Handshake for TCP client {:?} is completed", client_pk);
198
199 let secure_socket = Framed::new(stream, Codec::new(channel, stats));
200 let (mut to_client, from_client) = secure_socket.split();
201 let (to_client_tx, mut to_client_rx) = mpsc::channel(SERVER_CHANNEL_SIZE);
202
203 let processor = from_client
205 .map_err(|error| ConnectionError::DecodePacketError { error })
206 .try_for_each(|packet| {
207 debug!("Handle {:?} => {:?}", client_pk, packet);
208 server.handle_packet(&client_pk, packet)
209 .map_err(|error| ConnectionError::PacketHandlingError { error } )
210 });
211
212 let writer = async {
213 while let Some(packet) = to_client_rx.next().await {
214 trace!("Sending TCP packet {:?} to {:?}", packet, client_pk);
215 to_client.send(packet).await
216 .map_err(|error| ConnectionError::SendPacketError {
217 error
218 })?;
219 }
220
221 Ok(())
222 };
223
224 let client = Client::new(
225 to_client_tx,
226 &client_pk,
227 addr.ip(),
228 addr.port()
229 );
230 server.insert(client).await
231 .map_err(|error| ConnectionError::InsertClientError { error })?;
232
233 let r_processing = futures::select! {
234 res = processor.fuse() => res,
235 res = writer.fuse() => res
236 };
237
238 debug!("Shutdown a client with PK {:?}", &client_pk);
239
240 server.shutdown_client(&client_pk, addr.ip(), addr.port())
241 .await
242 .map_err(|error| ConnectionError::ShutdownError { error })?;
243
244 r_processing
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use tox_binary_io::*;
251
252 use failure::Error;
253
254 use crate::relay::codec::Codec;
255 use crate::relay::handshake::make_client_handshake;
256 use tox_packet::relay::{Packet, PingRequest, PongResponse};
257
258 use crate::relay::server::client::*;
259
260 #[tokio::test]
261 async fn run_connection() {
262 crypto_init().unwrap();
263 let (client_pk, client_sk) = gen_keypair();
264 let (server_pk, server_sk) = gen_keypair();
265
266 let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
267 let mut listener = TcpListener::bind(&addr).await.unwrap();
268 let addr = listener.local_addr().unwrap();
269
270 let stats = Stats::new();
271 let stats_c = stats.clone();
272 let server = async {
273 let connection = listener.incoming().next().await.unwrap().unwrap();
275 tcp_run_connection(&Server::new(), connection, server_sk, stats.clone())
276 .map_err(Error::from).await
277 };
278
279 let client = async {
280 let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
281 let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
282 .map_err(Error::from).await?;
283 let secure_socket = Framed::new(stream, Codec::new(channel, stats_c));
284 let (mut to_server, mut from_server) = secure_socket.split();
285 let packet = Packet::PingRequest(PingRequest {
286 ping_id: 42
287 });
288
289 to_server.send(packet).map_err(Error::from).await.unwrap();
290 let packet = from_server.next().await.unwrap();
291
292 assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
293 ping_id: 42
294 }));
295
296 Ok(())
297 };
298
299 let result = futures::select!(
300 res = server.fuse() => res,
301 res = client.fuse() => res,
302 );
303
304 assert!(result.is_ok());
305 }
306
307 #[tokio::test]
308 async fn run() {
309 tokio::time::pause();
310 crypto_init().unwrap();
311
312 let (client_pk, client_sk) = gen_keypair();
313 let (server_pk, server_sk) = gen_keypair();
314
315 let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
316 let listener = TcpListener::bind(&addr).await.unwrap();
317 let addr = listener.local_addr().unwrap();
318
319 let stats = Stats::new();
320 let server = async {
321 tcp_run(&Server::new(), listener, server_sk, stats.clone(), 1).await
322 .map_err(Error::from)
323 };
324
325 let client = async {
326 let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
327 let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
328 .map_err(Error::from).await?;
329
330 let secure_socket = Framed::new(stream, Codec::new(channel, stats.clone()));
331 let (mut to_server, mut from_server) = secure_socket.split();
332 let packet = Packet::PingRequest(PingRequest {
333 ping_id: 42
334 });
335 to_server.send(packet).map_err(Error::from).await?;
336
337 let packet = from_server.next().await.unwrap();
338 assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
339 ping_id: 42
340 }));
341 tokio::time::advance(TCP_PING_FREQUENCY + Duration::from_secs(1)).await;
343 while let Some(packet) = from_server.next().await {
344 let _ping_packet = unpack!(packet.unwrap(), Packet::PingRequest);
346 }
347
348 Ok(())
349 };
350
351 let result = futures::select!(
352 res = server.fuse() => res,
353 res = client.fuse() => res,
354 );
355
356 assert!(result.is_ok());
357 }
358}