sentinel_agent_protocol/v2/
uds.rs

1//! Unix Domain Socket transport for Agent Protocol v2.
2//!
3//! This module provides a binary protocol implementation for v2 over UDS,
4//! supporting bidirectional streaming with connection multiplexing.
5//!
6//! # Wire Format
7//!
8//! All messages use a length-prefixed binary format:
9//! ```text
10//! +--------+--------+------------------+
11//! | Length | Type   | Payload          |
12//! | 4 bytes| 1 byte | variable         |
13//! | BE u32 | u8     | MessagePack/JSON |
14//! +--------+--------+------------------+
15//! ```
16//!
17//! # Message Types
18//!
19//! - 0x01: Handshake Request (proxy -> agent)
20//! - 0x02: Handshake Response (agent -> proxy)
21//! - 0x10: Request Headers Event
22//! - 0x11: Request Body Chunk Event
23//! - 0x12: Response Headers Event
24//! - 0x13: Response Body Chunk Event
25//! - 0x14: Request Complete Event
26//! - 0x15: WebSocket Frame Event
27//! - 0x16: Guardrail Inspect Event
28//! - 0x17: Configure Event
29//! - 0x20: Agent Response
30//! - 0x30: Health Status
31//! - 0x31: Metrics Report
32//! - 0x32: Config Update Request
33//! - 0x33: Flow Control Signal
34//! - 0x40: Cancel Request
35//! - 0x41: Ping
36//! - 0x42: Pong
37
38use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54/// Maximum message size for UDS transport (16 MB).
55pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57/// Payload encoding for UDS transport.
58///
59/// Negotiated during handshake. The proxy sends its supported encodings,
60/// and the agent responds with the chosen encoding.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64    /// JSON encoding (default, always supported)
65    #[default]
66    Json,
67    /// MessagePack binary encoding (requires `binary-uds` feature)
68    #[serde(rename = "msgpack")]
69    MessagePack,
70}
71
72impl UdsEncoding {
73    /// Serialize a value using this encoding.
74    ///
75    /// Returns the serialized bytes, or an error if serialization fails.
76    #[inline]
77    pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78        match self {
79            UdsEncoding::Json => {
80                serde_json::to_vec(value)
81                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
82            }
83            #[cfg(feature = "binary-uds")]
84            UdsEncoding::MessagePack => {
85                rmp_serde::to_vec(value)
86                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
87            }
88            #[cfg(not(feature = "binary-uds"))]
89            UdsEncoding::MessagePack => {
90                // Fall back to JSON if binary-uds feature is not enabled
91                serde_json::to_vec(value)
92                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
93            }
94        }
95    }
96
97    /// Deserialize a value using this encoding.
98    ///
99    /// Returns the deserialized value, or an error if deserialization fails.
100    #[inline]
101    pub fn deserialize<'a, T: serde::Deserialize<'a>>(&self, bytes: &'a [u8]) -> Result<T, AgentProtocolError> {
102        match self {
103            UdsEncoding::Json => {
104                serde_json::from_slice(bytes)
105                    .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
106            }
107            #[cfg(feature = "binary-uds")]
108            UdsEncoding::MessagePack => {
109                rmp_serde::from_slice(bytes)
110                    .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
111            }
112            #[cfg(not(feature = "binary-uds"))]
113            UdsEncoding::MessagePack => {
114                // Fall back to JSON if binary-uds feature is not enabled
115                serde_json::from_slice(bytes)
116                    .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
117            }
118        }
119    }
120}
121
122/// Message type identifiers for the binary protocol.
123#[repr(u8)]
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum MessageType {
126    // Handshake
127    HandshakeRequest = 0x01,
128    HandshakeResponse = 0x02,
129
130    // Events (proxy -> agent)
131    RequestHeaders = 0x10,
132    RequestBodyChunk = 0x11,
133    ResponseHeaders = 0x12,
134    ResponseBodyChunk = 0x13,
135    RequestComplete = 0x14,
136    WebSocketFrame = 0x15,
137    GuardrailInspect = 0x16,
138    Configure = 0x17,
139
140    // Response (agent -> proxy)
141    AgentResponse = 0x20,
142
143    // Control messages (bidirectional)
144    HealthStatus = 0x30,
145    MetricsReport = 0x31,
146    ConfigUpdateRequest = 0x32,
147    FlowControl = 0x33,
148
149    // Management
150    Cancel = 0x40,
151    Ping = 0x41,
152    Pong = 0x42,
153}
154
155impl TryFrom<u8> for MessageType {
156    type Error = AgentProtocolError;
157
158    fn try_from(value: u8) -> Result<Self, Self::Error> {
159        match value {
160            0x01 => Ok(MessageType::HandshakeRequest),
161            0x02 => Ok(MessageType::HandshakeResponse),
162            0x10 => Ok(MessageType::RequestHeaders),
163            0x11 => Ok(MessageType::RequestBodyChunk),
164            0x12 => Ok(MessageType::ResponseHeaders),
165            0x13 => Ok(MessageType::ResponseBodyChunk),
166            0x14 => Ok(MessageType::RequestComplete),
167            0x15 => Ok(MessageType::WebSocketFrame),
168            0x16 => Ok(MessageType::GuardrailInspect),
169            0x17 => Ok(MessageType::Configure),
170            0x20 => Ok(MessageType::AgentResponse),
171            0x30 => Ok(MessageType::HealthStatus),
172            0x31 => Ok(MessageType::MetricsReport),
173            0x32 => Ok(MessageType::ConfigUpdateRequest),
174            0x33 => Ok(MessageType::FlowControl),
175            0x40 => Ok(MessageType::Cancel),
176            0x41 => Ok(MessageType::Ping),
177            0x42 => Ok(MessageType::Pong),
178            _ => Err(AgentProtocolError::InvalidMessage(format!(
179                "Unknown message type: 0x{:02x}",
180                value
181            ))),
182        }
183    }
184}
185
186/// Handshake request sent from proxy to agent over UDS.
187#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
188pub struct UdsHandshakeRequest {
189    pub supported_versions: Vec<u32>,
190    pub proxy_id: String,
191    pub proxy_version: String,
192    pub config: Option<serde_json::Value>,
193    /// Supported payload encodings (in order of preference).
194    /// If empty or missing, only JSON is supported.
195    #[serde(default, skip_serializing_if = "Vec::is_empty")]
196    pub supported_encodings: Vec<UdsEncoding>,
197}
198
199/// Handshake response from agent to proxy over UDS.
200#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
201pub struct UdsHandshakeResponse {
202    pub protocol_version: u32,
203    pub capabilities: UdsCapabilities,
204    pub success: bool,
205    pub error: Option<String>,
206    /// Negotiated encoding for subsequent messages.
207    /// If missing, defaults to JSON for backwards compatibility.
208    #[serde(default)]
209    pub encoding: UdsEncoding,
210}
211
212/// Agent capabilities for UDS protocol.
213#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
214pub struct UdsCapabilities {
215    pub agent_id: String,
216    pub name: String,
217    pub version: String,
218    pub supported_events: Vec<i32>,
219    pub features: UdsFeatures,
220    pub limits: UdsLimits,
221}
222
223/// Agent features.
224#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
225pub struct UdsFeatures {
226    pub streaming_body: bool,
227    pub websocket: bool,
228    pub guardrails: bool,
229    pub config_push: bool,
230    pub metrics_export: bool,
231    pub concurrent_requests: u32,
232    pub cancellation: bool,
233    pub flow_control: bool,
234    pub health_reporting: bool,
235}
236
237/// Agent limits.
238#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
239pub struct UdsLimits {
240    pub max_body_size: u64,
241    pub max_concurrency: u32,
242    pub preferred_chunk_size: u64,
243}
244
245impl From<UdsCapabilities> for AgentCapabilities {
246    fn from(caps: UdsCapabilities) -> Self {
247        AgentCapabilities {
248            protocol_version: PROTOCOL_VERSION_2 as u32,
249            agent_id: caps.agent_id,
250            name: caps.name,
251            version: caps.version,
252            supported_events: caps
253                .supported_events
254                .into_iter()
255                .filter_map(event_type_from_i32)
256                .collect(),
257            features: AgentFeatures {
258                streaming_body: caps.features.streaming_body,
259                websocket: caps.features.websocket,
260                guardrails: caps.features.guardrails,
261                config_push: caps.features.config_push,
262                metrics_export: caps.features.metrics_export,
263                concurrent_requests: caps.features.concurrent_requests,
264                cancellation: caps.features.cancellation,
265                flow_control: caps.features.flow_control,
266                health_reporting: caps.features.health_reporting,
267            },
268            limits: AgentLimits {
269                max_body_size: caps.limits.max_body_size as usize,
270                max_concurrency: caps.limits.max_concurrency,
271                preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
272                max_memory: None,
273                max_processing_time_ms: None,
274            },
275            health: HealthConfig::default(),
276        }
277    }
278}
279
280/// Convert i32 to EventType.
281fn event_type_from_i32(value: i32) -> Option<EventType> {
282    match value {
283        0 => Some(EventType::Configure),
284        1 => Some(EventType::RequestHeaders),
285        2 => Some(EventType::RequestBodyChunk),
286        3 => Some(EventType::ResponseHeaders),
287        4 => Some(EventType::ResponseBodyChunk),
288        5 => Some(EventType::RequestComplete),
289        6 => Some(EventType::WebSocketFrame),
290        7 => Some(EventType::GuardrailInspect),
291        _ => None,
292    }
293}
294
295/// v2 agent client over Unix Domain Socket.
296///
297/// This client maintains a single connection and multiplexes multiple requests
298/// over it using correlation IDs, similar to the gRPC client.
299pub struct AgentClientV2Uds {
300    /// Agent identifier
301    agent_id: String,
302    /// Socket path
303    socket_path: String,
304    /// Request timeout
305    timeout: Duration,
306    /// Negotiated capabilities
307    capabilities: RwLock<Option<AgentCapabilities>>,
308    /// Negotiated protocol version
309    protocol_version: AtomicU64,
310    /// Negotiated payload encoding
311    encoding: RwLock<UdsEncoding>,
312    /// Pending requests by correlation ID
313    pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
314    /// Sender for outbound messages
315    outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
316    /// Sequence counter for pings
317    ping_sequence: AtomicU64,
318    /// Connection state
319    connected: RwLock<bool>,
320    /// Flow control state
321    flow_state: RwLock<FlowState>,
322    /// Last known health state
323    health_state: RwLock<i32>,
324    /// In-flight request count
325    in_flight: AtomicU64,
326    /// Callback for metrics reports
327    metrics_callback: Option<MetricsCallback>,
328    /// Callback for config update requests
329    config_update_callback: Option<ConfigUpdateCallback>,
330}
331
332impl AgentClientV2Uds {
333    /// Create a new UDS v2 client.
334    pub async fn new(
335        agent_id: impl Into<String>,
336        socket_path: impl Into<String>,
337        timeout: Duration,
338    ) -> Result<Self, AgentProtocolError> {
339        let agent_id = agent_id.into();
340        let socket_path = socket_path.into();
341
342        debug!(
343            agent_id = %agent_id,
344            socket_path = %socket_path,
345            timeout_ms = timeout.as_millis(),
346            "Creating UDS v2 client"
347        );
348
349        Ok(Self {
350            agent_id,
351            socket_path,
352            timeout,
353            capabilities: RwLock::new(None),
354            protocol_version: AtomicU64::new(0),
355            encoding: RwLock::new(UdsEncoding::Json),
356            pending: Arc::new(Mutex::new(HashMap::new())),
357            outbound_tx: Mutex::new(None),
358            ping_sequence: AtomicU64::new(0),
359            connected: RwLock::new(false),
360            flow_state: RwLock::new(FlowState::Normal),
361            health_state: RwLock::new(1), // HEALTHY
362            in_flight: AtomicU64::new(0),
363            metrics_callback: None,
364            config_update_callback: None,
365        })
366    }
367
368    /// Returns the list of supported encodings for this client.
369    ///
370    /// When compiled with `binary-uds` feature, MessagePack is preferred.
371    fn supported_encodings() -> Vec<UdsEncoding> {
372        #[cfg(feature = "binary-uds")]
373        {
374            vec![UdsEncoding::MessagePack, UdsEncoding::Json]
375        }
376        #[cfg(not(feature = "binary-uds"))]
377        {
378            vec![UdsEncoding::Json]
379        }
380    }
381
382    /// Get the current negotiated encoding.
383    pub async fn encoding(&self) -> UdsEncoding {
384        *self.encoding.read().await
385    }
386
387    /// Set the metrics callback.
388    pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
389        self.metrics_callback = Some(callback);
390    }
391
392    /// Set the config update callback.
393    pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
394        self.config_update_callback = Some(callback);
395    }
396
397    /// Connect and perform handshake.
398    pub async fn connect(&self) -> Result<(), AgentProtocolError> {
399        info!(
400            agent_id = %self.agent_id,
401            socket_path = %self.socket_path,
402            "Connecting to agent via UDS v2"
403        );
404
405        // Connect to Unix socket
406        let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
407            error!(
408                agent_id = %self.agent_id,
409                socket_path = %self.socket_path,
410                error = %e,
411                "Failed to connect to agent via UDS"
412            );
413            AgentProtocolError::ConnectionFailed(e.to_string())
414        })?;
415
416        let (read_half, write_half) = stream.into_split();
417        let mut reader = BufReader::new(read_half);
418        let mut writer = BufWriter::new(write_half);
419
420        // Send handshake request with supported encodings
421        let handshake_req = UdsHandshakeRequest {
422            supported_versions: vec![PROTOCOL_VERSION_2 as u32],
423            proxy_id: "sentinel-proxy".to_string(),
424            proxy_version: env!("CARGO_PKG_VERSION").to_string(),
425            config: None,
426            supported_encodings: Self::supported_encodings(),
427        };
428
429        // Handshake always uses JSON (before encoding is negotiated)
430        let payload = serde_json::to_vec(&handshake_req)
431            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
432
433        write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
434
435        // Read handshake response (always JSON)
436        let (msg_type, response_bytes) = read_message(&mut reader).await?;
437
438        if msg_type != MessageType::HandshakeResponse {
439            return Err(AgentProtocolError::InvalidMessage(format!(
440                "Expected HandshakeResponse, got {:?}",
441                msg_type
442            )));
443        }
444
445        let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
446            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
447
448        if !response.success {
449            return Err(AgentProtocolError::ConnectionFailed(
450                response.error.unwrap_or_else(|| "Unknown handshake error".to_string()),
451            ));
452        }
453
454        // Store capabilities and negotiated encoding
455        let capabilities: AgentCapabilities = response.capabilities.into();
456        *self.capabilities.write().await = Some(capabilities);
457        self.protocol_version
458            .store(response.protocol_version as u64, Ordering::SeqCst);
459
460        // Store the negotiated encoding for subsequent messages
461        let negotiated_encoding = response.encoding;
462        *self.encoding.write().await = negotiated_encoding;
463
464        info!(
465            agent_id = %self.agent_id,
466            protocol_version = response.protocol_version,
467            encoding = ?negotiated_encoding,
468            "UDS v2 handshake successful"
469        );
470
471        // Create message channel
472        let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
473        *self.outbound_tx.lock().await = Some(tx);
474        *self.connected.write().await = true;
475
476        // Spawn writer task
477        let agent_id_clone = self.agent_id.clone();
478        tokio::spawn(async move {
479            while let Some((msg_type, payload)) = rx.recv().await {
480                if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
481                    error!(
482                        agent_id = %agent_id_clone,
483                        error = %e,
484                        "Failed to write message to UDS"
485                    );
486                    break;
487                }
488            }
489            debug!(agent_id = %agent_id_clone, "UDS writer task ended");
490        });
491
492        // Spawn reader task with the negotiated encoding
493        let pending = Arc::clone(&self.pending);
494        let agent_id = self.agent_id.clone();
495        let flow_state = Arc::new(RwLock::new(FlowState::Normal));
496        let health_state = Arc::new(RwLock::new(1i32));
497        let flow_state_clone = Arc::clone(&flow_state);
498        let health_state_clone = Arc::clone(&health_state);
499        let metrics_callback = self.metrics_callback.clone();
500        let config_update_callback = self.config_update_callback.clone();
501        // Encoding is fixed after handshake, so we can copy it
502        let reader_encoding = negotiated_encoding;
503
504        tokio::spawn(async move {
505            loop {
506                match read_message(&mut reader).await {
507                    Ok((msg_type, payload)) => {
508                        match msg_type {
509                            MessageType::AgentResponse => {
510                                match reader_encoding.deserialize::<AgentResponse>(&payload) {
511                                    Ok(response) => {
512                                        // Extract correlation ID from the response
513                                        // For UDS, we include correlation_id in the response
514                                        if let Some(sender) = pending.lock().await.remove(&response.audit.custom.get("correlation_id").and_then(|v| v.as_str()).unwrap_or("").to_string()) {
515                                            let _ = sender.send(response);
516                                        }
517                                    }
518                                    Err(e) => {
519                                        warn!(
520                                            agent_id = %agent_id,
521                                            error = %e,
522                                            encoding = ?reader_encoding,
523                                            "Failed to parse agent response"
524                                        );
525                                    }
526                                }
527                            }
528                            MessageType::HealthStatus => {
529                                // Health status uses a simple struct, try both encodings for robustness
530                                #[derive(serde::Deserialize)]
531                                struct HealthStatusMsg {
532                                    state: Option<i64>,
533                                }
534                                if let Ok(health) = reader_encoding.deserialize::<HealthStatusMsg>(&payload) {
535                                    if let Some(state) = health.state {
536                                        *health_state_clone.write().await = state as i32;
537                                    }
538                                }
539                            }
540                            MessageType::MetricsReport => {
541                                if let Some(ref callback) = metrics_callback {
542                                    if let Ok(report) = reader_encoding.deserialize(&payload) {
543                                        callback(report);
544                                    }
545                                }
546                            }
547                            MessageType::FlowControl => {
548                                #[derive(serde::Deserialize)]
549                                struct FlowControlMsg {
550                                    action: Option<i64>,
551                                }
552                                if let Ok(fc) = reader_encoding.deserialize::<FlowControlMsg>(&payload) {
553                                    let action = fc.action.unwrap_or(0);
554                                    let new_state = match action {
555                                        1 => FlowState::Paused,
556                                        2 => FlowState::Normal,
557                                        _ => FlowState::Normal,
558                                    };
559                                    *flow_state_clone.write().await = new_state;
560                                }
561                            }
562                            MessageType::ConfigUpdateRequest => {
563                                if let Some(ref callback) = config_update_callback {
564                                    if let Ok(request) = reader_encoding.deserialize(&payload) {
565                                        let _response = callback(agent_id.clone(), request);
566                                    }
567                                }
568                            }
569                            MessageType::Pong => {
570                                trace!(agent_id = %agent_id, "Received pong");
571                            }
572                            _ => {
573                                trace!(
574                                    agent_id = %agent_id,
575                                    msg_type = ?msg_type,
576                                    "Received unhandled message type"
577                                );
578                            }
579                        }
580                    }
581                    Err(e) => {
582                        if !matches!(e, AgentProtocolError::ConnectionClosed) {
583                            error!(
584                                agent_id = %agent_id,
585                                error = %e,
586                                "Error reading from UDS"
587                            );
588                        }
589                        break;
590                    }
591                }
592            }
593            debug!(agent_id = %agent_id, "UDS reader task ended");
594        });
595
596        Ok(())
597    }
598
599    /// Get negotiated capabilities.
600    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
601        self.capabilities.read().await.clone()
602    }
603
604    /// Check if connected.
605    pub async fn is_connected(&self) -> bool {
606        *self.connected.read().await
607    }
608
609    /// Send a request headers event.
610    pub async fn send_request_headers(
611        &self,
612        correlation_id: &str,
613        event: &crate::RequestHeadersEvent,
614    ) -> Result<AgentResponse, AgentProtocolError> {
615        self.send_event(MessageType::RequestHeaders, correlation_id, event).await
616    }
617
618    /// Send a request body chunk event.
619    pub async fn send_request_body_chunk(
620        &self,
621        correlation_id: &str,
622        event: &crate::RequestBodyChunkEvent,
623    ) -> Result<AgentResponse, AgentProtocolError> {
624        self.send_event(MessageType::RequestBodyChunk, correlation_id, event).await
625    }
626
627    /// Send a response headers event.
628    pub async fn send_response_headers(
629        &self,
630        correlation_id: &str,
631        event: &crate::ResponseHeadersEvent,
632    ) -> Result<AgentResponse, AgentProtocolError> {
633        self.send_event(MessageType::ResponseHeaders, correlation_id, event).await
634    }
635
636    /// Send a response body chunk event.
637    pub async fn send_response_body_chunk(
638        &self,
639        correlation_id: &str,
640        event: &crate::ResponseBodyChunkEvent,
641    ) -> Result<AgentResponse, AgentProtocolError> {
642        self.send_event(MessageType::ResponseBodyChunk, correlation_id, event).await
643    }
644
645    /// Send a binary request body chunk event (zero-copy path).
646    ///
647    /// This method avoids base64 encoding when using MessagePack encoding,
648    /// sending raw bytes directly over the wire for better throughput.
649    ///
650    /// # Performance
651    ///
652    /// When MessagePack encoding is negotiated:
653    /// - Bytes are serialized directly (no base64 encode/decode)
654    /// - Reduces CPU usage and latency for large bodies
655    ///
656    /// When JSON encoding is used:
657    /// - Falls back to base64 encoding for JSON compatibility
658    pub async fn send_request_body_chunk_binary(
659        &self,
660        event: &crate::BinaryRequestBodyChunkEvent,
661    ) -> Result<AgentResponse, AgentProtocolError> {
662        let correlation_id = &event.correlation_id;
663        self.send_binary_body_chunk(
664            MessageType::RequestBodyChunk,
665            correlation_id,
666            &event.data,
667            event.is_last,
668            event.total_size,
669            event.chunk_index,
670            Some(event.bytes_received),
671            None,
672        ).await
673    }
674
675    /// Send a binary response body chunk event (zero-copy path).
676    ///
677    /// This method avoids base64 encoding when using MessagePack encoding,
678    /// sending raw bytes directly over the wire for better throughput.
679    pub async fn send_response_body_chunk_binary(
680        &self,
681        event: &crate::BinaryResponseBodyChunkEvent,
682    ) -> Result<AgentResponse, AgentProtocolError> {
683        let correlation_id = &event.correlation_id;
684        self.send_binary_body_chunk(
685            MessageType::ResponseBodyChunk,
686            correlation_id,
687            &event.data,
688            event.is_last,
689            event.total_size,
690            event.chunk_index,
691            None,
692            Some(event.bytes_sent),
693        ).await
694    }
695
696    /// Internal helper to send binary body chunks with encoding-aware serialization.
697    #[allow(clippy::too_many_arguments)]
698    async fn send_binary_body_chunk(
699        &self,
700        msg_type: MessageType,
701        correlation_id: &str,
702        data: &bytes::Bytes,
703        is_last: bool,
704        total_size: Option<usize>,
705        chunk_index: u32,
706        bytes_received: Option<usize>,
707        bytes_sent: Option<usize>,
708    ) -> Result<AgentResponse, AgentProtocolError> {
709        // Create response channel
710        let (tx, rx) = oneshot::channel();
711        self.pending
712            .lock()
713            .await
714            .insert(correlation_id.to_string(), tx);
715
716        // Get the current encoding
717        let encoding = *self.encoding.read().await;
718
719        // Serialize body chunk using encoding-optimized format
720        let payload_bytes = match encoding {
721            UdsEncoding::Json => {
722                // JSON path: must use base64 encoding for binary data
723                use base64::{engine::general_purpose::STANDARD, Engine as _};
724                let json = serde_json::json!({
725                    "correlation_id": correlation_id,
726                    "data": STANDARD.encode(data),
727                    "is_last": is_last,
728                    "total_size": total_size,
729                    "chunk_index": chunk_index,
730                    "bytes_received": bytes_received,
731                    "bytes_sent": bytes_sent,
732                });
733                serde_json::to_vec(&json)
734                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
735            }
736            UdsEncoding::MessagePack => {
737                // MessagePack path: raw bytes via serde_bytes for zero-copy serialization
738                #[derive(serde::Serialize)]
739                struct BinaryBodyChunk<'a> {
740                    correlation_id: &'a str,
741                    #[serde(with = "serde_bytes")]
742                    data: &'a [u8],
743                    is_last: bool,
744                    #[serde(skip_serializing_if = "Option::is_none")]
745                    total_size: Option<usize>,
746                    chunk_index: u32,
747                    #[serde(skip_serializing_if = "Option::is_none")]
748                    bytes_received: Option<usize>,
749                    #[serde(skip_serializing_if = "Option::is_none")]
750                    bytes_sent: Option<usize>,
751                }
752                let chunk = BinaryBodyChunk {
753                    correlation_id,
754                    data: data.as_ref(),
755                    is_last,
756                    total_size,
757                    chunk_index,
758                    bytes_received,
759                    bytes_sent,
760                };
761                encoding.serialize(&chunk)?
762            }
763        };
764
765        // Send message
766        {
767            let outbound = self.outbound_tx.lock().await;
768            if let Some(tx) = outbound.as_ref() {
769                tx.send((msg_type, payload_bytes))
770                    .await
771                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
772            } else {
773                return Err(AgentProtocolError::ConnectionClosed);
774            }
775        }
776
777        self.in_flight.fetch_add(1, Ordering::Relaxed);
778
779        // Wait for response with timeout
780        let response = tokio::time::timeout(self.timeout, rx)
781            .await
782            .map_err(|_| {
783                self.pending.try_lock().ok().map(|mut p| p.remove(correlation_id));
784                AgentProtocolError::Timeout(self.timeout)
785            })?
786            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
787
788        self.in_flight.fetch_sub(1, Ordering::Relaxed);
789
790        Ok(response)
791    }
792
793    /// Send an event and wait for response.
794    async fn send_event<T: serde::Serialize>(
795        &self,
796        msg_type: MessageType,
797        correlation_id: &str,
798        event: &T,
799    ) -> Result<AgentResponse, AgentProtocolError> {
800        // Create response channel
801        let (tx, rx) = oneshot::channel();
802        self.pending
803            .lock()
804            .await
805            .insert(correlation_id.to_string(), tx);
806
807        // Get the current encoding
808        let encoding = *self.encoding.read().await;
809
810        // Serialize event using negotiated encoding
811        let payload_bytes = match encoding {
812            UdsEncoding::Json => {
813                // JSON path: use Value mutation for backwards compatibility
814                let mut payload = serde_json::to_value(event)
815                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
816                if let Some(obj) = payload.as_object_mut() {
817                    obj.insert(
818                        "correlation_id".to_string(),
819                        serde_json::Value::String(correlation_id.to_string()),
820                    );
821                }
822                serde_json::to_vec(&payload)
823                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
824            }
825            UdsEncoding::MessagePack => {
826                // MessagePack path: use wrapper struct for efficient serialization
827                #[derive(serde::Serialize)]
828                struct EventWithCorrelation<'a, T: serde::Serialize> {
829                    correlation_id: &'a str,
830                    #[serde(flatten)]
831                    event: &'a T,
832                }
833                let wrapped = EventWithCorrelation {
834                    correlation_id,
835                    event,
836                };
837                encoding.serialize(&wrapped)?
838            }
839        };
840
841        // Send message
842        {
843            let outbound = self.outbound_tx.lock().await;
844            if let Some(tx) = outbound.as_ref() {
845                tx.send((msg_type, payload_bytes))
846                    .await
847                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
848            } else {
849                return Err(AgentProtocolError::ConnectionClosed);
850            }
851        }
852
853        self.in_flight.fetch_add(1, Ordering::Relaxed);
854
855        // Wait for response with timeout
856        let response = tokio::time::timeout(self.timeout, rx)
857            .await
858            .map_err(|_| {
859                self.pending.try_lock().ok().map(|mut p| p.remove(correlation_id));
860                AgentProtocolError::Timeout(self.timeout)
861            })?
862            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
863
864        self.in_flight.fetch_sub(1, Ordering::Relaxed);
865
866        Ok(response)
867    }
868
869    /// Send a cancel request for a specific correlation ID.
870    pub async fn cancel_request(
871        &self,
872        correlation_id: &str,
873        reason: super::client::CancelReason,
874    ) -> Result<(), AgentProtocolError> {
875        let cancel = serde_json::json!({
876            "correlation_id": correlation_id,
877            "reason": reason as i32,
878            "timestamp_ms": now_ms(),
879        });
880
881        let payload = serde_json::to_vec(&cancel)
882            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
883
884        let outbound = self.outbound_tx.lock().await;
885        if let Some(tx) = outbound.as_ref() {
886            tx.send((MessageType::Cancel, payload))
887                .await
888                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
889        }
890
891        // Remove pending request
892        self.pending.lock().await.remove(correlation_id);
893
894        Ok(())
895    }
896
897    /// Cancel all in-flight requests.
898    pub async fn cancel_all(
899        &self,
900        reason: super::client::CancelReason,
901    ) -> Result<usize, AgentProtocolError> {
902        let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
903        let count = pending_ids.len();
904
905        for correlation_id in pending_ids {
906            let _ = self.cancel_request(&correlation_id, reason).await;
907        }
908
909        Ok(count)
910    }
911
912    /// Send a ping.
913    pub async fn ping(&self) -> Result<(), AgentProtocolError> {
914        let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
915        let ping = serde_json::json!({
916            "sequence": seq,
917            "timestamp_ms": now_ms(),
918        });
919
920        let payload = serde_json::to_vec(&ping)
921            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
922
923        let outbound = self.outbound_tx.lock().await;
924        if let Some(tx) = outbound.as_ref() {
925            tx.send((MessageType::Ping, payload))
926                .await
927                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
928        }
929
930        Ok(())
931    }
932
933    /// Close the connection.
934    pub async fn close(&self) -> Result<(), AgentProtocolError> {
935        *self.connected.write().await = false;
936        *self.outbound_tx.lock().await = None;
937        Ok(())
938    }
939
940    /// Get in-flight request count.
941    pub fn in_flight(&self) -> u64 {
942        self.in_flight.load(Ordering::Relaxed)
943    }
944
945    /// Get agent ID.
946    pub fn agent_id(&self) -> &str {
947        &self.agent_id
948    }
949
950    /// Check if the agent has requested flow control pause.
951    ///
952    /// Returns true if the agent sent a `FlowAction::Pause` signal,
953    /// indicating it cannot accept more requests.
954    pub async fn is_paused(&self) -> bool {
955        matches!(*self.flow_state.read().await, FlowState::Paused)
956    }
957
958    /// Check if the transport can accept new requests.
959    ///
960    /// Returns false if the agent has requested a flow control pause.
961    pub async fn can_accept_requests(&self) -> bool {
962        !self.is_paused().await
963    }
964}
965
966/// Write a message to the stream.
967pub async fn write_message<W: AsyncWriteExt + Unpin>(
968    writer: &mut W,
969    msg_type: MessageType,
970    payload: &[u8],
971) -> Result<(), AgentProtocolError> {
972    if payload.len() > MAX_UDS_MESSAGE_SIZE {
973        return Err(AgentProtocolError::MessageTooLarge {
974            size: payload.len(),
975            max: MAX_UDS_MESSAGE_SIZE,
976        });
977    }
978
979    // Write length (4 bytes, big-endian) - includes type byte
980    let total_len = (payload.len() + 1) as u32;
981    writer.write_all(&total_len.to_be_bytes()).await?;
982
983    // Write message type (1 byte)
984    writer.write_all(&[msg_type as u8]).await?;
985
986    // Write payload
987    writer.write_all(payload).await?;
988    writer.flush().await?;
989
990    Ok(())
991}
992
993/// Read a message from the stream.
994pub async fn read_message<R: AsyncReadExt + Unpin>(
995    reader: &mut R,
996) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
997    // Read length (4 bytes, big-endian)
998    let mut len_bytes = [0u8; 4];
999    match reader.read_exact(&mut len_bytes).await {
1000        Ok(_) => {}
1001        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1002            return Err(AgentProtocolError::ConnectionClosed);
1003        }
1004        Err(e) => return Err(e.into()),
1005    }
1006
1007    let total_len = u32::from_be_bytes(len_bytes) as usize;
1008
1009    if total_len == 0 {
1010        return Err(AgentProtocolError::InvalidMessage(
1011            "Zero-length message".to_string(),
1012        ));
1013    }
1014
1015    if total_len > MAX_UDS_MESSAGE_SIZE {
1016        return Err(AgentProtocolError::MessageTooLarge {
1017            size: total_len,
1018            max: MAX_UDS_MESSAGE_SIZE,
1019        });
1020    }
1021
1022    // Read message type (1 byte)
1023    let mut type_byte = [0u8; 1];
1024    reader.read_exact(&mut type_byte).await?;
1025    let msg_type = MessageType::try_from(type_byte[0])?;
1026
1027    // Read payload
1028    let payload_len = total_len - 1;
1029    let mut payload = vec![0u8; payload_len];
1030    if payload_len > 0 {
1031        reader.read_exact(&mut payload).await?;
1032    }
1033
1034    Ok((msg_type, payload))
1035}
1036
1037fn now_ms() -> u64 {
1038    std::time::SystemTime::now()
1039        .duration_since(std::time::UNIX_EPOCH)
1040        .map(|d| d.as_millis() as u64)
1041        .unwrap_or(0)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047
1048    #[test]
1049    fn test_message_type_roundtrip() {
1050        let types = [
1051            MessageType::HandshakeRequest,
1052            MessageType::HandshakeResponse,
1053            MessageType::RequestHeaders,
1054            MessageType::AgentResponse,
1055            MessageType::HealthStatus,
1056            MessageType::Ping,
1057            MessageType::Pong,
1058        ];
1059
1060        for msg_type in types {
1061            let byte = msg_type as u8;
1062            let parsed = MessageType::try_from(byte).unwrap();
1063            assert_eq!(parsed, msg_type);
1064        }
1065    }
1066
1067    #[test]
1068    fn test_invalid_message_type() {
1069        let result = MessageType::try_from(0xFF);
1070        assert!(result.is_err());
1071    }
1072
1073    #[test]
1074    fn test_handshake_serialization() {
1075        let req = UdsHandshakeRequest {
1076            supported_versions: vec![2],
1077            proxy_id: "test-proxy".to_string(),
1078            proxy_version: "1.0.0".to_string(),
1079            config: None,
1080            supported_encodings: vec![],
1081        };
1082
1083        let json = serde_json::to_string(&req).unwrap();
1084        let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1085
1086        assert_eq!(parsed.supported_versions, vec![2]);
1087        assert_eq!(parsed.proxy_id, "test-proxy");
1088    }
1089
1090    #[tokio::test]
1091    async fn test_write_read_message() {
1092        use tokio::io::duplex;
1093
1094        let (mut client, mut server) = duplex(1024);
1095
1096        // Write from client
1097        let payload = b"test payload";
1098        write_message(&mut client, MessageType::Ping, payload)
1099            .await
1100            .unwrap();
1101
1102        // Read from server
1103        let (msg_type, data) = read_message(&mut server).await.unwrap();
1104        assert_eq!(msg_type, MessageType::Ping);
1105        assert_eq!(data, payload);
1106    }
1107
1108    #[test]
1109    fn test_binary_body_chunk_json_serialization() {
1110        use base64::{engine::general_purpose::STANDARD, Engine as _};
1111
1112        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1113        let correlation_id = "test-123";
1114
1115        // JSON encoding must use base64
1116        let json = serde_json::json!({
1117            "correlation_id": correlation_id,
1118            "data": STANDARD.encode(&data),
1119            "is_last": true,
1120            "total_size": 100usize,
1121            "chunk_index": 0u32,
1122            "bytes_received": 100usize,
1123        });
1124
1125        let serialized = serde_json::to_vec(&json).unwrap();
1126        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1127
1128        // Verify base64 can be decoded back to original
1129        let data_field = parsed["data"].as_str().unwrap();
1130        let decoded = STANDARD.decode(data_field).unwrap();
1131        assert_eq!(decoded, data.as_ref());
1132    }
1133
1134    #[test]
1135    #[cfg(feature = "binary-uds")]
1136    fn test_binary_body_chunk_msgpack_serialization() {
1137        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1138        let correlation_id = "test-123";
1139
1140        // MessagePack uses serde_bytes for efficient serialization
1141        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1142        struct BinaryBodyChunk {
1143            correlation_id: String,
1144            #[serde(with = "serde_bytes")]
1145            data: Vec<u8>,
1146            is_last: bool,
1147            chunk_index: u32,
1148        }
1149
1150        let chunk = BinaryBodyChunk {
1151            correlation_id: correlation_id.to_string(),
1152            data: data.to_vec(),
1153            is_last: true,
1154            chunk_index: 0,
1155        };
1156
1157        // Serialize with MessagePack
1158        let serialized = rmp_serde::to_vec(&chunk).unwrap();
1159
1160        // Deserialize and verify
1161        let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1162        assert_eq!(parsed.correlation_id, correlation_id);
1163        assert_eq!(parsed.data, data.as_ref());
1164        assert!(parsed.is_last);
1165
1166        // Verify MessagePack is more compact than JSON+base64 for binary data
1167        use base64::Engine as _;
1168        let json_size = serde_json::to_vec(&serde_json::json!({
1169            "correlation_id": correlation_id,
1170            "data": base64::engine::general_purpose::STANDARD.encode(&data),
1171            "is_last": true,
1172            "chunk_index": 0u32,
1173        })).unwrap().len();
1174
1175        // MessagePack should be smaller (raw bytes vs base64 ~33% overhead)
1176        assert!(serialized.len() < json_size,
1177            "MessagePack ({}) should be smaller than JSON+base64 ({})",
1178            serialized.len(), json_size);
1179    }
1180
1181    #[test]
1182    fn test_uds_encoding_default() {
1183        assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1184    }
1185
1186    #[test]
1187    fn test_uds_encoding_serialize_json() {
1188        let encoding = UdsEncoding::Json;
1189        let value = serde_json::json!({"key": "value"});
1190        let serialized = encoding.serialize(&value).unwrap();
1191        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1192        assert_eq!(parsed, value);
1193    }
1194
1195    #[test]
1196    #[cfg(feature = "binary-uds")]
1197    fn test_uds_encoding_serialize_msgpack() {
1198        let encoding = UdsEncoding::MessagePack;
1199        let value = serde_json::json!({"key": "value"});
1200        let serialized = encoding.serialize(&value).unwrap();
1201        // Verify it's valid MessagePack by deserializing
1202        let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1203        assert_eq!(parsed, value);
1204    }
1205}