vellaveto_http_proxy/proxy/websocket/mod.rs
1// Copyright 2026 Paolo Vella
2// SPDX-License-Identifier: BUSL-1.1
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE-BSL-1.1 file at the root of this repository.
6//
7// Change Date: Three years from the date of publication of this version.
8// Change License: MPL-2.0
9
10//! WebSocket transport for MCP JSON-RPC messages (SEP-1288).
11//!
12//! Implements a WebSocket reverse proxy that sits between MCP clients and
13//! an upstream MCP server. WebSocket messages (text frames) are parsed as
14//! JSON-RPC, classified via `vellaveto_mcp::extractor`, evaluated against
15//! loaded policies, and forwarded to the upstream server.
16//!
17//! Security invariants:
18//! - **Fail-closed**: Unparseable messages close the connection (code 1008).
19//! - **No binary frames**: Only text frames are accepted (code 1003 for binary).
20//! - **Session binding**: Each WS connection is bound to exactly one session.
21//! - **Canonicalization**: Re-serialized JSON forwarded (TOCTOU defense).
22
23use axum::{
24 extract::{
25 ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade},
26 ConnectInfo, State,
27 },
28 http::HeaderMap,
29 response::Response,
30};
31use futures_util::{SinkExt, StreamExt};
32use serde_json::{json, Value};
33use std::net::SocketAddr;
34use std::sync::atomic::{AtomicU64, Ordering};
35use std::sync::Arc;
36use std::time::Duration;
37use tokio::sync::Mutex;
38use vellaveto_mcp::extractor::{self, MessageType};
39use vellaveto_mcp::inspection::{
40 inspect_for_injection, scan_notification_for_secrets, scan_parameters_for_secrets,
41 scan_response_for_secrets, scan_text_for_secrets, scan_tool_descriptions,
42 scan_tool_descriptions_with_scanner,
43};
44use vellaveto_mcp::output_validation::ValidationResult;
45use vellaveto_types::{Action, EvaluationContext, Verdict};
46
47use super::auth::{validate_agent_identity, validate_api_key, validate_oauth};
48use super::call_chain::{
49 check_privilege_escalation, sync_session_call_chain_from_headers, take_tracked_tool_call,
50 track_pending_tool_call,
51};
52use super::origin::validate_origin;
53use super::ProxyState;
54use crate::proxy_metrics::record_dlp_finding;
55
56/// Configuration for WebSocket transport.
57#[derive(Debug, Clone)]
58pub struct WebSocketConfig {
59 /// Maximum message size in bytes (default: 1 MB).
60 pub max_message_size: usize,
61 /// Idle timeout in seconds — close connection after no message activity (default: 300s).
62 /// SECURITY (FIND-R182-001): True idle timeout that resets on every message.
63 pub idle_timeout_secs: u64,
64 /// Maximum messages per second per connection for client-to-upstream (default: 100).
65 pub message_rate_limit: u32,
66 /// Maximum messages per second per connection for upstream-to-client (default: 500).
67 /// SECURITY (FIND-R46-WS-003): Rate limits on the upstream→client direction prevent
68 /// a malicious upstream from flooding the client with responses.
69 pub upstream_rate_limit: u32,
70}
71
72impl Default for WebSocketConfig {
73 fn default() -> Self {
74 Self {
75 max_message_size: 1_048_576,
76 idle_timeout_secs: 300,
77 message_rate_limit: 100,
78 upstream_rate_limit: 500,
79 }
80 }
81}
82
83/// WebSocket close codes per RFC 6455.
84const CLOSE_POLICY_VIOLATION: u16 = 1008;
85const CLOSE_UNSUPPORTED_DATA: u16 = 1003;
86/// Close code for oversized messages. Used by axum's `max_message_size`
87/// automatically; kept here for documentation and test assertions.
88#[cfg(test)]
89const CLOSE_MESSAGE_TOO_BIG: u16 = 1009;
90const CLOSE_NORMAL: u16 = 1000;
91
92/// Global WebSocket metrics counters.
93static WS_CONNECTIONS_TOTAL: AtomicU64 = AtomicU64::new(0);
94static WS_MESSAGES_TOTAL: AtomicU64 = AtomicU64::new(0);
95
96/// Record WebSocket connection metric.
97fn record_ws_connection() {
98 // SECURITY (FIND-R182-003): Use saturating arithmetic to prevent overflow.
99 // SECURITY (CA-005): SeqCst for security-adjacent metrics counters.
100 let _ = WS_CONNECTIONS_TOTAL.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
101 Some(v.saturating_add(1))
102 });
103 metrics::counter!("vellaveto_ws_connections_total").increment(1);
104}
105
106/// Record WebSocket message metric.
107fn record_ws_message(direction: &str) {
108 // SECURITY (FIND-R182-003): Use saturating arithmetic to prevent overflow.
109 // SECURITY (CA-005): SeqCst for security-adjacent metrics counters.
110 let _ = WS_MESSAGES_TOTAL.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
111 Some(v.saturating_add(1))
112 });
113 metrics::counter!(
114 "vellaveto_ws_messages_total",
115 "direction" => direction.to_string()
116 )
117 .increment(1);
118}
119
120/// Get current connection count (for testing).
121#[cfg(test)]
122pub(crate) fn ws_connections_count() -> u64 {
123 WS_CONNECTIONS_TOTAL.load(Ordering::SeqCst)
124}
125
126/// Get current message count (for testing).
127#[cfg(test)]
128pub(crate) fn ws_messages_count() -> u64 {
129 WS_MESSAGES_TOTAL.load(Ordering::SeqCst)
130}
131
132use vellaveto_types::is_unicode_format_char as is_unicode_format_char_ws;
133
134/// Query parameters for the WebSocket upgrade endpoint.
135#[derive(Debug, serde::Deserialize, Default)]
136#[serde(deny_unknown_fields)]
137pub struct WsQueryParams {
138 /// Optional session ID for session resumption.
139 #[serde(default)]
140 pub session_id: Option<String>,
141}
142
143/// Handle WebSocket upgrade request at `/mcp/ws`.
144///
145/// Authenticates the request, validates origin, creates/resumes a session,
146/// and upgrades the HTTP connection to a WebSocket.
147pub async fn handle_ws_upgrade(
148 State(state): State<ProxyState>,
149 ConnectInfo(addr): ConnectInfo<SocketAddr>,
150 headers: HeaderMap,
151 query: axum::extract::Query<WsQueryParams>,
152 ws: WebSocketUpgrade,
153) -> Response {
154 // 1. Validate origin (CSRF / DNS rebinding defense)
155 if let Err(resp) = validate_origin(&headers, &state.bind_addr, &state.allowed_origins) {
156 return resp;
157 }
158
159 // 2. Authenticate before upgrade (API key or OAuth)
160 if let Err(resp) = validate_api_key(&state, &headers) {
161 return resp;
162 }
163
164 // SECURITY (FIND-R53-WS-001): Validate OAuth token at upgrade time.
165 // Parity with HTTP POST (handlers.rs:154) and GET (handlers.rs:2864).
166 // Without this, WS connections bypass token expiry checks.
167 let oauth_claims = match validate_oauth(
168 &state,
169 &headers,
170 "GET",
171 &super::auth::build_effective_request_uri(
172 &headers,
173 state.bind_addr,
174 &axum::http::Uri::from_static("/mcp/ws"),
175 false,
176 ),
177 query.session_id.as_deref(),
178 )
179 .await
180 {
181 Ok(claims) => claims,
182 Err(response) => return response,
183 };
184
185 // SECURITY (FIND-R53-WS-002, FIND-R159-003): Validate agent identity at upgrade time.
186 // Parity with HTTP POST (handlers.rs:160) and GET (handlers.rs:2871).
187 // FIND-R159-003: Identity MUST be stored in session (was previously discarded with `_`).
188 let agent_identity = match validate_agent_identity(&state, &headers).await {
189 Ok(identity) => identity,
190 Err(response) => return response,
191 };
192
193 // SECURITY (FIND-R55-WS-004, FIND-R81-001): Validate session_id length and
194 // control/format characters from query parameter. Parity with HTTP POST/GET
195 // handlers (handlers.rs:154, handlers.rs:2928) which reject control chars.
196 // SECURITY (FIND-R81-WS-001): Also reject Unicode format characters (zero-width,
197 // bidi overrides, BOM) that can bypass string-based security checks.
198 let ws_session_id = query.session_id.as_deref().filter(|id| {
199 !id.is_empty()
200 && id.len() <= 128
201 && !id
202 .chars()
203 .any(|c| c.is_control() || is_unicode_format_char_ws(c))
204 });
205
206 // 3. Get or create session
207 let session_id = state.sessions.get_or_create(ws_session_id);
208
209 // SECURITY (FIND-R53-WS-003): Session ownership binding — prevent session fixation.
210 // Parity with HTTP GET (handlers.rs:2914-2953).
211 if let Some(ref claims) = oauth_claims {
212 if let Some(mut session) = state.sessions.get_mut(&session_id) {
213 match &session.oauth_subject {
214 Some(owner) if owner != &claims.sub => {
215 tracing::warn!(
216 session_id = %session_id,
217 owner = %owner,
218 requester = %claims.sub,
219 "WS upgrade rejected: session owned by different principal"
220 );
221 return axum::response::IntoResponse::into_response((
222 axum::http::StatusCode::FORBIDDEN,
223 axum::Json(json!({
224 "error": "Session belongs to another principal"
225 })),
226 ));
227 }
228 None => {
229 // Bind session to this OAuth subject
230 session.oauth_subject = Some(claims.sub.clone());
231 // SECURITY (FIND-R73-SRV-006): Store token expiry, matching
232 // HTTP POST handler pattern to enforce token lifetime.
233 if claims.exp > 0 {
234 session.token_expires_at = Some(claims.exp);
235 }
236 }
237 _ => {
238 // Already owned by this principal — use earliest expiry
239 // SECURITY (FIND-R73-SRV-006): Parity with HTTP POST handler
240 // (R23-PROXY-6) — prevent long-lived tokens from extending
241 // sessions originally bound to short-lived tokens.
242 if claims.exp > 0 {
243 session.token_expires_at = Some(
244 session
245 .token_expires_at
246 .map_or(claims.exp, |existing| existing.min(claims.exp)),
247 );
248 }
249 }
250 }
251 }
252 }
253
254 // SECURITY (FIND-R159-003): Store agent identity in session — parity with HTTP
255 // POST (handlers.rs:295-298) and GET (handlers.rs:3641-3643). Without this,
256 // ABAC policies referencing agent_identity would evaluate against None for
257 // WebSocket connections, creating a policy bypass.
258 if let Some(ref identity) = agent_identity {
259 if let Some(mut session) = state.sessions.get_mut(&session_id) {
260 session.agent_identity = Some(identity.clone());
261 }
262 }
263
264 // SECURITY (FIND-R46-006): Validate and extract call chain from upgrade headers.
265 // The call chain is synced once during upgrade and reused for all messages
266 // in this WebSocket connection.
267 if let Err(reason) = super::call_chain::validate_call_chain_header(&headers, &state.limits) {
268 tracing::warn!(
269 session_id = %session_id,
270 "WS upgrade rejected: invalid call chain header: {}",
271 reason
272 );
273 return axum::response::IntoResponse::into_response((
274 axum::http::StatusCode::BAD_REQUEST,
275 axum::Json(json!({
276 "error": "Invalid request"
277 })),
278 ));
279 }
280 sync_session_call_chain_from_headers(
281 &state.sessions,
282 &session_id,
283 &headers,
284 state.call_chain_hmac_key.as_ref(),
285 &state.limits,
286 );
287
288 let ws_config = state.ws_config.clone().unwrap_or_default();
289
290 // Phase 28: Extract W3C Trace Context from the HTTP upgrade request headers.
291 // The trace_id is used for correlating all audit entries during this WS session.
292 let trace_ctx = super::trace_propagation::extract_trace_context(&headers);
293 let ws_trace_id = trace_ctx
294 .trace_id
295 .clone()
296 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string().replace('-', ""));
297
298 tracing::info!(
299 session_id = %session_id,
300 trace_id = %ws_trace_id,
301 peer = %addr,
302 "WebSocket upgrade accepted"
303 );
304
305 // 4. Configure and upgrade
306 ws.max_message_size(ws_config.max_message_size)
307 .on_upgrade(move |socket| {
308 handle_ws_connection(socket, state, session_id, ws_config, addr, ws_trace_id)
309 })
310}
311
312/// Handle an established WebSocket connection.
313///
314/// Establishes an upstream WS connection, then relays messages bidirectionally
315/// with policy enforcement on client→upstream messages and DLP/injection
316/// scanning on upstream→client messages.
317async fn handle_ws_connection(
318 client_ws: WebSocket,
319 state: ProxyState,
320 session_id: String,
321 ws_config: WebSocketConfig,
322 peer_addr: SocketAddr,
323 trace_id: String,
324) {
325 record_ws_connection();
326 let start = std::time::Instant::now();
327 tracing::debug!(
328 session_id = %session_id,
329 trace_id = %trace_id,
330 "WebSocket connection established with trace context"
331 );
332
333 // Connect to upstream — use gateway default backend if configured
334 let upstream_url = if let Some(ref gw) = state.gateway {
335 match gw.route("") {
336 Some(d) => convert_to_ws_url(&d.upstream_url),
337 None => {
338 tracing::error!(session_id = %session_id, "No healthy upstream for WebSocket");
339 let (mut client_sink, _) = client_ws.split();
340 let _ = client_sink
341 .send(Message::Close(Some(CloseFrame {
342 code: CLOSE_POLICY_VIOLATION,
343 reason: "No healthy upstream available".into(),
344 })))
345 .await;
346 return;
347 }
348 }
349 } else {
350 convert_to_ws_url(&state.upstream_url)
351 };
352 let upstream_ws = match connect_upstream_ws(&upstream_url).await {
353 Ok(ws) => ws,
354 Err(e) => {
355 tracing::error!(
356 session_id = %session_id,
357 "Failed to connect to upstream WebSocket: {}",
358 e
359 );
360 // Send close frame to client
361 let (mut client_sink, _) = client_ws.split();
362 let _ = client_sink
363 .send(Message::Close(Some(CloseFrame {
364 code: CLOSE_POLICY_VIOLATION,
365 reason: "Upstream connection failed".into(),
366 })))
367 .await;
368 return;
369 }
370 };
371
372 let (client_sink, client_stream) = client_ws.split();
373 let (upstream_sink, upstream_stream) = upstream_ws.split();
374
375 // Wrap sinks in Arc<Mutex> for shared access
376 let client_sink = Arc::new(Mutex::new(client_sink));
377 let upstream_sink = Arc::new(Mutex::new(upstream_sink));
378
379 // Rate limiter state: track messages in the current second window
380 let rate_counter = Arc::new(AtomicU64::new(0));
381 let rate_window_start = Arc::new(std::sync::Mutex::new(std::time::Instant::now()));
382
383 // SECURITY (FIND-R46-WS-003): Separate rate limiter for upstream→client direction
384 let upstream_rate_counter = Arc::new(AtomicU64::new(0));
385 let upstream_rate_window_start = Arc::new(std::sync::Mutex::new(std::time::Instant::now()));
386
387 let idle_timeout = Duration::from_secs(ws_config.idle_timeout_secs);
388
389 // SECURITY (FIND-R182-001): Shared last-activity tracker so idle timeout resets
390 // on every message (true idle detection, not max-lifetime).
391 let last_activity = Arc::new(AtomicU64::new(0));
392 let connection_epoch = std::time::Instant::now();
393
394 // Client → Vellaveto → Upstream relay
395 let client_to_upstream = {
396 let state = state.clone();
397 let session_id = session_id.clone();
398 let client_sink = client_sink.clone();
399 let upstream_sink = upstream_sink.clone();
400 let rate_counter = rate_counter.clone();
401 let rate_window_start = rate_window_start.clone();
402 let ws_config = ws_config.clone();
403 let last_activity = last_activity.clone();
404
405 relay_client_to_upstream(
406 client_stream,
407 client_sink,
408 upstream_sink,
409 state,
410 session_id,
411 ws_config,
412 rate_counter,
413 rate_window_start,
414 last_activity,
415 connection_epoch,
416 )
417 };
418
419 // Upstream → Vellaveto → Client relay
420 let upstream_to_client = {
421 let state = state.clone();
422 let session_id = session_id.clone();
423 let client_sink = client_sink.clone();
424 let ws_config = ws_config.clone();
425 let last_activity = last_activity.clone();
426
427 relay_upstream_to_client(
428 upstream_stream,
429 client_sink,
430 state,
431 session_id,
432 ws_config,
433 upstream_rate_counter,
434 upstream_rate_window_start,
435 last_activity,
436 connection_epoch,
437 )
438 };
439
440 // SECURITY (FIND-R182-001): True idle timeout — check periodically and
441 // close only if no message activity since last check.
442 let idle_check = {
443 let session_id = session_id.clone();
444 let last_activity = last_activity.clone();
445 async move {
446 // Check every 10% of idle timeout (min 1s) for responsive detection.
447 let check_interval = Duration::from_secs((ws_config.idle_timeout_secs / 10).max(1));
448 let mut interval = tokio::time::interval(check_interval);
449 interval.tick().await; // first tick is immediate, skip it
450 loop {
451 interval.tick().await;
452 let last_ms = last_activity.load(Ordering::Relaxed);
453 // SECURITY (FIND-R190-002): Use saturating_sub to prevent underflow
454 // if Relaxed ordering causes a stale last_ms value.
455 let elapsed_since_activity =
456 (connection_epoch.elapsed().as_millis() as u64).saturating_sub(last_ms);
457 if elapsed_since_activity >= idle_timeout.as_millis() as u64 {
458 tracing::info!(
459 session_id = %session_id,
460 idle_secs = elapsed_since_activity / 1000,
461 "WebSocket idle timeout ({}s), closing",
462 ws_config.idle_timeout_secs
463 );
464 break;
465 }
466 }
467 }
468 };
469
470 // Run both relay loops with idle timeout
471 tokio::select! {
472 _ = client_to_upstream => {
473 tracing::debug!(session_id = %session_id, "Client stream ended");
474 }
475 _ = upstream_to_client => {
476 tracing::debug!(session_id = %session_id, "Upstream stream ended");
477 }
478 _ = idle_check => {}
479 }
480
481 // Clean shutdown: close both sides
482 {
483 let mut sink = client_sink.lock().await;
484 let _ = sink
485 .send(Message::Close(Some(CloseFrame {
486 code: CLOSE_NORMAL,
487 reason: "Session ended".into(),
488 })))
489 .await;
490 }
491 {
492 let mut sink = upstream_sink.lock().await;
493 let _ = sink.close().await;
494 }
495
496 let duration = start.elapsed();
497 metrics::histogram!("vellaveto_ws_connection_duration_seconds").record(duration.as_secs_f64());
498
499 tracing::info!(
500 session_id = %session_id,
501 peer = %peer_addr,
502 duration_secs = duration.as_secs(),
503 "WebSocket connection closed"
504 );
505}
506
507/// Relay messages from client to upstream with policy enforcement.
508#[allow(clippy::too_many_arguments)]
509#[allow(deprecated)] // evaluate_action_with_context: migration tracked in FIND-CREATIVE-005
510async fn relay_client_to_upstream(
511 mut client_stream: futures_util::stream::SplitStream<WebSocket>,
512 client_sink: Arc<Mutex<futures_util::stream::SplitSink<WebSocket, Message>>>,
513 upstream_sink: Arc<
514 Mutex<
515 futures_util::stream::SplitSink<
516 tokio_tungstenite::WebSocketStream<
517 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
518 >,
519 tokio_tungstenite::tungstenite::Message,
520 >,
521 >,
522 >,
523 state: ProxyState,
524 session_id: String,
525 ws_config: WebSocketConfig,
526 rate_counter: Arc<AtomicU64>,
527 rate_window_start: Arc<std::sync::Mutex<std::time::Instant>>,
528 last_activity: Arc<AtomicU64>,
529 connection_epoch: std::time::Instant,
530) {
531 while let Some(msg_result) = client_stream.next().await {
532 let msg = match msg_result {
533 Ok(m) => m,
534 Err(e) => {
535 tracing::debug!(session_id = %session_id, "Client WS error: {}", e);
536 break;
537 }
538 };
539
540 // SECURITY (FIND-R182-001): Update last-activity for true idle detection.
541 last_activity.store(
542 connection_epoch.elapsed().as_millis() as u64,
543 Ordering::Relaxed,
544 );
545
546 record_ws_message("client_to_upstream");
547
548 // SECURITY (FIND-R52-WS-003): Per-message OAuth token expiry check.
549 // After WebSocket upgrade, the HTTP auth middleware no longer runs.
550 // A token that expires mid-connection must be rejected to prevent
551 // indefinite access via a long-lived WebSocket.
552 {
553 let token_expired = state
554 .sessions
555 .get_mut(&session_id)
556 .and_then(|s| {
557 s.token_expires_at.map(|exp| {
558 let now = std::time::SystemTime::now()
559 .duration_since(std::time::UNIX_EPOCH)
560 .unwrap_or_default()
561 .as_secs();
562 now >= exp
563 })
564 })
565 .unwrap_or(false);
566 if token_expired {
567 tracing::warn!(
568 session_id = %session_id,
569 "SECURITY: OAuth token expired during WebSocket session, closing"
570 );
571 let error = json!({
572 "jsonrpc": "2.0",
573 "error": {
574 "code": -32001,
575 "message": "Session expired"
576 },
577 "id": null
578 });
579 let error_text = serde_json::to_string(&error)
580 .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","error":{"code":-32001,"message":"Session expired"},"id":null}"#.to_string());
581 let mut sink = client_sink.lock().await;
582 let _ = sink.send(Message::Text(error_text.into())).await;
583 let _ = sink
584 .send(Message::Close(Some(CloseFrame {
585 code: CLOSE_POLICY_VIOLATION,
586 reason: "Token expired".into(),
587 })))
588 .await;
589 break;
590 }
591 }
592
593 match msg {
594 Message::Text(text) => {
595 // Rate limiting
596 if !check_rate_limit(
597 &rate_counter,
598 &rate_window_start,
599 ws_config.message_rate_limit,
600 ) {
601 tracing::warn!(
602 session_id = %session_id,
603 "WebSocket rate limit exceeded, closing"
604 );
605 let mut sink = client_sink.lock().await;
606 let _ = sink
607 .send(Message::Close(Some(CloseFrame {
608 code: CLOSE_POLICY_VIOLATION,
609 reason: "Rate limit exceeded".into(),
610 })))
611 .await;
612 break;
613 }
614
615 // SECURITY (FIND-R46-005): Reject JSON with duplicate keys before parsing.
616 // Prevents parser-disagreement attacks (CVE-2017-12635, CVE-2020-16250)
617 // where the proxy evaluates one key value but upstream sees another.
618 if let Some(dup_key) = vellaveto_mcp::framing::find_duplicate_json_key(&text) {
619 tracing::warn!(
620 session_id = %session_id,
621 "SECURITY: Rejected WS message with duplicate key: \"{}\"",
622 dup_key
623 );
624 let mut sink = client_sink.lock().await;
625 let _ = sink
626 .send(Message::Close(Some(CloseFrame {
627 code: CLOSE_POLICY_VIOLATION,
628 reason: "Duplicate JSON key detected".into(),
629 })))
630 .await;
631 break;
632 }
633
634 // SECURITY (FIND-R53-WS-004): Reject WS messages with control characters.
635 // Parity with HTTP GET event_id validation (handlers.rs:2899).
636 // Control chars in JSON-RPC messages can be used for log injection
637 // or to bypass string-based security checks.
638 if text.chars().any(|c| {
639 // Allow standard JSON whitespace (\t, \n, \r) but reject other
640 // ASCII control chars and Unicode format chars (FIND-R54-011).
641 (c.is_control() && c != '\n' && c != '\r' && c != '\t')
642 || is_unicode_format_char_ws(c)
643 }) {
644 tracing::warn!(
645 session_id = %session_id,
646 "SECURITY: Rejected WS message with control characters"
647 );
648 let error =
649 make_ws_error_response(None, -32600, "Message contains control characters");
650 let mut sink = client_sink.lock().await;
651 let _ = sink.send(Message::Text(error.into())).await;
652 continue;
653 }
654
655 // Parse JSON
656 let parsed: Value = match serde_json::from_str(&text) {
657 Ok(v) => v,
658 Err(_) => {
659 tracing::warn!(
660 session_id = %session_id,
661 "Unparseable JSON in WebSocket text frame, closing (fail-closed)"
662 );
663 let mut sink = client_sink.lock().await;
664 let _ = sink
665 .send(Message::Close(Some(CloseFrame {
666 code: CLOSE_POLICY_VIOLATION,
667 reason: "Invalid JSON".into(),
668 })))
669 .await;
670 break;
671 }
672 };
673
674 // SECURITY (FIND-R46-WS-001): Injection scanning on client→upstream text frames.
675 // The HTTP proxy scans request bodies for injection; the WebSocket proxy must
676 // do the same to maintain security parity. Fail-closed: if injection is detected
677 // and blocking is enabled, deny the message.
678 if !state.injection_disabled {
679 let scannable = extract_scannable_text_from_request(&parsed);
680 if !scannable.is_empty() {
681 let injection_matches: Vec<String> =
682 if let Some(ref scanner) = state.injection_scanner {
683 scanner
684 .inspect(&scannable)
685 .into_iter()
686 .map(|s| s.to_string())
687 .collect()
688 } else {
689 inspect_for_injection(&scannable)
690 .into_iter()
691 .map(|s| s.to_string())
692 .collect()
693 };
694
695 if !injection_matches.is_empty() {
696 tracing::warn!(
697 "SECURITY: Injection in WS client request! Session: {}, Patterns: {:?}",
698 session_id,
699 injection_matches,
700 );
701
702 let verdict = if state.injection_blocking {
703 Verdict::Deny {
704 reason: format!(
705 "WS request injection blocked: {injection_matches:?}"
706 ),
707 }
708 } else {
709 Verdict::Allow
710 };
711
712 let action = Action::new(
713 "vellaveto",
714 "ws_request_injection",
715 json!({
716 "matched_patterns": injection_matches,
717 "session": session_id,
718 "transport": "websocket",
719 "direction": "client_to_upstream",
720 }),
721 );
722 if let Err(e) = state
723 .audit
724 .log_entry(
725 &action,
726 &verdict,
727 json!({
728 "source": "ws_proxy",
729 "event": "ws_request_injection_detected",
730 }),
731 )
732 .await
733 {
734 tracing::warn!("Failed to audit WS request injection: {}", e);
735 }
736
737 if state.injection_blocking {
738 let id = parsed.get("id");
739 let error = make_ws_error_response(
740 id,
741 -32001,
742 "Request blocked: injection detected",
743 );
744 let mut sink = client_sink.lock().await;
745 let _ = sink.send(Message::Text(error.into())).await;
746 continue;
747 }
748 }
749 }
750 }
751
752 // Classify and evaluate
753 let classified = extractor::classify_message(&parsed);
754 match classified {
755 MessageType::ToolCall {
756 ref id,
757 ref tool_name,
758 ref arguments,
759 } => {
760 // SECURITY (FIND-R46-009): Strict tool name validation (MCP 2025-11-25).
761 // When enabled, reject tool names that don't conform to the spec format.
762 if state.streamable_http.strict_tool_name_validation {
763 if let Err(e) = vellaveto_types::validate_mcp_tool_name(tool_name) {
764 tracing::warn!(
765 session_id = %session_id,
766 "SECURITY: Rejecting invalid WS tool name '{}': {}",
767 tool_name,
768 e
769 );
770 let error =
771 make_ws_error_response(Some(id), -32602, "Invalid tool name");
772 let mut sink = client_sink.lock().await;
773 let _ = sink.send(Message::Text(error.into())).await;
774 continue;
775 }
776 }
777
778 let mut action = extractor::extract_action(tool_name, arguments);
779
780 // SECURITY (FIND-R75-002): DNS resolution for IP-based policy evaluation.
781 // Parity with HTTP handler (handlers.rs:717). Without this, policies
782 // using ip_rules are completely bypassed on the WebSocket transport.
783 if state.engine.has_ip_rules() {
784 super::helpers::resolve_domains(&mut action).await;
785 }
786
787 // SECURITY (FIND-R46-006): Call chain validation and privilege escalation check.
788 // Extract X-Upstream-Agents from the initial WS upgrade headers stored in session.
789 // For WebSocket, we sync the call chain once during upgrade and reuse it.
790 let upstream_chain = {
791 let session_ref = state.sessions.get_mut(&session_id);
792 session_ref
793 .map(|s| s.current_call_chain.clone())
794 .unwrap_or_default()
795 };
796 let current_agent_id = {
797 let session_ref = state.sessions.get_mut(&session_id);
798 session_ref.and_then(|s| s.oauth_subject.clone())
799 };
800
801 // SECURITY (FIND-R46-006): Privilege escalation detection.
802 if !upstream_chain.is_empty() {
803 let priv_check = check_privilege_escalation(
804 &state.engine,
805 &state.policies,
806 &action,
807 &upstream_chain,
808 current_agent_id.as_deref(),
809 );
810 if priv_check.escalation_detected {
811 let verdict = Verdict::Deny {
812 reason: format!(
813 "Privilege escalation: agent '{}' would be denied",
814 priv_check
815 .escalating_from_agent
816 .as_deref()
817 .unwrap_or("unknown")
818 ),
819 };
820 if let Err(e) = state
821 .audit
822 .log_entry(
823 &action,
824 &verdict,
825 json!({
826 "source": "ws_proxy",
827 "session": session_id,
828 "transport": "websocket",
829 "event": "privilege_escalation_blocked",
830 "escalating_from_agent": priv_check.escalating_from_agent,
831 "upstream_deny_reason": priv_check.upstream_deny_reason,
832 }),
833 )
834 .await
835 {
836 tracing::warn!(
837 "Failed to audit WS privilege escalation: {}",
838 e
839 );
840 }
841 let error =
842 make_ws_error_response(Some(id), -32001, "Denied by policy");
843 let mut sink = client_sink.lock().await;
844 let _ = sink.send(Message::Text(error.into())).await;
845 continue;
846 }
847 }
848
849 // SECURITY (FIND-R46-007): Rug-pull detection.
850 // Block calls to tools whose annotations changed since initial tools/list.
851 // SECURITY (R240-PROXY-1): Fall back to global registry on session miss.
852 let is_flagged = state
853 .sessions
854 .get_mut(&session_id)
855 .map(|s| s.flagged_tools.contains(tool_name))
856 .unwrap_or_else(|| state.sessions.is_tool_globally_flagged(tool_name));
857 if is_flagged {
858 let verdict = Verdict::Deny {
859 reason: format!(
860 "Tool '{tool_name}' blocked: annotations changed (rug-pull detected)"
861 ),
862 };
863 if let Err(e) = state
864 .audit
865 .log_entry(
866 &action,
867 &verdict,
868 json!({
869 "source": "ws_proxy",
870 "session": session_id,
871 "transport": "websocket",
872 "event": "rug_pull_tool_blocked",
873 "tool": tool_name,
874 }),
875 )
876 .await
877 {
878 tracing::warn!("Failed to audit WS rug-pull block: {}", e);
879 }
880 let error =
881 make_ws_error_response(Some(id), -32001, "Denied by policy");
882 let mut sink = client_sink.lock().await;
883 let _ = sink.send(Message::Text(error.into())).await;
884 continue;
885 }
886
887 // SECURITY (FIND-R52-WS-001): DLP scan parameters for secret exfiltration.
888 // Matches HTTP handler's DLP check to maintain security parity.
889 {
890 let dlp_findings = scan_parameters_for_secrets(arguments);
891 // SECURITY (FIND-R55-WS-001): DLP on request params always blocks,
892 // matching HTTP handler. Previously gated on injection_blocking flag.
893 if !dlp_findings.is_empty() {
894 for finding in &dlp_findings {
895 record_dlp_finding(&finding.pattern_name);
896 }
897 let patterns: Vec<String> = dlp_findings
898 .iter()
899 .map(|f| format!("{} at {}", f.pattern_name, f.location))
900 .collect();
901 let audit_reason = format!(
902 "DLP: secrets detected in tool parameters: {patterns:?}"
903 );
904 tracing::warn!(
905 "SECURITY: DLP blocking WS tool '{}' in session {}: {}",
906 tool_name,
907 session_id,
908 audit_reason
909 );
910 let dlp_action = extractor::extract_action(tool_name, arguments);
911 if let Err(e) = state
912 .audit
913 .log_entry(
914 &dlp_action,
915 &Verdict::Deny {
916 reason: audit_reason,
917 },
918 json!({
919 "source": "ws_proxy",
920 "session": session_id,
921 "transport": "websocket",
922 "event": "dlp_secret_blocked",
923 "tool": tool_name,
924 "findings": patterns,
925 }),
926 )
927 .await
928 {
929 tracing::warn!("Failed to audit WS DLP finding: {}", e);
930 }
931 let error = make_ws_error_response(
932 Some(id),
933 -32001,
934 "Request blocked: security policy violation",
935 );
936 let mut sink = client_sink.lock().await;
937 let _ = sink.send(Message::Text(error.into())).await;
938 continue;
939 }
940 }
941
942 // SECURITY (FIND-R52-WS-002): Memory poisoning detection.
943 // Check if tool call parameters contain replayed response data,
944 // matching the HTTP handler's memory poisoning check.
945 {
946 let poisoning_detected = state
947 .sessions
948 .get_mut(&session_id)
949 .and_then(|session| {
950 let matches =
951 session.memory_tracker.check_parameters(arguments);
952 if !matches.is_empty() {
953 for m in &matches {
954 tracing::warn!(
955 "SECURITY: Memory poisoning detected in WS tool '{}' (session {}): \
956 param '{}' contains replayed data (fingerprint: {})",
957 tool_name,
958 session_id,
959 m.param_location,
960 m.fingerprint
961 );
962 }
963 Some(matches.len())
964 } else {
965 None
966 }
967 });
968 if let Some(match_count) = poisoning_detected {
969 let poison_action = extractor::extract_action(tool_name, arguments);
970 let deny_reason = format!(
971 "Memory poisoning detected: {match_count} replayed data fragment(s) in tool '{tool_name}'"
972 );
973 if let Err(e) = state
974 .audit
975 .log_entry(
976 &poison_action,
977 &Verdict::Deny {
978 reason: deny_reason,
979 },
980 json!({
981 "source": "ws_proxy",
982 "session": session_id,
983 "transport": "websocket",
984 "event": "memory_poisoning_detected",
985 "matches": match_count,
986 "tool": tool_name,
987 }),
988 )
989 .await
990 {
991 tracing::warn!("Failed to audit WS memory poisoning: {}", e);
992 }
993 let error = make_ws_error_response(
994 Some(id),
995 -32001,
996 "Request blocked: security policy violation",
997 );
998 let mut sink = client_sink.lock().await;
999 let _ = sink.send(Message::Text(error.into())).await;
1000 continue;
1001 }
1002 }
1003
1004 // SECURITY (FIND-R46-008): Circuit breaker check.
1005 // If the circuit is open for this tool, reject immediately.
1006 if let Some(ref circuit_breaker) = state.circuit_breaker {
1007 if let Err(reason) = circuit_breaker.can_proceed(tool_name) {
1008 tracing::warn!(
1009 session_id = %session_id,
1010 "SECURITY: WS circuit breaker open for tool '{}': {}",
1011 tool_name,
1012 reason
1013 );
1014 let verdict = Verdict::Deny {
1015 reason: format!("Circuit breaker open: {reason}"),
1016 };
1017 if let Err(e) = state
1018 .audit
1019 .log_entry(
1020 &action,
1021 &verdict,
1022 json!({
1023 "source": "ws_proxy",
1024 "session": session_id,
1025 "transport": "websocket",
1026 "event": "circuit_breaker_rejected",
1027 "tool": tool_name,
1028 }),
1029 )
1030 .await
1031 {
1032 tracing::warn!(
1033 "Failed to audit WS circuit breaker rejection: {}",
1034 e
1035 );
1036 }
1037 let error = make_ws_error_response(
1038 Some(id),
1039 -32001,
1040 "Service temporarily unavailable",
1041 );
1042 let mut sink = client_sink.lock().await;
1043 let _ = sink.send(Message::Text(error.into())).await;
1044 continue;
1045 }
1046 }
1047
1048 // SECURITY (FIND-R46-013): Tool registry trust check.
1049 // If tool registry is configured, check trust level before evaluation.
1050 if let Some(ref registry) = state.tool_registry {
1051 let trust = registry.check_trust_level(tool_name).await;
1052 match trust {
1053 vellaveto_mcp::tool_registry::TrustLevel::Unknown => {
1054 registry.register_unknown(tool_name).await;
1055 let verdict = Verdict::Deny {
1056 reason: "Unknown tool requires approval".to_string(),
1057 };
1058 if let Err(e) = state
1059 .audit
1060 .log_entry(
1061 &action,
1062 &verdict,
1063 json!({
1064 "source": "ws_proxy",
1065 "session": session_id,
1066 "transport": "websocket",
1067 "registry": "unknown_tool",
1068 "tool": tool_name,
1069 }),
1070 )
1071 .await
1072 {
1073 tracing::warn!("Failed to audit WS unknown tool: {}", e);
1074 }
1075 let approval_reason = "Approval required";
1076 let approval_id = create_ws_approval(
1077 &state,
1078 &session_id,
1079 &action,
1080 approval_reason,
1081 )
1082 .await;
1083 let error = make_ws_error_response_with_data(
1084 Some(id),
1085 -32001,
1086 approval_reason,
1087 Some(json!({
1088 "verdict": "require_approval",
1089 "reason": approval_reason,
1090 "approval_id": approval_id,
1091 })),
1092 );
1093 let mut sink = client_sink.lock().await;
1094 let _ = sink.send(Message::Text(error.into())).await;
1095 continue;
1096 }
1097 vellaveto_mcp::tool_registry::TrustLevel::Untrusted {
1098 score: _,
1099 } => {
1100 let verdict = Verdict::Deny {
1101 reason: "Untrusted tool requires approval".to_string(),
1102 };
1103 if let Err(e) = state
1104 .audit
1105 .log_entry(
1106 &action,
1107 &verdict,
1108 json!({
1109 "source": "ws_proxy",
1110 "session": session_id,
1111 "transport": "websocket",
1112 "registry": "untrusted_tool",
1113 "tool": tool_name,
1114 }),
1115 )
1116 .await
1117 {
1118 tracing::warn!("Failed to audit WS untrusted tool: {}", e);
1119 }
1120 let approval_reason = "Approval required";
1121 let approval_id = create_ws_approval(
1122 &state,
1123 &session_id,
1124 &action,
1125 approval_reason,
1126 )
1127 .await;
1128 let error = make_ws_error_response_with_data(
1129 Some(id),
1130 -32001,
1131 approval_reason,
1132 Some(json!({
1133 "verdict": "require_approval",
1134 "reason": approval_reason,
1135 "approval_id": approval_id,
1136 })),
1137 );
1138 let mut sink = client_sink.lock().await;
1139 let _ = sink.send(Message::Text(error.into())).await;
1140 continue;
1141 }
1142 vellaveto_mcp::tool_registry::TrustLevel::Trusted => {
1143 // Trusted — proceed to engine evaluation
1144 }
1145 }
1146 }
1147
1148 // SECURITY (FIND-R130-002): Combine context read, evaluation,
1149 // and session update into a single block holding the DashMap
1150 // shard lock. Without this, concurrent WS connections sharing
1151 // a session can bypass max_calls_in_window by racing: both
1152 // clone the same stale call_counts, both pass evaluation, both
1153 // increment. Matches HTTP handler R19-TOCTOU pattern
1154 // (handlers.rs:725-789).
1155 let (verdict, ctx) = if let Some(mut session) =
1156 state.sessions.get_mut(&session_id)
1157 {
1158 let ctx = EvaluationContext {
1159 timestamp: None,
1160 agent_id: session.oauth_subject.clone(),
1161 agent_identity: session.agent_identity.clone(),
1162 call_counts: session.call_counts.clone(),
1163 previous_actions: session.action_history.iter().cloned().collect(),
1164 call_chain: session.current_call_chain.clone(),
1165 tenant_id: None,
1166 verification_tier: None,
1167 capability_token: None,
1168 session_state: None,
1169 };
1170
1171 let verdict = match state.engine.evaluate_action_with_context(
1172 &action,
1173 &state.policies,
1174 Some(&ctx),
1175 ) {
1176 Ok(v) => v,
1177 Err(e) => {
1178 tracing::error!(
1179 session_id = %session_id,
1180 "Policy evaluation error: {}",
1181 e
1182 );
1183 Verdict::Deny {
1184 reason: format!("Policy evaluation failed: {e}"),
1185 }
1186 }
1187 };
1188
1189 // Atomically update session on Allow while still holding
1190 // the shard lock — prevents TOCTOU bypass of call limits.
1191 if matches!(verdict, Verdict::Allow) {
1192 session.touch();
1193 use crate::proxy::call_chain::{
1194 MAX_ACTION_HISTORY, MAX_CALL_COUNT_TOOLS,
1195 };
1196 if session.call_counts.len() < MAX_CALL_COUNT_TOOLS
1197 || session.call_counts.contains_key(tool_name)
1198 {
1199 let count = session
1200 .call_counts
1201 .entry(tool_name.to_string())
1202 .or_insert(0);
1203 *count = count.saturating_add(1);
1204 }
1205 if session.action_history.len() >= MAX_ACTION_HISTORY {
1206 session.action_history.pop_front();
1207 }
1208 session.action_history.push_back(tool_name.to_string());
1209 }
1210
1211 (verdict, ctx)
1212 } else {
1213 // No session — evaluate without context (fail-closed)
1214 let verdict = match state.engine.evaluate_action_with_context(
1215 &action,
1216 &state.policies,
1217 None,
1218 ) {
1219 Ok(v) => v,
1220 Err(e) => {
1221 tracing::error!(
1222 session_id = %session_id,
1223 "Policy evaluation error: {}",
1224 e
1225 );
1226 Verdict::Deny {
1227 reason: format!("Policy evaluation failed: {e}"),
1228 }
1229 }
1230 };
1231 (verdict, EvaluationContext::default())
1232 };
1233
1234 match verdict {
1235 Verdict::Allow => {
1236 // Phase 21: ABAC refinement — only runs when ABAC engine is configured
1237 if let Some(ref abac) = state.abac_engine {
1238 let principal_id =
1239 ctx.agent_id.as_deref().unwrap_or("anonymous");
1240 let principal_type = ctx.principal_type();
1241 let session_risk = state
1242 .sessions
1243 .get_mut(&session_id)
1244 .and_then(|s| s.risk_score.clone());
1245 let abac_ctx = vellaveto_engine::abac::AbacEvalContext {
1246 eval_ctx: &ctx,
1247 principal_type,
1248 principal_id,
1249 risk_score: session_risk.as_ref(),
1250 };
1251 match abac.evaluate(&action, &abac_ctx) {
1252 vellaveto_engine::abac::AbacDecision::Deny {
1253 policy_id,
1254 reason,
1255 } => {
1256 let deny_verdict = Verdict::Deny {
1257 reason: reason.clone(),
1258 };
1259 if let Err(e) = state
1260 .audit
1261 .log_entry(
1262 &action,
1263 &deny_verdict,
1264 json!({
1265 "source": "ws_proxy",
1266 "session": session_id,
1267 "transport": "websocket",
1268 "event": "abac_deny",
1269 "abac_policy": policy_id,
1270 }),
1271 )
1272 .await
1273 {
1274 tracing::warn!(
1275 "Failed to audit WS ABAC deny: {}",
1276 e
1277 );
1278 }
1279 // SECURITY (FIND-R46-012): Generic message to client;
1280 // detailed reason (ABAC policy_id, reason) is in
1281 // the audit log only.
1282 let error_resp = make_ws_error_response(
1283 Some(id),
1284 -32001,
1285 "Denied by policy",
1286 );
1287 let mut sink = client_sink.lock().await;
1288 let _ =
1289 sink.send(Message::Text(error_resp.into())).await;
1290 continue;
1291 }
1292 vellaveto_engine::abac::AbacDecision::Allow {
1293 policy_id,
1294 } => {
1295 if let Some(ref la) = state.least_agency {
1296 la.record_usage(
1297 principal_id,
1298 &session_id,
1299 &policy_id,
1300 tool_name,
1301 &action.function,
1302 );
1303 }
1304 }
1305 vellaveto_engine::abac::AbacDecision::NoMatch => {
1306 // Fall through — existing Allow stands
1307 }
1308 #[allow(unreachable_patterns)]
1309 // AbacDecision is #[non_exhaustive]
1310 _ => {
1311 // SECURITY (FIND-R74-002): Future variants — fail-closed (deny).
1312 // Must send deny and continue, not fall through to Allow path.
1313 tracing::warn!(
1314 "Unknown AbacDecision variant — fail-closed"
1315 );
1316 let error_resp = make_ws_error_response(
1317 Some(id),
1318 -32001,
1319 "Denied by policy",
1320 );
1321 let mut sink = client_sink.lock().await;
1322 let _ =
1323 sink.send(Message::Text(error_resp.into())).await;
1324 continue;
1325 }
1326 }
1327 }
1328
1329 // SECURITY (FIND-R46-013): Record tool call in registry on Allow
1330 if let Some(ref registry) = state.tool_registry {
1331 registry.record_call(tool_name).await;
1332 }
1333
1334 // NOTE: Session touch + call_counts/action_history
1335 // update already performed inside the TOCTOU-safe
1336 // block above (FIND-R130-002). No separate update here.
1337
1338 // Audit the allow
1339 if let Err(e) = state
1340 .audit
1341 .log_entry(
1342 &action,
1343 &Verdict::Allow,
1344 json!({
1345 "source": "ws_proxy",
1346 "session": session_id,
1347 "transport": "websocket",
1348 }),
1349 )
1350 .await
1351 {
1352 tracing::error!(
1353 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
1354 e
1355 );
1356 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
1357 // No unaudited security decisions can occur.
1358 if state.audit_strict_mode {
1359 let error = make_ws_error_response(
1360 Some(id),
1361 -32000,
1362 "Audit logging failed — request denied (strict audit mode)",
1363 );
1364 let mut sink = client_sink.lock().await;
1365 let _ = sink.send(Message::Text(error.into())).await;
1366 continue;
1367 }
1368 }
1369
1370 // Canonicalize and forward
1371 let forward_text = if state.canonicalize {
1372 match serde_json::to_string(&parsed) {
1373 Ok(canonical) => canonical,
1374 Err(e) => {
1375 tracing::error!(
1376 "SECURITY: WS canonicalization failed: {}",
1377 e
1378 );
1379 let error_resp = make_ws_error_response(
1380 Some(id),
1381 -32603,
1382 "Internal error",
1383 );
1384 let mut sink = client_sink.lock().await;
1385 let _ =
1386 sink.send(Message::Text(error_resp.into())).await;
1387 continue;
1388 }
1389 }
1390 } else {
1391 text.to_string()
1392 };
1393
1394 // Track request→response mapping for output-schema
1395 // enforcement when upstream omits result._meta.tool.
1396 track_pending_tool_call(
1397 &state.sessions,
1398 &session_id,
1399 id,
1400 tool_name,
1401 );
1402
1403 let mut sink = upstream_sink.lock().await;
1404 if let Err(e) = sink
1405 .send(tokio_tungstenite::tungstenite::Message::Text(
1406 forward_text.into(),
1407 ))
1408 .await
1409 {
1410 tracing::error!(
1411 session_id = %session_id,
1412 "Failed to forward to upstream: {}",
1413 e
1414 );
1415 break;
1416 }
1417 }
1418 Verdict::Deny { ref reason } => {
1419 // Audit the denial with detailed reason
1420 if let Err(e) = state
1421 .audit
1422 .log_entry(
1423 &action,
1424 &verdict,
1425 json!({
1426 "source": "ws_proxy",
1427 "session": session_id,
1428 "transport": "websocket",
1429 }),
1430 )
1431 .await
1432 {
1433 tracing::error!(
1434 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
1435 e
1436 );
1437 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
1438 if state.audit_strict_mode {
1439 let error = make_ws_error_response(
1440 Some(id),
1441 -32000,
1442 "Audit logging failed — request denied (strict audit mode)",
1443 );
1444 let mut sink = client_sink.lock().await;
1445 let _ = sink.send(Message::Text(error.into())).await;
1446 continue;
1447 }
1448 }
1449
1450 // SECURITY (FIND-R46-012): Generic message to client.
1451 // Detailed reason is in the audit log only.
1452 let _ = reason; // used in audit above
1453 let error =
1454 make_ws_error_response(Some(id), -32001, "Denied by policy");
1455 let mut sink = client_sink.lock().await;
1456 let _ = sink.send(Message::Text(error.into())).await;
1457 }
1458 Verdict::RequireApproval { ref reason, .. } => {
1459 // Treat as deny for audit, but preserve approval semantics.
1460 let deny_reason = format!("Requires approval: {reason}");
1461 if let Err(e) = state
1462 .audit
1463 .log_entry(
1464 &action,
1465 &Verdict::Deny {
1466 reason: deny_reason.clone(),
1467 },
1468 json!({
1469 "source": "ws_proxy",
1470 "session": session_id,
1471 "transport": "websocket",
1472 }),
1473 )
1474 .await
1475 {
1476 tracing::error!(
1477 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
1478 e
1479 );
1480 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
1481 if state.audit_strict_mode {
1482 let error = make_ws_error_response(
1483 Some(id),
1484 -32000,
1485 "Audit logging failed — request denied (strict audit mode)",
1486 );
1487 let mut sink = client_sink.lock().await;
1488 let _ = sink.send(Message::Text(error.into())).await;
1489 continue;
1490 }
1491 }
1492 let approval_reason = "Approval required";
1493 let approval_id =
1494 create_ws_approval(&state, &session_id, &action, reason).await;
1495 let error = make_ws_error_response_with_data(
1496 Some(id),
1497 -32001,
1498 approval_reason,
1499 Some(json!({
1500 "verdict": "require_approval",
1501 "reason": approval_reason,
1502 "approval_id": approval_id,
1503 })),
1504 );
1505 let mut sink = client_sink.lock().await;
1506 let _ = sink.send(Message::Text(error.into())).await;
1507 }
1508 // Fail-closed: unknown Verdict variants produce Deny
1509 _ => {
1510 let error =
1511 make_ws_error_response(Some(id), -32001, "Denied by policy");
1512 let mut sink = client_sink.lock().await;
1513 let _ = sink.send(Message::Text(error.into())).await;
1514 }
1515 }
1516 }
1517 MessageType::ResourceRead { ref id, ref uri } => {
1518 // SECURITY (FIND-R74-007): Check for memory poisoning in resource URI.
1519 // ResourceRead is a likely exfiltration vector: a poisoned tool response
1520 // says "read this file" and the agent issues resources/read for that URI.
1521 // Parity with HTTP handler (handlers.rs:1472).
1522 {
1523 let poisoning_detected = state
1524 .sessions
1525 .get_mut(&session_id)
1526 .and_then(|session| {
1527 let uri_params = json!({"uri": uri});
1528 let matches =
1529 session.memory_tracker.check_parameters(&uri_params);
1530 if !matches.is_empty() {
1531 for m in &matches {
1532 tracing::warn!(
1533 "SECURITY: Memory poisoning detected in WS resources/read (session {}): \
1534 param '{}' contains replayed data (fingerprint: {})",
1535 session_id,
1536 m.param_location,
1537 m.fingerprint
1538 );
1539 }
1540 Some(matches.len())
1541 } else {
1542 None
1543 }
1544 });
1545 if let Some(match_count) = poisoning_detected {
1546 let poison_action = extractor::extract_resource_action(uri);
1547 let deny_reason = format!(
1548 "Memory poisoning detected: {match_count} replayed data fragment(s) in resources/read"
1549 );
1550 if let Err(e) = state
1551 .audit
1552 .log_entry(
1553 &poison_action,
1554 &Verdict::Deny {
1555 reason: deny_reason.clone(),
1556 },
1557 json!({
1558 "source": "ws_proxy",
1559 "session": session_id,
1560 "transport": "websocket",
1561 "event": "memory_poisoning_detected",
1562 "matches": match_count,
1563 "uri": uri,
1564 }),
1565 )
1566 .await
1567 {
1568 tracing::warn!(
1569 "Failed to audit WS resource memory poisoning: {}",
1570 e
1571 );
1572 }
1573 let error = make_ws_error_response(
1574 Some(id),
1575 -32001,
1576 "Request blocked: security policy violation",
1577 );
1578 let mut sink = client_sink.lock().await;
1579 let _ = sink.send(Message::Text(error.into())).await;
1580 continue;
1581 }
1582 }
1583
1584 // SECURITY (FIND-R115-041): Rug-pull detection for resource URIs.
1585 // If the upstream server was flagged (annotations changed since initial
1586 // tools/list), block resource reads from that server.
1587 // Parity with HTTP handler (handlers.rs:1555).
1588 // SECURITY (R240-PROXY-1): Fall back to global registry on session miss.
1589 {
1590 let is_flagged = state
1591 .sessions
1592 .get_mut(&session_id)
1593 .map(|s| s.flagged_tools.contains(uri.as_str()))
1594 .unwrap_or_else(|| {
1595 state.sessions.is_tool_globally_flagged(uri.as_str())
1596 });
1597 if is_flagged {
1598 let action = extractor::extract_resource_action(uri);
1599 let verdict = Verdict::Deny {
1600 reason: format!(
1601 "Resource '{uri}' blocked: server flagged by rug-pull detection"
1602 ),
1603 };
1604 if let Err(e) = state
1605 .audit
1606 .log_entry(
1607 &action,
1608 &verdict,
1609 json!({
1610 "source": "ws_proxy",
1611 "session": session_id,
1612 "transport": "websocket",
1613 "event": "rug_pull_resource_blocked",
1614 "uri": uri,
1615 }),
1616 )
1617 .await
1618 {
1619 tracing::warn!(
1620 "Failed to audit WS resource rug-pull block: {}",
1621 e
1622 );
1623 }
1624 let error =
1625 make_ws_error_response(Some(id), -32001, "Denied by policy");
1626 let mut sink = client_sink.lock().await;
1627 let _ = sink.send(Message::Text(error.into())).await;
1628 continue;
1629 }
1630 }
1631
1632 // Build action for resource read
1633 let mut action = extractor::extract_resource_action(uri);
1634
1635 // SECURITY (FIND-R75-002): DNS resolution for resource reads.
1636 // Parity with HTTP handler (handlers.rs:1543).
1637 if state.engine.has_ip_rules() {
1638 super::helpers::resolve_domains(&mut action).await;
1639 }
1640
1641 // SECURITY (FIND-R116-004): DLP scan on resource URI.
1642 // Parity with HTTP handler (handlers.rs:1598).
1643 {
1644 let uri_params = json!({"uri": uri});
1645 let dlp_findings = scan_parameters_for_secrets(&uri_params);
1646 if !dlp_findings.is_empty() {
1647 for finding in &dlp_findings {
1648 record_dlp_finding(&finding.pattern_name);
1649 }
1650 tracing::warn!(
1651 "SECURITY: Secret detected in WS resource URI! Session: {}, URI: [redacted]",
1652 session_id,
1653 );
1654 let audit_verdict = Verdict::Deny {
1655 reason: "DLP blocked: secret detected in resource URI"
1656 .to_string(),
1657 };
1658 if let Err(e) = state.audit.log_entry(
1659 &action, &audit_verdict,
1660 json!({
1661 "source": "ws_proxy", "session": session_id,
1662 "transport": "websocket", "event": "resource_uri_dlp_alert",
1663 }),
1664 ).await {
1665 tracing::warn!("Failed to audit WS resource URI DLP: {}", e);
1666 }
1667 let error = make_ws_error_response(
1668 Some(id),
1669 -32001,
1670 "Request blocked: security policy violation",
1671 );
1672 let mut sink = client_sink.lock().await;
1673 let _ = sink.send(Message::Text(error.into())).await;
1674 continue;
1675 }
1676 }
1677
1678 // SECURITY (FIND-R115-042): Circuit breaker check for resource reads.
1679 // Parity with HTTP handler (handlers.rs:1668) — prevent resource reads
1680 // from hammering a failing upstream server.
1681 if let Some(ref circuit_breaker) = state.circuit_breaker {
1682 if let Err(reason) = circuit_breaker.can_proceed(uri) {
1683 tracing::warn!(
1684 "SECURITY: WS circuit breaker open for resource '{}' in session {}: {}",
1685 uri,
1686 session_id,
1687 reason
1688 );
1689 let verdict = Verdict::Deny {
1690 reason: format!("Circuit breaker open: {reason}"),
1691 };
1692 if let Err(e) = state
1693 .audit
1694 .log_entry(
1695 &action,
1696 &verdict,
1697 json!({
1698 "source": "ws_proxy",
1699 "session": session_id,
1700 "transport": "websocket",
1701 "event": "circuit_breaker_rejected",
1702 "uri": uri,
1703 }),
1704 )
1705 .await
1706 {
1707 tracing::warn!(
1708 "Failed to audit WS resource circuit breaker rejection: {}",
1709 e
1710 );
1711 }
1712 let error = make_ws_error_response(
1713 Some(id),
1714 -32001,
1715 "Service temporarily unavailable",
1716 );
1717 let mut sink = client_sink.lock().await;
1718 let _ = sink.send(Message::Text(error.into())).await;
1719 continue;
1720 }
1721 }
1722
1723 // SECURITY (FIND-R130-002): TOCTOU-safe context+eval+update
1724 // for resource reads. Matches ToolCall fix above and HTTP
1725 // handler FIND-R112-002 pattern (handlers.rs:1711-1774).
1726 let (verdict, ctx) = if let Some(mut session) =
1727 state.sessions.get_mut(&session_id)
1728 {
1729 let ctx = EvaluationContext {
1730 timestamp: None,
1731 agent_id: session.oauth_subject.clone(),
1732 agent_identity: session.agent_identity.clone(),
1733 call_counts: session.call_counts.clone(),
1734 previous_actions: session.action_history.iter().cloned().collect(),
1735 call_chain: session.current_call_chain.clone(),
1736 tenant_id: None,
1737 verification_tier: None,
1738 capability_token: None,
1739 session_state: None,
1740 };
1741
1742 let verdict = match state.engine.evaluate_action_with_context(
1743 &action,
1744 &state.policies,
1745 Some(&ctx),
1746 ) {
1747 Ok(v) => v,
1748 Err(e) => {
1749 tracing::error!(
1750 session_id = %session_id,
1751 "Resource policy evaluation error: {}",
1752 e
1753 );
1754 Verdict::Deny {
1755 reason: format!("Policy evaluation failed: {e}"),
1756 }
1757 }
1758 };
1759
1760 // Atomically update session on Allow
1761 if matches!(verdict, Verdict::Allow) {
1762 session.touch();
1763 use crate::proxy::call_chain::{
1764 MAX_ACTION_HISTORY, MAX_CALL_COUNT_TOOLS,
1765 };
1766 let resource_key = format!(
1767 "resources/read:{}",
1768 uri.chars().take(128).collect::<String>()
1769 );
1770 if session.call_counts.len() < MAX_CALL_COUNT_TOOLS
1771 || session.call_counts.contains_key(&resource_key)
1772 {
1773 let count =
1774 session.call_counts.entry(resource_key).or_insert(0);
1775 *count = count.saturating_add(1);
1776 }
1777 if session.action_history.len() >= MAX_ACTION_HISTORY {
1778 session.action_history.pop_front();
1779 }
1780 session
1781 .action_history
1782 .push_back("resources/read".to_string());
1783 }
1784
1785 (verdict, ctx)
1786 } else {
1787 let verdict = match state.engine.evaluate_action_with_context(
1788 &action,
1789 &state.policies,
1790 None,
1791 ) {
1792 Ok(v) => v,
1793 Err(e) => {
1794 tracing::error!(
1795 session_id = %session_id,
1796 "Resource policy evaluation error: {}",
1797 e
1798 );
1799 Verdict::Deny {
1800 reason: format!("Policy evaluation failed: {e}"),
1801 }
1802 }
1803 };
1804 (verdict, EvaluationContext::default())
1805 };
1806
1807 match verdict {
1808 Verdict::Allow => {
1809 // SECURITY (FIND-R116-002): ABAC refinement for resource reads.
1810 // Parity with HTTP handler (handlers.rs:1783) and gRPC (service.rs:972).
1811 if let Some(ref abac) = state.abac_engine {
1812 let principal_id =
1813 ctx.agent_id.as_deref().unwrap_or("anonymous");
1814 let principal_type = ctx.principal_type();
1815 let session_risk = state
1816 .sessions
1817 .get_mut(&session_id)
1818 .and_then(|s| s.risk_score.clone());
1819 let abac_ctx = vellaveto_engine::abac::AbacEvalContext {
1820 eval_ctx: &ctx,
1821 principal_type,
1822 principal_id,
1823 risk_score: session_risk.as_ref(),
1824 };
1825 match abac.evaluate(&action, &abac_ctx) {
1826 vellaveto_engine::abac::AbacDecision::Deny {
1827 policy_id,
1828 reason,
1829 } => {
1830 let deny_verdict = Verdict::Deny {
1831 reason: reason.clone(),
1832 };
1833 if let Err(e) = state
1834 .audit
1835 .log_entry(
1836 &action,
1837 &deny_verdict,
1838 json!({
1839 "source": "ws_proxy",
1840 "session": session_id,
1841 "transport": "websocket",
1842 "event": "abac_deny",
1843 "abac_policy": policy_id,
1844 "uri": uri,
1845 }),
1846 )
1847 .await
1848 {
1849 tracing::warn!(
1850 "Failed to audit WS resource ABAC deny: {}",
1851 e
1852 );
1853 }
1854 let error_resp = make_ws_error_response(
1855 Some(id),
1856 -32001,
1857 "Denied by policy",
1858 );
1859 let mut sink = client_sink.lock().await;
1860 let _ =
1861 sink.send(Message::Text(error_resp.into())).await;
1862 continue;
1863 }
1864 vellaveto_engine::abac::AbacDecision::Allow {
1865 policy_id,
1866 } => {
1867 // SECURITY (FIND-R192-002): record_usage parity.
1868 if let Some(ref la) = state.least_agency {
1869 la.record_usage(
1870 principal_id,
1871 &session_id,
1872 &policy_id,
1873 uri,
1874 &action.function,
1875 );
1876 }
1877 }
1878 vellaveto_engine::abac::AbacDecision::NoMatch => {
1879 // Fall through — existing Allow stands
1880 }
1881 #[allow(unreachable_patterns)]
1882 _ => {
1883 tracing::warn!(
1884 "Unknown AbacDecision variant in WS resource_read — fail-closed"
1885 );
1886 let error_resp = make_ws_error_response(
1887 Some(id),
1888 -32001,
1889 "Denied by policy",
1890 );
1891 let mut sink = client_sink.lock().await;
1892 let _ =
1893 sink.send(Message::Text(error_resp.into())).await;
1894 continue;
1895 }
1896 }
1897 }
1898
1899 // NOTE: Session touch + call_counts/action_history
1900 // update already performed inside the TOCTOU-safe
1901 // block above (FIND-R130-002). No separate update here.
1902
1903 // SECURITY (FIND-R46-WS-004): Audit log allowed resource reads
1904 if let Err(e) = state
1905 .audit
1906 .log_entry(
1907 &action,
1908 &Verdict::Allow,
1909 json!({
1910 "source": "ws_proxy",
1911 "session": session_id,
1912 "transport": "websocket",
1913 "resource_uri": uri,
1914 }),
1915 )
1916 .await
1917 {
1918 tracing::error!(
1919 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
1920 e
1921 );
1922 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
1923 if state.audit_strict_mode {
1924 let error = make_ws_error_response(
1925 Some(id),
1926 -32000,
1927 "Audit logging failed — request denied (strict audit mode)",
1928 );
1929 let mut sink = client_sink.lock().await;
1930 let _ = sink.send(Message::Text(error.into())).await;
1931 continue;
1932 }
1933 }
1934
1935 // SECURITY (FIND-R46-011): Fail-closed on canonicalization
1936 // failure. Do NOT fall back to original text.
1937 let forward_text = if state.canonicalize {
1938 match serde_json::to_string(&parsed) {
1939 Ok(canonical) => canonical,
1940 Err(e) => {
1941 tracing::error!(
1942 "SECURITY: WS resource canonicalization failed: {}",
1943 e
1944 );
1945 let error_resp = make_ws_error_response(
1946 Some(id),
1947 -32603,
1948 "Internal error",
1949 );
1950 let mut sink = client_sink.lock().await;
1951 let _ =
1952 sink.send(Message::Text(error_resp.into())).await;
1953 continue;
1954 }
1955 }
1956 } else {
1957 text.to_string()
1958 };
1959 let mut sink = upstream_sink.lock().await;
1960 if let Err(e) = sink
1961 .send(tokio_tungstenite::tungstenite::Message::Text(
1962 forward_text.into(),
1963 ))
1964 .await
1965 {
1966 tracing::error!("Failed to forward resource read: {}", e);
1967 break;
1968 }
1969 }
1970 // SECURITY (FIND-R116-009): Separate handling for Deny vs RequireApproval
1971 // with per-verdict audit logging. Parity with gRPC (service.rs:1051-1076).
1972 Verdict::Deny { ref reason } => {
1973 if let Err(e) = state
1974 .audit
1975 .log_entry(
1976 &action,
1977 &Verdict::Deny {
1978 reason: reason.clone(),
1979 },
1980 json!({
1981 "source": "ws_proxy",
1982 "session": session_id,
1983 "transport": "websocket",
1984 "resource_uri": uri,
1985 }),
1986 )
1987 .await
1988 {
1989 tracing::error!(
1990 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
1991 e
1992 );
1993 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
1994 if state.audit_strict_mode {
1995 let error = make_ws_error_response(
1996 Some(id),
1997 -32000,
1998 "Audit logging failed — request denied (strict audit mode)",
1999 );
2000 let mut sink = client_sink.lock().await;
2001 let _ = sink.send(Message::Text(error.into())).await;
2002 continue;
2003 }
2004 }
2005 let error =
2006 make_ws_error_response(Some(id), -32001, "Denied by policy");
2007 let mut sink = client_sink.lock().await;
2008 let _ = sink.send(Message::Text(error.into())).await;
2009 }
2010 Verdict::RequireApproval { ref reason, .. } => {
2011 let deny_reason = format!("Requires approval: {reason}");
2012 if let Err(e) = state
2013 .audit
2014 .log_entry(
2015 &action,
2016 &Verdict::Deny {
2017 reason: deny_reason,
2018 },
2019 json!({
2020 "source": "ws_proxy",
2021 "session": session_id,
2022 "transport": "websocket",
2023 "resource_uri": uri,
2024 "event": "require_approval",
2025 }),
2026 )
2027 .await
2028 {
2029 tracing::error!(
2030 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
2031 e
2032 );
2033 // SECURITY (FIND-CREATIVE-003): Strict audit mode — fail-closed.
2034 if state.audit_strict_mode {
2035 let error = make_ws_error_response(
2036 Some(id),
2037 -32000,
2038 "Audit logging failed — request denied (strict audit mode)",
2039 );
2040 let mut sink = client_sink.lock().await;
2041 let _ = sink.send(Message::Text(error.into())).await;
2042 continue;
2043 }
2044 }
2045 let error =
2046 make_ws_error_response(Some(id), -32001, "Denied by policy");
2047 let mut sink = client_sink.lock().await;
2048 let _ = sink.send(Message::Text(error.into())).await;
2049 }
2050 #[allow(unreachable_patterns)]
2051 _ => {
2052 // SECURITY: Future variants — fail-closed.
2053 tracing::warn!(
2054 "Unknown Verdict variant in WS resource_read — fail-closed"
2055 );
2056 let error =
2057 make_ws_error_response(Some(id), -32001, "Denied by policy");
2058 let mut sink = client_sink.lock().await;
2059 let _ = sink.send(Message::Text(error.into())).await;
2060 }
2061 }
2062 }
2063 MessageType::Batch => {
2064 // Reject batches per MCP spec
2065 let error = json!({
2066 "jsonrpc": "2.0",
2067 "error": {
2068 "code": -32600,
2069 "message": "JSON-RPC batch requests are not supported"
2070 },
2071 "id": null
2072 });
2073 let error_text = serde_json::to_string(&error)
2074 .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Batch not supported"},"id":null}"#.to_string());
2075 let mut sink = client_sink.lock().await;
2076 let _ = sink.send(Message::Text(error_text.into())).await;
2077 }
2078 MessageType::Invalid { ref id, ref reason } => {
2079 tracing::warn!(
2080 "Invalid JSON-RPC request in WebSocket transport: {}",
2081 reason
2082 );
2083 let error =
2084 make_ws_error_response(Some(id), -32600, "Invalid JSON-RPC request");
2085 let mut sink = client_sink.lock().await;
2086 let _ = sink.send(Message::Text(error.into())).await;
2087 }
2088 MessageType::SamplingRequest { ref id } => {
2089 // SECURITY (FIND-R74-006): Call inspect_sampling() for full
2090 // verdict (enabled + model filter + tool output check + rate limit),
2091 // matching HTTP handler parity (handlers.rs:1681).
2092 let params = parsed.get("params").cloned().unwrap_or(json!({}));
2093 // SECURITY (FIND-R125-001): Per-session sampling rate limit
2094 // parity with elicitation. Atomically read + increment.
2095 let sampling_verdict = {
2096 let mut session_ref = state.sessions.get_mut(&session_id);
2097 let current_count =
2098 session_ref.as_ref().map(|s| s.sampling_count).unwrap_or(0);
2099 let verdict = vellaveto_mcp::elicitation::inspect_sampling(
2100 ¶ms,
2101 &state.sampling_config,
2102 current_count,
2103 );
2104 if matches!(verdict, vellaveto_mcp::elicitation::SamplingVerdict::Allow)
2105 {
2106 if let Some(ref mut s) = session_ref {
2107 s.sampling_count = s.sampling_count.saturating_add(1);
2108 }
2109 }
2110 verdict
2111 };
2112 match sampling_verdict {
2113 vellaveto_mcp::elicitation::SamplingVerdict::Allow => {
2114 // Forward allowed sampling request
2115 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
2116 // Falling back to original text would create a TOCTOU gap.
2117 let forward_text = if state.canonicalize {
2118 match serde_json::to_string(&parsed) {
2119 Ok(canonical) => canonical,
2120 Err(e) => {
2121 tracing::error!(
2122 "SECURITY: WS sampling canonicalization failed: {}",
2123 e
2124 );
2125 let error_resp = make_ws_error_response(
2126 Some(id),
2127 -32603,
2128 "Internal error",
2129 );
2130 let mut sink = client_sink.lock().await;
2131 let _ =
2132 sink.send(Message::Text(error_resp.into())).await;
2133 continue;
2134 }
2135 }
2136 } else {
2137 text.to_string()
2138 };
2139 let mut sink = upstream_sink.lock().await;
2140 let _ = sink
2141 .send(tokio_tungstenite::tungstenite::Message::Text(
2142 forward_text.into(),
2143 ))
2144 .await;
2145 }
2146 vellaveto_mcp::elicitation::SamplingVerdict::Deny { reason } => {
2147 tracing::warn!(
2148 session_id = %session_id,
2149 "Blocked WS sampling/createMessage: {}",
2150 reason
2151 );
2152 let action = Action::new(
2153 "vellaveto",
2154 "ws_sampling_interception",
2155 json!({
2156 "method": "sampling/createMessage",
2157 "session": session_id,
2158 "transport": "websocket",
2159 "reason": &reason,
2160 }),
2161 );
2162 let verdict = Verdict::Deny {
2163 reason: reason.clone(),
2164 };
2165 if let Err(e) = state
2166 .audit
2167 .log_entry(
2168 &action,
2169 &verdict,
2170 json!({
2171 "source": "ws_proxy",
2172 "event": "ws_sampling_interception",
2173 }),
2174 )
2175 .await
2176 {
2177 tracing::warn!(
2178 "Failed to audit WS sampling interception: {}",
2179 e
2180 );
2181 }
2182 // SECURITY: Generic message to client — detailed reason
2183 // is in the audit log, not leaked to the client.
2184 let error = make_ws_error_response(
2185 Some(id),
2186 -32001,
2187 "sampling/createMessage blocked by policy",
2188 );
2189 let mut sink = client_sink.lock().await;
2190 let _ = sink.send(Message::Text(error.into())).await;
2191 }
2192 }
2193 }
2194 MessageType::TaskRequest {
2195 ref id,
2196 ref task_method,
2197 ref task_id,
2198 } => {
2199 // SECURITY (FIND-R76-001): Memory poisoning detection on task params.
2200 // Parity with HTTP handler (handlers.rs:2027-2084). Agents could
2201 // exfiltrate poisoned data via task management operations.
2202 {
2203 let task_params = parsed.get("params").cloned().unwrap_or(json!({}));
2204 let poisoning_detected = state
2205 .sessions
2206 .get_mut(&session_id)
2207 .and_then(|session| {
2208 let matches =
2209 session.memory_tracker.check_parameters(&task_params);
2210 if !matches.is_empty() {
2211 for m in &matches {
2212 tracing::warn!(
2213 "SECURITY: Memory poisoning detected in WS task '{}' (session {}): \
2214 param '{}' contains replayed data (fingerprint: {})",
2215 task_method,
2216 session_id,
2217 m.param_location,
2218 m.fingerprint
2219 );
2220 }
2221 Some(matches.len())
2222 } else {
2223 None
2224 }
2225 });
2226 if let Some(match_count) = poisoning_detected {
2227 let poison_action =
2228 extractor::extract_task_action(task_method, task_id.as_deref());
2229 let deny_reason = format!(
2230 "Memory poisoning detected: {match_count} replayed data fragment(s) in task '{task_method}'"
2231 );
2232 if let Err(e) = state
2233 .audit
2234 .log_entry(
2235 &poison_action,
2236 &Verdict::Deny {
2237 reason: deny_reason,
2238 },
2239 json!({
2240 "source": "ws_proxy",
2241 "session": session_id,
2242 "transport": "websocket",
2243 "event": "memory_poisoning_detected",
2244 "matches": match_count,
2245 "task_method": task_method,
2246 }),
2247 )
2248 .await
2249 {
2250 tracing::warn!(
2251 "Failed to audit WS task memory poisoning: {}",
2252 e
2253 );
2254 }
2255 let error = make_ws_error_response(
2256 Some(id),
2257 -32001,
2258 "Request blocked: security policy violation",
2259 );
2260 let mut sink = client_sink.lock().await;
2261 let _ = sink.send(Message::Text(error.into())).await;
2262 continue;
2263 }
2264 }
2265
2266 // SECURITY (FIND-R76-001): DLP scan task request parameters.
2267 // Parity with HTTP handler (handlers.rs:2086-2145). Agents could
2268 // embed secrets in task_id or params to exfiltrate them.
2269 {
2270 let task_params = parsed.get("params").cloned().unwrap_or(json!({}));
2271 let dlp_findings = scan_parameters_for_secrets(&task_params);
2272 if !dlp_findings.is_empty() {
2273 for finding in &dlp_findings {
2274 record_dlp_finding(&finding.pattern_name);
2275 }
2276 let patterns: Vec<String> = dlp_findings
2277 .iter()
2278 .map(|f| format!("{} at {}", f.pattern_name, f.location))
2279 .collect();
2280 tracing::warn!(
2281 "SECURITY: DLP blocking WS task '{}' in session {}: {:?}",
2282 task_method,
2283 session_id,
2284 patterns
2285 );
2286 let dlp_action =
2287 extractor::extract_task_action(task_method, task_id.as_deref());
2288 if let Err(e) = state
2289 .audit
2290 .log_entry(
2291 &dlp_action,
2292 &Verdict::Deny {
2293 reason: format!(
2294 "DLP: secrets detected in task request: {patterns:?}"
2295 ),
2296 },
2297 json!({
2298 "source": "ws_proxy",
2299 "session": session_id,
2300 "transport": "websocket",
2301 "event": "dlp_secret_detected_task",
2302 "task_method": task_method,
2303 "findings": patterns,
2304 }),
2305 )
2306 .await
2307 {
2308 tracing::warn!("Failed to audit WS task DLP: {}", e);
2309 }
2310 let error = make_ws_error_response(
2311 Some(id),
2312 -32001,
2313 "Request blocked: security policy violation",
2314 );
2315 let mut sink = client_sink.lock().await;
2316 let _ = sink.send(Message::Text(error.into())).await;
2317 continue;
2318 }
2319 }
2320
2321 // Policy-evaluate task requests (async operations)
2322 let action =
2323 extractor::extract_task_action(task_method, task_id.as_deref());
2324 // SECURITY (FIND-R130-002): TOCTOU-safe context+eval for task
2325 // requests. Context is built inside the DashMap shard lock to
2326 // prevent stale snapshot evaluation races.
2327 // SECURITY (FIND-R190-006): Update session state on Allow
2328 // (touch + call_counts + action_history) while still holding
2329 // the shard lock, matching ToolCall/ResourceRead parity.
2330 let (verdict, task_eval_ctx) = if let Some(mut session) =
2331 state.sessions.get_mut(&session_id)
2332 {
2333 let ctx = EvaluationContext {
2334 timestamp: None,
2335 agent_id: session.oauth_subject.clone(),
2336 agent_identity: session.agent_identity.clone(),
2337 call_counts: session.call_counts.clone(),
2338 previous_actions: session.action_history.iter().cloned().collect(),
2339 call_chain: session.current_call_chain.clone(),
2340 tenant_id: None,
2341 verification_tier: None,
2342 capability_token: None,
2343 session_state: None,
2344 };
2345 let verdict = match state.engine.evaluate_action_with_context(
2346 &action,
2347 &state.policies,
2348 Some(&ctx),
2349 ) {
2350 Ok(v) => v,
2351 Err(e) => {
2352 tracing::error!(
2353 session_id = %session_id,
2354 "Task policy evaluation error: {}", e
2355 );
2356 Verdict::Deny {
2357 reason: format!("Policy evaluation failed: {e}"),
2358 }
2359 }
2360 };
2361
2362 // Update session atomically on Allow
2363 if matches!(verdict, Verdict::Allow) {
2364 session.touch();
2365 use crate::proxy::call_chain::{
2366 MAX_ACTION_HISTORY, MAX_CALL_COUNT_TOOLS,
2367 };
2368 if session.call_counts.len() < MAX_CALL_COUNT_TOOLS
2369 || session.call_counts.contains_key(task_method)
2370 {
2371 let count = session
2372 .call_counts
2373 .entry(task_method.to_string())
2374 .or_insert(0);
2375 *count = count.saturating_add(1);
2376 }
2377 if session.action_history.len() >= MAX_ACTION_HISTORY {
2378 session.action_history.pop_front();
2379 }
2380 session.action_history.push_back(task_method.to_string());
2381 }
2382
2383 (verdict, ctx)
2384 } else {
2385 let verdict = match state.engine.evaluate_action_with_context(
2386 &action,
2387 &state.policies,
2388 None,
2389 ) {
2390 Ok(v) => v,
2391 Err(e) => {
2392 tracing::error!(
2393 session_id = %session_id,
2394 "Task policy evaluation error: {}", e
2395 );
2396 Verdict::Deny {
2397 reason: format!("Policy evaluation failed: {e}"),
2398 }
2399 }
2400 };
2401 (verdict, EvaluationContext::default())
2402 };
2403
2404 match verdict {
2405 Verdict::Allow => {
2406 // SECURITY (FIND-R190-001): ABAC refinement for TaskRequest,
2407 // matching ToolCall/ResourceRead parity.
2408 if let Some(ref abac) = state.abac_engine {
2409 let principal_id =
2410 task_eval_ctx.agent_id.as_deref().unwrap_or("anonymous");
2411 let principal_type = task_eval_ctx.principal_type();
2412 let session_risk = state
2413 .sessions
2414 .get_mut(&session_id)
2415 .and_then(|s| s.risk_score.clone());
2416 let abac_ctx = vellaveto_engine::abac::AbacEvalContext {
2417 eval_ctx: &task_eval_ctx,
2418 principal_type,
2419 principal_id,
2420 risk_score: session_risk.as_ref(),
2421 };
2422 match abac.evaluate(&action, &abac_ctx) {
2423 vellaveto_engine::abac::AbacDecision::Deny {
2424 policy_id,
2425 reason,
2426 } => {
2427 let deny_verdict = Verdict::Deny {
2428 reason: reason.clone(),
2429 };
2430 if let Err(e) = state
2431 .audit
2432 .log_entry(
2433 &action,
2434 &deny_verdict,
2435 json!({
2436 "source": "ws_proxy",
2437 "session": session_id,
2438 "transport": "websocket",
2439 "event": "abac_deny",
2440 "abac_policy": policy_id,
2441 "task_method": task_method,
2442 }),
2443 )
2444 .await
2445 {
2446 tracing::warn!(
2447 "Failed to audit WS task ABAC deny: {}",
2448 e
2449 );
2450 }
2451 let error_resp = make_ws_error_response(
2452 Some(id),
2453 -32001,
2454 "Denied by policy",
2455 );
2456 let mut sink = client_sink.lock().await;
2457 let _ =
2458 sink.send(Message::Text(error_resp.into())).await;
2459 continue;
2460 }
2461 vellaveto_engine::abac::AbacDecision::Allow {
2462 policy_id,
2463 } => {
2464 if let Some(ref la) = state.least_agency {
2465 la.record_usage(
2466 principal_id,
2467 &session_id,
2468 &policy_id,
2469 task_method,
2470 &action.function,
2471 );
2472 }
2473 }
2474 vellaveto_engine::abac::AbacDecision::NoMatch => {
2475 // Fall through — existing Allow stands
2476 }
2477 #[allow(unreachable_patterns)]
2478 _ => {
2479 tracing::warn!(
2480 "Unknown AbacDecision variant — fail-closed"
2481 );
2482 let error_resp = make_ws_error_response(
2483 Some(id),
2484 -32001,
2485 "Denied by policy",
2486 );
2487 let mut sink = client_sink.lock().await;
2488 let _ =
2489 sink.send(Message::Text(error_resp.into())).await;
2490 continue;
2491 }
2492 }
2493 }
2494
2495 if let Err(e) = state
2496 .audit
2497 .log_entry(
2498 &action,
2499 &Verdict::Allow,
2500 json!({
2501 "source": "ws_proxy",
2502 "session": session_id,
2503 "transport": "websocket",
2504 "task_method": task_method,
2505 }),
2506 )
2507 .await
2508 {
2509 tracing::warn!("Failed to audit WS task allow: {}", e);
2510 }
2511 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
2512 let forward_text = if state.canonicalize {
2513 match serde_json::to_string(&parsed) {
2514 Ok(canonical) => canonical,
2515 Err(e) => {
2516 tracing::error!(
2517 "SECURITY: WS task canonicalization failed: {}",
2518 e
2519 );
2520 let error_resp = make_ws_error_response(
2521 Some(id),
2522 -32603,
2523 "Internal error",
2524 );
2525 let mut sink = client_sink.lock().await;
2526 let _ =
2527 sink.send(Message::Text(error_resp.into())).await;
2528 continue;
2529 }
2530 }
2531 } else {
2532 text.to_string()
2533 };
2534 let mut sink = upstream_sink.lock().await;
2535 if let Err(e) = sink
2536 .send(tokio_tungstenite::tungstenite::Message::Text(
2537 forward_text.into(),
2538 ))
2539 .await
2540 {
2541 tracing::error!("Failed to forward task request: {}", e);
2542 break;
2543 }
2544 }
2545 Verdict::Deny { ref reason } => {
2546 if let Err(e) = state
2547 .audit
2548 .log_entry(
2549 &action,
2550 &Verdict::Deny {
2551 reason: reason.clone(),
2552 },
2553 json!({
2554 "source": "ws_proxy",
2555 "session": session_id,
2556 "transport": "websocket",
2557 "task_method": task_method,
2558 }),
2559 )
2560 .await
2561 {
2562 tracing::error!(
2563 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
2564 e
2565 );
2566 // SECURITY (FIND-R213-002): Strict audit mode — fail-closed.
2567 if state.audit_strict_mode {
2568 let error = make_ws_error_response(
2569 Some(id),
2570 -32000,
2571 "Audit logging failed — request denied (strict audit mode)",
2572 );
2573 let mut sink = client_sink.lock().await;
2574 let _ = sink.send(Message::Text(error.into())).await;
2575 continue;
2576 }
2577 }
2578 // SECURITY (FIND-R55-WS-005): Generic denial message to prevent
2579 // leaking policy names/details. Detailed reason is in audit log.
2580 let denial =
2581 make_ws_error_response(Some(id), -32001, "Denied by policy");
2582 let mut sink = client_sink.lock().await;
2583 let _ = sink.send(Message::Text(denial.into())).await;
2584 }
2585 Verdict::RequireApproval { ref reason, .. } => {
2586 let deny_reason = format!("Requires approval: {reason}");
2587 if let Err(e) = state
2588 .audit
2589 .log_entry(
2590 &action,
2591 &Verdict::Deny {
2592 reason: deny_reason,
2593 },
2594 json!({
2595 "source": "ws_proxy",
2596 "session": session_id,
2597 "transport": "websocket",
2598 "task_method": task_method,
2599 }),
2600 )
2601 .await
2602 {
2603 tracing::error!(
2604 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
2605 e
2606 );
2607 // SECURITY (FIND-R213-002): Strict audit mode — fail-closed.
2608 if state.audit_strict_mode {
2609 let error = make_ws_error_response(
2610 Some(id),
2611 -32000,
2612 "Audit logging failed — request denied (strict audit mode)",
2613 );
2614 let mut sink = client_sink.lock().await;
2615 let _ = sink.send(Message::Text(error.into())).await;
2616 continue;
2617 }
2618 }
2619 let approval_reason = "Approval required";
2620 let approval_id =
2621 create_ws_approval(&state, &session_id, &action, reason).await;
2622 let denial = make_ws_error_response_with_data(
2623 Some(id),
2624 -32001,
2625 approval_reason,
2626 Some(json!({
2627 "verdict": "require_approval",
2628 "reason": approval_reason,
2629 "approval_id": approval_id,
2630 })),
2631 );
2632 let mut sink = client_sink.lock().await;
2633 let _ = sink.send(Message::Text(denial.into())).await;
2634 }
2635 _ => {
2636 let denial =
2637 make_ws_error_response(Some(id), -32001, "Denied by policy");
2638 let mut sink = client_sink.lock().await;
2639 let _ = sink.send(Message::Text(denial.into())).await;
2640 }
2641 }
2642 }
2643 MessageType::ExtensionMethod {
2644 ref id,
2645 ref extension_id,
2646 ref method,
2647 } => {
2648 // Policy-evaluate extension method calls
2649 let params = parsed.get("params").cloned().unwrap_or(json!({}));
2650
2651 // SECURITY (FIND-R116-001): DLP scan extension method parameters.
2652 // Parity with gRPC handle_extension_method (service.rs:1542).
2653 let dlp_findings = scan_parameters_for_secrets(¶ms);
2654 if !dlp_findings.is_empty() {
2655 for finding in &dlp_findings {
2656 record_dlp_finding(&finding.pattern_name);
2657 }
2658 let patterns: Vec<String> = dlp_findings
2659 .iter()
2660 .map(|f| format!("{}:{}", f.pattern_name, f.location))
2661 .collect();
2662 tracing::warn!(
2663 "SECURITY: Secrets in WS extension method parameters! Session: {}, Extension: {}:{}, Findings: {:?}",
2664 session_id, extension_id, method, patterns,
2665 );
2666 let action =
2667 extractor::extract_extension_action(extension_id, method, ¶ms);
2668 let audit_verdict = Verdict::Deny {
2669 reason: format!(
2670 "DLP blocked: secret detected in extension parameters: {patterns:?}"
2671 ),
2672 };
2673 if let Err(e) = state.audit.log_entry(
2674 &action, &audit_verdict,
2675 json!({
2676 "source": "ws_proxy", "session": session_id, "transport": "websocket",
2677 "event": "ws_extension_parameter_dlp_alert",
2678 "extension_id": extension_id, "method": method, "findings": patterns,
2679 }),
2680 ).await {
2681 tracing::warn!("Failed to audit WS extension parameter DLP: {}", e);
2682 }
2683 let denial =
2684 make_ws_error_response(Some(id), -32001, "Denied by policy");
2685 let mut sink = client_sink.lock().await;
2686 let _ = sink.send(Message::Text(denial.into())).await;
2687 continue;
2688 }
2689
2690 // SECURITY (FIND-R116-001): Memory poisoning detection for extension params.
2691 // Parity with gRPC handle_extension_method (service.rs:1574).
2692 if let Some(session) = state.sessions.get_mut(&session_id) {
2693 let poisoning_matches =
2694 session.memory_tracker.check_parameters(¶ms);
2695 if !poisoning_matches.is_empty() {
2696 for m in &poisoning_matches {
2697 tracing::warn!(
2698 "SECURITY: Memory poisoning in WS extension '{}:{}' (session {}): \
2699 param '{}' replayed data (fingerprint: {})",
2700 extension_id, method, session_id, m.param_location, m.fingerprint
2701 );
2702 }
2703 let action = extractor::extract_extension_action(
2704 extension_id,
2705 method,
2706 ¶ms,
2707 );
2708 let deny_reason = format!(
2709 "Memory poisoning detected: {} replayed data fragment(s) in extension '{}:{}'",
2710 poisoning_matches.len(), extension_id, method
2711 );
2712 if let Err(e) = state.audit.log_entry(
2713 &action,
2714 &Verdict::Deny { reason: deny_reason.clone() },
2715 json!({
2716 "source": "ws_proxy", "session": session_id, "transport": "websocket",
2717 "event": "memory_poisoning_detected",
2718 "matches": poisoning_matches.len(),
2719 "extension_id": extension_id, "method": method,
2720 }),
2721 ).await {
2722 tracing::warn!("Failed to audit WS extension memory poisoning: {}", e);
2723 }
2724 let denial =
2725 make_ws_error_response(Some(id), -32001, "Denied by policy");
2726 let mut sink = client_sink.lock().await;
2727 let _ = sink.send(Message::Text(denial.into())).await;
2728 continue;
2729 }
2730 }
2731
2732 let mut action =
2733 extractor::extract_extension_action(extension_id, method, ¶ms);
2734
2735 // SECURITY (FIND-R118-004): DNS resolution for extension methods.
2736 // Parity with ToolCall (line 710) and ResourceRead (line 1439).
2737 if state.engine.has_ip_rules() {
2738 super::helpers::resolve_domains(&mut action).await;
2739 }
2740
2741 let ext_key = format!("extension:{extension_id}:{method}");
2742
2743 // SECURITY (FIND-R130-002): TOCTOU-safe context+eval+update
2744 // for extension methods. Matches ToolCall/ResourceRead fixes.
2745 let (verdict, ctx) = if let Some(mut session) =
2746 state.sessions.get_mut(&session_id)
2747 {
2748 let ctx = EvaluationContext {
2749 timestamp: None,
2750 agent_id: session.oauth_subject.clone(),
2751 agent_identity: session.agent_identity.clone(),
2752 call_counts: session.call_counts.clone(),
2753 previous_actions: session.action_history.iter().cloned().collect(),
2754 call_chain: session.current_call_chain.clone(),
2755 tenant_id: None,
2756 verification_tier: None,
2757 capability_token: None,
2758 session_state: None,
2759 };
2760
2761 let verdict = match state.engine.evaluate_action_with_context(
2762 &action,
2763 &state.policies,
2764 Some(&ctx),
2765 ) {
2766 Ok(v) => v,
2767 Err(e) => {
2768 tracing::error!(
2769 session_id = %session_id,
2770 "Extension policy evaluation error: {}", e
2771 );
2772 Verdict::Deny {
2773 reason: format!("Policy evaluation failed: {e}"),
2774 }
2775 }
2776 };
2777
2778 // Atomically update session on Allow
2779 if matches!(verdict, Verdict::Allow) {
2780 session.touch();
2781 use crate::proxy::call_chain::{
2782 MAX_ACTION_HISTORY, MAX_CALL_COUNT_TOOLS,
2783 };
2784 if session.call_counts.len() < MAX_CALL_COUNT_TOOLS
2785 || session.call_counts.contains_key(&ext_key)
2786 {
2787 let count =
2788 session.call_counts.entry(ext_key.clone()).or_insert(0);
2789 *count = count.saturating_add(1);
2790 }
2791 if session.action_history.len() >= MAX_ACTION_HISTORY {
2792 session.action_history.pop_front();
2793 }
2794 session.action_history.push_back(ext_key.clone());
2795 }
2796
2797 (verdict, ctx)
2798 } else {
2799 let verdict = match state.engine.evaluate_action_with_context(
2800 &action,
2801 &state.policies,
2802 None,
2803 ) {
2804 Ok(v) => v,
2805 Err(e) => {
2806 tracing::error!(
2807 session_id = %session_id,
2808 "Extension policy evaluation error: {}", e
2809 );
2810 Verdict::Deny {
2811 reason: format!("Policy evaluation failed: {e}"),
2812 }
2813 }
2814 };
2815 (verdict, EvaluationContext::default())
2816 };
2817
2818 match verdict {
2819 Verdict::Allow => {
2820 // SECURITY (FIND-R118-002): ABAC refinement for extension methods.
2821 // Parity with ToolCall (line 1099) and ResourceRead (line 1498).
2822 if let Some(ref abac) = state.abac_engine {
2823 let principal_id =
2824 ctx.agent_id.as_deref().unwrap_or("anonymous");
2825 let principal_type = ctx.principal_type();
2826 let session_risk = state
2827 .sessions
2828 .get_mut(&session_id)
2829 .and_then(|s| s.risk_score.clone());
2830 let abac_ctx = vellaveto_engine::abac::AbacEvalContext {
2831 eval_ctx: &ctx,
2832 principal_type,
2833 principal_id,
2834 risk_score: session_risk.as_ref(),
2835 };
2836 match abac.evaluate(&action, &abac_ctx) {
2837 vellaveto_engine::abac::AbacDecision::Deny {
2838 policy_id,
2839 reason,
2840 } => {
2841 let deny_verdict = Verdict::Deny {
2842 reason: reason.clone(),
2843 };
2844 if let Err(e) = state
2845 .audit
2846 .log_entry(
2847 &action,
2848 &deny_verdict,
2849 json!({
2850 "source": "ws_proxy",
2851 "session": session_id,
2852 "transport": "websocket",
2853 "event": "abac_deny",
2854 "extension_id": extension_id,
2855 "abac_policy": policy_id,
2856 }),
2857 )
2858 .await
2859 {
2860 tracing::warn!(
2861 "Failed to audit WS extension ABAC deny: {}",
2862 e
2863 );
2864 }
2865 let error_resp = make_ws_error_response(
2866 Some(id),
2867 -32001,
2868 "Denied by policy",
2869 );
2870 let mut sink = client_sink.lock().await;
2871 let _ =
2872 sink.send(Message::Text(error_resp.into())).await;
2873 continue;
2874 }
2875 vellaveto_engine::abac::AbacDecision::Allow {
2876 policy_id,
2877 } => {
2878 if let Some(ref la) = state.least_agency {
2879 la.record_usage(
2880 principal_id,
2881 &session_id,
2882 &policy_id,
2883 &ext_key,
2884 method,
2885 );
2886 }
2887 }
2888 vellaveto_engine::abac::AbacDecision::NoMatch => {
2889 // Fall through — existing Allow stands
2890 }
2891 #[allow(unreachable_patterns)]
2892 // AbacDecision is #[non_exhaustive]
2893 _ => {
2894 // SECURITY: Future variants — fail-closed (deny).
2895 tracing::warn!(
2896 "Unknown AbacDecision variant — fail-closed"
2897 );
2898 let error_resp = make_ws_error_response(
2899 Some(id),
2900 -32001,
2901 "Denied by policy",
2902 );
2903 let mut sink = client_sink.lock().await;
2904 let _ =
2905 sink.send(Message::Text(error_resp.into())).await;
2906 continue;
2907 }
2908 }
2909 }
2910
2911 // NOTE: Session touch + call_counts/action_history
2912 // update already performed inside the TOCTOU-safe
2913 // block above (FIND-R130-002). No separate update here.
2914
2915 if let Err(e) = state
2916 .audit
2917 .log_entry(
2918 &action,
2919 &Verdict::Allow,
2920 json!({
2921 "source": "ws_proxy",
2922 "session": session_id,
2923 "transport": "websocket",
2924 "extension_id": extension_id,
2925 }),
2926 )
2927 .await
2928 {
2929 tracing::error!(
2930 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
2931 e
2932 );
2933 // SECURITY (FIND-R215-007): Strict audit mode — fail-closed.
2934 // Parity with Deny and RequireApproval paths.
2935 if state.audit_strict_mode {
2936 let error = make_ws_error_response(
2937 Some(id),
2938 -32000,
2939 "Audit logging failed — request denied (strict audit mode)",
2940 );
2941 let mut sink = client_sink.lock().await;
2942 let _ = sink.send(Message::Text(error.into())).await;
2943 continue;
2944 }
2945 }
2946 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
2947 let forward_text = if state.canonicalize {
2948 match serde_json::to_string(&parsed) {
2949 Ok(canonical) => canonical,
2950 Err(e) => {
2951 tracing::error!("SECURITY: WS extension canonicalization failed: {}", e);
2952 let error_resp = make_ws_error_response(
2953 Some(id),
2954 -32603,
2955 "Internal error",
2956 );
2957 let mut sink = client_sink.lock().await;
2958 let _ =
2959 sink.send(Message::Text(error_resp.into())).await;
2960 continue;
2961 }
2962 }
2963 } else {
2964 text.to_string()
2965 };
2966 let mut sink = upstream_sink.lock().await;
2967 if let Err(e) = sink
2968 .send(tokio_tungstenite::tungstenite::Message::Text(
2969 forward_text.into(),
2970 ))
2971 .await
2972 {
2973 tracing::error!("Failed to forward extension request: {}", e);
2974 break;
2975 }
2976 }
2977 Verdict::Deny { ref reason } => {
2978 if let Err(e) = state
2979 .audit
2980 .log_entry(
2981 &action,
2982 &Verdict::Deny {
2983 reason: reason.clone(),
2984 },
2985 json!({
2986 "source": "ws_proxy",
2987 "session": session_id,
2988 "transport": "websocket",
2989 "extension_id": extension_id,
2990 }),
2991 )
2992 .await
2993 {
2994 tracing::error!(
2995 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
2996 e
2997 );
2998 // SECURITY (FIND-R213-002): Strict audit mode — fail-closed.
2999 if state.audit_strict_mode {
3000 let error = make_ws_error_response(
3001 Some(id),
3002 -32000,
3003 "Audit logging failed — request denied (strict audit mode)",
3004 );
3005 let mut sink = client_sink.lock().await;
3006 let _ = sink.send(Message::Text(error.into())).await;
3007 continue;
3008 }
3009 }
3010 // SECURITY (FIND-R213-001): Generic denial message — do not leak
3011 // detailed policy reason to client. Reason is in the audit log.
3012 let _ = reason;
3013 let denial =
3014 make_ws_error_response(Some(id), -32001, "Denied by policy");
3015 let mut sink = client_sink.lock().await;
3016 let _ = sink.send(Message::Text(denial.into())).await;
3017 }
3018 Verdict::RequireApproval { ref reason, .. } => {
3019 let deny_reason = format!("Requires approval: {reason}");
3020 if let Err(e) = state
3021 .audit
3022 .log_entry(
3023 &action,
3024 &Verdict::Deny {
3025 reason: deny_reason,
3026 },
3027 json!({
3028 "source": "ws_proxy",
3029 "session": session_id,
3030 "transport": "websocket",
3031 "extension_id": extension_id,
3032 }),
3033 )
3034 .await
3035 {
3036 tracing::error!(
3037 "AUDIT FAILURE in WS proxy: security decision not recorded: {}",
3038 e
3039 );
3040 // SECURITY (FIND-R213-002): Strict audit mode — fail-closed.
3041 if state.audit_strict_mode {
3042 let error = make_ws_error_response(
3043 Some(id),
3044 -32000,
3045 "Audit logging failed — request denied (strict audit mode)",
3046 );
3047 let mut sink = client_sink.lock().await;
3048 let _ = sink.send(Message::Text(error.into())).await;
3049 continue;
3050 }
3051 }
3052 let approval_reason = "Approval required";
3053 let approval_id =
3054 create_ws_approval(&state, &session_id, &action, reason).await;
3055 let denial = make_ws_error_response_with_data(
3056 Some(id),
3057 -32001,
3058 approval_reason,
3059 Some(json!({
3060 "verdict": "require_approval",
3061 "reason": approval_reason,
3062 "approval_id": approval_id,
3063 })),
3064 );
3065 let mut sink = client_sink.lock().await;
3066 let _ = sink.send(Message::Text(denial.into())).await;
3067 }
3068 _ => {
3069 let denial =
3070 make_ws_error_response(Some(id), -32001, "Denied by policy");
3071 let mut sink = client_sink.lock().await;
3072 let _ = sink.send(Message::Text(denial.into())).await;
3073 }
3074 }
3075 }
3076 MessageType::ElicitationRequest { ref id } => {
3077 // SECURITY (FIND-R46-010): Policy checks for elicitation requests.
3078 // Match the HTTP POST handler's elicitation inspection logic.
3079 let params = parsed.get("params").cloned().unwrap_or(json!({}));
3080 let elicitation_verdict = {
3081 let mut session_ref = state.sessions.get_mut(&session_id);
3082 let current_count = session_ref
3083 .as_ref()
3084 .map(|s| s.elicitation_count)
3085 .unwrap_or(0);
3086 let verdict = vellaveto_mcp::elicitation::inspect_elicitation(
3087 ¶ms,
3088 &state.elicitation_config,
3089 current_count,
3090 );
3091 // Pre-increment while holding the lock to close the TOCTOU gap
3092 if matches!(
3093 verdict,
3094 vellaveto_mcp::elicitation::ElicitationVerdict::Allow
3095 ) {
3096 if let Some(ref mut s) = session_ref {
3097 // SECURITY (FIND-R51-008): Use saturating_add for consistency.
3098 s.elicitation_count = s.elicitation_count.saturating_add(1);
3099 }
3100 }
3101 verdict
3102 };
3103 match elicitation_verdict {
3104 vellaveto_mcp::elicitation::ElicitationVerdict::Allow => {
3105 let action = Action::new(
3106 "vellaveto",
3107 "ws_forward_message",
3108 json!({
3109 "message_type": "elicitation_request",
3110 "session": session_id,
3111 "transport": "websocket",
3112 "direction": "client_to_upstream",
3113 }),
3114 );
3115 if let Err(e) = state
3116 .audit
3117 .log_entry(
3118 &action,
3119 &Verdict::Allow,
3120 json!({
3121 "source": "ws_proxy",
3122 "event": "ws_elicitation_forwarded",
3123 }),
3124 )
3125 .await
3126 {
3127 tracing::warn!("Failed to audit WS elicitation: {}", e);
3128 }
3129
3130 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
3131 let forward_text = if state.canonicalize {
3132 match serde_json::to_string(&parsed) {
3133 Ok(canonical) => canonical,
3134 Err(e) => {
3135 tracing::error!("SECURITY: WS elicitation canonicalization failed: {}", e);
3136 let error_resp = make_ws_error_response(
3137 Some(id),
3138 -32603,
3139 "Internal error",
3140 );
3141 let mut sink = client_sink.lock().await;
3142 let _ =
3143 sink.send(Message::Text(error_resp.into())).await;
3144 continue;
3145 }
3146 }
3147 } else {
3148 text.to_string()
3149 };
3150 let mut sink = upstream_sink.lock().await;
3151 if let Err(e) = sink
3152 .send(tokio_tungstenite::tungstenite::Message::Text(
3153 forward_text.into(),
3154 ))
3155 .await
3156 {
3157 // Rollback pre-incremented count on forward failure
3158 if let Some(mut s) = state.sessions.get_mut(&session_id) {
3159 s.elicitation_count = s.elicitation_count.saturating_sub(1);
3160 }
3161 tracing::error!("Failed to forward elicitation: {}", e);
3162 break;
3163 }
3164 }
3165 vellaveto_mcp::elicitation::ElicitationVerdict::Deny { reason } => {
3166 tracing::warn!(
3167 session_id = %session_id,
3168 "Blocked WS elicitation/create: {}",
3169 reason
3170 );
3171 let action = Action::new(
3172 "vellaveto",
3173 "ws_elicitation_interception",
3174 json!({
3175 "method": "elicitation/create",
3176 "session": session_id,
3177 "transport": "websocket",
3178 "reason": &reason,
3179 }),
3180 );
3181 let verdict = Verdict::Deny {
3182 reason: reason.clone(),
3183 };
3184 if let Err(e) = state
3185 .audit
3186 .log_entry(
3187 &action,
3188 &verdict,
3189 json!({
3190 "source": "ws_proxy",
3191 "event": "ws_elicitation_interception",
3192 }),
3193 )
3194 .await
3195 {
3196 tracing::warn!(
3197 "Failed to audit WS elicitation interception: {}",
3198 e
3199 );
3200 }
3201 // SECURITY (FIND-R46-012, FIND-R55-WS-006): Generic message to client.
3202 let error =
3203 make_ws_error_response(Some(id), -32001, "Denied by policy");
3204 let mut sink = client_sink.lock().await;
3205 let _ = sink.send(Message::Text(error.into())).await;
3206 }
3207 }
3208 }
3209 MessageType::PassThrough | MessageType::ProgressNotification { .. } => {
3210 // SECURITY (FIND-R76-003): DLP scan PassThrough params for secrets.
3211 // Parity with HTTP handler (handlers.rs:1795-1859). Agents could
3212 // exfiltrate secrets via prompts/get, completion/complete, or any
3213 // PassThrough method's parameters.
3214 // SECURITY (FIND-R97-001): Remove method gate — JSON-RPC responses
3215 // (sampling/elicitation replies) have no `method` field but carry
3216 // data in `result`. Parity with stdio proxy FIND-R96-001.
3217 if state.response_dlp_enabled {
3218 let mut dlp_findings = scan_notification_for_secrets(&parsed);
3219 // SECURITY (FIND-R97-001): Also scan `result` field for responses.
3220 if let Some(result_val) = parsed.get("result") {
3221 dlp_findings.extend(scan_parameters_for_secrets(result_val));
3222 }
3223 // SECURITY (FIND-R83-006): Cap combined findings from params+result
3224 // scans to maintain per-scan invariant (1000).
3225 dlp_findings.truncate(1000);
3226 if !dlp_findings.is_empty() {
3227 for finding in &dlp_findings {
3228 record_dlp_finding(&finding.pattern_name);
3229 }
3230 let patterns: Vec<String> = dlp_findings
3231 .iter()
3232 .map(|f| format!("{}:{}", f.pattern_name, f.location))
3233 .collect();
3234 tracing::warn!(
3235 "SECURITY: Secrets in WS passthrough params! Session: {}, Findings: {:?}",
3236 session_id,
3237 patterns
3238 );
3239 let n_action = Action::new(
3240 "vellaveto",
3241 "notification_dlp_scan",
3242 json!({
3243 "findings": patterns,
3244 "session": session_id,
3245 "transport": "websocket",
3246 }),
3247 );
3248 let verdict = if state.response_dlp_blocking {
3249 Verdict::Deny {
3250 reason: format!(
3251 "Notification blocked: secrets detected ({patterns:?})"
3252 ),
3253 }
3254 } else {
3255 Verdict::Allow
3256 };
3257 if let Err(e) = state
3258 .audit
3259 .log_entry(
3260 &n_action,
3261 &verdict,
3262 json!({
3263 "source": "ws_proxy",
3264 "event": "notification_dlp_alert",
3265 "blocked": state.response_dlp_blocking,
3266 }),
3267 )
3268 .await
3269 {
3270 tracing::warn!("Failed to audit WS passthrough DLP: {}", e);
3271 }
3272 if state.response_dlp_blocking {
3273 // Drop the message silently (passthrough has no id to respond to)
3274 continue;
3275 }
3276 }
3277 }
3278
3279 // SECURITY (FIND-R130-001): Injection scanning on PassThrough parameters.
3280 // Parity with HTTP handler (handlers.rs FIND-R112-008) and gRPC handler
3281 // (service.rs FIND-R113-001). An agent could inject prompt injection
3282 // payloads via any PassThrough method's parameters.
3283 if !state.injection_disabled {
3284 let mut inj_parts = Vec::new();
3285 if let Some(params) = parsed.get("params") {
3286 extract_strings_recursive(params, &mut inj_parts, 0);
3287 }
3288 if let Some(result) = parsed.get("result") {
3289 extract_strings_recursive(result, &mut inj_parts, 0);
3290 }
3291 let scannable = inj_parts.join("\n");
3292 if !scannable.is_empty() {
3293 let injection_matches: Vec<String> =
3294 if let Some(ref scanner) = state.injection_scanner {
3295 scanner
3296 .inspect(&scannable)
3297 .into_iter()
3298 .map(|s| s.to_string())
3299 .collect()
3300 } else {
3301 inspect_for_injection(&scannable)
3302 .into_iter()
3303 .map(|s| s.to_string())
3304 .collect()
3305 };
3306
3307 if !injection_matches.is_empty() {
3308 tracing::warn!(
3309 "SECURITY: Injection in WS passthrough params! \
3310 Session: {}, Patterns: {:?}",
3311 session_id,
3312 injection_matches,
3313 );
3314
3315 let verdict = if state.injection_blocking {
3316 Verdict::Deny {
3317 reason: format!(
3318 "WS passthrough injection blocked: {injection_matches:?}"
3319 ),
3320 }
3321 } else {
3322 Verdict::Allow
3323 };
3324
3325 let inj_action = Action::new(
3326 "vellaveto",
3327 "ws_passthrough_injection_scan",
3328 json!({
3329 "matched_patterns": injection_matches,
3330 "session": session_id,
3331 "transport": "websocket",
3332 "direction": "client_to_upstream",
3333 }),
3334 );
3335 if let Err(e) = state
3336 .audit
3337 .log_entry(
3338 &inj_action,
3339 &verdict,
3340 json!({
3341 "source": "ws_proxy",
3342 "event": "ws_passthrough_injection_detected",
3343 "blocking": state.injection_blocking,
3344 }),
3345 )
3346 .await
3347 {
3348 tracing::warn!(
3349 "Failed to audit WS passthrough injection: {}",
3350 e
3351 );
3352 }
3353
3354 if state.injection_blocking {
3355 // Drop the message (passthrough has no id to respond to)
3356 continue;
3357 }
3358 }
3359 }
3360 }
3361
3362 // SECURITY (IMP-R182-009): Memory poisoning check — parity with
3363 // tool calls, resource reads, tasks, and extension methods.
3364 if let Some(mut session) = state.sessions.get_mut(&session_id) {
3365 let params_to_scan = parsed.get("params").cloned().unwrap_or(json!({}));
3366 // SECURITY (IMP-R184-010): Also scan `result` field — parity
3367 // with DLP scan which scans both params and result.
3368 let mut poisoning_matches =
3369 session.memory_tracker.check_parameters(¶ms_to_scan);
3370 if let Some(result_val) = parsed.get("result") {
3371 poisoning_matches
3372 .extend(session.memory_tracker.check_parameters(result_val));
3373 }
3374 if !poisoning_matches.is_empty() {
3375 let method_name = parsed
3376 .get("method")
3377 .and_then(|m| m.as_str())
3378 .unwrap_or("unknown");
3379 for m in &poisoning_matches {
3380 tracing::warn!(
3381 "SECURITY: Memory poisoning in WS passthrough '{}' (session {}): \
3382 param '{}' replayed data (fingerprint: {})",
3383 method_name,
3384 session_id,
3385 m.param_location,
3386 m.fingerprint
3387 );
3388 }
3389 let poison_action = Action::new(
3390 "vellaveto",
3391 "ws_passthrough_memory_poisoning",
3392 json!({
3393 "method": method_name,
3394 "session": session_id,
3395 "matches": poisoning_matches.len(),
3396 "transport": "websocket",
3397 }),
3398 );
3399 if let Err(e) = state
3400 .audit
3401 .log_entry(
3402 &poison_action,
3403 &Verdict::Deny {
3404 reason: format!(
3405 "WS passthrough blocked: memory poisoning ({} matches)",
3406 poisoning_matches.len()
3407 ),
3408 },
3409 json!({
3410 "source": "ws_proxy",
3411 "event": "ws_passthrough_memory_poisoning",
3412 }),
3413 )
3414 .await
3415 {
3416 tracing::warn!(
3417 "Failed to audit WS passthrough memory poisoning: {}",
3418 e
3419 );
3420 }
3421 continue; // Drop the message
3422 }
3423 // Fingerprint for future poisoning detection.
3424 session.memory_tracker.extract_from_value(¶ms_to_scan);
3425 if let Some(result_val) = parsed.get("result") {
3426 session.memory_tracker.extract_from_value(result_val);
3427 }
3428 } else {
3429 // IMP-R186-005: Log when session is missing so the skip is observable.
3430 tracing::warn!(
3431 "Session {} not found for WS passthrough memory poisoning check",
3432 session_id
3433 );
3434 }
3435
3436 // SECURITY (FIND-R46-WS-004): Audit log forwarded passthrough/notification messages
3437 let msg_type = match &classified {
3438 MessageType::ProgressNotification { .. } => "progress_notification",
3439 _ => "passthrough",
3440 };
3441 let action = Action::new(
3442 "vellaveto",
3443 "ws_forward_message",
3444 json!({
3445 "message_type": msg_type,
3446 "session": session_id,
3447 "transport": "websocket",
3448 "direction": "client_to_upstream",
3449 }),
3450 );
3451 if let Err(e) = state
3452 .audit
3453 .log_entry(
3454 &action,
3455 &Verdict::Allow,
3456 json!({
3457 "source": "ws_proxy",
3458 "event": "ws_message_forwarded",
3459 }),
3460 )
3461 .await
3462 {
3463 tracing::warn!("Failed to audit WS passthrough: {}", e);
3464 }
3465
3466 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
3467 let forward_text = if state.canonicalize {
3468 match serde_json::to_string(&parsed) {
3469 Ok(canonical) => canonical,
3470 Err(e) => {
3471 tracing::error!(
3472 "SECURITY: WS passthrough canonicalization failed: {}",
3473 e
3474 );
3475 continue;
3476 }
3477 }
3478 } else {
3479 text.to_string()
3480 };
3481 let mut sink = upstream_sink.lock().await;
3482 if let Err(e) = sink
3483 .send(tokio_tungstenite::tungstenite::Message::Text(
3484 forward_text.into(),
3485 ))
3486 .await
3487 {
3488 tracing::error!("Failed to forward passthrough: {}", e);
3489 break;
3490 }
3491 }
3492 }
3493 }
3494 Message::Binary(_data) => {
3495 // SECURITY: Binary frames not allowed for JSON-RPC
3496 tracing::warn!(
3497 session_id = %session_id,
3498 "Binary WebSocket frame rejected (JSON-RPC is text-only)"
3499 );
3500
3501 // SECURITY (FIND-R46-WS-004): Audit log binary frame rejection
3502 let action = Action::new(
3503 "vellaveto",
3504 "ws_binary_frame_rejected",
3505 json!({
3506 "session": session_id,
3507 "transport": "websocket",
3508 "direction": "client_to_upstream",
3509 }),
3510 );
3511 if let Err(e) = state
3512 .audit
3513 .log_entry(
3514 &action,
3515 &Verdict::Deny {
3516 reason: "Binary frames not supported for JSON-RPC".to_string(),
3517 },
3518 json!({
3519 "source": "ws_proxy",
3520 "event": "ws_binary_frame_rejected",
3521 }),
3522 )
3523 .await
3524 {
3525 tracing::warn!("Failed to audit WS binary frame rejection: {}", e);
3526 }
3527
3528 let mut sink = client_sink.lock().await;
3529 let _ = sink
3530 .send(Message::Close(Some(CloseFrame {
3531 code: CLOSE_UNSUPPORTED_DATA,
3532 reason: "Binary frames not supported".into(),
3533 })))
3534 .await;
3535 break;
3536 }
3537 Message::Ping(data) => {
3538 let mut sink = client_sink.lock().await;
3539 let _ = sink.send(Message::Pong(data)).await;
3540 }
3541 Message::Pong(_) => {
3542 // Ignored
3543 }
3544 Message::Close(_) => {
3545 tracing::debug!(session_id = %session_id, "Client sent close frame");
3546 break;
3547 }
3548 }
3549 }
3550}
3551
3552/// Relay messages from upstream to client with DLP and injection scanning.
3553#[allow(clippy::too_many_arguments)]
3554async fn relay_upstream_to_client(
3555 mut upstream_stream: futures_util::stream::SplitStream<
3556 tokio_tungstenite::WebSocketStream<
3557 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
3558 >,
3559 >,
3560 client_sink: Arc<Mutex<futures_util::stream::SplitSink<WebSocket, Message>>>,
3561 state: ProxyState,
3562 session_id: String,
3563 ws_config: WebSocketConfig,
3564 upstream_rate_counter: Arc<AtomicU64>,
3565 upstream_rate_window_start: Arc<std::sync::Mutex<std::time::Instant>>,
3566 last_activity: Arc<AtomicU64>,
3567 connection_epoch: std::time::Instant,
3568) {
3569 while let Some(msg_result) = upstream_stream.next().await {
3570 let msg = match msg_result {
3571 Ok(m) => m,
3572 Err(e) => {
3573 tracing::debug!(session_id = %session_id, "Upstream WS error: {}", e);
3574 break;
3575 }
3576 };
3577
3578 // SECURITY (FIND-R182-001): Update last-activity for true idle detection.
3579 last_activity.store(
3580 connection_epoch.elapsed().as_millis() as u64,
3581 Ordering::Relaxed,
3582 );
3583
3584 record_ws_message("upstream_to_client");
3585
3586 // SECURITY (FIND-R46-WS-003): Rate limiting on upstream→client direction.
3587 // A malicious or compromised upstream could flood the client with messages.
3588 if !check_rate_limit(
3589 &upstream_rate_counter,
3590 &upstream_rate_window_start,
3591 ws_config.upstream_rate_limit,
3592 ) {
3593 tracing::warn!(
3594 session_id = %session_id,
3595 "WebSocket upstream rate limit exceeded ({}/s), dropping message",
3596 ws_config.upstream_rate_limit,
3597 );
3598
3599 let action = Action::new(
3600 "vellaveto",
3601 "ws_upstream_rate_limit",
3602 json!({
3603 "session": session_id,
3604 "transport": "websocket",
3605 "direction": "upstream_to_client",
3606 "limit": ws_config.upstream_rate_limit,
3607 }),
3608 );
3609 if let Err(e) = state
3610 .audit
3611 .log_entry(
3612 &action,
3613 &Verdict::Deny {
3614 reason: "Upstream rate limit exceeded".to_string(),
3615 },
3616 json!({
3617 "source": "ws_proxy",
3618 "event": "ws_upstream_rate_limit_exceeded",
3619 }),
3620 )
3621 .await
3622 {
3623 tracing::warn!("Failed to audit WS upstream rate limit: {}", e);
3624 }
3625
3626 metrics::counter!(
3627 "vellaveto_ws_upstream_rate_limited_total",
3628 "session" => session_id.clone()
3629 )
3630 .increment(1);
3631
3632 // Drop the message (don't close the connection — upstream flood should not
3633 // disconnect the client, just throttle the flow)
3634 continue;
3635 }
3636
3637 match msg {
3638 tokio_tungstenite::tungstenite::Message::Text(text) => {
3639 // Try to parse for scanning
3640 let forward = if let Ok(json_val) = serde_json::from_str::<Value>(&text) {
3641 // Resolve tracked tool context for response-side schema checks.
3642 let tracked_tool_name =
3643 take_tracked_tool_call(&state.sessions, &session_id, json_val.get("id"));
3644
3645 // SECURITY (FIND-R75-003): Track whether DLP or injection was detected
3646 // (even in log-only mode) to gate memory_tracker.record_response().
3647 // Recording fingerprints from tainted responses would poison the tracker.
3648 let mut dlp_found = false;
3649 let mut injection_found = false;
3650
3651 // DLP scanning on responses
3652 if state.response_dlp_enabled {
3653 let dlp_findings = scan_response_for_secrets(&json_val);
3654 if !dlp_findings.is_empty() {
3655 dlp_found = true;
3656 for finding in &dlp_findings {
3657 record_dlp_finding(&finding.pattern_name);
3658 }
3659
3660 let patterns: Vec<String> = dlp_findings
3661 .iter()
3662 .map(|f| format!("{}:{}", f.pattern_name, f.location))
3663 .collect();
3664
3665 tracing::warn!(
3666 "SECURITY: Secrets in WS response! Session: {}, Findings: {:?}",
3667 session_id,
3668 patterns,
3669 );
3670
3671 let verdict = if state.response_dlp_blocking {
3672 Verdict::Deny {
3673 reason: format!("WS response DLP blocked: {patterns:?}"),
3674 }
3675 } else {
3676 Verdict::Allow
3677 };
3678
3679 let action = Action::new(
3680 "vellaveto",
3681 "ws_response_dlp_scan",
3682 json!({
3683 "findings": patterns,
3684 "session": session_id,
3685 "transport": "websocket",
3686 }),
3687 );
3688 if let Err(e) = state
3689 .audit
3690 .log_entry(
3691 &action,
3692 &verdict,
3693 json!({
3694 "source": "ws_proxy",
3695 "event": "ws_response_dlp_alert",
3696 }),
3697 )
3698 .await
3699 {
3700 tracing::warn!("Failed to audit WS DLP: {}", e);
3701 }
3702
3703 if state.response_dlp_blocking {
3704 // Send error response instead
3705 let id = json_val.get("id");
3706 let error = make_ws_error_response(
3707 id,
3708 -32001,
3709 "Response blocked by DLP policy",
3710 );
3711 let mut sink = client_sink.lock().await;
3712 let _ = sink.send(Message::Text(error.into())).await;
3713 continue;
3714 }
3715 }
3716 }
3717
3718 // Injection scanning
3719 if !state.injection_disabled {
3720 let text_to_scan = extract_scannable_text(&json_val);
3721 if !text_to_scan.is_empty() {
3722 let injection_matches: Vec<String> =
3723 if let Some(ref scanner) = state.injection_scanner {
3724 scanner
3725 .inspect(&text_to_scan)
3726 .into_iter()
3727 .map(|s| s.to_string())
3728 .collect()
3729 } else {
3730 inspect_for_injection(&text_to_scan)
3731 .into_iter()
3732 .map(|s| s.to_string())
3733 .collect()
3734 };
3735
3736 if !injection_matches.is_empty() {
3737 injection_found = true;
3738 tracing::warn!(
3739 "SECURITY: Injection in WS response! Session: {}, Patterns: {:?}",
3740 session_id,
3741 injection_matches,
3742 );
3743
3744 let verdict = if state.injection_blocking {
3745 Verdict::Deny {
3746 reason: format!(
3747 "WS response injection blocked: {injection_matches:?}"
3748 ),
3749 }
3750 } else {
3751 Verdict::Allow
3752 };
3753
3754 let action = Action::new(
3755 "vellaveto",
3756 "ws_response_injection",
3757 json!({
3758 "matched_patterns": injection_matches,
3759 "session": session_id,
3760 "transport": "websocket",
3761 }),
3762 );
3763 if let Err(e) = state
3764 .audit
3765 .log_entry(
3766 &action,
3767 &verdict,
3768 json!({
3769 "source": "ws_proxy",
3770 "event": "ws_injection_detected",
3771 }),
3772 )
3773 .await
3774 {
3775 tracing::warn!("Failed to audit WS injection: {}", e);
3776 }
3777
3778 if state.injection_blocking {
3779 let id = json_val.get("id");
3780 let error = make_ws_error_response(
3781 id,
3782 -32001,
3783 "Response blocked: injection detected",
3784 );
3785 let mut sink = client_sink.lock().await;
3786 let _ = sink.send(Message::Text(error.into())).await;
3787 continue;
3788 }
3789 }
3790 }
3791 }
3792
3793 // SECURITY (FIND-R46-007): Rug-pull detection on tools/list responses.
3794 // Check if this is a response to a tools/list request and extract
3795 // annotations for rug-pull detection.
3796 if json_val.get("result").is_some() {
3797 // Check if result contains "tools" array (tools/list response)
3798 if json_val
3799 .get("result")
3800 .and_then(|r| r.get("tools"))
3801 .and_then(|t| t.as_array())
3802 .is_some()
3803 {
3804 super::helpers::extract_annotations_from_response(
3805 &json_val,
3806 &session_id,
3807 &state.sessions,
3808 &state.audit,
3809 &state.known_tools,
3810 )
3811 .await;
3812
3813 // Verify manifest if configured
3814 if let Some(ref manifest_config) = state.manifest_config {
3815 super::helpers::verify_manifest_from_response(
3816 &json_val,
3817 &session_id,
3818 &state.sessions,
3819 manifest_config,
3820 &state.audit,
3821 )
3822 .await;
3823 }
3824
3825 // SECURITY (FIND-R130-003): Scan tool descriptions for embedded
3826 // injection. Parity with HTTP upstream handler (upstream.rs:648-698).
3827 if !state.injection_disabled {
3828 let desc_findings =
3829 if let Some(ref scanner) = state.injection_scanner {
3830 scan_tool_descriptions_with_scanner(&json_val, scanner)
3831 } else {
3832 scan_tool_descriptions(&json_val)
3833 };
3834 for finding in &desc_findings {
3835 injection_found = true;
3836 tracing::warn!(
3837 "SECURITY: Injection in tool '{}' description! \
3838 Session: {}, Patterns: {:?}",
3839 finding.tool_name,
3840 session_id,
3841 finding.matched_patterns
3842 );
3843 let action = Action::new(
3844 "vellaveto",
3845 "tool_description_injection",
3846 json!({
3847 "tool": finding.tool_name,
3848 "matched_patterns": finding.matched_patterns,
3849 "session": session_id,
3850 "transport": "websocket",
3851 "blocking": state.injection_blocking,
3852 }),
3853 );
3854 if let Err(e) = state
3855 .audit
3856 .log_entry(
3857 &action,
3858 &Verdict::Deny {
3859 reason: format!(
3860 "Tool '{}' description contains injection: {:?}",
3861 finding.tool_name, finding.matched_patterns
3862 ),
3863 },
3864 json!({
3865 "source": "ws_proxy",
3866 "event": "tool_description_injection",
3867 }),
3868 )
3869 .await
3870 {
3871 tracing::warn!(
3872 "Failed to audit WS tool description injection: {}",
3873 e
3874 );
3875 }
3876 }
3877 if !desc_findings.is_empty() && state.injection_blocking {
3878 let id = json_val.get("id");
3879 let error = make_ws_error_response(
3880 id,
3881 -32001,
3882 "Response blocked: suspicious content in tool descriptions",
3883 );
3884 let mut sink = client_sink.lock().await;
3885 let _ = sink.send(Message::Text(error.into())).await;
3886 continue;
3887 }
3888 }
3889 }
3890 }
3891
3892 // SECURITY: Enforce output schema on WS structuredContent.
3893 // SECURITY (FIND-R154-004): Track schema violations for the
3894 // record_response guard below, even in non-blocking mode.
3895 let schema_violation_found = validate_ws_structured_content_response(
3896 &json_val,
3897 &state,
3898 &session_id,
3899 tracked_tool_name.as_deref(),
3900 )
3901 .await;
3902 if schema_violation_found {
3903 let id = json_val.get("id");
3904 let error = make_ws_error_response(
3905 id,
3906 -32001,
3907 "Response blocked: output schema validation failed",
3908 );
3909 let mut sink = client_sink.lock().await;
3910 let _ = sink.send(Message::Text(error.into())).await;
3911 continue;
3912 }
3913
3914 // SECURITY (FIND-R75-003, FIND-R154-004): Record response
3915 // fingerprints for memory poisoning detection. Skip recording
3916 // when DLP, injection, or schema violation was detected (even
3917 // in log-only mode) to avoid poisoning the tracker with tainted
3918 // data. Parity with stdio relay (relay.rs:2919).
3919 if !dlp_found && !injection_found && !schema_violation_found {
3920 if let Some(mut session) = state.sessions.get_mut(&session_id) {
3921 session.memory_tracker.record_response(&json_val);
3922 }
3923 }
3924
3925 // SECURITY (FIND-R46-WS-004): Audit log forwarded upstream→client text messages
3926 {
3927 let msg_type = if json_val.get("result").is_some() {
3928 "response"
3929 } else if json_val.get("error").is_some() {
3930 "error_response"
3931 } else if json_val.get("method").is_some() {
3932 "notification"
3933 } else {
3934 "unknown"
3935 };
3936 let action = Action::new(
3937 "vellaveto",
3938 "ws_forward_upstream_message",
3939 json!({
3940 "message_type": msg_type,
3941 "session": session_id,
3942 "transport": "websocket",
3943 "direction": "upstream_to_client",
3944 }),
3945 );
3946 if let Err(e) = state
3947 .audit
3948 .log_entry(
3949 &action,
3950 &Verdict::Allow,
3951 json!({
3952 "source": "ws_proxy",
3953 "event": "ws_upstream_message_forwarded",
3954 }),
3955 )
3956 .await
3957 {
3958 tracing::warn!("Failed to audit WS upstream message forward: {}", e);
3959 }
3960 }
3961
3962 // SECURITY (FIND-R48-001): Fail-closed on canonicalization failure.
3963 if state.canonicalize {
3964 match serde_json::to_string(&json_val) {
3965 Ok(canonical) => canonical,
3966 Err(e) => {
3967 tracing::error!(
3968 "SECURITY: WS response canonicalization failed: {}",
3969 e
3970 );
3971 continue;
3972 }
3973 }
3974 } else {
3975 text.to_string()
3976 }
3977 } else {
3978 // SECURITY (FIND-R166-001): Non-JSON upstream text must still be
3979 // scanned for DLP/injection before forwarding. A malicious upstream
3980 // could exfiltrate secrets or inject payloads via non-JSON frames.
3981 let text_str: &str = &text;
3982 // SECURITY (FIND-R168-001): DLP scan with audit logging parity.
3983 if state.response_dlp_enabled {
3984 let findings = scan_text_for_secrets(text_str, "ws.upstream.non_json_text");
3985 if !findings.is_empty() {
3986 let patterns: Vec<String> = findings
3987 .iter()
3988 .map(|f| format!("{}:{}", f.pattern_name, f.location))
3989 .collect();
3990 tracing::warn!(
3991 session_id = %session_id,
3992 "DLP: non-JSON upstream text contains sensitive data ({} findings)",
3993 findings.len(),
3994 );
3995 let verdict = if state.response_dlp_blocking {
3996 Verdict::Deny {
3997 reason: format!("WS non-JSON DLP: {patterns:?}"),
3998 }
3999 } else {
4000 Verdict::Allow
4001 };
4002 let action = Action::new(
4003 "vellaveto",
4004 "ws_nonjson_dlp_scan",
4005 json!({ "findings": patterns, "session": session_id, "transport": "websocket" }),
4006 );
4007 // SECURITY (SE-004): Log audit failures instead of silently discarding.
4008 if let Err(e) = state.audit.log_entry(
4009 &action, &verdict,
4010 json!({ "source": "ws_proxy", "event": "ws_nonjson_dlp_alert" }),
4011 ).await {
4012 tracing::error!(
4013 session_id = %session_id,
4014 error = %e,
4015 "AUDIT FAILURE: failed to log ws_nonjson_dlp_alert"
4016 );
4017 }
4018 if state.response_dlp_blocking {
4019 continue;
4020 }
4021 }
4022 }
4023 // SECURITY (FIND-R168-002): Injection scan with log-only mode
4024 // parity. Always log detections; only block when injection_blocking.
4025 {
4026 let alerts: Vec<String> = if let Some(ref scanner) = state.injection_scanner
4027 {
4028 scanner
4029 .inspect(text_str)
4030 .into_iter()
4031 .map(|s| s.to_string())
4032 .collect()
4033 } else {
4034 inspect_for_injection(text_str)
4035 .into_iter()
4036 .map(|s| s.to_string())
4037 .collect()
4038 };
4039 if !alerts.is_empty() {
4040 tracing::warn!(
4041 session_id = %session_id,
4042 "Injection: non-JSON upstream text ({} alerts), blocking={}",
4043 alerts.len(), state.injection_blocking,
4044 );
4045 let verdict = if state.injection_blocking {
4046 Verdict::Deny {
4047 reason: format!(
4048 "WS non-JSON injection: {} alerts",
4049 alerts.len()
4050 ),
4051 }
4052 } else {
4053 Verdict::Allow
4054 };
4055 let action = Action::new(
4056 "vellaveto",
4057 "ws_nonjson_injection_scan",
4058 json!({ "alerts": alerts.len(), "session": session_id, "transport": "websocket" }),
4059 );
4060 // SECURITY (SE-004): Log audit failures instead of silently discarding.
4061 if let Err(e) = state.audit.log_entry(
4062 &action, &verdict,
4063 json!({ "source": "ws_proxy", "event": "ws_nonjson_injection_alert" }),
4064 ).await {
4065 tracing::error!(
4066 session_id = %session_id,
4067 error = %e,
4068 "AUDIT FAILURE: failed to log ws_nonjson_injection_alert"
4069 );
4070 }
4071 if state.injection_blocking {
4072 continue;
4073 }
4074 }
4075 }
4076 text.to_string()
4077 };
4078
4079 let mut sink = client_sink.lock().await;
4080 if let Err(e) = sink.send(Message::Text(forward.into())).await {
4081 tracing::debug!("Failed to send to client: {}", e);
4082 break;
4083 }
4084 }
4085 tokio_tungstenite::tungstenite::Message::Binary(data) => {
4086 // SECURITY (FIND-R46-WS-002): DLP scanning on upstream binary frames.
4087 // Binary from upstream is unusual for JSON-RPC but must be scanned
4088 // before being dropped, to detect and audit secret exfiltration attempts
4089 // via binary frames.
4090 tracing::warn!(
4091 session_id = %session_id,
4092 "Unexpected binary frame from upstream ({} bytes), scanning before drop",
4093 data.len(),
4094 );
4095
4096 // DLP scan the binary data as UTF-8 lossy
4097 if state.response_dlp_enabled {
4098 let text_repr = String::from_utf8_lossy(&data);
4099 if !text_repr.is_empty() {
4100 let dlp_findings = scan_text_for_secrets(&text_repr, "ws_binary_frame");
4101 if !dlp_findings.is_empty() {
4102 for finding in &dlp_findings {
4103 record_dlp_finding(&finding.pattern_name);
4104 }
4105 let patterns: Vec<String> = dlp_findings
4106 .iter()
4107 .map(|f| format!("{}:{}", f.pattern_name, f.location))
4108 .collect();
4109
4110 tracing::warn!(
4111 "SECURITY: Secrets in WS upstream binary frame! Session: {}, Findings: {:?}",
4112 session_id,
4113 patterns,
4114 );
4115
4116 let action = Action::new(
4117 "vellaveto",
4118 "ws_binary_dlp_scan",
4119 json!({
4120 "findings": patterns,
4121 "session": session_id,
4122 "transport": "websocket",
4123 "direction": "upstream_to_client",
4124 "binary_size": data.len(),
4125 }),
4126 );
4127 if let Err(e) = state
4128 .audit
4129 .log_entry(
4130 &action,
4131 &Verdict::Deny {
4132 reason: format!("WS binary frame DLP: {patterns:?}"),
4133 },
4134 json!({
4135 "source": "ws_proxy",
4136 "event": "ws_binary_dlp_alert",
4137 }),
4138 )
4139 .await
4140 {
4141 tracing::warn!("Failed to audit WS binary DLP: {}", e);
4142 }
4143 }
4144 }
4145 }
4146
4147 // SECURITY (FIND-R46-WS-004): Audit log binary frame drop
4148 let action = Action::new(
4149 "vellaveto",
4150 "ws_upstream_binary_dropped",
4151 json!({
4152 "session": session_id,
4153 "transport": "websocket",
4154 "direction": "upstream_to_client",
4155 "binary_size": data.len(),
4156 }),
4157 );
4158 if let Err(e) = state
4159 .audit
4160 .log_entry(
4161 &action,
4162 &Verdict::Deny {
4163 reason: "Binary frames not supported for JSON-RPC".to_string(),
4164 },
4165 json!({
4166 "source": "ws_proxy",
4167 "event": "ws_upstream_binary_dropped",
4168 }),
4169 )
4170 .await
4171 {
4172 tracing::warn!("Failed to audit WS upstream binary drop: {}", e);
4173 }
4174 }
4175 tokio_tungstenite::tungstenite::Message::Ping(data) => {
4176 // Forward ping as pong to upstream (handled by tungstenite)
4177 let _ = data; // tungstenite auto-responds to pings
4178 }
4179 tokio_tungstenite::tungstenite::Message::Pong(_) => {}
4180 tokio_tungstenite::tungstenite::Message::Close(_) => {
4181 tracing::debug!(session_id = %session_id, "Upstream sent close frame");
4182 break;
4183 }
4184 tokio_tungstenite::tungstenite::Message::Frame(_) => {
4185 // Raw frame — ignore
4186 }
4187 }
4188 }
4189}
4190
4191/// Convert an HTTP URL to a WebSocket URL.
4192///
4193/// `http://` → `ws://`, `https://` → `wss://`.
4194///
4195/// SECURITY (FIND-R124-001): Only allows http/https/ws/wss schemes.
4196/// Unknown schemes (ftp://, file://, gopher://) are rejected with a
4197/// warning and fall back to the original URL prefixed with `ws://`
4198/// to maintain fail-closed behavior. This gives parity with scheme
4199/// validation in HTTP and gRPC transports (FIND-R42-015).
4200pub fn convert_to_ws_url(http_url: &str) -> String {
4201 if let Some(rest) = http_url.strip_prefix("https://") {
4202 format!("wss://{rest}")
4203 } else if let Some(rest) = http_url.strip_prefix("http://") {
4204 format!("ws://{rest}")
4205 } else if http_url.starts_with("wss://") || http_url.starts_with("ws://") {
4206 // Already a WebSocket URL — use as-is
4207 http_url.to_string()
4208 } else {
4209 // SECURITY (FIND-R124-001): Reject unknown schemes. Log warning
4210 // and return a URL that will fail to connect safely rather than
4211 // connecting to an unintended scheme (e.g., ftp://, file://).
4212 // SECURITY (FIND-R166-003): Sanitize logged value to prevent log injection
4213 // from URLs with control characters (possible in gateway mode).
4214 tracing::warn!(
4215 "convert_to_ws_url: rejecting URL with unsupported scheme: {}",
4216 vellaveto_types::sanitize_for_log(
4217 http_url.split("://").next().unwrap_or("unknown"),
4218 128,
4219 )
4220 );
4221 // Return invalid URL that will fail at connect_async()
4222 format!("ws://invalid-scheme-rejected.localhost/{}", http_url.len())
4223 }
4224}
4225
4226/// Connect to an upstream WebSocket server.
4227///
4228/// Returns the split WebSocket stream or an error.
4229async fn connect_upstream_ws(
4230 url: &str,
4231) -> Result<
4232 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
4233 String,
4234> {
4235 let connect_timeout = Duration::from_secs(10);
4236 match tokio::time::timeout(connect_timeout, tokio_tungstenite::connect_async(url)).await {
4237 Ok(Ok((ws_stream, _response))) => Ok(ws_stream),
4238 Ok(Err(e)) => Err(format!("WebSocket connection error: {e}")),
4239 Err(_) => Err("WebSocket connection timeout (10s)".to_string()),
4240 }
4241}
4242
4243/// Register output schemas and validate WS response `structuredContent`.
4244///
4245/// Returns true when the response should be blocked.
4246async fn validate_ws_structured_content_response(
4247 json_val: &Value,
4248 state: &ProxyState,
4249 session_id: &str,
4250 tracked_tool_name: Option<&str>,
4251) -> bool {
4252 // Keep WS behavior aligned with HTTP/SSE paths.
4253 state
4254 .output_schema_registry
4255 .register_from_tools_list(json_val);
4256
4257 let Some(result) = json_val.get("result") else {
4258 return false;
4259 };
4260 let Some(structured) = result.get("structuredContent") else {
4261 return false;
4262 };
4263
4264 let meta_tool_name = result
4265 .get("_meta")
4266 .and_then(|m| m.get("tool"))
4267 .and_then(|t| t.as_str());
4268 let tool_name = match (meta_tool_name, tracked_tool_name) {
4269 (Some(meta), Some(tracked)) if !meta.eq_ignore_ascii_case(tracked) => {
4270 tracing::warn!(
4271 "SECURITY: WS structuredContent tool mismatch (meta='{}', tracked='{}'); using tracked tool name",
4272 meta,
4273 tracked
4274 );
4275 tracked
4276 }
4277 (Some(meta), _) => meta,
4278 (None, Some(tracked)) => tracked,
4279 (None, None) => "unknown",
4280 };
4281
4282 match state.output_schema_registry.validate(tool_name, structured) {
4283 ValidationResult::Invalid { violations } => {
4284 tracing::warn!(
4285 "SECURITY: WS structuredContent validation failed for tool '{}': {:?}",
4286 tool_name,
4287 violations
4288 );
4289 let action = Action::new(
4290 "vellaveto",
4291 "output_schema_violation",
4292 json!({
4293 "tool": tool_name,
4294 "violations": violations,
4295 "session": session_id,
4296 "transport": "websocket",
4297 }),
4298 );
4299 if let Err(e) = state
4300 .audit
4301 .log_entry(
4302 &action,
4303 &Verdict::Deny {
4304 reason: format!("WS structuredContent validation failed: {violations:?}"),
4305 },
4306 json!({"source": "ws_proxy", "event": "output_schema_violation_ws"}),
4307 )
4308 .await
4309 {
4310 tracing::warn!("Failed to audit WS output schema violation: {}", e);
4311 }
4312 true
4313 }
4314 ValidationResult::Valid => {
4315 tracing::debug!("WS structuredContent validated for tool '{}'", tool_name);
4316 false
4317 }
4318 ValidationResult::NoSchema => {
4319 tracing::debug!(
4320 "No output schema registered for WS tool '{}', skipping validation",
4321 tool_name
4322 );
4323 false
4324 }
4325 }
4326}
4327
4328// NOTE: build_ws_evaluation_context() was removed in FIND-R130-002 fix.
4329// All callers now build EvaluationContext inline inside the DashMap shard
4330// lock to prevent TOCTOU races on call_counts/action_history.
4331
4332/// Check per-connection rate limit. Returns true if within limit.
4333fn check_rate_limit(
4334 counter: &AtomicU64,
4335 window_start: &std::sync::Mutex<std::time::Instant>,
4336 max_per_sec: u32,
4337) -> bool {
4338 // SECURITY (FIND-R182-006): Fail-closed — zero rate limit blocks all messages.
4339 // Previously returned true (fail-open), which disabled rate limiting entirely.
4340 if max_per_sec == 0 {
4341 return false;
4342 }
4343
4344 let now = std::time::Instant::now();
4345 let mut start = match window_start.lock() {
4346 Ok(guard) => guard,
4347 Err(e) => {
4348 tracing::error!("WS rate limiter mutex poisoned — fail-closed: {}", e);
4349 return false;
4350 }
4351 };
4352
4353 if now.duration_since(*start) >= Duration::from_secs(1) {
4354 // Reset window
4355 *start = now;
4356 // SECURITY (FIND-R55-WS-003): Use SeqCst for security-critical rate limit counter.
4357 counter.store(1, Ordering::SeqCst);
4358 true
4359 } else {
4360 // SECURITY (FIND-R182-003): saturating arithmetic prevents overflow wrap-to-zero.
4361 // SECURITY (FIND-R155-WS-001): Conditional atomic increment — only increment if
4362 // within limit, reject otherwise. This eliminates the TOCTOU gap between
4363 // load()+compare and also prevents counter inflation from rejected requests.
4364 // The closure returns None when limit is reached, causing fetch_update to fail
4365 // without modifying the counter.
4366 let limit = max_per_sec as u64;
4367 match counter.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
4368 if v >= limit {
4369 None // Limit reached — do not increment
4370 } else {
4371 Some(v.saturating_add(1))
4372 }
4373 }) {
4374 Ok(_prev) => true, // Within limit, counter was incremented
4375 Err(_) => false, // Limit exceeded, counter unchanged
4376 }
4377 }
4378}
4379
4380/// Extract scannable text from a JSON-RPC request for injection scanning.
4381///
4382/// SECURITY (FIND-R46-WS-001): Scans tool call arguments, resource URIs,
4383/// and sampling request content for injection payloads in the client→upstream
4384/// direction. Matches the HTTP proxy's request-side injection scanning.
4385fn extract_scannable_text_from_request(json_val: &Value) -> String {
4386 let mut text_parts = Vec::new();
4387
4388 // SECURITY (FIND-R224-002): Recursively scan the entire `params` subtree,
4389 // not just specific fields. Previous narrow extraction missed injection
4390 // payloads in TaskRequest (params.message), ExtensionMethod, and other
4391 // message types with non-standard parameter structures. This gives parity
4392 // with the HTTP handler's `extract_passthrough_text_for_injection` which
4393 // scans all of params recursively.
4394 if let Some(params) = json_val.get("params") {
4395 extract_strings_recursive(params, &mut text_parts, 0);
4396 }
4397
4398 // SECURITY (FIND-R224-006): Also scan `result` field for injection payloads.
4399 // JSON-RPC response messages (sampling/elicitation replies) carry data in
4400 // `result` rather than `params`. Without this, upstream injection payloads
4401 // in response results bypass WS scanning while being caught by HTTP/gRPC.
4402 if let Some(result) = json_val.get("result") {
4403 extract_strings_recursive(result, &mut text_parts, 0);
4404 }
4405
4406 text_parts.join("\n")
4407}
4408
4409/// Recursively extract string values from a JSON value, with depth and count bounds.
4410///
4411/// SECURITY (FIND-R48-007): Added MAX_PARTS to prevent memory amplification
4412/// from messages containing many short strings.
4413fn extract_strings_recursive(val: &Value, parts: &mut Vec<String>, depth: usize) {
4414 // SECURITY (FIND-R154-005): Use depth 32 matching shared MAX_SCAN_DEPTH
4415 // in scanner_base.rs. Previous limit of 10 allowed injection payloads
4416 // nested between depth 11-32 to evade WS scanning while being caught
4417 // by the stdio relay and DLP scanner (both use MAX_SCAN_DEPTH=32).
4418 const MAX_DEPTH: usize = 32;
4419 const MAX_PARTS: usize = 1000;
4420 if depth > MAX_DEPTH || parts.len() >= MAX_PARTS {
4421 return;
4422 }
4423 match val {
4424 Value::String(s) => parts.push(s.clone()),
4425 Value::Array(arr) => {
4426 for item in arr {
4427 extract_strings_recursive(item, parts, depth + 1);
4428 }
4429 }
4430 Value::Object(map) => {
4431 for (key, v) in map {
4432 // SECURITY (FIND-R154-003): Also scan object keys for injection
4433 // payloads. Parity with stdio relay's traverse_json_strings_with_keys.
4434 // Without this, attackers can hide injection in JSON key names.
4435 if parts.len() < MAX_PARTS {
4436 parts.push(key.clone());
4437 }
4438 extract_strings_recursive(v, parts, depth + 1);
4439 }
4440 }
4441 _ => {}
4442 }
4443}
4444
4445/// Extract scannable text from a JSON-RPC response for injection scanning.
4446///
4447/// SECURITY (FIND-R130-004): Delegates to the shared `extract_text_from_result()`
4448/// which covers `resource.text`, `resource.blob` (base64-decoded), `annotations`,
4449/// and `_meta` — all missing from the previous WS-only implementation.
4450fn extract_scannable_text(json_val: &Value) -> String {
4451 let mut text_parts = Vec::new();
4452
4453 // Scan result via shared extraction (covers content[].text, resource.text,
4454 // resource.blob, annotations, instructionsForUser, structuredContent, _meta).
4455 if let Some(result) = json_val.get("result") {
4456 let result_text = super::inspection::extract_text_from_result(result);
4457 if !result_text.is_empty() {
4458 text_parts.push(result_text);
4459 }
4460 }
4461
4462 // Scan error messages (not covered by extract_text_from_result)
4463 if let Some(error) = json_val.get("error") {
4464 if let Some(msg) = error.get("message").and_then(|m| m.as_str()) {
4465 text_parts.push(msg.to_string());
4466 }
4467 // SECURITY (FIND-R168-005): Use as_str() first to avoid wrapping
4468 // string values in JSON quotes. Parity with scanner_base.rs.
4469 if let Some(data) = error.get("data") {
4470 if let Some(s) = data.as_str() {
4471 text_parts.push(s.to_string());
4472 } else {
4473 text_parts.push(data.to_string());
4474 }
4475 }
4476 }
4477
4478 text_parts.join("\n")
4479}
4480
4481/// Create a pending approval for WebSocket-denied actions when an approval store
4482/// is configured. Returns the pending approval ID on success.
4483async fn create_ws_approval(
4484 state: &ProxyState,
4485 session_id: &str,
4486 action: &Action,
4487 reason: &str,
4488) -> Option<String> {
4489 let store = state.approval_store.as_ref()?;
4490 let requested_by = state.sessions.get_mut(session_id).and_then(|session| {
4491 session
4492 .agent_identity
4493 .as_ref()
4494 .and_then(|identity| identity.subject.clone())
4495 .or_else(|| session.oauth_subject.clone())
4496 });
4497 match store
4498 .create(action.clone(), reason.to_string(), requested_by)
4499 .await
4500 {
4501 Ok(id) => Some(id),
4502 Err(e) => {
4503 tracing::error!(
4504 session_id = %session_id,
4505 "Failed to create WebSocket approval (fail-closed): {}",
4506 e
4507 );
4508 None
4509 }
4510 }
4511}
4512
4513/// Build a JSON-RPC error response string for WebSocket with optional `error.data`.
4514fn make_ws_error_response_with_data(
4515 id: Option<&Value>,
4516 code: i64,
4517 message: &str,
4518 data: Option<Value>,
4519) -> String {
4520 let mut error = serde_json::Map::new();
4521 error.insert("code".to_string(), Value::from(code));
4522 error.insert("message".to_string(), Value::from(message));
4523 if let Some(data) = data {
4524 error.insert("data".to_string(), data);
4525 }
4526 let response = json!({
4527 "jsonrpc": "2.0",
4528 "id": id.cloned().unwrap_or(Value::Null),
4529 "error": Value::Object(error),
4530 });
4531 serde_json::to_string(&response).unwrap_or_else(|_| {
4532 format!(r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{message}"}},"id":null}}"#)
4533 })
4534}
4535
4536/// Build a JSON-RPC error response string for WebSocket.
4537fn make_ws_error_response(id: Option<&Value>, code: i64, message: &str) -> String {
4538 make_ws_error_response_with_data(id, code, message, None)
4539}
4540
4541#[cfg(test)]
4542mod tests;