Skip to main content

shift_proxy/
forward.rs

1//! Forward requests to upstream provider APIs and stream responses back.
2//!
3//! Handles header forwarding (auth passthrough), hop-by-hop header stripping
4//! (RFC 9110 §7.6.1), and transparent SSE/chunked response streaming.
5
6use axum::body::Body;
7use axum::http::{HeaderMap, HeaderValue, StatusCode};
8use axum::response::{IntoResponse, Response};
9use reqwest::Client;
10
11/// Headers stripped from upstream responses before forwarding to the client.
12///
13/// - `content-encoding` / `content-length`: reqwest auto-decompresses response
14///   bodies, so these are stale. Forwarding them causes double-decompression.
15/// - Hop-by-hop headers per RFC 9110 §7.6.1.
16const STRIP_RESPONSE_HEADERS: &[&str] = &[
17    "content-encoding",
18    "content-length",
19    "transfer-encoding",
20    "connection",
21    "keep-alive",
22    "proxy-authenticate",
23    "proxy-authorization",
24    "te",
25    "trailer",
26    "upgrade",
27];
28
29/// Headers stripped from the forwarded request (we let the upstream set its own).
30const STRIP_REQUEST_HEADERS: &[&str] = &["host", "content-length"];
31
32/// Forward a request to an upstream URL, streaming the response back.
33///
34/// Auth headers (`authorization`, `x-api-key`, `anthropic-version`, `x-goog-api-key`)
35/// pass through unchanged. The response body is streamed directly — SSE and
36/// chunked responses are not buffered.
37pub async fn forward_request(
38    client: &Client,
39    method: &str,
40    target_url: &str,
41    request_headers: &HeaderMap,
42    body: Option<String>,
43) -> Response {
44    let forwarded_headers = forward_headers(request_headers);
45
46    let mut req = match method.to_uppercase().as_str() {
47        "POST" => client.post(target_url),
48        "GET" => client.get(target_url),
49        "PUT" => client.put(target_url),
50        "DELETE" => client.delete(target_url),
51        "PATCH" => client.patch(target_url),
52        _ => client.post(target_url),
53    };
54
55    req = req.headers(forwarded_headers);
56
57    if let Some(body) = body {
58        req = req.body(body);
59    }
60
61    match req.send().await {
62        Ok(upstream) => stream_response(upstream),
63        Err(err) => {
64            tracing::error!("upstream error: {}", err);
65            (
66                StatusCode::BAD_GATEWAY,
67                axum::Json(serde_json::json!({
68                    "error": "Bad Gateway",
69                    "detail": "Upstream provider unreachable"
70                })),
71            )
72                .into_response()
73        }
74    }
75}
76
77/// Convert a reqwest Response into an axum Response, streaming the body
78/// and stripping hop-by-hop headers.
79fn stream_response(upstream: reqwest::Response) -> Response {
80    let status = StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::OK);
81
82    let mut response_headers = HeaderMap::new();
83    for (name, value) in upstream.headers() {
84        let name_str = name.as_str().to_lowercase();
85        if STRIP_RESPONSE_HEADERS
86            .iter()
87            .any(|h| h == &name_str.as_str())
88        {
89            continue;
90        }
91        if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
92            response_headers.insert(name.clone(), v);
93        }
94    }
95
96    // Stream the response body directly without buffering.
97    // This is critical for SSE (Anthropic/OpenAI streaming) to work correctly.
98    let body = Body::from_stream(upstream.bytes_stream());
99
100    let mut response = Response::new(body);
101    *response.status_mut() = status;
102    *response.headers_mut() = response_headers;
103    response
104}
105
106/// Forward request headers, stripping host/content-length but passing auth through.
107fn forward_headers(original: &HeaderMap) -> HeaderMap {
108    let strip: std::collections::HashSet<&str> = STRIP_REQUEST_HEADERS.iter().copied().collect();
109
110    let mut result = HeaderMap::new();
111    for (name, value) in original {
112        let name_lower = name.as_str().to_lowercase();
113        if !strip.contains(name_lower.as_str()) {
114            result.insert(name.clone(), value.clone());
115        }
116    }
117    result
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use axum::http::header;
124
125    #[test]
126    fn forward_headers_strips_host_and_content_length() {
127        let mut headers = HeaderMap::new();
128        headers.insert(header::HOST, "example.com".parse().unwrap());
129        headers.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
130        headers.insert(header::AUTHORIZATION, "Bearer sk-test".parse().unwrap());
131        headers.insert("x-api-key", "sk-ant-test".parse().unwrap());
132        headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
133
134        let result = forward_headers(&headers);
135
136        assert!(result.get(header::HOST).is_none());
137        assert!(result.get(header::CONTENT_LENGTH).is_none());
138        assert_eq!(result.get(header::AUTHORIZATION).unwrap(), "Bearer sk-test");
139        assert_eq!(result.get("x-api-key").unwrap(), "sk-ant-test");
140        assert_eq!(result.get("anthropic-version").unwrap(), "2023-06-01");
141    }
142}