sentinel_agent_protocol/v2/
client.rs

1//! Agent client implementation for Protocol v2.
2//!
3//! The v2 client supports bidirectional streaming with connection multiplexing,
4//! allowing multiple concurrent requests over a single connection.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
11use tonic::transport::Channel;
12use tracing::{debug, info, trace, warn};
13
14use crate::grpc_v2::{self, agent_service_v2_client::AgentServiceV2Client, ProxyToAgent};
15use crate::headers::iter_flat;
16use crate::v2::pool::CHANNEL_BUFFER_SIZE;
17use crate::v2::{AgentCapabilities, PROTOCOL_VERSION_2};
18use crate::{AgentProtocolError, AgentResponse, Decision, EventType, HeaderOp};
19
20/// Cancellation reason for in-flight requests.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CancelReason {
23    /// Client disconnected
24    ClientDisconnect,
25    /// Request timed out
26    Timeout,
27    /// Blocked by another agent
28    BlockedByAgent,
29    /// Upstream connection failed
30    UpstreamError,
31    /// Proxy is shutting down
32    ProxyShutdown,
33    /// Manual cancellation
34    Manual,
35}
36
37impl CancelReason {
38    fn to_grpc(self) -> i32 {
39        match self {
40            CancelReason::ClientDisconnect => 1,
41            CancelReason::Timeout => 2,
42            CancelReason::BlockedByAgent => 3,
43            CancelReason::UpstreamError => 4,
44            CancelReason::ProxyShutdown => 5,
45            CancelReason::Manual => 6,
46        }
47    }
48}
49
50/// Flow control state.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum FlowState {
53    /// Normal operation
54    #[default]
55    Normal,
56    /// Agent requested pause
57    Paused,
58    /// Draining (finishing in-flight, no new requests)
59    Draining,
60}
61
62/// Callback for metrics reports from agents.
63pub type MetricsCallback = Arc<dyn Fn(crate::v2::MetricsReport) + Send + Sync>;
64
65/// Callback for config update requests from agents.
66///
67/// The callback receives the agent ID and the config update request.
68/// It should return a response indicating whether the update was accepted.
69pub type ConfigUpdateCallback =
70    Arc<dyn Fn(String, crate::v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateResponse + Send + Sync>;
71
72/// v2 agent client with connection multiplexing.
73///
74/// This client maintains a single bidirectional stream and multiplexes
75/// multiple requests over it using correlation IDs.
76///
77/// # Features
78///
79/// - **Connection multiplexing**: Multiple concurrent requests over one connection
80/// - **Cancellation support**: Cancel in-flight requests
81/// - **Flow control**: Backpressure handling when agent is overloaded
82/// - **Health tracking**: Monitor agent health status
83/// - **Metrics collection**: Receive and forward agent metrics
84pub struct AgentClientV2 {
85    /// Agent identifier
86    agent_id: String,
87    /// gRPC channel (for reconnection)
88    channel: Channel,
89    /// Request timeout
90    timeout: Duration,
91    /// Negotiated capabilities
92    capabilities: RwLock<Option<AgentCapabilities>>,
93    /// Negotiated protocol version
94    protocol_version: AtomicU64,
95    /// Pending requests by correlation ID
96    pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
97    /// Sender for outbound messages
98    outbound_tx: Mutex<Option<mpsc::Sender<ProxyToAgent>>>,
99    /// Sequence counter for pings
100    ping_sequence: AtomicU64,
101    /// Connection state
102    connected: RwLock<bool>,
103    /// Flow control state
104    flow_state: RwLock<FlowState>,
105    /// Last known health state
106    health_state: RwLock<i32>,
107    /// In-flight request count
108    in_flight: AtomicU64,
109    /// Callback for metrics reports
110    metrics_callback: Option<MetricsCallback>,
111    /// Callback for config update requests
112    config_update_callback: Option<ConfigUpdateCallback>,
113}
114
115impl AgentClientV2 {
116    /// Create a new v2 client.
117    pub async fn new(
118        agent_id: impl Into<String>,
119        endpoint: impl Into<String>,
120        timeout: Duration,
121    ) -> Result<Self, AgentProtocolError> {
122        let agent_id = agent_id.into();
123        let endpoint = endpoint.into();
124
125        debug!(agent_id = %agent_id, endpoint = %endpoint, "Creating v2 client");
126
127        let channel = Channel::from_shared(endpoint.clone())
128            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid endpoint: {}", e)))?
129            .connect_timeout(timeout)
130            .timeout(timeout)
131            .connect()
132            .await
133            .map_err(|e| {
134                AgentProtocolError::ConnectionFailed(format!("Failed to connect: {}", e))
135            })?;
136
137        Ok(Self {
138            agent_id,
139            channel,
140            timeout,
141            capabilities: RwLock::new(None),
142            protocol_version: AtomicU64::new(1), // Default to v1 until handshake
143            pending: Arc::new(Mutex::new(HashMap::new())),
144            outbound_tx: Mutex::new(None),
145            ping_sequence: AtomicU64::new(0),
146            connected: RwLock::new(false),
147            flow_state: RwLock::new(FlowState::Normal),
148            health_state: RwLock::new(1), // HEALTHY
149            in_flight: AtomicU64::new(0),
150            metrics_callback: None,
151            config_update_callback: None,
152        })
153    }
154
155    /// Set the metrics callback for receiving agent metrics reports.
156    ///
157    /// This callback is invoked whenever the agent sends a metrics report
158    /// through the control stream. The callback should be fast and non-blocking.
159    pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
160        self.metrics_callback = Some(callback);
161    }
162
163    /// Set the config update callback for handling agent config requests.
164    ///
165    /// This callback is invoked whenever the agent sends a config update request
166    /// through the control stream (e.g., requesting a reload, reporting errors).
167    pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
168        self.config_update_callback = Some(callback);
169    }
170
171    /// Connect and perform handshake.
172    pub async fn connect(&self) -> Result<(), AgentProtocolError> {
173        let mut client = AgentServiceV2Client::new(self.channel.clone());
174
175        // Create bidirectional stream
176        let (tx, rx) = mpsc::channel::<ProxyToAgent>(CHANNEL_BUFFER_SIZE);
177        let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
178
179        let response_stream = client
180            .process_stream(rx_stream)
181            .await
182            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream failed: {}", e)))?;
183
184        let mut inbound = response_stream.into_inner();
185
186        // Send handshake
187        let handshake = ProxyToAgent {
188            message: Some(grpc_v2::proxy_to_agent::Message::Handshake(
189                grpc_v2::HandshakeRequest {
190                    supported_versions: vec![PROTOCOL_VERSION_2, 1],
191                    proxy_id: "sentinel-proxy".to_string(),
192                    proxy_version: env!("CARGO_PKG_VERSION").to_string(),
193                    config_json: "{}".to_string(),
194                },
195            )),
196        };
197
198        tx.send(handshake).await.map_err(|e| {
199            AgentProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e))
200        })?;
201
202        // Wait for handshake response
203        let handshake_resp = tokio::time::timeout(self.timeout, inbound.message())
204            .await
205            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
206            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream error: {}", e)))?
207            .ok_or_else(|| AgentProtocolError::ConnectionFailed("Empty handshake response".to_string()))?;
208
209        // Process handshake response
210        if let Some(grpc_v2::agent_to_proxy::Message::Handshake(resp)) = handshake_resp.message {
211            if !resp.success {
212                return Err(AgentProtocolError::ConnectionFailed(format!(
213                    "Handshake failed: {}",
214                    resp.error.unwrap_or_default()
215                )));
216            }
217
218            self.protocol_version
219                .store(resp.protocol_version as u64, Ordering::SeqCst);
220
221            if let Some(caps) = resp.capabilities {
222                let capabilities = convert_capabilities_from_grpc(caps);
223                *self.capabilities.write().await = Some(capabilities);
224            }
225
226            info!(
227                agent_id = %self.agent_id,
228                protocol_version = resp.protocol_version,
229                "v2 handshake successful"
230            );
231        } else {
232            return Err(AgentProtocolError::ConnectionFailed(
233                "Invalid handshake response".to_string(),
234            ));
235        }
236
237        // Store outbound sender
238        *self.outbound_tx.lock().await = Some(tx);
239        *self.connected.write().await = true;
240
241        // Spawn background task to handle incoming messages
242        let pending = Arc::clone(&self.pending);
243        let agent_id = self.agent_id.clone();
244        let flow_state = Arc::new(RwLock::new(FlowState::Normal));
245        let health_state = Arc::new(RwLock::new(1i32));
246        let _in_flight = Arc::new(AtomicU64::new(0));
247
248        // Share state with the spawned task
249        let flow_state_clone = Arc::clone(&flow_state);
250        let health_state_clone = Arc::clone(&health_state);
251        let metrics_callback = self.metrics_callback.clone();
252        let config_update_callback = self.config_update_callback.clone();
253
254        tokio::spawn(async move {
255            while let Ok(Some(msg)) = inbound.message().await {
256                match msg.message {
257                    Some(grpc_v2::agent_to_proxy::Message::Response(resp)) => {
258                        let correlation_id = resp.correlation_id.clone();
259                        if let Some(sender) = pending.lock().await.remove(&correlation_id) {
260                            let response = convert_response_from_grpc(resp);
261                            let _ = sender.send(response);
262                        } else {
263                            warn!(
264                                agent_id = %agent_id,
265                                correlation_id = %correlation_id,
266                                "Received response for unknown correlation ID"
267                            );
268                        }
269                    }
270                    Some(grpc_v2::agent_to_proxy::Message::Health(health)) => {
271                        trace!(
272                            agent_id = %agent_id,
273                            state = health.state,
274                            "Received health status"
275                        );
276                        *health_state_clone.write().await = health.state;
277                    }
278                    Some(grpc_v2::agent_to_proxy::Message::Metrics(metrics)) => {
279                        trace!(
280                            agent_id = %agent_id,
281                            counters = metrics.counters.len(),
282                            gauges = metrics.gauges.len(),
283                            histograms = metrics.histograms.len(),
284                            "Received metrics report"
285                        );
286                        if let Some(ref callback) = metrics_callback {
287                            let report = convert_metrics_from_grpc(metrics, &agent_id);
288                            callback(report);
289                        }
290                    }
291                    Some(grpc_v2::agent_to_proxy::Message::FlowControl(fc)) => {
292                        // Handle flow control signals
293                        let new_state = match fc.action {
294                            1 => FlowState::Paused,  // PAUSE
295                            2 => FlowState::Normal,  // RESUME
296                            _ => FlowState::Normal,
297                        };
298                        debug!(
299                            agent_id = %agent_id,
300                            action = fc.action,
301                            correlation_id = ?fc.correlation_id,
302                            "Received flow control signal"
303                        );
304                        *flow_state_clone.write().await = new_state;
305                    }
306                    Some(grpc_v2::agent_to_proxy::Message::Pong(pong)) => {
307                        trace!(
308                            agent_id = %agent_id,
309                            sequence = pong.sequence,
310                            latency_ms = pong.timestamp_ms.saturating_sub(pong.ping_timestamp_ms),
311                            "Received pong"
312                        );
313                    }
314                    Some(grpc_v2::agent_to_proxy::Message::ConfigUpdate(update)) => {
315                        debug!(
316                            agent_id = %agent_id,
317                            request_id = %update.request_id,
318                            "Received config update request from agent"
319                        );
320                        if let Some(ref callback) = config_update_callback {
321                            let request = convert_config_update_from_grpc(update);
322                            let _response = callback(agent_id.clone(), request);
323                            // Note: Response would be sent via control stream if we had one
324                            // For now, the callback handles the request and logs/processes it
325                        }
326                    }
327                    Some(grpc_v2::agent_to_proxy::Message::Log(log_msg)) => {
328                        // Handle log messages from agent
329                        match log_msg.level {
330                            1 => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent debug log"),
331                            2 => debug!(agent_id = %agent_id, msg = %log_msg.message, "Agent info log"),
332                            3 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent warning"),
333                            4 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent error"),
334                            _ => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent log"),
335                        }
336                    }
337                    _ => {}
338                }
339            }
340
341            debug!(agent_id = %agent_id, "Response handler ended");
342        });
343
344        Ok(())
345    }
346
347    /// Send a request headers event and wait for response.
348    pub async fn send_request_headers(
349        &self,
350        correlation_id: &str,
351        event: &crate::RequestHeadersEvent,
352    ) -> Result<AgentResponse, AgentProtocolError> {
353        let msg = ProxyToAgent {
354            message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
355                convert_request_headers_to_grpc(event),
356            )),
357        };
358
359        self.send_and_wait(correlation_id, msg).await
360    }
361
362    /// Send a request body chunk event and wait for response.
363    ///
364    /// For streaming body inspection, chunks are sent sequentially with
365    /// increasing `chunk_index`. The agent responds after processing each chunk.
366    pub async fn send_request_body_chunk(
367        &self,
368        correlation_id: &str,
369        event: &crate::RequestBodyChunkEvent,
370    ) -> Result<AgentResponse, AgentProtocolError> {
371        let msg = ProxyToAgent {
372            message: Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(
373                convert_body_chunk_to_grpc(event),
374            )),
375        };
376
377        self.send_and_wait(correlation_id, msg).await
378    }
379
380    /// Send a response headers event and wait for response.
381    ///
382    /// Called when upstream response headers are received, allowing the agent
383    /// to inspect/modify response headers before they're sent to the client.
384    pub async fn send_response_headers(
385        &self,
386        correlation_id: &str,
387        event: &crate::ResponseHeadersEvent,
388    ) -> Result<AgentResponse, AgentProtocolError> {
389        let msg = ProxyToAgent {
390            message: Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(
391                convert_response_headers_to_grpc(event),
392            )),
393        };
394
395        self.send_and_wait(correlation_id, msg).await
396    }
397
398    /// Send a response body chunk event and wait for response.
399    ///
400    /// For streaming response body inspection, chunks are sent sequentially.
401    /// The agent can inspect and optionally modify response body data.
402    pub async fn send_response_body_chunk(
403        &self,
404        correlation_id: &str,
405        event: &crate::ResponseBodyChunkEvent,
406    ) -> Result<AgentResponse, AgentProtocolError> {
407        let msg = ProxyToAgent {
408            message: Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(
409                convert_response_body_chunk_to_grpc(event),
410            )),
411        };
412
413        self.send_and_wait(correlation_id, msg).await
414    }
415
416    /// Send any event type and wait for response.
417    pub async fn send_event<T: serde::Serialize>(
418        &self,
419        event_type: EventType,
420        event: &T,
421    ) -> Result<AgentResponse, AgentProtocolError> {
422        // For compatibility, extract correlation_id from event
423        let correlation_id = extract_correlation_id(event);
424
425        let msg = match event_type {
426            EventType::RequestHeaders => {
427                if let Ok(e) = serde_json::from_value::<crate::RequestHeadersEvent>(
428                    serde_json::to_value(event).unwrap_or_default(),
429                ) {
430                    ProxyToAgent {
431                        message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
432                            convert_request_headers_to_grpc(&e),
433                        )),
434                    }
435                } else {
436                    return Err(AgentProtocolError::InvalidMessage(
437                        "Failed to convert event".to_string(),
438                    ));
439                }
440            }
441            _ => {
442                // Fall back to v1 for unsupported event types
443                return Err(AgentProtocolError::InvalidMessage(format!(
444                    "Event type {:?} not yet supported in v2 streaming mode",
445                    event_type
446                )));
447            }
448        };
449
450        self.send_and_wait(&correlation_id, msg).await
451    }
452
453    /// Send a message and wait for response.
454    async fn send_and_wait(
455        &self,
456        correlation_id: &str,
457        msg: ProxyToAgent,
458    ) -> Result<AgentResponse, AgentProtocolError> {
459        // Create response channel
460        let (tx, rx) = oneshot::channel();
461
462        // Register pending request
463        self.pending
464            .lock()
465            .await
466            .insert(correlation_id.to_string(), tx);
467
468        // Send message
469        {
470            let outbound = self.outbound_tx.lock().await;
471            if let Some(sender) = outbound.as_ref() {
472                sender.send(msg).await.map_err(|e| {
473                    AgentProtocolError::ConnectionFailed(format!("Send failed: {}", e))
474                })?;
475            } else {
476                return Err(AgentProtocolError::ConnectionFailed(
477                    "Not connected".to_string(),
478                ));
479            }
480        }
481
482        // Wait for response with timeout
483        match tokio::time::timeout(self.timeout, rx).await {
484            Ok(Ok(response)) => Ok(response),
485            Ok(Err(_)) => {
486                self.pending.lock().await.remove(correlation_id);
487                Err(AgentProtocolError::ConnectionFailed(
488                    "Response channel closed".to_string(),
489                ))
490            }
491            Err(_) => {
492                self.pending.lock().await.remove(correlation_id);
493                Err(AgentProtocolError::Timeout(self.timeout))
494            }
495        }
496    }
497
498    /// Send a ping and measure latency.
499    pub async fn ping(&self) -> Result<Duration, AgentProtocolError> {
500        let sequence = self.ping_sequence.fetch_add(1, Ordering::SeqCst);
501        let timestamp_ms = now_ms();
502
503        let msg = ProxyToAgent {
504            message: Some(grpc_v2::proxy_to_agent::Message::Ping(grpc_v2::Ping {
505                sequence,
506                timestamp_ms,
507            })),
508        };
509
510        let outbound = self.outbound_tx.lock().await;
511        if let Some(sender) = outbound.as_ref() {
512            sender
513                .send(msg)
514                .await
515                .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Ping failed: {}", e)))?;
516        }
517
518        // Note: In a full implementation, we'd track pong responses
519        // For now, just return a placeholder
520        Ok(Duration::from_millis(0))
521    }
522
523    /// Get negotiated protocol version.
524    pub fn protocol_version(&self) -> u32 {
525        self.protocol_version.load(Ordering::SeqCst) as u32
526    }
527
528    /// Get agent capabilities.
529    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
530        self.capabilities.read().await.clone()
531    }
532
533    /// Check if client is connected.
534    pub async fn is_connected(&self) -> bool {
535        *self.connected.read().await
536    }
537
538    /// Close the connection.
539    pub async fn close(&self) -> Result<(), AgentProtocolError> {
540        *self.outbound_tx.lock().await = None;
541        *self.connected.write().await = false;
542        Ok(())
543    }
544
545    /// Cancel an in-flight request.
546    ///
547    /// Sends a cancellation message to the agent and removes the request from
548    /// the pending map. The agent should stop processing and clean up resources.
549    pub async fn cancel_request(
550        &self,
551        correlation_id: &str,
552        reason: CancelReason,
553    ) -> Result<(), AgentProtocolError> {
554        // Remove from pending (will cause the waiter to receive an error)
555        self.pending.lock().await.remove(correlation_id);
556
557        // Send cancel message to agent
558        let msg = ProxyToAgent {
559            message: Some(grpc_v2::proxy_to_agent::Message::Cancel(
560                grpc_v2::CancelRequest {
561                    correlation_id: correlation_id.to_string(),
562                    reason: reason.to_grpc(),
563                    timestamp_ms: now_ms(),
564                    blocking_agent_id: None,
565                    manual_reason: None,
566                },
567            )),
568        };
569
570        let outbound = self.outbound_tx.lock().await;
571        if let Some(sender) = outbound.as_ref() {
572            sender.send(msg).await.map_err(|e| {
573                AgentProtocolError::ConnectionFailed(format!("Cancel send failed: {}", e))
574            })?;
575        }
576
577        debug!(
578            agent_id = %self.agent_id,
579            correlation_id = %correlation_id,
580            reason = ?reason,
581            "Cancelled request"
582        );
583
584        Ok(())
585    }
586
587    /// Cancel all in-flight requests.
588    ///
589    /// Used during shutdown or when the upstream connection fails.
590    pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
591        let correlation_ids: Vec<String> = {
592            let pending = self.pending.lock().await;
593            pending.keys().cloned().collect()
594        };
595
596        let count = correlation_ids.len();
597        for cid in correlation_ids {
598            let _ = self.cancel_request(&cid, reason).await;
599        }
600
601        debug!(
602            agent_id = %self.agent_id,
603            count = count,
604            reason = ?reason,
605            "Cancelled all requests"
606        );
607
608        Ok(count)
609    }
610
611    /// Get current flow control state.
612    pub async fn flow_state(&self) -> FlowState {
613        *self.flow_state.read().await
614    }
615
616    /// Check if the agent is accepting new requests.
617    ///
618    /// Returns false if the agent has requested a pause or is draining.
619    pub async fn can_accept_requests(&self) -> bool {
620        matches!(*self.flow_state.read().await, FlowState::Normal)
621    }
622
623    /// Wait for flow control to allow new requests.
624    ///
625    /// If the agent has requested a pause, this will wait until it resumes
626    /// or the timeout expires.
627    pub async fn wait_for_flow_control(&self, timeout: Duration) -> Result<(), AgentProtocolError> {
628        let deadline = tokio::time::Instant::now() + timeout;
629
630        loop {
631            if self.can_accept_requests().await {
632                return Ok(());
633            }
634
635            if tokio::time::Instant::now() >= deadline {
636                return Err(AgentProtocolError::Timeout(timeout));
637            }
638
639            // Poll every 10ms
640            tokio::time::sleep(Duration::from_millis(10)).await;
641        }
642    }
643
644    /// Get current health state.
645    ///
646    /// Returns the numeric health state:
647    /// - 1: Healthy
648    /// - 2: Degraded
649    /// - 3: Draining
650    /// - 4: Unhealthy
651    pub async fn health_state(&self) -> i32 {
652        *self.health_state.read().await
653    }
654
655    /// Check if the agent is healthy.
656    pub async fn is_healthy(&self) -> bool {
657        *self.health_state.read().await == 1
658    }
659
660    /// Get the number of in-flight requests.
661    pub fn in_flight_count(&self) -> u64 {
662        self.in_flight.load(Ordering::Relaxed)
663    }
664
665    // =========================================================================
666    // Control Stream Methods
667    // =========================================================================
668
669    /// Send a configuration update to the agent.
670    pub async fn send_configure(
671        &self,
672        config: serde_json::Value,
673        version: Option<String>,
674    ) -> Result<(), AgentProtocolError> {
675        let msg = ProxyToAgent {
676            message: Some(grpc_v2::proxy_to_agent::Message::Configure(
677                grpc_v2::ConfigureEvent {
678                    config_json: serde_json::to_string(&config).unwrap_or_default(),
679                    config_version: version,
680                    is_initial: false,
681                    timestamp_ms: now_ms(),
682                },
683            )),
684        };
685
686        let outbound = self.outbound_tx.lock().await;
687        if let Some(sender) = outbound.as_ref() {
688            sender.send(msg).await.map_err(|e| {
689                AgentProtocolError::ConnectionFailed(format!("Configure send failed: {}", e))
690            })?;
691        } else {
692            return Err(AgentProtocolError::ConnectionFailed(
693                "Not connected".to_string(),
694            ));
695        }
696
697        debug!(agent_id = %self.agent_id, "Sent configuration update");
698        Ok(())
699    }
700
701    /// Request the agent to shut down.
702    pub async fn send_shutdown(
703        &self,
704        reason: ShutdownReason,
705        grace_period_ms: u64,
706    ) -> Result<(), AgentProtocolError> {
707        info!(
708            agent_id = %self.agent_id,
709            reason = ?reason,
710            grace_period_ms = grace_period_ms,
711            "Requesting agent shutdown"
712        );
713
714        // For shutdown, we should cancel all pending requests first
715        let _ = self.cancel_all(CancelReason::ProxyShutdown).await;
716
717        // Close the connection
718        self.close().await
719    }
720
721    /// Request the agent to drain (stop accepting new requests).
722    pub async fn send_drain(
723        &self,
724        duration_ms: u64,
725        reason: DrainReason,
726    ) -> Result<(), AgentProtocolError> {
727        info!(
728            agent_id = %self.agent_id,
729            duration_ms = duration_ms,
730            reason = ?reason,
731            "Requesting agent drain"
732        );
733
734        // Set flow state to draining
735        *self.flow_state.write().await = FlowState::Draining;
736
737        Ok(())
738    }
739
740    /// Get agent identifier.
741    pub fn agent_id(&self) -> &str {
742        &self.agent_id
743    }
744}
745
746/// Shutdown reason for agent.
747#[derive(Debug, Clone, Copy, PartialEq, Eq)]
748pub enum ShutdownReason {
749    Graceful,
750    Immediate,
751    ConfigReload,
752    Upgrade,
753}
754
755/// Drain reason for agent.
756#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757pub enum DrainReason {
758    ConfigReload,
759    Maintenance,
760    HealthCheckFailed,
761    Manual,
762}
763
764// =============================================================================
765// Conversion Helpers
766// =============================================================================
767
768fn convert_capabilities_from_grpc(caps: grpc_v2::AgentCapabilities) -> AgentCapabilities {
769    use crate::v2::{AgentFeatures, AgentLimits, HealthConfig};
770
771    let features = caps.features.map(|f| AgentFeatures {
772        streaming_body: f.streaming_body,
773        websocket: f.websocket,
774        guardrails: f.guardrails,
775        config_push: f.config_push,
776        metrics_export: f.metrics_export,
777        concurrent_requests: f.concurrent_requests,
778        cancellation: f.cancellation,
779        flow_control: f.flow_control,
780        health_reporting: f.health_reporting,
781    }).unwrap_or_default();
782
783    let limits = caps.limits.map(|l| AgentLimits {
784        max_body_size: l.max_body_size as usize,
785        max_concurrency: l.max_concurrency,
786        preferred_chunk_size: l.preferred_chunk_size as usize,
787        max_memory: l.max_memory.map(|m| m as usize),
788        max_processing_time_ms: l.max_processing_time_ms,
789    }).unwrap_or_default();
790
791    let health = caps.health_config.map(|h| HealthConfig {
792        report_interval_ms: h.report_interval_ms,
793        include_load_metrics: h.include_load_metrics,
794        include_resource_metrics: h.include_resource_metrics,
795    }).unwrap_or_default();
796
797    AgentCapabilities {
798        protocol_version: caps.protocol_version,
799        agent_id: caps.agent_id,
800        name: caps.name,
801        version: caps.version,
802        supported_events: caps.supported_events.into_iter().filter_map(i32_to_event_type).collect(),
803        features,
804        limits,
805        health,
806    }
807}
808
809fn i32_to_event_type(i: i32) -> Option<EventType> {
810    match i {
811        1 => Some(EventType::RequestHeaders),
812        2 => Some(EventType::RequestBodyChunk),
813        3 => Some(EventType::ResponseHeaders),
814        4 => Some(EventType::ResponseBodyChunk),
815        5 => Some(EventType::RequestComplete),
816        6 => Some(EventType::WebSocketFrame),
817        7 => Some(EventType::GuardrailInspect),
818        8 => Some(EventType::Configure),
819        _ => None,
820    }
821}
822
823fn convert_request_headers_to_grpc(event: &crate::RequestHeadersEvent) -> grpc_v2::RequestHeadersEvent {
824    let metadata = Some(grpc_v2::RequestMetadata {
825        correlation_id: event.metadata.correlation_id.clone(),
826        request_id: event.metadata.request_id.clone(),
827        client_ip: event.metadata.client_ip.clone(),
828        client_port: event.metadata.client_port as u32,
829        server_name: event.metadata.server_name.clone(),
830        protocol: event.metadata.protocol.clone(),
831        tls_version: event.metadata.tls_version.clone(),
832        route_id: event.metadata.route_id.clone(),
833        upstream_id: event.metadata.upstream_id.clone(),
834        timestamp_ms: now_ms(),
835        traceparent: event.metadata.traceparent.clone(),
836    });
837
838    // Use iter_flat helper for cleaner iteration over flattened headers
839    let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
840        .map(|(name, value)| grpc_v2::Header {
841            name: name.to_string(),
842            value: value.to_string(),
843        })
844        .collect();
845
846    grpc_v2::RequestHeadersEvent {
847        metadata,
848        method: event.method.clone(),
849        uri: event.uri.clone(),
850        http_version: "HTTP/1.1".to_string(),
851        headers,
852    }
853}
854
855fn convert_body_chunk_to_grpc(event: &crate::RequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
856    // Convert through binary type to centralize the base64 decode logic
857    let binary: crate::BinaryRequestBodyChunkEvent = event.into();
858    convert_binary_body_chunk_to_grpc(&binary)
859}
860
861/// Convert binary body chunk directly to gRPC (no base64 decode needed).
862///
863/// This is the efficient path for binary transports (UDS binary mode, direct Bytes).
864fn convert_binary_body_chunk_to_grpc(event: &crate::BinaryRequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
865    grpc_v2::BodyChunkEvent {
866        correlation_id: event.correlation_id.clone(),
867        chunk_index: event.chunk_index,
868        data: event.data.to_vec(), // Bytes → Vec<u8> (single copy, no decode)
869        is_last: event.is_last,
870        total_size: event.total_size.map(|s| s as u64),
871        bytes_transferred: event.bytes_received as u64,
872        proxy_buffer_available: 0, // Will be set by flow control
873        timestamp_ms: now_ms(),
874    }
875}
876
877fn convert_response_headers_to_grpc(event: &crate::ResponseHeadersEvent) -> grpc_v2::ResponseHeadersEvent {
878    // Use iter_flat helper for cleaner iteration over flattened headers
879    let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
880        .map(|(name, value)| grpc_v2::Header {
881            name: name.to_string(),
882            value: value.to_string(),
883        })
884        .collect();
885
886    grpc_v2::ResponseHeadersEvent {
887        correlation_id: event.correlation_id.clone(),
888        status_code: event.status as u32,
889        headers,
890    }
891}
892
893fn convert_response_body_chunk_to_grpc(event: &crate::ResponseBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
894    // Convert through binary type to centralize the base64 decode logic
895    let binary: crate::BinaryResponseBodyChunkEvent = event.into();
896    convert_binary_response_body_chunk_to_grpc(&binary)
897}
898
899/// Convert binary response body chunk directly to gRPC (no base64 decode needed).
900///
901/// This is the efficient path for binary transports (UDS binary mode, direct Bytes).
902fn convert_binary_response_body_chunk_to_grpc(event: &crate::BinaryResponseBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
903    grpc_v2::BodyChunkEvent {
904        correlation_id: event.correlation_id.clone(),
905        chunk_index: event.chunk_index,
906        data: event.data.to_vec(), // Bytes → Vec<u8> (single copy, no decode)
907        is_last: event.is_last,
908        total_size: event.total_size.map(|s| s as u64),
909        bytes_transferred: event.bytes_sent as u64,
910        proxy_buffer_available: 0,
911        timestamp_ms: now_ms(),
912    }
913}
914
915fn convert_response_from_grpc(resp: grpc_v2::AgentResponse) -> AgentResponse {
916    let decision = match resp.decision {
917        Some(grpc_v2::agent_response::Decision::Allow(_)) => Decision::Allow,
918        Some(grpc_v2::agent_response::Decision::Block(b)) => Decision::Block {
919            status: b.status as u16,
920            body: b.body,
921            headers: if b.headers.is_empty() {
922                None
923            } else {
924                Some(b.headers.into_iter().map(|h| (h.name, h.value)).collect())
925            },
926        },
927        Some(grpc_v2::agent_response::Decision::Redirect(r)) => Decision::Redirect {
928            url: r.url,
929            status: r.status as u16,
930        },
931        Some(grpc_v2::agent_response::Decision::Challenge(c)) => Decision::Challenge {
932            challenge_type: c.challenge_type,
933            params: c.params,
934        },
935        None => Decision::Allow,
936    };
937
938    let request_headers: Vec<HeaderOp> = resp
939        .request_headers
940        .into_iter()
941        .filter_map(convert_header_op_from_grpc)
942        .collect();
943
944    let response_headers: Vec<HeaderOp> = resp
945        .response_headers
946        .into_iter()
947        .filter_map(convert_header_op_from_grpc)
948        .collect();
949
950    let audit = resp.audit.map(|a| crate::AuditMetadata {
951        tags: a.tags,
952        rule_ids: a.rule_ids,
953        confidence: a.confidence,
954        reason_codes: a.reason_codes,
955        custom: a.custom.into_iter().map(|(k, v)| (k, serde_json::Value::String(v))).collect(),
956    }).unwrap_or_default();
957
958    AgentResponse {
959        version: PROTOCOL_VERSION_2,
960        decision,
961        request_headers,
962        response_headers,
963        routing_metadata: HashMap::new(),
964        audit,
965        needs_more: resp.needs_more,
966        request_body_mutation: None,
967        response_body_mutation: None,
968        websocket_decision: None,
969    }
970}
971
972fn convert_header_op_from_grpc(op: grpc_v2::HeaderOp) -> Option<HeaderOp> {
973    match op.operation {
974        Some(grpc_v2::header_op::Operation::Set(h)) => Some(HeaderOp::Set {
975            name: h.name,
976            value: h.value,
977        }),
978        Some(grpc_v2::header_op::Operation::Add(h)) => Some(HeaderOp::Add {
979            name: h.name,
980            value: h.value,
981        }),
982        Some(grpc_v2::header_op::Operation::Remove(name)) => Some(HeaderOp::Remove { name }),
983        None => None,
984    }
985}
986
987fn convert_metrics_from_grpc(report: grpc_v2::MetricsReport, agent_id: &str) -> crate::v2::MetricsReport {
988    use crate::v2::metrics::{CounterMetric, GaugeMetric, HistogramBucket, HistogramMetric};
989
990    let counters = report
991        .counters
992        .into_iter()
993        .map(|c| CounterMetric {
994            name: c.name,
995            help: c.help.filter(|s| !s.is_empty()),
996            labels: c.labels,
997            value: c.value,
998        })
999        .collect();
1000
1001    let gauges = report
1002        .gauges
1003        .into_iter()
1004        .map(|g| GaugeMetric {
1005            name: g.name,
1006            help: g.help.filter(|s| !s.is_empty()),
1007            labels: g.labels,
1008            value: g.value,
1009        })
1010        .collect();
1011
1012    let histograms = report
1013        .histograms
1014        .into_iter()
1015        .map(|h| HistogramMetric {
1016            name: h.name,
1017            help: h.help.filter(|s| !s.is_empty()),
1018            labels: h.labels,
1019            sum: h.sum,
1020            count: h.count,
1021            buckets: h
1022                .buckets
1023                .into_iter()
1024                .map(|b| HistogramBucket { le: b.le, count: b.count })
1025                .collect(),
1026        })
1027        .collect();
1028
1029    crate::v2::MetricsReport {
1030        agent_id: agent_id.to_string(),
1031        timestamp_ms: report.timestamp_ms,
1032        interval_ms: report.interval_ms,
1033        counters,
1034        gauges,
1035        histograms,
1036    }
1037}
1038
1039fn convert_config_update_from_grpc(update: grpc_v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateRequest {
1040    use crate::v2::control::{ConfigUpdateType, RuleDefinition};
1041
1042    let update_type = match update.update_type {
1043        Some(grpc_v2::config_update_request::UpdateType::RequestReload(_)) => {
1044            ConfigUpdateType::RequestReload
1045        }
1046        Some(grpc_v2::config_update_request::UpdateType::RuleUpdate(ru)) => {
1047            ConfigUpdateType::RuleUpdate {
1048                rule_set: ru.rule_set,
1049                rules: ru
1050                    .rules
1051                    .into_iter()
1052                    .map(|r| RuleDefinition {
1053                        id: r.id,
1054                        priority: r.priority,
1055                        definition: serde_json::from_str(&r.definition_json).unwrap_or_default(),
1056                        enabled: r.enabled,
1057                        description: r.description,
1058                        tags: r.tags,
1059                    })
1060                    .collect(),
1061                remove_rules: ru.remove_rules,
1062            }
1063        }
1064        Some(grpc_v2::config_update_request::UpdateType::ListUpdate(lu)) => {
1065            ConfigUpdateType::ListUpdate {
1066                list_id: lu.list_id,
1067                add: lu.add,
1068                remove: lu.remove,
1069            }
1070        }
1071        Some(grpc_v2::config_update_request::UpdateType::RestartRequired(rr)) => {
1072            ConfigUpdateType::RestartRequired {
1073                reason: rr.reason,
1074                grace_period_ms: rr.grace_period_ms,
1075            }
1076        }
1077        Some(grpc_v2::config_update_request::UpdateType::ConfigError(ce)) => {
1078            ConfigUpdateType::ConfigError {
1079                error: ce.error,
1080                field: ce.field,
1081            }
1082        }
1083        None => ConfigUpdateType::RequestReload, // Default
1084    };
1085
1086    crate::v2::ConfigUpdateRequest {
1087        update_type,
1088        request_id: update.request_id,
1089        timestamp_ms: update.timestamp_ms,
1090    }
1091}
1092
1093fn extract_correlation_id<T: serde::Serialize>(event: &T) -> String {
1094    // Try to extract correlation_id from the serialized event
1095    if let Ok(value) = serde_json::to_value(event) {
1096        if let Some(metadata) = value.get("metadata") {
1097            if let Some(cid) = metadata.get("correlation_id").and_then(|v| v.as_str()) {
1098                return cid.to_string();
1099            }
1100        }
1101        if let Some(cid) = value.get("correlation_id").and_then(|v| v.as_str()) {
1102            return cid.to_string();
1103        }
1104    }
1105    uuid::Uuid::new_v4().to_string()
1106}
1107
1108fn now_ms() -> u64 {
1109    std::time::SystemTime::now()
1110        .duration_since(std::time::UNIX_EPOCH)
1111        .map(|d| d.as_millis() as u64)
1112        .unwrap_or(0)
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118
1119    #[test]
1120    fn test_event_type_conversion() {
1121        assert_eq!(i32_to_event_type(1), Some(EventType::RequestHeaders));
1122        assert_eq!(i32_to_event_type(2), Some(EventType::RequestBodyChunk));
1123        assert_eq!(i32_to_event_type(99), None);
1124    }
1125
1126    #[test]
1127    fn test_extract_correlation_id() {
1128        #[derive(serde::Serialize)]
1129        struct TestEvent {
1130            correlation_id: String,
1131        }
1132
1133        let event = TestEvent {
1134            correlation_id: "test-123".to_string(),
1135        };
1136
1137        assert_eq!(extract_correlation_id(&event), "test-123");
1138    }
1139}