sentinel_proxy/websocket/
inspector.rs1use base64::{engine::general_purpose::STANDARD, Engine as _};
7use sentinel_agent_protocol::{WebSocketDecision, WebSocketFrameEvent};
8use sentinel_common::observability::RequestMetrics;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Instant;
12use tracing::{debug, trace, warn};
13
14use super::codec::WebSocketFrame;
15use crate::agents::AgentManager;
16
17#[derive(Debug, Clone)]
19pub enum InspectionResult {
20 Allow,
22 Drop,
24 Close { code: u16, reason: String },
26}
27
28impl From<WebSocketDecision> for InspectionResult {
29 fn from(decision: WebSocketDecision) -> Self {
30 match decision {
31 WebSocketDecision::Allow => InspectionResult::Allow,
32 WebSocketDecision::Drop => InspectionResult::Drop,
33 WebSocketDecision::Close { code, reason } => InspectionResult::Close { code, reason },
34 }
35 }
36}
37
38pub struct WebSocketInspector {
43 agent_manager: Arc<AgentManager>,
45 route_id: String,
47 correlation_id: String,
49 client_ip: String,
51 client_frame_index: AtomicU64,
53 server_frame_index: AtomicU64,
55 timeout_ms: u64,
57 metrics: Option<Arc<RequestMetrics>>,
59}
60
61impl WebSocketInspector {
62 pub fn new(
64 agent_manager: Arc<AgentManager>,
65 route_id: String,
66 correlation_id: String,
67 client_ip: String,
68 timeout_ms: u64,
69 ) -> Self {
70 Self::with_metrics(
71 agent_manager,
72 route_id,
73 correlation_id,
74 client_ip,
75 timeout_ms,
76 None,
77 )
78 }
79
80 pub fn with_metrics(
82 agent_manager: Arc<AgentManager>,
83 route_id: String,
84 correlation_id: String,
85 client_ip: String,
86 timeout_ms: u64,
87 metrics: Option<Arc<RequestMetrics>>,
88 ) -> Self {
89 debug!(
90 route_id = %route_id,
91 correlation_id = %correlation_id,
92 "Creating WebSocket inspector"
93 );
94
95 if let Some(ref m) = metrics {
97 m.record_websocket_connection(&route_id);
98 }
99
100 Self {
101 agent_manager,
102 route_id,
103 correlation_id,
104 client_ip,
105 client_frame_index: AtomicU64::new(0),
106 server_frame_index: AtomicU64::new(0),
107 timeout_ms,
108 metrics,
109 }
110 }
111
112 pub async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
114 let frame_index = self.client_frame_index.fetch_add(1, Ordering::SeqCst);
115
116 trace!(
117 correlation_id = %self.correlation_id,
118 frame_index = frame_index,
119 opcode = ?frame.opcode,
120 "Inspecting client frame"
121 );
122
123 self.inspect_frame(frame, true, frame_index).await
124 }
125
126 pub async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
128 let frame_index = self.server_frame_index.fetch_add(1, Ordering::SeqCst);
129
130 trace!(
131 correlation_id = %self.correlation_id,
132 frame_index = frame_index,
133 opcode = ?frame.opcode,
134 "Inspecting server frame"
135 );
136
137 self.inspect_frame(frame, false, frame_index).await
138 }
139
140 async fn inspect_frame(
142 &self,
143 frame: &WebSocketFrame,
144 client_to_server: bool,
145 frame_index: u64,
146 ) -> InspectionResult {
147 let start = Instant::now();
148 let direction = if client_to_server { "c2s" } else { "s2c" };
149 let opcode = frame.opcode.as_str();
150
151 if let Some(ref metrics) = self.metrics {
153 metrics.record_websocket_frame_size(
154 &self.route_id,
155 direction,
156 opcode,
157 frame.payload.len(),
158 );
159 }
160
161 let event = WebSocketFrameEvent {
162 correlation_id: self.correlation_id.clone(),
163 opcode: opcode.to_string(),
164 data: STANDARD.encode(&frame.payload),
165 client_to_server,
166 frame_index,
167 fin: frame.fin,
168 route_id: Some(self.route_id.clone()),
169 client_ip: self.client_ip.clone(),
170 };
171
172 let result = match tokio::time::timeout(
174 std::time::Duration::from_millis(self.timeout_ms),
175 self.agent_manager
176 .process_websocket_frame(&self.route_id, event),
177 )
178 .await
179 {
180 Ok(Ok(response)) => {
181 if let Some(ws_decision) = response.websocket_decision {
182 let result = InspectionResult::from(ws_decision);
183 trace!(
184 correlation_id = %self.correlation_id,
185 frame_index = frame_index,
186 decision = ?result,
187 "Frame inspection complete"
188 );
189 result
190 } else {
191 InspectionResult::Allow
193 }
194 }
195 Ok(Err(e)) => {
196 warn!(
197 correlation_id = %self.correlation_id,
198 error = %e,
199 "Agent error during frame inspection, allowing frame"
200 );
201 InspectionResult::Allow
203 }
204 Err(_) => {
205 warn!(
206 correlation_id = %self.correlation_id,
207 timeout_ms = self.timeout_ms,
208 "Agent timeout during frame inspection, allowing frame"
209 );
210 InspectionResult::Allow
212 }
213 };
214
215 if let Some(ref metrics) = self.metrics {
217 let duration = start.elapsed();
218 metrics.record_websocket_inspection_duration(&self.route_id, duration);
219
220 let decision_str = match &result {
221 InspectionResult::Allow => "allow",
222 InspectionResult::Drop => "drop",
223 InspectionResult::Close { .. } => "close",
224 };
225 metrics.record_websocket_frame(&self.route_id, direction, opcode, decision_str);
226 }
227
228 result
229 }
230
231 pub fn correlation_id(&self) -> &str {
233 &self.correlation_id
234 }
235
236 pub fn route_id(&self) -> &str {
238 &self.route_id
239 }
240}
241
242pub struct WebSocketInspectorBuilder {
244 agent_manager: Option<Arc<AgentManager>>,
245 route_id: Option<String>,
246 correlation_id: Option<String>,
247 client_ip: Option<String>,
248 timeout_ms: u64,
249 metrics: Option<Arc<RequestMetrics>>,
250}
251
252impl Default for WebSocketInspectorBuilder {
253 fn default() -> Self {
254 Self {
255 agent_manager: None,
256 route_id: None,
257 correlation_id: None,
258 client_ip: None,
259 timeout_ms: 100, metrics: None,
261 }
262 }
263}
264
265impl WebSocketInspectorBuilder {
266 pub fn new() -> Self {
268 Self::default()
269 }
270
271 pub fn agent_manager(mut self, manager: Arc<AgentManager>) -> Self {
273 self.agent_manager = Some(manager);
274 self
275 }
276
277 pub fn route_id(mut self, id: impl Into<String>) -> Self {
279 self.route_id = Some(id.into());
280 self
281 }
282
283 pub fn correlation_id(mut self, id: impl Into<String>) -> Self {
285 self.correlation_id = Some(id.into());
286 self
287 }
288
289 pub fn client_ip(mut self, ip: impl Into<String>) -> Self {
291 self.client_ip = Some(ip.into());
292 self
293 }
294
295 pub fn timeout_ms(mut self, ms: u64) -> Self {
297 self.timeout_ms = ms;
298 self
299 }
300
301 pub fn metrics(mut self, metrics: Arc<RequestMetrics>) -> Self {
303 self.metrics = Some(metrics);
304 self
305 }
306
307 pub fn build(self) -> Option<WebSocketInspector> {
309 Some(WebSocketInspector::with_metrics(
310 self.agent_manager?,
311 self.route_id?,
312 self.correlation_id?,
313 self.client_ip?,
314 self.timeout_ms,
315 self.metrics,
316 ))
317 }
318}