systemprompt_api/services/gateway/
upstream.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use futures_util::StreamExt;
9use http::header::CONTENT_TYPE;
10use systemprompt_models::profile::GatewayRoute;
11
12use super::converter;
13use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
14
15#[derive(Debug)]
16pub struct UpstreamCtx<'a> {
17 pub route: &'a GatewayRoute,
18 pub api_key: &'a str,
19 pub raw_body: Bytes,
20 pub request: &'a AnthropicGatewayRequest,
21 pub is_streaming: bool,
22}
23
24#[allow(missing_debug_implementations)]
25pub enum UpstreamOutcome {
26 Buffered {
27 status: http::StatusCode,
28 content_type: String,
29 body: Bytes,
30 served_model: Option<String>,
31 },
32 Streaming {
33 status: http::StatusCode,
34 stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
35 },
36}
37
38impl UpstreamOutcome {
39 pub const fn status(&self) -> http::StatusCode {
40 match self {
41 Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
42 }
43 }
44}
45
46#[async_trait]
47pub trait GatewayUpstream: Send + Sync {
48 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct AnthropicCompatibleUpstream;
53
54#[async_trait]
55impl GatewayUpstream for AnthropicCompatibleUpstream {
56 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
57 let client = reqwest::Client::new();
58 let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
59
60 let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
61 let body_to_send: Bytes = if upstream_model == ctx.request.model {
62 ctx.raw_body
63 } else {
64 super::flatten::rewrite_request_model(ctx.raw_body, upstream_model)?
65 };
66
67 let mut req = client
68 .post(&url)
69 .header("x-api-key", ctx.api_key)
70 .header("anthropic-version", "2023-06-01")
71 .header("content-type", "application/json")
72 .body(body_to_send);
73
74 for (name, value) in &ctx.route.extra_headers {
75 req = req.header(name.as_str(), value.as_str());
76 }
77
78 let upstream_response = req
79 .send()
80 .await
81 .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
82
83 let status = upstream_response.status();
84 let upstream_content_type = upstream_response
85 .headers()
86 .get(CONTENT_TYPE)
87 .and_then(|v| v.to_str().ok())
88 .unwrap_or("application/json")
89 .to_string();
90
91 if ctx.is_streaming {
92 let stream = upstream_response.bytes_stream().map(|chunk| {
93 chunk.map_err(|e| {
94 tracing::warn!(error = %e, "Anthropic stream chunk error");
95 std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
96 })
97 });
98 return Ok(UpstreamOutcome::Streaming {
99 status,
100 stream: stream.boxed(),
101 });
102 }
103
104 let response_bytes = upstream_response
105 .bytes()
106 .await
107 .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
108
109 let served_model = super::flatten::parse_served_model(&response_bytes);
110
111 Ok(UpstreamOutcome::Buffered {
112 status,
113 content_type: upstream_content_type,
114 body: response_bytes,
115 served_model,
116 })
117 }
118}
119
120#[derive(Debug, Clone, Copy)]
121pub struct OpenAiCompatibleUpstream;
122
123#[async_trait]
124impl GatewayUpstream for OpenAiCompatibleUpstream {
125 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
126 let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
127 let openai_request = converter::to_openai_request(ctx.request, upstream_model);
128 let model = ctx.request.model.as_str();
129
130 let client = reqwest::Client::new();
131 let url = format!(
132 "{}/chat/completions",
133 ctx.route.endpoint.trim_end_matches('/')
134 );
135
136 let mut req = client
137 .post(&url)
138 .header("authorization", format!("Bearer {}", ctx.api_key))
139 .header("content-type", "application/json")
140 .json(&openai_request);
141
142 for (name, value) in &ctx.route.extra_headers {
143 req = req.header(name.as_str(), value.as_str());
144 }
145
146 let upstream_response = req
147 .send()
148 .await
149 .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
150
151 let status = upstream_response.status();
152
153 if ctx.is_streaming {
154 if !status.is_success() {
155 let err = upstream_response.text().await.unwrap_or_default();
156 return Err(anyhow!("Upstream error {status}: {err}"));
157 }
158
159 let model_str = model.to_string();
160 let stream = upstream_response
161 .bytes_stream()
162 .map(move |chunk| match chunk {
163 Ok(bytes) => Ok(openai_sse_to_anthropic_sse(&bytes, &model_str)),
164 Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
165 });
166 return Ok(UpstreamOutcome::Streaming {
167 status: http::StatusCode::OK,
168 stream: stream.boxed(),
169 });
170 }
171
172 if !status.is_success() {
173 let err = upstream_response.text().await.unwrap_or_default();
174 return Err(anyhow!("Upstream error {status}: {err}"));
175 }
176
177 let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
178 .json()
179 .await
180 .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
181
182 let served_model = Some(openai_resp.model.clone());
183
184 let anthropic_resp: AnthropicGatewayResponse =
185 converter::from_openai_response(openai_resp, model);
186
187 let body_bytes =
188 serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
189
190 Ok(UpstreamOutcome::Buffered {
191 status: http::StatusCode::OK,
192 content_type: "application/json".to_string(),
193 body: Bytes::from(body_bytes),
194 served_model,
195 })
196 }
197}
198
199#[derive(Debug, Clone, Copy)]
200pub struct GatewayUpstreamRegistration {
201 pub tag: &'static str,
202 pub factory: fn() -> Arc<dyn GatewayUpstream>,
203}
204
205inventory::collect!(GatewayUpstreamRegistration);
206
207pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
208 match outcome {
209 UpstreamOutcome::Buffered {
210 status,
211 content_type,
212 body,
213 served_model: _,
214 } => Response::builder()
215 .status(status)
216 .header(CONTENT_TYPE, content_type)
217 .body(Body::from(body))
218 .unwrap_or_else(|_| Response::new(Body::empty())),
219 UpstreamOutcome::Streaming { status, stream } => Response::builder()
220 .status(status)
221 .header(CONTENT_TYPE, "text/event-stream")
222 .header("cache-control", "no-cache")
223 .header("x-accel-buffering", "no")
224 .body(Body::from_stream(stream))
225 .unwrap_or_else(|_| Response::new(Body::empty())),
226 }
227}
228
229fn openai_sse_to_anthropic_sse(bytes: &Bytes, model: &str) -> Bytes {
230 let text = String::from_utf8_lossy(bytes);
231 let mut output = String::new();
232
233 for line in text.lines() {
234 let Some(data) = line.strip_prefix("data: ") else {
235 continue;
236 };
237 if data.trim() == "[DONE]" {
238 push_sse_frame(&mut output, &serde_json::json!({ "type": "message_stop" }));
239 continue;
240 }
241 let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) else {
242 continue;
243 };
244 append_anthropic_frames(&mut output, &chunk, model);
245 }
246
247 Bytes::from(output)
248}
249
250fn append_anthropic_frames(output: &mut String, chunk: &OpenAiStreamChunk, model: &str) {
251 for choice in &chunk.choices {
252 if choice.delta.role.is_some() {
253 let id = chunk.id.as_deref().unwrap_or("msg_openai");
254 push_sse_frame(
255 output,
256 &serde_json::json!({
257 "type": "message_start",
258 "message": {
259 "id": id,
260 "type": "message",
261 "role": "assistant",
262 "model": model,
263 "usage": { "input_tokens": 0, "output_tokens": 0 },
264 },
265 }),
266 );
267 push_sse_frame(
268 output,
269 &serde_json::json!({
270 "type": "content_block_start",
271 "index": 0,
272 "content_block": { "type": "text", "text": "" },
273 }),
274 );
275 }
276 if let Some(text) = choice.delta.content.as_deref() {
277 if !text.is_empty() {
278 push_sse_frame(
279 output,
280 &serde_json::json!({
281 "type": "content_block_delta",
282 "index": 0,
283 "delta": { "type": "text_delta", "text": text },
284 }),
285 );
286 }
287 }
288 if let Some(finish) = choice.finish_reason.as_deref() {
289 if !finish.is_empty() && finish != "null" {
290 let stop_reason = if finish == "stop" { "end_turn" } else { finish };
291 push_sse_frame(
292 output,
293 &serde_json::json!({
294 "type": "message_delta",
295 "delta": { "stop_reason": stop_reason },
296 "usage": { "output_tokens": 0 },
297 }),
298 );
299 }
300 }
301 }
302}
303
304fn push_sse_frame(output: &mut String, value: &serde_json::Value) {
305 output.push_str("data: ");
306 if let Ok(encoded) = serde_json::to_string(value) {
307 output.push_str(&encoded);
308 }
309 output.push_str("\n\n");
310}
311
312#[derive(serde::Deserialize)]
313struct OpenAiStreamChunk {
314 #[serde(default)]
315 id: Option<String>,
316 #[serde(default)]
317 choices: Vec<OpenAiStreamChoice>,
318}
319
320#[derive(serde::Deserialize)]
321struct OpenAiStreamChoice {
322 #[serde(default)]
323 delta: OpenAiStreamDelta,
324 #[serde(default)]
325 finish_reason: Option<String>,
326}
327
328#[derive(serde::Deserialize, Default)]
329struct OpenAiStreamDelta {
330 #[serde(default)]
331 role: Option<String>,
332 #[serde(default)]
333 content: Option<String>,
334}