Skip to main content

xbp_cli/commands/
cloudflared_access.rs

1use std::net::Ipv4Addr;
2use std::path::PathBuf;
3use std::process::Stdio;
4
5use tokio::net::TcpStream;
6use tokio::process::{Child, Command};
7use tokio::signal;
8use tokio::time::{sleep, Duration, Instant};
9use tracing::debug;
10
11const CLOUDFLARED_READY_TIMEOUT: Duration = Duration::from_secs(10);
12const CLOUDFLARED_READY_POLL_INTERVAL: Duration = Duration::from_millis(120);
13
14#[derive(Debug, Clone)]
15pub struct CloudflaredTcpOptions {
16    pub hostname: String,
17    pub listener: Option<String>,
18    pub destination: Option<String>,
19    pub binary_path: Option<PathBuf>,
20}
21
22pub struct CloudflaredTunnel {
23    child: Child,
24    pub hostname: String,
25    pub listener_addr: String,
26    pub local_port: u16,
27}
28
29impl CloudflaredTunnel {
30    pub async fn start(options: CloudflaredTcpOptions, debug_mode: bool) -> Result<Self, String> {
31        let listener_addr = match options.listener {
32            Some(listener) => listener,
33            None => format!("127.0.0.1:{}", reserve_local_port()?),
34        };
35        let local_port = parse_listener_port(&listener_addr)?;
36        let binary = options
37            .binary_path
38            .unwrap_or_else(|| PathBuf::from("cloudflared"));
39
40        let mut command = Command::new(&binary);
41        command
42            .arg("access")
43            .arg("tcp")
44            .arg("--hostname")
45            .arg(&options.hostname)
46            .arg("--url")
47            .arg(&listener_addr)
48            .stdin(Stdio::null())
49            .stdout(Stdio::null())
50            .stderr(Stdio::inherit());
51
52        if let Some(destination) = options.destination.and_then(normalize_optional_string) {
53            command.arg("--destination").arg(destination);
54        }
55
56        if debug_mode {
57            debug!(
58                "Starting cloudflared => bin: {}, hostname: {}, listen: {}",
59                binary.display(),
60                options.hostname,
61                listener_addr
62            );
63        }
64
65        let mut child = command
66            .spawn()
67            .map_err(|e| format!("Failed to start cloudflared: {}", e))?;
68        wait_for_forwarder(&mut child, &listener_addr).await?;
69
70        Ok(Self {
71            child,
72            hostname: options.hostname,
73            listener_addr,
74            local_port,
75        })
76    }
77
78    pub async fn shutdown(&mut self) {
79        match self.child.try_wait() {
80            Ok(Some(_)) => {}
81            Ok(None) => {
82                let _ = self.child.kill().await;
83                let _ = self.child.wait().await;
84            }
85            Err(_) => {}
86        }
87    }
88}
89
90pub async fn run_cloudflared_tcp(
91    options: CloudflaredTcpOptions,
92    debug_mode: bool,
93) -> Result<(), String> {
94    let mut tunnel = CloudflaredTunnel::start(options, debug_mode).await?;
95
96    println!("cloudflared tcp ready on {}", tunnel.listener_addr);
97    println!("Press Ctrl+C to stop the forwarder.");
98
99    let ctrl_c = signal::ctrl_c().await;
100    tunnel.shutdown().await;
101
102    ctrl_c.map_err(|e| format!("Failed to wait for Ctrl+C: {}", e))?;
103    Ok(())
104}
105
106fn parse_listener_port(listener_addr: &str) -> Result<u16, String> {
107    listener_addr
108        .rsplit_once(':')
109        .ok_or_else(|| format!("Invalid listener address `{listener_addr}`"))?
110        .1
111        .parse::<u16>()
112        .map_err(|e| format!("Invalid listener port in `{listener_addr}`: {}", e))
113}
114
115fn reserve_local_port() -> Result<u16, String> {
116    let listener = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
117        .map_err(|e| format!("Failed to reserve local port: {}", e))?;
118    listener
119        .local_addr()
120        .map(|addr| addr.port())
121        .map_err(|e| format!("Failed to inspect local listener address: {}", e))
122}
123
124async fn wait_for_forwarder(child: &mut Child, listener_addr: &str) -> Result<(), String> {
125    let deadline = Instant::now() + CLOUDFLARED_READY_TIMEOUT;
126
127    loop {
128        if let Ok(stream) = TcpStream::connect(listener_addr).await {
129            drop(stream);
130            return Ok(());
131        }
132
133        match child.try_wait() {
134            Ok(Some(status)) => {
135                return Err(format!(
136                    "cloudflared exited before opening local tunnel (status: {})",
137                    status
138                ));
139            }
140            Ok(None) => {}
141            Err(err) => return Err(format!("Failed to inspect cloudflared process: {}", err)),
142        }
143
144        if Instant::now() >= deadline {
145            let _ = child.kill().await;
146            let _ = child.wait().await;
147            return Err(format!(
148                "Timed out waiting for cloudflared to open {}",
149                listener_addr
150            ));
151        }
152
153        sleep(CLOUDFLARED_READY_POLL_INTERVAL).await;
154    }
155}
156
157fn normalize_optional_string(value: String) -> Option<String> {
158    let trimmed = value.trim();
159    if trimmed.is_empty() {
160        None
161    } else {
162        Some(trimmed.to_string())
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::parse_listener_port;
169
170    #[test]
171    fn parses_listener_port_from_host_port() {
172        assert_eq!(parse_listener_port("127.0.0.1:2222").unwrap(), 2222);
173    }
174
175    #[test]
176    fn rejects_invalid_listener_address() {
177        assert!(parse_listener_port("localhost").is_err());
178    }
179}