1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use pingora_timeout::timeout;
9use sentinel_agent_protocol::{
10 AgentResponse, EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseBodyChunkEvent,
11 ResponseHeadersEvent, WebSocketFrameEvent,
12};
13use sentinel_common::{
14 errors::{SentinelError, SentinelResult},
15 types::CircuitBreakerConfig,
16 CircuitBreaker,
17};
18use sentinel_config::{AgentConfig, FailureMode};
19use tokio::sync::{RwLock, Semaphore};
20use tracing::{debug, error, info, trace, warn};
21
22use super::agent::Agent;
23use super::context::AgentCallContext;
24use super::decision::AgentDecision;
25use super::metrics::AgentMetrics;
26use super::pool::AgentConnectionPool;
27
28pub struct AgentManager {
30 agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
32 connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
34 circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
36 metrics: Arc<AgentMetrics>,
38 #[allow(dead_code)]
40 max_concurrent_calls: usize,
41 call_semaphore: Arc<Semaphore>,
43}
44
45impl AgentManager {
46 pub async fn new(
48 agents: Vec<AgentConfig>,
49 max_concurrent_calls: usize,
50 ) -> SentinelResult<Self> {
51 info!(
52 agent_count = agents.len(),
53 max_concurrent_calls = max_concurrent_calls,
54 "Creating agent manager"
55 );
56
57 let mut agent_map = HashMap::new();
58 let mut pools = HashMap::new();
59 let mut breakers = HashMap::new();
60
61 for config in agents {
62 debug!(
63 agent_id = %config.id,
64 transport = ?config.transport,
65 timeout_ms = config.timeout_ms,
66 failure_mode = ?config.failure_mode,
67 "Configuring agent"
68 );
69
70 let pool = Arc::new(AgentConnectionPool::new(
71 10, 2, 5, Duration::from_secs(60),
75 ));
76
77 let circuit_breaker = Arc::new(CircuitBreaker::new(
78 config
79 .circuit_breaker
80 .clone()
81 .unwrap_or_else(CircuitBreakerConfig::default),
82 ));
83
84 trace!(
85 agent_id = %config.id,
86 "Creating agent instance"
87 );
88
89 let agent = Arc::new(Agent::new(
90 config.clone(),
91 Arc::clone(&pool),
92 Arc::clone(&circuit_breaker),
93 ));
94
95 agent_map.insert(config.id.clone(), agent);
96 pools.insert(config.id.clone(), pool);
97 breakers.insert(config.id.clone(), circuit_breaker);
98
99 debug!(
100 agent_id = %config.id,
101 "Agent configured successfully"
102 );
103 }
104
105 info!(
106 configured_agents = agent_map.len(),
107 "Agent manager created successfully"
108 );
109
110 Ok(Self {
111 agents: Arc::new(RwLock::new(agent_map)),
112 connection_pools: Arc::new(RwLock::new(pools)),
113 circuit_breakers: Arc::new(RwLock::new(breakers)),
114 metrics: Arc::new(AgentMetrics::default()),
115 max_concurrent_calls,
116 call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
117 })
118 }
119
120 pub async fn process_request_headers(
127 &self,
128 ctx: &AgentCallContext,
129 headers: &HashMap<String, Vec<String>>,
130 route_agents: &[(String, FailureMode)],
131 ) -> SentinelResult<AgentDecision> {
132 let event = RequestHeadersEvent {
133 metadata: ctx.metadata.clone(),
134 method: headers
135 .get(":method")
136 .and_then(|v| v.first())
137 .unwrap_or(&"GET".to_string())
138 .clone(),
139 uri: headers
140 .get(":path")
141 .and_then(|v| v.first())
142 .unwrap_or(&"/".to_string())
143 .clone(),
144 headers: headers.clone(),
145 };
146
147 self.process_event_with_failure_modes(EventType::RequestHeaders, &event, route_agents, ctx)
148 .await
149 }
150
151 pub async fn process_request_body(
153 &self,
154 ctx: &AgentCallContext,
155 data: &[u8],
156 is_last: bool,
157 route_agents: &[String],
158 ) -> SentinelResult<AgentDecision> {
159 let max_size = 1024 * 1024; if data.len() > max_size {
162 warn!(
163 correlation_id = %ctx.correlation_id,
164 size = data.len(),
165 "Request body exceeds agent inspection limit"
166 );
167 return Ok(AgentDecision::default_allow());
168 }
169
170 let event = RequestBodyChunkEvent {
171 correlation_id: ctx.correlation_id.to_string(),
172 data: STANDARD.encode(data),
173 is_last,
174 total_size: ctx.request_body.as_ref().map(|b| b.len()),
175 chunk_index: 0, bytes_received: data.len(),
177 };
178
179 self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
180 .await
181 }
182
183 pub async fn process_request_body_streaming(
188 &self,
189 ctx: &AgentCallContext,
190 data: &[u8],
191 is_last: bool,
192 chunk_index: u32,
193 bytes_received: usize,
194 total_size: Option<usize>,
195 route_agents: &[String],
196 ) -> SentinelResult<AgentDecision> {
197 trace!(
198 correlation_id = %ctx.correlation_id,
199 chunk_index = chunk_index,
200 chunk_size = data.len(),
201 bytes_received = bytes_received,
202 is_last = is_last,
203 "Processing streaming request body chunk"
204 );
205
206 let event = RequestBodyChunkEvent {
207 correlation_id: ctx.correlation_id.to_string(),
208 data: STANDARD.encode(data),
209 is_last,
210 total_size,
211 chunk_index,
212 bytes_received,
213 };
214
215 self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
216 .await
217 }
218
219 pub async fn process_response_body_streaming(
221 &self,
222 ctx: &AgentCallContext,
223 data: &[u8],
224 is_last: bool,
225 chunk_index: u32,
226 bytes_sent: usize,
227 total_size: Option<usize>,
228 route_agents: &[String],
229 ) -> SentinelResult<AgentDecision> {
230 trace!(
231 correlation_id = %ctx.correlation_id,
232 chunk_index = chunk_index,
233 chunk_size = data.len(),
234 bytes_sent = bytes_sent,
235 is_last = is_last,
236 "Processing streaming response body chunk"
237 );
238
239 let event = ResponseBodyChunkEvent {
240 correlation_id: ctx.correlation_id.to_string(),
241 data: STANDARD.encode(data),
242 is_last,
243 total_size,
244 chunk_index,
245 bytes_sent,
246 };
247
248 self.process_event(EventType::ResponseBodyChunk, &event, route_agents, ctx)
249 .await
250 }
251
252 pub async fn process_response_headers(
254 &self,
255 ctx: &AgentCallContext,
256 status: u16,
257 headers: &HashMap<String, Vec<String>>,
258 route_agents: &[String],
259 ) -> SentinelResult<AgentDecision> {
260 let event = ResponseHeadersEvent {
261 correlation_id: ctx.correlation_id.to_string(),
262 status,
263 headers: headers.clone(),
264 };
265
266 self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
267 .await
268 }
269
270 pub async fn process_websocket_frame(
276 &self,
277 route_id: &str,
278 event: WebSocketFrameEvent,
279 ) -> SentinelResult<AgentResponse> {
280 trace!(
281 correlation_id = %event.correlation_id,
282 route_id = %route_id,
283 frame_index = event.frame_index,
284 opcode = %event.opcode,
285 "Processing WebSocket frame through agents"
286 );
287
288 let agents = self.agents.read().await;
290 let relevant_agents: Vec<_> = agents
291 .values()
292 .filter(|agent| agent.handles_event(EventType::WebSocketFrame))
293 .collect();
294
295 if relevant_agents.is_empty() {
296 trace!(
297 correlation_id = %event.correlation_id,
298 "No agents handle WebSocket frames, allowing"
299 );
300 return Ok(AgentResponse::websocket_allow());
301 }
302
303 debug!(
304 correlation_id = %event.correlation_id,
305 route_id = %route_id,
306 agent_count = relevant_agents.len(),
307 "Processing WebSocket frame through agents"
308 );
309
310 for agent in relevant_agents {
312 if !agent.circuit_breaker().is_closed().await {
314 warn!(
315 agent_id = %agent.id(),
316 correlation_id = %event.correlation_id,
317 failure_mode = ?agent.failure_mode(),
318 "Circuit breaker open, skipping agent for WebSocket frame"
319 );
320
321 if agent.failure_mode() == FailureMode::Closed {
322 debug!(
323 correlation_id = %event.correlation_id,
324 agent_id = %agent.id(),
325 "Closing WebSocket due to circuit breaker (fail-closed mode)"
326 );
327 return Ok(AgentResponse::websocket_close(
328 1011,
329 "Service unavailable".to_string(),
330 ));
331 }
332 continue;
333 }
334
335 let start = Instant::now();
337 let timeout_duration = Duration::from_millis(agent.timeout_ms());
338
339 match timeout(
340 timeout_duration,
341 agent.call_event(EventType::WebSocketFrame, &event),
342 )
343 .await
344 {
345 Ok(Ok(response)) => {
346 let duration = start.elapsed();
347 agent.record_success(duration).await;
348
349 trace!(
350 correlation_id = %event.correlation_id,
351 agent_id = %agent.id(),
352 duration_ms = duration.as_millis(),
353 "WebSocket frame agent call succeeded"
354 );
355
356 if let Some(ref ws_decision) = response.websocket_decision {
358 if !matches!(
359 ws_decision,
360 sentinel_agent_protocol::WebSocketDecision::Allow
361 ) {
362 debug!(
363 correlation_id = %event.correlation_id,
364 agent_id = %agent.id(),
365 decision = ?ws_decision,
366 "Agent returned non-allow WebSocket decision"
367 );
368 return Ok(response);
369 }
370 }
371 }
372 Ok(Err(e)) => {
373 agent.record_failure().await;
374 error!(
375 agent_id = %agent.id(),
376 correlation_id = %event.correlation_id,
377 error = %e,
378 duration_ms = start.elapsed().as_millis(),
379 failure_mode = ?agent.failure_mode(),
380 "WebSocket frame agent call failed"
381 );
382
383 if agent.failure_mode() == FailureMode::Closed {
384 return Ok(AgentResponse::websocket_close(
385 1011,
386 "Agent error".to_string(),
387 ));
388 }
389 }
390 Err(_) => {
391 agent.record_timeout().await;
392 warn!(
393 agent_id = %agent.id(),
394 correlation_id = %event.correlation_id,
395 timeout_ms = agent.timeout_ms(),
396 failure_mode = ?agent.failure_mode(),
397 "WebSocket frame agent call timed out"
398 );
399
400 if agent.failure_mode() == FailureMode::Closed {
401 return Ok(AgentResponse::websocket_close(
402 1011,
403 "Gateway timeout".to_string(),
404 ));
405 }
406 }
407 }
408 }
409
410 Ok(AgentResponse::websocket_allow())
412 }
413
414 async fn process_event<T: serde::Serialize>(
416 &self,
417 event_type: EventType,
418 event: &T,
419 route_agents: &[String],
420 ctx: &AgentCallContext,
421 ) -> SentinelResult<AgentDecision> {
422 trace!(
423 correlation_id = %ctx.correlation_id,
424 event_type = ?event_type,
425 route_agents = ?route_agents,
426 "Starting agent event processing"
427 );
428
429 let agents = self.agents.read().await;
431 let relevant_agents: Vec<_> = route_agents
432 .iter()
433 .filter_map(|id| agents.get(id))
434 .filter(|agent| agent.handles_event(event_type))
435 .collect();
436
437 if relevant_agents.is_empty() {
438 trace!(
439 correlation_id = %ctx.correlation_id,
440 event_type = ?event_type,
441 "No relevant agents for event, allowing request"
442 );
443 return Ok(AgentDecision::default_allow());
444 }
445
446 debug!(
447 correlation_id = %ctx.correlation_id,
448 event_type = ?event_type,
449 agent_count = relevant_agents.len(),
450 agent_ids = ?relevant_agents.iter().map(|a| a.id()).collect::<Vec<_>>(),
451 "Processing event through agents"
452 );
453
454 let mut combined_decision = AgentDecision::default_allow();
456
457 for (agent_index, agent) in relevant_agents.iter().enumerate() {
458 trace!(
459 correlation_id = %ctx.correlation_id,
460 agent_id = %agent.id(),
461 agent_index = agent_index,
462 event_type = ?event_type,
463 "Processing event through agent"
464 );
465
466 trace!(
468 correlation_id = %ctx.correlation_id,
469 agent_id = %agent.id(),
470 "Acquiring agent call semaphore permit"
471 );
472 let _permit = self.call_semaphore.acquire().await.map_err(|_| {
473 error!(
474 correlation_id = %ctx.correlation_id,
475 agent_id = %agent.id(),
476 "Failed to acquire agent call semaphore permit"
477 );
478 SentinelError::Internal {
479 message: "Failed to acquire agent call permit".to_string(),
480 correlation_id: Some(ctx.correlation_id.to_string()),
481 source: None,
482 }
483 })?;
484
485 if !agent.circuit_breaker().is_closed().await {
487 warn!(
488 agent_id = %agent.id(),
489 correlation_id = %ctx.correlation_id,
490 failure_mode = ?agent.failure_mode(),
491 "Circuit breaker open, skipping agent"
492 );
493
494 if agent.failure_mode() == FailureMode::Closed {
496 debug!(
497 correlation_id = %ctx.correlation_id,
498 agent_id = %agent.id(),
499 "Blocking request due to circuit breaker (fail-closed mode)"
500 );
501 return Ok(AgentDecision::block(503, "Service unavailable"));
502 }
503 continue;
504 }
505
506 let start = Instant::now();
508 let timeout_duration = Duration::from_millis(agent.timeout_ms());
509
510 trace!(
511 correlation_id = %ctx.correlation_id,
512 agent_id = %agent.id(),
513 timeout_ms = agent.timeout_ms(),
514 "Calling agent"
515 );
516
517 match timeout(timeout_duration, agent.call_event(event_type, event)).await {
518 Ok(Ok(response)) => {
519 let duration = start.elapsed();
520 agent.record_success(duration).await;
521
522 trace!(
523 correlation_id = %ctx.correlation_id,
524 agent_id = %agent.id(),
525 duration_ms = duration.as_millis(),
526 decision = ?response,
527 "Agent call succeeded"
528 );
529
530 combined_decision.merge(response.into());
532
533 if !combined_decision.is_allow() {
535 debug!(
536 correlation_id = %ctx.correlation_id,
537 agent_id = %agent.id(),
538 decision = ?combined_decision,
539 "Agent returned blocking decision, stopping agent chain"
540 );
541 break;
542 }
543 }
544 Ok(Err(e)) => {
545 agent.record_failure().await;
546 error!(
547 agent_id = %agent.id(),
548 correlation_id = %ctx.correlation_id,
549 error = %e,
550 duration_ms = start.elapsed().as_millis(),
551 failure_mode = ?agent.failure_mode(),
552 "Agent call failed"
553 );
554
555 if agent.failure_mode() == FailureMode::Closed {
556 return Err(e);
557 }
558 }
559 Err(_) => {
560 agent.record_timeout().await;
561 warn!(
562 agent_id = %agent.id(),
563 correlation_id = %ctx.correlation_id,
564 timeout_ms = agent.timeout_ms(),
565 failure_mode = ?agent.failure_mode(),
566 "Agent call timed out"
567 );
568
569 if agent.failure_mode() == FailureMode::Closed {
570 debug!(
571 correlation_id = %ctx.correlation_id,
572 agent_id = %agent.id(),
573 "Blocking request due to timeout (fail-closed mode)"
574 );
575 return Ok(AgentDecision::block(504, "Gateway timeout"));
576 }
577 }
578 }
579 }
580
581 trace!(
582 correlation_id = %ctx.correlation_id,
583 decision = ?combined_decision,
584 agents_processed = relevant_agents.len(),
585 "Agent event processing completed"
586 );
587
588 Ok(combined_decision)
589 }
590
591 async fn process_event_with_failure_modes<T: serde::Serialize>(
596 &self,
597 event_type: EventType,
598 event: &T,
599 route_agents: &[(String, FailureMode)],
600 ctx: &AgentCallContext,
601 ) -> SentinelResult<AgentDecision> {
602 trace!(
603 correlation_id = %ctx.correlation_id,
604 event_type = ?event_type,
605 route_agents = ?route_agents.iter().map(|(id, _)| id).collect::<Vec<_>>(),
606 "Starting agent event processing with failure modes"
607 );
608
609 let agents = self.agents.read().await;
611 let relevant_agents: Vec<_> = route_agents
612 .iter()
613 .filter_map(|(id, failure_mode)| {
614 agents.get(id).map(|agent| (agent, *failure_mode))
615 })
616 .filter(|(agent, _)| agent.handles_event(event_type))
617 .collect();
618
619 if relevant_agents.is_empty() {
620 trace!(
621 correlation_id = %ctx.correlation_id,
622 event_type = ?event_type,
623 "No relevant agents for event, allowing request"
624 );
625 return Ok(AgentDecision::default_allow());
626 }
627
628 debug!(
629 correlation_id = %ctx.correlation_id,
630 event_type = ?event_type,
631 agent_count = relevant_agents.len(),
632 agent_ids = ?relevant_agents.iter().map(|(a, _)| a.id()).collect::<Vec<_>>(),
633 "Processing event through agents"
634 );
635
636 let mut combined_decision = AgentDecision::default_allow();
638
639 for (agent_index, (agent, filter_failure_mode)) in relevant_agents.iter().enumerate() {
640 trace!(
641 correlation_id = %ctx.correlation_id,
642 agent_id = %agent.id(),
643 agent_index = agent_index,
644 event_type = ?event_type,
645 filter_failure_mode = ?filter_failure_mode,
646 "Processing event through agent with filter failure mode"
647 );
648
649 let _permit = self.call_semaphore.acquire().await.map_err(|_| {
651 error!(
652 correlation_id = %ctx.correlation_id,
653 agent_id = %agent.id(),
654 "Failed to acquire agent call semaphore permit"
655 );
656 SentinelError::Internal {
657 message: "Failed to acquire agent call permit".to_string(),
658 correlation_id: Some(ctx.correlation_id.to_string()),
659 source: None,
660 }
661 })?;
662
663 if !agent.circuit_breaker().is_closed().await {
665 warn!(
666 agent_id = %agent.id(),
667 correlation_id = %ctx.correlation_id,
668 filter_failure_mode = ?filter_failure_mode,
669 "Circuit breaker open, skipping agent"
670 );
671
672 if *filter_failure_mode == FailureMode::Closed {
674 debug!(
675 correlation_id = %ctx.correlation_id,
676 agent_id = %agent.id(),
677 "Blocking request due to circuit breaker (filter fail-closed mode)"
678 );
679 return Ok(AgentDecision::block(503, "Service unavailable"));
680 }
681 continue;
683 }
684
685 let start = Instant::now();
687 let timeout_duration = Duration::from_millis(agent.timeout_ms());
688
689 trace!(
690 correlation_id = %ctx.correlation_id,
691 agent_id = %agent.id(),
692 timeout_ms = agent.timeout_ms(),
693 "Calling agent"
694 );
695
696 match timeout(timeout_duration, agent.call_event(event_type, event)).await {
697 Ok(Ok(response)) => {
698 let duration = start.elapsed();
699 agent.record_success(duration).await;
700
701 trace!(
702 correlation_id = %ctx.correlation_id,
703 agent_id = %agent.id(),
704 duration_ms = duration.as_millis(),
705 decision = ?response,
706 "Agent call succeeded"
707 );
708
709 combined_decision.merge(response.into());
711
712 if !combined_decision.is_allow() {
714 debug!(
715 correlation_id = %ctx.correlation_id,
716 agent_id = %agent.id(),
717 decision = ?combined_decision,
718 "Agent returned blocking decision, stopping agent chain"
719 );
720 break;
721 }
722 }
723 Ok(Err(e)) => {
724 agent.record_failure().await;
725 error!(
726 agent_id = %agent.id(),
727 correlation_id = %ctx.correlation_id,
728 error = %e,
729 duration_ms = start.elapsed().as_millis(),
730 filter_failure_mode = ?filter_failure_mode,
731 "Agent call failed"
732 );
733
734 if *filter_failure_mode == FailureMode::Closed {
736 debug!(
737 correlation_id = %ctx.correlation_id,
738 agent_id = %agent.id(),
739 "Blocking request due to agent failure (filter fail-closed mode)"
740 );
741 return Ok(AgentDecision::block(503, "Agent unavailable"));
742 }
743 debug!(
745 correlation_id = %ctx.correlation_id,
746 agent_id = %agent.id(),
747 "Continuing despite agent failure (filter fail-open mode)"
748 );
749 }
750 Err(_) => {
751 agent.record_timeout().await;
752 warn!(
753 agent_id = %agent.id(),
754 correlation_id = %ctx.correlation_id,
755 timeout_ms = agent.timeout_ms(),
756 filter_failure_mode = ?filter_failure_mode,
757 "Agent call timed out"
758 );
759
760 if *filter_failure_mode == FailureMode::Closed {
762 debug!(
763 correlation_id = %ctx.correlation_id,
764 agent_id = %agent.id(),
765 "Blocking request due to timeout (filter fail-closed mode)"
766 );
767 return Ok(AgentDecision::block(504, "Gateway timeout"));
768 }
769 debug!(
771 correlation_id = %ctx.correlation_id,
772 agent_id = %agent.id(),
773 "Continuing despite timeout (filter fail-open mode)"
774 );
775 }
776 }
777 }
778
779 trace!(
780 correlation_id = %ctx.correlation_id,
781 decision = ?combined_decision,
782 agents_processed = relevant_agents.len(),
783 "Agent event processing with failure modes completed"
784 );
785
786 Ok(combined_decision)
787 }
788
789 pub async fn initialize(&self) -> SentinelResult<()> {
791 let agents = self.agents.read().await;
792
793 info!(agent_count = agents.len(), "Initializing agent connections");
794
795 let mut initialized_count = 0;
796 let mut failed_count = 0;
797
798 for (id, agent) in agents.iter() {
799 debug!(agent_id = %id, "Initializing agent connection");
800 if let Err(e) = agent.initialize().await {
801 error!(
802 agent_id = %id,
803 error = %e,
804 "Failed to initialize agent"
805 );
806 failed_count += 1;
807 } else {
809 trace!(agent_id = %id, "Agent initialized successfully");
810 initialized_count += 1;
811 }
812 }
813
814 info!(
815 initialized = initialized_count,
816 failed = failed_count,
817 total = agents.len(),
818 "Agent initialization complete"
819 );
820
821 Ok(())
822 }
823
824 pub async fn shutdown(&self) {
826 let agents = self.agents.read().await;
827
828 info!(agent_count = agents.len(), "Shutting down agent manager");
829
830 for (id, agent) in agents.iter() {
831 debug!(agent_id = %id, "Shutting down agent");
832 agent.shutdown().await;
833 trace!(agent_id = %id, "Agent shutdown complete");
834 }
835
836 info!("Agent manager shutdown complete");
837 }
838
839 pub fn metrics(&self) -> &AgentMetrics {
841 &self.metrics
842 }
843
844 pub fn get_agents_for_event(&self, event_type: EventType) -> Vec<String> {
849 if let Ok(agents) = self.agents.try_read() {
852 agents
853 .values()
854 .filter(|agent| agent.handles_event(event_type))
855 .map(|agent| agent.id().to_string())
856 .collect()
857 } else {
858 Vec::new()
859 }
860 }
861}