sentinel_agent_protocol/
server.rs1use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{UnixListener, UnixStream};
7use tracing::{debug, error, info};
8
9use crate::errors::AgentProtocolError;
10use crate::protocol::{
11 AgentRequest, AgentResponse, AuditMetadata, EventType, HeaderOp, RequestBodyChunkEvent,
12 RequestCompleteEvent, RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent,
13 MAX_MESSAGE_SIZE,
14};
15
16pub struct AgentServer {
18 id: String,
20 socket_path: std::path::PathBuf,
22 handler: Arc<dyn AgentHandler>,
24}
25
26#[async_trait]
28pub trait AgentHandler: Send + Sync {
29 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
31 AgentResponse::default_allow()
32 }
33
34 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
36 AgentResponse::default_allow()
37 }
38
39 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
41 AgentResponse::default_allow()
42 }
43
44 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
46 AgentResponse::default_allow()
47 }
48
49 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
51 AgentResponse::default_allow()
52 }
53}
54
55impl AgentServer {
56 pub fn new(
58 id: impl Into<String>,
59 socket_path: impl Into<std::path::PathBuf>,
60 handler: Box<dyn AgentHandler>,
61 ) -> Self {
62 Self {
63 id: id.into(),
64 socket_path: socket_path.into(),
65 handler: Arc::from(handler),
66 }
67 }
68
69 pub async fn run(&self) -> Result<(), AgentProtocolError> {
71 if self.socket_path.exists() {
73 std::fs::remove_file(&self.socket_path)?;
74 }
75
76 let listener = UnixListener::bind(&self.socket_path)?;
78
79 info!(
80 "Agent server '{}' listening on {:?}",
81 self.id, self.socket_path
82 );
83
84 loop {
85 match listener.accept().await {
86 Ok((stream, _addr)) => {
87 let handler = Arc::clone(&self.handler);
88 tokio::spawn(async move {
89 if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
90 error!("Error handling agent connection: {}", e);
91 }
92 });
93 }
94 Err(e) => {
95 error!("Failed to accept connection: {}", e);
96 }
97 }
98 }
99 }
100
101 async fn handle_connection(
103 mut stream: UnixStream,
104 handler: &dyn AgentHandler,
105 ) -> Result<(), AgentProtocolError> {
106 loop {
107 let mut len_bytes = [0u8; 4];
109 match stream.read_exact(&mut len_bytes).await {
110 Ok(_) => {}
111 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
112 return Ok(());
114 }
115 Err(e) => return Err(e.into()),
116 }
117
118 let message_len = u32::from_be_bytes(len_bytes) as usize;
119
120 if message_len > MAX_MESSAGE_SIZE {
122 return Err(AgentProtocolError::MessageTooLarge {
123 size: message_len,
124 max: MAX_MESSAGE_SIZE,
125 });
126 }
127
128 let mut buffer = vec![0u8; message_len];
130 stream.read_exact(&mut buffer).await?;
131
132 let request: AgentRequest = serde_json::from_slice(&buffer)
134 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
135
136 let response = match request.event_type {
138 EventType::RequestHeaders => {
139 let event: RequestHeadersEvent = serde_json::from_value(request.payload)
140 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
141 handler.on_request_headers(event).await
142 }
143 EventType::RequestBodyChunk => {
144 let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
145 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
146 handler.on_request_body_chunk(event).await
147 }
148 EventType::ResponseHeaders => {
149 let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
150 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
151 handler.on_response_headers(event).await
152 }
153 EventType::ResponseBodyChunk => {
154 let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
155 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
156 handler.on_response_body_chunk(event).await
157 }
158 EventType::RequestComplete => {
159 let event: RequestCompleteEvent = serde_json::from_value(request.payload)
160 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
161 handler.on_request_complete(event).await
162 }
163 };
164
165 let response_bytes = serde_json::to_vec(&response)
167 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
168
169 let len_bytes = (response_bytes.len() as u32).to_be_bytes();
171 stream.write_all(&len_bytes).await?;
172 stream.write_all(&response_bytes).await?;
174 stream.flush().await?;
175 }
176 }
177}
178
179pub struct EchoAgent;
181
182#[async_trait]
183impl AgentHandler for EchoAgent {
184 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
185 debug!(
186 "Echo agent: request headers for {}",
187 event.metadata.correlation_id
188 );
189
190 AgentResponse::default_allow()
192 .add_request_header(HeaderOp::Set {
193 name: "X-Echo-Agent".to_string(),
194 value: event.metadata.correlation_id.clone(),
195 })
196 .with_audit(AuditMetadata {
197 tags: vec!["echo".to_string()],
198 ..Default::default()
199 })
200 }
201}
202
203pub struct DenylistAgent {
205 blocked_paths: Vec<String>,
206 blocked_ips: Vec<String>,
207}
208
209impl DenylistAgent {
210 pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
211 Self {
212 blocked_paths,
213 blocked_ips,
214 }
215 }
216}
217
218#[async_trait]
219impl AgentHandler for DenylistAgent {
220 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
221 for blocked_path in &self.blocked_paths {
223 if event.uri.starts_with(blocked_path) {
224 return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
225 AuditMetadata {
226 tags: vec!["denylist".to_string(), "blocked_path".to_string()],
227 reason_codes: vec!["PATH_BLOCKED".to_string()],
228 ..Default::default()
229 },
230 );
231 }
232 }
233
234 if self.blocked_ips.contains(&event.metadata.client_ip) {
236 return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
237 AuditMetadata {
238 tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
239 reason_codes: vec!["IP_BLOCKED".to_string()],
240 ..Default::default()
241 },
242 );
243 }
244
245 AgentResponse::default_allow()
246 }
247}