systemprompt_api/services/proxy/
engine.rs1use axum::body::Body;
2use axum::extract::{Path, Request, State};
3use axum::http::StatusCode;
4use axum::response::{IntoResponse, Response};
5use std::collections::HashMap;
6use std::sync::Arc;
7use systemprompt_identifiers::{AgentName, UserId};
8use systemprompt_models::RequestContext;
9use systemprompt_models::auth::{AuthenticatedUser, Permission};
10use systemprompt_runtime::AppContext;
11use tokio::sync::RwLock;
12
13use super::auth::AccessValidator;
14use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
15use super::client::ClientPool;
16use super::resolver::ServiceResolver;
17
18#[derive(Clone, Debug)]
19struct ProxySessionIdentity {
20 user_id: String,
21 user_type: String,
22 permissions: Vec<Permission>,
23 auth_token: String,
24}
25
26type SessionCache = Arc<RwLock<HashMap<String, ProxySessionIdentity>>>;
27
28#[derive(Debug, Clone)]
29pub struct ProxyEngine {
30 client_pool: ClientPool,
31 session_cache: SessionCache,
32}
33
34impl Default for ProxyEngine {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl ProxyEngine {
41 pub fn new() -> Self {
42 Self {
43 client_pool: ClientPool::new(),
44 session_cache: Arc::new(RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub async fn proxy_request(
49 &self,
50 service_name: &str,
51 path: &str,
52 request: Request<Body>,
53 ctx: AppContext,
54 ) -> Result<Response<Body>, ProxyError> {
55 if request.extensions().get::<RequestContext>().is_none() {
56 tracing::warn!("RequestContext missing from request extensions");
57 }
58
59 let service = ServiceResolver::resolve(service_name, &ctx).await?;
60
61 let req_ctx = request.extensions().get::<RequestContext>().cloned();
62 let authenticated_user = AccessValidator::validate(
63 request.headers(),
64 service_name,
65 &service,
66 &ctx,
67 req_ctx.as_ref(),
68 )
69 .await?;
70
71 let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
72
73 let method_str = request.method().to_string();
74 let request_headers = request.headers().clone();
75 let mut headers = request_headers.clone();
76 let query = request.uri().query();
77 let full_url = UrlResolver::append_query_params(backend_url, query);
78
79 let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
80 message: "Request context required - proxy cannot operate without authentication"
81 .to_string(),
82 })?;
83
84 if service.module_name == "agent" || service.module_name == "mcp" {
85 req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
86 }
87
88 if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
89 if let Some(session_id) = request_headers
90 .get("mcp-session-id")
91 .and_then(|v| v.to_str().ok())
92 {
93 if let Some(identity) = self.session_cache.read().await.get(session_id) {
94 tracing::info!(
95 service = %service_name,
96 session_id = %session_id,
97 user_id = %identity.user_id,
98 "Enriching session-only request with cached identity"
99 );
100 req_context = req_context
101 .with_user_id(UserId::from(identity.user_id.clone()))
102 .with_user_type(
103 identity
104 .user_type
105 .parse()
106 .unwrap_or(systemprompt_models::auth::UserType::Unknown),
107 )
108 .with_auth_token(identity.auth_token.clone())
109 .with_user(AuthenticatedUser::new(
110 identity.user_id.parse().unwrap_or_default(),
111 String::new(),
112 String::new(),
113 identity.permissions.clone(),
114 ));
115 }
116 }
117 }
118
119 let has_auth_before = headers.get("authorization").is_some();
120 let ctx_has_token = !req_context.auth_token().as_str().is_empty();
121
122 HeaderInjector::inject_context(&mut headers, &req_context);
123
124 let has_auth_after = headers.get("authorization").is_some();
125 tracing::debug!(
126 service = %service_name,
127 has_auth_before = has_auth_before,
128 ctx_has_token = ctx_has_token,
129 has_auth_after = has_auth_after,
130 "Proxy forwarding request"
131 );
132
133 let body = RequestBuilder::extract_body(request.into_body())
134 .await
135 .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
136
137 let reqwest_method = RequestBuilder::parse_method(&method_str)
138 .map_err(|reason| ProxyError::InvalidMethod { reason })?;
139
140 let client = self.client_pool.get_default_client();
141
142 let req_builder =
143 RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
144
145 let req_builder = req_builder.map_err(|status| ProxyError::InvalidResponse {
146 service: service_name.to_string(),
147 reason: format!("Failed to build request: {status}"),
148 })?;
149
150 let response = match req_builder.send().await {
151 Ok(resp) => resp,
152 Err(e) => {
153 tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
154 return Err(ProxyError::ConnectionFailed {
155 service: service_name.to_string(),
156 url: full_url.clone(),
157 source: e,
158 });
159 },
160 };
161
162 if service.module_name == "mcp" {
163 let resp_status = response.status();
164 let resp_session = response
165 .headers()
166 .get("mcp-session-id")
167 .and_then(|v| v.to_str().ok())
168 .unwrap_or("none");
169 let resp_content_type = response
170 .headers()
171 .get("content-type")
172 .and_then(|v| v.to_str().ok())
173 .unwrap_or("none");
174
175 tracing::info!(
176 service = %service_name,
177 status = %resp_status,
178 resp_session_id = %resp_session,
179 content_type = %resp_content_type,
180 method = %method_str,
181 "MCP backend response"
182 );
183
184 if !resp_status.is_success() {
185 let header_dump: Vec<String> = response
186 .headers()
187 .iter()
188 .map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or("?")))
189 .collect();
190 tracing::error!(
191 service = %service_name,
192 status = %resp_status,
193 headers = ?header_dump,
194 "MCP backend error response"
195 );
196
197 if resp_status == StatusCode::NOT_FOUND && method_str == "GET" {
198 if let Some(session_id) = request_headers
199 .get("mcp-session-id")
200 .and_then(|v| v.to_str().ok())
201 {
202 self.session_cache.write().await.remove(session_id);
203 tracing::info!(
204 service = %service_name,
205 session_id = %session_id,
206 "Evicted stale proxy session cache on 404 GET"
207 );
208 }
209 }
210 }
211
212 if let Some(session_id) = response
213 .headers()
214 .get("mcp-session-id")
215 .and_then(|v| v.to_str().ok())
216 {
217 if let Some(user) = &authenticated_user {
218 self.session_cache.write().await.insert(
219 session_id.to_string(),
220 ProxySessionIdentity {
221 user_id: user.id.to_string(),
222 user_type: req_context.user_type().to_string(),
223 permissions: user.permissions.clone(),
224 auth_token: req_context.auth_token().as_str().to_string(),
225 },
226 );
227 tracing::info!(
228 service = %service_name,
229 session_id = %session_id,
230 user_id = %user.id,
231 "Cached session identity for MCP session"
232 );
233 }
234 }
235
236 if method_str == "DELETE" {
237 if let Some(session_id) = request_headers
238 .get("mcp-session-id")
239 .and_then(|v| v.to_str().ok())
240 {
241 self.session_cache.write().await.remove(session_id);
242 tracing::debug!(session_id = %session_id, "Evicted session identity on DELETE");
243 }
244 }
245 }
246
247 match ResponseHandler::build_response(response) {
248 Ok(resp) => Ok(resp),
249 Err(e) => {
250 tracing::error!(service = %service_name, error = %e, "Failed to build response");
251 Err(ProxyError::InvalidResponse {
252 service: service_name.to_string(),
253 reason: format!("Failed to build response: {e}"),
254 })
255 },
256 }
257 }
258
259 pub async fn handle_mcp_request(
260 &self,
261 path_params: Path<(String,)>,
262 State(ctx): State<AppContext>,
263 request: Request<Body>,
264 ) -> Response<Body> {
265 let Path((service_name,)) = path_params;
266 match self.proxy_request(&service_name, "", request, ctx).await {
267 Ok(response) => response,
268 Err(e) => e.into_response(),
269 }
270 }
271
272 pub async fn handle_mcp_request_with_path(
273 &self,
274 path_params: Path<(String, String)>,
275 State(ctx): State<AppContext>,
276 request: Request<Body>,
277 ) -> Response<Body> {
278 let Path((service_name, path)) = path_params;
279 match self.proxy_request(&service_name, &path, request, ctx).await {
280 Ok(response) => response,
281 Err(e) => e.into_response(),
282 }
283 }
284
285 pub async fn handle_agent_request(
286 &self,
287 path_params: Path<(String,)>,
288 State(ctx): State<AppContext>,
289 request: Request<Body>,
290 ) -> Result<Response<Body>, StatusCode> {
291 let Path((service_name,)) = path_params;
292 self.proxy_request(&service_name, "", request, ctx)
293 .await
294 .map_err(|e| e.to_status_code())
295 }
296
297 pub async fn handle_agent_request_with_path(
298 &self,
299 path_params: Path<(String, String)>,
300 State(ctx): State<AppContext>,
301 request: Request<Body>,
302 ) -> Result<Response<Body>, StatusCode> {
303 let Path((service_name, path)) = path_params;
304 self.proxy_request(&service_name, &path, request, ctx)
305 .await
306 .map_err(|e| e.to_status_code())
307 }
308}