Skip to main content

systemprompt_agent/repository/content/
agent.rs

1use crate::models::{Agent, AgentRow};
2use anyhow::{Context, Result};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{AgentId, CategoryId, SourceId};
7
8#[derive(Debug)]
9pub struct AgentRepository {
10    pool: Arc<PgPool>,
11    write_pool: Arc<PgPool>,
12}
13
14impl AgentRepository {
15    pub fn new(db: &DbPool) -> Result<Self> {
16        let pool = db.pool_arc().context("PostgreSQL pool not available")?;
17        let write_pool = db
18            .write_pool_arc()
19            .context("Write PostgreSQL pool not available")?;
20        Ok(Self { pool, write_pool })
21    }
22
23    pub async fn create(&self, agent: &Agent) -> Result<()> {
24        let agent_id_str = agent.id.as_str();
25        let category_id = agent.category_id.as_ref().map(ToString::to_string);
26        let source_id_str = agent.source_id.as_str();
27
28        sqlx::query!(
29            "INSERT INTO agents (agent_id, name, display_name, description, version,
30             system_prompt, enabled, port, endpoint, dev_only, is_primary, is_default,
31             tags, category_id, source_id, provider, model, mcp_servers, skills, card_json)
32             VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, \
33             $18, $19, $20)",
34            agent_id_str,
35            agent.name,
36            agent.display_name,
37            agent.description,
38            agent.version,
39            agent.system_prompt,
40            agent.enabled,
41            agent.port,
42            agent.endpoint,
43            agent.dev_only,
44            agent.is_primary,
45            agent.is_default,
46            &agent.tags[..],
47            category_id,
48            source_id_str,
49            agent.provider,
50            agent.model,
51            &agent.mcp_servers[..],
52            &agent.skills[..],
53            agent.card_json
54        )
55        .execute(self.write_pool.as_ref())
56        .await
57        .context(format!("Failed to create agent: {}", agent.name))?;
58
59        Ok(())
60    }
61
62    pub async fn get_by_agent_id(&self, agent_id: &AgentId) -> Result<Option<Agent>> {
63        let agent_id_str = agent_id.as_str();
64
65        let row = sqlx::query_as!(
66            AgentRow,
67            r#"SELECT
68                agent_id as "agent_id!: AgentId",
69                name as "name!",
70                display_name as "display_name!",
71                description as "description!",
72                version as "version!",
73                system_prompt,
74                enabled as "enabled!",
75                port as "port!",
76                endpoint as "endpoint!",
77                dev_only as "dev_only!",
78                is_primary as "is_primary!",
79                is_default as "is_default!",
80                tags,
81                category_id as "category_id?: CategoryId",
82                source_id as "source_id!: SourceId",
83                provider,
84                model,
85                mcp_servers,
86                skills,
87                card_json as "card_json!",
88                created_at as "created_at!",
89                updated_at as "updated_at!"
90            FROM agents WHERE agent_id = $1"#,
91            agent_id_str
92        )
93        .fetch_optional(self.pool.as_ref())
94        .await
95        .context(format!("Failed to get agent by id: {agent_id}"))?;
96
97        Ok(row.map(agent_from_row))
98    }
99
100    pub async fn get_by_name(&self, name: &str) -> Result<Option<Agent>> {
101        let row = sqlx::query_as!(
102            AgentRow,
103            r#"SELECT
104                agent_id as "agent_id!: AgentId",
105                name as "name!",
106                display_name as "display_name!",
107                description as "description!",
108                version as "version!",
109                system_prompt,
110                enabled as "enabled!",
111                port as "port!",
112                endpoint as "endpoint!",
113                dev_only as "dev_only!",
114                is_primary as "is_primary!",
115                is_default as "is_default!",
116                tags,
117                category_id as "category_id?: CategoryId",
118                source_id as "source_id!: SourceId",
119                provider,
120                model,
121                mcp_servers,
122                skills,
123                card_json as "card_json!",
124                created_at as "created_at!",
125                updated_at as "updated_at!"
126            FROM agents WHERE name = $1"#,
127            name
128        )
129        .fetch_optional(self.pool.as_ref())
130        .await
131        .context(format!("Failed to get agent by name: {name}"))?;
132
133        Ok(row.map(agent_from_row))
134    }
135
136    pub async fn list_enabled(&self) -> Result<Vec<Agent>> {
137        let rows = sqlx::query_as!(
138            AgentRow,
139            r#"SELECT
140                agent_id as "agent_id!: AgentId",
141                name as "name!",
142                display_name as "display_name!",
143                description as "description!",
144                version as "version!",
145                system_prompt,
146                enabled as "enabled!",
147                port as "port!",
148                endpoint as "endpoint!",
149                dev_only as "dev_only!",
150                is_primary as "is_primary!",
151                is_default as "is_default!",
152                tags,
153                category_id as "category_id?: CategoryId",
154                source_id as "source_id!: SourceId",
155                provider,
156                model,
157                mcp_servers,
158                skills,
159                card_json as "card_json!",
160                created_at as "created_at!",
161                updated_at as "updated_at!"
162            FROM agents WHERE enabled = true ORDER BY name ASC"#
163        )
164        .fetch_all(self.pool.as_ref())
165        .await
166        .context("Failed to list enabled agents")?;
167
168        Ok(rows.into_iter().map(agent_from_row).collect())
169    }
170
171    pub async fn list_all(&self) -> Result<Vec<Agent>> {
172        let rows = sqlx::query_as!(
173            AgentRow,
174            r#"SELECT
175                agent_id as "agent_id!: AgentId",
176                name as "name!",
177                display_name as "display_name!",
178                description as "description!",
179                version as "version!",
180                system_prompt,
181                enabled as "enabled!",
182                port as "port!",
183                endpoint as "endpoint!",
184                dev_only as "dev_only!",
185                is_primary as "is_primary!",
186                is_default as "is_default!",
187                tags,
188                category_id as "category_id?: CategoryId",
189                source_id as "source_id!: SourceId",
190                provider,
191                model,
192                mcp_servers,
193                skills,
194                card_json as "card_json!",
195                created_at as "created_at!",
196                updated_at as "updated_at!"
197            FROM agents ORDER BY name ASC"#
198        )
199        .fetch_all(self.pool.as_ref())
200        .await
201        .context("Failed to list all agents")?;
202
203        Ok(rows.into_iter().map(agent_from_row).collect())
204    }
205
206    pub async fn update(&self, agent_id: &AgentId, agent: &Agent) -> Result<()> {
207        let agent_id_str = agent_id.as_str();
208
209        sqlx::query!(
210            "UPDATE agents SET name = $1, display_name = $2, description = $3, version = $4,
211             system_prompt = $5, enabled = $6, port = $7, endpoint = $8, dev_only = $9,
212             is_primary = $10, is_default = $11, tags = $12, provider = $13, model = $14,
213             mcp_servers = $15, skills = $16, card_json = $17, updated_at = CURRENT_TIMESTAMP
214             WHERE agent_id = $18",
215            agent.name,
216            agent.display_name,
217            agent.description,
218            agent.version,
219            agent.system_prompt,
220            agent.enabled,
221            agent.port,
222            agent.endpoint,
223            agent.dev_only,
224            agent.is_primary,
225            agent.is_default,
226            &agent.tags[..],
227            agent.provider,
228            agent.model,
229            &agent.mcp_servers[..],
230            &agent.skills[..],
231            agent.card_json,
232            agent_id_str
233        )
234        .execute(self.write_pool.as_ref())
235        .await
236        .context(format!("Failed to update agent: {}", agent.name))?;
237
238        Ok(())
239    }
240
241    pub async fn delete(&self, agent_id: &AgentId) -> Result<()> {
242        let agent_id_str = agent_id.as_str();
243
244        sqlx::query!("DELETE FROM agents WHERE agent_id = $1", agent_id_str)
245            .execute(self.write_pool.as_ref())
246            .await
247            .context(format!("Failed to delete agent: {agent_id}"))?;
248
249        Ok(())
250    }
251}
252
253fn agent_from_row(row: AgentRow) -> Agent {
254    Agent {
255        id: row.agent_id,
256        name: row.name,
257        display_name: row.display_name,
258        description: row.description,
259        version: row.version,
260        system_prompt: row.system_prompt,
261        enabled: row.enabled,
262        port: row.port,
263        endpoint: row.endpoint,
264        dev_only: row.dev_only,
265        is_primary: row.is_primary,
266        is_default: row.is_default,
267        tags: row.tags.unwrap_or_else(Vec::new),
268        category_id: row.category_id,
269        source_id: row.source_id,
270        provider: row.provider,
271        model: row.model,
272        mcp_servers: row.mcp_servers.unwrap_or_else(Vec::new),
273        skills: row.skills.unwrap_or_else(Vec::new),
274        card_json: row.card_json,
275        created_at: row.created_at,
276        updated_at: row.updated_at,
277    }
278}