sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17    AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
18    RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
19    ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
20    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
21};
22
23/// Agent client for communicating with external agents
24pub struct AgentClient {
25    /// Agent ID
26    id: String,
27    /// Connection to agent
28    connection: AgentConnection,
29    /// Timeout for agent calls
30    timeout: Duration,
31    /// Maximum retries
32    #[allow(dead_code)]
33    max_retries: u32,
34}
35
36/// Agent connection type
37enum AgentConnection {
38    UnixSocket(UnixStream),
39    Grpc(AgentProcessorClient<Channel>),
40}
41
42impl AgentClient {
43    /// Create a new Unix socket agent client
44    pub async fn unix_socket(
45        id: impl Into<String>,
46        path: impl AsRef<std::path::Path>,
47        timeout: Duration,
48    ) -> Result<Self, AgentProtocolError> {
49        let id = id.into();
50        let path = path.as_ref();
51
52        trace!(
53            agent_id = %id,
54            socket_path = %path.display(),
55            timeout_ms = timeout.as_millis() as u64,
56            "Connecting to agent via Unix socket"
57        );
58
59        let stream = UnixStream::connect(path).await.map_err(|e| {
60            error!(
61                agent_id = %id,
62                socket_path = %path.display(),
63                error = %e,
64                "Failed to connect to agent via Unix socket"
65            );
66            AgentProtocolError::ConnectionFailed(e.to_string())
67        })?;
68
69        debug!(
70            agent_id = %id,
71            socket_path = %path.display(),
72            "Connected to agent via Unix socket"
73        );
74
75        Ok(Self {
76            id,
77            connection: AgentConnection::UnixSocket(stream),
78            timeout,
79            max_retries: 3,
80        })
81    }
82
83    /// Create a new gRPC agent client
84    ///
85    /// # Arguments
86    /// * `id` - Agent identifier
87    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
88    /// * `timeout` - Timeout for agent calls
89    pub async fn grpc(
90        id: impl Into<String>,
91        address: impl Into<String>,
92        timeout: Duration,
93    ) -> Result<Self, AgentProtocolError> {
94        let id = id.into();
95        let address = address.into();
96
97        trace!(
98            agent_id = %id,
99            address = %address,
100            timeout_ms = timeout.as_millis() as u64,
101            "Connecting to agent via gRPC"
102        );
103
104        let channel = Channel::from_shared(address.clone())
105            .map_err(|e| {
106                error!(
107                    agent_id = %id,
108                    address = %address,
109                    error = %e,
110                    "Invalid gRPC URI"
111                );
112                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
113            })?
114            .timeout(timeout)
115            .connect()
116            .await
117            .map_err(|e| {
118                error!(
119                    agent_id = %id,
120                    address = %address,
121                    error = %e,
122                    "Failed to connect to agent via gRPC"
123                );
124                AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
125            })?;
126
127        let client = AgentProcessorClient::new(channel);
128
129        debug!(
130            agent_id = %id,
131            address = %address,
132            "Connected to agent via gRPC"
133        );
134
135        Ok(Self {
136            id,
137            connection: AgentConnection::Grpc(client),
138            timeout,
139            max_retries: 3,
140        })
141    }
142
143    /// Get the agent ID
144    #[allow(dead_code)]
145    pub fn id(&self) -> &str {
146        &self.id
147    }
148
149    /// Send an event to the agent and get a response
150    pub async fn send_event(
151        &mut self,
152        event_type: EventType,
153        payload: impl Serialize,
154    ) -> Result<AgentResponse, AgentProtocolError> {
155        match &mut self.connection {
156            AgentConnection::UnixSocket(_) => {
157                self.send_event_unix_socket(event_type, payload).await
158            }
159            AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
160        }
161    }
162
163    /// Send event via Unix socket (length-prefixed JSON)
164    async fn send_event_unix_socket(
165        &mut self,
166        event_type: EventType,
167        payload: impl Serialize,
168    ) -> Result<AgentResponse, AgentProtocolError> {
169        let request = AgentRequest {
170            version: PROTOCOL_VERSION,
171            event_type,
172            payload: serde_json::to_value(payload)
173                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
174        };
175
176        // Serialize request
177        let request_bytes = serde_json::to_vec(&request)
178            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
179
180        // Check message size
181        if request_bytes.len() > MAX_MESSAGE_SIZE {
182            return Err(AgentProtocolError::MessageTooLarge {
183                size: request_bytes.len(),
184                max: MAX_MESSAGE_SIZE,
185            });
186        }
187
188        // Send with timeout
189        let response = tokio::time::timeout(self.timeout, async {
190            self.send_raw_unix(&request_bytes).await?;
191            self.receive_raw_unix().await
192        })
193        .await
194        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
195
196        // Parse response
197        let agent_response: AgentResponse = serde_json::from_slice(&response)
198            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
199
200        // Verify protocol version
201        if agent_response.version != PROTOCOL_VERSION {
202            return Err(AgentProtocolError::VersionMismatch {
203                expected: PROTOCOL_VERSION,
204                actual: agent_response.version,
205            });
206        }
207
208        Ok(agent_response)
209    }
210
211    /// Send event via gRPC
212    async fn send_event_grpc(
213        &mut self,
214        event_type: EventType,
215        payload: impl Serialize,
216    ) -> Result<AgentResponse, AgentProtocolError> {
217        // Build request first (doesn't need mutable borrow)
218        let grpc_request = Self::build_grpc_request(event_type, payload)?;
219
220        let AgentConnection::Grpc(client) = &mut self.connection else {
221            unreachable!()
222        };
223
224        // Send with timeout
225        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
226            .await
227            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
228            .map_err(|e| {
229                AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
230            })?;
231
232        // Convert gRPC response to internal format
233        Self::convert_grpc_response(response.into_inner())
234    }
235
236    /// Build a gRPC request from internal types
237    fn build_grpc_request(
238        event_type: EventType,
239        payload: impl Serialize,
240    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
241        let payload_json = serde_json::to_value(&payload)
242            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
243
244        let grpc_event_type = match event_type {
245            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
246            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
247            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
248            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
249            EventType::RequestComplete => grpc::EventType::RequestComplete,
250            EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
251        };
252
253        let event = match event_type {
254            EventType::RequestHeaders => {
255                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
256                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
258                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
259                    method: event.method,
260                    uri: event.uri,
261                    headers: event
262                        .headers
263                        .into_iter()
264                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
265                        .collect(),
266                })
267            }
268            EventType::RequestBodyChunk => {
269                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
270                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
271                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
272                    correlation_id: event.correlation_id,
273                    data: event.data.into_bytes(),
274                    is_last: event.is_last,
275                    total_size: event.total_size.map(|s| s as u64),
276                    chunk_index: event.chunk_index,
277                    bytes_received: event.bytes_received as u64,
278                })
279            }
280            EventType::ResponseHeaders => {
281                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
282                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
283                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
284                    correlation_id: event.correlation_id,
285                    status: event.status as u32,
286                    headers: event
287                        .headers
288                        .into_iter()
289                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
290                        .collect(),
291                })
292            }
293            EventType::ResponseBodyChunk => {
294                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
295                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
296                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
297                    correlation_id: event.correlation_id,
298                    data: event.data.into_bytes(),
299                    is_last: event.is_last,
300                    total_size: event.total_size.map(|s| s as u64),
301                    chunk_index: event.chunk_index,
302                    bytes_sent: event.bytes_sent as u64,
303                })
304            }
305            EventType::RequestComplete => {
306                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
307                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
308                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
309                    correlation_id: event.correlation_id,
310                    status: event.status as u32,
311                    duration_ms: event.duration_ms,
312                    request_body_size: event.request_body_size as u64,
313                    response_body_size: event.response_body_size as u64,
314                    upstream_attempts: event.upstream_attempts,
315                    error: event.error,
316                })
317            }
318            EventType::WebSocketFrame => {
319                use base64::{engine::general_purpose::STANDARD, Engine as _};
320                let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
321                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
322                grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
323                    correlation_id: event.correlation_id,
324                    opcode: event.opcode,
325                    data: STANDARD.decode(&event.data).unwrap_or_default(),
326                    client_to_server: event.client_to_server,
327                    frame_index: event.frame_index,
328                    fin: event.fin,
329                    route_id: event.route_id,
330                    client_ip: event.client_ip,
331                })
332            }
333        };
334
335        Ok(grpc::AgentRequest {
336            version: PROTOCOL_VERSION,
337            event_type: grpc_event_type as i32,
338            event: Some(event),
339        })
340    }
341
342    /// Convert internal metadata to gRPC format
343    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
344        grpc::RequestMetadata {
345            correlation_id: metadata.correlation_id.clone(),
346            request_id: metadata.request_id.clone(),
347            client_ip: metadata.client_ip.clone(),
348            client_port: metadata.client_port as u32,
349            server_name: metadata.server_name.clone(),
350            protocol: metadata.protocol.clone(),
351            tls_version: metadata.tls_version.clone(),
352            tls_cipher: metadata.tls_cipher.clone(),
353            route_id: metadata.route_id.clone(),
354            upstream_id: metadata.upstream_id.clone(),
355            timestamp: metadata.timestamp.clone(),
356        }
357    }
358
359    /// Convert gRPC response to internal format
360    fn convert_grpc_response(
361        response: grpc::AgentResponse,
362    ) -> Result<AgentResponse, AgentProtocolError> {
363        let decision = match response.decision {
364            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
365            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
366                status: b.status as u16,
367                body: b.body,
368                headers: if b.headers.is_empty() {
369                    None
370                } else {
371                    Some(b.headers)
372                },
373            },
374            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
375                url: r.url,
376                status: r.status as u16,
377            },
378            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
379                challenge_type: c.challenge_type,
380                params: c.params,
381            },
382            None => Decision::Allow, // Default to allow if no decision
383        };
384
385        let request_headers: Vec<HeaderOp> = response
386            .request_headers
387            .into_iter()
388            .filter_map(Self::convert_header_op_from_grpc)
389            .collect();
390
391        let response_headers: Vec<HeaderOp> = response
392            .response_headers
393            .into_iter()
394            .filter_map(Self::convert_header_op_from_grpc)
395            .collect();
396
397        let audit = response.audit.map(|a| AuditMetadata {
398            tags: a.tags,
399            rule_ids: a.rule_ids,
400            confidence: a.confidence,
401            reason_codes: a.reason_codes,
402            custom: a
403                .custom
404                .into_iter()
405                .map(|(k, v)| (k, serde_json::Value::String(v)))
406                .collect(),
407        });
408
409        // Convert body mutations
410        let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
411            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
412            chunk_index: m.chunk_index,
413        });
414
415        let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
416            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
417            chunk_index: m.chunk_index,
418        });
419
420        // Convert WebSocket decision
421        let websocket_decision = response
422            .websocket_decision
423            .map(|ws_decision| match ws_decision {
424                grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
425                    WebSocketDecision::Allow
426                }
427                grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
428                    WebSocketDecision::Drop
429                }
430                grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
431                    WebSocketDecision::Close {
432                        code: c.code as u16,
433                        reason: c.reason,
434                    }
435                }
436            });
437
438        Ok(AgentResponse {
439            version: response.version,
440            decision,
441            request_headers,
442            response_headers,
443            routing_metadata: response.routing_metadata,
444            audit: audit.unwrap_or_default(),
445            needs_more: response.needs_more,
446            request_body_mutation,
447            response_body_mutation,
448            websocket_decision,
449        })
450    }
451
452    /// Convert gRPC header operation to internal format
453    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
454        match op.operation? {
455            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
456                name: s.name,
457                value: s.value,
458            }),
459            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
460                name: a.name,
461                value: a.value,
462            }),
463            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
464        }
465    }
466
467    /// Send raw bytes to agent (Unix socket only)
468    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
469        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
470            unreachable!()
471        };
472        // Write message length (4 bytes, big-endian)
473        let len_bytes = (data.len() as u32).to_be_bytes();
474        stream.write_all(&len_bytes).await?;
475        // Write message data
476        stream.write_all(data).await?;
477        stream.flush().await?;
478        Ok(())
479    }
480
481    /// Receive raw bytes from agent (Unix socket only)
482    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
483        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
484            unreachable!()
485        };
486        // Read message length (4 bytes, big-endian)
487        let mut len_bytes = [0u8; 4];
488        stream.read_exact(&mut len_bytes).await?;
489        let message_len = u32::from_be_bytes(len_bytes) as usize;
490
491        // Check message size
492        if message_len > MAX_MESSAGE_SIZE {
493            return Err(AgentProtocolError::MessageTooLarge {
494                size: message_len,
495                max: MAX_MESSAGE_SIZE,
496            });
497        }
498
499        // Read message data
500        let mut buffer = vec![0u8; message_len];
501        stream.read_exact(&mut buffer).await?;
502        Ok(buffer)
503    }
504
505    /// Close the agent connection
506    pub async fn close(self) -> Result<(), AgentProtocolError> {
507        match self.connection {
508            AgentConnection::UnixSocket(mut stream) => {
509                stream.shutdown().await?;
510                Ok(())
511            }
512            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
513        }
514    }
515}