1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use sentinel_agent_protocol::{
9 EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseHeadersEvent,
10};
11use sentinel_common::{
12 errors::{SentinelError, SentinelResult},
13 types::CircuitBreakerConfig,
14 CircuitBreaker,
15};
16use sentinel_config::{AgentConfig, FailureMode};
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, error, info, warn};
19
20use super::agent::Agent;
21use super::context::AgentCallContext;
22use super::decision::AgentDecision;
23use super::metrics::AgentMetrics;
24use super::pool::AgentConnectionPool;
25
26pub struct AgentManager {
28 agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
30 connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
32 circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
34 metrics: Arc<AgentMetrics>,
36 #[allow(dead_code)]
38 max_concurrent_calls: usize,
39 call_semaphore: Arc<Semaphore>,
41}
42
43impl AgentManager {
44 pub async fn new(
46 agents: Vec<AgentConfig>,
47 max_concurrent_calls: usize,
48 ) -> SentinelResult<Self> {
49 let mut agent_map = HashMap::new();
50 let mut pools = HashMap::new();
51 let mut breakers = HashMap::new();
52
53 for config in agents {
54 let pool = Arc::new(AgentConnectionPool::new(
55 10, 2, 5, Duration::from_secs(60),
59 ));
60
61 let circuit_breaker = Arc::new(CircuitBreaker::new(
62 config
63 .circuit_breaker
64 .clone()
65 .unwrap_or_else(CircuitBreakerConfig::default),
66 ));
67
68 let agent = Arc::new(Agent::new(
69 config.clone(),
70 Arc::clone(&pool),
71 Arc::clone(&circuit_breaker),
72 ));
73
74 agent_map.insert(config.id.clone(), agent);
75 pools.insert(config.id.clone(), pool);
76 breakers.insert(config.id.clone(), circuit_breaker);
77 }
78
79 Ok(Self {
80 agents: Arc::new(RwLock::new(agent_map)),
81 connection_pools: Arc::new(RwLock::new(pools)),
82 circuit_breakers: Arc::new(RwLock::new(breakers)),
83 metrics: Arc::new(AgentMetrics::default()),
84 max_concurrent_calls,
85 call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
86 })
87 }
88
89 pub async fn process_request_headers(
91 &self,
92 ctx: &AgentCallContext,
93 headers: &HashMap<String, Vec<String>>,
94 route_agents: &[String],
95 ) -> SentinelResult<AgentDecision> {
96 let event = RequestHeadersEvent {
97 metadata: ctx.metadata.clone(),
98 method: headers
99 .get(":method")
100 .and_then(|v| v.first())
101 .unwrap_or(&"GET".to_string())
102 .clone(),
103 uri: headers
104 .get(":path")
105 .and_then(|v| v.first())
106 .unwrap_or(&"/".to_string())
107 .clone(),
108 headers: headers.clone(),
109 };
110
111 self.process_event(EventType::RequestHeaders, &event, route_agents, ctx)
112 .await
113 }
114
115 pub async fn process_request_body(
117 &self,
118 ctx: &AgentCallContext,
119 data: &[u8],
120 is_last: bool,
121 route_agents: &[String],
122 ) -> SentinelResult<AgentDecision> {
123 let max_size = 1024 * 1024; if data.len() > max_size {
126 warn!(
127 correlation_id = %ctx.correlation_id,
128 size = data.len(),
129 "Request body exceeds agent inspection limit"
130 );
131 return Ok(AgentDecision::default_allow());
132 }
133
134 let event = RequestBodyChunkEvent {
135 correlation_id: ctx.correlation_id.to_string(),
136 data: STANDARD.encode(data),
137 is_last,
138 total_size: ctx.request_body.as_ref().map(|b| b.len()),
139 };
140
141 self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
142 .await
143 }
144
145 pub async fn process_response_headers(
147 &self,
148 ctx: &AgentCallContext,
149 status: u16,
150 headers: &HashMap<String, Vec<String>>,
151 route_agents: &[String],
152 ) -> SentinelResult<AgentDecision> {
153 let event = ResponseHeadersEvent {
154 correlation_id: ctx.correlation_id.to_string(),
155 status,
156 headers: headers.clone(),
157 };
158
159 self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
160 .await
161 }
162
163 async fn process_event<T: serde::Serialize>(
165 &self,
166 event_type: EventType,
167 event: &T,
168 route_agents: &[String],
169 ctx: &AgentCallContext,
170 ) -> SentinelResult<AgentDecision> {
171 let agents = self.agents.read().await;
173 let relevant_agents: Vec<_> = route_agents
174 .iter()
175 .filter_map(|id| agents.get(id))
176 .filter(|agent| agent.handles_event(event_type))
177 .collect();
178
179 if relevant_agents.is_empty() {
180 return Ok(AgentDecision::default_allow());
181 }
182
183 debug!(
184 correlation_id = %ctx.correlation_id,
185 event_type = ?event_type,
186 agent_count = relevant_agents.len(),
187 "Processing event through agents"
188 );
189
190 let mut combined_decision = AgentDecision::default_allow();
192
193 for agent in relevant_agents {
194 let _permit = self.call_semaphore.acquire().await.map_err(|_| {
196 SentinelError::Internal {
197 message: "Failed to acquire agent call permit".to_string(),
198 correlation_id: Some(ctx.correlation_id.to_string()),
199 source: None,
200 }
201 })?;
202
203 if !agent.circuit_breaker().is_closed().await {
205 warn!(
206 agent_id = %agent.id(),
207 correlation_id = %ctx.correlation_id,
208 "Circuit breaker open, skipping agent"
209 );
210
211 if agent.failure_mode() == FailureMode::Closed {
213 return Ok(AgentDecision::block(503, "Service unavailable"));
214 }
215 continue;
216 }
217
218 let start = Instant::now();
220 let timeout = Duration::from_millis(agent.timeout_ms());
221
222 match tokio::time::timeout(timeout, agent.call_event(event_type, event)).await {
223 Ok(Ok(response)) => {
224 let duration = start.elapsed();
225 agent.record_success(duration).await;
226
227 combined_decision.merge(response.into());
229
230 if !combined_decision.is_allow() {
232 break;
233 }
234 }
235 Ok(Err(e)) => {
236 agent.record_failure().await;
237 error!(
238 agent_id = %agent.id(),
239 correlation_id = %ctx.correlation_id,
240 error = %e,
241 "Agent call failed"
242 );
243
244 if agent.failure_mode() == FailureMode::Closed {
245 return Err(e);
246 }
247 }
248 Err(_) => {
249 agent.record_timeout().await;
250 warn!(
251 agent_id = %agent.id(),
252 correlation_id = %ctx.correlation_id,
253 timeout_ms = agent.timeout_ms(),
254 "Agent call timed out"
255 );
256
257 if agent.failure_mode() == FailureMode::Closed {
258 return Ok(AgentDecision::block(504, "Gateway timeout"));
259 }
260 }
261 }
262 }
263
264 Ok(combined_decision)
265 }
266
267 pub async fn initialize(&self) -> SentinelResult<()> {
269 let agents = self.agents.read().await;
270
271 for (id, agent) in agents.iter() {
272 info!("Initializing agent: {}", id);
273 if let Err(e) = agent.initialize().await {
274 error!("Failed to initialize agent {}: {}", id, e);
275 }
277 }
278
279 Ok(())
280 }
281
282 pub async fn shutdown(&self) {
284 info!("Shutting down agent manager");
285
286 let agents = self.agents.read().await;
287 for (id, agent) in agents.iter() {
288 debug!("Shutting down agent: {}", id);
289 agent.shutdown().await;
290 }
291 }
292
293 pub fn metrics(&self) -> &AgentMetrics {
295 &self.metrics
296 }
297}