turn_server/server/transport/
tcp.rs1use std::{io::Error, net::SocketAddr};
2
3#[cfg(feature = "ssl")]
4use std::sync::Arc;
5
6use anyhow::Result;
7use bytes::{Bytes, BytesMut};
8use tokio::{
9 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
10 net::{TcpListener as TokioTcpListener, TcpStream},
11 sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel},
12};
13
14#[cfg(feature = "ssl")]
15use tokio_rustls::{
16 TlsAcceptor,
17 rustls::{
18 ServerConfig,
19 pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
20 },
21 server::TlsStream,
22};
23
24use crate::{
25 codec::Decoder,
26 server::transport::{MAX_MESSAGE_SIZE, Server, ServerOptions, Socket},
27};
28
29enum MaybeSslStream {
30 #[cfg(feature = "ssl")]
31 Ssl(Box<TlsStream<TcpStream>>),
32 Base(TcpStream),
33}
34
35impl MaybeSslStream {
36 fn split(self) -> (Reader, Writer) {
37 use tokio::io::split;
38
39 match self {
40 Self::Base(it) => {
41 let (rx, tx) = split(it);
42
43 (Reader::Base(rx), Writer::Base(tx))
44 }
45 #[cfg(feature = "ssl")]
46 Self::Ssl(it) => {
47 let (rx, tx) = split(it);
48
49 (Reader::Ssl(rx), Writer::Ssl(tx))
50 }
51 }
52 }
53}
54
55enum Reader {
56 #[cfg(feature = "ssl")]
57 Ssl(ReadHalf<Box<TlsStream<TcpStream>>>),
58 Base(ReadHalf<TcpStream>),
59}
60
61impl Reader {
62 async fn read_buf(&mut self, buffer: &mut BytesMut) -> Result<usize, Error> {
63 match self {
64 Self::Base(it) => it.read_buf(buffer).await,
65 #[cfg(feature = "ssl")]
66 Self::Ssl(it) => it.read_buf(buffer).await,
67 }
68 }
69}
70
71enum Writer {
72 #[cfg(feature = "ssl")]
73 Ssl(WriteHalf<Box<TlsStream<TcpStream>>>),
74 Base(WriteHalf<TcpStream>),
75}
76
77impl Writer {
78 async fn write_all(&mut self, buffer: &[u8]) -> Result<(), Error> {
79 match self {
80 Self::Base(it) => it.write_all(buffer).await,
81 #[cfg(feature = "ssl")]
82 Self::Ssl(it) => it.write_all(buffer).await,
83 }
84 }
85}
86
87pub struct TcpSocket {
88 writer: Writer,
89 receiver: UnboundedReceiver<Bytes>,
90 close_signal_sender: Sender<()>,
91}
92
93impl TcpSocket {
94 fn new(stream: MaybeSslStream, addr: SocketAddr) -> Self {
95 let (close_signal_sender, mut close_signal_receiver) = channel::<()>(1);
96 let (tx, receiver) = unbounded_channel::<Bytes>();
97 let (mut reader, writer) = stream.split();
98
99 tokio::spawn(async move {
100 let mut buffer = BytesMut::new();
101
102 'a: loop {
103 tokio::select! {
104 Ok(size) = reader.read_buf(&mut buffer) => {
105 if size == 0 {
106 break;
107 }
108
109 if buffer.len() < 4 {
112 continue;
113 }
114
115 if buffer.len() > MAX_MESSAGE_SIZE * 3 {
118 break;
119 }
120
121 loop {
122 if buffer.len() <= 4 {
123 break;
124 }
125
126 let size = match Decoder::message_size(&buffer, true) {
131 Err(_) => break,
132 Ok(size) => {
133 if size > MAX_MESSAGE_SIZE {
134 log::warn!(
135 "tcp message size too large: \
136 size={size}, \
137 max={MAX_MESSAGE_SIZE}, \
138 addr={addr:?}"
139 );
140
141 break 'a;
142 }
143
144 if size > buffer.len() {
145 break;
146 }
147
148 size
149 }
150 };
151
152 if tx.send(buffer.split_to(size).freeze()).is_err() {
153 break 'a;
154 }
155 }
156 }
157 _ = close_signal_receiver.recv() => {
158 break;
159 }
160 else => {
161 break;
162 }
163 }
164 }
165 });
166
167 Self {
168 close_signal_sender,
169 writer,
170 receiver,
171 }
172 }
173}
174
175impl Socket for TcpSocket {
176 async fn read(&mut self) -> Option<Bytes> {
177 self.receiver.recv().await
178 }
179
180 async fn write(&mut self, buffer: &[u8]) -> Result<()> {
181 Ok(self.writer.write_all(buffer).await?)
182 }
183
184 async fn close(&mut self) {
185 self.receiver.close();
186
187 let _ = self.close_signal_sender.send(()).await;
188 }
189}
190
191pub struct TcpServer {
192 socket_receiver: UnboundedReceiver<(TcpSocket, SocketAddr)>,
193 local_addr: SocketAddr,
194}
195
196impl Server for TcpServer {
197 type Socket = TcpSocket;
198
199 async fn bind(options: &ServerOptions) -> Result<Self> {
200 #[cfg(feature = "ssl")]
201 let acceptor = if let Some(ssl) = &options.ssl {
202 Some(TlsAcceptor::from(Arc::new(
203 ServerConfig::builder()
204 .with_no_client_auth()
205 .with_single_cert(
206 CertificateDer::pem_file_iter(ssl.certificate_chain.clone())?
207 .collect::<Result<Vec<_>, _>>()?,
208 PrivateKeyDer::from_pem_file(ssl.private_key.clone())?,
209 )?,
210 )))
211 } else {
212 None
213 };
214
215 let listener = TokioTcpListener::bind(options.listen).await?;
216 let local_addr = listener.local_addr()?;
217
218 let (tx, socket_receiver) = unbounded_channel::<(TcpSocket, SocketAddr)>();
219 tokio::spawn(async move {
220 while let Ok((socket, addr)) = listener.accept().await {
221 if let Err(e) = socket.set_nodelay(true) {
225 log::warn!("tls socket set nodelay failed!: addr={addr}, err={e}");
226 }
227
228 #[cfg(feature = "ssl")]
229 if let Some(acceptor) = acceptor.clone() {
230 let tx = tx.clone();
231
232 tokio::spawn(async move {
233 if let Ok(socket) = acceptor.accept(socket).await {
234 let _ = tx.send((
235 TcpSocket::new(MaybeSslStream::Ssl(socket.into()), addr),
236 addr,
237 ));
238 };
239 });
240
241 continue;
242 }
243
244 if tx
245 .send((TcpSocket::new(MaybeSslStream::Base(socket), addr), addr))
246 .is_err()
247 {
248 break;
249 }
250 }
251 });
252
253 Ok(Self {
254 socket_receiver,
255 local_addr,
256 })
257 }
258
259 async fn accept(&mut self) -> Option<(Self::Socket, SocketAddr)> {
260 self.socket_receiver.recv().await
261 }
262
263 fn local_addr(&self) -> Result<SocketAddr> {
264 Ok(self.local_addr)
265 }
266}