Skip to main content

systemprompt_api/services/gateway/service/
mod.rs

1//! Gateway dispatch entry point: route resolution, policy and quota checks,
2//! upstream send, and response finalization.
3#![allow(clippy::clone_on_ref_ptr)]
4
5mod finalize;
6
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use axum::body::Body;
11use axum::response::Response;
12use bytes::Bytes;
13use systemprompt_database::DbPool;
14use systemprompt_models::profile::GatewayConfig;
15
16use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
17use super::audit::{GatewayAudit, GatewayRequestContext};
18use super::policy::PolicyResolver;
19use super::protocol::canonical::CanonicalRequest;
20use super::protocol::inbound::InboundAdapter;
21use super::protocol::outbound::OutboundCtx;
22use super::quota;
23use super::registry::GatewayUpstreamRegistry;
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30#[derive(Debug)]
31pub struct DispatchInputs {
32    pub request: CanonicalRequest,
33    pub raw_body: Bytes,
34    pub ctx: GatewayRequestContext,
35    pub inbound: Arc<dyn InboundAdapter>,
36}
37
38#[derive(Debug, thiserror::Error)]
39#[error("{0}")]
40pub struct PolicyDenied(pub String);
41
42#[derive(Debug, thiserror::Error)]
43#[error("{message}")]
44pub struct QuotaExceeded {
45    pub message: String,
46    pub retry_after_seconds: i32,
47}
48
49impl GatewayService {
50    pub async fn dispatch(
51        config: &GatewayConfig,
52        db: &DbPool,
53        inputs: DispatchInputs,
54    ) -> Result<Response<Body>> {
55        let DispatchInputs {
56            request,
57            raw_body,
58            ctx,
59            inbound,
60        } = inputs;
61        if ctx.session_id.is_none() {
62            return Err(anyhow!(
63                "gateway dispatch missing conversation binding (session_id)"
64            ));
65        }
66
67        let route = config
68            .find_route(&request.model)
69            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
70
71        let secrets = systemprompt_config::SecretsBootstrap::get()
72            .map_err(|e| anyhow!("Secrets not available: {e}"))?;
73
74        let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
75            anyhow!(
76                "Gateway API key secret '{}' not configured",
77                route.api_key_secret
78            )
79        })?;
80
81        let upstream = GatewayUpstreamRegistry::global()
82            .get(&route.provider)
83            .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
84
85        let is_streaming = request.stream;
86        let ai_request_id = ctx.ai_request_id.clone();
87
88        tracing::info!(
89            ai_request_id = %ai_request_id,
90            user_id = %ctx.user_id,
91            model = %request.model,
92            provider = %route.provider,
93            upstream = %route.endpoint,
94            wire_protocol = %ctx.wire_protocol,
95            streaming = is_streaming,
96            "Gateway request dispatched"
97        );
98
99        let resolver = PolicyResolver::new(db)?;
100        let policy = resolver.resolve().await;
101
102        if !policy.model_allowed(&request.model) {
103            tracing::warn!(
104                ai_request_id = %ai_request_id,
105                model = %request.model,
106                "Gateway policy denied: model not in allowed list"
107            );
108            return Err(PolicyDenied(format!(
109                "model '{}' is not permitted by gateway policy",
110                request.model
111            ))
112            .into());
113        }
114
115        let audit = Arc::new(
116            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
117        );
118
119        if let Err(e) = audit.open(&request, &raw_body).await {
120            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
121        }
122
123        if let Some(decision) =
124            quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
125        {
126            if !decision.allow {
127                let msg = format!(
128                    "quota exceeded for window {}s (used {}/{:?})",
129                    decision.window_seconds, decision.state.requests, decision.limit_requests
130                );
131                if let Err(e) = audit.fail(&msg).await {
132                    tracing::warn!(error = %e, "quota audit fail failed");
133                }
134                return Err(QuotaExceeded {
135                    message: msg,
136                    retry_after_seconds: decision.window_seconds,
137                }
138                .into());
139            }
140        }
141
142        run_request_safety_scan(db, &ai_request_id, &request).await;
143
144        let upstream_model = route.effective_upstream_model(&request.model).to_string();
145        let outbound_ctx = OutboundCtx {
146            route,
147            api_key: upstream_api_key,
148            request: &request,
149            upstream_model: &upstream_model,
150        };
151
152        let outcome = match upstream.send(outbound_ctx).await {
153            Ok(o) => o,
154            Err(e) => {
155                if let Err(audit_err) = audit.fail(&e.to_string()).await {
156                    tracing::warn!(error = %audit_err, "upstream audit fail failed");
157                }
158                return Err(e);
159            },
160        };
161
162        let response = finalize(
163            outcome,
164            FinalizeCtx {
165                audit: Arc::clone(&audit),
166                db: db.clone(),
167                ai_request_id: ai_request_id.clone(),
168                policy,
169                inbound,
170                request_model: request.model.clone(),
171            },
172        )
173        .await;
174        Ok(attach_request_id(response, &ai_request_id))
175    }
176}