Skip to main content

systemprompt_api/services/proxy/
engine.rs

1use axum::body::Body;
2use axum::extract::{Path, Request, State};
3use axum::http::StatusCode;
4use axum::response::{IntoResponse, Response};
5use systemprompt_identifiers::AgentName;
6use systemprompt_models::RequestContext;
7use systemprompt_runtime::AppContext;
8
9use super::auth::AccessValidator;
10use super::backend::{HeaderInjector, ProxyError, RequestBuilder, ResponseHandler, UrlResolver};
11use super::client::ClientPool;
12use super::resolver::ServiceResolver;
13
14#[derive(Debug, Clone)]
15pub struct ProxyEngine {
16    client_pool: ClientPool,
17}
18
19impl Default for ProxyEngine {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl ProxyEngine {
26    pub fn new() -> Self {
27        Self {
28            client_pool: ClientPool::new(),
29        }
30    }
31
32    pub async fn proxy_request(
33        &self,
34        service_name: &str,
35        path: &str,
36        request: Request<Body>,
37        ctx: AppContext,
38    ) -> Result<Response<Body>, ProxyError> {
39        if request.extensions().get::<RequestContext>().is_none() {
40            tracing::warn!("RequestContext missing from request extensions");
41        }
42
43        let service = ServiceResolver::resolve(service_name, &ctx).await?;
44
45        let req_ctx = request.extensions().get::<RequestContext>().cloned();
46        AccessValidator::validate(
47            request.headers(),
48            service_name,
49            &service,
50            &ctx,
51            req_ctx.as_ref(),
52        )
53        .await?;
54
55        let backend_url = UrlResolver::build_backend_url("http", "127.0.0.1", service.port, path);
56
57        let method_str = request.method().to_string();
58        let mut headers = request.headers().clone();
59        let query = request.uri().query();
60        let full_url = UrlResolver::append_query_params(backend_url, query);
61
62        let mut req_context = req_ctx.clone().ok_or_else(|| ProxyError::MissingContext {
63            message: "Request context required - proxy cannot operate without authentication"
64                .to_string(),
65        })?;
66
67        if service.module_name == "agent" || service.module_name == "mcp" {
68            req_context = req_context.with_agent_name(AgentName::new(service_name.to_string()));
69        }
70
71        let has_auth_before = headers.get("authorization").is_some();
72        let ctx_has_token = !req_context.auth_token().as_str().is_empty();
73
74        HeaderInjector::inject_context(&mut headers, &req_context);
75
76        let has_auth_after = headers.get("authorization").is_some();
77        tracing::debug!(
78            service = %service_name,
79            has_auth_before = has_auth_before,
80            ctx_has_token = ctx_has_token,
81            has_auth_after = has_auth_after,
82            "Proxy forwarding request"
83        );
84
85        let body = RequestBuilder::extract_body(request.into_body())
86            .await
87            .map_err(|e| ProxyError::BodyExtractionFailed { source: e })?;
88
89        let reqwest_method = RequestBuilder::parse_method(&method_str)
90            .map_err(|reason| ProxyError::InvalidMethod { reason })?;
91
92        let client = self.client_pool.get_default_client();
93
94        let req_builder =
95            RequestBuilder::build_request(&client, reqwest_method, &full_url, &headers, body);
96
97        let req_builder = req_builder.map_err(|status| ProxyError::InvalidResponse {
98            service: service_name.to_string(),
99            reason: format!("Failed to build request: {status}"),
100        })?;
101
102        let response = match req_builder.send().await {
103            Ok(resp) => resp,
104            Err(e) => {
105                tracing::error!(service = %service_name, url = %full_url, error = %e, "Connection failed");
106                return Err(ProxyError::ConnectionFailed {
107                    service: service_name.to_string(),
108                    url: full_url.clone(),
109                    source: e,
110                });
111            },
112        };
113
114        match ResponseHandler::build_response(response) {
115            Ok(resp) => Ok(resp),
116            Err(e) => {
117                tracing::error!(service = %service_name, error = %e, "Failed to build response");
118                Err(ProxyError::InvalidResponse {
119                    service: service_name.to_string(),
120                    reason: format!("Failed to build response: {e}"),
121                })
122            },
123        }
124    }
125
126    pub async fn handle_mcp_request(
127        &self,
128        path_params: Path<(String,)>,
129        State(ctx): State<AppContext>,
130        request: Request<Body>,
131    ) -> Response<Body> {
132        let Path((service_name,)) = path_params;
133        match self.proxy_request(&service_name, "", request, ctx).await {
134            Ok(response) => response,
135            Err(e) => e.into_response(),
136        }
137    }
138
139    pub async fn handle_mcp_request_with_path(
140        &self,
141        path_params: Path<(String, String)>,
142        State(ctx): State<AppContext>,
143        request: Request<Body>,
144    ) -> Response<Body> {
145        let Path((service_name, path)) = path_params;
146        match self.proxy_request(&service_name, &path, request, ctx).await {
147            Ok(response) => response,
148            Err(e) => e.into_response(),
149        }
150    }
151
152    pub async fn handle_agent_request(
153        &self,
154        path_params: Path<(String,)>,
155        State(ctx): State<AppContext>,
156        request: Request<Body>,
157    ) -> Result<Response<Body>, StatusCode> {
158        let Path((service_name,)) = path_params;
159        self.proxy_request(&service_name, "", request, ctx)
160            .await
161            .map_err(|e| e.to_status_code())
162    }
163
164    pub async fn handle_agent_request_with_path(
165        &self,
166        path_params: Path<(String, String)>,
167        State(ctx): State<AppContext>,
168        request: Request<Body>,
169    ) -> Result<Response<Body>, StatusCode> {
170        let Path((service_name, path)) = path_params;
171        self.proxy_request(&service_name, &path, request, ctx)
172            .await
173            .map_err(|e| e.to_status_code())
174    }
175}