systemprompt_api/services/gateway/
service.rs1#![allow(clippy::clone_on_ref_ptr)]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use http::HeaderValue;
9use systemprompt_ai::InsertSafetyFinding;
10use systemprompt_ai::repository::AiSafetyFindingRepository;
11use systemprompt_database::DbPool;
12use systemprompt_identifiers::AiRequestId;
13use systemprompt_models::profile::GatewayConfig;
14
15use super::audit::{GatewayAudit, GatewayRequestContext};
16use super::models::AnthropicGatewayRequest;
17use super::policy::{GatewayPolicySpec, PolicyResolver};
18use super::registry::GatewayUpstreamRegistry;
19use super::safety::{HeuristicScanner, SafetyScanner};
20use super::upstream::{UpstreamCtx, UpstreamOutcome, build_response};
21use super::{parse, quota, stream_tap};
22
23pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
24
25#[derive(Debug, Clone, Copy)]
26pub struct GatewayService;
27
28impl GatewayService {
29 pub async fn dispatch(
30 config: &GatewayConfig,
31 request: AnthropicGatewayRequest,
32 raw_body: Bytes,
33 ctx: GatewayRequestContext,
34 db: &DbPool,
35 ) -> Result<Response<Body>> {
36 let route = config
37 .find_route(&request.model)
38 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
39
40 let secrets = systemprompt_models::SecretsBootstrap::get()
41 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
42
43 let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
44 anyhow!(
45 "Gateway API key secret '{}' not configured",
46 route.api_key_secret
47 )
48 })?;
49
50 let upstream = GatewayUpstreamRegistry::global()
51 .get(&route.provider)
52 .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
53
54 let is_streaming = request.stream.unwrap_or(false);
55 let ai_request_id = ctx.ai_request_id.clone();
56
57 tracing::info!(
58 ai_request_id = %ai_request_id,
59 user_id = %ctx.user_id,
60 tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
61 model = %request.model,
62 provider = %route.provider,
63 upstream = %route.endpoint,
64 streaming = is_streaming,
65 "Gateway request dispatched"
66 );
67
68 let resolver = PolicyResolver::new(db)?;
69 let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
70
71 if !policy.model_allowed(&request.model) {
72 tracing::warn!(
73 ai_request_id = %ai_request_id,
74 model = %request.model,
75 "Gateway policy denied: model not in allowed list"
76 );
77 return Err(PolicyDenied(format!(
78 "model '{}' is not permitted by gateway policy",
79 request.model
80 ))
81 .into());
82 }
83
84 let audit = Arc::new(
85 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
86 );
87
88 if let Err(e) = audit.open(&raw_body).await {
89 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
90 }
91
92 if let Some(decision) = quota::precheck_and_reserve(
93 db,
94 ctx.tenant_id.as_ref(),
95 &ctx.user_id,
96 &policy.quota_windows,
97 )
98 .await?
99 {
100 if !decision.allow {
101 let msg = format!(
102 "quota exceeded for window {}s (used {}/{:?})",
103 decision.window_seconds, decision.state.requests, decision.limit_requests
104 );
105 let _ = audit.fail(&msg).await;
106 return Err(QuotaExceeded {
107 message: msg,
108 retry_after_seconds: decision.window_seconds,
109 }
110 .into());
111 }
112 }
113
114 run_request_safety_scan(db, &ai_request_id, &request).await;
115
116 let upstream_ctx = UpstreamCtx {
117 route,
118 api_key: upstream_api_key,
119 raw_body,
120 request: &request,
121 is_streaming,
122 };
123
124 let outcome = match upstream.proxy(upstream_ctx).await {
125 Ok(o) => o,
126 Err(e) => {
127 let _ = audit.fail(&e.to_string()).await;
128 return Err(e);
129 },
130 };
131
132 let response = finalize(
133 outcome,
134 Arc::clone(&audit),
135 db.clone(),
136 ai_request_id.clone(),
137 policy,
138 )
139 .await;
140 Ok(attach_request_id(response, &ai_request_id))
141 }
142}
143
144#[derive(Debug, thiserror::Error)]
145#[error("{0}")]
146pub struct PolicyDenied(pub String);
147
148#[derive(Debug, thiserror::Error)]
149#[error("{message}")]
150pub struct QuotaExceeded {
151 pub message: String,
152 pub retry_after_seconds: i32,
153}
154
155async fn finalize(
156 outcome: UpstreamOutcome,
157 audit: Arc<GatewayAudit>,
158 db: DbPool,
159 ai_request_id: AiRequestId,
160 policy: GatewayPolicySpec,
161) -> Response<Body> {
162 match outcome {
163 UpstreamOutcome::Buffered {
164 status,
165 content_type,
166 body,
167 } => {
168 let body_clone = body.clone();
169 let audit_clone = Arc::clone(&audit);
170 tokio::spawn(async move {
171 if status.is_success() {
172 let (usage, tool_calls) = parse::extract_from_anthropic_response(&body_clone);
173 if let Err(e) = audit_clone.complete(usage, tool_calls, &body_clone).await {
174 tracing::warn!(error = %e, "buffered audit complete failed");
175 }
176 quota::post_update_tokens(
177 &db,
178 quota::PostUpdateParams {
179 tenant_id: audit_clone.ctx.tenant_id.as_ref(),
180 user_id: &audit_clone.ctx.user_id,
181 windows: &policy.quota_windows,
182 input_tokens: usage.input_tokens,
183 output_tokens: usage.output_tokens,
184 },
185 )
186 .await;
187 run_response_safety_scan(&db, &ai_request_id, &body_clone).await;
188 } else {
189 let err_msg = format!(
190 "upstream status {}: {}",
191 status.as_u16(),
192 String::from_utf8_lossy(&body_clone)
193 );
194 if let Err(e) = audit_clone.fail(&err_msg).await {
195 tracing::warn!(error = %e, "buffered audit fail update failed");
196 }
197 }
198 });
199 build_response(UpstreamOutcome::Buffered {
200 status,
201 content_type,
202 body,
203 })
204 },
205 UpstreamOutcome::Streaming { status, stream } => {
206 let body = stream_tap::tap(stream, Arc::clone(&audit));
207 Response::builder()
208 .status(status)
209 .header(http::header::CONTENT_TYPE, "text/event-stream")
210 .header("cache-control", "no-cache")
211 .header("x-accel-buffering", "no")
212 .body(body)
213 .unwrap_or_else(|_| Response::new(Body::empty()))
214 },
215 }
216}
217
218async fn run_request_safety_scan(
219 db: &DbPool,
220 ai_request_id: &AiRequestId,
221 request: &AnthropicGatewayRequest,
222) {
223 let scanner = HeuristicScanner;
224 let findings = scanner.scan_request(request).await;
225 if findings.is_empty() {
226 return;
227 }
228 persist_findings(db, ai_request_id, findings).await;
229}
230
231async fn run_response_safety_scan(db: &DbPool, ai_request_id: &AiRequestId, body: &[u8]) {
232 let scanner = HeuristicScanner;
233 let findings = scanner.scan_response_final(body).await;
234 if findings.is_empty() {
235 return;
236 }
237 persist_findings(db, ai_request_id, findings).await;
238}
239
240async fn persist_findings(
241 db: &DbPool,
242 ai_request_id: &AiRequestId,
243 findings: Vec<super::safety::Finding>,
244) {
245 let repo = match AiSafetyFindingRepository::new(db) {
246 Ok(r) => r,
247 Err(e) => {
248 tracing::warn!(error = %e, "safety findings repo init failed");
249 return;
250 },
251 };
252 for f in findings {
253 let params = InsertSafetyFinding {
254 ai_request_id,
255 phase: f.phase,
256 severity: f.severity.as_str(),
257 category: &f.category,
258 scanner: f.scanner,
259 excerpt: f.excerpt.as_deref(),
260 };
261 if let Err(e) = repo.insert(params).await {
262 tracing::warn!(error = %e, "safety finding insert failed");
263 }
264 }
265}
266
267fn attach_request_id(mut response: Response<Body>, id: &AiRequestId) -> Response<Body> {
268 if let Ok(v) = HeaderValue::from_str(id.as_str()) {
269 response.headers_mut().insert(REQUEST_ID_HEADER, v);
270 }
271 response
272}