Skip to main content

systemprompt_api/services/gateway/
service.rs

1#![allow(clippy::clone_on_ref_ptr)]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use axum::body::Body;
6use axum::response::Response;
7use bytes::Bytes;
8use http::HeaderValue;
9use systemprompt_ai::InsertSafetyFinding;
10use systemprompt_ai::repository::AiSafetyFindingRepository;
11use systemprompt_database::DbPool;
12use systemprompt_identifiers::AiRequestId;
13use systemprompt_models::profile::GatewayConfig;
14
15use super::audit::{GatewayAudit, GatewayRequestContext};
16use super::models::AnthropicGatewayRequest;
17use super::policy::{GatewayPolicySpec, PolicyResolver};
18use super::registry::GatewayUpstreamRegistry;
19use super::safety::{HeuristicScanner, SafetyScanner};
20use super::upstream::{UpstreamCtx, UpstreamOutcome, build_response};
21use super::{parse, quota, stream_tap};
22
23pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
24
25#[derive(Debug, Clone, Copy)]
26pub struct GatewayService;
27
28impl GatewayService {
29    pub async fn dispatch(
30        config: &GatewayConfig,
31        request: AnthropicGatewayRequest,
32        raw_body: Bytes,
33        ctx: GatewayRequestContext,
34        db: &DbPool,
35    ) -> Result<Response<Body>> {
36        let route = config
37            .find_route(&request.model)
38            .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
39
40        let secrets = systemprompt_models::SecretsBootstrap::get()
41            .map_err(|e| anyhow!("Secrets not available: {e}"))?;
42
43        let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
44            anyhow!(
45                "Gateway API key secret '{}' not configured",
46                route.api_key_secret
47            )
48        })?;
49
50        let upstream = GatewayUpstreamRegistry::global()
51            .get(&route.provider)
52            .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
53
54        let is_streaming = request.stream.unwrap_or(false);
55        let ai_request_id = ctx.ai_request_id.clone();
56
57        tracing::info!(
58            ai_request_id = %ai_request_id,
59            user_id = %ctx.user_id,
60            tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
61            model = %request.model,
62            provider = %route.provider,
63            upstream = %route.endpoint,
64            streaming = is_streaming,
65            "Gateway request dispatched"
66        );
67
68        let resolver = PolicyResolver::new(db)?;
69        let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
70
71        if !policy.model_allowed(&request.model) {
72            tracing::warn!(
73                ai_request_id = %ai_request_id,
74                model = %request.model,
75                "Gateway policy denied: model not in allowed list"
76            );
77            return Err(PolicyDenied(format!(
78                "model '{}' is not permitted by gateway policy",
79                request.model
80            ))
81            .into());
82        }
83
84        let audit = Arc::new(
85            GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
86        );
87
88        if let Err(e) = audit.open(&request, &raw_body).await {
89            tracing::error!(error = %e, "audit open failed — proceeding without audit row");
90        }
91
92        if let Some(decision) = quota::precheck_and_reserve(
93            db,
94            ctx.tenant_id.as_ref(),
95            &ctx.user_id,
96            &policy.quota_windows,
97        )
98        .await?
99        {
100            if !decision.allow {
101                let msg = format!(
102                    "quota exceeded for window {}s (used {}/{:?})",
103                    decision.window_seconds, decision.state.requests, decision.limit_requests
104                );
105                let _ = audit.fail(&msg).await;
106                return Err(QuotaExceeded {
107                    message: msg,
108                    retry_after_seconds: decision.window_seconds,
109                }
110                .into());
111            }
112        }
113
114        run_request_safety_scan(db, &ai_request_id, &request).await;
115
116        let upstream_ctx = UpstreamCtx {
117            route,
118            api_key: upstream_api_key,
119            raw_body,
120            request: &request,
121            is_streaming,
122        };
123
124        let outcome = match upstream.proxy(upstream_ctx).await {
125            Ok(o) => o,
126            Err(e) => {
127                let _ = audit.fail(&e.to_string()).await;
128                return Err(e);
129            },
130        };
131
132        let response = finalize(
133            outcome,
134            Arc::clone(&audit),
135            db.clone(),
136            ai_request_id.clone(),
137            policy,
138        )
139        .await;
140        Ok(attach_request_id(response, &ai_request_id))
141    }
142}
143
144#[derive(Debug, thiserror::Error)]
145#[error("{0}")]
146pub struct PolicyDenied(pub String);
147
148#[derive(Debug, thiserror::Error)]
149#[error("{message}")]
150pub struct QuotaExceeded {
151    pub message: String,
152    pub retry_after_seconds: i32,
153}
154
155async fn finalize(
156    outcome: UpstreamOutcome,
157    audit: Arc<GatewayAudit>,
158    db: DbPool,
159    ai_request_id: AiRequestId,
160    policy: GatewayPolicySpec,
161) -> Response<Body> {
162    match outcome {
163        UpstreamOutcome::Buffered {
164            status,
165            content_type,
166            body,
167            served_model,
168        } => {
169            let body_clone = body.clone();
170            let audit_clone = Arc::clone(&audit);
171            let served_model_clone = served_model.clone();
172            tokio::spawn(async move {
173                if status.is_success() {
174                    if let Some(model) = served_model_clone.as_deref() {
175                        audit_clone.set_served_model(model).await;
176                    }
177                    let (usage, tool_calls) = parse::extract_from_anthropic_response(&body_clone);
178                    if let Err(e) = audit_clone.complete(usage, tool_calls, &body_clone).await {
179                        tracing::warn!(error = %e, "buffered audit complete failed");
180                    }
181                    quota::post_update_tokens(
182                        &db,
183                        quota::PostUpdateParams {
184                            tenant_id: audit_clone.ctx.tenant_id.as_ref(),
185                            user_id: &audit_clone.ctx.user_id,
186                            windows: &policy.quota_windows,
187                            input_tokens: usage.input_tokens,
188                            output_tokens: usage.output_tokens,
189                        },
190                    )
191                    .await;
192                    run_response_safety_scan(&db, &ai_request_id, &body_clone).await;
193                } else {
194                    let err_msg = format!(
195                        "upstream status {}: {}",
196                        status.as_u16(),
197                        String::from_utf8_lossy(&body_clone)
198                    );
199                    if let Err(e) = audit_clone.fail(&err_msg).await {
200                        tracing::warn!(error = %e, "buffered audit fail update failed");
201                    }
202                }
203            });
204            build_response(UpstreamOutcome::Buffered {
205                status,
206                content_type,
207                body,
208                served_model,
209            })
210        },
211        UpstreamOutcome::Streaming { status, stream } => {
212            let body = stream_tap::tap(stream, Arc::clone(&audit));
213            Response::builder()
214                .status(status)
215                .header(http::header::CONTENT_TYPE, "text/event-stream")
216                .header("cache-control", "no-cache")
217                .header("x-accel-buffering", "no")
218                .body(body)
219                .unwrap_or_else(|_| Response::new(Body::empty()))
220        },
221    }
222}
223
224async fn run_request_safety_scan(
225    db: &DbPool,
226    ai_request_id: &AiRequestId,
227    request: &AnthropicGatewayRequest,
228) {
229    let scanner = HeuristicScanner;
230    let findings = scanner.scan_request(request).await;
231    if findings.is_empty() {
232        return;
233    }
234    persist_findings(db, ai_request_id, findings).await;
235}
236
237async fn run_response_safety_scan(db: &DbPool, ai_request_id: &AiRequestId, body: &[u8]) {
238    let scanner = HeuristicScanner;
239    let findings = scanner.scan_response_final(body).await;
240    if findings.is_empty() {
241        return;
242    }
243    persist_findings(db, ai_request_id, findings).await;
244}
245
246async fn persist_findings(
247    db: &DbPool,
248    ai_request_id: &AiRequestId,
249    findings: Vec<super::safety::Finding>,
250) {
251    let repo = match AiSafetyFindingRepository::new(db) {
252        Ok(r) => r,
253        Err(e) => {
254            tracing::warn!(error = %e, "safety findings repo init failed");
255            return;
256        },
257    };
258    for f in findings {
259        let params = InsertSafetyFinding {
260            ai_request_id,
261            phase: f.phase,
262            severity: f.severity.as_str(),
263            category: &f.category,
264            scanner: f.scanner,
265            excerpt: f.excerpt.as_deref(),
266        };
267        if let Err(e) = repo.insert(params).await {
268            tracing::warn!(error = %e, "safety finding insert failed");
269        }
270    }
271}
272
273fn attach_request_id(mut response: Response<Body>, id: &AiRequestId) -> Response<Body> {
274    if let Ok(v) = HeaderValue::from_str(id.as_str()) {
275        response.headers_mut().insert(REQUEST_ID_HEADER, v);
276    }
277    response
278}