turn_server/
server.rs

1use crate::{
2    config::{Config, Interface},
3    router::Router,
4    statistics::Statistics,
5    turn::{Observer, Service},
6};
7
8use std::net::SocketAddr;
9
10#[allow(unused)]
11struct ServerStartOptions<T> {
12    bind: SocketAddr,
13    external: SocketAddr,
14    service: Service<T>,
15    router: Router,
16    statistics: Statistics,
17}
18
19#[allow(unused)]
20trait Server {
21    async fn start<T>(options: ServerStartOptions<T>) -> Result<(), anyhow::Error>
22    where
23        T: Clone + Observer + 'static;
24}
25
26#[cfg(feature = "udp")]
27mod udp {
28    use super::{Server as ServerExt, ServerStartOptions};
29    use crate::{
30        statistics::Stats,
31        stun::Transport,
32        turn::{Observer, ResponseMethod, SessionAddr},
33    };
34
35    use std::{io::ErrorKind::ConnectionReset, sync::Arc};
36
37    use tokio::net::UdpSocket;
38
39    /// udp socket process thread.
40    ///
41    /// read the data packet from the UDP socket and hand
42    /// it to the proto for processing, and send the processed
43    /// data packet to the specified address.
44    pub struct Server;
45
46    impl ServerExt for Server {
47        async fn start<T>(
48            ServerStartOptions {
49                bind,
50                external,
51                service,
52                router,
53                statistics,
54            }: ServerStartOptions<T>,
55        ) -> Result<(), anyhow::Error>
56        where
57            T: Clone + Observer + 'static,
58        {
59            let socket = Arc::new(UdpSocket::bind(bind).await?);
60            let local_addr = socket.local_addr()?;
61
62            {
63                let socket = socket.clone();
64                let router = router.clone();
65                let reporter = statistics.get_reporter(Transport::UDP);
66                let mut operationer = service.get_operationer(external, external);
67
68                let mut session_addr = SessionAddr {
69                    address: external,
70                    interface: external,
71                };
72
73                tokio::spawn(async move {
74                    let mut buf = vec![0u8; 2048];
75
76                    loop {
77                        // Note: An error will also be reported when the remote host is
78                        // shut down, which is not processed yet, but a
79                        // warning will be issued.
80                        let (size, addr) = match socket.recv_from(&mut buf).await {
81                            Err(e) if e.kind() != ConnectionReset => break,
82                            Ok(s) => s,
83                            _ => continue,
84                        };
85
86                        session_addr.address = addr;
87
88                        reporter.send(&session_addr, &[Stats::ReceivedBytes(size), Stats::ReceivedPkts(1)]);
89
90                        // The stun message requires at least 4 bytes. (currently the
91                        // smallest stun message is channel data,
92                        // excluding content)
93                        if size >= 4 {
94                            if let Ok(Some(res)) = operationer.route(&buf[..size], addr) {
95                                let target = res.relay.as_ref().unwrap_or(&addr);
96                                if let Some(ref endpoint) = res.endpoint {
97                                    router.send(endpoint, res.method, target, res.bytes);
98                                } else {
99                                    if let Err(e) = socket.send_to(res.bytes, target).await {
100                                        if e.kind() != ConnectionReset {
101                                            break;
102                                        }
103                                    }
104
105                                    reporter
106                                        .send(&session_addr, &[Stats::SendBytes(res.bytes.len()), Stats::SendPkts(1)]);
107
108                                    if let ResponseMethod::Stun(method) = res.method {
109                                        if method.is_error() {
110                                            reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
111                                        }
112                                    }
113                                }
114                            }
115                        }
116                    }
117                });
118            }
119
120            tokio::spawn(async move {
121                let mut session_addr = SessionAddr {
122                    address: external,
123                    interface: external,
124                };
125
126                let reporter = statistics.get_reporter(Transport::UDP);
127                let mut receiver = router.get_receiver(external);
128                while let Some((bytes, _, addr)) = receiver.recv().await {
129                    session_addr.address = addr;
130
131                    if let Err(e) = socket.send_to(&bytes, addr).await {
132                        if e.kind() != ConnectionReset {
133                            break;
134                        }
135                    } else {
136                        reporter.send(&session_addr, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
137                    }
138                }
139
140                router.remove(&external);
141
142                log::error!("udp server close: interface={:?}", local_addr);
143            });
144
145            log::info!(
146                "turn server listening: bind={}, external={}, transport=UDP",
147                bind,
148                external,
149            );
150
151            Ok(())
152        }
153    }
154}
155
156#[cfg(feature = "tcp")]
157mod tcp {
158    use super::{Server as ServerExt, ServerStartOptions};
159    use crate::{
160        statistics::Stats,
161        stun::{Decoder, Transport},
162        turn::{Observer, ResponseMethod, SessionAddr},
163    };
164
165    use std::{
166        ops::{Deref, DerefMut},
167        sync::Arc,
168    };
169
170    use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::TcpListener, sync::Mutex};
171
172    static ZERO_BYTES: [u8; 8] = [0u8; 8];
173
174    /// An emulated double buffer queue, this is used when reading data over
175    /// TCP.
176    ///
177    /// When reading data over TCP, you need to keep adding to the buffer until
178    /// you find the delimited position. But this double buffer queue solves
179    /// this problem well, in the queue, the separation is treated as the first
180    /// read operation and after the separation the buffer is reversed and
181    /// another free buffer is used for writing the data.
182    ///
183    /// If the current buffer in the separation after the existence of
184    /// unconsumed data, this time the unconsumed data will be copied to another
185    /// free buffer, and fill the length of the free buffer data, this time to
186    /// write data again when you can continue to fill to the end of the
187    /// unconsumed data.
188    ///
189    /// This queue only needs to copy the unconsumed data without duplicating
190    /// the memory allocation, which will reduce a lot of overhead.
191    struct ExchangeBuffer {
192        buffers: [(Vec<u8>, usize /* len */); 2],
193        index: usize,
194    }
195
196    impl Default for ExchangeBuffer {
197        #[rustfmt::skip]
198        fn default() -> Self {
199            Self {
200                index: 0,
201                buffers: [
202                    (vec![0u8; 2048], 0),
203                    (vec![0u8; 2048], 0),
204                ],
205            }
206        }
207    }
208
209    impl Deref for ExchangeBuffer {
210        type Target = [u8];
211
212        fn deref(&self) -> &Self::Target {
213            &self.buffers[self.index].0[..]
214        }
215    }
216
217    impl DerefMut for ExchangeBuffer {
218        // Writes need to take into account overwriting written data, so fetching the
219        // writable buffer starts with the internal cursor.
220        fn deref_mut(&mut self) -> &mut Self::Target {
221            let len = self.buffers[self.index].1;
222            &mut self.buffers[self.index].0[len..]
223        }
224    }
225
226    impl ExchangeBuffer {
227        fn len(&self) -> usize {
228            self.buffers[self.index].1
229        }
230
231        /// The buffer does not automatically advance the cursor as BytesMut
232        /// does, and you need to manually advance the length of the data
233        /// written.
234        fn advance(&mut self, len: usize) {
235            self.buffers[self.index].1 += len;
236        }
237
238        fn split(&mut self, len: usize) -> &[u8] {
239            let (ref current_bytes, current_len) = self.buffers[self.index];
240
241            // The length of the separation cannot be greater than the length of the data.
242            assert!(len <= current_len);
243
244            // Length of unconsumed data
245            let remaining = current_len - len;
246
247            {
248                // The current buffer is no longer in use, resetting the content length.
249                self.buffers[self.index].1 = 0;
250
251                // Invert the buffer.
252                self.index = if self.index == 0 { 1 } else { 0 };
253
254                // The length of unconsumed data needs to be updated into the reversed
255                // completion buffer.
256                self.buffers[self.index].1 = remaining;
257            }
258
259            // Unconsumed data exists and is copied to the free buffer.
260            #[allow(mutable_transmutes)]
261            if remaining > 0 {
262                unsafe { std::mem::transmute::<&[u8], &mut [u8]>(&self.buffers[self.index].0[..remaining]) }
263                    .copy_from_slice(&current_bytes[len..current_len]);
264            }
265
266            &current_bytes[..len]
267        }
268    }
269
270    /// tcp socket process thread.
271    ///
272    /// This function is used to handle all connections coming from the tcp
273    /// listener, and handle the receiving, sending and forwarding of messages.
274    pub struct Server;
275
276    impl ServerExt for Server {
277        async fn start<T>(
278            ServerStartOptions {
279                bind,
280                external,
281                service,
282                router,
283                statistics,
284            }: ServerStartOptions<T>,
285        ) -> Result<(), anyhow::Error>
286        where
287            T: Clone + Observer + 'static,
288        {
289            let listener = TcpListener::bind(bind).await?;
290            let local_addr = listener.local_addr()?;
291
292            tokio::spawn(async move {
293                // Accept all connections on the current listener, but exit the entire
294                // process when an error occurs.
295                while let Ok((socket, address)) = listener.accept().await {
296                    let router = router.clone();
297                    let reporter = statistics.get_reporter(Transport::TCP);
298                    let mut receiver = router.get_receiver(address);
299                    let mut operationer = service.get_operationer(address, external);
300
301                    log::info!("tcp socket accept: addr={:?}, interface={:?}", address, local_addr,);
302
303                    // Disable the Nagle algorithm.
304                    // because to maintain real-time, any received data should be processed
305                    // as soon as possible.
306                    if let Err(e) = socket.set_nodelay(true) {
307                        log::error!("tcp socket set nodelay failed!: addr={}, err={}", address, e);
308                    }
309
310                    let session_addr = SessionAddr {
311                        interface: external,
312                        address,
313                    };
314
315                    let (mut reader, writer) = socket.into_split();
316                    let writer = Arc::new(Mutex::new(writer));
317
318                    // Use a separate task to handle messages forwarded to this socket.
319                    let writer_ = writer.clone();
320                    let reporter_ = reporter.clone();
321                    tokio::spawn(async move {
322                        while let Some((bytes, method, _)) = receiver.recv().await {
323                            let mut writer = writer_.lock().await;
324                            if writer.write_all(bytes.as_slice()).await.is_err() {
325                                break;
326                            } else {
327                                reporter_.send(&session_addr, &[Stats::SendBytes(bytes.len()), Stats::SendPkts(1)]);
328                            }
329
330                            // The channel data needs to be aligned in multiples of 4 in
331                            // tcp. If the channel data is forwarded to tcp, the alignment
332                            // bit needs to be filled, because if the channel data comes
333                            // from udp, it is not guaranteed to be aligned and needs to be
334                            // checked.
335                            if method == ResponseMethod::ChannelData {
336                                let pad = bytes.len() % 4;
337                                if pad > 0 && writer.write_all(&ZERO_BYTES[..(4 - pad)]).await.is_err() {
338                                    break;
339                                }
340                            }
341                        }
342                    });
343
344                    let sessions = service.get_sessions();
345                    tokio::spawn(async move {
346                        let mut buffer = ExchangeBuffer::default();
347
348                        'a: while let Ok(size) = reader.read(&mut buffer).await {
349                            // When the received message is 0, it means that the socket
350                            // has been closed.
351                            if size == 0 {
352                                break;
353                            } else {
354                                reporter.send(&session_addr, &[Stats::ReceivedBytes(size)]);
355                                buffer.advance(size);
356                            }
357
358                            // The minimum length of a stun message will not be less
359                            // than 4.
360                            if buffer.len() < 4 {
361                                continue;
362                            }
363
364                            loop {
365                                if buffer.len() <= 4 {
366                                    break;
367                                }
368
369                                // Try to get the message length, if the currently
370                                // received data is less than the message length, jump
371                                // out of the current loop and continue to receive more
372                                // data.
373                                let size = match Decoder::message_size(&buffer, true) {
374                                    Err(_) => break,
375                                    Ok(s) => {
376                                        // Limit the maximum length of messages to 2048, this is to prevent buffer
377                                        // overflow attacks.
378                                        if s > 2048 {
379                                            break 'a;
380                                        }
381
382                                        if s > buffer.len() {
383                                            break;
384                                        }
385
386                                        reporter.send(&session_addr, &[Stats::ReceivedPkts(1)]);
387
388                                        s
389                                    }
390                                };
391
392                                let chunk = buffer.split(size);
393                                if let Ok(ret) = operationer.route(chunk, address) {
394                                    if let Some(res) = ret {
395                                        if let Some(ref inerface) = res.endpoint {
396                                            router.send(
397                                                inerface,
398                                                res.method,
399                                                res.relay.as_ref().unwrap_or(&address),
400                                                res.bytes,
401                                            );
402                                        } else {
403                                            if writer.lock().await.write_all(res.bytes).await.is_err() {
404                                                break 'a;
405                                            }
406
407                                            reporter.send(
408                                                &session_addr,
409                                                &[Stats::SendBytes(res.bytes.len()), Stats::SendPkts(1)],
410                                            );
411
412                                            if let ResponseMethod::Stun(method) = res.method {
413                                                if method.is_error() {
414                                                    reporter.send(&session_addr, &[Stats::ErrorPkts(1)]);
415                                                }
416                                            }
417                                        }
418                                    }
419                                } else {
420                                    break 'a;
421                                }
422                            }
423                        }
424
425                        // When the tcp connection is closed, the procedure to close the session is
426                        // process directly once, avoiding the connection being disconnected
427                        // directly without going through the closing
428                        // process.
429                        sessions.refresh(&session_addr, 0);
430
431                        router.remove(&address);
432
433                        log::info!("tcp socket disconnect: addr={:?}, interface={:?}", address, local_addr);
434                    });
435                }
436
437                log::error!("tcp server close: interface={:?}", local_addr);
438            });
439
440            log::info!(
441                "turn server listening: bind={}, external={}, transport=TCP",
442                bind,
443                external,
444            );
445
446            Ok(())
447        }
448    }
449}
450
451/// start turn server.
452///
453/// create a specified number of threads,
454/// each thread processes udp data separately.
455pub async fn start<T>(config: &Config, statistics: &Statistics, service: &Service<T>) -> anyhow::Result<()>
456where
457    T: Clone + Observer + 'static,
458{
459    #[allow(unused)]
460    use crate::config::Transport;
461
462    let router = Router::default();
463    for Interface {
464        transport,
465        external,
466        bind,
467    } in config.turn.interfaces.iter().cloned()
468    {
469        #[allow(unused)]
470        let options = ServerStartOptions {
471            statistics: statistics.clone(),
472            service: service.clone(),
473            router: router.clone(),
474            external,
475            bind,
476        };
477
478        match transport {
479            #[cfg(feature = "udp")]
480            Transport::UDP => udp::Server::start(options).await?,
481            #[cfg(feature = "tcp")]
482            Transport::TCP => tcp::Server::start(options).await?,
483            #[allow(unreachable_patterns)]
484            _ => (),
485        };
486    }
487
488    Ok(())
489}