systemprompt_api/services/gateway/service/
mod.rs1#![allow(clippy::clone_on_ref_ptr)]
4
5mod finalize;
6
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use axum::body::Body;
11use axum::response::Response;
12use bytes::Bytes;
13use systemprompt_database::DbPool;
14use systemprompt_models::profile::GatewayConfig;
15
16use self::finalize::{FinalizeCtx, attach_request_id, finalize, run_request_safety_scan};
17use super::audit::{GatewayAudit, GatewayRequestContext};
18use super::policy::PolicyResolver;
19use super::protocol::canonical::CanonicalRequest;
20use super::protocol::inbound::InboundAdapter;
21use super::protocol::outbound::OutboundCtx;
22use super::quota;
23use super::registry::GatewayUpstreamRegistry;
24
25pub const REQUEST_ID_HEADER: &str = "x-systemprompt-request-id";
26
27#[derive(Debug, Clone, Copy)]
28pub struct GatewayService;
29
30#[derive(Debug)]
31pub struct DispatchInputs {
32 pub request: CanonicalRequest,
33 pub raw_body: Bytes,
34 pub ctx: GatewayRequestContext,
35 pub inbound: Arc<dyn InboundAdapter>,
36}
37
38#[derive(Debug, thiserror::Error)]
39#[error("{0}")]
40pub struct PolicyDenied(pub String);
41
42#[derive(Debug, thiserror::Error)]
43#[error("{message}")]
44pub struct QuotaExceeded {
45 pub message: String,
46 pub retry_after_seconds: i32,
47}
48
49impl GatewayService {
50 pub async fn dispatch(
51 config: &GatewayConfig,
52 db: &DbPool,
53 inputs: DispatchInputs,
54 ) -> Result<Response<Body>> {
55 let DispatchInputs {
56 request,
57 raw_body,
58 ctx,
59 inbound,
60 } = inputs;
61 if ctx.session_id.is_none() {
62 return Err(anyhow!(
63 "gateway dispatch missing conversation binding (session_id)"
64 ));
65 }
66
67 let route = config
68 .find_route(&request.model)
69 .ok_or_else(|| anyhow!("No gateway route matches model '{}'", request.model))?;
70
71 let secrets = systemprompt_config::SecretsBootstrap::get()
72 .map_err(|e| anyhow!("Secrets not available: {e}"))?;
73
74 let upstream_api_key = secrets.get(&route.api_key_secret).ok_or_else(|| {
75 anyhow!(
76 "Gateway API key secret '{}' not configured",
77 route.api_key_secret
78 )
79 })?;
80
81 let upstream = GatewayUpstreamRegistry::global()
82 .get(&route.provider)
83 .ok_or_else(|| anyhow!("Gateway provider '{}' is not registered", route.provider))?;
84
85 let is_streaming = request.stream;
86 let ai_request_id = ctx.ai_request_id.clone();
87
88 tracing::info!(
89 ai_request_id = %ai_request_id,
90 user_id = %ctx.user_id,
91 model = %request.model,
92 provider = %route.provider,
93 upstream = %route.endpoint,
94 wire_protocol = %ctx.wire_protocol,
95 streaming = is_streaming,
96 "Gateway request dispatched"
97 );
98
99 let resolver = PolicyResolver::new(db)?;
100 let policy = resolver.resolve().await;
101
102 if !policy.model_allowed(&request.model) {
103 tracing::warn!(
104 ai_request_id = %ai_request_id,
105 model = %request.model,
106 "Gateway policy denied: model not in allowed list"
107 );
108 return Err(PolicyDenied(format!(
109 "model '{}' is not permitted by gateway policy",
110 request.model
111 ))
112 .into());
113 }
114
115 let audit = Arc::new(
116 GatewayAudit::new(db, ctx.clone()).map_err(|e| anyhow!("audit init failed: {e}"))?,
117 );
118
119 if let Err(e) = audit.open(&request, &raw_body).await {
120 tracing::error!(error = %e, "audit open failed — proceeding without audit row");
121 }
122
123 if let Some(decision) =
124 quota::precheck_and_reserve(db, &ctx.user_id, &policy.quota_windows).await?
125 {
126 if !decision.allow {
127 let msg = format!(
128 "quota exceeded for window {}s (used {}/{:?})",
129 decision.window_seconds, decision.state.requests, decision.limit_requests
130 );
131 if let Err(e) = audit.fail(&msg).await {
132 tracing::warn!(error = %e, "quota audit fail failed");
133 }
134 return Err(QuotaExceeded {
135 message: msg,
136 retry_after_seconds: decision.window_seconds,
137 }
138 .into());
139 }
140 }
141
142 run_request_safety_scan(db, &ai_request_id, &request).await;
143
144 let upstream_model = route.effective_upstream_model(&request.model).to_string();
145 let outbound_ctx = OutboundCtx {
146 route,
147 api_key: upstream_api_key,
148 request: &request,
149 upstream_model: &upstream_model,
150 };
151
152 let outcome = match upstream.send(outbound_ctx).await {
153 Ok(o) => o,
154 Err(e) => {
155 if let Err(audit_err) = audit.fail(&e.to_string()).await {
156 tracing::warn!(error = %audit_err, "upstream audit fail failed");
157 }
158 return Err(e);
159 },
160 };
161
162 let response = finalize(
163 outcome,
164 FinalizeCtx {
165 audit: Arc::clone(&audit),
166 db: db.clone(),
167 ai_request_id: ai_request_id.clone(),
168 policy,
169 inbound,
170 request_model: request.model.clone(),
171 },
172 )
173 .await;
174 Ok(attach_request_id(response, &ai_request_id))
175 }
176}