1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
11use tonic::transport::Channel;
12use tracing::{debug, info, trace, warn};
13
14use crate::grpc_v2::{self, agent_service_v2_client::AgentServiceV2Client, ProxyToAgent};
15use crate::headers::iter_flat;
16use crate::v2::pool::CHANNEL_BUFFER_SIZE;
17use crate::v2::{AgentCapabilities, PROTOCOL_VERSION_2};
18use crate::{AgentProtocolError, AgentResponse, Decision, EventType, HeaderOp};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CancelReason {
23 ClientDisconnect,
25 Timeout,
27 BlockedByAgent,
29 UpstreamError,
31 ProxyShutdown,
33 Manual,
35}
36
37impl CancelReason {
38 fn to_grpc(self) -> i32 {
39 match self {
40 CancelReason::ClientDisconnect => 1,
41 CancelReason::Timeout => 2,
42 CancelReason::BlockedByAgent => 3,
43 CancelReason::UpstreamError => 4,
44 CancelReason::ProxyShutdown => 5,
45 CancelReason::Manual => 6,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum FlowState {
53 #[default]
55 Normal,
56 Paused,
58 Draining,
60}
61
62pub type MetricsCallback = Arc<dyn Fn(crate::v2::MetricsReport) + Send + Sync>;
64
65pub type ConfigUpdateCallback =
70 Arc<dyn Fn(String, crate::v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateResponse + Send + Sync>;
71
72pub struct AgentClientV2 {
85 agent_id: String,
87 channel: Channel,
89 timeout: Duration,
91 capabilities: RwLock<Option<AgentCapabilities>>,
93 protocol_version: AtomicU64,
95 pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
97 outbound_tx: Mutex<Option<mpsc::Sender<ProxyToAgent>>>,
99 ping_sequence: AtomicU64,
101 connected: RwLock<bool>,
103 flow_state: RwLock<FlowState>,
105 health_state: RwLock<i32>,
107 in_flight: AtomicU64,
109 metrics_callback: Option<MetricsCallback>,
111 config_update_callback: Option<ConfigUpdateCallback>,
113}
114
115impl AgentClientV2 {
116 pub async fn new(
118 agent_id: impl Into<String>,
119 endpoint: impl Into<String>,
120 timeout: Duration,
121 ) -> Result<Self, AgentProtocolError> {
122 let agent_id = agent_id.into();
123 let endpoint = endpoint.into();
124
125 debug!(agent_id = %agent_id, endpoint = %endpoint, "Creating v2 client");
126
127 let channel = Channel::from_shared(endpoint.clone())
128 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid endpoint: {}", e)))?
129 .connect_timeout(timeout)
130 .timeout(timeout)
131 .connect()
132 .await
133 .map_err(|e| {
134 AgentProtocolError::ConnectionFailed(format!("Failed to connect: {}", e))
135 })?;
136
137 Ok(Self {
138 agent_id,
139 channel,
140 timeout,
141 capabilities: RwLock::new(None),
142 protocol_version: AtomicU64::new(1), pending: Arc::new(Mutex::new(HashMap::new())),
144 outbound_tx: Mutex::new(None),
145 ping_sequence: AtomicU64::new(0),
146 connected: RwLock::new(false),
147 flow_state: RwLock::new(FlowState::Normal),
148 health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
150 metrics_callback: None,
151 config_update_callback: None,
152 })
153 }
154
155 pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
160 self.metrics_callback = Some(callback);
161 }
162
163 pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
168 self.config_update_callback = Some(callback);
169 }
170
171 pub async fn connect(&self) -> Result<(), AgentProtocolError> {
173 let mut client = AgentServiceV2Client::new(self.channel.clone());
174
175 let (tx, rx) = mpsc::channel::<ProxyToAgent>(CHANNEL_BUFFER_SIZE);
177 let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
178
179 let response_stream = client
180 .process_stream(rx_stream)
181 .await
182 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream failed: {}", e)))?;
183
184 let mut inbound = response_stream.into_inner();
185
186 let handshake = ProxyToAgent {
188 message: Some(grpc_v2::proxy_to_agent::Message::Handshake(
189 grpc_v2::HandshakeRequest {
190 supported_versions: vec![PROTOCOL_VERSION_2, 1],
191 proxy_id: "sentinel-proxy".to_string(),
192 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
193 config_json: "{}".to_string(),
194 },
195 )),
196 };
197
198 tx.send(handshake).await.map_err(|e| {
199 AgentProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e))
200 })?;
201
202 let handshake_resp = tokio::time::timeout(self.timeout, inbound.message())
204 .await
205 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
206 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream error: {}", e)))?
207 .ok_or_else(|| AgentProtocolError::ConnectionFailed("Empty handshake response".to_string()))?;
208
209 if let Some(grpc_v2::agent_to_proxy::Message::Handshake(resp)) = handshake_resp.message {
211 if !resp.success {
212 return Err(AgentProtocolError::ConnectionFailed(format!(
213 "Handshake failed: {}",
214 resp.error.unwrap_or_default()
215 )));
216 }
217
218 self.protocol_version
219 .store(resp.protocol_version as u64, Ordering::SeqCst);
220
221 if let Some(caps) = resp.capabilities {
222 let capabilities = convert_capabilities_from_grpc(caps);
223 *self.capabilities.write().await = Some(capabilities);
224 }
225
226 info!(
227 agent_id = %self.agent_id,
228 protocol_version = resp.protocol_version,
229 "v2 handshake successful"
230 );
231 } else {
232 return Err(AgentProtocolError::ConnectionFailed(
233 "Invalid handshake response".to_string(),
234 ));
235 }
236
237 *self.outbound_tx.lock().await = Some(tx);
239 *self.connected.write().await = true;
240
241 let pending = Arc::clone(&self.pending);
243 let agent_id = self.agent_id.clone();
244 let flow_state = Arc::new(RwLock::new(FlowState::Normal));
245 let health_state = Arc::new(RwLock::new(1i32));
246 let _in_flight = Arc::new(AtomicU64::new(0));
247
248 let flow_state_clone = Arc::clone(&flow_state);
250 let health_state_clone = Arc::clone(&health_state);
251 let metrics_callback = self.metrics_callback.clone();
252 let config_update_callback = self.config_update_callback.clone();
253
254 tokio::spawn(async move {
255 while let Ok(Some(msg)) = inbound.message().await {
256 match msg.message {
257 Some(grpc_v2::agent_to_proxy::Message::Response(resp)) => {
258 let correlation_id = resp.correlation_id.clone();
259 if let Some(sender) = pending.lock().await.remove(&correlation_id) {
260 let response = convert_response_from_grpc(resp);
261 let _ = sender.send(response);
262 } else {
263 warn!(
264 agent_id = %agent_id,
265 correlation_id = %correlation_id,
266 "Received response for unknown correlation ID"
267 );
268 }
269 }
270 Some(grpc_v2::agent_to_proxy::Message::Health(health)) => {
271 trace!(
272 agent_id = %agent_id,
273 state = health.state,
274 "Received health status"
275 );
276 *health_state_clone.write().await = health.state;
277 }
278 Some(grpc_v2::agent_to_proxy::Message::Metrics(metrics)) => {
279 trace!(
280 agent_id = %agent_id,
281 counters = metrics.counters.len(),
282 gauges = metrics.gauges.len(),
283 histograms = metrics.histograms.len(),
284 "Received metrics report"
285 );
286 if let Some(ref callback) = metrics_callback {
287 let report = convert_metrics_from_grpc(metrics, &agent_id);
288 callback(report);
289 }
290 }
291 Some(grpc_v2::agent_to_proxy::Message::FlowControl(fc)) => {
292 let new_state = match fc.action {
294 1 => FlowState::Paused, 2 => FlowState::Normal, _ => FlowState::Normal,
297 };
298 debug!(
299 agent_id = %agent_id,
300 action = fc.action,
301 correlation_id = ?fc.correlation_id,
302 "Received flow control signal"
303 );
304 *flow_state_clone.write().await = new_state;
305 }
306 Some(grpc_v2::agent_to_proxy::Message::Pong(pong)) => {
307 trace!(
308 agent_id = %agent_id,
309 sequence = pong.sequence,
310 latency_ms = pong.timestamp_ms.saturating_sub(pong.ping_timestamp_ms),
311 "Received pong"
312 );
313 }
314 Some(grpc_v2::agent_to_proxy::Message::ConfigUpdate(update)) => {
315 debug!(
316 agent_id = %agent_id,
317 request_id = %update.request_id,
318 "Received config update request from agent"
319 );
320 if let Some(ref callback) = config_update_callback {
321 let request = convert_config_update_from_grpc(update);
322 let _response = callback(agent_id.clone(), request);
323 }
326 }
327 Some(grpc_v2::agent_to_proxy::Message::Log(log_msg)) => {
328 match log_msg.level {
330 1 => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent debug log"),
331 2 => debug!(agent_id = %agent_id, msg = %log_msg.message, "Agent info log"),
332 3 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent warning"),
333 4 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent error"),
334 _ => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent log"),
335 }
336 }
337 _ => {}
338 }
339 }
340
341 debug!(agent_id = %agent_id, "Response handler ended");
342 });
343
344 Ok(())
345 }
346
347 pub async fn send_request_headers(
349 &self,
350 correlation_id: &str,
351 event: &crate::RequestHeadersEvent,
352 ) -> Result<AgentResponse, AgentProtocolError> {
353 let msg = ProxyToAgent {
354 message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
355 convert_request_headers_to_grpc(event),
356 )),
357 };
358
359 self.send_and_wait(correlation_id, msg).await
360 }
361
362 pub async fn send_request_body_chunk(
367 &self,
368 correlation_id: &str,
369 event: &crate::RequestBodyChunkEvent,
370 ) -> Result<AgentResponse, AgentProtocolError> {
371 let msg = ProxyToAgent {
372 message: Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(
373 convert_body_chunk_to_grpc(event),
374 )),
375 };
376
377 self.send_and_wait(correlation_id, msg).await
378 }
379
380 pub async fn send_response_headers(
385 &self,
386 correlation_id: &str,
387 event: &crate::ResponseHeadersEvent,
388 ) -> Result<AgentResponse, AgentProtocolError> {
389 let msg = ProxyToAgent {
390 message: Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(
391 convert_response_headers_to_grpc(event),
392 )),
393 };
394
395 self.send_and_wait(correlation_id, msg).await
396 }
397
398 pub async fn send_response_body_chunk(
403 &self,
404 correlation_id: &str,
405 event: &crate::ResponseBodyChunkEvent,
406 ) -> Result<AgentResponse, AgentProtocolError> {
407 let msg = ProxyToAgent {
408 message: Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(
409 convert_response_body_chunk_to_grpc(event),
410 )),
411 };
412
413 self.send_and_wait(correlation_id, msg).await
414 }
415
416 pub async fn send_event<T: serde::Serialize>(
418 &self,
419 event_type: EventType,
420 event: &T,
421 ) -> Result<AgentResponse, AgentProtocolError> {
422 let correlation_id = extract_correlation_id(event);
424
425 let msg = match event_type {
426 EventType::RequestHeaders => {
427 if let Ok(e) = serde_json::from_value::<crate::RequestHeadersEvent>(
428 serde_json::to_value(event).unwrap_or_default(),
429 ) {
430 ProxyToAgent {
431 message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
432 convert_request_headers_to_grpc(&e),
433 )),
434 }
435 } else {
436 return Err(AgentProtocolError::InvalidMessage(
437 "Failed to convert event".to_string(),
438 ));
439 }
440 }
441 _ => {
442 return Err(AgentProtocolError::InvalidMessage(format!(
444 "Event type {:?} not yet supported in v2 streaming mode",
445 event_type
446 )));
447 }
448 };
449
450 self.send_and_wait(&correlation_id, msg).await
451 }
452
453 async fn send_and_wait(
455 &self,
456 correlation_id: &str,
457 msg: ProxyToAgent,
458 ) -> Result<AgentResponse, AgentProtocolError> {
459 let (tx, rx) = oneshot::channel();
461
462 self.pending
464 .lock()
465 .await
466 .insert(correlation_id.to_string(), tx);
467
468 {
470 let outbound = self.outbound_tx.lock().await;
471 if let Some(sender) = outbound.as_ref() {
472 sender.send(msg).await.map_err(|e| {
473 AgentProtocolError::ConnectionFailed(format!("Send failed: {}", e))
474 })?;
475 } else {
476 return Err(AgentProtocolError::ConnectionFailed(
477 "Not connected".to_string(),
478 ));
479 }
480 }
481
482 match tokio::time::timeout(self.timeout, rx).await {
484 Ok(Ok(response)) => Ok(response),
485 Ok(Err(_)) => {
486 self.pending.lock().await.remove(correlation_id);
487 Err(AgentProtocolError::ConnectionFailed(
488 "Response channel closed".to_string(),
489 ))
490 }
491 Err(_) => {
492 self.pending.lock().await.remove(correlation_id);
493 Err(AgentProtocolError::Timeout(self.timeout))
494 }
495 }
496 }
497
498 pub async fn ping(&self) -> Result<Duration, AgentProtocolError> {
500 let sequence = self.ping_sequence.fetch_add(1, Ordering::SeqCst);
501 let timestamp_ms = now_ms();
502
503 let msg = ProxyToAgent {
504 message: Some(grpc_v2::proxy_to_agent::Message::Ping(grpc_v2::Ping {
505 sequence,
506 timestamp_ms,
507 })),
508 };
509
510 let outbound = self.outbound_tx.lock().await;
511 if let Some(sender) = outbound.as_ref() {
512 sender
513 .send(msg)
514 .await
515 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Ping failed: {}", e)))?;
516 }
517
518 Ok(Duration::from_millis(0))
521 }
522
523 pub fn protocol_version(&self) -> u32 {
525 self.protocol_version.load(Ordering::SeqCst) as u32
526 }
527
528 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
530 self.capabilities.read().await.clone()
531 }
532
533 pub async fn is_connected(&self) -> bool {
535 *self.connected.read().await
536 }
537
538 pub async fn close(&self) -> Result<(), AgentProtocolError> {
540 *self.outbound_tx.lock().await = None;
541 *self.connected.write().await = false;
542 Ok(())
543 }
544
545 pub async fn cancel_request(
550 &self,
551 correlation_id: &str,
552 reason: CancelReason,
553 ) -> Result<(), AgentProtocolError> {
554 self.pending.lock().await.remove(correlation_id);
556
557 let msg = ProxyToAgent {
559 message: Some(grpc_v2::proxy_to_agent::Message::Cancel(
560 grpc_v2::CancelRequest {
561 correlation_id: correlation_id.to_string(),
562 reason: reason.to_grpc(),
563 timestamp_ms: now_ms(),
564 blocking_agent_id: None,
565 manual_reason: None,
566 },
567 )),
568 };
569
570 let outbound = self.outbound_tx.lock().await;
571 if let Some(sender) = outbound.as_ref() {
572 sender.send(msg).await.map_err(|e| {
573 AgentProtocolError::ConnectionFailed(format!("Cancel send failed: {}", e))
574 })?;
575 }
576
577 debug!(
578 agent_id = %self.agent_id,
579 correlation_id = %correlation_id,
580 reason = ?reason,
581 "Cancelled request"
582 );
583
584 Ok(())
585 }
586
587 pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
591 let correlation_ids: Vec<String> = {
592 let pending = self.pending.lock().await;
593 pending.keys().cloned().collect()
594 };
595
596 let count = correlation_ids.len();
597 for cid in correlation_ids {
598 let _ = self.cancel_request(&cid, reason).await;
599 }
600
601 debug!(
602 agent_id = %self.agent_id,
603 count = count,
604 reason = ?reason,
605 "Cancelled all requests"
606 );
607
608 Ok(count)
609 }
610
611 pub async fn flow_state(&self) -> FlowState {
613 *self.flow_state.read().await
614 }
615
616 pub async fn can_accept_requests(&self) -> bool {
620 matches!(*self.flow_state.read().await, FlowState::Normal)
621 }
622
623 pub async fn wait_for_flow_control(&self, timeout: Duration) -> Result<(), AgentProtocolError> {
628 let deadline = tokio::time::Instant::now() + timeout;
629
630 loop {
631 if self.can_accept_requests().await {
632 return Ok(());
633 }
634
635 if tokio::time::Instant::now() >= deadline {
636 return Err(AgentProtocolError::Timeout(timeout));
637 }
638
639 tokio::time::sleep(Duration::from_millis(10)).await;
641 }
642 }
643
644 pub async fn health_state(&self) -> i32 {
652 *self.health_state.read().await
653 }
654
655 pub async fn is_healthy(&self) -> bool {
657 *self.health_state.read().await == 1
658 }
659
660 pub fn in_flight_count(&self) -> u64 {
662 self.in_flight.load(Ordering::Relaxed)
663 }
664
665 pub async fn send_configure(
671 &self,
672 config: serde_json::Value,
673 version: Option<String>,
674 ) -> Result<(), AgentProtocolError> {
675 let msg = ProxyToAgent {
676 message: Some(grpc_v2::proxy_to_agent::Message::Configure(
677 grpc_v2::ConfigureEvent {
678 config_json: serde_json::to_string(&config).unwrap_or_default(),
679 config_version: version,
680 is_initial: false,
681 timestamp_ms: now_ms(),
682 },
683 )),
684 };
685
686 let outbound = self.outbound_tx.lock().await;
687 if let Some(sender) = outbound.as_ref() {
688 sender.send(msg).await.map_err(|e| {
689 AgentProtocolError::ConnectionFailed(format!("Configure send failed: {}", e))
690 })?;
691 } else {
692 return Err(AgentProtocolError::ConnectionFailed(
693 "Not connected".to_string(),
694 ));
695 }
696
697 debug!(agent_id = %self.agent_id, "Sent configuration update");
698 Ok(())
699 }
700
701 pub async fn send_shutdown(
703 &self,
704 reason: ShutdownReason,
705 grace_period_ms: u64,
706 ) -> Result<(), AgentProtocolError> {
707 info!(
708 agent_id = %self.agent_id,
709 reason = ?reason,
710 grace_period_ms = grace_period_ms,
711 "Requesting agent shutdown"
712 );
713
714 let _ = self.cancel_all(CancelReason::ProxyShutdown).await;
716
717 self.close().await
719 }
720
721 pub async fn send_drain(
723 &self,
724 duration_ms: u64,
725 reason: DrainReason,
726 ) -> Result<(), AgentProtocolError> {
727 info!(
728 agent_id = %self.agent_id,
729 duration_ms = duration_ms,
730 reason = ?reason,
731 "Requesting agent drain"
732 );
733
734 *self.flow_state.write().await = FlowState::Draining;
736
737 Ok(())
738 }
739
740 pub fn agent_id(&self) -> &str {
742 &self.agent_id
743 }
744}
745
746#[derive(Debug, Clone, Copy, PartialEq, Eq)]
748pub enum ShutdownReason {
749 Graceful,
750 Immediate,
751 ConfigReload,
752 Upgrade,
753}
754
755#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757pub enum DrainReason {
758 ConfigReload,
759 Maintenance,
760 HealthCheckFailed,
761 Manual,
762}
763
764fn convert_capabilities_from_grpc(caps: grpc_v2::AgentCapabilities) -> AgentCapabilities {
769 use crate::v2::{AgentFeatures, AgentLimits, HealthConfig};
770
771 let features = caps.features.map(|f| AgentFeatures {
772 streaming_body: f.streaming_body,
773 websocket: f.websocket,
774 guardrails: f.guardrails,
775 config_push: f.config_push,
776 metrics_export: f.metrics_export,
777 concurrent_requests: f.concurrent_requests,
778 cancellation: f.cancellation,
779 flow_control: f.flow_control,
780 health_reporting: f.health_reporting,
781 }).unwrap_or_default();
782
783 let limits = caps.limits.map(|l| AgentLimits {
784 max_body_size: l.max_body_size as usize,
785 max_concurrency: l.max_concurrency,
786 preferred_chunk_size: l.preferred_chunk_size as usize,
787 max_memory: l.max_memory.map(|m| m as usize),
788 max_processing_time_ms: l.max_processing_time_ms,
789 }).unwrap_or_default();
790
791 let health = caps.health_config.map(|h| HealthConfig {
792 report_interval_ms: h.report_interval_ms,
793 include_load_metrics: h.include_load_metrics,
794 include_resource_metrics: h.include_resource_metrics,
795 }).unwrap_or_default();
796
797 AgentCapabilities {
798 protocol_version: caps.protocol_version,
799 agent_id: caps.agent_id,
800 name: caps.name,
801 version: caps.version,
802 supported_events: caps.supported_events.into_iter().filter_map(i32_to_event_type).collect(),
803 features,
804 limits,
805 health,
806 }
807}
808
809fn i32_to_event_type(i: i32) -> Option<EventType> {
810 match i {
811 1 => Some(EventType::RequestHeaders),
812 2 => Some(EventType::RequestBodyChunk),
813 3 => Some(EventType::ResponseHeaders),
814 4 => Some(EventType::ResponseBodyChunk),
815 5 => Some(EventType::RequestComplete),
816 6 => Some(EventType::WebSocketFrame),
817 7 => Some(EventType::GuardrailInspect),
818 8 => Some(EventType::Configure),
819 _ => None,
820 }
821}
822
823fn convert_request_headers_to_grpc(event: &crate::RequestHeadersEvent) -> grpc_v2::RequestHeadersEvent {
824 let metadata = Some(grpc_v2::RequestMetadata {
825 correlation_id: event.metadata.correlation_id.clone(),
826 request_id: event.metadata.request_id.clone(),
827 client_ip: event.metadata.client_ip.clone(),
828 client_port: event.metadata.client_port as u32,
829 server_name: event.metadata.server_name.clone(),
830 protocol: event.metadata.protocol.clone(),
831 tls_version: event.metadata.tls_version.clone(),
832 route_id: event.metadata.route_id.clone(),
833 upstream_id: event.metadata.upstream_id.clone(),
834 timestamp_ms: now_ms(),
835 traceparent: event.metadata.traceparent.clone(),
836 });
837
838 let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
840 .map(|(name, value)| grpc_v2::Header {
841 name: name.to_string(),
842 value: value.to_string(),
843 })
844 .collect();
845
846 grpc_v2::RequestHeadersEvent {
847 metadata,
848 method: event.method.clone(),
849 uri: event.uri.clone(),
850 http_version: "HTTP/1.1".to_string(),
851 headers,
852 }
853}
854
855fn convert_body_chunk_to_grpc(event: &crate::RequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
856 let binary: crate::BinaryRequestBodyChunkEvent = event.into();
858 convert_binary_body_chunk_to_grpc(&binary)
859}
860
861fn convert_binary_body_chunk_to_grpc(event: &crate::BinaryRequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
865 grpc_v2::BodyChunkEvent {
866 correlation_id: event.correlation_id.clone(),
867 chunk_index: event.chunk_index,
868 data: event.data.to_vec(), is_last: event.is_last,
870 total_size: event.total_size.map(|s| s as u64),
871 bytes_transferred: event.bytes_received as u64,
872 proxy_buffer_available: 0, timestamp_ms: now_ms(),
874 }
875}
876
877fn convert_response_headers_to_grpc(event: &crate::ResponseHeadersEvent) -> grpc_v2::ResponseHeadersEvent {
878 let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
880 .map(|(name, value)| grpc_v2::Header {
881 name: name.to_string(),
882 value: value.to_string(),
883 })
884 .collect();
885
886 grpc_v2::ResponseHeadersEvent {
887 correlation_id: event.correlation_id.clone(),
888 status_code: event.status as u32,
889 headers,
890 }
891}
892
893fn convert_response_body_chunk_to_grpc(event: &crate::ResponseBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
894 let binary: crate::BinaryResponseBodyChunkEvent = event.into();
896 convert_binary_response_body_chunk_to_grpc(&binary)
897}
898
899fn convert_binary_response_body_chunk_to_grpc(event: &crate::BinaryResponseBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
903 grpc_v2::BodyChunkEvent {
904 correlation_id: event.correlation_id.clone(),
905 chunk_index: event.chunk_index,
906 data: event.data.to_vec(), is_last: event.is_last,
908 total_size: event.total_size.map(|s| s as u64),
909 bytes_transferred: event.bytes_sent as u64,
910 proxy_buffer_available: 0,
911 timestamp_ms: now_ms(),
912 }
913}
914
915fn convert_response_from_grpc(resp: grpc_v2::AgentResponse) -> AgentResponse {
916 let decision = match resp.decision {
917 Some(grpc_v2::agent_response::Decision::Allow(_)) => Decision::Allow,
918 Some(grpc_v2::agent_response::Decision::Block(b)) => Decision::Block {
919 status: b.status as u16,
920 body: b.body,
921 headers: if b.headers.is_empty() {
922 None
923 } else {
924 Some(b.headers.into_iter().map(|h| (h.name, h.value)).collect())
925 },
926 },
927 Some(grpc_v2::agent_response::Decision::Redirect(r)) => Decision::Redirect {
928 url: r.url,
929 status: r.status as u16,
930 },
931 Some(grpc_v2::agent_response::Decision::Challenge(c)) => Decision::Challenge {
932 challenge_type: c.challenge_type,
933 params: c.params,
934 },
935 None => Decision::Allow,
936 };
937
938 let request_headers: Vec<HeaderOp> = resp
939 .request_headers
940 .into_iter()
941 .filter_map(convert_header_op_from_grpc)
942 .collect();
943
944 let response_headers: Vec<HeaderOp> = resp
945 .response_headers
946 .into_iter()
947 .filter_map(convert_header_op_from_grpc)
948 .collect();
949
950 let audit = resp.audit.map(|a| crate::AuditMetadata {
951 tags: a.tags,
952 rule_ids: a.rule_ids,
953 confidence: a.confidence,
954 reason_codes: a.reason_codes,
955 custom: a.custom.into_iter().map(|(k, v)| (k, serde_json::Value::String(v))).collect(),
956 }).unwrap_or_default();
957
958 AgentResponse {
959 version: PROTOCOL_VERSION_2,
960 decision,
961 request_headers,
962 response_headers,
963 routing_metadata: HashMap::new(),
964 audit,
965 needs_more: resp.needs_more,
966 request_body_mutation: None,
967 response_body_mutation: None,
968 websocket_decision: None,
969 }
970}
971
972fn convert_header_op_from_grpc(op: grpc_v2::HeaderOp) -> Option<HeaderOp> {
973 match op.operation {
974 Some(grpc_v2::header_op::Operation::Set(h)) => Some(HeaderOp::Set {
975 name: h.name,
976 value: h.value,
977 }),
978 Some(grpc_v2::header_op::Operation::Add(h)) => Some(HeaderOp::Add {
979 name: h.name,
980 value: h.value,
981 }),
982 Some(grpc_v2::header_op::Operation::Remove(name)) => Some(HeaderOp::Remove { name }),
983 None => None,
984 }
985}
986
987fn convert_metrics_from_grpc(report: grpc_v2::MetricsReport, agent_id: &str) -> crate::v2::MetricsReport {
988 use crate::v2::metrics::{CounterMetric, GaugeMetric, HistogramBucket, HistogramMetric};
989
990 let counters = report
991 .counters
992 .into_iter()
993 .map(|c| CounterMetric {
994 name: c.name,
995 help: c.help.filter(|s| !s.is_empty()),
996 labels: c.labels,
997 value: c.value,
998 })
999 .collect();
1000
1001 let gauges = report
1002 .gauges
1003 .into_iter()
1004 .map(|g| GaugeMetric {
1005 name: g.name,
1006 help: g.help.filter(|s| !s.is_empty()),
1007 labels: g.labels,
1008 value: g.value,
1009 })
1010 .collect();
1011
1012 let histograms = report
1013 .histograms
1014 .into_iter()
1015 .map(|h| HistogramMetric {
1016 name: h.name,
1017 help: h.help.filter(|s| !s.is_empty()),
1018 labels: h.labels,
1019 sum: h.sum,
1020 count: h.count,
1021 buckets: h
1022 .buckets
1023 .into_iter()
1024 .map(|b| HistogramBucket { le: b.le, count: b.count })
1025 .collect(),
1026 })
1027 .collect();
1028
1029 crate::v2::MetricsReport {
1030 agent_id: agent_id.to_string(),
1031 timestamp_ms: report.timestamp_ms,
1032 interval_ms: report.interval_ms,
1033 counters,
1034 gauges,
1035 histograms,
1036 }
1037}
1038
1039fn convert_config_update_from_grpc(update: grpc_v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateRequest {
1040 use crate::v2::control::{ConfigUpdateType, RuleDefinition};
1041
1042 let update_type = match update.update_type {
1043 Some(grpc_v2::config_update_request::UpdateType::RequestReload(_)) => {
1044 ConfigUpdateType::RequestReload
1045 }
1046 Some(grpc_v2::config_update_request::UpdateType::RuleUpdate(ru)) => {
1047 ConfigUpdateType::RuleUpdate {
1048 rule_set: ru.rule_set,
1049 rules: ru
1050 .rules
1051 .into_iter()
1052 .map(|r| RuleDefinition {
1053 id: r.id,
1054 priority: r.priority,
1055 definition: serde_json::from_str(&r.definition_json).unwrap_or_default(),
1056 enabled: r.enabled,
1057 description: r.description,
1058 tags: r.tags,
1059 })
1060 .collect(),
1061 remove_rules: ru.remove_rules,
1062 }
1063 }
1064 Some(grpc_v2::config_update_request::UpdateType::ListUpdate(lu)) => {
1065 ConfigUpdateType::ListUpdate {
1066 list_id: lu.list_id,
1067 add: lu.add,
1068 remove: lu.remove,
1069 }
1070 }
1071 Some(grpc_v2::config_update_request::UpdateType::RestartRequired(rr)) => {
1072 ConfigUpdateType::RestartRequired {
1073 reason: rr.reason,
1074 grace_period_ms: rr.grace_period_ms,
1075 }
1076 }
1077 Some(grpc_v2::config_update_request::UpdateType::ConfigError(ce)) => {
1078 ConfigUpdateType::ConfigError {
1079 error: ce.error,
1080 field: ce.field,
1081 }
1082 }
1083 None => ConfigUpdateType::RequestReload, };
1085
1086 crate::v2::ConfigUpdateRequest {
1087 update_type,
1088 request_id: update.request_id,
1089 timestamp_ms: update.timestamp_ms,
1090 }
1091}
1092
1093fn extract_correlation_id<T: serde::Serialize>(event: &T) -> String {
1094 if let Ok(value) = serde_json::to_value(event) {
1096 if let Some(metadata) = value.get("metadata") {
1097 if let Some(cid) = metadata.get("correlation_id").and_then(|v| v.as_str()) {
1098 return cid.to_string();
1099 }
1100 }
1101 if let Some(cid) = value.get("correlation_id").and_then(|v| v.as_str()) {
1102 return cid.to_string();
1103 }
1104 }
1105 uuid::Uuid::new_v4().to_string()
1106}
1107
1108fn now_ms() -> u64 {
1109 std::time::SystemTime::now()
1110 .duration_since(std::time::UNIX_EPOCH)
1111 .map(|d| d.as_millis() as u64)
1112 .unwrap_or(0)
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118
1119 #[test]
1120 fn test_event_type_conversion() {
1121 assert_eq!(i32_to_event_type(1), Some(EventType::RequestHeaders));
1122 assert_eq!(i32_to_event_type(2), Some(EventType::RequestBodyChunk));
1123 assert_eq!(i32_to_event_type(99), None);
1124 }
1125
1126 #[test]
1127 fn test_extract_correlation_id() {
1128 #[derive(serde::Serialize)]
1129 struct TestEvent {
1130 correlation_id: String,
1131 }
1132
1133 let event = TestEvent {
1134 correlation_id: "test-123".to_string(),
1135 };
1136
1137 assert_eq!(extract_correlation_id(&event), "test-123");
1138 }
1139}