sentinel_proxy/agents/
manager.rs

1//! Agent manager for coordinating external processing agents.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use sentinel_agent_protocol::{
9    EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseHeadersEvent,
10};
11use sentinel_common::{
12    errors::{SentinelError, SentinelResult},
13    types::CircuitBreakerConfig,
14    CircuitBreaker,
15};
16use sentinel_config::{AgentConfig, FailureMode};
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, error, info, trace, warn};
19
20use super::agent::Agent;
21use super::context::AgentCallContext;
22use super::decision::AgentDecision;
23use super::metrics::AgentMetrics;
24use super::pool::AgentConnectionPool;
25
26/// Agent manager handling all external agents.
27pub struct AgentManager {
28    /// Configured agents
29    agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
30    /// Connection pools for agents
31    connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
32    /// Circuit breakers per agent
33    circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
34    /// Global agent metrics
35    metrics: Arc<AgentMetrics>,
36    /// Maximum concurrent agent calls
37    #[allow(dead_code)]
38    max_concurrent_calls: usize,
39    /// Global semaphore for agent calls
40    call_semaphore: Arc<Semaphore>,
41}
42
43impl AgentManager {
44    /// Create new agent manager.
45    pub async fn new(
46        agents: Vec<AgentConfig>,
47        max_concurrent_calls: usize,
48    ) -> SentinelResult<Self> {
49        info!(
50            agent_count = agents.len(),
51            max_concurrent_calls = max_concurrent_calls,
52            "Creating agent manager"
53        );
54
55        let mut agent_map = HashMap::new();
56        let mut pools = HashMap::new();
57        let mut breakers = HashMap::new();
58
59        for config in agents {
60            debug!(
61                agent_id = %config.id,
62                transport = ?config.transport,
63                timeout_ms = config.timeout_ms,
64                failure_mode = ?config.failure_mode,
65                "Configuring agent"
66            );
67
68            let pool = Arc::new(AgentConnectionPool::new(
69                10, // max connections
70                2,  // min idle
71                5,  // max idle
72                Duration::from_secs(60),
73            ));
74
75            let circuit_breaker = Arc::new(CircuitBreaker::new(
76                config
77                    .circuit_breaker
78                    .clone()
79                    .unwrap_or_else(CircuitBreakerConfig::default),
80            ));
81
82            trace!(
83                agent_id = %config.id,
84                "Creating agent instance"
85            );
86
87            let agent = Arc::new(Agent::new(
88                config.clone(),
89                Arc::clone(&pool),
90                Arc::clone(&circuit_breaker),
91            ));
92
93            agent_map.insert(config.id.clone(), agent);
94            pools.insert(config.id.clone(), pool);
95            breakers.insert(config.id.clone(), circuit_breaker);
96
97            debug!(
98                agent_id = %config.id,
99                "Agent configured successfully"
100            );
101        }
102
103        info!(
104            configured_agents = agent_map.len(),
105            "Agent manager created successfully"
106        );
107
108        Ok(Self {
109            agents: Arc::new(RwLock::new(agent_map)),
110            connection_pools: Arc::new(RwLock::new(pools)),
111            circuit_breakers: Arc::new(RwLock::new(breakers)),
112            metrics: Arc::new(AgentMetrics::default()),
113            max_concurrent_calls,
114            call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
115        })
116    }
117
118    /// Process request headers through agents.
119    pub async fn process_request_headers(
120        &self,
121        ctx: &AgentCallContext,
122        headers: &HashMap<String, Vec<String>>,
123        route_agents: &[String],
124    ) -> SentinelResult<AgentDecision> {
125        let event = RequestHeadersEvent {
126            metadata: ctx.metadata.clone(),
127            method: headers
128                .get(":method")
129                .and_then(|v| v.first())
130                .unwrap_or(&"GET".to_string())
131                .clone(),
132            uri: headers
133                .get(":path")
134                .and_then(|v| v.first())
135                .unwrap_or(&"/".to_string())
136                .clone(),
137            headers: headers.clone(),
138        };
139
140        self.process_event(EventType::RequestHeaders, &event, route_agents, ctx)
141            .await
142    }
143
144    /// Process request body chunk through agents.
145    pub async fn process_request_body(
146        &self,
147        ctx: &AgentCallContext,
148        data: &[u8],
149        is_last: bool,
150        route_agents: &[String],
151    ) -> SentinelResult<AgentDecision> {
152        // Check body size limits
153        let max_size = 1024 * 1024; // 1MB default
154        if data.len() > max_size {
155            warn!(
156                correlation_id = %ctx.correlation_id,
157                size = data.len(),
158                "Request body exceeds agent inspection limit"
159            );
160            return Ok(AgentDecision::default_allow());
161        }
162
163        let event = RequestBodyChunkEvent {
164            correlation_id: ctx.correlation_id.to_string(),
165            data: STANDARD.encode(data),
166            is_last,
167            total_size: ctx.request_body.as_ref().map(|b| b.len()),
168        };
169
170        self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
171            .await
172    }
173
174    /// Process response headers through agents.
175    pub async fn process_response_headers(
176        &self,
177        ctx: &AgentCallContext,
178        status: u16,
179        headers: &HashMap<String, Vec<String>>,
180        route_agents: &[String],
181    ) -> SentinelResult<AgentDecision> {
182        let event = ResponseHeadersEvent {
183            correlation_id: ctx.correlation_id.to_string(),
184            status,
185            headers: headers.clone(),
186        };
187
188        self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
189            .await
190    }
191
192    /// Process an event through relevant agents.
193    async fn process_event<T: serde::Serialize>(
194        &self,
195        event_type: EventType,
196        event: &T,
197        route_agents: &[String],
198        ctx: &AgentCallContext,
199    ) -> SentinelResult<AgentDecision> {
200        trace!(
201            correlation_id = %ctx.correlation_id,
202            event_type = ?event_type,
203            route_agents = ?route_agents,
204            "Starting agent event processing"
205        );
206
207        // Get relevant agents for this route and event type
208        let agents = self.agents.read().await;
209        let relevant_agents: Vec<_> = route_agents
210            .iter()
211            .filter_map(|id| agents.get(id))
212            .filter(|agent| agent.handles_event(event_type))
213            .collect();
214
215        if relevant_agents.is_empty() {
216            trace!(
217                correlation_id = %ctx.correlation_id,
218                event_type = ?event_type,
219                "No relevant agents for event, allowing request"
220            );
221            return Ok(AgentDecision::default_allow());
222        }
223
224        debug!(
225            correlation_id = %ctx.correlation_id,
226            event_type = ?event_type,
227            agent_count = relevant_agents.len(),
228            agent_ids = ?relevant_agents.iter().map(|a| a.id()).collect::<Vec<_>>(),
229            "Processing event through agents"
230        );
231
232        // Process through each agent sequentially
233        let mut combined_decision = AgentDecision::default_allow();
234
235        for (agent_index, agent) in relevant_agents.iter().enumerate() {
236            trace!(
237                correlation_id = %ctx.correlation_id,
238                agent_id = %agent.id(),
239                agent_index = agent_index,
240                event_type = ?event_type,
241                "Processing event through agent"
242            );
243
244            // Acquire semaphore permit
245            trace!(
246                correlation_id = %ctx.correlation_id,
247                agent_id = %agent.id(),
248                "Acquiring agent call semaphore permit"
249            );
250            let _permit = self.call_semaphore.acquire().await.map_err(|_| {
251                error!(
252                    correlation_id = %ctx.correlation_id,
253                    agent_id = %agent.id(),
254                    "Failed to acquire agent call semaphore permit"
255                );
256                SentinelError::Internal {
257                    message: "Failed to acquire agent call permit".to_string(),
258                    correlation_id: Some(ctx.correlation_id.to_string()),
259                    source: None,
260                }
261            })?;
262
263            // Check circuit breaker
264            if !agent.circuit_breaker().is_closed().await {
265                warn!(
266                    agent_id = %agent.id(),
267                    correlation_id = %ctx.correlation_id,
268                    failure_mode = ?agent.failure_mode(),
269                    "Circuit breaker open, skipping agent"
270                );
271
272                // Handle based on failure mode
273                if agent.failure_mode() == FailureMode::Closed {
274                    debug!(
275                        correlation_id = %ctx.correlation_id,
276                        agent_id = %agent.id(),
277                        "Blocking request due to circuit breaker (fail-closed mode)"
278                    );
279                    return Ok(AgentDecision::block(503, "Service unavailable"));
280                }
281                continue;
282            }
283
284            // Call agent with timeout
285            let start = Instant::now();
286            let timeout = Duration::from_millis(agent.timeout_ms());
287
288            trace!(
289                correlation_id = %ctx.correlation_id,
290                agent_id = %agent.id(),
291                timeout_ms = agent.timeout_ms(),
292                "Calling agent"
293            );
294
295            match tokio::time::timeout(timeout, agent.call_event(event_type, event)).await {
296                Ok(Ok(response)) => {
297                    let duration = start.elapsed();
298                    agent.record_success(duration).await;
299
300                    trace!(
301                        correlation_id = %ctx.correlation_id,
302                        agent_id = %agent.id(),
303                        duration_ms = duration.as_millis(),
304                        decision = ?response,
305                        "Agent call succeeded"
306                    );
307
308                    // Merge response into combined decision
309                    combined_decision.merge(response.into());
310
311                    // If decision is to block/redirect/challenge, stop processing
312                    if !combined_decision.is_allow() {
313                        debug!(
314                            correlation_id = %ctx.correlation_id,
315                            agent_id = %agent.id(),
316                            decision = ?combined_decision,
317                            "Agent returned blocking decision, stopping agent chain"
318                        );
319                        break;
320                    }
321                }
322                Ok(Err(e)) => {
323                    agent.record_failure().await;
324                    error!(
325                        agent_id = %agent.id(),
326                        correlation_id = %ctx.correlation_id,
327                        error = %e,
328                        duration_ms = start.elapsed().as_millis(),
329                        failure_mode = ?agent.failure_mode(),
330                        "Agent call failed"
331                    );
332
333                    if agent.failure_mode() == FailureMode::Closed {
334                        return Err(e);
335                    }
336                }
337                Err(_) => {
338                    agent.record_timeout().await;
339                    warn!(
340                        agent_id = %agent.id(),
341                        correlation_id = %ctx.correlation_id,
342                        timeout_ms = agent.timeout_ms(),
343                        failure_mode = ?agent.failure_mode(),
344                        "Agent call timed out"
345                    );
346
347                    if agent.failure_mode() == FailureMode::Closed {
348                        debug!(
349                            correlation_id = %ctx.correlation_id,
350                            agent_id = %agent.id(),
351                            "Blocking request due to timeout (fail-closed mode)"
352                        );
353                        return Ok(AgentDecision::block(504, "Gateway timeout"));
354                    }
355                }
356            }
357        }
358
359        trace!(
360            correlation_id = %ctx.correlation_id,
361            decision = ?combined_decision,
362            agents_processed = relevant_agents.len(),
363            "Agent event processing completed"
364        );
365
366        Ok(combined_decision)
367    }
368
369    /// Initialize agent connections.
370    pub async fn initialize(&self) -> SentinelResult<()> {
371        let agents = self.agents.read().await;
372
373        info!(
374            agent_count = agents.len(),
375            "Initializing agent connections"
376        );
377
378        let mut initialized_count = 0;
379        let mut failed_count = 0;
380
381        for (id, agent) in agents.iter() {
382            debug!(agent_id = %id, "Initializing agent connection");
383            if let Err(e) = agent.initialize().await {
384                error!(
385                    agent_id = %id,
386                    error = %e,
387                    "Failed to initialize agent"
388                );
389                failed_count += 1;
390                // Continue with other agents
391            } else {
392                trace!(agent_id = %id, "Agent initialized successfully");
393                initialized_count += 1;
394            }
395        }
396
397        info!(
398            initialized = initialized_count,
399            failed = failed_count,
400            total = agents.len(),
401            "Agent initialization complete"
402        );
403
404        Ok(())
405    }
406
407    /// Shutdown all agents.
408    pub async fn shutdown(&self) {
409        let agents = self.agents.read().await;
410
411        info!(
412            agent_count = agents.len(),
413            "Shutting down agent manager"
414        );
415
416        for (id, agent) in agents.iter() {
417            debug!(agent_id = %id, "Shutting down agent");
418            agent.shutdown().await;
419            trace!(agent_id = %id, "Agent shutdown complete");
420        }
421
422        info!("Agent manager shutdown complete");
423    }
424
425    /// Get agent metrics.
426    pub fn metrics(&self) -> &AgentMetrics {
427        &self.metrics
428    }
429}