Skip to main content

specmock_runtime/http/
mod.rs

1//! HTTP and WebSocket server runtime.
2
3pub mod negotiate;
4pub mod openapi;
5pub mod proxy;
6pub mod router;
7pub mod ws_handler;
8
9use std::{collections::HashMap, net::SocketAddr, sync::Arc};
10
11use axum::{
12    Json, Router,
13    body::{Body, to_bytes},
14    extract::{Request, State},
15    http::{HeaderMap, Method, StatusCode},
16    response::{IntoResponse, Response},
17    routing::get,
18    serve,
19};
20use negotiate::PreferDirectives;
21use openapi::{MatchedOperation, OpenApiRuntime};
22use serde_json::Value;
23use specmock_core::{
24    MockMode, PROBLEM_JSON_CONTENT_TYPE, ProblemDetails, ValidationIssue,
25    faker::generate_json_value,
26};
27use tokio::{net::TcpListener, task::JoinHandle};
28
29use crate::{RuntimeError, ws::AsyncApiRuntime};
30
31/// Hash multiplier for path and method hashing.
32const HASH_MULTIPLIER: u64 = 131;
33
34/// Shared runtime state.
35#[derive(Clone)]
36pub struct HttpRuntime {
37    openapi: Option<OpenApiRuntime>,
38    asyncapi: Option<AsyncApiRuntime>,
39    mode: MockMode,
40    upstream: Option<url::Url>,
41    seed: u64,
42    ws_path: String,
43    /// Map from per-channel WS path to channel name.
44    ws_channel_paths: HashMap<String, String>,
45    max_body_size: usize,
46    client: hpx::Client,
47}
48
49impl std::fmt::Debug for HttpRuntime {
50    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        formatter
52            .debug_struct("HttpRuntime")
53            .field("openapi", &self.openapi)
54            .field("asyncapi", &self.asyncapi)
55            .field("mode", &self.mode)
56            .field("upstream", &self.upstream)
57            .field("seed", &self.seed)
58            .field("ws_path", &self.ws_path)
59            .field("ws_channel_paths", &self.ws_channel_paths)
60            .field("max_body_size", &self.max_body_size)
61            .finish_non_exhaustive()
62    }
63}
64
65impl HttpRuntime {
66    /// Resolve the pinned channel name for a WebSocket request path.
67    ///
68    /// Returns `Some(channel_name)` when the path matches a per-channel
69    /// route, or `None` for the catch-all default path.
70    fn resolve_ws_channel(&self, path: &str) -> Option<String> {
71        self.ws_channel_paths.get(path).cloned()
72    }
73
74    /// Build from global server config.
75    pub async fn from_config(config: &crate::ServerConfig) -> Result<Self, RuntimeError> {
76        let openapi = config.openapi_spec.as_deref().map(OpenApiRuntime::from_path).transpose()?;
77        let asyncapi =
78            config.asyncapi_spec.as_deref().map(AsyncApiRuntime::from_path).transpose()?;
79
80        if config.mode == MockMode::Proxy && config.upstream.is_none() {
81            return Err(RuntimeError::Config(
82                "proxy mode requires upstream base URL (--upstream)".to_owned(),
83            ));
84        }
85
86        let upstream = config
87            .upstream
88            .as_ref()
89            .map(|raw| {
90                raw.parse::<url::Url>().map_err(|error| {
91                    RuntimeError::Config(format!("invalid upstream URL '{raw}': {error}"))
92                })
93            })
94            .transpose()?;
95
96        let ws_base = config.ws_path.trim_end_matches('/');
97        let ws_channel_paths: HashMap<String, String> = asyncapi
98            .as_ref()
99            .map(|a| {
100                a.channel_names().into_iter().map(|ch| (format!("{ws_base}/{ch}"), ch)).collect()
101            })
102            .unwrap_or_default();
103
104        Ok(Self {
105            openapi,
106            asyncapi,
107            mode: config.mode,
108            upstream,
109            seed: config.seed,
110            ws_path: config.ws_path.clone(),
111            ws_channel_paths,
112            max_body_size: config.max_body_size,
113            client: hpx::Client::new(),
114        })
115    }
116}
117
118/// Spawn HTTP/WS server.
119pub async fn spawn_http_server(
120    runtime: HttpRuntime,
121    bind_addr: SocketAddr,
122    shutdown: Arc<tokio::sync::Notify>,
123) -> Result<(SocketAddr, JoinHandle<()>), RuntimeError> {
124    let listener = TcpListener::bind(bind_addr).await?;
125    let bound = listener.local_addr()?;
126    let state = Arc::new(runtime);
127
128    let mut app = Router::new().route(&state.ws_path, get(ws_handler::ws_upgrade_handler));
129    for ws_channel_path in state.ws_channel_paths.keys() {
130        app = app.route(ws_channel_path, get(ws_handler::ws_upgrade_handler));
131    }
132    let app = app.fallback(http_fallback_handler).with_state(Arc::clone(&state));
133
134    let task = tokio::spawn(async move {
135        let _ignored = serve(listener, app)
136            .with_graceful_shutdown(async move {
137                shutdown.notified().await;
138            })
139            .await;
140    });
141
142    Ok((bound, task))
143}
144
145async fn http_fallback_handler(
146    State(runtime): State<Arc<HttpRuntime>>,
147    request: Request,
148) -> Response {
149    let method = request.method().clone();
150    let uri = request.uri().clone();
151    let headers = request.headers().clone();
152
153    let body_bytes = match to_bytes(request.into_body(), runtime.max_body_size).await {
154        Ok(bytes) => bytes,
155        Err(_error) => {
156            return problem_response(ProblemDetails::payload_too_large(&format!(
157                "request body exceeds maximum size of {} bytes",
158                runtime.max_body_size
159            )));
160        }
161    };
162
163    let Some(openapi) = &runtime.openapi else {
164        return problem_response(ProblemDetails::not_found("no OpenAPI runtime configured"));
165    };
166
167    let path = uri.path().to_owned();
168    let Some(matched) = openapi.match_operation(&method, &path) else {
169        return problem_response(ProblemDetails::not_found("operation not found"));
170    };
171
172    // Content-Type validation: if operation declares a request body schema and the body
173    // is non-empty, require a JSON-compatible Content-Type.
174    if matched.operation.request_body_schema.is_some() &&
175        !body_bytes.is_empty() &&
176        !header_is_json(&headers)
177    {
178        return problem_response(ProblemDetails::unsupported_media_type(
179            "Content-Type must be application/json for this operation",
180        ));
181    }
182
183    let query_params = parse_query(uri.query());
184    let request_body_json = match parse_optional_json_body(
185        &headers,
186        &body_bytes,
187        matched.operation.request_body_schema.is_some(),
188    ) {
189        Ok(value) => value,
190        Err(issue) => return error_response(StatusCode::BAD_REQUEST, vec![issue]),
191    };
192
193    let validation_issues =
194        validate_http_request(&matched, &query_params, &headers, request_body_json.as_ref());
195    if !validation_issues.is_empty() {
196        return error_response(StatusCode::BAD_REQUEST, validation_issues);
197    }
198
199    if runtime.mode == MockMode::Proxy &&
200        let Some(upstream) = &runtime.upstream
201    {
202        return proxy::proxy_request(
203            runtime.as_ref(),
204            upstream,
205            &method,
206            &uri,
207            &headers,
208            &body_bytes,
209            &matched,
210        )
211        .await;
212    }
213
214    let prefer = PreferDirectives::from_headers(&headers);
215    let seed = runtime.seed ^ hash_path_and_method(&path, &method);
216    let response = match matched.operation.mock_response(seed, &prefer) {
217        Ok(mock_response) => {
218            if let Some(body) = mock_response.body {
219                json_response(
220                    StatusCode::from_u16(mock_response.status).unwrap_or(StatusCode::OK),
221                    &body,
222                )
223            } else {
224                Response::builder()
225                    .status(StatusCode::from_u16(mock_response.status).unwrap_or(StatusCode::OK))
226                    .body(Body::empty())
227                    .unwrap_or_else(|_error| Response::new(Body::empty()))
228            }
229        }
230        Err(error) => {
231            return error_response(
232                StatusCode::INTERNAL_SERVER_ERROR,
233                vec![ValidationIssue {
234                    instance_pointer: "/response".to_owned(),
235                    schema_pointer: "#/responses".to_owned(),
236                    keyword: "response".to_owned(),
237                    message: error.to_string(),
238                }],
239            );
240        }
241    };
242
243    // Fire callbacks asynchronously (fire-and-forget).
244    if !matched.operation.callbacks.is_empty() {
245        for callback in &matched.operation.callbacks {
246            if let Some(url) = openapi::resolve_callback_url(
247                &callback.callback_url_expression,
248                request_body_json.as_ref(),
249            ) {
250                let client = runtime.client.clone();
251                let cb_method = callback.method.clone();
252                let cb_schema = callback.request_body_schema.clone();
253                tokio::spawn(async move {
254                    fire_callback(&client, &url, &cb_method, cb_schema.as_ref(), seed).await;
255                });
256            }
257        }
258    }
259
260    response
261}
262
263fn validate_http_request(
264    matched: &MatchedOperation<'_>,
265    query_params: &HashMap<String, Vec<String>>,
266    headers: &HeaderMap,
267    body_json: Option<&Value>,
268) -> Vec<ValidationIssue> {
269    matched.operation.validate_request(&matched.path_params, query_params, headers, body_json)
270}
271
272/// Fire an outbound callback request. Errors are logged but never propagated.
273async fn fire_callback(
274    client: &hpx::Client,
275    url: &str,
276    method: &Method,
277    schema: Option<&Value>,
278    seed: u64,
279) {
280    let body = schema.and_then(|s| generate_json_value(s, seed).ok());
281    let mut req = client.request(method.clone(), url);
282    if let Some(body) = body {
283        let encoded = serde_json::to_vec(&body).unwrap_or_default();
284        req = req.header("content-type", "application/json").body(encoded);
285    }
286    match req.send().await {
287        Ok(response) => tracing::info!(status = %response.status(), url, "callback fired"),
288        Err(error) => tracing::warn!(%error, url, "callback failed"),
289    }
290}
291
292fn parse_optional_json_body(
293    headers: &HeaderMap,
294    bytes: &[u8],
295    should_parse: bool,
296) -> Result<Option<Value>, ValidationIssue> {
297    if !should_parse || bytes.is_empty() {
298        return Ok(None);
299    }
300    if !header_is_json(headers) {
301        return Ok(None);
302    }
303    serde_json::from_slice::<Value>(bytes).map(Some).map_err(|error| ValidationIssue {
304        instance_pointer: "/body".to_owned(),
305        schema_pointer: "#/requestBody".to_owned(),
306        keyword: "json".to_owned(),
307        message: format!("invalid json request body: {error}"),
308    })
309}
310
311fn parse_query(query: Option<&str>) -> HashMap<String, Vec<String>> {
312    let mut out: HashMap<String, Vec<String>> = HashMap::new();
313    if let Some(raw) = query {
314        for (key, value) in url::form_urlencoded::parse(raw.as_bytes()) {
315            out.entry(key.into_owned()).or_default().push(value.into_owned());
316        }
317    }
318    out
319}
320
321fn header_is_json(headers: &HeaderMap) -> bool {
322    headers
323        .get(axum::http::header::CONTENT_TYPE)
324        .and_then(|value| value.to_str().ok())
325        .is_some_and(|value| value.to_ascii_lowercase().contains("application/json"))
326}
327
328fn error_response(status: StatusCode, issues: Vec<ValidationIssue>) -> Response {
329    let problem = ProblemDetails::validation_error(status.as_u16(), issues);
330    problem_response(problem)
331}
332
333fn problem_response(problem: ProblemDetails) -> Response {
334    let status = StatusCode::from_u16(problem.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
335    let body = serde_json::to_vec(&problem).unwrap_or_default();
336    Response::builder()
337        .status(status)
338        .header(axum::http::header::CONTENT_TYPE, PROBLEM_JSON_CONTENT_TYPE)
339        .body(Body::from(body))
340        .unwrap_or_else(|_| Response::new(Body::empty()))
341}
342
343fn json_response(status: StatusCode, body: &Value) -> Response {
344    (status, Json(body.clone())).into_response()
345}
346
347fn hash_path_and_method(path: &str, method: &Method) -> u64 {
348    let method_hash = method
349        .as_str()
350        .bytes()
351        .fold(0_u64, |acc, byte| acc.wrapping_mul(HASH_MULTIPLIER).wrapping_add(u64::from(byte)));
352    path.bytes().fold(method_hash, |acc, byte| {
353        acc.wrapping_mul(HASH_MULTIPLIER).wrapping_add(u64::from(byte))
354    })
355}