turn_server/server/provider/
udp.rs1use std::{io::ErrorKind, net::SocketAddr, ops::DerefMut, sync::Arc, task::Poll};
2
3use ahash::{HashMap, HashMapExt};
4use anyhow::{Result, anyhow};
5use tokio::{
6 net::UdpSocket,
7 sync::mpsc::{
8 Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
9 },
10};
11
12use crate::server::{
13 memory_pool::{Buffer, MemoryPool},
14 provider::{ProviderServer, ProviderStream, ServerOptions},
15};
16
17pub struct UdpSession {
18 close_signal_sender: UnboundedSender<SocketAddr>,
19 bytes_receiver: Receiver<Buffer>,
20 socket: Arc<UdpSocket>,
21 addr: SocketAddr,
22}
23
24impl ProviderStream for UdpSession {
25 async fn read(&mut self) -> Result<Buffer> {
26 self.bytes_receiver
27 .recv()
28 .await
29 .ok_or_else(|| anyhow!("channel closed"))
30 }
31
32 async fn write(&mut self, buffer: &[u8]) -> Result<()> {
33 if let Err(e) = self.socket.send_to(buffer, self.addr).await {
34 if e.kind() != ErrorKind::ConnectionReset {
38 return Err(e.into());
39 }
40 }
41
42 Ok(())
43 }
44
45 async fn close(&mut self) {
46 self.bytes_receiver.close();
47
48 let _ = self.close_signal_sender.send(self.addr);
49 }
50}
51
52pub struct UdpServer {
53 receiver: UnboundedReceiver<UdpSession>,
54 socket: Arc<UdpSocket>,
55}
56
57impl ProviderServer for UdpServer {
58 type Stream = UdpSession;
59
60 async fn bind(options: &ServerOptions) -> Result<Self> {
61 let socket = Arc::new(UdpSocket::bind(options.listen).await?);
62 let (socket_sender, socket_receiver) = unbounded_channel::<UdpSession>();
63 let (close_signal_sender, mut close_signal_receiver) = unbounded_channel::<SocketAddr>();
64
65 {
66 let socket = socket.clone();
67
68 tokio::spawn(async move {
69 let mut sockets = HashMap::<SocketAddr, Sender<Buffer>>::with_capacity(1024);
70
71 loop {
72 let mut buffer = MemoryPool::acquire();
73
74 tokio::select! {
75 ret = socket.recv_buf_from(buffer.deref_mut()) => {
76 let (size, addr) = match ret {
77 Ok(it) => it,
78 Err(e) => {
82 if e.kind() != ErrorKind::ConnectionReset {
83 log::error!("udp server recv_from error={e}");
84
85 break;
86 } else {
87 continue;
88 }
89 }
90 };
91
92 if size < 4 {
93 continue;
94 }
95
96 if let Some(stream) = sockets.get(&addr) {
97 if stream.try_send(buffer).is_err()
98 {
99 sockets.remove(&addr);
100 }
101 } else {
102 let (tx, bytes_receiver) = channel::<Buffer>(100);
103
104 if tx.try_send(buffer).is_err() {
106 continue;
107 }
108
109 sockets.insert(addr, tx);
110
111 if socket_sender
112 .send(UdpSession {
113 close_signal_sender: close_signal_sender.clone(),
114 socket: socket.clone(),
115 bytes_receiver,
116 addr,
117 })
118 .is_err()
119 {
120 break;
121 }
122 }
123 }
124 Some(addr) = close_signal_receiver.recv() => {
125 let _ = sockets.remove(&addr);
126 }
127 else => {
128 break;
129 }
130 }
131 }
132 });
133 }
134
135 Ok(Self {
136 receiver: socket_receiver,
137 socket,
138 })
139 }
140
141 async fn accept(&mut self) -> Result<Poll<(UdpSession, SocketAddr)>> {
142 let socket = self
143 .receiver
144 .recv()
145 .await
146 .ok_or_else(|| anyhow!("channel closed"))?;
147
148 let addr = socket.addr;
149
150 Ok(Poll::Ready((socket, addr)))
151 }
152
153 fn local_addr(&self) -> Result<SocketAddr> {
154 Ok(self.socket.local_addr()?)
155 }
156}