Skip to main content

systemprompt_api/services/proxy/engine/
mod.rs

1//! Reverse-proxy engine for MCP and agent backends.
2//!
3//! [`ProxyEngine`] resolves a service by name, enforces access via the proxy
4//! auth boundary, forwards the request to the local backend port, and streams
5//! the response back (with SSE keep-alive). For MCP it also maintains the
6//! session-identity cache so a session-only follow-up request can be enriched
7//! with the identity established on the authenticated initialize call.
8
9mod mcp_session;
10
11use axum::body::Body;
12use axum::extract::{Path, Request, State};
13use axum::http::StatusCode;
14use axum::response::{IntoResponse, Response};
15use std::collections::HashMap;
16use std::sync::Arc;
17use systemprompt_identifiers::AgentName;
18use systemprompt_models::RequestContext;
19use systemprompt_runtime::AppContext;
20use tokio::sync::RwLock;
21
22use super::auth::{AccessValidator, build_mcp_unknown_service_challenge};
23use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
24use super::client::ClientPool;
25use super::resolver::ServiceResolver;
26use mcp_session::SessionCache;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ProxyKind {
30    Mcp,
31    Agent,
32}
33
34#[derive(Debug)]
35pub struct ProxyTarget<'a> {
36    pub service_name: &'a str,
37    pub path: &'a str,
38    pub kind: ProxyKind,
39}
40
41#[derive(Debug, Clone)]
42pub struct ProxyEngine {
43    client_pool: ClientPool,
44    session_cache: SessionCache,
45}
46
47impl Default for ProxyEngine {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl ProxyEngine {
54    pub fn new() -> Self {
55        Self {
56            client_pool: ClientPool::new(),
57            session_cache: Arc::new(RwLock::new(HashMap::new())),
58        }
59    }
60
61    pub async fn proxy_request(
62        &self,
63        target: ProxyTarget<'_>,
64        request: Request<Body>,
65        ctx: AppContext,
66    ) -> Result<Response<Body>, ProxyError> {
67        let ProxyTarget {
68            service_name,
69            path,
70            kind: proxy_kind,
71        } = target;
72        if request.extensions().get::<RequestContext>().is_none() {
73            tracing::warn!("RequestContext missing from request extensions");
74        }
75
76        let service = match ServiceResolver::resolve(service_name, &ctx).await {
77            Ok(svc) => svc,
78            Err(err) => {
79                if proxy_kind == ProxyKind::Mcp && matches!(err, ProxyError::ServiceNotFound { .. })
80                {
81                    let req_ctx = request.extensions().get::<RequestContext>().cloned();
82                    if let Some(challenge) = build_mcp_unknown_service_challenge(
83                        service_name,
84                        request.headers(),
85                        &ctx,
86                        req_ctx.as_ref(),
87                    ) {
88                        return Err(challenge);
89                    }
90                }
91                return Err(err);
92            },
93        };
94
95        let req_ctx = request.extensions().get::<RequestContext>().cloned();
96        let authenticated_user = AccessValidator::validate(
97            request.headers(),
98            service_name,
99            &service,
100            &ctx,
101            req_ctx.as_ref(),
102        )
103        .await?;
104
105        let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
106
107        let method_str = request.method().to_string();
108        let request_headers = request.headers().clone();
109        let mut headers = request_headers.clone();
110        let query = request.uri().query();
111        let full_url = UrlResolver::append_query_params(backend_url, query);
112
113        let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
114            message: "Request context required - proxy cannot operate without authentication"
115                .to_owned(),
116        })?;
117
118        if service.module_name == "agent" || service.module_name == "mcp" {
119            req_context = req_context.with_agent_name(AgentName::new(service_name.to_owned()));
120        }
121
122        if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
123            req_context = mcp_session::enrich_with_cached_identity(
124                &self.session_cache,
125                &request_headers,
126                req_context,
127                service_name,
128            )
129            .await;
130        }
131
132        let has_auth_before = headers.get("authorization").is_some();
133        let ctx_has_token = !req_context.auth_token().as_str().is_empty();
134
135        HeaderInjector::inject_context(&mut headers, &req_context);
136
137        let has_auth_after = headers.get("authorization").is_some();
138        tracing::debug!(
139            service = %service_name,
140            has_auth_before = has_auth_before,
141            ctx_has_token = ctx_has_token,
142            has_auth_after = has_auth_after,
143            "Proxy forwarding request"
144        );
145
146        let body = RequestBuilder::extract_body(request.into_body())
147            .await
148            .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
149
150        let reqwest_method = RequestBuilder::parse_method(&method_str)
151            .map_err(|reason| ProxyError::InvalidMethod { reason })?;
152
153        let client = self.client_pool.get_default_client();
154
155        let req_builder =
156            RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
157
158        let response = match req_builder.send().await {
159            Ok(resp) => resp,
160            Err(e) => {
161                tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
162                return Err(ProxyError::ConnectionFailed {
163                    service: service_name.to_owned(),
164                    url: full_url.clone(),
165                    source: e,
166                });
167            },
168        };
169
170        if service.module_name == "mcp" {
171            mcp_session::handle_mcp_response(mcp_session::McpResponseCtx {
172                cache: &self.session_cache,
173                response: &response,
174                request_headers: &request_headers,
175                req_context: &req_context,
176                authenticated_user: authenticated_user.as_ref(),
177                service_name,
178                method_str: &method_str,
179            })
180            .await;
181        }
182
183        match ResponseHandler::build_response(response) {
184            Ok(resp) => Ok(resp),
185            Err(e) => {
186                tracing::error!(service = %service_name, error = %e, "Failed to build response");
187                Err(ProxyError::InvalidResponse {
188                    service: service_name.to_owned(),
189                    reason: format!("Failed to build response: {e}"),
190                })
191            },
192        }
193    }
194
195    pub async fn handle_mcp_request(
196        &self,
197        path_params: Path<(String,)>,
198        State(ctx): State<AppContext>,
199        request: Request<Body>,
200    ) -> Response<Body> {
201        let Path((service_name,)) = path_params;
202        let target = ProxyTarget {
203            service_name: &service_name,
204            path: "",
205            kind: ProxyKind::Mcp,
206        };
207        match self.proxy_request(target, request, ctx).await {
208            Ok(response) => response,
209            Err(e) => e.into_response(),
210        }
211    }
212
213    pub async fn handle_mcp_request_with_path(
214        &self,
215        path_params: Path<(String, String)>,
216        State(ctx): State<AppContext>,
217        request: Request<Body>,
218    ) -> Response<Body> {
219        let Path((service_name, path)) = path_params;
220        let target = ProxyTarget {
221            service_name: &service_name,
222            path: &path,
223            kind: ProxyKind::Mcp,
224        };
225        match self.proxy_request(target, request, ctx).await {
226            Ok(response) => response,
227            Err(e) => e.into_response(),
228        }
229    }
230
231    pub async fn handle_agent_request(
232        &self,
233        path_params: Path<(String,)>,
234        State(ctx): State<AppContext>,
235        request: Request<Body>,
236    ) -> Result<Response<Body>, StatusCode> {
237        let Path((service_name,)) = path_params;
238        let target = ProxyTarget {
239            service_name: &service_name,
240            path: "",
241            kind: ProxyKind::Agent,
242        };
243        self.proxy_request(target, request, ctx)
244            .await
245            .map_err(|e| e.to_status_code())
246    }
247
248    pub async fn handle_agent_request_with_path(
249        &self,
250        path_params: Path<(String, String)>,
251        State(ctx): State<AppContext>,
252        request: Request<Body>,
253    ) -> Result<Response<Body>, StatusCode> {
254        let Path((service_name, path)) = path_params;
255        let target = ProxyTarget {
256            service_name: &service_name,
257            path: &path,
258            kind: ProxyKind::Agent,
259        };
260        self.proxy_request(target, request, ctx)
261            .await
262            .map_err(|e| e.to_status_code())
263    }
264}