Skip to main content

specmock_runtime/http/
proxy.rs

1//! Proxy handler for forwarding requests to upstream servers.
2
3use axum::{
4    body::Body,
5    http::{HeaderMap, Method, StatusCode},
6    response::Response,
7};
8use serde_json::Value;
9use specmock_core::{ValidationIssue, validate::validate_instance};
10
11use super::{HttpRuntime, error_response, header_is_json};
12
13const HEADER_HOST: &str = "host";
14const HEADER_CONTENT_LENGTH: &str = "content-length";
15
16/// Proxy a request to the upstream server and validate the response.
17pub async fn proxy_request(
18    runtime: &HttpRuntime,
19    upstream: &url::Url,
20    method: &Method,
21    uri: &axum::http::Uri,
22    headers: &HeaderMap,
23    body_bytes: &[u8],
24    matched: &super::openapi::MatchedOperation<'_>,
25) -> Response {
26    let target_url = format!(
27        "{}{}{}",
28        upstream.as_str().trim_end_matches('/'),
29        uri.path(),
30        uri.query().map_or_else(String::new, |query| format!("?{query}"))
31    );
32
33    let mut request_builder =
34        runtime.client.request(method.clone(), target_url).body(body_bytes.to_vec());
35    for (name, value) in headers {
36        let lower = name.as_str().to_ascii_lowercase();
37        if lower == HEADER_HOST || lower == HEADER_CONTENT_LENGTH {
38            continue;
39        }
40        request_builder = request_builder.header(name, value);
41    }
42
43    // Set Host header from upstream URL so the proxy target receives the
44    // correct virtual-host identity.
45    if let Some(host) = upstream.host_str() {
46        let host_value = if let Some(port) = upstream.port() {
47            format!("{host}:{port}")
48        } else {
49            host.to_owned()
50        };
51        request_builder = request_builder.header("Host", host_value);
52    }
53
54    let upstream_response = match request_builder.send().await {
55        Ok(response) => response,
56        Err(error) => {
57            return error_response(
58                StatusCode::BAD_GATEWAY,
59                vec![ValidationIssue {
60                    instance_pointer: "/proxy".to_owned(),
61                    schema_pointer: "#".to_owned(),
62                    keyword: "proxy".to_owned(),
63                    message: format!("upstream request failed: {error}"),
64                }],
65            );
66        }
67    };
68
69    let status = upstream_response.status();
70    let response_headers = upstream_response.headers().clone();
71    let response_bytes = match upstream_response.bytes().await {
72        Ok(bytes) => bytes,
73        Err(error) => {
74            return error_response(
75                StatusCode::BAD_GATEWAY,
76                vec![ValidationIssue {
77                    instance_pointer: "/response/body".to_owned(),
78                    schema_pointer: "#".to_owned(),
79                    keyword: "proxy".to_owned(),
80                    message: format!("failed to read upstream response body: {error}"),
81                }],
82            );
83        }
84    };
85
86    if let Some(schema) = matched.operation.response_schema_for_status(status.as_u16()) &&
87        header_is_json(&response_headers)
88    {
89        match serde_json::from_slice::<Value>(&response_bytes) {
90            Ok(response_json) => match validate_instance(schema, &response_json) {
91                Ok(issues) if !issues.is_empty() => {
92                    return error_response(StatusCode::BAD_GATEWAY, issues);
93                }
94                Ok(_issues) => {}
95                Err(error) => {
96                    return error_response(
97                        StatusCode::BAD_GATEWAY,
98                        vec![ValidationIssue {
99                            instance_pointer: "/response".to_owned(),
100                            schema_pointer: "#/responses".to_owned(),
101                            keyword: "schema".to_owned(),
102                            message: error.to_string(),
103                        }],
104                    );
105                }
106            },
107            Err(error) => {
108                return error_response(
109                    StatusCode::BAD_GATEWAY,
110                    vec![ValidationIssue {
111                        instance_pointer: "/response/body".to_owned(),
112                        schema_pointer: "#/responses".to_owned(),
113                        keyword: "json".to_owned(),
114                        message: format!("upstream response is not valid json: {error}"),
115                    }],
116                );
117            }
118        }
119    }
120
121    let mut builder = Response::builder().status(status);
122    if let Some(target_headers) = builder.headers_mut() {
123        for (name, value) in &response_headers {
124            target_headers.append(name, value.clone());
125        }
126    }
127    builder.body(Body::from(response_bytes)).unwrap_or_else(|_error| Response::new(Body::empty()))
128}