Skip to main content

tf_proxy/
lib.rs

1//! tf-proxy: TrustForge enforcement reverse proxy.
2//!
3//! Sits in front of an upstream HTTP service (TLS termination happens at
4//! the proxy's own listener; a TLS upstream is out of scope for this first
5//! cut, in both the buffered and raw-splice paths). For every request it
6//! consults `tf-daemon`'s `/v1/decide` endpoint and either forwards, denies,
7//! or surfaces an approval-required handoff based on the daemon's verdict.
8//!
9//! This crate is structured as a library so that the binary entry point in
10//! `src/main.rs` is a thin wrapper and the proxy logic can be exercised by
11//! integration tests.
12
13use std::convert::Infallible;
14use std::net::SocketAddr;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use http_body_util::{BodyExt, Full};
19use hyper::body::{Bytes, Incoming};
20use hyper::header::{HeaderName, HeaderValue, UPGRADE};
21use hyper::service::service_fn;
22use hyper::{Method, Request, Response, StatusCode, Uri};
23use hyper_util::client::legacy::{connect::HttpConnector, Client};
24use hyper_util::rt::{TokioExecutor, TokioIo};
25use serde::{Deserialize, Serialize};
26use std::io::BufReader;
27use tokio::io::AsyncWriteExt;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::TlsAcceptor;
30use tracing::{debug, error, info, warn};
31
32/// Operating mode for the proxy.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum Mode {
35    /// Always forward upstream, but still consult and log the daemon decision.
36    ObserveOnly,
37    /// Honour the daemon decision: deny becomes 403, approval becomes 202.
38    Enforce,
39}
40
41impl std::str::FromStr for Mode {
42    type Err = String;
43    fn from_str(s: &str) -> Result<Self, Self::Err> {
44        match s {
45            "observe-only" | "observe_only" | "observe" => Ok(Mode::ObserveOnly),
46            "enforce" => Ok(Mode::Enforce),
47            other => Err(format!("unknown mode: {other}")),
48        }
49    }
50}
51
52/// Runtime configuration for the proxy server.
53#[derive(Clone, Debug)]
54pub struct ProxyConfig {
55    pub listen: SocketAddr,
56    pub upstream: String,
57    pub daemon: String,
58    pub admin_token: Option<String>,
59    pub profile: String,
60    pub mode: Mode,
61    pub tls_cert: Option<String>,
62    pub tls_key: Option<String>,
63}
64
65/// Decide-request body sent to tf-daemon.
66#[derive(Serialize, Debug)]
67pub struct DecideRequest<'a> {
68    pub actor: Option<&'a str>,
69    pub host_token: Option<String>,
70    pub host_token_kind: Option<String>,
71    pub action: String,
72    pub target: String,
73    pub context: serde_json::Value,
74    pub trace_id: String,
75}
76
77/// Decide-response body returned by tf-daemon.
78#[derive(Deserialize, Debug, Clone, Default)]
79pub struct DecideResponse {
80    pub decision: String,
81    #[serde(default)]
82    pub reason: Option<String>,
83    #[serde(default)]
84    pub proof_id: Option<String>,
85    #[serde(default)]
86    pub approval_id: Option<String>,
87}
88
89/// Shared state used by every connection handler.
90pub struct ProxyState {
91    pub config: ProxyConfig,
92    pub http: Client<HttpConnector, Full<Bytes>>,
93    counter: AtomicU64,
94    /// OpenTelemetry handle owned by the binary entry point. Set once at
95    /// startup via [`ProxyState::set_otel`]. We use `OnceLock` so the
96    /// `Arc<ProxyState>` we hand to connection tasks does not need to be
97    /// rebuilt after wiring telemetry.
98    otel: std::sync::OnceLock<tf_otel::TfOtelHandle>,
99}
100
101impl ProxyState {
102    pub fn new(config: ProxyConfig) -> Arc<Self> {
103        let http = Client::builder(TokioExecutor::new())
104            .pool_idle_timeout(std::time::Duration::from_secs(30))
105            .build_http();
106        Arc::new(Self {
107            config,
108            http,
109            counter: AtomicU64::new(0),
110            otel: std::sync::OnceLock::new(),
111        })
112    }
113
114    /// Install the process-wide OpenTelemetry handle. Should be called
115    /// at most once during startup, before [`run`] handles any traffic.
116    /// Uses `OnceLock` so this works through an `Arc<Self>`.
117    pub fn set_otel(&self, handle: tf_otel::TfOtelHandle) {
118        let _ = self.otel.set(handle);
119    }
120
121    /// Borrow the OTel handle, if one was installed.
122    pub fn otel(&self) -> Option<&tf_otel::TfOtelHandle> {
123        self.otel.get()
124    }
125
126    fn next_trace_id(&self) -> String {
127        let n = self.counter.fetch_add(1, Ordering::Relaxed);
128        let nanos = std::time::SystemTime::now()
129            .duration_since(std::time::UNIX_EPOCH)
130            .map(|d| d.as_nanos())
131            .unwrap_or(0);
132        format!("tf-proxy-{nanos}-{n}")
133    }
134}
135
136/// Pull a host token out of either an `Authorization: Bearer ...` header or a
137/// session cookie. Returns the token plus a heuristic kind: `"jwt"` if it
138/// looks like a JWT (three dot-separated segments), otherwise `"opaque"`.
139pub fn extract_host_token(headers: &hyper::HeaderMap) -> Option<(String, String)> {
140    if let Some(v) = headers.get(hyper::header::AUTHORIZATION) {
141        if let Ok(s) = v.to_str() {
142            if let Some(rest) = s.strip_prefix("Bearer ") {
143                let token = rest.trim().to_string();
144                if !token.is_empty() {
145                    let kind = classify_token(&token);
146                    return Some((token, kind));
147                }
148            }
149        }
150    }
151    if let Some(v) = headers.get(hyper::header::COOKIE) {
152        if let Ok(s) = v.to_str() {
153            for raw in s.split(';') {
154                let part = raw.trim();
155                for name in ["__session=", "__Secure-next-auth.session-token="] {
156                    if let Some(val) = part.strip_prefix(name) {
157                        let token = val.trim().to_string();
158                        if !token.is_empty() {
159                            let kind = classify_token(&token);
160                            return Some((token, kind));
161                        }
162                    }
163                }
164            }
165        }
166    }
167    None
168}
169
170fn classify_token(t: &str) -> String {
171    let dots = t.bytes().filter(|b| *b == b'.').count();
172    if dots == 2
173        && t.bytes()
174            .all(|b| b.is_ascii_alphanumeric() || b == b'.' || b == b'-' || b == b'_')
175    {
176        "jwt".to_string()
177    } else {
178        "opaque".to_string()
179    }
180}
181
182/// Build the `action` string for a request. We split on `/`, drop empty
183/// segments, lowercase the method, and join with `.`.
184pub fn action_for(method: &Method, path: &str) -> String {
185    let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
186    let m = method.as_str().to_ascii_lowercase();
187    if segments.is_empty() {
188        format!("{m}.root")
189    } else {
190        format!("{m}.{}", segments.join("."))
191    }
192}
193
194/// Detect a websocket upgrade request.
195pub fn is_websocket_upgrade(req: &Request<Incoming>) -> bool {
196    req.headers()
197        .get(UPGRADE)
198        .and_then(|v| v.to_str().ok())
199        .map(|s| s.eq_ignore_ascii_case("websocket"))
200        .unwrap_or(false)
201}
202
203/// Call tf-daemon's `/v1/decide`. Returns `Err` when the daemon is
204/// unreachable or returns a malformed body.
205pub async fn call_decide(
206    state: &ProxyState,
207    req_headers: &hyper::HeaderMap,
208    method: &Method,
209    path: &str,
210    client_addr: SocketAddr,
211    is_connect: bool,
212) -> Result<DecideResponse, String> {
213    let (token, kind) = match extract_host_token(req_headers) {
214        Some((t, k)) => (Some(t), Some(k)),
215        None => (None, None),
216    };
217    let action = if is_connect {
218        "connect".to_string()
219    } else {
220        action_for(method, path)
221    };
222    let user_agent = req_headers
223        .get(hyper::header::USER_AGENT)
224        .and_then(|v| v.to_str().ok())
225        .map(|s| s.to_string());
226    let context = serde_json::json!({
227        "ip": client_addr.ip().to_string(),
228        "user_agent": user_agent,
229    });
230    let body = DecideRequest {
231        actor: None,
232        host_token: token,
233        host_token_kind: kind,
234        action,
235        target: path.to_string(),
236        context,
237        trace_id: state.next_trace_id(),
238    };
239    let url = format!("{}/v1/decide", state.config.daemon.trim_end_matches('/'));
240    let payload = serde_json::to_vec(&body).map_err(|e| format!("encode decide body: {e}"))?;
241    let mut rb = Request::builder()
242        .method(Method::POST)
243        .uri(&url)
244        .header("content-type", "application/json");
245    if let Some(t) = state.config.admin_token.as_deref() {
246        rb = rb.header("X-Admin-Token", t);
247    }
248    let req = rb
249        .body(Full::new(Bytes::from(payload)))
250        .map_err(|e| format!("build decide request: {e}"))?;
251    let resp = state
252        .http
253        .request(req)
254        .await
255        .map_err(|e| format!("daemon unreachable: {e}"))?;
256    if !resp.status().is_success() {
257        return Err(format!("daemon status {}", resp.status()));
258    }
259    let bytes = resp
260        .into_body()
261        .collect()
262        .await
263        .map_err(|e| format!("daemon body read: {e}"))?
264        .to_bytes();
265    let txt = String::from_utf8_lossy(&bytes);
266    let decoded: DecideResponse =
267        serde_json::from_slice(&bytes).map_err(|e| format!("daemon malformed body: {e}: {txt}"))?;
268    if decoded.decision.is_empty() {
269        return Err("daemon returned empty decision".to_string());
270    }
271    Ok(decoded)
272}
273
274/// Forward an HTTP request to the upstream service and copy the response
275/// back as a hyper response.
276pub async fn forward_to_upstream(
277    state: &ProxyState,
278    req: Request<Incoming>,
279) -> Result<Response<Full<Bytes>>, String> {
280    let upstream_base = state.config.upstream.trim_end_matches('/').to_string();
281    let path_and_query = req
282        .uri()
283        .path_and_query()
284        .map(|p| p.as_str().to_string())
285        .unwrap_or_else(|| req.uri().path().to_string());
286    let url = format!("{upstream_base}{path_and_query}");
287
288    let (parts, body) = req.into_parts();
289    let body_bytes = body
290        .collect()
291        .await
292        .map_err(|e| format!("read req body: {e}"))?
293        .to_bytes();
294
295    let mut rb = Request::builder().method(parts.method.clone()).uri(&url);
296    for (k, v) in parts.headers.iter() {
297        // Skip hop-by-hop headers and host (the client sets it from the
298        // upstream URI).
299        let name = k.as_str().to_ascii_lowercase();
300        if matches!(
301            name.as_str(),
302            "host"
303                | "connection"
304                | "keep-alive"
305                | "proxy-authenticate"
306                | "proxy-authorization"
307                | "te"
308                | "trailers"
309                | "transfer-encoding"
310                | "upgrade"
311                | "content-length"
312        ) {
313            continue;
314        }
315        rb = rb.header(k, v);
316    }
317    let upstream_req = rb
318        .body(Full::new(body_bytes))
319        .map_err(|e| format!("build upstream request: {e}"))?;
320    let upstream_resp = state
321        .http
322        .request(upstream_req)
323        .await
324        .map_err(|e| format!("upstream error: {e}"))?;
325    let (resp_parts, resp_body) = upstream_resp.into_parts();
326    let mut builder = Response::builder().status(resp_parts.status);
327    for (k, v) in resp_parts.headers.iter() {
328        let name = k.as_str().to_ascii_lowercase();
329        if matches!(
330            name.as_str(),
331            "connection"
332                | "keep-alive"
333                | "proxy-authenticate"
334                | "proxy-authorization"
335                | "te"
336                | "trailers"
337                | "transfer-encoding"
338                | "upgrade"
339                | "content-length"
340        ) {
341            continue;
342        }
343        builder = builder.header(k, v);
344    }
345    let body = resp_body
346        .collect()
347        .await
348        .map_err(|e| format!("upstream body: {e}"))?
349        .to_bytes();
350    builder
351        .body(Full::new(body))
352        .map_err(|e| format!("response build: {e}"))
353}
354
355fn json_response(status: StatusCode, body: serde_json::Value) -> Response<Full<Bytes>> {
356    let bytes = serde_json::to_vec(&body).unwrap_or_else(|_| b"{}".to_vec());
357    Response::builder()
358        .status(status)
359        .header(hyper::header::CONTENT_TYPE, "application/json")
360        .body(Full::new(Bytes::from(bytes)))
361        .expect("static response")
362}
363
364/// Top-level request handler. Returns a hyper response wrapping a buffered
365/// body. Websocket upgrades are handled out of band by the connection driver
366/// (see [`serve_connection`]).
367pub async fn handle_request(
368    state: Arc<ProxyState>,
369    req: Request<Incoming>,
370    client_addr: SocketAddr,
371) -> Result<Response<Full<Bytes>>, Infallible> {
372    let method = req.method().clone();
373    let path = req.uri().path().to_string();
374    let is_ws = is_websocket_upgrade(&req);
375
376    // Span the entire decision lifetime under tf.daemon.decide so the
377    // Grafana trace explorer can pivot on tf.action / tf.decision /
378    // tf.actor_resolved exactly like the TS daemon does.
379    let span = tracing::info_span!(
380        "tf.daemon.decide",
381        otel.name = "tf.daemon.decide",
382        tf.action = %action_for(&method, &path),
383        tf.target = %path,
384        // Filled in once the decision lands. tracing supports
385        // record-after-creation so we don't need to know these up front.
386        tf.decision = tracing::field::Empty,
387        tf.actor_resolved = tracing::field::Empty,
388    );
389    let _enter = span.enter();
390
391    let started = std::time::Instant::now();
392    let decision =
393        match call_decide(&state, req.headers(), &method, &path, client_addr, is_ws).await {
394            Ok(d) => d,
395            Err(e) => {
396                error!(error = %e, "daemon decide failed");
397                return Ok(json_response(
398                    StatusCode::BAD_GATEWAY,
399                    serde_json::json!({"error": "daemon-error", "detail": e}),
400                ));
401            }
402        };
403
404    // Record the outcome on the active span and on the canonical metric
405    // pipeline. Both are fire-and-forget; if telemetry is off they are
406    // cheap no-ops.
407    span.record("tf.decision", decision.decision.as_str());
408    if let Some(otel) = state.otel() {
409        let actor = "unknown";
410        let action = action_for(&method, &path);
411        let elapsed = started.elapsed().as_secs_f64();
412        tf_otel::record_decide(
413            otel.metrics(),
414            &decision.decision,
415            &action,
416            actor,
417            Some(&path),
418            elapsed,
419        );
420    }
421
422    info!(
423        decision = %decision.decision,
424        method = %method,
425        path = %path,
426        mode = ?state.config.mode,
427        "decision"
428    );
429
430    match decision.decision.as_str() {
431        "allow" => match forward_to_upstream(&state, req).await {
432            Ok(r) => Ok(r),
433            Err(e) => {
434                error!(error = %e, "upstream forward failed");
435                Ok(json_response(
436                    StatusCode::BAD_GATEWAY,
437                    serde_json::json!({"error": "upstream-error", "detail": e}),
438                ))
439            }
440        },
441        "deny" => {
442            if state.config.mode == Mode::Enforce {
443                let realm = state
444                    .config
445                    .upstream
446                    .parse::<Uri>()
447                    .ok()
448                    .and_then(|u| u.host().map(|s| s.to_string()))
449                    .unwrap_or_else(|| state.config.profile.clone());
450                let reason = decision.reason.clone().unwrap_or_default();
451                let proof = decision.proof_id.clone().unwrap_or_default();
452                let www_auth = format!("TrustForge realm=\"{realm}\", reason=\"{reason}\"");
453                let body = serde_json::json!({
454                    "error": "deny",
455                    "reason": reason,
456                    "proof_id": proof,
457                });
458                let mut resp = json_response(StatusCode::FORBIDDEN, body);
459                if let Ok(hv) = HeaderValue::from_str(&www_auth) {
460                    resp.headers_mut()
461                        .insert(hyper::header::WWW_AUTHENTICATE, hv);
462                }
463                Ok(resp)
464            } else {
465                warn!(
466                    proof_id = ?decision.proof_id,
467                    reason = ?decision.reason,
468                    "observe-only: forwarding despite deny"
469                );
470                match forward_to_upstream(&state, req).await {
471                    Ok(r) => Ok(r),
472                    Err(e) => Ok(json_response(
473                        StatusCode::BAD_GATEWAY,
474                        serde_json::json!({"error": "upstream-error", "detail": e}),
475                    )),
476                }
477            }
478        }
479        "approval-required" | "approval_required" => {
480            let approval_id = decision.approval_id.clone().unwrap_or_default();
481            let location = format!(
482                "{}/v1/approval/{}",
483                state.config.daemon.trim_end_matches('/'),
484                approval_id
485            );
486            let body = serde_json::json!({
487                "status": "pending",
488                "approval_id": approval_id,
489            });
490            let mut resp = json_response(StatusCode::ACCEPTED, body);
491            if let Ok(hv) = HeaderValue::from_str(&location) {
492                resp.headers_mut().insert(hyper::header::LOCATION, hv);
493            }
494            Ok(resp)
495        }
496        "log-only" | "log_only" => {
497            info!(
498                proof_id = ?decision.proof_id,
499                reason = ?decision.reason,
500                "proof-event log-only forwarding"
501            );
502            match forward_to_upstream(&state, req).await {
503                Ok(r) => Ok(r),
504                Err(e) => Ok(json_response(
505                    StatusCode::BAD_GATEWAY,
506                    serde_json::json!({"error": "upstream-error", "detail": e}),
507                )),
508            }
509        }
510        other => {
511            warn!(decision = %other, "unknown decision; treating as deny");
512            if state.config.mode == Mode::Enforce {
513                Ok(json_response(
514                    StatusCode::FORBIDDEN,
515                    serde_json::json!({"error": "deny", "reason": format!("unknown decision: {other}")}),
516                ))
517            } else {
518                match forward_to_upstream(&state, req).await {
519                    Ok(r) => Ok(r),
520                    Err(e) => Ok(json_response(
521                        StatusCode::BAD_GATEWAY,
522                        serde_json::json!({"error": "upstream-error", "detail": e}),
523                    )),
524                }
525            }
526        }
527    }
528}
529
530/// Drive a single connection. If the request is a websocket upgrade (and the
531/// daemon allows it), we transparently splice the client and upstream TCP
532/// streams together. Otherwise we fall through to [`handle_request`].
533pub async fn serve_connection(state: Arc<ProxyState>, stream: TcpStream, client_addr: SocketAddr) {
534    // We need to peek at the request before deciding between websocket
535    // splice and regular HTTP service. Use hyper with a service_fn that owns
536    // a one-shot signal: when the handler sees a websocket upgrade we let
537    // hyper finish the response (we'll send a 101 directly) and then take
538    // the underlying TCP socket out of the connection.
539    //
540    // To keep the implementation simple and predictable, we read enough of
541    // the first request to see the headers ourselves, decide, and then
542    // either handle it inline as websocket or hand the original bytes to
543    // hyper for normal processing.
544    if let Err(e) = serve_connection_inner(state, stream, client_addr).await {
545        debug!(error = %e, "connection ended");
546    }
547}
548
549async fn serve_connection_inner(
550    state: Arc<ProxyState>,
551    mut stream: TcpStream,
552    client_addr: SocketAddr,
553) -> std::io::Result<()> {
554    // Peek the first chunk to detect a websocket upgrade without consuming.
555    let mut peek = [0u8; 4096];
556    let n = stream.peek(&mut peek).await?;
557    if n == 0 {
558        return Ok(());
559    }
560    let head = &peek[..n];
561    let is_ws = head_looks_like_websocket(head);
562
563    if is_ws {
564        // Parse method+path+headers minimally for the decide call.
565        if let Some((method, path, headers)) = parse_request_head(head) {
566            let m = Method::from_bytes(method.as_bytes()).unwrap_or(Method::GET);
567            match call_decide(&state, &headers, &m, &path, client_addr, true).await {
568                Ok(d) => {
569                    let allow = d.decision == "allow"
570                        || (state.config.mode == Mode::ObserveOnly
571                            && d.decision != "approval-required"
572                            && d.decision != "approval_required");
573                    if allow {
574                        info!(decision = %d.decision, "websocket upgrade allowed");
575                        return splice_to_upstream(&state, stream).await;
576                    } else if d.decision == "approval-required" || d.decision == "approval_required"
577                    {
578                        let approval_id = d.approval_id.unwrap_or_default();
579                        let loc = format!(
580                            "{}/v1/approval/{}",
581                            state.config.daemon.trim_end_matches('/'),
582                            approval_id
583                        );
584                        let body =
585                            format!("{{\"status\":\"pending\",\"approval_id\":\"{approval_id}\"}}");
586                        let resp = format!(
587                            "HTTP/1.1 202 Accepted\r\nLocation: {loc}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
588                            body.len()
589                        );
590                        stream.write_all(resp.as_bytes()).await?;
591                        return Ok(());
592                    } else {
593                        let reason = d.reason.unwrap_or_default();
594                        let body = format!("{{\"error\":\"deny\",\"reason\":\"{reason}\"}}");
595                        let resp = format!(
596                            "HTTP/1.1 403 Forbidden\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
597                            body.len()
598                        );
599                        stream.write_all(resp.as_bytes()).await?;
600                        return Ok(());
601                    }
602                }
603                Err(e) => {
604                    error!(error = %e, "ws decide failed");
605                    let body = format!("{{\"error\":\"daemon-error\",\"detail\":\"{e}\"}}");
606                    let resp = format!(
607                        "HTTP/1.1 502 Bad Gateway\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
608                        body.len()
609                    );
610                    stream.write_all(resp.as_bytes()).await?;
611                    return Ok(());
612                }
613            }
614        }
615    }
616
617    let io = TokioIo::new(stream);
618    let svc = service_fn(move |req: Request<Incoming>| {
619        let state = state.clone();
620        async move { handle_request(state, req, client_addr).await }
621    });
622    if let Err(e) = hyper::server::conn::http1::Builder::new()
623        .serve_connection(io, svc)
624        .await
625    {
626        debug!(error = %e, "hyper serve_connection error");
627    }
628    Ok(())
629}
630
631fn head_looks_like_websocket(buf: &[u8]) -> bool {
632    let s = match std::str::from_utf8(buf) {
633        Ok(s) => s,
634        Err(_) => return false,
635    };
636    let head_end = match s.find("\r\n\r\n") {
637        Some(i) => i,
638        None => s.len(),
639    };
640    let head = &s[..head_end];
641    for line in head.split("\r\n").skip(1) {
642        if let Some((name, value)) = line.split_once(':') {
643            if name.trim().eq_ignore_ascii_case("upgrade")
644                && value.trim().eq_ignore_ascii_case("websocket")
645            {
646                return true;
647            }
648        }
649    }
650    false
651}
652
653fn parse_request_head(buf: &[u8]) -> Option<(String, String, hyper::HeaderMap)> {
654    let s = std::str::from_utf8(buf).ok()?;
655    let head_end = s.find("\r\n\r\n").unwrap_or(s.len());
656    let head = &s[..head_end];
657    let mut lines = head.split("\r\n");
658    let request_line = lines.next()?;
659    let mut parts = request_line.split_whitespace();
660    let method = parts.next()?.to_string();
661    let path = parts.next()?.to_string();
662    let mut headers = hyper::HeaderMap::new();
663    for line in lines {
664        if let Some((n, v)) = line.split_once(':') {
665            if let (Ok(hn), Ok(hv)) = (
666                HeaderName::from_bytes(n.trim().as_bytes()),
667                HeaderValue::from_str(v.trim()),
668            ) {
669                headers.insert(hn, hv);
670            }
671        }
672    }
673    Some((method, path, headers))
674}
675
676async fn splice_to_upstream(state: &ProxyState, mut client: TcpStream) -> std::io::Result<()> {
677    // Connect to upstream and pipe bytes both ways. We assume the upstream
678    // URL is plain http://host:port (TLS upstream is out of scope for this
679    // first cut; reverse-proxy TLS termination happens at the listener).
680    let url = match state.config.upstream.parse::<Uri>() {
681        Ok(u) => u,
682        Err(e) => {
683            error!(error = %e, "bad upstream URL");
684            return Ok(());
685        }
686    };
687    let host = url.host().unwrap_or("127.0.0.1");
688    let port = url.port_u16().unwrap_or(match url.scheme_str() {
689        Some("https") => 443,
690        _ => 80,
691    });
692    let mut upstream = TcpStream::connect((host, port)).await?;
693    // Drain the peeked bytes from the client (we did not consume them
694    // because we used `peek`) and ferry both directions.
695    let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream).await;
696    Ok(())
697}
698
699/// Run the proxy until cancelled. Returns when the listener is dropped.
700pub async fn run(state: Arc<ProxyState>) -> std::io::Result<()> {
701    let listener = TcpListener::bind(state.config.listen).await?;
702    info!(listen = %state.config.listen, upstream = %state.config.upstream, "tf-proxy listening");
703    let tls = build_tls_acceptor(&state.config)?;
704    loop {
705        let (stream, addr) = listener.accept().await?;
706        let s = state.clone();
707        match &tls {
708            Some(acceptor) => {
709                let acceptor = acceptor.clone();
710                tokio::spawn(async move {
711                    let _ = serve_tls(s, acceptor, stream, addr).await;
712                });
713            }
714            None => {
715                tokio::spawn(async move {
716                    serve_connection(s, stream, addr).await;
717                });
718            }
719        }
720    }
721}
722
723fn build_tls_acceptor(cfg: &ProxyConfig) -> std::io::Result<Option<TlsAcceptor>> {
724    match (&cfg.tls_cert, &cfg.tls_key) {
725        (Some(cert_path), Some(key_path)) => {
726            let cert_file = std::fs::File::open(cert_path)?;
727            let key_file = std::fs::File::open(key_path)?;
728            let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
729                rustls_pemfile::certs(&mut BufReader::new(cert_file))
730                    .collect::<Result<Vec<_>, _>>()?;
731            let key =
732                rustls_pemfile::private_key(&mut BufReader::new(key_file))?.ok_or_else(|| {
733                    std::io::Error::new(
734                        std::io::ErrorKind::InvalidData,
735                        "no private key in pem file",
736                    )
737                })?;
738            let cfg = rustls::ServerConfig::builder()
739                .with_no_client_auth()
740                .with_single_cert(certs, key)
741                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
742            Ok(Some(TlsAcceptor::from(Arc::new(cfg))))
743        }
744        (None, None) => Ok(None),
745        _ => Err(std::io::Error::new(
746            std::io::ErrorKind::InvalidInput,
747            "--tls-cert and --tls-key must be provided together",
748        )),
749    }
750}
751
752async fn serve_tls(
753    state: Arc<ProxyState>,
754    acceptor: TlsAcceptor,
755    stream: TcpStream,
756    addr: SocketAddr,
757) -> std::io::Result<()> {
758    let tls_stream = acceptor.accept(stream).await?;
759    let io = TokioIo::new(tls_stream);
760    let svc = service_fn(move |req: Request<Incoming>| {
761        let state = state.clone();
762        async move { handle_request(state, req, addr).await }
763    });
764    if let Err(e) = hyper::server::conn::http1::Builder::new()
765        .serve_connection(io, svc)
766        .await
767    {
768        debug!(error = %e, "tls hyper error");
769    }
770    Ok(())
771}