1use serde::Serialize;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::UnixStream;
14use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
15use tracing::{debug, error, trace, warn};
16
17use crate::errors::AgentProtocolError;
18use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
19use crate::protocol::{
20 AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
21 RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
22 ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
23 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
24};
25
26#[derive(Debug, Clone, Default)]
28pub struct GrpcTlsConfig {
29 pub insecure_skip_verify: bool,
31 pub ca_cert_pem: Option<Vec<u8>>,
33 pub client_cert_pem: Option<Vec<u8>>,
35 pub client_key_pem: Option<Vec<u8>>,
37 pub domain_name: Option<String>,
39}
40
41impl GrpcTlsConfig {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub async fn with_ca_cert_file(
49 mut self,
50 path: impl AsRef<Path>,
51 ) -> Result<Self, std::io::Error> {
52 self.ca_cert_pem = Some(tokio::fs::read(path).await?);
53 Ok(self)
54 }
55
56 pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
58 self.ca_cert_pem = Some(pem.into());
59 self
60 }
61
62 pub async fn with_client_cert_files(
64 mut self,
65 cert_path: impl AsRef<Path>,
66 key_path: impl AsRef<Path>,
67 ) -> Result<Self, std::io::Error> {
68 self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
69 self.client_key_pem = Some(tokio::fs::read(key_path).await?);
70 Ok(self)
71 }
72
73 pub fn with_client_identity(
75 mut self,
76 cert_pem: impl Into<Vec<u8>>,
77 key_pem: impl Into<Vec<u8>>,
78 ) -> Self {
79 self.client_cert_pem = Some(cert_pem.into());
80 self.client_key_pem = Some(key_pem.into());
81 self
82 }
83
84 pub fn with_domain_name(mut self, domain: impl Into<String>) -> Self {
86 self.domain_name = Some(domain.into());
87 self
88 }
89
90 pub fn with_insecure_skip_verify(mut self) -> Self {
92 self.insecure_skip_verify = true;
93 self
94 }
95}
96
97#[derive(Debug, Clone, Default)]
99pub struct HttpTlsConfig {
100 pub insecure_skip_verify: bool,
102 pub ca_cert_pem: Option<Vec<u8>>,
104 pub client_cert_pem: Option<Vec<u8>>,
106 pub client_key_pem: Option<Vec<u8>>,
108}
109
110impl HttpTlsConfig {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub async fn with_ca_cert_file(
118 mut self,
119 path: impl AsRef<Path>,
120 ) -> Result<Self, std::io::Error> {
121 self.ca_cert_pem = Some(tokio::fs::read(path).await?);
122 Ok(self)
123 }
124
125 pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
127 self.ca_cert_pem = Some(pem.into());
128 self
129 }
130
131 pub async fn with_client_cert_files(
133 mut self,
134 cert_path: impl AsRef<Path>,
135 key_path: impl AsRef<Path>,
136 ) -> Result<Self, std::io::Error> {
137 self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
138 self.client_key_pem = Some(tokio::fs::read(key_path).await?);
139 Ok(self)
140 }
141
142 pub fn with_client_identity(
144 mut self,
145 cert_pem: impl Into<Vec<u8>>,
146 key_pem: impl Into<Vec<u8>>,
147 ) -> Self {
148 self.client_cert_pem = Some(cert_pem.into());
149 self.client_key_pem = Some(key_pem.into());
150 self
151 }
152
153 pub fn with_insecure_skip_verify(mut self) -> Self {
155 self.insecure_skip_verify = true;
156 self
157 }
158}
159
160struct HttpConnection {
162 client: reqwest::Client,
164 url: String,
166}
167
168pub struct AgentClient {
170 id: String,
172 connection: AgentConnection,
174 timeout: Duration,
176 #[allow(dead_code)]
178 max_retries: u32,
179}
180
181enum AgentConnection {
183 UnixSocket(UnixStream),
184 Grpc(AgentProcessorClient<Channel>),
185 Http(Arc<HttpConnection>),
186}
187
188impl AgentClient {
189 pub async fn unix_socket(
191 id: impl Into<String>,
192 path: impl AsRef<std::path::Path>,
193 timeout: Duration,
194 ) -> Result<Self, AgentProtocolError> {
195 let id = id.into();
196 let path = path.as_ref();
197
198 trace!(
199 agent_id = %id,
200 socket_path = %path.display(),
201 timeout_ms = timeout.as_millis() as u64,
202 "Connecting to agent via Unix socket"
203 );
204
205 let stream = UnixStream::connect(path).await.map_err(|e| {
206 error!(
207 agent_id = %id,
208 socket_path = %path.display(),
209 error = %e,
210 "Failed to connect to agent via Unix socket"
211 );
212 AgentProtocolError::ConnectionFailed(e.to_string())
213 })?;
214
215 debug!(
216 agent_id = %id,
217 socket_path = %path.display(),
218 "Connected to agent via Unix socket"
219 );
220
221 Ok(Self {
222 id,
223 connection: AgentConnection::UnixSocket(stream),
224 timeout,
225 max_retries: 3,
226 })
227 }
228
229 pub async fn grpc(
236 id: impl Into<String>,
237 address: impl Into<String>,
238 timeout: Duration,
239 ) -> Result<Self, AgentProtocolError> {
240 let id = id.into();
241 let address = address.into();
242
243 trace!(
244 agent_id = %id,
245 address = %address,
246 timeout_ms = timeout.as_millis() as u64,
247 "Connecting to agent via gRPC"
248 );
249
250 let channel = Channel::from_shared(address.clone())
251 .map_err(|e| {
252 error!(
253 agent_id = %id,
254 address = %address,
255 error = %e,
256 "Invalid gRPC URI"
257 );
258 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
259 })?
260 .timeout(timeout)
261 .connect()
262 .await
263 .map_err(|e| {
264 error!(
265 agent_id = %id,
266 address = %address,
267 error = %e,
268 "Failed to connect to agent via gRPC"
269 );
270 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
271 })?;
272
273 let client = AgentProcessorClient::new(channel);
274
275 debug!(
276 agent_id = %id,
277 address = %address,
278 "Connected to agent via gRPC"
279 );
280
281 Ok(Self {
282 id,
283 connection: AgentConnection::Grpc(client),
284 timeout,
285 max_retries: 3,
286 })
287 }
288
289 pub async fn grpc_tls(
297 id: impl Into<String>,
298 address: impl Into<String>,
299 timeout: Duration,
300 tls_config: GrpcTlsConfig,
301 ) -> Result<Self, AgentProtocolError> {
302 let id = id.into();
303 let address = address.into();
304
305 trace!(
306 agent_id = %id,
307 address = %address,
308 timeout_ms = timeout.as_millis() as u64,
309 has_ca_cert = tls_config.ca_cert_pem.is_some(),
310 has_client_cert = tls_config.client_cert_pem.is_some(),
311 insecure = tls_config.insecure_skip_verify,
312 "Connecting to agent via gRPC with TLS"
313 );
314
315 let mut client_tls_config = ClientTlsConfig::new();
317
318 if let Some(domain) = &tls_config.domain_name {
320 client_tls_config = client_tls_config.domain_name(domain.clone());
321 } else {
322 if let Some(domain) = Self::extract_domain(&address) {
324 client_tls_config = client_tls_config.domain_name(domain);
325 }
326 }
327
328 if let Some(ca_pem) = &tls_config.ca_cert_pem {
330 let ca_cert = Certificate::from_pem(ca_pem);
331 client_tls_config = client_tls_config.ca_certificate(ca_cert);
332 debug!(
333 agent_id = %id,
334 "Using custom CA certificate for gRPC TLS"
335 );
336 }
337
338 if let (Some(cert_pem), Some(key_pem)) =
340 (&tls_config.client_cert_pem, &tls_config.client_key_pem)
341 {
342 let identity = Identity::from_pem(cert_pem, key_pem);
343 client_tls_config = client_tls_config.identity(identity);
344 debug!(
345 agent_id = %id,
346 "Using client certificate for mTLS to gRPC agent"
347 );
348 }
349
350 if tls_config.insecure_skip_verify {
352 warn!(
353 agent_id = %id,
354 address = %address,
355 "SECURITY WARNING: TLS certificate verification disabled for gRPC agent connection"
356 );
357 }
361
362 let channel = Channel::from_shared(address.clone())
364 .map_err(|e| {
365 error!(
366 agent_id = %id,
367 address = %address,
368 error = %e,
369 "Invalid gRPC URI"
370 );
371 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
372 })?
373 .tls_config(client_tls_config)
374 .map_err(|e| {
375 error!(
376 agent_id = %id,
377 address = %address,
378 error = %e,
379 "Invalid TLS configuration"
380 );
381 AgentProtocolError::ConnectionFailed(format!("TLS config error: {}", e))
382 })?
383 .timeout(timeout)
384 .connect()
385 .await
386 .map_err(|e| {
387 error!(
388 agent_id = %id,
389 address = %address,
390 error = %e,
391 "Failed to connect to agent via gRPC with TLS"
392 );
393 AgentProtocolError::ConnectionFailed(format!("gRPC TLS connect failed: {}", e))
394 })?;
395
396 let client = AgentProcessorClient::new(channel);
397
398 debug!(
399 agent_id = %id,
400 address = %address,
401 "Connected to agent via gRPC with TLS"
402 );
403
404 Ok(Self {
405 id,
406 connection: AgentConnection::Grpc(client),
407 timeout,
408 max_retries: 3,
409 })
410 }
411
412 fn extract_domain(address: &str) -> Option<String> {
414 let address = address.trim();
416
417 if let Some(rest) = address
419 .strip_prefix("https://")
420 .or_else(|| address.strip_prefix("http://"))
421 {
422 let host = rest.split(':').next()?.split('/').next()?;
424 if !host.is_empty() {
425 return Some(host.to_string());
426 }
427 }
428
429 None
430 }
431
432 pub async fn http(
439 id: impl Into<String>,
440 url: impl Into<String>,
441 timeout: Duration,
442 ) -> Result<Self, AgentProtocolError> {
443 let id = id.into();
444 let url = url.into();
445
446 trace!(
447 agent_id = %id,
448 url = %url,
449 timeout_ms = timeout.as_millis() as u64,
450 "Creating HTTP agent client"
451 );
452
453 let client = reqwest::Client::builder()
454 .timeout(timeout)
455 .build()
456 .map_err(|e| {
457 error!(
458 agent_id = %id,
459 url = %url,
460 error = %e,
461 "Failed to create HTTP client"
462 );
463 AgentProtocolError::ConnectionFailed(format!("HTTP client error: {}", e))
464 })?;
465
466 debug!(
467 agent_id = %id,
468 url = %url,
469 "HTTP agent client created"
470 );
471
472 Ok(Self {
473 id,
474 connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
475 timeout,
476 max_retries: 3,
477 })
478 }
479
480 pub async fn http_tls(
488 id: impl Into<String>,
489 url: impl Into<String>,
490 timeout: Duration,
491 tls_config: HttpTlsConfig,
492 ) -> Result<Self, AgentProtocolError> {
493 let id = id.into();
494 let url = url.into();
495
496 trace!(
497 agent_id = %id,
498 url = %url,
499 timeout_ms = timeout.as_millis() as u64,
500 has_ca_cert = tls_config.ca_cert_pem.is_some(),
501 has_client_cert = tls_config.client_cert_pem.is_some(),
502 insecure = tls_config.insecure_skip_verify,
503 "Creating HTTP agent client with TLS"
504 );
505
506 let mut client_builder = reqwest::Client::builder().timeout(timeout).use_rustls_tls();
507
508 if let Some(ca_pem) = &tls_config.ca_cert_pem {
510 let ca_cert = reqwest::Certificate::from_pem(ca_pem).map_err(|e| {
511 error!(
512 agent_id = %id,
513 error = %e,
514 "Failed to parse CA certificate"
515 );
516 AgentProtocolError::ConnectionFailed(format!("Invalid CA certificate: {}", e))
517 })?;
518 client_builder = client_builder.add_root_certificate(ca_cert);
519 debug!(
520 agent_id = %id,
521 "Using custom CA certificate for HTTP TLS"
522 );
523 }
524
525 if let (Some(cert_pem), Some(key_pem)) =
527 (&tls_config.client_cert_pem, &tls_config.client_key_pem)
528 {
529 let mut identity_pem = cert_pem.clone();
531 identity_pem.extend_from_slice(b"\n");
532 identity_pem.extend_from_slice(key_pem);
533
534 let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
535 error!(
536 agent_id = %id,
537 error = %e,
538 "Failed to parse client certificate/key"
539 );
540 AgentProtocolError::ConnectionFailed(format!("Invalid client certificate: {}", e))
541 })?;
542 client_builder = client_builder.identity(identity);
543 debug!(
544 agent_id = %id,
545 "Using client certificate for mTLS to HTTP agent"
546 );
547 }
548
549 if tls_config.insecure_skip_verify {
551 warn!(
552 agent_id = %id,
553 url = %url,
554 "SECURITY WARNING: TLS certificate verification disabled for HTTP agent connection"
555 );
556 client_builder = client_builder.danger_accept_invalid_certs(true);
557 }
558
559 let client = client_builder.build().map_err(|e| {
560 error!(
561 agent_id = %id,
562 url = %url,
563 error = %e,
564 "Failed to create HTTP TLS client"
565 );
566 AgentProtocolError::ConnectionFailed(format!("HTTP TLS client error: {}", e))
567 })?;
568
569 debug!(
570 agent_id = %id,
571 url = %url,
572 "HTTP agent client created with TLS"
573 );
574
575 Ok(Self {
576 id,
577 connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
578 timeout,
579 max_retries: 3,
580 })
581 }
582
583 #[allow(dead_code)]
585 pub fn id(&self) -> &str {
586 &self.id
587 }
588
589 pub async fn send_event(
591 &mut self,
592 event_type: EventType,
593 payload: impl Serialize,
594 ) -> Result<AgentResponse, AgentProtocolError> {
595 let http_conn = if let AgentConnection::Http(conn) = &self.connection {
597 Some(Arc::clone(conn))
598 } else {
599 None
600 };
601
602 match &mut self.connection {
603 AgentConnection::UnixSocket(_) => {
604 self.send_event_unix_socket(event_type, payload).await
605 }
606 AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
607 AgentConnection::Http(_) => {
608 self.send_event_http(http_conn.unwrap(), event_type, payload)
610 .await
611 }
612 }
613 }
614
615 async fn send_event_unix_socket(
617 &mut self,
618 event_type: EventType,
619 payload: impl Serialize,
620 ) -> Result<AgentResponse, AgentProtocolError> {
621 let request = AgentRequest {
622 version: PROTOCOL_VERSION,
623 event_type,
624 payload: serde_json::to_value(payload)
625 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
626 };
627
628 let request_bytes = serde_json::to_vec(&request)
630 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
631
632 if request_bytes.len() > MAX_MESSAGE_SIZE {
634 return Err(AgentProtocolError::MessageTooLarge {
635 size: request_bytes.len(),
636 max: MAX_MESSAGE_SIZE,
637 });
638 }
639
640 let response = tokio::time::timeout(self.timeout, async {
642 self.send_raw_unix(&request_bytes).await?;
643 self.receive_raw_unix().await
644 })
645 .await
646 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
647
648 let agent_response: AgentResponse = serde_json::from_slice(&response)
650 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
651
652 if agent_response.version != PROTOCOL_VERSION {
654 return Err(AgentProtocolError::VersionMismatch {
655 expected: PROTOCOL_VERSION,
656 actual: agent_response.version,
657 });
658 }
659
660 Ok(agent_response)
661 }
662
663 async fn send_event_grpc(
665 &mut self,
666 event_type: EventType,
667 payload: impl Serialize,
668 ) -> Result<AgentResponse, AgentProtocolError> {
669 let grpc_request = Self::build_grpc_request(event_type, payload)?;
671
672 let AgentConnection::Grpc(client) = &mut self.connection else {
673 return Err(AgentProtocolError::WrongConnectionType(
674 "Expected gRPC connection but found Unix socket".to_string(),
675 ));
676 };
677
678 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
680 .await
681 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
682 .map_err(|e| {
683 AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
684 })?;
685
686 Self::convert_grpc_response(response.into_inner())
688 }
689
690 async fn send_event_http(
692 &self,
693 conn: Arc<HttpConnection>,
694 event_type: EventType,
695 payload: impl Serialize,
696 ) -> Result<AgentResponse, AgentProtocolError> {
697 let request = AgentRequest {
698 version: PROTOCOL_VERSION,
699 event_type,
700 payload: serde_json::to_value(payload)
701 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
702 };
703
704 let request_json = serde_json::to_string(&request)
706 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
707
708 if request_json.len() > MAX_MESSAGE_SIZE {
710 return Err(AgentProtocolError::MessageTooLarge {
711 size: request_json.len(),
712 max: MAX_MESSAGE_SIZE,
713 });
714 }
715
716 trace!(
717 agent_id = %self.id,
718 url = %conn.url,
719 event_type = ?event_type,
720 request_size = request_json.len(),
721 "Sending HTTP request to agent"
722 );
723
724 let response = conn
726 .client
727 .post(&conn.url)
728 .header("Content-Type", "application/json")
729 .header("X-Sentinel-Protocol-Version", PROTOCOL_VERSION.to_string())
730 .body(request_json)
731 .send()
732 .await
733 .map_err(|e| {
734 error!(
735 agent_id = %self.id,
736 url = %conn.url,
737 error = %e,
738 "HTTP request to agent failed"
739 );
740 if e.is_timeout() {
741 AgentProtocolError::Timeout(self.timeout)
742 } else if e.is_connect() {
743 AgentProtocolError::ConnectionFailed(format!("HTTP connect failed: {}", e))
744 } else {
745 AgentProtocolError::ConnectionFailed(format!("HTTP request failed: {}", e))
746 }
747 })?;
748
749 let status = response.status();
751 if !status.is_success() {
752 let body = response.text().await.unwrap_or_default();
753 error!(
754 agent_id = %self.id,
755 url = %conn.url,
756 status = %status,
757 body = %body,
758 "Agent returned HTTP error"
759 );
760 return Err(AgentProtocolError::ConnectionFailed(format!(
761 "HTTP {} from agent: {}",
762 status, body
763 )));
764 }
765
766 let response_bytes = response.bytes().await.map_err(|e| {
768 AgentProtocolError::ConnectionFailed(format!("Failed to read response body: {}", e))
769 })?;
770
771 if response_bytes.len() > MAX_MESSAGE_SIZE {
773 return Err(AgentProtocolError::MessageTooLarge {
774 size: response_bytes.len(),
775 max: MAX_MESSAGE_SIZE,
776 });
777 }
778
779 let agent_response: AgentResponse =
780 serde_json::from_slice(&response_bytes).map_err(|e| {
781 AgentProtocolError::InvalidMessage(format!("Invalid JSON response: {}", e))
782 })?;
783
784 if agent_response.version != PROTOCOL_VERSION {
786 return Err(AgentProtocolError::VersionMismatch {
787 expected: PROTOCOL_VERSION,
788 actual: agent_response.version,
789 });
790 }
791
792 trace!(
793 agent_id = %self.id,
794 decision = ?agent_response.decision,
795 "Received HTTP response from agent"
796 );
797
798 Ok(agent_response)
799 }
800
801 fn build_grpc_request(
803 event_type: EventType,
804 payload: impl Serialize,
805 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
806 let payload_json = serde_json::to_value(&payload)
807 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
808
809 let grpc_event_type = match event_type {
810 EventType::Configure => {
811 return Err(AgentProtocolError::Serialization(
812 "Configure events are not supported via gRPC".to_string(),
813 ))
814 }
815 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
816 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
817 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
818 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
819 EventType::RequestComplete => grpc::EventType::RequestComplete,
820 EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
821 EventType::GuardrailInspect => {
822 return Err(AgentProtocolError::Serialization(
823 "GuardrailInspect events are not yet supported via gRPC".to_string(),
824 ))
825 }
826 };
827
828 let event = match event_type {
829 EventType::Configure => {
830 return Err(AgentProtocolError::InvalidMessage(
831 "Configure event should be handled separately".to_string(),
832 ));
833 }
834 EventType::RequestHeaders => {
835 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
836 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
837 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
838 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
839 method: event.method,
840 uri: event.uri,
841 headers: event
842 .headers
843 .into_iter()
844 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
845 .collect(),
846 })
847 }
848 EventType::RequestBodyChunk => {
849 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
850 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
851 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
852 correlation_id: event.correlation_id,
853 data: event.data.into_bytes(),
854 is_last: event.is_last,
855 total_size: event.total_size.map(|s| s as u64),
856 chunk_index: event.chunk_index,
857 bytes_received: event.bytes_received as u64,
858 })
859 }
860 EventType::ResponseHeaders => {
861 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
862 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
863 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
864 correlation_id: event.correlation_id,
865 status: event.status as u32,
866 headers: event
867 .headers
868 .into_iter()
869 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
870 .collect(),
871 })
872 }
873 EventType::ResponseBodyChunk => {
874 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
875 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
876 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
877 correlation_id: event.correlation_id,
878 data: event.data.into_bytes(),
879 is_last: event.is_last,
880 total_size: event.total_size.map(|s| s as u64),
881 chunk_index: event.chunk_index,
882 bytes_sent: event.bytes_sent as u64,
883 })
884 }
885 EventType::RequestComplete => {
886 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
887 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
888 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
889 correlation_id: event.correlation_id,
890 status: event.status as u32,
891 duration_ms: event.duration_ms,
892 request_body_size: event.request_body_size as u64,
893 response_body_size: event.response_body_size as u64,
894 upstream_attempts: event.upstream_attempts,
895 error: event.error,
896 })
897 }
898 EventType::WebSocketFrame => {
899 use base64::{engine::general_purpose::STANDARD, Engine as _};
900 let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
901 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
902 grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
903 correlation_id: event.correlation_id,
904 opcode: event.opcode,
905 data: STANDARD.decode(&event.data).unwrap_or_default(),
906 client_to_server: event.client_to_server,
907 frame_index: event.frame_index,
908 fin: event.fin,
909 route_id: event.route_id,
910 client_ip: event.client_ip,
911 })
912 }
913 EventType::GuardrailInspect => {
914 return Err(AgentProtocolError::InvalidMessage(
915 "GuardrailInspect events are not yet supported via gRPC".to_string(),
916 ));
917 }
918 };
919
920 Ok(grpc::AgentRequest {
921 version: PROTOCOL_VERSION,
922 event_type: grpc_event_type as i32,
923 event: Some(event),
924 })
925 }
926
927 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
929 grpc::RequestMetadata {
930 correlation_id: metadata.correlation_id.clone(),
931 request_id: metadata.request_id.clone(),
932 client_ip: metadata.client_ip.clone(),
933 client_port: metadata.client_port as u32,
934 server_name: metadata.server_name.clone(),
935 protocol: metadata.protocol.clone(),
936 tls_version: metadata.tls_version.clone(),
937 tls_cipher: metadata.tls_cipher.clone(),
938 route_id: metadata.route_id.clone(),
939 upstream_id: metadata.upstream_id.clone(),
940 timestamp: metadata.timestamp.clone(),
941 traceparent: metadata.traceparent.clone(),
942 }
943 }
944
945 fn convert_grpc_response(
947 response: grpc::AgentResponse,
948 ) -> Result<AgentResponse, AgentProtocolError> {
949 let decision = match response.decision {
950 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
951 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
952 status: b.status as u16,
953 body: b.body,
954 headers: if b.headers.is_empty() {
955 None
956 } else {
957 Some(b.headers)
958 },
959 },
960 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
961 url: r.url,
962 status: r.status as u16,
963 },
964 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
965 challenge_type: c.challenge_type,
966 params: c.params,
967 },
968 None => Decision::Allow, };
970
971 let request_headers: Vec<HeaderOp> = response
972 .request_headers
973 .into_iter()
974 .filter_map(Self::convert_header_op_from_grpc)
975 .collect();
976
977 let response_headers: Vec<HeaderOp> = response
978 .response_headers
979 .into_iter()
980 .filter_map(Self::convert_header_op_from_grpc)
981 .collect();
982
983 let audit = response.audit.map(|a| AuditMetadata {
984 tags: a.tags,
985 rule_ids: a.rule_ids,
986 confidence: a.confidence,
987 reason_codes: a.reason_codes,
988 custom: a
989 .custom
990 .into_iter()
991 .map(|(k, v)| (k, serde_json::Value::String(v)))
992 .collect(),
993 });
994
995 let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
997 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
998 chunk_index: m.chunk_index,
999 });
1000
1001 let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
1002 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
1003 chunk_index: m.chunk_index,
1004 });
1005
1006 let websocket_decision = response
1008 .websocket_decision
1009 .map(|ws_decision| match ws_decision {
1010 grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
1011 WebSocketDecision::Allow
1012 }
1013 grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
1014 WebSocketDecision::Drop
1015 }
1016 grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
1017 WebSocketDecision::Close {
1018 code: c.code as u16,
1019 reason: c.reason,
1020 }
1021 }
1022 });
1023
1024 Ok(AgentResponse {
1025 version: response.version,
1026 decision,
1027 request_headers,
1028 response_headers,
1029 routing_metadata: response.routing_metadata,
1030 audit: audit.unwrap_or_default(),
1031 needs_more: response.needs_more,
1032 request_body_mutation,
1033 response_body_mutation,
1034 websocket_decision,
1035 })
1036 }
1037
1038 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
1040 match op.operation? {
1041 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
1042 name: s.name,
1043 value: s.value,
1044 }),
1045 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
1046 name: a.name,
1047 value: a.value,
1048 }),
1049 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
1050 }
1051 }
1052
1053 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
1055 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1056 return Err(AgentProtocolError::WrongConnectionType(
1057 "Expected Unix socket connection but found gRPC".to_string(),
1058 ));
1059 };
1060 let len_bytes = (data.len() as u32).to_be_bytes();
1062 stream.write_all(&len_bytes).await?;
1063 stream.write_all(data).await?;
1065 stream.flush().await?;
1066 Ok(())
1067 }
1068
1069 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
1071 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1072 return Err(AgentProtocolError::WrongConnectionType(
1073 "Expected Unix socket connection but found gRPC".to_string(),
1074 ));
1075 };
1076 let mut len_bytes = [0u8; 4];
1078 stream.read_exact(&mut len_bytes).await?;
1079 let message_len = u32::from_be_bytes(len_bytes) as usize;
1080
1081 if message_len > MAX_MESSAGE_SIZE {
1083 return Err(AgentProtocolError::MessageTooLarge {
1084 size: message_len,
1085 max: MAX_MESSAGE_SIZE,
1086 });
1087 }
1088
1089 let mut buffer = vec![0u8; message_len];
1091 stream.read_exact(&mut buffer).await?;
1092 Ok(buffer)
1093 }
1094
1095 pub async fn close(self) -> Result<(), AgentProtocolError> {
1097 match self.connection {
1098 AgentConnection::UnixSocket(mut stream) => {
1099 stream.shutdown().await?;
1100 Ok(())
1101 }
1102 AgentConnection::Grpc(_) => Ok(()), AgentConnection::Http(_) => Ok(()), }
1105 }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110 use super::*;
1111
1112 #[test]
1113 fn test_extract_domain_https() {
1114 assert_eq!(
1115 AgentClient::extract_domain("https://example.com:443"),
1116 Some("example.com".to_string())
1117 );
1118 assert_eq!(
1119 AgentClient::extract_domain("https://agent.internal:50051"),
1120 Some("agent.internal".to_string())
1121 );
1122 assert_eq!(
1123 AgentClient::extract_domain("https://localhost:8080/path"),
1124 Some("localhost".to_string())
1125 );
1126 }
1127
1128 #[test]
1129 fn test_extract_domain_http() {
1130 assert_eq!(
1131 AgentClient::extract_domain("http://example.com:8080"),
1132 Some("example.com".to_string())
1133 );
1134 assert_eq!(
1135 AgentClient::extract_domain("http://localhost:50051"),
1136 Some("localhost".to_string())
1137 );
1138 }
1139
1140 #[test]
1141 fn test_extract_domain_invalid() {
1142 assert_eq!(AgentClient::extract_domain("example.com:443"), None);
1143 assert_eq!(AgentClient::extract_domain("tcp://example.com:443"), None);
1144 assert_eq!(AgentClient::extract_domain(""), None);
1145 }
1146
1147 #[test]
1148 fn test_grpc_tls_config_builder() {
1149 let config = GrpcTlsConfig::new()
1150 .with_ca_cert_pem(b"test-ca-cert".to_vec())
1151 .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec())
1152 .with_domain_name("example.com");
1153
1154 assert!(config.ca_cert_pem.is_some());
1155 assert!(config.client_cert_pem.is_some());
1156 assert!(config.client_key_pem.is_some());
1157 assert_eq!(config.domain_name, Some("example.com".to_string()));
1158 assert!(!config.insecure_skip_verify);
1159 }
1160
1161 #[test]
1162 fn test_grpc_tls_config_insecure() {
1163 let config = GrpcTlsConfig::new().with_insecure_skip_verify();
1164
1165 assert!(config.insecure_skip_verify);
1166 assert!(config.ca_cert_pem.is_none());
1167 }
1168
1169 #[test]
1170 fn test_http_tls_config_builder() {
1171 let config = HttpTlsConfig::new()
1172 .with_ca_cert_pem(b"test-ca-cert".to_vec())
1173 .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec());
1174
1175 assert!(config.ca_cert_pem.is_some());
1176 assert!(config.client_cert_pem.is_some());
1177 assert!(config.client_key_pem.is_some());
1178 assert!(!config.insecure_skip_verify);
1179 }
1180
1181 #[test]
1182 fn test_http_tls_config_insecure() {
1183 let config = HttpTlsConfig::new().with_insecure_skip_verify();
1184
1185 assert!(config.insecure_skip_verify);
1186 assert!(config.ca_cert_pem.is_none());
1187 }
1188
1189 #[tokio::test]
1190 async fn test_http_client_creation() {
1191 let result = AgentClient::http(
1193 "test-agent",
1194 "http://localhost:9999/agent",
1195 Duration::from_secs(5),
1196 )
1197 .await;
1198
1199 assert!(result.is_ok());
1201 }
1202}