Skip to main content

systemprompt_agent/repository/agent_service/
mod.rs

1use anyhow::{Context, Result};
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_traits::{Repository as RepositoryTrait, RepositoryError};
6
7#[derive(Debug)]
8pub struct AgentServiceRow {
9    pub name: String,
10    pub pid: Option<i32>,
11    pub port: i32,
12    pub status: String,
13}
14
15#[derive(Debug)]
16pub struct AgentServerIdRow {
17    pub name: String,
18}
19
20#[derive(Debug)]
21pub struct AgentServerIdPidRow {
22    pub name: String,
23    pub pid: i32,
24}
25
26#[derive(Debug, Clone)]
27pub struct AgentServiceRepository {
28    db_pool: DbPool,
29}
30
31impl RepositoryTrait for AgentServiceRepository {
32    type Pool = DbPool;
33    type Error = RepositoryError;
34
35    fn pool(&self) -> &Self::Pool {
36        &self.db_pool
37    }
38}
39
40impl AgentServiceRepository {
41    pub const fn new(db_pool: DbPool) -> Self {
42        Self { db_pool }
43    }
44
45    fn get_pg_pool(&self) -> Result<Arc<PgPool>> {
46        self.db_pool
47            .as_ref()
48            .get_postgres_pool()
49            .context("PostgreSQL pool not available")
50    }
51
52    pub async fn register_agent(
53        &self,
54        name: &str,
55        pid: u32,
56        port: u16,
57    ) -> Result<String, RepositoryError> {
58        self.remove_agent_service(name).await?;
59
60        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
61        let pid_i32 = pid as i32;
62        let port_i32 = i32::from(port);
63
64        sqlx::query!(
65            "INSERT INTO services (name, module_name, pid, port, status, updated_at)
66             VALUES ($1, 'agent', $2, $3, 'running', CURRENT_TIMESTAMP)
67             ON CONFLICT (name) DO UPDATE SET pid = $2, port = $3, status = 'running', updated_at \
68             = CURRENT_TIMESTAMP",
69            name,
70            pid_i32,
71            port_i32
72        )
73        .execute(pool.as_ref())
74        .await
75        .context("Failed to register agent")
76        .map_err(RepositoryError::Other)?;
77
78        Ok(name.to_string())
79    }
80
81    pub async fn register_agent_starting(
82        &self,
83        name: &str,
84        pid: u32,
85        port: u16,
86    ) -> Result<String, RepositoryError> {
87        self.remove_agent_service(name).await?;
88
89        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
90        let pid_i32 = pid as i32;
91        let port_i32 = i32::from(port);
92
93        sqlx::query!(
94            "INSERT INTO services (name, module_name, pid, port, status, updated_at)
95             VALUES ($1, 'agent', $2, $3, 'starting', CURRENT_TIMESTAMP)
96             ON CONFLICT (name) DO UPDATE SET pid = $2, port = $3, status = 'starting', updated_at \
97             = CURRENT_TIMESTAMP",
98            name,
99            pid_i32,
100            port_i32
101        )
102        .execute(pool.as_ref())
103        .await
104        .context("Failed to register agent as starting")
105        .map_err(RepositoryError::Other)?;
106
107        Ok(name.to_string())
108    }
109
110    pub async fn mark_running(&self, agent_name: &str) -> Result<(), RepositoryError> {
111        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
112
113        sqlx::query!(
114            "UPDATE services SET status = 'running', updated_at = CURRENT_TIMESTAMP WHERE name = \
115             $1",
116            agent_name
117        )
118        .execute(pool.as_ref())
119        .await
120        .context("Failed to mark agent as running")
121        .map_err(RepositoryError::Other)?;
122
123        Ok(())
124    }
125
126    pub async fn get_agent_status(
127        &self,
128        agent_name: &str,
129    ) -> Result<Option<AgentServiceRow>, RepositoryError> {
130        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
131
132        let row = sqlx::query!(
133            "SELECT name, pid, port, status FROM services WHERE name = $1",
134            agent_name
135        )
136        .fetch_optional(pool.as_ref())
137        .await
138        .context("Failed to get agent status")
139        .map_err(RepositoryError::Other)?;
140
141        Ok(row.map(|r| AgentServiceRow {
142            name: r.name,
143            pid: r.pid,
144            port: r.port,
145            status: r.status,
146        }))
147    }
148
149    pub async fn mark_crashed(&self, agent_name: &str) -> Result<(), RepositoryError> {
150        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
151
152        sqlx::query!(
153            "UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP \
154             WHERE name = $1",
155            agent_name
156        )
157        .execute(pool.as_ref())
158        .await
159        .context("Failed to mark agent as crashed")
160        .map_err(RepositoryError::Other)?;
161
162        Ok(())
163    }
164
165    pub async fn mark_stopped(&self, agent_name: &str) -> Result<(), RepositoryError> {
166        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
167
168        sqlx::query!(
169            "UPDATE services SET status = 'stopped', pid = NULL, updated_at = CURRENT_TIMESTAMP \
170             WHERE name = $1",
171            agent_name
172        )
173        .execute(pool.as_ref())
174        .await
175        .context("Failed to mark agent as stopped")
176        .map_err(RepositoryError::Other)?;
177
178        Ok(())
179    }
180
181    pub async fn mark_error(&self, agent_name: &str) -> Result<(), RepositoryError> {
182        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
183
184        sqlx::query!(
185            "UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP \
186             WHERE name = $1",
187            agent_name
188        )
189        .execute(pool.as_ref())
190        .await
191        .context("Failed to mark agent with error")
192        .map_err(RepositoryError::Other)?;
193
194        Ok(())
195    }
196
197    pub async fn list_running_agents(&self) -> Result<Vec<AgentServerIdRow>, RepositoryError> {
198        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
199
200        let rows = sqlx::query!("SELECT name FROM services WHERE status = 'running'")
201            .fetch_all(pool.as_ref())
202            .await
203            .context("Failed to list running agents")
204            .map_err(RepositoryError::Other)?;
205
206        Ok(rows
207            .into_iter()
208            .map(|r| AgentServerIdRow { name: r.name })
209            .collect())
210    }
211
212    pub async fn list_running_agent_pids(
213        &self,
214    ) -> Result<Vec<AgentServerIdPidRow>, RepositoryError> {
215        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
216
217        let rows = sqlx::query!(
218            "SELECT name, pid FROM services WHERE status = 'running' AND pid IS NOT NULL"
219        )
220        .fetch_all(pool.as_ref())
221        .await
222        .context("Failed to list running agent PIDs")
223        .map_err(RepositoryError::Other)?;
224
225        Ok(rows
226            .into_iter()
227            .filter_map(|r| r.pid.map(|pid| AgentServerIdPidRow { name: r.name, pid }))
228            .collect())
229    }
230
231    pub async fn remove_agent_service(&self, agent_name: &str) -> Result<(), RepositoryError> {
232        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
233
234        sqlx::query!("DELETE FROM services WHERE name = $1", agent_name)
235            .execute(pool.as_ref())
236            .await
237            .context("Failed to remove agent service")
238            .map_err(RepositoryError::Other)?;
239
240        Ok(())
241    }
242
243    pub async fn update_health_status(
244        &self,
245        agent_name: &str,
246        health_status: &str,
247    ) -> Result<(), RepositoryError> {
248        let pool = self.get_pg_pool().map_err(RepositoryError::Other)?;
249
250        sqlx::query!(
251            "UPDATE services SET status = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2",
252            health_status,
253            agent_name
254        )
255        .execute(pool.as_ref())
256        .await
257        .context("Failed to update agent health status")
258        .map_err(RepositoryError::Other)?;
259
260        Ok(())
261    }
262}