Skip to main content

sentinel_agent_protocol/v2/
uds.rs

1//! Unix Domain Socket transport for Agent Protocol v2.
2//!
3//! This module provides a binary protocol implementation for v2 over UDS,
4//! supporting bidirectional streaming with connection multiplexing.
5//!
6//! # Wire Format
7//!
8//! All messages use a length-prefixed binary format:
9//! ```text
10//! +--------+--------+------------------+
11//! | Length | Type   | Payload          |
12//! | 4 bytes| 1 byte | variable         |
13//! | BE u32 | u8     | MessagePack/JSON |
14//! +--------+--------+------------------+
15//! ```
16//!
17//! # Message Types
18//!
19//! - 0x01: Handshake Request (proxy -> agent)
20//! - 0x02: Handshake Response (agent -> proxy)
21//! - 0x10: Request Headers Event
22//! - 0x11: Request Body Chunk Event
23//! - 0x12: Response Headers Event
24//! - 0x13: Response Body Chunk Event
25//! - 0x14: Request Complete Event
26//! - 0x15: WebSocket Frame Event
27//! - 0x16: Guardrail Inspect Event
28//! - 0x17: Configure Event
29//! - 0x20: Agent Response
30//! - 0x30: Health Status
31//! - 0x31: Metrics Report
32//! - 0x32: Config Update Request
33//! - 0x33: Flow Control Signal
34//! - 0x40: Cancel Request
35//! - 0x41: Ping
36//! - 0x42: Pong
37
38use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54/// Maximum message size for UDS transport (16 MB).
55pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57/// Payload encoding for UDS transport.
58///
59/// Negotiated during handshake. The proxy sends its supported encodings,
60/// and the agent responds with the chosen encoding.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64    /// JSON encoding (default, always supported)
65    #[default]
66    Json,
67    /// MessagePack binary encoding (requires `binary-uds` feature)
68    #[serde(rename = "msgpack")]
69    MessagePack,
70}
71
72impl UdsEncoding {
73    /// Serialize a value using this encoding.
74    ///
75    /// Returns the serialized bytes, or an error if serialization fails.
76    #[inline]
77    pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78        match self {
79            UdsEncoding::Json => serde_json::to_vec(value)
80                .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
81            #[cfg(feature = "binary-uds")]
82            UdsEncoding::MessagePack => rmp_serde::to_vec(value)
83                .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
84            #[cfg(not(feature = "binary-uds"))]
85            UdsEncoding::MessagePack => {
86                // Fall back to JSON if binary-uds feature is not enabled
87                serde_json::to_vec(value)
88                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
89            }
90        }
91    }
92
93    /// Deserialize a value using this encoding.
94    ///
95    /// Returns the deserialized value, or an error if deserialization fails.
96    #[inline]
97    pub fn deserialize<'a, T: serde::Deserialize<'a>>(
98        &self,
99        bytes: &'a [u8],
100    ) -> Result<T, AgentProtocolError> {
101        match self {
102            UdsEncoding::Json => serde_json::from_slice(bytes)
103                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
104            #[cfg(feature = "binary-uds")]
105            UdsEncoding::MessagePack => rmp_serde::from_slice(bytes)
106                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
107            #[cfg(not(feature = "binary-uds"))]
108            UdsEncoding::MessagePack => {
109                // Fall back to JSON if binary-uds feature is not enabled
110                serde_json::from_slice(bytes)
111                    .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
112            }
113        }
114    }
115}
116
117/// Message type identifiers for the binary protocol.
118#[repr(u8)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MessageType {
121    // Handshake
122    HandshakeRequest = 0x01,
123    HandshakeResponse = 0x02,
124
125    // Events (proxy -> agent)
126    RequestHeaders = 0x10,
127    RequestBodyChunk = 0x11,
128    ResponseHeaders = 0x12,
129    ResponseBodyChunk = 0x13,
130    RequestComplete = 0x14,
131    WebSocketFrame = 0x15,
132    GuardrailInspect = 0x16,
133    Configure = 0x17,
134
135    // Response (agent -> proxy)
136    AgentResponse = 0x20,
137
138    // Control messages (bidirectional)
139    HealthStatus = 0x30,
140    MetricsReport = 0x31,
141    ConfigUpdateRequest = 0x32,
142    FlowControl = 0x33,
143
144    // Management
145    Cancel = 0x40,
146    Ping = 0x41,
147    Pong = 0x42,
148}
149
150impl TryFrom<u8> for MessageType {
151    type Error = AgentProtocolError;
152
153    fn try_from(value: u8) -> Result<Self, Self::Error> {
154        match value {
155            0x01 => Ok(MessageType::HandshakeRequest),
156            0x02 => Ok(MessageType::HandshakeResponse),
157            0x10 => Ok(MessageType::RequestHeaders),
158            0x11 => Ok(MessageType::RequestBodyChunk),
159            0x12 => Ok(MessageType::ResponseHeaders),
160            0x13 => Ok(MessageType::ResponseBodyChunk),
161            0x14 => Ok(MessageType::RequestComplete),
162            0x15 => Ok(MessageType::WebSocketFrame),
163            0x16 => Ok(MessageType::GuardrailInspect),
164            0x17 => Ok(MessageType::Configure),
165            0x20 => Ok(MessageType::AgentResponse),
166            0x30 => Ok(MessageType::HealthStatus),
167            0x31 => Ok(MessageType::MetricsReport),
168            0x32 => Ok(MessageType::ConfigUpdateRequest),
169            0x33 => Ok(MessageType::FlowControl),
170            0x40 => Ok(MessageType::Cancel),
171            0x41 => Ok(MessageType::Ping),
172            0x42 => Ok(MessageType::Pong),
173            _ => Err(AgentProtocolError::InvalidMessage(format!(
174                "Unknown message type: 0x{:02x}",
175                value
176            ))),
177        }
178    }
179}
180
181/// Handshake request sent from proxy to agent over UDS.
182#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
183pub struct UdsHandshakeRequest {
184    pub supported_versions: Vec<u32>,
185    pub proxy_id: String,
186    pub proxy_version: String,
187    pub config: Option<serde_json::Value>,
188    /// Supported payload encodings (in order of preference).
189    /// If empty or missing, only JSON is supported.
190    #[serde(default, skip_serializing_if = "Vec::is_empty")]
191    pub supported_encodings: Vec<UdsEncoding>,
192}
193
194/// Handshake response from agent to proxy over UDS.
195#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
196pub struct UdsHandshakeResponse {
197    pub protocol_version: u32,
198    pub capabilities: UdsCapabilities,
199    pub success: bool,
200    pub error: Option<String>,
201    /// Negotiated encoding for subsequent messages.
202    /// If missing, defaults to JSON for backwards compatibility.
203    #[serde(default)]
204    pub encoding: UdsEncoding,
205}
206
207/// Agent capabilities for UDS protocol.
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct UdsCapabilities {
210    pub agent_id: String,
211    pub name: String,
212    pub version: String,
213    pub supported_events: Vec<i32>,
214    pub features: UdsFeatures,
215    pub limits: UdsLimits,
216}
217
218/// Agent features.
219#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
220pub struct UdsFeatures {
221    pub streaming_body: bool,
222    pub websocket: bool,
223    pub guardrails: bool,
224    pub config_push: bool,
225    pub metrics_export: bool,
226    pub concurrent_requests: u32,
227    pub cancellation: bool,
228    pub flow_control: bool,
229    pub health_reporting: bool,
230}
231
232/// Agent limits.
233#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
234pub struct UdsLimits {
235    pub max_body_size: u64,
236    pub max_concurrency: u32,
237    pub preferred_chunk_size: u64,
238}
239
240impl From<UdsCapabilities> for AgentCapabilities {
241    fn from(caps: UdsCapabilities) -> Self {
242        AgentCapabilities {
243            protocol_version: PROTOCOL_VERSION_2,
244            agent_id: caps.agent_id,
245            name: caps.name,
246            version: caps.version,
247            supported_events: caps
248                .supported_events
249                .into_iter()
250                .filter_map(event_type_from_i32)
251                .collect(),
252            features: AgentFeatures {
253                streaming_body: caps.features.streaming_body,
254                websocket: caps.features.websocket,
255                guardrails: caps.features.guardrails,
256                config_push: caps.features.config_push,
257                metrics_export: caps.features.metrics_export,
258                concurrent_requests: caps.features.concurrent_requests,
259                cancellation: caps.features.cancellation,
260                flow_control: caps.features.flow_control,
261                health_reporting: caps.features.health_reporting,
262            },
263            limits: AgentLimits {
264                max_body_size: caps.limits.max_body_size as usize,
265                max_concurrency: caps.limits.max_concurrency,
266                preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
267                max_memory: None,
268                max_processing_time_ms: None,
269            },
270            health: HealthConfig::default(),
271        }
272    }
273}
274
275/// Convert i32 to EventType.
276fn event_type_from_i32(value: i32) -> Option<EventType> {
277    match value {
278        0 => Some(EventType::Configure),
279        1 => Some(EventType::RequestHeaders),
280        2 => Some(EventType::RequestBodyChunk),
281        3 => Some(EventType::ResponseHeaders),
282        4 => Some(EventType::ResponseBodyChunk),
283        5 => Some(EventType::RequestComplete),
284        6 => Some(EventType::WebSocketFrame),
285        7 => Some(EventType::GuardrailInspect),
286        _ => None,
287    }
288}
289
290/// v2 agent client over Unix Domain Socket.
291///
292/// This client maintains a single connection and multiplexes multiple requests
293/// over it using correlation IDs, similar to the gRPC client.
294pub struct AgentClientV2Uds {
295    /// Agent identifier
296    agent_id: String,
297    /// Socket path
298    socket_path: String,
299    /// Request timeout
300    timeout: Duration,
301    /// Negotiated capabilities
302    capabilities: RwLock<Option<AgentCapabilities>>,
303    /// Negotiated protocol version
304    protocol_version: AtomicU64,
305    /// Negotiated payload encoding
306    encoding: RwLock<UdsEncoding>,
307    /// Pending requests by correlation ID
308    pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
309    /// Sender for outbound messages
310    #[allow(clippy::type_complexity)]
311    outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
312    /// Sequence counter for pings
313    ping_sequence: AtomicU64,
314    /// Connection state
315    connected: RwLock<bool>,
316    /// Flow control state
317    flow_state: RwLock<FlowState>,
318    /// Last known health state
319    health_state: RwLock<i32>,
320    /// In-flight request count
321    in_flight: AtomicU64,
322    /// Callback for metrics reports
323    metrics_callback: Option<MetricsCallback>,
324    /// Callback for config update requests
325    config_update_callback: Option<ConfigUpdateCallback>,
326}
327
328impl AgentClientV2Uds {
329    /// Create a new UDS v2 client.
330    pub async fn new(
331        agent_id: impl Into<String>,
332        socket_path: impl Into<String>,
333        timeout: Duration,
334    ) -> Result<Self, AgentProtocolError> {
335        let agent_id = agent_id.into();
336        let socket_path = socket_path.into();
337
338        debug!(
339            agent_id = %agent_id,
340            socket_path = %socket_path,
341            timeout_ms = timeout.as_millis(),
342            "Creating UDS v2 client"
343        );
344
345        Ok(Self {
346            agent_id,
347            socket_path,
348            timeout,
349            capabilities: RwLock::new(None),
350            protocol_version: AtomicU64::new(0),
351            encoding: RwLock::new(UdsEncoding::Json),
352            pending: Arc::new(Mutex::new(HashMap::new())),
353            outbound_tx: Mutex::new(None),
354            ping_sequence: AtomicU64::new(0),
355            connected: RwLock::new(false),
356            flow_state: RwLock::new(FlowState::Normal),
357            health_state: RwLock::new(1), // HEALTHY
358            in_flight: AtomicU64::new(0),
359            metrics_callback: None,
360            config_update_callback: None,
361        })
362    }
363
364    /// Returns the list of supported encodings for this client.
365    ///
366    /// When compiled with `binary-uds` feature, MessagePack is preferred.
367    fn supported_encodings() -> Vec<UdsEncoding> {
368        #[cfg(feature = "binary-uds")]
369        {
370            vec![UdsEncoding::MessagePack, UdsEncoding::Json]
371        }
372        #[cfg(not(feature = "binary-uds"))]
373        {
374            vec![UdsEncoding::Json]
375        }
376    }
377
378    /// Get the current negotiated encoding.
379    pub async fn encoding(&self) -> UdsEncoding {
380        *self.encoding.read().await
381    }
382
383    /// Set the metrics callback.
384    pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
385        self.metrics_callback = Some(callback);
386    }
387
388    /// Set the config update callback.
389    pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
390        self.config_update_callback = Some(callback);
391    }
392
393    /// Connect and perform handshake.
394    pub async fn connect(&self) -> Result<(), AgentProtocolError> {
395        info!(
396            agent_id = %self.agent_id,
397            socket_path = %self.socket_path,
398            "Connecting to agent via UDS v2"
399        );
400
401        // Connect to Unix socket
402        let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
403            error!(
404                agent_id = %self.agent_id,
405                socket_path = %self.socket_path,
406                error = %e,
407                "Failed to connect to agent via UDS"
408            );
409            AgentProtocolError::ConnectionFailed(e.to_string())
410        })?;
411
412        let (read_half, write_half) = stream.into_split();
413        let mut reader = BufReader::new(read_half);
414        let mut writer = BufWriter::new(write_half);
415
416        // Send handshake request with supported encodings
417        let handshake_req = UdsHandshakeRequest {
418            supported_versions: vec![PROTOCOL_VERSION_2],
419            proxy_id: "sentinel-proxy".to_string(),
420            proxy_version: env!("CARGO_PKG_VERSION").to_string(),
421            config: None,
422            supported_encodings: Self::supported_encodings(),
423        };
424
425        // Handshake always uses JSON (before encoding is negotiated)
426        let payload = serde_json::to_vec(&handshake_req)
427            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
428
429        write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
430
431        // Read handshake response (always JSON)
432        let (msg_type, response_bytes) = read_message(&mut reader).await?;
433
434        if msg_type != MessageType::HandshakeResponse {
435            return Err(AgentProtocolError::InvalidMessage(format!(
436                "Expected HandshakeResponse, got {:?}",
437                msg_type
438            )));
439        }
440
441        let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
442            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
443
444        if !response.success {
445            return Err(AgentProtocolError::ConnectionFailed(
446                response
447                    .error
448                    .unwrap_or_else(|| "Unknown handshake error".to_string()),
449            ));
450        }
451
452        // Store capabilities and negotiated encoding
453        let capabilities: AgentCapabilities = response.capabilities.into();
454        *self.capabilities.write().await = Some(capabilities);
455        self.protocol_version
456            .store(response.protocol_version as u64, Ordering::SeqCst);
457
458        // Store the negotiated encoding for subsequent messages
459        let negotiated_encoding = response.encoding;
460        *self.encoding.write().await = negotiated_encoding;
461
462        info!(
463            agent_id = %self.agent_id,
464            protocol_version = response.protocol_version,
465            encoding = ?negotiated_encoding,
466            "UDS v2 handshake successful"
467        );
468
469        // Create message channel
470        let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
471        *self.outbound_tx.lock().await = Some(tx);
472        *self.connected.write().await = true;
473
474        // Spawn writer task
475        let agent_id_clone = self.agent_id.clone();
476        tokio::spawn(async move {
477            while let Some((msg_type, payload)) = rx.recv().await {
478                if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
479                    error!(
480                        agent_id = %agent_id_clone,
481                        error = %e,
482                        "Failed to write message to UDS"
483                    );
484                    break;
485                }
486            }
487            debug!(agent_id = %agent_id_clone, "UDS writer task ended");
488        });
489
490        // Spawn reader task with the negotiated encoding
491        let pending = Arc::clone(&self.pending);
492        let agent_id = self.agent_id.clone();
493        let flow_state = Arc::new(RwLock::new(FlowState::Normal));
494        let health_state = Arc::new(RwLock::new(1i32));
495        let flow_state_clone = Arc::clone(&flow_state);
496        let health_state_clone = Arc::clone(&health_state);
497        let metrics_callback = self.metrics_callback.clone();
498        let config_update_callback = self.config_update_callback.clone();
499        // Encoding is fixed after handshake, so we can copy it
500        let reader_encoding = negotiated_encoding;
501
502        tokio::spawn(async move {
503            loop {
504                match read_message(&mut reader).await {
505                    Ok((msg_type, payload)) => {
506                        match msg_type {
507                            MessageType::AgentResponse => {
508                                match reader_encoding.deserialize::<AgentResponse>(&payload) {
509                                    Ok(response) => {
510                                        // Extract correlation ID from the response
511                                        // For UDS, we include correlation_id in the response
512                                        if let Some(sender) = pending.lock().await.remove(
513                                            &response
514                                                .audit
515                                                .custom
516                                                .get("correlation_id")
517                                                .and_then(|v| v.as_str())
518                                                .unwrap_or("")
519                                                .to_string(),
520                                        ) {
521                                            let _ = sender.send(response);
522                                        }
523                                    }
524                                    Err(e) => {
525                                        warn!(
526                                            agent_id = %agent_id,
527                                            error = %e,
528                                            encoding = ?reader_encoding,
529                                            "Failed to parse agent response"
530                                        );
531                                    }
532                                }
533                            }
534                            MessageType::HealthStatus => {
535                                // Health status uses a simple struct, try both encodings for robustness
536                                #[derive(serde::Deserialize)]
537                                struct HealthStatusMsg {
538                                    state: Option<i64>,
539                                }
540                                if let Ok(health) =
541                                    reader_encoding.deserialize::<HealthStatusMsg>(&payload)
542                                {
543                                    if let Some(state) = health.state {
544                                        *health_state_clone.write().await = state as i32;
545                                    }
546                                }
547                            }
548                            MessageType::MetricsReport => {
549                                if let Some(ref callback) = metrics_callback {
550                                    if let Ok(report) = reader_encoding.deserialize(&payload) {
551                                        callback(report);
552                                    }
553                                }
554                            }
555                            MessageType::FlowControl => {
556                                #[derive(serde::Deserialize)]
557                                struct FlowControlMsg {
558                                    action: Option<i64>,
559                                }
560                                if let Ok(fc) =
561                                    reader_encoding.deserialize::<FlowControlMsg>(&payload)
562                                {
563                                    let action = fc.action.unwrap_or(0);
564                                    let new_state = match action {
565                                        1 => FlowState::Paused,
566                                        2 => FlowState::Normal,
567                                        _ => FlowState::Normal,
568                                    };
569                                    *flow_state_clone.write().await = new_state;
570                                }
571                            }
572                            MessageType::ConfigUpdateRequest => {
573                                if let Some(ref callback) = config_update_callback {
574                                    if let Ok(request) = reader_encoding.deserialize(&payload) {
575                                        let _response = callback(agent_id.clone(), request);
576                                    }
577                                }
578                            }
579                            MessageType::Pong => {
580                                trace!(agent_id = %agent_id, "Received pong");
581                            }
582                            _ => {
583                                trace!(
584                                    agent_id = %agent_id,
585                                    msg_type = ?msg_type,
586                                    "Received unhandled message type"
587                                );
588                            }
589                        }
590                    }
591                    Err(e) => {
592                        if !matches!(e, AgentProtocolError::ConnectionClosed) {
593                            error!(
594                                agent_id = %agent_id,
595                                error = %e,
596                                "Error reading from UDS"
597                            );
598                        }
599                        break;
600                    }
601                }
602            }
603            debug!(agent_id = %agent_id, "UDS reader task ended");
604        });
605
606        Ok(())
607    }
608
609    /// Get negotiated capabilities.
610    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
611        self.capabilities.read().await.clone()
612    }
613
614    /// Check if connected.
615    pub async fn is_connected(&self) -> bool {
616        *self.connected.read().await
617    }
618
619    /// Send a request headers event.
620    pub async fn send_request_headers(
621        &self,
622        correlation_id: &str,
623        event: &crate::RequestHeadersEvent,
624    ) -> Result<AgentResponse, AgentProtocolError> {
625        self.send_event(MessageType::RequestHeaders, correlation_id, event)
626            .await
627    }
628
629    /// Send a request body chunk event.
630    pub async fn send_request_body_chunk(
631        &self,
632        correlation_id: &str,
633        event: &crate::RequestBodyChunkEvent,
634    ) -> Result<AgentResponse, AgentProtocolError> {
635        self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
636            .await
637    }
638
639    /// Send a response headers event.
640    pub async fn send_response_headers(
641        &self,
642        correlation_id: &str,
643        event: &crate::ResponseHeadersEvent,
644    ) -> Result<AgentResponse, AgentProtocolError> {
645        self.send_event(MessageType::ResponseHeaders, correlation_id, event)
646            .await
647    }
648
649    /// Send a response body chunk event.
650    pub async fn send_response_body_chunk(
651        &self,
652        correlation_id: &str,
653        event: &crate::ResponseBodyChunkEvent,
654    ) -> Result<AgentResponse, AgentProtocolError> {
655        self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
656            .await
657    }
658
659    /// Send a binary request body chunk event (zero-copy path).
660    ///
661    /// This method avoids base64 encoding when using MessagePack encoding,
662    /// sending raw bytes directly over the wire for better throughput.
663    ///
664    /// # Performance
665    ///
666    /// When MessagePack encoding is negotiated:
667    /// - Bytes are serialized directly (no base64 encode/decode)
668    /// - Reduces CPU usage and latency for large bodies
669    ///
670    /// When JSON encoding is used:
671    /// - Falls back to base64 encoding for JSON compatibility
672    pub async fn send_request_body_chunk_binary(
673        &self,
674        event: &crate::BinaryRequestBodyChunkEvent,
675    ) -> Result<AgentResponse, AgentProtocolError> {
676        let correlation_id = &event.correlation_id;
677        self.send_binary_body_chunk(
678            MessageType::RequestBodyChunk,
679            correlation_id,
680            &event.data,
681            event.is_last,
682            event.total_size,
683            event.chunk_index,
684            Some(event.bytes_received),
685            None,
686        )
687        .await
688    }
689
690    /// Send a binary response body chunk event (zero-copy path).
691    ///
692    /// This method avoids base64 encoding when using MessagePack encoding,
693    /// sending raw bytes directly over the wire for better throughput.
694    pub async fn send_response_body_chunk_binary(
695        &self,
696        event: &crate::BinaryResponseBodyChunkEvent,
697    ) -> Result<AgentResponse, AgentProtocolError> {
698        let correlation_id = &event.correlation_id;
699        self.send_binary_body_chunk(
700            MessageType::ResponseBodyChunk,
701            correlation_id,
702            &event.data,
703            event.is_last,
704            event.total_size,
705            event.chunk_index,
706            None,
707            Some(event.bytes_sent),
708        )
709        .await
710    }
711
712    /// Internal helper to send binary body chunks with encoding-aware serialization.
713    #[allow(clippy::too_many_arguments)]
714    async fn send_binary_body_chunk(
715        &self,
716        msg_type: MessageType,
717        correlation_id: &str,
718        data: &bytes::Bytes,
719        is_last: bool,
720        total_size: Option<usize>,
721        chunk_index: u32,
722        bytes_received: Option<usize>,
723        bytes_sent: Option<usize>,
724    ) -> Result<AgentResponse, AgentProtocolError> {
725        // Create response channel
726        let (tx, rx) = oneshot::channel();
727        self.pending
728            .lock()
729            .await
730            .insert(correlation_id.to_string(), tx);
731
732        // Get the current encoding
733        let encoding = *self.encoding.read().await;
734
735        // Serialize body chunk using encoding-optimized format
736        let payload_bytes = match encoding {
737            UdsEncoding::Json => {
738                // JSON path: must use base64 encoding for binary data
739                use base64::{engine::general_purpose::STANDARD, Engine as _};
740                let json = serde_json::json!({
741                    "correlation_id": correlation_id,
742                    "data": STANDARD.encode(data),
743                    "is_last": is_last,
744                    "total_size": total_size,
745                    "chunk_index": chunk_index,
746                    "bytes_received": bytes_received,
747                    "bytes_sent": bytes_sent,
748                });
749                serde_json::to_vec(&json)
750                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
751            }
752            UdsEncoding::MessagePack => {
753                // MessagePack path: raw bytes via serde_bytes for zero-copy serialization
754                #[derive(serde::Serialize)]
755                struct BinaryBodyChunk<'a> {
756                    correlation_id: &'a str,
757                    #[serde(with = "serde_bytes")]
758                    data: &'a [u8],
759                    is_last: bool,
760                    #[serde(skip_serializing_if = "Option::is_none")]
761                    total_size: Option<usize>,
762                    chunk_index: u32,
763                    #[serde(skip_serializing_if = "Option::is_none")]
764                    bytes_received: Option<usize>,
765                    #[serde(skip_serializing_if = "Option::is_none")]
766                    bytes_sent: Option<usize>,
767                }
768                let chunk = BinaryBodyChunk {
769                    correlation_id,
770                    data: data.as_ref(),
771                    is_last,
772                    total_size,
773                    chunk_index,
774                    bytes_received,
775                    bytes_sent,
776                };
777                encoding.serialize(&chunk)?
778            }
779        };
780
781        // Send message
782        {
783            let outbound = self.outbound_tx.lock().await;
784            if let Some(tx) = outbound.as_ref() {
785                tx.send((msg_type, payload_bytes))
786                    .await
787                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
788            } else {
789                return Err(AgentProtocolError::ConnectionClosed);
790            }
791        }
792
793        self.in_flight.fetch_add(1, Ordering::Relaxed);
794
795        // Wait for response with timeout
796        let response = tokio::time::timeout(self.timeout, rx)
797            .await
798            .map_err(|_| {
799                self.pending
800                    .try_lock()
801                    .ok()
802                    .map(|mut p| p.remove(correlation_id));
803                AgentProtocolError::Timeout(self.timeout)
804            })?
805            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
806
807        self.in_flight.fetch_sub(1, Ordering::Relaxed);
808
809        Ok(response)
810    }
811
812    /// Send an event and wait for response.
813    async fn send_event<T: serde::Serialize>(
814        &self,
815        msg_type: MessageType,
816        correlation_id: &str,
817        event: &T,
818    ) -> Result<AgentResponse, AgentProtocolError> {
819        // Create response channel
820        let (tx, rx) = oneshot::channel();
821        self.pending
822            .lock()
823            .await
824            .insert(correlation_id.to_string(), tx);
825
826        // Get the current encoding
827        let encoding = *self.encoding.read().await;
828
829        // Serialize event using negotiated encoding
830        let payload_bytes = match encoding {
831            UdsEncoding::Json => {
832                // JSON path: use Value mutation for backwards compatibility
833                let mut payload = serde_json::to_value(event)
834                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
835                if let Some(obj) = payload.as_object_mut() {
836                    obj.insert(
837                        "correlation_id".to_string(),
838                        serde_json::Value::String(correlation_id.to_string()),
839                    );
840                }
841                serde_json::to_vec(&payload)
842                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
843            }
844            UdsEncoding::MessagePack => {
845                // MessagePack path: use wrapper struct for efficient serialization
846                #[derive(serde::Serialize)]
847                struct EventWithCorrelation<'a, T: serde::Serialize> {
848                    correlation_id: &'a str,
849                    #[serde(flatten)]
850                    event: &'a T,
851                }
852                let wrapped = EventWithCorrelation {
853                    correlation_id,
854                    event,
855                };
856                encoding.serialize(&wrapped)?
857            }
858        };
859
860        // Send message
861        {
862            let outbound = self.outbound_tx.lock().await;
863            if let Some(tx) = outbound.as_ref() {
864                tx.send((msg_type, payload_bytes))
865                    .await
866                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
867            } else {
868                return Err(AgentProtocolError::ConnectionClosed);
869            }
870        }
871
872        self.in_flight.fetch_add(1, Ordering::Relaxed);
873
874        // Wait for response with timeout
875        let response = tokio::time::timeout(self.timeout, rx)
876            .await
877            .map_err(|_| {
878                self.pending
879                    .try_lock()
880                    .ok()
881                    .map(|mut p| p.remove(correlation_id));
882                AgentProtocolError::Timeout(self.timeout)
883            })?
884            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
885
886        self.in_flight.fetch_sub(1, Ordering::Relaxed);
887
888        Ok(response)
889    }
890
891    /// Send a cancel request for a specific correlation ID.
892    pub async fn cancel_request(
893        &self,
894        correlation_id: &str,
895        reason: super::client::CancelReason,
896    ) -> Result<(), AgentProtocolError> {
897        let cancel = serde_json::json!({
898            "correlation_id": correlation_id,
899            "reason": reason as i32,
900            "timestamp_ms": now_ms(),
901        });
902
903        let payload = serde_json::to_vec(&cancel)
904            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
905
906        let outbound = self.outbound_tx.lock().await;
907        if let Some(tx) = outbound.as_ref() {
908            tx.send((MessageType::Cancel, payload))
909                .await
910                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
911        }
912
913        // Remove pending request
914        self.pending.lock().await.remove(correlation_id);
915
916        Ok(())
917    }
918
919    /// Cancel all in-flight requests.
920    pub async fn cancel_all(
921        &self,
922        reason: super::client::CancelReason,
923    ) -> Result<usize, AgentProtocolError> {
924        let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
925        let count = pending_ids.len();
926
927        for correlation_id in pending_ids {
928            let _ = self.cancel_request(&correlation_id, reason).await;
929        }
930
931        Ok(count)
932    }
933
934    /// Send a ping.
935    pub async fn ping(&self) -> Result<(), AgentProtocolError> {
936        let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
937        let ping = serde_json::json!({
938            "sequence": seq,
939            "timestamp_ms": now_ms(),
940        });
941
942        let payload = serde_json::to_vec(&ping)
943            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
944
945        let outbound = self.outbound_tx.lock().await;
946        if let Some(tx) = outbound.as_ref() {
947            tx.send((MessageType::Ping, payload))
948                .await
949                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
950        }
951
952        Ok(())
953    }
954
955    /// Close the connection.
956    pub async fn close(&self) -> Result<(), AgentProtocolError> {
957        *self.connected.write().await = false;
958        *self.outbound_tx.lock().await = None;
959        Ok(())
960    }
961
962    /// Get in-flight request count.
963    pub fn in_flight(&self) -> u64 {
964        self.in_flight.load(Ordering::Relaxed)
965    }
966
967    /// Get agent ID.
968    pub fn agent_id(&self) -> &str {
969        &self.agent_id
970    }
971
972    /// Check if the agent has requested flow control pause.
973    ///
974    /// Returns true if the agent sent a `FlowAction::Pause` signal,
975    /// indicating it cannot accept more requests.
976    pub async fn is_paused(&self) -> bool {
977        matches!(*self.flow_state.read().await, FlowState::Paused)
978    }
979
980    /// Check if the transport can accept new requests.
981    ///
982    /// Returns false if the agent has requested a flow control pause.
983    pub async fn can_accept_requests(&self) -> bool {
984        !self.is_paused().await
985    }
986}
987
988/// Write a message to the stream.
989pub async fn write_message<W: AsyncWriteExt + Unpin>(
990    writer: &mut W,
991    msg_type: MessageType,
992    payload: &[u8],
993) -> Result<(), AgentProtocolError> {
994    if payload.len() > MAX_UDS_MESSAGE_SIZE {
995        return Err(AgentProtocolError::MessageTooLarge {
996            size: payload.len(),
997            max: MAX_UDS_MESSAGE_SIZE,
998        });
999    }
1000
1001    // Write length (4 bytes, big-endian) - includes type byte
1002    let total_len = (payload.len() + 1) as u32;
1003    writer.write_all(&total_len.to_be_bytes()).await?;
1004
1005    // Write message type (1 byte)
1006    writer.write_all(&[msg_type as u8]).await?;
1007
1008    // Write payload
1009    writer.write_all(payload).await?;
1010    writer.flush().await?;
1011
1012    Ok(())
1013}
1014
1015/// Read a message from the stream.
1016pub async fn read_message<R: AsyncReadExt + Unpin>(
1017    reader: &mut R,
1018) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
1019    // Read length (4 bytes, big-endian)
1020    let mut len_bytes = [0u8; 4];
1021    match reader.read_exact(&mut len_bytes).await {
1022        Ok(_) => {}
1023        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1024            return Err(AgentProtocolError::ConnectionClosed);
1025        }
1026        Err(e) => return Err(e.into()),
1027    }
1028
1029    let total_len = u32::from_be_bytes(len_bytes) as usize;
1030
1031    if total_len == 0 {
1032        return Err(AgentProtocolError::InvalidMessage(
1033            "Zero-length message".to_string(),
1034        ));
1035    }
1036
1037    if total_len > MAX_UDS_MESSAGE_SIZE {
1038        return Err(AgentProtocolError::MessageTooLarge {
1039            size: total_len,
1040            max: MAX_UDS_MESSAGE_SIZE,
1041        });
1042    }
1043
1044    // Read message type (1 byte)
1045    let mut type_byte = [0u8; 1];
1046    reader.read_exact(&mut type_byte).await?;
1047    let msg_type = MessageType::try_from(type_byte[0])?;
1048
1049    // Read payload
1050    let payload_len = total_len - 1;
1051    let mut payload = vec![0u8; payload_len];
1052    if payload_len > 0 {
1053        reader.read_exact(&mut payload).await?;
1054    }
1055
1056    Ok((msg_type, payload))
1057}
1058
1059fn now_ms() -> u64 {
1060    std::time::SystemTime::now()
1061        .duration_since(std::time::UNIX_EPOCH)
1062        .map(|d| d.as_millis() as u64)
1063        .unwrap_or(0)
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069
1070    #[test]
1071    fn test_message_type_roundtrip() {
1072        let types = [
1073            MessageType::HandshakeRequest,
1074            MessageType::HandshakeResponse,
1075            MessageType::RequestHeaders,
1076            MessageType::AgentResponse,
1077            MessageType::HealthStatus,
1078            MessageType::Ping,
1079            MessageType::Pong,
1080        ];
1081
1082        for msg_type in types {
1083            let byte = msg_type as u8;
1084            let parsed = MessageType::try_from(byte).unwrap();
1085            assert_eq!(parsed, msg_type);
1086        }
1087    }
1088
1089    #[test]
1090    fn test_invalid_message_type() {
1091        let result = MessageType::try_from(0xFF);
1092        assert!(result.is_err());
1093    }
1094
1095    #[test]
1096    fn test_handshake_serialization() {
1097        let req = UdsHandshakeRequest {
1098            supported_versions: vec![2],
1099            proxy_id: "test-proxy".to_string(),
1100            proxy_version: "1.0.0".to_string(),
1101            config: None,
1102            supported_encodings: vec![],
1103        };
1104
1105        let json = serde_json::to_string(&req).unwrap();
1106        let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1107
1108        assert_eq!(parsed.supported_versions, vec![2]);
1109        assert_eq!(parsed.proxy_id, "test-proxy");
1110    }
1111
1112    #[tokio::test]
1113    async fn test_write_read_message() {
1114        use tokio::io::duplex;
1115
1116        let (mut client, mut server) = duplex(1024);
1117
1118        // Write from client
1119        let payload = b"test payload";
1120        write_message(&mut client, MessageType::Ping, payload)
1121            .await
1122            .unwrap();
1123
1124        // Read from server
1125        let (msg_type, data) = read_message(&mut server).await.unwrap();
1126        assert_eq!(msg_type, MessageType::Ping);
1127        assert_eq!(data, payload);
1128    }
1129
1130    #[test]
1131    fn test_binary_body_chunk_json_serialization() {
1132        use base64::{engine::general_purpose::STANDARD, Engine as _};
1133
1134        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1135        let correlation_id = "test-123";
1136
1137        // JSON encoding must use base64
1138        let json = serde_json::json!({
1139            "correlation_id": correlation_id,
1140            "data": STANDARD.encode(&data),
1141            "is_last": true,
1142            "total_size": 100usize,
1143            "chunk_index": 0u32,
1144            "bytes_received": 100usize,
1145        });
1146
1147        let serialized = serde_json::to_vec(&json).unwrap();
1148        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1149
1150        // Verify base64 can be decoded back to original
1151        let data_field = parsed["data"].as_str().unwrap();
1152        let decoded = STANDARD.decode(data_field).unwrap();
1153        assert_eq!(decoded, data.as_ref());
1154    }
1155
1156    #[test]
1157    #[cfg(feature = "binary-uds")]
1158    fn test_binary_body_chunk_msgpack_serialization() {
1159        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1160        let correlation_id = "test-123";
1161
1162        // MessagePack uses serde_bytes for efficient serialization
1163        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1164        struct BinaryBodyChunk {
1165            correlation_id: String,
1166            #[serde(with = "serde_bytes")]
1167            data: Vec<u8>,
1168            is_last: bool,
1169            chunk_index: u32,
1170        }
1171
1172        let chunk = BinaryBodyChunk {
1173            correlation_id: correlation_id.to_string(),
1174            data: data.to_vec(),
1175            is_last: true,
1176            chunk_index: 0,
1177        };
1178
1179        // Serialize with MessagePack
1180        let serialized = rmp_serde::to_vec(&chunk).unwrap();
1181
1182        // Deserialize and verify
1183        let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1184        assert_eq!(parsed.correlation_id, correlation_id);
1185        assert_eq!(parsed.data, data.as_ref());
1186        assert!(parsed.is_last);
1187
1188        // Verify MessagePack is more compact than JSON+base64 for binary data
1189        use base64::Engine as _;
1190        let json_size = serde_json::to_vec(&serde_json::json!({
1191            "correlation_id": correlation_id,
1192            "data": base64::engine::general_purpose::STANDARD.encode(&data),
1193            "is_last": true,
1194            "chunk_index": 0u32,
1195        }))
1196        .unwrap()
1197        .len();
1198
1199        // MessagePack should be smaller (raw bytes vs base64 ~33% overhead)
1200        assert!(
1201            serialized.len() < json_size,
1202            "MessagePack ({}) should be smaller than JSON+base64 ({})",
1203            serialized.len(),
1204            json_size
1205        );
1206    }
1207
1208    #[test]
1209    fn test_uds_encoding_default() {
1210        assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1211    }
1212
1213    #[test]
1214    fn test_uds_encoding_serialize_json() {
1215        let encoding = UdsEncoding::Json;
1216        let value = serde_json::json!({"key": "value"});
1217        let serialized = encoding.serialize(&value).unwrap();
1218        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1219        assert_eq!(parsed, value);
1220    }
1221
1222    #[test]
1223    #[cfg(feature = "binary-uds")]
1224    fn test_uds_encoding_serialize_msgpack() {
1225        let encoding = UdsEncoding::MessagePack;
1226        let value = serde_json::json!({"key": "value"});
1227        let serialized = encoding.serialize(&value).unwrap();
1228        // Verify it's valid MessagePack by deserializing
1229        let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1230        assert_eq!(parsed, value);
1231    }
1232}