1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use std::time::Instant;
8use tracing::info;
9use uuid::Uuid;
10
11use crate::config::AccountConfig;
12
13const HOP_BY_HOP: &[&str] = &[
15 "connection",
16 "keep-alive",
17 "proxy-authenticate",
18 "proxy-authorization",
19 "te",
20 "trailers",
21 "transfer-encoding",
22 "upgrade",
23 "host",
24 "content-length",
25];
26
27const ALLOWED_REQUEST_HEADERS: &[&str] = &[
30 "content-type",
31 "accept",
32 "anthropic-version",
33 "anthropic-beta",
34 "anthropic-dangerous-direct-browser-access",
35 "x-request-id",
36 "user-agent",
37 "openai-sentinel-chat-requirements-token",
39];
40
41const BLOCKED_RESPONSE_HEADERS: &[&str] = &[
43 "set-cookie",
44 "set-cookie2",
45 "access-control-allow-origin",
46 "access-control-allow-credentials",
47 "access-control-allow-methods",
48 "access-control-allow-headers",
49];
50
51fn is_hop_by_hop(name: &str) -> bool {
52 HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
53}
54
55pub struct Forwarder {
56 client: Client,
57}
58
59impl Forwarder {
60 pub fn new(_base_url: impl Into<String>, timeout_secs: u64) -> Result<Self> {
61 let client = Client::builder()
62 .timeout(std::time::Duration::from_secs(timeout_secs))
63 .redirect(reqwest::redirect::Policy::none())
64 .build()
65 .context("Failed to build HTTP client")?;
66
67 Ok(Self { client })
68 }
69
70 pub async fn forward(
77 &self,
78 upstream: &str,
79 method: &str,
80 path: &str,
81 body: Bytes,
82 client_headers: &HeaderMap,
83 account: &AccountConfig,
84 token: &str,
85 ) -> Result<Response<Body>> {
86 let request_id = &Uuid::new_v4().to_string()[..8];
87 let url = format!("{}{}", upstream, path);
88
89 let mut upstream_headers = reqwest::header::HeaderMap::new();
90
91 for &name in ALLOWED_REQUEST_HEADERS {
93 if let Some(value) = client_headers.get(name) {
94 if let Ok(n) = reqwest::header::HeaderName::from_str(name) {
95 if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
96 upstream_headers.insert(n, v);
97 }
98 }
99 }
100 }
101
102 account.provider.inject_auth_headers(&mut upstream_headers, token)
104 .context("failed to inject auth headers")?;
105
106 let t0 = Instant::now();
107 let upstream_resp = self
108 .client
109 .request(
110 reqwest::Method::from_str(method).context("invalid method")?,
111 &url,
112 )
113 .headers(upstream_headers)
114 .body(body.clone())
115 .send()
116 .await
117 .context("upstream request failed")?;
118
119 let status = upstream_resp.status();
120
121 let mut builder = Response::builder().status(status.as_u16());
122
123 for (name, value) in upstream_resp.headers().iter() {
124 let lower = name.as_str().to_ascii_lowercase();
125 if is_hop_by_hop(&lower) || BLOCKED_RESPONSE_HEADERS.contains(&lower.as_str()) {
127 continue;
128 }
129 if let (Ok(n), Ok(v)) = (
130 HeaderName::from_str(name.as_str()),
131 HeaderValue::from_bytes(value.as_bytes()),
132 ) {
133 builder = builder.header(n, v);
134 }
135 }
136
137 let body = Body::from_stream(upstream_resp.bytes_stream());
138 Ok(builder.body(body).expect("response builder invariant"))
139 }
140}