systemprompt_api/services/gateway/service/
mod.rs1#![allow(clippy::clone_on_ref_ptr)]
4
5mod finalize;
6
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use axum::body::Body;
11use axum::response::Response;
12use bytes::Bytes;
13use systemprompt_database::DbPool;
14use systemprompt_models::profile::GatewayConfig;
15
16use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
17use super::audit::{GatewayAudit, GatewayRequestContext};
18use super::policy::PolicyResolver;
19use super::protocol::canonical::CanonicalRequest;
20use super::protocol::inbound::InboundAdapter;
21use super::protocol::outbound::OutboundCtx;
22use super::quota;
23use super::registry::GatewayUpstreamRegistry;
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30#[derive(Debug)]
31pub struct DispatchInputs {
32 pub request: CanonicalRequest,
33 pub raw_body: Bytes,
34 pub ctx: GatewayRequestContext,
35 pub inbound: Arc<dyn InboundAdapter>,
36}
37
38#[derive(Debug, thiserror::Error)]
39#[error("{0}")]
40pub struct PolicyDenied(pub String);
41
42#[derive(Debug, thiserror::Error)]
43#[error("{message}")]
44pub struct QuotaExceeded {
45 pub message: String,
46 pub retry_after_seconds: i32,
47}
48
49impl GatewayService {
50 pub async fn dispatch(
51 config: &GatewayConfig,
52 db: &DbPool,
53 inputs: DispatchInputs,
54 ) -> Result<Response<Body>> {
55 let DispatchInputs {
56 request,
57 raw_body,
58 ctx,
59 inbound,
60 } = inputs;
61 if ctx.session_id.is_none() {
62 return Err(anyhow!(
63 "gateway dispatch missing conversation binding (session_id)"
64 ));
65 }
66
67 let route = config
68 .find_route(&request.model)
69 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
70
71 let secrets = systemprompt_config::SecretsBootstrap::get()
72 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
73
74 let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
75 anyhow!(
76 "Gateway API key secret '{}' not configured",
77 route.api_key_secret
78 )
79 })?;
80
81 let upstream = GatewayUpstreamRegistry::global()
82 .get(&route.provider)
83 .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
84
85 let is_streaming = request.stream;
86 let ai_request_id = ctx.ai_request_id.clone();
87
88 tracing::info!(
89 ai_request_id = %ai_request_id,
90 user_id = %ctx.user_id,
91 tenant_id = ctx.tenant_id.as_ref().map_or("-", |t| t.as_str()),
92 model = %request.model,
93 provider = %route.provider,
94 upstream = %route.endpoint,
95 wire_protocol = %ctx.wire_protocol,
96 streaming = is_streaming,
97 "Gateway request dispatched"
98 );
99
100 let resolver = PolicyResolver::new(db)?;
101 let policy = resolver.resolve(ctx.tenant_id.as_ref()).await;
102
103 if !policy.model_allowed(&request.model) {
104 tracing::warn!(
105 ai_request_id = %ai_request_id,
106 model = %request.model,
107 "Gateway policy denied: model not in allowed list"
108 );
109 return Err(PolicyDenied(format!(
110 "model '{}' is not permitted by gateway policy",
111 request.model
112 ))
113 .into());
114 }
115
116 let audit = Arc::new(
117 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
118 );
119
120 if let Err(e) = audit.open(&request, &raw_body).await {
121 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
122 }
123
124 if let Some(decision) = quota::precheck_and_reserve(
125 db,
126 ctx.tenant_id.as_ref(),
127 &ctx.user_id,
128 &policy.quota_windows,
129 )
130 .await?
131 {
132 if !decision.allow {
133 let msg = format!(
134 "quota exceeded for window {}s (used {}/{:?})",
135 decision.window_seconds, decision.state.requests, decision.limit_requests
136 );
137 if let Err(e) = audit.fail(&msg).await {
138 tracing::warn!(error = %e, "quota audit fail failed");
139 }
140 return Err(QuotaExceeded {
141 message: msg,
142 retry_after_seconds: decision.window_seconds,
143 }
144 .into());
145 }
146 }
147
148 run_request_safety_scan(db, &ai_request_id, &request).await;
149
150 let upstream_model = route.effective_upstream_model(&request.model).to_string();
151 let outbound_ctx = OutboundCtx {
152 route,
153 api_key: upstream_api_key,
154 request: &request,
155 upstream_model: &upstream_model,
156 };
157
158 let outcome = match upstream.send(outbound_ctx).await {
159 Ok(o) => o,
160 Err(e) => {
161 if let Err(audit_err) = audit.fail(&e.to_string()).await {
162 tracing::warn!(error = %audit_err, "upstream audit fail failed");
163 }
164 return Err(e);
165 },
166 };
167
168 let response = finalize(
169 outcome,
170 FinalizeCtx {
171 audit: Arc::clone(&audit),
172 db: db.clone(),
173 ai_request_id: ai_request_id.clone(),
174 policy,
175 inbound,
176 request_model: request.model.clone(),
177 },
178 )
179 .await;
180 Ok(attach_request_id(response, &ai_request_id))
181 }
182}