Skip to main content

specmock_runtime/http/
ws_handler.rs

1//! WebSocket handler for AsyncAPI runtime.
2
3use std::sync::Arc;
4
5use axum::{
6    Json,
7    extract::{
8        State,
9        ws::{Message, WebSocket, WebSocketUpgrade},
10    },
11    http::StatusCode,
12    response::IntoResponse,
13};
14use futures_util::StreamExt;
15use serde_json::Value;
16use specmock_core::ValidationIssue;
17use tokio::time::{Duration, Instant};
18
19use super::HttpRuntime;
20use crate::ws::WsOutcome;
21
22/// Maximum number of WebSocket messages per second per connection.
23const MAX_WS_MESSAGES_PER_SECOND: u32 = 100;
24
25/// WebSocket rate limiter state.
26#[derive(Debug)]
27struct RateLimiter {
28    message_count: u32,
29    window_start: Instant,
30}
31
32impl RateLimiter {
33    fn new() -> Self {
34        Self { message_count: 0, window_start: Instant::now() }
35    }
36
37    fn check_and_update(&mut self) -> bool {
38        let now = Instant::now();
39        if now.duration_since(self.window_start) >= Duration::from_secs(1) {
40            self.window_start = now;
41            self.message_count = 1;
42            return true;
43        }
44
45        if self.message_count >= MAX_WS_MESSAGES_PER_SECOND {
46            return false;
47        }
48
49        self.message_count += 1;
50        true
51    }
52}
53
54/// Handle WebSocket upgrade requests.
55pub async fn ws_upgrade_handler(
56    ws: WebSocketUpgrade,
57    State(runtime): State<Arc<HttpRuntime>>,
58    uri: axum::http::Uri,
59) -> impl IntoResponse {
60    if runtime.asyncapi.is_none() {
61        return (
62            StatusCode::NOT_FOUND,
63            Json(serde_json::json!({"error":"asyncapi runtime is not configured"})),
64        )
65            .into_response();
66    }
67
68    let pinned_channel = runtime.resolve_ws_channel(uri.path());
69    ws.on_upgrade(move |socket| ws_socket_loop(socket, runtime, pinned_channel)).into_response()
70}
71
72/// WebSocket message processing loop.
73async fn ws_socket_loop(
74    mut socket: WebSocket,
75    runtime: Arc<HttpRuntime>,
76    pinned_channel: Option<String>,
77) {
78    let mut rate_limiter = RateLimiter::new();
79
80    while let Some(next_item) = socket.next().await {
81        let Ok(message) = next_item else {
82            break;
83        };
84
85        let Message::Text(text) = message else {
86            continue;
87        };
88
89        // Check rate limit
90        if !rate_limiter.check_and_update() {
91            let error_response = serde_json::json!({
92                "type": "error",
93                "errors": [{
94                    "instance_pointer": "/",
95                    "schema_pointer": "#",
96                    "keyword": "rate_limit",
97                    "message": format!("rate limit exceeded: {} messages per second", MAX_WS_MESSAGES_PER_SECOND)
98                }]
99            });
100            if socket.send(Message::Text(error_response.to_string().into())).await.is_err() {
101                break;
102            }
103            continue;
104        }
105
106        let outcome = runtime.asyncapi.as_ref().map_or_else(
107            || WsOutcome::Error {
108                errors: vec![ValidationIssue {
109                    instance_pointer: "/".to_owned(),
110                    schema_pointer: "#".to_owned(),
111                    keyword: "runtime".to_owned(),
112                    message: "asyncapi runtime is not configured".to_owned(),
113                }],
114            },
115            |asyncapi| {
116                if let Some(channel) = &pinned_channel {
117                    // Wrap raw payload in an explicit envelope for the pinned channel.
118                    let envelope = match serde_json::from_str::<Value>(&text) {
119                        Ok(payload) => serde_json::json!({"channel": channel, "payload": payload}),
120                        Err(_error) => {
121                            // If the raw text is not valid JSON, pass it through and let
122                            // handle_message produce the parse-error outcome.
123                            return asyncapi.handle_message(&text, runtime.seed);
124                        }
125                    };
126                    asyncapi.handle_message(&envelope.to_string(), runtime.seed)
127                } else {
128                    asyncapi.handle_message(&text, runtime.seed)
129                }
130            },
131        );
132
133        let encoded = match serde_json::to_string(&outcome) {
134            Ok(value) => value,
135            Err(error) => {
136                let fallback = serde_json::json!({
137                    "type": "error",
138                    "errors": [{
139                        "instance_pointer": "/",
140                        "schema_pointer": "#",
141                        "keyword": "json",
142                        "message": format!("failed to encode ws response: {error}")
143                    }]
144                });
145                fallback.to_string()
146            }
147        };
148
149        if socket.send(Message::Text(encoded.into())).await.is_err() {
150            break;
151        }
152    }
153}