turn_server/server/transport/
udp.rs1use std::{io::ErrorKind, net::SocketAddr, sync::Arc};
2
3use ahash::{HashMap, HashMapExt};
4use anyhow::Result;
5use bytes::{Bytes, BytesMut};
6use tokio::{
7 net::UdpSocket as TokioUdpSocket,
8 sync::mpsc::{
9 Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
10 },
11};
12
13use crate::server::transport::{Server, ServerOptions, Socket};
14
15pub struct UdpSocket {
16 close_signal_sender: UnboundedSender<SocketAddr>,
17 bytes_receiver: Receiver<Bytes>,
18 socket: Arc<TokioUdpSocket>,
19 addr: SocketAddr,
20}
21
22impl Socket for UdpSocket {
23 async fn read(&mut self) -> Option<Bytes> {
24 self.bytes_receiver.recv().await
25 }
26
27 async fn write(&mut self, buffer: &[u8]) -> Result<()> {
28 if let Err(e) = self.socket.send_to(buffer, self.addr).await {
29 if e.kind() != ErrorKind::ConnectionReset {
33 return Err(e.into());
34 }
35 }
36
37 Ok(())
38 }
39
40 async fn close(&mut self) {
41 self.bytes_receiver.close();
42
43 let _ = self.close_signal_sender.send(self.addr);
44 }
45}
46
47pub struct UdpServer {
48 receiver: UnboundedReceiver<(UdpSocket, SocketAddr)>,
49 socket: Arc<TokioUdpSocket>,
50}
51
52impl Server for UdpServer {
53 type Socket = UdpSocket;
54
55 async fn bind(options: &ServerOptions) -> Result<Self> {
56 let socket = Arc::new(TokioUdpSocket::bind(options.listen).await?);
57 let (socket_sender, socket_receiver) = unbounded_channel::<(UdpSocket, SocketAddr)>();
58 let (close_signal_sender, mut close_signal_receiver) = unbounded_channel::<SocketAddr>();
59
60 {
61 let socket = socket.clone();
62
63 let mut buffer = BytesMut::zeroed(options.mtu);
64
65 tokio::spawn(async move {
66 let mut sockets = HashMap::<SocketAddr, Sender<Bytes>>::with_capacity(1024);
67
68 loop {
69 tokio::select! {
70 ret = socket.recv_from(&mut buffer) => {
71 let (size, addr) = match ret {
72 Ok(it) => it,
73 Err(e) => {
77 if e.kind() != ErrorKind::ConnectionReset {
78 log::error!("udp server recv_from error={e}");
79
80 break;
81 } else {
82 continue;
83 }
84 }
85 };
86
87 if size < 4 {
88 continue;
89 }
90
91 if let Some(stream) = sockets.get(&addr) {
92 if stream.try_send(Bytes::copy_from_slice(&buffer[..size])).is_err()
93 {
94 sockets.remove(&addr);
95 }
96 } else {
97 let (tx, bytes_receiver) = channel::<Bytes>(100);
98
99 if tx.try_send(Bytes::copy_from_slice(&buffer[..size])).is_err() {
101 continue;
102 }
103
104 sockets.insert(addr, tx);
105
106 if socket_sender
107 .send((
108 UdpSocket {
109 close_signal_sender: close_signal_sender.clone(),
110 socket: socket.clone(),
111 bytes_receiver,
112 addr,
113 },
114 addr,
115 ))
116 .is_err()
117 {
118 break;
119 }
120 }
121 }
122 Some(addr) = close_signal_receiver.recv() => {
123 let _ = sockets.remove(&addr);
124 }
125 else => {
126 break;
127 }
128 }
129 }
130 });
131 }
132
133 Ok(Self {
134 receiver: socket_receiver,
135 socket,
136 })
137 }
138
139 async fn accept(&mut self) -> Option<(UdpSocket, SocketAddr)> {
140 self.receiver.recv().await
141 }
142
143 fn local_addr(&self) -> Result<SocketAddr> {
144 Ok(self.socket.local_addr()?)
145 }
146}