systemprompt_api/services/proxy/
engine.rs1use axum::body::Body;
2use axum::extract::{Path, Request, State};
3use axum::http::StatusCode;
4use axum::response::{IntoResponse, Response};
5use std::collections::HashMap;
6use std::sync::Arc;
7use systemprompt_identifiers::{AgentName, UserId};
8use systemprompt_models::RequestContext;
9use systemprompt_models::auth::{AuthenticatedUser, Permission};
10use systemprompt_runtime::AppContext;
11use tokio::sync::RwLock;
12
13use super::auth::AccessValidator;
14use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
15use super::client::ClientPool;
16use super::resolver::ServiceResolver;
17
18#[derive(Clone, Debug)]
19struct ProxySessionIdentity {
20 user: String,
21 user_type: String,
22 permissions: Vec<Permission>,
23 auth_token: String,
24}
25
26type SessionCache = Arc<RwLock<HashMap<String, ProxySessionIdentity>>>;
27
28#[derive(Debug, Clone)]
29pub struct ProxyEngine {
30 client_pool: ClientPool,
31 session_cache: SessionCache,
32}
33
34impl Default for ProxyEngine {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl ProxyEngine {
41 pub fn new() -> Self {
42 Self {
43 client_pool: ClientPool::new(),
44 session_cache: Arc::new(RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub async fn proxy_request(
49 &self,
50 service_name: &str,
51 path: &str,
52 request: Request<Body>,
53 ctx: AppContext,
54 ) -> Result<Response<Body>, ProxyError> {
55 if request.extensions().get::<RequestContext>().is_none() {
56 tracing::warn!("RequestContext missing from request extensions");
57 }
58
59 let service = ServiceResolver::resolve(service_name, &ctx).await?;
60
61 let req_ctx = request.extensions().get::<RequestContext>().cloned();
62 let authenticated_user = AccessValidator::validate(
63 request.headers(),
64 service_name,
65 &service,
66 &ctx,
67 req_ctx.as_ref(),
68 )
69 .await?;
70
71 let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
72
73 let method_str = request.method().to_string();
74 let request_headers = request.headers().clone();
75 let mut headers = request_headers.clone();
76 let query = request.uri().query();
77 let full_url = UrlResolver::append_query_params(backend_url, query);
78
79 let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
80 message: "Request context required - proxy cannot operate without authentication"
81 .to_string(),
82 })?;
83
84 if service.module_name == "agent" || service.module_name == "mcp" {
85 req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
86 }
87
88 if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
89 if let Some(session_id) = request_headers
90 .get("mcp-session-id")
91 .and_then(|v| v.to_str().ok())
92 {
93 if let Some(identity) = self.session_cache.read().await.get(session_id) {
94 tracing::info!(
95 service = %service_name,
96 session_id = %session_id,
97 user_id = %identity.user,
98 "Enriching session-only request with cached identity"
99 );
100 req_context = req_context
101 .with_user_id(UserId::from(identity.user.clone()))
102 .with_user_type(
103 identity
104 .user_type
105 .parse()
106 .unwrap_or(systemprompt_models::auth::UserType::Unknown),
107 )
108 .with_auth_token(identity.auth_token.clone())
109 .with_user(AuthenticatedUser::new(
110 identity.user.parse().unwrap_or(uuid::Uuid::nil()),
111 String::new(),
112 String::new(),
113 identity.permissions.clone(),
114 ));
115 }
116 }
117 }
118
119 let has_auth_before = headers.get("authorization").is_some();
120 let ctx_has_token = !req_context.auth_token().as_str().is_empty();
121
122 HeaderInjector::inject_context(&mut headers, &req_context);
123
124 let has_auth_after = headers.get("authorization").is_some();
125 tracing::debug!(
126 service = %service_name,
127 has_auth_before = has_auth_before,
128 ctx_has_token = ctx_has_token,
129 has_auth_after = has_auth_after,
130 "Proxy forwarding request"
131 );
132
133 let body = RequestBuilder::extract_body(request.into_body())
134 .await
135 .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
136
137 let reqwest_method = RequestBuilder::parse_method(&method_str)
138 .map_err(|reason| ProxyError::InvalidMethod { reason })?;
139
140 let client = self.client_pool.get_default_client();
141
142 let req_builder =
143 RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
144
145 let response = match req_builder.send().await {
146 Ok(resp) => resp,
147 Err(e) => {
148 tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
149 return Err(ProxyError::ConnectionFailed {
150 service: service_name.to_string(),
151 url: full_url.clone(),
152 source: e,
153 });
154 },
155 };
156
157 if service.module_name == "mcp" {
158 let resp_status = response.status();
159 let resp_session = response
160 .headers()
161 .get("mcp-session-id")
162 .and_then(|v| v.to_str().ok())
163 .unwrap_or("none");
164 let resp_content_type = response
165 .headers()
166 .get("content-type")
167 .and_then(|v| v.to_str().ok())
168 .unwrap_or("none");
169
170 tracing::info!(
171 service = %service_name,
172 status = %resp_status,
173 resp_session_id = %resp_session,
174 content_type = %resp_content_type,
175 method = %method_str,
176 "MCP backend response"
177 );
178
179 if !resp_status.is_success() {
180 let header_dump: Vec<String> = response
181 .headers()
182 .iter()
183 .map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or("?")))
184 .collect();
185 tracing::error!(
186 service = %service_name,
187 status = %resp_status,
188 headers = ?header_dump,
189 "MCP backend error response"
190 );
191
192 if resp_status == StatusCode::NOT_FOUND && method_str == "GET" {
193 if let Some(session_id) = request_headers
194 .get("mcp-session-id")
195 .and_then(|v| v.to_str().ok())
196 {
197 self.session_cache.write().await.remove(session_id);
198 tracing::info!(
199 service = %service_name,
200 session_id = %session_id,
201 "Evicted stale proxy session cache on 404 GET"
202 );
203 }
204 }
205 }
206
207 if let Some(session_id) = response
208 .headers()
209 .get("mcp-session-id")
210 .and_then(|v| v.to_str().ok())
211 {
212 if let Some(user) = &authenticated_user {
213 self.session_cache.write().await.insert(
214 session_id.to_string(),
215 ProxySessionIdentity {
216 user: user.id.to_string(),
217 user_type: req_context.user_type().to_string(),
218 permissions: user.permissions.clone(),
219 auth_token: req_context.auth_token().as_str().to_string(),
220 },
221 );
222 tracing::info!(
223 service = %service_name,
224 session_id = %session_id,
225 user_id = %user.id,
226 "Cached session identity for MCP session"
227 );
228 }
229 }
230
231 if method_str == "DELETE" {
232 if let Some(session_id) = request_headers
233 .get("mcp-session-id")
234 .and_then(|v| v.to_str().ok())
235 {
236 self.session_cache.write().await.remove(session_id);
237 tracing::debug!(session_id = %session_id, "Evicted session identity on DELETE");
238 }
239 }
240 }
241
242 match ResponseHandler::build_response(response) {
243 Ok(resp) => Ok(resp),
244 Err(e) => {
245 tracing::error!(service = %service_name, error = %e, "Failed to build response");
246 Err(ProxyError::InvalidResponse {
247 service: service_name.to_string(),
248 reason: format!("Failed to build response: {e}"),
249 })
250 },
251 }
252 }
253
254 pub async fn handle_mcp_request(
255 &self,
256 path_params: Path<(String,)>,
257 State(ctx): State<AppContext>,
258 request: Request<Body>,
259 ) -> Response<Body> {
260 let Path((service_name,)) = path_params;
261 match self.proxy_request(&service_name, "", request, ctx).await {
262 Ok(response) => response,
263 Err(e) => e.into_response(),
264 }
265 }
266
267 pub async fn handle_mcp_request_with_path(
268 &self,
269 path_params: Path<(String, String)>,
270 State(ctx): State<AppContext>,
271 request: Request<Body>,
272 ) -> Response<Body> {
273 let Path((service_name, path)) = path_params;
274 match self.proxy_request(&service_name, &path, request, ctx).await {
275 Ok(response) => response,
276 Err(e) => e.into_response(),
277 }
278 }
279
280 pub async fn handle_agent_request(
281 &self,
282 path_params: Path<(String,)>,
283 State(ctx): State<AppContext>,
284 request: Request<Body>,
285 ) -> Result<Response<Body>, StatusCode> {
286 let Path((service_name,)) = path_params;
287 self.proxy_request(&service_name, "", request, ctx)
288 .await
289 .map_err(|e| e.to_status_code())
290 }
291
292 pub async fn handle_agent_request_with_path(
293 &self,
294 path_params: Path<(String, String)>,
295 State(ctx): State<AppContext>,
296 request: Request<Body>,
297 ) -> Result<Response<Body>, StatusCode> {
298 let Path((service_name, path)) = path_params;
299 self.proxy_request(&service_name, &path, request, ctx)
300 .await
301 .map_err(|e| e.to_status_code())
302 }
303}