Skip to main content

sentinel_agent_protocol/v2/
client.rs

1//! Agent client implementation for Protocol v2.
2//!
3//! The v2 client supports bidirectional streaming with connection multiplexing,
4//! allowing multiple concurrent requests over a single connection.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
11use tonic::transport::Channel;
12use tracing::{debug, info, trace, warn};
13
14use crate::grpc_v2::{self, agent_service_v2_client::AgentServiceV2Client, ProxyToAgent};
15use crate::headers::iter_flat;
16use crate::v2::pool::CHANNEL_BUFFER_SIZE;
17use crate::v2::{AgentCapabilities, PROTOCOL_VERSION_2};
18use crate::{AgentProtocolError, AgentResponse, Decision, EventType, HeaderOp};
19
20/// Cancellation reason for in-flight requests.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CancelReason {
23    /// Client disconnected
24    ClientDisconnect,
25    /// Request timed out
26    Timeout,
27    /// Blocked by another agent
28    BlockedByAgent,
29    /// Upstream connection failed
30    UpstreamError,
31    /// Proxy is shutting down
32    ProxyShutdown,
33    /// Manual cancellation
34    Manual,
35}
36
37impl CancelReason {
38    fn to_grpc(self) -> i32 {
39        match self {
40            CancelReason::ClientDisconnect => 1,
41            CancelReason::Timeout => 2,
42            CancelReason::BlockedByAgent => 3,
43            CancelReason::UpstreamError => 4,
44            CancelReason::ProxyShutdown => 5,
45            CancelReason::Manual => 6,
46        }
47    }
48}
49
50/// Flow control state.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum FlowState {
53    /// Normal operation
54    #[default]
55    Normal,
56    /// Agent requested pause
57    Paused,
58    /// Draining (finishing in-flight, no new requests)
59    Draining,
60}
61
62/// Callback for metrics reports from agents.
63pub type MetricsCallback = Arc<dyn Fn(crate::v2::MetricsReport) + Send + Sync>;
64
65/// Callback for config update requests from agents.
66///
67/// The callback receives the agent ID and the config update request.
68/// It should return a response indicating whether the update was accepted.
69pub type ConfigUpdateCallback = Arc<
70    dyn Fn(String, crate::v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateResponse + Send + Sync,
71>;
72
73/// v2 agent client with connection multiplexing.
74///
75/// This client maintains a single bidirectional stream and multiplexes
76/// multiple requests over it using correlation IDs.
77///
78/// # Features
79///
80/// - **Connection multiplexing**: Multiple concurrent requests over one connection
81/// - **Cancellation support**: Cancel in-flight requests
82/// - **Flow control**: Backpressure handling when agent is overloaded
83/// - **Health tracking**: Monitor agent health status
84/// - **Metrics collection**: Receive and forward agent metrics
85pub struct AgentClientV2 {
86    /// Agent identifier
87    agent_id: String,
88    /// gRPC channel (for reconnection)
89    channel: Channel,
90    /// Request timeout
91    timeout: Duration,
92    /// Negotiated capabilities
93    capabilities: RwLock<Option<AgentCapabilities>>,
94    /// Negotiated protocol version
95    protocol_version: AtomicU64,
96    /// Pending requests by correlation ID
97    pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
98    /// Sender for outbound messages
99    outbound_tx: Mutex<Option<mpsc::Sender<ProxyToAgent>>>,
100    /// Sequence counter for pings
101    ping_sequence: AtomicU64,
102    /// Connection state
103    connected: RwLock<bool>,
104    /// Flow control state
105    flow_state: RwLock<FlowState>,
106    /// Last known health state
107    health_state: RwLock<i32>,
108    /// In-flight request count
109    in_flight: AtomicU64,
110    /// Callback for metrics reports
111    metrics_callback: Option<MetricsCallback>,
112    /// Callback for config update requests
113    config_update_callback: Option<ConfigUpdateCallback>,
114}
115
116impl AgentClientV2 {
117    /// Create a new v2 client.
118    pub async fn new(
119        agent_id: impl Into<String>,
120        endpoint: impl Into<String>,
121        timeout: Duration,
122    ) -> Result<Self, AgentProtocolError> {
123        let agent_id = agent_id.into();
124        let endpoint = endpoint.into();
125
126        debug!(agent_id = %agent_id, endpoint = %endpoint, "Creating v2 client");
127
128        let channel = Channel::from_shared(endpoint.clone())
129            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid endpoint: {}", e)))?
130            .connect_timeout(timeout)
131            .timeout(timeout)
132            .connect()
133            .await
134            .map_err(|e| {
135                AgentProtocolError::ConnectionFailed(format!("Failed to connect: {}", e))
136            })?;
137
138        Ok(Self {
139            agent_id,
140            channel,
141            timeout,
142            capabilities: RwLock::new(None),
143            protocol_version: AtomicU64::new(1), // Default to v1 until handshake
144            pending: Arc::new(Mutex::new(HashMap::new())),
145            outbound_tx: Mutex::new(None),
146            ping_sequence: AtomicU64::new(0),
147            connected: RwLock::new(false),
148            flow_state: RwLock::new(FlowState::Normal),
149            health_state: RwLock::new(1), // HEALTHY
150            in_flight: AtomicU64::new(0),
151            metrics_callback: None,
152            config_update_callback: None,
153        })
154    }
155
156    /// Set the metrics callback for receiving agent metrics reports.
157    ///
158    /// This callback is invoked whenever the agent sends a metrics report
159    /// through the control stream. The callback should be fast and non-blocking.
160    pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
161        self.metrics_callback = Some(callback);
162    }
163
164    /// Set the config update callback for handling agent config requests.
165    ///
166    /// This callback is invoked whenever the agent sends a config update request
167    /// through the control stream (e.g., requesting a reload, reporting errors).
168    pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
169        self.config_update_callback = Some(callback);
170    }
171
172    /// Connect and perform handshake.
173    pub async fn connect(&self) -> Result<(), AgentProtocolError> {
174        let mut client = AgentServiceV2Client::new(self.channel.clone());
175
176        // Create bidirectional stream
177        let (tx, rx) = mpsc::channel::<ProxyToAgent>(CHANNEL_BUFFER_SIZE);
178        let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
179
180        let response_stream = client
181            .process_stream(rx_stream)
182            .await
183            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream failed: {}", e)))?;
184
185        let mut inbound = response_stream.into_inner();
186
187        // Send handshake
188        let handshake = ProxyToAgent {
189            message: Some(grpc_v2::proxy_to_agent::Message::Handshake(
190                grpc_v2::HandshakeRequest {
191                    supported_versions: vec![PROTOCOL_VERSION_2, 1],
192                    proxy_id: "sentinel-proxy".to_string(),
193                    proxy_version: env!("CARGO_PKG_VERSION").to_string(),
194                    config_json: "{}".to_string(),
195                },
196            )),
197        };
198
199        tx.send(handshake).await.map_err(|e| {
200            AgentProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e))
201        })?;
202
203        // Wait for handshake response
204        let handshake_resp = tokio::time::timeout(self.timeout, inbound.message())
205            .await
206            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
207            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream error: {}", e)))?
208            .ok_or_else(|| {
209                AgentProtocolError::ConnectionFailed("Empty handshake response".to_string())
210            })?;
211
212        // Process handshake response
213        if let Some(grpc_v2::agent_to_proxy::Message::Handshake(resp)) = handshake_resp.message {
214            if !resp.success {
215                return Err(AgentProtocolError::ConnectionFailed(format!(
216                    "Handshake failed: {}",
217                    resp.error.unwrap_or_default()
218                )));
219            }
220
221            self.protocol_version
222                .store(resp.protocol_version as u64, Ordering::SeqCst);
223
224            if let Some(caps) = resp.capabilities {
225                let capabilities = convert_capabilities_from_grpc(caps);
226                *self.capabilities.write().await = Some(capabilities);
227            }
228
229            info!(
230                agent_id = %self.agent_id,
231                protocol_version = resp.protocol_version,
232                "v2 handshake successful"
233            );
234        } else {
235            return Err(AgentProtocolError::ConnectionFailed(
236                "Invalid handshake response".to_string(),
237            ));
238        }
239
240        // Store outbound sender
241        *self.outbound_tx.lock().await = Some(tx);
242        *self.connected.write().await = true;
243
244        // Spawn background task to handle incoming messages
245        let pending = Arc::clone(&self.pending);
246        let agent_id = self.agent_id.clone();
247        let flow_state = Arc::new(RwLock::new(FlowState::Normal));
248        let health_state = Arc::new(RwLock::new(1i32));
249        let _in_flight = Arc::new(AtomicU64::new(0));
250
251        // Share state with the spawned task
252        let flow_state_clone = Arc::clone(&flow_state);
253        let health_state_clone = Arc::clone(&health_state);
254        let metrics_callback = self.metrics_callback.clone();
255        let config_update_callback = self.config_update_callback.clone();
256
257        tokio::spawn(async move {
258            while let Ok(Some(msg)) = inbound.message().await {
259                match msg.message {
260                    Some(grpc_v2::agent_to_proxy::Message::Response(resp)) => {
261                        let correlation_id = resp.correlation_id.clone();
262                        if let Some(sender) = pending.lock().await.remove(&correlation_id) {
263                            let response = convert_response_from_grpc(resp);
264                            let _ = sender.send(response);
265                        } else {
266                            warn!(
267                                agent_id = %agent_id,
268                                correlation_id = %correlation_id,
269                                "Received response for unknown correlation ID"
270                            );
271                        }
272                    }
273                    Some(grpc_v2::agent_to_proxy::Message::Health(health)) => {
274                        trace!(
275                            agent_id = %agent_id,
276                            state = health.state,
277                            "Received health status"
278                        );
279                        *health_state_clone.write().await = health.state;
280                    }
281                    Some(grpc_v2::agent_to_proxy::Message::Metrics(metrics)) => {
282                        trace!(
283                            agent_id = %agent_id,
284                            counters = metrics.counters.len(),
285                            gauges = metrics.gauges.len(),
286                            histograms = metrics.histograms.len(),
287                            "Received metrics report"
288                        );
289                        if let Some(ref callback) = metrics_callback {
290                            let report = convert_metrics_from_grpc(metrics, &agent_id);
291                            callback(report);
292                        }
293                    }
294                    Some(grpc_v2::agent_to_proxy::Message::FlowControl(fc)) => {
295                        // Handle flow control signals
296                        let new_state = match fc.action {
297                            1 => FlowState::Paused, // PAUSE
298                            2 => FlowState::Normal, // RESUME
299                            _ => FlowState::Normal,
300                        };
301                        debug!(
302                            agent_id = %agent_id,
303                            action = fc.action,
304                            correlation_id = ?fc.correlation_id,
305                            "Received flow control signal"
306                        );
307                        *flow_state_clone.write().await = new_state;
308                    }
309                    Some(grpc_v2::agent_to_proxy::Message::Pong(pong)) => {
310                        trace!(
311                            agent_id = %agent_id,
312                            sequence = pong.sequence,
313                            latency_ms = pong.timestamp_ms.saturating_sub(pong.ping_timestamp_ms),
314                            "Received pong"
315                        );
316                    }
317                    Some(grpc_v2::agent_to_proxy::Message::ConfigUpdate(update)) => {
318                        debug!(
319                            agent_id = %agent_id,
320                            request_id = %update.request_id,
321                            "Received config update request from agent"
322                        );
323                        if let Some(ref callback) = config_update_callback {
324                            let request = convert_config_update_from_grpc(update);
325                            let _response = callback(agent_id.clone(), request);
326                            // Note: Response would be sent via control stream if we had one
327                            // For now, the callback handles the request and logs/processes it
328                        }
329                    }
330                    Some(grpc_v2::agent_to_proxy::Message::Log(log_msg)) => {
331                        // Handle log messages from agent
332                        match log_msg.level {
333                            1 => {
334                                trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent debug log")
335                            }
336                            2 => {
337                                debug!(agent_id = %agent_id, msg = %log_msg.message, "Agent info log")
338                            }
339                            3 => {
340                                warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent warning")
341                            }
342                            4 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent error"),
343                            _ => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent log"),
344                        }
345                    }
346                    _ => {}
347                }
348            }
349
350            debug!(agent_id = %agent_id, "Response handler ended");
351        });
352
353        Ok(())
354    }
355
356    /// Send a request headers event and wait for response.
357    pub async fn send_request_headers(
358        &self,
359        correlation_id: &str,
360        event: &crate::RequestHeadersEvent,
361    ) -> Result<AgentResponse, AgentProtocolError> {
362        let msg = ProxyToAgent {
363            message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
364                convert_request_headers_to_grpc(event),
365            )),
366        };
367
368        self.send_and_wait(correlation_id, msg).await
369    }
370
371    /// Send a request body chunk event and wait for response.
372    ///
373    /// For streaming body inspection, chunks are sent sequentially with
374    /// increasing `chunk_index`. The agent responds after processing each chunk.
375    pub async fn send_request_body_chunk(
376        &self,
377        correlation_id: &str,
378        event: &crate::RequestBodyChunkEvent,
379    ) -> Result<AgentResponse, AgentProtocolError> {
380        let msg = ProxyToAgent {
381            message: Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(
382                convert_body_chunk_to_grpc(event),
383            )),
384        };
385
386        self.send_and_wait(correlation_id, msg).await
387    }
388
389    /// Send a response headers event and wait for response.
390    ///
391    /// Called when upstream response headers are received, allowing the agent
392    /// to inspect/modify response headers before they're sent to the client.
393    pub async fn send_response_headers(
394        &self,
395        correlation_id: &str,
396        event: &crate::ResponseHeadersEvent,
397    ) -> Result<AgentResponse, AgentProtocolError> {
398        let msg = ProxyToAgent {
399            message: Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(
400                convert_response_headers_to_grpc(event),
401            )),
402        };
403
404        self.send_and_wait(correlation_id, msg).await
405    }
406
407    /// Send a response body chunk event and wait for response.
408    ///
409    /// For streaming response body inspection, chunks are sent sequentially.
410    /// The agent can inspect and optionally modify response body data.
411    pub async fn send_response_body_chunk(
412        &self,
413        correlation_id: &str,
414        event: &crate::ResponseBodyChunkEvent,
415    ) -> Result<AgentResponse, AgentProtocolError> {
416        let msg = ProxyToAgent {
417            message: Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(
418                convert_response_body_chunk_to_grpc(event),
419            )),
420        };
421
422        self.send_and_wait(correlation_id, msg).await
423    }
424
425    /// Send any event type and wait for response.
426    pub async fn send_event<T: serde::Serialize>(
427        &self,
428        event_type: EventType,
429        event: &T,
430    ) -> Result<AgentResponse, AgentProtocolError> {
431        // For compatibility, extract correlation_id from event
432        let correlation_id = extract_correlation_id(event);
433
434        let msg = match event_type {
435            EventType::RequestHeaders => {
436                if let Ok(e) = serde_json::from_value::<crate::RequestHeadersEvent>(
437                    serde_json::to_value(event).unwrap_or_default(),
438                ) {
439                    ProxyToAgent {
440                        message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
441                            convert_request_headers_to_grpc(&e),
442                        )),
443                    }
444                } else {
445                    return Err(AgentProtocolError::InvalidMessage(
446                        "Failed to convert event".to_string(),
447                    ));
448                }
449            }
450            _ => {
451                // Fall back to v1 for unsupported event types
452                return Err(AgentProtocolError::InvalidMessage(format!(
453                    "Event type {:?} not yet supported in v2 streaming mode",
454                    event_type
455                )));
456            }
457        };
458
459        self.send_and_wait(&correlation_id, msg).await
460    }
461
462    /// Send a message and wait for response.
463    async fn send_and_wait(
464        &self,
465        correlation_id: &str,
466        msg: ProxyToAgent,
467    ) -> Result<AgentResponse, AgentProtocolError> {
468        // Create response channel
469        let (tx, rx) = oneshot::channel();
470
471        // Register pending request
472        self.pending
473            .lock()
474            .await
475            .insert(correlation_id.to_string(), tx);
476
477        // Send message
478        {
479            let outbound = self.outbound_tx.lock().await;
480            if let Some(sender) = outbound.as_ref() {
481                sender.send(msg).await.map_err(|e| {
482                    AgentProtocolError::ConnectionFailed(format!("Send failed: {}", e))
483                })?;
484            } else {
485                return Err(AgentProtocolError::ConnectionFailed(
486                    "Not connected".to_string(),
487                ));
488            }
489        }
490
491        // Wait for response with timeout
492        match tokio::time::timeout(self.timeout, rx).await {
493            Ok(Ok(response)) => Ok(response),
494            Ok(Err(_)) => {
495                self.pending.lock().await.remove(correlation_id);
496                Err(AgentProtocolError::ConnectionFailed(
497                    "Response channel closed".to_string(),
498                ))
499            }
500            Err(_) => {
501                self.pending.lock().await.remove(correlation_id);
502                Err(AgentProtocolError::Timeout(self.timeout))
503            }
504        }
505    }
506
507    /// Send a ping and measure latency.
508    pub async fn ping(&self) -> Result<Duration, AgentProtocolError> {
509        let sequence = self.ping_sequence.fetch_add(1, Ordering::SeqCst);
510        let timestamp_ms = now_ms();
511
512        let msg = ProxyToAgent {
513            message: Some(grpc_v2::proxy_to_agent::Message::Ping(grpc_v2::Ping {
514                sequence,
515                timestamp_ms,
516            })),
517        };
518
519        let outbound = self.outbound_tx.lock().await;
520        if let Some(sender) = outbound.as_ref() {
521            sender
522                .send(msg)
523                .await
524                .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Ping failed: {}", e)))?;
525        }
526
527        // Note: In a full implementation, we'd track pong responses
528        // For now, just return a placeholder
529        Ok(Duration::from_millis(0))
530    }
531
532    /// Get negotiated protocol version.
533    pub fn protocol_version(&self) -> u32 {
534        self.protocol_version.load(Ordering::SeqCst) as u32
535    }
536
537    /// Get agent capabilities.
538    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
539        self.capabilities.read().await.clone()
540    }
541
542    /// Check if client is connected.
543    pub async fn is_connected(&self) -> bool {
544        *self.connected.read().await
545    }
546
547    /// Close the connection.
548    pub async fn close(&self) -> Result<(), AgentProtocolError> {
549        *self.outbound_tx.lock().await = None;
550        *self.connected.write().await = false;
551        Ok(())
552    }
553
554    /// Cancel an in-flight request.
555    ///
556    /// Sends a cancellation message to the agent and removes the request from
557    /// the pending map. The agent should stop processing and clean up resources.
558    pub async fn cancel_request(
559        &self,
560        correlation_id: &str,
561        reason: CancelReason,
562    ) -> Result<(), AgentProtocolError> {
563        // Remove from pending (will cause the waiter to receive an error)
564        self.pending.lock().await.remove(correlation_id);
565
566        // Send cancel message to agent
567        let msg = ProxyToAgent {
568            message: Some(grpc_v2::proxy_to_agent::Message::Cancel(
569                grpc_v2::CancelRequest {
570                    correlation_id: correlation_id.to_string(),
571                    reason: reason.to_grpc(),
572                    timestamp_ms: now_ms(),
573                    blocking_agent_id: None,
574                    manual_reason: None,
575                },
576            )),
577        };
578
579        let outbound = self.outbound_tx.lock().await;
580        if let Some(sender) = outbound.as_ref() {
581            sender.send(msg).await.map_err(|e| {
582                AgentProtocolError::ConnectionFailed(format!("Cancel send failed: {}", e))
583            })?;
584        }
585
586        debug!(
587            agent_id = %self.agent_id,
588            correlation_id = %correlation_id,
589            reason = ?reason,
590            "Cancelled request"
591        );
592
593        Ok(())
594    }
595
596    /// Cancel all in-flight requests.
597    ///
598    /// Used during shutdown or when the upstream connection fails.
599    pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
600        let correlation_ids: Vec<String> = {
601            let pending = self.pending.lock().await;
602            pending.keys().cloned().collect()
603        };
604
605        let count = correlation_ids.len();
606        for cid in correlation_ids {
607            let _ = self.cancel_request(&cid, reason).await;
608        }
609
610        debug!(
611            agent_id = %self.agent_id,
612            count = count,
613            reason = ?reason,
614            "Cancelled all requests"
615        );
616
617        Ok(count)
618    }
619
620    /// Get current flow control state.
621    pub async fn flow_state(&self) -> FlowState {
622        *self.flow_state.read().await
623    }
624
625    /// Check if the agent is accepting new requests.
626    ///
627    /// Returns false if the agent has requested a pause or is draining.
628    pub async fn can_accept_requests(&self) -> bool {
629        matches!(*self.flow_state.read().await, FlowState::Normal)
630    }
631
632    /// Wait for flow control to allow new requests.
633    ///
634    /// If the agent has requested a pause, this will wait until it resumes
635    /// or the timeout expires.
636    pub async fn wait_for_flow_control(&self, timeout: Duration) -> Result<(), AgentProtocolError> {
637        let deadline = tokio::time::Instant::now() + timeout;
638
639        loop {
640            if self.can_accept_requests().await {
641                return Ok(());
642            }
643
644            if tokio::time::Instant::now() >= deadline {
645                return Err(AgentProtocolError::Timeout(timeout));
646            }
647
648            // Poll every 10ms
649            tokio::time::sleep(Duration::from_millis(10)).await;
650        }
651    }
652
653    /// Get current health state.
654    ///
655    /// Returns the numeric health state:
656    /// - 1: Healthy
657    /// - 2: Degraded
658    /// - 3: Draining
659    /// - 4: Unhealthy
660    pub async fn health_state(&self) -> i32 {
661        *self.health_state.read().await
662    }
663
664    /// Check if the agent is healthy.
665    pub async fn is_healthy(&self) -> bool {
666        *self.health_state.read().await == 1
667    }
668
669    /// Get the number of in-flight requests.
670    pub fn in_flight_count(&self) -> u64 {
671        self.in_flight.load(Ordering::Relaxed)
672    }
673
674    // =========================================================================
675    // Control Stream Methods
676    // =========================================================================
677
678    /// Send a configuration update to the agent.
679    pub async fn send_configure(
680        &self,
681        config: serde_json::Value,
682        version: Option<String>,
683    ) -> Result<(), AgentProtocolError> {
684        let msg = ProxyToAgent {
685            message: Some(grpc_v2::proxy_to_agent::Message::Configure(
686                grpc_v2::ConfigureEvent {
687                    config_json: serde_json::to_string(&config).unwrap_or_default(),
688                    config_version: version,
689                    is_initial: false,
690                    timestamp_ms: now_ms(),
691                },
692            )),
693        };
694
695        let outbound = self.outbound_tx.lock().await;
696        if let Some(sender) = outbound.as_ref() {
697            sender.send(msg).await.map_err(|e| {
698                AgentProtocolError::ConnectionFailed(format!("Configure send failed: {}", e))
699            })?;
700        } else {
701            return Err(AgentProtocolError::ConnectionFailed(
702                "Not connected".to_string(),
703            ));
704        }
705
706        debug!(agent_id = %self.agent_id, "Sent configuration update");
707        Ok(())
708    }
709
710    /// Request the agent to shut down.
711    pub async fn send_shutdown(
712        &self,
713        reason: ShutdownReason,
714        grace_period_ms: u64,
715    ) -> Result<(), AgentProtocolError> {
716        info!(
717            agent_id = %self.agent_id,
718            reason = ?reason,
719            grace_period_ms = grace_period_ms,
720            "Requesting agent shutdown"
721        );
722
723        // For shutdown, we should cancel all pending requests first
724        let _ = self.cancel_all(CancelReason::ProxyShutdown).await;
725
726        // Close the connection
727        self.close().await
728    }
729
730    /// Request the agent to drain (stop accepting new requests).
731    pub async fn send_drain(
732        &self,
733        duration_ms: u64,
734        reason: DrainReason,
735    ) -> Result<(), AgentProtocolError> {
736        info!(
737            agent_id = %self.agent_id,
738            duration_ms = duration_ms,
739            reason = ?reason,
740            "Requesting agent drain"
741        );
742
743        // Set flow state to draining
744        *self.flow_state.write().await = FlowState::Draining;
745
746        Ok(())
747    }
748
749    /// Get agent identifier.
750    pub fn agent_id(&self) -> &str {
751        &self.agent_id
752    }
753}
754
755/// Shutdown reason for agent.
756#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757pub enum ShutdownReason {
758    Graceful,
759    Immediate,
760    ConfigReload,
761    Upgrade,
762}
763
764/// Drain reason for agent.
765#[derive(Debug, Clone, Copy, PartialEq, Eq)]
766pub enum DrainReason {
767    ConfigReload,
768    Maintenance,
769    HealthCheckFailed,
770    Manual,
771}
772
773// =============================================================================
774// Conversion Helpers
775// =============================================================================
776
777fn convert_capabilities_from_grpc(caps: grpc_v2::AgentCapabilities) -> AgentCapabilities {
778    use crate::v2::{AgentFeatures, AgentLimits, HealthConfig};
779
780    let features = caps
781        .features
782        .map(|f| AgentFeatures {
783            streaming_body: f.streaming_body,
784            websocket: f.websocket,
785            guardrails: f.guardrails,
786            config_push: f.config_push,
787            metrics_export: f.metrics_export,
788            concurrent_requests: f.concurrent_requests,
789            cancellation: f.cancellation,
790            flow_control: f.flow_control,
791            health_reporting: f.health_reporting,
792        })
793        .unwrap_or_default();
794
795    let limits = caps
796        .limits
797        .map(|l| AgentLimits {
798            max_body_size: l.max_body_size as usize,
799            max_concurrency: l.max_concurrency,
800            preferred_chunk_size: l.preferred_chunk_size as usize,
801            max_memory: l.max_memory.map(|m| m as usize),
802            max_processing_time_ms: l.max_processing_time_ms,
803        })
804        .unwrap_or_default();
805
806    let health = caps
807        .health_config
808        .map(|h| HealthConfig {
809            report_interval_ms: h.report_interval_ms,
810            include_load_metrics: h.include_load_metrics,
811            include_resource_metrics: h.include_resource_metrics,
812        })
813        .unwrap_or_default();
814
815    AgentCapabilities {
816        protocol_version: caps.protocol_version,
817        agent_id: caps.agent_id,
818        name: caps.name,
819        version: caps.version,
820        supported_events: caps
821            .supported_events
822            .into_iter()
823            .filter_map(i32_to_event_type)
824            .collect(),
825        features,
826        limits,
827        health,
828    }
829}
830
831fn i32_to_event_type(i: i32) -> Option<EventType> {
832    match i {
833        1 => Some(EventType::RequestHeaders),
834        2 => Some(EventType::RequestBodyChunk),
835        3 => Some(EventType::ResponseHeaders),
836        4 => Some(EventType::ResponseBodyChunk),
837        5 => Some(EventType::RequestComplete),
838        6 => Some(EventType::WebSocketFrame),
839        7 => Some(EventType::GuardrailInspect),
840        8 => Some(EventType::Configure),
841        _ => None,
842    }
843}
844
845fn convert_request_headers_to_grpc(
846    event: &crate::RequestHeadersEvent,
847) -> grpc_v2::RequestHeadersEvent {
848    let metadata = Some(grpc_v2::RequestMetadata {
849        correlation_id: event.metadata.correlation_id.clone(),
850        request_id: event.metadata.request_id.clone(),
851        client_ip: event.metadata.client_ip.clone(),
852        client_port: event.metadata.client_port as u32,
853        server_name: event.metadata.server_name.clone(),
854        protocol: event.metadata.protocol.clone(),
855        tls_version: event.metadata.tls_version.clone(),
856        route_id: event.metadata.route_id.clone(),
857        upstream_id: event.metadata.upstream_id.clone(),
858        timestamp_ms: now_ms(),
859        traceparent: event.metadata.traceparent.clone(),
860    });
861
862    // Use iter_flat helper for cleaner iteration over flattened headers
863    let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
864        .map(|(name, value)| grpc_v2::Header {
865            name: name.to_string(),
866            value: value.to_string(),
867        })
868        .collect();
869
870    grpc_v2::RequestHeadersEvent {
871        metadata,
872        method: event.method.clone(),
873        uri: event.uri.clone(),
874        http_version: "HTTP/1.1".to_string(),
875        headers,
876    }
877}
878
879fn convert_body_chunk_to_grpc(event: &crate::RequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
880    // Convert through binary type to centralize the base64 decode logic
881    let binary: crate::BinaryRequestBodyChunkEvent = event.into();
882    convert_binary_body_chunk_to_grpc(&binary)
883}
884
885/// Convert binary body chunk directly to gRPC (no base64 decode needed).
886///
887/// This is the efficient path for binary transports (UDS binary mode, direct Bytes).
888fn convert_binary_body_chunk_to_grpc(
889    event: &crate::BinaryRequestBodyChunkEvent,
890) -> grpc_v2::BodyChunkEvent {
891    grpc_v2::BodyChunkEvent {
892        correlation_id: event.correlation_id.clone(),
893        chunk_index: event.chunk_index,
894        data: event.data.to_vec(), // Bytes → Vec<u8> (single copy, no decode)
895        is_last: event.is_last,
896        total_size: event.total_size.map(|s| s as u64),
897        bytes_transferred: event.bytes_received as u64,
898        proxy_buffer_available: 0, // Will be set by flow control
899        timestamp_ms: now_ms(),
900    }
901}
902
903fn convert_response_headers_to_grpc(
904    event: &crate::ResponseHeadersEvent,
905) -> grpc_v2::ResponseHeadersEvent {
906    // Use iter_flat helper for cleaner iteration over flattened headers
907    let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
908        .map(|(name, value)| grpc_v2::Header {
909            name: name.to_string(),
910            value: value.to_string(),
911        })
912        .collect();
913
914    grpc_v2::ResponseHeadersEvent {
915        correlation_id: event.correlation_id.clone(),
916        status_code: event.status as u32,
917        headers,
918    }
919}
920
921fn convert_response_body_chunk_to_grpc(
922    event: &crate::ResponseBodyChunkEvent,
923) -> grpc_v2::BodyChunkEvent {
924    // Convert through binary type to centralize the base64 decode logic
925    let binary: crate::BinaryResponseBodyChunkEvent = event.into();
926    convert_binary_response_body_chunk_to_grpc(&binary)
927}
928
929/// Convert binary response body chunk directly to gRPC (no base64 decode needed).
930///
931/// This is the efficient path for binary transports (UDS binary mode, direct Bytes).
932fn convert_binary_response_body_chunk_to_grpc(
933    event: &crate::BinaryResponseBodyChunkEvent,
934) -> grpc_v2::BodyChunkEvent {
935    grpc_v2::BodyChunkEvent {
936        correlation_id: event.correlation_id.clone(),
937        chunk_index: event.chunk_index,
938        data: event.data.to_vec(), // Bytes → Vec<u8> (single copy, no decode)
939        is_last: event.is_last,
940        total_size: event.total_size.map(|s| s as u64),
941        bytes_transferred: event.bytes_sent as u64,
942        proxy_buffer_available: 0,
943        timestamp_ms: now_ms(),
944    }
945}
946
947fn convert_response_from_grpc(resp: grpc_v2::AgentResponse) -> AgentResponse {
948    let decision = match resp.decision {
949        Some(grpc_v2::agent_response::Decision::Allow(_)) => Decision::Allow,
950        Some(grpc_v2::agent_response::Decision::Block(b)) => Decision::Block {
951            status: b.status as u16,
952            body: b.body,
953            headers: if b.headers.is_empty() {
954                None
955            } else {
956                Some(b.headers.into_iter().map(|h| (h.name, h.value)).collect())
957            },
958        },
959        Some(grpc_v2::agent_response::Decision::Redirect(r)) => Decision::Redirect {
960            url: r.url,
961            status: r.status as u16,
962        },
963        Some(grpc_v2::agent_response::Decision::Challenge(c)) => Decision::Challenge {
964            challenge_type: c.challenge_type,
965            params: c.params,
966        },
967        None => Decision::Allow,
968    };
969
970    let request_headers: Vec<HeaderOp> = resp
971        .request_headers
972        .into_iter()
973        .filter_map(convert_header_op_from_grpc)
974        .collect();
975
976    let response_headers: Vec<HeaderOp> = resp
977        .response_headers
978        .into_iter()
979        .filter_map(convert_header_op_from_grpc)
980        .collect();
981
982    let audit = resp
983        .audit
984        .map(|a| crate::AuditMetadata {
985            tags: a.tags,
986            rule_ids: a.rule_ids,
987            confidence: a.confidence,
988            reason_codes: a.reason_codes,
989            custom: a
990                .custom
991                .into_iter()
992                .map(|(k, v)| (k, serde_json::Value::String(v)))
993                .collect(),
994        })
995        .unwrap_or_default();
996
997    AgentResponse {
998        version: PROTOCOL_VERSION_2,
999        decision,
1000        request_headers,
1001        response_headers,
1002        routing_metadata: HashMap::new(),
1003        audit,
1004        needs_more: resp.needs_more,
1005        request_body_mutation: None,
1006        response_body_mutation: None,
1007        websocket_decision: None,
1008    }
1009}
1010
1011fn convert_header_op_from_grpc(op: grpc_v2::HeaderOp) -> Option<HeaderOp> {
1012    match op.operation {
1013        Some(grpc_v2::header_op::Operation::Set(h)) => Some(HeaderOp::Set {
1014            name: h.name,
1015            value: h.value,
1016        }),
1017        Some(grpc_v2::header_op::Operation::Add(h)) => Some(HeaderOp::Add {
1018            name: h.name,
1019            value: h.value,
1020        }),
1021        Some(grpc_v2::header_op::Operation::Remove(name)) => Some(HeaderOp::Remove { name }),
1022        None => None,
1023    }
1024}
1025
1026fn convert_metrics_from_grpc(
1027    report: grpc_v2::MetricsReport,
1028    agent_id: &str,
1029) -> crate::v2::MetricsReport {
1030    use crate::v2::metrics::{CounterMetric, GaugeMetric, HistogramBucket, HistogramMetric};
1031
1032    let counters = report
1033        .counters
1034        .into_iter()
1035        .map(|c| CounterMetric {
1036            name: c.name,
1037            help: c.help.filter(|s| !s.is_empty()),
1038            labels: c.labels,
1039            value: c.value,
1040        })
1041        .collect();
1042
1043    let gauges = report
1044        .gauges
1045        .into_iter()
1046        .map(|g| GaugeMetric {
1047            name: g.name,
1048            help: g.help.filter(|s| !s.is_empty()),
1049            labels: g.labels,
1050            value: g.value,
1051        })
1052        .collect();
1053
1054    let histograms = report
1055        .histograms
1056        .into_iter()
1057        .map(|h| HistogramMetric {
1058            name: h.name,
1059            help: h.help.filter(|s| !s.is_empty()),
1060            labels: h.labels,
1061            sum: h.sum,
1062            count: h.count,
1063            buckets: h
1064                .buckets
1065                .into_iter()
1066                .map(|b| HistogramBucket {
1067                    le: b.le,
1068                    count: b.count,
1069                })
1070                .collect(),
1071        })
1072        .collect();
1073
1074    crate::v2::MetricsReport {
1075        agent_id: agent_id.to_string(),
1076        timestamp_ms: report.timestamp_ms,
1077        interval_ms: report.interval_ms,
1078        counters,
1079        gauges,
1080        histograms,
1081    }
1082}
1083
1084fn convert_config_update_from_grpc(
1085    update: grpc_v2::ConfigUpdateRequest,
1086) -> crate::v2::ConfigUpdateRequest {
1087    use crate::v2::control::{ConfigUpdateType, RuleDefinition};
1088
1089    let update_type = match update.update_type {
1090        Some(grpc_v2::config_update_request::UpdateType::RequestReload(_)) => {
1091            ConfigUpdateType::RequestReload
1092        }
1093        Some(grpc_v2::config_update_request::UpdateType::RuleUpdate(ru)) => {
1094            ConfigUpdateType::RuleUpdate {
1095                rule_set: ru.rule_set,
1096                rules: ru
1097                    .rules
1098                    .into_iter()
1099                    .map(|r| RuleDefinition {
1100                        id: r.id,
1101                        priority: r.priority,
1102                        definition: serde_json::from_str(&r.definition_json).unwrap_or_default(),
1103                        enabled: r.enabled,
1104                        description: r.description,
1105                        tags: r.tags,
1106                    })
1107                    .collect(),
1108                remove_rules: ru.remove_rules,
1109            }
1110        }
1111        Some(grpc_v2::config_update_request::UpdateType::ListUpdate(lu)) => {
1112            ConfigUpdateType::ListUpdate {
1113                list_id: lu.list_id,
1114                add: lu.add,
1115                remove: lu.remove,
1116            }
1117        }
1118        Some(grpc_v2::config_update_request::UpdateType::RestartRequired(rr)) => {
1119            ConfigUpdateType::RestartRequired {
1120                reason: rr.reason,
1121                grace_period_ms: rr.grace_period_ms,
1122            }
1123        }
1124        Some(grpc_v2::config_update_request::UpdateType::ConfigError(ce)) => {
1125            ConfigUpdateType::ConfigError {
1126                error: ce.error,
1127                field: ce.field,
1128            }
1129        }
1130        None => ConfigUpdateType::RequestReload, // Default
1131    };
1132
1133    crate::v2::ConfigUpdateRequest {
1134        update_type,
1135        request_id: update.request_id,
1136        timestamp_ms: update.timestamp_ms,
1137    }
1138}
1139
1140fn extract_correlation_id<T: serde::Serialize>(event: &T) -> String {
1141    // Try to extract correlation_id from the serialized event
1142    if let Ok(value) = serde_json::to_value(event) {
1143        if let Some(metadata) = value.get("metadata") {
1144            if let Some(cid) = metadata.get("correlation_id").and_then(|v| v.as_str()) {
1145                return cid.to_string();
1146            }
1147        }
1148        if let Some(cid) = value.get("correlation_id").and_then(|v| v.as_str()) {
1149            return cid.to_string();
1150        }
1151    }
1152    uuid::Uuid::new_v4().to_string()
1153}
1154
1155fn now_ms() -> u64 {
1156    std::time::SystemTime::now()
1157        .duration_since(std::time::UNIX_EPOCH)
1158        .map(|d| d.as_millis() as u64)
1159        .unwrap_or(0)
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165
1166    #[test]
1167    fn test_event_type_conversion() {
1168        assert_eq!(i32_to_event_type(1), Some(EventType::RequestHeaders));
1169        assert_eq!(i32_to_event_type(2), Some(EventType::RequestBodyChunk));
1170        assert_eq!(i32_to_event_type(99), None);
1171    }
1172
1173    #[test]
1174    fn test_extract_correlation_id() {
1175        #[derive(serde::Serialize)]
1176        struct TestEvent {
1177            correlation_id: String,
1178        }
1179
1180        let event = TestEvent {
1181            correlation_id: "test-123".to_string(),
1182        };
1183
1184        assert_eq!(extract_correlation_id(&event), "test-123");
1185    }
1186}