sentinel_proxy/websocket/
proxy.rs1use bytes::{Bytes, BytesMut};
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tracing::{debug, trace, warn};
25
26use super::codec::{WebSocketCodec, WebSocketFrame};
27use super::inspector::{InspectionResult, WebSocketInspector};
28
29#[async_trait::async_trait]
34pub trait FrameInspector: Send + Sync {
35 async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
37
38 async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
40
41 fn correlation_id(&self) -> &str;
43}
44
45#[async_trait::async_trait]
47impl FrameInspector for WebSocketInspector {
48 async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
49 WebSocketInspector::inspect_client_frame(self, frame).await
50 }
51
52 async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
53 WebSocketInspector::inspect_server_frame(self, frame).await
54 }
55
56 fn correlation_id(&self) -> &str {
57 WebSocketInspector::correlation_id(self)
58 }
59}
60
61pub struct WebSocketHandler<I: FrameInspector = WebSocketInspector> {
66 codec: WebSocketCodec,
68 inspector: Arc<I>,
70 client_buffer: Mutex<BytesMut>,
72 server_buffer: Mutex<BytesMut>,
74 should_close: Mutex<Option<CloseReason>>,
76}
77
78#[derive(Debug, Clone)]
80pub struct CloseReason {
81 pub code: u16,
82 pub reason: String,
83}
84
85#[derive(Debug)]
87pub enum ProcessResult {
88 Forward(Option<Bytes>),
90 Close(CloseReason),
92}
93
94impl<I: FrameInspector> WebSocketHandler<I> {
95 pub fn with_inspector(inspector: Arc<I>, max_frame_size: usize) -> Self {
97 debug!(
98 correlation_id = %inspector.correlation_id(),
99 max_frame_size = max_frame_size,
100 "Creating WebSocket handler"
101 );
102
103 Self {
104 codec: WebSocketCodec::new(max_frame_size),
105 inspector,
106 client_buffer: Mutex::new(BytesMut::with_capacity(4096)),
107 server_buffer: Mutex::new(BytesMut::with_capacity(4096)),
108 should_close: Mutex::new(None),
109 }
110 }
111}
112
113impl WebSocketHandler<WebSocketInspector> {
114 pub fn new(inspector: Arc<WebSocketInspector>, max_frame_size: usize) -> Self {
116 Self::with_inspector(inspector, max_frame_size)
117 }
118}
119
120impl<I: FrameInspector> WebSocketHandler<I> {
121 pub async fn process_client_data(&self, data: Option<Bytes>) -> ProcessResult {
125 if let Some(reason) = self.should_close.lock().await.clone() {
127 return ProcessResult::Close(reason);
128 }
129
130 let Some(data) = data else {
131 return ProcessResult::Forward(None);
133 };
134
135 self.process_data(data, true).await
136 }
137
138 pub async fn process_server_data(&self, data: Option<Bytes>) -> ProcessResult {
142 if let Some(reason) = self.should_close.lock().await.clone() {
144 return ProcessResult::Close(reason);
145 }
146
147 let Some(data) = data else {
148 return ProcessResult::Forward(None);
150 };
151
152 self.process_data(data, false).await
153 }
154
155 async fn process_data(&self, data: Bytes, client_to_server: bool) -> ProcessResult {
157 let buffer = if client_to_server {
158 &self.client_buffer
159 } else {
160 &self.server_buffer
161 };
162
163 let mut buf = buffer.lock().await;
164 buf.extend_from_slice(&data);
165
166 let mut output = BytesMut::new();
167 let mut frames_processed = 0;
168 let mut frames_dropped = 0;
169
170 loop {
172 match self.codec.decode_frame(&buf) {
174 Ok(Some((frame, consumed))) => {
175 frames_processed += 1;
176
177 let result = if client_to_server {
179 self.inspector.inspect_client_frame(&frame).await
180 } else {
181 self.inspector.inspect_server_frame(&frame).await
182 };
183
184 match result {
185 InspectionResult::Allow => {
186 output.extend_from_slice(&buf[..consumed]);
188 }
189 InspectionResult::Drop => {
190 frames_dropped += 1;
191 trace!(
192 correlation_id = %self.inspector.correlation_id(),
193 opcode = ?frame.opcode,
194 direction = if client_to_server { "c2s" } else { "s2c" },
195 "Dropping WebSocket frame"
196 );
197 }
199 InspectionResult::Close { code, reason } => {
200 debug!(
201 correlation_id = %self.inspector.correlation_id(),
202 code = code,
203 reason = %reason,
204 "Agent requested WebSocket close"
205 );
206
207 *self.should_close.lock().await = Some(CloseReason {
209 code,
210 reason: reason.clone(),
211 });
212
213 let close_frame = WebSocketFrame::close(code, &reason);
215 if let Ok(encoded) =
216 self.codec.encode_frame(&close_frame, !client_to_server)
217 {
218 output.extend_from_slice(&encoded);
219 }
220
221 let _ = buf.split_to(consumed);
223 return ProcessResult::Close(CloseReason { code, reason });
224 }
225 }
226
227 let _ = buf.split_to(consumed);
229 }
230 Ok(None) => {
231 break;
233 }
234 Err(e) => {
235 warn!(
236 correlation_id = %self.inspector.correlation_id(),
237 error = %e,
238 "WebSocket frame decode error"
239 );
240 output.extend_from_slice(&buf);
243 buf.clear();
244 break;
245 }
246 }
247 }
248
249 if frames_processed > 0 {
250 trace!(
251 correlation_id = %self.inspector.correlation_id(),
252 frames_processed = frames_processed,
253 frames_dropped = frames_dropped,
254 output_len = output.len(),
255 buffer_remaining = buf.len(),
256 direction = if client_to_server { "c2s" } else { "s2c" },
257 "Processed WebSocket frames"
258 );
259 }
260
261 if output.is_empty() && frames_dropped > 0 {
262 ProcessResult::Forward(Some(Bytes::new()))
264 } else if output.is_empty() {
265 ProcessResult::Forward(Some(Bytes::new()))
268 } else {
269 ProcessResult::Forward(Some(output.freeze()))
270 }
271 }
272
273 pub async fn should_close(&self) -> Option<CloseReason> {
275 self.should_close.lock().await.clone()
276 }
277
278 pub fn correlation_id(&self) -> &str {
280 self.inspector.correlation_id()
281 }
282}
283
284pub struct WebSocketHandlerBuilder {
286 inspector: Option<Arc<WebSocketInspector>>,
287 max_frame_size: usize,
288}
289
290impl Default for WebSocketHandlerBuilder {
291 fn default() -> Self {
292 Self {
293 inspector: None,
294 max_frame_size: 1024 * 1024, }
296 }
297}
298
299impl WebSocketHandlerBuilder {
300 pub fn new() -> Self {
302 Self::default()
303 }
304
305 pub fn inspector(mut self, inspector: Arc<WebSocketInspector>) -> Self {
307 self.inspector = Some(inspector);
308 self
309 }
310
311 pub fn max_frame_size(mut self, size: usize) -> Self {
313 self.max_frame_size = size;
314 self
315 }
316
317 pub fn build(self) -> Option<WebSocketHandler> {
319 Some(WebSocketHandler::new(self.inspector?, self.max_frame_size))
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::websocket::codec::Opcode;
327 use std::sync::atomic::{AtomicUsize, Ordering};
328
329 struct MockInspector {
331 client_decision: InspectionResult,
333 server_decision: InspectionResult,
335 client_frame_count: AtomicUsize,
337 server_frame_count: AtomicUsize,
339 }
340
341 impl MockInspector {
342 fn new(client_decision: InspectionResult, server_decision: InspectionResult) -> Self {
343 Self {
344 client_decision,
345 server_decision,
346 client_frame_count: AtomicUsize::new(0),
347 server_frame_count: AtomicUsize::new(0),
348 }
349 }
350
351 fn allowing() -> Self {
352 Self::new(InspectionResult::Allow, InspectionResult::Allow)
353 }
354
355 fn dropping_client() -> Self {
356 Self::new(InspectionResult::Drop, InspectionResult::Allow)
357 }
358
359 fn dropping_server() -> Self {
360 Self::new(InspectionResult::Allow, InspectionResult::Drop)
361 }
362
363 fn closing_client(code: u16, reason: &str) -> Self {
364 Self::new(
365 InspectionResult::Close {
366 code,
367 reason: reason.to_string(),
368 },
369 InspectionResult::Allow,
370 )
371 }
372
373 fn client_frames_inspected(&self) -> usize {
374 self.client_frame_count.load(Ordering::SeqCst)
375 }
376
377 fn server_frames_inspected(&self) -> usize {
378 self.server_frame_count.load(Ordering::SeqCst)
379 }
380 }
381
382 #[async_trait::async_trait]
383 impl FrameInspector for MockInspector {
384 async fn inspect_client_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
385 self.client_frame_count.fetch_add(1, Ordering::SeqCst);
386 self.client_decision.clone()
387 }
388
389 async fn inspect_server_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
390 self.server_frame_count.fetch_add(1, Ordering::SeqCst);
391 self.server_decision.clone()
392 }
393
394 fn correlation_id(&self) -> &str {
395 "test-correlation-id"
396 }
397 }
398
399 fn make_text_frame(text: &str, masked: bool) -> Bytes {
401 let codec = WebSocketCodec::new(1024 * 1024);
402 let frame = WebSocketFrame::new(Opcode::Text, text.as_bytes().to_vec());
403 Bytes::from(codec.encode_frame(&frame, masked).unwrap())
404 }
405
406 #[test]
407 fn test_close_reason() {
408 let reason = CloseReason {
409 code: 1000,
410 reason: "Normal closure".to_string(),
411 };
412 assert_eq!(reason.code, 1000);
413 assert_eq!(reason.reason, "Normal closure");
414 }
415
416 #[test]
417 fn test_builder_defaults() {
418 let builder = WebSocketHandlerBuilder::new();
419 assert_eq!(builder.max_frame_size, 1024 * 1024);
420 }
421
422 #[tokio::test]
423 async fn test_frame_allow() {
424 let inspector = Arc::new(MockInspector::allowing());
425 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
426
427 let frame_data = make_text_frame("Hello", false);
429 let result = handler.process_client_data(Some(frame_data.clone())).await;
430
431 match result {
432 ProcessResult::Forward(Some(data)) => {
433 assert_eq!(data, frame_data);
435 }
436 _ => panic!("Expected Forward result"),
437 }
438
439 assert_eq!(inspector.client_frames_inspected(), 1);
440 }
441
442 #[tokio::test]
443 async fn test_frame_drop_client() {
444 let inspector = Arc::new(MockInspector::dropping_client());
445 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
446
447 let frame_data = make_text_frame("Hello", false);
449 let result = handler.process_client_data(Some(frame_data)).await;
450
451 match result {
452 ProcessResult::Forward(Some(data)) => {
453 assert!(data.is_empty(), "Dropped frame should produce empty output");
455 }
456 _ => panic!("Expected Forward with empty data"),
457 }
458
459 assert_eq!(inspector.client_frames_inspected(), 1);
460 }
461
462 #[tokio::test]
463 async fn test_frame_drop_server() {
464 let inspector = Arc::new(MockInspector::dropping_server());
465 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
466
467 let frame_data = make_text_frame("Server message", false);
469 let result = handler.process_server_data(Some(frame_data)).await;
470
471 match result {
472 ProcessResult::Forward(Some(data)) => {
473 assert!(data.is_empty(), "Dropped frame should produce empty output");
475 }
476 _ => panic!("Expected Forward with empty data"),
477 }
478
479 assert_eq!(inspector.server_frames_inspected(), 1);
480 }
481
482 #[tokio::test]
483 async fn test_frame_close() {
484 let inspector = Arc::new(MockInspector::closing_client(1008, "Policy violation"));
485 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
486
487 let frame_data = make_text_frame("Malicious content", false);
489 let result = handler.process_client_data(Some(frame_data)).await;
490
491 match result {
492 ProcessResult::Close(reason) => {
493 assert_eq!(reason.code, 1008);
494 assert_eq!(reason.reason, "Policy violation");
495 }
496 _ => panic!("Expected Close result"),
497 }
498
499 assert_eq!(inspector.client_frames_inspected(), 1);
500
501 let result = handler
503 .process_client_data(Some(make_text_frame("More data", false)))
504 .await;
505 match result {
506 ProcessResult::Close(_) => {}
507 _ => panic!("Expected Close result on subsequent call"),
508 }
509 }
510
511 #[tokio::test]
512 async fn test_multiple_frames_mixed_decisions() {
513 let inspector = Arc::new(MockInspector::allowing());
515 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
516
517 let frame1 = make_text_frame("Frame 1", false);
519 let result = handler.process_client_data(Some(frame1.clone())).await;
520 assert!(matches!(result, ProcessResult::Forward(Some(_))));
521
522 let frame2 = make_text_frame("Frame 2", false);
524 let result = handler.process_client_data(Some(frame2.clone())).await;
525 assert!(matches!(result, ProcessResult::Forward(Some(_))));
526
527 assert_eq!(inspector.client_frames_inspected(), 2);
528 }
529
530 #[tokio::test]
531 async fn test_end_of_stream() {
532 let inspector = Arc::new(MockInspector::allowing());
533 let handler = WebSocketHandler::with_inspector(inspector, 1024 * 1024);
534
535 let result = handler.process_client_data(None).await;
537 match result {
538 ProcessResult::Forward(None) => {}
539 _ => panic!("Expected Forward(None) for end of stream"),
540 }
541 }
542
543 #[tokio::test]
544 async fn test_partial_frame_buffering() {
545 let inspector = Arc::new(MockInspector::allowing());
546 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
547
548 let full_frame = make_text_frame("Hello World", false);
550 let (part1, part2) = full_frame.split_at(full_frame.len() / 2);
551
552 let result = handler
554 .process_client_data(Some(Bytes::from(part1.to_vec())))
555 .await;
556 match result {
557 ProcessResult::Forward(Some(data)) => {
558 assert!(data.is_empty(), "Partial frame should not produce output");
559 }
560 _ => panic!("Expected Forward with empty data for partial frame"),
561 }
562 assert_eq!(
563 inspector.client_frames_inspected(),
564 0,
565 "Partial frame should not be inspected"
566 );
567
568 let result = handler
570 .process_client_data(Some(Bytes::from(part2.to_vec())))
571 .await;
572 match result {
573 ProcessResult::Forward(Some(data)) => {
574 assert_eq!(data, full_frame, "Complete frame should be forwarded");
575 }
576 _ => panic!("Expected Forward with complete frame"),
577 }
578 assert_eq!(
579 inspector.client_frames_inspected(),
580 1,
581 "Complete frame should be inspected"
582 );
583 }
584
585 #[tokio::test]
586 async fn test_bidirectional_independence() {
587 let inspector = Arc::new(MockInspector::new(
589 InspectionResult::Drop,
590 InspectionResult::Allow,
591 ));
592 let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
593
594 let client_frame = make_text_frame("Client", false);
596 let result = handler.process_client_data(Some(client_frame)).await;
597 match result {
598 ProcessResult::Forward(Some(data)) => assert!(data.is_empty()),
599 _ => panic!("Expected empty forward for dropped client frame"),
600 }
601
602 let server_frame = make_text_frame("Server", false);
604 let original_len = server_frame.len();
605 let result = handler.process_server_data(Some(server_frame)).await;
606 match result {
607 ProcessResult::Forward(Some(data)) => assert_eq!(data.len(), original_len),
608 _ => panic!("Expected forward for allowed server frame"),
609 }
610
611 assert_eq!(inspector.client_frames_inspected(), 1);
612 assert_eq!(inspector.server_frames_inspected(), 1);
613 }
614}