Skip to main content

shift_proxy/routes/
passthrough.rs

1//! Catch-all passthrough handler.
2//!
3//! Forwards requests to the upstream provider detected from the request
4//! path. Used for routes not explicitly matched by the provider-specific
5//! handlers (e.g., OpenAI batch endpoints, Anthropic beta paths, GET
6//! /v1/models, etc.).
7
8use crate::forward::forward_request;
9use crate::ProxyState;
10use axum::extract::State;
11use axum::http::{HeaderMap, Method, StatusCode, Uri};
12use axum::response::{IntoResponse, Response};
13
14/// Catch-all handler — detect provider from path and forward unchanged.
15/// Handles all HTTP methods (GET, POST, PUT, PATCH, DELETE).
16pub async fn passthrough_handler(
17    State(state): State<ProxyState>,
18    method: Method,
19    uri: Uri,
20    headers: HeaderMap,
21    body: String,
22) -> Response {
23    let path = uri.path();
24    let provider = detect_provider_from_route(path);
25
26    let base_url = match provider {
27        Some("anthropic") => &state.config.providers.anthropic,
28        Some("openai") => &state.config.providers.openai,
29        Some("google") => &state.config.providers.google,
30        _ => {
31            return (
32                StatusCode::NOT_FOUND,
33                axum::Json(serde_json::json!({
34                    "error": "Unknown route — cannot determine upstream provider"
35                })),
36            )
37                .into_response();
38        }
39    };
40
41    let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
42    let target_url = format!("{}{}{}", base_url, path, query);
43
44    if state.config.verbose {
45        tracing::info!("Passthrough: {} {} → {}{}", method, path, base_url, path);
46    }
47
48    // For methods without a body (GET, HEAD), don't forward one.
49    let has_body = !matches!(method, Method::GET | Method::HEAD);
50    let body = if has_body { Some(body) } else { None };
51
52    forward_request(
53        &state.http_client,
54        method.as_str(),
55        &target_url,
56        &headers,
57        body,
58    )
59    .await
60}
61
62/// Detect which provider a route path belongs to.
63fn detect_provider_from_route(path: &str) -> Option<&'static str> {
64    if path.starts_with("/v1/messages") {
65        Some("anthropic")
66    } else if path.starts_with("/v1/chat/") || path.starts_with("/v1/embeddings") {
67        Some("openai")
68    } else if path.starts_with("/v1beta/") || path.starts_with("/v1/models/gemini") {
69        Some("google")
70    } else if path.starts_with("/v1/") {
71        // Default to OpenAI for /v1/* paths (most common)
72        Some("openai")
73    } else {
74        None
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn detect_anthropic() {
84        assert_eq!(
85            detect_provider_from_route("/v1/messages"),
86            Some("anthropic")
87        );
88        assert_eq!(
89            detect_provider_from_route("/v1/messages/batches"),
90            Some("anthropic")
91        );
92    }
93
94    #[test]
95    fn detect_openai() {
96        assert_eq!(
97            detect_provider_from_route("/v1/chat/completions"),
98            Some("openai")
99        );
100        assert_eq!(detect_provider_from_route("/v1/embeddings"), Some("openai"));
101    }
102
103    #[test]
104    fn detect_google() {
105        assert_eq!(
106            detect_provider_from_route("/v1beta/models/gemini-2.5-pro:generateContent"),
107            Some("google")
108        );
109    }
110
111    #[test]
112    fn detect_unknown() {
113        assert_eq!(detect_provider_from_route("/unknown"), None);
114    }
115}