systemprompt_api/services/gateway/service/
mod.rs1#![expect(
4 clippy::clone_on_ref_ptr,
5 reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
6)]
7
8mod finalize;
9
10use std::sync::Arc;
11
12use anyhow::{Result, anyhow};
13use axum::body::Body;
14use axum::response::Response;
15use bytes::Bytes;
16use systemprompt_database::DbPool;
17use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};
18
19use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
20use super::audit::{GatewayAudit, GatewayRequestContext};
21use super::policy::PolicyResolver;
22use super::protocol::canonical::CanonicalRequest;
23use super::protocol::inbound::InboundAdapter;
24use super::protocol::outbound::OutboundCtx;
25use super::quota;
26use super::registry::GatewayUpstreamRegistry;
27
28pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
29
30#[derive(Debug, Clone, Copy)]
31pub struct GatewayService;
32
33#[derive(Debug)]
34pub struct DispatchInputs {
35 pub request: CanonicalRequest,
36 pub raw_body: Bytes,
37 pub ctx: GatewayRequestContext,
38 pub inbound: Arc<dyn InboundAdapter>,
39}
40
41#[derive(Debug, thiserror::Error)]
42#[error("{0}")]
43pub struct PolicyDenied(pub String);
44
45#[derive(Debug, thiserror::Error)]
46#[error("{message}")]
47pub struct QuotaExceeded {
48 pub message: String,
49 pub retry_after_seconds: i32,
50}
51
52impl GatewayService {
53 pub async fn dispatch(
54 config: &GatewayConfig,
55 registry: &ProviderRegistry,
56 db: &DbPool,
57 inputs: DispatchInputs,
58 ) -> Result<Response<Body>> {
59 let DispatchInputs {
60 request,
61 raw_body,
62 ctx,
63 inbound,
64 } = inputs;
65 if ctx.session_id.is_none() {
66 return Err(anyhow!(
67 "gateway dispatch missing conversation binding (session_id)"
68 ));
69 }
70
71 let ai_request_id = ctx.ai_request_id.clone();
72
73 if !config.is_model_exposed(registry, &request.model) {
74 tracing::warn!(
75 ai_request_id = %ai_request_id,
76 model = %request.model,
77 "Gateway denied: model not exposed by gateway policy or registry"
78 );
79 return Err(PolicyDenied(format!(
80 "model '{}' is not permitted by gateway policy",
81 request.model
82 ))
83 .into());
84 }
85
86 let route = config
87 .resolve_route(registry, &request.model)
88 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
89
90 let provider = route.resolve(registry).ok_or_else(|| {
91 anyhow!(
92 "Gateway route '{}' provider '{}' is not declared in profile.providers",
93 route.id.as_str(),
94 route.provider.as_str()
95 )
96 })?;
97
98 let secrets = systemprompt_config::SecretsBootstrap::get()
99 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
100
101 let upstream_api_key = secrets
102 .get(provider.api_key_secret.as_str())
103 .ok_or_else(|| {
104 anyhow!(
105 "Gateway API key secret '{}' not configured",
106 provider.api_key_secret.as_str()
107 )
108 })?;
109
110 let upstream = GatewayUpstreamRegistry::global()
111 .get(provider.protocol.as_tag())
112 .ok_or_else(|| {
113 anyhow!(
114 "Gateway has no outbound adapter for wire protocol '{}'",
115 provider.protocol.as_tag()
116 )
117 })?;
118
119 let is_streaming = request.stream;
120
121 tracing::info!(
122 ai_request_id = %ai_request_id,
123 user_id = %ctx.user_id,
124 model = %request.model,
125 provider = %route.provider,
126 upstream = %provider.endpoint,
127 wire_protocol = %ctx.wire_protocol,
128 streaming = is_streaming,
129 "Gateway request dispatched"
130 );
131
132 let resolver = PolicyResolver::new(db)?;
133 let policy = resolver.resolve().await;
134
135 let audit = Arc::new(
136 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
137 );
138
139 if let Err(e) = audit.open(&request, &raw_body).await {
140 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
141 }
142
143 if let Some(decision) =
144 quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
145 {
146 if !decision.allow {
147 let msg = format!(
148 "quota exceeded for window {}s (used {}/{:?})",
149 decision.window_seconds, decision.state.requests, decision.limit_requests
150 );
151 if let Err(e) = audit.fail(&msg).await {
152 tracing::warn!(error = %e, "quota audit fail failed");
153 }
154 return Err(QuotaExceeded {
155 message: msg,
156 retry_after_seconds: decision.window_seconds,
157 }
158 .into());
159 }
160 }
161
162 run_request_safety_scan(db, &ai_request_id, &request).await;
163
164 let upstream_model = route.effective_upstream_model(&request.model).to_owned();
165 let outbound_ctx = OutboundCtx {
166 route: route.as_ref(),
167 endpoint: &provider.endpoint,
168 api_key: upstream_api_key,
169 request: &request,
170 upstream_model: &upstream_model,
171 };
172
173 let outcome = match upstream.send(outbound_ctx).await {
174 Ok(o) => o,
175 Err(e) => {
176 if let Err(audit_err) = audit.fail(&e.to_string()).await {
177 tracing::warn!(error = %audit_err, "upstream audit fail failed");
178 }
179 return Err(e);
180 },
181 };
182
183 let response = finalize(
184 outcome,
185 FinalizeCtx {
186 audit: Arc::clone(&audit),
187 db: db.clone(),
188 ai_request_id: ai_request_id.clone(),
189 policy,
190 inbound,
191 request_model: request.model.clone(),
192 },
193 )
194 .await;
195 Ok(attach_request_id(response, &ai_request_id))
196 }
197}