sentinel_proxy/agents/
manager.rs

1//! Agent manager for coordinating external processing agents.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use pingora_timeout::timeout;
9use sentinel_agent_protocol::{
10    AgentResponse, EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseBodyChunkEvent,
11    ResponseHeadersEvent, WebSocketFrameEvent,
12};
13use sentinel_common::{
14    errors::{SentinelError, SentinelResult},
15    types::CircuitBreakerConfig,
16    CircuitBreaker,
17};
18use sentinel_config::{AgentConfig, FailureMode};
19use tokio::sync::{RwLock, Semaphore};
20use tracing::{debug, error, info, trace, warn};
21
22use super::agent::Agent;
23use super::context::AgentCallContext;
24use super::decision::AgentDecision;
25use super::metrics::AgentMetrics;
26use super::pool::AgentConnectionPool;
27
28/// Agent manager handling all external agents.
29pub struct AgentManager {
30    /// Configured agents
31    agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
32    /// Connection pools for agents
33    connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
34    /// Circuit breakers per agent
35    circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
36    /// Global agent metrics
37    metrics: Arc<AgentMetrics>,
38    /// Maximum concurrent agent calls
39    #[allow(dead_code)]
40    max_concurrent_calls: usize,
41    /// Global semaphore for agent calls
42    call_semaphore: Arc<Semaphore>,
43}
44
45impl AgentManager {
46    /// Create new agent manager.
47    pub async fn new(
48        agents: Vec<AgentConfig>,
49        max_concurrent_calls: usize,
50    ) -> SentinelResult<Self> {
51        info!(
52            agent_count = agents.len(),
53            max_concurrent_calls = max_concurrent_calls,
54            "Creating agent manager"
55        );
56
57        let mut agent_map = HashMap::new();
58        let mut pools = HashMap::new();
59        let mut breakers = HashMap::new();
60
61        for config in agents {
62            debug!(
63                agent_id = %config.id,
64                transport = ?config.transport,
65                timeout_ms = config.timeout_ms,
66                failure_mode = ?config.failure_mode,
67                "Configuring agent"
68            );
69
70            let pool = Arc::new(AgentConnectionPool::new(
71                10, // max connections
72                2,  // min idle
73                5,  // max idle
74                Duration::from_secs(60),
75            ));
76
77            let circuit_breaker = Arc::new(CircuitBreaker::new(
78                config
79                    .circuit_breaker
80                    .clone()
81                    .unwrap_or_else(CircuitBreakerConfig::default),
82            ));
83
84            trace!(
85                agent_id = %config.id,
86                "Creating agent instance"
87            );
88
89            let agent = Arc::new(Agent::new(
90                config.clone(),
91                Arc::clone(&pool),
92                Arc::clone(&circuit_breaker),
93            ));
94
95            agent_map.insert(config.id.clone(), agent);
96            pools.insert(config.id.clone(), pool);
97            breakers.insert(config.id.clone(), circuit_breaker);
98
99            debug!(
100                agent_id = %config.id,
101                "Agent configured successfully"
102            );
103        }
104
105        info!(
106            configured_agents = agent_map.len(),
107            "Agent manager created successfully"
108        );
109
110        Ok(Self {
111            agents: Arc::new(RwLock::new(agent_map)),
112            connection_pools: Arc::new(RwLock::new(pools)),
113            circuit_breakers: Arc::new(RwLock::new(breakers)),
114            metrics: Arc::new(AgentMetrics::default()),
115            max_concurrent_calls,
116            call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
117        })
118    }
119
120    /// Process request headers through agents.
121    pub async fn process_request_headers(
122        &self,
123        ctx: &AgentCallContext,
124        headers: &HashMap<String, Vec<String>>,
125        route_agents: &[String],
126    ) -> SentinelResult<AgentDecision> {
127        let event = RequestHeadersEvent {
128            metadata: ctx.metadata.clone(),
129            method: headers
130                .get(":method")
131                .and_then(|v| v.first())
132                .unwrap_or(&"GET".to_string())
133                .clone(),
134            uri: headers
135                .get(":path")
136                .and_then(|v| v.first())
137                .unwrap_or(&"/".to_string())
138                .clone(),
139            headers: headers.clone(),
140        };
141
142        self.process_event(EventType::RequestHeaders, &event, route_agents, ctx)
143            .await
144    }
145
146    /// Process request body chunk through agents.
147    pub async fn process_request_body(
148        &self,
149        ctx: &AgentCallContext,
150        data: &[u8],
151        is_last: bool,
152        route_agents: &[String],
153    ) -> SentinelResult<AgentDecision> {
154        // Check body size limits
155        let max_size = 1024 * 1024; // 1MB default
156        if data.len() > max_size {
157            warn!(
158                correlation_id = %ctx.correlation_id,
159                size = data.len(),
160                "Request body exceeds agent inspection limit"
161            );
162            return Ok(AgentDecision::default_allow());
163        }
164
165        let event = RequestBodyChunkEvent {
166            correlation_id: ctx.correlation_id.to_string(),
167            data: STANDARD.encode(data),
168            is_last,
169            total_size: ctx.request_body.as_ref().map(|b| b.len()),
170            chunk_index: 0, // Buffer mode sends entire body as single chunk
171            bytes_received: data.len(),
172        };
173
174        self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
175            .await
176    }
177
178    /// Process a single request body chunk through agents (streaming mode).
179    ///
180    /// Unlike `process_request_body` which is used for buffered mode, this method
181    /// is designed for streaming where chunks are sent individually as they arrive.
182    pub async fn process_request_body_streaming(
183        &self,
184        ctx: &AgentCallContext,
185        data: &[u8],
186        is_last: bool,
187        chunk_index: u32,
188        bytes_received: usize,
189        total_size: Option<usize>,
190        route_agents: &[String],
191    ) -> SentinelResult<AgentDecision> {
192        trace!(
193            correlation_id = %ctx.correlation_id,
194            chunk_index = chunk_index,
195            chunk_size = data.len(),
196            bytes_received = bytes_received,
197            is_last = is_last,
198            "Processing streaming request body chunk"
199        );
200
201        let event = RequestBodyChunkEvent {
202            correlation_id: ctx.correlation_id.to_string(),
203            data: STANDARD.encode(data),
204            is_last,
205            total_size,
206            chunk_index,
207            bytes_received,
208        };
209
210        self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
211            .await
212    }
213
214    /// Process a single response body chunk through agents (streaming mode).
215    pub async fn process_response_body_streaming(
216        &self,
217        ctx: &AgentCallContext,
218        data: &[u8],
219        is_last: bool,
220        chunk_index: u32,
221        bytes_sent: usize,
222        total_size: Option<usize>,
223        route_agents: &[String],
224    ) -> SentinelResult<AgentDecision> {
225        trace!(
226            correlation_id = %ctx.correlation_id,
227            chunk_index = chunk_index,
228            chunk_size = data.len(),
229            bytes_sent = bytes_sent,
230            is_last = is_last,
231            "Processing streaming response body chunk"
232        );
233
234        let event = ResponseBodyChunkEvent {
235            correlation_id: ctx.correlation_id.to_string(),
236            data: STANDARD.encode(data),
237            is_last,
238            total_size,
239            chunk_index,
240            bytes_sent,
241        };
242
243        self.process_event(EventType::ResponseBodyChunk, &event, route_agents, ctx)
244            .await
245    }
246
247    /// Process response headers through agents.
248    pub async fn process_response_headers(
249        &self,
250        ctx: &AgentCallContext,
251        status: u16,
252        headers: &HashMap<String, Vec<String>>,
253        route_agents: &[String],
254    ) -> SentinelResult<AgentDecision> {
255        let event = ResponseHeadersEvent {
256            correlation_id: ctx.correlation_id.to_string(),
257            status,
258            headers: headers.clone(),
259        };
260
261        self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
262            .await
263    }
264
265    /// Process a WebSocket frame through agents.
266    ///
267    /// This is used for WebSocket frame inspection after an upgrade.
268    /// Returns the agent response directly to allow the caller to access
269    /// the websocket_decision field.
270    pub async fn process_websocket_frame(
271        &self,
272        route_id: &str,
273        event: WebSocketFrameEvent,
274    ) -> SentinelResult<AgentResponse> {
275        trace!(
276            correlation_id = %event.correlation_id,
277            route_id = %route_id,
278            frame_index = event.frame_index,
279            opcode = %event.opcode,
280            "Processing WebSocket frame through agents"
281        );
282
283        // Get relevant agents for this route that handle WebSocket frames
284        let agents = self.agents.read().await;
285        let relevant_agents: Vec<_> = agents
286            .values()
287            .filter(|agent| agent.handles_event(EventType::WebSocketFrame))
288            .collect();
289
290        if relevant_agents.is_empty() {
291            trace!(
292                correlation_id = %event.correlation_id,
293                "No agents handle WebSocket frames, allowing"
294            );
295            return Ok(AgentResponse::websocket_allow());
296        }
297
298        debug!(
299            correlation_id = %event.correlation_id,
300            route_id = %route_id,
301            agent_count = relevant_agents.len(),
302            "Processing WebSocket frame through agents"
303        );
304
305        // Process through each agent sequentially
306        for agent in relevant_agents {
307            // Check circuit breaker
308            if !agent.circuit_breaker().is_closed().await {
309                warn!(
310                    agent_id = %agent.id(),
311                    correlation_id = %event.correlation_id,
312                    failure_mode = ?agent.failure_mode(),
313                    "Circuit breaker open, skipping agent for WebSocket frame"
314                );
315
316                if agent.failure_mode() == FailureMode::Closed {
317                    debug!(
318                        correlation_id = %event.correlation_id,
319                        agent_id = %agent.id(),
320                        "Closing WebSocket due to circuit breaker (fail-closed mode)"
321                    );
322                    return Ok(AgentResponse::websocket_close(
323                        1011,
324                        "Service unavailable".to_string(),
325                    ));
326                }
327                continue;
328            }
329
330            // Call agent with timeout
331            let start = Instant::now();
332            let timeout_duration = Duration::from_millis(agent.timeout_ms());
333
334            match timeout(
335                timeout_duration,
336                agent.call_event(EventType::WebSocketFrame, &event),
337            )
338            .await
339            {
340                Ok(Ok(response)) => {
341                    let duration = start.elapsed();
342                    agent.record_success(duration).await;
343
344                    trace!(
345                        correlation_id = %event.correlation_id,
346                        agent_id = %agent.id(),
347                        duration_ms = duration.as_millis(),
348                        "WebSocket frame agent call succeeded"
349                    );
350
351                    // If agent returned a WebSocket decision that's not Allow, return immediately
352                    if let Some(ref ws_decision) = response.websocket_decision {
353                        if !matches!(
354                            ws_decision,
355                            sentinel_agent_protocol::WebSocketDecision::Allow
356                        ) {
357                            debug!(
358                                correlation_id = %event.correlation_id,
359                                agent_id = %agent.id(),
360                                decision = ?ws_decision,
361                                "Agent returned non-allow WebSocket decision"
362                            );
363                            return Ok(response);
364                        }
365                    }
366                }
367                Ok(Err(e)) => {
368                    agent.record_failure().await;
369                    error!(
370                        agent_id = %agent.id(),
371                        correlation_id = %event.correlation_id,
372                        error = %e,
373                        duration_ms = start.elapsed().as_millis(),
374                        failure_mode = ?agent.failure_mode(),
375                        "WebSocket frame agent call failed"
376                    );
377
378                    if agent.failure_mode() == FailureMode::Closed {
379                        return Ok(AgentResponse::websocket_close(
380                            1011,
381                            "Agent error".to_string(),
382                        ));
383                    }
384                }
385                Err(_) => {
386                    agent.record_timeout().await;
387                    warn!(
388                        agent_id = %agent.id(),
389                        correlation_id = %event.correlation_id,
390                        timeout_ms = agent.timeout_ms(),
391                        failure_mode = ?agent.failure_mode(),
392                        "WebSocket frame agent call timed out"
393                    );
394
395                    if agent.failure_mode() == FailureMode::Closed {
396                        return Ok(AgentResponse::websocket_close(
397                            1011,
398                            "Gateway timeout".to_string(),
399                        ));
400                    }
401                }
402            }
403        }
404
405        // All agents allowed the frame
406        Ok(AgentResponse::websocket_allow())
407    }
408
409    /// Process an event through relevant agents.
410    async fn process_event<T: serde::Serialize>(
411        &self,
412        event_type: EventType,
413        event: &T,
414        route_agents: &[String],
415        ctx: &AgentCallContext,
416    ) -> SentinelResult<AgentDecision> {
417        trace!(
418            correlation_id = %ctx.correlation_id,
419            event_type = ?event_type,
420            route_agents = ?route_agents,
421            "Starting agent event processing"
422        );
423
424        // Get relevant agents for this route and event type
425        let agents = self.agents.read().await;
426        let relevant_agents: Vec<_> = route_agents
427            .iter()
428            .filter_map(|id| agents.get(id))
429            .filter(|agent| agent.handles_event(event_type))
430            .collect();
431
432        if relevant_agents.is_empty() {
433            trace!(
434                correlation_id = %ctx.correlation_id,
435                event_type = ?event_type,
436                "No relevant agents for event, allowing request"
437            );
438            return Ok(AgentDecision::default_allow());
439        }
440
441        debug!(
442            correlation_id = %ctx.correlation_id,
443            event_type = ?event_type,
444            agent_count = relevant_agents.len(),
445            agent_ids = ?relevant_agents.iter().map(|a| a.id()).collect::<Vec<_>>(),
446            "Processing event through agents"
447        );
448
449        // Process through each agent sequentially
450        let mut combined_decision = AgentDecision::default_allow();
451
452        for (agent_index, agent) in relevant_agents.iter().enumerate() {
453            trace!(
454                correlation_id = %ctx.correlation_id,
455                agent_id = %agent.id(),
456                agent_index = agent_index,
457                event_type = ?event_type,
458                "Processing event through agent"
459            );
460
461            // Acquire semaphore permit
462            trace!(
463                correlation_id = %ctx.correlation_id,
464                agent_id = %agent.id(),
465                "Acquiring agent call semaphore permit"
466            );
467            let _permit = self.call_semaphore.acquire().await.map_err(|_| {
468                error!(
469                    correlation_id = %ctx.correlation_id,
470                    agent_id = %agent.id(),
471                    "Failed to acquire agent call semaphore permit"
472                );
473                SentinelError::Internal {
474                    message: "Failed to acquire agent call permit".to_string(),
475                    correlation_id: Some(ctx.correlation_id.to_string()),
476                    source: None,
477                }
478            })?;
479
480            // Check circuit breaker
481            if !agent.circuit_breaker().is_closed().await {
482                warn!(
483                    agent_id = %agent.id(),
484                    correlation_id = %ctx.correlation_id,
485                    failure_mode = ?agent.failure_mode(),
486                    "Circuit breaker open, skipping agent"
487                );
488
489                // Handle based on failure mode
490                if agent.failure_mode() == FailureMode::Closed {
491                    debug!(
492                        correlation_id = %ctx.correlation_id,
493                        agent_id = %agent.id(),
494                        "Blocking request due to circuit breaker (fail-closed mode)"
495                    );
496                    return Ok(AgentDecision::block(503, "Service unavailable"));
497                }
498                continue;
499            }
500
501            // Call agent with timeout (using pingora-timeout for efficiency)
502            let start = Instant::now();
503            let timeout_duration = Duration::from_millis(agent.timeout_ms());
504
505            trace!(
506                correlation_id = %ctx.correlation_id,
507                agent_id = %agent.id(),
508                timeout_ms = agent.timeout_ms(),
509                "Calling agent"
510            );
511
512            match timeout(timeout_duration, agent.call_event(event_type, event)).await {
513                Ok(Ok(response)) => {
514                    let duration = start.elapsed();
515                    agent.record_success(duration).await;
516
517                    trace!(
518                        correlation_id = %ctx.correlation_id,
519                        agent_id = %agent.id(),
520                        duration_ms = duration.as_millis(),
521                        decision = ?response,
522                        "Agent call succeeded"
523                    );
524
525                    // Merge response into combined decision
526                    combined_decision.merge(response.into());
527
528                    // If decision is to block/redirect/challenge, stop processing
529                    if !combined_decision.is_allow() {
530                        debug!(
531                            correlation_id = %ctx.correlation_id,
532                            agent_id = %agent.id(),
533                            decision = ?combined_decision,
534                            "Agent returned blocking decision, stopping agent chain"
535                        );
536                        break;
537                    }
538                }
539                Ok(Err(e)) => {
540                    agent.record_failure().await;
541                    error!(
542                        agent_id = %agent.id(),
543                        correlation_id = %ctx.correlation_id,
544                        error = %e,
545                        duration_ms = start.elapsed().as_millis(),
546                        failure_mode = ?agent.failure_mode(),
547                        "Agent call failed"
548                    );
549
550                    if agent.failure_mode() == FailureMode::Closed {
551                        return Err(e);
552                    }
553                }
554                Err(_) => {
555                    agent.record_timeout().await;
556                    warn!(
557                        agent_id = %agent.id(),
558                        correlation_id = %ctx.correlation_id,
559                        timeout_ms = agent.timeout_ms(),
560                        failure_mode = ?agent.failure_mode(),
561                        "Agent call timed out"
562                    );
563
564                    if agent.failure_mode() == FailureMode::Closed {
565                        debug!(
566                            correlation_id = %ctx.correlation_id,
567                            agent_id = %agent.id(),
568                            "Blocking request due to timeout (fail-closed mode)"
569                        );
570                        return Ok(AgentDecision::block(504, "Gateway timeout"));
571                    }
572                }
573            }
574        }
575
576        trace!(
577            correlation_id = %ctx.correlation_id,
578            decision = ?combined_decision,
579            agents_processed = relevant_agents.len(),
580            "Agent event processing completed"
581        );
582
583        Ok(combined_decision)
584    }
585
586    /// Initialize agent connections.
587    pub async fn initialize(&self) -> SentinelResult<()> {
588        let agents = self.agents.read().await;
589
590        info!(agent_count = agents.len(), "Initializing agent connections");
591
592        let mut initialized_count = 0;
593        let mut failed_count = 0;
594
595        for (id, agent) in agents.iter() {
596            debug!(agent_id = %id, "Initializing agent connection");
597            if let Err(e) = agent.initialize().await {
598                error!(
599                    agent_id = %id,
600                    error = %e,
601                    "Failed to initialize agent"
602                );
603                failed_count += 1;
604                // Continue with other agents
605            } else {
606                trace!(agent_id = %id, "Agent initialized successfully");
607                initialized_count += 1;
608            }
609        }
610
611        info!(
612            initialized = initialized_count,
613            failed = failed_count,
614            total = agents.len(),
615            "Agent initialization complete"
616        );
617
618        Ok(())
619    }
620
621    /// Shutdown all agents.
622    pub async fn shutdown(&self) {
623        let agents = self.agents.read().await;
624
625        info!(agent_count = agents.len(), "Shutting down agent manager");
626
627        for (id, agent) in agents.iter() {
628            debug!(agent_id = %id, "Shutting down agent");
629            agent.shutdown().await;
630            trace!(agent_id = %id, "Agent shutdown complete");
631        }
632
633        info!("Agent manager shutdown complete");
634    }
635
636    /// Get agent metrics.
637    pub fn metrics(&self) -> &AgentMetrics {
638        &self.metrics
639    }
640
641    /// Get agent IDs that handle a specific event type.
642    ///
643    /// This is useful for pre-filtering agents before making calls,
644    /// e.g., to check if any agents handle WebSocket frames.
645    pub fn get_agents_for_event(&self, event_type: EventType) -> Vec<String> {
646        // Use try_read to avoid blocking - return empty if lock is held
647        // This is acceptable since this is only used for informational purposes
648        if let Ok(agents) = self.agents.try_read() {
649            agents
650                .values()
651                .filter(|agent| agent.handles_event(event_type))
652                .map(|agent| agent.id().to_string())
653                .collect()
654        } else {
655            Vec::new()
656        }
657    }
658}