systemprompt_api/services/gateway/
upstream.rs1mod sse;
2
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use axum::body::Body;
8use axum::response::Response;
9use bytes::Bytes;
10use futures_util::StreamExt;
11use http::header::CONTENT_TYPE;
12use systemprompt_models::profile::GatewayRoute;
13
14use super::converter;
15use super::models::{AnthropicGatewayRequest, AnthropicGatewayResponse};
16
17#[derive(Debug)]
18pub struct UpstreamCtx<'a> {
19 pub route: &'a GatewayRoute,
20 pub api_key: &'a str,
21 pub raw_body: Bytes,
22 pub request: &'a AnthropicGatewayRequest,
23 pub is_streaming: bool,
24}
25
26#[allow(missing_debug_implementations)]
27pub enum UpstreamOutcome {
28 Buffered {
29 status: http::StatusCode,
30 content_type: String,
31 body: Bytes,
32 served_model: Option<String>,
33 },
34 Streaming {
35 status: http::StatusCode,
36 stream: futures_util::stream::BoxStream<'static, Result<Bytes, std::io::Error>>,
37 },
38}
39
40impl UpstreamOutcome {
41 pub const fn status(&self) -> http::StatusCode {
42 match self {
43 Self::Buffered { status, .. } | Self::Streaming { status, .. } => *status,
44 }
45 }
46}
47
48#[async_trait]
49pub trait GatewayUpstream: Send + Sync {
50 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome>;
51}
52
53#[derive(Debug, Clone, Copy)]
54pub struct AnthropicCompatibleUpstream;
55
56#[async_trait]
57impl GatewayUpstream for AnthropicCompatibleUpstream {
58 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
59 let client = reqwest::Client::new();
60 let url = format!("{}/messages", ctx.route.endpoint.trim_end_matches('/'));
61
62 let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
63 let body_to_send: Bytes = if upstream_model == ctx.request.model {
64 ctx.raw_body
65 } else {
66 super::flatten::rewrite_request_model(ctx.raw_body, upstream_model)?
67 };
68
69 let mut req = client
70 .post(&url)
71 .header("x-api-key", ctx.api_key)
72 .header("anthropic-version", "2023-06-01")
73 .header("content-type", "application/json")
74 .body(body_to_send);
75
76 for (name, value) in &ctx.route.extra_headers {
77 req = req.header(name.as_str(), value.as_str());
78 }
79
80 let upstream_response = req
81 .send()
82 .await
83 .map_err(|e| anyhow!("Upstream Anthropic request failed: {e}"))?;
84
85 let status = upstream_response.status();
86 let upstream_content_type = upstream_response
87 .headers()
88 .get(CONTENT_TYPE)
89 .and_then(|v| v.to_str().ok())
90 .unwrap_or("application/json")
91 .to_string();
92
93 if ctx.is_streaming {
94 let stream = upstream_response.bytes_stream().map(|chunk| {
95 chunk.map_err(|e| {
96 tracing::warn!(error = %e, "Anthropic stream chunk error");
97 std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
98 })
99 });
100 return Ok(UpstreamOutcome::Streaming {
101 status,
102 stream: stream.boxed(),
103 });
104 }
105
106 let response_bytes = upstream_response
107 .bytes()
108 .await
109 .map_err(|e| anyhow!("Failed to read Anthropic response: {e}"))?;
110
111 let served_model = super::flatten::parse_served_model(&response_bytes);
112
113 Ok(UpstreamOutcome::Buffered {
114 status,
115 content_type: upstream_content_type,
116 body: response_bytes,
117 served_model,
118 })
119 }
120}
121
122#[derive(Debug, Clone, Copy)]
123pub struct OpenAiCompatibleUpstream;
124
125#[async_trait]
126impl GatewayUpstream for OpenAiCompatibleUpstream {
127 async fn proxy(&self, ctx: UpstreamCtx<'_>) -> Result<UpstreamOutcome> {
128 let upstream_model = ctx.route.effective_upstream_model(&ctx.request.model);
129 let openai_request = converter::to_openai_request(ctx.request, upstream_model);
130 let model = ctx.request.model.as_str();
131
132 let client = reqwest::Client::new();
133 let url = format!(
134 "{}/chat/completions",
135 ctx.route.endpoint.trim_end_matches('/')
136 );
137
138 let mut req = client
139 .post(&url)
140 .header("authorization", format!("Bearer {}", ctx.api_key))
141 .header("content-type", "application/json")
142 .json(&openai_request);
143
144 for (name, value) in &ctx.route.extra_headers {
145 req = req.header(name.as_str(), value.as_str());
146 }
147
148 let upstream_response = req
149 .send()
150 .await
151 .map_err(|e| anyhow!("Upstream OpenAI-compatible request failed: {e}"))?;
152
153 let status = upstream_response.status();
154
155 if ctx.is_streaming {
156 if !status.is_success() {
157 let err = upstream_response.text().await.unwrap_or_default();
158 return Err(anyhow!("Upstream error {status}: {err}"));
159 }
160
161 let model_str = model.to_string();
162 let stream = upstream_response
163 .bytes_stream()
164 .map(move |chunk| match chunk {
165 Ok(bytes) => Ok(sse::openai_sse_to_anthropic_sse(&bytes, &model_str)),
166 Err(e) => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
167 });
168 return Ok(UpstreamOutcome::Streaming {
169 status: http::StatusCode::OK,
170 stream: stream.boxed(),
171 });
172 }
173
174 if !status.is_success() {
175 let err = upstream_response.text().await.unwrap_or_default();
176 return Err(anyhow!("Upstream error {status}: {err}"));
177 }
178
179 let openai_resp: super::models::OpenAiGatewayResponse = upstream_response
180 .json()
181 .await
182 .map_err(|e| anyhow!("Failed to deserialize upstream response: {e}"))?;
183
184 let served_model = Some(openai_resp.model.clone());
185
186 let anthropic_resp: AnthropicGatewayResponse =
187 converter::from_openai_response(openai_resp, model);
188
189 let body_bytes =
190 serde_json::to_vec(&anthropic_resp).map_err(|e| anyhow!("Serialization error: {e}"))?;
191
192 Ok(UpstreamOutcome::Buffered {
193 status: http::StatusCode::OK,
194 content_type: "application/json".to_string(),
195 body: Bytes::from(body_bytes),
196 served_model,
197 })
198 }
199}
200
201#[derive(Debug, Clone, Copy)]
202pub struct GatewayUpstreamRegistration {
203 pub tag: &'static str,
204 pub factory: fn() -> Arc<dyn GatewayUpstream>,
205}
206
207inventory::collect!(GatewayUpstreamRegistration);
208
209pub fn build_response(outcome: UpstreamOutcome) -> Response<Body> {
210 match outcome {
211 UpstreamOutcome::Buffered {
212 status,
213 content_type,
214 body,
215 served_model: _,
216 } => Response::builder()
217 .status(status)
218 .header(CONTENT_TYPE, content_type)
219 .body(Body::from(body))
220 .unwrap_or_else(|_| Response::new(Body::empty())),
221 UpstreamOutcome::Streaming { status, stream } => Response::builder()
222 .status(status)
223 .header(CONTENT_TYPE, "text/event-stream")
224 .header("cache-control", "no-cache")
225 .header("x-accel-buffering", "no")
226 .body(Body::from_stream(stream))
227 .unwrap_or_else(|_| Response::new(Body::empty())),
228 }
229}