systemprompt_api/services/proxy/
engine.rs1mod mcp_session;
2
3use axum::body::Body;
4use axum::extract::{Path, Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use std::collections::HashMap;
8use std::sync::Arc;
9use systemprompt_identifiers::AgentName;
10use systemprompt_models::RequestContext;
11use systemprompt_runtime::AppContext;
12use tokio::sync::RwLock;
13
14use super::auth::AccessValidator;
15use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
16use super::client::ClientPool;
17use super::resolver::ServiceResolver;
18use mcp_session::SessionCache;
19
20#[derive(Debug, Clone)]
21pub struct ProxyEngine {
22 client_pool: ClientPool,
23 session_cache: SessionCache,
24}
25
26impl Default for ProxyEngine {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl ProxyEngine {
33 pub fn new() -> Self {
34 Self {
35 client_pool: ClientPool::new(),
36 session_cache: Arc::new(RwLock::new(HashMap::new())),
37 }
38 }
39
40 pub async fn proxy_request(
41 &self,
42 service_name: &str,
43 path: &str,
44 request: Request<Body>,
45 ctx: AppContext,
46 ) -> Result<Response<Body>, ProxyError> {
47 if request.extensions().get::<RequestContext>().is_none() {
48 tracing::warn!("RequestContext missing from request extensions");
49 }
50
51 let service = ServiceResolver::resolve(service_name, &ctx).await?;
52
53 let req_ctx = request.extensions().get::<RequestContext>().cloned();
54 let authenticated_user = AccessValidator::validate(
55 request.headers(),
56 service_name,
57 &service,
58 &ctx,
59 req_ctx.as_ref(),
60 )
61 .await?;
62
63 let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
64
65 let method_str = request.method().to_string();
66 let request_headers = request.headers().clone();
67 let mut headers = request_headers.clone();
68 let query = request.uri().query();
69 let full_url = UrlResolver::append_query_params(backend_url, query);
70
71 let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
72 message: "Request context required - proxy cannot operate without authentication"
73 .to_string(),
74 })?;
75
76 if service.module_name == "agent" || service.module_name == "mcp" {
77 req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
78 }
79
80 if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
81 req_context = mcp_session::enrich_with_cached_identity(
82 &self.session_cache,
83 &request_headers,
84 req_context,
85 service_name,
86 )
87 .await;
88 }
89
90 let has_auth_before = headers.get("authorization").is_some();
91 let ctx_has_token = !req_context.auth_token().as_str().is_empty();
92
93 HeaderInjector::inject_context(&mut headers, &req_context);
94
95 let has_auth_after = headers.get("authorization").is_some();
96 tracing::debug!(
97 service = %service_name,
98 has_auth_before = has_auth_before,
99 ctx_has_token = ctx_has_token,
100 has_auth_after = has_auth_after,
101 "Proxy forwarding request"
102 );
103
104 let body = RequestBuilder::extract_body(request.into_body())
105 .await
106 .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
107
108 let reqwest_method = RequestBuilder::parse_method(&method_str)
109 .map_err(|reason| ProxyError::InvalidMethod { reason })?;
110
111 let client = self.client_pool.get_default_client();
112
113 let req_builder =
114 RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
115
116 let response = match req_builder.send().await {
117 Ok(resp) => resp,
118 Err(e) => {
119 tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
120 return Err(ProxyError::ConnectionFailed {
121 service: service_name.to_string(),
122 url: full_url.clone(),
123 source: e,
124 });
125 },
126 };
127
128 if service.module_name == "mcp" {
129 mcp_session::handle_mcp_response(mcp_session::McpResponseCtx {
130 cache: &self.session_cache,
131 response: &response,
132 request_headers: &request_headers,
133 req_context: &req_context,
134 authenticated_user: authenticated_user.as_ref(),
135 service_name,
136 method_str: &method_str,
137 })
138 .await;
139 }
140
141 match ResponseHandler::build_response(response) {
142 Ok(resp) => Ok(resp),
143 Err(e) => {
144 tracing::error!(service = %service_name, error = %e, "Failed to build response");
145 Err(ProxyError::InvalidResponse {
146 service: service_name.to_string(),
147 reason: format!("Failed to build response: {e}"),
148 })
149 },
150 }
151 }
152
153 pub async fn handle_mcp_request(
154 &self,
155 path_params: Path<(String,)>,
156 State(ctx): State<AppContext>,
157 request: Request<Body>,
158 ) -> Response<Body> {
159 let Path((service_name,)) = path_params;
160 match self.proxy_request(&service_name, "", request, ctx).await {
161 Ok(response) => response,
162 Err(e) => e.into_response(),
163 }
164 }
165
166 pub async fn handle_mcp_request_with_path(
167 &self,
168 path_params: Path<(String, String)>,
169 State(ctx): State<AppContext>,
170 request: Request<Body>,
171 ) -> Response<Body> {
172 let Path((service_name, path)) = path_params;
173 match self.proxy_request(&service_name, &path, request, ctx).await {
174 Ok(response) => response,
175 Err(e) => e.into_response(),
176 }
177 }
178
179 pub async fn handle_agent_request(
180 &self,
181 path_params: Path<(String,)>,
182 State(ctx): State<AppContext>,
183 request: Request<Body>,
184 ) -> Result<Response<Body>, StatusCode> {
185 let Path((service_name,)) = path_params;
186 self.proxy_request(&service_name, "", request, ctx)
187 .await
188 .map_err(|e| e.to_status_code())
189 }
190
191 pub async fn handle_agent_request_with_path(
192 &self,
193 path_params: Path<(String, String)>,
194 State(ctx): State<AppContext>,
195 request: Request<Body>,
196 ) -> Result<Response<Body>, StatusCode> {
197 let Path((service_name, path)) = path_params;
198 self.proxy_request(&service_name, &path, request, ctx)
199 .await
200 .map_err(|e| e.to_status_code())
201 }
202}