1use async_trait::async_trait;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_stream::StreamExt;
13use tonic::{Request, Response, Status, Streaming};
14use tracing::{debug, error, info};
15
16use crate::errors::AgentProtocolError;
17use crate::grpc::{self, agent_processor_server::AgentProcessor, agent_processor_server::AgentProcessorServer};
18use crate::protocol::{
19 AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
20 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
21 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
22};
23
24pub struct AgentServer {
26 id: String,
28 socket_path: std::path::PathBuf,
30 handler: Arc<dyn AgentHandler>,
32}
33
34#[async_trait]
36pub trait AgentHandler: Send + Sync {
37 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
39 AgentResponse::default_allow()
40 }
41
42 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
44 AgentResponse::default_allow()
45 }
46
47 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
49 AgentResponse::default_allow()
50 }
51
52 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
54 AgentResponse::default_allow()
55 }
56
57 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
59 AgentResponse::default_allow()
60 }
61}
62
63impl AgentServer {
64 pub fn new(
66 id: impl Into<String>,
67 socket_path: impl Into<std::path::PathBuf>,
68 handler: Box<dyn AgentHandler>,
69 ) -> Self {
70 Self {
71 id: id.into(),
72 socket_path: socket_path.into(),
73 handler: Arc::from(handler),
74 }
75 }
76
77 pub async fn run(&self) -> Result<(), AgentProtocolError> {
79 if self.socket_path.exists() {
81 std::fs::remove_file(&self.socket_path)?;
82 }
83
84 let listener = UnixListener::bind(&self.socket_path)?;
86
87 info!(
88 "Agent server '{}' listening on {:?}",
89 self.id, self.socket_path
90 );
91
92 loop {
93 match listener.accept().await {
94 Ok((stream, _addr)) => {
95 let handler = Arc::clone(&self.handler);
96 tokio::spawn(async move {
97 if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
98 error!("Error handling agent connection: {}", e);
99 }
100 });
101 }
102 Err(e) => {
103 error!("Failed to accept connection: {}", e);
104 }
105 }
106 }
107 }
108
109 async fn handle_connection(
111 mut stream: UnixStream,
112 handler: &dyn AgentHandler,
113 ) -> Result<(), AgentProtocolError> {
114 loop {
115 let mut len_bytes = [0u8; 4];
117 match stream.read_exact(&mut len_bytes).await {
118 Ok(_) => {}
119 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
120 return Ok(());
122 }
123 Err(e) => return Err(e.into()),
124 }
125
126 let message_len = u32::from_be_bytes(len_bytes) as usize;
127
128 if message_len > MAX_MESSAGE_SIZE {
130 return Err(AgentProtocolError::MessageTooLarge {
131 size: message_len,
132 max: MAX_MESSAGE_SIZE,
133 });
134 }
135
136 let mut buffer = vec![0u8; message_len];
138 stream.read_exact(&mut buffer).await?;
139
140 let request: AgentRequest = serde_json::from_slice(&buffer)
142 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
143
144 let response = match request.event_type {
146 EventType::RequestHeaders => {
147 let event: RequestHeadersEvent = serde_json::from_value(request.payload)
148 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
149 handler.on_request_headers(event).await
150 }
151 EventType::RequestBodyChunk => {
152 let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
153 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
154 handler.on_request_body_chunk(event).await
155 }
156 EventType::ResponseHeaders => {
157 let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
158 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
159 handler.on_response_headers(event).await
160 }
161 EventType::ResponseBodyChunk => {
162 let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
163 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
164 handler.on_response_body_chunk(event).await
165 }
166 EventType::RequestComplete => {
167 let event: RequestCompleteEvent = serde_json::from_value(request.payload)
168 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
169 handler.on_request_complete(event).await
170 }
171 };
172
173 let response_bytes = serde_json::to_vec(&response)
175 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
176
177 let len_bytes = (response_bytes.len() as u32).to_be_bytes();
179 stream.write_all(&len_bytes).await?;
180 stream.write_all(&response_bytes).await?;
182 stream.flush().await?;
183 }
184 }
185}
186
187pub struct EchoAgent;
189
190#[async_trait]
191impl AgentHandler for EchoAgent {
192 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
193 debug!(
194 "Echo agent: request headers for {}",
195 event.metadata.correlation_id
196 );
197
198 AgentResponse::default_allow()
200 .add_request_header(HeaderOp::Set {
201 name: "X-Echo-Agent".to_string(),
202 value: event.metadata.correlation_id.clone(),
203 })
204 .with_audit(AuditMetadata {
205 tags: vec!["echo".to_string()],
206 ..Default::default()
207 })
208 }
209}
210
211pub struct DenylistAgent {
213 blocked_paths: Vec<String>,
214 blocked_ips: Vec<String>,
215}
216
217impl DenylistAgent {
218 pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
219 Self {
220 blocked_paths,
221 blocked_ips,
222 }
223 }
224}
225
226#[async_trait]
227impl AgentHandler for DenylistAgent {
228 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
229 for blocked_path in &self.blocked_paths {
231 if event.uri.starts_with(blocked_path) {
232 return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
233 AuditMetadata {
234 tags: vec!["denylist".to_string(), "blocked_path".to_string()],
235 reason_codes: vec!["PATH_BLOCKED".to_string()],
236 ..Default::default()
237 },
238 );
239 }
240 }
241
242 if self.blocked_ips.contains(&event.metadata.client_ip) {
244 return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
245 AuditMetadata {
246 tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
247 reason_codes: vec!["IP_BLOCKED".to_string()],
248 ..Default::default()
249 },
250 );
251 }
252
253 AgentResponse::default_allow()
254 }
255}
256
257pub struct GrpcAgentServer {
263 id: String,
265 handler: Arc<dyn AgentHandler>,
267}
268
269impl GrpcAgentServer {
270 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandler>) -> Self {
272 Self {
273 id: id.into(),
274 handler: Arc::from(handler),
275 }
276 }
277
278 pub fn into_service(self) -> AgentProcessorServer<GrpcAgentHandler> {
280 AgentProcessorServer::new(GrpcAgentHandler {
281 id: self.id,
282 handler: self.handler,
283 })
284 }
285
286 pub async fn run(self, addr: SocketAddr) -> Result<(), AgentProtocolError> {
288 info!("gRPC agent server '{}' listening on {}", self.id, addr);
289
290 tonic::transport::Server::builder()
291 .add_service(self.into_service())
292 .serve(addr)
293 .await
294 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC server error: {}", e)))
295 }
296}
297
298pub struct GrpcAgentHandler {
300 id: String,
301 handler: Arc<dyn AgentHandler>,
302}
303
304#[tonic::async_trait]
305impl AgentProcessor for GrpcAgentHandler {
306 async fn process_event(
307 &self,
308 request: Request<grpc::AgentRequest>,
309 ) -> Result<Response<grpc::AgentResponse>, Status> {
310 let grpc_request = request.into_inner();
311
312 let response = match grpc_request.event {
314 Some(grpc::agent_request::Event::RequestHeaders(e)) => {
315 let event = Self::convert_request_headers_from_grpc(e);
316 self.handler.on_request_headers(event).await
317 }
318 Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
319 let event = Self::convert_request_body_chunk_from_grpc(e);
320 self.handler.on_request_body_chunk(event).await
321 }
322 Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
323 let event = Self::convert_response_headers_from_grpc(e);
324 self.handler.on_response_headers(event).await
325 }
326 Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
327 let event = Self::convert_response_body_chunk_from_grpc(e);
328 self.handler.on_response_body_chunk(event).await
329 }
330 Some(grpc::agent_request::Event::RequestComplete(e)) => {
331 let event = Self::convert_request_complete_from_grpc(e);
332 self.handler.on_request_complete(event).await
333 }
334 None => {
335 return Err(Status::invalid_argument("Missing event in request"));
336 }
337 };
338
339 let grpc_response = Self::convert_response_to_grpc(response);
341 Ok(Response::new(grpc_response))
342 }
343
344 async fn process_event_stream(
345 &self,
346 request: Request<Streaming<grpc::AgentRequest>>,
347 ) -> Result<Response<grpc::AgentResponse>, Status> {
348 let mut stream = request.into_inner();
349
350 let mut final_response = AgentResponse::default_allow();
352
353 while let Some(result) = stream.next().await {
354 let grpc_request = result.map_err(|e| Status::internal(format!("Stream error: {}", e)))?;
355
356 let response = match grpc_request.event {
357 Some(grpc::agent_request::Event::RequestHeaders(e)) => {
358 let event = Self::convert_request_headers_from_grpc(e);
359 self.handler.on_request_headers(event).await
360 }
361 Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
362 let event = Self::convert_request_body_chunk_from_grpc(e);
363 self.handler.on_request_body_chunk(event).await
364 }
365 Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
366 let event = Self::convert_response_headers_from_grpc(e);
367 self.handler.on_response_headers(event).await
368 }
369 Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
370 let event = Self::convert_response_body_chunk_from_grpc(e);
371 self.handler.on_response_body_chunk(event).await
372 }
373 Some(grpc::agent_request::Event::RequestComplete(e)) => {
374 let event = Self::convert_request_complete_from_grpc(e);
375 self.handler.on_request_complete(event).await
376 }
377 None => continue,
378 };
379
380 if !matches!(response.decision, Decision::Allow) {
382 final_response = response;
383 break;
384 }
385 final_response = response;
386 }
387
388 let grpc_response = Self::convert_response_to_grpc(final_response);
389 Ok(Response::new(grpc_response))
390 }
391}
392
393impl GrpcAgentHandler {
394 fn convert_request_headers_from_grpc(e: grpc::RequestHeadersEvent) -> RequestHeadersEvent {
396 RequestHeadersEvent {
397 metadata: Self::convert_metadata_from_grpc(e.metadata),
398 method: e.method,
399 uri: e.uri,
400 headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
401 }
402 }
403
404 fn convert_request_body_chunk_from_grpc(e: grpc::RequestBodyChunkEvent) -> RequestBodyChunkEvent {
406 RequestBodyChunkEvent {
407 correlation_id: e.correlation_id,
408 data: String::from_utf8_lossy(&e.data).to_string(),
409 is_last: e.is_last,
410 total_size: e.total_size.map(|s| s as usize),
411 }
412 }
413
414 fn convert_response_headers_from_grpc(e: grpc::ResponseHeadersEvent) -> ResponseHeadersEvent {
416 ResponseHeadersEvent {
417 correlation_id: e.correlation_id,
418 status: e.status as u16,
419 headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
420 }
421 }
422
423 fn convert_response_body_chunk_from_grpc(e: grpc::ResponseBodyChunkEvent) -> ResponseBodyChunkEvent {
425 ResponseBodyChunkEvent {
426 correlation_id: e.correlation_id,
427 data: String::from_utf8_lossy(&e.data).to_string(),
428 is_last: e.is_last,
429 total_size: e.total_size.map(|s| s as usize),
430 }
431 }
432
433 fn convert_request_complete_from_grpc(e: grpc::RequestCompleteEvent) -> RequestCompleteEvent {
435 RequestCompleteEvent {
436 correlation_id: e.correlation_id,
437 status: e.status as u16,
438 duration_ms: e.duration_ms,
439 request_body_size: e.request_body_size as usize,
440 response_body_size: e.response_body_size as usize,
441 upstream_attempts: e.upstream_attempts,
442 error: e.error,
443 }
444 }
445
446 fn convert_metadata_from_grpc(metadata: Option<grpc::RequestMetadata>) -> RequestMetadata {
448 match metadata {
449 Some(m) => RequestMetadata {
450 correlation_id: m.correlation_id,
451 request_id: m.request_id,
452 client_ip: m.client_ip,
453 client_port: m.client_port as u16,
454 server_name: m.server_name,
455 protocol: m.protocol,
456 tls_version: m.tls_version,
457 tls_cipher: m.tls_cipher,
458 route_id: m.route_id,
459 upstream_id: m.upstream_id,
460 timestamp: m.timestamp,
461 },
462 None => RequestMetadata {
463 correlation_id: String::new(),
464 request_id: String::new(),
465 client_ip: String::new(),
466 client_port: 0,
467 server_name: None,
468 protocol: String::new(),
469 tls_version: None,
470 tls_cipher: None,
471 route_id: None,
472 upstream_id: None,
473 timestamp: String::new(),
474 },
475 }
476 }
477
478 fn convert_response_to_grpc(response: AgentResponse) -> grpc::AgentResponse {
480 let decision = match response.decision {
481 Decision::Allow => Some(grpc::agent_response::Decision::Allow(grpc::AllowDecision {})),
482 Decision::Block { status, body, headers } => {
483 Some(grpc::agent_response::Decision::Block(grpc::BlockDecision {
484 status: status as u32,
485 body,
486 headers: headers.unwrap_or_default(),
487 }))
488 }
489 Decision::Redirect { url, status } => {
490 Some(grpc::agent_response::Decision::Redirect(grpc::RedirectDecision {
491 url,
492 status: status as u32,
493 }))
494 }
495 Decision::Challenge { challenge_type, params } => {
496 Some(grpc::agent_response::Decision::Challenge(grpc::ChallengeDecision {
497 challenge_type,
498 params,
499 }))
500 }
501 };
502
503 let request_headers: Vec<grpc::HeaderOp> = response.request_headers
504 .into_iter()
505 .map(Self::convert_header_op_to_grpc)
506 .collect();
507
508 let response_headers: Vec<grpc::HeaderOp> = response.response_headers
509 .into_iter()
510 .map(Self::convert_header_op_to_grpc)
511 .collect();
512
513 let audit = Some(grpc::AuditMetadata {
514 tags: response.audit.tags,
515 rule_ids: response.audit.rule_ids,
516 confidence: response.audit.confidence,
517 reason_codes: response.audit.reason_codes,
518 custom: response.audit.custom.into_iter().map(|(k, v)| {
519 (k, v.to_string())
520 }).collect(),
521 });
522
523 grpc::AgentResponse {
524 version: PROTOCOL_VERSION,
525 decision,
526 request_headers,
527 response_headers,
528 routing_metadata: response.routing_metadata,
529 audit,
530 }
531 }
532
533 fn convert_header_op_to_grpc(op: HeaderOp) -> grpc::HeaderOp {
535 let operation = match op {
536 HeaderOp::Set { name, value } => {
537 Some(grpc::header_op::Operation::Set(grpc::SetHeader { name, value }))
538 }
539 HeaderOp::Add { name, value } => {
540 Some(grpc::header_op::Operation::Add(grpc::AddHeader { name, value }))
541 }
542 HeaderOp::Remove { name } => {
543 Some(grpc::header_op::Operation::Remove(grpc::RemoveHeader { name }))
544 }
545 };
546 grpc::HeaderOp { operation }
547 }
548}