Skip to main content

systemprompt_agent/services/agent_orchestration/
monitor.rs

1use anyhow::Result;
2use std::time::Duration;
3use systemprompt_database::DbPool;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6
7use crate::services::agent_orchestration::database::AgentDatabaseService;
8use crate::services::agent_orchestration::{OrchestrationResult, process};
9
10#[derive(Debug)]
11pub struct AgentMonitor {
12    db_service: AgentDatabaseService,
13}
14
15impl AgentMonitor {
16    pub fn new(db_pool: &DbPool) -> OrchestrationResult<Self> {
17        use crate::repository::agent_service::AgentServiceRepository;
18
19        let agent_service_repo = AgentServiceRepository::new(db_pool)?;
20        let db_service = AgentDatabaseService::new(agent_service_repo)?;
21
22        Ok(Self { db_service })
23    }
24
25    pub async fn comprehensive_health_check(
26        &self,
27        agent_id: &str,
28    ) -> OrchestrationResult<HealthCheckResult> {
29        let status = self.db_service.get_status(agent_id).await?;
30
31        match status {
32            crate::services::agent_orchestration::AgentStatus::Running { pid, port } => {
33                if !process::process_exists(pid) {
34                    return Ok(HealthCheckResult {
35                        healthy: false,
36                        message: format!("Process {} no longer exists", pid),
37                        response_time_ms: 0,
38                    });
39                }
40
41                match perform_tcp_health_check("127.0.0.1", port).await {
42                    Ok(result) => Ok(result),
43                    Err(e) => Ok(HealthCheckResult {
44                        healthy: false,
45                        message: format!("TCP check failed: {e}"),
46                        response_time_ms: 0,
47                    }),
48                }
49            },
50            crate::services::agent_orchestration::AgentStatus::Failed { .. } => Ok(HealthCheckResult {
51                healthy: false,
52                message: format!("Agent {} not in running state", agent_id),
53                response_time_ms: 0,
54            }),
55        }
56    }
57
58    pub async fn monitor_all_agents(&self) -> OrchestrationResult<MonitoringReport> {
59        let agents = self.db_service.list_all_agents().await?;
60        let mut report = MonitoringReport::new();
61
62        for (agent_id, status) in agents {
63            match status {
64                crate::services::agent_orchestration::AgentStatus::Running { pid, port } => {
65                    if process::process_exists(pid) {
66                        let health_result = perform_tcp_health_check("127.0.0.1", port).await?;
67                        if health_result.healthy {
68                            report.healthy.push(agent_id);
69                        } else {
70                            report.unhealthy.push(agent_id);
71                        }
72                    } else {
73                        self.db_service
74                            .mark_failed(&agent_id, "Process died")
75                            .await?;
76                        report.failed.push(agent_id);
77                    }
78                },
79                crate::services::agent_orchestration::AgentStatus::Failed { .. } => {
80                    report.failed.push(agent_id);
81                },
82            }
83        }
84
85        Ok(report)
86    }
87
88    pub async fn cleanup_unresponsive_agents(&self, max_failures: u32) -> OrchestrationResult<u32> {
89        tracing::debug!("Cleaning up unresponsive agents");
90
91        let unresponsive_agents = self
92            .db_service
93            .get_unresponsive_agents(max_failures)
94            .await?;
95        let mut cleaned_up = 0;
96
97        for (agent_id, pid_opt) in unresponsive_agents {
98            if let Some(pid) = pid_opt {
99                tracing::warn!(agent_id = %agent_id, pid = %pid, "Killing unresponsive agent");
100
101                if process::kill_process(pid) {
102                    self.db_service.mark_crashed(&agent_id).await?;
103                    cleaned_up += 1;
104                    tracing::info!(agent_id = %agent_id, "Cleaned up agent");
105                } else {
106                    tracing::error!(agent_id = %agent_id, pid = %pid, "Failed to kill agent");
107                }
108            }
109        }
110
111        if cleaned_up > 0 {
112            tracing::info!(cleaned_up = %cleaned_up, "Cleaned up unresponsive agents");
113        } else {
114            tracing::debug!("No unresponsive agents found");
115        }
116
117        Ok(cleaned_up)
118    }
119}
120
121#[derive(Debug, Clone)]
122pub struct HealthCheckResult {
123    pub healthy: bool,
124    pub message: String,
125    pub response_time_ms: u64,
126}
127
128#[derive(Debug)]
129pub struct MonitoringReport {
130    pub healthy: Vec<String>,
131    pub unhealthy: Vec<String>,
132    pub failed: Vec<String>,
133}
134
135impl Default for MonitoringReport {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl MonitoringReport {
142    pub const fn new() -> Self {
143        Self {
144            healthy: Vec::new(),
145            unhealthy: Vec::new(),
146            failed: Vec::new(),
147        }
148    }
149
150    pub fn total_agents(&self) -> usize {
151        self.healthy.len() + self.unhealthy.len() + self.failed.len()
152    }
153
154    pub fn healthy_percentage(&self) -> f64 {
155        let total = self.total_agents();
156        if total == 0 {
157            0.0
158        } else {
159            (self.healthy.len() as f64 / total as f64) * 100.0
160        }
161    }
162}
163
164pub async fn check_agent_health(agent_id: &str) -> Result<HealthCheckResult> {
165    let port = get_agent_port_simple(agent_id);
166    perform_tcp_health_check("127.0.0.1", port).await
167}
168
169async fn perform_tcp_health_check(host: &str, port: u16) -> Result<HealthCheckResult> {
170    let start = std::time::Instant::now();
171    let address = format!("{host}:{port}");
172
173    tracing::trace!(address = %address, "Attempting TCP health check");
174
175    match timeout(Duration::from_secs(15), TcpStream::connect(&address)).await {
176        Ok(Ok(_)) => {
177            let response_time = start.elapsed().as_millis() as u64;
178            tracing::trace!(address = %address, response_time_ms = %response_time, "Health check passed");
179            Ok(HealthCheckResult {
180                healthy: true,
181                message: "TCP connection successful".to_string(),
182                response_time_ms: response_time,
183            })
184        },
185        Ok(Err(e)) => {
186            tracing::debug!(address = %address, error = %e, "Health check failed - connection error");
187            Ok(HealthCheckResult {
188                healthy: false,
189                message: format!("Connection failed: {e}"),
190                response_time_ms: 0,
191            })
192        },
193        Err(_) => {
194            tracing::debug!(address = %address, "Health check timeout");
195            Ok(HealthCheckResult {
196                healthy: false,
197                message: "Connection timeout".to_string(),
198                response_time_ms: 5000,
199            })
200        },
201    }
202}
203
204fn get_agent_port_simple(agent_id: &str) -> u16 {
205    let port_str = agent_id
206        .chars()
207        .filter(char::is_ascii_digit)
208        .collect::<String>();
209
210    if port_str.is_empty() {
211        return 8000;
212    }
213
214    let port_num: u16 = port_str.parse().unwrap_or(8000);
215    8000 + (port_num % 1000)
216}
217
218pub async fn check_agent_responsiveness(agent_id: &str, timeout_secs: u64) -> Result<bool> {
219    let port = get_agent_port_simple(agent_id);
220    let address = format!("127.0.0.1:{port}");
221
222    match timeout(
223        Duration::from_secs(timeout_secs),
224        TcpStream::connect(&address),
225    )
226    .await
227    {
228        Ok(Ok(_)) => {
229            tracing::trace!(agent_id = %agent_id, "Agent is responsive");
230            Ok(true)
231        },
232        Ok(Err(e)) => {
233            tracing::debug!(agent_id = %agent_id, error = %e, "Agent connection failed");
234            Ok(false)
235        },
236        Err(_) => {
237            tracing::debug!(agent_id = %agent_id, timeout_secs = %timeout_secs, "Agent connection timeout");
238            Ok(false)
239        },
240    }
241}
242
243pub async fn check_a2a_agent_health(port: u16, timeout_secs: u64) -> Result<bool> {
244    let url = format!("http://localhost:{}/.well-known/agent-card.json", port);
245
246    let client = reqwest::Client::new();
247    let response = client
248        .get(&url)
249        .timeout(Duration::from_secs(timeout_secs))
250        .send()
251        .await;
252
253    match response {
254        Ok(resp) if resp.status().is_success() => {
255            resp.json::<serde_json::Value>()
256                .await
257                .map_or(Ok(false), |json| {
258                    let is_valid_card = json.get("name").is_some() && json.get("url").is_some();
259                    Ok(is_valid_card)
260                })
261        },
262        Ok(_) | Err(_) => Ok(false),
263    }
264}