Skip to main content

systemprompt_api/services/gateway/service/
mod.rs

1//! Gateway dispatch entry point: route resolution, policy and quota checks,
2//! upstream send, and response finalization.
3#![expect(
4    clippy::clone_on_ref_ptr,
5    reason = "Arc::clone usage is intentional and ergonomic in this gateway dispatch path"
6)]
7
8mod finalize;
9
10use std::sync::Arc;
11
12use anyhow::{Result, anyhow};
13use axum::body::Body;
14use axum::response::Response;
15use bytes::Bytes;
16use systemprompt_database::DbPool;
17use systemprompt_models::profile::GatewayConfig;
18
19use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
20use super::audit::{GatewayAudit, GatewayRequestContext};
21use super::policy::PolicyResolver;
22use super::protocol::canonical::CanonicalRequest;
23use super::protocol::inbound::InboundAdapter;
24use super::protocol::outbound::OutboundCtx;
25use super::quota;
26use super::registry::GatewayUpstreamRegistry;
27
28pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
29
30#[derive(Debug, Clone, Copy)]
31pub struct GatewayService;
32
33#[derive(Debug)]
34pub struct DispatchInputs {
35    pub request: CanonicalRequest,
36    pub raw_body: Bytes,
37    pub ctx: GatewayRequestContext,
38    pub inbound: Arc<dyn InboundAdapter>,
39}
40
41#[derive(Debug, thiserror::Error)]
42#[error("{0}")]
43pub struct PolicyDenied(pub String);
44
45#[derive(Debug, thiserror::Error)]
46#[error("{message}")]
47pub struct QuotaExceeded {
48    pub message: String,
49    pub retry_after_seconds: i32,
50}
51
52impl GatewayService {
53    pub async fn dispatch(
54        config: &GatewayConfig,
55        db: &DbPool,
56        inputs: DispatchInputs,
57    ) -> Result<Response<Body>> {
58        let DispatchInputs {
59            request,
60            raw_body,
61            ctx,
62            inbound,
63        } = inputs;
64        if ctx.session_id.is_none() {
65            return Err(anyhow!(
66                "gateway dispatch missing conversation binding (session_id)"
67            ));
68        }
69
70        let ai_request_id = ctx.ai_request_id.clone();
71
72        if !config.is_model_exposed(&request.model) {
73            tracing::warn!(
74                ai_request_id = %ai_request_id,
75                model = %request.model,
76                "Gateway denied: model not in catalog"
77            );
78            return Err(PolicyDenied(format!(
79                "model '{}' is not permitted by gateway policy",
80                request.model
81            ))
82            .into());
83        }
84
85        let route = config
86            .find_route(&request.model)
87            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
88
89        let catalog = config.catalog.as_ref().ok_or_else(|| {
90            anyhow!("Gateway catalog is not loaded; cannot resolve provider for route")
91        })?;
92        let provider = route.resolve(&catalog.providers).ok_or_else(|| {
93            anyhow!(
94                "Gateway provider '{}' is not declared in the catalog",
95                route.provider.as_str()
96            )
97        })?;
98
99        let secrets = systemprompt_config::SecretsBootstrap::get()
100            .map_err(|e| anyhow!("Secrets not available: {e}"))?;
101
102        let upstream_api_key = secrets
103            .get(provider.api_key_secret.as_str())
104            .ok_or_else(|| {
105                anyhow!(
106                    "Gateway API key secret '{}' not configured",
107                    provider.api_key_secret.as_str()
108                )
109            })?;
110
111        let upstream = GatewayUpstreamRegistry::global()
112            .get(route.provider.as_str())
113            .ok_or_else(|| {
114                anyhow!(
115                    "Gateway provider '{}' is not registered",
116                    route.provider.as_str()
117                )
118            })?;
119
120        let is_streaming = request.stream;
121
122        tracing::info!(
123            ai_request_id = %ai_request_id,
124            user_id = %ctx.user_id,
125            model = %request.model,
126            provider = %route.provider,
127            upstream = %provider.endpoint,
128            wire_protocol = %ctx.wire_protocol,
129            streaming = is_streaming,
130            "Gateway request dispatched"
131        );
132
133        let resolver = PolicyResolver::new(db)?;
134        let policy = resolver.resolve().await;
135
136        let audit = Arc::new(
137            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
138        );
139
140        if let Err(e) = audit.open(&request, &raw_body).await {
141            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
142        }
143
144        if let Some(decision) =
145            quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
146        {
147            if !decision.allow {
148                let msg = format!(
149                    "quota exceeded for window {}s (used {}/{:?})",
150                    decision.window_seconds, decision.state.requests, decision.limit_requests
151                );
152                if let Err(e) = audit.fail(&msg).await {
153                    tracing::warn!(error = %e, "quota audit fail failed");
154                }
155                return Err(QuotaExceeded {
156                    message: msg,
157                    retry_after_seconds: decision.window_seconds,
158                }
159                .into());
160            }
161        }
162
163        run_request_safety_scan(db, &ai_request_id, &request).await;
164
165        let upstream_model = route.effective_upstream_model(&request.model).to_owned();
166        let outbound_ctx = OutboundCtx {
167            route,
168            endpoint: &provider.endpoint,
169            api_key: upstream_api_key,
170            request: &request,
171            upstream_model: &upstream_model,
172        };
173
174        let outcome = match upstream.send(outbound_ctx).await {
175            Ok(o) => o,
176            Err(e) => {
177                if let Err(audit_err) = audit.fail(&e.to_string()).await {
178                    tracing::warn!(error = %audit_err, "upstream audit fail failed");
179                }
180                return Err(e);
181            },
182        };
183
184        let response = finalize(
185            outcome,
186            FinalizeCtx {
187                audit: Arc::clone(&audit),
188                db: db.clone(),
189                ai_request_id: ai_request_id.clone(),
190                policy,
191                inbound,
192                request_model: request.model.clone(),
193            },
194        )
195        .await;
196        Ok(attach_request_id(response, &ai_request_id))
197    }
198}