sentinel_agent_protocol/
server.rs

1//! Agent server for implementing external agents.
2
3use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{UnixListener, UnixStream};
7use tracing::{debug, error, info};
8
9use crate::errors::AgentProtocolError;
10use crate::protocol::{
11    AgentRequest, AgentResponse, AuditMetadata, EventType, HeaderOp, RequestBodyChunkEvent,
12    RequestCompleteEvent, RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent,
13    MAX_MESSAGE_SIZE,
14};
15
16/// Agent server for testing and reference implementations
17pub struct AgentServer {
18    /// Agent ID
19    id: String,
20    /// Unix socket path
21    socket_path: std::path::PathBuf,
22    /// Request handler
23    handler: Arc<dyn AgentHandler>,
24}
25
26/// Trait for implementing agent logic
27#[async_trait]
28pub trait AgentHandler: Send + Sync {
29    /// Handle a request headers event
30    async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
31        AgentResponse::default_allow()
32    }
33
34    /// Handle a request body chunk event
35    async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
36        AgentResponse::default_allow()
37    }
38
39    /// Handle a response headers event
40    async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
41        AgentResponse::default_allow()
42    }
43
44    /// Handle a response body chunk event
45    async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
46        AgentResponse::default_allow()
47    }
48
49    /// Handle a request complete event
50    async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
51        AgentResponse::default_allow()
52    }
53}
54
55impl AgentServer {
56    /// Create a new agent server
57    pub fn new(
58        id: impl Into<String>,
59        socket_path: impl Into<std::path::PathBuf>,
60        handler: Box<dyn AgentHandler>,
61    ) -> Self {
62        Self {
63            id: id.into(),
64            socket_path: socket_path.into(),
65            handler: Arc::from(handler),
66        }
67    }
68
69    /// Start the agent server
70    pub async fn run(&self) -> Result<(), AgentProtocolError> {
71        // Remove existing socket file if it exists
72        if self.socket_path.exists() {
73            std::fs::remove_file(&self.socket_path)?;
74        }
75
76        // Create Unix socket listener
77        let listener = UnixListener::bind(&self.socket_path)?;
78
79        info!(
80            "Agent server '{}' listening on {:?}",
81            self.id, self.socket_path
82        );
83
84        loop {
85            match listener.accept().await {
86                Ok((stream, _addr)) => {
87                    let handler = Arc::clone(&self.handler);
88                    tokio::spawn(async move {
89                        if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
90                            error!("Error handling agent connection: {}", e);
91                        }
92                    });
93                }
94                Err(e) => {
95                    error!("Failed to accept connection: {}", e);
96                }
97            }
98        }
99    }
100
101    /// Handle a single connection
102    async fn handle_connection(
103        mut stream: UnixStream,
104        handler: &dyn AgentHandler,
105    ) -> Result<(), AgentProtocolError> {
106        loop {
107            // Read message length
108            let mut len_bytes = [0u8; 4];
109            match stream.read_exact(&mut len_bytes).await {
110                Ok(_) => {}
111                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
112                    // Client disconnected
113                    return Ok(());
114                }
115                Err(e) => return Err(e.into()),
116            }
117
118            let message_len = u32::from_be_bytes(len_bytes) as usize;
119
120            // Check message size
121            if message_len > MAX_MESSAGE_SIZE {
122                return Err(AgentProtocolError::MessageTooLarge {
123                    size: message_len,
124                    max: MAX_MESSAGE_SIZE,
125                });
126            }
127
128            // Read message data
129            let mut buffer = vec![0u8; message_len];
130            stream.read_exact(&mut buffer).await?;
131
132            // Parse request
133            let request: AgentRequest = serde_json::from_slice(&buffer)
134                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
135
136            // Handle request based on event type
137            let response = match request.event_type {
138                EventType::RequestHeaders => {
139                    let event: RequestHeadersEvent = serde_json::from_value(request.payload)
140                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
141                    handler.on_request_headers(event).await
142                }
143                EventType::RequestBodyChunk => {
144                    let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
145                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
146                    handler.on_request_body_chunk(event).await
147                }
148                EventType::ResponseHeaders => {
149                    let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
150                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
151                    handler.on_response_headers(event).await
152                }
153                EventType::ResponseBodyChunk => {
154                    let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
155                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
156                    handler.on_response_body_chunk(event).await
157                }
158                EventType::RequestComplete => {
159                    let event: RequestCompleteEvent = serde_json::from_value(request.payload)
160                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
161                    handler.on_request_complete(event).await
162                }
163            };
164
165            // Send response
166            let response_bytes = serde_json::to_vec(&response)
167                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
168
169            // Write message length
170            let len_bytes = (response_bytes.len() as u32).to_be_bytes();
171            stream.write_all(&len_bytes).await?;
172            // Write message data
173            stream.write_all(&response_bytes).await?;
174            stream.flush().await?;
175        }
176    }
177}
178
179/// Reference implementation: Echo agent (for testing)
180pub struct EchoAgent;
181
182#[async_trait]
183impl AgentHandler for EchoAgent {
184    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
185        debug!(
186            "Echo agent: request headers for {}",
187            event.metadata.correlation_id
188        );
189
190        // Echo back correlation ID as a header
191        AgentResponse::default_allow()
192            .add_request_header(HeaderOp::Set {
193                name: "X-Echo-Agent".to_string(),
194                value: event.metadata.correlation_id.clone(),
195            })
196            .with_audit(AuditMetadata {
197                tags: vec!["echo".to_string()],
198                ..Default::default()
199            })
200    }
201}
202
203/// Reference implementation: Denylist agent
204pub struct DenylistAgent {
205    blocked_paths: Vec<String>,
206    blocked_ips: Vec<String>,
207}
208
209impl DenylistAgent {
210    pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
211        Self {
212            blocked_paths,
213            blocked_ips,
214        }
215    }
216}
217
218#[async_trait]
219impl AgentHandler for DenylistAgent {
220    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
221        // Check if path is blocked
222        for blocked_path in &self.blocked_paths {
223            if event.uri.starts_with(blocked_path) {
224                return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
225                    AuditMetadata {
226                        tags: vec!["denylist".to_string(), "blocked_path".to_string()],
227                        reason_codes: vec!["PATH_BLOCKED".to_string()],
228                        ..Default::default()
229                    },
230                );
231            }
232        }
233
234        // Check if IP is blocked
235        if self.blocked_ips.contains(&event.metadata.client_ip) {
236            return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
237                AuditMetadata {
238                    tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
239                    reason_codes: vec!["IP_BLOCKED".to_string()],
240                    ..Default::default()
241                },
242            );
243        }
244
245        AgentResponse::default_allow()
246    }
247}