1use axum::body::Body;
7use axum::http::{HeaderMap, HeaderValue, StatusCode};
8use axum::response::{IntoResponse, Response};
9use reqwest::Client;
10
11const STRIP_RESPONSE_HEADERS: &[&str] = &[
17 "content-encoding",
18 "content-length",
19 "transfer-encoding",
20 "connection",
21 "keep-alive",
22 "proxy-authenticate",
23 "proxy-authorization",
24 "te",
25 "trailer",
26 "upgrade",
27];
28
29const STRIP_REQUEST_HEADERS: &[&str] = &["host", "content-length"];
31
32pub async fn forward_request(
38 client: &Client,
39 method: &str,
40 target_url: &str,
41 request_headers: &HeaderMap,
42 body: Option<String>,
43) -> Response {
44 let forwarded_headers = forward_headers(request_headers);
45
46 let mut req = match method.to_uppercase().as_str() {
47 "POST" => client.post(target_url),
48 "GET" => client.get(target_url),
49 "PUT" => client.put(target_url),
50 "DELETE" => client.delete(target_url),
51 "PATCH" => client.patch(target_url),
52 _ => client.post(target_url),
53 };
54
55 req = req.headers(forwarded_headers);
56
57 if let Some(body) = body {
58 req = req.body(body);
59 }
60
61 match req.send().await {
62 Ok(upstream) => stream_response(upstream),
63 Err(err) => {
64 tracing::error!("upstream error: {}", err);
65 (
66 StatusCode::BAD_GATEWAY,
67 axum::Json(serde_json::json!({
68 "error": "Bad Gateway",
69 "detail": "Upstream provider unreachable"
70 })),
71 )
72 .into_response()
73 }
74 }
75}
76
77fn stream_response(upstream: reqwest::Response) -> Response {
80 let status = StatusCode::from_u16(upstream.status().as_u16()).unwrap_or(StatusCode::OK);
81
82 let mut response_headers = HeaderMap::new();
83 for (name, value) in upstream.headers() {
84 let name_str = name.as_str().to_lowercase();
85 if STRIP_RESPONSE_HEADERS
86 .iter()
87 .any(|h| h == &name_str.as_str())
88 {
89 continue;
90 }
91 if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
92 response_headers.insert(name.clone(), v);
93 }
94 }
95
96 let body = Body::from_stream(upstream.bytes_stream());
99
100 let mut response = Response::new(body);
101 *response.status_mut() = status;
102 *response.headers_mut() = response_headers;
103 response
104}
105
106fn forward_headers(original: &HeaderMap) -> HeaderMap {
108 let strip: std::collections::HashSet<&str> = STRIP_REQUEST_HEADERS.iter().copied().collect();
109
110 let mut result = HeaderMap::new();
111 for (name, value) in original {
112 let name_lower = name.as_str().to_lowercase();
113 if !strip.contains(name_lower.as_str()) {
114 result.insert(name.clone(), value.clone());
115 }
116 }
117 result
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use axum::http::header;
124
125 #[test]
126 fn forward_headers_strips_host_and_content_length() {
127 let mut headers = HeaderMap::new();
128 headers.insert(header::HOST, "example.com".parse().unwrap());
129 headers.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
130 headers.insert(header::AUTHORIZATION, "Bearer sk-test".parse().unwrap());
131 headers.insert("x-api-key", "sk-ant-test".parse().unwrap());
132 headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
133
134 let result = forward_headers(&headers);
135
136 assert!(result.get(header::HOST).is_none());
137 assert!(result.get(header::CONTENT_LENGTH).is_none());
138 assert_eq!(result.get(header::AUTHORIZATION).unwrap(), "Bearer sk-test");
139 assert_eq!(result.get("x-api-key").unwrap(), "sk-ant-test");
140 assert_eq!(result.get("anthropic-version").unwrap(), "2023-06-01");
141 }
142}