1use serde::Serialize;
9use std::path::Path;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::UnixStream;
14use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
15use tracing::{debug, error, trace, warn};
16
17use crate::errors::AgentProtocolError;
18use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
19use crate::protocol::{
20 AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
21 RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
22 ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
23 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
24};
25
26#[derive(Debug, Clone, Default)]
28pub struct GrpcTlsConfig {
29 pub insecure_skip_verify: bool,
31 pub ca_cert_pem: Option<Vec<u8>>,
33 pub client_cert_pem: Option<Vec<u8>>,
35 pub client_key_pem: Option<Vec<u8>>,
37 pub domain_name: Option<String>,
39}
40
41impl GrpcTlsConfig {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub async fn with_ca_cert_file(mut self, path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
49 self.ca_cert_pem = Some(tokio::fs::read(path).await?);
50 Ok(self)
51 }
52
53 pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
55 self.ca_cert_pem = Some(pem.into());
56 self
57 }
58
59 pub async fn with_client_cert_files(
61 mut self,
62 cert_path: impl AsRef<Path>,
63 key_path: impl AsRef<Path>,
64 ) -> Result<Self, std::io::Error> {
65 self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
66 self.client_key_pem = Some(tokio::fs::read(key_path).await?);
67 Ok(self)
68 }
69
70 pub fn with_client_identity(mut self, cert_pem: impl Into<Vec<u8>>, key_pem: impl Into<Vec<u8>>) -> Self {
72 self.client_cert_pem = Some(cert_pem.into());
73 self.client_key_pem = Some(key_pem.into());
74 self
75 }
76
77 pub fn with_domain_name(mut self, domain: impl Into<String>) -> Self {
79 self.domain_name = Some(domain.into());
80 self
81 }
82
83 pub fn with_insecure_skip_verify(mut self) -> Self {
85 self.insecure_skip_verify = true;
86 self
87 }
88}
89
90#[derive(Debug, Clone, Default)]
92pub struct HttpTlsConfig {
93 pub insecure_skip_verify: bool,
95 pub ca_cert_pem: Option<Vec<u8>>,
97 pub client_cert_pem: Option<Vec<u8>>,
99 pub client_key_pem: Option<Vec<u8>>,
101}
102
103impl HttpTlsConfig {
104 pub fn new() -> Self {
106 Self::default()
107 }
108
109 pub async fn with_ca_cert_file(mut self, path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
111 self.ca_cert_pem = Some(tokio::fs::read(path).await?);
112 Ok(self)
113 }
114
115 pub fn with_ca_cert_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
117 self.ca_cert_pem = Some(pem.into());
118 self
119 }
120
121 pub async fn with_client_cert_files(
123 mut self,
124 cert_path: impl AsRef<Path>,
125 key_path: impl AsRef<Path>,
126 ) -> Result<Self, std::io::Error> {
127 self.client_cert_pem = Some(tokio::fs::read(cert_path).await?);
128 self.client_key_pem = Some(tokio::fs::read(key_path).await?);
129 Ok(self)
130 }
131
132 pub fn with_client_identity(mut self, cert_pem: impl Into<Vec<u8>>, key_pem: impl Into<Vec<u8>>) -> Self {
134 self.client_cert_pem = Some(cert_pem.into());
135 self.client_key_pem = Some(key_pem.into());
136 self
137 }
138
139 pub fn with_insecure_skip_verify(mut self) -> Self {
141 self.insecure_skip_verify = true;
142 self
143 }
144}
145
146struct HttpConnection {
148 client: reqwest::Client,
150 url: String,
152}
153
154pub struct AgentClient {
156 id: String,
158 connection: AgentConnection,
160 timeout: Duration,
162 #[allow(dead_code)]
164 max_retries: u32,
165}
166
167enum AgentConnection {
169 UnixSocket(UnixStream),
170 Grpc(AgentProcessorClient<Channel>),
171 Http(Arc<HttpConnection>),
172}
173
174impl AgentClient {
175 pub async fn unix_socket(
177 id: impl Into<String>,
178 path: impl AsRef<std::path::Path>,
179 timeout: Duration,
180 ) -> Result<Self, AgentProtocolError> {
181 let id = id.into();
182 let path = path.as_ref();
183
184 trace!(
185 agent_id = %id,
186 socket_path = %path.display(),
187 timeout_ms = timeout.as_millis() as u64,
188 "Connecting to agent via Unix socket"
189 );
190
191 let stream = UnixStream::connect(path).await.map_err(|e| {
192 error!(
193 agent_id = %id,
194 socket_path = %path.display(),
195 error = %e,
196 "Failed to connect to agent via Unix socket"
197 );
198 AgentProtocolError::ConnectionFailed(e.to_string())
199 })?;
200
201 debug!(
202 agent_id = %id,
203 socket_path = %path.display(),
204 "Connected to agent via Unix socket"
205 );
206
207 Ok(Self {
208 id,
209 connection: AgentConnection::UnixSocket(stream),
210 timeout,
211 max_retries: 3,
212 })
213 }
214
215 pub async fn grpc(
222 id: impl Into<String>,
223 address: impl Into<String>,
224 timeout: Duration,
225 ) -> Result<Self, AgentProtocolError> {
226 let id = id.into();
227 let address = address.into();
228
229 trace!(
230 agent_id = %id,
231 address = %address,
232 timeout_ms = timeout.as_millis() as u64,
233 "Connecting to agent via gRPC"
234 );
235
236 let channel = Channel::from_shared(address.clone())
237 .map_err(|e| {
238 error!(
239 agent_id = %id,
240 address = %address,
241 error = %e,
242 "Invalid gRPC URI"
243 );
244 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
245 })?
246 .timeout(timeout)
247 .connect()
248 .await
249 .map_err(|e| {
250 error!(
251 agent_id = %id,
252 address = %address,
253 error = %e,
254 "Failed to connect to agent via gRPC"
255 );
256 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
257 })?;
258
259 let client = AgentProcessorClient::new(channel);
260
261 debug!(
262 agent_id = %id,
263 address = %address,
264 "Connected to agent via gRPC"
265 );
266
267 Ok(Self {
268 id,
269 connection: AgentConnection::Grpc(client),
270 timeout,
271 max_retries: 3,
272 })
273 }
274
275 pub async fn grpc_tls(
283 id: impl Into<String>,
284 address: impl Into<String>,
285 timeout: Duration,
286 tls_config: GrpcTlsConfig,
287 ) -> Result<Self, AgentProtocolError> {
288 let id = id.into();
289 let address = address.into();
290
291 trace!(
292 agent_id = %id,
293 address = %address,
294 timeout_ms = timeout.as_millis() as u64,
295 has_ca_cert = tls_config.ca_cert_pem.is_some(),
296 has_client_cert = tls_config.client_cert_pem.is_some(),
297 insecure = tls_config.insecure_skip_verify,
298 "Connecting to agent via gRPC with TLS"
299 );
300
301 let mut client_tls_config = ClientTlsConfig::new();
303
304 if let Some(domain) = &tls_config.domain_name {
306 client_tls_config = client_tls_config.domain_name(domain.clone());
307 } else {
308 if let Some(domain) = Self::extract_domain(&address) {
310 client_tls_config = client_tls_config.domain_name(domain);
311 }
312 }
313
314 if let Some(ca_pem) = &tls_config.ca_cert_pem {
316 let ca_cert = Certificate::from_pem(ca_pem);
317 client_tls_config = client_tls_config.ca_certificate(ca_cert);
318 debug!(
319 agent_id = %id,
320 "Using custom CA certificate for gRPC TLS"
321 );
322 }
323
324 if let (Some(cert_pem), Some(key_pem)) = (&tls_config.client_cert_pem, &tls_config.client_key_pem) {
326 let identity = Identity::from_pem(cert_pem, key_pem);
327 client_tls_config = client_tls_config.identity(identity);
328 debug!(
329 agent_id = %id,
330 "Using client certificate for mTLS to gRPC agent"
331 );
332 }
333
334 if tls_config.insecure_skip_verify {
336 warn!(
337 agent_id = %id,
338 address = %address,
339 "SECURITY WARNING: TLS certificate verification disabled for gRPC agent connection"
340 );
341 }
345
346 let channel = Channel::from_shared(address.clone())
348 .map_err(|e| {
349 error!(
350 agent_id = %id,
351 address = %address,
352 error = %e,
353 "Invalid gRPC URI"
354 );
355 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
356 })?
357 .tls_config(client_tls_config)
358 .map_err(|e| {
359 error!(
360 agent_id = %id,
361 address = %address,
362 error = %e,
363 "Invalid TLS configuration"
364 );
365 AgentProtocolError::ConnectionFailed(format!("TLS config error: {}", e))
366 })?
367 .timeout(timeout)
368 .connect()
369 .await
370 .map_err(|e| {
371 error!(
372 agent_id = %id,
373 address = %address,
374 error = %e,
375 "Failed to connect to agent via gRPC with TLS"
376 );
377 AgentProtocolError::ConnectionFailed(format!("gRPC TLS connect failed: {}", e))
378 })?;
379
380 let client = AgentProcessorClient::new(channel);
381
382 debug!(
383 agent_id = %id,
384 address = %address,
385 "Connected to agent via gRPC with TLS"
386 );
387
388 Ok(Self {
389 id,
390 connection: AgentConnection::Grpc(client),
391 timeout,
392 max_retries: 3,
393 })
394 }
395
396 fn extract_domain(address: &str) -> Option<String> {
398 let address = address.trim();
400
401 if let Some(rest) = address.strip_prefix("https://").or_else(|| address.strip_prefix("http://")) {
403 let host = rest.split(':').next()?.split('/').next()?;
405 if !host.is_empty() {
406 return Some(host.to_string());
407 }
408 }
409
410 None
411 }
412
413 pub async fn http(
420 id: impl Into<String>,
421 url: impl Into<String>,
422 timeout: Duration,
423 ) -> Result<Self, AgentProtocolError> {
424 let id = id.into();
425 let url = url.into();
426
427 trace!(
428 agent_id = %id,
429 url = %url,
430 timeout_ms = timeout.as_millis() as u64,
431 "Creating HTTP agent client"
432 );
433
434 let client = reqwest::Client::builder()
435 .timeout(timeout)
436 .build()
437 .map_err(|e| {
438 error!(
439 agent_id = %id,
440 url = %url,
441 error = %e,
442 "Failed to create HTTP client"
443 );
444 AgentProtocolError::ConnectionFailed(format!("HTTP client error: {}", e))
445 })?;
446
447 debug!(
448 agent_id = %id,
449 url = %url,
450 "HTTP agent client created"
451 );
452
453 Ok(Self {
454 id,
455 connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
456 timeout,
457 max_retries: 3,
458 })
459 }
460
461 pub async fn http_tls(
469 id: impl Into<String>,
470 url: impl Into<String>,
471 timeout: Duration,
472 tls_config: HttpTlsConfig,
473 ) -> Result<Self, AgentProtocolError> {
474 let id = id.into();
475 let url = url.into();
476
477 trace!(
478 agent_id = %id,
479 url = %url,
480 timeout_ms = timeout.as_millis() as u64,
481 has_ca_cert = tls_config.ca_cert_pem.is_some(),
482 has_client_cert = tls_config.client_cert_pem.is_some(),
483 insecure = tls_config.insecure_skip_verify,
484 "Creating HTTP agent client with TLS"
485 );
486
487 let mut client_builder = reqwest::Client::builder()
488 .timeout(timeout)
489 .use_rustls_tls();
490
491 if let Some(ca_pem) = &tls_config.ca_cert_pem {
493 let ca_cert = reqwest::Certificate::from_pem(ca_pem).map_err(|e| {
494 error!(
495 agent_id = %id,
496 error = %e,
497 "Failed to parse CA certificate"
498 );
499 AgentProtocolError::ConnectionFailed(format!("Invalid CA certificate: {}", e))
500 })?;
501 client_builder = client_builder.add_root_certificate(ca_cert);
502 debug!(
503 agent_id = %id,
504 "Using custom CA certificate for HTTP TLS"
505 );
506 }
507
508 if let (Some(cert_pem), Some(key_pem)) = (&tls_config.client_cert_pem, &tls_config.client_key_pem) {
510 let mut identity_pem = cert_pem.clone();
512 identity_pem.extend_from_slice(b"\n");
513 identity_pem.extend_from_slice(key_pem);
514
515 let identity = reqwest::Identity::from_pem(&identity_pem).map_err(|e| {
516 error!(
517 agent_id = %id,
518 error = %e,
519 "Failed to parse client certificate/key"
520 );
521 AgentProtocolError::ConnectionFailed(format!("Invalid client certificate: {}", e))
522 })?;
523 client_builder = client_builder.identity(identity);
524 debug!(
525 agent_id = %id,
526 "Using client certificate for mTLS to HTTP agent"
527 );
528 }
529
530 if tls_config.insecure_skip_verify {
532 warn!(
533 agent_id = %id,
534 url = %url,
535 "SECURITY WARNING: TLS certificate verification disabled for HTTP agent connection"
536 );
537 client_builder = client_builder.danger_accept_invalid_certs(true);
538 }
539
540 let client = client_builder.build().map_err(|e| {
541 error!(
542 agent_id = %id,
543 url = %url,
544 error = %e,
545 "Failed to create HTTP TLS client"
546 );
547 AgentProtocolError::ConnectionFailed(format!("HTTP TLS client error: {}", e))
548 })?;
549
550 debug!(
551 agent_id = %id,
552 url = %url,
553 "HTTP agent client created with TLS"
554 );
555
556 Ok(Self {
557 id,
558 connection: AgentConnection::Http(Arc::new(HttpConnection { client, url })),
559 timeout,
560 max_retries: 3,
561 })
562 }
563
564 #[allow(dead_code)]
566 pub fn id(&self) -> &str {
567 &self.id
568 }
569
570 pub async fn send_event(
572 &mut self,
573 event_type: EventType,
574 payload: impl Serialize,
575 ) -> Result<AgentResponse, AgentProtocolError> {
576 let http_conn = if let AgentConnection::Http(conn) = &self.connection {
578 Some(Arc::clone(conn))
579 } else {
580 None
581 };
582
583 match &mut self.connection {
584 AgentConnection::UnixSocket(_) => {
585 self.send_event_unix_socket(event_type, payload).await
586 }
587 AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
588 AgentConnection::Http(_) => {
589 self.send_event_http(http_conn.unwrap(), event_type, payload).await
591 }
592 }
593 }
594
595 async fn send_event_unix_socket(
597 &mut self,
598 event_type: EventType,
599 payload: impl Serialize,
600 ) -> Result<AgentResponse, AgentProtocolError> {
601 let request = AgentRequest {
602 version: PROTOCOL_VERSION,
603 event_type,
604 payload: serde_json::to_value(payload)
605 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
606 };
607
608 let request_bytes = serde_json::to_vec(&request)
610 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
611
612 if request_bytes.len() > MAX_MESSAGE_SIZE {
614 return Err(AgentProtocolError::MessageTooLarge {
615 size: request_bytes.len(),
616 max: MAX_MESSAGE_SIZE,
617 });
618 }
619
620 let response = tokio::time::timeout(self.timeout, async {
622 self.send_raw_unix(&request_bytes).await?;
623 self.receive_raw_unix().await
624 })
625 .await
626 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
627
628 let agent_response: AgentResponse = serde_json::from_slice(&response)
630 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
631
632 if agent_response.version != PROTOCOL_VERSION {
634 return Err(AgentProtocolError::VersionMismatch {
635 expected: PROTOCOL_VERSION,
636 actual: agent_response.version,
637 });
638 }
639
640 Ok(agent_response)
641 }
642
643 async fn send_event_grpc(
645 &mut self,
646 event_type: EventType,
647 payload: impl Serialize,
648 ) -> Result<AgentResponse, AgentProtocolError> {
649 let grpc_request = Self::build_grpc_request(event_type, payload)?;
651
652 let AgentConnection::Grpc(client) = &mut self.connection else {
653 return Err(AgentProtocolError::WrongConnectionType(
654 "Expected gRPC connection but found Unix socket".to_string()
655 ));
656 };
657
658 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
660 .await
661 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
662 .map_err(|e| {
663 AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
664 })?;
665
666 Self::convert_grpc_response(response.into_inner())
668 }
669
670 async fn send_event_http(
672 &self,
673 conn: Arc<HttpConnection>,
674 event_type: EventType,
675 payload: impl Serialize,
676 ) -> Result<AgentResponse, AgentProtocolError> {
677 let request = AgentRequest {
678 version: PROTOCOL_VERSION,
679 event_type,
680 payload: serde_json::to_value(payload)
681 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
682 };
683
684 let request_json = serde_json::to_string(&request)
686 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
687
688 if request_json.len() > MAX_MESSAGE_SIZE {
690 return Err(AgentProtocolError::MessageTooLarge {
691 size: request_json.len(),
692 max: MAX_MESSAGE_SIZE,
693 });
694 }
695
696 trace!(
697 agent_id = %self.id,
698 url = %conn.url,
699 event_type = ?event_type,
700 request_size = request_json.len(),
701 "Sending HTTP request to agent"
702 );
703
704 let response = conn
706 .client
707 .post(&conn.url)
708 .header("Content-Type", "application/json")
709 .header("X-Sentinel-Protocol-Version", PROTOCOL_VERSION.to_string())
710 .body(request_json)
711 .send()
712 .await
713 .map_err(|e| {
714 error!(
715 agent_id = %self.id,
716 url = %conn.url,
717 error = %e,
718 "HTTP request to agent failed"
719 );
720 if e.is_timeout() {
721 AgentProtocolError::Timeout(self.timeout)
722 } else if e.is_connect() {
723 AgentProtocolError::ConnectionFailed(format!("HTTP connect failed: {}", e))
724 } else {
725 AgentProtocolError::ConnectionFailed(format!("HTTP request failed: {}", e))
726 }
727 })?;
728
729 let status = response.status();
731 if !status.is_success() {
732 let body = response.text().await.unwrap_or_default();
733 error!(
734 agent_id = %self.id,
735 url = %conn.url,
736 status = %status,
737 body = %body,
738 "Agent returned HTTP error"
739 );
740 return Err(AgentProtocolError::ConnectionFailed(format!(
741 "HTTP {} from agent: {}",
742 status,
743 body
744 )));
745 }
746
747 let response_bytes = response.bytes().await.map_err(|e| {
749 AgentProtocolError::ConnectionFailed(format!("Failed to read response body: {}", e))
750 })?;
751
752 if response_bytes.len() > MAX_MESSAGE_SIZE {
754 return Err(AgentProtocolError::MessageTooLarge {
755 size: response_bytes.len(),
756 max: MAX_MESSAGE_SIZE,
757 });
758 }
759
760 let agent_response: AgentResponse = serde_json::from_slice(&response_bytes)
761 .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid JSON response: {}", e)))?;
762
763 if agent_response.version != PROTOCOL_VERSION {
765 return Err(AgentProtocolError::VersionMismatch {
766 expected: PROTOCOL_VERSION,
767 actual: agent_response.version,
768 });
769 }
770
771 trace!(
772 agent_id = %self.id,
773 decision = ?agent_response.decision,
774 "Received HTTP response from agent"
775 );
776
777 Ok(agent_response)
778 }
779
780 fn build_grpc_request(
782 event_type: EventType,
783 payload: impl Serialize,
784 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
785 let payload_json = serde_json::to_value(&payload)
786 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
787
788 let grpc_event_type = match event_type {
789 EventType::Configure => {
790 return Err(AgentProtocolError::Serialization(
791 "Configure events are not supported via gRPC".to_string(),
792 ))
793 }
794 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
795 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
796 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
797 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
798 EventType::RequestComplete => grpc::EventType::RequestComplete,
799 EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
800 EventType::GuardrailInspect => {
801 return Err(AgentProtocolError::Serialization(
802 "GuardrailInspect events are not yet supported via gRPC".to_string(),
803 ))
804 }
805 };
806
807 let event = match event_type {
808 EventType::Configure => {
809 return Err(AgentProtocolError::InvalidMessage(
810 "Configure event should be handled separately".to_string()
811 ));
812 },
813 EventType::RequestHeaders => {
814 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
815 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
816 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
817 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
818 method: event.method,
819 uri: event.uri,
820 headers: event
821 .headers
822 .into_iter()
823 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
824 .collect(),
825 })
826 }
827 EventType::RequestBodyChunk => {
828 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
829 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
830 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
831 correlation_id: event.correlation_id,
832 data: event.data.into_bytes(),
833 is_last: event.is_last,
834 total_size: event.total_size.map(|s| s as u64),
835 chunk_index: event.chunk_index,
836 bytes_received: event.bytes_received as u64,
837 })
838 }
839 EventType::ResponseHeaders => {
840 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
841 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
842 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
843 correlation_id: event.correlation_id,
844 status: event.status as u32,
845 headers: event
846 .headers
847 .into_iter()
848 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
849 .collect(),
850 })
851 }
852 EventType::ResponseBodyChunk => {
853 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
854 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
855 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
856 correlation_id: event.correlation_id,
857 data: event.data.into_bytes(),
858 is_last: event.is_last,
859 total_size: event.total_size.map(|s| s as u64),
860 chunk_index: event.chunk_index,
861 bytes_sent: event.bytes_sent as u64,
862 })
863 }
864 EventType::RequestComplete => {
865 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
866 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
867 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
868 correlation_id: event.correlation_id,
869 status: event.status as u32,
870 duration_ms: event.duration_ms,
871 request_body_size: event.request_body_size as u64,
872 response_body_size: event.response_body_size as u64,
873 upstream_attempts: event.upstream_attempts,
874 error: event.error,
875 })
876 }
877 EventType::WebSocketFrame => {
878 use base64::{engine::general_purpose::STANDARD, Engine as _};
879 let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
880 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
881 grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
882 correlation_id: event.correlation_id,
883 opcode: event.opcode,
884 data: STANDARD.decode(&event.data).unwrap_or_default(),
885 client_to_server: event.client_to_server,
886 frame_index: event.frame_index,
887 fin: event.fin,
888 route_id: event.route_id,
889 client_ip: event.client_ip,
890 })
891 }
892 EventType::GuardrailInspect => {
893 return Err(AgentProtocolError::InvalidMessage(
894 "GuardrailInspect events are not yet supported via gRPC".to_string()
895 ));
896 }
897 };
898
899 Ok(grpc::AgentRequest {
900 version: PROTOCOL_VERSION,
901 event_type: grpc_event_type as i32,
902 event: Some(event),
903 })
904 }
905
906 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
908 grpc::RequestMetadata {
909 correlation_id: metadata.correlation_id.clone(),
910 request_id: metadata.request_id.clone(),
911 client_ip: metadata.client_ip.clone(),
912 client_port: metadata.client_port as u32,
913 server_name: metadata.server_name.clone(),
914 protocol: metadata.protocol.clone(),
915 tls_version: metadata.tls_version.clone(),
916 tls_cipher: metadata.tls_cipher.clone(),
917 route_id: metadata.route_id.clone(),
918 upstream_id: metadata.upstream_id.clone(),
919 timestamp: metadata.timestamp.clone(),
920 traceparent: metadata.traceparent.clone(),
921 }
922 }
923
924 fn convert_grpc_response(
926 response: grpc::AgentResponse,
927 ) -> Result<AgentResponse, AgentProtocolError> {
928 let decision = match response.decision {
929 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
930 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
931 status: b.status as u16,
932 body: b.body,
933 headers: if b.headers.is_empty() {
934 None
935 } else {
936 Some(b.headers)
937 },
938 },
939 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
940 url: r.url,
941 status: r.status as u16,
942 },
943 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
944 challenge_type: c.challenge_type,
945 params: c.params,
946 },
947 None => Decision::Allow, };
949
950 let request_headers: Vec<HeaderOp> = response
951 .request_headers
952 .into_iter()
953 .filter_map(Self::convert_header_op_from_grpc)
954 .collect();
955
956 let response_headers: Vec<HeaderOp> = response
957 .response_headers
958 .into_iter()
959 .filter_map(Self::convert_header_op_from_grpc)
960 .collect();
961
962 let audit = response.audit.map(|a| AuditMetadata {
963 tags: a.tags,
964 rule_ids: a.rule_ids,
965 confidence: a.confidence,
966 reason_codes: a.reason_codes,
967 custom: a
968 .custom
969 .into_iter()
970 .map(|(k, v)| (k, serde_json::Value::String(v)))
971 .collect(),
972 });
973
974 let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
976 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
977 chunk_index: m.chunk_index,
978 });
979
980 let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
981 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
982 chunk_index: m.chunk_index,
983 });
984
985 let websocket_decision = response
987 .websocket_decision
988 .map(|ws_decision| match ws_decision {
989 grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
990 WebSocketDecision::Allow
991 }
992 grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
993 WebSocketDecision::Drop
994 }
995 grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
996 WebSocketDecision::Close {
997 code: c.code as u16,
998 reason: c.reason,
999 }
1000 }
1001 });
1002
1003 Ok(AgentResponse {
1004 version: response.version,
1005 decision,
1006 request_headers,
1007 response_headers,
1008 routing_metadata: response.routing_metadata,
1009 audit: audit.unwrap_or_default(),
1010 needs_more: response.needs_more,
1011 request_body_mutation,
1012 response_body_mutation,
1013 websocket_decision,
1014 })
1015 }
1016
1017 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
1019 match op.operation? {
1020 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
1021 name: s.name,
1022 value: s.value,
1023 }),
1024 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
1025 name: a.name,
1026 value: a.value,
1027 }),
1028 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
1029 }
1030 }
1031
1032 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
1034 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1035 return Err(AgentProtocolError::WrongConnectionType(
1036 "Expected Unix socket connection but found gRPC".to_string()
1037 ));
1038 };
1039 let len_bytes = (data.len() as u32).to_be_bytes();
1041 stream.write_all(&len_bytes).await?;
1042 stream.write_all(data).await?;
1044 stream.flush().await?;
1045 Ok(())
1046 }
1047
1048 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
1050 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
1051 return Err(AgentProtocolError::WrongConnectionType(
1052 "Expected Unix socket connection but found gRPC".to_string()
1053 ));
1054 };
1055 let mut len_bytes = [0u8; 4];
1057 stream.read_exact(&mut len_bytes).await?;
1058 let message_len = u32::from_be_bytes(len_bytes) as usize;
1059
1060 if message_len > MAX_MESSAGE_SIZE {
1062 return Err(AgentProtocolError::MessageTooLarge {
1063 size: message_len,
1064 max: MAX_MESSAGE_SIZE,
1065 });
1066 }
1067
1068 let mut buffer = vec![0u8; message_len];
1070 stream.read_exact(&mut buffer).await?;
1071 Ok(buffer)
1072 }
1073
1074 pub async fn close(self) -> Result<(), AgentProtocolError> {
1076 match self.connection {
1077 AgentConnection::UnixSocket(mut stream) => {
1078 stream.shutdown().await?;
1079 Ok(())
1080 }
1081 AgentConnection::Grpc(_) => Ok(()), AgentConnection::Http(_) => Ok(()), }
1084 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089 use super::*;
1090
1091 #[test]
1092 fn test_extract_domain_https() {
1093 assert_eq!(
1094 AgentClient::extract_domain("https://example.com:443"),
1095 Some("example.com".to_string())
1096 );
1097 assert_eq!(
1098 AgentClient::extract_domain("https://agent.internal:50051"),
1099 Some("agent.internal".to_string())
1100 );
1101 assert_eq!(
1102 AgentClient::extract_domain("https://localhost:8080/path"),
1103 Some("localhost".to_string())
1104 );
1105 }
1106
1107 #[test]
1108 fn test_extract_domain_http() {
1109 assert_eq!(
1110 AgentClient::extract_domain("http://example.com:8080"),
1111 Some("example.com".to_string())
1112 );
1113 assert_eq!(
1114 AgentClient::extract_domain("http://localhost:50051"),
1115 Some("localhost".to_string())
1116 );
1117 }
1118
1119 #[test]
1120 fn test_extract_domain_invalid() {
1121 assert_eq!(AgentClient::extract_domain("example.com:443"), None);
1122 assert_eq!(AgentClient::extract_domain("tcp://example.com:443"), None);
1123 assert_eq!(AgentClient::extract_domain(""), None);
1124 }
1125
1126 #[test]
1127 fn test_grpc_tls_config_builder() {
1128 let config = GrpcTlsConfig::new()
1129 .with_ca_cert_pem(b"test-ca-cert".to_vec())
1130 .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec())
1131 .with_domain_name("example.com");
1132
1133 assert!(config.ca_cert_pem.is_some());
1134 assert!(config.client_cert_pem.is_some());
1135 assert!(config.client_key_pem.is_some());
1136 assert_eq!(config.domain_name, Some("example.com".to_string()));
1137 assert!(!config.insecure_skip_verify);
1138 }
1139
1140 #[test]
1141 fn test_grpc_tls_config_insecure() {
1142 let config = GrpcTlsConfig::new().with_insecure_skip_verify();
1143
1144 assert!(config.insecure_skip_verify);
1145 assert!(config.ca_cert_pem.is_none());
1146 }
1147
1148 #[test]
1149 fn test_http_tls_config_builder() {
1150 let config = HttpTlsConfig::new()
1151 .with_ca_cert_pem(b"test-ca-cert".to_vec())
1152 .with_client_identity(b"test-cert".to_vec(), b"test-key".to_vec());
1153
1154 assert!(config.ca_cert_pem.is_some());
1155 assert!(config.client_cert_pem.is_some());
1156 assert!(config.client_key_pem.is_some());
1157 assert!(!config.insecure_skip_verify);
1158 }
1159
1160 #[test]
1161 fn test_http_tls_config_insecure() {
1162 let config = HttpTlsConfig::new().with_insecure_skip_verify();
1163
1164 assert!(config.insecure_skip_verify);
1165 assert!(config.ca_cert_pem.is_none());
1166 }
1167
1168 #[tokio::test]
1169 async fn test_http_client_creation() {
1170 let result = AgentClient::http(
1172 "test-agent",
1173 "http://localhost:9999/agent",
1174 Duration::from_secs(5),
1175 )
1176 .await;
1177
1178 assert!(result.is_ok());
1180 }
1181}