1use bytes::{Buf, BufMut, Bytes, BytesMut};
25use std::collections::HashMap;
26use std::io;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28
29use crate::{AgentProtocolError, Decision, HeaderOp};
30
31pub const MAX_BINARY_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
33
34#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MessageType {
38 HandshakeRequest = 0x01,
40 HandshakeResponse = 0x02,
42 RequestHeaders = 0x10,
44 RequestBodyChunk = 0x11,
46 ResponseHeaders = 0x12,
48 ResponseBodyChunk = 0x13,
50 RequestComplete = 0x14,
52 WebSocketFrame = 0x15,
54 AgentResponse = 0x20,
56 Ping = 0x30,
58 Pong = 0x31,
60 Cancel = 0x40,
62 Error = 0xFF,
64}
65
66impl TryFrom<u8> for MessageType {
67 type Error = AgentProtocolError;
68
69 fn try_from(value: u8) -> Result<Self, AgentProtocolError> {
70 match value {
71 0x01 => Ok(MessageType::HandshakeRequest),
72 0x02 => Ok(MessageType::HandshakeResponse),
73 0x10 => Ok(MessageType::RequestHeaders),
74 0x11 => Ok(MessageType::RequestBodyChunk),
75 0x12 => Ok(MessageType::ResponseHeaders),
76 0x13 => Ok(MessageType::ResponseBodyChunk),
77 0x14 => Ok(MessageType::RequestComplete),
78 0x15 => Ok(MessageType::WebSocketFrame),
79 0x20 => Ok(MessageType::AgentResponse),
80 0x30 => Ok(MessageType::Ping),
81 0x31 => Ok(MessageType::Pong),
82 0x40 => Ok(MessageType::Cancel),
83 0xFF => Ok(MessageType::Error),
84 _ => Err(AgentProtocolError::InvalidMessage(format!(
85 "Unknown message type: 0x{:02x}",
86 value
87 ))),
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct BinaryFrame {
95 pub msg_type: MessageType,
96 pub payload: Bytes,
97}
98
99impl BinaryFrame {
100 pub fn new(msg_type: MessageType, payload: impl Into<Bytes>) -> Self {
102 Self {
103 msg_type,
104 payload: payload.into(),
105 }
106 }
107
108 pub fn encode(&self) -> Bytes {
110 let payload_len = self.payload.len();
111 let total_len = 1 + payload_len; let mut buf = BytesMut::with_capacity(4 + total_len);
114 buf.put_u32(total_len as u32);
115 buf.put_u8(self.msg_type as u8);
116 buf.put_slice(&self.payload);
117
118 buf.freeze()
119 }
120
121 pub async fn decode<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self, AgentProtocolError> {
123 let mut len_buf = [0u8; 4];
125 reader.read_exact(&mut len_buf).await.map_err(|e| {
126 if e.kind() == io::ErrorKind::UnexpectedEof {
127 AgentProtocolError::ConnectionFailed("Connection closed".to_string())
128 } else {
129 AgentProtocolError::Io(e)
130 }
131 })?;
132 let total_len = u32::from_be_bytes(len_buf) as usize;
133
134 if total_len == 0 {
136 return Err(AgentProtocolError::InvalidMessage(
137 "Empty message".to_string(),
138 ));
139 }
140 if total_len > MAX_BINARY_MESSAGE_SIZE {
141 return Err(AgentProtocolError::MessageTooLarge {
142 size: total_len,
143 max: MAX_BINARY_MESSAGE_SIZE,
144 });
145 }
146
147 let mut type_buf = [0u8; 1];
149 reader.read_exact(&mut type_buf).await?;
150 let msg_type = MessageType::try_from(type_buf[0])?;
151
152 let payload_len = total_len - 1;
154 let mut payload = BytesMut::with_capacity(payload_len);
155 payload.resize(payload_len, 0);
156 reader.read_exact(&mut payload).await?;
157
158 Ok(Self {
159 msg_type,
160 payload: payload.freeze(),
161 })
162 }
163
164 pub async fn write<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> Result<(), AgentProtocolError> {
166 let encoded = self.encode();
167 writer.write_all(&encoded).await?;
168 writer.flush().await?;
169 Ok(())
170 }
171}
172
173#[derive(Debug, Clone)]
183pub struct BinaryRequestHeaders {
184 pub correlation_id: String,
185 pub method: String,
186 pub uri: String,
187 pub headers: HashMap<String, Vec<String>>,
188 pub client_ip: String,
189 pub client_port: u16,
190}
191
192impl BinaryRequestHeaders {
193 pub fn encode(&self) -> Bytes {
195 let mut buf = BytesMut::with_capacity(256);
196
197 put_string(&mut buf, &self.correlation_id);
199 put_string(&mut buf, &self.method);
201 put_string(&mut buf, &self.uri);
203
204 let header_count: usize = self.headers.values().map(|v| v.len()).sum();
206 buf.put_u16(header_count as u16);
207
208 for (name, values) in &self.headers {
210 for value in values {
211 put_string(&mut buf, name);
212 put_string(&mut buf, value);
213 }
214 }
215
216 put_string(&mut buf, &self.client_ip);
218 buf.put_u16(self.client_port);
220
221 buf.freeze()
222 }
223
224 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
226 let correlation_id = get_string(&mut data)?;
227 let method = get_string(&mut data)?;
228 let uri = get_string(&mut data)?;
229
230 if data.remaining() < 2 {
232 return Err(AgentProtocolError::InvalidMessage(
233 "Missing header count".to_string(),
234 ));
235 }
236 let header_count = data.get_u16() as usize;
237
238 let mut headers: HashMap<String, Vec<String>> = HashMap::new();
239 for _ in 0..header_count {
240 let name = get_string(&mut data)?;
241 let value = get_string(&mut data)?;
242 headers.entry(name).or_default().push(value);
243 }
244
245 let client_ip = get_string(&mut data)?;
246
247 if data.remaining() < 2 {
248 return Err(AgentProtocolError::InvalidMessage(
249 "Missing client port".to_string(),
250 ));
251 }
252 let client_port = data.get_u16();
253
254 Ok(Self {
255 correlation_id,
256 method,
257 uri,
258 headers,
259 client_ip,
260 client_port,
261 })
262 }
263}
264
265#[derive(Debug, Clone)]
274pub struct BinaryBodyChunk {
275 pub correlation_id: String,
276 pub chunk_index: u32,
277 pub is_last: bool,
278 pub data: Bytes,
279}
280
281impl BinaryBodyChunk {
282 pub fn encode(&self) -> Bytes {
284 let mut buf = BytesMut::with_capacity(32 + self.data.len());
285
286 put_string(&mut buf, &self.correlation_id);
287 buf.put_u32(self.chunk_index);
288 buf.put_u8(if self.is_last { 1 } else { 0 });
289 buf.put_u32(self.data.len() as u32);
290 buf.put_slice(&self.data);
291
292 buf.freeze()
293 }
294
295 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
297 let correlation_id = get_string(&mut data)?;
298
299 if data.remaining() < 9 {
300 return Err(AgentProtocolError::InvalidMessage(
301 "Missing body chunk fields".to_string(),
302 ));
303 }
304
305 let chunk_index = data.get_u32();
306 let is_last = data.get_u8() != 0;
307 let data_len = data.get_u32() as usize;
308
309 if data.remaining() < data_len {
310 return Err(AgentProtocolError::InvalidMessage(
311 "Body data truncated".to_string(),
312 ));
313 }
314
315 let body_data = data.copy_to_bytes(data_len);
316
317 Ok(Self {
318 correlation_id,
319 chunk_index,
320 is_last,
321 data: body_data,
322 })
323 }
324}
325
326#[derive(Debug, Clone)]
336pub struct BinaryAgentResponse {
337 pub correlation_id: String,
338 pub decision: Decision,
339 pub request_headers: Vec<HeaderOp>,
340 pub response_headers: Vec<HeaderOp>,
341 pub needs_more: bool,
342}
343
344impl BinaryAgentResponse {
345 pub fn encode(&self) -> Bytes {
347 let mut buf = BytesMut::with_capacity(128);
348
349 put_string(&mut buf, &self.correlation_id);
350
351 match &self.decision {
353 Decision::Allow => {
354 buf.put_u8(0);
355 }
356 Decision::Block { status, body, headers } => {
357 buf.put_u8(1);
358 buf.put_u16(*status);
359 put_optional_string(&mut buf, body.as_deref());
360 let h_count = headers.as_ref().map(|h| h.len()).unwrap_or(0);
362 buf.put_u16(h_count as u16);
363 if let Some(headers) = headers {
364 for (k, v) in headers {
365 put_string(&mut buf, k);
366 put_string(&mut buf, v);
367 }
368 }
369 }
370 Decision::Redirect { url, status } => {
371 buf.put_u8(2);
372 put_string(&mut buf, url);
373 buf.put_u16(*status);
374 }
375 Decision::Challenge { challenge_type, params } => {
376 buf.put_u8(3);
377 put_string(&mut buf, challenge_type);
378 buf.put_u16(params.len() as u16);
379 for (k, v) in params {
380 put_string(&mut buf, k);
381 put_string(&mut buf, v);
382 }
383 }
384 }
385
386 buf.put_u16(self.request_headers.len() as u16);
388 for op in &self.request_headers {
389 encode_header_op(&mut buf, op);
390 }
391
392 buf.put_u16(self.response_headers.len() as u16);
394 for op in &self.response_headers {
395 encode_header_op(&mut buf, op);
396 }
397
398 buf.put_u8(if self.needs_more { 1 } else { 0 });
400
401 buf.freeze()
402 }
403
404 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
406 let correlation_id = get_string(&mut data)?;
407
408 if data.remaining() < 1 {
409 return Err(AgentProtocolError::InvalidMessage(
410 "Missing decision type".to_string(),
411 ));
412 }
413
414 let decision_type = data.get_u8();
415 let decision = match decision_type {
416 0 => Decision::Allow,
417 1 => {
418 if data.remaining() < 2 {
419 return Err(AgentProtocolError::InvalidMessage(
420 "Missing block status".to_string(),
421 ));
422 }
423 let status = data.get_u16();
424 let body = get_optional_string(&mut data)?;
425 if data.remaining() < 2 {
426 return Err(AgentProtocolError::InvalidMessage(
427 "Missing block headers count".to_string(),
428 ));
429 }
430 let h_count = data.get_u16() as usize;
431 let headers = if h_count > 0 {
432 let mut h = HashMap::new();
433 for _ in 0..h_count {
434 let k = get_string(&mut data)?;
435 let v = get_string(&mut data)?;
436 h.insert(k, v);
437 }
438 Some(h)
439 } else {
440 None
441 };
442 Decision::Block { status, body, headers }
443 }
444 2 => {
445 let url = get_string(&mut data)?;
446 if data.remaining() < 2 {
447 return Err(AgentProtocolError::InvalidMessage(
448 "Missing redirect status".to_string(),
449 ));
450 }
451 let status = data.get_u16();
452 Decision::Redirect { url, status }
453 }
454 3 => {
455 let challenge_type = get_string(&mut data)?;
456 if data.remaining() < 2 {
457 return Err(AgentProtocolError::InvalidMessage(
458 "Missing challenge params count".to_string(),
459 ));
460 }
461 let p_count = data.get_u16() as usize;
462 let mut params = HashMap::new();
463 for _ in 0..p_count {
464 let k = get_string(&mut data)?;
465 let v = get_string(&mut data)?;
466 params.insert(k, v);
467 }
468 Decision::Challenge { challenge_type, params }
469 }
470 _ => {
471 return Err(AgentProtocolError::InvalidMessage(format!(
472 "Unknown decision type: {}",
473 decision_type
474 )));
475 }
476 };
477
478 if data.remaining() < 2 {
480 return Err(AgentProtocolError::InvalidMessage(
481 "Missing request headers count".to_string(),
482 ));
483 }
484 let req_h_count = data.get_u16() as usize;
485 let mut request_headers = Vec::with_capacity(req_h_count);
486 for _ in 0..req_h_count {
487 request_headers.push(decode_header_op(&mut data)?);
488 }
489
490 if data.remaining() < 2 {
492 return Err(AgentProtocolError::InvalidMessage(
493 "Missing response headers count".to_string(),
494 ));
495 }
496 let resp_h_count = data.get_u16() as usize;
497 let mut response_headers = Vec::with_capacity(resp_h_count);
498 for _ in 0..resp_h_count {
499 response_headers.push(decode_header_op(&mut data)?);
500 }
501
502 if data.remaining() < 1 {
504 return Err(AgentProtocolError::InvalidMessage(
505 "Missing needs_more".to_string(),
506 ));
507 }
508 let needs_more = data.get_u8() != 0;
509
510 Ok(Self {
511 correlation_id,
512 decision,
513 request_headers,
514 response_headers,
515 needs_more,
516 })
517 }
518}
519
520fn put_string(buf: &mut BytesMut, s: &str) {
525 let bytes = s.as_bytes();
526 buf.put_u16(bytes.len() as u16);
527 buf.put_slice(bytes);
528}
529
530fn get_string(data: &mut Bytes) -> Result<String, AgentProtocolError> {
531 if data.remaining() < 2 {
532 return Err(AgentProtocolError::InvalidMessage(
533 "Missing string length".to_string(),
534 ));
535 }
536 let len = data.get_u16() as usize;
537 if data.remaining() < len {
538 return Err(AgentProtocolError::InvalidMessage(
539 "String data truncated".to_string(),
540 ));
541 }
542 let bytes = data.copy_to_bytes(len);
543 String::from_utf8(bytes.to_vec())
544 .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid UTF-8: {}", e)))
545}
546
547fn put_optional_string(buf: &mut BytesMut, s: Option<&str>) {
548 match s {
549 Some(s) => {
550 buf.put_u8(1);
551 put_string(buf, s);
552 }
553 None => {
554 buf.put_u8(0);
555 }
556 }
557}
558
559fn get_optional_string(data: &mut Bytes) -> Result<Option<String>, AgentProtocolError> {
560 if data.remaining() < 1 {
561 return Err(AgentProtocolError::InvalidMessage(
562 "Missing optional string flag".to_string(),
563 ));
564 }
565 let present = data.get_u8() != 0;
566 if present {
567 get_string(data).map(Some)
568 } else {
569 Ok(None)
570 }
571}
572
573fn encode_header_op(buf: &mut BytesMut, op: &HeaderOp) {
574 match op {
575 HeaderOp::Set { name, value } => {
576 buf.put_u8(0);
577 put_string(buf, name);
578 put_string(buf, value);
579 }
580 HeaderOp::Add { name, value } => {
581 buf.put_u8(1);
582 put_string(buf, name);
583 put_string(buf, value);
584 }
585 HeaderOp::Remove { name } => {
586 buf.put_u8(2);
587 put_string(buf, name);
588 }
589 }
590}
591
592fn decode_header_op(data: &mut Bytes) -> Result<HeaderOp, AgentProtocolError> {
593 if data.remaining() < 1 {
594 return Err(AgentProtocolError::InvalidMessage(
595 "Missing header op type".to_string(),
596 ));
597 }
598 let op_type = data.get_u8();
599 match op_type {
600 0 => {
601 let name = get_string(data)?;
602 let value = get_string(data)?;
603 Ok(HeaderOp::Set { name, value })
604 }
605 1 => {
606 let name = get_string(data)?;
607 let value = get_string(data)?;
608 Ok(HeaderOp::Add { name, value })
609 }
610 2 => {
611 let name = get_string(data)?;
612 Ok(HeaderOp::Remove { name })
613 }
614 _ => Err(AgentProtocolError::InvalidMessage(format!(
615 "Unknown header op type: {}",
616 op_type
617 ))),
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn test_message_type_roundtrip() {
627 for t in [
628 MessageType::HandshakeRequest,
629 MessageType::HandshakeResponse,
630 MessageType::RequestHeaders,
631 MessageType::RequestBodyChunk,
632 MessageType::AgentResponse,
633 MessageType::Ping,
634 MessageType::Pong,
635 MessageType::Cancel,
636 MessageType::Error,
637 ] {
638 let byte = t as u8;
639 let decoded = MessageType::try_from(byte).unwrap();
640 assert_eq!(t, decoded);
641 }
642 }
643
644 #[test]
645 fn test_binary_frame_encode_decode() {
646 let frame = BinaryFrame::new(MessageType::Ping, Bytes::from_static(b"hello"));
647 let encoded = frame.encode();
648
649 assert_eq!(encoded.len(), 4 + 1 + 5); assert_eq!(&encoded[0..4], &[0, 0, 0, 6]); assert_eq!(encoded[4], MessageType::Ping as u8);
653 assert_eq!(&encoded[5..], b"hello");
654 }
655
656 #[test]
657 fn test_binary_request_headers_roundtrip() {
658 let headers = BinaryRequestHeaders {
659 correlation_id: "req-123".to_string(),
660 method: "POST".to_string(),
661 uri: "/api/test".to_string(),
662 headers: {
663 let mut h = HashMap::new();
664 h.insert("content-type".to_string(), vec!["application/json".to_string()]);
665 h.insert("x-custom".to_string(), vec!["value1".to_string(), "value2".to_string()]);
666 h
667 },
668 client_ip: "192.168.1.1".to_string(),
669 client_port: 12345,
670 };
671
672 let encoded = headers.encode();
673 let decoded = BinaryRequestHeaders::decode(encoded).unwrap();
674
675 assert_eq!(decoded.correlation_id, "req-123");
676 assert_eq!(decoded.method, "POST");
677 assert_eq!(decoded.uri, "/api/test");
678 assert_eq!(decoded.client_ip, "192.168.1.1");
679 assert_eq!(decoded.client_port, 12345);
680 assert_eq!(decoded.headers.get("content-type").unwrap(), &vec!["application/json".to_string()]);
681 }
682
683 #[test]
684 fn test_binary_body_chunk_roundtrip() {
685 let chunk = BinaryBodyChunk {
686 correlation_id: "req-456".to_string(),
687 chunk_index: 2,
688 is_last: true,
689 data: Bytes::from_static(b"binary data here"),
690 };
691
692 let encoded = chunk.encode();
693 let decoded = BinaryBodyChunk::decode(encoded).unwrap();
694
695 assert_eq!(decoded.correlation_id, "req-456");
696 assert_eq!(decoded.chunk_index, 2);
697 assert!(decoded.is_last);
698 assert_eq!(&decoded.data[..], b"binary data here");
699 }
700
701 #[test]
702 fn test_binary_agent_response_allow() {
703 let response = BinaryAgentResponse {
704 correlation_id: "req-789".to_string(),
705 decision: Decision::Allow,
706 request_headers: vec![HeaderOp::Set {
707 name: "X-Added".to_string(),
708 value: "true".to_string(),
709 }],
710 response_headers: vec![],
711 needs_more: false,
712 };
713
714 let encoded = response.encode();
715 let decoded = BinaryAgentResponse::decode(encoded).unwrap();
716
717 assert_eq!(decoded.correlation_id, "req-789");
718 assert!(matches!(decoded.decision, Decision::Allow));
719 assert_eq!(decoded.request_headers.len(), 1);
720 assert!(!decoded.needs_more);
721 }
722
723 #[test]
724 fn test_binary_agent_response_block() {
725 let response = BinaryAgentResponse {
726 correlation_id: "req-block".to_string(),
727 decision: Decision::Block {
728 status: 403,
729 body: Some("Forbidden".to_string()),
730 headers: None,
731 },
732 request_headers: vec![],
733 response_headers: vec![],
734 needs_more: false,
735 };
736
737 let encoded = response.encode();
738 let decoded = BinaryAgentResponse::decode(encoded).unwrap();
739
740 assert_eq!(decoded.correlation_id, "req-block");
741 match decoded.decision {
742 Decision::Block { status, body, headers } => {
743 assert_eq!(status, 403);
744 assert_eq!(body, Some("Forbidden".to_string()));
745 assert!(headers.is_none());
746 }
747 _ => panic!("Expected Block decision"),
748 }
749 }
750}