spec_ai_config/config/
registry.rs

1use anyhow::{anyhow, Context, Result};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use super::agent::AgentProfile;
6use crate::persistence::Persistence;
7
8const ACTIVE_AGENT_KEY: &str = "active_agent";
9
10/// Registry for managing agent profiles and tracking the active agent
11#[derive(Clone)]
12pub struct AgentRegistry {
13    agents: Arc<RwLock<HashMap<String, AgentProfile>>>,
14    active_agent: Arc<RwLock<Option<String>>>,
15    persistence: Persistence,
16}
17
18impl AgentRegistry {
19    /// Create a new AgentRegistry with the given agents and persistence
20    pub fn new(agents: HashMap<String, AgentProfile>, persistence: Persistence) -> Self {
21        Self {
22            agents: Arc::new(RwLock::new(agents)),
23            active_agent: Arc::new(RwLock::new(None)),
24            persistence,
25        }
26    }
27
28    /// Initialize the registry by loading the active agent from persistence
29    pub fn init(&self) -> Result<()> {
30        // Load the active agent from persistence if it exists
31        if let Some(entry) = self.persistence.policy_get(ACTIVE_AGENT_KEY)? {
32            if let Some(agent_name) = entry.value.as_str() {
33                // Validate that this agent still exists in the registry
34                let agents = self.agents.read().unwrap();
35                if agents.contains_key(agent_name) {
36                    drop(agents);
37                    let mut active = self.active_agent.write().unwrap();
38                    *active = Some(agent_name.to_string());
39                }
40                // If the persisted agent doesn't exist, we'll leave active as None
41                // and let the caller set a new default
42            }
43        }
44        Ok(())
45    }
46
47    /// Set the active agent profile by name
48    pub fn set_active(&self, name: &str) -> Result<()> {
49        // Verify the agent exists
50        let agents = self.agents.read().unwrap();
51        if !agents.contains_key(name) {
52            return Err(anyhow!(
53                "Agent '{}' not found. Available agents: {}",
54                name,
55                if agents.is_empty() {
56                    "none".to_string()
57                } else {
58                    agents
59                        .keys()
60                        .map(|s| s.as_str())
61                        .collect::<Vec<_>>()
62                        .join(", ")
63                }
64            ));
65        }
66        drop(agents);
67
68        // Update in-memory state
69        {
70            let mut active = self.active_agent.write().unwrap();
71            *active = Some(name.to_string());
72        }
73
74        // Persist to database
75        let value = serde_json::json!(name);
76        self.persistence
77            .policy_upsert(ACTIVE_AGENT_KEY, &value)
78            .context("persisting active agent")?;
79
80        Ok(())
81    }
82
83    /// Get the currently active agent profile
84    pub fn active(&self) -> Result<Option<(String, AgentProfile)>> {
85        let active_name = {
86            let active = self.active_agent.read().unwrap();
87            active.clone()
88        };
89
90        if let Some(name) = active_name {
91            let agents = self.agents.read().unwrap();
92            if let Some(profile) = agents.get(&name) {
93                Ok(Some((name.clone(), profile.clone())))
94            } else {
95                Ok(None)
96            }
97        } else {
98            Ok(None)
99        }
100    }
101
102    /// Get the name of the currently active agent (if any)
103    pub fn active_name(&self) -> Option<String> {
104        let active = self.active_agent.read().unwrap();
105        active.clone()
106    }
107
108    /// List all available agent profiles
109    pub fn list(&self) -> Vec<String> {
110        let agents = self.agents.read().unwrap();
111        let mut names: Vec<_> = agents.keys().cloned().collect();
112        names.sort();
113        names
114    }
115
116    /// Get a specific agent profile by name
117    pub fn get(&self, name: &str) -> Option<AgentProfile> {
118        let agents = self.agents.read().unwrap();
119        agents.get(name).cloned()
120    }
121
122    /// Add or update an agent profile
123    pub fn upsert(&self, name: String, profile: AgentProfile) -> Result<()> {
124        profile
125            .validate()
126            .with_context(|| format!("validating agent profile '{}'", name))?;
127
128        let mut agents = self.agents.write().unwrap();
129        agents.insert(name, profile);
130        Ok(())
131    }
132
133    /// Remove an agent profile
134    pub fn remove(&self, name: &str) -> Result<()> {
135        // Check if this is the active agent
136        let active_name = self.active_name();
137        if active_name.as_deref() == Some(name) {
138            return Err(anyhow!(
139                "Cannot remove '{}' because it is the currently active agent. \
140                 Please switch to a different agent first.",
141                name
142            ));
143        }
144
145        let mut agents = self.agents.write().unwrap();
146        if agents.remove(name).is_none() {
147            return Err(anyhow!("Agent '{}' not found", name));
148        }
149
150        Ok(())
151    }
152
153    /// Check if an agent exists
154    pub fn exists(&self, name: &str) -> bool {
155        let agents = self.agents.read().unwrap();
156        agents.contains_key(name)
157    }
158
159    /// Get the number of registered agents
160    pub fn count(&self) -> usize {
161        let agents = self.agents.read().unwrap();
162        agents.len()
163    }
164
165    /// Get the shared persistence layer
166    pub fn persistence(&self) -> &Persistence {
167        &self.persistence
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use tempfile::TempDir;
175
176    fn create_test_profile() -> AgentProfile {
177        AgentProfile {
178            prompt: Some("Test prompt".to_string()),
179            ..Default::default()
180        }
181    }
182
183    #[test]
184    fn test_new_registry() {
185        let temp_dir = TempDir::new().unwrap();
186        let db_path = temp_dir.path().join("test.duckdb");
187        let persistence = Persistence::new(&db_path).unwrap();
188
189        let mut agents = HashMap::new();
190        agents.insert("agent1".to_string(), create_test_profile());
191
192        let registry = AgentRegistry::new(agents, persistence);
193        assert_eq!(registry.count(), 1);
194        assert!(registry.exists("agent1"));
195    }
196
197    #[test]
198    fn test_set_and_get_active() {
199        let temp_dir = TempDir::new().unwrap();
200        let db_path = temp_dir.path().join("test.duckdb");
201        let persistence = Persistence::new(&db_path).unwrap();
202
203        let mut agents = HashMap::new();
204        agents.insert("agent1".to_string(), create_test_profile());
205        agents.insert("agent2".to_string(), create_test_profile());
206
207        let registry = AgentRegistry::new(agents, persistence);
208        registry.init().unwrap();
209
210        // Initially no active agent
211        assert!(registry.active().unwrap().is_none());
212
213        // Set active agent
214        registry.set_active("agent1").unwrap();
215        let active = registry.active().unwrap();
216        assert!(active.is_some());
217        assert_eq!(active.unwrap().0, "agent1");
218
219        // Verify it's persisted
220        assert_eq!(registry.active_name(), Some("agent1".to_string()));
221    }
222
223    #[test]
224    fn test_set_active_nonexistent_agent() {
225        let temp_dir = TempDir::new().unwrap();
226        let db_path = temp_dir.path().join("test.duckdb");
227        let persistence = Persistence::new(&db_path).unwrap();
228
229        let agents = HashMap::new();
230        let registry = AgentRegistry::new(agents, persistence);
231        registry.init().unwrap();
232
233        let result = registry.set_active("nonexistent");
234        assert!(result.is_err());
235    }
236
237    #[test]
238    fn test_list_agents() {
239        let temp_dir = TempDir::new().unwrap();
240        let db_path = temp_dir.path().join("test.duckdb");
241        let persistence = Persistence::new(&db_path).unwrap();
242
243        let mut agents = HashMap::new();
244        agents.insert("zebra".to_string(), create_test_profile());
245        agents.insert("alpha".to_string(), create_test_profile());
246        agents.insert("beta".to_string(), create_test_profile());
247
248        let registry = AgentRegistry::new(agents, persistence);
249        let list = registry.list();
250
251        // Should be sorted
252        assert_eq!(list, vec!["alpha", "beta", "zebra"]);
253    }
254
255    #[test]
256    fn test_upsert_and_remove() {
257        let temp_dir = TempDir::new().unwrap();
258        let db_path = temp_dir.path().join("test.duckdb");
259        let persistence = Persistence::new(&db_path).unwrap();
260
261        let agents = HashMap::new();
262        let registry = AgentRegistry::new(agents, persistence);
263        registry.init().unwrap();
264
265        // Add an agent
266        registry
267            .upsert("new_agent".to_string(), create_test_profile())
268            .unwrap();
269        assert!(registry.exists("new_agent"));
270        assert_eq!(registry.count(), 1);
271
272        // Remove the agent
273        registry.remove("new_agent").unwrap();
274        assert!(!registry.exists("new_agent"));
275        assert_eq!(registry.count(), 0);
276    }
277
278    #[test]
279    fn test_cannot_remove_active_agent() {
280        let temp_dir = TempDir::new().unwrap();
281        let db_path = temp_dir.path().join("test.duckdb");
282        let persistence = Persistence::new(&db_path).unwrap();
283
284        let mut agents = HashMap::new();
285        agents.insert("agent1".to_string(), create_test_profile());
286
287        let registry = AgentRegistry::new(agents, persistence);
288        registry.init().unwrap();
289        registry.set_active("agent1").unwrap();
290
291        // Should not be able to remove the active agent
292        let result = registry.remove("agent1");
293        assert!(result.is_err());
294    }
295
296    #[test]
297    fn test_persistence_across_restarts() {
298        let temp_dir = TempDir::new().unwrap();
299        let db_path = temp_dir.path().join("test.duckdb");
300
301        // First session: set active agent
302        {
303            let persistence = Persistence::new(&db_path).unwrap();
304            let mut agents = HashMap::new();
305            agents.insert("agent1".to_string(), create_test_profile());
306
307            let registry = AgentRegistry::new(agents, persistence);
308            registry.init().unwrap();
309            registry.set_active("agent1").unwrap();
310        }
311
312        // Second session: verify active agent is still set
313        {
314            let persistence = Persistence::new(&db_path).unwrap();
315            let mut agents = HashMap::new();
316            agents.insert("agent1".to_string(), create_test_profile());
317
318            let registry = AgentRegistry::new(agents, persistence);
319            registry.init().unwrap();
320
321            assert_eq!(registry.active_name(), Some("agent1".to_string()));
322        }
323    }
324}