sentinel_agent_protocol/v2/
pool.rs

1//! Agent connection pool for Protocol v2.
2//!
3//! This module provides a production-ready connection pool for managing
4//! multiple connections to agents with:
5//!
6//! - **Connection pooling**: Maintain multiple connections per agent
7//! - **Load balancing**: Round-robin, least-connections, or health-based routing
8//! - **Health tracking**: Route requests based on agent health
9//! - **Automatic reconnection**: Reconnect failed connections
10//! - **Graceful shutdown**: Drain connections before closing
11
12use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use dashmap::DashMap;
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, info, trace, warn};
19
20use crate::v2::client::{AgentClientV2, CancelReason, ConfigUpdateCallback, MetricsCallback};
21use crate::v2::control::ConfigUpdateType;
22use crate::v2::observability::{ConfigPusher, ConfigUpdateHandler, MetricsCollector};
23use crate::v2::protocol_metrics::ProtocolMetrics;
24use crate::v2::reverse::ReverseConnectionClient;
25use crate::v2::uds::AgentClientV2Uds;
26use crate::v2::AgentCapabilities;
27use crate::{
28    AgentProtocolError, AgentResponse, RequestBodyChunkEvent, RequestHeadersEvent,
29    ResponseBodyChunkEvent, ResponseHeadersEvent,
30};
31
32/// Channel buffer size for all transports.
33///
34/// This is aligned across UDS, gRPC, and reverse connections to ensure
35/// consistent backpressure behavior. A smaller buffer (64 vs 1024) means
36/// backpressure kicks in earlier, preventing memory buildup under load.
37///
38/// The value 64 balances:
39/// - Small enough to apply backpressure before memory issues
40/// - Large enough to handle burst traffic without blocking
41pub const CHANNEL_BUFFER_SIZE: usize = 64;
42
43/// Load balancing strategy for the connection pool.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum LoadBalanceStrategy {
46    /// Round-robin across all healthy connections
47    #[default]
48    RoundRobin,
49    /// Route to connection with fewest in-flight requests
50    LeastConnections,
51    /// Route based on health score (prefer healthier agents)
52    HealthBased,
53    /// Random selection
54    Random,
55}
56
57/// Flow control behavior when an agent signals it cannot accept requests.
58///
59/// When an agent sends a flow control "pause" signal, this determines
60/// whether requests should fail immediately or be allowed through.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum FlowControlMode {
63    /// Fail requests when agent is paused (default, safer).
64    ///
65    /// Returns `AgentProtocolError::FlowControlPaused` immediately.
66    /// Use this when you want strict backpressure and can handle
67    /// the error at the caller (e.g., return 503 to client).
68    #[default]
69    FailClosed,
70
71    /// Allow requests through even when agent is paused.
72    ///
73    /// Requests proceed without agent processing. Use this when
74    /// agent processing is optional (e.g., logging, analytics)
75    /// and you prefer availability over consistency.
76    FailOpen,
77
78    /// Queue requests briefly, then fail if still paused.
79    ///
80    /// Waits up to `flow_control_wait_timeout` for the agent to
81    /// resume before failing. Useful for transient pauses.
82    WaitAndRetry,
83}
84
85/// A sticky session entry tracking connection affinity for long-lived streams.
86///
87/// Used for WebSocket connections, Server-Sent Events, long-polling, and other
88/// streaming scenarios where the same agent connection should be used for the
89/// entire stream lifetime.
90struct StickySession {
91    /// The connection to use for this session
92    connection: Arc<PooledConnection>,
93    /// Agent ID for this session
94    agent_id: String,
95    /// When the session was created
96    created_at: Instant,
97    /// When the session was last accessed
98    last_accessed: AtomicU64,
99}
100
101impl std::fmt::Debug for StickySession {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("StickySession")
104            .field("agent_id", &self.agent_id)
105            .field("created_at", &self.created_at)
106            .finish_non_exhaustive()
107    }
108}
109
110/// Configuration for the agent connection pool.
111#[derive(Debug, Clone)]
112pub struct AgentPoolConfig {
113    /// Number of connections to maintain per agent
114    pub connections_per_agent: usize,
115    /// Load balancing strategy
116    pub load_balance_strategy: LoadBalanceStrategy,
117    /// Connection timeout
118    pub connect_timeout: Duration,
119    /// Request timeout
120    pub request_timeout: Duration,
121    /// Time between reconnection attempts
122    pub reconnect_interval: Duration,
123    /// Maximum reconnection attempts before marking agent unhealthy
124    pub max_reconnect_attempts: usize,
125    /// Time to wait for in-flight requests during shutdown
126    pub drain_timeout: Duration,
127    /// Maximum concurrent requests per connection
128    pub max_concurrent_per_connection: usize,
129    /// Health check interval
130    pub health_check_interval: Duration,
131    /// Channel buffer size for all transports.
132    ///
133    /// Controls backpressure behavior. Smaller values (16-64) apply
134    /// backpressure earlier, preventing memory buildup. Larger values
135    /// (128-256) handle burst traffic better but use more memory.
136    ///
137    /// Default: 64
138    pub channel_buffer_size: usize,
139    /// Flow control behavior when an agent signals it cannot accept requests.
140    ///
141    /// Default: `FlowControlMode::FailClosed`
142    pub flow_control_mode: FlowControlMode,
143    /// Timeout for `FlowControlMode::WaitAndRetry` before failing.
144    ///
145    /// Only used when `flow_control_mode` is `WaitAndRetry`.
146    /// Default: 100ms
147    pub flow_control_wait_timeout: Duration,
148    /// Timeout for sticky sessions before they expire.
149    ///
150    /// Sticky sessions are used for long-lived streaming connections
151    /// (WebSocket, SSE, long-polling) to ensure the same agent connection
152    /// is used for the entire stream lifetime.
153    ///
154    /// Set to None to disable automatic expiry (sessions only cleared explicitly).
155    ///
156    /// Default: 5 minutes
157    pub sticky_session_timeout: Option<Duration>,
158}
159
160impl Default for AgentPoolConfig {
161    fn default() -> Self {
162        Self {
163            connections_per_agent: 4,
164            load_balance_strategy: LoadBalanceStrategy::RoundRobin,
165            connect_timeout: Duration::from_secs(5),
166            request_timeout: Duration::from_secs(30),
167            reconnect_interval: Duration::from_secs(5),
168            max_reconnect_attempts: 3,
169            drain_timeout: Duration::from_secs(30),
170            max_concurrent_per_connection: 100,
171            health_check_interval: Duration::from_secs(10),
172            channel_buffer_size: CHANNEL_BUFFER_SIZE,
173            flow_control_mode: FlowControlMode::FailClosed,
174            flow_control_wait_timeout: Duration::from_millis(100),
175            sticky_session_timeout: Some(Duration::from_secs(5 * 60)), // 5 minutes
176        }
177    }
178}
179
180impl StickySession {
181    fn new(agent_id: String, connection: Arc<PooledConnection>) -> Self {
182        Self {
183            connection,
184            agent_id,
185            created_at: Instant::now(),
186            last_accessed: AtomicU64::new(0),
187        }
188    }
189
190    fn touch(&self) {
191        let offset = self.created_at.elapsed().as_millis() as u64;
192        self.last_accessed.store(offset, Ordering::Relaxed);
193    }
194
195    fn last_accessed(&self) -> Instant {
196        let offset_ms = self.last_accessed.load(Ordering::Relaxed);
197        self.created_at + Duration::from_millis(offset_ms)
198    }
199
200    fn is_expired(&self, timeout: Duration) -> bool {
201        self.last_accessed().elapsed() > timeout
202    }
203}
204
205/// Transport layer for v2 agent connections.
206///
207/// Supports gRPC, Unix Domain Socket, and reverse connections.
208pub enum V2Transport {
209    /// gRPC over HTTP/2
210    Grpc(AgentClientV2),
211    /// Binary protocol over Unix Domain Socket
212    Uds(AgentClientV2Uds),
213    /// Reverse connection (agent connected to proxy)
214    Reverse(ReverseConnectionClient),
215}
216
217impl V2Transport {
218    /// Check if the transport is connected.
219    pub async fn is_connected(&self) -> bool {
220        match self {
221            V2Transport::Grpc(client) => client.is_connected().await,
222            V2Transport::Uds(client) => client.is_connected().await,
223            V2Transport::Reverse(client) => client.is_connected().await,
224        }
225    }
226
227    /// Check if the transport can accept new requests.
228    ///
229    /// Returns false if the agent has sent a flow control pause signal.
230    pub async fn can_accept_requests(&self) -> bool {
231        match self {
232            V2Transport::Grpc(client) => client.can_accept_requests().await,
233            V2Transport::Uds(client) => client.can_accept_requests().await,
234            V2Transport::Reverse(client) => client.can_accept_requests().await,
235        }
236    }
237
238    /// Get negotiated capabilities.
239    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
240        match self {
241            V2Transport::Grpc(client) => client.capabilities().await,
242            V2Transport::Uds(client) => client.capabilities().await,
243            V2Transport::Reverse(client) => client.capabilities().await,
244        }
245    }
246
247    /// Send a request headers event.
248    pub async fn send_request_headers(
249        &self,
250        correlation_id: &str,
251        event: &RequestHeadersEvent,
252    ) -> Result<AgentResponse, AgentProtocolError> {
253        match self {
254            V2Transport::Grpc(client) => client.send_request_headers(correlation_id, event).await,
255            V2Transport::Uds(client) => client.send_request_headers(correlation_id, event).await,
256            V2Transport::Reverse(client) => client.send_request_headers(correlation_id, event).await,
257        }
258    }
259
260    /// Send a request body chunk event.
261    pub async fn send_request_body_chunk(
262        &self,
263        correlation_id: &str,
264        event: &RequestBodyChunkEvent,
265    ) -> Result<AgentResponse, AgentProtocolError> {
266        match self {
267            V2Transport::Grpc(client) => client.send_request_body_chunk(correlation_id, event).await,
268            V2Transport::Uds(client) => client.send_request_body_chunk(correlation_id, event).await,
269            V2Transport::Reverse(client) => client.send_request_body_chunk(correlation_id, event).await,
270        }
271    }
272
273    /// Send a response headers event.
274    pub async fn send_response_headers(
275        &self,
276        correlation_id: &str,
277        event: &ResponseHeadersEvent,
278    ) -> Result<AgentResponse, AgentProtocolError> {
279        match self {
280            V2Transport::Grpc(client) => client.send_response_headers(correlation_id, event).await,
281            V2Transport::Uds(client) => client.send_response_headers(correlation_id, event).await,
282            V2Transport::Reverse(client) => client.send_response_headers(correlation_id, event).await,
283        }
284    }
285
286    /// Send a response body chunk event.
287    pub async fn send_response_body_chunk(
288        &self,
289        correlation_id: &str,
290        event: &ResponseBodyChunkEvent,
291    ) -> Result<AgentResponse, AgentProtocolError> {
292        match self {
293            V2Transport::Grpc(client) => client.send_response_body_chunk(correlation_id, event).await,
294            V2Transport::Uds(client) => client.send_response_body_chunk(correlation_id, event).await,
295            V2Transport::Reverse(client) => client.send_response_body_chunk(correlation_id, event).await,
296        }
297    }
298
299    /// Cancel a specific request.
300    pub async fn cancel_request(
301        &self,
302        correlation_id: &str,
303        reason: CancelReason,
304    ) -> Result<(), AgentProtocolError> {
305        match self {
306            V2Transport::Grpc(client) => client.cancel_request(correlation_id, reason).await,
307            V2Transport::Uds(client) => client.cancel_request(correlation_id, reason).await,
308            V2Transport::Reverse(client) => client.cancel_request(correlation_id, reason).await,
309        }
310    }
311
312    /// Cancel all in-flight requests.
313    pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
314        match self {
315            V2Transport::Grpc(client) => client.cancel_all(reason).await,
316            V2Transport::Uds(client) => client.cancel_all(reason).await,
317            V2Transport::Reverse(client) => client.cancel_all(reason).await,
318        }
319    }
320
321    /// Close the transport.
322    pub async fn close(&self) -> Result<(), AgentProtocolError> {
323        match self {
324            V2Transport::Grpc(client) => client.close().await,
325            V2Transport::Uds(client) => client.close().await,
326            V2Transport::Reverse(client) => client.close().await,
327        }
328    }
329
330    /// Get agent ID.
331    pub fn agent_id(&self) -> &str {
332        match self {
333            V2Transport::Grpc(client) => client.agent_id(),
334            V2Transport::Uds(client) => client.agent_id(),
335            V2Transport::Reverse(client) => client.agent_id(),
336        }
337    }
338}
339
340/// A pooled connection to an agent.
341struct PooledConnection {
342    client: V2Transport,
343    created_at: Instant,
344    /// Milliseconds since created_at when last used (avoids RwLock in hot path)
345    last_used_offset_ms: AtomicU64,
346    in_flight: AtomicU64,
347    request_count: AtomicU64,
348    error_count: AtomicU64,
349    consecutive_errors: AtomicU64,
350    concurrency_limiter: Semaphore,
351    /// Cached health state - updated by background maintenance, read in hot path
352    healthy_cached: AtomicBool,
353}
354
355impl PooledConnection {
356    fn new(client: V2Transport, max_concurrent: usize) -> Self {
357        Self {
358            client,
359            created_at: Instant::now(),
360            last_used_offset_ms: AtomicU64::new(0),
361            in_flight: AtomicU64::new(0),
362            request_count: AtomicU64::new(0),
363            error_count: AtomicU64::new(0),
364            consecutive_errors: AtomicU64::new(0),
365            concurrency_limiter: Semaphore::new(max_concurrent),
366            healthy_cached: AtomicBool::new(true), // Assume healthy until proven otherwise
367        }
368    }
369
370    fn in_flight(&self) -> u64 {
371        self.in_flight.load(Ordering::Relaxed)
372    }
373
374    fn error_rate(&self) -> f64 {
375        let requests = self.request_count.load(Ordering::Relaxed);
376        let errors = self.error_count.load(Ordering::Relaxed);
377        if requests == 0 {
378            0.0
379        } else {
380            errors as f64 / requests as f64
381        }
382    }
383
384    /// Fast health check using cached state (no async, no I/O).
385    /// Updated by background maintenance task.
386    #[inline]
387    fn is_healthy_cached(&self) -> bool {
388        self.healthy_cached.load(Ordering::Acquire)
389    }
390
391    /// Full health check with I/O - only called by maintenance task.
392    async fn check_and_update_health(&self) -> bool {
393        let connected = self.client.is_connected().await;
394        let low_errors = self.consecutive_errors.load(Ordering::Relaxed) < 3;
395        let can_accept = self.client.can_accept_requests().await;
396
397        let healthy = connected && low_errors && can_accept;
398        self.healthy_cached.store(healthy, Ordering::Release);
399        healthy
400    }
401
402    /// Record that this connection was just used.
403    #[inline]
404    fn touch(&self) {
405        let offset = self.created_at.elapsed().as_millis() as u64;
406        self.last_used_offset_ms.store(offset, Ordering::Relaxed);
407    }
408
409    /// Get the last used time.
410    fn last_used(&self) -> Instant {
411        let offset_ms = self.last_used_offset_ms.load(Ordering::Relaxed);
412        self.created_at + Duration::from_millis(offset_ms)
413    }
414}
415
416/// Statistics for a single agent in the pool.
417#[derive(Debug, Clone)]
418pub struct AgentPoolStats {
419    /// Agent identifier
420    pub agent_id: String,
421    /// Number of active connections
422    pub active_connections: usize,
423    /// Number of healthy connections
424    pub healthy_connections: usize,
425    /// Total in-flight requests across all connections
426    pub total_in_flight: u64,
427    /// Total requests processed
428    pub total_requests: u64,
429    /// Total errors
430    pub total_errors: u64,
431    /// Average error rate
432    pub error_rate: f64,
433    /// Whether the agent is considered healthy
434    pub is_healthy: bool,
435}
436
437/// An agent entry in the pool.
438struct AgentEntry {
439    agent_id: String,
440    endpoint: String,
441    /// Connections are rarely modified (only on reconnect), so RwLock is acceptable here.
442    /// The hot-path reads use try_read() to avoid blocking.
443    connections: RwLock<Vec<Arc<PooledConnection>>>,
444    capabilities: RwLock<Option<AgentCapabilities>>,
445    round_robin_index: AtomicUsize,
446    reconnect_attempts: AtomicUsize,
447    /// Stored as millis since UNIX_EPOCH to avoid RwLock
448    last_reconnect_attempt_ms: AtomicU64,
449    /// Cached aggregate health - true if any connection is healthy
450    healthy: AtomicBool,
451}
452
453impl AgentEntry {
454    fn new(agent_id: String, endpoint: String) -> Self {
455        Self {
456            agent_id,
457            endpoint,
458            connections: RwLock::new(Vec::new()),
459            capabilities: RwLock::new(None),
460            round_robin_index: AtomicUsize::new(0),
461            reconnect_attempts: AtomicUsize::new(0),
462            last_reconnect_attempt_ms: AtomicU64::new(0),
463            healthy: AtomicBool::new(true),
464        }
465    }
466
467    /// Check if enough time has passed since last reconnect attempt.
468    fn should_reconnect(&self, interval: Duration) -> bool {
469        let last_ms = self.last_reconnect_attempt_ms.load(Ordering::Relaxed);
470        if last_ms == 0 {
471            return true;
472        }
473        let now_ms = std::time::SystemTime::now()
474            .duration_since(std::time::UNIX_EPOCH)
475            .map(|d| d.as_millis() as u64)
476            .unwrap_or(0);
477        now_ms.saturating_sub(last_ms) > interval.as_millis() as u64
478    }
479
480    /// Record that a reconnect attempt was made.
481    fn mark_reconnect_attempt(&self) {
482        let now_ms = std::time::SystemTime::now()
483            .duration_since(std::time::UNIX_EPOCH)
484            .map(|d| d.as_millis() as u64)
485            .unwrap_or(0);
486        self.last_reconnect_attempt_ms.store(now_ms, Ordering::Relaxed);
487    }
488}
489
490/// Agent connection pool.
491///
492/// Manages multiple connections to multiple agents with load balancing,
493/// health tracking, automatic reconnection, and metrics collection.
494///
495/// # Performance
496///
497/// Uses `DashMap` for lock-free reads in the hot path. Agent lookup is O(1)
498/// without contention. Connection selection uses cached health state to avoid
499/// async I/O per request.
500pub struct AgentPool {
501    config: AgentPoolConfig,
502    /// Lock-free concurrent map for agent lookup.
503    /// Reads (select_connection) are lock-free. Writes (add/remove agent) shard-lock.
504    agents: DashMap<String, Arc<AgentEntry>>,
505    total_requests: AtomicU64,
506    total_errors: AtomicU64,
507    /// Shared metrics collector for all agents
508    metrics_collector: Arc<MetricsCollector>,
509    /// Callback used to record metrics from clients
510    metrics_callback: MetricsCallback,
511    /// Config pusher for distributing config updates to agents
512    config_pusher: Arc<ConfigPusher>,
513    /// Handler for config update requests from agents
514    config_update_handler: Arc<ConfigUpdateHandler>,
515    /// Callback used to handle config updates from clients
516    config_update_callback: ConfigUpdateCallback,
517    /// Protocol-level metrics (proxy-side instrumentation)
518    protocol_metrics: Arc<ProtocolMetrics>,
519    /// Connection affinity: correlation_id → connection used for headers.
520    /// Ensures body chunks go to the same connection as headers for streaming.
521    correlation_affinity: DashMap<String, Arc<PooledConnection>>,
522    /// Sticky sessions: session_id → session info for long-lived streams.
523    /// Used for WebSocket, SSE, and long-polling connections.
524    sticky_sessions: DashMap<String, StickySession>,
525}
526
527impl AgentPool {
528    /// Create a new agent pool with default configuration.
529    pub fn new() -> Self {
530        Self::with_config(AgentPoolConfig::default())
531    }
532
533    /// Create a new agent pool with custom configuration.
534    pub fn with_config(config: AgentPoolConfig) -> Self {
535        let metrics_collector = Arc::new(MetricsCollector::new());
536        let collector_clone = Arc::clone(&metrics_collector);
537
538        // Create a callback that records metrics to the collector
539        let metrics_callback: MetricsCallback = Arc::new(move |report| {
540            collector_clone.record(&report);
541        });
542
543        // Create config pusher and handler
544        let config_pusher = Arc::new(ConfigPusher::new());
545        let config_update_handler = Arc::new(ConfigUpdateHandler::new());
546        let handler_clone = Arc::clone(&config_update_handler);
547
548        // Create a callback that handles config update requests from agents
549        let config_update_callback: ConfigUpdateCallback = Arc::new(move |agent_id, request| {
550            debug!(
551                agent_id = %agent_id,
552                request_id = %request.request_id,
553                "Processing config update request from agent"
554            );
555            handler_clone.handle(request)
556        });
557
558        Self {
559            config,
560            agents: DashMap::new(),
561            total_requests: AtomicU64::new(0),
562            total_errors: AtomicU64::new(0),
563            metrics_collector,
564            metrics_callback,
565            config_pusher,
566            config_update_handler,
567            config_update_callback,
568            protocol_metrics: Arc::new(ProtocolMetrics::new()),
569            correlation_affinity: DashMap::new(),
570            sticky_sessions: DashMap::new(),
571        }
572    }
573
574    /// Get the protocol metrics for accessing proxy-side instrumentation.
575    pub fn protocol_metrics(&self) -> &ProtocolMetrics {
576        &self.protocol_metrics
577    }
578
579    /// Get an Arc to the protocol metrics.
580    pub fn protocol_metrics_arc(&self) -> Arc<ProtocolMetrics> {
581        Arc::clone(&self.protocol_metrics)
582    }
583
584    /// Get the metrics collector for accessing aggregated agent metrics.
585    pub fn metrics_collector(&self) -> &MetricsCollector {
586        &self.metrics_collector
587    }
588
589    /// Get an Arc to the metrics collector.
590    ///
591    /// This is useful for registering the collector with a MetricsManager.
592    pub fn metrics_collector_arc(&self) -> Arc<MetricsCollector> {
593        Arc::clone(&self.metrics_collector)
594    }
595
596    /// Export all agent metrics in Prometheus format.
597    pub fn export_prometheus(&self) -> String {
598        self.metrics_collector.export_prometheus()
599    }
600
601    /// Clear connection affinity for a correlation ID.
602    ///
603    /// Call this when a request is complete (after receiving final response)
604    /// to free the affinity mapping. Not strictly necessary (affinities are
605    /// cheap), but good practice for long-running proxies.
606    pub fn clear_correlation_affinity(&self, correlation_id: &str) {
607        self.correlation_affinity.remove(correlation_id);
608    }
609
610    /// Get the number of active correlation affinities.
611    ///
612    /// This is useful for monitoring and debugging.
613    pub fn correlation_affinity_count(&self) -> usize {
614        self.correlation_affinity.len()
615    }
616
617    // =========================================================================
618    // Sticky Sessions
619    // =========================================================================
620
621    /// Create a sticky session for a long-lived streaming connection.
622    ///
623    /// Sticky sessions ensure that all requests for a given session use the
624    /// same agent connection. This is essential for:
625    /// - WebSocket connections
626    /// - Server-Sent Events (SSE)
627    /// - Long-polling
628    /// - Any streaming scenario requiring connection affinity
629    ///
630    /// # Arguments
631    ///
632    /// * `session_id` - A unique identifier for this session (e.g., WebSocket connection ID)
633    /// * `agent_id` - The agent to bind this session to
634    ///
635    /// # Returns
636    ///
637    /// Returns `Ok(())` if the session was created, or an error if the agent
638    /// is not found or has no healthy connections.
639    ///
640    /// # Example
641    ///
642    /// ```ignore
643    /// // When a WebSocket is established
644    /// pool.create_sticky_session("ws-12345", "waf-agent").await?;
645    ///
646    /// // All subsequent messages use the same connection
647    /// pool.send_request_with_sticky_session("ws-12345", &event).await?;
648    ///
649    /// // When the WebSocket closes
650    /// pool.clear_sticky_session("ws-12345");
651    /// ```
652    pub fn create_sticky_session(
653        &self,
654        session_id: impl Into<String>,
655        agent_id: &str,
656    ) -> Result<(), AgentProtocolError> {
657        let session_id = session_id.into();
658        let conn = self.select_connection(agent_id)?;
659
660        let session = StickySession::new(agent_id.to_string(), conn);
661        session.touch();
662
663        self.sticky_sessions.insert(session_id.clone(), session);
664
665        debug!(
666            session_id = %session_id,
667            agent_id = %agent_id,
668            "Created sticky session"
669        );
670
671        Ok(())
672    }
673
674    /// Get the connection for a sticky session (internal use).
675    ///
676    /// Returns the connection bound to this session, or None if the session
677    /// doesn't exist or has expired.
678    fn get_sticky_session_conn(&self, session_id: &str) -> Option<Arc<PooledConnection>> {
679        let entry = self.sticky_sessions.get(session_id)?;
680
681        // Check expiration if configured
682        if let Some(timeout) = self.config.sticky_session_timeout {
683            if entry.is_expired(timeout) {
684                drop(entry); // Release the reference before removing
685                self.sticky_sessions.remove(session_id);
686                debug!(session_id = %session_id, "Sticky session expired");
687                return None;
688            }
689        }
690
691        entry.touch();
692        Some(Arc::clone(&entry.connection))
693    }
694
695    /// Refresh a sticky session, updating its last-accessed time.
696    ///
697    /// Returns true if the session exists and was refreshed, false otherwise.
698    pub fn refresh_sticky_session(&self, session_id: &str) -> bool {
699        self.get_sticky_session_conn(session_id).is_some()
700    }
701
702    /// Check if a sticky session exists and is valid.
703    pub fn has_sticky_session(&self, session_id: &str) -> bool {
704        self.get_sticky_session_conn(session_id).is_some()
705    }
706
707    /// Clear a sticky session.
708    ///
709    /// Call this when a long-lived stream ends (WebSocket closed, SSE ended, etc.)
710    pub fn clear_sticky_session(&self, session_id: &str) {
711        if self.sticky_sessions.remove(session_id).is_some() {
712            debug!(session_id = %session_id, "Cleared sticky session");
713        }
714    }
715
716    /// Get the number of active sticky sessions.
717    ///
718    /// Useful for monitoring and debugging.
719    pub fn sticky_session_count(&self) -> usize {
720        self.sticky_sessions.len()
721    }
722
723    /// Get the agent ID bound to a sticky session.
724    pub fn sticky_session_agent(&self, session_id: &str) -> Option<String> {
725        self.sticky_sessions.get(session_id).map(|s| s.agent_id.clone())
726    }
727
728    /// Send a request using a sticky session.
729    ///
730    /// If the session exists and is valid, uses the bound connection.
731    /// If the session doesn't exist or has expired, falls back to normal
732    /// connection selection.
733    ///
734    /// # Returns
735    ///
736    /// A tuple of (response, used_sticky_session).
737    pub async fn send_request_headers_with_sticky_session(
738        &self,
739        session_id: &str,
740        agent_id: &str,
741        correlation_id: &str,
742        event: &RequestHeadersEvent,
743    ) -> Result<(AgentResponse, bool), AgentProtocolError> {
744        let start = Instant::now();
745        self.total_requests.fetch_add(1, Ordering::Relaxed);
746        self.protocol_metrics.inc_requests();
747        self.protocol_metrics.inc_in_flight();
748
749        // Try sticky session first
750        let (conn, used_sticky) = if let Some(sticky_conn) = self.get_sticky_session_conn(session_id) {
751            (sticky_conn, true)
752        } else {
753            (self.select_connection(agent_id)?, false)
754        };
755
756        // Check flow control
757        match self.check_flow_control(&conn, agent_id).await {
758            Ok(true) => {}
759            Ok(false) => {
760                self.protocol_metrics.dec_in_flight();
761                return Ok((AgentResponse::default_allow(), used_sticky));
762            }
763            Err(e) => {
764                self.protocol_metrics.dec_in_flight();
765                return Err(e);
766            }
767        }
768
769        // Acquire concurrency permit
770        let _permit = conn
771            .concurrency_limiter
772            .acquire()
773            .await
774            .map_err(|_| {
775                self.protocol_metrics.dec_in_flight();
776                self.protocol_metrics.inc_connection_errors();
777                AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
778            })?;
779
780        conn.in_flight.fetch_add(1, Ordering::Relaxed);
781        conn.touch();
782
783        // Store correlation affinity
784        self.correlation_affinity.insert(correlation_id.to_string(), Arc::clone(&conn));
785
786        let result = conn.client.send_request_headers(correlation_id, event).await;
787
788        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
789        conn.request_count.fetch_add(1, Ordering::Relaxed);
790        self.protocol_metrics.dec_in_flight();
791        self.protocol_metrics.record_request_duration(start.elapsed());
792
793        match &result {
794            Ok(_) => {
795                conn.consecutive_errors.store(0, Ordering::Relaxed);
796                self.protocol_metrics.inc_responses();
797            }
798            Err(e) => {
799                conn.error_count.fetch_add(1, Ordering::Relaxed);
800                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
801                self.total_errors.fetch_add(1, Ordering::Relaxed);
802
803                match e {
804                    AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
805                    AgentProtocolError::ConnectionFailed(_) | AgentProtocolError::ConnectionClosed => {
806                        self.protocol_metrics.inc_connection_errors();
807                    }
808                    AgentProtocolError::Serialization(_) => {
809                        self.protocol_metrics.inc_serialization_errors();
810                    }
811                    _ => {}
812                }
813
814                if consecutive >= 3 {
815                    conn.healthy_cached.store(false, Ordering::Release);
816                }
817            }
818        }
819
820        result.map(|r| (r, used_sticky))
821    }
822
823    /// Clean up expired sticky sessions.
824    ///
825    /// Called automatically by the maintenance task, but can also be called
826    /// manually to immediately reclaim resources.
827    pub fn cleanup_expired_sessions(&self) -> usize {
828        let Some(timeout) = self.config.sticky_session_timeout else {
829            return 0;
830        };
831
832        let mut removed = 0;
833        self.sticky_sessions.retain(|session_id, session| {
834            if session.is_expired(timeout) {
835                debug!(session_id = %session_id, "Removing expired sticky session");
836                removed += 1;
837                false
838            } else {
839                true
840            }
841        });
842
843        if removed > 0 {
844            trace!(removed = removed, "Cleaned up expired sticky sessions");
845        }
846
847        removed
848    }
849
850    /// Get the config pusher for pushing configuration updates to agents.
851    pub fn config_pusher(&self) -> &ConfigPusher {
852        &self.config_pusher
853    }
854
855    /// Get the config update handler for processing agent config requests.
856    pub fn config_update_handler(&self) -> &ConfigUpdateHandler {
857        &self.config_update_handler
858    }
859
860    /// Push a configuration update to a specific agent.
861    ///
862    /// Returns the push ID if the agent supports config push, None otherwise.
863    pub fn push_config_to_agent(&self, agent_id: &str, update_type: ConfigUpdateType) -> Option<String> {
864        self.config_pusher.push_to_agent(agent_id, update_type)
865    }
866
867    /// Push a configuration update to all agents that support config push.
868    ///
869    /// Returns the push IDs for each agent that received the update.
870    pub fn push_config_to_all(&self, update_type: ConfigUpdateType) -> Vec<String> {
871        self.config_pusher.push_to_all(update_type)
872    }
873
874    /// Acknowledge a config push by its push ID.
875    pub fn acknowledge_config_push(&self, push_id: &str, accepted: bool, error: Option<String>) {
876        self.config_pusher.acknowledge(push_id, accepted, error);
877    }
878
879    /// Add an agent to the pool.
880    ///
881    /// This creates the configured number of connections to the agent.
882    pub async fn add_agent(
883        &self,
884        agent_id: impl Into<String>,
885        endpoint: impl Into<String>,
886    ) -> Result<(), AgentProtocolError> {
887        let agent_id = agent_id.into();
888        let endpoint = endpoint.into();
889
890        info!(agent_id = %agent_id, endpoint = %endpoint, "Adding agent to pool");
891
892        let entry = Arc::new(AgentEntry::new(agent_id.clone(), endpoint.clone()));
893
894        // Create initial connections
895        let mut connections = Vec::with_capacity(self.config.connections_per_agent);
896        for i in 0..self.config.connections_per_agent {
897            match self.create_connection(&agent_id, &endpoint).await {
898                Ok(conn) => {
899                    connections.push(Arc::new(conn));
900                    debug!(
901                        agent_id = %agent_id,
902                        connection = i,
903                        "Created connection"
904                    );
905                }
906                Err(e) => {
907                    warn!(
908                        agent_id = %agent_id,
909                        connection = i,
910                        error = %e,
911                        "Failed to create connection"
912                    );
913                    // Continue - we'll try to reconnect later
914                }
915            }
916        }
917
918        if connections.is_empty() {
919            return Err(AgentProtocolError::ConnectionFailed(format!(
920                "Failed to create any connections to agent {}",
921                agent_id
922            )));
923        }
924
925        // Store capabilities from first successful connection and register with ConfigPusher
926        if let Some(conn) = connections.first() {
927            if let Some(caps) = conn.client.capabilities().await {
928                // Register with ConfigPusher based on capabilities
929                let supports_config_push = caps.features.config_push;
930                let agent_name = caps.name.clone();
931                self.config_pusher.register_agent(
932                    &agent_id,
933                    &agent_name,
934                    supports_config_push,
935                );
936                debug!(
937                    agent_id = %agent_id,
938                    supports_config_push = supports_config_push,
939                    "Registered agent with ConfigPusher"
940                );
941
942                *entry.capabilities.write().await = Some(caps);
943            }
944        }
945
946        *entry.connections.write().await = connections;
947        self.agents.insert(agent_id.clone(), entry);
948
949        info!(
950            agent_id = %agent_id,
951            connections = self.config.connections_per_agent,
952            "Agent added to pool"
953        );
954
955        Ok(())
956    }
957
958    /// Remove an agent from the pool.
959    ///
960    /// This gracefully closes all connections to the agent.
961    pub async fn remove_agent(&self, agent_id: &str) -> Result<(), AgentProtocolError> {
962        info!(agent_id = %agent_id, "Removing agent from pool");
963
964        // Unregister from ConfigPusher
965        self.config_pusher.unregister_agent(agent_id);
966
967        let (_, entry) = self
968            .agents
969            .remove(agent_id)
970            .ok_or_else(|| AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id)))?;
971
972        // Close all connections
973        let connections = entry.connections.read().await;
974        for conn in connections.iter() {
975            let _ = conn.client.close().await;
976        }
977
978        info!(agent_id = %agent_id, "Agent removed from pool");
979        Ok(())
980    }
981
982    /// Add a reverse connection to the pool.
983    ///
984    /// This is called by the ReverseConnectionListener when an agent connects.
985    /// The connection is wrapped in a V2Transport and added to the agent's
986    /// connection pool.
987    pub async fn add_reverse_connection(
988        &self,
989        agent_id: &str,
990        client: ReverseConnectionClient,
991        capabilities: AgentCapabilities,
992    ) -> Result<(), AgentProtocolError> {
993        info!(
994            agent_id = %agent_id,
995            connection_id = %client.connection_id(),
996            "Adding reverse connection to pool"
997        );
998
999        let transport = V2Transport::Reverse(client);
1000        let conn = Arc::new(PooledConnection::new(
1001            transport,
1002            self.config.max_concurrent_per_connection,
1003        ));
1004
1005        // Check if agent already exists (use entry API for atomic check-and-insert)
1006        if let Some(entry) = self.agents.get(agent_id) {
1007            // Add to existing agent's connections
1008            let mut connections = entry.connections.write().await;
1009
1010            // Check connection limit
1011            if connections.len() >= self.config.connections_per_agent {
1012                warn!(
1013                    agent_id = %agent_id,
1014                    current = connections.len(),
1015                    max = self.config.connections_per_agent,
1016                    "Reverse connection rejected: at connection limit"
1017                );
1018                return Err(AgentProtocolError::ConnectionFailed(format!(
1019                    "Agent {} already has maximum connections ({})",
1020                    agent_id, self.config.connections_per_agent
1021                )));
1022            }
1023
1024            connections.push(conn);
1025            info!(
1026                agent_id = %agent_id,
1027                total_connections = connections.len(),
1028                "Added reverse connection to existing agent"
1029            );
1030        } else {
1031            // Create new agent entry
1032            let entry = Arc::new(AgentEntry::new(
1033                agent_id.to_string(),
1034                format!("reverse://{}", agent_id),
1035            ));
1036
1037            // Register with ConfigPusher
1038            let supports_config_push = capabilities.features.config_push;
1039            let agent_name = capabilities.name.clone();
1040            self.config_pusher.register_agent(
1041                agent_id,
1042                &agent_name,
1043                supports_config_push,
1044            );
1045            debug!(
1046                agent_id = %agent_id,
1047                supports_config_push = supports_config_push,
1048                "Registered reverse connection agent with ConfigPusher"
1049            );
1050
1051            *entry.capabilities.write().await = Some(capabilities);
1052            *entry.connections.write().await = vec![conn];
1053            self.agents.insert(agent_id.to_string(), entry);
1054
1055            info!(
1056                agent_id = %agent_id,
1057                "Created new agent entry for reverse connection"
1058            );
1059        }
1060
1061        Ok(())
1062    }
1063
1064    /// Check flow control and handle according to configured mode.
1065    ///
1066    /// Returns `Ok(true)` if request should proceed normally.
1067    /// Returns `Ok(false)` if request should skip agent (FailOpen mode).
1068    /// Returns `Err` if request should fail (FailClosed or WaitAndRetry timeout).
1069    async fn check_flow_control(
1070        &self,
1071        conn: &PooledConnection,
1072        agent_id: &str,
1073    ) -> Result<bool, AgentProtocolError> {
1074        if conn.client.can_accept_requests().await {
1075            return Ok(true);
1076        }
1077
1078        match self.config.flow_control_mode {
1079            FlowControlMode::FailClosed => {
1080                self.protocol_metrics.record_flow_rejection();
1081                Err(AgentProtocolError::FlowControlPaused {
1082                    agent_id: agent_id.to_string(),
1083                })
1084            }
1085            FlowControlMode::FailOpen => {
1086                // Log but allow through
1087                debug!(agent_id = %agent_id, "Flow control: agent paused, allowing request (fail-open mode)");
1088                self.protocol_metrics.record_flow_rejection();
1089                Ok(false) // Signal to skip agent processing
1090            }
1091            FlowControlMode::WaitAndRetry => {
1092                // Wait briefly for agent to resume
1093                let deadline = Instant::now() + self.config.flow_control_wait_timeout;
1094                while Instant::now() < deadline {
1095                    tokio::time::sleep(Duration::from_millis(10)).await;
1096                    if conn.client.can_accept_requests().await {
1097                        trace!(agent_id = %agent_id, "Flow control: agent resumed after wait");
1098                        return Ok(true);
1099                    }
1100                }
1101                // Timeout - fail the request
1102                self.protocol_metrics.record_flow_rejection();
1103                Err(AgentProtocolError::FlowControlPaused {
1104                    agent_id: agent_id.to_string(),
1105                })
1106            }
1107        }
1108    }
1109
1110    /// Send a request headers event to an agent.
1111    ///
1112    /// The pool selects the best connection based on the load balancing strategy.
1113    ///
1114    /// # Performance
1115    ///
1116    /// This is the hot path. Uses:
1117    /// - Lock-free agent lookup via `DashMap`
1118    /// - Cached health state (no async I/O for health check)
1119    /// - Atomic last_used tracking (no RwLock)
1120    pub async fn send_request_headers(
1121        &self,
1122        agent_id: &str,
1123        correlation_id: &str,
1124        event: &RequestHeadersEvent,
1125    ) -> Result<AgentResponse, AgentProtocolError> {
1126        let start = Instant::now();
1127        self.total_requests.fetch_add(1, Ordering::Relaxed);
1128        self.protocol_metrics.inc_requests();
1129        self.protocol_metrics.inc_in_flight();
1130
1131        let conn = self.select_connection(agent_id)?;
1132
1133        // Check flow control before sending (respects flow_control_mode config)
1134        match self.check_flow_control(&conn, agent_id).await {
1135            Ok(true) => {} // Proceed normally
1136            Ok(false) => {
1137                // FailOpen mode: skip agent, return allow response
1138                self.protocol_metrics.dec_in_flight();
1139                return Ok(AgentResponse::default_allow());
1140            }
1141            Err(e) => {
1142                self.protocol_metrics.dec_in_flight();
1143                return Err(e);
1144            }
1145        }
1146
1147        // Acquire concurrency permit
1148        let _permit = conn
1149            .concurrency_limiter
1150            .acquire()
1151            .await
1152            .map_err(|_| {
1153                self.protocol_metrics.dec_in_flight();
1154                self.protocol_metrics.inc_connection_errors();
1155                AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1156            })?;
1157
1158        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1159        conn.touch(); // Atomic, no lock
1160
1161        // Store connection affinity for body chunk routing
1162        self.correlation_affinity.insert(correlation_id.to_string(), Arc::clone(&conn));
1163
1164        let result = conn.client.send_request_headers(correlation_id, event).await;
1165
1166        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1167        conn.request_count.fetch_add(1, Ordering::Relaxed);
1168        self.protocol_metrics.dec_in_flight();
1169        self.protocol_metrics.record_request_duration(start.elapsed());
1170
1171        match &result {
1172            Ok(_) => {
1173                conn.consecutive_errors.store(0, Ordering::Relaxed);
1174                self.protocol_metrics.inc_responses();
1175            }
1176            Err(e) => {
1177                conn.error_count.fetch_add(1, Ordering::Relaxed);
1178                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1179                self.total_errors.fetch_add(1, Ordering::Relaxed);
1180
1181                // Record error type
1182                match e {
1183                    AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
1184                    AgentProtocolError::ConnectionFailed(_) | AgentProtocolError::ConnectionClosed => {
1185                        self.protocol_metrics.inc_connection_errors();
1186                    }
1187                    AgentProtocolError::Serialization(_) => {
1188                        self.protocol_metrics.inc_serialization_errors();
1189                    }
1190                    _ => {}
1191                }
1192
1193                // Mark unhealthy immediately on repeated failures (fast feedback)
1194                if consecutive >= 3 {
1195                    conn.healthy_cached.store(false, Ordering::Release);
1196                    trace!(agent_id = %agent_id, error = %e, "Connection marked unhealthy after consecutive errors");
1197                }
1198            }
1199        }
1200
1201        result
1202    }
1203
1204    /// Send a request body chunk to an agent.
1205    ///
1206    /// The pool uses correlation_id to route the chunk to the same connection
1207    /// that received the request headers (connection affinity).
1208    pub async fn send_request_body_chunk(
1209        &self,
1210        agent_id: &str,
1211        correlation_id: &str,
1212        event: &RequestBodyChunkEvent,
1213    ) -> Result<AgentResponse, AgentProtocolError> {
1214        self.total_requests.fetch_add(1, Ordering::Relaxed);
1215
1216        // Try to use affinity (same connection as headers), fall back to selection
1217        let conn = if let Some(affinity_conn) = self.correlation_affinity.get(correlation_id) {
1218            Arc::clone(&affinity_conn)
1219        } else {
1220            // No affinity found, use normal selection (shouldn't happen in normal flow)
1221            trace!(correlation_id = %correlation_id, "No affinity found for body chunk, using selection");
1222            self.select_connection(agent_id)?
1223        };
1224
1225        // Check flow control before sending body chunks (critical for backpressure)
1226        match self.check_flow_control(&conn, agent_id).await {
1227            Ok(true) => {} // Proceed normally
1228            Ok(false) => {
1229                // FailOpen mode: skip agent, return allow response
1230                return Ok(AgentResponse::default_allow());
1231            }
1232            Err(e) => return Err(e),
1233        }
1234
1235        let _permit = conn
1236            .concurrency_limiter
1237            .acquire()
1238            .await
1239            .map_err(|_| AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string()))?;
1240
1241        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1242        conn.touch();
1243
1244        let result = conn.client.send_request_body_chunk(correlation_id, event).await;
1245
1246        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1247        conn.request_count.fetch_add(1, Ordering::Relaxed);
1248
1249        match &result {
1250            Ok(_) => {
1251                conn.consecutive_errors.store(0, Ordering::Relaxed);
1252            }
1253            Err(_) => {
1254                conn.error_count.fetch_add(1, Ordering::Relaxed);
1255                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1256                self.total_errors.fetch_add(1, Ordering::Relaxed);
1257                if consecutive >= 3 {
1258                    conn.healthy_cached.store(false, Ordering::Release);
1259                }
1260            }
1261        }
1262
1263        result
1264    }
1265
1266    /// Send response headers to an agent.
1267    ///
1268    /// Called when upstream response headers are received, allowing the agent
1269    /// to inspect/modify response headers before they're sent to the client.
1270    pub async fn send_response_headers(
1271        &self,
1272        agent_id: &str,
1273        correlation_id: &str,
1274        event: &ResponseHeadersEvent,
1275    ) -> Result<AgentResponse, AgentProtocolError> {
1276        self.total_requests.fetch_add(1, Ordering::Relaxed);
1277
1278        let conn = self.select_connection(agent_id)?;
1279
1280        let _permit = conn
1281            .concurrency_limiter
1282            .acquire()
1283            .await
1284            .map_err(|_| AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string()))?;
1285
1286        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1287        conn.touch();
1288
1289        let result = conn.client.send_response_headers(correlation_id, event).await;
1290
1291        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1292        conn.request_count.fetch_add(1, Ordering::Relaxed);
1293
1294        match &result {
1295            Ok(_) => {
1296                conn.consecutive_errors.store(0, Ordering::Relaxed);
1297            }
1298            Err(_) => {
1299                conn.error_count.fetch_add(1, Ordering::Relaxed);
1300                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1301                self.total_errors.fetch_add(1, Ordering::Relaxed);
1302                if consecutive >= 3 {
1303                    conn.healthy_cached.store(false, Ordering::Release);
1304                }
1305            }
1306        }
1307
1308        result
1309    }
1310
1311    /// Send a response body chunk to an agent.
1312    ///
1313    /// For streaming response body inspection, chunks are sent sequentially.
1314    /// The agent can inspect and optionally modify response body data.
1315    pub async fn send_response_body_chunk(
1316        &self,
1317        agent_id: &str,
1318        correlation_id: &str,
1319        event: &ResponseBodyChunkEvent,
1320    ) -> Result<AgentResponse, AgentProtocolError> {
1321        self.total_requests.fetch_add(1, Ordering::Relaxed);
1322
1323        let conn = self.select_connection(agent_id)?;
1324
1325        // Check flow control before sending body chunks (critical for backpressure)
1326        match self.check_flow_control(&conn, agent_id).await {
1327            Ok(true) => {} // Proceed normally
1328            Ok(false) => {
1329                // FailOpen mode: skip agent, return allow response
1330                return Ok(AgentResponse::default_allow());
1331            }
1332            Err(e) => return Err(e),
1333        }
1334
1335        let _permit = conn
1336            .concurrency_limiter
1337            .acquire()
1338            .await
1339            .map_err(|_| AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string()))?;
1340
1341        conn.in_flight.fetch_add(1, Ordering::Relaxed);
1342        conn.touch();
1343
1344        let result = conn.client.send_response_body_chunk(correlation_id, event).await;
1345
1346        conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1347        conn.request_count.fetch_add(1, Ordering::Relaxed);
1348
1349        match &result {
1350            Ok(_) => {
1351                conn.consecutive_errors.store(0, Ordering::Relaxed);
1352            }
1353            Err(_) => {
1354                conn.error_count.fetch_add(1, Ordering::Relaxed);
1355                let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1356                self.total_errors.fetch_add(1, Ordering::Relaxed);
1357                if consecutive >= 3 {
1358                    conn.healthy_cached.store(false, Ordering::Release);
1359                }
1360            }
1361        }
1362
1363        result
1364    }
1365
1366    /// Cancel a request on all connections for an agent.
1367    pub async fn cancel_request(
1368        &self,
1369        agent_id: &str,
1370        correlation_id: &str,
1371        reason: CancelReason,
1372    ) -> Result<(), AgentProtocolError> {
1373        let entry = self
1374            .agents
1375            .get(agent_id)
1376            .ok_or_else(|| AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id)))?;
1377
1378        let connections = entry.connections.read().await;
1379        for conn in connections.iter() {
1380            let _ = conn.client.cancel_request(correlation_id, reason).await;
1381        }
1382
1383        Ok(())
1384    }
1385
1386    /// Get statistics for all agents in the pool.
1387    pub async fn stats(&self) -> Vec<AgentPoolStats> {
1388        let mut stats = Vec::with_capacity(self.agents.len());
1389
1390        for entry_ref in self.agents.iter() {
1391            let agent_id = entry_ref.key().clone();
1392            let entry = entry_ref.value();
1393
1394            let connections = entry.connections.read().await;
1395            let mut healthy_count = 0;
1396            let mut total_in_flight = 0;
1397            let mut total_requests = 0;
1398            let mut total_errors = 0;
1399
1400            for conn in connections.iter() {
1401                // Use cached health for stats (consistent with hot path)
1402                if conn.is_healthy_cached() {
1403                    healthy_count += 1;
1404                }
1405                total_in_flight += conn.in_flight();
1406                total_requests += conn.request_count.load(Ordering::Relaxed);
1407                total_errors += conn.error_count.load(Ordering::Relaxed);
1408            }
1409
1410            let error_rate = if total_requests == 0 {
1411                0.0
1412            } else {
1413                total_errors as f64 / total_requests as f64
1414            };
1415
1416            stats.push(AgentPoolStats {
1417                agent_id,
1418                active_connections: connections.len(),
1419                healthy_connections: healthy_count,
1420                total_in_flight,
1421                total_requests,
1422                total_errors,
1423                error_rate,
1424                is_healthy: entry.healthy.load(Ordering::Acquire),
1425            });
1426        }
1427
1428        stats
1429    }
1430
1431    /// Get statistics for a specific agent.
1432    pub async fn agent_stats(&self, agent_id: &str) -> Option<AgentPoolStats> {
1433        self.stats()
1434            .await
1435            .into_iter()
1436            .find(|s| s.agent_id == agent_id)
1437    }
1438
1439    /// Get the capabilities of an agent.
1440    pub async fn agent_capabilities(&self, agent_id: &str) -> Option<AgentCapabilities> {
1441        // Clone the Arc out of the DashMap Ref to avoid lifetime issues
1442        let entry = match self.agents.get(agent_id) {
1443            Some(entry_ref) => Arc::clone(&*entry_ref),
1444            None => return None,
1445        };
1446        // Bind to temp to ensure guard drops before function returns
1447        let result = entry.capabilities.read().await.clone();
1448        result
1449    }
1450
1451    /// Check if an agent is healthy.
1452    ///
1453    /// Uses cached health state for fast, lock-free access.
1454    pub fn is_agent_healthy(&self, agent_id: &str) -> bool {
1455        self.agents
1456            .get(agent_id)
1457            .map(|e| e.healthy.load(Ordering::Acquire))
1458            .unwrap_or(false)
1459    }
1460
1461    /// Get all agent IDs in the pool.
1462    pub fn agent_ids(&self) -> Vec<String> {
1463        self.agents.iter().map(|e| e.key().clone()).collect()
1464    }
1465
1466    /// Gracefully shut down the pool.
1467    ///
1468    /// This drains all connections and waits for in-flight requests to complete.
1469    pub async fn shutdown(&self) -> Result<(), AgentProtocolError> {
1470        info!("Shutting down agent pool");
1471
1472        // Collect all agents (DashMap doesn't have drain, so we remove one by one)
1473        let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1474
1475        for agent_id in agent_ids {
1476            if let Some((_, entry)) = self.agents.remove(&agent_id) {
1477                debug!(agent_id = %agent_id, "Draining agent connections");
1478
1479                let connections = entry.connections.read().await;
1480                for conn in connections.iter() {
1481                    // Cancel all pending requests
1482                    let _ = conn.client.cancel_all(CancelReason::ProxyShutdown).await;
1483                }
1484
1485                // Wait for in-flight requests to complete
1486                let drain_deadline = Instant::now() + self.config.drain_timeout;
1487                loop {
1488                    let total_in_flight: u64 = connections.iter().map(|c| c.in_flight()).sum();
1489                    if total_in_flight == 0 {
1490                        break;
1491                    }
1492                    if Instant::now() > drain_deadline {
1493                        warn!(
1494                            agent_id = %agent_id,
1495                            in_flight = total_in_flight,
1496                            "Drain timeout, forcing close"
1497                        );
1498                        break;
1499                    }
1500                    tokio::time::sleep(Duration::from_millis(100)).await;
1501                }
1502
1503                // Close all connections
1504                for conn in connections.iter() {
1505                    let _ = conn.client.close().await;
1506                }
1507            }
1508        }
1509
1510        info!("Agent pool shutdown complete");
1511        Ok(())
1512    }
1513
1514    /// Run background maintenance tasks.
1515    ///
1516    /// This should be spawned as a background task. It handles:
1517    /// - Health checking (updates cached health state)
1518    /// - Reconnection of failed connections
1519    /// - Cleanup of idle connections
1520    ///
1521    /// # Health Check Strategy
1522    ///
1523    /// Health is checked here (with I/O) and cached in `PooledConnection::healthy_cached`.
1524    /// The hot path (`select_connection`) reads the cached value without I/O.
1525    pub async fn run_maintenance(&self) {
1526        let mut interval = tokio::time::interval(self.config.health_check_interval);
1527
1528        loop {
1529            interval.tick().await;
1530
1531            // Clean up expired sticky sessions
1532            self.cleanup_expired_sessions();
1533
1534            // Iterate without holding a long-lived lock
1535            let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1536
1537            for agent_id in agent_ids {
1538                let Some(entry_ref) = self.agents.get(&agent_id) else {
1539                    continue; // Agent was removed
1540                };
1541                let entry = entry_ref.value().clone();
1542                drop(entry_ref); // Release DashMap ref before async work
1543
1544                // Check connection health (this does I/O)
1545                let connections = entry.connections.read().await;
1546                let mut healthy_count = 0;
1547
1548                for conn in connections.iter() {
1549                    // Full health check with I/O, updates cached state
1550                    if conn.check_and_update_health().await {
1551                        healthy_count += 1;
1552                    }
1553                }
1554
1555                // Update aggregate agent health status
1556                let was_healthy = entry.healthy.load(Ordering::Acquire);
1557                let is_healthy = healthy_count > 0;
1558                entry.healthy.store(is_healthy, Ordering::Release);
1559
1560                if was_healthy && !is_healthy {
1561                    warn!(agent_id = %agent_id, "Agent marked unhealthy");
1562                } else if !was_healthy && is_healthy {
1563                    info!(agent_id = %agent_id, "Agent recovered");
1564                }
1565
1566                // Try to reconnect failed connections
1567                if healthy_count < self.config.connections_per_agent {
1568                    if entry.should_reconnect(self.config.reconnect_interval) {
1569                        drop(connections); // Release read lock before reconnect
1570                        if let Err(e) = self.reconnect_agent(&agent_id, &entry).await {
1571                            trace!(agent_id = %agent_id, error = %e, "Reconnect failed");
1572                        }
1573                    }
1574                }
1575            }
1576        }
1577    }
1578
1579    // =========================================================================
1580    // Internal Methods
1581    // =========================================================================
1582
1583    async fn create_connection(
1584        &self,
1585        agent_id: &str,
1586        endpoint: &str,
1587    ) -> Result<PooledConnection, AgentProtocolError> {
1588        // Detect transport type from endpoint
1589        let transport = if is_uds_endpoint(endpoint) {
1590            // Unix Domain Socket transport
1591            let socket_path = endpoint
1592                .strip_prefix("unix:")
1593                .unwrap_or(endpoint);
1594
1595            let mut client =
1596                AgentClientV2Uds::new(agent_id, socket_path, self.config.request_timeout).await?;
1597
1598            // Set callbacks before connecting
1599            client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1600            client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1601
1602            client.connect().await?;
1603            V2Transport::Uds(client)
1604        } else {
1605            // gRPC transport (default)
1606            let mut client =
1607                AgentClientV2::new(agent_id, endpoint, self.config.request_timeout).await?;
1608
1609            // Set callbacks before connecting
1610            client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1611            client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1612
1613            client.connect().await?;
1614            V2Transport::Grpc(client)
1615        };
1616
1617        Ok(PooledConnection::new(
1618            transport,
1619            self.config.max_concurrent_per_connection,
1620        ))
1621    }
1622
1623    /// Select a connection for a request.
1624    ///
1625    /// # Performance
1626    ///
1627    /// This is the hot path. Optimizations:
1628    /// - Lock-free agent lookup via `DashMap::get()`
1629    /// - Uses `try_read()` to avoid blocking on connections lock
1630    /// - Cached health state (no async I/O)
1631    /// - All operations are synchronous
1632    ///
1633    /// # Errors
1634    ///
1635    /// Returns error if agent not found, no connections, or no healthy connections.
1636    fn select_connection(
1637        &self,
1638        agent_id: &str,
1639    ) -> Result<Arc<PooledConnection>, AgentProtocolError> {
1640        let entry = self
1641            .agents
1642            .get(agent_id)
1643            .ok_or_else(|| AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id)))?;
1644
1645        // Try non-blocking read first; fall back to blocking if contended
1646        let connections_guard = match entry.connections.try_read() {
1647            Ok(guard) => guard,
1648            Err(_) => {
1649                // Blocking fallback - this should be rare
1650                trace!(agent_id = %agent_id, "select_connection: blocking on connections lock");
1651                futures::executor::block_on(entry.connections.read())
1652            }
1653        };
1654
1655        if connections_guard.is_empty() {
1656            return Err(AgentProtocolError::ConnectionFailed(format!(
1657                "No connections available for agent {}",
1658                agent_id
1659            )));
1660        }
1661
1662        // Filter to healthy connections using cached health (no I/O)
1663        let healthy: Vec<_> = connections_guard
1664            .iter()
1665            .filter(|c| c.is_healthy_cached())
1666            .cloned()
1667            .collect();
1668
1669        if healthy.is_empty() {
1670            return Err(AgentProtocolError::ConnectionFailed(format!(
1671                "No healthy connections for agent {}",
1672                agent_id
1673            )));
1674        }
1675
1676        let selected = match self.config.load_balance_strategy {
1677            LoadBalanceStrategy::RoundRobin => {
1678                let idx = entry.round_robin_index.fetch_add(1, Ordering::Relaxed);
1679                healthy[idx % healthy.len()].clone()
1680            }
1681            LoadBalanceStrategy::LeastConnections => {
1682                healthy
1683                    .iter()
1684                    .min_by_key(|c| c.in_flight())
1685                    .cloned()
1686                    .unwrap()
1687            }
1688            LoadBalanceStrategy::HealthBased => {
1689                // Prefer connections with lower error rates
1690                healthy
1691                    .iter()
1692                    .min_by(|a, b| {
1693                        a.error_rate()
1694                            .partial_cmp(&b.error_rate())
1695                            .unwrap_or(std::cmp::Ordering::Equal)
1696                    })
1697                    .cloned()
1698                    .unwrap()
1699            }
1700            LoadBalanceStrategy::Random => {
1701                use std::collections::hash_map::RandomState;
1702                use std::hash::{BuildHasher, Hasher};
1703                let idx = RandomState::new().build_hasher().finish() as usize % healthy.len();
1704                healthy[idx].clone()
1705            }
1706        };
1707
1708        Ok(selected)
1709    }
1710
1711    async fn reconnect_agent(
1712        &self,
1713        agent_id: &str,
1714        entry: &AgentEntry,
1715    ) -> Result<(), AgentProtocolError> {
1716        entry.mark_reconnect_attempt();
1717        let attempts = entry.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
1718
1719        if attempts >= self.config.max_reconnect_attempts {
1720            debug!(
1721                agent_id = %agent_id,
1722                attempts = attempts,
1723                "Max reconnect attempts reached"
1724            );
1725            return Ok(());
1726        }
1727
1728        debug!(agent_id = %agent_id, attempt = attempts + 1, "Attempting reconnect");
1729
1730        match self.create_connection(agent_id, &entry.endpoint).await {
1731            Ok(conn) => {
1732                let mut connections = entry.connections.write().await;
1733                connections.push(Arc::new(conn));
1734                entry.reconnect_attempts.store(0, Ordering::Relaxed);
1735                info!(agent_id = %agent_id, "Reconnected successfully");
1736                Ok(())
1737            }
1738            Err(e) => {
1739                debug!(agent_id = %agent_id, error = %e, "Reconnect failed");
1740                Err(e)
1741            }
1742        }
1743    }
1744}
1745
1746impl Default for AgentPool {
1747    fn default() -> Self {
1748        Self::new()
1749    }
1750}
1751
1752impl std::fmt::Debug for AgentPool {
1753    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1754        f.debug_struct("AgentPool")
1755            .field("config", &self.config)
1756            .field("total_requests", &self.total_requests.load(Ordering::Relaxed))
1757            .field("total_errors", &self.total_errors.load(Ordering::Relaxed))
1758            .finish()
1759    }
1760}
1761
1762/// Check if an endpoint is a Unix Domain Socket path.
1763///
1764/// Returns true for endpoints that:
1765/// - Start with "unix:" prefix
1766/// - Are absolute paths (start with "/")
1767/// - Have ".sock" extension
1768fn is_uds_endpoint(endpoint: &str) -> bool {
1769    endpoint.starts_with("unix:")
1770        || endpoint.starts_with('/')
1771        || endpoint.ends_with(".sock")
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776    use super::*;
1777
1778    #[test]
1779    fn test_pool_config_default() {
1780        let config = AgentPoolConfig::default();
1781        assert_eq!(config.connections_per_agent, 4);
1782        assert_eq!(config.load_balance_strategy, LoadBalanceStrategy::RoundRobin);
1783    }
1784
1785    #[test]
1786    fn test_load_balance_strategy() {
1787        assert_eq!(LoadBalanceStrategy::default(), LoadBalanceStrategy::RoundRobin);
1788    }
1789
1790    #[test]
1791    fn test_pool_creation() {
1792        let pool = AgentPool::new();
1793        assert_eq!(pool.total_requests.load(Ordering::Relaxed), 0);
1794        assert_eq!(pool.total_errors.load(Ordering::Relaxed), 0);
1795    }
1796
1797    #[test]
1798    fn test_pool_with_config() {
1799        let config = AgentPoolConfig {
1800            connections_per_agent: 8,
1801            load_balance_strategy: LoadBalanceStrategy::LeastConnections,
1802            ..Default::default()
1803        };
1804        let pool = AgentPool::with_config(config.clone());
1805        assert_eq!(pool.config.connections_per_agent, 8);
1806    }
1807
1808    #[test]
1809    fn test_agent_ids_empty() {
1810        let pool = AgentPool::new();
1811        assert!(pool.agent_ids().is_empty());
1812    }
1813
1814    #[test]
1815    fn test_is_agent_healthy_not_found() {
1816        let pool = AgentPool::new();
1817        assert!(!pool.is_agent_healthy("nonexistent"));
1818    }
1819
1820    #[tokio::test]
1821    async fn test_stats_empty() {
1822        let pool = AgentPool::new();
1823        assert!(pool.stats().await.is_empty());
1824    }
1825
1826    #[test]
1827    fn test_is_uds_endpoint() {
1828        // Unix prefix
1829        assert!(is_uds_endpoint("unix:/var/run/agent.sock"));
1830        assert!(is_uds_endpoint("unix:agent.sock"));
1831
1832        // Absolute path
1833        assert!(is_uds_endpoint("/var/run/agent.sock"));
1834        assert!(is_uds_endpoint("/tmp/test.sock"));
1835
1836        // .sock extension
1837        assert!(is_uds_endpoint("agent.sock"));
1838
1839        // Not UDS
1840        assert!(!is_uds_endpoint("http://localhost:8080"));
1841        assert!(!is_uds_endpoint("localhost:50051"));
1842        assert!(!is_uds_endpoint("127.0.0.1:8080"));
1843    }
1844
1845    #[test]
1846    fn test_flow_control_mode_default() {
1847        assert_eq!(FlowControlMode::default(), FlowControlMode::FailClosed);
1848    }
1849
1850    #[test]
1851    fn test_pool_config_flow_control_defaults() {
1852        let config = AgentPoolConfig::default();
1853        assert_eq!(config.channel_buffer_size, CHANNEL_BUFFER_SIZE);
1854        assert_eq!(config.flow_control_mode, FlowControlMode::FailClosed);
1855        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(100));
1856    }
1857
1858    #[test]
1859    fn test_pool_config_custom_flow_control() {
1860        let config = AgentPoolConfig {
1861            channel_buffer_size: 128,
1862            flow_control_mode: FlowControlMode::FailOpen,
1863            flow_control_wait_timeout: Duration::from_millis(500),
1864            ..Default::default()
1865        };
1866        assert_eq!(config.channel_buffer_size, 128);
1867        assert_eq!(config.flow_control_mode, FlowControlMode::FailOpen);
1868        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(500));
1869    }
1870
1871    #[test]
1872    fn test_pool_config_wait_and_retry() {
1873        let config = AgentPoolConfig {
1874            flow_control_mode: FlowControlMode::WaitAndRetry,
1875            flow_control_wait_timeout: Duration::from_millis(250),
1876            ..Default::default()
1877        };
1878        assert_eq!(config.flow_control_mode, FlowControlMode::WaitAndRetry);
1879        assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(250));
1880    }
1881
1882    #[test]
1883    fn test_pool_config_sticky_session_default() {
1884        let config = AgentPoolConfig::default();
1885        assert_eq!(
1886            config.sticky_session_timeout,
1887            Some(Duration::from_secs(5 * 60))
1888        );
1889    }
1890
1891    #[test]
1892    fn test_pool_config_sticky_session_custom() {
1893        let config = AgentPoolConfig {
1894            sticky_session_timeout: Some(Duration::from_secs(60)),
1895            ..Default::default()
1896        };
1897        assert_eq!(config.sticky_session_timeout, Some(Duration::from_secs(60)));
1898    }
1899
1900    #[test]
1901    fn test_pool_config_sticky_session_disabled() {
1902        let config = AgentPoolConfig {
1903            sticky_session_timeout: None,
1904            ..Default::default()
1905        };
1906        assert!(config.sticky_session_timeout.is_none());
1907    }
1908
1909    #[test]
1910    fn test_sticky_session_count_empty() {
1911        let pool = AgentPool::new();
1912        assert_eq!(pool.sticky_session_count(), 0);
1913    }
1914
1915    #[test]
1916    fn test_sticky_session_has_nonexistent() {
1917        let pool = AgentPool::new();
1918        assert!(!pool.has_sticky_session("nonexistent"));
1919    }
1920
1921    #[test]
1922    fn test_sticky_session_clear_nonexistent() {
1923        let pool = AgentPool::new();
1924        // Should not panic
1925        pool.clear_sticky_session("nonexistent");
1926    }
1927
1928    #[test]
1929    fn test_cleanup_expired_sessions_empty() {
1930        let pool = AgentPool::new();
1931        let removed = pool.cleanup_expired_sessions();
1932        assert_eq!(removed, 0);
1933    }
1934}