1use std::convert::Infallible;
14use std::net::SocketAddr;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use http_body_util::{BodyExt, Full};
19use hyper::body::{Bytes, Incoming};
20use hyper::header::{HeaderName, HeaderValue, UPGRADE};
21use hyper::service::service_fn;
22use hyper::{Method, Request, Response, StatusCode, Uri};
23use hyper_util::client::legacy::{connect::HttpConnector, Client};
24use hyper_util::rt::{TokioExecutor, TokioIo};
25use serde::{Deserialize, Serialize};
26use std::io::BufReader;
27use tokio::io::AsyncWriteExt;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::TlsAcceptor;
30use tracing::{debug, error, info, warn};
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum Mode {
35 ObserveOnly,
37 Enforce,
39}
40
41impl std::str::FromStr for Mode {
42 type Err = String;
43 fn from_str(s: &str) -> Result<Self, Self::Err> {
44 match s {
45 "observe-only" | "observe_only" | "observe" => Ok(Mode::ObserveOnly),
46 "enforce" => Ok(Mode::Enforce),
47 other => Err(format!("unknown mode: {other}")),
48 }
49 }
50}
51
52#[derive(Clone, Debug)]
54pub struct ProxyConfig {
55 pub listen: SocketAddr,
56 pub upstream: String,
57 pub daemon: String,
58 pub admin_token: Option<String>,
59 pub profile: String,
60 pub mode: Mode,
61 pub tls_cert: Option<String>,
62 pub tls_key: Option<String>,
63}
64
65#[derive(Serialize, Debug)]
67pub struct DecideRequest<'a> {
68 pub actor: Option<&'a str>,
69 pub host_token: Option<String>,
70 pub host_token_kind: Option<String>,
71 pub action: String,
72 pub target: String,
73 pub context: serde_json::Value,
74 pub trace_id: String,
75}
76
77#[derive(Deserialize, Debug, Clone, Default)]
79pub struct DecideResponse {
80 pub decision: String,
81 #[serde(default)]
82 pub reason: Option<String>,
83 #[serde(default)]
84 pub proof_id: Option<String>,
85 #[serde(default)]
86 pub approval_id: Option<String>,
87}
88
89pub struct ProxyState {
91 pub config: ProxyConfig,
92 pub http: Client<HttpConnector, Full<Bytes>>,
93 counter: AtomicU64,
94 otel: std::sync::OnceLock<tf_otel::TfOtelHandle>,
99}
100
101impl ProxyState {
102 pub fn new(config: ProxyConfig) -> Arc<Self> {
103 let http = Client::builder(TokioExecutor::new())
104 .pool_idle_timeout(std::time::Duration::from_secs(30))
105 .build_http();
106 Arc::new(Self {
107 config,
108 http,
109 counter: AtomicU64::new(0),
110 otel: std::sync::OnceLock::new(),
111 })
112 }
113
114 pub fn set_otel(&self, handle: tf_otel::TfOtelHandle) {
118 let _ = self.otel.set(handle);
119 }
120
121 pub fn otel(&self) -> Option<&tf_otel::TfOtelHandle> {
123 self.otel.get()
124 }
125
126 fn next_trace_id(&self) -> String {
127 let n = self.counter.fetch_add(1, Ordering::Relaxed);
128 let nanos = std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .map(|d| d.as_nanos())
131 .unwrap_or(0);
132 format!("tf-proxy-{nanos}-{n}")
133 }
134}
135
136pub fn extract_host_token(headers: &hyper::HeaderMap) -> Option<(String, String)> {
140 if let Some(v) = headers.get(hyper::header::AUTHORIZATION) {
141 if let Ok(s) = v.to_str() {
142 if let Some(rest) = s.strip_prefix("Bearer ") {
143 let token = rest.trim().to_string();
144 if !token.is_empty() {
145 let kind = classify_token(&token);
146 return Some((token, kind));
147 }
148 }
149 }
150 }
151 if let Some(v) = headers.get(hyper::header::COOKIE) {
152 if let Ok(s) = v.to_str() {
153 for raw in s.split(';') {
154 let part = raw.trim();
155 for name in ["__session=", "__Secure-next-auth.session-token="] {
156 if let Some(val) = part.strip_prefix(name) {
157 let token = val.trim().to_string();
158 if !token.is_empty() {
159 let kind = classify_token(&token);
160 return Some((token, kind));
161 }
162 }
163 }
164 }
165 }
166 }
167 None
168}
169
170fn classify_token(t: &str) -> String {
171 let dots = t.bytes().filter(|b| *b == b'.').count();
172 if dots == 2
173 && t.bytes()
174 .all(|b| b.is_ascii_alphanumeric() || b == b'.' || b == b'-' || b == b'_')
175 {
176 "jwt".to_string()
177 } else {
178 "opaque".to_string()
179 }
180}
181
182pub fn action_for(method: &Method, path: &str) -> String {
185 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
186 let m = method.as_str().to_ascii_lowercase();
187 if segments.is_empty() {
188 format!("{m}.root")
189 } else {
190 format!("{m}.{}", segments.join("."))
191 }
192}
193
194pub fn is_websocket_upgrade(req: &Request<Incoming>) -> bool {
196 req.headers()
197 .get(UPGRADE)
198 .and_then(|v| v.to_str().ok())
199 .map(|s| s.eq_ignore_ascii_case("websocket"))
200 .unwrap_or(false)
201}
202
203pub async fn call_decide(
206 state: &ProxyState,
207 req_headers: &hyper::HeaderMap,
208 method: &Method,
209 path: &str,
210 client_addr: SocketAddr,
211 is_connect: bool,
212) -> Result<DecideResponse, String> {
213 let (token, kind) = match extract_host_token(req_headers) {
214 Some((t, k)) => (Some(t), Some(k)),
215 None => (None, None),
216 };
217 let action = if is_connect {
218 "connect".to_string()
219 } else {
220 action_for(method, path)
221 };
222 let user_agent = req_headers
223 .get(hyper::header::USER_AGENT)
224 .and_then(|v| v.to_str().ok())
225 .map(|s| s.to_string());
226 let context = serde_json::json!({
227 "ip": client_addr.ip().to_string(),
228 "user_agent": user_agent,
229 });
230 let body = DecideRequest {
231 actor: None,
232 host_token: token,
233 host_token_kind: kind,
234 action,
235 target: path.to_string(),
236 context,
237 trace_id: state.next_trace_id(),
238 };
239 let url = format!("{}/v1/decide", state.config.daemon.trim_end_matches('/'));
240 let payload = serde_json::to_vec(&body).map_err(|e| format!("encode decide body: {e}"))?;
241 let mut rb = Request::builder()
242 .method(Method::POST)
243 .uri(&url)
244 .header("content-type", "application/json");
245 if let Some(t) = state.config.admin_token.as_deref() {
246 rb = rb.header("X-Admin-Token", t);
247 }
248 let req = rb
249 .body(Full::new(Bytes::from(payload)))
250 .map_err(|e| format!("build decide request: {e}"))?;
251 let resp = state
252 .http
253 .request(req)
254 .await
255 .map_err(|e| format!("daemon unreachable: {e}"))?;
256 if !resp.status().is_success() {
257 return Err(format!("daemon status {}", resp.status()));
258 }
259 let bytes = resp
260 .into_body()
261 .collect()
262 .await
263 .map_err(|e| format!("daemon body read: {e}"))?
264 .to_bytes();
265 let txt = String::from_utf8_lossy(&bytes);
266 let decoded: DecideResponse =
267 serde_json::from_slice(&bytes).map_err(|e| format!("daemon malformed body: {e}: {txt}"))?;
268 if decoded.decision.is_empty() {
269 return Err("daemon returned empty decision".to_string());
270 }
271 Ok(decoded)
272}
273
274pub async fn forward_to_upstream(
277 state: &ProxyState,
278 req: Request<Incoming>,
279) -> Result<Response<Full<Bytes>>, String> {
280 let upstream_base = state.config.upstream.trim_end_matches('/').to_string();
281 let path_and_query = req
282 .uri()
283 .path_and_query()
284 .map(|p| p.as_str().to_string())
285 .unwrap_or_else(|| req.uri().path().to_string());
286 let url = format!("{upstream_base}{path_and_query}");
287
288 let (parts, body) = req.into_parts();
289 let body_bytes = body
290 .collect()
291 .await
292 .map_err(|e| format!("read req body: {e}"))?
293 .to_bytes();
294
295 let mut rb = Request::builder().method(parts.method.clone()).uri(&url);
296 for (k, v) in parts.headers.iter() {
297 let name = k.as_str().to_ascii_lowercase();
300 if matches!(
301 name.as_str(),
302 "host"
303 | "connection"
304 | "keep-alive"
305 | "proxy-authenticate"
306 | "proxy-authorization"
307 | "te"
308 | "trailers"
309 | "transfer-encoding"
310 | "upgrade"
311 | "content-length"
312 ) {
313 continue;
314 }
315 rb = rb.header(k, v);
316 }
317 let upstream_req = rb
318 .body(Full::new(body_bytes))
319 .map_err(|e| format!("build upstream request: {e}"))?;
320 let upstream_resp = state
321 .http
322 .request(upstream_req)
323 .await
324 .map_err(|e| format!("upstream error: {e}"))?;
325 let (resp_parts, resp_body) = upstream_resp.into_parts();
326 let mut builder = Response::builder().status(resp_parts.status);
327 for (k, v) in resp_parts.headers.iter() {
328 let name = k.as_str().to_ascii_lowercase();
329 if matches!(
330 name.as_str(),
331 "connection"
332 | "keep-alive"
333 | "proxy-authenticate"
334 | "proxy-authorization"
335 | "te"
336 | "trailers"
337 | "transfer-encoding"
338 | "upgrade"
339 | "content-length"
340 ) {
341 continue;
342 }
343 builder = builder.header(k, v);
344 }
345 let body = resp_body
346 .collect()
347 .await
348 .map_err(|e| format!("upstream body: {e}"))?
349 .to_bytes();
350 builder
351 .body(Full::new(body))
352 .map_err(|e| format!("response build: {e}"))
353}
354
355fn json_response(status: StatusCode, body: serde_json::Value) -> Response<Full<Bytes>> {
356 let bytes = serde_json::to_vec(&body).unwrap_or_else(|_| b"{}".to_vec());
357 Response::builder()
358 .status(status)
359 .header(hyper::header::CONTENT_TYPE, "application/json")
360 .body(Full::new(Bytes::from(bytes)))
361 .expect("static response")
362}
363
364pub async fn handle_request(
368 state: Arc<ProxyState>,
369 req: Request<Incoming>,
370 client_addr: SocketAddr,
371) -> Result<Response<Full<Bytes>>, Infallible> {
372 let method = req.method().clone();
373 let path = req.uri().path().to_string();
374 let is_ws = is_websocket_upgrade(&req);
375
376 let span = tracing::info_span!(
380 "tf.daemon.decide",
381 otel.name = "tf.daemon.decide",
382 tf.action = %action_for(&method, &path),
383 tf.target = %path,
384 tf.decision = tracing::field::Empty,
387 tf.actor_resolved = tracing::field::Empty,
388 );
389 let _enter = span.enter();
390
391 let started = std::time::Instant::now();
392 let decision =
393 match call_decide(&state, req.headers(), &method, &path, client_addr, is_ws).await {
394 Ok(d) => d,
395 Err(e) => {
396 error!(error = %e, "daemon decide failed");
397 return Ok(json_response(
398 StatusCode::BAD_GATEWAY,
399 serde_json::json!({"error": "daemon-error", "detail": e}),
400 ));
401 }
402 };
403
404 span.record("tf.decision", decision.decision.as_str());
408 if let Some(otel) = state.otel() {
409 let actor = "unknown";
410 let action = action_for(&method, &path);
411 let elapsed = started.elapsed().as_secs_f64();
412 tf_otel::record_decide(
413 otel.metrics(),
414 &decision.decision,
415 &action,
416 actor,
417 Some(&path),
418 elapsed,
419 );
420 }
421
422 info!(
423 decision = %decision.decision,
424 method = %method,
425 path = %path,
426 mode = ?state.config.mode,
427 "decision"
428 );
429
430 match decision.decision.as_str() {
431 "allow" => match forward_to_upstream(&state, req).await {
432 Ok(r) => Ok(r),
433 Err(e) => {
434 error!(error = %e, "upstream forward failed");
435 Ok(json_response(
436 StatusCode::BAD_GATEWAY,
437 serde_json::json!({"error": "upstream-error", "detail": e}),
438 ))
439 }
440 },
441 "deny" => {
442 if state.config.mode == Mode::Enforce {
443 let realm = state
444 .config
445 .upstream
446 .parse::<Uri>()
447 .ok()
448 .and_then(|u| u.host().map(|s| s.to_string()))
449 .unwrap_or_else(|| state.config.profile.clone());
450 let reason = decision.reason.clone().unwrap_or_default();
451 let proof = decision.proof_id.clone().unwrap_or_default();
452 let www_auth = format!("TrustForge realm=\"{realm}\", reason=\"{reason}\"");
453 let body = serde_json::json!({
454 "error": "deny",
455 "reason": reason,
456 "proof_id": proof,
457 });
458 let mut resp = json_response(StatusCode::FORBIDDEN, body);
459 if let Ok(hv) = HeaderValue::from_str(&www_auth) {
460 resp.headers_mut()
461 .insert(hyper::header::WWW_AUTHENTICATE, hv);
462 }
463 Ok(resp)
464 } else {
465 warn!(
466 proof_id = ?decision.proof_id,
467 reason = ?decision.reason,
468 "observe-only: forwarding despite deny"
469 );
470 match forward_to_upstream(&state, req).await {
471 Ok(r) => Ok(r),
472 Err(e) => Ok(json_response(
473 StatusCode::BAD_GATEWAY,
474 serde_json::json!({"error": "upstream-error", "detail": e}),
475 )),
476 }
477 }
478 }
479 "approval-required" | "approval_required" => {
480 let approval_id = decision.approval_id.clone().unwrap_or_default();
481 let location = format!(
482 "{}/v1/approval/{}",
483 state.config.daemon.trim_end_matches('/'),
484 approval_id
485 );
486 let body = serde_json::json!({
487 "status": "pending",
488 "approval_id": approval_id,
489 });
490 let mut resp = json_response(StatusCode::ACCEPTED, body);
491 if let Ok(hv) = HeaderValue::from_str(&location) {
492 resp.headers_mut().insert(hyper::header::LOCATION, hv);
493 }
494 Ok(resp)
495 }
496 "log-only" | "log_only" => {
497 info!(
498 proof_id = ?decision.proof_id,
499 reason = ?decision.reason,
500 "proof-event log-only forwarding"
501 );
502 match forward_to_upstream(&state, req).await {
503 Ok(r) => Ok(r),
504 Err(e) => Ok(json_response(
505 StatusCode::BAD_GATEWAY,
506 serde_json::json!({"error": "upstream-error", "detail": e}),
507 )),
508 }
509 }
510 other => {
511 warn!(decision = %other, "unknown decision; treating as deny");
512 if state.config.mode == Mode::Enforce {
513 Ok(json_response(
514 StatusCode::FORBIDDEN,
515 serde_json::json!({"error": "deny", "reason": format!("unknown decision: {other}")}),
516 ))
517 } else {
518 match forward_to_upstream(&state, req).await {
519 Ok(r) => Ok(r),
520 Err(e) => Ok(json_response(
521 StatusCode::BAD_GATEWAY,
522 serde_json::json!({"error": "upstream-error", "detail": e}),
523 )),
524 }
525 }
526 }
527 }
528}
529
530pub async fn serve_connection(state: Arc<ProxyState>, stream: TcpStream, client_addr: SocketAddr) {
534 if let Err(e) = serve_connection_inner(state, stream, client_addr).await {
545 debug!(error = %e, "connection ended");
546 }
547}
548
549async fn serve_connection_inner(
550 state: Arc<ProxyState>,
551 mut stream: TcpStream,
552 client_addr: SocketAddr,
553) -> std::io::Result<()> {
554 let mut peek = [0u8; 4096];
556 let n = stream.peek(&mut peek).await?;
557 if n == 0 {
558 return Ok(());
559 }
560 let head = &peek[..n];
561 let is_ws = head_looks_like_websocket(head);
562
563 if is_ws {
564 if let Some((method, path, headers)) = parse_request_head(head) {
566 let m = Method::from_bytes(method.as_bytes()).unwrap_or(Method::GET);
567 match call_decide(&state, &headers, &m, &path, client_addr, true).await {
568 Ok(d) => {
569 let allow = d.decision == "allow"
570 || (state.config.mode == Mode::ObserveOnly
571 && d.decision != "approval-required"
572 && d.decision != "approval_required");
573 if allow {
574 info!(decision = %d.decision, "websocket upgrade allowed");
575 return splice_to_upstream(&state, stream).await;
576 } else if d.decision == "approval-required" || d.decision == "approval_required"
577 {
578 let approval_id = d.approval_id.unwrap_or_default();
579 let loc = format!(
580 "{}/v1/approval/{}",
581 state.config.daemon.trim_end_matches('/'),
582 approval_id
583 );
584 let body =
585 format!("{{\"status\":\"pending\",\"approval_id\":\"{approval_id}\"}}");
586 let resp = format!(
587 "HTTP/1.1 202 Accepted\r\nLocation: {loc}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
588 body.len()
589 );
590 stream.write_all(resp.as_bytes()).await?;
591 return Ok(());
592 } else {
593 let reason = d.reason.unwrap_or_default();
594 let body = format!("{{\"error\":\"deny\",\"reason\":\"{reason}\"}}");
595 let resp = format!(
596 "HTTP/1.1 403 Forbidden\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
597 body.len()
598 );
599 stream.write_all(resp.as_bytes()).await?;
600 return Ok(());
601 }
602 }
603 Err(e) => {
604 error!(error = %e, "ws decide failed");
605 let body = format!("{{\"error\":\"daemon-error\",\"detail\":\"{e}\"}}");
606 let resp = format!(
607 "HTTP/1.1 502 Bad Gateway\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
608 body.len()
609 );
610 stream.write_all(resp.as_bytes()).await?;
611 return Ok(());
612 }
613 }
614 }
615 }
616
617 let io = TokioIo::new(stream);
618 let svc = service_fn(move |req: Request<Incoming>| {
619 let state = state.clone();
620 async move { handle_request(state, req, client_addr).await }
621 });
622 if let Err(e) = hyper::server::conn::http1::Builder::new()
623 .serve_connection(io, svc)
624 .await
625 {
626 debug!(error = %e, "hyper serve_connection error");
627 }
628 Ok(())
629}
630
631fn head_looks_like_websocket(buf: &[u8]) -> bool {
632 let s = match std::str::from_utf8(buf) {
633 Ok(s) => s,
634 Err(_) => return false,
635 };
636 let head_end = match s.find("\r\n\r\n") {
637 Some(i) => i,
638 None => s.len(),
639 };
640 let head = &s[..head_end];
641 for line in head.split("\r\n").skip(1) {
642 if let Some((name, value)) = line.split_once(':') {
643 if name.trim().eq_ignore_ascii_case("upgrade")
644 && value.trim().eq_ignore_ascii_case("websocket")
645 {
646 return true;
647 }
648 }
649 }
650 false
651}
652
653fn parse_request_head(buf: &[u8]) -> Option<(String, String, hyper::HeaderMap)> {
654 let s = std::str::from_utf8(buf).ok()?;
655 let head_end = s.find("\r\n\r\n").unwrap_or(s.len());
656 let head = &s[..head_end];
657 let mut lines = head.split("\r\n");
658 let request_line = lines.next()?;
659 let mut parts = request_line.split_whitespace();
660 let method = parts.next()?.to_string();
661 let path = parts.next()?.to_string();
662 let mut headers = hyper::HeaderMap::new();
663 for line in lines {
664 if let Some((n, v)) = line.split_once(':') {
665 if let (Ok(hn), Ok(hv)) = (
666 HeaderName::from_bytes(n.trim().as_bytes()),
667 HeaderValue::from_str(v.trim()),
668 ) {
669 headers.insert(hn, hv);
670 }
671 }
672 }
673 Some((method, path, headers))
674}
675
676async fn splice_to_upstream(state: &ProxyState, mut client: TcpStream) -> std::io::Result<()> {
677 let url = match state.config.upstream.parse::<Uri>() {
681 Ok(u) => u,
682 Err(e) => {
683 error!(error = %e, "bad upstream URL");
684 return Ok(());
685 }
686 };
687 let host = url.host().unwrap_or("127.0.0.1");
688 let port = url.port_u16().unwrap_or(match url.scheme_str() {
689 Some("https") => 443,
690 _ => 80,
691 });
692 let mut upstream = TcpStream::connect((host, port)).await?;
693 let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream).await;
696 Ok(())
697}
698
699pub async fn run(state: Arc<ProxyState>) -> std::io::Result<()> {
701 let listener = TcpListener::bind(state.config.listen).await?;
702 info!(listen = %state.config.listen, upstream = %state.config.upstream, "tf-proxy listening");
703 let tls = build_tls_acceptor(&state.config)?;
704 loop {
705 let (stream, addr) = listener.accept().await?;
706 let s = state.clone();
707 match &tls {
708 Some(acceptor) => {
709 let acceptor = acceptor.clone();
710 tokio::spawn(async move {
711 let _ = serve_tls(s, acceptor, stream, addr).await;
712 });
713 }
714 None => {
715 tokio::spawn(async move {
716 serve_connection(s, stream, addr).await;
717 });
718 }
719 }
720 }
721}
722
723fn build_tls_acceptor(cfg: &ProxyConfig) -> std::io::Result<Option<TlsAcceptor>> {
724 match (&cfg.tls_cert, &cfg.tls_key) {
725 (Some(cert_path), Some(key_path)) => {
726 let cert_file = std::fs::File::open(cert_path)?;
727 let key_file = std::fs::File::open(key_path)?;
728 let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
729 rustls_pemfile::certs(&mut BufReader::new(cert_file))
730 .collect::<Result<Vec<_>, _>>()?;
731 let key =
732 rustls_pemfile::private_key(&mut BufReader::new(key_file))?.ok_or_else(|| {
733 std::io::Error::new(
734 std::io::ErrorKind::InvalidData,
735 "no private key in pem file",
736 )
737 })?;
738 let cfg = rustls::ServerConfig::builder()
739 .with_no_client_auth()
740 .with_single_cert(certs, key)
741 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
742 Ok(Some(TlsAcceptor::from(Arc::new(cfg))))
743 }
744 (None, None) => Ok(None),
745 _ => Err(std::io::Error::new(
746 std::io::ErrorKind::InvalidInput,
747 "--tls-cert and --tls-key must be provided together",
748 )),
749 }
750}
751
752async fn serve_tls(
753 state: Arc<ProxyState>,
754 acceptor: TlsAcceptor,
755 stream: TcpStream,
756 addr: SocketAddr,
757) -> std::io::Result<()> {
758 let tls_stream = acceptor.accept(stream).await?;
759 let io = TokioIo::new(tls_stream);
760 let svc = service_fn(move |req: Request<Incoming>| {
761 let state = state.clone();
762 async move { handle_request(state, req, addr).await }
763 });
764 if let Err(e) = hyper::server::conn::http1::Builder::new()
765 .serve_connection(io, svc)
766 .await
767 {
768 debug!(error = %e, "tls hyper error");
769 }
770 Ok(())
771}