Skip to main content

rustant_core/multi/
spawner.rs

1//! Agent spawner — lifecycle management for agents.
2//!
3//! Manages creating and terminating agents, enforces limits, and tracks
4//! parent-child relationships for hierarchical agent spawning.
5
6use super::isolation::{AgentContext, AgentStatus, ResourceLimits};
7use crate::config::SafetyConfig;
8use std::collections::HashMap;
9use std::path::PathBuf;
10use uuid::Uuid;
11
12/// Configuration for multi-agent spawning.
13#[derive(Debug, Clone)]
14pub struct SpawnerConfig {
15    /// Maximum number of concurrent agents.
16    pub max_agents: usize,
17    /// Default window size for agent memory.
18    pub default_window_size: usize,
19    /// Default safety config applied to spawned agents.
20    pub default_safety: SafetyConfig,
21}
22
23impl Default for SpawnerConfig {
24    fn default() -> Self {
25        Self {
26            max_agents: 8,
27            default_window_size: 10,
28            default_safety: SafetyConfig::default(),
29        }
30    }
31}
32
33/// Manages agent lifecycle — spawn, terminate, query.
34pub struct AgentSpawner {
35    config: SpawnerConfig,
36    contexts: HashMap<Uuid, AgentContext>,
37}
38
39impl AgentSpawner {
40    pub fn new(config: SpawnerConfig) -> Self {
41        Self {
42            config,
43            contexts: HashMap::new(),
44        }
45    }
46
47    /// Spawn a new top-level agent. Returns the agent ID, or an error if limit reached.
48    pub fn spawn(&mut self, name: impl Into<String>) -> Result<Uuid, String> {
49        if self.contexts.len() >= self.config.max_agents {
50            return Err(format!(
51                "Agent limit reached (max {})",
52                self.config.max_agents
53            ));
54        }
55
56        let ctx = AgentContext::new(
57            name,
58            self.config.default_window_size,
59            self.config.default_safety.clone(),
60        );
61        let id = ctx.agent_id;
62        self.contexts.insert(id, ctx);
63        Ok(id)
64    }
65
66    /// Spawn a child agent under a parent. Returns the child's ID.
67    pub fn spawn_child(
68        &mut self,
69        name: impl Into<String>,
70        parent_id: Uuid,
71    ) -> Result<Uuid, String> {
72        if !self.contexts.contains_key(&parent_id) {
73            return Err(format!("Parent agent {} not found", parent_id));
74        }
75        if self.contexts.len() >= self.config.max_agents {
76            return Err(format!(
77                "Agent limit reached (max {})",
78                self.config.max_agents
79            ));
80        }
81
82        let ctx = AgentContext::new_child(
83            name,
84            parent_id,
85            self.config.default_window_size,
86            self.config.default_safety.clone(),
87        );
88        let id = ctx.agent_id;
89        self.contexts.insert(id, ctx);
90        Ok(id)
91    }
92
93    /// Terminate an agent and all its children. Returns number of agents removed.
94    pub fn terminate(&mut self, agent_id: Uuid) -> usize {
95        let children = self.children_of(agent_id);
96        let mut count = 0;
97
98        // Recursively terminate children first
99        for child_id in children {
100            count += self.terminate(child_id);
101        }
102
103        if self.contexts.remove(&agent_id).is_some() {
104            count += 1;
105        }
106        count
107    }
108
109    /// Get a reference to an agent's context.
110    pub fn get(&self, agent_id: &Uuid) -> Option<&AgentContext> {
111        self.contexts.get(agent_id)
112    }
113
114    /// Get a mutable reference to an agent's context.
115    pub fn get_mut(&mut self, agent_id: &Uuid) -> Option<&mut AgentContext> {
116        self.contexts.get_mut(agent_id)
117    }
118
119    /// Find all direct children of a given agent.
120    pub fn children_of(&self, parent_id: Uuid) -> Vec<Uuid> {
121        self.contexts
122            .iter()
123            .filter(|(_, ctx)| ctx.parent_id == Some(parent_id))
124            .map(|(id, _)| *id)
125            .collect()
126    }
127
128    /// Total number of active agents.
129    pub fn agent_count(&self) -> usize {
130        self.contexts.len()
131    }
132
133    /// List all active agent IDs.
134    pub fn agent_ids(&self) -> Vec<Uuid> {
135        self.contexts.keys().copied().collect()
136    }
137
138    /// Spawn an agent with custom configuration.
139    pub fn spawn_with_config(
140        &mut self,
141        name: impl Into<String>,
142        workspace_dir: Option<PathBuf>,
143        llm_override: Option<String>,
144        resource_limits: ResourceLimits,
145    ) -> Result<Uuid, String> {
146        if self.contexts.len() >= self.config.max_agents {
147            return Err(format!(
148                "Agent limit reached (max {})",
149                self.config.max_agents
150            ));
151        }
152
153        let mut ctx = AgentContext::new(
154            name,
155            self.config.default_window_size,
156            self.config.default_safety.clone(),
157        );
158        ctx.workspace_dir = workspace_dir;
159        ctx.llm_override = llm_override;
160        ctx.resource_limits = resource_limits;
161        let id = ctx.agent_id;
162        self.contexts.insert(id, ctx);
163        Ok(id)
164    }
165
166    /// Get the status of an agent.
167    pub fn get_status(&self, agent_id: &Uuid) -> Option<AgentStatus> {
168        self.contexts.get(agent_id).map(|ctx| ctx.status)
169    }
170
171    /// Set the status of an agent.
172    pub fn set_status(&mut self, agent_id: &Uuid, status: AgentStatus) -> Result<(), String> {
173        let ctx = self
174            .contexts
175            .get_mut(agent_id)
176            .ok_or_else(|| format!("Agent {} not found", agent_id))?;
177        ctx.status = status;
178        Ok(())
179    }
180
181    /// List all agents with a given status.
182    pub fn list_by_status(&self, status: AgentStatus) -> Vec<Uuid> {
183        self.contexts
184            .iter()
185            .filter(|(_, ctx)| ctx.status == status)
186            .map(|(id, _)| *id)
187            .collect()
188    }
189}
190
191impl Default for AgentSpawner {
192    fn default() -> Self {
193        Self::new(SpawnerConfig::default())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_spawner_new_empty() {
203        let spawner = AgentSpawner::default();
204        assert_eq!(spawner.agent_count(), 0);
205    }
206
207    #[test]
208    fn test_spawner_spawn() {
209        let mut spawner = AgentSpawner::default();
210        let id = spawner.spawn("agent-1").unwrap();
211        assert_eq!(spawner.agent_count(), 1);
212        assert!(spawner.get(&id).is_some());
213        assert_eq!(spawner.get(&id).unwrap().name, "agent-1");
214    }
215
216    #[test]
217    fn test_spawner_max_limit() {
218        let config = SpawnerConfig {
219            max_agents: 2,
220            ..Default::default()
221        };
222        let mut spawner = AgentSpawner::new(config);
223        spawner.spawn("a1").unwrap();
224        spawner.spawn("a2").unwrap();
225        let result = spawner.spawn("a3");
226        assert!(result.is_err());
227        assert!(result.unwrap_err().contains("limit reached"));
228    }
229
230    #[test]
231    fn test_spawner_spawn_child() {
232        let mut spawner = AgentSpawner::default();
233        let parent = spawner.spawn("parent").unwrap();
234        let child = spawner.spawn_child("child", parent).unwrap();
235
236        assert_eq!(spawner.agent_count(), 2);
237        assert!(spawner.get(&child).unwrap().is_child());
238        assert_eq!(spawner.get(&child).unwrap().parent_id, Some(parent));
239    }
240
241    #[test]
242    fn test_spawner_spawn_child_missing_parent() {
243        let mut spawner = AgentSpawner::default();
244        let fake_parent = Uuid::new_v4();
245        let result = spawner.spawn_child("orphan", fake_parent);
246        assert!(result.is_err());
247        assert!(result.unwrap_err().contains("not found"));
248    }
249
250    #[test]
251    fn test_spawner_terminate() {
252        let mut spawner = AgentSpawner::default();
253        let id = spawner.spawn("agent").unwrap();
254        assert_eq!(spawner.agent_count(), 1);
255
256        let removed = spawner.terminate(id);
257        assert_eq!(removed, 1);
258        assert_eq!(spawner.agent_count(), 0);
259    }
260
261    #[test]
262    fn test_spawner_terminate_cascades_to_children() {
263        let mut spawner = AgentSpawner::default();
264        let parent = spawner.spawn("parent").unwrap();
265        let _child1 = spawner.spawn_child("child1", parent).unwrap();
266        let _child2 = spawner.spawn_child("child2", parent).unwrap();
267        assert_eq!(spawner.agent_count(), 3);
268
269        let removed = spawner.terminate(parent);
270        assert_eq!(removed, 3);
271        assert_eq!(spawner.agent_count(), 0);
272    }
273
274    #[test]
275    fn test_spawner_children_of() {
276        let mut spawner = AgentSpawner::default();
277        let parent = spawner.spawn("parent").unwrap();
278        let c1 = spawner.spawn_child("c1", parent).unwrap();
279        let c2 = spawner.spawn_child("c2", parent).unwrap();
280        let _other = spawner.spawn("other").unwrap();
281
282        let children = spawner.children_of(parent);
283        assert_eq!(children.len(), 2);
284        assert!(children.contains(&c1));
285        assert!(children.contains(&c2));
286    }
287
288    #[test]
289    fn test_spawn_with_custom_config() {
290        let mut spawner = AgentSpawner::default();
291        let limits = ResourceLimits {
292            max_memory_mb: Some(256),
293            max_tokens_per_turn: Some(2048),
294            max_tool_calls: Some(20),
295            max_runtime_secs: Some(120),
296        };
297        let id = spawner
298            .spawn_with_config(
299                "custom",
300                Some(PathBuf::from("/tmp/workspace")),
301                Some("claude-3-sonnet".into()),
302                limits,
303            )
304            .unwrap();
305
306        let ctx = spawner.get(&id).unwrap();
307        assert_eq!(
308            ctx.workspace_dir.as_deref(),
309            Some(std::path::Path::new("/tmp/workspace"))
310        );
311        assert_eq!(ctx.llm_override.as_deref(), Some("claude-3-sonnet"));
312        assert_eq!(ctx.resource_limits.max_memory_mb, Some(256));
313    }
314
315    #[test]
316    fn test_get_set_status() {
317        let mut spawner = AgentSpawner::default();
318        let id = spawner.spawn("agent").unwrap();
319
320        assert_eq!(spawner.get_status(&id), Some(AgentStatus::Idle));
321
322        spawner.set_status(&id, AgentStatus::Running).unwrap();
323        assert_eq!(spawner.get_status(&id), Some(AgentStatus::Running));
324
325        spawner.set_status(&id, AgentStatus::Terminated).unwrap();
326        assert_eq!(spawner.get_status(&id), Some(AgentStatus::Terminated));
327    }
328
329    #[test]
330    fn test_list_by_status() {
331        let mut spawner = AgentSpawner::default();
332        let a1 = spawner.spawn("a1").unwrap();
333        let a2 = spawner.spawn("a2").unwrap();
334        let _a3 = spawner.spawn("a3").unwrap();
335
336        spawner.set_status(&a1, AgentStatus::Running).unwrap();
337        spawner.set_status(&a2, AgentStatus::Running).unwrap();
338
339        let running = spawner.list_by_status(AgentStatus::Running);
340        assert_eq!(running.len(), 2);
341        assert!(running.contains(&a1));
342        assert!(running.contains(&a2));
343
344        let idle = spawner.list_by_status(AgentStatus::Idle);
345        assert_eq!(idle.len(), 1);
346    }
347
348    #[test]
349    fn test_set_status_unknown_agent() {
350        let mut spawner = AgentSpawner::default();
351        let fake = Uuid::new_v4();
352        let result = spawner.set_status(&fake, AgentStatus::Running);
353        assert!(result.is_err());
354        assert!(result.unwrap_err().contains("not found"));
355    }
356}