1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use pingora_timeout::timeout;
9use sentinel_agent_protocol::{
10 AgentResponse, EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseBodyChunkEvent,
11 ResponseHeadersEvent, WebSocketFrameEvent,
12};
13use sentinel_common::{
14 errors::{SentinelError, SentinelResult},
15 types::CircuitBreakerConfig,
16 CircuitBreaker,
17};
18use sentinel_config::{AgentConfig, FailureMode};
19use tokio::sync::{RwLock, Semaphore};
20use tracing::{debug, error, info, trace, warn};
21
22use super::agent::Agent;
23use super::context::AgentCallContext;
24use super::decision::AgentDecision;
25use super::metrics::AgentMetrics;
26use super::pool::AgentConnectionPool;
27
28pub struct AgentManager {
30 agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
32 connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
34 circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
36 metrics: Arc<AgentMetrics>,
38 #[allow(dead_code)]
40 max_concurrent_calls: usize,
41 call_semaphore: Arc<Semaphore>,
43}
44
45impl AgentManager {
46 pub async fn new(
48 agents: Vec<AgentConfig>,
49 max_concurrent_calls: usize,
50 ) -> SentinelResult<Self> {
51 info!(
52 agent_count = agents.len(),
53 max_concurrent_calls = max_concurrent_calls,
54 "Creating agent manager"
55 );
56
57 let mut agent_map = HashMap::new();
58 let mut pools = HashMap::new();
59 let mut breakers = HashMap::new();
60
61 for config in agents {
62 debug!(
63 agent_id = %config.id,
64 transport = ?config.transport,
65 timeout_ms = config.timeout_ms,
66 failure_mode = ?config.failure_mode,
67 "Configuring agent"
68 );
69
70 let pool = Arc::new(AgentConnectionPool::new(
71 10, 2, 5, Duration::from_secs(60),
75 ));
76
77 let circuit_breaker = Arc::new(CircuitBreaker::new(
78 config
79 .circuit_breaker
80 .clone()
81 .unwrap_or_else(CircuitBreakerConfig::default),
82 ));
83
84 trace!(
85 agent_id = %config.id,
86 "Creating agent instance"
87 );
88
89 let agent = Arc::new(Agent::new(
90 config.clone(),
91 Arc::clone(&pool),
92 Arc::clone(&circuit_breaker),
93 ));
94
95 agent_map.insert(config.id.clone(), agent);
96 pools.insert(config.id.clone(), pool);
97 breakers.insert(config.id.clone(), circuit_breaker);
98
99 debug!(
100 agent_id = %config.id,
101 "Agent configured successfully"
102 );
103 }
104
105 info!(
106 configured_agents = agent_map.len(),
107 "Agent manager created successfully"
108 );
109
110 Ok(Self {
111 agents: Arc::new(RwLock::new(agent_map)),
112 connection_pools: Arc::new(RwLock::new(pools)),
113 circuit_breakers: Arc::new(RwLock::new(breakers)),
114 metrics: Arc::new(AgentMetrics::default()),
115 max_concurrent_calls,
116 call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
117 })
118 }
119
120 pub async fn process_request_headers(
122 &self,
123 ctx: &AgentCallContext,
124 headers: &HashMap<String, Vec<String>>,
125 route_agents: &[String],
126 ) -> SentinelResult<AgentDecision> {
127 let event = RequestHeadersEvent {
128 metadata: ctx.metadata.clone(),
129 method: headers
130 .get(":method")
131 .and_then(|v| v.first())
132 .unwrap_or(&"GET".to_string())
133 .clone(),
134 uri: headers
135 .get(":path")
136 .and_then(|v| v.first())
137 .unwrap_or(&"/".to_string())
138 .clone(),
139 headers: headers.clone(),
140 };
141
142 self.process_event(EventType::RequestHeaders, &event, route_agents, ctx)
143 .await
144 }
145
146 pub async fn process_request_body(
148 &self,
149 ctx: &AgentCallContext,
150 data: &[u8],
151 is_last: bool,
152 route_agents: &[String],
153 ) -> SentinelResult<AgentDecision> {
154 let max_size = 1024 * 1024; if data.len() > max_size {
157 warn!(
158 correlation_id = %ctx.correlation_id,
159 size = data.len(),
160 "Request body exceeds agent inspection limit"
161 );
162 return Ok(AgentDecision::default_allow());
163 }
164
165 let event = RequestBodyChunkEvent {
166 correlation_id: ctx.correlation_id.to_string(),
167 data: STANDARD.encode(data),
168 is_last,
169 total_size: ctx.request_body.as_ref().map(|b| b.len()),
170 chunk_index: 0, bytes_received: data.len(),
172 };
173
174 self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
175 .await
176 }
177
178 pub async fn process_request_body_streaming(
183 &self,
184 ctx: &AgentCallContext,
185 data: &[u8],
186 is_last: bool,
187 chunk_index: u32,
188 bytes_received: usize,
189 total_size: Option<usize>,
190 route_agents: &[String],
191 ) -> SentinelResult<AgentDecision> {
192 trace!(
193 correlation_id = %ctx.correlation_id,
194 chunk_index = chunk_index,
195 chunk_size = data.len(),
196 bytes_received = bytes_received,
197 is_last = is_last,
198 "Processing streaming request body chunk"
199 );
200
201 let event = RequestBodyChunkEvent {
202 correlation_id: ctx.correlation_id.to_string(),
203 data: STANDARD.encode(data),
204 is_last,
205 total_size,
206 chunk_index,
207 bytes_received,
208 };
209
210 self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
211 .await
212 }
213
214 pub async fn process_response_body_streaming(
216 &self,
217 ctx: &AgentCallContext,
218 data: &[u8],
219 is_last: bool,
220 chunk_index: u32,
221 bytes_sent: usize,
222 total_size: Option<usize>,
223 route_agents: &[String],
224 ) -> SentinelResult<AgentDecision> {
225 trace!(
226 correlation_id = %ctx.correlation_id,
227 chunk_index = chunk_index,
228 chunk_size = data.len(),
229 bytes_sent = bytes_sent,
230 is_last = is_last,
231 "Processing streaming response body chunk"
232 );
233
234 let event = ResponseBodyChunkEvent {
235 correlation_id: ctx.correlation_id.to_string(),
236 data: STANDARD.encode(data),
237 is_last,
238 total_size,
239 chunk_index,
240 bytes_sent,
241 };
242
243 self.process_event(EventType::ResponseBodyChunk, &event, route_agents, ctx)
244 .await
245 }
246
247 pub async fn process_response_headers(
249 &self,
250 ctx: &AgentCallContext,
251 status: u16,
252 headers: &HashMap<String, Vec<String>>,
253 route_agents: &[String],
254 ) -> SentinelResult<AgentDecision> {
255 let event = ResponseHeadersEvent {
256 correlation_id: ctx.correlation_id.to_string(),
257 status,
258 headers: headers.clone(),
259 };
260
261 self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
262 .await
263 }
264
265 pub async fn process_websocket_frame(
271 &self,
272 route_id: &str,
273 event: WebSocketFrameEvent,
274 ) -> SentinelResult<AgentResponse> {
275 trace!(
276 correlation_id = %event.correlation_id,
277 route_id = %route_id,
278 frame_index = event.frame_index,
279 opcode = %event.opcode,
280 "Processing WebSocket frame through agents"
281 );
282
283 let agents = self.agents.read().await;
285 let relevant_agents: Vec<_> = agents
286 .values()
287 .filter(|agent| agent.handles_event(EventType::WebSocketFrame))
288 .collect();
289
290 if relevant_agents.is_empty() {
291 trace!(
292 correlation_id = %event.correlation_id,
293 "No agents handle WebSocket frames, allowing"
294 );
295 return Ok(AgentResponse::websocket_allow());
296 }
297
298 debug!(
299 correlation_id = %event.correlation_id,
300 route_id = %route_id,
301 agent_count = relevant_agents.len(),
302 "Processing WebSocket frame through agents"
303 );
304
305 for agent in relevant_agents {
307 if !agent.circuit_breaker().is_closed().await {
309 warn!(
310 agent_id = %agent.id(),
311 correlation_id = %event.correlation_id,
312 failure_mode = ?agent.failure_mode(),
313 "Circuit breaker open, skipping agent for WebSocket frame"
314 );
315
316 if agent.failure_mode() == FailureMode::Closed {
317 debug!(
318 correlation_id = %event.correlation_id,
319 agent_id = %agent.id(),
320 "Closing WebSocket due to circuit breaker (fail-closed mode)"
321 );
322 return Ok(AgentResponse::websocket_close(
323 1011,
324 "Service unavailable".to_string(),
325 ));
326 }
327 continue;
328 }
329
330 let start = Instant::now();
332 let timeout_duration = Duration::from_millis(agent.timeout_ms());
333
334 match timeout(
335 timeout_duration,
336 agent.call_event(EventType::WebSocketFrame, &event),
337 )
338 .await
339 {
340 Ok(Ok(response)) => {
341 let duration = start.elapsed();
342 agent.record_success(duration).await;
343
344 trace!(
345 correlation_id = %event.correlation_id,
346 agent_id = %agent.id(),
347 duration_ms = duration.as_millis(),
348 "WebSocket frame agent call succeeded"
349 );
350
351 if let Some(ref ws_decision) = response.websocket_decision {
353 if !matches!(
354 ws_decision,
355 sentinel_agent_protocol::WebSocketDecision::Allow
356 ) {
357 debug!(
358 correlation_id = %event.correlation_id,
359 agent_id = %agent.id(),
360 decision = ?ws_decision,
361 "Agent returned non-allow WebSocket decision"
362 );
363 return Ok(response);
364 }
365 }
366 }
367 Ok(Err(e)) => {
368 agent.record_failure().await;
369 error!(
370 agent_id = %agent.id(),
371 correlation_id = %event.correlation_id,
372 error = %e,
373 duration_ms = start.elapsed().as_millis(),
374 failure_mode = ?agent.failure_mode(),
375 "WebSocket frame agent call failed"
376 );
377
378 if agent.failure_mode() == FailureMode::Closed {
379 return Ok(AgentResponse::websocket_close(
380 1011,
381 "Agent error".to_string(),
382 ));
383 }
384 }
385 Err(_) => {
386 agent.record_timeout().await;
387 warn!(
388 agent_id = %agent.id(),
389 correlation_id = %event.correlation_id,
390 timeout_ms = agent.timeout_ms(),
391 failure_mode = ?agent.failure_mode(),
392 "WebSocket frame agent call timed out"
393 );
394
395 if agent.failure_mode() == FailureMode::Closed {
396 return Ok(AgentResponse::websocket_close(
397 1011,
398 "Gateway timeout".to_string(),
399 ));
400 }
401 }
402 }
403 }
404
405 Ok(AgentResponse::websocket_allow())
407 }
408
409 async fn process_event<T: serde::Serialize>(
411 &self,
412 event_type: EventType,
413 event: &T,
414 route_agents: &[String],
415 ctx: &AgentCallContext,
416 ) -> SentinelResult<AgentDecision> {
417 trace!(
418 correlation_id = %ctx.correlation_id,
419 event_type = ?event_type,
420 route_agents = ?route_agents,
421 "Starting agent event processing"
422 );
423
424 let agents = self.agents.read().await;
426 let relevant_agents: Vec<_> = route_agents
427 .iter()
428 .filter_map(|id| agents.get(id))
429 .filter(|agent| agent.handles_event(event_type))
430 .collect();
431
432 if relevant_agents.is_empty() {
433 trace!(
434 correlation_id = %ctx.correlation_id,
435 event_type = ?event_type,
436 "No relevant agents for event, allowing request"
437 );
438 return Ok(AgentDecision::default_allow());
439 }
440
441 debug!(
442 correlation_id = %ctx.correlation_id,
443 event_type = ?event_type,
444 agent_count = relevant_agents.len(),
445 agent_ids = ?relevant_agents.iter().map(|a| a.id()).collect::<Vec<_>>(),
446 "Processing event through agents"
447 );
448
449 let mut combined_decision = AgentDecision::default_allow();
451
452 for (agent_index, agent) in relevant_agents.iter().enumerate() {
453 trace!(
454 correlation_id = %ctx.correlation_id,
455 agent_id = %agent.id(),
456 agent_index = agent_index,
457 event_type = ?event_type,
458 "Processing event through agent"
459 );
460
461 trace!(
463 correlation_id = %ctx.correlation_id,
464 agent_id = %agent.id(),
465 "Acquiring agent call semaphore permit"
466 );
467 let _permit = self.call_semaphore.acquire().await.map_err(|_| {
468 error!(
469 correlation_id = %ctx.correlation_id,
470 agent_id = %agent.id(),
471 "Failed to acquire agent call semaphore permit"
472 );
473 SentinelError::Internal {
474 message: "Failed to acquire agent call permit".to_string(),
475 correlation_id: Some(ctx.correlation_id.to_string()),
476 source: None,
477 }
478 })?;
479
480 if !agent.circuit_breaker().is_closed().await {
482 warn!(
483 agent_id = %agent.id(),
484 correlation_id = %ctx.correlation_id,
485 failure_mode = ?agent.failure_mode(),
486 "Circuit breaker open, skipping agent"
487 );
488
489 if agent.failure_mode() == FailureMode::Closed {
491 debug!(
492 correlation_id = %ctx.correlation_id,
493 agent_id = %agent.id(),
494 "Blocking request due to circuit breaker (fail-closed mode)"
495 );
496 return Ok(AgentDecision::block(503, "Service unavailable"));
497 }
498 continue;
499 }
500
501 let start = Instant::now();
503 let timeout_duration = Duration::from_millis(agent.timeout_ms());
504
505 trace!(
506 correlation_id = %ctx.correlation_id,
507 agent_id = %agent.id(),
508 timeout_ms = agent.timeout_ms(),
509 "Calling agent"
510 );
511
512 match timeout(timeout_duration, agent.call_event(event_type, event)).await {
513 Ok(Ok(response)) => {
514 let duration = start.elapsed();
515 agent.record_success(duration).await;
516
517 trace!(
518 correlation_id = %ctx.correlation_id,
519 agent_id = %agent.id(),
520 duration_ms = duration.as_millis(),
521 decision = ?response,
522 "Agent call succeeded"
523 );
524
525 combined_decision.merge(response.into());
527
528 if !combined_decision.is_allow() {
530 debug!(
531 correlation_id = %ctx.correlation_id,
532 agent_id = %agent.id(),
533 decision = ?combined_decision,
534 "Agent returned blocking decision, stopping agent chain"
535 );
536 break;
537 }
538 }
539 Ok(Err(e)) => {
540 agent.record_failure().await;
541 error!(
542 agent_id = %agent.id(),
543 correlation_id = %ctx.correlation_id,
544 error = %e,
545 duration_ms = start.elapsed().as_millis(),
546 failure_mode = ?agent.failure_mode(),
547 "Agent call failed"
548 );
549
550 if agent.failure_mode() == FailureMode::Closed {
551 return Err(e);
552 }
553 }
554 Err(_) => {
555 agent.record_timeout().await;
556 warn!(
557 agent_id = %agent.id(),
558 correlation_id = %ctx.correlation_id,
559 timeout_ms = agent.timeout_ms(),
560 failure_mode = ?agent.failure_mode(),
561 "Agent call timed out"
562 );
563
564 if agent.failure_mode() == FailureMode::Closed {
565 debug!(
566 correlation_id = %ctx.correlation_id,
567 agent_id = %agent.id(),
568 "Blocking request due to timeout (fail-closed mode)"
569 );
570 return Ok(AgentDecision::block(504, "Gateway timeout"));
571 }
572 }
573 }
574 }
575
576 trace!(
577 correlation_id = %ctx.correlation_id,
578 decision = ?combined_decision,
579 agents_processed = relevant_agents.len(),
580 "Agent event processing completed"
581 );
582
583 Ok(combined_decision)
584 }
585
586 pub async fn initialize(&self) -> SentinelResult<()> {
588 let agents = self.agents.read().await;
589
590 info!(agent_count = agents.len(), "Initializing agent connections");
591
592 let mut initialized_count = 0;
593 let mut failed_count = 0;
594
595 for (id, agent) in agents.iter() {
596 debug!(agent_id = %id, "Initializing agent connection");
597 if let Err(e) = agent.initialize().await {
598 error!(
599 agent_id = %id,
600 error = %e,
601 "Failed to initialize agent"
602 );
603 failed_count += 1;
604 } else {
606 trace!(agent_id = %id, "Agent initialized successfully");
607 initialized_count += 1;
608 }
609 }
610
611 info!(
612 initialized = initialized_count,
613 failed = failed_count,
614 total = agents.len(),
615 "Agent initialization complete"
616 );
617
618 Ok(())
619 }
620
621 pub async fn shutdown(&self) {
623 let agents = self.agents.read().await;
624
625 info!(agent_count = agents.len(), "Shutting down agent manager");
626
627 for (id, agent) in agents.iter() {
628 debug!(agent_id = %id, "Shutting down agent");
629 agent.shutdown().await;
630 trace!(agent_id = %id, "Agent shutdown complete");
631 }
632
633 info!("Agent manager shutdown complete");
634 }
635
636 pub fn metrics(&self) -> &AgentMetrics {
638 &self.metrics
639 }
640
641 pub fn get_agents_for_event(&self, event_type: EventType) -> Vec<String> {
646 if let Ok(agents) = self.agents.try_read() {
649 agents
650 .values()
651 .filter(|agent| agent.handles_event(event_type))
652 .map(|agent| agent.id().to_string())
653 .collect()
654 } else {
655 Vec::new()
656 }
657 }
658}