Skip to main content

turn_server/server/provider/
tcp.rs

1use std::{net::SocketAddr, task::Poll};
2
3#[cfg(feature = "ssl")]
4use std::sync::Arc;
5
6use anyhow::{Result, anyhow};
7use tokio::{
8    io::{AsyncReadExt, AsyncWriteExt},
9    net::{TcpListener, TcpStream},
10};
11
12#[cfg(feature = "ssl")]
13use tokio_rustls::{
14    TlsAcceptor,
15    rustls::{
16        ServerConfig,
17        pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
18    },
19    server::TlsStream,
20};
21
22use crate::{
23    codec::Decoder,
24    server::{
25        memory_pool::{Buffer, MemoryPool},
26        provider::{ProviderServer, ProviderStream, ServerOptions},
27    },
28};
29
30pub enum MaybeSslStream {
31    Base(TcpStream),
32    #[cfg(feature = "ssl")]
33    Ssl(TlsStream<TcpStream>),
34}
35
36impl ProviderStream for MaybeSslStream {
37    async fn read(&mut self) -> Result<Buffer> {
38        let mut buffer = MemoryPool::acquire();
39
40        unsafe {
41            buffer.set_len(4);
42        }
43
44        let size = {
45            if match self {
46                #[cfg(feature = "ssl")]
47                Self::Ssl(stream) => stream.read_exact(&mut buffer[..4]).await?,
48                Self::Base(stream) => stream.read_exact(&mut buffer[..4]).await?,
49            } < 4
50            {
51                return Err(anyhow!("failed to read the first 4 bytes of the message"));
52            }
53
54            Decoder::message_size(&buffer[..4], true)?
55        };
56
57        // The buffer is resized to the actual size of the message, which is determined by the first 4 bytes of the message.
58        if size > MemoryPool::MAX_MESSAGE_SIZE {
59            return Err(anyhow!(
60                "message size {} exceeds the maximum allowed size",
61                size
62            ));
63        }
64
65        // SAFETY: The buffer is initialized with zeroes and the length is set to
66        // the actual size of the message, which is determined by the first 4
67        // bytes of the message.
68        //
69        // The buffer is not used until it is fully initialized, so it is safe to
70        // set the length after reading the message.
71        unsafe {
72            buffer.set_len(size);
73        }
74
75        // Read the rest of the message based on the size determined by the first 4 bytes.
76        if match self {
77            #[cfg(feature = "ssl")]
78            Self::Ssl(stream) => stream.read_exact(&mut buffer[4..size]).await?,
79            Self::Base(stream) => stream.read_exact(&mut buffer[4..size]).await?,
80        } < size - 4
81        {
82            return Err(anyhow!("failed to read the full message"));
83        }
84
85        Ok(buffer)
86    }
87
88    async fn write(&mut self, buffer: &[u8]) -> Result<()> {
89        match self {
90            #[cfg(feature = "ssl")]
91            Self::Ssl(stream) => stream.write_all(buffer).await?,
92            Self::Base(stream) => stream.write_all(buffer).await?,
93        }
94
95        Ok(())
96    }
97
98    async fn close(&mut self) {
99        match self {
100            #[cfg(feature = "ssl")]
101            Self::Ssl(stream) => {
102                let _ = stream.shutdown().await;
103            }
104            Self::Base(stream) => {
105                let _ = stream.shutdown().await;
106            }
107        }
108    }
109}
110
111pub struct TcpServer {
112    listener: TcpListener,
113    local_addr: SocketAddr,
114    #[cfg(feature = "ssl")]
115    acceptor: Option<TlsAcceptor>,
116}
117
118impl ProviderServer for TcpServer {
119    type Stream = MaybeSslStream;
120
121    async fn bind(options: &ServerOptions) -> Result<Self> {
122        #[cfg(feature = "ssl")]
123        let acceptor = if let Some(ssl) = &options.ssl {
124            Some(TlsAcceptor::from(Arc::new(
125                ServerConfig::builder()
126                    .with_no_client_auth()
127                    .with_single_cert(
128                        CertificateDer::pem_file_iter(ssl.certificate_chain.clone())?
129                            .collect::<Result<Vec<_>, _>>()?,
130                        PrivateKeyDer::from_pem_file(ssl.private_key.clone())?,
131                    )?,
132            )))
133        } else {
134            None
135        };
136
137        let listener = TcpListener::bind(options.listen).await?;
138        let local_addr = listener.local_addr()?;
139
140        Ok(Self {
141            listener,
142            local_addr,
143            #[cfg(feature = "ssl")]
144            acceptor,
145        })
146    }
147
148    async fn accept(&mut self) -> Result<Poll<(Self::Stream, SocketAddr)>> {
149        let (socket, addr) = self.listener.accept().await?;
150
151        // Disable the Nagle algorithm.
152        // because to maintain real-time, any received data should be processed
153        // as soon as possible.
154        if let Err(e) = socket.set_nodelay(true) {
155            log::warn!("tls socket set nodelay failed!: addr={addr}, err={e}");
156        }
157
158        #[cfg(feature = "ssl")]
159        if let Some(ref acceptor) = self.acceptor {
160            return Ok(if let Ok(socket) = acceptor.accept(socket).await {
161                Poll::Ready((MaybeSslStream::Ssl(socket), addr))
162            } else {
163                Poll::Pending
164            });
165        }
166
167        Ok(Poll::Ready((MaybeSslStream::Base(socket), addr)))
168    }
169
170    fn local_addr(&self) -> Result<SocketAddr> {
171        Ok(self.local_addr)
172    }
173}