Skip to main content

systemprompt_database/repository/
service.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use sqlx::{FromRow, PgPool};
6
7use crate::DbPool;
8
9#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
10pub struct ServiceConfig {
11    pub name: String,
12    pub module_name: String,
13    pub status: String,
14    pub pid: Option<i32>,
15    pub port: i32,
16    pub binary_mtime: Option<i64>,
17    pub created_at: String,
18    pub updated_at: String,
19}
20
21#[derive(Debug)]
22pub struct CreateServiceInput<'a> {
23    pub name: &'a str,
24    pub module_name: &'a str,
25    pub status: &'a str,
26    pub port: u16,
27    pub binary_mtime: Option<i64>,
28}
29
30#[derive(Debug, Clone)]
31pub struct ServiceRepository {
32    pool: Arc<PgPool>,
33    write_pool: Arc<PgPool>,
34}
35
36impl ServiceRepository {
37    pub fn new(db: &DbPool) -> Result<Self> {
38        let pool = db.pool_arc()?;
39        let write_pool = db.write_pool_arc()?;
40        Ok(Self { pool, write_pool })
41    }
42
43    pub async fn get_service_by_name(&self, name: &str) -> Result<Option<ServiceConfig>> {
44        let row = sqlx::query!(
45            r#"
46            SELECT name, module_name, status, pid, port, binary_mtime,
47                   created_at::text as "created_at!", updated_at::text as "updated_at!"
48            FROM services
49            WHERE name = $1
50            "#,
51            name
52        )
53        .fetch_optional(&*self.pool)
54        .await?;
55
56        Ok(row.map(|r| ServiceConfig {
57            name: r.name,
58            module_name: r.module_name,
59            status: r.status,
60            pid: r.pid,
61            port: r.port,
62            binary_mtime: r.binary_mtime,
63            created_at: r.created_at,
64            updated_at: r.updated_at,
65        }))
66    }
67
68    pub async fn get_all_agent_service_names(&self) -> Result<Vec<String>> {
69        let rows = sqlx::query!(
70            r#"
71            SELECT name FROM services WHERE module_name = 'agent'
72            "#
73        )
74        .fetch_all(&*self.pool)
75        .await?;
76
77        Ok(rows.into_iter().map(|r| r.name).collect())
78    }
79
80    pub async fn get_mcp_services(&self) -> Result<Vec<ServiceConfig>> {
81        let rows = sqlx::query!(
82            r#"
83            SELECT name, module_name, status, pid, port, binary_mtime,
84                   created_at::text as "created_at!", updated_at::text as "updated_at!"
85            FROM services
86            WHERE module_name = 'mcp'
87            ORDER BY name
88            "#
89        )
90        .fetch_all(&*self.pool)
91        .await?;
92
93        Ok(rows
94            .into_iter()
95            .map(|r| ServiceConfig {
96                name: r.name,
97                module_name: r.module_name,
98                status: r.status,
99                pid: r.pid,
100                port: r.port,
101                binary_mtime: r.binary_mtime,
102                created_at: r.created_at,
103                updated_at: r.updated_at,
104            })
105            .collect())
106    }
107
108    pub async fn create_service(&self, input: CreateServiceInput<'_>) -> Result<()> {
109        let port_i32 = i32::from(input.port);
110        sqlx::query!(
111            r#"
112            INSERT INTO services (name, module_name, status, port, binary_mtime)
113            VALUES ($1, $2, $3, $4, $5)
114            ON CONFLICT (name) DO UPDATE SET
115              module_name = EXCLUDED.module_name,
116              status = EXCLUDED.status,
117              port = EXCLUDED.port,
118              binary_mtime = EXCLUDED.binary_mtime,
119              updated_at = CURRENT_TIMESTAMP
120            "#,
121            input.name,
122            input.module_name,
123            input.status,
124            port_i32,
125            input.binary_mtime
126        )
127        .execute(&*self.write_pool)
128        .await?;
129        Ok(())
130    }
131
132    pub async fn update_service_status(&self, service_name: &str, status: &str) -> Result<()> {
133        sqlx::query!(
134            r#"
135            UPDATE services SET status = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2
136            "#,
137            status,
138            service_name
139        )
140        .execute(&*self.write_pool)
141        .await?;
142        Ok(())
143    }
144
145    pub async fn delete_service(&self, service_name: &str) -> Result<()> {
146        sqlx::query!(
147            r#"
148            DELETE FROM services WHERE name = $1
149            "#,
150            service_name
151        )
152        .execute(&*self.write_pool)
153        .await?;
154        Ok(())
155    }
156
157    pub async fn update_service_pid(&self, service_name: &str, pid: i32) -> Result<()> {
158        sqlx::query!(
159            r#"
160            UPDATE services SET pid = $1, updated_at = CURRENT_TIMESTAMP WHERE name = $2
161            "#,
162            pid,
163            service_name
164        )
165        .execute(&*self.write_pool)
166        .await?;
167        Ok(())
168    }
169
170    pub async fn clear_service_pid(&self, service_name: &str) -> Result<()> {
171        sqlx::query!(
172            r#"
173            UPDATE services SET pid = NULL, updated_at = CURRENT_TIMESTAMP WHERE name = $1
174            "#,
175            service_name
176        )
177        .execute(&*self.write_pool)
178        .await?;
179        Ok(())
180    }
181
182    pub async fn get_all_running_services(&self) -> Result<Vec<ServiceConfig>> {
183        let rows = sqlx::query!(
184            r#"
185            SELECT name, module_name, status, pid, port, binary_mtime,
186                   created_at::text as "created_at!", updated_at::text as "updated_at!"
187            FROM services
188            WHERE status = 'running'
189            ORDER BY name
190            "#
191        )
192        .fetch_all(&*self.pool)
193        .await?;
194
195        Ok(rows
196            .into_iter()
197            .map(|r| ServiceConfig {
198                name: r.name,
199                module_name: r.module_name,
200                status: r.status,
201                pid: r.pid,
202                port: r.port,
203                binary_mtime: r.binary_mtime,
204                created_at: r.created_at,
205                updated_at: r.updated_at,
206            })
207            .collect())
208    }
209
210    pub async fn count_running_services(&self, module_name: &str) -> Result<usize> {
211        let row = sqlx::query!(
212            r#"
213            SELECT COUNT(*) as "count!" FROM services WHERE module_name = $1 AND status = 'running'
214            "#,
215            module_name
216        )
217        .fetch_one(&*self.pool)
218        .await?;
219
220        Ok(usize::try_from(row.count).unwrap_or(0))
221    }
222
223    pub async fn mark_service_crashed(&self, service_name: &str) -> Result<()> {
224        sqlx::query!(
225            r#"
226            UPDATE services SET status = 'error', pid = NULL, updated_at = CURRENT_TIMESTAMP WHERE name = $1
227            "#,
228            service_name
229        )
230        .execute(&*self.write_pool)
231        .await?;
232        Ok(())
233    }
234
235    pub async fn update_service_stopped(&self, service_name: &str) -> Result<()> {
236        sqlx::query!(
237            r#"
238            UPDATE services
239            SET status = 'stopped', pid = NULL, updated_at = CURRENT_TIMESTAMP
240            WHERE name = $1
241            "#,
242            service_name
243        )
244        .execute(&*self.write_pool)
245        .await?;
246        Ok(())
247    }
248
249    pub async fn get_running_services_with_pid(&self) -> Result<Vec<ServiceConfig>> {
250        self.get_all_running_services().await
251    }
252
253    pub async fn get_services_by_type(&self, module_name: &str) -> Result<Vec<ServiceConfig>> {
254        let rows = sqlx::query!(
255            r#"
256            SELECT name, module_name, status, pid, port, binary_mtime,
257                   created_at::text as "created_at!", updated_at::text as "updated_at!"
258            FROM services
259            WHERE module_name = $1
260            ORDER BY name
261            "#,
262            module_name
263        )
264        .fetch_all(&*self.pool)
265        .await?;
266
267        Ok(rows
268            .into_iter()
269            .map(|r| ServiceConfig {
270                name: r.name,
271                module_name: r.module_name,
272                status: r.status,
273                pid: r.pid,
274                port: r.port,
275                binary_mtime: r.binary_mtime,
276                created_at: r.created_at,
277                updated_at: r.updated_at,
278            })
279            .collect())
280    }
281
282    pub async fn cleanup_stale_entries(&self) -> Result<u64> {
283        let result = sqlx::query!(
284            r#"
285            DELETE FROM services
286            WHERE status IN ('error', 'crashed')
287               OR (status = 'running' AND pid IS NULL)
288            "#
289        )
290        .execute(&*self.write_pool)
291        .await?;
292        Ok(result.rows_affected())
293    }
294}