1#![allow(dead_code)]
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{UnixListener, UnixStream};
19use tracing::{debug, error, info};
20
21
22pub const PROTOCOL_VERSION: u32 = 1;
24
25pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
27
28#[derive(Error, Debug)]
30pub enum AgentProtocolError {
31 #[error("Connection failed: {0}")]
32 ConnectionFailed(String),
33
34 #[error("Protocol version mismatch: expected {expected}, got {actual}")]
35 VersionMismatch { expected: u32, actual: u32 },
36
37 #[error("Message too large: {size} bytes (max: {max}")]
38 MessageTooLarge { size: usize, max: usize },
39
40 #[error("Invalid message format: {0}")]
41 InvalidMessage(String),
42
43 #[error("Timeout after {0:?}")]
44 Timeout(Duration),
45
46 #[error("Agent unavailable")]
47 Unavailable,
48
49 #[error("IO error: {0}")]
50 Io(#[from] std::io::Error),
51
52 #[error("Serialization error: {0}")]
53 Serialization(String),
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "snake_case")]
59pub enum EventType {
60 RequestHeaders,
62 RequestBodyChunk,
64 ResponseHeaders,
66 ResponseBodyChunk,
68 RequestComplete,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74#[serde(rename_all = "snake_case")]
75pub enum Decision {
76 Allow,
78 Block {
80 status: u16,
82 body: Option<String>,
84 headers: Option<HashMap<String, String>>,
86 },
87 Redirect {
89 url: String,
91 status: u16,
93 },
94 Challenge {
96 challenge_type: String,
98 params: HashMap<String, String>,
100 },
101}
102
103impl Default for Decision {
104 fn default() -> Self {
105 Self::Allow
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111#[serde(rename_all = "snake_case")]
112pub enum HeaderOp {
113 Set { name: String, value: String },
115 Add { name: String, value: String },
117 Remove { name: String },
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct RequestMetadata {
124 pub correlation_id: String,
126 pub request_id: String,
128 pub client_ip: String,
130 pub client_port: u16,
132 pub server_name: Option<String>,
134 pub protocol: String,
136 pub tls_version: Option<String>,
138 pub tls_cipher: Option<String>,
140 pub route_id: Option<String>,
142 pub upstream_id: Option<String>,
144 pub timestamp: String,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RequestHeadersEvent {
151 pub metadata: RequestMetadata,
153 pub method: String,
155 pub uri: String,
157 pub headers: HashMap<String, Vec<String>>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct RequestBodyChunkEvent {
164 pub correlation_id: String,
166 pub data: String,
168 pub is_last: bool,
170 pub total_size: Option<usize>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ResponseHeadersEvent {
177 pub correlation_id: String,
179 pub status: u16,
181 pub headers: HashMap<String, Vec<String>>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct ResponseBodyChunkEvent {
188 pub correlation_id: String,
190 pub data: String,
192 pub is_last: bool,
194 pub total_size: Option<usize>,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct RequestCompleteEvent {
201 pub correlation_id: String,
203 pub status: u16,
205 pub duration_ms: u64,
207 pub request_body_size: usize,
209 pub response_body_size: usize,
211 pub upstream_attempts: u32,
213 pub error: Option<String>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AgentRequest {
220 pub version: u32,
222 pub event_type: EventType,
224 pub payload: serde_json::Value,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct AgentResponse {
231 pub version: u32,
233 pub decision: Decision,
235 #[serde(default)]
237 pub request_headers: Vec<HeaderOp>,
238 #[serde(default)]
240 pub response_headers: Vec<HeaderOp>,
241 #[serde(default)]
243 pub routing_metadata: HashMap<String, String>,
244 #[serde(default)]
246 pub audit: AuditMetadata,
247}
248
249#[derive(Debug, Clone, Default, Serialize, Deserialize)]
251pub struct AuditMetadata {
252 #[serde(default)]
254 pub tags: Vec<String>,
255 #[serde(default)]
257 pub rule_ids: Vec<String>,
258 pub confidence: Option<f32>,
260 #[serde(default)]
262 pub reason_codes: Vec<String>,
263 #[serde(default)]
265 pub custom: HashMap<String, serde_json::Value>,
266}
267
268pub struct AgentClient {
270 id: String,
272 connection: AgentConnection,
274 timeout: Duration,
276 max_retries: u32,
278}
279
280enum AgentConnection {
282 UnixSocket(UnixStream),
283 Grpc(tonic::transport::Channel),
284}
285
286impl AgentClient {
287 pub async fn unix_socket(
289 id: impl Into<String>,
290 path: impl AsRef<std::path::Path>,
291 timeout: Duration,
292 ) -> Result<Self, AgentProtocolError> {
293 let stream = UnixStream::connect(path.as_ref())
294 .await
295 .map_err(|e| AgentProtocolError::ConnectionFailed(e.to_string()))?;
296
297 Ok(Self {
298 id: id.into(),
299 connection: AgentConnection::UnixSocket(stream),
300 timeout,
301 max_retries: 3,
302 })
303 }
304
305 pub async fn send_event(
307 &mut self,
308 event_type: EventType,
309 payload: impl Serialize,
310 ) -> Result<AgentResponse, AgentProtocolError> {
311 let request = AgentRequest {
312 version: PROTOCOL_VERSION,
313 event_type,
314 payload: serde_json::to_value(payload)
315 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
316 };
317
318 let request_bytes = serde_json::to_vec(&request)
320 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
321
322 if request_bytes.len() > MAX_MESSAGE_SIZE {
324 return Err(AgentProtocolError::MessageTooLarge {
325 size: request_bytes.len(),
326 max: MAX_MESSAGE_SIZE,
327 });
328 }
329
330 let response = tokio::time::timeout(self.timeout, async {
332 self.send_raw(&request_bytes).await?;
333 self.receive_raw().await
334 })
335 .await
336 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
337
338 let agent_response: AgentResponse = serde_json::from_slice(&response)
340 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
341
342 if agent_response.version != PROTOCOL_VERSION {
344 return Err(AgentProtocolError::VersionMismatch {
345 expected: PROTOCOL_VERSION,
346 actual: agent_response.version,
347 });
348 }
349
350 Ok(agent_response)
351 }
352
353 async fn send_raw(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
355 match &mut self.connection {
356 AgentConnection::UnixSocket(stream) => {
357 let len_bytes = (data.len() as u32).to_be_bytes();
359 stream.write_all(&len_bytes).await?;
360 stream.write_all(data).await?;
362 stream.flush().await?;
363 Ok(())
364 }
365 AgentConnection::Grpc(_channel) => {
366 unimplemented!("gRPC transport not yet implemented")
368 }
369 }
370 }
371
372 async fn receive_raw(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
374 match &mut self.connection {
375 AgentConnection::UnixSocket(stream) => {
376 let mut len_bytes = [0u8; 4];
378 stream.read_exact(&mut len_bytes).await?;
379 let message_len = u32::from_be_bytes(len_bytes) as usize;
380
381 if message_len > MAX_MESSAGE_SIZE {
383 return Err(AgentProtocolError::MessageTooLarge {
384 size: message_len,
385 max: MAX_MESSAGE_SIZE,
386 });
387 }
388
389 let mut buffer = vec![0u8; message_len];
391 stream.read_exact(&mut buffer).await?;
392 Ok(buffer)
393 }
394 AgentConnection::Grpc(_channel) => {
395 unimplemented!("gRPC transport not yet implemented")
397 }
398 }
399 }
400
401 pub async fn close(self) -> Result<(), AgentProtocolError> {
403 match self.connection {
404 AgentConnection::UnixSocket(mut stream) => {
405 stream.shutdown().await?;
406 Ok(())
407 }
408 AgentConnection::Grpc(_) => Ok(()),
409 }
410 }
411}
412
413pub struct AgentServer {
415 id: String,
417 socket_path: std::path::PathBuf,
419 handler: Arc<dyn AgentHandler>,
421}
422
423#[async_trait]
425pub trait AgentHandler: Send + Sync {
426 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
428 AgentResponse::default_allow()
429 }
430
431 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
433 AgentResponse::default_allow()
434 }
435
436 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
438 AgentResponse::default_allow()
439 }
440
441 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
443 AgentResponse::default_allow()
444 }
445
446 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
448 AgentResponse::default_allow()
449 }
450}
451
452impl AgentResponse {
453 pub fn default_allow() -> Self {
455 Self {
456 version: PROTOCOL_VERSION,
457 decision: Decision::Allow,
458 request_headers: vec![],
459 response_headers: vec![],
460 routing_metadata: HashMap::new(),
461 audit: AuditMetadata::default(),
462 }
463 }
464
465 pub fn block(status: u16, body: Option<String>) -> Self {
467 Self {
468 version: PROTOCOL_VERSION,
469 decision: Decision::Block {
470 status,
471 body,
472 headers: None,
473 },
474 request_headers: vec![],
475 response_headers: vec![],
476 routing_metadata: HashMap::new(),
477 audit: AuditMetadata::default(),
478 }
479 }
480
481 pub fn redirect(url: String, status: u16) -> Self {
483 Self {
484 version: PROTOCOL_VERSION,
485 decision: Decision::Redirect { url, status },
486 request_headers: vec![],
487 response_headers: vec![],
488 routing_metadata: HashMap::new(),
489 audit: AuditMetadata::default(),
490 }
491 }
492
493 pub fn add_request_header(mut self, op: HeaderOp) -> Self {
495 self.request_headers.push(op);
496 self
497 }
498
499 pub fn add_response_header(mut self, op: HeaderOp) -> Self {
501 self.response_headers.push(op);
502 self
503 }
504
505 pub fn with_audit(mut self, audit: AuditMetadata) -> Self {
507 self.audit = audit;
508 self
509 }
510}
511
512impl AgentServer {
513 pub fn new(
515 id: impl Into<String>,
516 socket_path: impl Into<std::path::PathBuf>,
517 handler: Box<dyn AgentHandler>,
518 ) -> Self {
519 Self {
520 id: id.into(),
521 socket_path: socket_path.into(),
522 handler: Arc::from(handler),
523 }
524 }
525
526 pub async fn run(&self) -> Result<(), AgentProtocolError> {
528 if self.socket_path.exists() {
530 std::fs::remove_file(&self.socket_path)?;
531 }
532
533 let listener = UnixListener::bind(&self.socket_path)?;
535
536 info!(
537 "Agent server '{}' listening on {:?}",
538 self.id, self.socket_path
539 );
540
541 loop {
542 match listener.accept().await {
543 Ok((stream, _addr)) => {
544 let handler = Arc::clone(&self.handler);
545 tokio::spawn(async move {
546 if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
547 error!("Error handling agent connection: {}", e);
548 }
549 });
550 }
551 Err(e) => {
552 error!("Failed to accept connection: {}", e);
553 }
554 }
555 }
556 }
557
558 async fn handle_connection(
560 mut stream: UnixStream,
561 handler: &dyn AgentHandler,
562 ) -> Result<(), AgentProtocolError> {
563 loop {
564 let mut len_bytes = [0u8; 4];
566 match stream.read_exact(&mut len_bytes).await {
567 Ok(_) => {}
568 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
569 return Ok(());
571 }
572 Err(e) => return Err(e.into()),
573 }
574
575 let message_len = u32::from_be_bytes(len_bytes) as usize;
576
577 if message_len > MAX_MESSAGE_SIZE {
579 return Err(AgentProtocolError::MessageTooLarge {
580 size: message_len,
581 max: MAX_MESSAGE_SIZE,
582 });
583 }
584
585 let mut buffer = vec![0u8; message_len];
587 stream.read_exact(&mut buffer).await?;
588
589 let request: AgentRequest = serde_json::from_slice(&buffer)
591 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
592
593 let response = match request.event_type {
595 EventType::RequestHeaders => {
596 let event: RequestHeadersEvent = serde_json::from_value(request.payload)
597 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
598 handler.on_request_headers(event).await
599 }
600 EventType::RequestBodyChunk => {
601 let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
602 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
603 handler.on_request_body_chunk(event).await
604 }
605 EventType::ResponseHeaders => {
606 let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
607 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
608 handler.on_response_headers(event).await
609 }
610 EventType::ResponseBodyChunk => {
611 let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
612 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
613 handler.on_response_body_chunk(event).await
614 }
615 EventType::RequestComplete => {
616 let event: RequestCompleteEvent = serde_json::from_value(request.payload)
617 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
618 handler.on_request_complete(event).await
619 }
620 };
621
622 let response_bytes = serde_json::to_vec(&response)
624 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
625
626 let len_bytes = (response_bytes.len() as u32).to_be_bytes();
628 stream.write_all(&len_bytes).await?;
629 stream.write_all(&response_bytes).await?;
631 stream.flush().await?;
632 }
633 }
634}
635
636pub struct EchoAgent;
638
639#[async_trait]
640impl AgentHandler for EchoAgent {
641 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
642 debug!("Echo agent: request headers for {}", event.metadata.correlation_id);
643
644 AgentResponse::default_allow()
646 .add_request_header(HeaderOp::Set {
647 name: "X-Echo-Agent".to_string(),
648 value: event.metadata.correlation_id.clone(),
649 })
650 .with_audit(AuditMetadata {
651 tags: vec!["echo".to_string()],
652 ..Default::default()
653 })
654 }
655}
656
657pub struct DenylistAgent {
659 blocked_paths: Vec<String>,
660 blocked_ips: Vec<String>,
661}
662
663impl DenylistAgent {
664 pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
665 Self {
666 blocked_paths,
667 blocked_ips,
668 }
669 }
670}
671
672#[async_trait]
673impl AgentHandler for DenylistAgent {
674 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
675 for blocked_path in &self.blocked_paths {
677 if event.uri.starts_with(blocked_path) {
678 return AgentResponse::block(403, Some("Forbidden path".to_string()))
679 .with_audit(AuditMetadata {
680 tags: vec!["denylist".to_string(), "blocked_path".to_string()],
681 reason_codes: vec!["PATH_BLOCKED".to_string()],
682 ..Default::default()
683 });
684 }
685 }
686
687 if self.blocked_ips.contains(&event.metadata.client_ip) {
689 return AgentResponse::block(403, Some("Forbidden IP".to_string()))
690 .with_audit(AuditMetadata {
691 tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
692 reason_codes: vec!["IP_BLOCKED".to_string()],
693 ..Default::default()
694 });
695 }
696
697 AgentResponse::default_allow()
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use tempfile::tempdir;
705
706 #[tokio::test]
707 async fn test_agent_protocol_echo() {
708 let dir = tempdir().unwrap();
709 let socket_path = dir.path().join("test.sock");
710
711 let server = AgentServer::new(
713 "test-echo",
714 socket_path.clone(),
715 Box::new(EchoAgent),
716 );
717
718 let server_handle = tokio::spawn(async move {
719 server.run().await.unwrap();
720 });
721
722 tokio::time::sleep(Duration::from_millis(100)).await;
724
725 let mut client = AgentClient::unix_socket(
727 "test-client",
728 &socket_path,
729 Duration::from_secs(5),
730 )
731 .await
732 .unwrap();
733
734 let event = RequestHeadersEvent {
736 metadata: RequestMetadata {
737 correlation_id: "test-123".to_string(),
738 request_id: "req-456".to_string(),
739 client_ip: "127.0.0.1".to_string(),
740 client_port: 12345,
741 server_name: Some("example.com".to_string()),
742 protocol: "HTTP/1.1".to_string(),
743 tls_version: None,
744 tls_cipher: None,
745 route_id: Some("default".to_string()),
746 upstream_id: Some("backend".to_string()),
747 timestamp: chrono::Utc::now().to_rfc3339(),
748 },
749 method: "GET".to_string(),
750 uri: "/test".to_string(),
751 headers: HashMap::new(),
752 };
753
754 let response = client.send_event(EventType::RequestHeaders, &event)
755 .await
756 .unwrap();
757
758 assert_eq!(response.decision, Decision::Allow);
760 assert_eq!(response.request_headers.len(), 1);
761
762 client.close().await.unwrap();
764 server_handle.abort();
765 }
766
767 #[tokio::test]
768 async fn test_agent_protocol_denylist() {
769 let dir = tempdir().unwrap();
770 let socket_path = dir.path().join("denylist.sock");
771
772 let agent = DenylistAgent::new(
774 vec!["/admin".to_string()],
775 vec!["10.0.0.1".to_string()],
776 );
777 let server = AgentServer::new(
778 "test-denylist",
779 socket_path.clone(),
780 Box::new(agent),
781 );
782
783 let server_handle = tokio::spawn(async move {
784 server.run().await.unwrap();
785 });
786
787 tokio::time::sleep(Duration::from_millis(100)).await;
789
790 let mut client = AgentClient::unix_socket(
792 "test-client",
793 &socket_path,
794 Duration::from_secs(5),
795 )
796 .await
797 .unwrap();
798
799 let event = RequestHeadersEvent {
801 metadata: RequestMetadata {
802 correlation_id: "test-123".to_string(),
803 request_id: "req-456".to_string(),
804 client_ip: "127.0.0.1".to_string(),
805 client_port: 12345,
806 server_name: Some("example.com".to_string()),
807 protocol: "HTTP/1.1".to_string(),
808 tls_version: None,
809 tls_cipher: None,
810 route_id: Some("default".to_string()),
811 upstream_id: Some("backend".to_string()),
812 timestamp: chrono::Utc::now().to_rfc3339(),
813 },
814 method: "GET".to_string(),
815 uri: "/admin/secret".to_string(),
816 headers: HashMap::new(),
817 };
818
819 let response = client.send_event(EventType::RequestHeaders, &event)
820 .await
821 .unwrap();
822
823 match response.decision {
825 Decision::Block { status, .. } => assert_eq!(status, 403),
826 _ => panic!("Expected block decision"),
827 }
828
829 client.close().await.unwrap();
831 server_handle.abort();
832 }
833}