1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use dashmap::DashMap;
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, info, trace, warn};
19
20use crate::v2::client::{AgentClientV2, CancelReason, ConfigUpdateCallback, MetricsCallback};
21use crate::v2::control::ConfigUpdateType;
22use crate::v2::observability::{ConfigPusher, ConfigUpdateHandler, MetricsCollector};
23use crate::v2::protocol_metrics::ProtocolMetrics;
24use crate::v2::reverse::ReverseConnectionClient;
25use crate::v2::uds::AgentClientV2Uds;
26use crate::v2::AgentCapabilities;
27use crate::{
28 AgentProtocolError, AgentResponse, RequestBodyChunkEvent, RequestHeadersEvent,
29 ResponseBodyChunkEvent, ResponseHeadersEvent,
30};
31
32pub const CHANNEL_BUFFER_SIZE: usize = 64;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum LoadBalanceStrategy {
46 #[default]
48 RoundRobin,
49 LeastConnections,
51 HealthBased,
53 Random,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum FlowControlMode {
63 #[default]
69 FailClosed,
70
71 FailOpen,
77
78 WaitAndRetry,
83}
84
85struct StickySession {
91 connection: Arc<PooledConnection>,
93 agent_id: String,
95 created_at: Instant,
97 last_accessed: AtomicU64,
99}
100
101impl std::fmt::Debug for StickySession {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("StickySession")
104 .field("agent_id", &self.agent_id)
105 .field("created_at", &self.created_at)
106 .finish_non_exhaustive()
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct AgentPoolConfig {
113 pub connections_per_agent: usize,
115 pub load_balance_strategy: LoadBalanceStrategy,
117 pub connect_timeout: Duration,
119 pub request_timeout: Duration,
121 pub reconnect_interval: Duration,
123 pub max_reconnect_attempts: usize,
125 pub drain_timeout: Duration,
127 pub max_concurrent_per_connection: usize,
129 pub health_check_interval: Duration,
131 pub channel_buffer_size: usize,
139 pub flow_control_mode: FlowControlMode,
143 pub flow_control_wait_timeout: Duration,
148 pub sticky_session_timeout: Option<Duration>,
158}
159
160impl Default for AgentPoolConfig {
161 fn default() -> Self {
162 Self {
163 connections_per_agent: 4,
164 load_balance_strategy: LoadBalanceStrategy::RoundRobin,
165 connect_timeout: Duration::from_secs(5),
166 request_timeout: Duration::from_secs(30),
167 reconnect_interval: Duration::from_secs(5),
168 max_reconnect_attempts: 3,
169 drain_timeout: Duration::from_secs(30),
170 max_concurrent_per_connection: 100,
171 health_check_interval: Duration::from_secs(10),
172 channel_buffer_size: CHANNEL_BUFFER_SIZE,
173 flow_control_mode: FlowControlMode::FailClosed,
174 flow_control_wait_timeout: Duration::from_millis(100),
175 sticky_session_timeout: Some(Duration::from_secs(5 * 60)), }
177 }
178}
179
180impl StickySession {
181 fn new(agent_id: String, connection: Arc<PooledConnection>) -> Self {
182 Self {
183 connection,
184 agent_id,
185 created_at: Instant::now(),
186 last_accessed: AtomicU64::new(0),
187 }
188 }
189
190 fn touch(&self) {
191 let offset = self.created_at.elapsed().as_millis() as u64;
192 self.last_accessed.store(offset, Ordering::Relaxed);
193 }
194
195 fn last_accessed(&self) -> Instant {
196 let offset_ms = self.last_accessed.load(Ordering::Relaxed);
197 self.created_at + Duration::from_millis(offset_ms)
198 }
199
200 fn is_expired(&self, timeout: Duration) -> bool {
201 self.last_accessed().elapsed() > timeout
202 }
203}
204
205pub enum V2Transport {
209 Grpc(AgentClientV2),
211 Uds(AgentClientV2Uds),
213 Reverse(ReverseConnectionClient),
215}
216
217impl V2Transport {
218 pub async fn is_connected(&self) -> bool {
220 match self {
221 V2Transport::Grpc(client) => client.is_connected().await,
222 V2Transport::Uds(client) => client.is_connected().await,
223 V2Transport::Reverse(client) => client.is_connected().await,
224 }
225 }
226
227 pub async fn can_accept_requests(&self) -> bool {
231 match self {
232 V2Transport::Grpc(client) => client.can_accept_requests().await,
233 V2Transport::Uds(client) => client.can_accept_requests().await,
234 V2Transport::Reverse(client) => client.can_accept_requests().await,
235 }
236 }
237
238 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
240 match self {
241 V2Transport::Grpc(client) => client.capabilities().await,
242 V2Transport::Uds(client) => client.capabilities().await,
243 V2Transport::Reverse(client) => client.capabilities().await,
244 }
245 }
246
247 pub async fn send_request_headers(
249 &self,
250 correlation_id: &str,
251 event: &RequestHeadersEvent,
252 ) -> Result<AgentResponse, AgentProtocolError> {
253 match self {
254 V2Transport::Grpc(client) => client.send_request_headers(correlation_id, event).await,
255 V2Transport::Uds(client) => client.send_request_headers(correlation_id, event).await,
256 V2Transport::Reverse(client) => {
257 client.send_request_headers(correlation_id, event).await
258 }
259 }
260 }
261
262 pub async fn send_request_body_chunk(
264 &self,
265 correlation_id: &str,
266 event: &RequestBodyChunkEvent,
267 ) -> Result<AgentResponse, AgentProtocolError> {
268 match self {
269 V2Transport::Grpc(client) => {
270 client.send_request_body_chunk(correlation_id, event).await
271 }
272 V2Transport::Uds(client) => client.send_request_body_chunk(correlation_id, event).await,
273 V2Transport::Reverse(client) => {
274 client.send_request_body_chunk(correlation_id, event).await
275 }
276 }
277 }
278
279 pub async fn send_response_headers(
281 &self,
282 correlation_id: &str,
283 event: &ResponseHeadersEvent,
284 ) -> Result<AgentResponse, AgentProtocolError> {
285 match self {
286 V2Transport::Grpc(client) => client.send_response_headers(correlation_id, event).await,
287 V2Transport::Uds(client) => client.send_response_headers(correlation_id, event).await,
288 V2Transport::Reverse(client) => {
289 client.send_response_headers(correlation_id, event).await
290 }
291 }
292 }
293
294 pub async fn send_response_body_chunk(
296 &self,
297 correlation_id: &str,
298 event: &ResponseBodyChunkEvent,
299 ) -> Result<AgentResponse, AgentProtocolError> {
300 match self {
301 V2Transport::Grpc(client) => {
302 client.send_response_body_chunk(correlation_id, event).await
303 }
304 V2Transport::Uds(client) => {
305 client.send_response_body_chunk(correlation_id, event).await
306 }
307 V2Transport::Reverse(client) => {
308 client.send_response_body_chunk(correlation_id, event).await
309 }
310 }
311 }
312
313 pub async fn cancel_request(
315 &self,
316 correlation_id: &str,
317 reason: CancelReason,
318 ) -> Result<(), AgentProtocolError> {
319 match self {
320 V2Transport::Grpc(client) => client.cancel_request(correlation_id, reason).await,
321 V2Transport::Uds(client) => client.cancel_request(correlation_id, reason).await,
322 V2Transport::Reverse(client) => client.cancel_request(correlation_id, reason).await,
323 }
324 }
325
326 pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
328 match self {
329 V2Transport::Grpc(client) => client.cancel_all(reason).await,
330 V2Transport::Uds(client) => client.cancel_all(reason).await,
331 V2Transport::Reverse(client) => client.cancel_all(reason).await,
332 }
333 }
334
335 pub async fn close(&self) -> Result<(), AgentProtocolError> {
337 match self {
338 V2Transport::Grpc(client) => client.close().await,
339 V2Transport::Uds(client) => client.close().await,
340 V2Transport::Reverse(client) => client.close().await,
341 }
342 }
343
344 pub fn agent_id(&self) -> &str {
346 match self {
347 V2Transport::Grpc(client) => client.agent_id(),
348 V2Transport::Uds(client) => client.agent_id(),
349 V2Transport::Reverse(client) => client.agent_id(),
350 }
351 }
352}
353
354struct PooledConnection {
356 client: V2Transport,
357 created_at: Instant,
358 last_used_offset_ms: AtomicU64,
360 in_flight: AtomicU64,
361 request_count: AtomicU64,
362 error_count: AtomicU64,
363 consecutive_errors: AtomicU64,
364 concurrency_limiter: Semaphore,
365 healthy_cached: AtomicBool,
367}
368
369impl PooledConnection {
370 fn new(client: V2Transport, max_concurrent: usize) -> Self {
371 Self {
372 client,
373 created_at: Instant::now(),
374 last_used_offset_ms: AtomicU64::new(0),
375 in_flight: AtomicU64::new(0),
376 request_count: AtomicU64::new(0),
377 error_count: AtomicU64::new(0),
378 consecutive_errors: AtomicU64::new(0),
379 concurrency_limiter: Semaphore::new(max_concurrent),
380 healthy_cached: AtomicBool::new(true), }
382 }
383
384 fn in_flight(&self) -> u64 {
385 self.in_flight.load(Ordering::Relaxed)
386 }
387
388 fn error_rate(&self) -> f64 {
389 let requests = self.request_count.load(Ordering::Relaxed);
390 let errors = self.error_count.load(Ordering::Relaxed);
391 if requests == 0 {
392 0.0
393 } else {
394 errors as f64 / requests as f64
395 }
396 }
397
398 #[inline]
401 fn is_healthy_cached(&self) -> bool {
402 self.healthy_cached.load(Ordering::Acquire)
403 }
404
405 async fn check_and_update_health(&self) -> bool {
407 let connected = self.client.is_connected().await;
408 let low_errors = self.consecutive_errors.load(Ordering::Relaxed) < 3;
409 let can_accept = self.client.can_accept_requests().await;
410
411 let healthy = connected && low_errors && can_accept;
412 self.healthy_cached.store(healthy, Ordering::Release);
413 healthy
414 }
415
416 #[inline]
418 fn touch(&self) {
419 let offset = self.created_at.elapsed().as_millis() as u64;
420 self.last_used_offset_ms.store(offset, Ordering::Relaxed);
421 }
422
423 fn last_used(&self) -> Instant {
425 let offset_ms = self.last_used_offset_ms.load(Ordering::Relaxed);
426 self.created_at + Duration::from_millis(offset_ms)
427 }
428}
429
430#[derive(Debug, Clone)]
432pub struct AgentPoolStats {
433 pub agent_id: String,
435 pub active_connections: usize,
437 pub healthy_connections: usize,
439 pub total_in_flight: u64,
441 pub total_requests: u64,
443 pub total_errors: u64,
445 pub error_rate: f64,
447 pub is_healthy: bool,
449}
450
451struct AgentEntry {
453 agent_id: String,
454 endpoint: String,
455 connections: RwLock<Vec<Arc<PooledConnection>>>,
458 capabilities: RwLock<Option<AgentCapabilities>>,
459 round_robin_index: AtomicUsize,
460 reconnect_attempts: AtomicUsize,
461 last_reconnect_attempt_ms: AtomicU64,
463 healthy: AtomicBool,
465}
466
467impl AgentEntry {
468 fn new(agent_id: String, endpoint: String) -> Self {
469 Self {
470 agent_id,
471 endpoint,
472 connections: RwLock::new(Vec::new()),
473 capabilities: RwLock::new(None),
474 round_robin_index: AtomicUsize::new(0),
475 reconnect_attempts: AtomicUsize::new(0),
476 last_reconnect_attempt_ms: AtomicU64::new(0),
477 healthy: AtomicBool::new(true),
478 }
479 }
480
481 fn should_reconnect(&self, interval: Duration) -> bool {
483 let last_ms = self.last_reconnect_attempt_ms.load(Ordering::Relaxed);
484 if last_ms == 0 {
485 return true;
486 }
487 let now_ms = std::time::SystemTime::now()
488 .duration_since(std::time::UNIX_EPOCH)
489 .map(|d| d.as_millis() as u64)
490 .unwrap_or(0);
491 now_ms.saturating_sub(last_ms) > interval.as_millis() as u64
492 }
493
494 fn mark_reconnect_attempt(&self) {
496 let now_ms = std::time::SystemTime::now()
497 .duration_since(std::time::UNIX_EPOCH)
498 .map(|d| d.as_millis() as u64)
499 .unwrap_or(0);
500 self.last_reconnect_attempt_ms
501 .store(now_ms, Ordering::Relaxed);
502 }
503}
504
505pub struct AgentPool {
516 config: AgentPoolConfig,
517 agents: DashMap<String, Arc<AgentEntry>>,
520 total_requests: AtomicU64,
521 total_errors: AtomicU64,
522 metrics_collector: Arc<MetricsCollector>,
524 metrics_callback: MetricsCallback,
526 config_pusher: Arc<ConfigPusher>,
528 config_update_handler: Arc<ConfigUpdateHandler>,
530 config_update_callback: ConfigUpdateCallback,
532 protocol_metrics: Arc<ProtocolMetrics>,
534 correlation_affinity: DashMap<String, Arc<PooledConnection>>,
537 sticky_sessions: DashMap<String, StickySession>,
540}
541
542impl AgentPool {
543 pub fn new() -> Self {
545 Self::with_config(AgentPoolConfig::default())
546 }
547
548 pub fn with_config(config: AgentPoolConfig) -> Self {
550 let metrics_collector = Arc::new(MetricsCollector::new());
551 let collector_clone = Arc::clone(&metrics_collector);
552
553 let metrics_callback: MetricsCallback = Arc::new(move |report| {
555 collector_clone.record(&report);
556 });
557
558 let config_pusher = Arc::new(ConfigPusher::new());
560 let config_update_handler = Arc::new(ConfigUpdateHandler::new());
561 let handler_clone = Arc::clone(&config_update_handler);
562
563 let config_update_callback: ConfigUpdateCallback = Arc::new(move |agent_id, request| {
565 debug!(
566 agent_id = %agent_id,
567 request_id = %request.request_id,
568 "Processing config update request from agent"
569 );
570 handler_clone.handle(request)
571 });
572
573 Self {
574 config,
575 agents: DashMap::new(),
576 total_requests: AtomicU64::new(0),
577 total_errors: AtomicU64::new(0),
578 metrics_collector,
579 metrics_callback,
580 config_pusher,
581 config_update_handler,
582 config_update_callback,
583 protocol_metrics: Arc::new(ProtocolMetrics::new()),
584 correlation_affinity: DashMap::new(),
585 sticky_sessions: DashMap::new(),
586 }
587 }
588
589 pub fn protocol_metrics(&self) -> &ProtocolMetrics {
591 &self.protocol_metrics
592 }
593
594 pub fn protocol_metrics_arc(&self) -> Arc<ProtocolMetrics> {
596 Arc::clone(&self.protocol_metrics)
597 }
598
599 pub fn metrics_collector(&self) -> &MetricsCollector {
601 &self.metrics_collector
602 }
603
604 pub fn metrics_collector_arc(&self) -> Arc<MetricsCollector> {
608 Arc::clone(&self.metrics_collector)
609 }
610
611 pub fn export_prometheus(&self) -> String {
613 self.metrics_collector.export_prometheus()
614 }
615
616 pub fn clear_correlation_affinity(&self, correlation_id: &str) {
622 self.correlation_affinity.remove(correlation_id);
623 }
624
625 pub fn correlation_affinity_count(&self) -> usize {
629 self.correlation_affinity.len()
630 }
631
632 pub fn create_sticky_session(
668 &self,
669 session_id: impl Into<String>,
670 agent_id: &str,
671 ) -> Result<(), AgentProtocolError> {
672 let session_id = session_id.into();
673 let conn = self.select_connection(agent_id)?;
674
675 let session = StickySession::new(agent_id.to_string(), conn);
676 session.touch();
677
678 self.sticky_sessions.insert(session_id.clone(), session);
679
680 debug!(
681 session_id = %session_id,
682 agent_id = %agent_id,
683 "Created sticky session"
684 );
685
686 Ok(())
687 }
688
689 fn get_sticky_session_conn(&self, session_id: &str) -> Option<Arc<PooledConnection>> {
694 let entry = self.sticky_sessions.get(session_id)?;
695
696 if let Some(timeout) = self.config.sticky_session_timeout {
698 if entry.is_expired(timeout) {
699 drop(entry); self.sticky_sessions.remove(session_id);
701 debug!(session_id = %session_id, "Sticky session expired");
702 return None;
703 }
704 }
705
706 entry.touch();
707 Some(Arc::clone(&entry.connection))
708 }
709
710 pub fn refresh_sticky_session(&self, session_id: &str) -> bool {
714 self.get_sticky_session_conn(session_id).is_some()
715 }
716
717 pub fn has_sticky_session(&self, session_id: &str) -> bool {
719 self.get_sticky_session_conn(session_id).is_some()
720 }
721
722 pub fn clear_sticky_session(&self, session_id: &str) {
726 if self.sticky_sessions.remove(session_id).is_some() {
727 debug!(session_id = %session_id, "Cleared sticky session");
728 }
729 }
730
731 pub fn sticky_session_count(&self) -> usize {
735 self.sticky_sessions.len()
736 }
737
738 pub fn sticky_session_agent(&self, session_id: &str) -> Option<String> {
740 self.sticky_sessions
741 .get(session_id)
742 .map(|s| s.agent_id.clone())
743 }
744
745 pub async fn send_request_headers_with_sticky_session(
755 &self,
756 session_id: &str,
757 agent_id: &str,
758 correlation_id: &str,
759 event: &RequestHeadersEvent,
760 ) -> Result<(AgentResponse, bool), AgentProtocolError> {
761 let start = Instant::now();
762 self.total_requests.fetch_add(1, Ordering::Relaxed);
763 self.protocol_metrics.inc_requests();
764 self.protocol_metrics.inc_in_flight();
765
766 let (conn, used_sticky) =
768 if let Some(sticky_conn) = self.get_sticky_session_conn(session_id) {
769 (sticky_conn, true)
770 } else {
771 (self.select_connection(agent_id)?, false)
772 };
773
774 match self.check_flow_control(&conn, agent_id).await {
776 Ok(true) => {}
777 Ok(false) => {
778 self.protocol_metrics.dec_in_flight();
779 return Ok((AgentResponse::default_allow(), used_sticky));
780 }
781 Err(e) => {
782 self.protocol_metrics.dec_in_flight();
783 return Err(e);
784 }
785 }
786
787 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
789 self.protocol_metrics.dec_in_flight();
790 self.protocol_metrics.inc_connection_errors();
791 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
792 })?;
793
794 conn.in_flight.fetch_add(1, Ordering::Relaxed);
795 conn.touch();
796
797 self.correlation_affinity
799 .insert(correlation_id.to_string(), Arc::clone(&conn));
800
801 let result = conn
802 .client
803 .send_request_headers(correlation_id, event)
804 .await;
805
806 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
807 conn.request_count.fetch_add(1, Ordering::Relaxed);
808 self.protocol_metrics.dec_in_flight();
809 self.protocol_metrics
810 .record_request_duration(start.elapsed());
811
812 match &result {
813 Ok(_) => {
814 conn.consecutive_errors.store(0, Ordering::Relaxed);
815 self.protocol_metrics.inc_responses();
816 }
817 Err(e) => {
818 conn.error_count.fetch_add(1, Ordering::Relaxed);
819 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
820 self.total_errors.fetch_add(1, Ordering::Relaxed);
821
822 match e {
823 AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
824 AgentProtocolError::ConnectionFailed(_)
825 | AgentProtocolError::ConnectionClosed => {
826 self.protocol_metrics.inc_connection_errors();
827 }
828 AgentProtocolError::Serialization(_) => {
829 self.protocol_metrics.inc_serialization_errors();
830 }
831 _ => {}
832 }
833
834 if consecutive >= 3 {
835 conn.healthy_cached.store(false, Ordering::Release);
836 }
837 }
838 }
839
840 result.map(|r| (r, used_sticky))
841 }
842
843 pub fn cleanup_expired_sessions(&self) -> usize {
848 let Some(timeout) = self.config.sticky_session_timeout else {
849 return 0;
850 };
851
852 let mut removed = 0;
853 self.sticky_sessions.retain(|session_id, session| {
854 if session.is_expired(timeout) {
855 debug!(session_id = %session_id, "Removing expired sticky session");
856 removed += 1;
857 false
858 } else {
859 true
860 }
861 });
862
863 if removed > 0 {
864 trace!(removed = removed, "Cleaned up expired sticky sessions");
865 }
866
867 removed
868 }
869
870 pub fn config_pusher(&self) -> &ConfigPusher {
872 &self.config_pusher
873 }
874
875 pub fn config_update_handler(&self) -> &ConfigUpdateHandler {
877 &self.config_update_handler
878 }
879
880 pub fn push_config_to_agent(
884 &self,
885 agent_id: &str,
886 update_type: ConfigUpdateType,
887 ) -> Option<String> {
888 self.config_pusher.push_to_agent(agent_id, update_type)
889 }
890
891 pub fn push_config_to_all(&self, update_type: ConfigUpdateType) -> Vec<String> {
895 self.config_pusher.push_to_all(update_type)
896 }
897
898 pub fn acknowledge_config_push(&self, push_id: &str, accepted: bool, error: Option<String>) {
900 self.config_pusher.acknowledge(push_id, accepted, error);
901 }
902
903 pub async fn add_agent(
907 &self,
908 agent_id: impl Into<String>,
909 endpoint: impl Into<String>,
910 ) -> Result<(), AgentProtocolError> {
911 let agent_id = agent_id.into();
912 let endpoint = endpoint.into();
913
914 info!(agent_id = %agent_id, endpoint = %endpoint, "Adding agent to pool");
915
916 let entry = Arc::new(AgentEntry::new(agent_id.clone(), endpoint.clone()));
917
918 let mut connections = Vec::with_capacity(self.config.connections_per_agent);
920 for i in 0..self.config.connections_per_agent {
921 match self.create_connection(&agent_id, &endpoint).await {
922 Ok(conn) => {
923 connections.push(Arc::new(conn));
924 debug!(
925 agent_id = %agent_id,
926 connection = i,
927 "Created connection"
928 );
929 }
930 Err(e) => {
931 warn!(
932 agent_id = %agent_id,
933 connection = i,
934 error = %e,
935 "Failed to create connection"
936 );
937 }
939 }
940 }
941
942 if connections.is_empty() {
943 return Err(AgentProtocolError::ConnectionFailed(format!(
944 "Failed to create any connections to agent {}",
945 agent_id
946 )));
947 }
948
949 if let Some(conn) = connections.first() {
951 if let Some(caps) = conn.client.capabilities().await {
952 let supports_config_push = caps.features.config_push;
954 let agent_name = caps.name.clone();
955 self.config_pusher
956 .register_agent(&agent_id, &agent_name, supports_config_push);
957 debug!(
958 agent_id = %agent_id,
959 supports_config_push = supports_config_push,
960 "Registered agent with ConfigPusher"
961 );
962
963 *entry.capabilities.write().await = Some(caps);
964 }
965 }
966
967 *entry.connections.write().await = connections;
968 self.agents.insert(agent_id.clone(), entry);
969
970 info!(
971 agent_id = %agent_id,
972 connections = self.config.connections_per_agent,
973 "Agent added to pool"
974 );
975
976 Ok(())
977 }
978
979 pub async fn remove_agent(&self, agent_id: &str) -> Result<(), AgentProtocolError> {
983 info!(agent_id = %agent_id, "Removing agent from pool");
984
985 self.config_pusher.unregister_agent(agent_id);
987
988 let (_, entry) = self.agents.remove(agent_id).ok_or_else(|| {
989 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
990 })?;
991
992 let connections = entry.connections.read().await;
994 for conn in connections.iter() {
995 let _ = conn.client.close().await;
996 }
997
998 info!(agent_id = %agent_id, "Agent removed from pool");
999 Ok(())
1000 }
1001
1002 pub async fn add_reverse_connection(
1008 &self,
1009 agent_id: &str,
1010 client: ReverseConnectionClient,
1011 capabilities: AgentCapabilities,
1012 ) -> Result<(), AgentProtocolError> {
1013 info!(
1014 agent_id = %agent_id,
1015 connection_id = %client.connection_id(),
1016 "Adding reverse connection to pool"
1017 );
1018
1019 let transport = V2Transport::Reverse(client);
1020 let conn = Arc::new(PooledConnection::new(
1021 transport,
1022 self.config.max_concurrent_per_connection,
1023 ));
1024
1025 if let Some(entry) = self.agents.get(agent_id) {
1027 let mut connections = entry.connections.write().await;
1029
1030 if connections.len() >= self.config.connections_per_agent {
1032 warn!(
1033 agent_id = %agent_id,
1034 current = connections.len(),
1035 max = self.config.connections_per_agent,
1036 "Reverse connection rejected: at connection limit"
1037 );
1038 return Err(AgentProtocolError::ConnectionFailed(format!(
1039 "Agent {} already has maximum connections ({})",
1040 agent_id, self.config.connections_per_agent
1041 )));
1042 }
1043
1044 connections.push(conn);
1045 info!(
1046 agent_id = %agent_id,
1047 total_connections = connections.len(),
1048 "Added reverse connection to existing agent"
1049 );
1050 } else {
1051 let entry = Arc::new(AgentEntry::new(
1053 agent_id.to_string(),
1054 format!("reverse://{}", agent_id),
1055 ));
1056
1057 let supports_config_push = capabilities.features.config_push;
1059 let agent_name = capabilities.name.clone();
1060 self.config_pusher
1061 .register_agent(agent_id, &agent_name, supports_config_push);
1062 debug!(
1063 agent_id = %agent_id,
1064 supports_config_push = supports_config_push,
1065 "Registered reverse connection agent with ConfigPusher"
1066 );
1067
1068 *entry.capabilities.write().await = Some(capabilities);
1069 *entry.connections.write().await = vec![conn];
1070 self.agents.insert(agent_id.to_string(), entry);
1071
1072 info!(
1073 agent_id = %agent_id,
1074 "Created new agent entry for reverse connection"
1075 );
1076 }
1077
1078 Ok(())
1079 }
1080
1081 async fn check_flow_control(
1087 &self,
1088 conn: &PooledConnection,
1089 agent_id: &str,
1090 ) -> Result<bool, AgentProtocolError> {
1091 if conn.client.can_accept_requests().await {
1092 return Ok(true);
1093 }
1094
1095 match self.config.flow_control_mode {
1096 FlowControlMode::FailClosed => {
1097 self.protocol_metrics.record_flow_rejection();
1098 Err(AgentProtocolError::FlowControlPaused {
1099 agent_id: agent_id.to_string(),
1100 })
1101 }
1102 FlowControlMode::FailOpen => {
1103 debug!(agent_id = %agent_id, "Flow control: agent paused, allowing request (fail-open mode)");
1105 self.protocol_metrics.record_flow_rejection();
1106 Ok(false) }
1108 FlowControlMode::WaitAndRetry => {
1109 let deadline = Instant::now() + self.config.flow_control_wait_timeout;
1111 while Instant::now() < deadline {
1112 tokio::time::sleep(Duration::from_millis(10)).await;
1113 if conn.client.can_accept_requests().await {
1114 trace!(agent_id = %agent_id, "Flow control: agent resumed after wait");
1115 return Ok(true);
1116 }
1117 }
1118 self.protocol_metrics.record_flow_rejection();
1120 Err(AgentProtocolError::FlowControlPaused {
1121 agent_id: agent_id.to_string(),
1122 })
1123 }
1124 }
1125 }
1126
1127 pub async fn send_request_headers(
1138 &self,
1139 agent_id: &str,
1140 correlation_id: &str,
1141 event: &RequestHeadersEvent,
1142 ) -> Result<AgentResponse, AgentProtocolError> {
1143 let start = Instant::now();
1144 self.total_requests.fetch_add(1, Ordering::Relaxed);
1145 self.protocol_metrics.inc_requests();
1146 self.protocol_metrics.inc_in_flight();
1147
1148 let conn = self.select_connection(agent_id)?;
1149
1150 match self.check_flow_control(&conn, agent_id).await {
1152 Ok(true) => {} Ok(false) => {
1154 self.protocol_metrics.dec_in_flight();
1156 return Ok(AgentResponse::default_allow());
1157 }
1158 Err(e) => {
1159 self.protocol_metrics.dec_in_flight();
1160 return Err(e);
1161 }
1162 }
1163
1164 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1166 self.protocol_metrics.dec_in_flight();
1167 self.protocol_metrics.inc_connection_errors();
1168 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1169 })?;
1170
1171 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1172 conn.touch(); self.correlation_affinity
1176 .insert(correlation_id.to_string(), Arc::clone(&conn));
1177
1178 let result = conn
1179 .client
1180 .send_request_headers(correlation_id, event)
1181 .await;
1182
1183 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1184 conn.request_count.fetch_add(1, Ordering::Relaxed);
1185 self.protocol_metrics.dec_in_flight();
1186 self.protocol_metrics
1187 .record_request_duration(start.elapsed());
1188
1189 match &result {
1190 Ok(_) => {
1191 conn.consecutive_errors.store(0, Ordering::Relaxed);
1192 self.protocol_metrics.inc_responses();
1193 }
1194 Err(e) => {
1195 conn.error_count.fetch_add(1, Ordering::Relaxed);
1196 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1197 self.total_errors.fetch_add(1, Ordering::Relaxed);
1198
1199 match e {
1201 AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
1202 AgentProtocolError::ConnectionFailed(_)
1203 | AgentProtocolError::ConnectionClosed => {
1204 self.protocol_metrics.inc_connection_errors();
1205 }
1206 AgentProtocolError::Serialization(_) => {
1207 self.protocol_metrics.inc_serialization_errors();
1208 }
1209 _ => {}
1210 }
1211
1212 if consecutive >= 3 {
1214 conn.healthy_cached.store(false, Ordering::Release);
1215 trace!(agent_id = %agent_id, error = %e, "Connection marked unhealthy after consecutive errors");
1216 }
1217 }
1218 }
1219
1220 result
1221 }
1222
1223 pub async fn send_request_body_chunk(
1228 &self,
1229 agent_id: &str,
1230 correlation_id: &str,
1231 event: &RequestBodyChunkEvent,
1232 ) -> Result<AgentResponse, AgentProtocolError> {
1233 self.total_requests.fetch_add(1, Ordering::Relaxed);
1234
1235 let conn = if let Some(affinity_conn) = self.correlation_affinity.get(correlation_id) {
1237 Arc::clone(&affinity_conn)
1238 } else {
1239 trace!(correlation_id = %correlation_id, "No affinity found for body chunk, using selection");
1241 self.select_connection(agent_id)?
1242 };
1243
1244 match self.check_flow_control(&conn, agent_id).await {
1246 Ok(true) => {} Ok(false) => {
1248 return Ok(AgentResponse::default_allow());
1250 }
1251 Err(e) => return Err(e),
1252 }
1253
1254 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1255 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1256 })?;
1257
1258 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1259 conn.touch();
1260
1261 let result = conn
1262 .client
1263 .send_request_body_chunk(correlation_id, event)
1264 .await;
1265
1266 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1267 conn.request_count.fetch_add(1, Ordering::Relaxed);
1268
1269 match &result {
1270 Ok(_) => {
1271 conn.consecutive_errors.store(0, Ordering::Relaxed);
1272 }
1273 Err(_) => {
1274 conn.error_count.fetch_add(1, Ordering::Relaxed);
1275 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1276 self.total_errors.fetch_add(1, Ordering::Relaxed);
1277 if consecutive >= 3 {
1278 conn.healthy_cached.store(false, Ordering::Release);
1279 }
1280 }
1281 }
1282
1283 result
1284 }
1285
1286 pub async fn send_response_headers(
1291 &self,
1292 agent_id: &str,
1293 correlation_id: &str,
1294 event: &ResponseHeadersEvent,
1295 ) -> Result<AgentResponse, AgentProtocolError> {
1296 self.total_requests.fetch_add(1, Ordering::Relaxed);
1297
1298 let conn = self.select_connection(agent_id)?;
1299
1300 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1301 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1302 })?;
1303
1304 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1305 conn.touch();
1306
1307 let result = conn
1308 .client
1309 .send_response_headers(correlation_id, event)
1310 .await;
1311
1312 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1313 conn.request_count.fetch_add(1, Ordering::Relaxed);
1314
1315 match &result {
1316 Ok(_) => {
1317 conn.consecutive_errors.store(0, Ordering::Relaxed);
1318 }
1319 Err(_) => {
1320 conn.error_count.fetch_add(1, Ordering::Relaxed);
1321 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1322 self.total_errors.fetch_add(1, Ordering::Relaxed);
1323 if consecutive >= 3 {
1324 conn.healthy_cached.store(false, Ordering::Release);
1325 }
1326 }
1327 }
1328
1329 result
1330 }
1331
1332 pub async fn send_response_body_chunk(
1337 &self,
1338 agent_id: &str,
1339 correlation_id: &str,
1340 event: &ResponseBodyChunkEvent,
1341 ) -> Result<AgentResponse, AgentProtocolError> {
1342 self.total_requests.fetch_add(1, Ordering::Relaxed);
1343
1344 let conn = self.select_connection(agent_id)?;
1345
1346 match self.check_flow_control(&conn, agent_id).await {
1348 Ok(true) => {} Ok(false) => {
1350 return Ok(AgentResponse::default_allow());
1352 }
1353 Err(e) => return Err(e),
1354 }
1355
1356 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1357 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1358 })?;
1359
1360 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1361 conn.touch();
1362
1363 let result = conn
1364 .client
1365 .send_response_body_chunk(correlation_id, event)
1366 .await;
1367
1368 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1369 conn.request_count.fetch_add(1, Ordering::Relaxed);
1370
1371 match &result {
1372 Ok(_) => {
1373 conn.consecutive_errors.store(0, Ordering::Relaxed);
1374 }
1375 Err(_) => {
1376 conn.error_count.fetch_add(1, Ordering::Relaxed);
1377 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1378 self.total_errors.fetch_add(1, Ordering::Relaxed);
1379 if consecutive >= 3 {
1380 conn.healthy_cached.store(false, Ordering::Release);
1381 }
1382 }
1383 }
1384
1385 result
1386 }
1387
1388 pub async fn cancel_request(
1390 &self,
1391 agent_id: &str,
1392 correlation_id: &str,
1393 reason: CancelReason,
1394 ) -> Result<(), AgentProtocolError> {
1395 let entry = self.agents.get(agent_id).ok_or_else(|| {
1396 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1397 })?;
1398
1399 let connections = entry.connections.read().await;
1400 for conn in connections.iter() {
1401 let _ = conn.client.cancel_request(correlation_id, reason).await;
1402 }
1403
1404 Ok(())
1405 }
1406
1407 pub async fn stats(&self) -> Vec<AgentPoolStats> {
1409 let mut stats = Vec::with_capacity(self.agents.len());
1410
1411 for entry_ref in self.agents.iter() {
1412 let agent_id = entry_ref.key().clone();
1413 let entry = entry_ref.value();
1414
1415 let connections = entry.connections.read().await;
1416 let mut healthy_count = 0;
1417 let mut total_in_flight = 0;
1418 let mut total_requests = 0;
1419 let mut total_errors = 0;
1420
1421 for conn in connections.iter() {
1422 if conn.is_healthy_cached() {
1424 healthy_count += 1;
1425 }
1426 total_in_flight += conn.in_flight();
1427 total_requests += conn.request_count.load(Ordering::Relaxed);
1428 total_errors += conn.error_count.load(Ordering::Relaxed);
1429 }
1430
1431 let error_rate = if total_requests == 0 {
1432 0.0
1433 } else {
1434 total_errors as f64 / total_requests as f64
1435 };
1436
1437 stats.push(AgentPoolStats {
1438 agent_id,
1439 active_connections: connections.len(),
1440 healthy_connections: healthy_count,
1441 total_in_flight,
1442 total_requests,
1443 total_errors,
1444 error_rate,
1445 is_healthy: entry.healthy.load(Ordering::Acquire),
1446 });
1447 }
1448
1449 stats
1450 }
1451
1452 pub async fn agent_stats(&self, agent_id: &str) -> Option<AgentPoolStats> {
1454 self.stats()
1455 .await
1456 .into_iter()
1457 .find(|s| s.agent_id == agent_id)
1458 }
1459
1460 pub async fn agent_capabilities(&self, agent_id: &str) -> Option<AgentCapabilities> {
1462 let entry = match self.agents.get(agent_id) {
1464 Some(entry_ref) => Arc::clone(&*entry_ref),
1465 None => return None,
1466 };
1467 let result = entry.capabilities.read().await.clone();
1469 result
1470 }
1471
1472 pub fn is_agent_healthy(&self, agent_id: &str) -> bool {
1476 self.agents
1477 .get(agent_id)
1478 .map(|e| e.healthy.load(Ordering::Acquire))
1479 .unwrap_or(false)
1480 }
1481
1482 pub fn agent_ids(&self) -> Vec<String> {
1484 self.agents.iter().map(|e| e.key().clone()).collect()
1485 }
1486
1487 pub async fn shutdown(&self) -> Result<(), AgentProtocolError> {
1491 info!("Shutting down agent pool");
1492
1493 let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1495
1496 for agent_id in agent_ids {
1497 if let Some((_, entry)) = self.agents.remove(&agent_id) {
1498 debug!(agent_id = %agent_id, "Draining agent connections");
1499
1500 let connections = entry.connections.read().await;
1501 for conn in connections.iter() {
1502 let _ = conn.client.cancel_all(CancelReason::ProxyShutdown).await;
1504 }
1505
1506 let drain_deadline = Instant::now() + self.config.drain_timeout;
1508 loop {
1509 let total_in_flight: u64 = connections.iter().map(|c| c.in_flight()).sum();
1510 if total_in_flight == 0 {
1511 break;
1512 }
1513 if Instant::now() > drain_deadline {
1514 warn!(
1515 agent_id = %agent_id,
1516 in_flight = total_in_flight,
1517 "Drain timeout, forcing close"
1518 );
1519 break;
1520 }
1521 tokio::time::sleep(Duration::from_millis(100)).await;
1522 }
1523
1524 for conn in connections.iter() {
1526 let _ = conn.client.close().await;
1527 }
1528 }
1529 }
1530
1531 info!("Agent pool shutdown complete");
1532 Ok(())
1533 }
1534
1535 pub async fn run_maintenance(&self) {
1547 let mut interval = tokio::time::interval(self.config.health_check_interval);
1548
1549 loop {
1550 interval.tick().await;
1551
1552 self.cleanup_expired_sessions();
1554
1555 let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1557
1558 for agent_id in agent_ids {
1559 let Some(entry_ref) = self.agents.get(&agent_id) else {
1560 continue; };
1562 let entry = entry_ref.value().clone();
1563 drop(entry_ref); let connections = entry.connections.read().await;
1567 let mut healthy_count = 0;
1568
1569 for conn in connections.iter() {
1570 if conn.check_and_update_health().await {
1572 healthy_count += 1;
1573 }
1574 }
1575
1576 let was_healthy = entry.healthy.load(Ordering::Acquire);
1578 let is_healthy = healthy_count > 0;
1579 entry.healthy.store(is_healthy, Ordering::Release);
1580
1581 if was_healthy && !is_healthy {
1582 warn!(agent_id = %agent_id, "Agent marked unhealthy");
1583 } else if !was_healthy && is_healthy {
1584 info!(agent_id = %agent_id, "Agent recovered");
1585 }
1586
1587 if healthy_count < self.config.connections_per_agent
1589 && entry.should_reconnect(self.config.reconnect_interval)
1590 {
1591 drop(connections); if let Err(e) = self.reconnect_agent(&agent_id, &entry).await {
1593 trace!(agent_id = %agent_id, error = %e, "Reconnect failed");
1594 }
1595 }
1596 }
1597 }
1598 }
1599
1600 async fn create_connection(
1605 &self,
1606 agent_id: &str,
1607 endpoint: &str,
1608 ) -> Result<PooledConnection, AgentProtocolError> {
1609 let transport = if is_uds_endpoint(endpoint) {
1611 let socket_path = endpoint.strip_prefix("unix:").unwrap_or(endpoint);
1613
1614 let mut client =
1615 AgentClientV2Uds::new(agent_id, socket_path, self.config.request_timeout).await?;
1616
1617 client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1619 client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1620
1621 client.connect().await?;
1622 V2Transport::Uds(client)
1623 } else {
1624 let mut client =
1626 AgentClientV2::new(agent_id, endpoint, self.config.request_timeout).await?;
1627
1628 client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1630 client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1631
1632 client.connect().await?;
1633 V2Transport::Grpc(client)
1634 };
1635
1636 Ok(PooledConnection::new(
1637 transport,
1638 self.config.max_concurrent_per_connection,
1639 ))
1640 }
1641
1642 fn select_connection(
1656 &self,
1657 agent_id: &str,
1658 ) -> Result<Arc<PooledConnection>, AgentProtocolError> {
1659 let entry = self.agents.get(agent_id).ok_or_else(|| {
1660 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1661 })?;
1662
1663 let connections_guard = match entry.connections.try_read() {
1665 Ok(guard) => guard,
1666 Err(_) => {
1667 trace!(agent_id = %agent_id, "select_connection: blocking on connections lock");
1669 futures::executor::block_on(entry.connections.read())
1670 }
1671 };
1672
1673 if connections_guard.is_empty() {
1674 return Err(AgentProtocolError::ConnectionFailed(format!(
1675 "No connections available for agent {}",
1676 agent_id
1677 )));
1678 }
1679
1680 let healthy: Vec<_> = connections_guard
1682 .iter()
1683 .filter(|c| c.is_healthy_cached())
1684 .cloned()
1685 .collect();
1686
1687 if healthy.is_empty() {
1688 return Err(AgentProtocolError::ConnectionFailed(format!(
1689 "No healthy connections for agent {}",
1690 agent_id
1691 )));
1692 }
1693
1694 let selected = match self.config.load_balance_strategy {
1695 LoadBalanceStrategy::RoundRobin => {
1696 let idx = entry.round_robin_index.fetch_add(1, Ordering::Relaxed);
1697 healthy[idx % healthy.len()].clone()
1698 }
1699 LoadBalanceStrategy::LeastConnections => healthy
1700 .iter()
1701 .min_by_key(|c| c.in_flight())
1702 .cloned()
1703 .unwrap(),
1704 LoadBalanceStrategy::HealthBased => {
1705 healthy
1707 .iter()
1708 .min_by(|a, b| {
1709 a.error_rate()
1710 .partial_cmp(&b.error_rate())
1711 .unwrap_or(std::cmp::Ordering::Equal)
1712 })
1713 .cloned()
1714 .unwrap()
1715 }
1716 LoadBalanceStrategy::Random => {
1717 use std::collections::hash_map::RandomState;
1718 use std::hash::{BuildHasher, Hasher};
1719 let idx = RandomState::new().build_hasher().finish() as usize % healthy.len();
1720 healthy[idx].clone()
1721 }
1722 };
1723
1724 Ok(selected)
1725 }
1726
1727 async fn reconnect_agent(
1728 &self,
1729 agent_id: &str,
1730 entry: &AgentEntry,
1731 ) -> Result<(), AgentProtocolError> {
1732 entry.mark_reconnect_attempt();
1733 let attempts = entry.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
1734
1735 if attempts >= self.config.max_reconnect_attempts {
1736 debug!(
1737 agent_id = %agent_id,
1738 attempts = attempts,
1739 "Max reconnect attempts reached"
1740 );
1741 return Ok(());
1742 }
1743
1744 debug!(agent_id = %agent_id, attempt = attempts + 1, "Attempting reconnect");
1745
1746 match self.create_connection(agent_id, &entry.endpoint).await {
1747 Ok(conn) => {
1748 let mut connections = entry.connections.write().await;
1749 connections.push(Arc::new(conn));
1750 entry.reconnect_attempts.store(0, Ordering::Relaxed);
1751 info!(agent_id = %agent_id, "Reconnected successfully");
1752 Ok(())
1753 }
1754 Err(e) => {
1755 debug!(agent_id = %agent_id, error = %e, "Reconnect failed");
1756 Err(e)
1757 }
1758 }
1759 }
1760}
1761
1762impl Default for AgentPool {
1763 fn default() -> Self {
1764 Self::new()
1765 }
1766}
1767
1768impl std::fmt::Debug for AgentPool {
1769 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1770 f.debug_struct("AgentPool")
1771 .field("config", &self.config)
1772 .field(
1773 "total_requests",
1774 &self.total_requests.load(Ordering::Relaxed),
1775 )
1776 .field("total_errors", &self.total_errors.load(Ordering::Relaxed))
1777 .finish()
1778 }
1779}
1780
1781fn is_uds_endpoint(endpoint: &str) -> bool {
1788 endpoint.starts_with("unix:") || endpoint.starts_with('/') || endpoint.ends_with(".sock")
1789}
1790
1791#[cfg(test)]
1792mod tests {
1793 use super::*;
1794
1795 #[test]
1796 fn test_pool_config_default() {
1797 let config = AgentPoolConfig::default();
1798 assert_eq!(config.connections_per_agent, 4);
1799 assert_eq!(
1800 config.load_balance_strategy,
1801 LoadBalanceStrategy::RoundRobin
1802 );
1803 }
1804
1805 #[test]
1806 fn test_load_balance_strategy() {
1807 assert_eq!(
1808 LoadBalanceStrategy::default(),
1809 LoadBalanceStrategy::RoundRobin
1810 );
1811 }
1812
1813 #[test]
1814 fn test_pool_creation() {
1815 let pool = AgentPool::new();
1816 assert_eq!(pool.total_requests.load(Ordering::Relaxed), 0);
1817 assert_eq!(pool.total_errors.load(Ordering::Relaxed), 0);
1818 }
1819
1820 #[test]
1821 fn test_pool_with_config() {
1822 let config = AgentPoolConfig {
1823 connections_per_agent: 8,
1824 load_balance_strategy: LoadBalanceStrategy::LeastConnections,
1825 ..Default::default()
1826 };
1827 let pool = AgentPool::with_config(config.clone());
1828 assert_eq!(pool.config.connections_per_agent, 8);
1829 }
1830
1831 #[test]
1832 fn test_agent_ids_empty() {
1833 let pool = AgentPool::new();
1834 assert!(pool.agent_ids().is_empty());
1835 }
1836
1837 #[test]
1838 fn test_is_agent_healthy_not_found() {
1839 let pool = AgentPool::new();
1840 assert!(!pool.is_agent_healthy("nonexistent"));
1841 }
1842
1843 #[tokio::test]
1844 async fn test_stats_empty() {
1845 let pool = AgentPool::new();
1846 assert!(pool.stats().await.is_empty());
1847 }
1848
1849 #[test]
1850 fn test_is_uds_endpoint() {
1851 assert!(is_uds_endpoint("unix:/var/run/agent.sock"));
1853 assert!(is_uds_endpoint("unix:agent.sock"));
1854
1855 assert!(is_uds_endpoint("/var/run/agent.sock"));
1857 assert!(is_uds_endpoint("/tmp/test.sock"));
1858
1859 assert!(is_uds_endpoint("agent.sock"));
1861
1862 assert!(!is_uds_endpoint("http://localhost:8080"));
1864 assert!(!is_uds_endpoint("localhost:50051"));
1865 assert!(!is_uds_endpoint("127.0.0.1:8080"));
1866 }
1867
1868 #[test]
1869 fn test_flow_control_mode_default() {
1870 assert_eq!(FlowControlMode::default(), FlowControlMode::FailClosed);
1871 }
1872
1873 #[test]
1874 fn test_pool_config_flow_control_defaults() {
1875 let config = AgentPoolConfig::default();
1876 assert_eq!(config.channel_buffer_size, CHANNEL_BUFFER_SIZE);
1877 assert_eq!(config.flow_control_mode, FlowControlMode::FailClosed);
1878 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(100));
1879 }
1880
1881 #[test]
1882 fn test_pool_config_custom_flow_control() {
1883 let config = AgentPoolConfig {
1884 channel_buffer_size: 128,
1885 flow_control_mode: FlowControlMode::FailOpen,
1886 flow_control_wait_timeout: Duration::from_millis(500),
1887 ..Default::default()
1888 };
1889 assert_eq!(config.channel_buffer_size, 128);
1890 assert_eq!(config.flow_control_mode, FlowControlMode::FailOpen);
1891 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(500));
1892 }
1893
1894 #[test]
1895 fn test_pool_config_wait_and_retry() {
1896 let config = AgentPoolConfig {
1897 flow_control_mode: FlowControlMode::WaitAndRetry,
1898 flow_control_wait_timeout: Duration::from_millis(250),
1899 ..Default::default()
1900 };
1901 assert_eq!(config.flow_control_mode, FlowControlMode::WaitAndRetry);
1902 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(250));
1903 }
1904
1905 #[test]
1906 fn test_pool_config_sticky_session_default() {
1907 let config = AgentPoolConfig::default();
1908 assert_eq!(
1909 config.sticky_session_timeout,
1910 Some(Duration::from_secs(5 * 60))
1911 );
1912 }
1913
1914 #[test]
1915 fn test_pool_config_sticky_session_custom() {
1916 let config = AgentPoolConfig {
1917 sticky_session_timeout: Some(Duration::from_secs(60)),
1918 ..Default::default()
1919 };
1920 assert_eq!(config.sticky_session_timeout, Some(Duration::from_secs(60)));
1921 }
1922
1923 #[test]
1924 fn test_pool_config_sticky_session_disabled() {
1925 let config = AgentPoolConfig {
1926 sticky_session_timeout: None,
1927 ..Default::default()
1928 };
1929 assert!(config.sticky_session_timeout.is_none());
1930 }
1931
1932 #[test]
1933 fn test_sticky_session_count_empty() {
1934 let pool = AgentPool::new();
1935 assert_eq!(pool.sticky_session_count(), 0);
1936 }
1937
1938 #[test]
1939 fn test_sticky_session_has_nonexistent() {
1940 let pool = AgentPool::new();
1941 assert!(!pool.has_sticky_session("nonexistent"));
1942 }
1943
1944 #[test]
1945 fn test_sticky_session_clear_nonexistent() {
1946 let pool = AgentPool::new();
1947 pool.clear_sticky_session("nonexistent");
1949 }
1950
1951 #[test]
1952 fn test_cleanup_expired_sessions_empty() {
1953 let pool = AgentPool::new();
1954 let removed = pool.cleanup_expired_sessions();
1955 assert_eq!(removed, 0);
1956 }
1957}