1use crate::queues::WakePipe;
30use crate::virtio_net_log;
31use smoltcp::iface::{Interface, SocketHandle, SocketSet};
32use smoltcp::socket::tcp;
33use smoltcp::wire::IpListenEndpoint;
34use std::collections::{HashMap, HashSet};
35use std::io::{self, Read, Write};
36use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream};
37use std::sync::atomic::{AtomicU8, Ordering};
38use std::sync::mpsc::{self, Receiver, SyncSender, TryRecvError};
39use std::sync::Arc;
40use std::thread;
41use std::time::Duration;
42
43const TCP_RX_BUFFER_BYTES: usize = 64 * 1024;
44const TCP_TX_BUFFER_BYTES: usize = 64 * 1024;
45const MAX_CONNECTIONS: usize = 256;
46const CHANNEL_CAPACITY: usize = 32;
47const RELAY_BUFFER_BYTES: usize = 16 * 1024;
48const CLOSE_RETRY_LIMIT: u16 = 64;
49const PROXY_IDLE_SLEEP: Duration = Duration::from_millis(10);
50const PUBLISHED_PORT_START: u16 = 49_152;
51const PUBLISHED_PORT_END: u16 = 65_535;
52
53pub struct TcpRelayTable {
58 connections: HashMap<SocketHandle, TrackedConnection>,
59 connection_keys: HashSet<(SocketAddr, SocketAddr)>,
60 used_published_ports: HashSet<u16>,
61 next_published_port: u16,
62 max_connections: usize,
63}
64
65pub struct NewTcpConnection {
71 pub destination: SocketAddr,
73 pub relay_target: RelayTarget,
75 pub from_smoltcp: Receiver<Vec<u8>>,
77 pub to_smoltcp: SyncSender<Vec<u8>>,
79 pub exit_state: RelayExitState,
81}
82
83#[derive(Debug)]
84struct TrackedConnection {
85 source: SocketAddr,
87 destination: SocketAddr,
88 to_proxy: SyncSender<Vec<u8>>,
90 from_proxy: Receiver<Vec<u8>>,
92 pending_proxy_endpoints: Option<PendingProxyEndpoints>,
94 relay_spawned: bool,
96 buffered_proxy_data: Option<(Vec<u8>, usize)>,
98 close_attempts: u16,
100 exit_state: RelayExitState,
102 reserved_published_port: Option<u16>,
104}
105
106#[derive(Debug)]
107struct PendingProxyEndpoints {
108 from_smoltcp: Receiver<Vec<u8>>,
109 to_smoltcp: SyncSender<Vec<u8>>,
110 relay_target: RelayTarget,
111}
112
113#[derive(Debug)]
115pub enum RelayTarget {
116 Connect(SocketAddr),
118 Attached(TcpStream),
120}
121
122#[derive(Clone, Debug)]
130pub struct RelayExitState {
131 inner: Arc<AtomicU8>,
132}
133
134#[derive(Clone, Copy, Debug, PartialEq, Eq)]
136#[repr(u8)]
137pub enum RelayExitMode {
138 Running = 0,
140 Graceful = 1,
142 Abort = 2,
144}
145
146impl RelayExitState {
147 fn new() -> Self {
148 Self {
149 inner: Arc::new(AtomicU8::new(RelayExitMode::Running as u8)),
150 }
151 }
152
153 fn load(&self) -> RelayExitMode {
154 match self.inner.load(Ordering::Relaxed) {
155 1 => RelayExitMode::Graceful,
156 2 => RelayExitMode::Abort,
157 _ => RelayExitMode::Running,
158 }
159 }
160
161 fn store(&self, mode: RelayExitMode) {
162 self.inner.store(mode as u8, Ordering::Relaxed);
163 }
164}
165
166impl TcpRelayTable {
167 pub fn new(max_connections: Option<usize>) -> Self {
169 Self {
170 connections: HashMap::new(),
171 connection_keys: HashSet::new(),
172 used_published_ports: HashSet::new(),
173 next_published_port: PUBLISHED_PORT_START,
174 max_connections: max_connections.unwrap_or(MAX_CONNECTIONS),
175 }
176 }
177
178 pub fn has_socket_for(&self, source: &SocketAddr, destination: &SocketAddr) -> bool {
180 self.connection_keys.contains(&(*source, *destination))
181 }
182
183 pub fn create_tcp_socket(
200 &mut self,
201 source: SocketAddr,
202 destination: SocketAddr,
203 sockets: &mut SocketSet<'_>,
204 ) -> bool {
205 if self.connections.len() >= self.max_connections {
206 tracing::warn!("dropping TCP connection because the relay table is full");
207 return false;
208 }
209
210 let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
211 let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
212 let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
213 let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
214 return false;
215 };
216
217 let listen_endpoint = IpListenEndpoint {
218 addr: Some(destination_ip.into()),
219 port: destination.port(),
220 };
221 if socket.listen(listen_endpoint).is_err() {
222 return false;
223 }
224
225 let handle = sockets.add(socket);
226
227 let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
228 let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
229 let exit_state = RelayExitState::new();
230
231 self.connection_keys.insert((source, destination));
232 self.connections.insert(
233 handle,
234 TrackedConnection {
235 source,
236 destination,
237 to_proxy: to_proxy_tx,
238 from_proxy: from_proxy_rx,
239 pending_proxy_endpoints: Some(PendingProxyEndpoints {
240 from_smoltcp: to_proxy_rx,
241 to_smoltcp: from_proxy_tx,
242 relay_target: RelayTarget::Connect(destination),
243 }),
244 relay_spawned: false,
245 buffered_proxy_data: None,
246 close_attempts: 0,
247 exit_state,
248 reserved_published_port: None,
249 },
250 );
251
252 true
253 }
254
255 pub fn create_published_socket(
271 &mut self,
272 interface: &mut Interface,
273 gateway_ip: Ipv4Addr,
274 destination: SocketAddr,
275 host_stream: TcpStream,
276 sockets: &mut SocketSet<'_>,
277 ) -> bool {
278 if self.connections.len() >= self.max_connections {
279 tracing::warn!("dropping published TCP connection because the relay table is full");
280 return false;
281 }
282
283 let Some(local_port) = self.allocate_published_port() else {
284 tracing::warn!(
285 "dropping published TCP connection because no gateway source port is available"
286 );
287 return false;
288 };
289
290 let std::net::IpAddr::V4(destination_ip) = destination.ip() else {
291 self.used_published_ports.remove(&local_port);
292 return false;
293 };
294
295 let rx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUFFER_BYTES]);
296 let tx_buffer = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUFFER_BYTES]);
297 let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
298 let local_endpoint = IpListenEndpoint {
299 addr: Some(gateway_ip.into()),
300 port: local_port,
301 };
302 if socket
303 .connect(
304 interface.context(),
305 (destination_ip, destination.port()),
306 local_endpoint,
307 )
308 .is_err()
309 {
310 self.used_published_ports.remove(&local_port);
311 return false;
312 }
313
314 let handle = sockets.add(socket);
315 let source = SocketAddr::new(std::net::IpAddr::V4(gateway_ip), local_port);
316
317 let (to_proxy_tx, to_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
318 let (from_proxy_tx, from_proxy_rx) = mpsc::sync_channel(CHANNEL_CAPACITY);
319 let exit_state = RelayExitState::new();
320
321 self.connection_keys.insert((source, destination));
322 self.connections.insert(
323 handle,
324 TrackedConnection {
325 source,
326 destination,
327 to_proxy: to_proxy_tx,
328 from_proxy: from_proxy_rx,
329 pending_proxy_endpoints: Some(PendingProxyEndpoints {
330 from_smoltcp: to_proxy_rx,
331 to_smoltcp: from_proxy_tx,
332 relay_target: RelayTarget::Attached(host_stream),
333 }),
334 relay_spawned: false,
335 buffered_proxy_data: None,
336 close_attempts: 0,
337 exit_state,
338 reserved_published_port: Some(local_port),
339 },
340 );
341
342 true
343 }
344
345 pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
354 let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
355
356 for (&handle, connection) in &mut self.connections {
357 if !connection.relay_spawned {
358 continue;
359 }
360
361 let socket = sockets.get_mut::<tcp::Socket>(handle);
362
363 match connection.exit_state.load() {
364 RelayExitMode::Abort => {
365 socket.abort();
366 continue;
367 }
368 RelayExitMode::Graceful => {
369 flush_proxy_data(socket, connection);
370 if connection.buffered_proxy_data.is_none() {
371 socket.close();
372 } else {
373 connection.close_attempts += 1;
374 if connection.close_attempts >= CLOSE_RETRY_LIMIT {
375 socket.abort();
376 }
377 }
378 continue;
379 }
380 RelayExitMode::Running => {}
381 }
382
383 while socket.can_recv() {
384 match socket.recv_slice(&mut read_buffer) {
385 Ok(bytes_read) if bytes_read > 0 => {
386 let payload = read_buffer[..bytes_read].to_vec();
387 if connection.to_proxy.try_send(payload).is_err() {
388 break;
389 }
390 }
391 _ => break,
392 }
393 }
394
395 flush_proxy_data(socket, connection);
396 }
397 }
398
399 pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewTcpConnection> {
405 let mut new_connections = Vec::new();
406
407 for (&handle, connection) in &mut self.connections {
408 if connection.relay_spawned {
409 continue;
410 }
411
412 let socket = sockets.get::<tcp::Socket>(handle);
413 if socket.state() == tcp::State::Established {
414 connection.relay_spawned = true;
415
416 if let Some(endpoints) = connection.pending_proxy_endpoints.take() {
417 new_connections.push(NewTcpConnection {
418 destination: connection.destination,
419 relay_target: endpoints.relay_target,
420 from_smoltcp: endpoints.from_smoltcp,
421 to_smoltcp: endpoints.to_smoltcp,
422 exit_state: connection.exit_state.clone(),
423 });
424 }
425 }
426 }
427
428 new_connections
429 }
430
431 pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
435 let keys = &mut self.connection_keys;
436 let published_ports = &mut self.used_published_ports;
437 self.connections.retain(|&handle, connection| {
438 let socket = sockets.get::<tcp::Socket>(handle);
439 if socket.state() == tcp::State::Closed {
440 keys.remove(&(connection.source, connection.destination));
441 if let Some(port) = connection.reserved_published_port {
442 published_ports.remove(&port);
443 }
444 sockets.remove(handle);
445 false
446 } else {
447 true
448 }
449 });
450 }
451
452 fn allocate_published_port(&mut self) -> Option<u16> {
453 let start = self.next_published_port;
454
455 loop {
456 let candidate = self.next_published_port;
457 self.next_published_port = if candidate == PUBLISHED_PORT_END {
458 PUBLISHED_PORT_START
459 } else {
460 candidate + 1
461 };
462
463 if self.used_published_ports.insert(candidate) {
464 return Some(candidate);
465 }
466
467 if self.next_published_port == start {
468 return None;
469 }
470 }
471 }
472}
473
474pub fn spawn_tcp_relay(
483 destination: SocketAddr,
484 relay_target: RelayTarget,
485 from_smoltcp: Receiver<Vec<u8>>,
486 to_smoltcp: SyncSender<Vec<u8>>,
487 relay_wake: Arc<WakePipe>,
488 exit_state: RelayExitState,
489) {
490 let thread_name = format!("smolvm-tcp-{}", destination.port());
491 virtio_net_log!(
492 "virtio-net: spawning host TCP relay thread destination={} thread={}",
493 destination,
494 thread_name
495 );
496 let _ = thread::Builder::new().name(thread_name).spawn(move || {
497 run_tcp_relay(
498 destination,
499 relay_target,
500 from_smoltcp,
501 to_smoltcp,
502 relay_wake,
503 exit_state,
504 )
505 });
506}
507
508fn run_tcp_relay(
509 destination: SocketAddr,
510 relay_target: RelayTarget,
511 from_smoltcp: Receiver<Vec<u8>>,
512 to_smoltcp: SyncSender<Vec<u8>>,
513 relay_wake: Arc<WakePipe>,
514 exit_state: RelayExitState,
515) {
516 virtio_net_log!(
519 "virtio-net: host TCP relay thread started destination={}",
520 destination
521 );
522 match tcp_relay_loop(
523 destination,
524 relay_target,
525 from_smoltcp,
526 to_smoltcp,
527 relay_wake,
528 ) {
529 Ok(mode) => {
530 virtio_net_log!(
531 "virtio-net: host TCP relay thread exited destination={} mode={:?}",
532 destination,
533 mode
534 );
535 exit_state.store(mode)
536 }
537 Err(err) => {
538 virtio_net_log!(
539 "virtio-net: host TCP relay failed destination={} error={}",
540 destination,
541 err
542 );
543 exit_state.store(RelayExitMode::Abort);
544 }
545 }
546}
547
548fn tcp_relay_loop(
549 destination: SocketAddr,
550 relay_target: RelayTarget,
551 from_smoltcp: Receiver<Vec<u8>>,
552 to_smoltcp: SyncSender<Vec<u8>>,
553 relay_wake: Arc<WakePipe>,
554) -> io::Result<RelayExitMode> {
555 let mut stream = match relay_target {
562 RelayTarget::Connect(destination) => {
563 virtio_net_log!(
564 "virtio-net: connecting host TCP relay socket destination={}",
565 destination
566 );
567 let stream = TcpStream::connect(destination)?;
568 virtio_net_log!(
569 "virtio-net: host TCP relay socket connected destination={}",
570 destination
571 );
572 stream
573 }
574 RelayTarget::Attached(stream) => {
575 virtio_net_log!(
576 "virtio-net: using accepted host TCP socket for published port guest_destination={} peer_addr={:?} local_addr={:?}",
577 destination,
578 stream.peer_addr().ok(),
579 stream.local_addr().ok()
580 );
581 stream
582 }
583 };
584 stream.set_nonblocking(true)?;
585
586 let mut guest_write_closed = false;
587 let mut read_buffer = [0u8; RELAY_BUFFER_BYTES];
588
589 loop {
590 let mut did_work = false;
591
592 loop {
593 match from_smoltcp.try_recv() {
594 Ok(payload) => {
595 stream.write_all(&payload)?;
596 did_work = true;
597 }
598 Err(TryRecvError::Empty) => break,
599 Err(TryRecvError::Disconnected) => {
600 if !guest_write_closed {
604 let _ = stream.shutdown(Shutdown::Write);
605 guest_write_closed = true;
606 }
607 break;
608 }
609 }
610 }
611
612 match stream.read(&mut read_buffer) {
613 Ok(0) => return Ok(RelayExitMode::Graceful),
614 Ok(bytes_read) => {
615 if to_smoltcp.send(read_buffer[..bytes_read].to_vec()).is_err() {
616 return Ok(RelayExitMode::Graceful);
617 }
618 relay_wake.wake();
619 did_work = true;
620 }
621 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
622 Err(err) => return Err(err),
623 }
624
625 if !did_work {
626 thread::sleep(PROXY_IDLE_SLEEP);
627 }
628 }
629}
630
631fn flush_proxy_data(socket: &mut tcp::Socket<'_>, connection: &mut TrackedConnection) {
632 if let Some((data, offset)) = &mut connection.buffered_proxy_data {
636 if socket.can_send() {
637 match socket.send_slice(&data[*offset..]) {
638 Ok(written) => {
639 *offset += written;
640 if *offset >= data.len() {
641 connection.buffered_proxy_data = None;
642 }
643 }
644 Err(_) => return,
645 }
646 } else {
647 return;
648 }
649 }
650
651 while connection.buffered_proxy_data.is_none() {
652 match connection.from_proxy.try_recv() {
653 Ok(payload) => {
654 if socket.can_send() {
655 match socket.send_slice(&payload) {
656 Ok(written) if written < payload.len() => {
657 connection.buffered_proxy_data = Some((payload, written));
658 }
659 Err(_) => {
660 connection.buffered_proxy_data = Some((payload, 0));
661 }
662 _ => {}
663 }
664 } else {
665 connection.buffered_proxy_data = Some((payload, 0));
666 }
667 }
668 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
669 }
670 }
671}