systemprompt_api/services/gateway/
upstream.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use futures_util::StreamExt;
9use http::header::CONTENT_TYPE;
10use systemprompt_models::profile::GatewayRoute;
11
12use super::converter;
13use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
14
15#[derive(Debug)]
16pub struct UpstreamCtx<'a> {
17 pub route: &'a GatewayRoute,
18 pub api_key: &'a str,
19 pub raw_body: Bytes,
20 pub request: &'a AnthropicGatewayRequest,
21 pub is_streaming: bool,
22}
23
24#[allow(missing_debug_implementations)]
25pub enum UpstreamOutcome {
26 Buffered {
27 status: http::StatusCode,
28 content_type: String,
29 body: Bytes,
30 },
31 Streaming {
32 status: http::StatusCode,
33 stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
34 },
35}
36
37impl UpstreamOutcome {
38 pub const fn status(&self) -> http::StatusCode {
39 match self {
40 Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
41 }
42 }
43}
44
45#[async_trait]
46pub trait GatewayUpstream: Send + Sync {
47 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
48}
49
50#[derive(Debug, Clone, Copy)]
51pub struct AnthropicCompatibleUpstream;
52
53#[async_trait]
54impl GatewayUpstream for AnthropicCompatibleUpstream {
55 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
56 let client = reqwest::Client::new();
57 let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
58
59 let mut req = client
60 .post(&url)
61 .header("x-api-key", ctx.api_key)
62 .header("anthropic-version", "2023-06-01")
63 .header("content-type", "application/json")
64 .body(ctx.raw_body);
65
66 for (name, value) in &ctx.route.extra_headers {
67 req = req.header(name.as_str(), value.as_str());
68 }
69
70 let upstream_response = req
71 .send()
72 .await
73 .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
74
75 let status = upstream_response.status();
76 let upstream_content_type = upstream_response
77 .headers()
78 .get(CONTENT_TYPE)
79 .and_then(|v| v.to_str().ok())
80 .unwrap_or("application/json")
81 .to_string();
82
83 if ctx.is_streaming {
84 let stream = upstream_response.bytes_stream().map(|chunk| {
85 chunk.map_err(|e| {
86 tracing::warn!(error = %e, "Anthropic stream chunk error");
87 std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
88 })
89 });
90 return Ok(UpstreamOutcome::Streaming {
91 status,
92 stream: stream.boxed(),
93 });
94 }
95
96 let response_bytes = upstream_response
97 .bytes()
98 .await
99 .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
100
101 Ok(UpstreamOutcome::Buffered {
102 status,
103 content_type: upstream_content_type,
104 body: response_bytes,
105 })
106 }
107}
108
109#[derive(Debug, Clone, Copy)]
110pub struct OpenAiCompatibleUpstream;
111
112#[async_trait]
113impl GatewayUpstream for OpenAiCompatibleUpstream {
114 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
115 let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
116 let openai_request = converter::to_openai_request(ctx.request, upstream_model);
117 let model = ctx.request.model.as_str();
118
119 let client = reqwest::Client::new();
120 let url = format!(
121 "{}/chat/completions",
122 ctx.route.endpoint.trim_end_matches('/')
123 );
124
125 let mut req = client
126 .post(&url)
127 .header("authorization", format!("Bearer {}", ctx.api_key))
128 .header("content-type", "application/json")
129 .json(&openai_request);
130
131 for (name, value) in &ctx.route.extra_headers {
132 req = req.header(name.as_str(), value.as_str());
133 }
134
135 let upstream_response = req
136 .send()
137 .await
138 .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
139
140 let status = upstream_response.status();
141
142 if ctx.is_streaming {
143 if !status.is_success() {
144 let err = upstream_response.text().await.unwrap_or_default();
145 return Err(anyhow!("Upstream error {status}: {err}"));
146 }
147
148 let model_str = model.to_string();
149 let stream = upstream_response
150 .bytes_stream()
151 .map(move |chunk| match chunk {
152 Ok(bytes) => Ok(openai_sse_to_anthropic_sse(&bytes, &model_str)),
153 Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
154 });
155 return Ok(UpstreamOutcome::Streaming {
156 status: http::StatusCode::OK,
157 stream: stream.boxed(),
158 });
159 }
160
161 if !status.is_success() {
162 let err = upstream_response.text().await.unwrap_or_default();
163 return Err(anyhow!("Upstream error {status}: {err}"));
164 }
165
166 let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
167 .json()
168 .await
169 .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
170
171 let anthropic_resp: AnthropicGatewayResponse =
172 converter::from_openai_response(openai_resp, model);
173
174 let body_bytes =
175 serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
176
177 Ok(UpstreamOutcome::Buffered {
178 status: http::StatusCode::OK,
179 content_type: "application/json".to_string(),
180 body: Bytes::from(body_bytes),
181 })
182 }
183}
184
185#[derive(Debug, Clone, Copy)]
186pub struct GatewayUpstreamRegistration {
187 pub tag: &'static str,
188 pub factory: fn() -> Arc<dyn GatewayUpstream>,
189}
190
191inventory::collect!(GatewayUpstreamRegistration);
192
193pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
194 match outcome {
195 UpstreamOutcome::Buffered {
196 status,
197 content_type,
198 body,
199 } => Response::builder()
200 .status(status)
201 .header(CONTENT_TYPE, content_type)
202 .body(Body::from(body))
203 .unwrap_or_else(|_| Response::new(Body::empty())),
204 UpstreamOutcome::Streaming { status, stream } => Response::builder()
205 .status(status)
206 .header(CONTENT_TYPE, "text/event-stream")
207 .header("cache-control", "no-cache")
208 .header("x-accel-buffering", "no")
209 .body(Body::from_stream(stream))
210 .unwrap_or_else(|_| Response::new(Body::empty())),
211 }
212}
213
214fn openai_sse_to_anthropic_sse(bytes: &Bytes, model: &str) -> Bytes {
215 let text = String::from_utf8_lossy(bytes);
216 let mut output = String::new();
217
218 for line in text.lines() {
219 let Some(data) = line.strip_prefix("data: ") else {
220 continue;
221 };
222 if data.trim() == "[DONE]" {
223 push_sse_frame(&mut output, &serde_json::json!({ "type": "message_stop" }));
224 continue;
225 }
226 let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) else {
227 continue;
228 };
229 append_anthropic_frames(&mut output, &chunk, model);
230 }
231
232 Bytes::from(output)
233}
234
235fn append_anthropic_frames(output: &mut String, chunk: &OpenAiStreamChunk, model: &str) {
236 for choice in &chunk.choices {
237 if choice.delta.role.is_some() {
238 let id = chunk.id.as_deref().unwrap_or("msg_openai");
239 push_sse_frame(
240 output,
241 &serde_json::json!({
242 "type": "message_start",
243 "message": {
244 "id": id,
245 "type": "message",
246 "role": "assistant",
247 "model": model,
248 "usage": { "input_tokens": 0, "output_tokens": 0 },
249 },
250 }),
251 );
252 push_sse_frame(
253 output,
254 &serde_json::json!({
255 "type": "content_block_start",
256 "index": 0,
257 "content_block": { "type": "text", "text": "" },
258 }),
259 );
260 }
261 if let Some(text) = choice.delta.content.as_deref() {
262 if !text.is_empty() {
263 push_sse_frame(
264 output,
265 &serde_json::json!({
266 "type": "content_block_delta",
267 "index": 0,
268 "delta": { "type": "text_delta", "text": text },
269 }),
270 );
271 }
272 }
273 if let Some(finish) = choice.finish_reason.as_deref() {
274 if !finish.is_empty() && finish != "null" {
275 let stop_reason = if finish == "stop" { "end_turn" } else { finish };
276 push_sse_frame(
277 output,
278 &serde_json::json!({
279 "type": "message_delta",
280 "delta": { "stop_reason": stop_reason },
281 "usage": { "output_tokens": 0 },
282 }),
283 );
284 }
285 }
286 }
287}
288
289fn push_sse_frame(output: &mut String, value: &serde_json::Value) {
290 output.push_str("data: ");
291 if let Ok(encoded) = serde_json::to_string(value) {
292 output.push_str(&encoded);
293 }
294 output.push_str("\n\n");
295}
296
297#[derive(serde::Deserialize)]
298struct OpenAiStreamChunk {
299 #[serde(default)]
300 id: Option<String>,
301 #[serde(default)]
302 choices: Vec<OpenAiStreamChoice>,
303}
304
305#[derive(serde::Deserialize)]
306struct OpenAiStreamChoice {
307 #[serde(default)]
308 delta: OpenAiStreamDelta,
309 #[serde(default)]
310 finish_reason: Option<String>,
311}
312
313#[derive(serde::Deserialize, Default)]
314struct OpenAiStreamDelta {
315 #[serde(default)]
316 role: Option<String>,
317 #[serde(default)]
318 content: Option<String>,
319}