Skip to main content

smolvm_network/
tcp_relay.rs

1//! TCP relay support for the virtio-net backend.
2//!
3//! Context
4//! =======
5//!
6//! In the Phase 1 virtio-net design, guest TCP does not flow directly from the
7//! guest to the outside network through the host kernel. Instead, the host-side
8//! smoltcp runtime terminates the guest-visible TCP connection in userspace and
9//! relays payloads to a normal host `TcpStream`.
10//!
11//! Conceptually:
12//!
13//! ```text
14//! guest app
15//!   -> guest kernel TCP
16//!   -> Ethernet frame
17//!   -> smoltcp TCP socket (inside smolvm)
18//!   -> channel
19//!   -> host TcpStream
20//!   -> remote server
21//! ```
22//!
23//! That means:
24//! - the host runtime can observe every guest TCP byte stream on this NIC
25//! - smoltcp owns the guest-facing TCP state machine
26//! - the relay thread owns the host-facing TCP socket
27//! - channels bridge payloads between them
28
29use crate::queues::WakePipe;
30use crate::virtio_net_log;
31use smoltcp::iface::{Interface, SocketHandle, SocketSet};
32use smoltcp::socket::tcp;
33use smoltcp::wire::IpListenEndpoint;
34use std::collections::{HashMap, HashSet};
35use std::io::{self, Read, Write};
36use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream};
37use std::sync::atomic::{AtomicU8, Ordering};
38use std::sync::mpsc::{self, Receiver, SyncSender, TryRecvError};
39use std::sync::Arc;
40use std::thread;
41use std::time::Duration;
42
43const TCP_RX_BUFFER_BYTES: usize = 64 * 1024;
44const TCP_TX_BUFFER_BYTES: usize = 64 * 1024;
45const MAX_CONNECTIONS: usize = 256;
46const CHANNEL_CAPACITY: usize = 32;
47const RELAY_BUFFER_BYTES: usize = 16 * 1024;
48const CLOSE_RETRY_LIMIT: u16 = 64;
49const PROXY_IDLE_SLEEP: Duration = Duration::from_millis(10);
50const PUBLISHED_PORT_START: u16 = 49_152;
51const PUBLISHED_PORT_END: u16 = 65_535;
52
53/// Track all active guest TCP connections bridged through host sockets.
54///
55/// One entry corresponds to one `(guest source, destination)` tuple. The table
56/// lives in the smoltcp poll thread and owns all guest-facing socket handles.
57pub struct TcpRelayTable {
58    connections: HashMap<SocketHandle, TrackedConnection>,
59    connection_keys: HashSet<(SocketAddr, SocketAddr)>,
60    used_published_ports: HashSet<u16>,
61    next_published_port: u16,
62    max_connections: usize,
63}
64
65/// Newly established guest connection ready for a host relay thread.
66///
67/// The poll loop emits these once the guest-side smoltcp socket reaches
68/// `Established`. At that point we can safely create the host-side relay
69/// thread and give it channel endpoints for payload exchange.
70pub struct NewTcpConnection {
71    /// Destination originally requested by the guest.
72    pub destination: SocketAddr,
73    /// How the host-side relay should be started.
74    pub relay_target: RelayTarget,
75    /// Guest-to-host payloads read from the smoltcp socket.
76    pub from_smoltcp: Receiver<Vec<u8>>,
77    /// Host-to-guest payloads written back into the smoltcp socket.
78    pub to_smoltcp: SyncSender<Vec<u8>>,
79    /// Shared relay exit state.
80    pub exit_state: RelayExitState,
81}
82
83#[derive(Debug)]
84struct TrackedConnection {
85    // `source` and `destination` identify the guest-side flow.
86    source: SocketAddr,
87    destination: SocketAddr,
88    // guest -> host relay payloads
89    to_proxy: SyncSender<Vec<u8>>,
90    // host -> guest relay payloads
91    from_proxy: Receiver<Vec<u8>>,
92    // endpoints are held here until the guest-side handshake completes
93    pending_proxy_endpoints: Option<PendingProxyEndpoints>,
94    // once true, a dedicated host relay thread exists
95    relay_spawned: bool,
96    // partial host->guest payload not yet fully accepted by smoltcp
97    buffered_proxy_data: Option<(Vec<u8>, usize)>,
98    // bounded retry count for closing with unsent buffered data
99    close_attempts: u16,
100    // relay thread termination mode observed by the poll loop
101    exit_state: RelayExitState,
102    // reserved local source port for published inbound connections
103    reserved_published_port: Option<u16>,
104}
105
106#[derive(Debug)]
107struct PendingProxyEndpoints {
108    from_smoltcp: Receiver<Vec<u8>>,
109    to_smoltcp: SyncSender<Vec<u8>>,
110    relay_target: RelayTarget,
111}
112
113/// How a host-side TCP relay should obtain its remote socket.
114#[derive(Debug)]
115pub enum RelayTarget {
116    /// Open a new outbound host `TcpStream` to the destination.
117    Connect(SocketAddr),
118    /// Use an already-accepted host `TcpStream` from a published port listener.
119    Attached(TcpStream),
120}
121
122/// Host relay termination state shared between the poll loop and the relay thread.
123///
124/// The relay thread cannot mutate smoltcp sockets directly because those sockets
125/// are owned by the poll loop thread. Instead it reports how it finished, and
126/// the poll loop interprets that into guest-side socket actions:
127/// - `Graceful` -> close guest socket cleanly
128/// - `Abort`    -> abort/reset guest socket
129#[derive(Clone, Debug)]
130pub struct RelayExitState {
131    inner: Arc<AtomicU8>,
132}
133
134/// How a host TCP relay thread terminated.
135#[derive(Clone, Copy, Debug, PartialEq, Eq)]
136#[repr(u8)]
137pub enum RelayExitMode {
138    /// Relay thread is still running.
139    Running = 0,
140    /// Remote side closed normally; send FIN toward the guest.
141    Graceful = 1,
142    /// Remote connect or I/O failed; abort the guest TCP socket.
143    Abort = 2,
144}
145
146impl RelayExitState {
147    fn new() -> Self {
148        Self {
149            inner: Arc::new(AtomicU8::new(RelayExitMode::Running as u8)),
150        }
151    }
152
153    fn load(&self) -> RelayExitMode {
154        match self.inner.load(Ordering::Relaxed) {
155            1 => RelayExitMode::Graceful,
156            2 => RelayExitMode::Abort,
157            _ => RelayExitMode::Running,
158        }
159    }
160
161    fn store(&self, mode: RelayExitMode) {
162        self.inner.store(mode as u8, Ordering::Relaxed);
163    }
164}
165
166impl TcpRelayTable {
167    /// Create a new relay table.
168    pub fn new(max_connections: Option<usize>) -> Self {
169        Self {
170            connections: HashMap::new(),
171            connection_keys: HashSet::new(),
172            used_published_ports: HashSet::new(),
173            next_published_port: PUBLISHED_PORT_START,
174            max_connections: max_connections.unwrap_or(MAX_CONNECTIONS),
175        }
176    }
177
178    /// Whether a relay socket already exists for the same guest source and destination.
179    pub fn has_socket_for(&self, source: &SocketAddr, destination: &SocketAddr) -> bool {
180        self.connection_keys.contains(&(*source, *destination))
181    }
182
183    /// Create a smoltcp TCP socket for a guest SYN.
184    ///
185    /// Why this happens before full ingress processing:
186    /// - when the first guest SYN arrives, smoltcp needs a matching socket to
187    ///   receive it
188    /// - the poll loop therefore pre-creates a listening socket keyed to the
189    ///   destination the guest is trying to reach
190    /// - only after the guest-facing connection reaches `Established` do we
191    ///   spawn the host relay thread
192    ///
193    /// Data path after creation:
194    ///
195    /// ```text
196    /// smoltcp socket --to_proxy channel--> host relay thread
197    /// host relay thread --from_proxy channel--> smoltcp socket
198    /// ```
199    pub fn create_tcp_socket(
200        &mut self,
201        source: SocketAddr,
202        destination: SocketAddr,
203        sockets: &mut SocketSet<'_>,
204    ) -> bool {
205        if self.connections.len() >= self.max_connections {
206            tracing::warn!("dropping TCP connection because the relay table is full");
207            return false;
208        }
209
210        let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
211        let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
212        let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
213        let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
214            return false;
215        };
216
217        let listen_endpoint = IpListenEndpoint {
218            addr: Some(destination_ip.into()),
219            port: destination.port(),
220        };
221        if socket.listen(listen_endpoint).is_err() {
222            return false;
223        }
224
225        let handle = sockets.add(socket);
226
227        let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
228        let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
229        let exit_state = RelayExitState::new();
230
231        self.connection_keys.insert((source, destination));
232        self.connections.insert(
233            handle,
234            TrackedConnection {
235                source,
236                destination,
237                to_proxy: to_proxy_tx,
238                from_proxy: from_proxy_rx,
239                pending_proxy_endpoints: Some(PendingProxyEndpoints {
240                    from_smoltcp: to_proxy_rx,
241                    to_smoltcp: from_proxy_tx,
242                    relay_target: RelayTarget::Connect(destination),
243                }),
244                relay_spawned: false,
245                buffered_proxy_data: None,
246                close_attempts: 0,
247                exit_state,
248                reserved_published_port: None,
249            },
250        );
251
252        true
253    }
254
255    /// Create a guest-facing TCP connection for a published host socket.
256    ///
257    /// This is the host->guest mirror of `create_tcp_socket`:
258    ///
259    /// ```text
260    /// host client connects to published port
261    ///   -> host listener accepts TcpStream
262    ///   -> poll loop creates smoltcp TCP socket from gateway_ip:ephemeral
263    ///      to guest_ip:guest_port
264    ///   -> guest kernel sees a normal inbound TCP connection on guest_port
265    /// ```
266    ///
267    /// The guest-visible source address is the gateway IP, not the original
268    /// host peer address. That keeps the first version simple and matches the
269    /// fact that this runtime is acting as a userspace gateway/proxy.
270    pub fn create_published_socket(
271        &mut self,
272        interface: &mut Interface,
273        gateway_ip: Ipv4Addr,
274        destination: SocketAddr,
275        host_stream: TcpStream,
276        sockets: &mut SocketSet<'_>,
277    ) -> bool {
278        if self.connections.len() >= self.max_connections {
279            tracing::warn!("dropping published TCP connection because the relay table is full");
280            return false;
281        }
282
283        let Some(local_port) = self.allocate_published_port() else {
284            tracing::warn!(
285                "dropping published TCP connection because no gateway source port is available"
286            );
287            return false;
288        };
289
290        let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
291            self.used_published_ports.remove(&local_port);
292            return false;
293        };
294
295        let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
296        let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
297        let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
298        let local_endpoint = IpListenEndpoint {
299            addr: Some(gateway_ip.into()),
300            port: local_port,
301        };
302        if socket
303            .connect(
304                interface.context(),
305                (destination_ip, destination.port()),
306                local_endpoint,
307            )
308            .is_err()
309        {
310            self.used_published_ports.remove(&local_port);
311            return false;
312        }
313
314        let handle = sockets.add(socket);
315        let source = SocketAddr::new(std::net::IpAddr::V4(gateway_ip), local_port);
316
317        let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
318        let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
319        let exit_state = RelayExitState::new();
320
321        self.connection_keys.insert((source, destination));
322        self.connections.insert(
323            handle,
324            TrackedConnection {
325                source,
326                destination,
327                to_proxy: to_proxy_tx,
328                from_proxy: from_proxy_rx,
329                pending_proxy_endpoints: Some(PendingProxyEndpoints {
330                    from_smoltcp: to_proxy_rx,
331                    to_smoltcp: from_proxy_tx,
332                    relay_target: RelayTarget::Attached(host_stream),
333                }),
334                relay_spawned: false,
335                buffered_proxy_data: None,
336                close_attempts: 0,
337                exit_state,
338                reserved_published_port: Some(local_port),
339            },
340        );
341
342        true
343    }
344
345    /// Relay TCP payloads between smoltcp sockets and host relay threads.
346    ///
347    /// This runs in the poll thread. It is responsible for:
348    /// - draining bytes received from the guest-facing smoltcp socket and
349    ///   pushing them toward the host relay thread
350    /// - draining bytes received from the host relay thread and writing them
351    ///   back into the smoltcp socket
352    /// - interpreting relay exit state into guest-side `close()` or `abort()`
353    pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
354        let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
355
356        for (&handle, connection) in &mut self.connections {
357            if !connection.relay_spawned {
358                continue;
359            }
360
361            let socket = sockets.get_mut::<tcp::Socket>(handle);
362
363            match connection.exit_state.load() {
364                RelayExitMode::Abort => {
365                    socket.abort();
366                    continue;
367                }
368                RelayExitMode::Graceful => {
369                    flush_proxy_data(socket, connection);
370                    if connection.buffered_proxy_data.is_none() {
371                        socket.close();
372                    } else {
373                        connection.close_attempts += 1;
374                        if connection.close_attempts >= CLOSE_RETRY_LIMIT {
375                            socket.abort();
376                        }
377                    }
378                    continue;
379                }
380                RelayExitMode::Running => {}
381            }
382
383            while socket.can_recv() {
384                match socket.recv_slice(&mut read_buffer) {
385                    Ok(bytes_read) if bytes_read > 0 => {
386                        let payload = read_buffer[..bytes_read].to_vec();
387                        if connection.to_proxy.try_send(payload).is_err() {
388                            break;
389                        }
390                    }
391                    _ => break,
392                }
393            }
394
395            flush_proxy_data(socket, connection);
396        }
397    }
398
399    /// Collect connections that reached ESTABLISHED and need a host relay thread.
400    ///
401    /// The separation between `create_tcp_socket` and this method is important:
402    /// the guest TCP handshake is accepted first on the smoltcp side, and only
403    /// once that succeeds do we commit to opening the host-side `TcpStream`.
404    pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewTcpConnection> {
405        let mut new_connections = Vec::new();
406
407        for (&handle, connection) in &mut self.connections {
408            if connection.relay_spawned {
409                continue;
410            }
411
412            let socket = sockets.get::<tcp::Socket>(handle);
413            if socket.state() == tcp::State::Established {
414                connection.relay_spawned = true;
415
416                if let Some(endpoints) = connection.pending_proxy_endpoints.take() {
417                    new_connections.push(NewTcpConnection {
418                        destination: connection.destination,
419                        relay_target: endpoints.relay_target,
420                        from_smoltcp: endpoints.from_smoltcp,
421                        to_smoltcp: endpoints.to_smoltcp,
422                        exit_state: connection.exit_state.clone(),
423                    });
424                }
425            }
426        }
427
428        new_connections
429    }
430
431    /// Remove closed sockets and drop their relay endpoints.
432    ///
433    /// This is the final ownership cleanup step for a guest TCP flow.
434    pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
435        let keys = &mut self.connection_keys;
436        let published_ports = &mut self.used_published_ports;
437        self.connections.retain(|&handle, connection| {
438            let socket = sockets.get::<tcp::Socket>(handle);
439            if socket.state() == tcp::State::Closed {
440                keys.remove(&(connection.source, connection.destination));
441                if let Some(port) = connection.reserved_published_port {
442                    published_ports.remove(&port);
443                }
444                sockets.remove(handle);
445                false
446            } else {
447                true
448            }
449        });
450    }
451
452    fn allocate_published_port(&mut self) -> Option<u16> {
453        let start = self.next_published_port;
454
455        loop {
456            let candidate = self.next_published_port;
457            self.next_published_port = if candidate == PUBLISHED_PORT_END {
458                PUBLISHED_PORT_START
459            } else {
460                candidate + 1
461            };
462
463            if self.used_published_ports.insert(candidate) {
464                return Some(candidate);
465            }
466
467            if self.next_published_port == start {
468                return None;
469            }
470        }
471    }
472}
473
474/// Spawn one host TCP relay thread for an established guest connection.
475///
476/// Thread responsibilities:
477/// - connect a host `TcpStream` to the guest-requested destination
478/// - copy bytes guest->host from `from_smoltcp`
479/// - copy bytes host->guest into `to_smoltcp`
480/// - wake the poll loop when host->guest data arrives
481/// - report termination mode through `exit_state`
482pub fn spawn_tcp_relay(
483    destination: SocketAddr,
484    relay_target: RelayTarget,
485    from_smoltcp: Receiver<Vec<u8>>,
486    to_smoltcp: SyncSender<Vec<u8>>,
487    relay_wake: Arc<WakePipe>,
488    exit_state: RelayExitState,
489) {
490    let thread_name = format!("smolvm-tcp-{}", destination.port());
491    virtio_net_log!(
492        "virtio-net: spawning host TCP relay thread destination={} thread={}",
493        destination,
494        thread_name
495    );
496    let _ = thread::Builder::new().name(thread_name).spawn(move || {
497        run_tcp_relay(
498            destination,
499            relay_target,
500            from_smoltcp,
501            to_smoltcp,
502            relay_wake,
503            exit_state,
504        )
505    });
506}
507
508fn run_tcp_relay(
509    destination: SocketAddr,
510    relay_target: RelayTarget,
511    from_smoltcp: Receiver<Vec<u8>>,
512    to_smoltcp: SyncSender<Vec<u8>>,
513    relay_wake: Arc<WakePipe>,
514    exit_state: RelayExitState,
515) {
516    // The relay thread is intentionally isolated from smoltcp internals. Its
517    // contract is just channels in, channels out, and an exit code back.
518    virtio_net_log!(
519        "virtio-net: host TCP relay thread started destination={}",
520        destination
521    );
522    match tcp_relay_loop(
523        destination,
524        relay_target,
525        from_smoltcp,
526        to_smoltcp,
527        relay_wake,
528    ) {
529        Ok(mode) => {
530            virtio_net_log!(
531                "virtio-net: host TCP relay thread exited destination={} mode={:?}",
532                destination,
533                mode
534            );
535            exit_state.store(mode)
536        }
537        Err(err) => {
538            virtio_net_log!(
539                "virtio-net: host TCP relay failed destination={} error={}",
540                destination,
541                err
542            );
543            exit_state.store(RelayExitMode::Abort);
544        }
545    }
546}
547
548fn tcp_relay_loop(
549    destination: SocketAddr,
550    relay_target: RelayTarget,
551    from_smoltcp: Receiver<Vec<u8>>,
552    to_smoltcp: SyncSender<Vec<u8>>,
553    relay_wake: Arc<WakePipe>,
554) -> io::Result<RelayExitMode> {
555    // Host-side flow:
556    //
557    // 1. Connect a normal host TcpStream to the destination.
558    // 2. Non-blockingly drain guest payloads from the channel into the socket.
559    // 3. Non-blockingly read remote payloads from the socket into the channel.
560    // 4. If neither side made progress, sleep briefly to avoid a hot spin loop.
561    let mut stream = match relay_target {
562        RelayTarget::Connect(destination) => {
563            virtio_net_log!(
564                "virtio-net: connecting host TCP relay socket destination={}",
565                destination
566            );
567            let stream = TcpStream::connect(destination)?;
568            virtio_net_log!(
569                "virtio-net: host TCP relay socket connected destination={}",
570                destination
571            );
572            stream
573        }
574        RelayTarget::Attached(stream) => {
575            virtio_net_log!(
576                "virtio-net: using accepted host TCP socket for published port guest_destination={} peer_addr={:?} local_addr={:?}",
577                destination,
578                stream.peer_addr().ok(),
579                stream.local_addr().ok()
580            );
581            stream
582        }
583    };
584    stream.set_nonblocking(true)?;
585
586    let mut guest_write_closed = false;
587    let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
588
589    loop {
590        let mut did_work = false;
591
592        loop {
593            match from_smoltcp.try_recv() {
594                Ok(payload) => {
595                    stream.write_all(&payload)?;
596                    did_work = true;
597                }
598                Err(TryRecvError::Empty) => break,
599                Err(TryRecvError::Disconnected) => {
600                    // The guest side closed its write half. Mirror that toward
601                    // the remote peer once, then keep reading until the remote
602                    // side closes too.
603                    if !guest_write_closed {
604                        let _ = stream.shutdown(Shutdown::Write);
605                        guest_write_closed = true;
606                    }
607                    break;
608                }
609            }
610        }
611
612        match stream.read(&mut read_buffer) {
613            Ok(0) => return Ok(RelayExitMode::Graceful),
614            Ok(bytes_read) => {
615                if to_smoltcp.send(read_buffer[..bytes_read].to_vec()).is_err() {
616                    return Ok(RelayExitMode::Graceful);
617                }
618                relay_wake.wake();
619                did_work = true;
620            }
621            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
622            Err(err) => return Err(err),
623        }
624
625        if !did_work {
626            thread::sleep(PROXY_IDLE_SLEEP);
627        }
628    }
629}
630
631fn flush_proxy_data(socket: &mut tcp::Socket<'_>, connection: &mut TrackedConnection) {
632    // smoltcp send windows may accept only part of an inbound host payload.
633    // `buffered_proxy_data` remembers the unwritten remainder so the next poll
634    // iteration can continue where it left off instead of dropping bytes.
635    if let Some((data, offset)) = &mut connection.buffered_proxy_data {
636        if socket.can_send() {
637            match socket.send_slice(&data[*offset..]) {
638                Ok(written) => {
639                    *offset += written;
640                    if *offset >= data.len() {
641                        connection.buffered_proxy_data = None;
642                    }
643                }
644                Err(_) => return,
645            }
646        } else {
647            return;
648        }
649    }
650
651    while connection.buffered_proxy_data.is_none() {
652        match connection.from_proxy.try_recv() {
653            Ok(payload) => {
654                if socket.can_send() {
655                    match socket.send_slice(&payload) {
656                        Ok(written) if written < payload.len() => {
657                            connection.buffered_proxy_data = Some((payload, written));
658                        }
659                        Err(_) => {
660                            connection.buffered_proxy_data = Some((payload, 0));
661                        }
662                        _ => {}
663                    }
664                } else {
665                    connection.buffered_proxy_data = Some((payload, 0));
666                }
667            }
668            Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
669        }
670    }
671}