Skip to main content

turn_server/server/provider/
udp.rs

1use std::{io::ErrorKind, net::SocketAddr, ops::DerefMut, sync::Arc, task::Poll};
2
3use ahash::{HashMap, HashMapExt};
4use anyhow::{Result, anyhow};
5use tokio::{
6    net::UdpSocket,
7    sync::mpsc::{
8        Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
9    },
10};
11
12use crate::server::{
13    memory_pool::{Buffer, MemoryPool},
14    provider::{ProviderServer, ProviderStream, ServerOptions},
15};
16
17pub struct UdpSession {
18    close_signal_sender: UnboundedSender<SocketAddr>,
19    bytes_receiver: Receiver<Buffer>,
20    socket: Arc<UdpSocket>,
21    addr: SocketAddr,
22}
23
24impl ProviderStream for UdpSession {
25    async fn read(&mut self) -> Result<Buffer> {
26        self.bytes_receiver
27            .recv()
28            .await
29            .ok_or_else(|| anyhow!("channel closed"))
30    }
31
32    async fn write(&mut self, buffer: &[u8]) -> Result<()> {
33        if let Err(e) = self.socket.send_to(buffer, self.addr).await {
34            // Note: An error will also be reported when the remote host is
35            // shut down, which is not processed yet, but a
36            // warning will be issued.
37            if e.kind() != ErrorKind::ConnectionReset {
38                return Err(e.into());
39            }
40        }
41
42        Ok(())
43    }
44
45    async fn close(&mut self) {
46        self.bytes_receiver.close();
47
48        let _ = self.close_signal_sender.send(self.addr);
49    }
50}
51
52pub struct UdpServer {
53    receiver: UnboundedReceiver<UdpSession>,
54    socket: Arc<UdpSocket>,
55}
56
57impl ProviderServer for UdpServer {
58    type Stream = UdpSession;
59
60    async fn bind(options: &ServerOptions) -> Result<Self> {
61        let socket = Arc::new(UdpSocket::bind(options.listen).await?);
62        let (socket_sender, socket_receiver) = unbounded_channel::<UdpSession>();
63        let (close_signal_sender, mut close_signal_receiver) = unbounded_channel::<SocketAddr>();
64
65        {
66            let socket = socket.clone();
67
68            tokio::spawn(async move {
69                let mut sockets = HashMap::<SocketAddr, Sender<Buffer>>::with_capacity(1024);
70
71                loop {
72                    let mut buffer = MemoryPool::acquire();
73
74                    tokio::select! {
75                        ret = socket.recv_buf_from(buffer.deref_mut()) => {
76                            let (size, addr) = match ret {
77                                Ok(it) => it,
78                                // Note: An error will also be reported when the remote host is
79                                // shut down, which is not processed yet, but a
80                                // warning will be issued.
81                                Err(e) => {
82                                    if e.kind() != ErrorKind::ConnectionReset {
83                                        log::error!("udp server recv_from error={e}");
84
85                                        break;
86                                    } else {
87                                        continue;
88                                    }
89                                }
90                            };
91
92                            if size < 4 {
93                                continue;
94                            }
95
96                            if let Some(stream) = sockets.get(&addr) {
97                                if stream.try_send(buffer).is_err()
98                                {
99                                    sockets.remove(&addr);
100                                }
101                            } else {
102                                let (tx, bytes_receiver) = channel::<Buffer>(100);
103
104                                // Send the first packet to the new socket
105                                if tx.try_send(buffer).is_err() {
106                                    continue;
107                                }
108
109                                sockets.insert(addr, tx);
110
111                                if socket_sender
112                                    .send(UdpSession {
113                                        close_signal_sender: close_signal_sender.clone(),
114                                        socket: socket.clone(),
115                                        bytes_receiver,
116                                        addr,
117                                    })
118                                    .is_err()
119                                {
120                                    break;
121                                }
122                            }
123                        }
124                        Some(addr) = close_signal_receiver.recv() => {
125                            let _ = sockets.remove(&addr);
126                        }
127                        else => {
128                            break;
129                        }
130                    }
131                }
132            });
133        }
134
135        Ok(Self {
136            receiver: socket_receiver,
137            socket,
138        })
139    }
140
141    async fn accept(&mut self) -> Result<Poll<(UdpSession, SocketAddr)>> {
142        let socket = self
143            .receiver
144            .recv()
145            .await
146            .ok_or_else(|| anyhow!("channel closed"))?;
147
148        let addr = socket.addr;
149
150        Ok(Poll::Ready((socket, addr)))
151    }
152
153    fn local_addr(&self) -> Result<SocketAddr> {
154        Ok(self.socket.local_addr()?)
155    }
156}