sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12
13use crate::errors::AgentProtocolError;
14use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
15use crate::protocol::{
16    AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
17    RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
18    ResponseHeadersEvent, MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
19};
20
21/// Agent client for communicating with external agents
22pub struct AgentClient {
23    /// Agent ID
24    id: String,
25    /// Connection to agent
26    connection: AgentConnection,
27    /// Timeout for agent calls
28    timeout: Duration,
29    /// Maximum retries
30    #[allow(dead_code)]
31    max_retries: u32,
32}
33
34/// Agent connection type
35enum AgentConnection {
36    UnixSocket(UnixStream),
37    Grpc(AgentProcessorClient<Channel>),
38}
39
40impl AgentClient {
41    /// Create a new Unix socket agent client
42    pub async fn unix_socket(
43        id: impl Into<String>,
44        path: impl AsRef<std::path::Path>,
45        timeout: Duration,
46    ) -> Result<Self, AgentProtocolError> {
47        let stream = UnixStream::connect(path.as_ref())
48            .await
49            .map_err(|e| AgentProtocolError::ConnectionFailed(e.to_string()))?;
50
51        Ok(Self {
52            id: id.into(),
53            connection: AgentConnection::UnixSocket(stream),
54            timeout,
55            max_retries: 3,
56        })
57    }
58
59    /// Create a new gRPC agent client
60    ///
61    /// # Arguments
62    /// * `id` - Agent identifier
63    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
64    /// * `timeout` - Timeout for agent calls
65    pub async fn grpc(
66        id: impl Into<String>,
67        address: impl Into<String>,
68        timeout: Duration,
69    ) -> Result<Self, AgentProtocolError> {
70        let address = address.into();
71        let channel = Channel::from_shared(address.clone())
72            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e)))?
73            .timeout(timeout)
74            .connect()
75            .await
76            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e)))?;
77
78        let client = AgentProcessorClient::new(channel);
79
80        Ok(Self {
81            id: id.into(),
82            connection: AgentConnection::Grpc(client),
83            timeout,
84            max_retries: 3,
85        })
86    }
87
88    /// Get the agent ID
89    #[allow(dead_code)]
90    pub fn id(&self) -> &str {
91        &self.id
92    }
93
94    /// Send an event to the agent and get a response
95    pub async fn send_event(
96        &mut self,
97        event_type: EventType,
98        payload: impl Serialize,
99    ) -> Result<AgentResponse, AgentProtocolError> {
100        match &mut self.connection {
101            AgentConnection::UnixSocket(_) => {
102                self.send_event_unix_socket(event_type, payload).await
103            }
104            AgentConnection::Grpc(_) => {
105                self.send_event_grpc(event_type, payload).await
106            }
107        }
108    }
109
110    /// Send event via Unix socket (length-prefixed JSON)
111    async fn send_event_unix_socket(
112        &mut self,
113        event_type: EventType,
114        payload: impl Serialize,
115    ) -> Result<AgentResponse, AgentProtocolError> {
116        let request = AgentRequest {
117            version: PROTOCOL_VERSION,
118            event_type,
119            payload: serde_json::to_value(payload)
120                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
121        };
122
123        // Serialize request
124        let request_bytes = serde_json::to_vec(&request)
125            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
126
127        // Check message size
128        if request_bytes.len() > MAX_MESSAGE_SIZE {
129            return Err(AgentProtocolError::MessageTooLarge {
130                size: request_bytes.len(),
131                max: MAX_MESSAGE_SIZE,
132            });
133        }
134
135        // Send with timeout
136        let response = tokio::time::timeout(self.timeout, async {
137            self.send_raw_unix(&request_bytes).await?;
138            self.receive_raw_unix().await
139        })
140        .await
141        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
142
143        // Parse response
144        let agent_response: AgentResponse = serde_json::from_slice(&response)
145            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
146
147        // Verify protocol version
148        if agent_response.version != PROTOCOL_VERSION {
149            return Err(AgentProtocolError::VersionMismatch {
150                expected: PROTOCOL_VERSION,
151                actual: agent_response.version,
152            });
153        }
154
155        Ok(agent_response)
156    }
157
158    /// Send event via gRPC
159    async fn send_event_grpc(
160        &mut self,
161        event_type: EventType,
162        payload: impl Serialize,
163    ) -> Result<AgentResponse, AgentProtocolError> {
164        // Build request first (doesn't need mutable borrow)
165        let grpc_request = Self::build_grpc_request(event_type, payload)?;
166
167        let AgentConnection::Grpc(client) = &mut self.connection else {
168            unreachable!()
169        };
170
171        // Send with timeout
172        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
173            .await
174            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
175            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e)))?;
176
177        // Convert gRPC response to internal format
178        Self::convert_grpc_response(response.into_inner())
179    }
180
181    /// Build a gRPC request from internal types
182    fn build_grpc_request(
183        event_type: EventType,
184        payload: impl Serialize,
185    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
186        let payload_json = serde_json::to_value(&payload)
187            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
188
189        let grpc_event_type = match event_type {
190            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
191            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
192            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
193            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
194            EventType::RequestComplete => grpc::EventType::RequestComplete,
195        };
196
197        let event = match event_type {
198            EventType::RequestHeaders => {
199                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
200                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
201                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
202                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
203                    method: event.method,
204                    uri: event.uri,
205                    headers: event.headers.into_iter().map(|(k, v)| {
206                        (k, grpc::HeaderValues { values: v })
207                    }).collect(),
208                })
209            }
210            EventType::RequestBodyChunk => {
211                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
212                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
213                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
214                    correlation_id: event.correlation_id,
215                    data: event.data.into_bytes(),
216                    is_last: event.is_last,
217                    total_size: event.total_size.map(|s| s as u64),
218                })
219            }
220            EventType::ResponseHeaders => {
221                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
222                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
223                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
224                    correlation_id: event.correlation_id,
225                    status: event.status as u32,
226                    headers: event.headers.into_iter().map(|(k, v)| {
227                        (k, grpc::HeaderValues { values: v })
228                    }).collect(),
229                })
230            }
231            EventType::ResponseBodyChunk => {
232                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
233                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
234                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
235                    correlation_id: event.correlation_id,
236                    data: event.data.into_bytes(),
237                    is_last: event.is_last,
238                    total_size: event.total_size.map(|s| s as u64),
239                })
240            }
241            EventType::RequestComplete => {
242                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
243                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
244                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
245                    correlation_id: event.correlation_id,
246                    status: event.status as u32,
247                    duration_ms: event.duration_ms,
248                    request_body_size: event.request_body_size as u64,
249                    response_body_size: event.response_body_size as u64,
250                    upstream_attempts: event.upstream_attempts,
251                    error: event.error,
252                })
253            }
254        };
255
256        Ok(grpc::AgentRequest {
257            version: PROTOCOL_VERSION,
258            event_type: grpc_event_type as i32,
259            event: Some(event),
260        })
261    }
262
263    /// Convert internal metadata to gRPC format
264    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
265        grpc::RequestMetadata {
266            correlation_id: metadata.correlation_id.clone(),
267            request_id: metadata.request_id.clone(),
268            client_ip: metadata.client_ip.clone(),
269            client_port: metadata.client_port as u32,
270            server_name: metadata.server_name.clone(),
271            protocol: metadata.protocol.clone(),
272            tls_version: metadata.tls_version.clone(),
273            tls_cipher: metadata.tls_cipher.clone(),
274            route_id: metadata.route_id.clone(),
275            upstream_id: metadata.upstream_id.clone(),
276            timestamp: metadata.timestamp.clone(),
277        }
278    }
279
280    /// Convert gRPC response to internal format
281    fn convert_grpc_response(
282        response: grpc::AgentResponse,
283    ) -> Result<AgentResponse, AgentProtocolError> {
284        let decision = match response.decision {
285            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
286            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
287                status: b.status as u16,
288                body: b.body,
289                headers: if b.headers.is_empty() { None } else { Some(b.headers) },
290            },
291            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
292                url: r.url,
293                status: r.status as u16,
294            },
295            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
296                challenge_type: c.challenge_type,
297                params: c.params,
298            },
299            None => Decision::Allow, // Default to allow if no decision
300        };
301
302        let request_headers: Vec<HeaderOp> = response.request_headers
303            .into_iter()
304            .filter_map(Self::convert_header_op_from_grpc)
305            .collect();
306
307        let response_headers: Vec<HeaderOp> = response.response_headers
308            .into_iter()
309            .filter_map(Self::convert_header_op_from_grpc)
310            .collect();
311
312        let audit = response.audit.map(|a| AuditMetadata {
313            tags: a.tags,
314            rule_ids: a.rule_ids,
315            confidence: a.confidence,
316            reason_codes: a.reason_codes,
317            custom: a.custom.into_iter().map(|(k, v)| {
318                (k, serde_json::Value::String(v))
319            }).collect(),
320        });
321
322        Ok(AgentResponse {
323            version: response.version,
324            decision,
325            request_headers,
326            response_headers,
327            routing_metadata: response.routing_metadata,
328            audit: audit.unwrap_or_default(),
329        })
330    }
331
332    /// Convert gRPC header operation to internal format
333    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
334        match op.operation? {
335            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
336                name: s.name,
337                value: s.value,
338            }),
339            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
340                name: a.name,
341                value: a.value,
342            }),
343            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove {
344                name: r.name,
345            }),
346        }
347    }
348
349    /// Send raw bytes to agent (Unix socket only)
350    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
351        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
352            unreachable!()
353        };
354        // Write message length (4 bytes, big-endian)
355        let len_bytes = (data.len() as u32).to_be_bytes();
356        stream.write_all(&len_bytes).await?;
357        // Write message data
358        stream.write_all(data).await?;
359        stream.flush().await?;
360        Ok(())
361    }
362
363    /// Receive raw bytes from agent (Unix socket only)
364    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
365        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
366            unreachable!()
367        };
368        // Read message length (4 bytes, big-endian)
369        let mut len_bytes = [0u8; 4];
370        stream.read_exact(&mut len_bytes).await?;
371        let message_len = u32::from_be_bytes(len_bytes) as usize;
372
373        // Check message size
374        if message_len > MAX_MESSAGE_SIZE {
375            return Err(AgentProtocolError::MessageTooLarge {
376                size: message_len,
377                max: MAX_MESSAGE_SIZE,
378            });
379        }
380
381        // Read message data
382        let mut buffer = vec![0u8; message_len];
383        stream.read_exact(&mut buffer).await?;
384        Ok(buffer)
385    }
386
387    /// Close the agent connection
388    pub async fn close(self) -> Result<(), AgentProtocolError> {
389        match self.connection {
390            AgentConnection::UnixSocket(mut stream) => {
391                stream.shutdown().await?;
392                Ok(())
393            }
394            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
395        }
396    }
397}