Skip to main content

systemprompt_agent/services/agent_orchestration/
database.rs

1use crate::repository::agent_service::AgentServiceRepository;
2use crate::services::agent_orchestration::{
3    process, AgentStatus, OrchestrationError, OrchestrationResult,
4};
5use crate::services::registry::AgentRegistry;
6use systemprompt_models::services::AgentConfig;
7
8#[derive(Debug)]
9pub struct AgentDatabaseService {
10    pub repository: AgentServiceRepository,
11    pub registry: AgentRegistry,
12}
13
14impl AgentDatabaseService {
15    pub async fn new(repository: AgentServiceRepository) -> OrchestrationResult<Self> {
16        let registry = AgentRegistry::new().await.map_err(|e| {
17            OrchestrationError::Database(format!("Failed to load agent registry: {e}"))
18        })?;
19
20        Ok(Self {
21            repository,
22            registry,
23        })
24    }
25
26    pub async fn register_agent(
27        &self,
28        name: &str,
29        pid: u32,
30        port: u16,
31    ) -> OrchestrationResult<String> {
32        self.repository
33            .register_agent(name, pid, port)
34            .await
35            .map_err(|e| OrchestrationError::Database(e.to_string()))
36    }
37
38    pub async fn get_status(&self, agent_name: &str) -> OrchestrationResult<AgentStatus> {
39        let row = self
40            .repository
41            .get_agent_status(agent_name)
42            .await
43            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
44
45        match row {
46            Some(r) => match (r.pid, r.status.as_str()) {
47                (Some(pid), "running") => {
48                    let pid = pid as u32;
49                    if process::process_exists(pid) {
50                        Ok(AgentStatus::Running {
51                            pid,
52                            port: r.port as u16,
53                        })
54                    } else {
55                        self.mark_failed(agent_name, "Process died unexpectedly")
56                            .await?;
57                        Ok(AgentStatus::Failed {
58                            reason: "Process died unexpectedly".to_string(),
59                            last_attempt: None,
60                            retry_count: 0,
61                        })
62                    }
63                },
64                (_, "starting") => Ok(AgentStatus::Failed {
65                    reason: "Agent is starting".to_string(),
66                    last_attempt: None,
67                    retry_count: 0,
68                }),
69                (_, "failed" | "crashed" | "stopped") => {
70                    let error_msg = self
71                        .get_error_message(agent_name)
72                        .await
73                        .unwrap_or_else(|_| "Unknown failure".to_string());
74                    Ok(AgentStatus::Failed {
75                        reason: error_msg,
76                        last_attempt: None,
77                        retry_count: 0,
78                    })
79                },
80                _ => {
81                    self.mark_failed(agent_name, "Invalid database state")
82                        .await?;
83                    Ok(AgentStatus::Failed {
84                        reason: "Invalid database state".to_string(),
85                        last_attempt: None,
86                        retry_count: 0,
87                    })
88                },
89            },
90            None => Ok(AgentStatus::Failed {
91                reason: "No service record found".to_string(),
92                last_attempt: None,
93                retry_count: 0,
94            }),
95        }
96    }
97
98    pub async fn mark_failed(&self, agent_name: &str, _reason: &str) -> OrchestrationResult<()> {
99        self.repository
100            .mark_error(agent_name)
101            .await
102            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
103
104        self.repository
105            .mark_crashed(agent_name)
106            .await
107            .map_err(|e| OrchestrationError::Database(e.to_string()))
108    }
109
110    pub async fn mark_crashed(&self, agent_name: &str) -> OrchestrationResult<()> {
111        self.mark_failed(agent_name, "Process crashed").await
112    }
113
114    pub async fn get_error_message(&self, agent_name: &str) -> OrchestrationResult<String> {
115        let row = self
116            .repository
117            .get_agent_status(agent_name)
118            .await
119            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
120
121        match row {
122            Some(r) => Ok(format!("Status: {}", r.status)),
123            None => Ok("No service record".to_string()),
124        }
125    }
126
127    pub async fn mark_error(&self, agent_name: &str) -> OrchestrationResult<()> {
128        self.repository
129            .mark_error(agent_name)
130            .await
131            .map_err(|e| OrchestrationError::Database(e.to_string()))
132    }
133
134    pub async fn list_running_agents(&self) -> OrchestrationResult<Vec<String>> {
135        let rows = self
136            .repository
137            .list_running_agents()
138            .await
139            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
140
141        Ok(rows.into_iter().map(|row| row.name).collect())
142    }
143
144    pub async fn list_all_agents(&self) -> OrchestrationResult<Vec<(String, AgentStatus)>> {
145        let agent_configs = self.registry.list_agents().await.map_err(|e| {
146            OrchestrationError::Database(format!("Failed to list agents from config: {e}"))
147        })?;
148
149        let mut agents = Vec::new();
150
151        for agent_config in agent_configs {
152            let agent_name = &agent_config.name;
153
154            let status = self.get_status(agent_name).await?;
155
156            agents.push((agent_name.clone(), status));
157        }
158
159        Ok(agents)
160    }
161
162    pub async fn agent_exists(&self, agent_name: &str) -> OrchestrationResult<bool> {
163        self.registry
164            .get_agent(agent_name)
165            .await
166            .map(|_| true)
167            .or_else(|_| Ok(false))
168    }
169
170    pub async fn get_agent_config(&self, agent_name: &str) -> OrchestrationResult<AgentConfig> {
171        let agent_config = self.registry.get_agent(agent_name).await.map_err(|e| {
172            OrchestrationError::AgentNotFound(format!(
173                "Agent {} not found in config: {}",
174                agent_name, e
175            ))
176        })?;
177
178        Ok(agent_config)
179    }
180
181    pub async fn cleanup_orphaned_services(&self) -> OrchestrationResult<u64> {
182        let rows = self
183            .repository
184            .list_running_agent_pids()
185            .await
186            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
187
188        let mut cleaned = 0u64;
189
190        for row in rows {
191            let pid = row.pid as u32;
192            if !process::process_exists(pid) {
193                self.mark_crashed(&row.name).await?;
194                cleaned += 1;
195            }
196        }
197
198        Ok(cleaned)
199    }
200
201    pub async fn remove_agent_service(&self, agent_name: &str) -> OrchestrationResult<()> {
202        self.repository
203            .remove_agent_service(agent_name)
204            .await
205            .map_err(|e| OrchestrationError::Database(e.to_string()))
206    }
207
208    pub async fn update_health_status(
209        &self,
210        agent_name: &str,
211        health_status: &str,
212    ) -> OrchestrationResult<()> {
213        self.repository
214            .update_health_status(agent_name, health_status)
215            .await
216            .map_err(|e| OrchestrationError::Database(e.to_string()))
217    }
218
219    pub async fn update_agent_running(
220        &self,
221        agent_name: &str,
222        pid: u32,
223        port: u16,
224    ) -> OrchestrationResult<String> {
225        self.repository
226            .register_agent(agent_name, pid, port)
227            .await
228            .map_err(|e| OrchestrationError::Database(e.to_string()))
229    }
230
231    pub async fn update_agent_stopped(&self, agent_name: &str) -> OrchestrationResult<()> {
232        self.repository
233            .mark_stopped(agent_name)
234            .await
235            .map_err(|e| OrchestrationError::Database(e.to_string()))
236    }
237
238    pub async fn register_agent_starting(
239        &self,
240        agent_name: &str,
241        pid: u32,
242        port: u16,
243    ) -> OrchestrationResult<String> {
244        self.repository
245            .register_agent_starting(agent_name, pid, port)
246            .await
247            .map_err(|e| OrchestrationError::Database(e.to_string()))
248    }
249
250    pub async fn mark_running(&self, agent_name: &str) -> OrchestrationResult<()> {
251        self.repository
252            .mark_running(agent_name)
253            .await
254            .map_err(|e| OrchestrationError::Database(e.to_string()))
255    }
256
257    pub async fn get_unresponsive_agents(
258        &self,
259        _max_failures: u32,
260    ) -> OrchestrationResult<Vec<(String, Option<u32>)>> {
261        use crate::services::agent_orchestration::monitor::check_a2a_agent_health;
262
263        let agents = self.list_all_agents().await?;
264
265        let mut unresponsive = Vec::new();
266        for (agent_name, status) in agents {
267            if let AgentStatus::Running { pid, port, .. } = status {
268                let is_healthy = check_a2a_agent_health(port, 10).await.unwrap_or(false);
269
270                if !is_healthy {
271                    unresponsive.push((agent_name, Some(pid)));
272                }
273            }
274        }
275
276        Ok(unresponsive)
277    }
278}