systemprompt_api/services/gateway/
service.rs1#![allow(clippy::clone_on_ref_ptr)]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use http::HeaderValue;
9use systemprompt_ai::InsertSafetyFinding;
10use systemprompt_ai::repository::AiSafetyFindingRepository;
11use systemprompt_database::DbPool;
12use systemprompt_identifiers::AiRequestId;
13use systemprompt_models::profile::GatewayConfig;
14
15use super::audit::{GatewayAudit, GatewayRequestContext};
16use super::policy::{GatewayPolicySpec, PolicyResolver};
17use super::protocol::canonical::CanonicalRequest;
18use super::protocol::canonical_response::CanonicalResponse;
19use super::protocol::inbound::InboundAdapter;
20use super::protocol::outbound::{OutboundCtx, OutboundOutcome};
21use super::registry::GatewayUpstreamRegistry;
22use super::safety::{HeuristicScanner, SafetyScanner};
23use super::{parse, quota, stream_tap};
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30#[derive(Debug)]
35pub struct DispatchInputs {
36 pub request: CanonicalRequest,
37 pub raw_body: Bytes,
38 pub ctx: GatewayRequestContext,
39 pub inbound: Arc<dyn InboundAdapter>,
40}
41
42impl GatewayService {
43 pub async fn dispatch(
44 config: &GatewayConfig,
45 db: &DbPool,
46 inputs: DispatchInputs,
47 ) -> Result<Response<Body>> {
48 let DispatchInputs {
49 request,
50 raw_body,
51 ctx,
52 inbound,
53 } = inputs;
54 if ctx.session_id.is_none() {
55 return Err(anyhow!(
56 "gateway dispatch missing conversation binding (session_id)"
57 ));
58 }
59
60 let route = config
61 .find_route(&request.model)
62 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
63
64 let secrets = systemprompt_config::SecretsBootstrap::get()
65 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
66
67 let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
68 anyhow!(
69 "Gateway API key secret '{}' not configured",
70 route.api_key_secret
71 )
72 })?;
73
74 let upstream = GatewayUpstreamRegistry::global()
75 .get(&route.provider)
76 .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
77
78 let is_streaming = request.stream;
79 let ai_request_id = ctx.ai_request_id.clone();
80
81 tracing::info!(
82 ai_request_id = %ai_request_id,
83 user_id = %ctx.user_id,
84 tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
85 model = %request.model,
86 provider = %route.provider,
87 upstream = %route.endpoint,
88 wire_protocol = %ctx.wire_protocol,
89 streaming = is_streaming,
90 "Gateway request dispatched"
91 );
92
93 let resolver = PolicyResolver::new(db)?;
94 let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
95
96 if !policy.model_allowed(&request.model) {
97 tracing::warn!(
98 ai_request_id = %ai_request_id,
99 model = %request.model,
100 "Gateway policy denied: model not in allowed list"
101 );
102 return Err(PolicyDenied(format!(
103 "model '{}' is not permitted by gateway policy",
104 request.model
105 ))
106 .into());
107 }
108
109 let audit = Arc::new(
110 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
111 );
112
113 if let Err(e) = audit.open(&request, &raw_body).await {
114 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
115 }
116
117 if let Some(decision) = quota::precheck_and_reserve(
118 db,
119 ctx.tenant_id.as_ref(),
120 &ctx.user_id,
121 &policy.quota_windows,
122 )
123 .await?
124 {
125 if !decision.allow {
126 let msg = format!(
127 "quota exceeded for window {}s (used {}/{:?})",
128 decision.window_seconds, decision.state.requests, decision.limit_requests
129 );
130 if let Err(e) = audit.fail(&msg).await {
131 tracing::warn!(error = %e, "quota audit fail failed");
132 }
133 return Err(QuotaExceeded {
134 message: msg,
135 retry_after_seconds: decision.window_seconds,
136 }
137 .into());
138 }
139 }
140
141 run_request_safety_scan(db, &ai_request_id, &request).await;
142
143 let upstream_model = route.effective_upstream_model(&request.model).to_string();
144 let outbound_ctx = OutboundCtx {
145 route,
146 api_key: upstream_api_key,
147 request: &request,
148 upstream_model: &upstream_model,
149 };
150
151 let outcome = match upstream.send(outbound_ctx).await {
152 Ok(o) => o,
153 Err(e) => {
154 if let Err(audit_err) = audit.fail(&e.to_string()).await {
155 tracing::warn!(error = %audit_err, "upstream audit fail failed");
156 }
157 return Err(e);
158 },
159 };
160
161 let response = finalize(
162 outcome,
163 FinalizeCtx {
164 audit: Arc::clone(&audit),
165 db: db.clone(),
166 ai_request_id: ai_request_id.clone(),
167 policy,
168 inbound,
169 request_model: request.model.clone(),
170 },
171 )
172 .await;
173 Ok(attach_request_id(response, &ai_request_id))
174 }
175}
176
177struct FinalizeCtx {
178 audit: Arc<GatewayAudit>,
179 db: DbPool,
180 ai_request_id: AiRequestId,
181 policy: GatewayPolicySpec,
182 inbound: Arc<dyn InboundAdapter>,
183 request_model: String,
184}
185
186#[derive(Debug, thiserror::Error)]
187#[error("{0}")]
188pub struct PolicyDenied(pub String);
189
190#[derive(Debug, thiserror::Error)]
191#[error("{message}")]
192pub struct QuotaExceeded {
193 pub message: String,
194 pub retry_after_seconds: i32,
195}
196
197async fn finalize(outcome: OutboundOutcome, fctx: FinalizeCtx) -> Response<Body> {
198 let FinalizeCtx {
199 audit,
200 db,
201 ai_request_id,
202 policy,
203 inbound,
204 request_model,
205 } = fctx;
206 match outcome {
207 OutboundOutcome::Buffered(canonical) => {
208 let body_bytes = inbound.render_response(&canonical);
209 let audit_clone = Arc::clone(&audit);
210 let body_for_task = body_bytes.clone();
211 tokio::spawn(async move {
212 let canonical_for_task = canonical;
213 let served_model = canonical_for_task.model.clone();
214 if !served_model.is_empty() {
215 audit_clone.set_served_model(&served_model).await;
216 }
217 let (usage, tool_calls) = parse::extract_from_canonical(&canonical_for_task);
218 if let Err(e) = audit_clone
219 .complete(usage, tool_calls, &canonical_for_task, &body_for_task)
220 .await
221 {
222 tracing::warn!(error = %e, "buffered audit complete failed");
223 }
224 quota::post_update_tokens(
225 &db,
226 quota::PostUpdateParams {
227 tenant_id: audit_clone.ctx.tenant_id.as_ref(),
228 user_id: &audit_clone.ctx.user_id,
229 windows: &policy.quota_windows,
230 input_tokens: usage.input_tokens,
231 output_tokens: usage.output_tokens,
232 },
233 )
234 .await;
235 run_response_safety_scan(&db, &ai_request_id, &canonical_for_task).await;
236 });
237 Response::builder()
238 .status(http::StatusCode::OK)
239 .header(http::header::CONTENT_TYPE, "application/json")
240 .body(Body::from(body_bytes))
241 .unwrap_or_else(|_| Response::new(Body::empty()))
242 },
243 OutboundOutcome::Streaming(stream) => {
244 let body = stream_tap::tap(stream, Arc::clone(&inbound), request_model, audit);
245 Response::builder()
246 .status(http::StatusCode::OK)
247 .header(http::header::CONTENT_TYPE, inbound.streaming_content_type())
248 .header("cache-control", "no-cache")
249 .header("x-accel-buffering", "no")
250 .body(body)
251 .unwrap_or_else(|_| Response::new(Body::empty()))
252 },
253 }
254}
255
256async fn run_request_safety_scan(
257 db: &DbPool,
258 ai_request_id: &AiRequestId,
259 request: &CanonicalRequest,
260) {
261 let scanner = HeuristicScanner;
262 let findings = scanner.scan_request(request).await;
263 if findings.is_empty() {
264 return;
265 }
266 persist_findings(db, ai_request_id, findings).await;
267}
268
269async fn run_response_safety_scan(
270 db: &DbPool,
271 ai_request_id: &AiRequestId,
272 response: &CanonicalResponse,
273) {
274 let scanner = HeuristicScanner;
275 let findings = scanner.scan_response_final(response).await;
276 if findings.is_empty() {
277 return;
278 }
279 persist_findings(db, ai_request_id, findings).await;
280}
281
282async fn persist_findings(
283 db: &DbPool,
284 ai_request_id: &AiRequestId,
285 findings: Vec<super::safety::Finding>,
286) {
287 let repo = match AiSafetyFindingRepository::new(db) {
288 Ok(r) => r,
289 Err(e) => {
290 tracing::warn!(error = %e, "safety findings repo init failed");
291 return;
292 },
293 };
294 for f in findings {
295 let params = InsertSafetyFinding {
296 ai_request_id,
297 phase: f.phase,
298 severity: f.severity.as_str(),
299 category: &f.category,
300 scanner: f.scanner,
301 excerpt: f.excerpt.as_deref(),
302 };
303 if let Err(e) = repo.insert(params).await {
304 tracing::warn!(error = %e, "safety finding insert failed");
305 }
306 }
307}
308
309fn attach_request_id(mut response: Response<Body>, id: &AiRequestId) -> Response<Body> {
310 if let Ok(v) = HeaderValue::from_str(id.as_str()) {
311 response.headers_mut().insert(REQUEST_ID_HEADER, v);
312 }
313 response
314}