Skip to main content

shunt/
forwarder.rs

1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use std::time::Instant;
8use tracing::info;
9use uuid::Uuid;
10
11use crate::config::AccountConfig;
12
13/// Headers that must never be forwarded in either direction.
14const HOP_BY_HOP: &[&str] = &[
15    "connection",
16    "keep-alive",
17    "proxy-authenticate",
18    "proxy-authorization",
19    "te",
20    "trailers",
21    "transfer-encoding",
22    "upgrade",
23    "host",
24    "content-length",
25];
26
27/// Headers the proxy explicitly passes through to upstream.
28/// All other client-supplied headers are dropped (allowlist approach, #15).
29const ALLOWED_REQUEST_HEADERS: &[&str] = &[
30    "content-type",
31    "accept",
32    "anthropic-version",
33    "anthropic-beta",
34    "anthropic-dangerous-direct-browser-access",
35    "x-request-id",
36    "user-agent",
37    // chatgpt.com sentinel token — injected by proxy, pass through
38    "openai-sentinel-chat-requirements-token",
39];
40
41/// Sensitive response headers that upstream must never inject into client responses (#21).
42const BLOCKED_RESPONSE_HEADERS: &[&str] = &[
43    "set-cookie",
44    "set-cookie2",
45    "access-control-allow-origin",
46    "access-control-allow-credentials",
47    "access-control-allow-methods",
48    "access-control-allow-headers",
49];
50
51fn is_hop_by_hop(name: &str) -> bool {
52    HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
53}
54
55pub struct Forwarder {
56    client: Client,
57}
58
59impl Forwarder {
60    pub fn new(_base_url: impl Into<String>, timeout_secs: u64) -> Result<Self> {
61        let client = Client::builder()
62            .timeout(std::time::Duration::from_secs(timeout_secs))
63            .redirect(reqwest::redirect::Policy::none())
64            .build()
65            .context("Failed to build HTTP client")?;
66
67        Ok(Self { client })
68    }
69
70    /// Forward a request to the upstream using the given account's OAuth credential.
71    ///
72    /// - `upstream` overrides the base URL for this account (per-provider routing).
73    /// - Strips `Authorization` and `x-api-key` from the client request.
74    /// - Injects `Authorization: Bearer <token>` (live token, may differ from account.credential).
75    /// - Keeps the upstream TCP connection alive for streaming responses.
76    pub async fn forward(
77        &self,
78        upstream: &str,
79        method: &str,
80        path: &str,
81        body: Bytes,
82        client_headers: &HeaderMap,
83        account: &AccountConfig,
84        token: &str,
85    ) -> Result<Response<Body>> {
86        let request_id = &Uuid::new_v4().to_string()[..8];
87        let url = format!("{}{}", upstream, path);
88
89        let mut upstream_headers = reqwest::header::HeaderMap::new();
90
91        // #15: allowlist — only forward explicitly permitted client headers.
92        for &name in ALLOWED_REQUEST_HEADERS {
93            if let Some(value) = client_headers.get(name) {
94                if let Ok(n) = reqwest::header::HeaderName::from_str(name) {
95                    if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
96                        upstream_headers.insert(n, v);
97                    }
98                }
99            }
100        }
101
102        // Inject provider-specific auth headers (Bearer token + any required protocol headers).
103        account.provider.inject_auth_headers(&mut upstream_headers, token)
104            .context("failed to inject auth headers")?;
105
106        let t0 = Instant::now();
107        let upstream_resp = self
108            .client
109            .request(
110                reqwest::Method::from_str(method).context("invalid method")?,
111                &url,
112            )
113            .headers(upstream_headers)
114            .body(body.clone())
115            .send()
116            .await
117            .context("upstream request failed")?;
118
119        let status = upstream_resp.status();
120
121        let mut builder = Response::builder().status(status.as_u16());
122
123        for (name, value) in upstream_resp.headers().iter() {
124            let lower = name.as_str().to_ascii_lowercase();
125            // #21: drop hop-by-hop and sensitive response headers.
126            if is_hop_by_hop(&lower) || BLOCKED_RESPONSE_HEADERS.contains(&lower.as_str()) {
127                continue;
128            }
129            if let (Ok(n), Ok(v)) = (
130                HeaderName::from_str(name.as_str()),
131                HeaderValue::from_bytes(value.as_bytes()),
132            ) {
133                builder = builder.header(n, v);
134            }
135        }
136
137        let body = Body::from_stream(upstream_resp.bytes_stream());
138        Ok(builder.body(body).expect("response builder invariant"))
139    }
140}