systemprompt_api/services/proxy/engine/
mod.rs1mod mcp_session;
10
11use axum::body::Body;
12use axum::extract::{Path, Request, State};
13use axum::http::StatusCode;
14use axum::response::{IntoResponse, Response};
15use std::collections::HashMap;
16use std::sync::Arc;
17use systemprompt_identifiers::AgentName;
18use systemprompt_models::RequestContext;
19use systemprompt_runtime::AppContext;
20use tokio::sync::RwLock;
21
22use super::auth::{AccessValidator, build_mcp_unknown_service_challenge};
23use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
24use super::client::ClientPool;
25use super::resolver::ServiceResolver;
26use mcp_session::SessionCache;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ProxyKind {
30 Mcp,
31 Agent,
32}
33
34#[derive(Debug)]
35pub struct ProxyTarget<'a> {
36 pub service_name: &'a str,
37 pub path: &'a str,
38 pub kind: ProxyKind,
39}
40
41#[derive(Debug, Clone)]
42pub struct ProxyEngine {
43 client_pool: ClientPool,
44 session_cache: SessionCache,
45}
46
47impl Default for ProxyEngine {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl ProxyEngine {
54 pub fn new() -> Self {
55 Self {
56 client_pool: ClientPool::new(),
57 session_cache: Arc::new(RwLock::new(HashMap::new())),
58 }
59 }
60
61 pub async fn proxy_request(
62 &self,
63 target: ProxyTarget<'_>,
64 request: Request<Body>,
65 ctx: AppContext,
66 ) -> Result<Response<Body>, ProxyError> {
67 let ProxyTarget {
68 service_name,
69 path,
70 kind: proxy_kind,
71 } = target;
72 if request.extensions().get::<RequestContext>().is_none() {
73 tracing::warn!("RequestContext missing from request extensions");
74 }
75
76 let service = match ServiceResolver::resolve(service_name, &ctx).await {
77 Ok(svc) => svc,
78 Err(err) => {
79 if proxy_kind == ProxyKind::Mcp && matches!(err, ProxyError::ServiceNotFound { .. })
80 {
81 let req_ctx = request.extensions().get::<RequestContext>().cloned();
82 if let Some(challenge) = build_mcp_unknown_service_challenge(
83 service_name,
84 request.headers(),
85 &ctx,
86 req_ctx.as_ref(),
87 ) {
88 return Err(challenge);
89 }
90 }
91 return Err(err);
92 },
93 };
94
95 let req_ctx = request.extensions().get::<RequestContext>().cloned();
96 let authenticated_user = AccessValidator::validate(
97 request.headers(),
98 service_name,
99 &service,
100 &ctx,
101 req_ctx.as_ref(),
102 )
103 .await?;
104
105 let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
106
107 let method_str = request.method().to_string();
108 let request_headers = request.headers().clone();
109 let mut headers = request_headers.clone();
110 let query = request.uri().query();
111 let full_url = UrlResolver::append_query_params(backend_url, query);
112
113 let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
114 message: "Request context required - proxy cannot operate without authentication"
115 .to_owned(),
116 })?;
117
118 if service.module_name == "agent" || service.module_name == "mcp" {
119 req_context = req_context.with_agent_name(AgentName::new(service_name.to_owned()));
120 }
121
122 if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
123 req_context = mcp_session::enrich_with_cached_identity(
124 &self.session_cache,
125 &request_headers,
126 req_context,
127 service_name,
128 )
129 .await;
130 }
131
132 let has_auth_before = headers.get("authorization").is_some();
133 let ctx_has_token = !req_context.auth_token().as_str().is_empty();
134
135 HeaderInjector::inject_context(&mut headers, &req_context);
136
137 let has_auth_after = headers.get("authorization").is_some();
138 tracing::debug!(
139 service = %service_name,
140 has_auth_before = has_auth_before,
141 ctx_has_token = ctx_has_token,
142 has_auth_after = has_auth_after,
143 "Proxy forwarding request"
144 );
145
146 let body = RequestBuilder::extract_body(request.into_body())
147 .await
148 .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
149
150 let reqwest_method = RequestBuilder::parse_method(&method_str)
151 .map_err(|reason| ProxyError::InvalidMethod { reason })?;
152
153 let client = self.client_pool.get_default_client();
154
155 let req_builder =
156 RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
157
158 let response = match req_builder.send().await {
159 Ok(resp) => resp,
160 Err(e) => {
161 tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
162 return Err(ProxyError::ConnectionFailed {
163 service: service_name.to_owned(),
164 url: full_url.clone(),
165 source: e,
166 });
167 },
168 };
169
170 if service.module_name == "mcp" {
171 mcp_session::handle_mcp_response(mcp_session::McpResponseCtx {
172 cache: &self.session_cache,
173 response: &response,
174 request_headers: &request_headers,
175 req_context: &req_context,
176 authenticated_user: authenticated_user.as_ref(),
177 service_name,
178 method_str: &method_str,
179 })
180 .await;
181 }
182
183 match ResponseHandler::build_response(response) {
184 Ok(resp) => Ok(resp),
185 Err(e) => {
186 tracing::error!(service = %service_name, error = %e, "Failed to build response");
187 Err(ProxyError::InvalidResponse {
188 service: service_name.to_owned(),
189 reason: format!("Failed to build response: {e}"),
190 })
191 },
192 }
193 }
194
195 pub async fn handle_mcp_request(
196 &self,
197 path_params: Path<(String,)>,
198 State(ctx): State<AppContext>,
199 request: Request<Body>,
200 ) -> Response<Body> {
201 let Path((service_name,)) = path_params;
202 let target = ProxyTarget {
203 service_name: &service_name,
204 path: "",
205 kind: ProxyKind::Mcp,
206 };
207 match self.proxy_request(target, request, ctx).await {
208 Ok(response) => response,
209 Err(e) => e.into_response(),
210 }
211 }
212
213 pub async fn handle_mcp_request_with_path(
214 &self,
215 path_params: Path<(String, String)>,
216 State(ctx): State<AppContext>,
217 request: Request<Body>,
218 ) -> Response<Body> {
219 let Path((service_name, path)) = path_params;
220 let target = ProxyTarget {
221 service_name: &service_name,
222 path: &path,
223 kind: ProxyKind::Mcp,
224 };
225 match self.proxy_request(target, request, ctx).await {
226 Ok(response) => response,
227 Err(e) => e.into_response(),
228 }
229 }
230
231 pub async fn handle_agent_request(
232 &self,
233 path_params: Path<(String,)>,
234 State(ctx): State<AppContext>,
235 request: Request<Body>,
236 ) -> Result<Response<Body>, StatusCode> {
237 let Path((service_name,)) = path_params;
238 let target = ProxyTarget {
239 service_name: &service_name,
240 path: "",
241 kind: ProxyKind::Agent,
242 };
243 self.proxy_request(target, request, ctx)
244 .await
245 .map_err(|e| e.to_status_code())
246 }
247
248 pub async fn handle_agent_request_with_path(
249 &self,
250 path_params: Path<(String, String)>,
251 State(ctx): State<AppContext>,
252 request: Request<Body>,
253 ) -> Result<Response<Body>, StatusCode> {
254 let Path((service_name, path)) = path_params;
255 let target = ProxyTarget {
256 service_name: &service_name,
257 path: &path,
258 kind: ProxyKind::Agent,
259 };
260 self.proxy_request(target, request, ctx)
261 .await
262 .map_err(|e| e.to_status_code())
263 }
264}