1use super::isolation::{AgentContext, AgentStatus, ResourceLimits};
7use crate::config::SafetyConfig;
8use std::collections::HashMap;
9use std::path::PathBuf;
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
14pub struct SpawnerConfig {
15 pub max_agents: usize,
17 pub default_window_size: usize,
19 pub default_safety: SafetyConfig,
21}
22
23impl Default for SpawnerConfig {
24 fn default() -> Self {
25 Self {
26 max_agents: 8,
27 default_window_size: 10,
28 default_safety: SafetyConfig::default(),
29 }
30 }
31}
32
33pub struct AgentSpawner {
35 config: SpawnerConfig,
36 contexts: HashMap<Uuid, AgentContext>,
37}
38
39impl AgentSpawner {
40 pub fn new(config: SpawnerConfig) -> Self {
41 Self {
42 config,
43 contexts: HashMap::new(),
44 }
45 }
46
47 pub fn spawn(&mut self, name: impl Into<String>) -> Result<Uuid, String> {
49 if self.contexts.len() >= self.config.max_agents {
50 return Err(format!(
51 "Agent limit reached (max {})",
52 self.config.max_agents
53 ));
54 }
55
56 let ctx = AgentContext::new(
57 name,
58 self.config.default_window_size,
59 self.config.default_safety.clone(),
60 );
61 let id = ctx.agent_id;
62 self.contexts.insert(id, ctx);
63 Ok(id)
64 }
65
66 pub fn spawn_child(
68 &mut self,
69 name: impl Into<String>,
70 parent_id: Uuid,
71 ) -> Result<Uuid, String> {
72 if !self.contexts.contains_key(&parent_id) {
73 return Err(format!("Parent agent {} not found", parent_id));
74 }
75 if self.contexts.len() >= self.config.max_agents {
76 return Err(format!(
77 "Agent limit reached (max {})",
78 self.config.max_agents
79 ));
80 }
81
82 let ctx = AgentContext::new_child(
83 name,
84 parent_id,
85 self.config.default_window_size,
86 self.config.default_safety.clone(),
87 );
88 let id = ctx.agent_id;
89 self.contexts.insert(id, ctx);
90 Ok(id)
91 }
92
93 pub fn terminate(&mut self, agent_id: Uuid) -> usize {
95 let children = self.children_of(agent_id);
96 let mut count = 0;
97
98 for child_id in children {
100 count += self.terminate(child_id);
101 }
102
103 if self.contexts.remove(&agent_id).is_some() {
104 count += 1;
105 }
106 count
107 }
108
109 pub fn get(&self, agent_id: &Uuid) -> Option<&AgentContext> {
111 self.contexts.get(agent_id)
112 }
113
114 pub fn get_mut(&mut self, agent_id: &Uuid) -> Option<&mut AgentContext> {
116 self.contexts.get_mut(agent_id)
117 }
118
119 pub fn children_of(&self, parent_id: Uuid) -> Vec<Uuid> {
121 self.contexts
122 .iter()
123 .filter(|(_, ctx)| ctx.parent_id == Some(parent_id))
124 .map(|(id, _)| *id)
125 .collect()
126 }
127
128 pub fn agent_count(&self) -> usize {
130 self.contexts.len()
131 }
132
133 pub fn agent_ids(&self) -> Vec<Uuid> {
135 self.contexts.keys().copied().collect()
136 }
137
138 pub fn spawn_with_config(
140 &mut self,
141 name: impl Into<String>,
142 workspace_dir: Option<PathBuf>,
143 llm_override: Option<String>,
144 resource_limits: ResourceLimits,
145 ) -> Result<Uuid, String> {
146 if self.contexts.len() >= self.config.max_agents {
147 return Err(format!(
148 "Agent limit reached (max {})",
149 self.config.max_agents
150 ));
151 }
152
153 let mut ctx = AgentContext::new(
154 name,
155 self.config.default_window_size,
156 self.config.default_safety.clone(),
157 );
158 ctx.workspace_dir = workspace_dir;
159 ctx.llm_override = llm_override;
160 ctx.resource_limits = resource_limits;
161 let id = ctx.agent_id;
162 self.contexts.insert(id, ctx);
163 Ok(id)
164 }
165
166 pub fn get_status(&self, agent_id: &Uuid) -> Option<AgentStatus> {
168 self.contexts.get(agent_id).map(|ctx| ctx.status)
169 }
170
171 pub fn set_status(&mut self, agent_id: &Uuid, status: AgentStatus) -> Result<(), String> {
173 let ctx = self
174 .contexts
175 .get_mut(agent_id)
176 .ok_or_else(|| format!("Agent {} not found", agent_id))?;
177 ctx.status = status;
178 Ok(())
179 }
180
181 pub fn list_by_status(&self, status: AgentStatus) -> Vec<Uuid> {
183 self.contexts
184 .iter()
185 .filter(|(_, ctx)| ctx.status == status)
186 .map(|(id, _)| *id)
187 .collect()
188 }
189}
190
191impl Default for AgentSpawner {
192 fn default() -> Self {
193 Self::new(SpawnerConfig::default())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_spawner_new_empty() {
203 let spawner = AgentSpawner::default();
204 assert_eq!(spawner.agent_count(), 0);
205 }
206
207 #[test]
208 fn test_spawner_spawn() {
209 let mut spawner = AgentSpawner::default();
210 let id = spawner.spawn("agent-1").unwrap();
211 assert_eq!(spawner.agent_count(), 1);
212 assert!(spawner.get(&id).is_some());
213 assert_eq!(spawner.get(&id).unwrap().name, "agent-1");
214 }
215
216 #[test]
217 fn test_spawner_max_limit() {
218 let config = SpawnerConfig {
219 max_agents: 2,
220 ..Default::default()
221 };
222 let mut spawner = AgentSpawner::new(config);
223 spawner.spawn("a1").unwrap();
224 spawner.spawn("a2").unwrap();
225 let result = spawner.spawn("a3");
226 assert!(result.is_err());
227 assert!(result.unwrap_err().contains("limit reached"));
228 }
229
230 #[test]
231 fn test_spawner_spawn_child() {
232 let mut spawner = AgentSpawner::default();
233 let parent = spawner.spawn("parent").unwrap();
234 let child = spawner.spawn_child("child", parent).unwrap();
235
236 assert_eq!(spawner.agent_count(), 2);
237 assert!(spawner.get(&child).unwrap().is_child());
238 assert_eq!(spawner.get(&child).unwrap().parent_id, Some(parent));
239 }
240
241 #[test]
242 fn test_spawner_spawn_child_missing_parent() {
243 let mut spawner = AgentSpawner::default();
244 let fake_parent = Uuid::new_v4();
245 let result = spawner.spawn_child("orphan", fake_parent);
246 assert!(result.is_err());
247 assert!(result.unwrap_err().contains("not found"));
248 }
249
250 #[test]
251 fn test_spawner_terminate() {
252 let mut spawner = AgentSpawner::default();
253 let id = spawner.spawn("agent").unwrap();
254 assert_eq!(spawner.agent_count(), 1);
255
256 let removed = spawner.terminate(id);
257 assert_eq!(removed, 1);
258 assert_eq!(spawner.agent_count(), 0);
259 }
260
261 #[test]
262 fn test_spawner_terminate_cascades_to_children() {
263 let mut spawner = AgentSpawner::default();
264 let parent = spawner.spawn("parent").unwrap();
265 let _child1 = spawner.spawn_child("child1", parent).unwrap();
266 let _child2 = spawner.spawn_child("child2", parent).unwrap();
267 assert_eq!(spawner.agent_count(), 3);
268
269 let removed = spawner.terminate(parent);
270 assert_eq!(removed, 3);
271 assert_eq!(spawner.agent_count(), 0);
272 }
273
274 #[test]
275 fn test_spawner_children_of() {
276 let mut spawner = AgentSpawner::default();
277 let parent = spawner.spawn("parent").unwrap();
278 let c1 = spawner.spawn_child("c1", parent).unwrap();
279 let c2 = spawner.spawn_child("c2", parent).unwrap();
280 let _other = spawner.spawn("other").unwrap();
281
282 let children = spawner.children_of(parent);
283 assert_eq!(children.len(), 2);
284 assert!(children.contains(&c1));
285 assert!(children.contains(&c2));
286 }
287
288 #[test]
289 fn test_spawn_with_custom_config() {
290 let mut spawner = AgentSpawner::default();
291 let limits = ResourceLimits {
292 max_memory_mb: Some(256),
293 max_tokens_per_turn: Some(2048),
294 max_tool_calls: Some(20),
295 max_runtime_secs: Some(120),
296 };
297 let id = spawner
298 .spawn_with_config(
299 "custom",
300 Some(PathBuf::from("/tmp/workspace")),
301 Some("claude-3-sonnet".into()),
302 limits,
303 )
304 .unwrap();
305
306 let ctx = spawner.get(&id).unwrap();
307 assert_eq!(
308 ctx.workspace_dir.as_deref(),
309 Some(std::path::Path::new("/tmp/workspace"))
310 );
311 assert_eq!(ctx.llm_override.as_deref(), Some("claude-3-sonnet"));
312 assert_eq!(ctx.resource_limits.max_memory_mb, Some(256));
313 }
314
315 #[test]
316 fn test_get_set_status() {
317 let mut spawner = AgentSpawner::default();
318 let id = spawner.spawn("agent").unwrap();
319
320 assert_eq!(spawner.get_status(&id), Some(AgentStatus::Idle));
321
322 spawner.set_status(&id, AgentStatus::Running).unwrap();
323 assert_eq!(spawner.get_status(&id), Some(AgentStatus::Running));
324
325 spawner.set_status(&id, AgentStatus::Terminated).unwrap();
326 assert_eq!(spawner.get_status(&id), Some(AgentStatus::Terminated));
327 }
328
329 #[test]
330 fn test_list_by_status() {
331 let mut spawner = AgentSpawner::default();
332 let a1 = spawner.spawn("a1").unwrap();
333 let a2 = spawner.spawn("a2").unwrap();
334 let _a3 = spawner.spawn("a3").unwrap();
335
336 spawner.set_status(&a1, AgentStatus::Running).unwrap();
337 spawner.set_status(&a2, AgentStatus::Running).unwrap();
338
339 let running = spawner.list_by_status(AgentStatus::Running);
340 assert_eq!(running.len(), 2);
341 assert!(running.contains(&a1));
342 assert!(running.contains(&a2));
343
344 let idle = spawner.list_by_status(AgentStatus::Idle);
345 assert_eq!(idle.len(), 1);
346 }
347
348 #[test]
349 fn test_set_status_unknown_agent() {
350 let mut spawner = AgentSpawner::default();
351 let fake = Uuid::new_v4();
352 let result = spawner.set_status(&fake, AgentStatus::Running);
353 assert!(result.is_err());
354 assert!(result.unwrap_err().contains("not found"));
355 }
356}