1pub mod negotiate;
4pub mod openapi;
5pub mod proxy;
6pub mod router;
7pub mod ws_handler;
8
9use std::{collections::HashMap, net::SocketAddr, sync::Arc};
10
11use axum::{
12 Json, Router,
13 body::{Body, to_bytes},
14 extract::{Request, State},
15 http::{HeaderMap, Method, StatusCode},
16 response::{IntoResponse, Response},
17 routing::get,
18 serve,
19};
20use negotiate::PreferDirectives;
21use openapi::{MatchedOperation, OpenApiRuntime};
22use serde_json::Value;
23use specmock_core::{
24 MockMode, PROBLEM_JSON_CONTENT_TYPE, ProblemDetails, ValidationIssue,
25 faker::generate_json_value,
26};
27use tokio::{net::TcpListener, task::JoinHandle};
28
29use crate::{RuntimeError, ws::AsyncApiRuntime};
30
31const HASH_MULTIPLIER: u64 = 131;
33
34#[derive(Clone)]
36pub struct HttpRuntime {
37 openapi: Option<OpenApiRuntime>,
38 asyncapi: Option<AsyncApiRuntime>,
39 mode: MockMode,
40 upstream: Option<url::Url>,
41 seed: u64,
42 ws_path: String,
43 ws_channel_paths: HashMap<String, String>,
45 max_body_size: usize,
46 client: hpx::Client,
47}
48
49impl std::fmt::Debug for HttpRuntime {
50 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 formatter
52 .debug_struct("HttpRuntime")
53 .field("openapi", &self.openapi)
54 .field("asyncapi", &self.asyncapi)
55 .field("mode", &self.mode)
56 .field("upstream", &self.upstream)
57 .field("seed", &self.seed)
58 .field("ws_path", &self.ws_path)
59 .field("ws_channel_paths", &self.ws_channel_paths)
60 .field("max_body_size", &self.max_body_size)
61 .finish_non_exhaustive()
62 }
63}
64
65impl HttpRuntime {
66 fn resolve_ws_channel(&self, path: &str) -> Option<String> {
71 self.ws_channel_paths.get(path).cloned()
72 }
73
74 pub async fn from_config(config: &crate::ServerConfig) -> Result<Self, RuntimeError> {
76 let openapi = config.openapi_spec.as_deref().map(OpenApiRuntime::from_path).transpose()?;
77 let asyncapi =
78 config.asyncapi_spec.as_deref().map(AsyncApiRuntime::from_path).transpose()?;
79
80 if config.mode == MockMode::Proxy && config.upstream.is_none() {
81 return Err(RuntimeError::Config(
82 "proxy mode requires upstream base URL (--upstream)".to_owned(),
83 ));
84 }
85
86 let upstream = config
87 .upstream
88 .as_ref()
89 .map(|raw| {
90 raw.parse::<url::Url>().map_err(|error| {
91 RuntimeError::Config(format!("invalid upstream URL '{raw}': {error}"))
92 })
93 })
94 .transpose()?;
95
96 let ws_base = config.ws_path.trim_end_matches('/');
97 let ws_channel_paths: HashMap<String, String> = asyncapi
98 .as_ref()
99 .map(|a| {
100 a.channel_names().into_iter().map(|ch| (format!("{ws_base}/{ch}"), ch)).collect()
101 })
102 .unwrap_or_default();
103
104 Ok(Self {
105 openapi,
106 asyncapi,
107 mode: config.mode,
108 upstream,
109 seed: config.seed,
110 ws_path: config.ws_path.clone(),
111 ws_channel_paths,
112 max_body_size: config.max_body_size,
113 client: hpx::Client::new(),
114 })
115 }
116}
117
118pub async fn spawn_http_server(
120 runtime: HttpRuntime,
121 bind_addr: SocketAddr,
122 shutdown: Arc<tokio::sync::Notify>,
123) -> Result<(SocketAddr, JoinHandle<()>), RuntimeError> {
124 let listener = TcpListener::bind(bind_addr).await?;
125 let bound = listener.local_addr()?;
126 let state = Arc::new(runtime);
127
128 let mut app = Router::new().route(&state.ws_path, get(ws_handler::ws_upgrade_handler));
129 for ws_channel_path in state.ws_channel_paths.keys() {
130 app = app.route(ws_channel_path, get(ws_handler::ws_upgrade_handler));
131 }
132 let app = app.fallback(http_fallback_handler).with_state(Arc::clone(&state));
133
134 let task = tokio::spawn(async move {
135 let _ignored = serve(listener, app)
136 .with_graceful_shutdown(async move {
137 shutdown.notified().await;
138 })
139 .await;
140 });
141
142 Ok((bound, task))
143}
144
145async fn http_fallback_handler(
146 State(runtime): State<Arc<HttpRuntime>>,
147 request: Request,
148) -> Response {
149 let method = request.method().clone();
150 let uri = request.uri().clone();
151 let headers = request.headers().clone();
152
153 let body_bytes = match to_bytes(request.into_body(), runtime.max_body_size).await {
154 Ok(bytes) => bytes,
155 Err(_error) => {
156 return problem_response(ProblemDetails::payload_too_large(&format!(
157 "request body exceeds maximum size of {} bytes",
158 runtime.max_body_size
159 )));
160 }
161 };
162
163 let Some(openapi) = &runtime.openapi else {
164 return problem_response(ProblemDetails::not_found("no OpenAPI runtime configured"));
165 };
166
167 let path = uri.path().to_owned();
168 let Some(matched) = openapi.match_operation(&method, &path) else {
169 return problem_response(ProblemDetails::not_found("operation not found"));
170 };
171
172 if matched.operation.request_body_schema.is_some() &&
175 !body_bytes.is_empty() &&
176 !header_is_json(&headers)
177 {
178 return problem_response(ProblemDetails::unsupported_media_type(
179 "Content-Type must be application/json for this operation",
180 ));
181 }
182
183 let query_params = parse_query(uri.query());
184 let request_body_json = match parse_optional_json_body(
185 &headers,
186 &body_bytes,
187 matched.operation.request_body_schema.is_some(),
188 ) {
189 Ok(value) => value,
190 Err(issue) => return error_response(StatusCode::BAD_REQUEST, vec![issue]),
191 };
192
193 let validation_issues =
194 validate_http_request(&matched, &query_params, &headers, request_body_json.as_ref());
195 if !validation_issues.is_empty() {
196 return error_response(StatusCode::BAD_REQUEST, validation_issues);
197 }
198
199 if runtime.mode == MockMode::Proxy &&
200 let Some(upstream) = &runtime.upstream
201 {
202 return proxy::proxy_request(
203 runtime.as_ref(),
204 upstream,
205 &method,
206 &uri,
207 &headers,
208 &body_bytes,
209 &matched,
210 )
211 .await;
212 }
213
214 let prefer = PreferDirectives::from_headers(&headers);
215 let seed = runtime.seed ^ hash_path_and_method(&path, &method);
216 let response = match matched.operation.mock_response(seed, &prefer) {
217 Ok(mock_response) => {
218 if let Some(body) = mock_response.body {
219 json_response(
220 StatusCode::from_u16(mock_response.status).unwrap_or(StatusCode::OK),
221 &body,
222 )
223 } else {
224 Response::builder()
225 .status(StatusCode::from_u16(mock_response.status).unwrap_or(StatusCode::OK))
226 .body(Body::empty())
227 .unwrap_or_else(|_error| Response::new(Body::empty()))
228 }
229 }
230 Err(error) => {
231 return error_response(
232 StatusCode::INTERNAL_SERVER_ERROR,
233 vec![ValidationIssue {
234 instance_pointer: "/response".to_owned(),
235 schema_pointer: "#/responses".to_owned(),
236 keyword: "response".to_owned(),
237 message: error.to_string(),
238 }],
239 );
240 }
241 };
242
243 if !matched.operation.callbacks.is_empty() {
245 for callback in &matched.operation.callbacks {
246 if let Some(url) = openapi::resolve_callback_url(
247 &callback.callback_url_expression,
248 request_body_json.as_ref(),
249 ) {
250 let client = runtime.client.clone();
251 let cb_method = callback.method.clone();
252 let cb_schema = callback.request_body_schema.clone();
253 tokio::spawn(async move {
254 fire_callback(&client, &url, &cb_method, cb_schema.as_ref(), seed).await;
255 });
256 }
257 }
258 }
259
260 response
261}
262
263fn validate_http_request(
264 matched: &MatchedOperation<'_>,
265 query_params: &HashMap<String, Vec<String>>,
266 headers: &HeaderMap,
267 body_json: Option<&Value>,
268) -> Vec<ValidationIssue> {
269 matched.operation.validate_request(&matched.path_params, query_params, headers, body_json)
270}
271
272async fn fire_callback(
274 client: &hpx::Client,
275 url: &str,
276 method: &Method,
277 schema: Option<&Value>,
278 seed: u64,
279) {
280 let body = schema.and_then(|s| generate_json_value(s, seed).ok());
281 let mut req = client.request(method.clone(), url);
282 if let Some(body) = body {
283 let encoded = serde_json::to_vec(&body).unwrap_or_default();
284 req = req.header("content-type", "application/json").body(encoded);
285 }
286 match req.send().await {
287 Ok(response) => tracing::info!(status = %response.status(), url, "callback fired"),
288 Err(error) => tracing::warn!(%error, url, "callback failed"),
289 }
290}
291
292fn parse_optional_json_body(
293 headers: &HeaderMap,
294 bytes: &[u8],
295 should_parse: bool,
296) -> Result<Option<Value>, ValidationIssue> {
297 if !should_parse || bytes.is_empty() {
298 return Ok(None);
299 }
300 if !header_is_json(headers) {
301 return Ok(None);
302 }
303 serde_json::from_slice::<Value>(bytes).map(Some).map_err(|error| ValidationIssue {
304 instance_pointer: "/body".to_owned(),
305 schema_pointer: "#/requestBody".to_owned(),
306 keyword: "json".to_owned(),
307 message: format!("invalid json request body: {error}"),
308 })
309}
310
311fn parse_query(query: Option<&str>) -> HashMap<String, Vec<String>> {
312 let mut out: HashMap<String, Vec<String>> = HashMap::new();
313 if let Some(raw) = query {
314 for (key, value) in url::form_urlencoded::parse(raw.as_bytes()) {
315 out.entry(key.into_owned()).or_default().push(value.into_owned());
316 }
317 }
318 out
319}
320
321fn header_is_json(headers: &HeaderMap) -> bool {
322 headers
323 .get(axum::http::header::CONTENT_TYPE)
324 .and_then(|value| value.to_str().ok())
325 .is_some_and(|value| value.to_ascii_lowercase().contains("application/json"))
326}
327
328fn error_response(status: StatusCode, issues: Vec<ValidationIssue>) -> Response {
329 let problem = ProblemDetails::validation_error(status.as_u16(), issues);
330 problem_response(problem)
331}
332
333fn problem_response(problem: ProblemDetails) -> Response {
334 let status = StatusCode::from_u16(problem.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
335 let body = serde_json::to_vec(&problem).unwrap_or_default();
336 Response::builder()
337 .status(status)
338 .header(axum::http::header::CONTENT_TYPE, PROBLEM_JSON_CONTENT_TYPE)
339 .body(Body::from(body))
340 .unwrap_or_else(|_| Response::new(Body::empty()))
341}
342
343fn json_response(status: StatusCode, body: &Value) -> Response {
344 (status, Json(body.clone())).into_response()
345}
346
347fn hash_path_and_method(path: &str, method: &Method) -> u64 {
348 let method_hash = method
349 .as_str()
350 .bytes()
351 .fold(0_u64, |acc, byte| acc.wrapping_mul(HASH_MULTIPLIER).wrapping_add(u64::from(byte)));
352 path.bytes().fold(method_hash, |acc, byte| {
353 acc.wrapping_mul(HASH_MULTIPLIER).wrapping_add(u64::from(byte))
354 })
355}