Skip to main content

systemprompt_api/services/proxy/
engine.rs

1use axum::body::Body;
2use axum::extract::{Path, Request, State};
3use axum::http::StatusCode;
4use axum::response::{IntoResponse, Response};
5use std::collections::HashMap;
6use std::sync::Arc;
7use systemprompt_identifiers::{AgentName, UserId};
8use systemprompt_models::RequestContext;
9use systemprompt_models::auth::{AuthenticatedUser, Permission};
10use systemprompt_runtime::AppContext;
11use tokio::sync::RwLock;
12
13use super::auth::AccessValidator;
14use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
15use super::client::ClientPool;
16use super::resolver::ServiceResolver;
17
18#[derive(Clone, Debug)]
19struct ProxySessionIdentity {
20    user: String,
21    user_type: String,
22    permissions: Vec<Permission>,
23    auth_token: String,
24}
25
26type SessionCache = Arc<RwLock<HashMap<String, ProxySessionIdentity>>>;
27
28#[derive(Debug, Clone)]
29pub struct ProxyEngine {
30    client_pool: ClientPool,
31    session_cache: SessionCache,
32}
33
34impl Default for ProxyEngine {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl ProxyEngine {
41    pub fn new() -> Self {
42        Self {
43            client_pool: ClientPool::new(),
44            session_cache: Arc::new(RwLock::new(HashMap::new())),
45        }
46    }
47
48    pub async fn proxy_request(
49        &self,
50        service_name: &str,
51        path: &str,
52        request: Request<Body>,
53        ctx: AppContext,
54    ) -> Result<Response<Body>, ProxyError> {
55        if request.extensions().get::<RequestContext>().is_none() {
56            tracing::warn!("RequestContext missing from request extensions");
57        }
58
59        let service = ServiceResolver::resolve(service_name, &ctx).await?;
60
61        let req_ctx = request.extensions().get::<RequestContext>().cloned();
62        let authenticated_user = AccessValidator::validate(
63            request.headers(),
64            service_name,
65            &service,
66            &ctx,
67            req_ctx.as_ref(),
68        )
69        .await?;
70
71        let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
72
73        let method_str = request.method().to_string();
74        let request_headers = request.headers().clone();
75        let mut headers = request_headers.clone();
76        let query = request.uri().query();
77        let full_url = UrlResolver::append_query_params(backend_url, query);
78
79        let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
80            message: "Request context required - proxy cannot operate without authentication"
81                .to_string(),
82        })?;
83
84        if service.module_name == "agent" || service.module_name == "mcp" {
85            req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
86        }
87
88        if service.module_name == "mcp" && req_context.auth_token().as_str().is_empty() {
89            if let Some(session_id) = request_headers
90                .get("mcp-session-id")
91                .and_then(|v| v.to_str().ok())
92            {
93                if let Some(identity) = self.session_cache.read().await.get(session_id) {
94                    tracing::info!(
95                        service = %service_name,
96                        session_id = %session_id,
97                        user_id = %identity.user,
98                        "Enriching session-only request with cached identity"
99                    );
100                    req_context = req_context
101                        .with_user_id(UserId::from(identity.user.clone()))
102                        .with_user_type(
103                            identity
104                                .user_type
105                                .parse()
106                                .unwrap_or(systemprompt_models::auth::UserType::Unknown),
107                        )
108                        .with_auth_token(identity.auth_token.clone())
109                        .with_user(AuthenticatedUser::new(
110                            identity.user.parse().unwrap_or(uuid::Uuid::nil()),
111                            String::new(),
112                            String::new(),
113                            identity.permissions.clone(),
114                        ));
115                }
116            }
117        }
118
119        let has_auth_before = headers.get("authorization").is_some();
120        let ctx_has_token = !req_context.auth_token().as_str().is_empty();
121
122        HeaderInjector::inject_context(&mut headers, &req_context);
123
124        let has_auth_after = headers.get("authorization").is_some();
125        tracing::debug!(
126            service = %service_name,
127            has_auth_before = has_auth_before,
128            ctx_has_token = ctx_has_token,
129            has_auth_after = has_auth_after,
130            "Proxy forwarding request"
131        );
132
133        let body = RequestBuilder::extract_body(request.into_body())
134            .await
135            .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
136
137        let reqwest_method = RequestBuilder::parse_method(&method_str)
138            .map_err(|reason| ProxyError::InvalidMethod { reason })?;
139
140        let client = self.client_pool.get_default_client();
141
142        let req_builder =
143            RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
144
145        let response = match req_builder.send().await {
146            Ok(resp) => resp,
147            Err(e) => {
148                tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
149                return Err(ProxyError::ConnectionFailed {
150                    service: service_name.to_string(),
151                    url: full_url.clone(),
152                    source: e,
153                });
154            },
155        };
156
157        if service.module_name == "mcp" {
158            let resp_status = response.status();
159            let resp_session = response
160                .headers()
161                .get("mcp-session-id")
162                .and_then(|v| v.to_str().ok())
163                .unwrap_or("none");
164            let resp_content_type = response
165                .headers()
166                .get("content-type")
167                .and_then(|v| v.to_str().ok())
168                .unwrap_or("none");
169
170            tracing::info!(
171                service = %service_name,
172                status = %resp_status,
173                resp_session_id = %resp_session,
174                content_type = %resp_content_type,
175                method = %method_str,
176                "MCP backend response"
177            );
178
179            if !resp_status.is_success() {
180                let header_dump: Vec<String> = response
181                    .headers()
182                    .iter()
183                    .map(|(k, v)| format!("{}: {}", k, v.to_str().unwrap_or("?")))
184                    .collect();
185                tracing::error!(
186                    service = %service_name,
187                    status = %resp_status,
188                    headers = ?header_dump,
189                    "MCP backend error response"
190                );
191
192                if resp_status == StatusCode::NOT_FOUND && method_str == "GET" {
193                    if let Some(session_id) = request_headers
194                        .get("mcp-session-id")
195                        .and_then(|v| v.to_str().ok())
196                    {
197                        self.session_cache.write().await.remove(session_id);
198                        tracing::info!(
199                            service = %service_name,
200                            session_id = %session_id,
201                            "Evicted stale proxy session cache on 404 GET"
202                        );
203                    }
204                }
205            }
206
207            if let Some(session_id) = response
208                .headers()
209                .get("mcp-session-id")
210                .and_then(|v| v.to_str().ok())
211            {
212                if let Some(user) = &authenticated_user {
213                    self.session_cache.write().await.insert(
214                        session_id.to_string(),
215                        ProxySessionIdentity {
216                            user: user.id.to_string(),
217                            user_type: req_context.user_type().to_string(),
218                            permissions: user.permissions.clone(),
219                            auth_token: req_context.auth_token().as_str().to_string(),
220                        },
221                    );
222                    tracing::info!(
223                        service = %service_name,
224                        session_id = %session_id,
225                        user_id = %user.id,
226                        "Cached session identity for MCP session"
227                    );
228                }
229            }
230
231            if method_str == "DELETE" {
232                if let Some(session_id) = request_headers
233                    .get("mcp-session-id")
234                    .and_then(|v| v.to_str().ok())
235                {
236                    self.session_cache.write().await.remove(session_id);
237                    tracing::debug!(session_id = %session_id, "Evicted session identity on DELETE");
238                }
239            }
240        }
241
242        match ResponseHandler::build_response(response) {
243            Ok(resp) => Ok(resp),
244            Err(e) => {
245                tracing::error!(service = %service_name, error = %e, "Failed to build response");
246                Err(ProxyError::InvalidResponse {
247                    service: service_name.to_string(),
248                    reason: format!("Failed to build response: {e}"),
249                })
250            },
251        }
252    }
253
254    pub async fn handle_mcp_request(
255        &self,
256        path_params: Path<(String,)>,
257        State(ctx): State<AppContext>,
258        request: Request<Body>,
259    ) -> Response<Body> {
260        let Path((service_name,)) = path_params;
261        match self.proxy_request(&service_name, "", request, ctx).await {
262            Ok(response) => response,
263            Err(e) => e.into_response(),
264        }
265    }
266
267    pub async fn handle_mcp_request_with_path(
268        &self,
269        path_params: Path<(String, String)>,
270        State(ctx): State<AppContext>,
271        request: Request<Body>,
272    ) -> Response<Body> {
273        let Path((service_name, path)) = path_params;
274        match self.proxy_request(&service_name, &path, request, ctx).await {
275            Ok(response) => response,
276            Err(e) => e.into_response(),
277        }
278    }
279
280    pub async fn handle_agent_request(
281        &self,
282        path_params: Path<(String,)>,
283        State(ctx): State<AppContext>,
284        request: Request<Body>,
285    ) -> Result<Response<Body>, StatusCode> {
286        let Path((service_name,)) = path_params;
287        self.proxy_request(&service_name, "", request, ctx)
288            .await
289            .map_err(|e| e.to_status_code())
290    }
291
292    pub async fn handle_agent_request_with_path(
293        &self,
294        path_params: Path<(String, String)>,
295        State(ctx): State<AppContext>,
296        request: Request<Body>,
297    ) -> Result<Response<Body>, StatusCode> {
298        let Path((service_name, path)) = path_params;
299        self.proxy_request(&service_name, &path, request, ctx)
300            .await
301            .map_err(|e| e.to_status_code())
302    }
303}