Skip to main content

systemprompt_database/repository/
service.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use sqlx::FromRow;
4
5use crate::DbPool;
6
7#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
8pub struct ServiceConfig {
9    pub name: String,
10    pub module_name: String,
11    pub status: String,
12    pub pid: Option<i32>,
13    pub port: i32,
14    pub binary_mtime: Option<i64>,
15    pub created_at: String,
16    pub updated_at: String,
17}
18
19#[derive(Debug)]
20pub struct CreateServiceInput<'a> {
21    pub name: &'a str,
22    pub module_name: &'a str,
23    pub status: &'a str,
24    pub port: u16,
25    pub binary_mtime: Option<i64>,
26}
27
28#[derive(Debug, Clone)]
29pub struct ServiceRepository {
30    db_pool: DbPool,
31}
32
33impl ServiceRepository {
34    pub const fn new(db_pool: DbPool) -> Self {
35        Self { db_pool }
36    }
37
38    fn get_pool(&self) -> Result<std::sync::Arc<sqlx::PgPool>> {
39        self.db_pool
40            .pool()
41            .ok_or_else(|| anyhow::anyhow!("No database pool available"))
42    }
43
44    pub async fn get_service_by_name(&self, name: &str) -> Result<Option<ServiceConfig>> {
45        let pool = self.get_pool()?;
46        let row = sqlx::query!(
47            r#"
48            SELECT name, module_name, status, pid, port, binary_mtime,
49                   created_at::text as "created_at!", updated_at::text as "updated_at!"
50            FROM services
51            WHERE name = $1
52            "#,
53            name
54        )
55        .fetch_optional(&*pool)
56        .await?;
57
58        Ok(row.map(|r| ServiceConfig {
59            name: r.name,
60            module_name: r.module_name,
61            status: r.status,
62            pid: r.pid,
63            port: r.port,
64            binary_mtime: r.binary_mtime,
65            created_at: r.created_at,
66            updated_at: r.updated_at,
67        }))
68    }
69
70    pub async fn get_all_agent_service_names(&self) -> Result<Vec<String>> {
71        let pool = self.get_pool()?;
72        let rows = sqlx::query!(
73            r#"
74            SELECT name FROM services WHERE module_name = 'agent'
75            "#
76        )
77        .fetch_all(&*pool)
78        .await?;
79
80        Ok(rows.into_iter().map(|r| r.name).collect())
81    }
82
83    pub async fn get_mcp_services(&self) -> Result<Vec<ServiceConfig>> {
84        let pool = self.get_pool()?;
85        let rows = sqlx::query!(
86            r#"
87            SELECT name, module_name, status, pid, port, binary_mtime,
88                   created_at::text as "created_at!", updated_at::text as "updated_at!"
89            FROM services
90            WHERE module_name = 'mcp'
91            ORDER BY name
92            "#
93        )
94        .fetch_all(&*pool)
95        .await?;
96
97        Ok(rows
98            .into_iter()
99            .map(|r| ServiceConfig {
100                name: r.name,
101                module_name: r.module_name,
102                status: r.status,
103                pid: r.pid,
104                port: r.port,
105                binary_mtime: r.binary_mtime,
106                created_at: r.created_at,
107                updated_at: r.updated_at,
108            })
109            .collect())
110    }
111
112    pub async fn create_service(&self, input: CreateServiceInput<'_>) -> Result<()> {
113        let pool = self.get_pool()?;
114        let port_i32 = i32::from(input.port);
115        sqlx::query!(
116            r#"
117            INSERT INTO services (name, module_name, status, port, binary_mtime)
118            VALUES ($1, $2, $3, $4, $5)
119            ON CONFLICT (name) DO UPDATE SET
120              module_name = EXCLUDED.module_name,
121              status = EXCLUDED.status,
122              port = EXCLUDED.port,
123              binary_mtime = EXCLUDED.binary_mtime,
124              updated_at = CURRENT_TIMESTAMP
125            "#,
126            input.name,
127            input.module_name,
128            input.status,
129            port_i32,
130            input.binary_mtime
131        )
132        .execute(&*pool)
133        .await?;
134        Ok(())
135    }
136
137    pub async fn update_service_status(&self, service_name: &str, status: &str) -> Result<()> {
138        let pool = self.get_pool()?;
139        sqlx::query!(
140            r#"
141            UPDATE services SET status = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2
142            "#,
143            status,
144            service_name
145        )
146        .execute(&*pool)
147        .await?;
148        Ok(())
149    }
150
151    pub async fn delete_service(&self, service_name: &str) -> Result<()> {
152        let pool = self.get_pool()?;
153        sqlx::query!(
154            r#"
155            DELETE FROM services WHERE name = $1
156            "#,
157            service_name
158        )
159        .execute(&*pool)
160        .await?;
161        Ok(())
162    }
163
164    pub async fn update_service_pid(&self, service_name: &str, pid: i32) -> Result<()> {
165        let pool = self.get_pool()?;
166        sqlx::query!(
167            r#"
168            UPDATE services SET pid = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2
169            "#,
170            pid,
171            service_name
172        )
173        .execute(&*pool)
174        .await?;
175        Ok(())
176    }
177
178    pub async fn clear_service_pid(&self, service_name: &str) -> Result<()> {
179        let pool = self.get_pool()?;
180        sqlx::query!(
181            r#"
182            UPDATE services SET pid = NULL, updated_at = CURRENT_TIMESTAMP WHERE name = $1
183            "#,
184            service_name
185        )
186        .execute(&*pool)
187        .await?;
188        Ok(())
189    }
190
191    pub async fn get_all_running_services(&self) -> Result<Vec<ServiceConfig>> {
192        let pool = self.get_pool()?;
193        let rows = sqlx::query!(
194            r#"
195            SELECT name, module_name, status, pid, port, binary_mtime,
196                   created_at::text as "created_at!", updated_at::text as "updated_at!"
197            FROM services
198            WHERE status = 'running'
199            ORDER BY name
200            "#
201        )
202        .fetch_all(&*pool)
203        .await?;
204
205        Ok(rows
206            .into_iter()
207            .map(|r| ServiceConfig {
208                name: r.name,
209                module_name: r.module_name,
210                status: r.status,
211                pid: r.pid,
212                port: r.port,
213                binary_mtime: r.binary_mtime,
214                created_at: r.created_at,
215                updated_at: r.updated_at,
216            })
217            .collect())
218    }
219
220    pub async fn count_running_services(&self, module_name: &str) -> Result<usize> {
221        let pool = self.get_pool()?;
222        let row = sqlx::query!(
223            r#"
224            SELECT COUNT(*) as "count!" FROM services WHERE module_name = $1 AND status = 'running'
225            "#,
226            module_name
227        )
228        .fetch_one(&*pool)
229        .await?;
230
231        Ok(usize::try_from(row.count).unwrap_or(0))
232    }
233
234    pub async fn mark_service_crashed(&self, service_name: &str) -> Result<()> {
235        let pool = self.get_pool()?;
236        sqlx::query!(
237            r#"
238            UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP WHERE name = $1
239            "#,
240            service_name
241        )
242        .execute(&*pool)
243        .await?;
244        Ok(())
245    }
246
247    pub async fn update_service_stopped(&self, service_name: &str) -> Result<()> {
248        let pool = self.get_pool()?;
249        sqlx::query!(
250            r#"
251            UPDATE services
252            SET status = 'stopped', pid = NULL, updated_at = CURRENT_TIMESTAMP
253            WHERE name = $1
254            "#,
255            service_name
256        )
257        .execute(&*pool)
258        .await?;
259        Ok(())
260    }
261
262    pub async fn get_running_services_with_pid(&self) -> Result<Vec<ServiceConfig>> {
263        self.get_all_running_services().await
264    }
265
266    pub async fn get_services_by_type(&self, module_name: &str) -> Result<Vec<ServiceConfig>> {
267        let pool = self.get_pool()?;
268        let rows = sqlx::query!(
269            r#"
270            SELECT name, module_name, status, pid, port, binary_mtime,
271                   created_at::text as "created_at!", updated_at::text as "updated_at!"
272            FROM services
273            WHERE module_name = $1
274            ORDER BY name
275            "#,
276            module_name
277        )
278        .fetch_all(&*pool)
279        .await?;
280
281        Ok(rows
282            .into_iter()
283            .map(|r| ServiceConfig {
284                name: r.name,
285                module_name: r.module_name,
286                status: r.status,
287                pid: r.pid,
288                port: r.port,
289                binary_mtime: r.binary_mtime,
290                created_at: r.created_at,
291                updated_at: r.updated_at,
292            })
293            .collect())
294    }
295
296    pub async fn cleanup_stale_entries(&self) -> Result<u64> {
297        let pool = self.get_pool()?;
298        let result = sqlx::query!(
299            r#"
300            DELETE FROM services
301            WHERE status IN ('error', 'crashed')
302               OR (status = 'running' AND pid IS NULL)
303            "#
304        )
305        .execute(&*pool)
306        .await?;
307        Ok(result.rows_affected())
308    }
309}