sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports three transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2, with optional TLS)
6//! - HTTP REST (JSON over HTTP/1.1 or HTTP/2, with optional TLS)
7
8use serde::Serialize;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::UnixStream;
14use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
15use tracing::{debug, error, trace, warn};
16
17use crate::errors::AgentProtocolError;
18use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
19use crate::protocol::{
20    AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
21    RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
22    ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
23    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
24};
25
26/// TLS configuration for gRPC agent connections
27#[derive(Debug, Clone, Default)]
28pub struct GrpcTlsConfig {
29    /// Skip certificate verification (DANGEROUS - only for testing)
30    pub insecure_skip_verify: bool,
31    /// CA certificate PEM data for verifying the server
32    pub ca_cert_pem: Option<Vec<u8>>,
33    /// Client certificate PEM data for mTLS
34    pub client_cert_pem: Option<Vec<u8>>,
35    /// Client key PEM data for mTLS
36    pub client_key_pem: Option<Vec<u8>>,
37    /// Domain name to use for TLS SNI and certificate validation
38    pub domain_name: Option<String>,
39}
40
41impl GrpcTlsConfig {
42    /// Create a new TLS config builder
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Load CA certificate from a file
48    pub async fn with_ca_cert_file(mut self, path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
49        self.ca_cert_pem = Some(tokio::fs::read(path).await?);
50        Ok(self)
51    }
52
53    /// Set CA certificate from PEM data
54    pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
55        self.ca_cert_pem = Some(pem.into());
56        self
57    }
58
59    /// Load client certificate and key from files (for mTLS)
60    pub async fn with_client_cert_files(
61        mut self,
62        cert_path: impl AsRef<Path>,
63        key_path: impl AsRef<Path>,
64    ) -> Result<Self, std::io::Error> {
65        self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
66        self.client_key_pem = Some(tokio::fs::read(key_path).await?);
67        Ok(self)
68    }
69
70    /// Set client certificate and key from PEM data (for mTLS)
71    pub fn with_client_identity(mut self, cert_pem: impl Into<Vec<u8>>, key_pem: impl Into<Vec<u8>>) -> Self {
72        self.client_cert_pem = Some(cert_pem.into());
73        self.client_key_pem = Some(key_pem.into());
74        self
75    }
76
77    /// Set the domain name for TLS SNI and certificate validation
78    pub fn with_domain_name(mut self, domain: impl Into<String>) -> Self {
79        self.domain_name = Some(domain.into());
80        self
81    }
82
83    /// Skip certificate verification (DANGEROUS - only for testing)
84    pub fn with_insecure_skip_verify(mut self) -> Self {
85        self.insecure_skip_verify = true;
86        self
87    }
88}
89
90/// TLS configuration for HTTP agent connections
91#[derive(Debug, Clone, Default)]
92pub struct HttpTlsConfig {
93    /// Skip certificate verification (DANGEROUS - only for testing)
94    pub insecure_skip_verify: bool,
95    /// CA certificate PEM data for verifying the server
96    pub ca_cert_pem: Option<Vec<u8>>,
97    /// Client certificate PEM data for mTLS
98    pub client_cert_pem: Option<Vec<u8>>,
99    /// Client key PEM data for mTLS
100    pub client_key_pem: Option<Vec<u8>>,
101}
102
103impl HttpTlsConfig {
104    /// Create a new TLS config builder
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Load CA certificate from a file
110    pub async fn with_ca_cert_file(mut self, path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
111        self.ca_cert_pem = Some(tokio::fs::read(path).await?);
112        Ok(self)
113    }
114
115    /// Set CA certificate from PEM data
116    pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
117        self.ca_cert_pem = Some(pem.into());
118        self
119    }
120
121    /// Load client certificate and key from files (for mTLS)
122    pub async fn with_client_cert_files(
123        mut self,
124        cert_path: impl AsRef<Path>,
125        key_path: impl AsRef<Path>,
126    ) -> Result<Self, std::io::Error> {
127        self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
128        self.client_key_pem = Some(tokio::fs::read(key_path).await?);
129        Ok(self)
130    }
131
132    /// Set client certificate and key from PEM data (for mTLS)
133    pub fn with_client_identity(mut self, cert_pem: impl Into<Vec<u8>>, key_pem: impl Into<Vec<u8>>) -> Self {
134        self.client_cert_pem = Some(cert_pem.into());
135        self.client_key_pem = Some(key_pem.into());
136        self
137    }
138
139    /// Skip certificate verification (DANGEROUS - only for testing)
140    pub fn with_insecure_skip_verify(mut self) -> Self {
141        self.insecure_skip_verify = true;
142        self
143    }
144}
145
146/// HTTP connection details
147struct HttpConnection {
148    /// HTTP client
149    client: reqwest::Client,
150    /// Base URL for the agent endpoint
151    url: String,
152}
153
154/// Agent client for communicating with external agents
155pub struct AgentClient {
156    /// Agent ID
157    id: String,
158    /// Connection to agent
159    connection: AgentConnection,
160    /// Timeout for agent calls
161    timeout: Duration,
162    /// Maximum retries
163    #[allow(dead_code)]
164    max_retries: u32,
165}
166
167/// Agent connection type
168enum AgentConnection {
169    UnixSocket(UnixStream),
170    Grpc(AgentProcessorClient<Channel>),
171    Http(Arc<HttpConnection>),
172}
173
174impl AgentClient {
175    /// Create a new Unix socket agent client
176    pub async fn unix_socket(
177        id: impl Into<String>,
178        path: impl AsRef<std::path::Path>,
179        timeout: Duration,
180    ) -> Result<Self, AgentProtocolError> {
181        let id = id.into();
182        let path = path.as_ref();
183
184        trace!(
185            agent_id = %id,
186            socket_path = %path.display(),
187            timeout_ms = timeout.as_millis() as u64,
188            "Connecting to agent via Unix socket"
189        );
190
191        let stream = UnixStream::connect(path).await.map_err(|e| {
192            error!(
193                agent_id = %id,
194                socket_path = %path.display(),
195                error = %e,
196                "Failed to connect to agent via Unix socket"
197            );
198            AgentProtocolError::ConnectionFailed(e.to_string())
199        })?;
200
201        debug!(
202            agent_id = %id,
203            socket_path = %path.display(),
204            "Connected to agent via Unix socket"
205        );
206
207        Ok(Self {
208            id,
209            connection: AgentConnection::UnixSocket(stream),
210            timeout,
211            max_retries: 3,
212        })
213    }
214
215    /// Create a new gRPC agent client
216    ///
217    /// # Arguments
218    /// * `id` - Agent identifier
219    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
220    /// * `timeout` - Timeout for agent calls
221    pub async fn grpc(
222        id: impl Into<String>,
223        address: impl Into<String>,
224        timeout: Duration,
225    ) -> Result<Self, AgentProtocolError> {
226        let id = id.into();
227        let address = address.into();
228
229        trace!(
230            agent_id = %id,
231            address = %address,
232            timeout_ms = timeout.as_millis() as u64,
233            "Connecting to agent via gRPC"
234        );
235
236        let channel = Channel::from_shared(address.clone())
237            .map_err(|e| {
238                error!(
239                    agent_id = %id,
240                    address = %address,
241                    error = %e,
242                    "Invalid gRPC URI"
243                );
244                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
245            })?
246            .timeout(timeout)
247            .connect()
248            .await
249            .map_err(|e| {
250                error!(
251                    agent_id = %id,
252                    address = %address,
253                    error = %e,
254                    "Failed to connect to agent via gRPC"
255                );
256                AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
257            })?;
258
259        let client = AgentProcessorClient::new(channel);
260
261        debug!(
262            agent_id = %id,
263            address = %address,
264            "Connected to agent via gRPC"
265        );
266
267        Ok(Self {
268            id,
269            connection: AgentConnection::Grpc(client),
270            timeout,
271            max_retries: 3,
272        })
273    }
274
275    /// Create a new gRPC agent client with TLS
276    ///
277    /// # Arguments
278    /// * `id` - Agent identifier
279    /// * `address` - gRPC server address (e.g., "https://localhost:50051")
280    /// * `timeout` - Timeout for agent calls
281    /// * `tls_config` - TLS configuration
282    pub async fn grpc_tls(
283        id: impl Into<String>,
284        address: impl Into<String>,
285        timeout: Duration,
286        tls_config: GrpcTlsConfig,
287    ) -> Result<Self, AgentProtocolError> {
288        let id = id.into();
289        let address = address.into();
290
291        trace!(
292            agent_id = %id,
293            address = %address,
294            timeout_ms = timeout.as_millis() as u64,
295            has_ca_cert = tls_config.ca_cert_pem.is_some(),
296            has_client_cert = tls_config.client_cert_pem.is_some(),
297            insecure = tls_config.insecure_skip_verify,
298            "Connecting to agent via gRPC with TLS"
299        );
300
301        // Build TLS config
302        let mut client_tls_config = ClientTlsConfig::new();
303
304        // Set domain name for SNI if provided, otherwise extract from address
305        if let Some(domain) = &tls_config.domain_name {
306            client_tls_config = client_tls_config.domain_name(domain.clone());
307        } else {
308            // Try to extract domain from address URL
309            if let Some(domain) = Self::extract_domain(&address) {
310                client_tls_config = client_tls_config.domain_name(domain);
311            }
312        }
313
314        // Add CA certificate if provided
315        if let Some(ca_pem) = &tls_config.ca_cert_pem {
316            let ca_cert = Certificate::from_pem(ca_pem);
317            client_tls_config = client_tls_config.ca_certificate(ca_cert);
318            debug!(
319                agent_id = %id,
320                "Using custom CA certificate for gRPC TLS"
321            );
322        }
323
324        // Add client identity for mTLS if provided
325        if let (Some(cert_pem), Some(key_pem)) = (&tls_config.client_cert_pem, &tls_config.client_key_pem) {
326            let identity = Identity::from_pem(cert_pem, key_pem);
327            client_tls_config = client_tls_config.identity(identity);
328            debug!(
329                agent_id = %id,
330                "Using client certificate for mTLS to gRPC agent"
331            );
332        }
333
334        // Handle insecure skip verify (dangerous - only for testing)
335        if tls_config.insecure_skip_verify {
336            warn!(
337                agent_id = %id,
338                address = %address,
339                "SECURITY WARNING: TLS certificate verification disabled for gRPC agent connection"
340            );
341            // Note: tonic doesn't have a direct "skip verify" option like some other libraries
342            // The best we can do is use a permissive TLS config. For truly insecure connections,
343            // users should use the non-TLS grpc() method instead.
344        }
345
346        // Build channel with TLS
347        let channel = Channel::from_shared(address.clone())
348            .map_err(|e| {
349                error!(
350                    agent_id = %id,
351                    address = %address,
352                    error = %e,
353                    "Invalid gRPC URI"
354                );
355                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
356            })?
357            .tls_config(client_tls_config)
358            .map_err(|e| {
359                error!(
360                    agent_id = %id,
361                    address = %address,
362                    error = %e,
363                    "Invalid TLS configuration"
364                );
365                AgentProtocolError::ConnectionFailed(format!("TLS config error: {}", e))
366            })?
367            .timeout(timeout)
368            .connect()
369            .await
370            .map_err(|e| {
371                error!(
372                    agent_id = %id,
373                    address = %address,
374                    error = %e,
375                    "Failed to connect to agent via gRPC with TLS"
376                );
377                AgentProtocolError::ConnectionFailed(format!("gRPC TLS connect failed: {}", e))
378            })?;
379
380        let client = AgentProcessorClient::new(channel);
381
382        debug!(
383            agent_id = %id,
384            address = %address,
385            "Connected to agent via gRPC with TLS"
386        );
387
388        Ok(Self {
389            id,
390            connection: AgentConnection::Grpc(client),
391            timeout,
392            max_retries: 3,
393        })
394    }
395
396    /// Extract domain name from a URL for TLS SNI
397    fn extract_domain(address: &str) -> Option<String> {
398        // Try to parse as URL and extract host
399        let address = address.trim();
400
401        // Handle URLs like "https://example.com:443" or "http://example.com:8080"
402        if let Some(rest) = address.strip_prefix("https://").or_else(|| address.strip_prefix("http://")) {
403            // Split off port and path
404            let host = rest.split(':').next()?.split('/').next()?;
405            if !host.is_empty() {
406                return Some(host.to_string());
407            }
408        }
409
410        None
411    }
412
413    /// Create a new HTTP agent client
414    ///
415    /// # Arguments
416    /// * `id` - Agent identifier
417    /// * `url` - HTTP endpoint URL (e.g., "http://localhost:8080/agent")
418    /// * `timeout` - Timeout for agent calls
419    pub async fn http(
420        id: impl Into<String>,
421        url: impl Into<String>,
422        timeout: Duration,
423    ) -> Result<Self, AgentProtocolError> {
424        let id = id.into();
425        let url = url.into();
426
427        trace!(
428            agent_id = %id,
429            url = %url,
430            timeout_ms = timeout.as_millis() as u64,
431            "Creating HTTP agent client"
432        );
433
434        let client = reqwest::Client::builder()
435            .timeout(timeout)
436            .build()
437            .map_err(|e| {
438                error!(
439                    agent_id = %id,
440                    url = %url,
441                    error = %e,
442                    "Failed to create HTTP client"
443                );
444                AgentProtocolError::ConnectionFailed(format!("HTTP client error: {}", e))
445            })?;
446
447        debug!(
448            agent_id = %id,
449            url = %url,
450            "HTTP agent client created"
451        );
452
453        Ok(Self {
454            id,
455            connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
456            timeout,
457            max_retries: 3,
458        })
459    }
460
461    /// Create a new HTTP agent client with TLS
462    ///
463    /// # Arguments
464    /// * `id` - Agent identifier
465    /// * `url` - HTTPS endpoint URL (e.g., "https://agent.internal:8443/agent")
466    /// * `timeout` - Timeout for agent calls
467    /// * `tls_config` - TLS configuration
468    pub async fn http_tls(
469        id: impl Into<String>,
470        url: impl Into<String>,
471        timeout: Duration,
472        tls_config: HttpTlsConfig,
473    ) -> Result<Self, AgentProtocolError> {
474        let id = id.into();
475        let url = url.into();
476
477        trace!(
478            agent_id = %id,
479            url = %url,
480            timeout_ms = timeout.as_millis() as u64,
481            has_ca_cert = tls_config.ca_cert_pem.is_some(),
482            has_client_cert = tls_config.client_cert_pem.is_some(),
483            insecure = tls_config.insecure_skip_verify,
484            "Creating HTTP agent client with TLS"
485        );
486
487        let mut client_builder = reqwest::Client::builder()
488            .timeout(timeout)
489            .use_rustls_tls();
490
491        // Add CA certificate if provided
492        if let Some(ca_pem) = &tls_config.ca_cert_pem {
493            let ca_cert = reqwest::Certificate::from_pem(ca_pem).map_err(|e| {
494                error!(
495                    agent_id = %id,
496                    error = %e,
497                    "Failed to parse CA certificate"
498                );
499                AgentProtocolError::ConnectionFailed(format!("Invalid CA certificate: {}", e))
500            })?;
501            client_builder = client_builder.add_root_certificate(ca_cert);
502            debug!(
503                agent_id = %id,
504                "Using custom CA certificate for HTTP TLS"
505            );
506        }
507
508        // Add client identity for mTLS if provided
509        if let (Some(cert_pem), Some(key_pem)) = (&tls_config.client_cert_pem, &tls_config.client_key_pem) {
510            // Combine cert and key into identity PEM
511            let mut identity_pem = cert_pem.clone();
512            identity_pem.extend_from_slice(b"\n");
513            identity_pem.extend_from_slice(key_pem);
514
515            let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
516                error!(
517                    agent_id = %id,
518                    error = %e,
519                    "Failed to parse client certificate/key"
520                );
521                AgentProtocolError::ConnectionFailed(format!("Invalid client certificate: {}", e))
522            })?;
523            client_builder = client_builder.identity(identity);
524            debug!(
525                agent_id = %id,
526                "Using client certificate for mTLS to HTTP agent"
527            );
528        }
529
530        // Handle insecure skip verify (dangerous - only for testing)
531        if tls_config.insecure_skip_verify {
532            warn!(
533                agent_id = %id,
534                url = %url,
535                "SECURITY WARNING: TLS certificate verification disabled for HTTP agent connection"
536            );
537            client_builder = client_builder.danger_accept_invalid_certs(true);
538        }
539
540        let client = client_builder.build().map_err(|e| {
541            error!(
542                agent_id = %id,
543                url = %url,
544                error = %e,
545                "Failed to create HTTP TLS client"
546            );
547            AgentProtocolError::ConnectionFailed(format!("HTTP TLS client error: {}", e))
548        })?;
549
550        debug!(
551            agent_id = %id,
552            url = %url,
553            "HTTP agent client created with TLS"
554        );
555
556        Ok(Self {
557            id,
558            connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
559            timeout,
560            max_retries: 3,
561        })
562    }
563
564    /// Get the agent ID
565    #[allow(dead_code)]
566    pub fn id(&self) -> &str {
567        &self.id
568    }
569
570    /// Send an event to the agent and get a response
571    pub async fn send_event(
572        &mut self,
573        event_type: EventType,
574        payload: impl Serialize,
575    ) -> Result<AgentResponse, AgentProtocolError> {
576        // Clone HTTP connection Arc before match to avoid borrow issues
577        let http_conn = if let AgentConnection::Http(conn) = &self.connection {
578            Some(Arc::clone(conn))
579        } else {
580            None
581        };
582
583        match &mut self.connection {
584            AgentConnection::UnixSocket(_) => {
585                self.send_event_unix_socket(event_type, payload).await
586            }
587            AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
588            AgentConnection::Http(_) => {
589                // Use the cloned Arc
590                self.send_event_http(http_conn.unwrap(), event_type, payload).await
591            }
592        }
593    }
594
595    /// Send event via Unix socket (length-prefixed JSON)
596    async fn send_event_unix_socket(
597        &mut self,
598        event_type: EventType,
599        payload: impl Serialize,
600    ) -> Result<AgentResponse, AgentProtocolError> {
601        let request = AgentRequest {
602            version: PROTOCOL_VERSION,
603            event_type,
604            payload: serde_json::to_value(payload)
605                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
606        };
607
608        // Serialize request
609        let request_bytes = serde_json::to_vec(&request)
610            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
611
612        // Check message size
613        if request_bytes.len() > MAX_MESSAGE_SIZE {
614            return Err(AgentProtocolError::MessageTooLarge {
615                size: request_bytes.len(),
616                max: MAX_MESSAGE_SIZE,
617            });
618        }
619
620        // Send with timeout
621        let response = tokio::time::timeout(self.timeout, async {
622            self.send_raw_unix(&request_bytes).await?;
623            self.receive_raw_unix().await
624        })
625        .await
626        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
627
628        // Parse response
629        let agent_response: AgentResponse = serde_json::from_slice(&response)
630            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
631
632        // Verify protocol version
633        if agent_response.version != PROTOCOL_VERSION {
634            return Err(AgentProtocolError::VersionMismatch {
635                expected: PROTOCOL_VERSION,
636                actual: agent_response.version,
637            });
638        }
639
640        Ok(agent_response)
641    }
642
643    /// Send event via gRPC
644    async fn send_event_grpc(
645        &mut self,
646        event_type: EventType,
647        payload: impl Serialize,
648    ) -> Result<AgentResponse, AgentProtocolError> {
649        // Build request first (doesn't need mutable borrow)
650        let grpc_request = Self::build_grpc_request(event_type, payload)?;
651
652        let AgentConnection::Grpc(client) = &mut self.connection else {
653            return Err(AgentProtocolError::WrongConnectionType(
654                "Expected gRPC connection but found Unix socket".to_string()
655            ));
656        };
657
658        // Send with timeout
659        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
660            .await
661            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
662            .map_err(|e| {
663                AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
664            })?;
665
666        // Convert gRPC response to internal format
667        Self::convert_grpc_response(response.into_inner())
668    }
669
670    /// Send event via HTTP POST (JSON)
671    async fn send_event_http(
672        &self,
673        conn: Arc<HttpConnection>,
674        event_type: EventType,
675        payload: impl Serialize,
676    ) -> Result<AgentResponse, AgentProtocolError> {
677        let request = AgentRequest {
678            version: PROTOCOL_VERSION,
679            event_type,
680            payload: serde_json::to_value(payload)
681                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
682        };
683
684        // Serialize request
685        let request_json = serde_json::to_string(&request)
686            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
687
688        // Check message size
689        if request_json.len() > MAX_MESSAGE_SIZE {
690            return Err(AgentProtocolError::MessageTooLarge {
691                size: request_json.len(),
692                max: MAX_MESSAGE_SIZE,
693            });
694        }
695
696        trace!(
697            agent_id = %self.id,
698            url = %conn.url,
699            event_type = ?event_type,
700            request_size = request_json.len(),
701            "Sending HTTP request to agent"
702        );
703
704        // Send HTTP POST request
705        let response = conn
706            .client
707            .post(&conn.url)
708            .header("Content-Type", "application/json")
709            .header("X-Sentinel-Protocol-Version", PROTOCOL_VERSION.to_string())
710            .body(request_json)
711            .send()
712            .await
713            .map_err(|e| {
714                error!(
715                    agent_id = %self.id,
716                    url = %conn.url,
717                    error = %e,
718                    "HTTP request to agent failed"
719                );
720                if e.is_timeout() {
721                    AgentProtocolError::Timeout(self.timeout)
722                } else if e.is_connect() {
723                    AgentProtocolError::ConnectionFailed(format!("HTTP connect failed: {}", e))
724                } else {
725                    AgentProtocolError::ConnectionFailed(format!("HTTP request failed: {}", e))
726                }
727            })?;
728
729        // Check HTTP status
730        let status = response.status();
731        if !status.is_success() {
732            let body = response.text().await.unwrap_or_default();
733            error!(
734                agent_id = %self.id,
735                url = %conn.url,
736                status = %status,
737                body = %body,
738                "Agent returned HTTP error"
739            );
740            return Err(AgentProtocolError::ConnectionFailed(format!(
741                "HTTP {} from agent: {}",
742                status,
743                body
744            )));
745        }
746
747        // Parse response
748        let response_bytes = response.bytes().await.map_err(|e| {
749            AgentProtocolError::ConnectionFailed(format!("Failed to read response body: {}", e))
750        })?;
751
752        // Check response size
753        if response_bytes.len() > MAX_MESSAGE_SIZE {
754            return Err(AgentProtocolError::MessageTooLarge {
755                size: response_bytes.len(),
756                max: MAX_MESSAGE_SIZE,
757            });
758        }
759
760        let agent_response: AgentResponse = serde_json::from_slice(&response_bytes)
761            .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid JSON response: {}", e)))?;
762
763        // Verify protocol version
764        if agent_response.version != PROTOCOL_VERSION {
765            return Err(AgentProtocolError::VersionMismatch {
766                expected: PROTOCOL_VERSION,
767                actual: agent_response.version,
768            });
769        }
770
771        trace!(
772            agent_id = %self.id,
773            decision = ?agent_response.decision,
774            "Received HTTP response from agent"
775        );
776
777        Ok(agent_response)
778    }
779
780    /// Build a gRPC request from internal types
781    fn build_grpc_request(
782        event_type: EventType,
783        payload: impl Serialize,
784    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
785        let payload_json = serde_json::to_value(&payload)
786            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
787
788        let grpc_event_type = match event_type {
789            EventType::Configure => {
790                return Err(AgentProtocolError::Serialization(
791                    "Configure events are not supported via gRPC".to_string(),
792                ))
793            }
794            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
795            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
796            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
797            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
798            EventType::RequestComplete => grpc::EventType::RequestComplete,
799            EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
800            EventType::GuardrailInspect => {
801                return Err(AgentProtocolError::Serialization(
802                    "GuardrailInspect events are not yet supported via gRPC".to_string(),
803                ))
804            }
805        };
806
807        let event = match event_type {
808            EventType::Configure => {
809                return Err(AgentProtocolError::InvalidMessage(
810                    "Configure event should be handled separately".to_string()
811                ));
812            },
813            EventType::RequestHeaders => {
814                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
815                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
816                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
817                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
818                    method: event.method,
819                    uri: event.uri,
820                    headers: event
821                        .headers
822                        .into_iter()
823                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
824                        .collect(),
825                })
826            }
827            EventType::RequestBodyChunk => {
828                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
829                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
830                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
831                    correlation_id: event.correlation_id,
832                    data: event.data.into_bytes(),
833                    is_last: event.is_last,
834                    total_size: event.total_size.map(|s| s as u64),
835                    chunk_index: event.chunk_index,
836                    bytes_received: event.bytes_received as u64,
837                })
838            }
839            EventType::ResponseHeaders => {
840                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
841                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
842                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
843                    correlation_id: event.correlation_id,
844                    status: event.status as u32,
845                    headers: event
846                        .headers
847                        .into_iter()
848                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
849                        .collect(),
850                })
851            }
852            EventType::ResponseBodyChunk => {
853                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
854                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
855                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
856                    correlation_id: event.correlation_id,
857                    data: event.data.into_bytes(),
858                    is_last: event.is_last,
859                    total_size: event.total_size.map(|s| s as u64),
860                    chunk_index: event.chunk_index,
861                    bytes_sent: event.bytes_sent as u64,
862                })
863            }
864            EventType::RequestComplete => {
865                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
866                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
867                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
868                    correlation_id: event.correlation_id,
869                    status: event.status as u32,
870                    duration_ms: event.duration_ms,
871                    request_body_size: event.request_body_size as u64,
872                    response_body_size: event.response_body_size as u64,
873                    upstream_attempts: event.upstream_attempts,
874                    error: event.error,
875                })
876            }
877            EventType::WebSocketFrame => {
878                use base64::{engine::general_purpose::STANDARD, Engine as _};
879                let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
880                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
881                grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
882                    correlation_id: event.correlation_id,
883                    opcode: event.opcode,
884                    data: STANDARD.decode(&event.data).unwrap_or_default(),
885                    client_to_server: event.client_to_server,
886                    frame_index: event.frame_index,
887                    fin: event.fin,
888                    route_id: event.route_id,
889                    client_ip: event.client_ip,
890                })
891            }
892            EventType::GuardrailInspect => {
893                return Err(AgentProtocolError::InvalidMessage(
894                    "GuardrailInspect events are not yet supported via gRPC".to_string()
895                ));
896            }
897        };
898
899        Ok(grpc::AgentRequest {
900            version: PROTOCOL_VERSION,
901            event_type: grpc_event_type as i32,
902            event: Some(event),
903        })
904    }
905
906    /// Convert internal metadata to gRPC format
907    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
908        grpc::RequestMetadata {
909            correlation_id: metadata.correlation_id.clone(),
910            request_id: metadata.request_id.clone(),
911            client_ip: metadata.client_ip.clone(),
912            client_port: metadata.client_port as u32,
913            server_name: metadata.server_name.clone(),
914            protocol: metadata.protocol.clone(),
915            tls_version: metadata.tls_version.clone(),
916            tls_cipher: metadata.tls_cipher.clone(),
917            route_id: metadata.route_id.clone(),
918            upstream_id: metadata.upstream_id.clone(),
919            timestamp: metadata.timestamp.clone(),
920            traceparent: metadata.traceparent.clone(),
921        }
922    }
923
924    /// Convert gRPC response to internal format
925    fn convert_grpc_response(
926        response: grpc::AgentResponse,
927    ) -> Result<AgentResponse, AgentProtocolError> {
928        let decision = match response.decision {
929            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
930            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
931                status: b.status as u16,
932                body: b.body,
933                headers: if b.headers.is_empty() {
934                    None
935                } else {
936                    Some(b.headers)
937                },
938            },
939            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
940                url: r.url,
941                status: r.status as u16,
942            },
943            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
944                challenge_type: c.challenge_type,
945                params: c.params,
946            },
947            None => Decision::Allow, // Default to allow if no decision
948        };
949
950        let request_headers: Vec<HeaderOp> = response
951            .request_headers
952            .into_iter()
953            .filter_map(Self::convert_header_op_from_grpc)
954            .collect();
955
956        let response_headers: Vec<HeaderOp> = response
957            .response_headers
958            .into_iter()
959            .filter_map(Self::convert_header_op_from_grpc)
960            .collect();
961
962        let audit = response.audit.map(|a| AuditMetadata {
963            tags: a.tags,
964            rule_ids: a.rule_ids,
965            confidence: a.confidence,
966            reason_codes: a.reason_codes,
967            custom: a
968                .custom
969                .into_iter()
970                .map(|(k, v)| (k, serde_json::Value::String(v)))
971                .collect(),
972        });
973
974        // Convert body mutations
975        let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
976            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
977            chunk_index: m.chunk_index,
978        });
979
980        let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
981            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
982            chunk_index: m.chunk_index,
983        });
984
985        // Convert WebSocket decision
986        let websocket_decision = response
987            .websocket_decision
988            .map(|ws_decision| match ws_decision {
989                grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
990                    WebSocketDecision::Allow
991                }
992                grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
993                    WebSocketDecision::Drop
994                }
995                grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
996                    WebSocketDecision::Close {
997                        code: c.code as u16,
998                        reason: c.reason,
999                    }
1000                }
1001            });
1002
1003        Ok(AgentResponse {
1004            version: response.version,
1005            decision,
1006            request_headers,
1007            response_headers,
1008            routing_metadata: response.routing_metadata,
1009            audit: audit.unwrap_or_default(),
1010            needs_more: response.needs_more,
1011            request_body_mutation,
1012            response_body_mutation,
1013            websocket_decision,
1014        })
1015    }
1016
1017    /// Convert gRPC header operation to internal format
1018    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
1019        match op.operation? {
1020            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
1021                name: s.name,
1022                value: s.value,
1023            }),
1024            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
1025                name: a.name,
1026                value: a.value,
1027            }),
1028            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
1029        }
1030    }
1031
1032    /// Send raw bytes to agent (Unix socket only)
1033    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
1034        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1035            return Err(AgentProtocolError::WrongConnectionType(
1036                "Expected Unix socket connection but found gRPC".to_string()
1037            ));
1038        };
1039        // Write message length (4 bytes, big-endian)
1040        let len_bytes = (data.len() as u32).to_be_bytes();
1041        stream.write_all(&len_bytes).await?;
1042        // Write message data
1043        stream.write_all(data).await?;
1044        stream.flush().await?;
1045        Ok(())
1046    }
1047
1048    /// Receive raw bytes from agent (Unix socket only)
1049    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
1050        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1051            return Err(AgentProtocolError::WrongConnectionType(
1052                "Expected Unix socket connection but found gRPC".to_string()
1053            ));
1054        };
1055        // Read message length (4 bytes, big-endian)
1056        let mut len_bytes = [0u8; 4];
1057        stream.read_exact(&mut len_bytes).await?;
1058        let message_len = u32::from_be_bytes(len_bytes) as usize;
1059
1060        // Check message size
1061        if message_len > MAX_MESSAGE_SIZE {
1062            return Err(AgentProtocolError::MessageTooLarge {
1063                size: message_len,
1064                max: MAX_MESSAGE_SIZE,
1065            });
1066        }
1067
1068        // Read message data
1069        let mut buffer = vec![0u8; message_len];
1070        stream.read_exact(&mut buffer).await?;
1071        Ok(buffer)
1072    }
1073
1074    /// Close the agent connection
1075    pub async fn close(self) -> Result<(), AgentProtocolError> {
1076        match self.connection {
1077            AgentConnection::UnixSocket(mut stream) => {
1078                stream.shutdown().await?;
1079                Ok(())
1080            }
1081            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
1082            AgentConnection::Http(_) => Ok(()),  // HTTP clients are stateless, no cleanup needed
1083        }
1084    }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use super::*;
1090
1091    #[test]
1092    fn test_extract_domain_https() {
1093        assert_eq!(
1094            AgentClient::extract_domain("https://example.com:443"),
1095            Some("example.com".to_string())
1096        );
1097        assert_eq!(
1098            AgentClient::extract_domain("https://agent.internal:50051"),
1099            Some("agent.internal".to_string())
1100        );
1101        assert_eq!(
1102            AgentClient::extract_domain("https://localhost:8080/path"),
1103            Some("localhost".to_string())
1104        );
1105    }
1106
1107    #[test]
1108    fn test_extract_domain_http() {
1109        assert_eq!(
1110            AgentClient::extract_domain("http://example.com:8080"),
1111            Some("example.com".to_string())
1112        );
1113        assert_eq!(
1114            AgentClient::extract_domain("http://localhost:50051"),
1115            Some("localhost".to_string())
1116        );
1117    }
1118
1119    #[test]
1120    fn test_extract_domain_invalid() {
1121        assert_eq!(AgentClient::extract_domain("example.com:443"), None);
1122        assert_eq!(AgentClient::extract_domain("tcp://example.com:443"), None);
1123        assert_eq!(AgentClient::extract_domain(""), None);
1124    }
1125
1126    #[test]
1127    fn test_grpc_tls_config_builder() {
1128        let config = GrpcTlsConfig::new()
1129            .with_ca_cert_pem(b"test-ca-cert".to_vec())
1130            .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec())
1131            .with_domain_name("example.com");
1132
1133        assert!(config.ca_cert_pem.is_some());
1134        assert!(config.client_cert_pem.is_some());
1135        assert!(config.client_key_pem.is_some());
1136        assert_eq!(config.domain_name, Some("example.com".to_string()));
1137        assert!(!config.insecure_skip_verify);
1138    }
1139
1140    #[test]
1141    fn test_grpc_tls_config_insecure() {
1142        let config = GrpcTlsConfig::new().with_insecure_skip_verify();
1143
1144        assert!(config.insecure_skip_verify);
1145        assert!(config.ca_cert_pem.is_none());
1146    }
1147
1148    #[test]
1149    fn test_http_tls_config_builder() {
1150        let config = HttpTlsConfig::new()
1151            .with_ca_cert_pem(b"test-ca-cert".to_vec())
1152            .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec());
1153
1154        assert!(config.ca_cert_pem.is_some());
1155        assert!(config.client_cert_pem.is_some());
1156        assert!(config.client_key_pem.is_some());
1157        assert!(!config.insecure_skip_verify);
1158    }
1159
1160    #[test]
1161    fn test_http_tls_config_insecure() {
1162        let config = HttpTlsConfig::new().with_insecure_skip_verify();
1163
1164        assert!(config.insecure_skip_verify);
1165        assert!(config.ca_cert_pem.is_none());
1166    }
1167
1168    #[tokio::test]
1169    async fn test_http_client_creation() {
1170        // Test that we can create an HTTP client (doesn't actually connect)
1171        let result = AgentClient::http(
1172            "test-agent",
1173            "http://localhost:9999/agent",
1174            Duration::from_secs(5),
1175        )
1176        .await;
1177
1178        // Client should be created successfully (connection happens on first request)
1179        assert!(result.is_ok());
1180    }
1181}