Skip to main content

vm_rs/network/
port_forward.rs

1//! TCP port forwarding from host to VM.
2//!
3//! When a service publishes ports (e.g., "8080:80"), we bind a TCP listener
4//! on the host and proxy connections to the VM's IP.
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use tokio::net::{TcpListener, TcpStream};
10use tokio::sync::Notify;
11
12/// A running port forwarder. Proxies TCP connections from a host port to a VM port.
13pub struct PortForwarder {
14    stop: Arc<Notify>,
15    handle: tokio::task::JoinHandle<()>,
16    /// The host address being listened on.
17    pub bind_addr: SocketAddr,
18    /// The host port being listened on.
19    pub host_port: u16,
20    /// The target address (VM IP + port).
21    pub target: SocketAddr,
22}
23
24impl PortForwarder {
25    /// Start forwarding `host_port` on loopback to `target_ip:target_port`.
26    pub async fn start(
27        host_port: u16,
28        target_ip: &str,
29        target_port: u16,
30    ) -> Result<Self, PortForwardError> {
31        Self::start_on("127.0.0.1", host_port, target_ip, target_port).await
32    }
33
34    /// Start forwarding `host_port` on a specific host bind address.
35    pub async fn start_on(
36        bind_ip: &str,
37        host_port: u16,
38        target_ip: &str,
39        target_port: u16,
40    ) -> Result<Self, PortForwardError> {
41        let bind_addr: SocketAddr = format!("{}:{}", bind_ip, host_port)
42            .parse()
43            .map_err(|e| PortForwardError::InvalidBindAddress(format!("{}", e)))?;
44        let target: SocketAddr = format!("{}:{}", target_ip, target_port)
45            .parse()
46            .map_err(|e| PortForwardError::InvalidTarget(format!("{}", e)))?;
47
48        let listener =
49            TcpListener::bind(bind_addr)
50                .await
51                .map_err(|e| PortForwardError::BindFailed {
52                    address: bind_addr,
53                    detail: format!("{}", e),
54                })?;
55
56        tracing::info!(bind = %bind_addr, target = %target, "port forwarder started");
57
58        let stop = Arc::new(Notify::new());
59        let stop_clone = Arc::clone(&stop);
60
61        let handle = tokio::spawn(async move {
62            loop {
63                tokio::select! {
64                    result = listener.accept() => {
65                        match result {
66                            Ok((client, _)) => {
67                                tokio::spawn(async move {
68                                    proxy(client, target).await;
69                                });
70                            }
71                            Err(e) => {
72                                tracing::error!("port forwarder accept error: {}", e);
73                                break;
74                            }
75                        }
76                    }
77                    _ = stop_clone.notified() => break,
78                }
79            }
80        });
81
82        Ok(PortForwarder {
83            stop,
84            handle,
85            bind_addr,
86            host_port,
87            target,
88        })
89    }
90
91    /// Stop forwarding and clean up.
92    pub fn stop(self) {
93        self.stop.notify_one();
94        self.handle.abort();
95    }
96}
97
98/// Proxy TCP traffic bidirectionally between client and server.
99async fn proxy(mut client: TcpStream, target: SocketAddr) {
100    let mut server = match tokio::time::timeout(
101        std::time::Duration::from_secs(5),
102        TcpStream::connect(target),
103    )
104    .await
105    {
106        Ok(Ok(s)) => s,
107        Ok(Err(e)) => {
108            tracing::warn!("port forward connect failed to {}: {}", target, e);
109            return;
110        }
111        Err(_) => {
112            tracing::warn!("port forward connect timeout to {}", target);
113            return;
114        }
115    };
116
117    if let Err(e) = tokio::io::copy_bidirectional(&mut client, &mut server).await {
118        tracing::warn!("port forward proxy error: {}", e);
119    }
120}
121
122/// Port forwarding errors.
123#[derive(Debug, thiserror::Error)]
124pub enum PortForwardError {
125    #[error("invalid bind address: {0}")]
126    InvalidBindAddress(String),
127
128    #[error("invalid target address: {0}")]
129    InvalidTarget(String),
130
131    #[error("cannot bind {address}: {detail}")]
132    BindFailed { address: SocketAddr, detail: String },
133}