sentinel_agent_protocol/
server.rs

1//! Agent server for implementing external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use async_trait::async_trait;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_stream::StreamExt;
13use tonic::{Request, Response, Status, Streaming};
14use tracing::{debug, error, info};
15
16use crate::errors::AgentProtocolError;
17use crate::grpc::{self, agent_processor_server::AgentProcessor, agent_processor_server::AgentProcessorServer};
18use crate::protocol::{
19    AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
20    RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
21    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
22};
23
24/// Agent server for testing and reference implementations
25pub struct AgentServer {
26    /// Agent ID
27    id: String,
28    /// Unix socket path
29    socket_path: std::path::PathBuf,
30    /// Request handler
31    handler: Arc<dyn AgentHandler>,
32}
33
34/// Trait for implementing agent logic
35#[async_trait]
36pub trait AgentHandler: Send + Sync {
37    /// Handle a request headers event
38    async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
39        AgentResponse::default_allow()
40    }
41
42    /// Handle a request body chunk event
43    async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
44        AgentResponse::default_allow()
45    }
46
47    /// Handle a response headers event
48    async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
49        AgentResponse::default_allow()
50    }
51
52    /// Handle a response body chunk event
53    async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
54        AgentResponse::default_allow()
55    }
56
57    /// Handle a request complete event
58    async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
59        AgentResponse::default_allow()
60    }
61}
62
63impl AgentServer {
64    /// Create a new agent server
65    pub fn new(
66        id: impl Into<String>,
67        socket_path: impl Into<std::path::PathBuf>,
68        handler: Box<dyn AgentHandler>,
69    ) -> Self {
70        Self {
71            id: id.into(),
72            socket_path: socket_path.into(),
73            handler: Arc::from(handler),
74        }
75    }
76
77    /// Start the agent server
78    pub async fn run(&self) -> Result<(), AgentProtocolError> {
79        // Remove existing socket file if it exists
80        if self.socket_path.exists() {
81            std::fs::remove_file(&self.socket_path)?;
82        }
83
84        // Create Unix socket listener
85        let listener = UnixListener::bind(&self.socket_path)?;
86
87        info!(
88            "Agent server '{}' listening on {:?}",
89            self.id, self.socket_path
90        );
91
92        loop {
93            match listener.accept().await {
94                Ok((stream, _addr)) => {
95                    let handler = Arc::clone(&self.handler);
96                    tokio::spawn(async move {
97                        if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
98                            error!("Error handling agent connection: {}", e);
99                        }
100                    });
101                }
102                Err(e) => {
103                    error!("Failed to accept connection: {}", e);
104                }
105            }
106        }
107    }
108
109    /// Handle a single connection
110    async fn handle_connection(
111        mut stream: UnixStream,
112        handler: &dyn AgentHandler,
113    ) -> Result<(), AgentProtocolError> {
114        loop {
115            // Read message length
116            let mut len_bytes = [0u8; 4];
117            match stream.read_exact(&mut len_bytes).await {
118                Ok(_) => {}
119                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
120                    // Client disconnected
121                    return Ok(());
122                }
123                Err(e) => return Err(e.into()),
124            }
125
126            let message_len = u32::from_be_bytes(len_bytes) as usize;
127
128            // Check message size
129            if message_len > MAX_MESSAGE_SIZE {
130                return Err(AgentProtocolError::MessageTooLarge {
131                    size: message_len,
132                    max: MAX_MESSAGE_SIZE,
133                });
134            }
135
136            // Read message data
137            let mut buffer = vec![0u8; message_len];
138            stream.read_exact(&mut buffer).await?;
139
140            // Parse request
141            let request: AgentRequest = serde_json::from_slice(&buffer)
142                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
143
144            // Handle request based on event type
145            let response = match request.event_type {
146                EventType::RequestHeaders => {
147                    let event: RequestHeadersEvent = serde_json::from_value(request.payload)
148                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
149                    handler.on_request_headers(event).await
150                }
151                EventType::RequestBodyChunk => {
152                    let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
153                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
154                    handler.on_request_body_chunk(event).await
155                }
156                EventType::ResponseHeaders => {
157                    let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
158                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
159                    handler.on_response_headers(event).await
160                }
161                EventType::ResponseBodyChunk => {
162                    let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
163                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
164                    handler.on_response_body_chunk(event).await
165                }
166                EventType::RequestComplete => {
167                    let event: RequestCompleteEvent = serde_json::from_value(request.payload)
168                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
169                    handler.on_request_complete(event).await
170                }
171            };
172
173            // Send response
174            let response_bytes = serde_json::to_vec(&response)
175                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
176
177            // Write message length
178            let len_bytes = (response_bytes.len() as u32).to_be_bytes();
179            stream.write_all(&len_bytes).await?;
180            // Write message data
181            stream.write_all(&response_bytes).await?;
182            stream.flush().await?;
183        }
184    }
185}
186
187/// Reference implementation: Echo agent (for testing)
188pub struct EchoAgent;
189
190#[async_trait]
191impl AgentHandler for EchoAgent {
192    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
193        debug!(
194            "Echo agent: request headers for {}",
195            event.metadata.correlation_id
196        );
197
198        // Echo back correlation ID as a header
199        AgentResponse::default_allow()
200            .add_request_header(HeaderOp::Set {
201                name: "X-Echo-Agent".to_string(),
202                value: event.metadata.correlation_id.clone(),
203            })
204            .with_audit(AuditMetadata {
205                tags: vec!["echo".to_string()],
206                ..Default::default()
207            })
208    }
209}
210
211/// Reference implementation: Denylist agent
212pub struct DenylistAgent {
213    blocked_paths: Vec<String>,
214    blocked_ips: Vec<String>,
215}
216
217impl DenylistAgent {
218    pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
219        Self {
220            blocked_paths,
221            blocked_ips,
222        }
223    }
224}
225
226#[async_trait]
227impl AgentHandler for DenylistAgent {
228    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
229        // Check if path is blocked
230        for blocked_path in &self.blocked_paths {
231            if event.uri.starts_with(blocked_path) {
232                return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
233                    AuditMetadata {
234                        tags: vec!["denylist".to_string(), "blocked_path".to_string()],
235                        reason_codes: vec!["PATH_BLOCKED".to_string()],
236                        ..Default::default()
237                    },
238                );
239            }
240        }
241
242        // Check if IP is blocked
243        if self.blocked_ips.contains(&event.metadata.client_ip) {
244            return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
245                AuditMetadata {
246                    tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
247                    reason_codes: vec!["IP_BLOCKED".to_string()],
248                    ..Default::default()
249                },
250            );
251        }
252
253        AgentResponse::default_allow()
254    }
255}
256
257// ============================================================================
258// gRPC Server Implementation
259// ============================================================================
260
261/// gRPC agent server for implementing external agents
262pub struct GrpcAgentServer {
263    /// Agent ID
264    id: String,
265    /// Request handler
266    handler: Arc<dyn AgentHandler>,
267}
268
269impl GrpcAgentServer {
270    /// Create a new gRPC agent server
271    pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandler>) -> Self {
272        Self {
273            id: id.into(),
274            handler: Arc::from(handler),
275        }
276    }
277
278    /// Get the tonic service for this agent
279    pub fn into_service(self) -> AgentProcessorServer<GrpcAgentHandler> {
280        AgentProcessorServer::new(GrpcAgentHandler {
281            id: self.id,
282            handler: self.handler,
283        })
284    }
285
286    /// Start the gRPC server on the given address
287    pub async fn run(self, addr: SocketAddr) -> Result<(), AgentProtocolError> {
288        info!("gRPC agent server '{}' listening on {}", self.id, addr);
289
290        tonic::transport::Server::builder()
291            .add_service(self.into_service())
292            .serve(addr)
293            .await
294            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC server error: {}", e)))
295    }
296}
297
298/// Internal handler that implements the gRPC AgentProcessor trait
299pub struct GrpcAgentHandler {
300    id: String,
301    handler: Arc<dyn AgentHandler>,
302}
303
304#[tonic::async_trait]
305impl AgentProcessor for GrpcAgentHandler {
306    async fn process_event(
307        &self,
308        request: Request<grpc::AgentRequest>,
309    ) -> Result<Response<grpc::AgentResponse>, Status> {
310        let grpc_request = request.into_inner();
311
312        // Convert gRPC event to internal event and dispatch
313        let response = match grpc_request.event {
314            Some(grpc::agent_request::Event::RequestHeaders(e)) => {
315                let event = Self::convert_request_headers_from_grpc(e);
316                self.handler.on_request_headers(event).await
317            }
318            Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
319                let event = Self::convert_request_body_chunk_from_grpc(e);
320                self.handler.on_request_body_chunk(event).await
321            }
322            Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
323                let event = Self::convert_response_headers_from_grpc(e);
324                self.handler.on_response_headers(event).await
325            }
326            Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
327                let event = Self::convert_response_body_chunk_from_grpc(e);
328                self.handler.on_response_body_chunk(event).await
329            }
330            Some(grpc::agent_request::Event::RequestComplete(e)) => {
331                let event = Self::convert_request_complete_from_grpc(e);
332                self.handler.on_request_complete(event).await
333            }
334            None => {
335                return Err(Status::invalid_argument("Missing event in request"));
336            }
337        };
338
339        // Convert internal response to gRPC response
340        let grpc_response = Self::convert_response_to_grpc(response);
341        Ok(Response::new(grpc_response))
342    }
343
344    async fn process_event_stream(
345        &self,
346        request: Request<Streaming<grpc::AgentRequest>>,
347    ) -> Result<Response<grpc::AgentResponse>, Status> {
348        let mut stream = request.into_inner();
349
350        // Process all events in the stream, returning the final response
351        let mut final_response = AgentResponse::default_allow();
352
353        while let Some(result) = stream.next().await {
354            let grpc_request = result.map_err(|e| Status::internal(format!("Stream error: {}", e)))?;
355
356            let response = match grpc_request.event {
357                Some(grpc::agent_request::Event::RequestHeaders(e)) => {
358                    let event = Self::convert_request_headers_from_grpc(e);
359                    self.handler.on_request_headers(event).await
360                }
361                Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
362                    let event = Self::convert_request_body_chunk_from_grpc(e);
363                    self.handler.on_request_body_chunk(event).await
364                }
365                Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
366                    let event = Self::convert_response_headers_from_grpc(e);
367                    self.handler.on_response_headers(event).await
368                }
369                Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
370                    let event = Self::convert_response_body_chunk_from_grpc(e);
371                    self.handler.on_response_body_chunk(event).await
372                }
373                Some(grpc::agent_request::Event::RequestComplete(e)) => {
374                    let event = Self::convert_request_complete_from_grpc(e);
375                    self.handler.on_request_complete(event).await
376                }
377                None => continue,
378            };
379
380            // If any event results in a block/redirect, that becomes the final response
381            if !matches!(response.decision, Decision::Allow) {
382                final_response = response;
383                break;
384            }
385            final_response = response;
386        }
387
388        let grpc_response = Self::convert_response_to_grpc(final_response);
389        Ok(Response::new(grpc_response))
390    }
391}
392
393impl GrpcAgentHandler {
394    /// Convert gRPC RequestHeadersEvent to internal format
395    fn convert_request_headers_from_grpc(e: grpc::RequestHeadersEvent) -> RequestHeadersEvent {
396        RequestHeadersEvent {
397            metadata: Self::convert_metadata_from_grpc(e.metadata),
398            method: e.method,
399            uri: e.uri,
400            headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
401        }
402    }
403
404    /// Convert gRPC RequestBodyChunkEvent to internal format
405    fn convert_request_body_chunk_from_grpc(e: grpc::RequestBodyChunkEvent) -> RequestBodyChunkEvent {
406        RequestBodyChunkEvent {
407            correlation_id: e.correlation_id,
408            data: String::from_utf8_lossy(&e.data).to_string(),
409            is_last: e.is_last,
410            total_size: e.total_size.map(|s| s as usize),
411        }
412    }
413
414    /// Convert gRPC ResponseHeadersEvent to internal format
415    fn convert_response_headers_from_grpc(e: grpc::ResponseHeadersEvent) -> ResponseHeadersEvent {
416        ResponseHeadersEvent {
417            correlation_id: e.correlation_id,
418            status: e.status as u16,
419            headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
420        }
421    }
422
423    /// Convert gRPC ResponseBodyChunkEvent to internal format
424    fn convert_response_body_chunk_from_grpc(e: grpc::ResponseBodyChunkEvent) -> ResponseBodyChunkEvent {
425        ResponseBodyChunkEvent {
426            correlation_id: e.correlation_id,
427            data: String::from_utf8_lossy(&e.data).to_string(),
428            is_last: e.is_last,
429            total_size: e.total_size.map(|s| s as usize),
430        }
431    }
432
433    /// Convert gRPC RequestCompleteEvent to internal format
434    fn convert_request_complete_from_grpc(e: grpc::RequestCompleteEvent) -> RequestCompleteEvent {
435        RequestCompleteEvent {
436            correlation_id: e.correlation_id,
437            status: e.status as u16,
438            duration_ms: e.duration_ms,
439            request_body_size: e.request_body_size as usize,
440            response_body_size: e.response_body_size as usize,
441            upstream_attempts: e.upstream_attempts,
442            error: e.error,
443        }
444    }
445
446    /// Convert gRPC metadata to internal format
447    fn convert_metadata_from_grpc(metadata: Option<grpc::RequestMetadata>) -> RequestMetadata {
448        match metadata {
449            Some(m) => RequestMetadata {
450                correlation_id: m.correlation_id,
451                request_id: m.request_id,
452                client_ip: m.client_ip,
453                client_port: m.client_port as u16,
454                server_name: m.server_name,
455                protocol: m.protocol,
456                tls_version: m.tls_version,
457                tls_cipher: m.tls_cipher,
458                route_id: m.route_id,
459                upstream_id: m.upstream_id,
460                timestamp: m.timestamp,
461            },
462            None => RequestMetadata {
463                correlation_id: String::new(),
464                request_id: String::new(),
465                client_ip: String::new(),
466                client_port: 0,
467                server_name: None,
468                protocol: String::new(),
469                tls_version: None,
470                tls_cipher: None,
471                route_id: None,
472                upstream_id: None,
473                timestamp: String::new(),
474            },
475        }
476    }
477
478    /// Convert internal response to gRPC format
479    fn convert_response_to_grpc(response: AgentResponse) -> grpc::AgentResponse {
480        let decision = match response.decision {
481            Decision::Allow => Some(grpc::agent_response::Decision::Allow(grpc::AllowDecision {})),
482            Decision::Block { status, body, headers } => {
483                Some(grpc::agent_response::Decision::Block(grpc::BlockDecision {
484                    status: status as u32,
485                    body,
486                    headers: headers.unwrap_or_default(),
487                }))
488            }
489            Decision::Redirect { url, status } => {
490                Some(grpc::agent_response::Decision::Redirect(grpc::RedirectDecision {
491                    url,
492                    status: status as u32,
493                }))
494            }
495            Decision::Challenge { challenge_type, params } => {
496                Some(grpc::agent_response::Decision::Challenge(grpc::ChallengeDecision {
497                    challenge_type,
498                    params,
499                }))
500            }
501        };
502
503        let request_headers: Vec<grpc::HeaderOp> = response.request_headers
504            .into_iter()
505            .map(Self::convert_header_op_to_grpc)
506            .collect();
507
508        let response_headers: Vec<grpc::HeaderOp> = response.response_headers
509            .into_iter()
510            .map(Self::convert_header_op_to_grpc)
511            .collect();
512
513        let audit = Some(grpc::AuditMetadata {
514            tags: response.audit.tags,
515            rule_ids: response.audit.rule_ids,
516            confidence: response.audit.confidence,
517            reason_codes: response.audit.reason_codes,
518            custom: response.audit.custom.into_iter().map(|(k, v)| {
519                (k, v.to_string())
520            }).collect(),
521        });
522
523        grpc::AgentResponse {
524            version: PROTOCOL_VERSION,
525            decision,
526            request_headers,
527            response_headers,
528            routing_metadata: response.routing_metadata,
529            audit,
530        }
531    }
532
533    /// Convert internal header operation to gRPC format
534    fn convert_header_op_to_grpc(op: HeaderOp) -> grpc::HeaderOp {
535        let operation = match op {
536            HeaderOp::Set { name, value } => {
537                Some(grpc::header_op::Operation::Set(grpc::SetHeader { name, value }))
538            }
539            HeaderOp::Add { name, value } => {
540                Some(grpc::header_op::Operation::Add(grpc::AddHeader { name, value }))
541            }
542            HeaderOp::Remove { name } => {
543                Some(grpc::header_op::Operation::Remove(grpc::RemoveHeader { name }))
544            }
545        };
546        grpc::HeaderOp { operation }
547    }
548}