Skip to main content

systemprompt_api/services/gateway/
upstream.rs

1mod sse;
2
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use axum::body::Body;
8use axum::response::Response;
9use bytes::Bytes;
10use futures_util::StreamExt;
11use http::header::CONTENT_TYPE;
12use systemprompt_models::profile::GatewayRoute;
13
14use super::converter;
15use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
16
17#[derive(Debug)]
18pub struct UpstreamCtx<'a> {
19    pub route: &'a GatewayRoute,
20    pub api_key: &'a str,
21    pub raw_body: Bytes,
22    pub request: &'a AnthropicGatewayRequest,
23    pub is_streaming: bool,
24}
25
26#[allow(missing_debug_implementations)]
27pub enum UpstreamOutcome {
28    Buffered {
29        status: http::StatusCode,
30        content_type: String,
31        body: Bytes,
32        served_model: Option<String>,
33    },
34    Streaming {
35        status: http::StatusCode,
36        stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
37    },
38}
39
40impl UpstreamOutcome {
41    pub const fn status(&self) -> http::StatusCode {
42        match self {
43            Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
44        }
45    }
46}
47
48#[async_trait]
49pub trait GatewayUpstream: Send + Sync {
50    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
51}
52
53#[derive(Debug, Clone, Copy)]
54pub struct AnthropicCompatibleUpstream;
55
56#[async_trait]
57impl GatewayUpstream for AnthropicCompatibleUpstream {
58    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
59        let client = reqwest::Client::new();
60        let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
61
62        let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
63        let body_to_send: Bytes = if upstream_model == ctx.request.model {
64            ctx.raw_body
65        } else {
66            super::flatten::rewrite_request_model(ctx.raw_body, upstream_model)?
67        };
68
69        let mut req = client
70            .post(&url)
71            .header("x-api-key", ctx.api_key)
72            .header("anthropic-version", "2023-06-01")
73            .header("content-type", "application/json")
74            .body(body_to_send);
75
76        for (name, value) in &ctx.route.extra_headers {
77            req = req.header(name.as_str(), value.as_str());
78        }
79
80        let upstream_response = req
81            .send()
82            .await
83            .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
84
85        let status = upstream_response.status();
86        let upstream_content_type = upstream_response
87            .headers()
88            .get(CONTENT_TYPE)
89            .and_then(|v| v.to_str().ok())
90            .unwrap_or("application/json")
91            .to_string();
92
93        if ctx.is_streaming {
94            let stream = upstream_response.bytes_stream().map(|chunk| {
95                chunk.map_err(|e| {
96                    tracing::warn!(error = %e, "Anthropic stream chunk error");
97                    std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
98                })
99            });
100            return Ok(UpstreamOutcome::Streaming {
101                status,
102                stream: stream.boxed(),
103            });
104        }
105
106        let response_bytes = upstream_response
107            .bytes()
108            .await
109            .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
110
111        let served_model = super::flatten::parse_served_model(&response_bytes);
112
113        Ok(UpstreamOutcome::Buffered {
114            status,
115            content_type: upstream_content_type,
116            body: response_bytes,
117            served_model,
118        })
119    }
120}
121
122#[derive(Debug, Clone, Copy)]
123pub struct OpenAiCompatibleUpstream;
124
125#[async_trait]
126impl GatewayUpstream for OpenAiCompatibleUpstream {
127    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
128        let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
129        let openai_request = converter::to_openai_request(ctx.request, upstream_model);
130        let model = ctx.request.model.as_str();
131
132        let client = reqwest::Client::new();
133        let url = format!(
134            "{}/chat/completions",
135            ctx.route.endpoint.trim_end_matches('/')
136        );
137
138        let mut req = client
139            .post(&url)
140            .header("authorization", format!("Bearer {}", ctx.api_key))
141            .header("content-type", "application/json")
142            .json(&openai_request);
143
144        for (name, value) in &ctx.route.extra_headers {
145            req = req.header(name.as_str(), value.as_str());
146        }
147
148        let upstream_response = req
149            .send()
150            .await
151            .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
152
153        let status = upstream_response.status();
154
155        if ctx.is_streaming {
156            if !status.is_success() {
157                let err = upstream_response.text().await.unwrap_or_default();
158                return Err(anyhow!("Upstream error {status}: {err}"));
159            }
160
161            let model_str = model.to_string();
162            let stream = upstream_response
163                .bytes_stream()
164                .map(move |chunk| match chunk {
165                    Ok(bytes) => Ok(sse::openai_sse_to_anthropic_sse(&bytes, &model_str)),
166                    Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
167                });
168            return Ok(UpstreamOutcome::Streaming {
169                status: http::StatusCode::OK,
170                stream: stream.boxed(),
171            });
172        }
173
174        if !status.is_success() {
175            let err = upstream_response.text().await.unwrap_or_default();
176            return Err(anyhow!("Upstream error {status}: {err}"));
177        }
178
179        let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
180            .json()
181            .await
182            .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
183
184        let served_model = Some(openai_resp.model.clone());
185
186        let anthropic_resp: AnthropicGatewayResponse =
187            converter::from_openai_response(openai_resp, model);
188
189        let body_bytes =
190            serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
191
192        Ok(UpstreamOutcome::Buffered {
193            status: http::StatusCode::OK,
194            content_type: "application/json".to_string(),
195            body: Bytes::from(body_bytes),
196            served_model,
197        })
198    }
199}
200
201#[derive(Debug, Clone, Copy)]
202pub struct GatewayUpstreamRegistration {
203    pub tag: &'static str,
204    pub factory: fn() -> Arc<dyn GatewayUpstream>,
205}
206
207inventory::collect!(GatewayUpstreamRegistration);
208
209pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
210    match outcome {
211        UpstreamOutcome::Buffered {
212            status,
213            content_type,
214            body,
215            served_model: _,
216        } => Response::builder()
217            .status(status)
218            .header(CONTENT_TYPE, content_type)
219            .body(Body::from(body))
220            .unwrap_or_else(|_| Response::new(Body::empty())),
221        UpstreamOutcome::Streaming { status, stream } => Response::builder()
222            .status(status)
223            .header(CONTENT_TYPE, "text/event-stream")
224            .header("cache-control", "no-cache")
225            .header("x-accel-buffering", "no")
226            .body(Body::from_stream(stream))
227            .unwrap_or_else(|_| Response::new(Body::empty())),
228    }
229}