Skip to main content

systemprompt_api/services/gateway/service/
mod.rs

1//! Gateway dispatch entry point: route resolution, policy and quota checks,
2//! upstream send, and response finalization.
3#![allow(clippy::clone_on_ref_ptr)]
4
5mod finalize;
6
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use axum::body::Body;
11use axum::response::Response;
12use bytes::Bytes;
13use systemprompt_database::DbPool;
14use systemprompt_models::profile::GatewayConfig;
15
16use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
17use super::audit::{GatewayAudit, GatewayRequestContext};
18use super::policy::PolicyResolver;
19use super::protocol::canonical::CanonicalRequest;
20use super::protocol::inbound::InboundAdapter;
21use super::protocol::outbound::OutboundCtx;
22use super::quota;
23use super::registry::GatewayUpstreamRegistry;
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30#[derive(Debug)]
31pub struct DispatchInputs {
32    pub request: CanonicalRequest,
33    pub raw_body: Bytes,
34    pub ctx: GatewayRequestContext,
35    pub inbound: Arc<dyn InboundAdapter>,
36}
37
38#[derive(Debug, thiserror::Error)]
39#[error("{0}")]
40pub struct PolicyDenied(pub String);
41
42#[derive(Debug, thiserror::Error)]
43#[error("{message}")]
44pub struct QuotaExceeded {
45    pub message: String,
46    pub retry_after_seconds: i32,
47}
48
49impl GatewayService {
50    pub async fn dispatch(
51        config: &GatewayConfig,
52        db: &DbPool,
53        inputs: DispatchInputs,
54    ) -> Result<Response<Body>> {
55        let DispatchInputs {
56            request,
57            raw_body,
58            ctx,
59            inbound,
60        } = inputs;
61        if ctx.session_id.is_none() {
62            return Err(anyhow!(
63                "gateway dispatch missing conversation binding (session_id)"
64            ));
65        }
66
67        let route = config
68            .find_route(&request.model)
69            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
70
71        let secrets = systemprompt_config::SecretsBootstrap::get()
72            .map_err(|e| anyhow!("Secrets not available: {e}"))?;
73
74        let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
75            anyhow!(
76                "Gateway API key secret '{}' not configured",
77                route.api_key_secret
78            )
79        })?;
80
81        let upstream = GatewayUpstreamRegistry::global()
82            .get(&route.provider)
83            .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
84
85        let is_streaming = request.stream;
86        let ai_request_id = ctx.ai_request_id.clone();
87
88        tracing::info!(
89            ai_request_id = %ai_request_id,
90            user_id = %ctx.user_id,
91            tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
92            model = %request.model,
93            provider = %route.provider,
94            upstream = %route.endpoint,
95            wire_protocol = %ctx.wire_protocol,
96            streaming = is_streaming,
97            "Gateway request dispatched"
98        );
99
100        let resolver = PolicyResolver::new(db)?;
101        let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
102
103        if !policy.model_allowed(&request.model) {
104            tracing::warn!(
105                ai_request_id = %ai_request_id,
106                model = %request.model,
107                "Gateway policy denied: model not in allowed list"
108            );
109            return Err(PolicyDenied(format!(
110                "model '{}' is not permitted by gateway policy",
111                request.model
112            ))
113            .into());
114        }
115
116        let audit = Arc::new(
117            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
118        );
119
120        if let Err(e) = audit.open(&request, &raw_body).await {
121            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
122        }
123
124        if let Some(decision) = quota::precheck_and_reserve(
125            db,
126            ctx.tenant_id.as_ref(),
127            &ctx.user_id,
128            &policy.quota_windows,
129        )
130        .await?
131        {
132            if !decision.allow {
133                let msg = format!(
134                    "quota exceeded for window {}s (used {}/{:?})",
135                    decision.window_seconds, decision.state.requests, decision.limit_requests
136                );
137                if let Err(e) = audit.fail(&msg).await {
138                    tracing::warn!(error = %e, "quota audit fail failed");
139                }
140                return Err(QuotaExceeded {
141                    message: msg,
142                    retry_after_seconds: decision.window_seconds,
143                }
144                .into());
145            }
146        }
147
148        run_request_safety_scan(db, &ai_request_id, &request).await;
149
150        let upstream_model = route.effective_upstream_model(&request.model).to_string();
151        let outbound_ctx = OutboundCtx {
152            route,
153            api_key: upstream_api_key,
154            request: &request,
155            upstream_model: &upstream_model,
156        };
157
158        let outcome = match upstream.send(outbound_ctx).await {
159            Ok(o) => o,
160            Err(e) => {
161                if let Err(audit_err) = audit.fail(&e.to_string()).await {
162                    tracing::warn!(error = %audit_err, "upstream audit fail failed");
163                }
164                return Err(e);
165            },
166        };
167
168        let response = finalize(
169            outcome,
170            FinalizeCtx {
171                audit: Arc::clone(&audit),
172                db: db.clone(),
173                ai_request_id: ai_request_id.clone(),
174                policy,
175                inbound,
176                request_model: request.model.clone(),
177            },
178        )
179        .await;
180        Ok(attach_request_id(response, &ai_request_id))
181    }
182}