systemprompt_api/services/gateway/
service.rs1#![allow(clippy::clone_on_ref_ptr)]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use http::HeaderValue;
9use systemprompt_ai::InsertSafetyFinding;
10use systemprompt_ai::repository::AiSafetyFindingRepository;
11use systemprompt_database::DbPool;
12use systemprompt_identifiers::AiRequestId;
13use systemprompt_models::profile::GatewayConfig;
14
15use super::audit::{GatewayAudit, GatewayRequestContext};
16use super::models::AnthropicGatewayRequest;
17use super::policy::{GatewayPolicySpec, PolicyResolver};
18use super::registry::GatewayUpstreamRegistry;
19use super::safety::{HeuristicScanner, SafetyScanner};
20use super::upstream::{UpstreamCtx, UpstreamOutcome, build_response};
21use super::{parse, quota, stream_tap};
22
23pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
24
25#[derive(Debug, Clone, Copy)]
26pub struct GatewayService;
27
28impl GatewayService {
29 pub async fn dispatch(
30 config: &GatewayConfig,
31 request: AnthropicGatewayRequest,
32 raw_body: Bytes,
33 ctx: GatewayRequestContext,
34 db: &DbPool,
35 ) -> Result<Response<Body>> {
36 let route = config
37 .find_route(&request.model)
38 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
39
40 let secrets = systemprompt_config::SecretsBootstrap::get()
41 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
42
43 let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
44 anyhow!(
45 "Gateway API key secret '{}' not configured",
46 route.api_key_secret
47 )
48 })?;
49
50 let upstream = GatewayUpstreamRegistry::global()
51 .get(&route.provider)
52 .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
53
54 let is_streaming = request.stream.unwrap_or(false);
55 let ai_request_id = ctx.ai_request_id.clone();
56
57 tracing::info!(
58 ai_request_id = %ai_request_id,
59 user_id = %ctx.user_id,
60 tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
61 model = %request.model,
62 provider = %route.provider,
63 upstream = %route.endpoint,
64 streaming = is_streaming,
65 "Gateway request dispatched"
66 );
67
68 let resolver = PolicyResolver::new(db)?;
69 let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
70
71 if !policy.model_allowed(&request.model) {
72 tracing::warn!(
73 ai_request_id = %ai_request_id,
74 model = %request.model,
75 "Gateway policy denied: model not in allowed list"
76 );
77 return Err(PolicyDenied(format!(
78 "model '{}' is not permitted by gateway policy",
79 request.model
80 ))
81 .into());
82 }
83
84 let audit = Arc::new(
85 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
86 );
87
88 if let Err(e) = audit.open(&request, &raw_body).await {
89 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
90 }
91
92 if let Some(decision) = quota::precheck_and_reserve(
93 db,
94 ctx.tenant_id.as_ref(),
95 &ctx.user_id,
96 &policy.quota_windows,
97 )
98 .await?
99 {
100 if !decision.allow {
101 let msg = format!(
102 "quota exceeded for window {}s (used {}/{:?})",
103 decision.window_seconds, decision.state.requests, decision.limit_requests
104 );
105 if let Err(e) = audit.fail(&msg).await {
106 tracing::warn!(error = %e, "quota audit fail failed");
107 }
108 return Err(QuotaExceeded {
109 message: msg,
110 retry_after_seconds: decision.window_seconds,
111 }
112 .into());
113 }
114 }
115
116 run_request_safety_scan(db, &ai_request_id, &request).await;
117
118 let upstream_ctx = UpstreamCtx {
119 route,
120 api_key: upstream_api_key,
121 raw_body,
122 request: &request,
123 is_streaming,
124 };
125
126 let outcome = match upstream.proxy(upstream_ctx).await {
127 Ok(o) => o,
128 Err(e) => {
129 if let Err(audit_err) = audit.fail(&e.to_string()).await {
130 tracing::warn!(error = %audit_err, "upstream audit fail failed");
131 }
132 return Err(e);
133 },
134 };
135
136 let response = finalize(
137 outcome,
138 Arc::clone(&audit),
139 db.clone(),
140 ai_request_id.clone(),
141 policy,
142 )
143 .await;
144 Ok(attach_request_id(response, &ai_request_id))
145 }
146}
147
148#[derive(Debug, thiserror::Error)]
149#[error("{0}")]
150pub struct PolicyDenied(pub String);
151
152#[derive(Debug, thiserror::Error)]
153#[error("{message}")]
154pub struct QuotaExceeded {
155 pub message: String,
156 pub retry_after_seconds: i32,
157}
158
159async fn finalize(
160 outcome: UpstreamOutcome,
161 audit: Arc<GatewayAudit>,
162 db: DbPool,
163 ai_request_id: AiRequestId,
164 policy: GatewayPolicySpec,
165) -> Response<Body> {
166 match outcome {
167 UpstreamOutcome::Buffered {
168 status,
169 content_type,
170 body,
171 served_model,
172 } => {
173 let body_clone = body.clone();
174 let audit_clone = Arc::clone(&audit);
175 let served_model_clone = served_model.clone();
176 tokio::spawn(async move {
177 if status.is_success() {
178 if let Some(model) = served_model_clone.as_deref() {
179 audit_clone.set_served_model(model).await;
180 }
181 let (usage, tool_calls) = parse::extract_from_anthropic_response(&body_clone);
182 if let Err(e) = audit_clone.complete(usage, tool_calls, &body_clone).await {
183 tracing::warn!(error = %e, "buffered audit complete failed");
184 }
185 quota::post_update_tokens(
186 &db,
187 quota::PostUpdateParams {
188 tenant_id: audit_clone.ctx.tenant_id.as_ref(),
189 user_id: &audit_clone.ctx.user_id,
190 windows: &policy.quota_windows,
191 input_tokens: usage.input_tokens,
192 output_tokens: usage.output_tokens,
193 },
194 )
195 .await;
196 run_response_safety_scan(&db, &ai_request_id, &body_clone).await;
197 } else {
198 let err_msg = format!(
199 "upstream status {}: {}",
200 status.as_u16(),
201 String::from_utf8_lossy(&body_clone)
202 );
203 if let Err(e) = audit_clone.fail(&err_msg).await {
204 tracing::warn!(error = %e, "buffered audit fail update failed");
205 }
206 }
207 });
208 build_response(UpstreamOutcome::Buffered {
209 status,
210 content_type,
211 body,
212 served_model,
213 })
214 },
215 UpstreamOutcome::Streaming { status, stream } => {
216 let body = stream_tap::tap(stream, Arc::clone(&audit));
217 Response::builder()
218 .status(status)
219 .header(http::header::CONTENT_TYPE, "text/event-stream")
220 .header("cache-control", "no-cache")
221 .header("x-accel-buffering", "no")
222 .body(body)
223 .unwrap_or_else(|_| Response::new(Body::empty()))
224 },
225 }
226}
227
228async fn run_request_safety_scan(
229 db: &DbPool,
230 ai_request_id: &AiRequestId,
231 request: &AnthropicGatewayRequest,
232) {
233 let scanner = HeuristicScanner;
234 let findings = scanner.scan_request(request).await;
235 if findings.is_empty() {
236 return;
237 }
238 persist_findings(db, ai_request_id, findings).await;
239}
240
241async fn run_response_safety_scan(db: &DbPool, ai_request_id: &AiRequestId, body: &[u8]) {
242 let scanner = HeuristicScanner;
243 let findings = scanner.scan_response_final(body).await;
244 if findings.is_empty() {
245 return;
246 }
247 persist_findings(db, ai_request_id, findings).await;
248}
249
250async fn persist_findings(
251 db: &DbPool,
252 ai_request_id: &AiRequestId,
253 findings: Vec<super::safety::Finding>,
254) {
255 let repo = match AiSafetyFindingRepository::new(db) {
256 Ok(r) => r,
257 Err(e) => {
258 tracing::warn!(error = %e, "safety findings repo init failed");
259 return;
260 },
261 };
262 for f in findings {
263 let params = InsertSafetyFinding {
264 ai_request_id,
265 phase: f.phase,
266 severity: f.severity.as_str(),
267 category: &f.category,
268 scanner: f.scanner,
269 excerpt: f.excerpt.as_deref(),
270 };
271 if let Err(e) = repo.insert(params).await {
272 tracing::warn!(error = %e, "safety finding insert failed");
273 }
274 }
275}
276
277fn attach_request_id(mut response: Response<Body>, id: &AiRequestId) -> Response<Body> {
278 if let Ok(v) = HeaderValue::from_str(id.as_str()) {
279 response.headers_mut().insert(REQUEST_ID_HEADER, v);
280 }
281 response
282}