sentinel_proxy/websocket/
inspector.rs

1//! WebSocket frame inspector for agent integration.
2//!
3//! Sends WebSocket frames to subscribed agents for inspection and applies
4//! their decisions (allow, drop, or close).
5
6use base64::{engine::general_purpose::STANDARD, Engine as _};
7use sentinel_agent_protocol::{WebSocketDecision, WebSocketFrameEvent};
8use sentinel_common::observability::RequestMetrics;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Instant;
12use tracing::{debug, trace, warn};
13
14use super::codec::WebSocketFrame;
15use crate::agents::AgentManager;
16
17/// Result of inspecting a WebSocket frame
18#[derive(Debug, Clone)]
19pub enum InspectionResult {
20    /// Allow the frame to pass through
21    Allow,
22    /// Drop this frame (don't forward)
23    Drop,
24    /// Close the WebSocket connection
25    Close { code: u16, reason: String },
26}
27
28impl From<WebSocketDecision> for InspectionResult {
29    fn from(decision: WebSocketDecision) -> Self {
30        match decision {
31            WebSocketDecision::Allow => InspectionResult::Allow,
32            WebSocketDecision::Drop => InspectionResult::Drop,
33            WebSocketDecision::Close { code, reason } => InspectionResult::Close { code, reason },
34        }
35    }
36}
37
38/// WebSocket frame inspector
39///
40/// Handles bidirectional frame inspection by sending frames to agents
41/// and applying their decisions.
42pub struct WebSocketInspector {
43    /// Agent manager for sending events
44    agent_manager: Arc<AgentManager>,
45    /// Route ID for this connection
46    route_id: String,
47    /// Correlation ID (from the original upgrade request)
48    correlation_id: String,
49    /// Client IP address
50    client_ip: String,
51    /// Frame index counter for client -> server direction
52    client_frame_index: AtomicU64,
53    /// Frame index counter for server -> client direction
54    server_frame_index: AtomicU64,
55    /// Timeout for agent calls in milliseconds
56    timeout_ms: u64,
57    /// Metrics collector
58    metrics: Option<Arc<RequestMetrics>>,
59}
60
61impl WebSocketInspector {
62    /// Create a new WebSocket inspector
63    pub fn new(
64        agent_manager: Arc<AgentManager>,
65        route_id: String,
66        correlation_id: String,
67        client_ip: String,
68        timeout_ms: u64,
69    ) -> Self {
70        Self::with_metrics(
71            agent_manager,
72            route_id,
73            correlation_id,
74            client_ip,
75            timeout_ms,
76            None,
77        )
78    }
79
80    /// Create a new WebSocket inspector with metrics
81    pub fn with_metrics(
82        agent_manager: Arc<AgentManager>,
83        route_id: String,
84        correlation_id: String,
85        client_ip: String,
86        timeout_ms: u64,
87        metrics: Option<Arc<RequestMetrics>>,
88    ) -> Self {
89        debug!(
90            route_id = %route_id,
91            correlation_id = %correlation_id,
92            "Creating WebSocket inspector"
93        );
94
95        // Record the WebSocket connection
96        if let Some(ref m) = metrics {
97            m.record_websocket_connection(&route_id);
98        }
99
100        Self {
101            agent_manager,
102            route_id,
103            correlation_id,
104            client_ip,
105            client_frame_index: AtomicU64::new(0),
106            server_frame_index: AtomicU64::new(0),
107            timeout_ms,
108            metrics,
109        }
110    }
111
112    /// Inspect a frame from client to server
113    pub async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
114        let frame_index = self.client_frame_index.fetch_add(1, Ordering::SeqCst);
115
116        trace!(
117            correlation_id = %self.correlation_id,
118            frame_index = frame_index,
119            opcode = ?frame.opcode,
120            "Inspecting client frame"
121        );
122
123        self.inspect_frame(frame, true, frame_index).await
124    }
125
126    /// Inspect a frame from server to client
127    pub async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
128        let frame_index = self.server_frame_index.fetch_add(1, Ordering::SeqCst);
129
130        trace!(
131            correlation_id = %self.correlation_id,
132            frame_index = frame_index,
133            opcode = ?frame.opcode,
134            "Inspecting server frame"
135        );
136
137        self.inspect_frame(frame, false, frame_index).await
138    }
139
140    /// Internal frame inspection
141    async fn inspect_frame(
142        &self,
143        frame: &WebSocketFrame,
144        client_to_server: bool,
145        frame_index: u64,
146    ) -> InspectionResult {
147        let start = Instant::now();
148        let direction = if client_to_server { "c2s" } else { "s2c" };
149        let opcode = frame.opcode.as_str();
150
151        // Record frame size metric
152        if let Some(ref metrics) = self.metrics {
153            metrics.record_websocket_frame_size(
154                &self.route_id,
155                direction,
156                opcode,
157                frame.payload.len(),
158            );
159        }
160
161        let event = WebSocketFrameEvent {
162            correlation_id: self.correlation_id.clone(),
163            opcode: opcode.to_string(),
164            data: STANDARD.encode(&frame.payload),
165            client_to_server,
166            frame_index,
167            fin: frame.fin,
168            route_id: Some(self.route_id.clone()),
169            client_ip: self.client_ip.clone(),
170        };
171
172        // Send to agent manager for processing
173        let result = match tokio::time::timeout(
174            std::time::Duration::from_millis(self.timeout_ms),
175            self.agent_manager
176                .process_websocket_frame(&self.route_id, event),
177        )
178        .await
179        {
180            Ok(Ok(response)) => {
181                if let Some(ws_decision) = response.websocket_decision {
182                    let result = InspectionResult::from(ws_decision);
183                    trace!(
184                        correlation_id = %self.correlation_id,
185                        frame_index = frame_index,
186                        decision = ?result,
187                        "Frame inspection complete"
188                    );
189                    result
190                } else {
191                    // No WebSocket decision means allow
192                    InspectionResult::Allow
193                }
194            }
195            Ok(Err(e)) => {
196                warn!(
197                    correlation_id = %self.correlation_id,
198                    error = %e,
199                    "Agent error during frame inspection, allowing frame"
200                );
201                // Fail-open: allow frame on agent error
202                InspectionResult::Allow
203            }
204            Err(_) => {
205                warn!(
206                    correlation_id = %self.correlation_id,
207                    timeout_ms = self.timeout_ms,
208                    "Agent timeout during frame inspection, allowing frame"
209                );
210                // Fail-open: allow frame on timeout
211                InspectionResult::Allow
212            }
213        };
214
215        // Record metrics
216        if let Some(ref metrics) = self.metrics {
217            let duration = start.elapsed();
218            metrics.record_websocket_inspection_duration(&self.route_id, duration);
219
220            let decision_str = match &result {
221                InspectionResult::Allow => "allow",
222                InspectionResult::Drop => "drop",
223                InspectionResult::Close { .. } => "close",
224            };
225            metrics.record_websocket_frame(&self.route_id, direction, opcode, decision_str);
226        }
227
228        result
229    }
230
231    /// Get the correlation ID
232    pub fn correlation_id(&self) -> &str {
233        &self.correlation_id
234    }
235
236    /// Get the route ID
237    pub fn route_id(&self) -> &str {
238        &self.route_id
239    }
240}
241
242/// Builder for WebSocketInspector
243pub struct WebSocketInspectorBuilder {
244    agent_manager: Option<Arc<AgentManager>>,
245    route_id: Option<String>,
246    correlation_id: Option<String>,
247    client_ip: Option<String>,
248    timeout_ms: u64,
249    metrics: Option<Arc<RequestMetrics>>,
250}
251
252impl Default for WebSocketInspectorBuilder {
253    fn default() -> Self {
254        Self {
255            agent_manager: None,
256            route_id: None,
257            correlation_id: None,
258            client_ip: None,
259            timeout_ms: 100, // 100ms default timeout
260            metrics: None,
261        }
262    }
263}
264
265impl WebSocketInspectorBuilder {
266    /// Create a new builder
267    pub fn new() -> Self {
268        Self::default()
269    }
270
271    /// Set the agent manager
272    pub fn agent_manager(mut self, manager: Arc<AgentManager>) -> Self {
273        self.agent_manager = Some(manager);
274        self
275    }
276
277    /// Set the route ID
278    pub fn route_id(mut self, id: impl Into<String>) -> Self {
279        self.route_id = Some(id.into());
280        self
281    }
282
283    /// Set the correlation ID
284    pub fn correlation_id(mut self, id: impl Into<String>) -> Self {
285        self.correlation_id = Some(id.into());
286        self
287    }
288
289    /// Set the client IP
290    pub fn client_ip(mut self, ip: impl Into<String>) -> Self {
291        self.client_ip = Some(ip.into());
292        self
293    }
294
295    /// Set the timeout in milliseconds
296    pub fn timeout_ms(mut self, ms: u64) -> Self {
297        self.timeout_ms = ms;
298        self
299    }
300
301    /// Set the metrics collector
302    pub fn metrics(mut self, metrics: Arc<RequestMetrics>) -> Self {
303        self.metrics = Some(metrics);
304        self
305    }
306
307    /// Build the inspector
308    pub fn build(self) -> Option<WebSocketInspector> {
309        Some(WebSocketInspector::with_metrics(
310            self.agent_manager?,
311            self.route_id?,
312            self.correlation_id?,
313            self.client_ip?,
314            self.timeout_ms,
315            self.metrics,
316        ))
317    }
318}