specmock_runtime/http/
ws_handler.rs1use std::sync::Arc;
4
5use axum::{
6 Json,
7 extract::{
8 State,
9 ws::{Message, WebSocket, WebSocketUpgrade},
10 },
11 http::StatusCode,
12 response::IntoResponse,
13};
14use futures_util::StreamExt;
15use serde_json::Value;
16use specmock_core::ValidationIssue;
17use tokio::time::{Duration, Instant};
18
19use super::HttpRuntime;
20use crate::ws::WsOutcome;
21
22const MAX_WS_MESSAGES_PER_SECOND: u32 = 100;
24
25#[derive(Debug)]
27struct RateLimiter {
28 message_count: u32,
29 window_start: Instant,
30}
31
32impl RateLimiter {
33 fn new() -> Self {
34 Self { message_count: 0, window_start: Instant::now() }
35 }
36
37 fn check_and_update(&mut self) -> bool {
38 let now = Instant::now();
39 if now.duration_since(self.window_start) >= Duration::from_secs(1) {
40 self.window_start = now;
41 self.message_count = 1;
42 return true;
43 }
44
45 if self.message_count >= MAX_WS_MESSAGES_PER_SECOND {
46 return false;
47 }
48
49 self.message_count += 1;
50 true
51 }
52}
53
54pub async fn ws_upgrade_handler(
56 ws: WebSocketUpgrade,
57 State(runtime): State<Arc<HttpRuntime>>,
58 uri: axum::http::Uri,
59) -> impl IntoResponse {
60 if runtime.asyncapi.is_none() {
61 return (
62 StatusCode::NOT_FOUND,
63 Json(serde_json::json!({"error":"asyncapi runtime is not configured"})),
64 )
65 .into_response();
66 }
67
68 let pinned_channel = runtime.resolve_ws_channel(uri.path());
69 ws.on_upgrade(move |socket| ws_socket_loop(socket, runtime, pinned_channel)).into_response()
70}
71
72async fn ws_socket_loop(
74 mut socket: WebSocket,
75 runtime: Arc<HttpRuntime>,
76 pinned_channel: Option<String>,
77) {
78 let mut rate_limiter = RateLimiter::new();
79
80 while let Some(next_item) = socket.next().await {
81 let Ok(message) = next_item else {
82 break;
83 };
84
85 let Message::Text(text) = message else {
86 continue;
87 };
88
89 if !rate_limiter.check_and_update() {
91 let error_response = serde_json::json!({
92 "type": "error",
93 "errors": [{
94 "instance_pointer": "/",
95 "schema_pointer": "#",
96 "keyword": "rate_limit",
97 "message": format!("rate limit exceeded: {} messages per second", MAX_WS_MESSAGES_PER_SECOND)
98 }]
99 });
100 if socket.send(Message::Text(error_response.to_string().into())).await.is_err() {
101 break;
102 }
103 continue;
104 }
105
106 let outcome = runtime.asyncapi.as_ref().map_or_else(
107 || WsOutcome::Error {
108 errors: vec![ValidationIssue {
109 instance_pointer: "/".to_owned(),
110 schema_pointer: "#".to_owned(),
111 keyword: "runtime".to_owned(),
112 message: "asyncapi runtime is not configured".to_owned(),
113 }],
114 },
115 |asyncapi| {
116 if let Some(channel) = &pinned_channel {
117 let envelope = match serde_json::from_str::<Value>(&text) {
119 Ok(payload) => serde_json::json!({"channel": channel, "payload": payload}),
120 Err(_error) => {
121 return asyncapi.handle_message(&text, runtime.seed);
124 }
125 };
126 asyncapi.handle_message(&envelope.to_string(), runtime.seed)
127 } else {
128 asyncapi.handle_message(&text, runtime.seed)
129 }
130 },
131 );
132
133 let encoded = match serde_json::to_string(&outcome) {
134 Ok(value) => value,
135 Err(error) => {
136 let fallback = serde_json::json!({
137 "type": "error",
138 "errors": [{
139 "instance_pointer": "/",
140 "schema_pointer": "#",
141 "keyword": "json",
142 "message": format!("failed to encode ws response: {error}")
143 }]
144 });
145 fallback.to_string()
146 }
147 };
148
149 if socket.send(Message::Text(encoded.into())).await.is_err() {
150 break;
151 }
152 }
153}