1use std::collections::{HashMap, VecDeque};
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use once_cell::sync::Lazy;
9use parking_lot::Mutex;
10use regex::Regex;
11use tokio::io::{AsyncBufReadExt, BufReader};
12use tokio::process::Child;
13use tokio::task::JoinHandle;
14
15use crate::shared_client::DEFAULT_CONNECT_TIMEOUT_SECS;
16use crate::tunnels::errors::TunnelError;
17use crate::urls::{backend_url_api, backend_url_base, join_url};
18
19static URL_RE: Lazy<Regex> =
20 Lazy::new(|| Regex::new(r"https://[a-z0-9-]+\\.trycloudflare\\.com").unwrap());
21
22const CLOUDFLARED_RELEASES: &str = "https://updatecloudflared.com/launcher";
23
24#[derive(Debug)]
25pub struct ManagedProcess {
26 pub child: Child,
27 pub logs: Arc<Mutex<VecDeque<String>>>,
28 stdout_task: Option<JoinHandle<()>>,
29 stderr_task: Option<JoinHandle<()>>,
30}
31
32impl ManagedProcess {
33 async fn stop(&mut self) {
34 let _ = self.child.start_kill();
35 let _ = self.child.wait().await;
36 if let Some(task) = self.stdout_task.take() {
37 task.abort();
38 }
39 if let Some(task) = self.stderr_task.take() {
40 task.abort();
41 }
42 }
43}
44
45static TRACKED: Lazy<Mutex<HashMap<usize, ManagedProcess>>> =
46 Lazy::new(|| Mutex::new(HashMap::new()));
47static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
48
49pub fn track_process(proc: ManagedProcess) -> usize {
50 let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
51 TRACKED.lock().insert(id, proc);
52 id
53}
54
55pub async fn stop_tracked(id: usize) -> Result<(), TunnelError> {
56 let mut guard = TRACKED.lock();
57 if let Some(mut proc) = guard.remove(&id) {
58 proc.stop().await;
59 return Ok(());
60 }
61 Err(TunnelError::process(format!("process id {id} not found")))
62}
63
64pub async fn cleanup_all() {
65 let mut procs = TRACKED.lock();
66 for (_, proc) in procs.iter_mut() {
67 proc.stop().await;
68 }
69 procs.clear();
70}
71
72fn synth_bin_dir() -> Result<PathBuf, TunnelError> {
73 let home = std::env::var("HOME").map_err(|_| TunnelError::config("HOME not set"))?;
74 Ok(Path::new(&home).join(".synth").join("bin"))
75}
76
77pub fn get_cloudflared_path(prefer_system: bool) -> Option<PathBuf> {
78 if let Ok(dir) = synth_bin_dir() {
79 let candidate = dir.join("cloudflared");
80 if candidate.exists() {
81 return Some(candidate);
82 }
83 }
84 if prefer_system {
85 if let Ok(path) = which::which("cloudflared") {
86 return Some(path);
87 }
88 }
89 let common = [
90 PathBuf::from("/usr/local/bin/cloudflared"),
91 PathBuf::from("/opt/homebrew/bin/cloudflared"),
92 PathBuf::from(std::env::var("HOME").ok().unwrap_or_default()).join("bin/cloudflared"),
93 ];
94 for path in common {
95 if path.exists() {
96 return Some(path);
97 }
98 }
99 None
100}
101
102pub async fn ensure_cloudflared_installed(force: bool) -> Result<PathBuf, TunnelError> {
103 if !force {
104 if let Some(path) = get_cloudflared_path(true) {
105 return Ok(path);
106 }
107 }
108 let dir = synth_bin_dir()?;
109 tokio::fs::create_dir_all(&dir)
110 .await
111 .map_err(|e| TunnelError::process(format!("mkdir failed: {e}")))?;
112 let url = resolve_cloudflared_download_url().await?;
113 let tmp = download_file(&url).await?;
114 let target = dir.join("cloudflared");
115 if tmp.extension().and_then(|s| s.to_str()) == Some("gz") {
116 extract_gzip(&tmp, &target)?;
117 } else if tmp.to_string_lossy().ends_with(".tar.gz") {
118 extract_tarball(&tmp, &dir)?;
119 } else {
120 tokio::fs::copy(&tmp, &target)
121 .await
122 .map_err(|e| TunnelError::process(format!("copy failed: {e}")))?;
123 }
124 #[cfg(unix)]
125 {
126 use std::os::unix::fs::PermissionsExt;
127 let _ = tokio::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o755)).await;
128 }
129 Ok(target)
130}
131
132pub async fn require_cloudflared() -> Result<PathBuf, TunnelError> {
133 get_cloudflared_path(true).ok_or_else(|| TunnelError::config("cloudflared not found"))
134}
135
136async fn resolve_cloudflared_download_url() -> Result<String, TunnelError> {
137 let system = std::env::consts::OS;
138 let arch = std::env::consts::ARCH;
139 let platform = match system {
140 "macos" | "darwin" => "macos",
141 "linux" => "linux",
142 "windows" => "windows",
143 _ => {
144 return Err(TunnelError::config(format!(
145 "unsupported platform {system}"
146 )))
147 }
148 };
149 let arch_key = if arch == "aarch64" || arch == "arm64" {
150 "arm64"
151 } else {
152 "amd64"
153 };
154 let url = format!("{CLOUDFLARED_RELEASES}/v1/{platform}/{arch_key}/versions/stable");
155 let resp = reqwest::get(&url)
156 .await
157 .map_err(|e| TunnelError::process(format!("cloudflared metadata fetch failed: {e}")))?;
158 let json: serde_json::Value = resp
159 .json()
160 .await
161 .map_err(|e| TunnelError::process(format!("cloudflared metadata parse failed: {e}")))?;
162 json.get("url")
163 .and_then(|v| v.as_str())
164 .map(|s| s.to_string())
165 .ok_or_else(|| TunnelError::process("cloudflared metadata missing url"))
166}
167
168async fn download_file(url: &str) -> Result<PathBuf, TunnelError> {
169 let resp = reqwest::get(url)
170 .await
171 .map_err(|e| TunnelError::process(format!("download failed: {e}")))?;
172 let bytes = resp
173 .bytes()
174 .await
175 .map_err(|e| TunnelError::process(format!("download bytes failed: {e}")))?;
176 let tmp = std::env::temp_dir().join(format!("cloudflared-{}.tmp", uuid::Uuid::new_v4()));
177 tokio::fs::write(&tmp, bytes)
178 .await
179 .map_err(|e| TunnelError::process(format!("write failed: {e}")))?;
180 Ok(tmp)
181}
182
183fn extract_gzip(src: &Path, target: &Path) -> Result<(), TunnelError> {
184 let input = std::fs::File::open(src).map_err(|e| TunnelError::process(format!("{e}")))?;
185 let mut gz = flate2::read::GzDecoder::new(input);
186 let mut out =
187 std::fs::File::create(target).map_err(|e| TunnelError::process(format!("{e}")))?;
188 std::io::copy(&mut gz, &mut out).map_err(|e| TunnelError::process(format!("{e}")))?;
189 Ok(())
190}
191
192fn extract_tarball(src: &Path, target_dir: &Path) -> Result<(), TunnelError> {
193 let input = std::fs::File::open(src).map_err(|e| TunnelError::process(format!("{e}")))?;
194 let gz = flate2::read::GzDecoder::new(input);
195 let mut archive = tar::Archive::new(gz);
196 archive
197 .unpack(target_dir)
198 .map_err(|e| TunnelError::process(format!("{e}")))?;
199 Ok(())
200}
201
202async fn spawn_process(args: &[String]) -> Result<ManagedProcess, TunnelError> {
203 let mut cmd = tokio::process::Command::new(&args[0]);
204 cmd.args(&args[1..])
205 .stdout(Stdio::piped())
206 .stderr(Stdio::piped());
207 let mut child = cmd
208 .spawn()
209 .map_err(|e| TunnelError::process(e.to_string()))?;
210 let stdout = child.stdout.take();
211 let stderr = child.stderr.take();
212 let logs = Arc::new(Mutex::new(VecDeque::with_capacity(200)));
213 let mut stdout_task = None;
214 let mut stderr_task = None;
215 if let Some(out) = stdout {
216 let logs = logs.clone();
217 stdout_task = Some(tokio::spawn(async move {
218 let mut lines = BufReader::new(out).lines();
219 while let Ok(Some(line)) = lines.next_line().await {
220 push_log(&logs, &line);
221 }
222 }));
223 }
224 if let Some(err) = stderr {
225 let logs = logs.clone();
226 stderr_task = Some(tokio::spawn(async move {
227 let mut lines = BufReader::new(err).lines();
228 while let Ok(Some(line)) = lines.next_line().await {
229 push_log(&logs, &line);
230 }
231 }));
232 }
233 Ok(ManagedProcess {
234 child,
235 logs,
236 stdout_task,
237 stderr_task,
238 })
239}
240
241fn push_log(logs: &Arc<Mutex<VecDeque<String>>>, line: &str) {
242 let mut guard = logs.lock();
243 guard.push_back(line.to_string());
244 if guard.len() > 200 {
245 guard.pop_front();
246 }
247}
248
249pub async fn open_quick_tunnel(
250 port: u16,
251 wait_s: f64,
252) -> Result<(String, ManagedProcess), TunnelError> {
253 let bin = require_cloudflared().await?;
254 let args = vec![
255 bin.to_string_lossy().to_string(),
256 "tunnel".to_string(),
257 "--config".to_string(),
258 "/dev/null".to_string(),
259 "--url".to_string(),
260 format!("http://127.0.0.1:{port}"),
261 ];
262 let mut proc = spawn_process(&args).await?;
263 let deadline = Instant::now() + Duration::from_secs_f64(wait_s);
264 loop {
265 if Instant::now() > deadline {
266 let _ = proc.child.start_kill();
267 return Err(TunnelError::process(
268 "timed out waiting for quick tunnel URL",
269 ));
270 }
271 if let Some(status) = proc.child.try_wait().ok().flatten() {
272 return Err(TunnelError::process(format!(
273 "cloudflared exited early with status {status}"
274 )));
275 }
276 let url = {
277 let logs = proc.logs.lock();
278 logs.iter()
279 .find_map(|line| URL_RE.find(line).map(|m| m.as_str().to_string()))
280 };
281 if let Some(url) = url {
282 return Ok((url, proc));
283 }
284 tokio::time::sleep(Duration::from_millis(50)).await;
285 }
286}
287
288pub async fn open_quick_tunnel_with_dns_verification(
289 port: u16,
290 wait_s: f64,
291 verify_dns: bool,
292 api_key: Option<String>,
293) -> Result<(String, ManagedProcess), TunnelError> {
294 let (url, proc) = open_quick_tunnel(port, wait_s).await?;
295 if verify_dns {
296 verify_tunnel_dns_resolution(&url, "tunnel", 60.0, api_key).await?;
297 }
298 Ok((url, proc))
299}
300
301pub async fn open_managed_tunnel(tunnel_token: &str) -> Result<ManagedProcess, TunnelError> {
302 let bin = require_cloudflared().await?;
303 let args = vec![
304 bin.to_string_lossy().to_string(),
305 "tunnel".to_string(),
306 "run".to_string(),
307 "--token".to_string(),
308 tunnel_token.to_string(),
309 ];
310 spawn_process(&args).await
311}
312
313pub async fn open_managed_tunnel_with_connection_wait(
314 tunnel_token: &str,
315 timeout_seconds: f64,
316) -> Result<ManagedProcess, TunnelError> {
317 let mut proc = open_managed_tunnel(tunnel_token).await?;
318 let deadline = Instant::now() + Duration::from_secs_f64(timeout_seconds);
319 let patterns = [
320 Regex::new("Registered tunnel connection").unwrap(),
321 Regex::new("Connection .* registered").unwrap(),
322 ];
323 loop {
324 if Instant::now() > deadline {
325 let _ = proc.child.start_kill();
326 return Err(TunnelError::connector("cloudflared connection timeout"));
327 }
328 if let Some(status) = proc.child.try_wait().ok().flatten() {
329 return Err(TunnelError::connector(format!(
330 "cloudflared exited early with status {status}"
331 )));
332 }
333 let connected = {
334 let logs = proc.logs.lock();
335 logs.iter()
336 .any(|line| patterns.iter().any(|p| p.is_match(line)))
337 };
338 if connected {
339 return Ok(proc);
340 }
341 tokio::time::sleep(Duration::from_millis(100)).await;
342 }
343}
344
345#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
346pub struct TunnelRotateResponse {
347 pub tunnel_token: String,
348 pub hostname: String,
349 pub access_client_id: Option<String>,
350 pub access_client_secret: Option<String>,
351 pub dns_verified: Option<bool>,
352}
353
354fn normalize_backend_base(url: &str) -> String {
357 let mut s = url.trim_end_matches('/').to_string();
358 if s.ends_with("/v1") {
359 s.truncate(s.len() - 3);
360 s = s.trim_end_matches('/').to_string();
361 }
362 if s.ends_with("/api") {
363 s.truncate(s.len() - 4);
364 s = s.trim_end_matches('/').to_string();
365 }
366 s
367}
368
369pub async fn rotate_tunnel(
370 api_key: &str,
371 port: u16,
372 backend_url: Option<String>,
373) -> Result<TunnelRotateResponse, TunnelError> {
374 let raw = backend_url.unwrap_or_else(backend_url_base);
375 let base = normalize_backend_base(&raw);
376 let url = join_url(&base, "/api/v1/tunnels/rotate");
377 let client = reqwest::Client::builder()
378 .timeout(Duration::from_secs(180))
379 .pool_max_idle_per_host(20)
380 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
381 .build()
382 .map_err(|e| TunnelError::api(e.to_string()))?;
383 let resp = client
384 .post(url)
385 .header("X-API-Key", api_key)
386 .header("Authorization", format!("Bearer {api_key}"))
387 .json(&serde_json::json!({
388 "local_port": port,
389 "local_host": "127.0.0.1",
390 }))
391 .send()
392 .await
393 .map_err(|e| TunnelError::api(e.to_string()))?;
394 if !resp.status().is_success() {
395 let text = resp.text().await.unwrap_or_default();
396 return Err(TunnelError::api(format!("rotate failed: {}", text)));
397 }
398 let data: TunnelRotateResponse = resp
399 .json()
400 .await
401 .map_err(|e| TunnelError::api(e.to_string()))?;
402 Ok(data)
403}
404
405#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
406pub struct TunnelCreateResponse {
407 pub tunnel_token: String,
408 pub hostname: String,
409 pub access_client_id: Option<String>,
410 pub access_client_secret: Option<String>,
411 pub dns_verified: Option<bool>,
412}
413
414pub async fn create_tunnel(
415 api_key: &str,
416 port: u16,
417 subdomain: Option<String>,
418) -> Result<TunnelCreateResponse, TunnelError> {
419 let url = join_url(&backend_url_api(), "/v1/tunnels/");
420 let client = reqwest::Client::builder()
421 .timeout(Duration::from_secs(180))
422 .pool_max_idle_per_host(20)
423 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
424 .build()
425 .map_err(|e| TunnelError::api(e.to_string()))?;
426 let resp = client
427 .post(url)
428 .header("X-API-Key", api_key)
429 .header("Authorization", format!("Bearer {api_key}"))
430 .json(&serde_json::json!({
431 "subdomain": subdomain.unwrap_or_else(|| format!("tunnel-{port}")),
432 "local_port": port,
433 "local_host": "127.0.0.1",
434 }))
435 .send()
436 .await
437 .map_err(|e| TunnelError::api(e.to_string()))?;
438 if !resp.status().is_success() {
439 let text = resp.text().await.unwrap_or_default();
440 return Err(TunnelError::api(format!("create failed: {}", text)));
441 }
442 let data: TunnelCreateResponse = resp
443 .json()
444 .await
445 .map_err(|e| TunnelError::api(e.to_string()))?;
446 Ok(data)
447}
448
449pub async fn wait_for_health_check(
450 host: &str,
451 port: u16,
452 api_key: Option<String>,
453 timeout: f64,
454) -> Result<(), TunnelError> {
455 let url = format!("http://{host}:{port}/health");
456 let client = reqwest::Client::builder()
457 .timeout(Duration::from_secs(5))
458 .pool_max_idle_per_host(10)
459 .connect_timeout(Duration::from_secs(5))
460 .no_proxy()
461 .build()
462 .map_err(|e| TunnelError::local(e.to_string()))?;
463 let start = Instant::now();
464 let headers = api_key.map(|k| ("X-API-Key", k));
465 while start.elapsed() < Duration::from_secs_f64(timeout) {
466 let mut req = client.get(&url);
467 if let Some((k, v)) = headers.clone() {
468 req = req.header(k, v);
469 }
470 if let Ok(resp) = req.send().await {
471 let status = resp.status().as_u16();
472 if status == 200 || status == 400 {
473 return Ok(());
474 }
475 }
476 tokio::time::sleep(Duration::from_millis(500)).await;
477 }
478 Err(TunnelError::local(format!(
479 "health check failed: {url} not ready after {timeout}s"
480 )))
481}
482
483fn prefer_ipv4(ips: impl Iterator<Item = std::net::IpAddr>) -> Option<std::net::IpAddr> {
486 let mut fallback: Option<std::net::IpAddr> = None;
487 for ip in ips {
488 if ip.is_ipv4() {
489 return Some(ip);
490 }
491 if fallback.is_none() {
492 fallback = Some(ip);
493 }
494 }
495 fallback
496}
497
498pub async fn resolve_hostname_with_explicit_resolvers(
499 hostname: &str,
500) -> Result<std::net::IpAddr, TunnelError> {
501 use trust_dns_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
502 use trust_dns_resolver::TokioAsyncResolver;
503
504 let servers = vec![("1.1.1.1:53", "1.1.1.1"), ("8.8.8.8:53", "8.8.8.8")];
505 for (socket, _) in servers {
506 if let Ok(addr) = socket.parse() {
507 let config = ResolverConfig::from_parts(
508 None,
509 vec![],
510 vec![NameServerConfig {
511 socket_addr: addr,
512 protocol: Protocol::Udp,
513 tls_dns_name: None,
514 trust_negative_responses: false,
515 bind_addr: None,
516 }],
517 );
518 let resolver = TokioAsyncResolver::tokio(config, ResolverOpts::default());
519 if let Ok(lookup) = resolver.lookup_ip(hostname).await {
520 if let Some(ip) = prefer_ipv4(lookup.iter()) {
521 return Ok(ip);
522 }
523 }
524 }
525 }
526 let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
528 let lookup = resolver
529 .lookup_ip(hostname)
530 .await
531 .map_err(|e| TunnelError::dns(e.to_string()))?;
532 prefer_ipv4(lookup.iter()).ok_or_else(|| TunnelError::dns("no ip resolved"))
533}
534
535pub async fn verify_tunnel_dns_resolution(
536 tunnel_url: &str,
537 _name: &str,
538 timeout_seconds: f64,
539 api_key: Option<String>,
540) -> Result<(), TunnelError> {
541 let parsed =
542 url::Url::parse(tunnel_url).map_err(|e| TunnelError::dns(format!("invalid url: {e}")))?;
543 let hostname = parsed
544 .host_str()
545 .ok_or_else(|| TunnelError::dns("missing hostname"))?;
546 if hostname == "localhost" || hostname == "127.0.0.1" {
547 return Ok(());
548 }
549 let deadline = Instant::now() + Duration::from_secs_f64(timeout_seconds);
550 let mut last_err: Option<String> = None;
551 loop {
552 if Instant::now() > deadline {
553 return Err(TunnelError::dns(format!(
554 "dns verification timeout: {} ({:?})",
555 hostname, last_err
556 )));
557 }
558 let ip = resolve_hostname_with_explicit_resolvers(hostname).await?;
559 let port = if parsed.scheme() == "http" { 80 } else { 443 };
560 let builder = reqwest::Client::builder()
561 .timeout(Duration::from_secs(5))
562 .pool_max_idle_per_host(10)
563 .connect_timeout(Duration::from_secs(5))
564 .danger_accept_invalid_certs(true)
565 .resolve(hostname, (ip, port).into());
566 let client = builder
567 .build()
568 .map_err(|e| TunnelError::dns(e.to_string()))?;
569 let mut req = client.get(parsed.clone());
570 if let Some(key) = api_key.clone() {
571 req = req.header("X-API-Key", key);
572 }
573 match req.send().await {
574 Ok(resp) => {
575 let status = resp.status().as_u16();
576 if matches!(status, 200 | 400 | 401 | 403 | 404 | 405 | 502) {
577 return Ok(());
578 }
579 last_err = Some(format!("status {status}"));
580 }
581 Err(e) => {
582 last_err = Some(e.to_string());
583 }
584 }
585 tokio::time::sleep(Duration::from_secs(1)).await;
586 }
587}
588
589pub async fn stop_tunnel(mut proc: ManagedProcess) {
590 let _ = proc.child.start_kill();
591 let _ = proc.child.wait().await;
592}