Skip to main content

systemprompt_agent/services/agent_orchestration/
database.rs

1use crate::repository::agent_service::AgentServiceRepository;
2use crate::services::agent_orchestration::{
3    AgentStatus, OrchestrationError, OrchestrationResult, process,
4};
5use crate::services::registry::AgentRegistry;
6use systemprompt_models::services::AgentConfig;
7
8#[derive(Debug)]
9pub struct AgentDatabaseService {
10    pub repository: AgentServiceRepository,
11    pub registry: AgentRegistry,
12}
13
14impl AgentDatabaseService {
15    pub fn new(repository: AgentServiceRepository) -> OrchestrationResult<Self> {
16        let registry = AgentRegistry::new().map_err(|e| {
17            OrchestrationError::Database(format!("Failed to load agent registry: {e}"))
18        })?;
19
20        Ok(Self {
21            repository,
22            registry,
23        })
24    }
25
26    pub async fn register_agent(
27        &self,
28        name: &str,
29        pid: u32,
30        port: u16,
31    ) -> OrchestrationResult<String> {
32        self.repository
33            .register_agent(name, pid, port)
34            .await
35            .map_err(|e| OrchestrationError::Database(e.to_string()))
36    }
37
38    pub async fn get_status(&self, agent_name: &str) -> OrchestrationResult<AgentStatus> {
39        let row = self
40            .repository
41            .get_agent_status(agent_name)
42            .await
43            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
44
45        match row {
46            Some(r) => match (r.pid, r.status.as_str()) {
47                (Some(pid), "running") => {
48                    let pid = pid as u32;
49                    if process::process_exists(pid) {
50                        Ok(AgentStatus::Running {
51                            pid,
52                            port: r.port as u16,
53                        })
54                    } else {
55                        self.mark_failed(agent_name).await?;
56                        Ok(AgentStatus::Failed {
57                            reason: "Process died unexpectedly".to_string(),
58                            last_attempt: None,
59                            retry_count: 0,
60                        })
61                    }
62                },
63                (_, "starting") => Ok(AgentStatus::Failed {
64                    reason: "Agent is starting".to_string(),
65                    last_attempt: None,
66                    retry_count: 0,
67                }),
68                (_, "failed" | "crashed" | "stopped") => {
69                    let error_msg = self
70                        .get_error_message(agent_name)
71                        .await
72                        .unwrap_or_else(|_| "Unknown failure".to_string());
73                    Ok(AgentStatus::Failed {
74                        reason: error_msg,
75                        last_attempt: None,
76                        retry_count: 0,
77                    })
78                },
79                _ => {
80                    self.mark_failed(agent_name).await?;
81                    Ok(AgentStatus::Failed {
82                        reason: "Invalid database state".to_string(),
83                        last_attempt: None,
84                        retry_count: 0,
85                    })
86                },
87            },
88            None => Ok(AgentStatus::Failed {
89                reason: "No service record found".to_string(),
90                last_attempt: None,
91                retry_count: 0,
92            }),
93        }
94    }
95
96    pub async fn mark_failed(&self, agent_name: &str) -> OrchestrationResult<()> {
97        self.repository
98            .mark_error(agent_name)
99            .await
100            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
101
102        self.repository
103            .mark_crashed(agent_name)
104            .await
105            .map_err(|e| OrchestrationError::Database(e.to_string()))
106    }
107
108    pub async fn mark_crashed(&self, agent_name: &str) -> OrchestrationResult<()> {
109        self.mark_failed(agent_name).await
110    }
111
112    pub async fn get_error_message(&self, agent_name: &str) -> OrchestrationResult<String> {
113        let row = self
114            .repository
115            .get_agent_status(agent_name)
116            .await
117            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
118
119        match row {
120            Some(r) => Ok(format!("Status: {}", r.status)),
121            None => Ok("No service record".to_string()),
122        }
123    }
124
125    pub async fn mark_error(&self, agent_name: &str) -> OrchestrationResult<()> {
126        self.repository
127            .mark_error(agent_name)
128            .await
129            .map_err(|e| OrchestrationError::Database(e.to_string()))
130    }
131
132    pub async fn list_running_agents(&self) -> OrchestrationResult<Vec<String>> {
133        let rows = self
134            .repository
135            .list_running_agents()
136            .await
137            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
138
139        Ok(rows.into_iter().map(|row| row.name).collect())
140    }
141
142    pub async fn list_all_agents(&self) -> OrchestrationResult<Vec<(String, AgentStatus)>> {
143        let agent_configs = self.registry.list_agents().await.map_err(|e| {
144            OrchestrationError::Database(format!("Failed to list agents from config: {e}"))
145        })?;
146
147        let mut agents = Vec::new();
148
149        for agent_config in agent_configs {
150            let agent_name = &agent_config.name;
151
152            let status = self.get_status(agent_name).await?;
153
154            agents.push((agent_name.clone(), status));
155        }
156
157        Ok(agents)
158    }
159
160    pub async fn agent_exists(&self, agent_name: &str) -> OrchestrationResult<bool> {
161        self.registry
162            .get_agent(agent_name)
163            .await
164            .map(|_| true)
165            .or_else(|_| Ok(false))
166    }
167
168    pub async fn get_agent_config(&self, agent_name: &str) -> OrchestrationResult<AgentConfig> {
169        let agent_config = self.registry.get_agent(agent_name).await.map_err(|e| {
170            OrchestrationError::AgentNotFound(format!(
171                "Agent {} not found in config: {}",
172                agent_name, e
173            ))
174        })?;
175
176        Ok(agent_config)
177    }
178
179    pub async fn cleanup_orphaned_services(&self) -> OrchestrationResult<u64> {
180        let rows = self
181            .repository
182            .list_running_agent_pids()
183            .await
184            .map_err(|e| OrchestrationError::Database(e.to_string()))?;
185
186        let mut cleaned = 0u64;
187
188        for row in rows {
189            let pid = row.pid as u32;
190            if !process::process_exists(pid) {
191                self.mark_crashed(&row.name).await?;
192                cleaned += 1;
193            }
194        }
195
196        Ok(cleaned)
197    }
198
199    pub async fn remove_agent_service(&self, agent_name: &str) -> OrchestrationResult<()> {
200        self.repository
201            .remove_agent_service(agent_name)
202            .await
203            .map_err(|e| OrchestrationError::Database(e.to_string()))
204    }
205
206    pub async fn update_health_status(
207        &self,
208        agent_name: &str,
209        health_status: &str,
210    ) -> OrchestrationResult<()> {
211        self.repository
212            .update_health_status(agent_name, health_status)
213            .await
214            .map_err(|e| OrchestrationError::Database(e.to_string()))
215    }
216
217    pub async fn update_agent_running(
218        &self,
219        agent_name: &str,
220        pid: u32,
221        port: u16,
222    ) -> OrchestrationResult<String> {
223        self.repository
224            .register_agent(agent_name, pid, port)
225            .await
226            .map_err(|e| OrchestrationError::Database(e.to_string()))
227    }
228
229    pub async fn update_agent_stopped(&self, agent_name: &str) -> OrchestrationResult<()> {
230        self.repository
231            .mark_stopped(agent_name)
232            .await
233            .map_err(|e| OrchestrationError::Database(e.to_string()))
234    }
235
236    pub async fn register_agent_starting(
237        &self,
238        agent_name: &str,
239        pid: u32,
240        port: u16,
241    ) -> OrchestrationResult<String> {
242        self.repository
243            .register_agent_starting(agent_name, pid, port)
244            .await
245            .map_err(|e| OrchestrationError::Database(e.to_string()))
246    }
247
248    pub async fn mark_running(&self, agent_name: &str) -> OrchestrationResult<()> {
249        self.repository
250            .mark_running(agent_name)
251            .await
252            .map_err(|e| OrchestrationError::Database(e.to_string()))
253    }
254
255    pub async fn get_unresponsive_agents(&self) -> OrchestrationResult<Vec<(String, Option<u32>)>> {
256        use crate::services::agent_orchestration::monitor::check_a2a_agent_health;
257
258        let agents = self.list_all_agents().await?;
259
260        let mut unresponsive = Vec::new();
261        for (agent_name, status) in agents {
262            if let AgentStatus::Running { pid, port, .. } = status {
263                let is_healthy = check_a2a_agent_health(port, 10).await.unwrap_or(false);
264
265                if !is_healthy {
266                    unresponsive.push((agent_name, Some(pid)));
267                }
268            }
269        }
270
271        Ok(unresponsive)
272    }
273}