systemprompt_api/services/gateway/service/
mod.rs1#![expect(
4 clippy::clone_on_ref_ptr,
5 reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
6)]
7
8mod finalize;
9
10use std::sync::Arc;
11
12use anyhow::{Result, anyhow};
13use axum::body::Body;
14use axum::response::Response;
15use bytes::Bytes;
16use systemprompt_ai::SafetyConfig;
17use systemprompt_database::DbPool;
18use systemprompt_identifiers::AiRequestId;
19use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};
20
21use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
22use super::audit::{GatewayAudit, GatewayRequestContext};
23use super::policy::PolicyResolver;
24use super::protocol::canonical::CanonicalRequest;
25use super::protocol::inbound::InboundAdapter;
26use super::protocol::outbound::OutboundCtx;
27use super::quota;
28use super::registry::GatewayUpstreamRegistry;
29
30pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
31
32#[derive(Debug, Clone, Copy)]
33pub struct GatewayService;
34
35#[derive(Debug)]
36pub struct DispatchInputs {
37 pub request: CanonicalRequest,
38 pub raw_body: Bytes,
39 pub ctx: GatewayRequestContext,
40 pub inbound: Arc<dyn InboundAdapter>,
41}
42
43#[derive(Debug, thiserror::Error)]
44pub enum DispatchError {
45 #[error(transparent)]
46 PreAudit(anyhow::Error),
47 #[error(transparent)]
48 Recorded(anyhow::Error),
49}
50
51#[derive(Debug, thiserror::Error)]
52#[error("{0}")]
53pub struct PolicyDenied(pub String);
54
55#[derive(Debug, thiserror::Error)]
56#[error("{message}")]
57pub struct QuotaExceeded {
58 pub message: String,
59 pub retry_after_seconds: i32,
60}
61
62#[derive(Debug, thiserror::Error)]
63#[error("{message}")]
64pub struct SafetyBlocked {
65 pub category: String,
66 pub message: String,
67}
68
69impl GatewayService {
70 pub async fn dispatch(
71 config: &GatewayConfig,
72 registry: &ProviderRegistry,
73 db: &DbPool,
74 inputs: DispatchInputs,
75 ) -> Result<Response<Body>, DispatchError> {
76 let DispatchInputs {
77 request,
78 raw_body,
79 ctx,
80 inbound,
81 } = inputs;
82 if ctx.session_id.is_none() {
83 return Err(DispatchError::PreAudit(anyhow!(
84 "gateway dispatch missing conversation binding (session_id)"
85 )));
86 }
87
88 let ai_request_id = ctx.ai_request_id.clone();
89
90 if !config.is_model_exposed(registry, &request.model) {
91 tracing::warn!(
92 ai_request_id = %ai_request_id,
93 model = %request.model,
94 "Gateway denied: model not exposed by gateway policy or registry"
95 );
96 return Err(DispatchError::PreAudit(
97 PolicyDenied(format!(
98 "model '{}' is not permitted by gateway policy",
99 request.model
100 ))
101 .into(),
102 ));
103 }
104
105 let route = config
106 .resolve_route(registry, &request.model)
107 .ok_or_else(|| {
108 DispatchError::PreAudit(anyhow!(
109 "No gateway route matches model '{}'",
110 request.model
111 ))
112 })?;
113
114 let provider = route.resolve(registry).ok_or_else(|| {
115 DispatchError::PreAudit(anyhow!(
116 "Gateway route '{}' provider '{}' is not declared in profile.providers",
117 route.id.as_str(),
118 route.provider.as_str()
119 ))
120 })?;
121
122 let secrets = systemprompt_config::SecretsBootstrap::get()
123 .map_err(|e| DispatchError::PreAudit(anyhow!("Secrets not available: {e}")))?;
124
125 let upstream_api_key = secrets
126 .get(provider.api_key_secret.as_str())
127 .ok_or_else(|| {
128 DispatchError::PreAudit(anyhow!(
129 "Gateway API key secret '{}' not configured",
130 provider.api_key_secret.as_str()
131 ))
132 })?;
133
134 let upstream = GatewayUpstreamRegistry::global()
135 .get(provider.wire.as_tag())
136 .ok_or_else(|| {
137 DispatchError::PreAudit(anyhow!(
138 "Gateway has no outbound adapter for wire protocol '{}'",
139 provider.wire.as_tag()
140 ))
141 })?;
142
143 let is_streaming = request.stream;
144
145 tracing::info!(
146 ai_request_id = %ai_request_id,
147 user_id = %ctx.user_id,
148 model = %request.model,
149 provider = %route.provider,
150 upstream = %provider.endpoint,
151 wire_protocol = %ctx.wire_protocol,
152 streaming = is_streaming,
153 "Gateway request dispatched"
154 );
155
156 let resolver = PolicyResolver::new(db).map_err(DispatchError::PreAudit)?;
157 let policy = resolver.resolve().await;
158
159 let audit = Arc::new(
160 GatewayAudit::new(db, ctx.clone())
161 .map_err(|e| DispatchError::PreAudit(anyhow!("audit init failed: {e}")))?,
162 );
163
164 if let Err(e) = audit.open(&request, &raw_body).await {
165 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
166 }
167
168 let reservation = quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows)
169 .await
170 .map_err(DispatchError::Recorded)?;
171 if let Some(decision) = reservation {
172 if !decision.allow {
173 let msg = format!(
174 "quota exceeded for window {}s (used {}/{:?})",
175 decision.window_seconds, decision.state.requests, decision.limit_requests
176 );
177 if let Err(e) = audit.fail(&msg).await {
178 tracing::warn!(error = %e, "quota audit fail failed");
179 }
180 return Err(DispatchError::Recorded(
181 QuotaExceeded {
182 message: msg,
183 retry_after_seconds: decision.window_seconds,
184 }
185 .into(),
186 ));
187 }
188 }
189
190 enforce_request_safety(db, &ai_request_id, &request, &policy.safety, &audit).await?;
191
192 let upstream_model = route.effective_upstream_model(&request.model).to_owned();
193 let model_limits = provider.find_model(&upstream_model).map(|m| m.limits);
194 let outbound_ctx = OutboundCtx {
195 route: route.as_ref(),
196 endpoint: &provider.endpoint,
197 api_key: upstream_api_key,
198 request: &request,
199 upstream_model: &upstream_model,
200 model_limits,
201 };
202
203 let outcome = match upstream.send(outbound_ctx).await {
204 Ok(o) => o,
205 Err(e) => {
206 audit_upstream_failure(&audit, upstream.provider_tag(), &request.model, &e).await;
207 return Err(DispatchError::Recorded(e));
208 },
209 };
210
211 let response = finalize(
212 outcome,
213 FinalizeCtx {
214 audit: Arc::clone(&audit),
215 db: db.clone(),
216 ai_request_id: ai_request_id.clone(),
217 policy,
218 inbound,
219 request_model: request.model.clone(),
220 },
221 )
222 .await;
223 Ok(attach_request_id(response, &ai_request_id))
224 }
225}
226
227async fn enforce_request_safety(
228 db: &DbPool,
229 ai_request_id: &AiRequestId,
230 request: &CanonicalRequest,
231 safety: &SafetyConfig,
232 audit: &GatewayAudit,
233) -> Result<(), DispatchError> {
234 let findings = run_request_safety_scan(db, ai_request_id, request, safety).await;
235 let Some(finding) = findings
236 .iter()
237 .find(|f| safety.block_categories.contains(&f.category))
238 else {
239 return Ok(());
240 };
241 let msg = format!(
242 "request blocked by safety policy: category '{}'",
243 finding.category
244 );
245 tracing::warn!(
246 ai_request_id = %ai_request_id,
247 category = %finding.category,
248 scanner = %finding.scanner,
249 "Gateway blocked request by safety policy"
250 );
251 if let Err(e) = audit.fail(&msg).await {
252 tracing::warn!(error = %e, "safety-block audit fail failed");
253 }
254 Err(DispatchError::Recorded(
255 SafetyBlocked {
256 category: finding.category.clone(),
257 message: msg,
258 }
259 .into(),
260 ))
261}
262
263async fn audit_upstream_failure(
264 audit: &GatewayAudit,
265 provider: &str,
266 model: &str,
267 error: &anyhow::Error,
268) {
269 tracing::warn!(
270 provider = %provider,
271 model = %model,
272 error = %error,
273 "gateway upstream call failed"
274 );
275 if let Err(audit_err) = audit.fail(&error.to_string()).await {
276 tracing::warn!(error = %audit_err, "upstream audit fail failed");
277 }
278}