Skip to main content

systemprompt_api/services/gateway/service/
mod.rs

1//! Gateway dispatch entry point: route resolution, policy and quota checks,
2//! upstream send, and response finalization.
3#![expect(
4    clippy::clone_on_ref_ptr,
5    reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
6)]
7
8mod finalize;
9
10use std::sync::Arc;
11
12use anyhow::{Result, anyhow};
13use axum::body::Body;
14use axum::response::Response;
15use bytes::Bytes;
16use systemprompt_ai::SafetyConfig;
17use systemprompt_database::DbPool;
18use systemprompt_identifiers::AiRequestId;
19use systemprompt_models::profile::{GatewayConfig, ProviderRegistry};
20
21use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
22use super::audit::{GatewayAudit, GatewayRequestContext};
23use super::policy::PolicyResolver;
24use super::protocol::canonical::CanonicalRequest;
25use super::protocol::inbound::InboundAdapter;
26use super::protocol::outbound::OutboundCtx;
27use super::quota;
28use super::registry::GatewayUpstreamRegistry;
29
30pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
31
32#[derive(Debug, Clone, Copy)]
33pub struct GatewayService;
34
35#[derive(Debug)]
36pub struct DispatchInputs {
37    pub request: CanonicalRequest,
38    pub raw_body: Bytes,
39    pub ctx: GatewayRequestContext,
40    pub inbound: Arc<dyn InboundAdapter>,
41}
42
43#[derive(Debug, thiserror::Error)]
44pub enum DispatchError {
45    #[error(transparent)]
46    PreAudit(anyhow::Error),
47    #[error(transparent)]
48    Recorded(anyhow::Error),
49}
50
51#[derive(Debug, thiserror::Error)]
52#[error("{0}")]
53pub struct PolicyDenied(pub String);
54
55#[derive(Debug, thiserror::Error)]
56#[error("{message}")]
57pub struct QuotaExceeded {
58    pub message: String,
59    pub retry_after_seconds: i32,
60}
61
62#[derive(Debug, thiserror::Error)]
63#[error("{message}")]
64pub struct SafetyBlocked {
65    pub category: String,
66    pub message: String,
67}
68
69impl GatewayService {
70    pub async fn dispatch(
71        config: &GatewayConfig,
72        registry: &ProviderRegistry,
73        db: &DbPool,
74        inputs: DispatchInputs,
75    ) -> Result<Response<Body>, DispatchError> {
76        let DispatchInputs {
77            request,
78            raw_body,
79            ctx,
80            inbound,
81        } = inputs;
82        if ctx.session_id.is_none() {
83            return Err(DispatchError::PreAudit(anyhow!(
84                "gateway dispatch missing conversation binding (session_id)"
85            )));
86        }
87
88        let ai_request_id = ctx.ai_request_id.clone();
89
90        if !config.is_model_exposed(registry, &request.model) {
91            tracing::warn!(
92                ai_request_id = %ai_request_id,
93                model = %request.model,
94                "Gateway denied: model not exposed by gateway policy or registry"
95            );
96            return Err(DispatchError::PreAudit(
97                PolicyDenied(format!(
98                    "model '{}' is not permitted by gateway policy",
99                    request.model
100                ))
101                .into(),
102            ));
103        }
104
105        let route = config
106            .resolve_route(registry, &request.model)
107            .ok_or_else(|| {
108                DispatchError::PreAudit(anyhow!(
109                    "No gateway route matches model '{}'",
110                    request.model
111                ))
112            })?;
113
114        let provider = route.resolve(registry).ok_or_else(|| {
115            DispatchError::PreAudit(anyhow!(
116                "Gateway route '{}' provider '{}' is not declared in profile.providers",
117                route.id.as_str(),
118                route.provider.as_str()
119            ))
120        })?;
121
122        let secrets = systemprompt_config::SecretsBootstrap::get()
123            .map_err(|e| DispatchError::PreAudit(anyhow!("Secrets not available: {e}")))?;
124
125        let upstream_api_key = secrets
126            .get(provider.api_key_secret.as_str())
127            .ok_or_else(|| {
128                DispatchError::PreAudit(anyhow!(
129                    "Gateway API key secret '{}' not configured",
130                    provider.api_key_secret.as_str()
131                ))
132            })?;
133
134        let upstream = GatewayUpstreamRegistry::global()
135            .get(provider.wire.as_tag())
136            .ok_or_else(|| {
137                DispatchError::PreAudit(anyhow!(
138                    "Gateway has no outbound adapter for wire protocol '{}'",
139                    provider.wire.as_tag()
140                ))
141            })?;
142
143        let is_streaming = request.stream;
144
145        tracing::info!(
146            ai_request_id = %ai_request_id,
147            user_id = %ctx.user_id,
148            model = %request.model,
149            provider = %route.provider,
150            upstream = %provider.endpoint,
151            wire_protocol = %ctx.wire_protocol,
152            streaming = is_streaming,
153            "Gateway request dispatched"
154        );
155
156        let resolver = PolicyResolver::new(db).map_err(DispatchError::PreAudit)?;
157        let policy = resolver.resolve().await;
158
159        let audit = Arc::new(
160            GatewayAudit::new(db, ctx.clone())
161                .map_err(|e| DispatchError::PreAudit(anyhow!("audit init failed: {e}")))?,
162        );
163
164        if let Err(e) = audit.open(&request, &raw_body).await {
165            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
166        }
167
168        let reservation = quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows)
169            .await
170            .map_err(DispatchError::Recorded)?;
171        if let Some(decision) = reservation {
172            if !decision.allow {
173                let msg = format!(
174                    "quota exceeded for window {}s (used {}/{:?})",
175                    decision.window_seconds, decision.state.requests, decision.limit_requests
176                );
177                if let Err(e) = audit.fail(&msg).await {
178                    tracing::warn!(error = %e, "quota audit fail failed");
179                }
180                return Err(DispatchError::Recorded(
181                    QuotaExceeded {
182                        message: msg,
183                        retry_after_seconds: decision.window_seconds,
184                    }
185                    .into(),
186                ));
187            }
188        }
189
190        enforce_request_safety(db, &ai_request_id, &request, &policy.safety, &audit).await?;
191
192        let upstream_model = route.effective_upstream_model(&request.model).to_owned();
193        let model_limits = provider.find_model(&upstream_model).map(|m| m.limits);
194        let outbound_ctx = OutboundCtx {
195            route: route.as_ref(),
196            endpoint: &provider.endpoint,
197            api_key: upstream_api_key,
198            request: &request,
199            upstream_model: &upstream_model,
200            model_limits,
201        };
202
203        let outcome = match upstream.send(outbound_ctx).await {
204            Ok(o) => o,
205            Err(e) => {
206                audit_upstream_failure(&audit, upstream.provider_tag(), &request.model, &e).await;
207                return Err(DispatchError::Recorded(e));
208            },
209        };
210
211        let response = finalize(
212            outcome,
213            FinalizeCtx {
214                audit: Arc::clone(&audit),
215                db: db.clone(),
216                ai_request_id: ai_request_id.clone(),
217                policy,
218                inbound,
219                request_model: request.model.clone(),
220            },
221        )
222        .await;
223        Ok(attach_request_id(response, &ai_request_id))
224    }
225}
226
227async fn enforce_request_safety(
228    db: &DbPool,
229    ai_request_id: &AiRequestId,
230    request: &CanonicalRequest,
231    safety: &SafetyConfig,
232    audit: &GatewayAudit,
233) -> Result<(), DispatchError> {
234    let findings = run_request_safety_scan(db, ai_request_id, request, safety).await;
235    let Some(finding) = findings
236        .iter()
237        .find(|f| safety.block_categories.contains(&f.category))
238    else {
239        return Ok(());
240    };
241    let msg = format!(
242        "request blocked by safety policy: category '{}'",
243        finding.category
244    );
245    tracing::warn!(
246        ai_request_id = %ai_request_id,
247        category = %finding.category,
248        scanner = %finding.scanner,
249        "Gateway blocked request by safety policy"
250    );
251    if let Err(e) = audit.fail(&msg).await {
252        tracing::warn!(error = %e, "safety-block audit fail failed");
253    }
254    Err(DispatchError::Recorded(
255        SafetyBlocked {
256            category: finding.category.clone(),
257            message: msg,
258        }
259        .into(),
260    ))
261}
262
263async fn audit_upstream_failure(
264    audit: &GatewayAudit,
265    provider: &str,
266    model: &str,
267    error: &anyhow::Error,
268) {
269    tracing::warn!(
270        provider = %provider,
271        model = %model,
272        error = %error,
273        "gateway upstream call failed"
274    );
275    if let Err(audit_err) = audit.fail(&error.to_string()).await {
276        tracing::warn!(error = %audit_err, "upstream audit fail failed");
277    }
278}