Skip to main content

systemprompt_api/services/proxy/
engine.rs

1mod mcp_session;
2
3use axum::body::Body;
4use axum::extract::{Path, Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use std::collections::HashMap;
8use std::sync::Arc;
9use systemprompt_identifiers::AgentName;
10use systemprompt_models::RequestContext;
11use systemprompt_runtime::AppContext;
12use tokio::sync::RwLock;
13
14use super::auth::AccessValidator;
15use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
16use super::client::ClientPool;
17use super::resolver::ServiceResolver;
18use mcp_session::SessionCache;
19
20#[derive(Debug, Clone)]
21pub struct ProxyEngine {
22    client_pool: ClientPool,
23    session_cache: SessionCache,
24}
25
26impl Default for ProxyEngine {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl ProxyEngine {
33    pub fn new() -> Self {
34        Self {
35            client_pool: ClientPool::new(),
36            session_cache: Arc::new(RwLock::new(HashMap::new())),
37        }
38    }
39
40    pub async fn proxy_request(
41        &self,
42        service_name: &str,
43        path: &str,
44        request: Request<Body>,
45        ctx: AppContext,
46    ) -> Result<Response<Body>, ProxyError> {
47        if request.extensions().get::<RequestContext>().is_none() {
48            tracing::warn!("RequestContext missing from request extensions");
49        }
50
51        let service = ServiceResolver::resolve(service_name, &ctx).await?;
52
53        let req_ctx = request.extensions().get::<RequestContext>().cloned();
54        let authenticated_user = AccessValidator::validate(
55            request.headers(),
56            service_name,
57            &service,
58            &ctx,
59            req_ctx.as_ref(),
60        )
61        .await?;
62
63        let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
64
65        let method_str = request.method().to_string();
66        let request_headers = request.headers().clone();
67        let mut headers = request_headers.clone();
68        let query = request.uri().query();
69        let full_url = UrlResolver::append_query_params(backend_url, query);
70
71        let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
72            message: "Request context required - proxy cannot operate without authentication"
73                .to_string(),
74        })?;
75
76        if service.module_name == "agent" || service.module_name == "mcp" {
77            req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
78        }
79
80        if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
81            req_context = mcp_session::enrich_with_cached_identity(
82                &self.session_cache,
83                &request_headers,
84                req_context,
85                service_name,
86            )
87            .await;
88        }
89
90        let has_auth_before = headers.get("authorization").is_some();
91        let ctx_has_token = !req_context.auth_token().as_str().is_empty();
92
93        HeaderInjector::inject_context(&mut headers, &req_context);
94
95        let has_auth_after = headers.get("authorization").is_some();
96        tracing::debug!(
97            service = %service_name,
98            has_auth_before = has_auth_before,
99            ctx_has_token = ctx_has_token,
100            has_auth_after = has_auth_after,
101            "Proxy forwarding request"
102        );
103
104        let body = RequestBuilder::extract_body(request.into_body())
105            .await
106            .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
107
108        let reqwest_method = RequestBuilder::parse_method(&method_str)
109            .map_err(|reason| ProxyError::InvalidMethod { reason })?;
110
111        let client = self.client_pool.get_default_client();
112
113        let req_builder =
114            RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
115
116        let response = match req_builder.send().await {
117            Ok(resp) => resp,
118            Err(e) => {
119                tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
120                return Err(ProxyError::ConnectionFailed {
121                    service: service_name.to_string(),
122                    url: full_url.clone(),
123                    source: e,
124                });
125            },
126        };
127
128        if service.module_name == "mcp" {
129            mcp_session::handle_mcp_response(mcp_session::McpResponseCtx {
130                cache: &self.session_cache,
131                response: &response,
132                request_headers: &request_headers,
133                req_context: &req_context,
134                authenticated_user: authenticated_user.as_ref(),
135                service_name,
136                method_str: &method_str,
137            })
138            .await;
139        }
140
141        match ResponseHandler::build_response(response) {
142            Ok(resp) => Ok(resp),
143            Err(e) => {
144                tracing::error!(service = %service_name, error = %e, "Failed to build response");
145                Err(ProxyError::InvalidResponse {
146                    service: service_name.to_string(),
147                    reason: format!("Failed to build response: {e}"),
148                })
149            },
150        }
151    }
152
153    pub async fn handle_mcp_request(
154        &self,
155        path_params: Path<(String,)>,
156        State(ctx): State<AppContext>,
157        request: Request<Body>,
158    ) -> Response<Body> {
159        let Path((service_name,)) = path_params;
160        match self.proxy_request(&service_name, "", request, ctx).await {
161            Ok(response) => response,
162            Err(e) => e.into_response(),
163        }
164    }
165
166    pub async fn handle_mcp_request_with_path(
167        &self,
168        path_params: Path<(String, String)>,
169        State(ctx): State<AppContext>,
170        request: Request<Body>,
171    ) -> Response<Body> {
172        let Path((service_name, path)) = path_params;
173        match self.proxy_request(&service_name, &path, request, ctx).await {
174            Ok(response) => response,
175            Err(e) => e.into_response(),
176        }
177    }
178
179    pub async fn handle_agent_request(
180        &self,
181        path_params: Path<(String,)>,
182        State(ctx): State<AppContext>,
183        request: Request<Body>,
184    ) -> Result<Response<Body>, StatusCode> {
185        let Path((service_name,)) = path_params;
186        self.proxy_request(&service_name, "", request, ctx)
187            .await
188            .map_err(|e| e.to_status_code())
189    }
190
191    pub async fn handle_agent_request_with_path(
192        &self,
193        path_params: Path<(String, String)>,
194        State(ctx): State<AppContext>,
195        request: Request<Body>,
196    ) -> Result<Response<Body>, StatusCode> {
197        let Path((service_name, path)) = path_params;
198        self.proxy_request(&service_name, &path, request, ctx)
199            .await
200            .map_err(|e| e.to_status_code())
201    }
202}