zlayer_proxy/stream/
udp.rs1use dashmap::DashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc;
12
13use super::config::DEFAULT_UDP_SESSION_TIMEOUT;
14use super::registry::StreamRegistry;
15
16struct UdpSession {
18 backend: SocketAddr,
20 backend_socket: Arc<UdpSocket>,
22 last_activity: Instant,
24}
25
26pub struct UdpStreamService {
31 registry: Arc<StreamRegistry>,
32 listen_port: u16,
33 session_timeout: Duration,
34}
35
36impl UdpStreamService {
37 #[must_use]
39 pub fn new(
40 registry: Arc<StreamRegistry>,
41 listen_port: u16,
42 session_timeout: Option<Duration>,
43 ) -> Self {
44 Self {
45 registry,
46 listen_port,
47 session_timeout: session_timeout.unwrap_or(DEFAULT_UDP_SESSION_TIMEOUT),
48 }
49 }
50
51 #[must_use]
53 pub fn port(&self) -> u16 {
54 self.listen_port
55 }
56
57 #[must_use]
59 pub fn session_timeout(&self) -> Duration {
60 self.session_timeout
61 }
62
63 #[must_use]
65 pub fn registry(&self) -> &Arc<StreamRegistry> {
66 &self.registry
67 }
68
69 pub async fn run(self: Arc<Self>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
79 let listen_addr = format!("0.0.0.0:{}", self.listen_port);
81 let socket = UdpSocket::bind(&listen_addr).await?;
82
83 tracing::info!(port = self.listen_port, "UDP stream proxy listening");
84
85 self.serve(socket).await
86 }
87
88 #[allow(clippy::too_many_lines)]
101 pub async fn serve(
102 self: Arc<Self>,
103 socket: UdpSocket,
104 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
105 let socket = Arc::new(socket);
106
107 tracing::info!(
108 port = self.listen_port,
109 "UDP stream proxy serving (standalone)"
110 );
111
112 let sessions: Arc<DashMap<SocketAddr, UdpSession>> = Arc::new(DashMap::new());
114
115 let (response_tx, mut response_rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(4096);
117
118 let socket_for_responses = socket.clone();
120 tokio::spawn(async move {
121 while let Some((data, client_addr)) = response_rx.recv().await {
122 if let Err(e) = socket_for_responses.send_to(&data, client_addr).await {
123 tracing::debug!(
124 error = %e,
125 client = %client_addr,
126 "Failed to send UDP response to client"
127 );
128 }
129 }
130 });
131
132 let sessions_for_cleanup = sessions.clone();
134 let timeout = self.session_timeout;
135 tokio::spawn(async move {
136 let mut interval = tokio::time::interval(Duration::from_secs(10));
137 loop {
138 interval.tick().await;
139 let now = Instant::now();
140 let before = sessions_for_cleanup.len();
141 sessions_for_cleanup
142 .retain(|_, session| now.duration_since(session.last_activity) < timeout);
143 let after = sessions_for_cleanup.len();
144 if before != after {
145 tracing::debug!(
146 removed = before - after,
147 remaining = after,
148 "Cleaned up expired UDP sessions"
149 );
150 }
151 }
152 });
153
154 let mut buf = vec![0u8; 65535];
156 loop {
157 let (len, client_addr) = socket.recv_from(&mut buf).await?;
158 let data = buf[..len].to_vec();
159
160 let session_backend = if let Some(mut existing) = sessions.get_mut(&client_addr) {
162 existing.last_activity = Instant::now();
163 existing.backend
164 } else {
165 let Some(service) = self.registry.resolve_udp(self.listen_port) else {
167 tracing::warn!(
168 port = self.listen_port,
169 client = %client_addr,
170 "No service registered for UDP port"
171 );
172 continue;
173 };
174
175 let Some(backend) = service.select_backend() else {
176 tracing::warn!(
177 port = self.listen_port,
178 service = %service.name,
179 client = %client_addr,
180 "No backends available for UDP service"
181 );
182 continue;
183 };
184
185 let backend_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
187 backend_socket.connect(&backend).await?;
188
189 tracing::debug!(
190 port = self.listen_port,
191 service = %service.name,
192 client = %client_addr,
193 backend = %backend,
194 "Created new UDP session"
195 );
196
197 let backend_socket_recv = backend_socket.clone();
199 let response_tx = response_tx.clone();
200 let client = client_addr;
201 let sessions_ref = sessions.clone();
202 tokio::spawn(async move {
203 let mut buf = vec![0u8; 65535];
204 loop {
205 match backend_socket_recv.recv(&mut buf).await {
206 Ok(len) => {
207 if let Some(mut s) = sessions_ref.get_mut(&client) {
209 s.last_activity = Instant::now();
210 }
211 if response_tx
213 .send((buf[..len].to_vec(), client))
214 .await
215 .is_err()
216 {
217 break; }
219 }
220 Err(e) => {
221 tracing::debug!(
222 error = %e,
223 client = %client,
224 "Backend socket receive error"
225 );
226 break;
227 }
228 }
229 }
230 });
231
232 let session = UdpSession {
233 backend,
234 backend_socket,
235 last_activity: Instant::now(),
236 };
237 sessions.insert(client_addr, session);
238 backend
239 };
240
241 if let Some(s) = sessions.get(&client_addr) {
243 if let Err(e) = s.backend_socket.send(&data).await {
244 tracing::debug!(
245 error = %e,
246 client = %client_addr,
247 backend = %session_backend,
248 "Failed to forward UDP packet to backend"
249 );
250 }
251 }
252 }
253 }
254}