systemprompt_api/services/gateway/service/
mod.rs1#![expect(
4 clippy::clone_on_ref_ptr,
5 reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
6)]
7
8mod finalize;
9
10use std::sync::Arc;
11
12use anyhow::{Result, anyhow};
13use axum::body::Body;
14use axum::response::Response;
15use bytes::Bytes;
16use systemprompt_database::DbPool;
17use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};
18
19use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
20use super::audit::{GatewayAudit, GatewayRequestContext};
21use super::policy::PolicyResolver;
22use super::protocol::canonical::CanonicalRequest;
23use super::protocol::inbound::InboundAdapter;
24use super::protocol::outbound::OutboundCtx;
25use super::quota;
26use super::registry::GatewayUpstreamRegistry;
27
28pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
29
30#[derive(Debug, Clone, Copy)]
31pub struct GatewayService;
32
33#[derive(Debug)]
34pub struct DispatchInputs {
35 pub request: CanonicalRequest,
36 pub raw_body: Bytes,
37 pub ctx: GatewayRequestContext,
38 pub inbound: Arc<dyn InboundAdapter>,
39}
40
41#[derive(Debug, thiserror::Error)]
42#[error("{0}")]
43pub struct PolicyDenied(pub String);
44
45#[derive(Debug, thiserror::Error)]
46#[error("{message}")]
47pub struct QuotaExceeded {
48 pub message: String,
49 pub retry_after_seconds: i32,
50}
51
52#[derive(Debug, thiserror::Error)]
53#[error("{message}")]
54pub struct SafetyBlocked {
55 pub category: String,
56 pub message: String,
57}
58
59impl GatewayService {
60 pub async fn dispatch(
61 config: &GatewayConfig,
62 registry: &ProviderRegistry,
63 db: &DbPool,
64 inputs: DispatchInputs,
65 ) -> Result<Response<Body>> {
66 let DispatchInputs {
67 request,
68 raw_body,
69 ctx,
70 inbound,
71 } = inputs;
72 if ctx.session_id.is_none() {
73 return Err(anyhow!(
74 "gateway dispatch missing conversation binding (session_id)"
75 ));
76 }
77
78 let ai_request_id = ctx.ai_request_id.clone();
79
80 if !config.is_model_exposed(registry, &request.model) {
81 tracing::warn!(
82 ai_request_id = %ai_request_id,
83 model = %request.model,
84 "Gateway denied: model not exposed by gateway policy or registry"
85 );
86 return Err(PolicyDenied(format!(
87 "model '{}' is not permitted by gateway policy",
88 request.model
89 ))
90 .into());
91 }
92
93 let route = config
94 .resolve_route(registry, &request.model)
95 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
96
97 let provider = route.resolve(registry).ok_or_else(|| {
98 anyhow!(
99 "Gateway route '{}' provider '{}' is not declared in profile.providers",
100 route.id.as_str(),
101 route.provider.as_str()
102 )
103 })?;
104
105 let secrets = systemprompt_config::SecretsBootstrap::get()
106 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
107
108 let upstream_api_key = secrets
109 .get(provider.api_key_secret.as_str())
110 .ok_or_else(|| {
111 anyhow!(
112 "Gateway API key secret '{}' not configured",
113 provider.api_key_secret.as_str()
114 )
115 })?;
116
117 let upstream = GatewayUpstreamRegistry::global()
118 .get(provider.protocol.as_tag())
119 .ok_or_else(|| {
120 anyhow!(
121 "Gateway has no outbound adapter for wire protocol '{}'",
122 provider.protocol.as_tag()
123 )
124 })?;
125
126 let is_streaming = request.stream;
127
128 tracing::info!(
129 ai_request_id = %ai_request_id,
130 user_id = %ctx.user_id,
131 model = %request.model,
132 provider = %route.provider,
133 upstream = %provider.endpoint,
134 wire_protocol = %ctx.wire_protocol,
135 streaming = is_streaming,
136 "Gateway request dispatched"
137 );
138
139 let resolver = PolicyResolver::new(db)?;
140 let policy = resolver.resolve().await;
141
142 let audit = Arc::new(
143 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
144 );
145
146 if let Err(e) = audit.open(&request, &raw_body).await {
147 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
148 }
149
150 if let Some(decision) =
151 quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
152 {
153 if !decision.allow {
154 let msg = format!(
155 "quota exceeded for window {}s (used {}/{:?})",
156 decision.window_seconds, decision.state.requests, decision.limit_requests
157 );
158 if let Err(e) = audit.fail(&msg).await {
159 tracing::warn!(error = %e, "quota audit fail failed");
160 }
161 return Err(QuotaExceeded {
162 message: msg,
163 retry_after_seconds: decision.window_seconds,
164 }
165 .into());
166 }
167 }
168
169 let findings = run_request_safety_scan(db, &ai_request_id, &request, &policy.safety).await;
170 if let Some(finding) = findings
171 .iter()
172 .find(|f| policy.safety.block_categories.contains(&f.category))
173 {
174 let msg = format!(
175 "request blocked by safety policy: category '{}'",
176 finding.category
177 );
178 tracing::warn!(
179 ai_request_id = %ai_request_id,
180 category = %finding.category,
181 scanner = %finding.scanner,
182 "Gateway blocked request by safety policy"
183 );
184 if let Err(e) = audit.fail(&msg).await {
185 tracing::warn!(error = %e, "safety-block audit fail failed");
186 }
187 return Err(SafetyBlocked {
188 category: finding.category.clone(),
189 message: msg,
190 }
191 .into());
192 }
193
194 let upstream_model = route.effective_upstream_model(&request.model).to_owned();
195 let model_limits = provider.find_model(&upstream_model).map(|m| m.limits);
196 let outbound_ctx = OutboundCtx {
197 route: route.as_ref(),
198 endpoint: &provider.endpoint,
199 api_key: upstream_api_key,
200 request: &request,
201 upstream_model: &upstream_model,
202 model_limits,
203 };
204
205 let outcome = match upstream.send(outbound_ctx).await {
206 Ok(o) => o,
207 Err(e) => {
208 if let Err(audit_err) = audit.fail(&e.to_string()).await {
209 tracing::warn!(error = %audit_err, "upstream audit fail failed");
210 }
211 return Err(e);
212 },
213 };
214
215 let response = finalize(
216 outcome,
217 FinalizeCtx {
218 audit: Arc::clone(&audit),
219 db: db.clone(),
220 ai_request_id: ai_request_id.clone(),
221 policy,
222 inbound,
223 request_model: request.model.clone(),
224 },
225 )
226 .await;
227 Ok(attach_request_id(response, &ai_request_id))
228 }
229}