turn_server/server/transport/
mod.rs

1pub mod tcp;
2pub mod udp;
3
4use std::{net::SocketAddr, time::Duration};
5
6use anyhow::Result;
7use bytes::Bytes;
8
9use tokio::time::interval;
10
11use crate::{
12    Service,
13    config::Ssl,
14    server::{Exchanger, PayloadType},
15    service::session::Identifier,
16    statistics::{Statistics, Stats},
17};
18
19pub const MAX_MESSAGE_SIZE: usize = 4096;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Transport {
23    Udp,
24    Tcp,
25}
26
27pub trait Socket: Send + 'static {
28    fn read(&mut self) -> impl Future<Output = Option<Bytes>> + Send;
29    fn write(&mut self, buffer: &[u8]) -> impl Future<Output = Result<()>> + Send;
30    fn close(&mut self) -> impl Future<Output = ()> + Send;
31}
32
33#[allow(unused)]
34pub struct ServerOptions {
35    pub transport: Transport,
36    pub idle_timeout: u32,
37    pub listen: SocketAddr,
38    pub external: SocketAddr,
39    pub ssl: Option<Ssl>,
40    pub mtu: usize,
41}
42
43pub trait Server: Sized + Send {
44    type Socket: Socket;
45
46    /// Bind the server to the specified address.
47    fn bind(options: &ServerOptions) -> impl Future<Output = Result<Self>> + Send;
48
49    /// Accept a new connection.
50    fn accept(&mut self) -> impl Future<Output = Option<(Self::Socket, SocketAddr)>> + Send;
51
52    /// Get the local address of the listener.
53    fn local_addr(&self) -> Result<SocketAddr>;
54
55    /// Start the server.
56    fn start(
57        options: ServerOptions,
58        service: Service,
59        statistics: Statistics,
60        exchanger: Exchanger,
61    ) -> impl Future<Output = Result<()>> + Send {
62        let transport = options.transport;
63        let idle_timeout = options.idle_timeout as u64;
64
65        async move {
66            let mut listener = Self::bind(&options).await?;
67            let local_addr = listener.local_addr()?;
68
69            log::info!(
70                "server listening: listen={}, external={}, local addr={local_addr}, transport={transport:?}",
71                options.listen,
72                options.external,
73            );
74
75            while let Some((mut socket, address)) = listener.accept().await {
76                let id = Identifier {
77                    interface: options.external,
78                    source: address,
79                };
80
81                let mut receiver = exchanger.get_receiver(address);
82                let mut router = service.make_router(address, options.external);
83                let reporter = statistics.get_reporter();
84
85                let service = service.clone();
86                let exchanger = exchanger.clone();
87
88                tokio::spawn(async move {
89                    let mut interval = interval(Duration::from_secs(1));
90                    let mut read_delay = 0;
91
92                    loop {
93                        tokio::select! {
94                            Some(buffer) = socket.read() => {
95                                read_delay = 0;
96
97                                if let Ok(res) = router.route(&buffer, address).await
98                                {
99                                    let (ty, bytes, target) = if let Some(it) = res {
100                                        (
101                                            it.method.map(PayloadType::Message).unwrap_or(PayloadType::ChannelData),
102                                            it.bytes,
103                                            it.target,
104                                        )
105                                    } else {
106                                        continue;
107                                    };
108
109                                    if let Some(endpoint) = target.endpoint {
110                                        exchanger.send(&endpoint, ty, Bytes::copy_from_slice(bytes));
111                                    } else {
112                                        if socket.write(bytes).await.is_err() {
113                                            break;
114                                        }
115
116                                        reporter.send(
117                                            &id,
118                                            &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)],
119                                        );
120
121                                        if let PayloadType::Message(method) = ty && method.is_error() {
122                                            reporter.send(&id, &[Stats::ErrorPkts(1)]);
123                                        }
124                                    }
125                                }
126                            }
127                            Some((bytes, method)) = receiver.recv() => {
128                                if socket.write(&bytes).await.is_err() {
129                                    break;
130                                } else {
131                                    reporter.send(&id, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
132                                }
133
134                                // The channel data needs to be aligned in multiples of 4 in
135                                // tcp. If the channel data is forwarded to tcp, the alignment
136                                // bit needs to be filled, because if the channel data comes
137                                // from udp, it is not guaranteed to be aligned and needs to be
138                                // checked.
139                                if transport == Transport::Tcp && method == PayloadType::ChannelData {
140                                    let pad = bytes.len() % 4;
141                                    if pad > 0 && socket.write(&[0u8; 8][..(4 - pad)]).await.is_err() {
142                                        break;
143                                    }
144                                }
145                            }
146                            _ = interval.tick() => {
147                                read_delay += 1;
148
149                                if read_delay >= idle_timeout {
150                                    break;
151                                }
152                            }
153                            else => {
154                                break;
155                            }
156                        }
157                    }
158
159                    // close the socket
160                    socket.close().await;
161
162                    // When the socket connection is closed, the procedure to close the session is
163                    // process directly once, avoiding the connection being disconnected
164                    // directly without going through the closing
165                    // process.
166                    service.get_session_manager().refresh(&id, 0);
167
168                    exchanger.remove(&address);
169
170                    log::info!(
171                        "socket disconnect: addr={address:?}, interface={local_addr:?}, transport={transport:?}"
172                    );
173                });
174            }
175
176            log::error!("server shutdown: interface={local_addr:?}, transport={transport:?}");
177
178            Ok(())
179        }
180    }
181}