wireguard_netstack/
netstack.rs

1//! Userspace TCP/IP network stack using smoltcp.
2//!
3//! This module provides a TCP/IP stack that runs entirely in userspace,
4//! routing packets through our WireGuard tunnel.
5
6use crate::error::{Error, Result};
7use crate::wireguard::WireGuardTunnel;
8use bytes::BytesMut;
9use parking_lot::Mutex;
10use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet};
11use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
12use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer, State as TcpState};
13use smoltcp::time::Instant;
14use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket};
15use std::collections::VecDeque;
16use std::net::{SocketAddr, SocketAddrV4};
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::mpsc;
20
21/// MTU for the virtual interface.
22/// Some networks drop large UDP packets, especially when WireGuard overhead is added.
23/// We use a conservative MTU that results in ~600 byte UDP packets after WireGuard
24/// encapsulation (MTU + 40 IP/TCP headers + 48 WG overhead ≈ 548 byte UDP).
25/// This works around networks that filter large UDP packets.
26const MTU: usize = 460;
27
28/// Size of TCP socket buffers.
29const TCP_BUFFER_SIZE: usize = 65535;
30
31/// A virtual network device that sends/receives through the WireGuard tunnel.
32struct VirtualDevice {
33    /// Packets ready to be received by smoltcp (from WireGuard).
34    rx_queue: VecDeque<BytesMut>,
35    /// Packets ready to be sent (to WireGuard).
36    tx_queue: VecDeque<BytesMut>,
37}
38
39impl VirtualDevice {
40    fn new() -> Self {
41        Self {
42            rx_queue: VecDeque::new(),
43            tx_queue: VecDeque::new(),
44        }
45    }
46
47    /// Add a packet to the receive queue (from WireGuard).
48    fn push_rx(&mut self, packet: BytesMut) {
49        self.rx_queue.push_back(packet);
50    }
51
52    /// Take all packets from the transmit queue (to send via WireGuard).
53    fn drain_tx(&mut self) -> Vec<BytesMut> {
54        self.tx_queue.drain(..).collect()
55    }
56}
57
58/// RxToken for smoltcp.
59struct VirtualRxToken {
60    buffer: BytesMut,
61}
62
63impl RxToken for VirtualRxToken {
64    fn consume<R, F>(self, f: F) -> R
65    where
66        F: FnOnce(&[u8]) -> R,
67    {
68        f(&self.buffer)
69    }
70}
71
72/// TxToken for smoltcp.
73struct VirtualTxToken<'a> {
74    tx_queue: &'a mut VecDeque<BytesMut>,
75}
76
77impl<'a> TxToken for VirtualTxToken<'a> {
78    fn consume<R, F>(self, len: usize, f: F) -> R
79    where
80        F: FnOnce(&mut [u8]) -> R,
81    {
82        let mut buffer = BytesMut::zeroed(len);
83        let result = f(&mut buffer);
84        self.tx_queue.push_back(buffer);
85        result
86    }
87
88    fn set_meta(&mut self, _meta: smoltcp::phy::PacketMeta) {
89        // No metadata handling needed for virtual device
90    }
91}
92
93impl Device for VirtualDevice {
94    type RxToken<'a> = VirtualRxToken;
95    type TxToken<'a> = VirtualTxToken<'a>;
96
97    fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
98        if let Some(buffer) = self.rx_queue.pop_front() {
99            Some((
100                VirtualRxToken { buffer },
101                VirtualTxToken {
102                    tx_queue: &mut self.tx_queue,
103                },
104            ))
105        } else {
106            None
107        }
108    }
109
110    fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
111        Some(VirtualTxToken {
112            tx_queue: &mut self.tx_queue,
113        })
114    }
115
116    fn capabilities(&self) -> DeviceCapabilities {
117        let mut caps = DeviceCapabilities::default();
118        caps.medium = Medium::Ip;
119        caps.max_transmission_unit = MTU;
120        caps
121    }
122}
123
124/// Shared state for the network stack.
125struct NetStackInner {
126    interface: Interface,
127    device: VirtualDevice,
128    sockets: SocketSet<'static>,
129}
130
131/// A userspace TCP/IP network stack.
132pub struct NetStack {
133    inner: Mutex<NetStackInner>,
134    wg_tunnel: Arc<WireGuardTunnel>,
135    /// Sender to queue packets for transmission through WireGuard.
136    wg_tx: mpsc::Sender<BytesMut>,
137}
138
139impl NetStack {
140    /// Create a new network stack backed by a WireGuard tunnel.
141    pub fn new(wg_tunnel: Arc<WireGuardTunnel>) -> Arc<Self> {
142        let tunnel_ip = wg_tunnel.tunnel_ip();
143        let wg_tx = wg_tunnel.outgoing_sender();
144
145        // Create the virtual device
146        let mut device = VirtualDevice::new();
147
148        // Create the interface configuration
149        let config = Config::new(HardwareAddress::Ip);
150
151        // Create the interface
152        let mut interface = Interface::new(config, &mut device, Instant::now());
153
154        // Configure the interface with our tunnel IP
155        interface.update_ip_addrs(|addrs| {
156            addrs
157                .push(IpCidr::new(
158                    IpAddress::v4(
159                        tunnel_ip.octets()[0],
160                        tunnel_ip.octets()[1],
161                        tunnel_ip.octets()[2],
162                        tunnel_ip.octets()[3],
163                    ),
164                    32,
165                ))
166                .unwrap();
167        });
168
169        // Set up routing - route everything through this interface
170        interface
171            .routes_mut()
172            .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 0))
173            .unwrap();
174
175        // Create socket set
176        let sockets = SocketSet::new(vec![]);
177
178        let inner = NetStackInner {
179            interface,
180            device,
181            sockets,
182        };
183
184        Arc::new(Self {
185            inner: Mutex::new(inner),
186            wg_tunnel,
187            wg_tx,
188        })
189    }
190
191    /// Create a new TCP socket and return its handle.
192    pub fn create_tcp_socket(&self) -> SocketHandle {
193        let mut inner = self.inner.lock();
194
195        let rx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
196        let tx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
197        let socket = TcpSocket::new(rx_buffer, tx_buffer);
198
199        inner.sockets.add(socket)
200    }
201
202    /// Connect a TCP socket to the given address.
203    pub fn connect(&self, handle: SocketHandle, addr: SocketAddr) -> Result<()> {
204        let mut inner = self.inner.lock();
205
206        let local_port = 49152 + (rand::random::<u16>() % 16384);
207        let local_addr = SocketAddrV4::new(self.wg_tunnel.tunnel_ip(), local_port);
208
209        let remote = match addr {
210            SocketAddr::V4(v4) => smoltcp::wire::IpEndpoint::new(
211                IpAddress::v4(
212                    v4.ip().octets()[0],
213                    v4.ip().octets()[1],
214                    v4.ip().octets()[2],
215                    v4.ip().octets()[3],
216                ),
217                v4.port(),
218            ),
219            SocketAddr::V6(_) => return Err(Error::Ipv6NotSupported),
220        };
221
222        let local = smoltcp::wire::IpEndpoint::new(
223            IpAddress::v4(
224                local_addr.ip().octets()[0],
225                local_addr.ip().octets()[1],
226                local_addr.ip().octets()[2],
227                local_addr.ip().octets()[3],
228            ),
229            local_addr.port(),
230        );
231
232        // Use destructuring to avoid split borrow issues
233        let NetStackInner {
234            ref mut interface,
235            ref mut sockets,
236            ..
237        } = *inner;
238        let cx = interface.context();
239        let socket = sockets.get_mut::<TcpSocket>(handle);
240        socket
241            .connect(cx, remote, local)
242            .map_err(|e| Error::TcpConnectGeneric(format!("TCP connect failed: {}", e)))?;
243
244        log::debug!("TCP socket connecting to {} from {}", addr, local_addr);
245
246        Ok(())
247    }
248
249    /// Check if a TCP socket is connected.
250    pub fn is_connected(&self, handle: SocketHandle) -> bool {
251        let inner = self.inner.lock();
252        let socket = inner.sockets.get::<TcpSocket>(handle);
253        socket.state() == TcpState::Established
254    }
255
256    /// Check if a TCP socket can send data.
257    pub fn can_send(&self, handle: SocketHandle) -> bool {
258        let inner = self.inner.lock();
259        let socket = inner.sockets.get::<TcpSocket>(handle);
260        socket.can_send()
261    }
262
263    /// Check if a TCP socket can receive data.
264    pub fn can_recv(&self, handle: SocketHandle) -> bool {
265        let inner = self.inner.lock();
266        let socket = inner.sockets.get::<TcpSocket>(handle);
267        let can = socket.can_recv();
268        let recv_queue = socket.recv_queue();
269        if recv_queue > 0 {
270            log::debug!(
271                "Socket can_recv={}, recv_queue={}, state={:?}",
272                can,
273                recv_queue,
274                socket.state()
275            );
276        }
277        can
278    }
279
280    /// Check if a TCP socket may send data (connection in progress or established).
281    pub fn may_send(&self, handle: SocketHandle) -> bool {
282        let inner = self.inner.lock();
283        let socket = inner.sockets.get::<TcpSocket>(handle);
284        socket.may_send()
285    }
286
287    /// Check if a TCP socket may receive data.
288    pub fn may_recv(&self, handle: SocketHandle) -> bool {
289        let inner = self.inner.lock();
290        let socket = inner.sockets.get::<TcpSocket>(handle);
291        socket.may_recv()
292    }
293
294    /// Get the TCP socket state.
295    pub fn socket_state(&self, handle: SocketHandle) -> TcpState {
296        let inner = self.inner.lock();
297        let socket = inner.sockets.get::<TcpSocket>(handle);
298        socket.state()
299    }
300
301    /// Send data on a TCP socket.
302    pub fn send(&self, handle: SocketHandle, data: &[u8]) -> Result<usize> {
303        let mut inner = self.inner.lock();
304        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
305
306        socket
307            .send_slice(data)
308            .map_err(|e| Error::TcpSend(e.to_string()))
309    }
310
311    /// Receive data from a TCP socket.
312    pub fn recv(&self, handle: SocketHandle, buffer: &mut [u8]) -> Result<usize> {
313        let mut inner = self.inner.lock();
314        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
315
316        socket
317            .recv_slice(buffer)
318            .map_err(|e| Error::TcpRecv(e.to_string()))
319    }
320
321    /// Close a TCP socket.
322    pub fn close(&self, handle: SocketHandle) {
323        let mut inner = self.inner.lock();
324        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
325        socket.close();
326    }
327
328    /// Remove a socket from the socket set.
329    pub fn remove_socket(&self, handle: SocketHandle) {
330        let mut inner = self.inner.lock();
331        inner.sockets.remove(handle);
332    }
333
334    /// Poll the network stack, processing packets and updating socket states.
335    /// Returns true if there was any activity.
336    pub fn poll(&self) -> bool {
337        let mut inner = self.inner.lock();
338
339        let timestamp = Instant::now();
340
341        // Destructure to allow split borrows
342        let NetStackInner {
343            ref mut interface,
344            ref mut device,
345            ref mut sockets,
346        } = *inner;
347
348        // Check if there are packets waiting
349        let rx_queue_len = device.rx_queue.len();
350        if rx_queue_len > 0 {
351            log::trace!("NetStack poll: {} packets in rx_queue", rx_queue_len);
352        }
353
354        // Poll the interface
355        let poll_result = interface.poll(timestamp, device, sockets);
356        let processed = poll_result != PollResult::None;
357
358        if processed {
359            log::trace!("NetStack poll processed packets");
360        }
361
362        // Drain transmitted packets and send through WireGuard
363        let tx_packets = device.drain_tx();
364        let tx_count = tx_packets.len();
365        drop(inner); // Release lock before async operations
366
367        if tx_count > 0 {
368            log::trace!("NetStack poll sending {} packets", tx_count);
369        }
370
371        for packet in tx_packets {
372            // Log outgoing TCP packets at debug level
373            if log::log_enabled!(log::Level::Debug) {
374                if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
375                    let protocol = ip_packet.next_header();
376                    if protocol == smoltcp::wire::IpProtocol::Tcp {
377                        if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
378                            let dst_port = tcp_packet.dst_port();
379                            let payload_len = tcp_packet.payload().len();
380
381                            let mut flags = String::new();
382                            if tcp_packet.syn() {
383                                flags.push_str("SYN ");
384                            }
385                            if tcp_packet.ack() {
386                                flags.push_str("ACK ");
387                            }
388                            if tcp_packet.fin() {
389                                flags.push_str("FIN ");
390                            }
391                            if tcp_packet.rst() {
392                                flags.push_str("RST ");
393                            }
394                            if tcp_packet.psh() {
395                                flags.push_str("PSH ");
396                            }
397
398                            log::debug!(
399                                "TX: {}:{} [{}] {} bytes",
400                                ip_packet.dst_addr(),
401                                dst_port,
402                                flags.trim(),
403                                payload_len
404                            );
405                        }
406                    }
407                }
408            }
409
410            let tx = self.wg_tx.clone();
411            tokio::spawn(async move {
412                if let Err(e) = tx.send(packet).await {
413                    log::error!("Failed to queue packet for WireGuard: {}", e);
414                }
415            });
416        }
417
418        processed
419    }
420
421    /// Push a received packet (from WireGuard) into the network stack.
422    pub fn push_rx_packet(&self, packet: BytesMut) {
423        // Parse and log TCP packet details for debugging
424        if log::log_enabled!(log::Level::Debug) {
425            if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
426                let protocol = ip_packet.next_header();
427                if protocol == smoltcp::wire::IpProtocol::Tcp {
428                    if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
429                        let src_port = tcp_packet.src_port();
430                        let payload_len = tcp_packet.payload().len();
431
432                        let mut flags = String::new();
433                        if tcp_packet.syn() {
434                            flags.push_str("SYN ");
435                        }
436                        if tcp_packet.ack() {
437                            flags.push_str("ACK ");
438                        }
439                        if tcp_packet.fin() {
440                            flags.push_str("FIN ");
441                        }
442                        if tcp_packet.rst() {
443                            flags.push_str("RST ");
444                        }
445                        if tcp_packet.psh() {
446                            flags.push_str("PSH ");
447                        }
448
449                        log::debug!(
450                            "RX: {}:{} [{}] {} bytes",
451                            ip_packet.src_addr(),
452                            src_port,
453                            flags.trim(),
454                            payload_len
455                        );
456                    }
457                }
458            }
459        }
460
461        let mut inner = self.inner.lock();
462        inner.device.push_rx(packet);
463    }
464
465    /// Run the network stack polling loop.
466    pub async fn run_poll_loop(self: &Arc<Self>) -> Result<()> {
467        let mut interval = tokio::time::interval(Duration::from_millis(1));
468
469        loop {
470            interval.tick().await;
471            self.poll();
472        }
473    }
474
475    /// Run the receive loop that takes packets from WireGuard and feeds them to the stack.
476    pub async fn run_rx_loop(self: &Arc<Self>, mut rx: mpsc::Receiver<BytesMut>) -> Result<()> {
477        while let Some(packet) = rx.recv().await {
478            log::debug!("NetStack received packet ({} bytes)", packet.len());
479            self.push_rx_packet(packet);
480            self.poll();
481        }
482
483        Ok(())
484    }
485}
486
487/// A TCP connection through our network stack.
488pub struct TcpConnection {
489    /// The network stack backing this connection.
490    pub netstack: Arc<NetStack>,
491    /// The socket handle for this connection.
492    pub handle: SocketHandle,
493}
494
495impl TcpConnection {
496    /// Create a new TCP connection.
497    pub async fn connect(netstack: Arc<NetStack>, addr: SocketAddr) -> Result<Self> {
498        let handle = netstack.create_tcp_socket();
499        netstack.connect(handle, addr)?;
500
501        // Poll until connected or timeout
502        let start = std::time::Instant::now();
503        let timeout = Duration::from_secs(30);
504
505        loop {
506            netstack.poll();
507
508            let state = netstack.socket_state(handle);
509            log::trace!("TCP state: {:?}", state);
510
511            if state == TcpState::Established {
512                log::info!("TCP connection established to {}", addr);
513                return Ok(Self { netstack, handle });
514            }
515
516            if state == TcpState::Closed || state == TcpState::TimeWait {
517                netstack.remove_socket(handle);
518                return Err(Error::TcpConnect {
519                    addr,
520                    message: format!("Connection failed (state: {:?})", state),
521                });
522            }
523
524            if start.elapsed() > timeout {
525                netstack.remove_socket(handle);
526                return Err(Error::TcpTimeout);
527            }
528
529            tokio::time::sleep(Duration::from_millis(1)).await;
530        }
531    }
532
533    /// Read data from the connection.
534    pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
535        let timeout = Duration::from_secs(30);
536        let start = std::time::Instant::now();
537
538        loop {
539            self.netstack.poll();
540
541            if self.netstack.can_recv(self.handle) {
542                match self.netstack.recv(self.handle, buf) {
543                    Ok(n) if n > 0 => return Ok(n),
544                    Ok(_) => {}
545                    Err(e) => return Err(e),
546                }
547            }
548
549            if !self.netstack.may_recv(self.handle) {
550                // Connection closed
551                return Ok(0);
552            }
553
554            if start.elapsed() > timeout {
555                return Err(Error::ReadTimeout);
556            }
557
558            tokio::time::sleep(Duration::from_millis(1)).await;
559        }
560    }
561
562    /// Write data to the connection.
563    pub async fn write(&self, data: &[u8]) -> Result<usize> {
564        let timeout = Duration::from_secs(30);
565        let start = std::time::Instant::now();
566
567        let mut written = 0;
568
569        while written < data.len() {
570            self.netstack.poll();
571
572            if self.netstack.can_send(self.handle) {
573                match self.netstack.send(self.handle, &data[written..]) {
574                    Ok(n) => {
575                        written += n;
576                        log::trace!("Wrote {} bytes (total: {})", n, written);
577                    }
578                    Err(e) => return Err(e),
579                }
580            }
581
582            if !self.netstack.may_send(self.handle) {
583                // Connection closed
584                return Err(Error::ConnectionClosed);
585            }
586
587            if start.elapsed() > timeout {
588                return Err(Error::WriteTimeout);
589            }
590
591            if written < data.len() {
592                tokio::time::sleep(Duration::from_millis(1)).await;
593            }
594        }
595
596        self.netstack.poll();
597        Ok(written)
598    }
599
600    /// Write all data to the connection.
601    pub async fn write_all(&self, data: &[u8]) -> Result<()> {
602        let n = self.write(data).await?;
603        if n != data.len() {
604            return Err(Error::ShortWrite {
605                written: n,
606                expected: data.len(),
607            });
608        }
609        Ok(())
610    }
611
612    /// Shutdown the connection.
613    pub fn shutdown(&self) {
614        self.netstack.close(self.handle);
615    }
616
617    /// Get the socket handle.
618    pub fn handle(&self) -> SocketHandle {
619        self.handle
620    }
621}
622
623impl Drop for TcpConnection {
624    fn drop(&mut self) {
625        self.netstack.close(self.handle);
626        // Give time for FIN to be sent
627        self.netstack.poll();
628    }
629}