1use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64 #[default]
66 Json,
67 #[serde(rename = "msgpack")]
69 MessagePack,
70}
71
72impl UdsEncoding {
73 #[inline]
77 pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78 match self {
79 UdsEncoding::Json => serde_json::to_vec(value)
80 .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
81 #[cfg(feature = "binary-uds")]
82 UdsEncoding::MessagePack => rmp_serde::to_vec(value)
83 .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
84 #[cfg(not(feature = "binary-uds"))]
85 UdsEncoding::MessagePack => {
86 serde_json::to_vec(value)
88 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
89 }
90 }
91 }
92
93 #[inline]
97 pub fn deserialize<'a, T: serde::Deserialize<'a>>(
98 &self,
99 bytes: &'a [u8],
100 ) -> Result<T, AgentProtocolError> {
101 match self {
102 UdsEncoding::Json => serde_json::from_slice(bytes)
103 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
104 #[cfg(feature = "binary-uds")]
105 UdsEncoding::MessagePack => rmp_serde::from_slice(bytes)
106 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
107 #[cfg(not(feature = "binary-uds"))]
108 UdsEncoding::MessagePack => {
109 serde_json::from_slice(bytes)
111 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
112 }
113 }
114 }
115}
116
117#[repr(u8)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MessageType {
121 HandshakeRequest = 0x01,
123 HandshakeResponse = 0x02,
124
125 RequestHeaders = 0x10,
127 RequestBodyChunk = 0x11,
128 ResponseHeaders = 0x12,
129 ResponseBodyChunk = 0x13,
130 RequestComplete = 0x14,
131 WebSocketFrame = 0x15,
132 GuardrailInspect = 0x16,
133 Configure = 0x17,
134
135 AgentResponse = 0x20,
137
138 HealthStatus = 0x30,
140 MetricsReport = 0x31,
141 ConfigUpdateRequest = 0x32,
142 FlowControl = 0x33,
143
144 Cancel = 0x40,
146 Ping = 0x41,
147 Pong = 0x42,
148}
149
150impl TryFrom<u8> for MessageType {
151 type Error = AgentProtocolError;
152
153 fn try_from(value: u8) -> Result<Self, Self::Error> {
154 match value {
155 0x01 => Ok(MessageType::HandshakeRequest),
156 0x02 => Ok(MessageType::HandshakeResponse),
157 0x10 => Ok(MessageType::RequestHeaders),
158 0x11 => Ok(MessageType::RequestBodyChunk),
159 0x12 => Ok(MessageType::ResponseHeaders),
160 0x13 => Ok(MessageType::ResponseBodyChunk),
161 0x14 => Ok(MessageType::RequestComplete),
162 0x15 => Ok(MessageType::WebSocketFrame),
163 0x16 => Ok(MessageType::GuardrailInspect),
164 0x17 => Ok(MessageType::Configure),
165 0x20 => Ok(MessageType::AgentResponse),
166 0x30 => Ok(MessageType::HealthStatus),
167 0x31 => Ok(MessageType::MetricsReport),
168 0x32 => Ok(MessageType::ConfigUpdateRequest),
169 0x33 => Ok(MessageType::FlowControl),
170 0x40 => Ok(MessageType::Cancel),
171 0x41 => Ok(MessageType::Ping),
172 0x42 => Ok(MessageType::Pong),
173 _ => Err(AgentProtocolError::InvalidMessage(format!(
174 "Unknown message type: 0x{:02x}",
175 value
176 ))),
177 }
178 }
179}
180
181#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
183pub struct UdsHandshakeRequest {
184 pub supported_versions: Vec<u32>,
185 pub proxy_id: String,
186 pub proxy_version: String,
187 pub config: Option<serde_json::Value>,
188 #[serde(default, skip_serializing_if = "Vec::is_empty")]
191 pub supported_encodings: Vec<UdsEncoding>,
192}
193
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
196pub struct UdsHandshakeResponse {
197 pub protocol_version: u32,
198 pub capabilities: UdsCapabilities,
199 pub success: bool,
200 pub error: Option<String>,
201 #[serde(default)]
204 pub encoding: UdsEncoding,
205}
206
207#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct UdsCapabilities {
210 pub agent_id: String,
211 pub name: String,
212 pub version: String,
213 pub supported_events: Vec<i32>,
214 pub features: UdsFeatures,
215 pub limits: UdsLimits,
216}
217
218#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
220pub struct UdsFeatures {
221 pub streaming_body: bool,
222 pub websocket: bool,
223 pub guardrails: bool,
224 pub config_push: bool,
225 pub metrics_export: bool,
226 pub concurrent_requests: u32,
227 pub cancellation: bool,
228 pub flow_control: bool,
229 pub health_reporting: bool,
230}
231
232#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
234pub struct UdsLimits {
235 pub max_body_size: u64,
236 pub max_concurrency: u32,
237 pub preferred_chunk_size: u64,
238}
239
240impl From<UdsCapabilities> for AgentCapabilities {
241 fn from(caps: UdsCapabilities) -> Self {
242 AgentCapabilities {
243 protocol_version: PROTOCOL_VERSION_2,
244 agent_id: caps.agent_id,
245 name: caps.name,
246 version: caps.version,
247 supported_events: caps
248 .supported_events
249 .into_iter()
250 .filter_map(event_type_from_i32)
251 .collect(),
252 features: AgentFeatures {
253 streaming_body: caps.features.streaming_body,
254 websocket: caps.features.websocket,
255 guardrails: caps.features.guardrails,
256 config_push: caps.features.config_push,
257 metrics_export: caps.features.metrics_export,
258 concurrent_requests: caps.features.concurrent_requests,
259 cancellation: caps.features.cancellation,
260 flow_control: caps.features.flow_control,
261 health_reporting: caps.features.health_reporting,
262 },
263 limits: AgentLimits {
264 max_body_size: caps.limits.max_body_size as usize,
265 max_concurrency: caps.limits.max_concurrency,
266 preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
267 max_memory: None,
268 max_processing_time_ms: None,
269 },
270 health: HealthConfig::default(),
271 }
272 }
273}
274
275fn event_type_from_i32(value: i32) -> Option<EventType> {
277 match value {
278 0 => Some(EventType::Configure),
279 1 => Some(EventType::RequestHeaders),
280 2 => Some(EventType::RequestBodyChunk),
281 3 => Some(EventType::ResponseHeaders),
282 4 => Some(EventType::ResponseBodyChunk),
283 5 => Some(EventType::RequestComplete),
284 6 => Some(EventType::WebSocketFrame),
285 7 => Some(EventType::GuardrailInspect),
286 _ => None,
287 }
288}
289
290pub struct AgentClientV2Uds {
295 agent_id: String,
297 socket_path: String,
299 timeout: Duration,
301 capabilities: RwLock<Option<AgentCapabilities>>,
303 protocol_version: AtomicU64,
305 encoding: RwLock<UdsEncoding>,
307 pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
309 #[allow(clippy::type_complexity)]
311 outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
312 ping_sequence: AtomicU64,
314 connected: RwLock<bool>,
316 flow_state: RwLock<FlowState>,
318 health_state: RwLock<i32>,
320 in_flight: AtomicU64,
322 metrics_callback: Option<MetricsCallback>,
324 config_update_callback: Option<ConfigUpdateCallback>,
326}
327
328impl AgentClientV2Uds {
329 pub async fn new(
331 agent_id: impl Into<String>,
332 socket_path: impl Into<String>,
333 timeout: Duration,
334 ) -> Result<Self, AgentProtocolError> {
335 let agent_id = agent_id.into();
336 let socket_path = socket_path.into();
337
338 debug!(
339 agent_id = %agent_id,
340 socket_path = %socket_path,
341 timeout_ms = timeout.as_millis(),
342 "Creating UDS v2 client"
343 );
344
345 Ok(Self {
346 agent_id,
347 socket_path,
348 timeout,
349 capabilities: RwLock::new(None),
350 protocol_version: AtomicU64::new(0),
351 encoding: RwLock::new(UdsEncoding::Json),
352 pending: Arc::new(Mutex::new(HashMap::new())),
353 outbound_tx: Mutex::new(None),
354 ping_sequence: AtomicU64::new(0),
355 connected: RwLock::new(false),
356 flow_state: RwLock::new(FlowState::Normal),
357 health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
359 metrics_callback: None,
360 config_update_callback: None,
361 })
362 }
363
364 fn supported_encodings() -> Vec<UdsEncoding> {
368 #[cfg(feature = "binary-uds")]
369 {
370 vec![UdsEncoding::MessagePack, UdsEncoding::Json]
371 }
372 #[cfg(not(feature = "binary-uds"))]
373 {
374 vec![UdsEncoding::Json]
375 }
376 }
377
378 pub async fn encoding(&self) -> UdsEncoding {
380 *self.encoding.read().await
381 }
382
383 pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
385 self.metrics_callback = Some(callback);
386 }
387
388 pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
390 self.config_update_callback = Some(callback);
391 }
392
393 pub async fn connect(&self) -> Result<(), AgentProtocolError> {
395 info!(
396 agent_id = %self.agent_id,
397 socket_path = %self.socket_path,
398 "Connecting to agent via UDS v2"
399 );
400
401 let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
403 error!(
404 agent_id = %self.agent_id,
405 socket_path = %self.socket_path,
406 error = %e,
407 "Failed to connect to agent via UDS"
408 );
409 AgentProtocolError::ConnectionFailed(e.to_string())
410 })?;
411
412 let (read_half, write_half) = stream.into_split();
413 let mut reader = BufReader::new(read_half);
414 let mut writer = BufWriter::new(write_half);
415
416 let handshake_req = UdsHandshakeRequest {
418 supported_versions: vec![PROTOCOL_VERSION_2],
419 proxy_id: "sentinel-proxy".to_string(),
420 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
421 config: None,
422 supported_encodings: Self::supported_encodings(),
423 };
424
425 let payload = serde_json::to_vec(&handshake_req)
427 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
428
429 write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
430
431 let (msg_type, response_bytes) = read_message(&mut reader).await?;
433
434 if msg_type != MessageType::HandshakeResponse {
435 return Err(AgentProtocolError::InvalidMessage(format!(
436 "Expected HandshakeResponse, got {:?}",
437 msg_type
438 )));
439 }
440
441 let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
442 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
443
444 if !response.success {
445 return Err(AgentProtocolError::ConnectionFailed(
446 response
447 .error
448 .unwrap_or_else(|| "Unknown handshake error".to_string()),
449 ));
450 }
451
452 let capabilities: AgentCapabilities = response.capabilities.into();
454 *self.capabilities.write().await = Some(capabilities);
455 self.protocol_version
456 .store(response.protocol_version as u64, Ordering::SeqCst);
457
458 let negotiated_encoding = response.encoding;
460 *self.encoding.write().await = negotiated_encoding;
461
462 info!(
463 agent_id = %self.agent_id,
464 protocol_version = response.protocol_version,
465 encoding = ?negotiated_encoding,
466 "UDS v2 handshake successful"
467 );
468
469 let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
471 *self.outbound_tx.lock().await = Some(tx);
472 *self.connected.write().await = true;
473
474 let agent_id_clone = self.agent_id.clone();
476 tokio::spawn(async move {
477 while let Some((msg_type, payload)) = rx.recv().await {
478 if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
479 error!(
480 agent_id = %agent_id_clone,
481 error = %e,
482 "Failed to write message to UDS"
483 );
484 break;
485 }
486 }
487 debug!(agent_id = %agent_id_clone, "UDS writer task ended");
488 });
489
490 let pending = Arc::clone(&self.pending);
492 let agent_id = self.agent_id.clone();
493 let flow_state = Arc::new(RwLock::new(FlowState::Normal));
494 let health_state = Arc::new(RwLock::new(1i32));
495 let flow_state_clone = Arc::clone(&flow_state);
496 let health_state_clone = Arc::clone(&health_state);
497 let metrics_callback = self.metrics_callback.clone();
498 let config_update_callback = self.config_update_callback.clone();
499 let reader_encoding = negotiated_encoding;
501
502 tokio::spawn(async move {
503 loop {
504 match read_message(&mut reader).await {
505 Ok((msg_type, payload)) => {
506 match msg_type {
507 MessageType::AgentResponse => {
508 match reader_encoding.deserialize::<AgentResponse>(&payload) {
509 Ok(response) => {
510 if let Some(sender) = pending.lock().await.remove(
513 &response
514 .audit
515 .custom
516 .get("correlation_id")
517 .and_then(|v| v.as_str())
518 .unwrap_or("")
519 .to_string(),
520 ) {
521 let _ = sender.send(response);
522 }
523 }
524 Err(e) => {
525 warn!(
526 agent_id = %agent_id,
527 error = %e,
528 encoding = ?reader_encoding,
529 "Failed to parse agent response"
530 );
531 }
532 }
533 }
534 MessageType::HealthStatus => {
535 #[derive(serde::Deserialize)]
537 struct HealthStatusMsg {
538 state: Option<i64>,
539 }
540 if let Ok(health) =
541 reader_encoding.deserialize::<HealthStatusMsg>(&payload)
542 {
543 if let Some(state) = health.state {
544 *health_state_clone.write().await = state as i32;
545 }
546 }
547 }
548 MessageType::MetricsReport => {
549 if let Some(ref callback) = metrics_callback {
550 if let Ok(report) = reader_encoding.deserialize(&payload) {
551 callback(report);
552 }
553 }
554 }
555 MessageType::FlowControl => {
556 #[derive(serde::Deserialize)]
557 struct FlowControlMsg {
558 action: Option<i64>,
559 }
560 if let Ok(fc) =
561 reader_encoding.deserialize::<FlowControlMsg>(&payload)
562 {
563 let action = fc.action.unwrap_or(0);
564 let new_state = match action {
565 1 => FlowState::Paused,
566 2 => FlowState::Normal,
567 _ => FlowState::Normal,
568 };
569 *flow_state_clone.write().await = new_state;
570 }
571 }
572 MessageType::ConfigUpdateRequest => {
573 if let Some(ref callback) = config_update_callback {
574 if let Ok(request) = reader_encoding.deserialize(&payload) {
575 let _response = callback(agent_id.clone(), request);
576 }
577 }
578 }
579 MessageType::Pong => {
580 trace!(agent_id = %agent_id, "Received pong");
581 }
582 _ => {
583 trace!(
584 agent_id = %agent_id,
585 msg_type = ?msg_type,
586 "Received unhandled message type"
587 );
588 }
589 }
590 }
591 Err(e) => {
592 if !matches!(e, AgentProtocolError::ConnectionClosed) {
593 error!(
594 agent_id = %agent_id,
595 error = %e,
596 "Error reading from UDS"
597 );
598 }
599 break;
600 }
601 }
602 }
603 debug!(agent_id = %agent_id, "UDS reader task ended");
604 });
605
606 Ok(())
607 }
608
609 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
611 self.capabilities.read().await.clone()
612 }
613
614 pub async fn is_connected(&self) -> bool {
616 *self.connected.read().await
617 }
618
619 pub async fn send_request_headers(
621 &self,
622 correlation_id: &str,
623 event: &crate::RequestHeadersEvent,
624 ) -> Result<AgentResponse, AgentProtocolError> {
625 self.send_event(MessageType::RequestHeaders, correlation_id, event)
626 .await
627 }
628
629 pub async fn send_request_body_chunk(
631 &self,
632 correlation_id: &str,
633 event: &crate::RequestBodyChunkEvent,
634 ) -> Result<AgentResponse, AgentProtocolError> {
635 self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
636 .await
637 }
638
639 pub async fn send_response_headers(
641 &self,
642 correlation_id: &str,
643 event: &crate::ResponseHeadersEvent,
644 ) -> Result<AgentResponse, AgentProtocolError> {
645 self.send_event(MessageType::ResponseHeaders, correlation_id, event)
646 .await
647 }
648
649 pub async fn send_response_body_chunk(
651 &self,
652 correlation_id: &str,
653 event: &crate::ResponseBodyChunkEvent,
654 ) -> Result<AgentResponse, AgentProtocolError> {
655 self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
656 .await
657 }
658
659 pub async fn send_request_body_chunk_binary(
673 &self,
674 event: &crate::BinaryRequestBodyChunkEvent,
675 ) -> Result<AgentResponse, AgentProtocolError> {
676 let correlation_id = &event.correlation_id;
677 self.send_binary_body_chunk(
678 MessageType::RequestBodyChunk,
679 correlation_id,
680 &event.data,
681 event.is_last,
682 event.total_size,
683 event.chunk_index,
684 Some(event.bytes_received),
685 None,
686 )
687 .await
688 }
689
690 pub async fn send_response_body_chunk_binary(
695 &self,
696 event: &crate::BinaryResponseBodyChunkEvent,
697 ) -> Result<AgentResponse, AgentProtocolError> {
698 let correlation_id = &event.correlation_id;
699 self.send_binary_body_chunk(
700 MessageType::ResponseBodyChunk,
701 correlation_id,
702 &event.data,
703 event.is_last,
704 event.total_size,
705 event.chunk_index,
706 None,
707 Some(event.bytes_sent),
708 )
709 .await
710 }
711
712 #[allow(clippy::too_many_arguments)]
714 async fn send_binary_body_chunk(
715 &self,
716 msg_type: MessageType,
717 correlation_id: &str,
718 data: &bytes::Bytes,
719 is_last: bool,
720 total_size: Option<usize>,
721 chunk_index: u32,
722 bytes_received: Option<usize>,
723 bytes_sent: Option<usize>,
724 ) -> Result<AgentResponse, AgentProtocolError> {
725 let (tx, rx) = oneshot::channel();
727 self.pending
728 .lock()
729 .await
730 .insert(correlation_id.to_string(), tx);
731
732 let encoding = *self.encoding.read().await;
734
735 let payload_bytes = match encoding {
737 UdsEncoding::Json => {
738 use base64::{engine::general_purpose::STANDARD, Engine as _};
740 let json = serde_json::json!({
741 "correlation_id": correlation_id,
742 "data": STANDARD.encode(data),
743 "is_last": is_last,
744 "total_size": total_size,
745 "chunk_index": chunk_index,
746 "bytes_received": bytes_received,
747 "bytes_sent": bytes_sent,
748 });
749 serde_json::to_vec(&json)
750 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
751 }
752 UdsEncoding::MessagePack => {
753 #[derive(serde::Serialize)]
755 struct BinaryBodyChunk<'a> {
756 correlation_id: &'a str,
757 #[serde(with = "serde_bytes")]
758 data: &'a [u8],
759 is_last: bool,
760 #[serde(skip_serializing_if = "Option::is_none")]
761 total_size: Option<usize>,
762 chunk_index: u32,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 bytes_received: Option<usize>,
765 #[serde(skip_serializing_if = "Option::is_none")]
766 bytes_sent: Option<usize>,
767 }
768 let chunk = BinaryBodyChunk {
769 correlation_id,
770 data: data.as_ref(),
771 is_last,
772 total_size,
773 chunk_index,
774 bytes_received,
775 bytes_sent,
776 };
777 encoding.serialize(&chunk)?
778 }
779 };
780
781 {
783 let outbound = self.outbound_tx.lock().await;
784 if let Some(tx) = outbound.as_ref() {
785 tx.send((msg_type, payload_bytes))
786 .await
787 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
788 } else {
789 return Err(AgentProtocolError::ConnectionClosed);
790 }
791 }
792
793 self.in_flight.fetch_add(1, Ordering::Relaxed);
794
795 let response = tokio::time::timeout(self.timeout, rx)
797 .await
798 .map_err(|_| {
799 self.pending
800 .try_lock()
801 .ok()
802 .map(|mut p| p.remove(correlation_id));
803 AgentProtocolError::Timeout(self.timeout)
804 })?
805 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
806
807 self.in_flight.fetch_sub(1, Ordering::Relaxed);
808
809 Ok(response)
810 }
811
812 async fn send_event<T: serde::Serialize>(
814 &self,
815 msg_type: MessageType,
816 correlation_id: &str,
817 event: &T,
818 ) -> Result<AgentResponse, AgentProtocolError> {
819 let (tx, rx) = oneshot::channel();
821 self.pending
822 .lock()
823 .await
824 .insert(correlation_id.to_string(), tx);
825
826 let encoding = *self.encoding.read().await;
828
829 let payload_bytes = match encoding {
831 UdsEncoding::Json => {
832 let mut payload = serde_json::to_value(event)
834 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
835 if let Some(obj) = payload.as_object_mut() {
836 obj.insert(
837 "correlation_id".to_string(),
838 serde_json::Value::String(correlation_id.to_string()),
839 );
840 }
841 serde_json::to_vec(&payload)
842 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
843 }
844 UdsEncoding::MessagePack => {
845 #[derive(serde::Serialize)]
847 struct EventWithCorrelation<'a, T: serde::Serialize> {
848 correlation_id: &'a str,
849 #[serde(flatten)]
850 event: &'a T,
851 }
852 let wrapped = EventWithCorrelation {
853 correlation_id,
854 event,
855 };
856 encoding.serialize(&wrapped)?
857 }
858 };
859
860 {
862 let outbound = self.outbound_tx.lock().await;
863 if let Some(tx) = outbound.as_ref() {
864 tx.send((msg_type, payload_bytes))
865 .await
866 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
867 } else {
868 return Err(AgentProtocolError::ConnectionClosed);
869 }
870 }
871
872 self.in_flight.fetch_add(1, Ordering::Relaxed);
873
874 let response = tokio::time::timeout(self.timeout, rx)
876 .await
877 .map_err(|_| {
878 self.pending
879 .try_lock()
880 .ok()
881 .map(|mut p| p.remove(correlation_id));
882 AgentProtocolError::Timeout(self.timeout)
883 })?
884 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
885
886 self.in_flight.fetch_sub(1, Ordering::Relaxed);
887
888 Ok(response)
889 }
890
891 pub async fn cancel_request(
893 &self,
894 correlation_id: &str,
895 reason: super::client::CancelReason,
896 ) -> Result<(), AgentProtocolError> {
897 let cancel = serde_json::json!({
898 "correlation_id": correlation_id,
899 "reason": reason as i32,
900 "timestamp_ms": now_ms(),
901 });
902
903 let payload = serde_json::to_vec(&cancel)
904 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
905
906 let outbound = self.outbound_tx.lock().await;
907 if let Some(tx) = outbound.as_ref() {
908 tx.send((MessageType::Cancel, payload))
909 .await
910 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
911 }
912
913 self.pending.lock().await.remove(correlation_id);
915
916 Ok(())
917 }
918
919 pub async fn cancel_all(
921 &self,
922 reason: super::client::CancelReason,
923 ) -> Result<usize, AgentProtocolError> {
924 let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
925 let count = pending_ids.len();
926
927 for correlation_id in pending_ids {
928 let _ = self.cancel_request(&correlation_id, reason).await;
929 }
930
931 Ok(count)
932 }
933
934 pub async fn ping(&self) -> Result<(), AgentProtocolError> {
936 let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
937 let ping = serde_json::json!({
938 "sequence": seq,
939 "timestamp_ms": now_ms(),
940 });
941
942 let payload = serde_json::to_vec(&ping)
943 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
944
945 let outbound = self.outbound_tx.lock().await;
946 if let Some(tx) = outbound.as_ref() {
947 tx.send((MessageType::Ping, payload))
948 .await
949 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
950 }
951
952 Ok(())
953 }
954
955 pub async fn close(&self) -> Result<(), AgentProtocolError> {
957 *self.connected.write().await = false;
958 *self.outbound_tx.lock().await = None;
959 Ok(())
960 }
961
962 pub fn in_flight(&self) -> u64 {
964 self.in_flight.load(Ordering::Relaxed)
965 }
966
967 pub fn agent_id(&self) -> &str {
969 &self.agent_id
970 }
971
972 pub async fn is_paused(&self) -> bool {
977 matches!(*self.flow_state.read().await, FlowState::Paused)
978 }
979
980 pub async fn can_accept_requests(&self) -> bool {
984 !self.is_paused().await
985 }
986}
987
988pub async fn write_message<W: AsyncWriteExt + Unpin>(
990 writer: &mut W,
991 msg_type: MessageType,
992 payload: &[u8],
993) -> Result<(), AgentProtocolError> {
994 if payload.len() > MAX_UDS_MESSAGE_SIZE {
995 return Err(AgentProtocolError::MessageTooLarge {
996 size: payload.len(),
997 max: MAX_UDS_MESSAGE_SIZE,
998 });
999 }
1000
1001 let total_len = (payload.len() + 1) as u32;
1003 writer.write_all(&total_len.to_be_bytes()).await?;
1004
1005 writer.write_all(&[msg_type as u8]).await?;
1007
1008 writer.write_all(payload).await?;
1010 writer.flush().await?;
1011
1012 Ok(())
1013}
1014
1015pub async fn read_message<R: AsyncReadExt + Unpin>(
1017 reader: &mut R,
1018) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
1019 let mut len_bytes = [0u8; 4];
1021 match reader.read_exact(&mut len_bytes).await {
1022 Ok(_) => {}
1023 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1024 return Err(AgentProtocolError::ConnectionClosed);
1025 }
1026 Err(e) => return Err(e.into()),
1027 }
1028
1029 let total_len = u32::from_be_bytes(len_bytes) as usize;
1030
1031 if total_len == 0 {
1032 return Err(AgentProtocolError::InvalidMessage(
1033 "Zero-length message".to_string(),
1034 ));
1035 }
1036
1037 if total_len > MAX_UDS_MESSAGE_SIZE {
1038 return Err(AgentProtocolError::MessageTooLarge {
1039 size: total_len,
1040 max: MAX_UDS_MESSAGE_SIZE,
1041 });
1042 }
1043
1044 let mut type_byte = [0u8; 1];
1046 reader.read_exact(&mut type_byte).await?;
1047 let msg_type = MessageType::try_from(type_byte[0])?;
1048
1049 let payload_len = total_len - 1;
1051 let mut payload = vec![0u8; payload_len];
1052 if payload_len > 0 {
1053 reader.read_exact(&mut payload).await?;
1054 }
1055
1056 Ok((msg_type, payload))
1057}
1058
1059fn now_ms() -> u64 {
1060 std::time::SystemTime::now()
1061 .duration_since(std::time::UNIX_EPOCH)
1062 .map(|d| d.as_millis() as u64)
1063 .unwrap_or(0)
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::*;
1069
1070 #[test]
1071 fn test_message_type_roundtrip() {
1072 let types = [
1073 MessageType::HandshakeRequest,
1074 MessageType::HandshakeResponse,
1075 MessageType::RequestHeaders,
1076 MessageType::AgentResponse,
1077 MessageType::HealthStatus,
1078 MessageType::Ping,
1079 MessageType::Pong,
1080 ];
1081
1082 for msg_type in types {
1083 let byte = msg_type as u8;
1084 let parsed = MessageType::try_from(byte).unwrap();
1085 assert_eq!(parsed, msg_type);
1086 }
1087 }
1088
1089 #[test]
1090 fn test_invalid_message_type() {
1091 let result = MessageType::try_from(0xFF);
1092 assert!(result.is_err());
1093 }
1094
1095 #[test]
1096 fn test_handshake_serialization() {
1097 let req = UdsHandshakeRequest {
1098 supported_versions: vec![2],
1099 proxy_id: "test-proxy".to_string(),
1100 proxy_version: "1.0.0".to_string(),
1101 config: None,
1102 supported_encodings: vec![],
1103 };
1104
1105 let json = serde_json::to_string(&req).unwrap();
1106 let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1107
1108 assert_eq!(parsed.supported_versions, vec![2]);
1109 assert_eq!(parsed.proxy_id, "test-proxy");
1110 }
1111
1112 #[tokio::test]
1113 async fn test_write_read_message() {
1114 use tokio::io::duplex;
1115
1116 let (mut client, mut server) = duplex(1024);
1117
1118 let payload = b"test payload";
1120 write_message(&mut client, MessageType::Ping, payload)
1121 .await
1122 .unwrap();
1123
1124 let (msg_type, data) = read_message(&mut server).await.unwrap();
1126 assert_eq!(msg_type, MessageType::Ping);
1127 assert_eq!(data, payload);
1128 }
1129
1130 #[test]
1131 fn test_binary_body_chunk_json_serialization() {
1132 use base64::{engine::general_purpose::STANDARD, Engine as _};
1133
1134 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1135 let correlation_id = "test-123";
1136
1137 let json = serde_json::json!({
1139 "correlation_id": correlation_id,
1140 "data": STANDARD.encode(&data),
1141 "is_last": true,
1142 "total_size": 100usize,
1143 "chunk_index": 0u32,
1144 "bytes_received": 100usize,
1145 });
1146
1147 let serialized = serde_json::to_vec(&json).unwrap();
1148 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1149
1150 let data_field = parsed["data"].as_str().unwrap();
1152 let decoded = STANDARD.decode(data_field).unwrap();
1153 assert_eq!(decoded, data.as_ref());
1154 }
1155
1156 #[test]
1157 #[cfg(feature = "binary-uds")]
1158 fn test_binary_body_chunk_msgpack_serialization() {
1159 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1160 let correlation_id = "test-123";
1161
1162 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1164 struct BinaryBodyChunk {
1165 correlation_id: String,
1166 #[serde(with = "serde_bytes")]
1167 data: Vec<u8>,
1168 is_last: bool,
1169 chunk_index: u32,
1170 }
1171
1172 let chunk = BinaryBodyChunk {
1173 correlation_id: correlation_id.to_string(),
1174 data: data.to_vec(),
1175 is_last: true,
1176 chunk_index: 0,
1177 };
1178
1179 let serialized = rmp_serde::to_vec(&chunk).unwrap();
1181
1182 let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1184 assert_eq!(parsed.correlation_id, correlation_id);
1185 assert_eq!(parsed.data, data.as_ref());
1186 assert!(parsed.is_last);
1187
1188 use base64::Engine as _;
1190 let json_size = serde_json::to_vec(&serde_json::json!({
1191 "correlation_id": correlation_id,
1192 "data": base64::engine::general_purpose::STANDARD.encode(&data),
1193 "is_last": true,
1194 "chunk_index": 0u32,
1195 }))
1196 .unwrap()
1197 .len();
1198
1199 assert!(
1201 serialized.len() < json_size,
1202 "MessagePack ({}) should be smaller than JSON+base64 ({})",
1203 serialized.len(),
1204 json_size
1205 );
1206 }
1207
1208 #[test]
1209 fn test_uds_encoding_default() {
1210 assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1211 }
1212
1213 #[test]
1214 fn test_uds_encoding_serialize_json() {
1215 let encoding = UdsEncoding::Json;
1216 let value = serde_json::json!({"key": "value"});
1217 let serialized = encoding.serialize(&value).unwrap();
1218 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1219 assert_eq!(parsed, value);
1220 }
1221
1222 #[test]
1223 #[cfg(feature = "binary-uds")]
1224 fn test_uds_encoding_serialize_msgpack() {
1225 let encoding = UdsEncoding::MessagePack;
1226 let value = serde_json::json!({"key": "value"});
1227 let serialized = encoding.serialize(&value).unwrap();
1228 let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1230 assert_eq!(parsed, value);
1231 }
1232}