Skip to main content

systemprompt_api/services/gateway/
service.rs

1#![allow(clippy::clone_on_ref_ptr)]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use http::HeaderValue;
9use systemprompt_ai::InsertSafetyFinding;
10use systemprompt_ai::repository::AiSafetyFindingRepository;
11use systemprompt_database::DbPool;
12use systemprompt_identifiers::AiRequestId;
13use systemprompt_models::profile::GatewayConfig;
14
15use super::audit::{GatewayAudit, GatewayRequestContext};
16use super::policy::{GatewayPolicySpec, PolicyResolver};
17use super::protocol::canonical::CanonicalRequest;
18use super::protocol::canonical_response::CanonicalResponse;
19use super::protocol::inbound::InboundAdapter;
20use super::protocol::outbound::{OutboundCtx, OutboundOutcome};
21use super::registry::GatewayUpstreamRegistry;
22use super::safety::{HeuristicScanner, SafetyScanner};
23use super::{parse, quota, stream_tap};
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30/// Per-request inputs to [`GatewayService::dispatch`].
31///
32/// Bundles `request`, `raw_body`, `ctx`, and `inbound` so that the
33/// surrounding env (`config`, `db`) stays as explicit arguments.
34#[derive(Debug)]
35pub struct DispatchInputs {
36    pub request: CanonicalRequest,
37    pub raw_body: Bytes,
38    pub ctx: GatewayRequestContext,
39    pub inbound: Arc<dyn InboundAdapter>,
40}
41
42impl GatewayService {
43    pub async fn dispatch(
44        config: &GatewayConfig,
45        db: &DbPool,
46        inputs: DispatchInputs,
47    ) -> Result<Response<Body>> {
48        let DispatchInputs {
49            request,
50            raw_body,
51            ctx,
52            inbound,
53        } = inputs;
54        if ctx.session_id.is_none() {
55            return Err(anyhow!(
56                "gateway dispatch missing conversation binding (session_id)"
57            ));
58        }
59
60        let route = config
61            .find_route(&request.model)
62            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
63
64        let secrets = systemprompt_config::SecretsBootstrap::get()
65            .map_err(|e| anyhow!("Secrets not available: {e}"))?;
66
67        let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
68            anyhow!(
69                "Gateway API key secret '{}' not configured",
70                route.api_key_secret
71            )
72        })?;
73
74        let upstream = GatewayUpstreamRegistry::global()
75            .get(&route.provider)
76            .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
77
78        let is_streaming = request.stream;
79        let ai_request_id = ctx.ai_request_id.clone();
80
81        tracing::info!(
82            ai_request_id = %ai_request_id,
83            user_id = %ctx.user_id,
84            tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
85            model = %request.model,
86            provider = %route.provider,
87            upstream = %route.endpoint,
88            wire_protocol = %ctx.wire_protocol,
89            streaming = is_streaming,
90            "Gateway request dispatched"
91        );
92
93        let resolver = PolicyResolver::new(db)?;
94        let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
95
96        if !policy.model_allowed(&request.model) {
97            tracing::warn!(
98                ai_request_id = %ai_request_id,
99                model = %request.model,
100                "Gateway policy denied: model not in allowed list"
101            );
102            return Err(PolicyDenied(format!(
103                "model '{}' is not permitted by gateway policy",
104                request.model
105            ))
106            .into());
107        }
108
109        let audit = Arc::new(
110            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
111        );
112
113        if let Err(e) = audit.open(&request, &raw_body).await {
114            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
115        }
116
117        if let Some(decision) = quota::precheck_and_reserve(
118            db,
119            ctx.tenant_id.as_ref(),
120            &ctx.user_id,
121            &policy.quota_windows,
122        )
123        .await?
124        {
125            if !decision.allow {
126                let msg = format!(
127                    "quota exceeded for window {}s (used {}/{:?})",
128                    decision.window_seconds, decision.state.requests, decision.limit_requests
129                );
130                if let Err(e) = audit.fail(&msg).await {
131                    tracing::warn!(error = %e, "quota audit fail failed");
132                }
133                return Err(QuotaExceeded {
134                    message: msg,
135                    retry_after_seconds: decision.window_seconds,
136                }
137                .into());
138            }
139        }
140
141        run_request_safety_scan(db, &ai_request_id, &request).await;
142
143        let upstream_model = route.effective_upstream_model(&request.model).to_string();
144        let outbound_ctx = OutboundCtx {
145            route,
146            api_key: upstream_api_key,
147            request: &request,
148            upstream_model: &upstream_model,
149        };
150
151        let outcome = match upstream.send(outbound_ctx).await {
152            Ok(o) => o,
153            Err(e) => {
154                if let Err(audit_err) = audit.fail(&e.to_string()).await {
155                    tracing::warn!(error = %audit_err, "upstream audit fail failed");
156                }
157                return Err(e);
158            },
159        };
160
161        let response = finalize(
162            outcome,
163            FinalizeCtx {
164                audit: Arc::clone(&audit),
165                db: db.clone(),
166                ai_request_id: ai_request_id.clone(),
167                policy,
168                inbound,
169                request_model: request.model.clone(),
170            },
171        )
172        .await;
173        Ok(attach_request_id(response, &ai_request_id))
174    }
175}
176
177struct FinalizeCtx {
178    audit: Arc<GatewayAudit>,
179    db: DbPool,
180    ai_request_id: AiRequestId,
181    policy: GatewayPolicySpec,
182    inbound: Arc<dyn InboundAdapter>,
183    request_model: String,
184}
185
186#[derive(Debug, thiserror::Error)]
187#[error("{0}")]
188pub struct PolicyDenied(pub String);
189
190#[derive(Debug, thiserror::Error)]
191#[error("{message}")]
192pub struct QuotaExceeded {
193    pub message: String,
194    pub retry_after_seconds: i32,
195}
196
197async fn finalize(outcome: OutboundOutcome, fctx: FinalizeCtx) -> Response<Body> {
198    let FinalizeCtx {
199        audit,
200        db,
201        ai_request_id,
202        policy,
203        inbound,
204        request_model,
205    } = fctx;
206    match outcome {
207        OutboundOutcome::Buffered(canonical) => {
208            let body_bytes = inbound.render_response(&canonical);
209            let audit_clone = Arc::clone(&audit);
210            let body_for_task = body_bytes.clone();
211            tokio::spawn(async move {
212                let canonical_for_task = canonical;
213                let served_model = canonical_for_task.model.clone();
214                if !served_model.is_empty() {
215                    audit_clone.set_served_model(&served_model).await;
216                }
217                let (usage, tool_calls) = parse::extract_from_canonical(&canonical_for_task);
218                if let Err(e) = audit_clone
219                    .complete(usage, tool_calls, &canonical_for_task, &body_for_task)
220                    .await
221                {
222                    tracing::warn!(error = %e, "buffered audit complete failed");
223                }
224                quota::post_update_tokens(
225                    &db,
226                    quota::PostUpdateParams {
227                        tenant_id: audit_clone.ctx.tenant_id.as_ref(),
228                        user_id: &audit_clone.ctx.user_id,
229                        windows: &policy.quota_windows,
230                        input_tokens: usage.input_tokens,
231                        output_tokens: usage.output_tokens,
232                    },
233                )
234                .await;
235                run_response_safety_scan(&db, &ai_request_id, &canonical_for_task).await;
236            });
237            Response::builder()
238                .status(http::StatusCode::OK)
239                .header(http::header::CONTENT_TYPE, "application/json")
240                .body(Body::from(body_bytes))
241                .unwrap_or_else(|_| Response::new(Body::empty()))
242        },
243        OutboundOutcome::Streaming(stream) => {
244            let body = stream_tap::tap(stream, Arc::clone(&inbound), request_model, audit);
245            Response::builder()
246                .status(http::StatusCode::OK)
247                .header(http::header::CONTENT_TYPE, inbound.streaming_content_type())
248                .header("cache-control", "no-cache")
249                .header("x-accel-buffering", "no")
250                .body(body)
251                .unwrap_or_else(|_| Response::new(Body::empty()))
252        },
253    }
254}
255
256async fn run_request_safety_scan(
257    db: &DbPool,
258    ai_request_id: &AiRequestId,
259    request: &CanonicalRequest,
260) {
261    let scanner = HeuristicScanner;
262    let findings = scanner.scan_request(request).await;
263    if findings.is_empty() {
264        return;
265    }
266    persist_findings(db, ai_request_id, findings).await;
267}
268
269async fn run_response_safety_scan(
270    db: &DbPool,
271    ai_request_id: &AiRequestId,
272    response: &CanonicalResponse,
273) {
274    let scanner = HeuristicScanner;
275    let findings = scanner.scan_response_final(response).await;
276    if findings.is_empty() {
277        return;
278    }
279    persist_findings(db, ai_request_id, findings).await;
280}
281
282async fn persist_findings(
283    db: &DbPool,
284    ai_request_id: &AiRequestId,
285    findings: Vec<super::safety::Finding>,
286) {
287    let repo = match AiSafetyFindingRepository::new(db) {
288        Ok(r) => r,
289        Err(e) => {
290            tracing::warn!(error = %e, "safety findings repo init failed");
291            return;
292        },
293    };
294    for f in findings {
295        let params = InsertSafetyFinding {
296            ai_request_id,
297            phase: f.phase,
298            severity: f.severity.as_str(),
299            category: &f.category,
300            scanner: f.scanner,
301            excerpt: f.excerpt.as_deref(),
302        };
303        if let Err(e) = repo.insert(params).await {
304            tracing::warn!(error = %e, "safety finding insert failed");
305        }
306    }
307}
308
309fn attach_request_id(mut response: Response<Body>, id: &AiRequestId) -> Response<Body> {
310    if let Ok(v) = HeaderValue::from_str(id.as_str()) {
311        response.headers_mut().insert(REQUEST_ID_HEADER, v);
312    }
313    response
314}