spec_ai_config/config/
registry.rs1use anyhow::{anyhow, Context, Result};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use super::agent::AgentProfile;
6use crate::persistence::Persistence;
7
8const ACTIVE_AGENT_KEY: &str = "active_agent";
9
10#[derive(Clone)]
12pub struct AgentRegistry {
13 agents: Arc<RwLock<HashMap<String, AgentProfile>>>,
14 active_agent: Arc<RwLock<Option<String>>>,
15 persistence: Persistence,
16}
17
18impl AgentRegistry {
19 pub fn new(agents: HashMap<String, AgentProfile>, persistence: Persistence) -> Self {
21 Self {
22 agents: Arc::new(RwLock::new(agents)),
23 active_agent: Arc::new(RwLock::new(None)),
24 persistence,
25 }
26 }
27
28 pub fn init(&self) -> Result<()> {
30 if let Some(entry) = self.persistence.policy_get(ACTIVE_AGENT_KEY)? {
32 if let Some(agent_name) = entry.value.as_str() {
33 let agents = self.agents.read().unwrap();
35 if agents.contains_key(agent_name) {
36 drop(agents);
37 let mut active = self.active_agent.write().unwrap();
38 *active = Some(agent_name.to_string());
39 }
40 }
43 }
44 Ok(())
45 }
46
47 pub fn set_active(&self, name: &str) -> Result<()> {
49 let agents = self.agents.read().unwrap();
51 if !agents.contains_key(name) {
52 return Err(anyhow!(
53 "Agent '{}' not found. Available agents: {}",
54 name,
55 if agents.is_empty() {
56 "none".to_string()
57 } else {
58 agents
59 .keys()
60 .map(|s| s.as_str())
61 .collect::<Vec<_>>()
62 .join(", ")
63 }
64 ));
65 }
66 drop(agents);
67
68 {
70 let mut active = self.active_agent.write().unwrap();
71 *active = Some(name.to_string());
72 }
73
74 let value = serde_json::json!(name);
76 self.persistence
77 .policy_upsert(ACTIVE_AGENT_KEY, &value)
78 .context("persisting active agent")?;
79
80 Ok(())
81 }
82
83 pub fn active(&self) -> Result<Option<(String, AgentProfile)>> {
85 let active_name = {
86 let active = self.active_agent.read().unwrap();
87 active.clone()
88 };
89
90 if let Some(name) = active_name {
91 let agents = self.agents.read().unwrap();
92 if let Some(profile) = agents.get(&name) {
93 Ok(Some((name.clone(), profile.clone())))
94 } else {
95 Ok(None)
96 }
97 } else {
98 Ok(None)
99 }
100 }
101
102 pub fn active_name(&self) -> Option<String> {
104 let active = self.active_agent.read().unwrap();
105 active.clone()
106 }
107
108 pub fn list(&self) -> Vec<String> {
110 let agents = self.agents.read().unwrap();
111 let mut names: Vec<_> = agents.keys().cloned().collect();
112 names.sort();
113 names
114 }
115
116 pub fn get(&self, name: &str) -> Option<AgentProfile> {
118 let agents = self.agents.read().unwrap();
119 agents.get(name).cloned()
120 }
121
122 pub fn upsert(&self, name: String, profile: AgentProfile) -> Result<()> {
124 profile
125 .validate()
126 .with_context(|| format!("validating agent profile '{}'", name))?;
127
128 let mut agents = self.agents.write().unwrap();
129 agents.insert(name, profile);
130 Ok(())
131 }
132
133 pub fn remove(&self, name: &str) -> Result<()> {
135 let active_name = self.active_name();
137 if active_name.as_deref() == Some(name) {
138 return Err(anyhow!(
139 "Cannot remove '{}' because it is the currently active agent. \
140 Please switch to a different agent first.",
141 name
142 ));
143 }
144
145 let mut agents = self.agents.write().unwrap();
146 if agents.remove(name).is_none() {
147 return Err(anyhow!("Agent '{}' not found", name));
148 }
149
150 Ok(())
151 }
152
153 pub fn exists(&self, name: &str) -> bool {
155 let agents = self.agents.read().unwrap();
156 agents.contains_key(name)
157 }
158
159 pub fn count(&self) -> usize {
161 let agents = self.agents.read().unwrap();
162 agents.len()
163 }
164
165 pub fn persistence(&self) -> &Persistence {
167 &self.persistence
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use tempfile::TempDir;
175
176 fn create_test_profile() -> AgentProfile {
177 AgentProfile {
178 prompt: Some("Test prompt".to_string()),
179 ..Default::default()
180 }
181 }
182
183 #[test]
184 fn test_new_registry() {
185 let temp_dir = TempDir::new().unwrap();
186 let db_path = temp_dir.path().join("test.duckdb");
187 let persistence = Persistence::new(&db_path).unwrap();
188
189 let mut agents = HashMap::new();
190 agents.insert("agent1".to_string(), create_test_profile());
191
192 let registry = AgentRegistry::new(agents, persistence);
193 assert_eq!(registry.count(), 1);
194 assert!(registry.exists("agent1"));
195 }
196
197 #[test]
198 fn test_set_and_get_active() {
199 let temp_dir = TempDir::new().unwrap();
200 let db_path = temp_dir.path().join("test.duckdb");
201 let persistence = Persistence::new(&db_path).unwrap();
202
203 let mut agents = HashMap::new();
204 agents.insert("agent1".to_string(), create_test_profile());
205 agents.insert("agent2".to_string(), create_test_profile());
206
207 let registry = AgentRegistry::new(agents, persistence);
208 registry.init().unwrap();
209
210 assert!(registry.active().unwrap().is_none());
212
213 registry.set_active("agent1").unwrap();
215 let active = registry.active().unwrap();
216 assert!(active.is_some());
217 assert_eq!(active.unwrap().0, "agent1");
218
219 assert_eq!(registry.active_name(), Some("agent1".to_string()));
221 }
222
223 #[test]
224 fn test_set_active_nonexistent_agent() {
225 let temp_dir = TempDir::new().unwrap();
226 let db_path = temp_dir.path().join("test.duckdb");
227 let persistence = Persistence::new(&db_path).unwrap();
228
229 let agents = HashMap::new();
230 let registry = AgentRegistry::new(agents, persistence);
231 registry.init().unwrap();
232
233 let result = registry.set_active("nonexistent");
234 assert!(result.is_err());
235 }
236
237 #[test]
238 fn test_list_agents() {
239 let temp_dir = TempDir::new().unwrap();
240 let db_path = temp_dir.path().join("test.duckdb");
241 let persistence = Persistence::new(&db_path).unwrap();
242
243 let mut agents = HashMap::new();
244 agents.insert("zebra".to_string(), create_test_profile());
245 agents.insert("alpha".to_string(), create_test_profile());
246 agents.insert("beta".to_string(), create_test_profile());
247
248 let registry = AgentRegistry::new(agents, persistence);
249 let list = registry.list();
250
251 assert_eq!(list, vec!["alpha", "beta", "zebra"]);
253 }
254
255 #[test]
256 fn test_upsert_and_remove() {
257 let temp_dir = TempDir::new().unwrap();
258 let db_path = temp_dir.path().join("test.duckdb");
259 let persistence = Persistence::new(&db_path).unwrap();
260
261 let agents = HashMap::new();
262 let registry = AgentRegistry::new(agents, persistence);
263 registry.init().unwrap();
264
265 registry
267 .upsert("new_agent".to_string(), create_test_profile())
268 .unwrap();
269 assert!(registry.exists("new_agent"));
270 assert_eq!(registry.count(), 1);
271
272 registry.remove("new_agent").unwrap();
274 assert!(!registry.exists("new_agent"));
275 assert_eq!(registry.count(), 0);
276 }
277
278 #[test]
279 fn test_cannot_remove_active_agent() {
280 let temp_dir = TempDir::new().unwrap();
281 let db_path = temp_dir.path().join("test.duckdb");
282 let persistence = Persistence::new(&db_path).unwrap();
283
284 let mut agents = HashMap::new();
285 agents.insert("agent1".to_string(), create_test_profile());
286
287 let registry = AgentRegistry::new(agents, persistence);
288 registry.init().unwrap();
289 registry.set_active("agent1").unwrap();
290
291 let result = registry.remove("agent1");
293 assert!(result.is_err());
294 }
295
296 #[test]
297 fn test_persistence_across_restarts() {
298 let temp_dir = TempDir::new().unwrap();
299 let db_path = temp_dir.path().join("test.duckdb");
300
301 {
303 let persistence = Persistence::new(&db_path).unwrap();
304 let mut agents = HashMap::new();
305 agents.insert("agent1".to_string(), create_test_profile());
306
307 let registry = AgentRegistry::new(agents, persistence);
308 registry.init().unwrap();
309 registry.set_active("agent1").unwrap();
310 }
311
312 {
314 let persistence = Persistence::new(&db_path).unwrap();
315 let mut agents = HashMap::new();
316 agents.insert("agent1".to_string(), create_test_profile());
317
318 let registry = AgentRegistry::new(agents, persistence);
319 registry.init().unwrap();
320
321 assert_eq!(registry.active_name(), Some("agent1".to_string()));
322 }
323 }
324}