Skip to main content

wireguard_netstack/
wireguard.rs

1//! WireGuard tunnel implementation using gotatun.
2//!
3//! This module wraps gotatun's `Tunn` struct and manages the UDP transport
4//! for sending/receiving encrypted WireGuard packets.
5
6use bytes::BytesMut;
7use gotatun::noise::rate_limiter::RateLimiter;
8use gotatun::noise::{Tunn, TunnResult};
9use gotatun::packet::Packet;
10use gotatun::x25519::{PublicKey, StaticSecret};
11use parking_lot::Mutex;
12use zerocopy::IntoBytes;
13use std::net::{Ipv4Addr, SocketAddr};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::net::UdpSocket;
17use tokio::sync::mpsc;
18
19use crate::error::{Error, Result};
20
21/// Configuration for the WireGuard tunnel.
22#[derive(Clone)]
23pub struct WireGuardConfig {
24    /// Our private key (32 bytes).
25    pub private_key: [u8; 32],
26    /// Peer's public key (32 bytes).
27    pub peer_public_key: [u8; 32],
28    /// Peer's endpoint (IP:port).
29    pub peer_endpoint: SocketAddr,
30    /// Our IP address inside the tunnel.
31    pub tunnel_ip: Ipv4Addr,
32    /// Optional preshared key for additional security.
33    pub preshared_key: Option<[u8; 32]>,
34    /// Keepalive interval in seconds (0 = disabled).
35    pub keepalive_seconds: Option<u16>,
36    /// MTU for the tunnel interface (defaults to 460 if None).
37    pub mtu: Option<u16>,
38}
39
40/// A WireGuard tunnel that encrypts/decrypts IP packets.
41pub struct WireGuardTunnel {
42    /// The underlying gotatun tunnel.
43    tunn: Mutex<Tunn>,
44    /// UDP socket for sending/receiving encrypted packets.
45    udp_socket: Arc<UdpSocket>,
46    /// Peer's endpoint address.
47    peer_endpoint: SocketAddr,
48    /// Our tunnel IP address.
49    tunnel_ip: Ipv4Addr,
50    /// MTU for the tunnel interface.
51    mtu: u16,
52    /// Channel to send received IP packets.
53    incoming_tx: mpsc::Sender<BytesMut>,
54    /// Channel to receive IP packets to send.
55    outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
56    /// Channel to receive incoming IP packets.
57    incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
58    /// Channel to send IP packets for encryption.
59    outgoing_tx: mpsc::Sender<BytesMut>,
60}
61
62impl WireGuardTunnel {
63    /// Create a new WireGuard tunnel with the given configuration.
64    pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
65        // Create the cryptographic keys
66        let private_key = StaticSecret::from(config.private_key);
67        let peer_public_key = PublicKey::from(config.peer_public_key);
68
69        // Create the tunnel
70        let tunn = Tunn::new(
71            private_key,
72            peer_public_key,
73            config.preshared_key,
74            config.keepalive_seconds,
75            rand::random::<u32>() >> 8, // Random index
76            Arc::new(RateLimiter::new(&peer_public_key, 0)),
77        );
78
79        // Bind UDP socket to any available port
80        let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
81
82        // Increase socket receive buffer to avoid packet loss
83        let sock_ref = socket2::SockRef::from(&udp_socket);
84        if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
85            // 1MB buffer
86            log::warn!("Failed to set UDP recv buffer size: {}", e);
87        }
88        if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
89            // 1MB buffer
90            log::warn!("Failed to set UDP send buffer size: {}", e);
91        }
92        log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
93        log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
94
95        log::info!(
96            "WireGuard UDP socket bound to {}",
97            udp_socket.local_addr()?
98        );
99
100        // Create channels for packet communication
101        // incoming: packets received from the tunnel (decrypted)
102        // outgoing: packets to send through the tunnel (to be encrypted)
103        let (incoming_tx, incoming_rx) = mpsc::channel(256);
104        let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
105
106        let tunnel = Arc::new(Self {
107            tunn: Mutex::new(tunn),
108            udp_socket: Arc::new(udp_socket),
109            peer_endpoint: config.peer_endpoint,
110            tunnel_ip: config.tunnel_ip,
111            mtu: config.mtu.unwrap_or(460), // Default MTU
112            incoming_tx,
113            incoming_rx: Mutex::new(Some(incoming_rx)),
114            outgoing_tx,
115            outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
116        });
117
118        Ok(tunnel)
119    }
120
121    /// Get our tunnel IP address.
122    pub fn tunnel_ip(&self) -> Ipv4Addr {
123        self.tunnel_ip
124    }
125
126    /// Get the MTU for the tunnel.
127    pub fn mtu(&self) -> u16 {
128        self.mtu
129    }
130
131    /// Get the sender for outgoing packets.
132    pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
133        self.outgoing_tx.clone()
134    }
135
136    /// Get the receiver for incoming packets (takes ownership of the receiver).
137    pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
138        self.incoming_rx.lock().take()
139    }
140
141    /// Returns the time elapsed since the last successful WireGuard handshake.
142    ///
143    /// Returns `Some(duration)` if a handshake has completed, or `None` if no
144    /// handshake has occurred yet. This is useful for health-checking the tunnel:
145    /// WireGuard re-handshakes every ~120s on an active session, so a value
146    /// exceeding ~180s typically indicates the tunnel is stale.
147    pub fn time_since_last_handshake(&self) -> Option<Duration> {
148        let tunn = self.tunn.lock();
149        tunn.stats().0
150    }
151
152    /// Initiate the WireGuard handshake.
153    pub async fn initiate_handshake(&self) -> Result<()> {
154        log::info!("Initiating WireGuard handshake...");
155
156        let handshake_init = {
157            let mut tunn = self.tunn.lock();
158            tunn.format_handshake_initiation(false)
159        };
160
161        if let Some(packet) = handshake_init {
162            // Convert Packet<WgHandshakeInit> to bytes
163            let data = packet.as_bytes();
164            self.udp_socket.send_to(data, self.peer_endpoint).await?;
165            log::debug!("Sent handshake initiation ({} bytes)", data.len());
166        }
167
168        Ok(())
169    }
170
171    /// Send an IP packet through the tunnel (encrypts and sends via UDP).
172    pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
173        let encrypted = {
174            let mut tunn = self.tunn.lock();
175            let pkt = Packet::from_bytes(packet);
176            tunn.handle_outgoing_packet(pkt)
177        };
178
179        if let Some(wg_packet) = encrypted {
180            // Convert WgKind to Packet<[u8]> and get bytes
181            let pkt: Packet = wg_packet.into();
182            let data = pkt.as_bytes();
183            self.udp_socket.send_to(data, self.peer_endpoint).await?;
184            log::trace!("Sent encrypted packet ({} bytes)", data.len());
185        }
186
187        Ok(())
188    }
189
190    /// Process a received UDP packet (decrypts and returns IP packet if any).
191    fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
192        let packet = Packet::from_bytes(BytesMut::from(data));
193        let wg_packet = match packet.try_into_wg() {
194            Ok(wg) => wg,
195            Err(_) => {
196                log::warn!("Received non-WireGuard packet");
197                return None;
198            }
199        };
200
201        let mut tunn = self.tunn.lock();
202        match tunn.handle_incoming_packet(wg_packet) {
203            TunnResult::Done => {
204                log::trace!("WG: Packet processed (no output)");
205                None
206            }
207            TunnResult::Err(e) => {
208                log::warn!("WG error: {:?}", e);
209                None
210            }
211            TunnResult::WriteToNetwork(response) => {
212                log::trace!("WG: Sending response packet");
213                // Need to send a response (e.g., handshake response, keepalive)
214                let pkt: Packet = response.into();
215                let data = BytesMut::from(pkt.as_bytes());
216                let socket = self.udp_socket.clone();
217                let endpoint = self.peer_endpoint;
218                tokio::spawn(async move {
219                    if let Err(e) = socket.send_to(&data, endpoint).await {
220                        log::error!("Failed to send response: {}", e);
221                    }
222                });
223
224                // Also try to send any queued packets
225                for queued in tunn.get_queued_packets() {
226                    let pkt: Packet = queued.into();
227                    let data = BytesMut::from(pkt.as_bytes());
228                    let socket = self.udp_socket.clone();
229                    let endpoint = self.peer_endpoint;
230                    tokio::spawn(async move {
231                        if let Err(e) = socket.send_to(&data, endpoint).await {
232                            log::error!("Failed to send queued packet: {}", e);
233                        }
234                    });
235                }
236
237                None
238            }
239            TunnResult::WriteToTunnel(decrypted) => {
240                if decrypted.is_empty() {
241                    log::trace!("WG: Received keepalive");
242                    return None;
243                }
244                let bytes = BytesMut::from(decrypted.as_bytes());
245                log::trace!("WG: Decrypted {} bytes", bytes.len());
246                Some(bytes)
247            }
248        }
249    }
250
251    /// Run the tunnel's receive loop (listens for UDP packets and decrypts them).
252    pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
253        let mut buf = vec![0u8; 65535];
254
255        loop {
256            match self.udp_socket.recv_from(&mut buf).await {
257                Ok((len, from)) => {
258                    if from != self.peer_endpoint {
259                        log::warn!("Received packet from unknown peer: {}", from);
260                        continue;
261                    }
262
263                    log::trace!("Received UDP packet ({} bytes) from {}", len, from);
264
265                    if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
266                        if self.incoming_tx.send(ip_packet).await.is_err() {
267                            log::error!("Incoming channel closed");
268                            break;
269                        }
270                    }
271                }
272                Err(e) => {
273                    log::error!("UDP receive error: {}", e);
274                    break;
275                }
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Run the tunnel's send loop (encrypts and sends IP packets).
283    pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
284        let mut outgoing_rx = self.outgoing_rx.lock().await;
285
286        while let Some(packet) = outgoing_rx.recv().await {
287            if let Err(e) = self.send_ip_packet(packet).await {
288                log::error!("Failed to send packet: {}", e);
289            }
290        }
291
292        Ok(())
293    }
294
295    /// Run the tunnel's timer loop (handles keepalives and handshake retries).
296    pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
297        let mut interval = tokio::time::interval(Duration::from_millis(250));
298
299        loop {
300            interval.tick().await;
301
302            let packets_to_send: Vec<Vec<u8>> = {
303                let mut tunn = self.tunn.lock();
304                match tunn.update_timers() {
305                    Ok(Some(packet)) => {
306                        let pkt: Packet = packet.into();
307                        vec![pkt.as_bytes().to_vec()]
308                    }
309                    Ok(None) => vec![],
310                    Err(e) => {
311                        log::trace!("Timer error (may be normal): {:?}", e);
312                        vec![]
313                    }
314                }
315            };
316
317            for packet in packets_to_send {
318                if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
319                    log::error!("Failed to send timer packet: {}", e);
320                }
321            }
322        }
323    }
324
325    /// Wait for the handshake to complete (with timeout).
326    pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
327        let start = std::time::Instant::now();
328
329        loop {
330            {
331                let tunn = self.tunn.lock();
332                // Check if we have an active session - time_since_handshake is Some when session is established
333                let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
334                if time_since_handshake.is_some() {
335                    log::info!("WireGuard handshake completed!");
336                    return Ok(());
337                }
338            }
339
340            if start.elapsed() > timeout_duration {
341                return Err(Error::HandshakeTimeout(timeout_duration));
342            }
343
344            tokio::time::sleep(Duration::from_millis(50)).await;
345        }
346    }
347}