Skip to main content

shunt/
forwarder.rs

1use anyhow::{Context, Result};
2use axum::body::Body;
3use axum::http::{HeaderMap, HeaderName, HeaderValue, Response};
4use bytes::Bytes;
5use reqwest::Client;
6use std::str::FromStr;
7use std::time::Instant;
8use tracing::info;
9use uuid::Uuid;
10
11use crate::config::AccountConfig;
12
13/// Headers that must never be forwarded in either direction.
14const HOP_BY_HOP: &[&str] = &[
15    "connection",
16    "keep-alive",
17    "proxy-authenticate",
18    "proxy-authorization",
19    "te",
20    "trailers",
21    "transfer-encoding",
22    "upgrade",
23    "host",
24    "content-length",
25];
26
27/// Auth headers that the proxy manages — always stripped from client requests
28/// and replaced with the selected account's credential.
29const CLIENT_AUTH_HEADERS: &[&str] = &["authorization", "x-api-key"];
30
31fn is_hop_by_hop(name: &str) -> bool {
32    HOP_BY_HOP.contains(&name.to_ascii_lowercase().as_str())
33}
34
35fn is_client_auth(name: &str) -> bool {
36    CLIENT_AUTH_HEADERS.contains(&name.to_ascii_lowercase().as_str())
37}
38
39pub struct Forwarder {
40    client: Client,
41}
42
43impl Forwarder {
44    pub fn new(_base_url: impl Into<String>, timeout_secs: u64) -> Result<Self> {
45        let client = Client::builder()
46            .timeout(std::time::Duration::from_secs(timeout_secs))
47            .redirect(reqwest::redirect::Policy::none())
48            .build()
49            .context("Failed to build HTTP client")?;
50
51        Ok(Self { client })
52    }
53
54    /// Forward a request to the upstream using the given account's OAuth credential.
55    ///
56    /// - `upstream` overrides the base URL for this account (per-provider routing).
57    /// - Strips `Authorization` and `x-api-key` from the client request.
58    /// - Injects `Authorization: Bearer <token>` (live token, may differ from account.credential).
59    /// - Keeps the upstream TCP connection alive for streaming responses.
60    pub async fn forward(
61        &self,
62        upstream: &str,
63        method: &str,
64        path: &str,
65        body: Bytes,
66        client_headers: &HeaderMap,
67        account: &AccountConfig,
68        token: &str,
69    ) -> Result<Response<Body>> {
70        let request_id = &Uuid::new_v4().to_string()[..8];
71        let url = format!("{}{}", upstream, path);
72
73        let mut upstream_headers = reqwest::header::HeaderMap::new();
74
75        for (name, value) in client_headers.iter() {
76            let lower = name.as_str().to_ascii_lowercase();
77            if is_hop_by_hop(&lower) || is_client_auth(&lower) {
78                continue;
79            }
80            if let (Ok(n), Ok(v)) = (
81                reqwest::header::HeaderName::from_str(name.as_str()),
82                reqwest::header::HeaderValue::from_bytes(value.as_bytes()),
83            ) {
84                upstream_headers.insert(n, v);
85            }
86        }
87
88        // Inject provider-specific auth headers (Bearer token + any required protocol headers).
89        account.provider.inject_auth_headers(&mut upstream_headers, token)
90            .context("failed to inject auth headers")?;
91
92        let t0 = Instant::now();
93        let upstream_resp = self
94            .client
95            .request(
96                reqwest::Method::from_str(method).context("invalid method")?,
97                &url,
98            )
99            .headers(upstream_headers)
100            .body(body)
101            .send()
102            .await
103            .context("upstream request failed")?;
104
105        let latency_ms = t0.elapsed().as_millis();
106        let status = upstream_resp.status();
107
108        info!(
109            request_id = %request_id,
110            account = %account.name,
111            status = status.as_u16(),
112            latency_ms = %latency_ms,
113            path = %path,
114            "request forwarded"
115        );
116
117        let mut builder = Response::builder().status(status.as_u16());
118
119        for (name, value) in upstream_resp.headers().iter() {
120            if !is_hop_by_hop(name.as_str()) {
121                if let (Ok(n), Ok(v)) = (
122                    HeaderName::from_str(name.as_str()),
123                    HeaderValue::from_bytes(value.as_bytes()),
124                ) {
125                    builder = builder.header(n, v);
126                }
127            }
128        }
129
130        let body = Body::from_stream(upstream_resp.bytes_stream());
131        Ok(builder.body(body).expect("response builder invariant"))
132    }
133}