vm_rs/network/
port_forward.rs1use std::net::SocketAddr;
7use std::sync::Arc;
8
9use tokio::net::{TcpListener, TcpStream};
10use tokio::sync::Notify;
11
12pub struct PortForwarder {
14 stop: Arc<Notify>,
15 handle: tokio::task::JoinHandle<()>,
16 pub bind_addr: SocketAddr,
18 pub host_port: u16,
20 pub target: SocketAddr,
22}
23
24impl PortForwarder {
25 pub async fn start(
27 host_port: u16,
28 target_ip: &str,
29 target_port: u16,
30 ) -> Result<Self, PortForwardError> {
31 Self::start_on("127.0.0.1", host_port, target_ip, target_port).await
32 }
33
34 pub async fn start_on(
36 bind_ip: &str,
37 host_port: u16,
38 target_ip: &str,
39 target_port: u16,
40 ) -> Result<Self, PortForwardError> {
41 let bind_addr: SocketAddr = format!("{}:{}", bind_ip, host_port)
42 .parse()
43 .map_err(|e| PortForwardError::InvalidBindAddress(format!("{}", e)))?;
44 let target: SocketAddr = format!("{}:{}", target_ip, target_port)
45 .parse()
46 .map_err(|e| PortForwardError::InvalidTarget(format!("{}", e)))?;
47
48 let listener =
49 TcpListener::bind(bind_addr)
50 .await
51 .map_err(|e| PortForwardError::BindFailed {
52 address: bind_addr,
53 detail: format!("{}", e),
54 })?;
55
56 tracing::info!(bind = %bind_addr, target = %target, "port forwarder started");
57
58 let stop = Arc::new(Notify::new());
59 let stop_clone = Arc::clone(&stop);
60
61 let handle = tokio::spawn(async move {
62 loop {
63 tokio::select! {
64 result = listener.accept() => {
65 match result {
66 Ok((client, _)) => {
67 tokio::spawn(async move {
68 proxy(client, target).await;
69 });
70 }
71 Err(e) => {
72 tracing::error!("port forwarder accept error: {}", e);
73 break;
74 }
75 }
76 }
77 _ = stop_clone.notified() => break,
78 }
79 }
80 });
81
82 Ok(PortForwarder {
83 stop,
84 handle,
85 bind_addr,
86 host_port,
87 target,
88 })
89 }
90
91 pub fn stop(self) {
93 self.stop.notify_one();
94 self.handle.abort();
95 }
96}
97
98async fn proxy(mut client: TcpStream, target: SocketAddr) {
100 let mut server = match tokio::time::timeout(
101 std::time::Duration::from_secs(5),
102 TcpStream::connect(target),
103 )
104 .await
105 {
106 Ok(Ok(s)) => s,
107 Ok(Err(e)) => {
108 tracing::warn!("port forward connect failed to {}: {}", target, e);
109 return;
110 }
111 Err(_) => {
112 tracing::warn!("port forward connect timeout to {}", target);
113 return;
114 }
115 };
116
117 if let Err(e) = tokio::io::copy_bidirectional(&mut client, &mut server).await {
118 tracing::warn!("port forward proxy error: {}", e);
119 }
120}
121
122#[derive(Debug, thiserror::Error)]
124pub enum PortForwardError {
125 #[error("invalid bind address: {0}")]
126 InvalidBindAddress(String),
127
128 #[error("invalid target address: {0}")]
129 InvalidTarget(String),
130
131 #[error("cannot bind {address}: {detail}")]
132 BindFailed { address: SocketAddr, detail: String },
133}