turn_server/server/provider/
tcp.rs1use std::{net::SocketAddr, task::Poll};
2
3#[cfg(feature = "ssl")]
4use std::sync::Arc;
5
6use anyhow::{Result, anyhow};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt},
9 net::{TcpListener, TcpStream},
10};
11
12#[cfg(feature = "ssl")]
13use tokio_rustls::{
14 TlsAcceptor,
15 rustls::{
16 ServerConfig,
17 pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
18 },
19 server::TlsStream,
20};
21
22use crate::{
23 codec::Decoder,
24 server::{
25 memory_pool::{Buffer, MemoryPool},
26 provider::{ProviderServer, ProviderStream, ServerOptions},
27 },
28};
29
30pub enum MaybeSslStream {
31 Base(TcpStream),
32 #[cfg(feature = "ssl")]
33 Ssl(TlsStream<TcpStream>),
34}
35
36impl ProviderStream for MaybeSslStream {
37 async fn read(&mut self) -> Result<Buffer> {
38 let mut buffer = MemoryPool::acquire();
39
40 unsafe {
41 buffer.set_len(4);
42 }
43
44 let size = {
45 if match self {
46 #[cfg(feature = "ssl")]
47 Self::Ssl(stream) => stream.read_exact(&mut buffer[..4]).await?,
48 Self::Base(stream) => stream.read_exact(&mut buffer[..4]).await?,
49 } < 4
50 {
51 return Err(anyhow!("failed to read the first 4 bytes of the message"));
52 }
53
54 Decoder::message_size(&buffer[..4], true)?
55 };
56
57 if size > MemoryPool::MAX_MESSAGE_SIZE {
59 return Err(anyhow!(
60 "message size {} exceeds the maximum allowed size",
61 size
62 ));
63 }
64
65 unsafe {
72 buffer.set_len(size);
73 }
74
75 if match self {
77 #[cfg(feature = "ssl")]
78 Self::Ssl(stream) => stream.read_exact(&mut buffer[4..size]).await?,
79 Self::Base(stream) => stream.read_exact(&mut buffer[4..size]).await?,
80 } < size - 4
81 {
82 return Err(anyhow!("failed to read the full message"));
83 }
84
85 Ok(buffer)
86 }
87
88 async fn write(&mut self, buffer: &[u8]) -> Result<()> {
89 match self {
90 #[cfg(feature = "ssl")]
91 Self::Ssl(stream) => stream.write_all(buffer).await?,
92 Self::Base(stream) => stream.write_all(buffer).await?,
93 }
94
95 Ok(())
96 }
97
98 async fn close(&mut self) {
99 match self {
100 #[cfg(feature = "ssl")]
101 Self::Ssl(stream) => {
102 let _ = stream.shutdown().await;
103 }
104 Self::Base(stream) => {
105 let _ = stream.shutdown().await;
106 }
107 }
108 }
109}
110
111pub struct TcpServer {
112 listener: TcpListener,
113 local_addr: SocketAddr,
114 #[cfg(feature = "ssl")]
115 acceptor: Option<TlsAcceptor>,
116}
117
118impl ProviderServer for TcpServer {
119 type Stream = MaybeSslStream;
120
121 async fn bind(options: &ServerOptions) -> Result<Self> {
122 #[cfg(feature = "ssl")]
123 let acceptor = if let Some(ssl) = &options.ssl {
124 Some(TlsAcceptor::from(Arc::new(
125 ServerConfig::builder()
126 .with_no_client_auth()
127 .with_single_cert(
128 CertificateDer::pem_file_iter(ssl.certificate_chain.clone())?
129 .collect::<Result<Vec<_>, _>>()?,
130 PrivateKeyDer::from_pem_file(ssl.private_key.clone())?,
131 )?,
132 )))
133 } else {
134 None
135 };
136
137 let listener = TcpListener::bind(options.listen).await?;
138 let local_addr = listener.local_addr()?;
139
140 Ok(Self {
141 listener,
142 local_addr,
143 #[cfg(feature = "ssl")]
144 acceptor,
145 })
146 }
147
148 async fn accept(&mut self) -> Result<Poll<(Self::Stream, SocketAddr)>> {
149 let (socket, addr) = self.listener.accept().await?;
150
151 if let Err(e) = socket.set_nodelay(true) {
155 log::warn!("tls socket set nodelay failed!: addr={addr}, err={e}");
156 }
157
158 #[cfg(feature = "ssl")]
159 if let Some(ref acceptor) = self.acceptor {
160 return Ok(if let Ok(socket) = acceptor.accept(socket).await {
161 Poll::Ready((MaybeSslStream::Ssl(socket), addr))
162 } else {
163 Poll::Pending
164 });
165 }
166
167 Ok(Poll::Ready((MaybeSslStream::Base(socket), addr)))
168 }
169
170 fn local_addr(&self) -> Result<SocketAddr> {
171 Ok(self.local_addr)
172 }
173}