Skip to main content

synth_ai_core/tunnels/
cloudflared.rs

1use std::collections::{HashMap, VecDeque};
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use once_cell::sync::Lazy;
9use parking_lot::Mutex;
10use regex::Regex;
11use tokio::io::{AsyncBufReadExt, BufReader};
12use tokio::process::Child;
13use tokio::task::JoinHandle;
14
15use crate::shared_client::DEFAULT_CONNECT_TIMEOUT_SECS;
16use crate::tunnels::errors::TunnelError;
17use crate::urls::{backend_url_api, backend_url_base, join_url};
18
19static URL_RE: Lazy<Regex> =
20    Lazy::new(|| Regex::new(r"https://[a-z0-9-]+\\.trycloudflare\\.com").unwrap());
21
22const CLOUDFLARED_RELEASES: &str = "https://updatecloudflared.com/launcher";
23
24#[derive(Debug)]
25pub struct ManagedProcess {
26    pub child: Child,
27    pub logs: Arc<Mutex<VecDeque<String>>>,
28    stdout_task: Option<JoinHandle<()>>,
29    stderr_task: Option<JoinHandle<()>>,
30}
31
32impl ManagedProcess {
33    async fn stop(&mut self) {
34        let _ = self.child.start_kill();
35        let _ = self.child.wait().await;
36        if let Some(task) = self.stdout_task.take() {
37            task.abort();
38        }
39        if let Some(task) = self.stderr_task.take() {
40            task.abort();
41        }
42    }
43}
44
45static TRACKED: Lazy<Mutex<HashMap<usize, ManagedProcess>>> =
46    Lazy::new(|| Mutex::new(HashMap::new()));
47static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
48
49pub fn track_process(proc: ManagedProcess) -> usize {
50    let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
51    TRACKED.lock().insert(id, proc);
52    id
53}
54
55pub async fn stop_tracked(id: usize) -> Result<(), TunnelError> {
56    let mut guard = TRACKED.lock();
57    if let Some(mut proc) = guard.remove(&id) {
58        proc.stop().await;
59        return Ok(());
60    }
61    Err(TunnelError::process(format!("process id {id} not found")))
62}
63
64pub async fn cleanup_all() {
65    let mut procs = TRACKED.lock();
66    for (_, proc) in procs.iter_mut() {
67        proc.stop().await;
68    }
69    procs.clear();
70}
71
72fn synth_bin_dir() -> Result<PathBuf, TunnelError> {
73    let home = std::env::var("HOME").map_err(|_| TunnelError::config("HOME not set"))?;
74    Ok(Path::new(&home).join(".synth").join("bin"))
75}
76
77pub fn get_cloudflared_path(prefer_system: bool) -> Option<PathBuf> {
78    if let Ok(dir) = synth_bin_dir() {
79        let candidate = dir.join("cloudflared");
80        if candidate.exists() {
81            return Some(candidate);
82        }
83    }
84    if prefer_system {
85        if let Ok(path) = which::which("cloudflared") {
86            return Some(path);
87        }
88    }
89    let common = [
90        PathBuf::from("/usr/local/bin/cloudflared"),
91        PathBuf::from("/opt/homebrew/bin/cloudflared"),
92        PathBuf::from(std::env::var("HOME").ok().unwrap_or_default()).join("bin/cloudflared"),
93    ];
94    for path in common {
95        if path.exists() {
96            return Some(path);
97        }
98    }
99    None
100}
101
102pub async fn ensure_cloudflared_installed(force: bool) -> Result<PathBuf, TunnelError> {
103    if !force {
104        if let Some(path) = get_cloudflared_path(true) {
105            return Ok(path);
106        }
107    }
108    let dir = synth_bin_dir()?;
109    tokio::fs::create_dir_all(&dir)
110        .await
111        .map_err(|e| TunnelError::process(format!("mkdir failed: {e}")))?;
112    let url = resolve_cloudflared_download_url().await?;
113    let tmp = download_file(&url).await?;
114    let target = dir.join("cloudflared");
115    if tmp.extension().and_then(|s| s.to_str()) == Some("gz") {
116        extract_gzip(&tmp, &target)?;
117    } else if tmp.to_string_lossy().ends_with(".tar.gz") {
118        extract_tarball(&tmp, &dir)?;
119    } else {
120        tokio::fs::copy(&tmp, &target)
121            .await
122            .map_err(|e| TunnelError::process(format!("copy failed: {e}")))?;
123    }
124    #[cfg(unix)]
125    {
126        use std::os::unix::fs::PermissionsExt;
127        let _ = tokio::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o755)).await;
128    }
129    Ok(target)
130}
131
132pub async fn require_cloudflared() -> Result<PathBuf, TunnelError> {
133    get_cloudflared_path(true).ok_or_else(|| TunnelError::config("cloudflared not found"))
134}
135
136async fn resolve_cloudflared_download_url() -> Result<String, TunnelError> {
137    let system = std::env::consts::OS;
138    let arch = std::env::consts::ARCH;
139    let platform = match system {
140        "macos" | "darwin" => "macos",
141        "linux" => "linux",
142        "windows" => "windows",
143        _ => {
144            return Err(TunnelError::config(format!(
145                "unsupported platform {system}"
146            )))
147        }
148    };
149    let arch_key = if arch == "aarch64" || arch == "arm64" {
150        "arm64"
151    } else {
152        "amd64"
153    };
154    let url = format!("{CLOUDFLARED_RELEASES}/v1/{platform}/{arch_key}/versions/stable");
155    let resp = reqwest::get(&url)
156        .await
157        .map_err(|e| TunnelError::process(format!("cloudflared metadata fetch failed: {e}")))?;
158    let json: serde_json::Value = resp
159        .json()
160        .await
161        .map_err(|e| TunnelError::process(format!("cloudflared metadata parse failed: {e}")))?;
162    json.get("url")
163        .and_then(|v| v.as_str())
164        .map(|s| s.to_string())
165        .ok_or_else(|| TunnelError::process("cloudflared metadata missing url"))
166}
167
168async fn download_file(url: &str) -> Result<PathBuf, TunnelError> {
169    let resp = reqwest::get(url)
170        .await
171        .map_err(|e| TunnelError::process(format!("download failed: {e}")))?;
172    let bytes = resp
173        .bytes()
174        .await
175        .map_err(|e| TunnelError::process(format!("download bytes failed: {e}")))?;
176    let tmp = std::env::temp_dir().join(format!("cloudflared-{}.tmp", uuid::Uuid::new_v4()));
177    tokio::fs::write(&tmp, bytes)
178        .await
179        .map_err(|e| TunnelError::process(format!("write failed: {e}")))?;
180    Ok(tmp)
181}
182
183fn extract_gzip(src: &Path, target: &Path) -> Result<(), TunnelError> {
184    let input = std::fs::File::open(src).map_err(|e| TunnelError::process(format!("{e}")))?;
185    let mut gz = flate2::read::GzDecoder::new(input);
186    let mut out =
187        std::fs::File::create(target).map_err(|e| TunnelError::process(format!("{e}")))?;
188    std::io::copy(&mut gz, &mut out).map_err(|e| TunnelError::process(format!("{e}")))?;
189    Ok(())
190}
191
192fn extract_tarball(src: &Path, target_dir: &Path) -> Result<(), TunnelError> {
193    let input = std::fs::File::open(src).map_err(|e| TunnelError::process(format!("{e}")))?;
194    let gz = flate2::read::GzDecoder::new(input);
195    let mut archive = tar::Archive::new(gz);
196    archive
197        .unpack(target_dir)
198        .map_err(|e| TunnelError::process(format!("{e}")))?;
199    Ok(())
200}
201
202async fn spawn_process(args: &[String]) -> Result<ManagedProcess, TunnelError> {
203    let mut cmd = tokio::process::Command::new(&args[0]);
204    cmd.args(&args[1..])
205        .stdout(Stdio::piped())
206        .stderr(Stdio::piped());
207    let mut child = cmd
208        .spawn()
209        .map_err(|e| TunnelError::process(e.to_string()))?;
210    let stdout = child.stdout.take();
211    let stderr = child.stderr.take();
212    let logs = Arc::new(Mutex::new(VecDeque::with_capacity(200)));
213    let mut stdout_task = None;
214    let mut stderr_task = None;
215    if let Some(out) = stdout {
216        let logs = logs.clone();
217        stdout_task = Some(tokio::spawn(async move {
218            let mut lines = BufReader::new(out).lines();
219            while let Ok(Some(line)) = lines.next_line().await {
220                push_log(&logs, &line);
221            }
222        }));
223    }
224    if let Some(err) = stderr {
225        let logs = logs.clone();
226        stderr_task = Some(tokio::spawn(async move {
227            let mut lines = BufReader::new(err).lines();
228            while let Ok(Some(line)) = lines.next_line().await {
229                push_log(&logs, &line);
230            }
231        }));
232    }
233    Ok(ManagedProcess {
234        child,
235        logs,
236        stdout_task,
237        stderr_task,
238    })
239}
240
241fn push_log(logs: &Arc<Mutex<VecDeque<String>>>, line: &str) {
242    let mut guard = logs.lock();
243    guard.push_back(line.to_string());
244    if guard.len() > 200 {
245        guard.pop_front();
246    }
247}
248
249pub async fn open_quick_tunnel(
250    port: u16,
251    wait_s: f64,
252) -> Result<(String, ManagedProcess), TunnelError> {
253    let bin = require_cloudflared().await?;
254    let args = vec![
255        bin.to_string_lossy().to_string(),
256        "tunnel".to_string(),
257        "--config".to_string(),
258        "/dev/null".to_string(),
259        "--url".to_string(),
260        format!("http://127.0.0.1:{port}"),
261    ];
262    let mut proc = spawn_process(&args).await?;
263    let deadline = Instant::now() + Duration::from_secs_f64(wait_s);
264    loop {
265        if Instant::now() > deadline {
266            let _ = proc.child.start_kill();
267            return Err(TunnelError::process(
268                "timed out waiting for quick tunnel URL",
269            ));
270        }
271        if let Some(status) = proc.child.try_wait().ok().flatten() {
272            return Err(TunnelError::process(format!(
273                "cloudflared exited early with status {status}"
274            )));
275        }
276        let url = {
277            let logs = proc.logs.lock();
278            logs.iter()
279                .find_map(|line| URL_RE.find(line).map(|m| m.as_str().to_string()))
280        };
281        if let Some(url) = url {
282            return Ok((url, proc));
283        }
284        tokio::time::sleep(Duration::from_millis(50)).await;
285    }
286}
287
288pub async fn open_quick_tunnel_with_dns_verification(
289    port: u16,
290    wait_s: f64,
291    verify_dns: bool,
292    api_key: Option<String>,
293) -> Result<(String, ManagedProcess), TunnelError> {
294    let (url, proc) = open_quick_tunnel(port, wait_s).await?;
295    if verify_dns {
296        verify_tunnel_dns_resolution(&url, "tunnel", 60.0, api_key).await?;
297    }
298    Ok((url, proc))
299}
300
301pub async fn open_managed_tunnel(tunnel_token: &str) -> Result<ManagedProcess, TunnelError> {
302    let bin = require_cloudflared().await?;
303    let args = vec![
304        bin.to_string_lossy().to_string(),
305        "tunnel".to_string(),
306        "run".to_string(),
307        "--token".to_string(),
308        tunnel_token.to_string(),
309    ];
310    spawn_process(&args).await
311}
312
313pub async fn open_managed_tunnel_with_connection_wait(
314    tunnel_token: &str,
315    timeout_seconds: f64,
316) -> Result<ManagedProcess, TunnelError> {
317    let mut proc = open_managed_tunnel(tunnel_token).await?;
318    let deadline = Instant::now() + Duration::from_secs_f64(timeout_seconds);
319    let patterns = [
320        Regex::new("Registered tunnel connection").unwrap(),
321        Regex::new("Connection .* registered").unwrap(),
322    ];
323    loop {
324        if Instant::now() > deadline {
325            let _ = proc.child.start_kill();
326            return Err(TunnelError::connector("cloudflared connection timeout"));
327        }
328        if let Some(status) = proc.child.try_wait().ok().flatten() {
329            return Err(TunnelError::connector(format!(
330                "cloudflared exited early with status {status}"
331            )));
332        }
333        let connected = {
334            let logs = proc.logs.lock();
335            logs.iter()
336                .any(|line| patterns.iter().any(|p| p.is_match(line)))
337        };
338        if connected {
339            return Ok(proc);
340        }
341        tokio::time::sleep(Duration::from_millis(100)).await;
342    }
343}
344
345#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
346pub struct TunnelRotateResponse {
347    pub tunnel_token: String,
348    pub hostname: String,
349    pub access_client_id: Option<String>,
350    pub access_client_secret: Option<String>,
351    pub dns_verified: Option<bool>,
352}
353
354/// Strip trailing `/v1` then `/api` from a backend URL so that callers
355/// can unconditionally append `/api/v1/…` without doubling path segments.
356fn normalize_backend_base(url: &str) -> String {
357    let mut s = url.trim_end_matches('/').to_string();
358    if s.ends_with("/v1") {
359        s.truncate(s.len() - 3);
360        s = s.trim_end_matches('/').to_string();
361    }
362    if s.ends_with("/api") {
363        s.truncate(s.len() - 4);
364        s = s.trim_end_matches('/').to_string();
365    }
366    s
367}
368
369pub async fn rotate_tunnel(
370    api_key: &str,
371    port: u16,
372    backend_url: Option<String>,
373) -> Result<TunnelRotateResponse, TunnelError> {
374    let raw = backend_url.unwrap_or_else(backend_url_base);
375    let base = normalize_backend_base(&raw);
376    let url = join_url(&base, "/api/v1/tunnels/rotate");
377    let client = reqwest::Client::builder()
378        .timeout(Duration::from_secs(180))
379        .pool_max_idle_per_host(20)
380        .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
381        .build()
382        .map_err(|e| TunnelError::api(e.to_string()))?;
383    let resp = client
384        .post(url)
385        .header("X-API-Key", api_key)
386        .header("Authorization", format!("Bearer {api_key}"))
387        .json(&serde_json::json!({
388            "local_port": port,
389            "local_host": "127.0.0.1",
390        }))
391        .send()
392        .await
393        .map_err(|e| TunnelError::api(e.to_string()))?;
394    if !resp.status().is_success() {
395        let text = resp.text().await.unwrap_or_default();
396        return Err(TunnelError::api(format!("rotate failed: {}", text)));
397    }
398    let data: TunnelRotateResponse = resp
399        .json()
400        .await
401        .map_err(|e| TunnelError::api(e.to_string()))?;
402    Ok(data)
403}
404
405#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
406pub struct TunnelCreateResponse {
407    pub tunnel_token: String,
408    pub hostname: String,
409    pub access_client_id: Option<String>,
410    pub access_client_secret: Option<String>,
411    pub dns_verified: Option<bool>,
412}
413
414pub async fn create_tunnel(
415    api_key: &str,
416    port: u16,
417    subdomain: Option<String>,
418) -> Result<TunnelCreateResponse, TunnelError> {
419    let url = join_url(&backend_url_api(), "/v1/tunnels/");
420    let client = reqwest::Client::builder()
421        .timeout(Duration::from_secs(180))
422        .pool_max_idle_per_host(20)
423        .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
424        .build()
425        .map_err(|e| TunnelError::api(e.to_string()))?;
426    let resp = client
427        .post(url)
428        .header("X-API-Key", api_key)
429        .header("Authorization", format!("Bearer {api_key}"))
430        .json(&serde_json::json!({
431            "subdomain": subdomain.unwrap_or_else(|| format!("tunnel-{port}")),
432            "local_port": port,
433            "local_host": "127.0.0.1",
434        }))
435        .send()
436        .await
437        .map_err(|e| TunnelError::api(e.to_string()))?;
438    if !resp.status().is_success() {
439        let text = resp.text().await.unwrap_or_default();
440        return Err(TunnelError::api(format!("create failed: {}", text)));
441    }
442    let data: TunnelCreateResponse = resp
443        .json()
444        .await
445        .map_err(|e| TunnelError::api(e.to_string()))?;
446    Ok(data)
447}
448
449pub async fn wait_for_health_check(
450    host: &str,
451    port: u16,
452    api_key: Option<String>,
453    timeout: f64,
454) -> Result<(), TunnelError> {
455    let url = format!("http://{host}:{port}/health");
456    let client = reqwest::Client::builder()
457        .timeout(Duration::from_secs(5))
458        .pool_max_idle_per_host(10)
459        .connect_timeout(Duration::from_secs(5))
460        .no_proxy()
461        .build()
462        .map_err(|e| TunnelError::local(e.to_string()))?;
463    let start = Instant::now();
464    let headers = api_key.map(|k| ("X-API-Key", k));
465    while start.elapsed() < Duration::from_secs_f64(timeout) {
466        let mut req = client.get(&url);
467        if let Some((k, v)) = headers.clone() {
468            req = req.header(k, v);
469        }
470        if let Ok(resp) = req.send().await {
471            let status = resp.status().as_u16();
472            if status == 200 || status == 400 {
473                return Ok(());
474            }
475        }
476        tokio::time::sleep(Duration::from_millis(500)).await;
477    }
478    Err(TunnelError::local(format!(
479        "health check failed: {url} not ready after {timeout}s"
480    )))
481}
482
483/// Pick the best IP from a DNS lookup, preferring IPv4 (A records) over IPv6
484/// since IPv6 may not be routable in all environments.
485fn prefer_ipv4(ips: impl Iterator<Item = std::net::IpAddr>) -> Option<std::net::IpAddr> {
486    let mut fallback: Option<std::net::IpAddr> = None;
487    for ip in ips {
488        if ip.is_ipv4() {
489            return Some(ip);
490        }
491        if fallback.is_none() {
492            fallback = Some(ip);
493        }
494    }
495    fallback
496}
497
498pub async fn resolve_hostname_with_explicit_resolvers(
499    hostname: &str,
500) -> Result<std::net::IpAddr, TunnelError> {
501    use trust_dns_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
502    use trust_dns_resolver::TokioAsyncResolver;
503
504    let servers = vec![("1.1.1.1:53", "1.1.1.1"), ("8.8.8.8:53", "8.8.8.8")];
505    for (socket, _) in servers {
506        if let Ok(addr) = socket.parse() {
507            let config = ResolverConfig::from_parts(
508                None,
509                vec![],
510                vec![NameServerConfig {
511                    socket_addr: addr,
512                    protocol: Protocol::Udp,
513                    tls_dns_name: None,
514                    trust_negative_responses: false,
515                    bind_addr: None,
516                }],
517            );
518            let resolver = TokioAsyncResolver::tokio(config, ResolverOpts::default());
519            if let Ok(lookup) = resolver.lookup_ip(hostname).await {
520                if let Some(ip) = prefer_ipv4(lookup.iter()) {
521                    return Ok(ip);
522                }
523            }
524        }
525    }
526    // Fallback to system resolver
527    let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
528    let lookup = resolver
529        .lookup_ip(hostname)
530        .await
531        .map_err(|e| TunnelError::dns(e.to_string()))?;
532    prefer_ipv4(lookup.iter()).ok_or_else(|| TunnelError::dns("no ip resolved"))
533}
534
535pub async fn verify_tunnel_dns_resolution(
536    tunnel_url: &str,
537    _name: &str,
538    timeout_seconds: f64,
539    api_key: Option<String>,
540) -> Result<(), TunnelError> {
541    let parsed =
542        url::Url::parse(tunnel_url).map_err(|e| TunnelError::dns(format!("invalid url: {e}")))?;
543    let hostname = parsed
544        .host_str()
545        .ok_or_else(|| TunnelError::dns("missing hostname"))?;
546    if hostname == "localhost" || hostname == "127.0.0.1" {
547        return Ok(());
548    }
549    let deadline = Instant::now() + Duration::from_secs_f64(timeout_seconds);
550    let mut last_err: Option<String> = None;
551    loop {
552        if Instant::now() > deadline {
553            return Err(TunnelError::dns(format!(
554                "dns verification timeout: {} ({:?})",
555                hostname, last_err
556            )));
557        }
558        let ip = resolve_hostname_with_explicit_resolvers(hostname).await?;
559        let port = if parsed.scheme() == "http" { 80 } else { 443 };
560        let builder = reqwest::Client::builder()
561            .timeout(Duration::from_secs(5))
562            .pool_max_idle_per_host(10)
563            .connect_timeout(Duration::from_secs(5))
564            .danger_accept_invalid_certs(true)
565            .resolve(hostname, (ip, port).into());
566        let client = builder
567            .build()
568            .map_err(|e| TunnelError::dns(e.to_string()))?;
569        let mut req = client.get(parsed.clone());
570        if let Some(key) = api_key.clone() {
571            req = req.header("X-API-Key", key);
572        }
573        match req.send().await {
574            Ok(resp) => {
575                let status = resp.status().as_u16();
576                if matches!(status, 200 | 400 | 401 | 403 | 404 | 405 | 502) {
577                    return Ok(());
578                }
579                last_err = Some(format!("status {status}"));
580            }
581            Err(e) => {
582                last_err = Some(e.to_string());
583            }
584        }
585        tokio::time::sleep(Duration::from_secs(1)).await;
586    }
587}
588
589pub async fn stop_tunnel(mut proc: ManagedProcess) {
590    let _ = proc.child.start_kill();
591    let _ = proc.child.wait().await;
592}