1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use std::time::Instant;
8use tracing::info;
9use uuid::Uuid;
10
11use crate::config::AccountConfig;
12
13const HOP_BY_HOP: &[&str] = &[
15 "connection",
16 "keep-alive",
17 "proxy-authenticate",
18 "proxy-authorization",
19 "te",
20 "trailers",
21 "transfer-encoding",
22 "upgrade",
23 "host",
24 "content-length",
25];
26
27const CLIENT_AUTH_HEADERS: &[&str] = &["authorization", "x-api-key"];
30
31fn is_hop_by_hop(name: &str) -> bool {
32 HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
33}
34
35fn is_client_auth(name: &str) -> bool {
36 CLIENT_AUTH_HEADERS.contains(&name.to_ascii_lowercase().as_str())
37}
38
39pub struct Forwarder {
40 client: Client,
41}
42
43impl Forwarder {
44 pub fn new(_base_url: impl Into<String>, timeout_secs: u64) -> Result<Self> {
45 let client = Client::builder()
46 .timeout(std::time::Duration::from_secs(timeout_secs))
47 .redirect(reqwest::redirect::Policy::none())
48 .build()
49 .context("Failed to build HTTP client")?;
50
51 Ok(Self { client })
52 }
53
54 pub async fn forward(
61 &self,
62 upstream: &str,
63 method: &str,
64 path: &str,
65 body: Bytes,
66 client_headers: &HeaderMap,
67 account: &AccountConfig,
68 token: &str,
69 ) -> Result<Response<Body>> {
70 let request_id = &Uuid::new_v4().to_string()[..8];
71 let url = format!("{}{}", upstream, path);
72
73 let mut upstream_headers = reqwest::header::HeaderMap::new();
74
75 for (name, value) in client_headers.iter() {
76 let lower = name.as_str().to_ascii_lowercase();
77 if is_hop_by_hop(&lower) || is_client_auth(&lower) {
78 continue;
79 }
80 if let (Ok(n), Ok(v)) = (
81 reqwest::header::HeaderName::from_str(name.as_str()),
82 reqwest::header::HeaderValue::from_bytes(value.as_bytes()),
83 ) {
84 upstream_headers.insert(n, v);
85 }
86 }
87
88 account.provider.inject_auth_headers(&mut upstream_headers, token)
90 .context("failed to inject auth headers")?;
91
92 let t0 = Instant::now();
93 let upstream_resp = self
94 .client
95 .request(
96 reqwest::Method::from_str(method).context("invalid method")?,
97 &url,
98 )
99 .headers(upstream_headers)
100 .body(body)
101 .send()
102 .await
103 .context("upstream request failed")?;
104
105 let latency_ms = t0.elapsed().as_millis();
106 let status = upstream_resp.status();
107
108 info!(
109 request_id = %request_id,
110 account = %account.name,
111 status = status.as_u16(),
112 latency_ms = %latency_ms,
113 path = %path,
114 "request forwarded"
115 );
116
117 let mut builder = Response::builder().status(status.as_u16());
118
119 for (name, value) in upstream_resp.headers().iter() {
120 if !is_hop_by_hop(name.as_str()) {
121 if let (Ok(n), Ok(v)) = (
122 HeaderName::from_str(name.as_str()),
123 HeaderValue::from_bytes(value.as_bytes()),
124 ) {
125 builder = builder.header(n, v);
126 }
127 }
128 }
129
130 let body = Body::from_stream(upstream_resp.bytes_stream());
131 Ok(builder.body(body).expect("response builder invariant"))
132 }
133}