Skip to main content

systemprompt_api/services/gateway/
upstream.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use futures_util::StreamExt;
9use http::header::CONTENT_TYPE;
10use systemprompt_models::profile::GatewayRoute;
11
12use super::converter;
13use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
14
15#[derive(Debug)]
16pub struct UpstreamCtx<'a> {
17    pub route: &'a GatewayRoute,
18    pub api_key: &'a str,
19    pub raw_body: Bytes,
20    pub request: &'a AnthropicGatewayRequest,
21    pub is_streaming: bool,
22}
23
24#[allow(missing_debug_implementations)]
25pub enum UpstreamOutcome {
26    Buffered {
27        status: http::StatusCode,
28        content_type: String,
29        body: Bytes,
30        served_model: Option<String>,
31    },
32    Streaming {
33        status: http::StatusCode,
34        stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
35    },
36}
37
38impl UpstreamOutcome {
39    pub const fn status(&self) -> http::StatusCode {
40        match self {
41            Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
42        }
43    }
44}
45
46#[async_trait]
47pub trait GatewayUpstream: Send + Sync {
48    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct AnthropicCompatibleUpstream;
53
54#[async_trait]
55impl GatewayUpstream for AnthropicCompatibleUpstream {
56    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
57        let client = reqwest::Client::new();
58        let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
59
60        let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
61        let body_to_send: Bytes = if upstream_model == ctx.request.model {
62            ctx.raw_body
63        } else {
64            super::flatten::rewrite_request_model(ctx.raw_body, upstream_model)?
65        };
66
67        let mut req = client
68            .post(&url)
69            .header("x-api-key", ctx.api_key)
70            .header("anthropic-version", "2023-06-01")
71            .header("content-type", "application/json")
72            .body(body_to_send);
73
74        for (name, value) in &ctx.route.extra_headers {
75            req = req.header(name.as_str(), value.as_str());
76        }
77
78        let upstream_response = req
79            .send()
80            .await
81            .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
82
83        let status = upstream_response.status();
84        let upstream_content_type = upstream_response
85            .headers()
86            .get(CONTENT_TYPE)
87            .and_then(|v| v.to_str().ok())
88            .unwrap_or("application/json")
89            .to_string();
90
91        if ctx.is_streaming {
92            let stream = upstream_response.bytes_stream().map(|chunk| {
93                chunk.map_err(|e| {
94                    tracing::warn!(error = %e, "Anthropic stream chunk error");
95                    std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
96                })
97            });
98            return Ok(UpstreamOutcome::Streaming {
99                status,
100                stream: stream.boxed(),
101            });
102        }
103
104        let response_bytes = upstream_response
105            .bytes()
106            .await
107            .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
108
109        let served_model = super::flatten::parse_served_model(&response_bytes);
110
111        Ok(UpstreamOutcome::Buffered {
112            status,
113            content_type: upstream_content_type,
114            body: response_bytes,
115            served_model,
116        })
117    }
118}
119
120#[derive(Debug, Clone, Copy)]
121pub struct OpenAiCompatibleUpstream;
122
123#[async_trait]
124impl GatewayUpstream for OpenAiCompatibleUpstream {
125    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
126        let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
127        let openai_request = converter::to_openai_request(ctx.request, upstream_model);
128        let model = ctx.request.model.as_str();
129
130        let client = reqwest::Client::new();
131        let url = format!(
132            "{}/chat/completions",
133            ctx.route.endpoint.trim_end_matches('/')
134        );
135
136        let mut req = client
137            .post(&url)
138            .header("authorization", format!("Bearer {}", ctx.api_key))
139            .header("content-type", "application/json")
140            .json(&openai_request);
141
142        for (name, value) in &ctx.route.extra_headers {
143            req = req.header(name.as_str(), value.as_str());
144        }
145
146        let upstream_response = req
147            .send()
148            .await
149            .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
150
151        let status = upstream_response.status();
152
153        if ctx.is_streaming {
154            if !status.is_success() {
155                let err = upstream_response.text().await.unwrap_or_default();
156                return Err(anyhow!("Upstream error {status}: {err}"));
157            }
158
159            let model_str = model.to_string();
160            let stream = upstream_response
161                .bytes_stream()
162                .map(move |chunk| match chunk {
163                    Ok(bytes) => Ok(openai_sse_to_anthropic_sse(&bytes, &model_str)),
164                    Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
165                });
166            return Ok(UpstreamOutcome::Streaming {
167                status: http::StatusCode::OK,
168                stream: stream.boxed(),
169            });
170        }
171
172        if !status.is_success() {
173            let err = upstream_response.text().await.unwrap_or_default();
174            return Err(anyhow!("Upstream error {status}: {err}"));
175        }
176
177        let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
178            .json()
179            .await
180            .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
181
182        let served_model = Some(openai_resp.model.clone());
183
184        let anthropic_resp: AnthropicGatewayResponse =
185            converter::from_openai_response(openai_resp, model);
186
187        let body_bytes =
188            serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
189
190        Ok(UpstreamOutcome::Buffered {
191            status: http::StatusCode::OK,
192            content_type: "application/json".to_string(),
193            body: Bytes::from(body_bytes),
194            served_model,
195        })
196    }
197}
198
199#[derive(Debug, Clone, Copy)]
200pub struct GatewayUpstreamRegistration {
201    pub tag: &'static str,
202    pub factory: fn() -> Arc<dyn GatewayUpstream>,
203}
204
205inventory::collect!(GatewayUpstreamRegistration);
206
207pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
208    match outcome {
209        UpstreamOutcome::Buffered {
210            status,
211            content_type,
212            body,
213            served_model: _,
214        } => Response::builder()
215            .status(status)
216            .header(CONTENT_TYPE, content_type)
217            .body(Body::from(body))
218            .unwrap_or_else(|_| Response::new(Body::empty())),
219        UpstreamOutcome::Streaming { status, stream } => Response::builder()
220            .status(status)
221            .header(CONTENT_TYPE, "text/event-stream")
222            .header("cache-control", "no-cache")
223            .header("x-accel-buffering", "no")
224            .body(Body::from_stream(stream))
225            .unwrap_or_else(|_| Response::new(Body::empty())),
226    }
227}
228
229fn openai_sse_to_anthropic_sse(bytes: &Bytes, model: &str) -> Bytes {
230    let text = String::from_utf8_lossy(bytes);
231    let mut output = String::new();
232
233    for line in text.lines() {
234        let Some(data) = line.strip_prefix("data: ") else {
235            continue;
236        };
237        if data.trim() == "[DONE]" {
238            push_sse_frame(&mut output, &serde_json::json!({ "type": "message_stop" }));
239            continue;
240        }
241        let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) else {
242            continue;
243        };
244        append_anthropic_frames(&mut output, &chunk, model);
245    }
246
247    Bytes::from(output)
248}
249
250fn append_anthropic_frames(output: &mut String, chunk: &OpenAiStreamChunk, model: &str) {
251    for choice in &chunk.choices {
252        if choice.delta.role.is_some() {
253            let id = chunk.id.as_deref().unwrap_or("msg_openai");
254            push_sse_frame(
255                output,
256                &serde_json::json!({
257                    "type": "message_start",
258                    "message": {
259                        "id": id,
260                        "type": "message",
261                        "role": "assistant",
262                        "model": model,
263                        "usage": { "input_tokens": 0, "output_tokens": 0 },
264                    },
265                }),
266            );
267            push_sse_frame(
268                output,
269                &serde_json::json!({
270                    "type": "content_block_start",
271                    "index": 0,
272                    "content_block": { "type": "text", "text": "" },
273                }),
274            );
275        }
276        if let Some(text) = choice.delta.content.as_deref() {
277            if !text.is_empty() {
278                push_sse_frame(
279                    output,
280                    &serde_json::json!({
281                        "type": "content_block_delta",
282                        "index": 0,
283                        "delta": { "type": "text_delta", "text": text },
284                    }),
285                );
286            }
287        }
288        if let Some(finish) = choice.finish_reason.as_deref() {
289            if !finish.is_empty() && finish != "null" {
290                let stop_reason = if finish == "stop" { "end_turn" } else { finish };
291                push_sse_frame(
292                    output,
293                    &serde_json::json!({
294                        "type": "message_delta",
295                        "delta": { "stop_reason": stop_reason },
296                        "usage": { "output_tokens": 0 },
297                    }),
298                );
299            }
300        }
301    }
302}
303
304fn push_sse_frame(output: &mut String, value: &serde_json::Value) {
305    output.push_str("data: ");
306    if let Ok(encoded) = serde_json::to_string(value) {
307        output.push_str(&encoded);
308    }
309    output.push_str("\n\n");
310}
311
312#[derive(serde::Deserialize)]
313struct OpenAiStreamChunk {
314    #[serde(default)]
315    id: Option<String>,
316    #[serde(default)]
317    choices: Vec<OpenAiStreamChoice>,
318}
319
320#[derive(serde::Deserialize)]
321struct OpenAiStreamChoice {
322    #[serde(default)]
323    delta: OpenAiStreamDelta,
324    #[serde(default)]
325    finish_reason: Option<String>,
326}
327
328#[derive(serde::Deserialize, Default)]
329struct OpenAiStreamDelta {
330    #[serde(default)]
331    role: Option<String>,
332    #[serde(default)]
333    content: Option<String>,
334}