Skip to main content

rust_web_server/udp_proxy/
mod.rs

1//! Layer-4 UDP proxy.
2//!
3//! [`UdpProxy`] receives UDP datagrams from clients and forwards each one to a
4//! backend server, then relays the backend's reply to the original sender.
5//! This request-reply model covers protocols such as DNS, NTP, and RADIUS.
6//!
7//! Each datagram is handled in its own thread, so the main `bind` loop is
8//! never blocked waiting for a backend reply.
9//!
10//! # Example
11//!
12//! ```rust,no_run
13//! use rust_web_server::udp_proxy::UdpProxy;
14//!
15//! // Forward DNS queries round-robin across two resolvers.
16//! UdpProxy::new(["8.8.8.8:53", "8.8.4.4:53"])
17//!     .reply_timeout_ms(2000)
18//!     .bind("0.0.0.0:53")
19//!     .unwrap();
20//! ```
21
22use std::net::{ToSocketAddrs, UdpSocket};
23use std::sync::{
24    Arc,
25    atomic::{AtomicUsize, Ordering},
26};
27use std::time::Duration;
28
29/// Layer-4 (raw UDP) reverse proxy with round-robin load balancing.
30///
31/// Each client datagram is forwarded to one backend; the backend's reply is
32/// delivered back to the originating client address. No session state is kept
33/// between datagrams.
34///
35/// Call [`UdpProxy::bind`] to start. It blocks the calling thread indefinitely.
36pub struct UdpProxy {
37    backends: Vec<String>,
38    counter: Arc<AtomicUsize>,
39    reply_timeout: Duration,
40    buffer_size: usize,
41}
42
43impl UdpProxy {
44    /// Create a proxy that distributes datagrams across `backends` in
45    /// round-robin order. Each entry must be `"host:port"`.
46    pub fn new<I, S>(backends: I) -> Self
47    where
48        I: IntoIterator<Item = S>,
49        S: Into<String>,
50    {
51        UdpProxy {
52            backends: backends.into_iter().map(|b| b.into()).collect(),
53            counter: Arc::new(AtomicUsize::new(0)),
54            reply_timeout: Duration::from_secs(5),
55            buffer_size: 65536,
56        }
57    }
58
59    /// Override the timeout waiting for a backend reply (default: 5 s).
60    pub fn reply_timeout_ms(mut self, ms: u64) -> Self {
61        self.reply_timeout = Duration::from_millis(ms);
62        self
63    }
64
65    /// Override the per-datagram buffer size (default: 65 536 B).
66    pub fn buffer_size(mut self, bytes: usize) -> Self {
67        self.buffer_size = bytes;
68        self
69    }
70
71    /// Bind on `addr` and start forwarding datagrams. Blocks indefinitely.
72    pub fn bind(self, addr: &str) -> Result<(), String> {
73        if self.backends.is_empty() {
74            return Err("UdpProxy: no backends configured".to_string());
75        }
76        let socket = UdpSocket::bind(addr)
77            .map_err(|e| format!("UdpProxy: bind on {} failed: {}", addr, e))?;
78        println!("UdpProxy: listening on {}", addr);
79        let proxy = Arc::new(self);
80
81        loop {
82            let mut buf = vec![0u8; proxy.buffer_size];
83            let (n, client_addr) = match socket.recv_from(&mut buf) {
84                Ok(v) => v,
85                Err(e) => {
86                    eprintln!("UdpProxy: recv_from error: {}", e);
87                    continue;
88                }
89            };
90            let packet = buf[..n].to_vec();
91            let backend_addr = proxy.pick_backend().to_string();
92            let reply_socket = match socket.try_clone() {
93                Ok(s) => s,
94                Err(e) => {
95                    eprintln!("UdpProxy: socket clone error: {}", e);
96                    continue;
97                }
98            };
99            let timeout = proxy.reply_timeout;
100            let buf_size = proxy.buffer_size;
101
102            std::thread::spawn(move || {
103                let backend_sock_addr = match backend_addr.to_socket_addrs() {
104                    Ok(mut a) => match a.next() {
105                        Some(addr) => addr,
106                        None => {
107                            eprintln!("UdpProxy: no address for {}", backend_addr);
108                            return;
109                        }
110                    },
111                    Err(e) => {
112                        eprintln!("UdpProxy: DNS lookup for {} failed: {}", backend_addr, e);
113                        return;
114                    }
115                };
116
117                let backend = match UdpSocket::bind("0.0.0.0:0") {
118                    Ok(s) => s,
119                    Err(e) => {
120                        eprintln!("UdpProxy: ephemeral socket error: {}", e);
121                        return;
122                    }
123                };
124                let _ = backend.set_read_timeout(Some(timeout));
125
126                if let Err(e) = backend.send_to(&packet, backend_sock_addr) {
127                    eprintln!("UdpProxy: send to {} failed: {}", backend_addr, e);
128                    return;
129                }
130
131                let mut reply = vec![0u8; buf_size];
132                match backend.recv_from(&mut reply) {
133                    Ok((m, _)) => {
134                        let _ = reply_socket.send_to(&reply[..m], client_addr);
135                    }
136                    Err(e) if e.kind() != std::io::ErrorKind::WouldBlock
137                           && e.kind() != std::io::ErrorKind::TimedOut => {
138                        eprintln!("UdpProxy: backend reply error from {}: {}", backend_addr, e);
139                    }
140                    _ => {} // timeout — backend didn't reply in time, drop silently
141                }
142            });
143        }
144    }
145
146    fn pick_backend(&self) -> &str {
147        let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
148        &self.backends[i]
149    }
150}