Skip to main content

sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports three transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2, with optional TLS)
6//! - HTTP REST (JSON over HTTP/1.1 or HTTP/2, with optional TLS)
7
8use serde::Serialize;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::UnixStream;
14use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
15use tracing::{debug, error, trace, warn};
16
17use crate::errors::AgentProtocolError;
18use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
19use crate::protocol::{
20    AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
21    RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
22    ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
23    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
24};
25
26/// TLS configuration for gRPC agent connections
27#[derive(Debug, Clone, Default)]
28pub struct GrpcTlsConfig {
29    /// Skip certificate verification (DANGEROUS - only for testing)
30    pub insecure_skip_verify: bool,
31    /// CA certificate PEM data for verifying the server
32    pub ca_cert_pem: Option<Vec<u8>>,
33    /// Client certificate PEM data for mTLS
34    pub client_cert_pem: Option<Vec<u8>>,
35    /// Client key PEM data for mTLS
36    pub client_key_pem: Option<Vec<u8>>,
37    /// Domain name to use for TLS SNI and certificate validation
38    pub domain_name: Option<String>,
39}
40
41impl GrpcTlsConfig {
42    /// Create a new TLS config builder
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Load CA certificate from a file
48    pub async fn with_ca_cert_file(
49        mut self,
50        path: impl AsRef<Path>,
51    ) -> Result<Self, std::io::Error> {
52        self.ca_cert_pem = Some(tokio::fs::read(path).await?);
53        Ok(self)
54    }
55
56    /// Set CA certificate from PEM data
57    pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
58        self.ca_cert_pem = Some(pem.into());
59        self
60    }
61
62    /// Load client certificate and key from files (for mTLS)
63    pub async fn with_client_cert_files(
64        mut self,
65        cert_path: impl AsRef<Path>,
66        key_path: impl AsRef<Path>,
67    ) -> Result<Self, std::io::Error> {
68        self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
69        self.client_key_pem = Some(tokio::fs::read(key_path).await?);
70        Ok(self)
71    }
72
73    /// Set client certificate and key from PEM data (for mTLS)
74    pub fn with_client_identity(
75        mut self,
76        cert_pem: impl Into<Vec<u8>>,
77        key_pem: impl Into<Vec<u8>>,
78    ) -> Self {
79        self.client_cert_pem = Some(cert_pem.into());
80        self.client_key_pem = Some(key_pem.into());
81        self
82    }
83
84    /// Set the domain name for TLS SNI and certificate validation
85    pub fn with_domain_name(mut self, domain: impl Into<String>) -> Self {
86        self.domain_name = Some(domain.into());
87        self
88    }
89
90    /// Skip certificate verification (DANGEROUS - only for testing)
91    pub fn with_insecure_skip_verify(mut self) -> Self {
92        self.insecure_skip_verify = true;
93        self
94    }
95}
96
97/// TLS configuration for HTTP agent connections
98#[derive(Debug, Clone, Default)]
99pub struct HttpTlsConfig {
100    /// Skip certificate verification (DANGEROUS - only for testing)
101    pub insecure_skip_verify: bool,
102    /// CA certificate PEM data for verifying the server
103    pub ca_cert_pem: Option<Vec<u8>>,
104    /// Client certificate PEM data for mTLS
105    pub client_cert_pem: Option<Vec<u8>>,
106    /// Client key PEM data for mTLS
107    pub client_key_pem: Option<Vec<u8>>,
108}
109
110impl HttpTlsConfig {
111    /// Create a new TLS config builder
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Load CA certificate from a file
117    pub async fn with_ca_cert_file(
118        mut self,
119        path: impl AsRef<Path>,
120    ) -> Result<Self, std::io::Error> {
121        self.ca_cert_pem = Some(tokio::fs::read(path).await?);
122        Ok(self)
123    }
124
125    /// Set CA certificate from PEM data
126    pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
127        self.ca_cert_pem = Some(pem.into());
128        self
129    }
130
131    /// Load client certificate and key from files (for mTLS)
132    pub async fn with_client_cert_files(
133        mut self,
134        cert_path: impl AsRef<Path>,
135        key_path: impl AsRef<Path>,
136    ) -> Result<Self, std::io::Error> {
137        self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
138        self.client_key_pem = Some(tokio::fs::read(key_path).await?);
139        Ok(self)
140    }
141
142    /// Set client certificate and key from PEM data (for mTLS)
143    pub fn with_client_identity(
144        mut self,
145        cert_pem: impl Into<Vec<u8>>,
146        key_pem: impl Into<Vec<u8>>,
147    ) -> Self {
148        self.client_cert_pem = Some(cert_pem.into());
149        self.client_key_pem = Some(key_pem.into());
150        self
151    }
152
153    /// Skip certificate verification (DANGEROUS - only for testing)
154    pub fn with_insecure_skip_verify(mut self) -> Self {
155        self.insecure_skip_verify = true;
156        self
157    }
158}
159
160/// HTTP connection details
161struct HttpConnection {
162    /// HTTP client
163    client: reqwest::Client,
164    /// Base URL for the agent endpoint
165    url: String,
166}
167
168/// Agent client for communicating with external agents
169pub struct AgentClient {
170    /// Agent ID
171    id: String,
172    /// Connection to agent
173    connection: AgentConnection,
174    /// Timeout for agent calls
175    timeout: Duration,
176    /// Maximum retries
177    #[allow(dead_code)]
178    max_retries: u32,
179}
180
181/// Agent connection type
182enum AgentConnection {
183    UnixSocket(UnixStream),
184    Grpc(AgentProcessorClient<Channel>),
185    Http(Arc<HttpConnection>),
186}
187
188impl AgentClient {
189    /// Create a new Unix socket agent client
190    pub async fn unix_socket(
191        id: impl Into<String>,
192        path: impl AsRef<std::path::Path>,
193        timeout: Duration,
194    ) -> Result<Self, AgentProtocolError> {
195        let id = id.into();
196        let path = path.as_ref();
197
198        trace!(
199            agent_id = %id,
200            socket_path = %path.display(),
201            timeout_ms = timeout.as_millis() as u64,
202            "Connecting to agent via Unix socket"
203        );
204
205        let stream = UnixStream::connect(path).await.map_err(|e| {
206            error!(
207                agent_id = %id,
208                socket_path = %path.display(),
209                error = %e,
210                "Failed to connect to agent via Unix socket"
211            );
212            AgentProtocolError::ConnectionFailed(e.to_string())
213        })?;
214
215        debug!(
216            agent_id = %id,
217            socket_path = %path.display(),
218            "Connected to agent via Unix socket"
219        );
220
221        Ok(Self {
222            id,
223            connection: AgentConnection::UnixSocket(stream),
224            timeout,
225            max_retries: 3,
226        })
227    }
228
229    /// Create a new gRPC agent client
230    ///
231    /// # Arguments
232    /// * `id` - Agent identifier
233    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
234    /// * `timeout` - Timeout for agent calls
235    pub async fn grpc(
236        id: impl Into<String>,
237        address: impl Into<String>,
238        timeout: Duration,
239    ) -> Result<Self, AgentProtocolError> {
240        let id = id.into();
241        let address = address.into();
242
243        trace!(
244            agent_id = %id,
245            address = %address,
246            timeout_ms = timeout.as_millis() as u64,
247            "Connecting to agent via gRPC"
248        );
249
250        let channel = Channel::from_shared(address.clone())
251            .map_err(|e| {
252                error!(
253                    agent_id = %id,
254                    address = %address,
255                    error = %e,
256                    "Invalid gRPC URI"
257                );
258                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
259            })?
260            .timeout(timeout)
261            .connect()
262            .await
263            .map_err(|e| {
264                error!(
265                    agent_id = %id,
266                    address = %address,
267                    error = %e,
268                    "Failed to connect to agent via gRPC"
269                );
270                AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
271            })?;
272
273        let client = AgentProcessorClient::new(channel);
274
275        debug!(
276            agent_id = %id,
277            address = %address,
278            "Connected to agent via gRPC"
279        );
280
281        Ok(Self {
282            id,
283            connection: AgentConnection::Grpc(client),
284            timeout,
285            max_retries: 3,
286        })
287    }
288
289    /// Create a new gRPC agent client with TLS
290    ///
291    /// # Arguments
292    /// * `id` - Agent identifier
293    /// * `address` - gRPC server address (e.g., "https://localhost:50051")
294    /// * `timeout` - Timeout for agent calls
295    /// * `tls_config` - TLS configuration
296    pub async fn grpc_tls(
297        id: impl Into<String>,
298        address: impl Into<String>,
299        timeout: Duration,
300        tls_config: GrpcTlsConfig,
301    ) -> Result<Self, AgentProtocolError> {
302        let id = id.into();
303        let address = address.into();
304
305        trace!(
306            agent_id = %id,
307            address = %address,
308            timeout_ms = timeout.as_millis() as u64,
309            has_ca_cert = tls_config.ca_cert_pem.is_some(),
310            has_client_cert = tls_config.client_cert_pem.is_some(),
311            insecure = tls_config.insecure_skip_verify,
312            "Connecting to agent via gRPC with TLS"
313        );
314
315        // Build TLS config
316        let mut client_tls_config = ClientTlsConfig::new();
317
318        // Set domain name for SNI if provided, otherwise extract from address
319        if let Some(domain) = &tls_config.domain_name {
320            client_tls_config = client_tls_config.domain_name(domain.clone());
321        } else {
322            // Try to extract domain from address URL
323            if let Some(domain) = Self::extract_domain(&address) {
324                client_tls_config = client_tls_config.domain_name(domain);
325            }
326        }
327
328        // Add CA certificate if provided
329        if let Some(ca_pem) = &tls_config.ca_cert_pem {
330            let ca_cert = Certificate::from_pem(ca_pem);
331            client_tls_config = client_tls_config.ca_certificate(ca_cert);
332            debug!(
333                agent_id = %id,
334                "Using custom CA certificate for gRPC TLS"
335            );
336        }
337
338        // Add client identity for mTLS if provided
339        if let (Some(cert_pem), Some(key_pem)) =
340            (&tls_config.client_cert_pem, &tls_config.client_key_pem)
341        {
342            let identity = Identity::from_pem(cert_pem, key_pem);
343            client_tls_config = client_tls_config.identity(identity);
344            debug!(
345                agent_id = %id,
346                "Using client certificate for mTLS to gRPC agent"
347            );
348        }
349
350        // Handle insecure skip verify (dangerous - only for testing)
351        if tls_config.insecure_skip_verify {
352            warn!(
353                agent_id = %id,
354                address = %address,
355                "SECURITY WARNING: TLS certificate verification disabled for gRPC agent connection"
356            );
357            // Note: tonic doesn't have a direct "skip verify" option like some other libraries
358            // The best we can do is use a permissive TLS config. For truly insecure connections,
359            // users should use the non-TLS grpc() method instead.
360        }
361
362        // Build channel with TLS
363        let channel = Channel::from_shared(address.clone())
364            .map_err(|e| {
365                error!(
366                    agent_id = %id,
367                    address = %address,
368                    error = %e,
369                    "Invalid gRPC URI"
370                );
371                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
372            })?
373            .tls_config(client_tls_config)
374            .map_err(|e| {
375                error!(
376                    agent_id = %id,
377                    address = %address,
378                    error = %e,
379                    "Invalid TLS configuration"
380                );
381                AgentProtocolError::ConnectionFailed(format!("TLS config error: {}", e))
382            })?
383            .timeout(timeout)
384            .connect()
385            .await
386            .map_err(|e| {
387                error!(
388                    agent_id = %id,
389                    address = %address,
390                    error = %e,
391                    "Failed to connect to agent via gRPC with TLS"
392                );
393                AgentProtocolError::ConnectionFailed(format!("gRPC TLS connect failed: {}", e))
394            })?;
395
396        let client = AgentProcessorClient::new(channel);
397
398        debug!(
399            agent_id = %id,
400            address = %address,
401            "Connected to agent via gRPC with TLS"
402        );
403
404        Ok(Self {
405            id,
406            connection: AgentConnection::Grpc(client),
407            timeout,
408            max_retries: 3,
409        })
410    }
411
412    /// Extract domain name from a URL for TLS SNI
413    fn extract_domain(address: &str) -> Option<String> {
414        // Try to parse as URL and extract host
415        let address = address.trim();
416
417        // Handle URLs like "https://example.com:443" or "http://example.com:8080"
418        if let Some(rest) = address
419            .strip_prefix("https://")
420            .or_else(|| address.strip_prefix("http://"))
421        {
422            // Split off port and path
423            let host = rest.split(':').next()?.split('/').next()?;
424            if !host.is_empty() {
425                return Some(host.to_string());
426            }
427        }
428
429        None
430    }
431
432    /// Create a new HTTP agent client
433    ///
434    /// # Arguments
435    /// * `id` - Agent identifier
436    /// * `url` - HTTP endpoint URL (e.g., "http://localhost:8080/agent")
437    /// * `timeout` - Timeout for agent calls
438    pub async fn http(
439        id: impl Into<String>,
440        url: impl Into<String>,
441        timeout: Duration,
442    ) -> Result<Self, AgentProtocolError> {
443        let id = id.into();
444        let url = url.into();
445
446        trace!(
447            agent_id = %id,
448            url = %url,
449            timeout_ms = timeout.as_millis() as u64,
450            "Creating HTTP agent client"
451        );
452
453        let client = reqwest::Client::builder()
454            .timeout(timeout)
455            .build()
456            .map_err(|e| {
457                error!(
458                    agent_id = %id,
459                    url = %url,
460                    error = %e,
461                    "Failed to create HTTP client"
462                );
463                AgentProtocolError::ConnectionFailed(format!("HTTP client error: {}", e))
464            })?;
465
466        debug!(
467            agent_id = %id,
468            url = %url,
469            "HTTP agent client created"
470        );
471
472        Ok(Self {
473            id,
474            connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
475            timeout,
476            max_retries: 3,
477        })
478    }
479
480    /// Create a new HTTP agent client with TLS
481    ///
482    /// # Arguments
483    /// * `id` - Agent identifier
484    /// * `url` - HTTPS endpoint URL (e.g., `https://agent.internal:8443/agent`)
485    /// * `timeout` - Timeout for agent calls
486    /// * `tls_config` - TLS configuration
487    pub async fn http_tls(
488        id: impl Into<String>,
489        url: impl Into<String>,
490        timeout: Duration,
491        tls_config: HttpTlsConfig,
492    ) -> Result<Self, AgentProtocolError> {
493        let id = id.into();
494        let url = url.into();
495
496        trace!(
497            agent_id = %id,
498            url = %url,
499            timeout_ms = timeout.as_millis() as u64,
500            has_ca_cert = tls_config.ca_cert_pem.is_some(),
501            has_client_cert = tls_config.client_cert_pem.is_some(),
502            insecure = tls_config.insecure_skip_verify,
503            "Creating HTTP agent client with TLS"
504        );
505
506        let mut client_builder = reqwest::Client::builder().timeout(timeout).use_rustls_tls();
507
508        // Add CA certificate if provided
509        if let Some(ca_pem) = &tls_config.ca_cert_pem {
510            let ca_cert = reqwest::Certificate::from_pem(ca_pem).map_err(|e| {
511                error!(
512                    agent_id = %id,
513                    error = %e,
514                    "Failed to parse CA certificate"
515                );
516                AgentProtocolError::ConnectionFailed(format!("Invalid CA certificate: {}", e))
517            })?;
518            client_builder = client_builder.add_root_certificate(ca_cert);
519            debug!(
520                agent_id = %id,
521                "Using custom CA certificate for HTTP TLS"
522            );
523        }
524
525        // Add client identity for mTLS if provided
526        if let (Some(cert_pem), Some(key_pem)) =
527            (&tls_config.client_cert_pem, &tls_config.client_key_pem)
528        {
529            // Combine cert and key into identity PEM
530            let mut identity_pem = cert_pem.clone();
531            identity_pem.extend_from_slice(b"\n");
532            identity_pem.extend_from_slice(key_pem);
533
534            let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
535                error!(
536                    agent_id = %id,
537                    error = %e,
538                    "Failed to parse client certificate/key"
539                );
540                AgentProtocolError::ConnectionFailed(format!("Invalid client certificate: {}", e))
541            })?;
542            client_builder = client_builder.identity(identity);
543            debug!(
544                agent_id = %id,
545                "Using client certificate for mTLS to HTTP agent"
546            );
547        }
548
549        // Handle insecure skip verify (dangerous - only for testing)
550        if tls_config.insecure_skip_verify {
551            warn!(
552                agent_id = %id,
553                url = %url,
554                "SECURITY WARNING: TLS certificate verification disabled for HTTP agent connection"
555            );
556            client_builder = client_builder.danger_accept_invalid_certs(true);
557        }
558
559        let client = client_builder.build().map_err(|e| {
560            error!(
561                agent_id = %id,
562                url = %url,
563                error = %e,
564                "Failed to create HTTP TLS client"
565            );
566            AgentProtocolError::ConnectionFailed(format!("HTTP TLS client error: {}", e))
567        })?;
568
569        debug!(
570            agent_id = %id,
571            url = %url,
572            "HTTP agent client created with TLS"
573        );
574
575        Ok(Self {
576            id,
577            connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
578            timeout,
579            max_retries: 3,
580        })
581    }
582
583    /// Get the agent ID
584    #[allow(dead_code)]
585    pub fn id(&self) -> &str {
586        &self.id
587    }
588
589    /// Send an event to the agent and get a response
590    pub async fn send_event(
591        &mut self,
592        event_type: EventType,
593        payload: impl Serialize,
594    ) -> Result<AgentResponse, AgentProtocolError> {
595        // Clone HTTP connection Arc before match to avoid borrow issues
596        let http_conn = if let AgentConnection::Http(conn) = &self.connection {
597            Some(Arc::clone(conn))
598        } else {
599            None
600        };
601
602        match &mut self.connection {
603            AgentConnection::UnixSocket(_) => {
604                self.send_event_unix_socket(event_type, payload).await
605            }
606            AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
607            AgentConnection::Http(_) => {
608                // Use the cloned Arc
609                self.send_event_http(http_conn.unwrap(), event_type, payload)
610                    .await
611            }
612        }
613    }
614
615    /// Send event via Unix socket (length-prefixed JSON)
616    async fn send_event_unix_socket(
617        &mut self,
618        event_type: EventType,
619        payload: impl Serialize,
620    ) -> Result<AgentResponse, AgentProtocolError> {
621        let request = AgentRequest {
622            version: PROTOCOL_VERSION,
623            event_type,
624            payload: serde_json::to_value(payload)
625                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
626        };
627
628        // Serialize request
629        let request_bytes = serde_json::to_vec(&request)
630            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
631
632        // Check message size
633        if request_bytes.len() > MAX_MESSAGE_SIZE {
634            return Err(AgentProtocolError::MessageTooLarge {
635                size: request_bytes.len(),
636                max: MAX_MESSAGE_SIZE,
637            });
638        }
639
640        // Send with timeout
641        let response = tokio::time::timeout(self.timeout, async {
642            self.send_raw_unix(&request_bytes).await?;
643            self.receive_raw_unix().await
644        })
645        .await
646        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
647
648        // Parse response
649        let agent_response: AgentResponse = serde_json::from_slice(&response)
650            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
651
652        // Verify protocol version
653        if agent_response.version != PROTOCOL_VERSION {
654            return Err(AgentProtocolError::VersionMismatch {
655                expected: PROTOCOL_VERSION,
656                actual: agent_response.version,
657            });
658        }
659
660        Ok(agent_response)
661    }
662
663    /// Send event via gRPC
664    async fn send_event_grpc(
665        &mut self,
666        event_type: EventType,
667        payload: impl Serialize,
668    ) -> Result<AgentResponse, AgentProtocolError> {
669        // Build request first (doesn't need mutable borrow)
670        let grpc_request = Self::build_grpc_request(event_type, payload)?;
671
672        let AgentConnection::Grpc(client) = &mut self.connection else {
673            return Err(AgentProtocolError::WrongConnectionType(
674                "Expected gRPC connection but found Unix socket".to_string(),
675            ));
676        };
677
678        // Send with timeout
679        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
680            .await
681            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
682            .map_err(|e| {
683                AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
684            })?;
685
686        // Convert gRPC response to internal format
687        Self::convert_grpc_response(response.into_inner())
688    }
689
690    /// Send event via HTTP POST (JSON)
691    async fn send_event_http(
692        &self,
693        conn: Arc<HttpConnection>,
694        event_type: EventType,
695        payload: impl Serialize,
696    ) -> Result<AgentResponse, AgentProtocolError> {
697        let request = AgentRequest {
698            version: PROTOCOL_VERSION,
699            event_type,
700            payload: serde_json::to_value(payload)
701                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
702        };
703
704        // Serialize request
705        let request_json = serde_json::to_string(&request)
706            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
707
708        // Check message size
709        if request_json.len() > MAX_MESSAGE_SIZE {
710            return Err(AgentProtocolError::MessageTooLarge {
711                size: request_json.len(),
712                max: MAX_MESSAGE_SIZE,
713            });
714        }
715
716        trace!(
717            agent_id = %self.id,
718            url = %conn.url,
719            event_type = ?event_type,
720            request_size = request_json.len(),
721            "Sending HTTP request to agent"
722        );
723
724        // Send HTTP POST request
725        let response = conn
726            .client
727            .post(&conn.url)
728            .header("Content-Type", "application/json")
729            .header("X-Sentinel-Protocol-Version", PROTOCOL_VERSION.to_string())
730            .body(request_json)
731            .send()
732            .await
733            .map_err(|e| {
734                error!(
735                    agent_id = %self.id,
736                    url = %conn.url,
737                    error = %e,
738                    "HTTP request to agent failed"
739                );
740                if e.is_timeout() {
741                    AgentProtocolError::Timeout(self.timeout)
742                } else if e.is_connect() {
743                    AgentProtocolError::ConnectionFailed(format!("HTTP connect failed: {}", e))
744                } else {
745                    AgentProtocolError::ConnectionFailed(format!("HTTP request failed: {}", e))
746                }
747            })?;
748
749        // Check HTTP status
750        let status = response.status();
751        if !status.is_success() {
752            let body = response.text().await.unwrap_or_default();
753            error!(
754                agent_id = %self.id,
755                url = %conn.url,
756                status = %status,
757                body = %body,
758                "Agent returned HTTP error"
759            );
760            return Err(AgentProtocolError::ConnectionFailed(format!(
761                "HTTP {} from agent: {}",
762                status, body
763            )));
764        }
765
766        // Parse response
767        let response_bytes = response.bytes().await.map_err(|e| {
768            AgentProtocolError::ConnectionFailed(format!("Failed to read response body: {}", e))
769        })?;
770
771        // Check response size
772        if response_bytes.len() > MAX_MESSAGE_SIZE {
773            return Err(AgentProtocolError::MessageTooLarge {
774                size: response_bytes.len(),
775                max: MAX_MESSAGE_SIZE,
776            });
777        }
778
779        let agent_response: AgentResponse =
780            serde_json::from_slice(&response_bytes).map_err(|e| {
781                AgentProtocolError::InvalidMessage(format!("Invalid JSON response: {}", e))
782            })?;
783
784        // Verify protocol version
785        if agent_response.version != PROTOCOL_VERSION {
786            return Err(AgentProtocolError::VersionMismatch {
787                expected: PROTOCOL_VERSION,
788                actual: agent_response.version,
789            });
790        }
791
792        trace!(
793            agent_id = %self.id,
794            decision = ?agent_response.decision,
795            "Received HTTP response from agent"
796        );
797
798        Ok(agent_response)
799    }
800
801    /// Build a gRPC request from internal types
802    fn build_grpc_request(
803        event_type: EventType,
804        payload: impl Serialize,
805    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
806        let payload_json = serde_json::to_value(&payload)
807            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
808
809        let grpc_event_type = match event_type {
810            EventType::Configure => {
811                return Err(AgentProtocolError::Serialization(
812                    "Configure events are not supported via gRPC".to_string(),
813                ))
814            }
815            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
816            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
817            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
818            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
819            EventType::RequestComplete => grpc::EventType::RequestComplete,
820            EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
821            EventType::GuardrailInspect => {
822                return Err(AgentProtocolError::Serialization(
823                    "GuardrailInspect events are not yet supported via gRPC".to_string(),
824                ))
825            }
826        };
827
828        let event = match event_type {
829            EventType::Configure => {
830                return Err(AgentProtocolError::InvalidMessage(
831                    "Configure event should be handled separately".to_string(),
832                ));
833            }
834            EventType::RequestHeaders => {
835                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
836                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
837                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
838                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
839                    method: event.method,
840                    uri: event.uri,
841                    headers: event
842                        .headers
843                        .into_iter()
844                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
845                        .collect(),
846                })
847            }
848            EventType::RequestBodyChunk => {
849                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
850                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
851                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
852                    correlation_id: event.correlation_id,
853                    data: event.data.into_bytes(),
854                    is_last: event.is_last,
855                    total_size: event.total_size.map(|s| s as u64),
856                    chunk_index: event.chunk_index,
857                    bytes_received: event.bytes_received as u64,
858                })
859            }
860            EventType::ResponseHeaders => {
861                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
862                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
863                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
864                    correlation_id: event.correlation_id,
865                    status: event.status as u32,
866                    headers: event
867                        .headers
868                        .into_iter()
869                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
870                        .collect(),
871                })
872            }
873            EventType::ResponseBodyChunk => {
874                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
875                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
876                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
877                    correlation_id: event.correlation_id,
878                    data: event.data.into_bytes(),
879                    is_last: event.is_last,
880                    total_size: event.total_size.map(|s| s as u64),
881                    chunk_index: event.chunk_index,
882                    bytes_sent: event.bytes_sent as u64,
883                })
884            }
885            EventType::RequestComplete => {
886                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
887                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
888                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
889                    correlation_id: event.correlation_id,
890                    status: event.status as u32,
891                    duration_ms: event.duration_ms,
892                    request_body_size: event.request_body_size as u64,
893                    response_body_size: event.response_body_size as u64,
894                    upstream_attempts: event.upstream_attempts,
895                    error: event.error,
896                })
897            }
898            EventType::WebSocketFrame => {
899                use base64::{engine::general_purpose::STANDARD, Engine as _};
900                let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
901                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
902                grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
903                    correlation_id: event.correlation_id,
904                    opcode: event.opcode,
905                    data: STANDARD.decode(&event.data).unwrap_or_default(),
906                    client_to_server: event.client_to_server,
907                    frame_index: event.frame_index,
908                    fin: event.fin,
909                    route_id: event.route_id,
910                    client_ip: event.client_ip,
911                })
912            }
913            EventType::GuardrailInspect => {
914                return Err(AgentProtocolError::InvalidMessage(
915                    "GuardrailInspect events are not yet supported via gRPC".to_string(),
916                ));
917            }
918        };
919
920        Ok(grpc::AgentRequest {
921            version: PROTOCOL_VERSION,
922            event_type: grpc_event_type as i32,
923            event: Some(event),
924        })
925    }
926
927    /// Convert internal metadata to gRPC format
928    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
929        grpc::RequestMetadata {
930            correlation_id: metadata.correlation_id.clone(),
931            request_id: metadata.request_id.clone(),
932            client_ip: metadata.client_ip.clone(),
933            client_port: metadata.client_port as u32,
934            server_name: metadata.server_name.clone(),
935            protocol: metadata.protocol.clone(),
936            tls_version: metadata.tls_version.clone(),
937            tls_cipher: metadata.tls_cipher.clone(),
938            route_id: metadata.route_id.clone(),
939            upstream_id: metadata.upstream_id.clone(),
940            timestamp: metadata.timestamp.clone(),
941            traceparent: metadata.traceparent.clone(),
942        }
943    }
944
945    /// Convert gRPC response to internal format
946    fn convert_grpc_response(
947        response: grpc::AgentResponse,
948    ) -> Result<AgentResponse, AgentProtocolError> {
949        let decision = match response.decision {
950            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
951            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
952                status: b.status as u16,
953                body: b.body,
954                headers: if b.headers.is_empty() {
955                    None
956                } else {
957                    Some(b.headers)
958                },
959            },
960            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
961                url: r.url,
962                status: r.status as u16,
963            },
964            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
965                challenge_type: c.challenge_type,
966                params: c.params,
967            },
968            None => Decision::Allow, // Default to allow if no decision
969        };
970
971        let request_headers: Vec<HeaderOp> = response
972            .request_headers
973            .into_iter()
974            .filter_map(Self::convert_header_op_from_grpc)
975            .collect();
976
977        let response_headers: Vec<HeaderOp> = response
978            .response_headers
979            .into_iter()
980            .filter_map(Self::convert_header_op_from_grpc)
981            .collect();
982
983        let audit = response.audit.map(|a| AuditMetadata {
984            tags: a.tags,
985            rule_ids: a.rule_ids,
986            confidence: a.confidence,
987            reason_codes: a.reason_codes,
988            custom: a
989                .custom
990                .into_iter()
991                .map(|(k, v)| (k, serde_json::Value::String(v)))
992                .collect(),
993        });
994
995        // Convert body mutations
996        let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
997            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
998            chunk_index: m.chunk_index,
999        });
1000
1001        let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
1002            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
1003            chunk_index: m.chunk_index,
1004        });
1005
1006        // Convert WebSocket decision
1007        let websocket_decision = response
1008            .websocket_decision
1009            .map(|ws_decision| match ws_decision {
1010                grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
1011                    WebSocketDecision::Allow
1012                }
1013                grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
1014                    WebSocketDecision::Drop
1015                }
1016                grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
1017                    WebSocketDecision::Close {
1018                        code: c.code as u16,
1019                        reason: c.reason,
1020                    }
1021                }
1022            });
1023
1024        Ok(AgentResponse {
1025            version: response.version,
1026            decision,
1027            request_headers,
1028            response_headers,
1029            routing_metadata: response.routing_metadata,
1030            audit: audit.unwrap_or_default(),
1031            needs_more: response.needs_more,
1032            request_body_mutation,
1033            response_body_mutation,
1034            websocket_decision,
1035        })
1036    }
1037
1038    /// Convert gRPC header operation to internal format
1039    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
1040        match op.operation? {
1041            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
1042                name: s.name,
1043                value: s.value,
1044            }),
1045            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
1046                name: a.name,
1047                value: a.value,
1048            }),
1049            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
1050        }
1051    }
1052
1053    /// Send raw bytes to agent (Unix socket only)
1054    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
1055        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1056            return Err(AgentProtocolError::WrongConnectionType(
1057                "Expected Unix socket connection but found gRPC".to_string(),
1058            ));
1059        };
1060        // Write message length (4 bytes, big-endian)
1061        let len_bytes = (data.len() as u32).to_be_bytes();
1062        stream.write_all(&len_bytes).await?;
1063        // Write message data
1064        stream.write_all(data).await?;
1065        stream.flush().await?;
1066        Ok(())
1067    }
1068
1069    /// Receive raw bytes from agent (Unix socket only)
1070    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
1071        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1072            return Err(AgentProtocolError::WrongConnectionType(
1073                "Expected Unix socket connection but found gRPC".to_string(),
1074            ));
1075        };
1076        // Read message length (4 bytes, big-endian)
1077        let mut len_bytes = [0u8; 4];
1078        stream.read_exact(&mut len_bytes).await?;
1079        let message_len = u32::from_be_bytes(len_bytes) as usize;
1080
1081        // Check message size
1082        if message_len > MAX_MESSAGE_SIZE {
1083            return Err(AgentProtocolError::MessageTooLarge {
1084                size: message_len,
1085                max: MAX_MESSAGE_SIZE,
1086            });
1087        }
1088
1089        // Read message data
1090        let mut buffer = vec![0u8; message_len];
1091        stream.read_exact(&mut buffer).await?;
1092        Ok(buffer)
1093    }
1094
1095    /// Close the agent connection
1096    pub async fn close(self) -> Result<(), AgentProtocolError> {
1097        match self.connection {
1098            AgentConnection::UnixSocket(mut stream) => {
1099                stream.shutdown().await?;
1100                Ok(())
1101            }
1102            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
1103            AgentConnection::Http(_) => Ok(()), // HTTP clients are stateless, no cleanup needed
1104        }
1105    }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111
1112    #[test]
1113    fn test_extract_domain_https() {
1114        assert_eq!(
1115            AgentClient::extract_domain("https://example.com:443"),
1116            Some("example.com".to_string())
1117        );
1118        assert_eq!(
1119            AgentClient::extract_domain("https://agent.internal:50051"),
1120            Some("agent.internal".to_string())
1121        );
1122        assert_eq!(
1123            AgentClient::extract_domain("https://localhost:8080/path"),
1124            Some("localhost".to_string())
1125        );
1126    }
1127
1128    #[test]
1129    fn test_extract_domain_http() {
1130        assert_eq!(
1131            AgentClient::extract_domain("http://example.com:8080"),
1132            Some("example.com".to_string())
1133        );
1134        assert_eq!(
1135            AgentClient::extract_domain("http://localhost:50051"),
1136            Some("localhost".to_string())
1137        );
1138    }
1139
1140    #[test]
1141    fn test_extract_domain_invalid() {
1142        assert_eq!(AgentClient::extract_domain("example.com:443"), None);
1143        assert_eq!(AgentClient::extract_domain("tcp://example.com:443"), None);
1144        assert_eq!(AgentClient::extract_domain(""), None);
1145    }
1146
1147    #[test]
1148    fn test_grpc_tls_config_builder() {
1149        let config = GrpcTlsConfig::new()
1150            .with_ca_cert_pem(b"test-ca-cert".to_vec())
1151            .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec())
1152            .with_domain_name("example.com");
1153
1154        assert!(config.ca_cert_pem.is_some());
1155        assert!(config.client_cert_pem.is_some());
1156        assert!(config.client_key_pem.is_some());
1157        assert_eq!(config.domain_name, Some("example.com".to_string()));
1158        assert!(!config.insecure_skip_verify);
1159    }
1160
1161    #[test]
1162    fn test_grpc_tls_config_insecure() {
1163        let config = GrpcTlsConfig::new().with_insecure_skip_verify();
1164
1165        assert!(config.insecure_skip_verify);
1166        assert!(config.ca_cert_pem.is_none());
1167    }
1168
1169    #[test]
1170    fn test_http_tls_config_builder() {
1171        let config = HttpTlsConfig::new()
1172            .with_ca_cert_pem(b"test-ca-cert".to_vec())
1173            .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec());
1174
1175        assert!(config.ca_cert_pem.is_some());
1176        assert!(config.client_cert_pem.is_some());
1177        assert!(config.client_key_pem.is_some());
1178        assert!(!config.insecure_skip_verify);
1179    }
1180
1181    #[test]
1182    fn test_http_tls_config_insecure() {
1183        let config = HttpTlsConfig::new().with_insecure_skip_verify();
1184
1185        assert!(config.insecure_skip_verify);
1186        assert!(config.ca_cert_pem.is_none());
1187    }
1188
1189    #[tokio::test]
1190    async fn test_http_client_creation() {
1191        // Test that we can create an HTTP client (doesn't actually connect)
1192        let result = AgentClient::http(
1193            "test-agent",
1194            "http://localhost:9999/agent",
1195            Duration::from_secs(5),
1196        )
1197        .await;
1198
1199        // Client should be created successfully (connection happens on first request)
1200        assert!(result.is_ok());
1201    }
1202}