Skip to main content

sentinel_agent_protocol/v2/
pool.rs

1//! Agent connection pool for Protocol v2.
2//!
3//! This module provides a production-ready connection pool for managing
4//! multiple connections to agents with:
5//!
6//! - **Connection pooling**: Maintain multiple connections per agent
7//! - **Load balancing**: Round-robin, least-connections, or health-based routing
8//! - **Health tracking**: Route requests based on agent health
9//! - **Automatic reconnection**: Reconnect failed connections
10//! - **Graceful shutdown**: Drain connections before closing
11
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use dashmap::DashMap;
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, info, trace, warn};
19
20use crate::v2::client::{AgentClientV2, CancelReason, ConfigUpdateCallback, MetricsCallback};
21use crate::v2::control::ConfigUpdateType;
22use crate::v2::observability::{ConfigPusher, ConfigUpdateHandler, MetricsCollector};
23use crate::v2::protocol_metrics::ProtocolMetrics;
24use crate::v2::reverse::ReverseConnectionClient;
25use crate::v2::uds::AgentClientV2Uds;
26use crate::v2::AgentCapabilities;
27use crate::{
28    AgentProtocolError, AgentResponse, RequestBodyChunkEvent, RequestHeadersEvent,
29    ResponseBodyChunkEvent, ResponseHeadersEvent,
30};
31
32/// Channel buffer size for all transports.
33///
34/// This is aligned across UDS, gRPC, and reverse connections to ensure
35/// consistent backpressure behavior. A smaller buffer (64 vs 1024) means
36/// backpressure kicks in earlier, preventing memory buildup under load.
37///
38/// The value 64 balances:
39/// - Small enough to apply backpressure before memory issues
40/// - Large enough to handle burst traffic without blocking
41pub const CHANNEL_BUFFER_SIZE: usize = 64;
42
43/// Load balancing strategy for the connection pool.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum LoadBalanceStrategy {
46    /// Round-robin across all healthy connections
47    #[default]
48    RoundRobin,
49    /// Route to connection with fewest in-flight requests
50    LeastConnections,
51    /// Route based on health score (prefer healthier agents)
52    HealthBased,
53    /// Random selection
54    Random,
55}
56
57/// Flow control behavior when an agent signals it cannot accept requests.
58///
59/// When an agent sends a flow control "pause" signal, this determines
60/// whether requests should fail immediately or be allowed through.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum FlowControlMode {
63    /// Fail requests when agent is paused (default, safer).
64    ///
65    /// Returns `AgentProtocolError::FlowControlPaused` immediately.
66    /// Use this when you want strict backpressure and can handle
67    /// the error at the caller (e.g., return 503 to client).
68    #[default]
69    FailClosed,
70
71    /// Allow requests through even when agent is paused.
72    ///
73    /// Requests proceed without agent processing. Use this when
74    /// agent processing is optional (e.g., logging, analytics)
75    /// and you prefer availability over consistency.
76    FailOpen,
77
78    /// Queue requests briefly, then fail if still paused.
79    ///
80    /// Waits up to `flow_control_wait_timeout` for the agent to
81    /// resume before failing. Useful for transient pauses.
82    WaitAndRetry,
83}
84
85/// A sticky session entry tracking connection affinity for long-lived streams.
86///
87/// Used for WebSocket connections, Server-Sent Events, long-polling, and other
88/// streaming scenarios where the same agent connection should be used for the
89/// entire stream lifetime.
90struct StickySession {
91    /// The connection to use for this session
92    connection: Arc<PooledConnection>,
93    /// Agent ID for this session
94    agent_id: String,
95    /// When the session was created
96    created_at: Instant,
97    /// When the session was last accessed
98    last_accessed: AtomicU64,
99}
100
101impl std::fmt::Debug for StickySession {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("StickySession")
104            .field("agent_id", &self.agent_id)
105            .field("created_at", &self.created_at)
106            .finish_non_exhaustive()
107    }
108}
109
110/// Configuration for the agent connection pool.
111#[derive(Debug, Clone)]
112pub struct AgentPoolConfig {
113    /// Number of connections to maintain per agent
114    pub connections_per_agent: usize,
115    /// Load balancing strategy
116    pub load_balance_strategy: LoadBalanceStrategy,
117    /// Connection timeout
118    pub connect_timeout: Duration,
119    /// Request timeout
120    pub request_timeout: Duration,
121    /// Time between reconnection attempts
122    pub reconnect_interval: Duration,
123    /// Maximum reconnection attempts before marking agent unhealthy
124    pub max_reconnect_attempts: usize,
125    /// Time to wait for in-flight requests during shutdown
126    pub drain_timeout: Duration,
127    /// Maximum concurrent requests per connection
128    pub max_concurrent_per_connection: usize,
129    /// Health check interval
130    pub health_check_interval: Duration,
131    /// Channel buffer size for all transports.
132    ///
133    /// Controls backpressure behavior. Smaller values (16-64) apply
134    /// backpressure earlier, preventing memory buildup. Larger values
135    /// (128-256) handle burst traffic better but use more memory.
136    ///
137    /// Default: 64
138    pub channel_buffer_size: usize,
139    /// Flow control behavior when an agent signals it cannot accept requests.
140    ///
141    /// Default: `FlowControlMode::FailClosed`
142    pub flow_control_mode: FlowControlMode,
143    /// Timeout for `FlowControlMode::WaitAndRetry` before failing.
144    ///
145    /// Only used when `flow_control_mode` is `WaitAndRetry`.
146    /// Default: 100ms
147    pub flow_control_wait_timeout: Duration,
148    /// Timeout for sticky sessions before they expire.
149    ///
150    /// Sticky sessions are used for long-lived streaming connections
151    /// (WebSocket, SSE, long-polling) to ensure the same agent connection
152    /// is used for the entire stream lifetime.
153    ///
154    /// Set to None to disable automatic expiry (sessions only cleared explicitly).
155    ///
156    /// Default: 5 minutes
157    pub sticky_session_timeout: Option<Duration>,
158}
159
160impl Default for AgentPoolConfig {
161    fn default() -> Self {
162        Self {
163            connections_per_agent: 4,
164            load_balance_strategy: LoadBalanceStrategy::RoundRobin,
165            connect_timeout: Duration::from_secs(5),
166            request_timeout: Duration::from_secs(30),
167            reconnect_interval: Duration::from_secs(5),
168            max_reconnect_attempts: 3,
169            drain_timeout: Duration::from_secs(30),
170            max_concurrent_per_connection: 100,
171            health_check_interval: Duration::from_secs(10),
172            channel_buffer_size: CHANNEL_BUFFER_SIZE,
173            flow_control_mode: FlowControlMode::FailClosed,
174            flow_control_wait_timeout: Duration::from_millis(100),
175            sticky_session_timeout: Some(Duration::from_secs(5 * 60)), // 5 minutes
176        }
177    }
178}
179
180impl StickySession {
181    fn new(agent_id: String, connection: Arc<PooledConnection>) -> Self {
182        Self {
183            connection,
184            agent_id,
185            created_at: Instant::now(),
186            last_accessed: AtomicU64::new(0),
187        }
188    }
189
190    fn touch(&self) {
191        let offset = self.created_at.elapsed().as_millis() as u64;
192        self.last_accessed.store(offset, Ordering::Relaxed);
193    }
194
195    fn last_accessed(&self) -> Instant {
196        let offset_ms = self.last_accessed.load(Ordering::Relaxed);
197        self.created_at + Duration::from_millis(offset_ms)
198    }
199
200    fn is_expired(&self, timeout: Duration) -> bool {
201        self.last_accessed().elapsed() > timeout
202    }
203}
204
205/// Transport layer for v2 agent connections.
206///
207/// Supports gRPC, Unix Domain Socket, and reverse connections.
208pub enum V2Transport {
209    /// gRPC over HTTP/2
210    Grpc(AgentClientV2),
211    /// Binary protocol over Unix Domain Socket
212    Uds(AgentClientV2Uds),
213    /// Reverse connection (agent connected to proxy)
214    Reverse(ReverseConnectionClient),
215}
216
217impl V2Transport {
218    /// Check if the transport is connected.
219    pub async fn is_connected(&self) -> bool {
220        match self {
221            V2Transport::Grpc(client) => client.is_connected().await,
222            V2Transport::Uds(client) => client.is_connected().await,
223            V2Transport::Reverse(client) => client.is_connected().await,
224        }
225    }
226
227    /// Check if the transport can accept new requests.
228    ///
229    /// Returns false if the agent has sent a flow control pause signal.
230    pub async fn can_accept_requests(&self) -> bool {
231        match self {
232            V2Transport::Grpc(client) => client.can_accept_requests().await,
233            V2Transport::Uds(client) => client.can_accept_requests().await,
234            V2Transport::Reverse(client) => client.can_accept_requests().await,
235        }
236    }
237
238    /// Get negotiated capabilities.
239    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
240        match self {
241            V2Transport::Grpc(client) => client.capabilities().await,
242            V2Transport::Uds(client) => client.capabilities().await,
243            V2Transport::Reverse(client) => client.capabilities().await,
244        }
245    }
246
247    /// Send a request headers event.
248    pub async fn send_request_headers(
249        &self,
250        correlation_id: &str,
251        event: &RequestHeadersEvent,
252    ) -> Result<AgentResponse, AgentProtocolError> {
253        match self {
254            V2Transport::Grpc(client) => client.send_request_headers(correlation_id, event).await,
255            V2Transport::Uds(client) => client.send_request_headers(correlation_id, event).await,
256            V2Transport::Reverse(client) => {
257                client.send_request_headers(correlation_id, event).await
258            }
259        }
260    }
261
262    /// Send a request body chunk event.
263    pub async fn send_request_body_chunk(
264        &self,
265        correlation_id: &str,
266        event: &RequestBodyChunkEvent,
267    ) -> Result<AgentResponse, AgentProtocolError> {
268        match self {
269            V2Transport::Grpc(client) => {
270                client.send_request_body_chunk(correlation_id, event).await
271            }
272            V2Transport::Uds(client) => client.send_request_body_chunk(correlation_id, event).await,
273            V2Transport::Reverse(client) => {
274                client.send_request_body_chunk(correlation_id, event).await
275            }
276        }
277    }
278
279    /// Send a response headers event.
280    pub async fn send_response_headers(
281        &self,
282        correlation_id: &str,
283        event: &ResponseHeadersEvent,
284    ) -> Result<AgentResponse, AgentProtocolError> {
285        match self {
286            V2Transport::Grpc(client) => client.send_response_headers(correlation_id, event).await,
287            V2Transport::Uds(client) => client.send_response_headers(correlation_id, event).await,
288            V2Transport::Reverse(client) => {
289                client.send_response_headers(correlation_id, event).await
290            }
291        }
292    }
293
294    /// Send a response body chunk event.
295    pub async fn send_response_body_chunk(
296        &self,
297        correlation_id: &str,
298        event: &ResponseBodyChunkEvent,
299    ) -> Result<AgentResponse, AgentProtocolError> {
300        match self {
301            V2Transport::Grpc(client) => {
302                client.send_response_body_chunk(correlation_id, event).await
303            }
304            V2Transport::Uds(client) => {
305                client.send_response_body_chunk(correlation_id, event).await
306            }
307            V2Transport::Reverse(client) => {
308                client.send_response_body_chunk(correlation_id, event).await
309            }
310        }
311    }
312
313    /// Cancel a specific request.
314    pub async fn cancel_request(
315        &self,
316        correlation_id: &str,
317        reason: CancelReason,
318    ) -> Result<(), AgentProtocolError> {
319        match self {
320            V2Transport::Grpc(client) => client.cancel_request(correlation_id, reason).await,
321            V2Transport::Uds(client) => client.cancel_request(correlation_id, reason).await,
322            V2Transport::Reverse(client) => client.cancel_request(correlation_id, reason).await,
323        }
324    }
325
326    /// Cancel all in-flight requests.
327    pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
328        match self {
329            V2Transport::Grpc(client) => client.cancel_all(reason).await,
330            V2Transport::Uds(client) => client.cancel_all(reason).await,
331            V2Transport::Reverse(client) => client.cancel_all(reason).await,
332        }
333    }
334
335    /// Close the transport.
336    pub async fn close(&self) -> Result<(), AgentProtocolError> {
337        match self {
338            V2Transport::Grpc(client) => client.close().await,
339            V2Transport::Uds(client) => client.close().await,
340            V2Transport::Reverse(client) => client.close().await,
341        }
342    }
343
344    /// Get agent ID.
345    pub fn agent_id(&self) -> &str {
346        match self {
347            V2Transport::Grpc(client) => client.agent_id(),
348            V2Transport::Uds(client) => client.agent_id(),
349            V2Transport::Reverse(client) => client.agent_id(),
350        }
351    }
352}
353
354/// A pooled connection to an agent.
355struct PooledConnection {
356    client: V2Transport,
357    created_at: Instant,
358    /// Milliseconds since created_at when last used (avoids RwLock in hot path)
359    last_used_offset_ms: AtomicU64,
360    in_flight: AtomicU64,
361    request_count: AtomicU64,
362    error_count: AtomicU64,
363    consecutive_errors: AtomicU64,
364    concurrency_limiter: Semaphore,
365    /// Cached health state - updated by background maintenance, read in hot path
366    healthy_cached: AtomicBool,
367}
368
369impl PooledConnection {
370    fn new(client: V2Transport, max_concurrent: usize) -> Self {
371        Self {
372            client,
373            created_at: Instant::now(),
374            last_used_offset_ms: AtomicU64::new(0),
375            in_flight: AtomicU64::new(0),
376            request_count: AtomicU64::new(0),
377            error_count: AtomicU64::new(0),
378            consecutive_errors: AtomicU64::new(0),
379            concurrency_limiter: Semaphore::new(max_concurrent),
380            healthy_cached: AtomicBool::new(true), // Assume healthy until proven otherwise
381        }
382    }
383
384    fn in_flight(&self) -> u64 {
385        self.in_flight.load(Ordering::Relaxed)
386    }
387
388    fn error_rate(&self) -> f64 {
389        let requests = self.request_count.load(Ordering::Relaxed);
390        let errors = self.error_count.load(Ordering::Relaxed);
391        if requests == 0 {
392            0.0
393        } else {
394            errors as f64 / requests as f64
395        }
396    }
397
398    /// Fast health check using cached state (no async, no I/O).
399    /// Updated by background maintenance task.
400    #[inline]
401    fn is_healthy_cached(&self) -> bool {
402        self.healthy_cached.load(Ordering::Acquire)
403    }
404
405    /// Full health check with I/O - only called by maintenance task.
406    async fn check_and_update_health(&self) -> bool {
407        let connected = self.client.is_connected().await;
408        let low_errors = self.consecutive_errors.load(Ordering::Relaxed) < 3;
409        let can_accept = self.client.can_accept_requests().await;
410
411        let healthy = connected && low_errors && can_accept;
412        self.healthy_cached.store(healthy, Ordering::Release);
413        healthy
414    }
415
416    /// Record that this connection was just used.
417    #[inline]
418    fn touch(&self) {
419        let offset = self.created_at.elapsed().as_millis() as u64;
420        self.last_used_offset_ms.store(offset, Ordering::Relaxed);
421    }
422
423    /// Get the last used time.
424    fn last_used(&self) -> Instant {
425        let offset_ms = self.last_used_offset_ms.load(Ordering::Relaxed);
426        self.created_at + Duration::from_millis(offset_ms)
427    }
428}
429
430/// Statistics for a single agent in the pool.
431#[derive(Debug, Clone)]
432pub struct AgentPoolStats {
433    /// Agent identifier
434    pub agent_id: String,
435    /// Number of active connections
436    pub active_connections: usize,
437    /// Number of healthy connections
438    pub healthy_connections: usize,
439    /// Total in-flight requests across all connections
440    pub total_in_flight: u64,
441    /// Total requests processed
442    pub total_requests: u64,
443    /// Total errors
444    pub total_errors: u64,
445    /// Average error rate
446    pub error_rate: f64,
447    /// Whether the agent is considered healthy
448    pub is_healthy: bool,
449}
450
451/// An agent entry in the pool.
452struct AgentEntry {
453    agent_id: String,
454    endpoint: String,
455    /// Connections are rarely modified (only on reconnect), so RwLock is acceptable here.
456    /// The hot-path reads use try_read() to avoid blocking.
457    connections: RwLock<Vec<Arc<PooledConnection>>>,
458    capabilities: RwLock<Option<AgentCapabilities>>,
459    round_robin_index: AtomicUsize,
460    reconnect_attempts: AtomicUsize,
461    /// Stored as millis since UNIX_EPOCH to avoid RwLock
462    last_reconnect_attempt_ms: AtomicU64,
463    /// Cached aggregate health - true if any connection is healthy
464    healthy: AtomicBool,
465}
466
467impl AgentEntry {
468    fn new(agent_id: String, endpoint: String) -> Self {
469        Self {
470            agent_id,
471            endpoint,
472            connections: RwLock::new(Vec::new()),
473            capabilities: RwLock::new(None),
474            round_robin_index: AtomicUsize::new(0),
475            reconnect_attempts: AtomicUsize::new(0),
476            last_reconnect_attempt_ms: AtomicU64::new(0),
477            healthy: AtomicBool::new(true),
478        }
479    }
480
481    /// Check if enough time has passed since last reconnect attempt.
482    fn should_reconnect(&self, interval: Duration) -> bool {
483        let last_ms = self.last_reconnect_attempt_ms.load(Ordering::Relaxed);
484        if last_ms == 0 {
485            return true;
486        }
487        let now_ms = std::time::SystemTime::now()
488            .duration_since(std::time::UNIX_EPOCH)
489            .map(|d| d.as_millis() as u64)
490            .unwrap_or(0);
491        now_ms.saturating_sub(last_ms) > interval.as_millis() as u64
492    }
493
494    /// Record that a reconnect attempt was made.
495    fn mark_reconnect_attempt(&self) {
496        let now_ms = std::time::SystemTime::now()
497            .duration_since(std::time::UNIX_EPOCH)
498            .map(|d| d.as_millis() as u64)
499            .unwrap_or(0);
500        self.last_reconnect_attempt_ms
501            .store(now_ms, Ordering::Relaxed);
502    }
503}
504
505/// Agent connection pool.
506///
507/// Manages multiple connections to multiple agents with load balancing,
508/// health tracking, automatic reconnection, and metrics collection.
509///
510/// # Performance
511///
512/// Uses `DashMap` for lock-free reads in the hot path. Agent lookup is O(1)
513/// without contention. Connection selection uses cached health state to avoid
514/// async I/O per request.
515pub struct AgentPool {
516    config: AgentPoolConfig,
517    /// Lock-free concurrent map for agent lookup.
518    /// Reads (select_connection) are lock-free. Writes (add/remove agent) shard-lock.
519    agents: DashMap<String, Arc<AgentEntry>>,
520    total_requests: AtomicU64,
521    total_errors: AtomicU64,
522    /// Shared metrics collector for all agents
523    metrics_collector: Arc<MetricsCollector>,
524    /// Callback used to record metrics from clients
525    metrics_callback: MetricsCallback,
526    /// Config pusher for distributing config updates to agents
527    config_pusher: Arc<ConfigPusher>,
528    /// Handler for config update requests from agents
529    config_update_handler: Arc<ConfigUpdateHandler>,
530    /// Callback used to handle config updates from clients
531    config_update_callback: ConfigUpdateCallback,
532    /// Protocol-level metrics (proxy-side instrumentation)
533    protocol_metrics: Arc<ProtocolMetrics>,
534    /// Connection affinity: correlation_id → connection used for headers.
535    /// Ensures body chunks go to the same connection as headers for streaming.
536    correlation_affinity: DashMap<String, Arc<PooledConnection>>,
537    /// Sticky sessions: session_id → session info for long-lived streams.
538    /// Used for WebSocket, SSE, and long-polling connections.
539    sticky_sessions: DashMap<String, StickySession>,
540}
541
542impl AgentPool {
543    /// Create a new agent pool with default configuration.
544    pub fn new() -> Self {
545        Self::with_config(AgentPoolConfig::default())
546    }
547
548    /// Create a new agent pool with custom configuration.
549    pub fn with_config(config: AgentPoolConfig) -> Self {
550        let metrics_collector = Arc::new(MetricsCollector::new());
551        let collector_clone = Arc::clone(&metrics_collector);
552
553        // Create a callback that records metrics to the collector
554        let metrics_callback: MetricsCallback = Arc::new(move |report| {
555            collector_clone.record(&report);
556        });
557
558        // Create config pusher and handler
559        let config_pusher = Arc::new(ConfigPusher::new());
560        let config_update_handler = Arc::new(ConfigUpdateHandler::new());
561        let handler_clone = Arc::clone(&config_update_handler);
562
563        // Create a callback that handles config update requests from agents
564        let config_update_callback: ConfigUpdateCallback = Arc::new(move |agent_id, request| {
565            debug!(
566                agent_id = %agent_id,
567                request_id = %request.request_id,
568                "Processing config update request from agent"
569            );
570            handler_clone.handle(request)
571        });
572
573        Self {
574            config,
575            agents: DashMap::new(),
576            total_requests: AtomicU64::new(0),
577            total_errors: AtomicU64::new(0),
578            metrics_collector,
579            metrics_callback,
580            config_pusher,
581            config_update_handler,
582            config_update_callback,
583            protocol_metrics: Arc::new(ProtocolMetrics::new()),
584            correlation_affinity: DashMap::new(),
585            sticky_sessions: DashMap::new(),
586        }
587    }
588
589    /// Get the protocol metrics for accessing proxy-side instrumentation.
590    pub fn protocol_metrics(&self) -> &ProtocolMetrics {
591        &self.protocol_metrics
592    }
593
594    /// Get an Arc to the protocol metrics.
595    pub fn protocol_metrics_arc(&self) -> Arc<ProtocolMetrics> {
596        Arc::clone(&self.protocol_metrics)
597    }
598
599    /// Get the metrics collector for accessing aggregated agent metrics.
600    pub fn metrics_collector(&self) -> &MetricsCollector {
601        &self.metrics_collector
602    }
603
604    /// Get an Arc to the metrics collector.
605    ///
606    /// This is useful for registering the collector with a MetricsManager.
607    pub fn metrics_collector_arc(&self) -> Arc<MetricsCollector> {
608        Arc::clone(&self.metrics_collector)
609    }
610
611    /// Export all agent metrics in Prometheus format.
612    pub fn export_prometheus(&self) -> String {
613        self.metrics_collector.export_prometheus()
614    }
615
616    /// Clear connection affinity for a correlation ID.
617    ///
618    /// Call this when a request is complete (after receiving final response)
619    /// to free the affinity mapping. Not strictly necessary (affinities are
620    /// cheap), but good practice for long-running proxies.
621    pub fn clear_correlation_affinity(&self, correlation_id: &str) {
622        self.correlation_affinity.remove(correlation_id);
623    }
624
625    /// Get the number of active correlation affinities.
626    ///
627    /// This is useful for monitoring and debugging.
628    pub fn correlation_affinity_count(&self) -> usize {
629        self.correlation_affinity.len()
630    }
631
632    // =========================================================================
633    // Sticky Sessions
634    // =========================================================================
635
636    /// Create a sticky session for a long-lived streaming connection.
637    ///
638    /// Sticky sessions ensure that all requests for a given session use the
639    /// same agent connection. This is essential for:
640    /// - WebSocket connections
641    /// - Server-Sent Events (SSE)
642    /// - Long-polling
643    /// - Any streaming scenario requiring connection affinity
644    ///
645    /// # Arguments
646    ///
647    /// * `session_id` - A unique identifier for this session (e.g., WebSocket connection ID)
648    /// * `agent_id` - The agent to bind this session to
649    ///
650    /// # Returns
651    ///
652    /// Returns `Ok(())` if the session was created, or an error if the agent
653    /// is not found or has no healthy connections.
654    ///
655    /// # Example
656    ///
657    /// ```ignore
658    /// // When a WebSocket is established
659    /// pool.create_sticky_session("ws-12345", "waf-agent").await?;
660    ///
661    /// // All subsequent messages use the same connection
662    /// pool.send_request_with_sticky_session("ws-12345", &event).await?;
663    ///
664    /// // When the WebSocket closes
665    /// pool.clear_sticky_session("ws-12345");
666    /// ```
667    pub fn create_sticky_session(
668        &self,
669        session_id: impl Into<String>,
670        agent_id: &str,
671    ) -> Result<(), AgentProtocolError> {
672        let session_id = session_id.into();
673        let conn = self.select_connection(agent_id)?;
674
675        let session = StickySession::new(agent_id.to_string(), conn);
676        session.touch();
677
678        self.sticky_sessions.insert(session_id.clone(), session);
679
680        debug!(
681            session_id = %session_id,
682            agent_id = %agent_id,
683            "Created sticky session"
684        );
685
686        Ok(())
687    }
688
689    /// Get the connection for a sticky session (internal use).
690    ///
691    /// Returns the connection bound to this session, or None if the session
692    /// doesn't exist or has expired.
693    fn get_sticky_session_conn(&self, session_id: &str) -> Option<Arc<PooledConnection>> {
694        let entry = self.sticky_sessions.get(session_id)?;
695
696        // Check expiration if configured
697        if let Some(timeout) = self.config.sticky_session_timeout {
698            if entry.is_expired(timeout) {
699                drop(entry); // Release the reference before removing
700                self.sticky_sessions.remove(session_id);
701                debug!(session_id = %session_id, "Sticky session expired");
702                return None;
703            }
704        }
705
706        entry.touch();
707        Some(Arc::clone(&entry.connection))
708    }
709
710    /// Refresh a sticky session, updating its last-accessed time.
711    ///
712    /// Returns true if the session exists and was refreshed, false otherwise.
713    pub fn refresh_sticky_session(&self, session_id: &str) -> bool {
714        self.get_sticky_session_conn(session_id).is_some()
715    }
716
717    /// Check if a sticky session exists and is valid.
718    pub fn has_sticky_session(&self, session_id: &str) -> bool {
719        self.get_sticky_session_conn(session_id).is_some()
720    }
721
722    /// Clear a sticky session.
723    ///
724    /// Call this when a long-lived stream ends (WebSocket closed, SSE ended, etc.)
725    pub fn clear_sticky_session(&self, session_id: &str) {
726        if self.sticky_sessions.remove(session_id).is_some() {
727            debug!(session_id = %session_id, "Cleared sticky session");
728        }
729    }
730
731    /// Get the number of active sticky sessions.
732    ///
733    /// Useful for monitoring and debugging.
734    pub fn sticky_session_count(&self) -> usize {
735        self.sticky_sessions.len()
736    }
737
738    /// Get the agent ID bound to a sticky session.
739    pub fn sticky_session_agent(&self, session_id: &str) -> Option<String> {
740        self.sticky_sessions
741            .get(session_id)
742            .map(|s| s.agent_id.clone())
743    }
744
745    /// Send a request using a sticky session.
746    ///
747    /// If the session exists and is valid, uses the bound connection.
748    /// If the session doesn't exist or has expired, falls back to normal
749    /// connection selection.
750    ///
751    /// # Returns
752    ///
753    /// A tuple of (response, used_sticky_session).
754    pub async fn send_request_headers_with_sticky_session(
755        &self,
756        session_id: &str,
757        agent_id: &str,
758        correlation_id: &str,
759        event: &RequestHeadersEvent,
760    ) -> Result<(AgentResponse, bool), AgentProtocolError> {
761        let start = Instant::now();
762        self.total_requests.fetch_add(1, Ordering::Relaxed);
763        self.protocol_metrics.inc_requests();
764        self.protocol_metrics.inc_in_flight();
765
766        // Try sticky session first
767        let (conn, used_sticky) =
768            if let Some(sticky_conn) = self.get_sticky_session_conn(session_id) {
769                (sticky_conn, true)
770            } else {
771                (self.select_connection(agent_id)?, false)
772            };
773
774        // Check flow control
775        match self.check_flow_control(&conn, agent_id).await {
776            Ok(true) => {}
777            Ok(false) => {
778                self.protocol_metrics.dec_in_flight();
779                return Ok((AgentResponse::default_allow(), used_sticky));
780            }
781            Err(e) => {
782                self.protocol_metrics.dec_in_flight();
783                return Err(e);
784            }
785        }
786
787        // Acquire concurrency permit
788        let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
789            self.protocol_metrics.dec_in_flight();
790            self.protocol_metrics.inc_connection_errors();
791            AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
792        })?;
793
794        conn.in_flight.fetch_add(1, Ordering::Relaxed);
795        conn.touch();
796
797        // Store correlation affinity
798        self.correlation_affinity
799            .insert(correlation_id.to_string(), Arc::clone(&conn));
800
801        let result = conn
802            .client
803            .send_request_headers(correlation_id, event)
804            .await;
805
806        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
807        conn.request_count.fetch_add(1, Ordering::Relaxed);
808        self.protocol_metrics.dec_in_flight();
809        self.protocol_metrics
810            .record_request_duration(start.elapsed());
811
812        match &result {
813            Ok(_) => {
814                conn.consecutive_errors.store(0, Ordering::Relaxed);
815                self.protocol_metrics.inc_responses();
816            }
817            Err(e) => {
818                conn.error_count.fetch_add(1, Ordering::Relaxed);
819                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
820                self.total_errors.fetch_add(1, Ordering::Relaxed);
821
822                match e {
823                    AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
824                    AgentProtocolError::ConnectionFailed(_)
825                    | AgentProtocolError::ConnectionClosed => {
826                        self.protocol_metrics.inc_connection_errors();
827                    }
828                    AgentProtocolError::Serialization(_) => {
829                        self.protocol_metrics.inc_serialization_errors();
830                    }
831                    _ => {}
832                }
833
834                if consecutive >= 3 {
835                    conn.healthy_cached.store(false, Ordering::Release);
836                }
837            }
838        }
839
840        result.map(|r| (r, used_sticky))
841    }
842
843    /// Clean up expired sticky sessions.
844    ///
845    /// Called automatically by the maintenance task, but can also be called
846    /// manually to immediately reclaim resources.
847    pub fn cleanup_expired_sessions(&self) -> usize {
848        let Some(timeout) = self.config.sticky_session_timeout else {
849            return 0;
850        };
851
852        let mut removed = 0;
853        self.sticky_sessions.retain(|session_id, session| {
854            if session.is_expired(timeout) {
855                debug!(session_id = %session_id, "Removing expired sticky session");
856                removed += 1;
857                false
858            } else {
859                true
860            }
861        });
862
863        if removed > 0 {
864            trace!(removed = removed, "Cleaned up expired sticky sessions");
865        }
866
867        removed
868    }
869
870    /// Get the config pusher for pushing configuration updates to agents.
871    pub fn config_pusher(&self) -> &ConfigPusher {
872        &self.config_pusher
873    }
874
875    /// Get the config update handler for processing agent config requests.
876    pub fn config_update_handler(&self) -> &ConfigUpdateHandler {
877        &self.config_update_handler
878    }
879
880    /// Push a configuration update to a specific agent.
881    ///
882    /// Returns the push ID if the agent supports config push, None otherwise.
883    pub fn push_config_to_agent(
884        &self,
885        agent_id: &str,
886        update_type: ConfigUpdateType,
887    ) -> Option<String> {
888        self.config_pusher.push_to_agent(agent_id, update_type)
889    }
890
891    /// Push a configuration update to all agents that support config push.
892    ///
893    /// Returns the push IDs for each agent that received the update.
894    pub fn push_config_to_all(&self, update_type: ConfigUpdateType) -> Vec<String> {
895        self.config_pusher.push_to_all(update_type)
896    }
897
898    /// Acknowledge a config push by its push ID.
899    pub fn acknowledge_config_push(&self, push_id: &str, accepted: bool, error: Option<String>) {
900        self.config_pusher.acknowledge(push_id, accepted, error);
901    }
902
903    /// Add an agent to the pool.
904    ///
905    /// This creates the configured number of connections to the agent.
906    pub async fn add_agent(
907        &self,
908        agent_id: impl Into<String>,
909        endpoint: impl Into<String>,
910    ) -> Result<(), AgentProtocolError> {
911        let agent_id = agent_id.into();
912        let endpoint = endpoint.into();
913
914        info!(agent_id = %agent_id, endpoint = %endpoint, "Adding agent to pool");
915
916        let entry = Arc::new(AgentEntry::new(agent_id.clone(), endpoint.clone()));
917
918        // Create initial connections
919        let mut connections = Vec::with_capacity(self.config.connections_per_agent);
920        for i in 0..self.config.connections_per_agent {
921            match self.create_connection(&agent_id, &endpoint).await {
922                Ok(conn) => {
923                    connections.push(Arc::new(conn));
924                    debug!(
925                        agent_id = %agent_id,
926                        connection = i,
927                        "Created connection"
928                    );
929                }
930                Err(e) => {
931                    warn!(
932                        agent_id = %agent_id,
933                        connection = i,
934                        error = %e,
935                        "Failed to create connection"
936                    );
937                    // Continue - we'll try to reconnect later
938                }
939            }
940        }
941
942        if connections.is_empty() {
943            return Err(AgentProtocolError::ConnectionFailed(format!(
944                "Failed to create any connections to agent {}",
945                agent_id
946            )));
947        }
948
949        // Store capabilities from first successful connection and register with ConfigPusher
950        if let Some(conn) = connections.first() {
951            if let Some(caps) = conn.client.capabilities().await {
952                // Register with ConfigPusher based on capabilities
953                let supports_config_push = caps.features.config_push;
954                let agent_name = caps.name.clone();
955                self.config_pusher
956                    .register_agent(&agent_id, &agent_name, supports_config_push);
957                debug!(
958                    agent_id = %agent_id,
959                    supports_config_push = supports_config_push,
960                    "Registered agent with ConfigPusher"
961                );
962
963                *entry.capabilities.write().await = Some(caps);
964            }
965        }
966
967        *entry.connections.write().await = connections;
968        self.agents.insert(agent_id.clone(), entry);
969
970        info!(
971            agent_id = %agent_id,
972            connections = self.config.connections_per_agent,
973            "Agent added to pool"
974        );
975
976        Ok(())
977    }
978
979    /// Remove an agent from the pool.
980    ///
981    /// This gracefully closes all connections to the agent.
982    pub async fn remove_agent(&self, agent_id: &str) -> Result<(), AgentProtocolError> {
983        info!(agent_id = %agent_id, "Removing agent from pool");
984
985        // Unregister from ConfigPusher
986        self.config_pusher.unregister_agent(agent_id);
987
988        let (_, entry) = self.agents.remove(agent_id).ok_or_else(|| {
989            AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
990        })?;
991
992        // Close all connections
993        let connections = entry.connections.read().await;
994        for conn in connections.iter() {
995            let _ = conn.client.close().await;
996        }
997
998        info!(agent_id = %agent_id, "Agent removed from pool");
999        Ok(())
1000    }
1001
1002    /// Add a reverse connection to the pool.
1003    ///
1004    /// This is called by the ReverseConnectionListener when an agent connects.
1005    /// The connection is wrapped in a V2Transport and added to the agent's
1006    /// connection pool.
1007    pub async fn add_reverse_connection(
1008        &self,
1009        agent_id: &str,
1010        client: ReverseConnectionClient,
1011        capabilities: AgentCapabilities,
1012    ) -> Result<(), AgentProtocolError> {
1013        info!(
1014            agent_id = %agent_id,
1015            connection_id = %client.connection_id(),
1016            "Adding reverse connection to pool"
1017        );
1018
1019        let transport = V2Transport::Reverse(client);
1020        let conn = Arc::new(PooledConnection::new(
1021            transport,
1022            self.config.max_concurrent_per_connection,
1023        ));
1024
1025        // Check if agent already exists (use entry API for atomic check-and-insert)
1026        if let Some(entry) = self.agents.get(agent_id) {
1027            // Add to existing agent's connections
1028            let mut connections = entry.connections.write().await;
1029
1030            // Check connection limit
1031            if connections.len() >= self.config.connections_per_agent {
1032                warn!(
1033                    agent_id = %agent_id,
1034                    current = connections.len(),
1035                    max = self.config.connections_per_agent,
1036                    "Reverse connection rejected: at connection limit"
1037                );
1038                return Err(AgentProtocolError::ConnectionFailed(format!(
1039                    "Agent {} already has maximum connections ({})",
1040                    agent_id, self.config.connections_per_agent
1041                )));
1042            }
1043
1044            connections.push(conn);
1045            info!(
1046                agent_id = %agent_id,
1047                total_connections = connections.len(),
1048                "Added reverse connection to existing agent"
1049            );
1050        } else {
1051            // Create new agent entry
1052            let entry = Arc::new(AgentEntry::new(
1053                agent_id.to_string(),
1054                format!("reverse://{}", agent_id),
1055            ));
1056
1057            // Register with ConfigPusher
1058            let supports_config_push = capabilities.features.config_push;
1059            let agent_name = capabilities.name.clone();
1060            self.config_pusher
1061                .register_agent(agent_id, &agent_name, supports_config_push);
1062            debug!(
1063                agent_id = %agent_id,
1064                supports_config_push = supports_config_push,
1065                "Registered reverse connection agent with ConfigPusher"
1066            );
1067
1068            *entry.capabilities.write().await = Some(capabilities);
1069            *entry.connections.write().await = vec![conn];
1070            self.agents.insert(agent_id.to_string(), entry);
1071
1072            info!(
1073                agent_id = %agent_id,
1074                "Created new agent entry for reverse connection"
1075            );
1076        }
1077
1078        Ok(())
1079    }
1080
1081    /// Check flow control and handle according to configured mode.
1082    ///
1083    /// Returns `Ok(true)` if request should proceed normally.
1084    /// Returns `Ok(false)` if request should skip agent (FailOpen mode).
1085    /// Returns `Err` if request should fail (FailClosed or WaitAndRetry timeout).
1086    async fn check_flow_control(
1087        &self,
1088        conn: &PooledConnection,
1089        agent_id: &str,
1090    ) -> Result<bool, AgentProtocolError> {
1091        if conn.client.can_accept_requests().await {
1092            return Ok(true);
1093        }
1094
1095        match self.config.flow_control_mode {
1096            FlowControlMode::FailClosed => {
1097                self.protocol_metrics.record_flow_rejection();
1098                Err(AgentProtocolError::FlowControlPaused {
1099                    agent_id: agent_id.to_string(),
1100                })
1101            }
1102            FlowControlMode::FailOpen => {
1103                // Log but allow through
1104                debug!(agent_id = %agent_id, "Flow control: agent paused, allowing request (fail-open mode)");
1105                self.protocol_metrics.record_flow_rejection();
1106                Ok(false) // Signal to skip agent processing
1107            }
1108            FlowControlMode::WaitAndRetry => {
1109                // Wait briefly for agent to resume
1110                let deadline = Instant::now() + self.config.flow_control_wait_timeout;
1111                while Instant::now() < deadline {
1112                    tokio::time::sleep(Duration::from_millis(10)).await;
1113                    if conn.client.can_accept_requests().await {
1114                        trace!(agent_id = %agent_id, "Flow control: agent resumed after wait");
1115                        return Ok(true);
1116                    }
1117                }
1118                // Timeout - fail the request
1119                self.protocol_metrics.record_flow_rejection();
1120                Err(AgentProtocolError::FlowControlPaused {
1121                    agent_id: agent_id.to_string(),
1122                })
1123            }
1124        }
1125    }
1126
1127    /// Send a request headers event to an agent.
1128    ///
1129    /// The pool selects the best connection based on the load balancing strategy.
1130    ///
1131    /// # Performance
1132    ///
1133    /// This is the hot path. Uses:
1134    /// - Lock-free agent lookup via `DashMap`
1135    /// - Cached health state (no async I/O for health check)
1136    /// - Atomic last_used tracking (no RwLock)
1137    pub async fn send_request_headers(
1138        &self,
1139        agent_id: &str,
1140        correlation_id: &str,
1141        event: &RequestHeadersEvent,
1142    ) -> Result<AgentResponse, AgentProtocolError> {
1143        let start = Instant::now();
1144        self.total_requests.fetch_add(1, Ordering::Relaxed);
1145        self.protocol_metrics.inc_requests();
1146        self.protocol_metrics.inc_in_flight();
1147
1148        let conn = self.select_connection(agent_id)?;
1149
1150        // Check flow control before sending (respects flow_control_mode config)
1151        match self.check_flow_control(&conn, agent_id).await {
1152            Ok(true) => {} // Proceed normally
1153            Ok(false) => {
1154                // FailOpen mode: skip agent, return allow response
1155                self.protocol_metrics.dec_in_flight();
1156                return Ok(AgentResponse::default_allow());
1157            }
1158            Err(e) => {
1159                self.protocol_metrics.dec_in_flight();
1160                return Err(e);
1161            }
1162        }
1163
1164        // Acquire concurrency permit
1165        let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1166            self.protocol_metrics.dec_in_flight();
1167            self.protocol_metrics.inc_connection_errors();
1168            AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1169        })?;
1170
1171        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1172        conn.touch(); // Atomic, no lock
1173
1174        // Store connection affinity for body chunk routing
1175        self.correlation_affinity
1176            .insert(correlation_id.to_string(), Arc::clone(&conn));
1177
1178        let result = conn
1179            .client
1180            .send_request_headers(correlation_id, event)
1181            .await;
1182
1183        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1184        conn.request_count.fetch_add(1, Ordering::Relaxed);
1185        self.protocol_metrics.dec_in_flight();
1186        self.protocol_metrics
1187            .record_request_duration(start.elapsed());
1188
1189        match &result {
1190            Ok(_) => {
1191                conn.consecutive_errors.store(0, Ordering::Relaxed);
1192                self.protocol_metrics.inc_responses();
1193            }
1194            Err(e) => {
1195                conn.error_count.fetch_add(1, Ordering::Relaxed);
1196                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1197                self.total_errors.fetch_add(1, Ordering::Relaxed);
1198
1199                // Record error type
1200                match e {
1201                    AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
1202                    AgentProtocolError::ConnectionFailed(_)
1203                    | AgentProtocolError::ConnectionClosed => {
1204                        self.protocol_metrics.inc_connection_errors();
1205                    }
1206                    AgentProtocolError::Serialization(_) => {
1207                        self.protocol_metrics.inc_serialization_errors();
1208                    }
1209                    _ => {}
1210                }
1211
1212                // Mark unhealthy immediately on repeated failures (fast feedback)
1213                if consecutive >= 3 {
1214                    conn.healthy_cached.store(false, Ordering::Release);
1215                    trace!(agent_id = %agent_id, error = %e, "Connection marked unhealthy after consecutive errors");
1216                }
1217            }
1218        }
1219
1220        result
1221    }
1222
1223    /// Send a request body chunk to an agent.
1224    ///
1225    /// The pool uses correlation_id to route the chunk to the same connection
1226    /// that received the request headers (connection affinity).
1227    pub async fn send_request_body_chunk(
1228        &self,
1229        agent_id: &str,
1230        correlation_id: &str,
1231        event: &RequestBodyChunkEvent,
1232    ) -> Result<AgentResponse, AgentProtocolError> {
1233        self.total_requests.fetch_add(1, Ordering::Relaxed);
1234
1235        // Try to use affinity (same connection as headers), fall back to selection
1236        let conn = if let Some(affinity_conn) = self.correlation_affinity.get(correlation_id) {
1237            Arc::clone(&affinity_conn)
1238        } else {
1239            // No affinity found, use normal selection (shouldn't happen in normal flow)
1240            trace!(correlation_id = %correlation_id, "No affinity found for body chunk, using selection");
1241            self.select_connection(agent_id)?
1242        };
1243
1244        // Check flow control before sending body chunks (critical for backpressure)
1245        match self.check_flow_control(&conn, agent_id).await {
1246            Ok(true) => {} // Proceed normally
1247            Ok(false) => {
1248                // FailOpen mode: skip agent, return allow response
1249                return Ok(AgentResponse::default_allow());
1250            }
1251            Err(e) => return Err(e),
1252        }
1253
1254        let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1255            AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1256        })?;
1257
1258        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1259        conn.touch();
1260
1261        let result = conn
1262            .client
1263            .send_request_body_chunk(correlation_id, event)
1264            .await;
1265
1266        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1267        conn.request_count.fetch_add(1, Ordering::Relaxed);
1268
1269        match &result {
1270            Ok(_) => {
1271                conn.consecutive_errors.store(0, Ordering::Relaxed);
1272            }
1273            Err(_) => {
1274                conn.error_count.fetch_add(1, Ordering::Relaxed);
1275                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1276                self.total_errors.fetch_add(1, Ordering::Relaxed);
1277                if consecutive >= 3 {
1278                    conn.healthy_cached.store(false, Ordering::Release);
1279                }
1280            }
1281        }
1282
1283        result
1284    }
1285
1286    /// Send response headers to an agent.
1287    ///
1288    /// Called when upstream response headers are received, allowing the agent
1289    /// to inspect/modify response headers before they're sent to the client.
1290    pub async fn send_response_headers(
1291        &self,
1292        agent_id: &str,
1293        correlation_id: &str,
1294        event: &ResponseHeadersEvent,
1295    ) -> Result<AgentResponse, AgentProtocolError> {
1296        self.total_requests.fetch_add(1, Ordering::Relaxed);
1297
1298        let conn = self.select_connection(agent_id)?;
1299
1300        let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1301            AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1302        })?;
1303
1304        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1305        conn.touch();
1306
1307        let result = conn
1308            .client
1309            .send_response_headers(correlation_id, event)
1310            .await;
1311
1312        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1313        conn.request_count.fetch_add(1, Ordering::Relaxed);
1314
1315        match &result {
1316            Ok(_) => {
1317                conn.consecutive_errors.store(0, Ordering::Relaxed);
1318            }
1319            Err(_) => {
1320                conn.error_count.fetch_add(1, Ordering::Relaxed);
1321                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1322                self.total_errors.fetch_add(1, Ordering::Relaxed);
1323                if consecutive >= 3 {
1324                    conn.healthy_cached.store(false, Ordering::Release);
1325                }
1326            }
1327        }
1328
1329        result
1330    }
1331
1332    /// Send a response body chunk to an agent.
1333    ///
1334    /// For streaming response body inspection, chunks are sent sequentially.
1335    /// The agent can inspect and optionally modify response body data.
1336    pub async fn send_response_body_chunk(
1337        &self,
1338        agent_id: &str,
1339        correlation_id: &str,
1340        event: &ResponseBodyChunkEvent,
1341    ) -> Result<AgentResponse, AgentProtocolError> {
1342        self.total_requests.fetch_add(1, Ordering::Relaxed);
1343
1344        let conn = self.select_connection(agent_id)?;
1345
1346        // Check flow control before sending body chunks (critical for backpressure)
1347        match self.check_flow_control(&conn, agent_id).await {
1348            Ok(true) => {} // Proceed normally
1349            Ok(false) => {
1350                // FailOpen mode: skip agent, return allow response
1351                return Ok(AgentResponse::default_allow());
1352            }
1353            Err(e) => return Err(e),
1354        }
1355
1356        let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1357            AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1358        })?;
1359
1360        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1361        conn.touch();
1362
1363        let result = conn
1364            .client
1365            .send_response_body_chunk(correlation_id, event)
1366            .await;
1367
1368        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1369        conn.request_count.fetch_add(1, Ordering::Relaxed);
1370
1371        match &result {
1372            Ok(_) => {
1373                conn.consecutive_errors.store(0, Ordering::Relaxed);
1374            }
1375            Err(_) => {
1376                conn.error_count.fetch_add(1, Ordering::Relaxed);
1377                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1378                self.total_errors.fetch_add(1, Ordering::Relaxed);
1379                if consecutive >= 3 {
1380                    conn.healthy_cached.store(false, Ordering::Release);
1381                }
1382            }
1383        }
1384
1385        result
1386    }
1387
1388    /// Cancel a request on all connections for an agent.
1389    pub async fn cancel_request(
1390        &self,
1391        agent_id: &str,
1392        correlation_id: &str,
1393        reason: CancelReason,
1394    ) -> Result<(), AgentProtocolError> {
1395        let entry = self.agents.get(agent_id).ok_or_else(|| {
1396            AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1397        })?;
1398
1399        let connections = entry.connections.read().await;
1400        for conn in connections.iter() {
1401            let _ = conn.client.cancel_request(correlation_id, reason).await;
1402        }
1403
1404        Ok(())
1405    }
1406
1407    /// Get statistics for all agents in the pool.
1408    pub async fn stats(&self) -> Vec<AgentPoolStats> {
1409        let mut stats = Vec::with_capacity(self.agents.len());
1410
1411        for entry_ref in self.agents.iter() {
1412            let agent_id = entry_ref.key().clone();
1413            let entry = entry_ref.value();
1414
1415            let connections = entry.connections.read().await;
1416            let mut healthy_count = 0;
1417            let mut total_in_flight = 0;
1418            let mut total_requests = 0;
1419            let mut total_errors = 0;
1420
1421            for conn in connections.iter() {
1422                // Use cached health for stats (consistent with hot path)
1423                if conn.is_healthy_cached() {
1424                    healthy_count += 1;
1425                }
1426                total_in_flight += conn.in_flight();
1427                total_requests += conn.request_count.load(Ordering::Relaxed);
1428                total_errors += conn.error_count.load(Ordering::Relaxed);
1429            }
1430
1431            let error_rate = if total_requests == 0 {
1432                0.0
1433            } else {
1434                total_errors as f64 / total_requests as f64
1435            };
1436
1437            stats.push(AgentPoolStats {
1438                agent_id,
1439                active_connections: connections.len(),
1440                healthy_connections: healthy_count,
1441                total_in_flight,
1442                total_requests,
1443                total_errors,
1444                error_rate,
1445                is_healthy: entry.healthy.load(Ordering::Acquire),
1446            });
1447        }
1448
1449        stats
1450    }
1451
1452    /// Get statistics for a specific agent.
1453    pub async fn agent_stats(&self, agent_id: &str) -> Option<AgentPoolStats> {
1454        self.stats()
1455            .await
1456            .into_iter()
1457            .find(|s| s.agent_id == agent_id)
1458    }
1459
1460    /// Get the capabilities of an agent.
1461    pub async fn agent_capabilities(&self, agent_id: &str) -> Option<AgentCapabilities> {
1462        // Clone the Arc out of the DashMap Ref to avoid lifetime issues
1463        let entry = match self.agents.get(agent_id) {
1464            Some(entry_ref) => Arc::clone(&*entry_ref),
1465            None => return None,
1466        };
1467        // Bind to temp to ensure guard drops before function returns
1468        let result = entry.capabilities.read().await.clone();
1469        result
1470    }
1471
1472    /// Check if an agent is healthy.
1473    ///
1474    /// Uses cached health state for fast, lock-free access.
1475    pub fn is_agent_healthy(&self, agent_id: &str) -> bool {
1476        self.agents
1477            .get(agent_id)
1478            .map(|e| e.healthy.load(Ordering::Acquire))
1479            .unwrap_or(false)
1480    }
1481
1482    /// Get all agent IDs in the pool.
1483    pub fn agent_ids(&self) -> Vec<String> {
1484        self.agents.iter().map(|e| e.key().clone()).collect()
1485    }
1486
1487    /// Gracefully shut down the pool.
1488    ///
1489    /// This drains all connections and waits for in-flight requests to complete.
1490    pub async fn shutdown(&self) -> Result<(), AgentProtocolError> {
1491        info!("Shutting down agent pool");
1492
1493        // Collect all agents (DashMap doesn't have drain, so we remove one by one)
1494        let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1495
1496        for agent_id in agent_ids {
1497            if let Some((_, entry)) = self.agents.remove(&agent_id) {
1498                debug!(agent_id = %agent_id, "Draining agent connections");
1499
1500                let connections = entry.connections.read().await;
1501                for conn in connections.iter() {
1502                    // Cancel all pending requests
1503                    let _ = conn.client.cancel_all(CancelReason::ProxyShutdown).await;
1504                }
1505
1506                // Wait for in-flight requests to complete
1507                let drain_deadline = Instant::now() + self.config.drain_timeout;
1508                loop {
1509                    let total_in_flight: u64 = connections.iter().map(|c| c.in_flight()).sum();
1510                    if total_in_flight == 0 {
1511                        break;
1512                    }
1513                    if Instant::now() > drain_deadline {
1514                        warn!(
1515                            agent_id = %agent_id,
1516                            in_flight = total_in_flight,
1517                            "Drain timeout, forcing close"
1518                        );
1519                        break;
1520                    }
1521                    tokio::time::sleep(Duration::from_millis(100)).await;
1522                }
1523
1524                // Close all connections
1525                for conn in connections.iter() {
1526                    let _ = conn.client.close().await;
1527                }
1528            }
1529        }
1530
1531        info!("Agent pool shutdown complete");
1532        Ok(())
1533    }
1534
1535    /// Run background maintenance tasks.
1536    ///
1537    /// This should be spawned as a background task. It handles:
1538    /// - Health checking (updates cached health state)
1539    /// - Reconnection of failed connections
1540    /// - Cleanup of idle connections
1541    ///
1542    /// # Health Check Strategy
1543    ///
1544    /// Health is checked here (with I/O) and cached in `PooledConnection::healthy_cached`.
1545    /// The hot path (`select_connection`) reads the cached value without I/O.
1546    pub async fn run_maintenance(&self) {
1547        let mut interval = tokio::time::interval(self.config.health_check_interval);
1548
1549        loop {
1550            interval.tick().await;
1551
1552            // Clean up expired sticky sessions
1553            self.cleanup_expired_sessions();
1554
1555            // Iterate without holding a long-lived lock
1556            let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1557
1558            for agent_id in agent_ids {
1559                let Some(entry_ref) = self.agents.get(&agent_id) else {
1560                    continue; // Agent was removed
1561                };
1562                let entry = entry_ref.value().clone();
1563                drop(entry_ref); // Release DashMap ref before async work
1564
1565                // Check connection health (this does I/O)
1566                let connections = entry.connections.read().await;
1567                let mut healthy_count = 0;
1568
1569                for conn in connections.iter() {
1570                    // Full health check with I/O, updates cached state
1571                    if conn.check_and_update_health().await {
1572                        healthy_count += 1;
1573                    }
1574                }
1575
1576                // Update aggregate agent health status
1577                let was_healthy = entry.healthy.load(Ordering::Acquire);
1578                let is_healthy = healthy_count > 0;
1579                entry.healthy.store(is_healthy, Ordering::Release);
1580
1581                if was_healthy && !is_healthy {
1582                    warn!(agent_id = %agent_id, "Agent marked unhealthy");
1583                } else if !was_healthy && is_healthy {
1584                    info!(agent_id = %agent_id, "Agent recovered");
1585                }
1586
1587                // Try to reconnect failed connections
1588                if healthy_count < self.config.connections_per_agent
1589                    && entry.should_reconnect(self.config.reconnect_interval)
1590                {
1591                    drop(connections); // Release read lock before reconnect
1592                    if let Err(e) = self.reconnect_agent(&agent_id, &entry).await {
1593                        trace!(agent_id = %agent_id, error = %e, "Reconnect failed");
1594                    }
1595                }
1596            }
1597        }
1598    }
1599
1600    // =========================================================================
1601    // Internal Methods
1602    // =========================================================================
1603
1604    async fn create_connection(
1605        &self,
1606        agent_id: &str,
1607        endpoint: &str,
1608    ) -> Result<PooledConnection, AgentProtocolError> {
1609        // Detect transport type from endpoint
1610        let transport = if is_uds_endpoint(endpoint) {
1611            // Unix Domain Socket transport
1612            let socket_path = endpoint.strip_prefix("unix:").unwrap_or(endpoint);
1613
1614            let mut client =
1615                AgentClientV2Uds::new(agent_id, socket_path, self.config.request_timeout).await?;
1616
1617            // Set callbacks before connecting
1618            client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1619            client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1620
1621            client.connect().await?;
1622            V2Transport::Uds(client)
1623        } else {
1624            // gRPC transport (default)
1625            let mut client =
1626                AgentClientV2::new(agent_id, endpoint, self.config.request_timeout).await?;
1627
1628            // Set callbacks before connecting
1629            client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1630            client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1631
1632            client.connect().await?;
1633            V2Transport::Grpc(client)
1634        };
1635
1636        Ok(PooledConnection::new(
1637            transport,
1638            self.config.max_concurrent_per_connection,
1639        ))
1640    }
1641
1642    /// Select a connection for a request.
1643    ///
1644    /// # Performance
1645    ///
1646    /// This is the hot path. Optimizations:
1647    /// - Lock-free agent lookup via `DashMap::get()`
1648    /// - Uses `try_read()` to avoid blocking on connections lock
1649    /// - Cached health state (no async I/O)
1650    /// - All operations are synchronous
1651    ///
1652    /// # Errors
1653    ///
1654    /// Returns error if agent not found, no connections, or no healthy connections.
1655    fn select_connection(
1656        &self,
1657        agent_id: &str,
1658    ) -> Result<Arc<PooledConnection>, AgentProtocolError> {
1659        let entry = self.agents.get(agent_id).ok_or_else(|| {
1660            AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1661        })?;
1662
1663        // Try non-blocking read first; fall back to blocking if contended
1664        let connections_guard = match entry.connections.try_read() {
1665            Ok(guard) => guard,
1666            Err(_) => {
1667                // Blocking fallback - this should be rare
1668                trace!(agent_id = %agent_id, "select_connection: blocking on connections lock");
1669                futures::executor::block_on(entry.connections.read())
1670            }
1671        };
1672
1673        if connections_guard.is_empty() {
1674            return Err(AgentProtocolError::ConnectionFailed(format!(
1675                "No connections available for agent {}",
1676                agent_id
1677            )));
1678        }
1679
1680        // Filter to healthy connections using cached health (no I/O)
1681        let healthy: Vec<_> = connections_guard
1682            .iter()
1683            .filter(|c| c.is_healthy_cached())
1684            .cloned()
1685            .collect();
1686
1687        if healthy.is_empty() {
1688            return Err(AgentProtocolError::ConnectionFailed(format!(
1689                "No healthy connections for agent {}",
1690                agent_id
1691            )));
1692        }
1693
1694        let selected = match self.config.load_balance_strategy {
1695            LoadBalanceStrategy::RoundRobin => {
1696                let idx = entry.round_robin_index.fetch_add(1, Ordering::Relaxed);
1697                healthy[idx % healthy.len()].clone()
1698            }
1699            LoadBalanceStrategy::LeastConnections => healthy
1700                .iter()
1701                .min_by_key(|c| c.in_flight())
1702                .cloned()
1703                .unwrap(),
1704            LoadBalanceStrategy::HealthBased => {
1705                // Prefer connections with lower error rates
1706                healthy
1707                    .iter()
1708                    .min_by(|a, b| {
1709                        a.error_rate()
1710                            .partial_cmp(&b.error_rate())
1711                            .unwrap_or(std::cmp::Ordering::Equal)
1712                    })
1713                    .cloned()
1714                    .unwrap()
1715            }
1716            LoadBalanceStrategy::Random => {
1717                use std::collections::hash_map::RandomState;
1718                use std::hash::{BuildHasher, Hasher};
1719                let idx = RandomState::new().build_hasher().finish() as usize % healthy.len();
1720                healthy[idx].clone()
1721            }
1722        };
1723
1724        Ok(selected)
1725    }
1726
1727    async fn reconnect_agent(
1728        &self,
1729        agent_id: &str,
1730        entry: &AgentEntry,
1731    ) -> Result<(), AgentProtocolError> {
1732        entry.mark_reconnect_attempt();
1733        let attempts = entry.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
1734
1735        if attempts >= self.config.max_reconnect_attempts {
1736            debug!(
1737                agent_id = %agent_id,
1738                attempts = attempts,
1739                "Max reconnect attempts reached"
1740            );
1741            return Ok(());
1742        }
1743
1744        debug!(agent_id = %agent_id, attempt = attempts + 1, "Attempting reconnect");
1745
1746        match self.create_connection(agent_id, &entry.endpoint).await {
1747            Ok(conn) => {
1748                let mut connections = entry.connections.write().await;
1749                connections.push(Arc::new(conn));
1750                entry.reconnect_attempts.store(0, Ordering::Relaxed);
1751                info!(agent_id = %agent_id, "Reconnected successfully");
1752                Ok(())
1753            }
1754            Err(e) => {
1755                debug!(agent_id = %agent_id, error = %e, "Reconnect failed");
1756                Err(e)
1757            }
1758        }
1759    }
1760}
1761
1762impl Default for AgentPool {
1763    fn default() -> Self {
1764        Self::new()
1765    }
1766}
1767
1768impl std::fmt::Debug for AgentPool {
1769    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1770        f.debug_struct("AgentPool")
1771            .field("config", &self.config)
1772            .field(
1773                "total_requests",
1774                &self.total_requests.load(Ordering::Relaxed),
1775            )
1776            .field("total_errors", &self.total_errors.load(Ordering::Relaxed))
1777            .finish()
1778    }
1779}
1780
1781/// Check if an endpoint is a Unix Domain Socket path.
1782///
1783/// Returns true for endpoints that:
1784/// - Start with "unix:" prefix
1785/// - Are absolute paths (start with "/")
1786/// - Have ".sock" extension
1787fn is_uds_endpoint(endpoint: &str) -> bool {
1788    endpoint.starts_with("unix:") || endpoint.starts_with('/') || endpoint.ends_with(".sock")
1789}
1790
1791#[cfg(test)]
1792mod tests {
1793    use super::*;
1794
1795    #[test]
1796    fn test_pool_config_default() {
1797        let config = AgentPoolConfig::default();
1798        assert_eq!(config.connections_per_agent, 4);
1799        assert_eq!(
1800            config.load_balance_strategy,
1801            LoadBalanceStrategy::RoundRobin
1802        );
1803    }
1804
1805    #[test]
1806    fn test_load_balance_strategy() {
1807        assert_eq!(
1808            LoadBalanceStrategy::default(),
1809            LoadBalanceStrategy::RoundRobin
1810        );
1811    }
1812
1813    #[test]
1814    fn test_pool_creation() {
1815        let pool = AgentPool::new();
1816        assert_eq!(pool.total_requests.load(Ordering::Relaxed), 0);
1817        assert_eq!(pool.total_errors.load(Ordering::Relaxed), 0);
1818    }
1819
1820    #[test]
1821    fn test_pool_with_config() {
1822        let config = AgentPoolConfig {
1823            connections_per_agent: 8,
1824            load_balance_strategy: LoadBalanceStrategy::LeastConnections,
1825            ..Default::default()
1826        };
1827        let pool = AgentPool::with_config(config.clone());
1828        assert_eq!(pool.config.connections_per_agent, 8);
1829    }
1830
1831    #[test]
1832    fn test_agent_ids_empty() {
1833        let pool = AgentPool::new();
1834        assert!(pool.agent_ids().is_empty());
1835    }
1836
1837    #[test]
1838    fn test_is_agent_healthy_not_found() {
1839        let pool = AgentPool::new();
1840        assert!(!pool.is_agent_healthy("nonexistent"));
1841    }
1842
1843    #[tokio::test]
1844    async fn test_stats_empty() {
1845        let pool = AgentPool::new();
1846        assert!(pool.stats().await.is_empty());
1847    }
1848
1849    #[test]
1850    fn test_is_uds_endpoint() {
1851        // Unix prefix
1852        assert!(is_uds_endpoint("unix:/var/run/agent.sock"));
1853        assert!(is_uds_endpoint("unix:agent.sock"));
1854
1855        // Absolute path
1856        assert!(is_uds_endpoint("/var/run/agent.sock"));
1857        assert!(is_uds_endpoint("/tmp/test.sock"));
1858
1859        // .sock extension
1860        assert!(is_uds_endpoint("agent.sock"));
1861
1862        // Not UDS
1863        assert!(!is_uds_endpoint("http://localhost:8080"));
1864        assert!(!is_uds_endpoint("localhost:50051"));
1865        assert!(!is_uds_endpoint("127.0.0.1:8080"));
1866    }
1867
1868    #[test]
1869    fn test_flow_control_mode_default() {
1870        assert_eq!(FlowControlMode::default(), FlowControlMode::FailClosed);
1871    }
1872
1873    #[test]
1874    fn test_pool_config_flow_control_defaults() {
1875        let config = AgentPoolConfig::default();
1876        assert_eq!(config.channel_buffer_size, CHANNEL_BUFFER_SIZE);
1877        assert_eq!(config.flow_control_mode, FlowControlMode::FailClosed);
1878        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(100));
1879    }
1880
1881    #[test]
1882    fn test_pool_config_custom_flow_control() {
1883        let config = AgentPoolConfig {
1884            channel_buffer_size: 128,
1885            flow_control_mode: FlowControlMode::FailOpen,
1886            flow_control_wait_timeout: Duration::from_millis(500),
1887            ..Default::default()
1888        };
1889        assert_eq!(config.channel_buffer_size, 128);
1890        assert_eq!(config.flow_control_mode, FlowControlMode::FailOpen);
1891        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(500));
1892    }
1893
1894    #[test]
1895    fn test_pool_config_wait_and_retry() {
1896        let config = AgentPoolConfig {
1897            flow_control_mode: FlowControlMode::WaitAndRetry,
1898            flow_control_wait_timeout: Duration::from_millis(250),
1899            ..Default::default()
1900        };
1901        assert_eq!(config.flow_control_mode, FlowControlMode::WaitAndRetry);
1902        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(250));
1903    }
1904
1905    #[test]
1906    fn test_pool_config_sticky_session_default() {
1907        let config = AgentPoolConfig::default();
1908        assert_eq!(
1909            config.sticky_session_timeout,
1910            Some(Duration::from_secs(5 * 60))
1911        );
1912    }
1913
1914    #[test]
1915    fn test_pool_config_sticky_session_custom() {
1916        let config = AgentPoolConfig {
1917            sticky_session_timeout: Some(Duration::from_secs(60)),
1918            ..Default::default()
1919        };
1920        assert_eq!(config.sticky_session_timeout, Some(Duration::from_secs(60)));
1921    }
1922
1923    #[test]
1924    fn test_pool_config_sticky_session_disabled() {
1925        let config = AgentPoolConfig {
1926            sticky_session_timeout: None,
1927            ..Default::default()
1928        };
1929        assert!(config.sticky_session_timeout.is_none());
1930    }
1931
1932    #[test]
1933    fn test_sticky_session_count_empty() {
1934        let pool = AgentPool::new();
1935        assert_eq!(pool.sticky_session_count(), 0);
1936    }
1937
1938    #[test]
1939    fn test_sticky_session_has_nonexistent() {
1940        let pool = AgentPool::new();
1941        assert!(!pool.has_sticky_session("nonexistent"));
1942    }
1943
1944    #[test]
1945    fn test_sticky_session_clear_nonexistent() {
1946        let pool = AgentPool::new();
1947        // Should not panic
1948        pool.clear_sticky_session("nonexistent");
1949    }
1950
1951    #[test]
1952    fn test_cleanup_expired_sessions_empty() {
1953        let pool = AgentPool::new();
1954        let removed = pool.cleanup_expired_sessions();
1955        assert_eq!(removed, 0);
1956    }
1957}