Skip to main content

turn_server/server/transport/
tcp.rs

1use std::{io::Error, net::SocketAddr};
2
3#[cfg(feature = "ssl")]
4use std::sync::Arc;
5
6use anyhow::Result;
7use bytes::{Bytes, BytesMut};
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
10    net::{TcpListener as TokioTcpListener, TcpStream},
11    sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel},
12};
13
14#[cfg(feature = "ssl")]
15use tokio_rustls::{
16    TlsAcceptor,
17    rustls::{
18        ServerConfig,
19        pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
20    },
21    server::TlsStream,
22};
23
24use crate::{
25    codec::Decoder,
26    server::transport::{MAX_MESSAGE_SIZE, Server, ServerOptions, Socket},
27};
28
29enum MaybeSslStream {
30    #[cfg(feature = "ssl")]
31    Ssl(Box<TlsStream<TcpStream>>),
32    Base(TcpStream),
33}
34
35impl MaybeSslStream {
36    fn split(self) -> (Reader, Writer) {
37        use tokio::io::split;
38
39        match self {
40            Self::Base(it) => {
41                let (rx, tx) = split(it);
42
43                (Reader::Base(rx), Writer::Base(tx))
44            }
45            #[cfg(feature = "ssl")]
46            Self::Ssl(it) => {
47                let (rx, tx) = split(it);
48
49                (Reader::Ssl(rx), Writer::Ssl(tx))
50            }
51        }
52    }
53}
54
55enum Reader {
56    #[cfg(feature = "ssl")]
57    Ssl(ReadHalf<Box<TlsStream<TcpStream>>>),
58    Base(ReadHalf<TcpStream>),
59}
60
61impl Reader {
62    async fn read_buf(&mut self, buffer: &mut BytesMut) -> Result<usize, Error> {
63        match self {
64            Self::Base(it) => it.read_buf(buffer).await,
65            #[cfg(feature = "ssl")]
66            Self::Ssl(it) => it.read_buf(buffer).await,
67        }
68    }
69}
70
71enum Writer {
72    #[cfg(feature = "ssl")]
73    Ssl(WriteHalf<Box<TlsStream<TcpStream>>>),
74    Base(WriteHalf<TcpStream>),
75}
76
77impl Writer {
78    async fn write_all(&mut self, buffer: &[u8]) -> Result<(), Error> {
79        match self {
80            Self::Base(it) => it.write_all(buffer).await,
81            #[cfg(feature = "ssl")]
82            Self::Ssl(it) => it.write_all(buffer).await,
83        }
84    }
85}
86
87pub struct TcpSocket {
88    writer: Writer,
89    receiver: UnboundedReceiver<Bytes>,
90    close_signal_sender: Sender<()>,
91}
92
93impl TcpSocket {
94    fn new(stream: MaybeSslStream, addr: SocketAddr) -> Self {
95        let (close_signal_sender, mut close_signal_receiver) = channel::<()>(1);
96        let (tx, receiver) = unbounded_channel::<Bytes>();
97        let (mut reader, writer) = stream.split();
98
99        tokio::spawn(async move {
100            let mut buffer = BytesMut::new();
101
102            'a: loop {
103                tokio::select! {
104                    Ok(size) = reader.read_buf(&mut buffer) => {
105                        if size == 0 {
106                            break;
107                        }
108
109                        // The minimum length of a stun message will not be less
110                        // than 4.
111                        if buffer.len() < 4 {
112                            continue;
113                        }
114
115                        // Limit the maximum length of messages to 2048, this is to prevent buffer
116                        // overflow attacks.
117                        if buffer.len() > MAX_MESSAGE_SIZE * 3 {
118                            break;
119                        }
120
121                        loop {
122                            if buffer.len() <= 4 {
123                                break;
124                            }
125
126                            // Try to get the message length, if the currently
127                            // received data is less than the message length, jump
128                            // out of the current loop and continue to receive more
129                            // data.
130                            let size = match Decoder::message_size(&buffer, true) {
131                                Err(_) => break,
132                                Ok(size) => {
133                                    if size > MAX_MESSAGE_SIZE {
134                                        log::warn!(
135                                            "tcp message size too large: \
136                                                size={size}, \
137                                                max={MAX_MESSAGE_SIZE}, \
138                                                addr={addr:?}"
139                                        );
140
141                                        break 'a;
142                                    }
143
144                                    if size > buffer.len() {
145                                        break;
146                                    }
147
148                                    size
149                                }
150                            };
151
152                            if tx.send(buffer.split_to(size).freeze()).is_err() {
153                                break 'a;
154                            }
155                        }
156                    }
157                    _ = close_signal_receiver.recv() => {
158                        break;
159                    }
160                    else => {
161                        break;
162                    }
163                }
164            }
165        });
166
167        Self {
168            close_signal_sender,
169            writer,
170            receiver,
171        }
172    }
173}
174
175impl Socket for TcpSocket {
176    async fn read(&mut self) -> Option<Bytes> {
177        self.receiver.recv().await
178    }
179
180    async fn write(&mut self, buffer: &[u8]) -> Result<()> {
181        Ok(self.writer.write_all(buffer).await?)
182    }
183
184    async fn close(&mut self) {
185        self.receiver.close();
186
187        let _ = self.close_signal_sender.send(()).await;
188    }
189}
190
191pub struct TcpServer {
192    socket_receiver: UnboundedReceiver<(TcpSocket, SocketAddr)>,
193    local_addr: SocketAddr,
194}
195
196impl Server for TcpServer {
197    type Socket = TcpSocket;
198
199    async fn bind(options: &ServerOptions) -> Result<Self> {
200        #[cfg(feature = "ssl")]
201        let acceptor = if let Some(ssl) = &options.ssl {
202            Some(TlsAcceptor::from(Arc::new(
203                ServerConfig::builder()
204                    .with_no_client_auth()
205                    .with_single_cert(
206                        CertificateDer::pem_file_iter(ssl.certificate_chain.clone())?
207                            .collect::<Result<Vec<_>, _>>()?,
208                        PrivateKeyDer::from_pem_file(ssl.private_key.clone())?,
209                    )?,
210            )))
211        } else {
212            None
213        };
214
215        let listener = TokioTcpListener::bind(options.listen).await?;
216        let local_addr = listener.local_addr()?;
217
218        let (tx, socket_receiver) = unbounded_channel::<(TcpSocket, SocketAddr)>();
219        tokio::spawn(async move {
220            while let Ok((socket, addr)) = listener.accept().await {
221                // Disable the Nagle algorithm.
222                // because to maintain real-time, any received data should be processed
223                // as soon as possible.
224                if let Err(e) = socket.set_nodelay(true) {
225                    log::warn!("tls socket set nodelay failed!: addr={addr}, err={e}");
226                }
227
228                #[cfg(feature = "ssl")]
229                if let Some(acceptor) = acceptor.clone() {
230                    let tx = tx.clone();
231
232                    tokio::spawn(async move {
233                        if let Ok(socket) = acceptor.accept(socket).await {
234                            let _ = tx.send((
235                                TcpSocket::new(MaybeSslStream::Ssl(socket.into()), addr),
236                                addr,
237                            ));
238                        };
239                    });
240
241                    continue;
242                }
243
244                if tx
245                    .send((TcpSocket::new(MaybeSslStream::Base(socket), addr), addr))
246                    .is_err()
247                {
248                    break;
249                }
250            }
251        });
252
253        Ok(Self {
254            socket_receiver,
255            local_addr,
256        })
257    }
258
259    async fn accept(&mut self) -> Option<(Self::Socket, SocketAddr)> {
260        self.socket_receiver.recv().await
261    }
262
263    fn local_addr(&self) -> Result<SocketAddr> {
264        Ok(self.local_addr)
265    }
266}