wireguard_netstack/
wireguard.rs

1//! WireGuard tunnel implementation using gotatun.
2//!
3//! This module wraps gotatun's `Tunn` struct and manages the UDP transport
4//! for sending/receiving encrypted WireGuard packets.
5
6use bytes::BytesMut;
7use gotatun::noise::{Tunn, TunnResult};
8use gotatun::packet::Packet;
9use gotatun::x25519::{PublicKey, StaticSecret};
10use parking_lot::Mutex;
11use zerocopy::IntoBytes;
12use std::net::{Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::net::UdpSocket;
16use tokio::sync::mpsc;
17
18use crate::error::{Error, Result};
19
20/// Configuration for the WireGuard tunnel.
21#[derive(Clone)]
22pub struct WireGuardConfig {
23    /// Our private key (32 bytes).
24    pub private_key: [u8; 32],
25    /// Peer's public key (32 bytes).
26    pub peer_public_key: [u8; 32],
27    /// Peer's endpoint (IP:port).
28    pub peer_endpoint: SocketAddr,
29    /// Our IP address inside the tunnel.
30    pub tunnel_ip: Ipv4Addr,
31    /// Optional preshared key for additional security.
32    pub preshared_key: Option<[u8; 32]>,
33    /// Keepalive interval in seconds (0 = disabled).
34    pub keepalive_seconds: Option<u16>,
35    /// MTU for the tunnel interface (defaults to 460 if None).
36    pub mtu: Option<u16>,
37}
38
39/// A WireGuard tunnel that encrypts/decrypts IP packets.
40pub struct WireGuardTunnel {
41    /// The underlying gotatun tunnel.
42    tunn: Mutex<Tunn>,
43    /// UDP socket for sending/receiving encrypted packets.
44    udp_socket: Arc<UdpSocket>,
45    /// Peer's endpoint address.
46    peer_endpoint: SocketAddr,
47    /// Our tunnel IP address.
48    tunnel_ip: Ipv4Addr,
49    /// MTU for the tunnel interface.
50    mtu: u16,
51    /// Channel to send received IP packets.
52    incoming_tx: mpsc::Sender<BytesMut>,
53    /// Channel to receive IP packets to send.
54    outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
55    /// Channel to receive incoming IP packets.
56    incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
57    /// Channel to send IP packets for encryption.
58    outgoing_tx: mpsc::Sender<BytesMut>,
59}
60
61impl WireGuardTunnel {
62    /// Create a new WireGuard tunnel with the given configuration.
63    pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
64        // Create the cryptographic keys
65        let private_key = StaticSecret::from(config.private_key);
66        let peer_public_key = PublicKey::from(config.peer_public_key);
67
68        // Create the tunnel
69        let tunn = Tunn::new(
70            private_key,
71            peer_public_key,
72            config.preshared_key,
73            config.keepalive_seconds,
74            rand::random::<u32>() >> 8, // Random index
75            None,                        // No rate limiter for client
76        );
77
78        // Bind UDP socket to any available port
79        let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
80
81        // Increase socket receive buffer to avoid packet loss
82        let sock_ref = socket2::SockRef::from(&udp_socket);
83        if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
84            // 1MB buffer
85            log::warn!("Failed to set UDP recv buffer size: {}", e);
86        }
87        if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
88            // 1MB buffer
89            log::warn!("Failed to set UDP send buffer size: {}", e);
90        }
91        log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
92        log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
93
94        log::info!(
95            "WireGuard UDP socket bound to {}",
96            udp_socket.local_addr()?
97        );
98
99        // Create channels for packet communication
100        // incoming: packets received from the tunnel (decrypted)
101        // outgoing: packets to send through the tunnel (to be encrypted)
102        let (incoming_tx, incoming_rx) = mpsc::channel(256);
103        let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
104
105        let tunnel = Arc::new(Self {
106            tunn: Mutex::new(tunn),
107            udp_socket: Arc::new(udp_socket),
108            peer_endpoint: config.peer_endpoint,
109            tunnel_ip: config.tunnel_ip,
110            mtu: config.mtu.unwrap_or(460), // Default MTU
111            incoming_tx,
112            incoming_rx: Mutex::new(Some(incoming_rx)),
113            outgoing_tx,
114            outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
115        });
116
117        Ok(tunnel)
118    }
119
120    /// Get our tunnel IP address.
121    pub fn tunnel_ip(&self) -> Ipv4Addr {
122        self.tunnel_ip
123    }
124
125    /// Get the MTU for the tunnel.
126    pub fn mtu(&self) -> u16 {
127        self.mtu
128    }
129
130    /// Get the sender for outgoing packets.
131    pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
132        self.outgoing_tx.clone()
133    }
134
135    /// Get the receiver for incoming packets (takes ownership of the receiver).
136    pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
137        self.incoming_rx.lock().take()
138    }
139
140    /// Initiate the WireGuard handshake.
141    pub async fn initiate_handshake(&self) -> Result<()> {
142        log::info!("Initiating WireGuard handshake...");
143
144        let handshake_init = {
145            let mut tunn = self.tunn.lock();
146            tunn.format_handshake_initiation(false)
147        };
148
149        if let Some(packet) = handshake_init {
150            // Convert Packet<WgHandshakeInit> to bytes
151            let data = packet.as_bytes();
152            self.udp_socket.send_to(data, self.peer_endpoint).await?;
153            log::debug!("Sent handshake initiation ({} bytes)", data.len());
154        }
155
156        Ok(())
157    }
158
159    /// Send an IP packet through the tunnel (encrypts and sends via UDP).
160    pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
161        let encrypted = {
162            let mut tunn = self.tunn.lock();
163            let pkt = Packet::from_bytes(packet);
164            tunn.handle_outgoing_packet(pkt)
165        };
166
167        if let Some(wg_packet) = encrypted {
168            // Convert WgKind to Packet<[u8]> and get bytes
169            let pkt: Packet = wg_packet.into();
170            let data = pkt.as_bytes();
171            self.udp_socket.send_to(data, self.peer_endpoint).await?;
172            log::trace!("Sent encrypted packet ({} bytes)", data.len());
173        }
174
175        Ok(())
176    }
177
178    /// Process a received UDP packet (decrypts and returns IP packet if any).
179    fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
180        let packet = Packet::from_bytes(BytesMut::from(data));
181        let wg_packet = match packet.try_into_wg() {
182            Ok(wg) => wg,
183            Err(_) => {
184                log::warn!("Received non-WireGuard packet");
185                return None;
186            }
187        };
188
189        let mut tunn = self.tunn.lock();
190        match tunn.handle_incoming_packet(wg_packet) {
191            TunnResult::Done => {
192                log::trace!("WG: Packet processed (no output)");
193                None
194            }
195            TunnResult::Err(e) => {
196                log::warn!("WG error: {:?}", e);
197                None
198            }
199            TunnResult::WriteToNetwork(response) => {
200                log::trace!("WG: Sending response packet");
201                // Need to send a response (e.g., handshake response, keepalive)
202                let pkt: Packet = response.into();
203                let data = BytesMut::from(pkt.as_bytes());
204                let socket = self.udp_socket.clone();
205                let endpoint = self.peer_endpoint;
206                tokio::spawn(async move {
207                    if let Err(e) = socket.send_to(&data, endpoint).await {
208                        log::error!("Failed to send response: {}", e);
209                    }
210                });
211
212                // Also try to send any queued packets
213                while let Some(queued) = tunn.next_queued_packet() {
214                    let pkt: Packet = queued.into();
215                    let data = BytesMut::from(pkt.as_bytes());
216                    let socket = self.udp_socket.clone();
217                    let endpoint = self.peer_endpoint;
218                    tokio::spawn(async move {
219                        if let Err(e) = socket.send_to(&data, endpoint).await {
220                            log::error!("Failed to send queued packet: {}", e);
221                        }
222                    });
223                }
224
225                None
226            }
227            TunnResult::WriteToTunnel(decrypted) => {
228                if decrypted.is_empty() {
229                    log::trace!("WG: Received keepalive");
230                    return None;
231                }
232                let bytes = BytesMut::from(decrypted.as_bytes());
233                log::trace!("WG: Decrypted {} bytes", bytes.len());
234                Some(bytes)
235            }
236        }
237    }
238
239    /// Run the tunnel's receive loop (listens for UDP packets and decrypts them).
240    pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
241        let mut buf = vec![0u8; 65535];
242
243        loop {
244            match self.udp_socket.recv_from(&mut buf).await {
245                Ok((len, from)) => {
246                    if from != self.peer_endpoint {
247                        log::warn!("Received packet from unknown peer: {}", from);
248                        continue;
249                    }
250
251                    log::trace!("Received UDP packet ({} bytes) from {}", len, from);
252
253                    if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
254                        if self.incoming_tx.send(ip_packet).await.is_err() {
255                            log::error!("Incoming channel closed");
256                            break;
257                        }
258                    }
259                }
260                Err(e) => {
261                    log::error!("UDP receive error: {}", e);
262                    break;
263                }
264            }
265        }
266
267        Ok(())
268    }
269
270    /// Run the tunnel's send loop (encrypts and sends IP packets).
271    pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
272        let mut outgoing_rx = self.outgoing_rx.lock().await;
273
274        while let Some(packet) = outgoing_rx.recv().await {
275            if let Err(e) = self.send_ip_packet(packet).await {
276                log::error!("Failed to send packet: {}", e);
277            }
278        }
279
280        Ok(())
281    }
282
283    /// Run the tunnel's timer loop (handles keepalives and handshake retries).
284    pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
285        let mut interval = tokio::time::interval(Duration::from_millis(250));
286
287        loop {
288            interval.tick().await;
289
290            let packets_to_send: Vec<Vec<u8>> = {
291                let mut tunn = self.tunn.lock();
292                match tunn.update_timers() {
293                    Ok(Some(packet)) => {
294                        let pkt: Packet = packet.into();
295                        vec![pkt.as_bytes().to_vec()]
296                    }
297                    Ok(None) => vec![],
298                    Err(e) => {
299                        log::trace!("Timer error (may be normal): {:?}", e);
300                        vec![]
301                    }
302                }
303            };
304
305            for packet in packets_to_send {
306                if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
307                    log::error!("Failed to send timer packet: {}", e);
308                }
309            }
310        }
311    }
312
313    /// Wait for the handshake to complete (with timeout).
314    pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
315        let start = std::time::Instant::now();
316
317        loop {
318            {
319                let tunn = self.tunn.lock();
320                // Check if we have an active session - time_since_handshake is Some when session is established
321                let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
322                if time_since_handshake.is_some() {
323                    log::info!("WireGuard handshake completed!");
324                    return Ok(());
325                }
326            }
327
328            if start.elapsed() > timeout_duration {
329                return Err(Error::HandshakeTimeout(timeout_duration));
330            }
331
332            tokio::time::sleep(Duration::from_millis(50)).await;
333        }
334    }
335}