Skip to main content

turn_server/server/transport/
mod.rs

1pub mod tcp;
2pub mod udp;
3
4use std::{net::SocketAddr, time::Duration};
5
6use anyhow::Result;
7use bytes::Bytes;
8
9use tokio::time::interval;
10
11use crate::{
12    Service,
13    config::Ssl,
14    server::{Exchanger, PayloadType},
15    service::session::Identifier,
16    statistics::{Statistics, Stats},
17};
18
19pub const MAX_MESSAGE_SIZE: usize = 4096;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Transport {
23    Udp,
24    Tcp,
25}
26
27pub trait Socket: Send + 'static {
28    fn read(&mut self) -> impl Future<Output = Option<Bytes>> + Send;
29    fn write(&mut self, buffer: &[u8]) -> impl Future<Output = Result<()>> + Send;
30    fn close(&mut self) -> impl Future<Output = ()> + Send;
31}
32
33#[allow(unused)]
34pub struct ServerOptions {
35    pub transport: Transport,
36    pub idle_timeout: u32,
37    pub listen: SocketAddr,
38    pub external: SocketAddr,
39    pub ssl: Option<Ssl>,
40    pub mtu: usize,
41}
42
43pub trait Server: Sized + Send {
44    type Socket: Socket;
45
46    /// Bind the server to the specified address.
47    fn bind(options: &ServerOptions) -> impl Future<Output = Result<Self>> + Send;
48
49    /// Accept a new connection.
50    fn accept(&mut self) -> impl Future<Output = Option<(Self::Socket, SocketAddr)>> + Send;
51
52    /// Get the local address of the listener.
53    fn local_addr(&self) -> Result<SocketAddr>;
54
55    /// Start the server.
56    fn start(
57        options: ServerOptions,
58        service: Service,
59        statistics: Statistics,
60        exchanger: Exchanger,
61    ) -> impl Future<Output = Result<()>> + Send {
62        let transport = options.transport;
63        let idle_timeout = options.idle_timeout as u64;
64
65        async move {
66            let mut listener = Self::bind(&options).await?;
67            let local_addr = listener.local_addr()?;
68
69            log::info!(
70                "server listening: listen={}, external={}, local addr={local_addr}, transport={transport:?}",
71                options.listen,
72                options.external,
73            );
74
75            while let Some((mut socket, address)) = listener.accept().await {
76                let id = Identifier::new(address, options.external);
77
78                let mut receiver = exchanger.get_receiver(address);
79                let mut router = service.make_router(address, options.external);
80                let reporter = statistics.get_reporter(transport);
81
82                let service = service.clone();
83                let exchanger = exchanger.clone();
84
85                tokio::spawn(async move {
86                    let mut interval = interval(Duration::from_secs(1));
87                    let mut read_delay = 0;
88
89                    loop {
90                        tokio::select! {
91                            Some(buffer) = socket.read() => {
92                                read_delay = 0;
93
94                                if let Ok(res) = router.route(&buffer, address).await
95                                {
96                                    let (ty, bytes, target) = if let Some(it) = res {
97                                        (
98                                            it.method.map(PayloadType::Message).unwrap_or(PayloadType::ChannelData),
99                                            it.bytes,
100                                            it.target,
101                                        )
102                                    } else {
103                                        continue;
104                                    };
105
106                                    if let Some(endpoint) = target.endpoint {
107                                        exchanger.send(&endpoint, ty, Bytes::copy_from_slice(bytes));
108                                    } else {
109                                        if socket.write(bytes).await.is_err() {
110                                            break;
111                                        }
112
113                                        reporter.send(
114                                            &id,
115                                            &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)],
116                                        );
117
118                                        if let PayloadType::Message(method) = ty && method.is_error() {
119                                            reporter.send(&id, &[Stats::ErrorPkts(1)]);
120                                        }
121                                    }
122                                }
123                            }
124                            Some((bytes, method)) = receiver.recv() => {
125                                if socket.write(&bytes).await.is_err() {
126                                    break;
127                                } else {
128                                    reporter.send(&id, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
129                                }
130
131                                // The channel data needs to be aligned in multiples of 4 in
132                                // tcp. If the channel data is forwarded to tcp, the alignment
133                                // bit needs to be filled, because if the channel data comes
134                                // from udp, it is not guaranteed to be aligned and needs to be
135                                // checked.
136                                if transport == Transport::Tcp && method == PayloadType::ChannelData {
137                                    let pad = bytes.len() % 4;
138                                    if pad > 0 && socket.write(&[0u8; 8][..(4 - pad)]).await.is_err() {
139                                        break;
140                                    }
141                                }
142                            }
143                            _ = interval.tick() => {
144                                read_delay += 1;
145
146                                if read_delay >= idle_timeout {
147                                    break;
148                                }
149                            }
150                            else => {
151                                break;
152                            }
153                        }
154                    }
155
156                    // close the socket
157                    socket.close().await;
158
159                    // When the socket connection is closed, the procedure to close the session is
160                    // process directly once, avoiding the connection being disconnected
161                    // directly without going through the closing
162                    // process.
163                    service.get_session_manager().refresh(&id, 0);
164
165                    exchanger.remove(&address);
166
167                    log::info!(
168                        "socket disconnect: addr={address:?}, interface={local_addr:?}, transport={transport:?}"
169                    );
170                });
171            }
172
173            log::error!("server shutdown: interface={local_addr:?}, transport={transport:?}");
174
175            Ok(())
176        }
177    }
178}