Skip to main content

systemprompt_api/services/gateway/
upstream.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use futures_util::StreamExt;
9use http::header::CONTENT_TYPE;
10use systemprompt_models::profile::GatewayRoute;
11
12use super::converter;
13use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
14
15#[derive(Debug)]
16pub struct UpstreamCtx<'a> {
17    pub route: &'a GatewayRoute,
18    pub api_key: &'a str,
19    pub raw_body: Bytes,
20    pub request: &'a AnthropicGatewayRequest,
21    pub is_streaming: bool,
22}
23
24#[allow(missing_debug_implementations)]
25pub enum UpstreamOutcome {
26    Buffered {
27        status: http::StatusCode,
28        content_type: String,
29        body: Bytes,
30    },
31    Streaming {
32        status: http::StatusCode,
33        stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
34    },
35}
36
37impl UpstreamOutcome {
38    pub const fn status(&self) -> http::StatusCode {
39        match self {
40            Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
41        }
42    }
43}
44
45#[async_trait]
46pub trait GatewayUpstream: Send + Sync {
47    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
48}
49
50#[derive(Debug, Clone, Copy)]
51pub struct AnthropicCompatibleUpstream;
52
53#[async_trait]
54impl GatewayUpstream for AnthropicCompatibleUpstream {
55    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
56        let client = reqwest::Client::new();
57        let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
58
59        let mut req = client
60            .post(&url)
61            .header("x-api-key", ctx.api_key)
62            .header("anthropic-version", "2023-06-01")
63            .header("content-type", "application/json")
64            .body(ctx.raw_body);
65
66        for (name, value) in &ctx.route.extra_headers {
67            req = req.header(name.as_str(), value.as_str());
68        }
69
70        let upstream_response = req
71            .send()
72            .await
73            .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
74
75        let status = upstream_response.status();
76        let upstream_content_type = upstream_response
77            .headers()
78            .get(CONTENT_TYPE)
79            .and_then(|v| v.to_str().ok())
80            .unwrap_or("application/json")
81            .to_string();
82
83        if ctx.is_streaming {
84            let stream = upstream_response.bytes_stream().map(|chunk| {
85                chunk.map_err(|e| {
86                    tracing::warn!(error = %e, "Anthropic stream chunk error");
87                    std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
88                })
89            });
90            return Ok(UpstreamOutcome::Streaming {
91                status,
92                stream: stream.boxed(),
93            });
94        }
95
96        let response_bytes = upstream_response
97            .bytes()
98            .await
99            .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
100
101        Ok(UpstreamOutcome::Buffered {
102            status,
103            content_type: upstream_content_type,
104            body: response_bytes,
105        })
106    }
107}
108
109#[derive(Debug, Clone, Copy)]
110pub struct OpenAiCompatibleUpstream;
111
112#[async_trait]
113impl GatewayUpstream for OpenAiCompatibleUpstream {
114    async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
115        let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
116        let openai_request = converter::to_openai_request(ctx.request, upstream_model);
117        let model = ctx.request.model.as_str();
118
119        let client = reqwest::Client::new();
120        let url = format!(
121            "{}/chat/completions",
122            ctx.route.endpoint.trim_end_matches('/')
123        );
124
125        let mut req = client
126            .post(&url)
127            .header("authorization", format!("Bearer {}", ctx.api_key))
128            .header("content-type", "application/json")
129            .json(&openai_request);
130
131        for (name, value) in &ctx.route.extra_headers {
132            req = req.header(name.as_str(), value.as_str());
133        }
134
135        let upstream_response = req
136            .send()
137            .await
138            .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
139
140        let status = upstream_response.status();
141
142        if ctx.is_streaming {
143            if !status.is_success() {
144                let err = upstream_response.text().await.unwrap_or_default();
145                return Err(anyhow!("Upstream error {status}: {err}"));
146            }
147
148            let model_str = model.to_string();
149            let stream = upstream_response
150                .bytes_stream()
151                .map(move |chunk| match chunk {
152                    Ok(bytes) => Ok(openai_sse_to_anthropic_sse(&bytes, &model_str)),
153                    Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
154                });
155            return Ok(UpstreamOutcome::Streaming {
156                status: http::StatusCode::OK,
157                stream: stream.boxed(),
158            });
159        }
160
161        if !status.is_success() {
162            let err = upstream_response.text().await.unwrap_or_default();
163            return Err(anyhow!("Upstream error {status}: {err}"));
164        }
165
166        let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
167            .json()
168            .await
169            .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
170
171        let anthropic_resp: AnthropicGatewayResponse =
172            converter::from_openai_response(openai_resp, model);
173
174        let body_bytes =
175            serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
176
177        Ok(UpstreamOutcome::Buffered {
178            status: http::StatusCode::OK,
179            content_type: "application/json".to_string(),
180            body: Bytes::from(body_bytes),
181        })
182    }
183}
184
185#[derive(Debug, Clone, Copy)]
186pub struct GatewayUpstreamRegistration {
187    pub tag: &'static str,
188    pub factory: fn() -> Arc<dyn GatewayUpstream>,
189}
190
191inventory::collect!(GatewayUpstreamRegistration);
192
193pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
194    match outcome {
195        UpstreamOutcome::Buffered {
196            status,
197            content_type,
198            body,
199        } => Response::builder()
200            .status(status)
201            .header(CONTENT_TYPE, content_type)
202            .body(Body::from(body))
203            .unwrap_or_else(|_| Response::new(Body::empty())),
204        UpstreamOutcome::Streaming { status, stream } => Response::builder()
205            .status(status)
206            .header(CONTENT_TYPE, "text/event-stream")
207            .header("cache-control", "no-cache")
208            .header("x-accel-buffering", "no")
209            .body(Body::from_stream(stream))
210            .unwrap_or_else(|_| Response::new(Body::empty())),
211    }
212}
213
214fn openai_sse_to_anthropic_sse(bytes: &Bytes, model: &str) -> Bytes {
215    let text = String::from_utf8_lossy(bytes);
216    let mut output = String::new();
217
218    for line in text.lines() {
219        let Some(data) = line.strip_prefix("data: ") else {
220            continue;
221        };
222        if data.trim() == "[DONE]" {
223            push_sse_frame(&mut output, &serde_json::json!({ "type": "message_stop" }));
224            continue;
225        }
226        let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) else {
227            continue;
228        };
229        append_anthropic_frames(&mut output, &chunk, model);
230    }
231
232    Bytes::from(output)
233}
234
235fn append_anthropic_frames(output: &mut String, chunk: &OpenAiStreamChunk, model: &str) {
236    for choice in &chunk.choices {
237        if choice.delta.role.is_some() {
238            let id = chunk.id.as_deref().unwrap_or("msg_openai");
239            push_sse_frame(
240                output,
241                &serde_json::json!({
242                    "type": "message_start",
243                    "message": {
244                        "id": id,
245                        "type": "message",
246                        "role": "assistant",
247                        "model": model,
248                        "usage": { "input_tokens": 0, "output_tokens": 0 },
249                    },
250                }),
251            );
252            push_sse_frame(
253                output,
254                &serde_json::json!({
255                    "type": "content_block_start",
256                    "index": 0,
257                    "content_block": { "type": "text", "text": "" },
258                }),
259            );
260        }
261        if let Some(text) = choice.delta.content.as_deref() {
262            if !text.is_empty() {
263                push_sse_frame(
264                    output,
265                    &serde_json::json!({
266                        "type": "content_block_delta",
267                        "index": 0,
268                        "delta": { "type": "text_delta", "text": text },
269                    }),
270                );
271            }
272        }
273        if let Some(finish) = choice.finish_reason.as_deref() {
274            if !finish.is_empty() && finish != "null" {
275                let stop_reason = if finish == "stop" { "end_turn" } else { finish };
276                push_sse_frame(
277                    output,
278                    &serde_json::json!({
279                        "type": "message_delta",
280                        "delta": { "stop_reason": stop_reason },
281                        "usage": { "output_tokens": 0 },
282                    }),
283                );
284            }
285        }
286    }
287}
288
289fn push_sse_frame(output: &mut String, value: &serde_json::Value) {
290    output.push_str("data: ");
291    if let Ok(encoded) = serde_json::to_string(value) {
292        output.push_str(&encoded);
293    }
294    output.push_str("\n\n");
295}
296
297#[derive(serde::Deserialize)]
298struct OpenAiStreamChunk {
299    #[serde(default)]
300    id: Option<String>,
301    #[serde(default)]
302    choices: Vec<OpenAiStreamChoice>,
303}
304
305#[derive(serde::Deserialize)]
306struct OpenAiStreamChoice {
307    #[serde(default)]
308    delta: OpenAiStreamDelta,
309    #[serde(default)]
310    finish_reason: Option<String>,
311}
312
313#[derive(serde::Deserialize, Default)]
314struct OpenAiStreamDelta {
315    #[serde(default)]
316    role: Option<String>,
317    #[serde(default)]
318    content: Option<String>,
319}