Skip to main content

systemprompt_agent/repository/agent_service/
mod.rs

1use anyhow::{Context, Result};
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_traits::RepositoryError;
6
7#[derive(Debug)]
8pub struct AgentServiceRow {
9    pub name: String,
10    pub pid: Option<i32>,
11    pub port: i32,
12    pub status: String,
13}
14
15#[derive(Debug)]
16pub struct AgentServerIdRow {
17    pub name: String,
18}
19
20#[derive(Debug)]
21pub struct AgentServerIdPidRow {
22    pub name: String,
23    pub pid: i32,
24}
25
26#[derive(Debug, Clone)]
27pub struct AgentServiceRepository {
28    pool: Arc<PgPool>,
29    write_pool: Arc<PgPool>,
30}
31
32impl AgentServiceRepository {
33    pub fn new(db: &DbPool) -> Result<Self> {
34        let pool = db.pool_arc().context("PostgreSQL pool not available")?;
35        let write_pool = db
36            .write_pool_arc()
37            .context("Write PostgreSQL pool not available")?;
38        Ok(Self { pool, write_pool })
39    }
40
41    pub async fn register_agent(
42        &self,
43        name: &str,
44        pid: u32,
45        port: u16,
46    ) -> Result<String, RepositoryError> {
47        self.remove_agent_service(name).await?;
48
49        let pool = &self.write_pool;
50        let pid_i32 = pid as i32;
51        let port_i32 = i32::from(port);
52
53        sqlx::query!(
54            "INSERT INTO services (name, module_name, pid, port, status, updated_at)
55             VALUES ($1, 'agent', $2, $3, 'running', CURRENT_TIMESTAMP)
56             ON CONFLICT (name) DO UPDATE SET pid = $2, port = $3, status = 'running', updated_at \
57             = CURRENT_TIMESTAMP",
58            name,
59            pid_i32,
60            port_i32
61        )
62        .execute(pool.as_ref())
63        .await
64        .context("Failed to register agent")
65        .map_err(RepositoryError::Other)?;
66
67        Ok(name.to_string())
68    }
69
70    pub async fn register_agent_starting(
71        &self,
72        name: &str,
73        pid: u32,
74        port: u16,
75    ) -> Result<String, RepositoryError> {
76        self.remove_agent_service(name).await?;
77
78        let pool = &self.write_pool;
79        let pid_i32 = pid as i32;
80        let port_i32 = i32::from(port);
81
82        sqlx::query!(
83            "INSERT INTO services (name, module_name, pid, port, status, updated_at)
84             VALUES ($1, 'agent', $2, $3, 'starting', CURRENT_TIMESTAMP)
85             ON CONFLICT (name) DO UPDATE SET pid = $2, port = $3, status = 'starting', updated_at \
86             = CURRENT_TIMESTAMP",
87            name,
88            pid_i32,
89            port_i32
90        )
91        .execute(pool.as_ref())
92        .await
93        .context("Failed to register agent as starting")
94        .map_err(RepositoryError::Other)?;
95
96        Ok(name.to_string())
97    }
98
99    pub async fn mark_running(&self, agent_name: &str) -> Result<(), RepositoryError> {
100        let pool = &self.write_pool;
101
102        sqlx::query!(
103            "UPDATE services SET status = 'running', updated_at = CURRENT_TIMESTAMP WHERE name = \
104             $1",
105            agent_name
106        )
107        .execute(pool.as_ref())
108        .await
109        .context("Failed to mark agent as running")
110        .map_err(RepositoryError::Other)?;
111
112        Ok(())
113    }
114
115    pub async fn get_agent_status(
116        &self,
117        agent_name: &str,
118    ) -> Result<Option<AgentServiceRow>, RepositoryError> {
119        let pool = &self.pool;
120
121        let row = sqlx::query!(
122            "SELECT name, pid, port, status FROM services WHERE name = $1",
123            agent_name
124        )
125        .fetch_optional(pool.as_ref())
126        .await
127        .context("Failed to get agent status")
128        .map_err(RepositoryError::Other)?;
129
130        Ok(row.map(|r| AgentServiceRow {
131            name: r.name,
132            pid: r.pid,
133            port: r.port,
134            status: r.status,
135        }))
136    }
137
138    pub async fn mark_crashed(&self, agent_name: &str) -> Result<(), RepositoryError> {
139        let pool = &self.write_pool;
140
141        sqlx::query!(
142            "UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP \
143             WHERE name = $1",
144            agent_name
145        )
146        .execute(pool.as_ref())
147        .await
148        .context("Failed to mark agent as crashed")
149        .map_err(RepositoryError::Other)?;
150
151        Ok(())
152    }
153
154    pub async fn mark_stopped(&self, agent_name: &str) -> Result<(), RepositoryError> {
155        let pool = &self.write_pool;
156
157        sqlx::query!(
158            "UPDATE services SET status = 'stopped', pid = NULL, updated_at = CURRENT_TIMESTAMP \
159             WHERE name = $1",
160            agent_name
161        )
162        .execute(pool.as_ref())
163        .await
164        .context("Failed to mark agent as stopped")
165        .map_err(RepositoryError::Other)?;
166
167        Ok(())
168    }
169
170    pub async fn mark_error(&self, agent_name: &str) -> Result<(), RepositoryError> {
171        let pool = &self.write_pool;
172
173        sqlx::query!(
174            "UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP \
175             WHERE name = $1",
176            agent_name
177        )
178        .execute(pool.as_ref())
179        .await
180        .context("Failed to mark agent with error")
181        .map_err(RepositoryError::Other)?;
182
183        Ok(())
184    }
185
186    pub async fn list_running_agents(&self) -> Result<Vec<AgentServerIdRow>, RepositoryError> {
187        let pool = &self.pool;
188
189        let rows = sqlx::query!("SELECT name FROM services WHERE status = 'running'")
190            .fetch_all(pool.as_ref())
191            .await
192            .context("Failed to list running agents")
193            .map_err(RepositoryError::Other)?;
194
195        Ok(rows
196            .into_iter()
197            .map(|r| AgentServerIdRow { name: r.name })
198            .collect())
199    }
200
201    pub async fn list_running_agent_pids(
202        &self,
203    ) -> Result<Vec<AgentServerIdPidRow>, RepositoryError> {
204        let pool = &self.pool;
205
206        let rows = sqlx::query!(
207            "SELECT name, pid FROM services WHERE status = 'running' AND pid IS NOT NULL"
208        )
209        .fetch_all(pool.as_ref())
210        .await
211        .context("Failed to list running agent PIDs")
212        .map_err(RepositoryError::Other)?;
213
214        Ok(rows
215            .into_iter()
216            .filter_map(|r| r.pid.map(|pid| AgentServerIdPidRow { name: r.name, pid }))
217            .collect())
218    }
219
220    pub async fn remove_agent_service(&self, agent_name: &str) -> Result<(), RepositoryError> {
221        let pool = &self.write_pool;
222
223        sqlx::query!("DELETE FROM services WHERE name = $1", agent_name)
224            .execute(pool.as_ref())
225            .await
226            .context("Failed to remove agent service")
227            .map_err(RepositoryError::Other)?;
228
229        Ok(())
230    }
231
232    pub async fn update_health_status(
233        &self,
234        agent_name: &str,
235        health_status: &str,
236    ) -> Result<(), RepositoryError> {
237        let pool = &self.write_pool;
238
239        sqlx::query!(
240            "UPDATE services SET status = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2",
241            health_status,
242            agent_name
243        )
244        .execute(pool.as_ref())
245        .await
246        .context("Failed to update agent health status")
247        .map_err(RepositoryError::Other)?;
248
249        Ok(())
250    }
251}