tox_core/relay/server/
server_ext.rs

1/*! Extension trait for run TCP server on `TcpStream` and ping sender
2*/
3
4use std::io::Error as IoError;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Duration;
8
9use failure::Fail;
10use futures::{future, FutureExt, TryFutureExt, SinkExt, StreamExt, TryStreamExt};
11use futures::channel::mpsc;
12use tokio::net::{TcpStream, TcpListener};
13use tokio_util::codec::Framed;
14use tokio::time::Error as TimerError;
15
16use tox_crypto::*;
17use crate::relay::codec::{DecodeError, EncodeError, Codec};
18use crate::relay::handshake::make_server_handshake;
19use crate::relay::server::{Client, Server};
20use crate::stats::*;
21
22/// Interval of time for Tcp Ping sender
23const TCP_PING_INTERVAL: Duration = Duration::from_secs(1);
24
25/// Interval of time for the TCP handshake.
26const TCP_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
27
28const SERVER_CHANNEL_SIZE: usize = 64;
29
30/// Error that can happen during server execution
31#[derive(Debug, Fail)]
32pub enum ServerRunError {
33    /// Incoming IO error
34    #[fail(display = "Incoming IO error: {:?}", error)]
35    IncomingError {
36        /// IO error
37        #[fail(cause)]
38        error: IoError
39    },
40    /// Ping wakeups timer error
41    #[fail(display = "Ping wakeups timer error: {:?}", error)]
42    PingWakeupsError {
43        /// Timer error
44        error: TimerError
45    },
46    /// Send pings error
47    #[fail(display = "Send pings error: {:?}", error)]
48    SendPingsError {
49        /// Send pings error
50        #[fail(cause)]
51        error: IoError
52    },
53}
54
55/// Error that can happen during TCP connection execution
56#[derive(Debug, Fail)]
57pub enum ConnectionError {
58    /// Error indicates that we couldn't get peer address
59    #[fail(display = "Failed to get peer address: {}", error)]
60    PeerAddrError {
61        /// Peer address error
62        #[fail(cause)]
63        error: IoError,
64    },
65    /// Sending packet error
66    #[fail(display = "Failed to send TCP packet: {}", error)]
67    SendPacketError {
68        error: EncodeError
69    },
70    /// Decode incoming packet error
71    #[fail(display = "Failed to decode incoming packet: {}", error)]
72    DecodePacketError {
73        error: DecodeError
74    },
75    /// Incoming IO error
76    #[fail(display = "Incoming IO error: {:?}", error)]
77    IncomingError {
78        /// IO error
79        #[fail(cause)]
80        error: IoError
81    },
82    /// Server handshake error
83    #[fail(display = "Server handshake error: {:?}", error)]
84    ServerHandshakeTimeoutError {
85        /// Server handshake error
86        #[fail(cause)]
87        error: tokio::time::Elapsed
88    },
89    #[fail(display = "Server handshake error: {:?}", error)]
90    ServerHandshakeIoError {
91        /// Server handshake error
92        #[fail(cause)]
93        error: IoError,
94    },
95    /// Packet handling error
96    #[fail(display = "Packet handling error: {:?}", error)]
97    PacketHandlingError {
98        /// Packet handling error
99        #[fail(cause)]
100        error: IoError
101    },
102    /// Insert client error
103    #[fail(display = "Packet handling error: {:?}", error)]
104    InsertClientError {
105        /// Insert client error
106        #[fail(cause)]
107        error: IoError
108    },
109
110    #[fail(display = "Packet handling error: {:?}", error)]
111    ShutdownError {
112        /// Insert client error
113        #[fail(cause)]
114        error: IoError
115    },
116}
117
118/// Running TCP ping sender and incoming `TcpStream`. This function uses
119/// `tokio::spawn` inside so it should be executed via tokio to be able to
120/// get tokio default executor.
121pub async fn tcp_run(server: &Server, mut listener: TcpListener, dht_sk: SecretKey, stats: Stats, connections_limit: usize) -> Result<(), ServerRunError> {
122    let connections_count = Arc::new(AtomicUsize::new(0));
123
124    let connections_future = async {
125        listener.incoming()
126            .map_err(|error| ServerRunError::IncomingError { error })
127            .try_for_each(|stream| {
128                if connections_count.load(Ordering::SeqCst) < connections_limit {
129                    connections_count.fetch_add(1, Ordering::SeqCst);
130                    let connections_count_c = connections_count.clone();
131                    let dht_sk = dht_sk.clone();
132                    let stats = stats.clone();
133                    let server = server.clone();
134
135                    tokio::spawn(
136                        async move {
137                            let res = tcp_run_connection(&server, stream, dht_sk, stats).await;
138
139                            if let Err(ref e) = res {
140                                error!("Error while running tcp connection: {:?}", e)
141                            }
142
143                            connections_count_c.fetch_sub(1, Ordering::SeqCst);
144                            res
145                        }
146                    );
147                } else {
148                    trace!("Tcp server has reached the limit of {} connections", connections_limit);
149                }
150
151                future::ok(())
152            }).await
153    };
154
155    let mut wakeups = tokio::time::interval(TCP_PING_INTERVAL);
156    let ping_future = async {
157        while wakeups.next().await.is_some() {
158            trace!("Tcp server ping sender wake up");
159            server.send_pings().await
160                .map_err(|error| ServerRunError::SendPingsError { error })?;
161        }
162
163        Ok(())
164    };
165
166    futures::select! {
167        res = connections_future.fuse() => res,
168        res = ping_future.fuse() => res,
169    }
170}
171
172/// Running TCP server on incoming `TcpStream`
173pub async fn tcp_run_connection(server: &Server, stream: TcpStream, dht_sk: SecretKey, stats: Stats) -> Result<(), ConnectionError> {
174    let addr = match stream.peer_addr() {
175        Ok(addr) => addr,
176        Err(error) => return Err(ConnectionError::PeerAddrError {
177            error
178        }),
179    };
180
181    debug!("A new TCP client connected from {}", addr);
182
183    let fut = tokio::time::timeout(
184        TCP_HANDSHAKE_TIMEOUT,
185        make_server_handshake(stream, dht_sk.clone())
186    );
187    let (stream, channel, client_pk) = match fut.await {
188        Err(error) => Err(
189            ConnectionError::ServerHandshakeTimeoutError { error }
190        ),
191        Ok(Err(error)) => Err(
192            ConnectionError::ServerHandshakeIoError { error }
193        ),
194        Ok(Ok(res)) => Ok(res)
195    }?;
196
197    debug!("Handshake for TCP client {:?} is completed", client_pk);
198
199    let secure_socket = Framed::new(stream, Codec::new(channel, stats));
200    let (mut to_client, from_client) = secure_socket.split();
201    let (to_client_tx, mut to_client_rx) = mpsc::channel(SERVER_CHANNEL_SIZE);
202
203    // processor = for each Packet from client process it
204    let processor = from_client
205        .map_err(|error| ConnectionError::DecodePacketError { error })
206        .try_for_each(|packet| {
207            debug!("Handle {:?} => {:?}", client_pk, packet);
208            server.handle_packet(&client_pk, packet)
209                .map_err(|error| ConnectionError::PacketHandlingError { error } )
210        });
211
212    let writer = async {
213        while let Some(packet) = to_client_rx.next().await {
214            trace!("Sending TCP packet {:?} to {:?}", packet, client_pk);
215            to_client.send(packet).await
216                .map_err(|error| ConnectionError::SendPacketError {
217                    error
218                })?;
219        }
220
221        Ok(())
222    };
223
224    let client = Client::new(
225        to_client_tx,
226        &client_pk,
227        addr.ip(),
228        addr.port()
229    );
230    server.insert(client).await
231        .map_err(|error| ConnectionError::InsertClientError { error })?;
232
233    let r_processing = futures::select! {
234        res = processor.fuse() => res,
235        res = writer.fuse() => res
236    };
237
238    debug!("Shutdown a client with PK {:?}", &client_pk);
239
240    server.shutdown_client(&client_pk, addr.ip(), addr.port())
241        .await
242        .map_err(|error| ConnectionError::ShutdownError { error })?;
243
244    r_processing
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use tox_binary_io::*;
251
252    use failure::Error;
253
254    use crate::relay::codec::Codec;
255    use crate::relay::handshake::make_client_handshake;
256    use tox_packet::relay::{Packet, PingRequest, PongResponse};
257
258    use crate::relay::server::client::*;
259
260    #[tokio::test]
261    async fn run_connection() {
262        crypto_init().unwrap();
263        let (client_pk, client_sk) = gen_keypair();
264        let (server_pk, server_sk) = gen_keypair();
265
266        let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
267        let mut listener = TcpListener::bind(&addr).await.unwrap();
268        let addr = listener.local_addr().unwrap();
269
270        let stats = Stats::new();
271        let stats_c = stats.clone();
272        let server = async {
273            // take the first connection
274            let connection = listener.incoming().next().await.unwrap().unwrap();
275            tcp_run_connection(&Server::new(), connection, server_sk, stats.clone())
276                .map_err(Error::from).await
277        };
278
279        let client = async {
280            let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
281            let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
282                .map_err(Error::from).await?;
283            let secure_socket = Framed::new(stream, Codec::new(channel, stats_c));
284            let (mut to_server, mut from_server) = secure_socket.split();
285            let packet = Packet::PingRequest(PingRequest {
286                ping_id: 42
287            });
288
289            to_server.send(packet).map_err(Error::from).await.unwrap();
290            let packet = from_server.next().await.unwrap();
291
292            assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
293                ping_id: 42
294            }));
295
296            Ok(())
297        };
298
299        let result = futures::select!(
300            res = server.fuse() => res,
301            res = client.fuse() => res,
302        );
303
304        assert!(result.is_ok());
305    }
306
307    #[tokio::test]
308    async fn run() {
309        tokio::time::pause();
310        crypto_init().unwrap();
311
312        let (client_pk, client_sk) = gen_keypair();
313        let (server_pk, server_sk) = gen_keypair();
314
315        let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
316        let listener = TcpListener::bind(&addr).await.unwrap();
317        let addr = listener.local_addr().unwrap();
318
319        let stats = Stats::new();
320        let server = async {
321            tcp_run(&Server::new(), listener, server_sk, stats.clone(), 1).await
322                .map_err(Error::from)
323        };
324
325        let client = async {
326            let socket = TcpStream::connect(&addr).map_err(Error::from).await?;
327            let (stream, channel) = make_client_handshake(socket, &client_pk, &client_sk, &server_pk)
328                .map_err(Error::from).await?;
329
330            let secure_socket = Framed::new(stream, Codec::new(channel, stats.clone()));
331            let (mut to_server, mut from_server) = secure_socket.split();
332            let packet = Packet::PingRequest(PingRequest {
333                ping_id: 42
334            });
335            to_server.send(packet).map_err(Error::from).await?;
336
337            let packet = from_server.next().await.unwrap();
338            assert_eq!(packet.unwrap(), Packet::PongResponse(PongResponse {
339                ping_id: 42
340            }));
341            // Set time when the client should be pinged
342            tokio::time::advance(TCP_PING_FREQUENCY + Duration::from_secs(1)).await;
343            while let Some(packet) = from_server.next().await {
344                // check the packet
345                let _ping_packet = unpack!(packet.unwrap(), Packet::PingRequest);
346            }
347
348            Ok(())
349        };
350
351        let result = futures::select!(
352            res = server.fuse() => res,
353            res = client.fuse() => res,
354        );
355
356        assert!(result.is_ok());
357    }
358}