1use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64 #[default]
66 Json,
67 #[serde(rename = "msgpack")]
69 MessagePack,
70}
71
72impl UdsEncoding {
73 #[inline]
77 pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78 match self {
79 UdsEncoding::Json => {
80 serde_json::to_vec(value)
81 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
82 }
83 #[cfg(feature = "binary-uds")]
84 UdsEncoding::MessagePack => {
85 rmp_serde::to_vec(value)
86 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
87 }
88 #[cfg(not(feature = "binary-uds"))]
89 UdsEncoding::MessagePack => {
90 serde_json::to_vec(value)
92 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
93 }
94 }
95 }
96
97 #[inline]
101 pub fn deserialize<'a, T: serde::Deserialize<'a>>(&self, bytes: &'a [u8]) -> Result<T, AgentProtocolError> {
102 match self {
103 UdsEncoding::Json => {
104 serde_json::from_slice(bytes)
105 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
106 }
107 #[cfg(feature = "binary-uds")]
108 UdsEncoding::MessagePack => {
109 rmp_serde::from_slice(bytes)
110 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
111 }
112 #[cfg(not(feature = "binary-uds"))]
113 UdsEncoding::MessagePack => {
114 serde_json::from_slice(bytes)
116 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
117 }
118 }
119 }
120}
121
122#[repr(u8)]
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum MessageType {
126 HandshakeRequest = 0x01,
128 HandshakeResponse = 0x02,
129
130 RequestHeaders = 0x10,
132 RequestBodyChunk = 0x11,
133 ResponseHeaders = 0x12,
134 ResponseBodyChunk = 0x13,
135 RequestComplete = 0x14,
136 WebSocketFrame = 0x15,
137 GuardrailInspect = 0x16,
138 Configure = 0x17,
139
140 AgentResponse = 0x20,
142
143 HealthStatus = 0x30,
145 MetricsReport = 0x31,
146 ConfigUpdateRequest = 0x32,
147 FlowControl = 0x33,
148
149 Cancel = 0x40,
151 Ping = 0x41,
152 Pong = 0x42,
153}
154
155impl TryFrom<u8> for MessageType {
156 type Error = AgentProtocolError;
157
158 fn try_from(value: u8) -> Result<Self, Self::Error> {
159 match value {
160 0x01 => Ok(MessageType::HandshakeRequest),
161 0x02 => Ok(MessageType::HandshakeResponse),
162 0x10 => Ok(MessageType::RequestHeaders),
163 0x11 => Ok(MessageType::RequestBodyChunk),
164 0x12 => Ok(MessageType::ResponseHeaders),
165 0x13 => Ok(MessageType::ResponseBodyChunk),
166 0x14 => Ok(MessageType::RequestComplete),
167 0x15 => Ok(MessageType::WebSocketFrame),
168 0x16 => Ok(MessageType::GuardrailInspect),
169 0x17 => Ok(MessageType::Configure),
170 0x20 => Ok(MessageType::AgentResponse),
171 0x30 => Ok(MessageType::HealthStatus),
172 0x31 => Ok(MessageType::MetricsReport),
173 0x32 => Ok(MessageType::ConfigUpdateRequest),
174 0x33 => Ok(MessageType::FlowControl),
175 0x40 => Ok(MessageType::Cancel),
176 0x41 => Ok(MessageType::Ping),
177 0x42 => Ok(MessageType::Pong),
178 _ => Err(AgentProtocolError::InvalidMessage(format!(
179 "Unknown message type: 0x{:02x}",
180 value
181 ))),
182 }
183 }
184}
185
186#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
188pub struct UdsHandshakeRequest {
189 pub supported_versions: Vec<u32>,
190 pub proxy_id: String,
191 pub proxy_version: String,
192 pub config: Option<serde_json::Value>,
193 #[serde(default, skip_serializing_if = "Vec::is_empty")]
196 pub supported_encodings: Vec<UdsEncoding>,
197}
198
199#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
201pub struct UdsHandshakeResponse {
202 pub protocol_version: u32,
203 pub capabilities: UdsCapabilities,
204 pub success: bool,
205 pub error: Option<String>,
206 #[serde(default)]
209 pub encoding: UdsEncoding,
210}
211
212#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
214pub struct UdsCapabilities {
215 pub agent_id: String,
216 pub name: String,
217 pub version: String,
218 pub supported_events: Vec<i32>,
219 pub features: UdsFeatures,
220 pub limits: UdsLimits,
221}
222
223#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
225pub struct UdsFeatures {
226 pub streaming_body: bool,
227 pub websocket: bool,
228 pub guardrails: bool,
229 pub config_push: bool,
230 pub metrics_export: bool,
231 pub concurrent_requests: u32,
232 pub cancellation: bool,
233 pub flow_control: bool,
234 pub health_reporting: bool,
235}
236
237#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
239pub struct UdsLimits {
240 pub max_body_size: u64,
241 pub max_concurrency: u32,
242 pub preferred_chunk_size: u64,
243}
244
245impl From<UdsCapabilities> for AgentCapabilities {
246 fn from(caps: UdsCapabilities) -> Self {
247 AgentCapabilities {
248 protocol_version: PROTOCOL_VERSION_2 as u32,
249 agent_id: caps.agent_id,
250 name: caps.name,
251 version: caps.version,
252 supported_events: caps
253 .supported_events
254 .into_iter()
255 .filter_map(event_type_from_i32)
256 .collect(),
257 features: AgentFeatures {
258 streaming_body: caps.features.streaming_body,
259 websocket: caps.features.websocket,
260 guardrails: caps.features.guardrails,
261 config_push: caps.features.config_push,
262 metrics_export: caps.features.metrics_export,
263 concurrent_requests: caps.features.concurrent_requests,
264 cancellation: caps.features.cancellation,
265 flow_control: caps.features.flow_control,
266 health_reporting: caps.features.health_reporting,
267 },
268 limits: AgentLimits {
269 max_body_size: caps.limits.max_body_size as usize,
270 max_concurrency: caps.limits.max_concurrency,
271 preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
272 max_memory: None,
273 max_processing_time_ms: None,
274 },
275 health: HealthConfig::default(),
276 }
277 }
278}
279
280fn event_type_from_i32(value: i32) -> Option<EventType> {
282 match value {
283 0 => Some(EventType::Configure),
284 1 => Some(EventType::RequestHeaders),
285 2 => Some(EventType::RequestBodyChunk),
286 3 => Some(EventType::ResponseHeaders),
287 4 => Some(EventType::ResponseBodyChunk),
288 5 => Some(EventType::RequestComplete),
289 6 => Some(EventType::WebSocketFrame),
290 7 => Some(EventType::GuardrailInspect),
291 _ => None,
292 }
293}
294
295pub struct AgentClientV2Uds {
300 agent_id: String,
302 socket_path: String,
304 timeout: Duration,
306 capabilities: RwLock<Option<AgentCapabilities>>,
308 protocol_version: AtomicU64,
310 encoding: RwLock<UdsEncoding>,
312 pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
314 outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
316 ping_sequence: AtomicU64,
318 connected: RwLock<bool>,
320 flow_state: RwLock<FlowState>,
322 health_state: RwLock<i32>,
324 in_flight: AtomicU64,
326 metrics_callback: Option<MetricsCallback>,
328 config_update_callback: Option<ConfigUpdateCallback>,
330}
331
332impl AgentClientV2Uds {
333 pub async fn new(
335 agent_id: impl Into<String>,
336 socket_path: impl Into<String>,
337 timeout: Duration,
338 ) -> Result<Self, AgentProtocolError> {
339 let agent_id = agent_id.into();
340 let socket_path = socket_path.into();
341
342 debug!(
343 agent_id = %agent_id,
344 socket_path = %socket_path,
345 timeout_ms = timeout.as_millis(),
346 "Creating UDS v2 client"
347 );
348
349 Ok(Self {
350 agent_id,
351 socket_path,
352 timeout,
353 capabilities: RwLock::new(None),
354 protocol_version: AtomicU64::new(0),
355 encoding: RwLock::new(UdsEncoding::Json),
356 pending: Arc::new(Mutex::new(HashMap::new())),
357 outbound_tx: Mutex::new(None),
358 ping_sequence: AtomicU64::new(0),
359 connected: RwLock::new(false),
360 flow_state: RwLock::new(FlowState::Normal),
361 health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
363 metrics_callback: None,
364 config_update_callback: None,
365 })
366 }
367
368 fn supported_encodings() -> Vec<UdsEncoding> {
372 #[cfg(feature = "binary-uds")]
373 {
374 vec![UdsEncoding::MessagePack, UdsEncoding::Json]
375 }
376 #[cfg(not(feature = "binary-uds"))]
377 {
378 vec![UdsEncoding::Json]
379 }
380 }
381
382 pub async fn encoding(&self) -> UdsEncoding {
384 *self.encoding.read().await
385 }
386
387 pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
389 self.metrics_callback = Some(callback);
390 }
391
392 pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
394 self.config_update_callback = Some(callback);
395 }
396
397 pub async fn connect(&self) -> Result<(), AgentProtocolError> {
399 info!(
400 agent_id = %self.agent_id,
401 socket_path = %self.socket_path,
402 "Connecting to agent via UDS v2"
403 );
404
405 let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
407 error!(
408 agent_id = %self.agent_id,
409 socket_path = %self.socket_path,
410 error = %e,
411 "Failed to connect to agent via UDS"
412 );
413 AgentProtocolError::ConnectionFailed(e.to_string())
414 })?;
415
416 let (read_half, write_half) = stream.into_split();
417 let mut reader = BufReader::new(read_half);
418 let mut writer = BufWriter::new(write_half);
419
420 let handshake_req = UdsHandshakeRequest {
422 supported_versions: vec![PROTOCOL_VERSION_2 as u32],
423 proxy_id: "sentinel-proxy".to_string(),
424 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
425 config: None,
426 supported_encodings: Self::supported_encodings(),
427 };
428
429 let payload = serde_json::to_vec(&handshake_req)
431 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
432
433 write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
434
435 let (msg_type, response_bytes) = read_message(&mut reader).await?;
437
438 if msg_type != MessageType::HandshakeResponse {
439 return Err(AgentProtocolError::InvalidMessage(format!(
440 "Expected HandshakeResponse, got {:?}",
441 msg_type
442 )));
443 }
444
445 let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
446 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
447
448 if !response.success {
449 return Err(AgentProtocolError::ConnectionFailed(
450 response.error.unwrap_or_else(|| "Unknown handshake error".to_string()),
451 ));
452 }
453
454 let capabilities: AgentCapabilities = response.capabilities.into();
456 *self.capabilities.write().await = Some(capabilities);
457 self.protocol_version
458 .store(response.protocol_version as u64, Ordering::SeqCst);
459
460 let negotiated_encoding = response.encoding;
462 *self.encoding.write().await = negotiated_encoding;
463
464 info!(
465 agent_id = %self.agent_id,
466 protocol_version = response.protocol_version,
467 encoding = ?negotiated_encoding,
468 "UDS v2 handshake successful"
469 );
470
471 let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
473 *self.outbound_tx.lock().await = Some(tx);
474 *self.connected.write().await = true;
475
476 let agent_id_clone = self.agent_id.clone();
478 tokio::spawn(async move {
479 while let Some((msg_type, payload)) = rx.recv().await {
480 if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
481 error!(
482 agent_id = %agent_id_clone,
483 error = %e,
484 "Failed to write message to UDS"
485 );
486 break;
487 }
488 }
489 debug!(agent_id = %agent_id_clone, "UDS writer task ended");
490 });
491
492 let pending = Arc::clone(&self.pending);
494 let agent_id = self.agent_id.clone();
495 let flow_state = Arc::new(RwLock::new(FlowState::Normal));
496 let health_state = Arc::new(RwLock::new(1i32));
497 let flow_state_clone = Arc::clone(&flow_state);
498 let health_state_clone = Arc::clone(&health_state);
499 let metrics_callback = self.metrics_callback.clone();
500 let config_update_callback = self.config_update_callback.clone();
501 let reader_encoding = negotiated_encoding;
503
504 tokio::spawn(async move {
505 loop {
506 match read_message(&mut reader).await {
507 Ok((msg_type, payload)) => {
508 match msg_type {
509 MessageType::AgentResponse => {
510 match reader_encoding.deserialize::<AgentResponse>(&payload) {
511 Ok(response) => {
512 if let Some(sender) = pending.lock().await.remove(&response.audit.custom.get("correlation_id").and_then(|v| v.as_str()).unwrap_or("").to_string()) {
515 let _ = sender.send(response);
516 }
517 }
518 Err(e) => {
519 warn!(
520 agent_id = %agent_id,
521 error = %e,
522 encoding = ?reader_encoding,
523 "Failed to parse agent response"
524 );
525 }
526 }
527 }
528 MessageType::HealthStatus => {
529 #[derive(serde::Deserialize)]
531 struct HealthStatusMsg {
532 state: Option<i64>,
533 }
534 if let Ok(health) = reader_encoding.deserialize::<HealthStatusMsg>(&payload) {
535 if let Some(state) = health.state {
536 *health_state_clone.write().await = state as i32;
537 }
538 }
539 }
540 MessageType::MetricsReport => {
541 if let Some(ref callback) = metrics_callback {
542 if let Ok(report) = reader_encoding.deserialize(&payload) {
543 callback(report);
544 }
545 }
546 }
547 MessageType::FlowControl => {
548 #[derive(serde::Deserialize)]
549 struct FlowControlMsg {
550 action: Option<i64>,
551 }
552 if let Ok(fc) = reader_encoding.deserialize::<FlowControlMsg>(&payload) {
553 let action = fc.action.unwrap_or(0);
554 let new_state = match action {
555 1 => FlowState::Paused,
556 2 => FlowState::Normal,
557 _ => FlowState::Normal,
558 };
559 *flow_state_clone.write().await = new_state;
560 }
561 }
562 MessageType::ConfigUpdateRequest => {
563 if let Some(ref callback) = config_update_callback {
564 if let Ok(request) = reader_encoding.deserialize(&payload) {
565 let _response = callback(agent_id.clone(), request);
566 }
567 }
568 }
569 MessageType::Pong => {
570 trace!(agent_id = %agent_id, "Received pong");
571 }
572 _ => {
573 trace!(
574 agent_id = %agent_id,
575 msg_type = ?msg_type,
576 "Received unhandled message type"
577 );
578 }
579 }
580 }
581 Err(e) => {
582 if !matches!(e, AgentProtocolError::ConnectionClosed) {
583 error!(
584 agent_id = %agent_id,
585 error = %e,
586 "Error reading from UDS"
587 );
588 }
589 break;
590 }
591 }
592 }
593 debug!(agent_id = %agent_id, "UDS reader task ended");
594 });
595
596 Ok(())
597 }
598
599 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
601 self.capabilities.read().await.clone()
602 }
603
604 pub async fn is_connected(&self) -> bool {
606 *self.connected.read().await
607 }
608
609 pub async fn send_request_headers(
611 &self,
612 correlation_id: &str,
613 event: &crate::RequestHeadersEvent,
614 ) -> Result<AgentResponse, AgentProtocolError> {
615 self.send_event(MessageType::RequestHeaders, correlation_id, event).await
616 }
617
618 pub async fn send_request_body_chunk(
620 &self,
621 correlation_id: &str,
622 event: &crate::RequestBodyChunkEvent,
623 ) -> Result<AgentResponse, AgentProtocolError> {
624 self.send_event(MessageType::RequestBodyChunk, correlation_id, event).await
625 }
626
627 pub async fn send_response_headers(
629 &self,
630 correlation_id: &str,
631 event: &crate::ResponseHeadersEvent,
632 ) -> Result<AgentResponse, AgentProtocolError> {
633 self.send_event(MessageType::ResponseHeaders, correlation_id, event).await
634 }
635
636 pub async fn send_response_body_chunk(
638 &self,
639 correlation_id: &str,
640 event: &crate::ResponseBodyChunkEvent,
641 ) -> Result<AgentResponse, AgentProtocolError> {
642 self.send_event(MessageType::ResponseBodyChunk, correlation_id, event).await
643 }
644
645 pub async fn send_request_body_chunk_binary(
659 &self,
660 event: &crate::BinaryRequestBodyChunkEvent,
661 ) -> Result<AgentResponse, AgentProtocolError> {
662 let correlation_id = &event.correlation_id;
663 self.send_binary_body_chunk(
664 MessageType::RequestBodyChunk,
665 correlation_id,
666 &event.data,
667 event.is_last,
668 event.total_size,
669 event.chunk_index,
670 Some(event.bytes_received),
671 None,
672 ).await
673 }
674
675 pub async fn send_response_body_chunk_binary(
680 &self,
681 event: &crate::BinaryResponseBodyChunkEvent,
682 ) -> Result<AgentResponse, AgentProtocolError> {
683 let correlation_id = &event.correlation_id;
684 self.send_binary_body_chunk(
685 MessageType::ResponseBodyChunk,
686 correlation_id,
687 &event.data,
688 event.is_last,
689 event.total_size,
690 event.chunk_index,
691 None,
692 Some(event.bytes_sent),
693 ).await
694 }
695
696 #[allow(clippy::too_many_arguments)]
698 async fn send_binary_body_chunk(
699 &self,
700 msg_type: MessageType,
701 correlation_id: &str,
702 data: &bytes::Bytes,
703 is_last: bool,
704 total_size: Option<usize>,
705 chunk_index: u32,
706 bytes_received: Option<usize>,
707 bytes_sent: Option<usize>,
708 ) -> Result<AgentResponse, AgentProtocolError> {
709 let (tx, rx) = oneshot::channel();
711 self.pending
712 .lock()
713 .await
714 .insert(correlation_id.to_string(), tx);
715
716 let encoding = *self.encoding.read().await;
718
719 let payload_bytes = match encoding {
721 UdsEncoding::Json => {
722 use base64::{engine::general_purpose::STANDARD, Engine as _};
724 let json = serde_json::json!({
725 "correlation_id": correlation_id,
726 "data": STANDARD.encode(data),
727 "is_last": is_last,
728 "total_size": total_size,
729 "chunk_index": chunk_index,
730 "bytes_received": bytes_received,
731 "bytes_sent": bytes_sent,
732 });
733 serde_json::to_vec(&json)
734 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
735 }
736 UdsEncoding::MessagePack => {
737 #[derive(serde::Serialize)]
739 struct BinaryBodyChunk<'a> {
740 correlation_id: &'a str,
741 #[serde(with = "serde_bytes")]
742 data: &'a [u8],
743 is_last: bool,
744 #[serde(skip_serializing_if = "Option::is_none")]
745 total_size: Option<usize>,
746 chunk_index: u32,
747 #[serde(skip_serializing_if = "Option::is_none")]
748 bytes_received: Option<usize>,
749 #[serde(skip_serializing_if = "Option::is_none")]
750 bytes_sent: Option<usize>,
751 }
752 let chunk = BinaryBodyChunk {
753 correlation_id,
754 data: data.as_ref(),
755 is_last,
756 total_size,
757 chunk_index,
758 bytes_received,
759 bytes_sent,
760 };
761 encoding.serialize(&chunk)?
762 }
763 };
764
765 {
767 let outbound = self.outbound_tx.lock().await;
768 if let Some(tx) = outbound.as_ref() {
769 tx.send((msg_type, payload_bytes))
770 .await
771 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
772 } else {
773 return Err(AgentProtocolError::ConnectionClosed);
774 }
775 }
776
777 self.in_flight.fetch_add(1, Ordering::Relaxed);
778
779 let response = tokio::time::timeout(self.timeout, rx)
781 .await
782 .map_err(|_| {
783 self.pending.try_lock().ok().map(|mut p| p.remove(correlation_id));
784 AgentProtocolError::Timeout(self.timeout)
785 })?
786 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
787
788 self.in_flight.fetch_sub(1, Ordering::Relaxed);
789
790 Ok(response)
791 }
792
793 async fn send_event<T: serde::Serialize>(
795 &self,
796 msg_type: MessageType,
797 correlation_id: &str,
798 event: &T,
799 ) -> Result<AgentResponse, AgentProtocolError> {
800 let (tx, rx) = oneshot::channel();
802 self.pending
803 .lock()
804 .await
805 .insert(correlation_id.to_string(), tx);
806
807 let encoding = *self.encoding.read().await;
809
810 let payload_bytes = match encoding {
812 UdsEncoding::Json => {
813 let mut payload = serde_json::to_value(event)
815 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
816 if let Some(obj) = payload.as_object_mut() {
817 obj.insert(
818 "correlation_id".to_string(),
819 serde_json::Value::String(correlation_id.to_string()),
820 );
821 }
822 serde_json::to_vec(&payload)
823 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
824 }
825 UdsEncoding::MessagePack => {
826 #[derive(serde::Serialize)]
828 struct EventWithCorrelation<'a, T: serde::Serialize> {
829 correlation_id: &'a str,
830 #[serde(flatten)]
831 event: &'a T,
832 }
833 let wrapped = EventWithCorrelation {
834 correlation_id,
835 event,
836 };
837 encoding.serialize(&wrapped)?
838 }
839 };
840
841 {
843 let outbound = self.outbound_tx.lock().await;
844 if let Some(tx) = outbound.as_ref() {
845 tx.send((msg_type, payload_bytes))
846 .await
847 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
848 } else {
849 return Err(AgentProtocolError::ConnectionClosed);
850 }
851 }
852
853 self.in_flight.fetch_add(1, Ordering::Relaxed);
854
855 let response = tokio::time::timeout(self.timeout, rx)
857 .await
858 .map_err(|_| {
859 self.pending.try_lock().ok().map(|mut p| p.remove(correlation_id));
860 AgentProtocolError::Timeout(self.timeout)
861 })?
862 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
863
864 self.in_flight.fetch_sub(1, Ordering::Relaxed);
865
866 Ok(response)
867 }
868
869 pub async fn cancel_request(
871 &self,
872 correlation_id: &str,
873 reason: super::client::CancelReason,
874 ) -> Result<(), AgentProtocolError> {
875 let cancel = serde_json::json!({
876 "correlation_id": correlation_id,
877 "reason": reason as i32,
878 "timestamp_ms": now_ms(),
879 });
880
881 let payload = serde_json::to_vec(&cancel)
882 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
883
884 let outbound = self.outbound_tx.lock().await;
885 if let Some(tx) = outbound.as_ref() {
886 tx.send((MessageType::Cancel, payload))
887 .await
888 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
889 }
890
891 self.pending.lock().await.remove(correlation_id);
893
894 Ok(())
895 }
896
897 pub async fn cancel_all(
899 &self,
900 reason: super::client::CancelReason,
901 ) -> Result<usize, AgentProtocolError> {
902 let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
903 let count = pending_ids.len();
904
905 for correlation_id in pending_ids {
906 let _ = self.cancel_request(&correlation_id, reason).await;
907 }
908
909 Ok(count)
910 }
911
912 pub async fn ping(&self) -> Result<(), AgentProtocolError> {
914 let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
915 let ping = serde_json::json!({
916 "sequence": seq,
917 "timestamp_ms": now_ms(),
918 });
919
920 let payload = serde_json::to_vec(&ping)
921 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
922
923 let outbound = self.outbound_tx.lock().await;
924 if let Some(tx) = outbound.as_ref() {
925 tx.send((MessageType::Ping, payload))
926 .await
927 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
928 }
929
930 Ok(())
931 }
932
933 pub async fn close(&self) -> Result<(), AgentProtocolError> {
935 *self.connected.write().await = false;
936 *self.outbound_tx.lock().await = None;
937 Ok(())
938 }
939
940 pub fn in_flight(&self) -> u64 {
942 self.in_flight.load(Ordering::Relaxed)
943 }
944
945 pub fn agent_id(&self) -> &str {
947 &self.agent_id
948 }
949
950 pub async fn is_paused(&self) -> bool {
955 matches!(*self.flow_state.read().await, FlowState::Paused)
956 }
957
958 pub async fn can_accept_requests(&self) -> bool {
962 !self.is_paused().await
963 }
964}
965
966pub async fn write_message<W: AsyncWriteExt + Unpin>(
968 writer: &mut W,
969 msg_type: MessageType,
970 payload: &[u8],
971) -> Result<(), AgentProtocolError> {
972 if payload.len() > MAX_UDS_MESSAGE_SIZE {
973 return Err(AgentProtocolError::MessageTooLarge {
974 size: payload.len(),
975 max: MAX_UDS_MESSAGE_SIZE,
976 });
977 }
978
979 let total_len = (payload.len() + 1) as u32;
981 writer.write_all(&total_len.to_be_bytes()).await?;
982
983 writer.write_all(&[msg_type as u8]).await?;
985
986 writer.write_all(payload).await?;
988 writer.flush().await?;
989
990 Ok(())
991}
992
993pub async fn read_message<R: AsyncReadExt + Unpin>(
995 reader: &mut R,
996) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
997 let mut len_bytes = [0u8; 4];
999 match reader.read_exact(&mut len_bytes).await {
1000 Ok(_) => {}
1001 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1002 return Err(AgentProtocolError::ConnectionClosed);
1003 }
1004 Err(e) => return Err(e.into()),
1005 }
1006
1007 let total_len = u32::from_be_bytes(len_bytes) as usize;
1008
1009 if total_len == 0 {
1010 return Err(AgentProtocolError::InvalidMessage(
1011 "Zero-length message".to_string(),
1012 ));
1013 }
1014
1015 if total_len > MAX_UDS_MESSAGE_SIZE {
1016 return Err(AgentProtocolError::MessageTooLarge {
1017 size: total_len,
1018 max: MAX_UDS_MESSAGE_SIZE,
1019 });
1020 }
1021
1022 let mut type_byte = [0u8; 1];
1024 reader.read_exact(&mut type_byte).await?;
1025 let msg_type = MessageType::try_from(type_byte[0])?;
1026
1027 let payload_len = total_len - 1;
1029 let mut payload = vec![0u8; payload_len];
1030 if payload_len > 0 {
1031 reader.read_exact(&mut payload).await?;
1032 }
1033
1034 Ok((msg_type, payload))
1035}
1036
1037fn now_ms() -> u64 {
1038 std::time::SystemTime::now()
1039 .duration_since(std::time::UNIX_EPOCH)
1040 .map(|d| d.as_millis() as u64)
1041 .unwrap_or(0)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::*;
1047
1048 #[test]
1049 fn test_message_type_roundtrip() {
1050 let types = [
1051 MessageType::HandshakeRequest,
1052 MessageType::HandshakeResponse,
1053 MessageType::RequestHeaders,
1054 MessageType::AgentResponse,
1055 MessageType::HealthStatus,
1056 MessageType::Ping,
1057 MessageType::Pong,
1058 ];
1059
1060 for msg_type in types {
1061 let byte = msg_type as u8;
1062 let parsed = MessageType::try_from(byte).unwrap();
1063 assert_eq!(parsed, msg_type);
1064 }
1065 }
1066
1067 #[test]
1068 fn test_invalid_message_type() {
1069 let result = MessageType::try_from(0xFF);
1070 assert!(result.is_err());
1071 }
1072
1073 #[test]
1074 fn test_handshake_serialization() {
1075 let req = UdsHandshakeRequest {
1076 supported_versions: vec![2],
1077 proxy_id: "test-proxy".to_string(),
1078 proxy_version: "1.0.0".to_string(),
1079 config: None,
1080 supported_encodings: vec![],
1081 };
1082
1083 let json = serde_json::to_string(&req).unwrap();
1084 let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1085
1086 assert_eq!(parsed.supported_versions, vec![2]);
1087 assert_eq!(parsed.proxy_id, "test-proxy");
1088 }
1089
1090 #[tokio::test]
1091 async fn test_write_read_message() {
1092 use tokio::io::duplex;
1093
1094 let (mut client, mut server) = duplex(1024);
1095
1096 let payload = b"test payload";
1098 write_message(&mut client, MessageType::Ping, payload)
1099 .await
1100 .unwrap();
1101
1102 let (msg_type, data) = read_message(&mut server).await.unwrap();
1104 assert_eq!(msg_type, MessageType::Ping);
1105 assert_eq!(data, payload);
1106 }
1107
1108 #[test]
1109 fn test_binary_body_chunk_json_serialization() {
1110 use base64::{engine::general_purpose::STANDARD, Engine as _};
1111
1112 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1113 let correlation_id = "test-123";
1114
1115 let json = serde_json::json!({
1117 "correlation_id": correlation_id,
1118 "data": STANDARD.encode(&data),
1119 "is_last": true,
1120 "total_size": 100usize,
1121 "chunk_index": 0u32,
1122 "bytes_received": 100usize,
1123 });
1124
1125 let serialized = serde_json::to_vec(&json).unwrap();
1126 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1127
1128 let data_field = parsed["data"].as_str().unwrap();
1130 let decoded = STANDARD.decode(data_field).unwrap();
1131 assert_eq!(decoded, data.as_ref());
1132 }
1133
1134 #[test]
1135 #[cfg(feature = "binary-uds")]
1136 fn test_binary_body_chunk_msgpack_serialization() {
1137 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1138 let correlation_id = "test-123";
1139
1140 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1142 struct BinaryBodyChunk {
1143 correlation_id: String,
1144 #[serde(with = "serde_bytes")]
1145 data: Vec<u8>,
1146 is_last: bool,
1147 chunk_index: u32,
1148 }
1149
1150 let chunk = BinaryBodyChunk {
1151 correlation_id: correlation_id.to_string(),
1152 data: data.to_vec(),
1153 is_last: true,
1154 chunk_index: 0,
1155 };
1156
1157 let serialized = rmp_serde::to_vec(&chunk).unwrap();
1159
1160 let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1162 assert_eq!(parsed.correlation_id, correlation_id);
1163 assert_eq!(parsed.data, data.as_ref());
1164 assert!(parsed.is_last);
1165
1166 use base64::Engine as _;
1168 let json_size = serde_json::to_vec(&serde_json::json!({
1169 "correlation_id": correlation_id,
1170 "data": base64::engine::general_purpose::STANDARD.encode(&data),
1171 "is_last": true,
1172 "chunk_index": 0u32,
1173 })).unwrap().len();
1174
1175 assert!(serialized.len() < json_size,
1177 "MessagePack ({}) should be smaller than JSON+base64 ({})",
1178 serialized.len(), json_size);
1179 }
1180
1181 #[test]
1182 fn test_uds_encoding_default() {
1183 assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1184 }
1185
1186 #[test]
1187 fn test_uds_encoding_serialize_json() {
1188 let encoding = UdsEncoding::Json;
1189 let value = serde_json::json!({"key": "value"});
1190 let serialized = encoding.serialize(&value).unwrap();
1191 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1192 assert_eq!(parsed, value);
1193 }
1194
1195 #[test]
1196 #[cfg(feature = "binary-uds")]
1197 fn test_uds_encoding_serialize_msgpack() {
1198 let encoding = UdsEncoding::MessagePack;
1199 let value = serde_json::json!({"key": "value"});
1200 let serialized = encoding.serialize(&value).unwrap();
1201 let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1203 assert_eq!(parsed, value);
1204 }
1205}