1use crate::models::{Agent, AgentRow};
2use anyhow::{Context, Result};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::{AgentId, CategoryId, SourceId};
7
8#[derive(Debug)]
9pub struct AgentRepository {
10 pool: Arc<PgPool>,
11 write_pool: Arc<PgPool>,
12}
13
14impl AgentRepository {
15 pub fn new(db: &DbPool) -> Result<Self> {
16 let pool = db.pool_arc().context("PostgreSQL pool not available")?;
17 let write_pool = db
18 .write_pool_arc()
19 .context("Write PostgreSQL pool not available")?;
20 Ok(Self { pool, write_pool })
21 }
22
23 pub async fn create(&self, agent: &Agent) -> Result<()> {
24 let agent_id_str = agent.id.as_str();
25 let category_id = agent.category_id.as_ref().map(ToString::to_string);
26 let source_id_str = agent.source_id.as_str();
27
28 sqlx::query!(
29 "INSERT INTO agents (agent_id, name, display_name, description, version,
30 system_prompt, enabled, port, endpoint, dev_only, is_primary, is_default,
31 tags, category_id, source_id, provider, model, mcp_servers, skills, card_json)
32 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, \
33 $18, $19, $20)",
34 agent_id_str,
35 agent.name,
36 agent.display_name,
37 agent.description,
38 agent.version,
39 agent.system_prompt,
40 agent.enabled,
41 agent.port,
42 agent.endpoint,
43 agent.dev_only,
44 agent.is_primary,
45 agent.is_default,
46 &agent.tags[..],
47 category_id,
48 source_id_str,
49 agent.provider,
50 agent.model,
51 &agent.mcp_servers[..],
52 &agent.skills[..],
53 agent.card_json
54 )
55 .execute(self.write_pool.as_ref())
56 .await
57 .context(format!("Failed to create agent: {}", agent.name))?;
58
59 Ok(())
60 }
61
62 pub async fn get_by_agent_id(&self, agent_id: &AgentId) -> Result<Option<Agent>> {
63 let agent_id_str = agent_id.as_str();
64
65 let row = sqlx::query_as!(
66 AgentRow,
67 r#"SELECT
68 agent_id as "agent_id!: AgentId",
69 name as "name!",
70 display_name as "display_name!",
71 description as "description!",
72 version as "version!",
73 system_prompt,
74 enabled as "enabled!",
75 port as "port!",
76 endpoint as "endpoint!",
77 dev_only as "dev_only!",
78 is_primary as "is_primary!",
79 is_default as "is_default!",
80 tags,
81 category_id as "category_id?: CategoryId",
82 source_id as "source_id!: SourceId",
83 provider,
84 model,
85 mcp_servers,
86 skills,
87 card_json as "card_json!",
88 created_at as "created_at!",
89 updated_at as "updated_at!"
90 FROM agents WHERE agent_id = $1"#,
91 agent_id_str
92 )
93 .fetch_optional(self.pool.as_ref())
94 .await
95 .context(format!("Failed to get agent by id: {agent_id}"))?;
96
97 Ok(row.map(agent_from_row))
98 }
99
100 pub async fn get_by_name(&self, name: &str) -> Result<Option<Agent>> {
101 let row = sqlx::query_as!(
102 AgentRow,
103 r#"SELECT
104 agent_id as "agent_id!: AgentId",
105 name as "name!",
106 display_name as "display_name!",
107 description as "description!",
108 version as "version!",
109 system_prompt,
110 enabled as "enabled!",
111 port as "port!",
112 endpoint as "endpoint!",
113 dev_only as "dev_only!",
114 is_primary as "is_primary!",
115 is_default as "is_default!",
116 tags,
117 category_id as "category_id?: CategoryId",
118 source_id as "source_id!: SourceId",
119 provider,
120 model,
121 mcp_servers,
122 skills,
123 card_json as "card_json!",
124 created_at as "created_at!",
125 updated_at as "updated_at!"
126 FROM agents WHERE name = $1"#,
127 name
128 )
129 .fetch_optional(self.pool.as_ref())
130 .await
131 .context(format!("Failed to get agent by name: {name}"))?;
132
133 Ok(row.map(agent_from_row))
134 }
135
136 pub async fn list_enabled(&self) -> Result<Vec<Agent>> {
137 let rows = sqlx::query_as!(
138 AgentRow,
139 r#"SELECT
140 agent_id as "agent_id!: AgentId",
141 name as "name!",
142 display_name as "display_name!",
143 description as "description!",
144 version as "version!",
145 system_prompt,
146 enabled as "enabled!",
147 port as "port!",
148 endpoint as "endpoint!",
149 dev_only as "dev_only!",
150 is_primary as "is_primary!",
151 is_default as "is_default!",
152 tags,
153 category_id as "category_id?: CategoryId",
154 source_id as "source_id!: SourceId",
155 provider,
156 model,
157 mcp_servers,
158 skills,
159 card_json as "card_json!",
160 created_at as "created_at!",
161 updated_at as "updated_at!"
162 FROM agents WHERE enabled = true ORDER BY name ASC"#
163 )
164 .fetch_all(self.pool.as_ref())
165 .await
166 .context("Failed to list enabled agents")?;
167
168 Ok(rows.into_iter().map(agent_from_row).collect())
169 }
170
171 pub async fn list_all(&self) -> Result<Vec<Agent>> {
172 let rows = sqlx::query_as!(
173 AgentRow,
174 r#"SELECT
175 agent_id as "agent_id!: AgentId",
176 name as "name!",
177 display_name as "display_name!",
178 description as "description!",
179 version as "version!",
180 system_prompt,
181 enabled as "enabled!",
182 port as "port!",
183 endpoint as "endpoint!",
184 dev_only as "dev_only!",
185 is_primary as "is_primary!",
186 is_default as "is_default!",
187 tags,
188 category_id as "category_id?: CategoryId",
189 source_id as "source_id!: SourceId",
190 provider,
191 model,
192 mcp_servers,
193 skills,
194 card_json as "card_json!",
195 created_at as "created_at!",
196 updated_at as "updated_at!"
197 FROM agents ORDER BY name ASC"#
198 )
199 .fetch_all(self.pool.as_ref())
200 .await
201 .context("Failed to list all agents")?;
202
203 Ok(rows.into_iter().map(agent_from_row).collect())
204 }
205
206 pub async fn update(&self, agent_id: &AgentId, agent: &Agent) -> Result<()> {
207 let agent_id_str = agent_id.as_str();
208
209 sqlx::query!(
210 "UPDATE agents SET name = $1, display_name = $2, description = $3, version = $4,
211 system_prompt = $5, enabled = $6, port = $7, endpoint = $8, dev_only = $9,
212 is_primary = $10, is_default = $11, tags = $12, provider = $13, model = $14,
213 mcp_servers = $15, skills = $16, card_json = $17, updated_at = CURRENT_TIMESTAMP
214 WHERE agent_id = $18",
215 agent.name,
216 agent.display_name,
217 agent.description,
218 agent.version,
219 agent.system_prompt,
220 agent.enabled,
221 agent.port,
222 agent.endpoint,
223 agent.dev_only,
224 agent.is_primary,
225 agent.is_default,
226 &agent.tags[..],
227 agent.provider,
228 agent.model,
229 &agent.mcp_servers[..],
230 &agent.skills[..],
231 agent.card_json,
232 agent_id_str
233 )
234 .execute(self.write_pool.as_ref())
235 .await
236 .context(format!("Failed to update agent: {}", agent.name))?;
237
238 Ok(())
239 }
240
241 pub async fn delete(&self, agent_id: &AgentId) -> Result<()> {
242 let agent_id_str = agent_id.as_str();
243
244 sqlx::query!("DELETE FROM agents WHERE agent_id = $1", agent_id_str)
245 .execute(self.write_pool.as_ref())
246 .await
247 .context(format!("Failed to delete agent: {agent_id}"))?;
248
249 Ok(())
250 }
251}
252
253fn agent_from_row(row: AgentRow) -> Agent {
254 Agent {
255 id: row.agent_id,
256 name: row.name,
257 display_name: row.display_name,
258 description: row.description,
259 version: row.version,
260 system_prompt: row.system_prompt,
261 enabled: row.enabled,
262 port: row.port,
263 endpoint: row.endpoint,
264 dev_only: row.dev_only,
265 is_primary: row.is_primary,
266 is_default: row.is_default,
267 tags: row.tags.unwrap_or_else(Vec::new),
268 category_id: row.category_id,
269 source_id: row.source_id,
270 provider: row.provider,
271 model: row.model,
272 mcp_servers: row.mcp_servers.unwrap_or_else(Vec::new),
273 skills: row.skills.unwrap_or_else(Vec::new),
274 card_json: row.card_json,
275 created_at: row.created_at,
276 updated_at: row.updated_at,
277 }
278}