sentinel_agent_protocol/
binary.rs

1//! Binary protocol for Unix Domain Socket transport.
2//!
3//! This module provides a binary framing format for efficient communication
4//! over UDS, eliminating JSON overhead and base64 encoding for body data.
5//!
6//! # Wire Format
7//!
8//! ```text
9//! +----------------+---------------+-------------------+
10//! | Length (4 BE)  | Type (1 byte) | Payload (N bytes) |
11//! +----------------+---------------+-------------------+
12//! ```
13//!
14//! - **Length**: 4-byte big-endian u32, total length of type + payload
15//! - **Type**: 1-byte message type discriminator
16//! - **Payload**: Variable-length payload (format depends on type)
17//!
18//! # Performance Benefits
19//!
20//! - No JSON parsing overhead (~10x faster for small messages)
21//! - No base64 encoding for body data (saves 33% bandwidth)
22//! - Zero-copy with `bytes::Bytes` where possible
23
24use bytes::{Buf, BufMut, Bytes, BytesMut};
25use std::collections::HashMap;
26use std::io;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28
29use crate::{AgentProtocolError, Decision, HeaderOp};
30
31/// Maximum binary message size (10 MB)
32pub const MAX_BINARY_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
33
34/// Binary message types
35#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MessageType {
38    /// Handshake request (proxy -> agent)
39    HandshakeRequest = 0x01,
40    /// Handshake response (agent -> proxy)
41    HandshakeResponse = 0x02,
42    /// Request headers event
43    RequestHeaders = 0x10,
44    /// Request body chunk (raw bytes, no base64)
45    RequestBodyChunk = 0x11,
46    /// Response headers event
47    ResponseHeaders = 0x12,
48    /// Response body chunk (raw bytes, no base64)
49    ResponseBodyChunk = 0x13,
50    /// Request complete event
51    RequestComplete = 0x14,
52    /// WebSocket frame event
53    WebSocketFrame = 0x15,
54    /// Agent response
55    AgentResponse = 0x20,
56    /// Ping
57    Ping = 0x30,
58    /// Pong
59    Pong = 0x31,
60    /// Cancel request
61    Cancel = 0x40,
62    /// Error
63    Error = 0xFF,
64}
65
66impl TryFrom<u8> for MessageType {
67    type Error = AgentProtocolError;
68
69    fn try_from(value: u8) -> Result<Self, AgentProtocolError> {
70        match value {
71            0x01 => Ok(MessageType::HandshakeRequest),
72            0x02 => Ok(MessageType::HandshakeResponse),
73            0x10 => Ok(MessageType::RequestHeaders),
74            0x11 => Ok(MessageType::RequestBodyChunk),
75            0x12 => Ok(MessageType::ResponseHeaders),
76            0x13 => Ok(MessageType::ResponseBodyChunk),
77            0x14 => Ok(MessageType::RequestComplete),
78            0x15 => Ok(MessageType::WebSocketFrame),
79            0x20 => Ok(MessageType::AgentResponse),
80            0x30 => Ok(MessageType::Ping),
81            0x31 => Ok(MessageType::Pong),
82            0x40 => Ok(MessageType::Cancel),
83            0xFF => Ok(MessageType::Error),
84            _ => Err(AgentProtocolError::InvalidMessage(format!(
85                "Unknown message type: 0x{:02x}",
86                value
87            ))),
88        }
89    }
90}
91
92/// Binary frame with header and payload.
93#[derive(Debug, Clone)]
94pub struct BinaryFrame {
95    pub msg_type: MessageType,
96    pub payload: Bytes,
97}
98
99impl BinaryFrame {
100    /// Create a new binary frame.
101    pub fn new(msg_type: MessageType, payload: impl Into<Bytes>) -> Self {
102        Self {
103            msg_type,
104            payload: payload.into(),
105        }
106    }
107
108    /// Encode frame to bytes.
109    pub fn encode(&self) -> Bytes {
110        let payload_len = self.payload.len();
111        let total_len = 1 + payload_len; // type byte + payload
112
113        let mut buf = BytesMut::with_capacity(4 + total_len);
114        buf.put_u32(total_len as u32);
115        buf.put_u8(self.msg_type as u8);
116        buf.put_slice(&self.payload);
117
118        buf.freeze()
119    }
120
121    /// Decode frame from reader.
122    pub async fn decode<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self, AgentProtocolError> {
123        // Read length (4 bytes)
124        let mut len_buf = [0u8; 4];
125        reader.read_exact(&mut len_buf).await.map_err(|e| {
126            if e.kind() == io::ErrorKind::UnexpectedEof {
127                AgentProtocolError::ConnectionFailed("Connection closed".to_string())
128            } else {
129                AgentProtocolError::Io(e)
130            }
131        })?;
132        let total_len = u32::from_be_bytes(len_buf) as usize;
133
134        // Validate length
135        if total_len == 0 {
136            return Err(AgentProtocolError::InvalidMessage(
137                "Empty message".to_string(),
138            ));
139        }
140        if total_len > MAX_BINARY_MESSAGE_SIZE {
141            return Err(AgentProtocolError::MessageTooLarge {
142                size: total_len,
143                max: MAX_BINARY_MESSAGE_SIZE,
144            });
145        }
146
147        // Read type byte
148        let mut type_buf = [0u8; 1];
149        reader.read_exact(&mut type_buf).await?;
150        let msg_type = MessageType::try_from(type_buf[0])?;
151
152        // Read payload
153        let payload_len = total_len - 1;
154        let mut payload = BytesMut::with_capacity(payload_len);
155        payload.resize(payload_len, 0);
156        reader.read_exact(&mut payload).await?;
157
158        Ok(Self {
159            msg_type,
160            payload: payload.freeze(),
161        })
162    }
163
164    /// Write frame to writer.
165    pub async fn write<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> Result<(), AgentProtocolError> {
166        let encoded = self.encode();
167        writer.write_all(&encoded).await?;
168        writer.flush().await?;
169        Ok(())
170    }
171}
172
173/// Binary request headers event.
174///
175/// Wire format:
176/// - correlation_id: length-prefixed string
177/// - method: length-prefixed string
178/// - uri: length-prefixed string
179/// - headers: count (u16) + [(name_len, name, value_len, value), ...]
180/// - client_ip: length-prefixed string
181/// - client_port: u16
182#[derive(Debug, Clone)]
183pub struct BinaryRequestHeaders {
184    pub correlation_id: String,
185    pub method: String,
186    pub uri: String,
187    pub headers: HashMap<String, Vec<String>>,
188    pub client_ip: String,
189    pub client_port: u16,
190}
191
192impl BinaryRequestHeaders {
193    /// Encode to bytes.
194    pub fn encode(&self) -> Bytes {
195        let mut buf = BytesMut::with_capacity(256);
196
197        // Correlation ID
198        put_string(&mut buf, &self.correlation_id);
199        // Method
200        put_string(&mut buf, &self.method);
201        // URI
202        put_string(&mut buf, &self.uri);
203
204        // Headers count
205        let header_count: usize = self.headers.values().map(|v| v.len()).sum();
206        buf.put_u16(header_count as u16);
207
208        // Headers (flattened: each value gets its own entry)
209        for (name, values) in &self.headers {
210            for value in values {
211                put_string(&mut buf, name);
212                put_string(&mut buf, value);
213            }
214        }
215
216        // Client IP
217        put_string(&mut buf, &self.client_ip);
218        // Client port
219        buf.put_u16(self.client_port);
220
221        buf.freeze()
222    }
223
224    /// Decode from bytes.
225    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
226        let correlation_id = get_string(&mut data)?;
227        let method = get_string(&mut data)?;
228        let uri = get_string(&mut data)?;
229
230        // Headers
231        if data.remaining() < 2 {
232            return Err(AgentProtocolError::InvalidMessage(
233                "Missing header count".to_string(),
234            ));
235        }
236        let header_count = data.get_u16() as usize;
237
238        let mut headers: HashMap<String, Vec<String>> = HashMap::new();
239        for _ in 0..header_count {
240            let name = get_string(&mut data)?;
241            let value = get_string(&mut data)?;
242            headers.entry(name).or_default().push(value);
243        }
244
245        let client_ip = get_string(&mut data)?;
246
247        if data.remaining() < 2 {
248            return Err(AgentProtocolError::InvalidMessage(
249                "Missing client port".to_string(),
250            ));
251        }
252        let client_port = data.get_u16();
253
254        Ok(Self {
255            correlation_id,
256            method,
257            uri,
258            headers,
259            client_ip,
260            client_port,
261        })
262    }
263}
264
265/// Binary body chunk event (zero-copy).
266///
267/// Wire format:
268/// - correlation_id: length-prefixed string
269/// - chunk_index: u32
270/// - is_last: u8 (0 or 1)
271/// - data_len: u32
272/// - data: raw bytes (no base64!)
273#[derive(Debug, Clone)]
274pub struct BinaryBodyChunk {
275    pub correlation_id: String,
276    pub chunk_index: u32,
277    pub is_last: bool,
278    pub data: Bytes,
279}
280
281impl BinaryBodyChunk {
282    /// Encode to bytes.
283    pub fn encode(&self) -> Bytes {
284        let mut buf = BytesMut::with_capacity(32 + self.data.len());
285
286        put_string(&mut buf, &self.correlation_id);
287        buf.put_u32(self.chunk_index);
288        buf.put_u8(if self.is_last { 1 } else { 0 });
289        buf.put_u32(self.data.len() as u32);
290        buf.put_slice(&self.data);
291
292        buf.freeze()
293    }
294
295    /// Decode from bytes.
296    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
297        let correlation_id = get_string(&mut data)?;
298
299        if data.remaining() < 9 {
300            return Err(AgentProtocolError::InvalidMessage(
301                "Missing body chunk fields".to_string(),
302            ));
303        }
304
305        let chunk_index = data.get_u32();
306        let is_last = data.get_u8() != 0;
307        let data_len = data.get_u32() as usize;
308
309        if data.remaining() < data_len {
310            return Err(AgentProtocolError::InvalidMessage(
311                "Body data truncated".to_string(),
312            ));
313        }
314
315        let body_data = data.copy_to_bytes(data_len);
316
317        Ok(Self {
318            correlation_id,
319            chunk_index,
320            is_last,
321            data: body_data,
322        })
323    }
324}
325
326/// Binary agent response.
327///
328/// Wire format:
329/// - correlation_id: length-prefixed string
330/// - decision_type: u8 (0=Allow, 1=Block, 2=Redirect, 3=Challenge)
331/// - decision_data: varies by type
332/// - request_headers_ops: count (u16) + ops
333/// - response_headers_ops: count (u16) + ops
334/// - needs_more: u8
335#[derive(Debug, Clone)]
336pub struct BinaryAgentResponse {
337    pub correlation_id: String,
338    pub decision: Decision,
339    pub request_headers: Vec<HeaderOp>,
340    pub response_headers: Vec<HeaderOp>,
341    pub needs_more: bool,
342}
343
344impl BinaryAgentResponse {
345    /// Encode to bytes.
346    pub fn encode(&self) -> Bytes {
347        let mut buf = BytesMut::with_capacity(128);
348
349        put_string(&mut buf, &self.correlation_id);
350
351        // Decision
352        match &self.decision {
353            Decision::Allow => {
354                buf.put_u8(0);
355            }
356            Decision::Block { status, body, headers } => {
357                buf.put_u8(1);
358                buf.put_u16(*status);
359                put_optional_string(&mut buf, body.as_deref());
360                // Block headers
361                let h_count = headers.as_ref().map(|h| h.len()).unwrap_or(0);
362                buf.put_u16(h_count as u16);
363                if let Some(headers) = headers {
364                    for (k, v) in headers {
365                        put_string(&mut buf, k);
366                        put_string(&mut buf, v);
367                    }
368                }
369            }
370            Decision::Redirect { url, status } => {
371                buf.put_u8(2);
372                put_string(&mut buf, url);
373                buf.put_u16(*status);
374            }
375            Decision::Challenge { challenge_type, params } => {
376                buf.put_u8(3);
377                put_string(&mut buf, challenge_type);
378                buf.put_u16(params.len() as u16);
379                for (k, v) in params {
380                    put_string(&mut buf, k);
381                    put_string(&mut buf, v);
382                }
383            }
384        }
385
386        // Request header ops
387        buf.put_u16(self.request_headers.len() as u16);
388        for op in &self.request_headers {
389            encode_header_op(&mut buf, op);
390        }
391
392        // Response header ops
393        buf.put_u16(self.response_headers.len() as u16);
394        for op in &self.response_headers {
395            encode_header_op(&mut buf, op);
396        }
397
398        // Needs more
399        buf.put_u8(if self.needs_more { 1 } else { 0 });
400
401        buf.freeze()
402    }
403
404    /// Decode from bytes.
405    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
406        let correlation_id = get_string(&mut data)?;
407
408        if data.remaining() < 1 {
409            return Err(AgentProtocolError::InvalidMessage(
410                "Missing decision type".to_string(),
411            ));
412        }
413
414        let decision_type = data.get_u8();
415        let decision = match decision_type {
416            0 => Decision::Allow,
417            1 => {
418                if data.remaining() < 2 {
419                    return Err(AgentProtocolError::InvalidMessage(
420                        "Missing block status".to_string(),
421                    ));
422                }
423                let status = data.get_u16();
424                let body = get_optional_string(&mut data)?;
425                if data.remaining() < 2 {
426                    return Err(AgentProtocolError::InvalidMessage(
427                        "Missing block headers count".to_string(),
428                    ));
429                }
430                let h_count = data.get_u16() as usize;
431                let headers = if h_count > 0 {
432                    let mut h = HashMap::new();
433                    for _ in 0..h_count {
434                        let k = get_string(&mut data)?;
435                        let v = get_string(&mut data)?;
436                        h.insert(k, v);
437                    }
438                    Some(h)
439                } else {
440                    None
441                };
442                Decision::Block { status, body, headers }
443            }
444            2 => {
445                let url = get_string(&mut data)?;
446                if data.remaining() < 2 {
447                    return Err(AgentProtocolError::InvalidMessage(
448                        "Missing redirect status".to_string(),
449                    ));
450                }
451                let status = data.get_u16();
452                Decision::Redirect { url, status }
453            }
454            3 => {
455                let challenge_type = get_string(&mut data)?;
456                if data.remaining() < 2 {
457                    return Err(AgentProtocolError::InvalidMessage(
458                        "Missing challenge params count".to_string(),
459                    ));
460                }
461                let p_count = data.get_u16() as usize;
462                let mut params = HashMap::new();
463                for _ in 0..p_count {
464                    let k = get_string(&mut data)?;
465                    let v = get_string(&mut data)?;
466                    params.insert(k, v);
467                }
468                Decision::Challenge { challenge_type, params }
469            }
470            _ => {
471                return Err(AgentProtocolError::InvalidMessage(format!(
472                    "Unknown decision type: {}",
473                    decision_type
474                )));
475            }
476        };
477
478        // Request header ops
479        if data.remaining() < 2 {
480            return Err(AgentProtocolError::InvalidMessage(
481                "Missing request headers count".to_string(),
482            ));
483        }
484        let req_h_count = data.get_u16() as usize;
485        let mut request_headers = Vec::with_capacity(req_h_count);
486        for _ in 0..req_h_count {
487            request_headers.push(decode_header_op(&mut data)?);
488        }
489
490        // Response header ops
491        if data.remaining() < 2 {
492            return Err(AgentProtocolError::InvalidMessage(
493                "Missing response headers count".to_string(),
494            ));
495        }
496        let resp_h_count = data.get_u16() as usize;
497        let mut response_headers = Vec::with_capacity(resp_h_count);
498        for _ in 0..resp_h_count {
499            response_headers.push(decode_header_op(&mut data)?);
500        }
501
502        // Needs more
503        if data.remaining() < 1 {
504            return Err(AgentProtocolError::InvalidMessage(
505                "Missing needs_more".to_string(),
506            ));
507        }
508        let needs_more = data.get_u8() != 0;
509
510        Ok(Self {
511            correlation_id,
512            decision,
513            request_headers,
514            response_headers,
515            needs_more,
516        })
517    }
518}
519
520// =============================================================================
521// Helper Functions
522// =============================================================================
523
524fn put_string(buf: &mut BytesMut, s: &str) {
525    let bytes = s.as_bytes();
526    buf.put_u16(bytes.len() as u16);
527    buf.put_slice(bytes);
528}
529
530fn get_string(data: &mut Bytes) -> Result<String, AgentProtocolError> {
531    if data.remaining() < 2 {
532        return Err(AgentProtocolError::InvalidMessage(
533            "Missing string length".to_string(),
534        ));
535    }
536    let len = data.get_u16() as usize;
537    if data.remaining() < len {
538        return Err(AgentProtocolError::InvalidMessage(
539            "String data truncated".to_string(),
540        ));
541    }
542    let bytes = data.copy_to_bytes(len);
543    String::from_utf8(bytes.to_vec())
544        .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid UTF-8: {}", e)))
545}
546
547fn put_optional_string(buf: &mut BytesMut, s: Option<&str>) {
548    match s {
549        Some(s) => {
550            buf.put_u8(1);
551            put_string(buf, s);
552        }
553        None => {
554            buf.put_u8(0);
555        }
556    }
557}
558
559fn get_optional_string(data: &mut Bytes) -> Result<Option<String>, AgentProtocolError> {
560    if data.remaining() < 1 {
561        return Err(AgentProtocolError::InvalidMessage(
562            "Missing optional string flag".to_string(),
563        ));
564    }
565    let present = data.get_u8() != 0;
566    if present {
567        get_string(data).map(Some)
568    } else {
569        Ok(None)
570    }
571}
572
573fn encode_header_op(buf: &mut BytesMut, op: &HeaderOp) {
574    match op {
575        HeaderOp::Set { name, value } => {
576            buf.put_u8(0);
577            put_string(buf, name);
578            put_string(buf, value);
579        }
580        HeaderOp::Add { name, value } => {
581            buf.put_u8(1);
582            put_string(buf, name);
583            put_string(buf, value);
584        }
585        HeaderOp::Remove { name } => {
586            buf.put_u8(2);
587            put_string(buf, name);
588        }
589    }
590}
591
592fn decode_header_op(data: &mut Bytes) -> Result<HeaderOp, AgentProtocolError> {
593    if data.remaining() < 1 {
594        return Err(AgentProtocolError::InvalidMessage(
595            "Missing header op type".to_string(),
596        ));
597    }
598    let op_type = data.get_u8();
599    match op_type {
600        0 => {
601            let name = get_string(data)?;
602            let value = get_string(data)?;
603            Ok(HeaderOp::Set { name, value })
604        }
605        1 => {
606            let name = get_string(data)?;
607            let value = get_string(data)?;
608            Ok(HeaderOp::Add { name, value })
609        }
610        2 => {
611            let name = get_string(data)?;
612            Ok(HeaderOp::Remove { name })
613        }
614        _ => Err(AgentProtocolError::InvalidMessage(format!(
615            "Unknown header op type: {}",
616            op_type
617        ))),
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn test_message_type_roundtrip() {
627        for t in [
628            MessageType::HandshakeRequest,
629            MessageType::HandshakeResponse,
630            MessageType::RequestHeaders,
631            MessageType::RequestBodyChunk,
632            MessageType::AgentResponse,
633            MessageType::Ping,
634            MessageType::Pong,
635            MessageType::Cancel,
636            MessageType::Error,
637        ] {
638            let byte = t as u8;
639            let decoded = MessageType::try_from(byte).unwrap();
640            assert_eq!(t, decoded);
641        }
642    }
643
644    #[test]
645    fn test_binary_frame_encode_decode() {
646        let frame = BinaryFrame::new(MessageType::Ping, Bytes::from_static(b"hello"));
647        let encoded = frame.encode();
648
649        // Verify structure
650        assert_eq!(encoded.len(), 4 + 1 + 5); // len + type + payload
651        assert_eq!(&encoded[0..4], &[0, 0, 0, 6]); // length = 6 (type + payload)
652        assert_eq!(encoded[4], MessageType::Ping as u8);
653        assert_eq!(&encoded[5..], b"hello");
654    }
655
656    #[test]
657    fn test_binary_request_headers_roundtrip() {
658        let headers = BinaryRequestHeaders {
659            correlation_id: "req-123".to_string(),
660            method: "POST".to_string(),
661            uri: "/api/test".to_string(),
662            headers: {
663                let mut h = HashMap::new();
664                h.insert("content-type".to_string(), vec!["application/json".to_string()]);
665                h.insert("x-custom".to_string(), vec!["value1".to_string(), "value2".to_string()]);
666                h
667            },
668            client_ip: "192.168.1.1".to_string(),
669            client_port: 12345,
670        };
671
672        let encoded = headers.encode();
673        let decoded = BinaryRequestHeaders::decode(encoded).unwrap();
674
675        assert_eq!(decoded.correlation_id, "req-123");
676        assert_eq!(decoded.method, "POST");
677        assert_eq!(decoded.uri, "/api/test");
678        assert_eq!(decoded.client_ip, "192.168.1.1");
679        assert_eq!(decoded.client_port, 12345);
680        assert_eq!(decoded.headers.get("content-type").unwrap(), &vec!["application/json".to_string()]);
681    }
682
683    #[test]
684    fn test_binary_body_chunk_roundtrip() {
685        let chunk = BinaryBodyChunk {
686            correlation_id: "req-456".to_string(),
687            chunk_index: 2,
688            is_last: true,
689            data: Bytes::from_static(b"binary data here"),
690        };
691
692        let encoded = chunk.encode();
693        let decoded = BinaryBodyChunk::decode(encoded).unwrap();
694
695        assert_eq!(decoded.correlation_id, "req-456");
696        assert_eq!(decoded.chunk_index, 2);
697        assert!(decoded.is_last);
698        assert_eq!(&decoded.data[..], b"binary data here");
699    }
700
701    #[test]
702    fn test_binary_agent_response_allow() {
703        let response = BinaryAgentResponse {
704            correlation_id: "req-789".to_string(),
705            decision: Decision::Allow,
706            request_headers: vec![HeaderOp::Set {
707                name: "X-Added".to_string(),
708                value: "true".to_string(),
709            }],
710            response_headers: vec![],
711            needs_more: false,
712        };
713
714        let encoded = response.encode();
715        let decoded = BinaryAgentResponse::decode(encoded).unwrap();
716
717        assert_eq!(decoded.correlation_id, "req-789");
718        assert!(matches!(decoded.decision, Decision::Allow));
719        assert_eq!(decoded.request_headers.len(), 1);
720        assert!(!decoded.needs_more);
721    }
722
723    #[test]
724    fn test_binary_agent_response_block() {
725        let response = BinaryAgentResponse {
726            correlation_id: "req-block".to_string(),
727            decision: Decision::Block {
728                status: 403,
729                body: Some("Forbidden".to_string()),
730                headers: None,
731            },
732            request_headers: vec![],
733            response_headers: vec![],
734            needs_more: false,
735        };
736
737        let encoded = response.encode();
738        let decoded = BinaryAgentResponse::decode(encoded).unwrap();
739
740        assert_eq!(decoded.correlation_id, "req-block");
741        match decoded.decision {
742            Decision::Block { status, body, headers } => {
743                assert_eq!(status, 403);
744                assert_eq!(body, Some("Forbidden".to_string()));
745                assert!(headers.is_none());
746            }
747            _ => panic!("Expected Block decision"),
748        }
749    }
750}