1use anyhow::Result;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10use tokio::sync::mpsc;
11
12use crate::extensions::runner::{AgentEvent, AgentResult};
13use crate::models::Task;
14
15use super::events::OpenCodeEvent;
16use super::manager::{global_manager, OpenCodeManager};
17
18#[derive(Debug, Clone)]
20pub struct AgentHandle {
21 pub task_id: String,
22 pub session_id: String,
23 pub tag: String,
24}
25
26pub struct AgentOrchestrator {
28 manager: Arc<OpenCodeManager>,
29 sessions: HashMap<String, AgentHandle>,
31 task_sessions: HashMap<String, String>,
33 event_tx: mpsc::Sender<AgentEvent>,
35 start_times: HashMap<String, std::time::Instant>,
37}
38
39impl AgentOrchestrator {
40 pub async fn new(event_tx: mpsc::Sender<AgentEvent>) -> Result<Self> {
42 let manager = global_manager();
43 manager.ensure_running().await?;
44
45 Ok(Self {
46 manager,
47 sessions: HashMap::new(),
48 task_sessions: HashMap::new(),
49 event_tx,
50 start_times: HashMap::new(),
51 })
52 }
53
54 pub fn with_manager(manager: Arc<OpenCodeManager>, event_tx: mpsc::Sender<AgentEvent>) -> Self {
56 Self {
57 manager,
58 sessions: HashMap::new(),
59 task_sessions: HashMap::new(),
60 event_tx,
61 start_times: HashMap::new(),
62 }
63 }
64
65 pub async fn spawn_agent(
67 &mut self,
68 task: &Task,
69 tag: &str,
70 prompt: &str,
71 model: Option<(&str, &str)>,
72 ) -> Result<AgentHandle> {
73 let client = self.manager.client();
74
75 let session = client
77 .create_session(&format!("[{}] {}", task.id, task.title))
78 .await?;
79
80 client.send_message(&session.id, prompt, model).await?;
82
83 let handle = AgentHandle {
85 task_id: task.id.clone(),
86 session_id: session.id.clone(),
87 tag: tag.to_string(),
88 };
89
90 self.sessions.insert(session.id.clone(), handle.clone());
92 self.task_sessions
93 .insert(task.id.clone(), session.id.clone());
94 self.start_times
95 .insert(task.id.clone(), std::time::Instant::now());
96
97 let _ = self
99 .event_tx
100 .send(AgentEvent::Started {
101 task_id: task.id.clone(),
102 })
103 .await;
104
105 Ok(handle)
106 }
107
108 pub async fn cancel_agent(&mut self, task_id: &str) -> Result<()> {
110 if let Some(session_id) = self.task_sessions.get(task_id) {
111 self.manager.client().abort_session(session_id).await?;
112 }
113 Ok(())
114 }
115
116 pub fn active_tasks(&self) -> Vec<String> {
118 self.task_sessions.keys().cloned().collect()
119 }
120
121 pub fn is_task_active(&self, task_id: &str) -> bool {
123 self.task_sessions.contains_key(task_id)
124 }
125
126 pub fn active_count(&self) -> usize {
128 self.task_sessions.len()
129 }
130
131 pub async fn process_event(&mut self, event: OpenCodeEvent) -> Option<AgentEvent> {
133 let session_id = event.session_id()?;
134 let handle = self.sessions.get(session_id)?;
135 let task_id = handle.task_id.clone();
136
137 match event {
138 OpenCodeEvent::TextDelta { text, .. } => {
139 let agent_event = AgentEvent::Output {
140 task_id: task_id.clone(),
141 line: text,
142 };
143 let _ = self.event_tx.send(agent_event.clone()).await;
144 Some(agent_event)
145 }
146
147 OpenCodeEvent::ToolStart {
148 tool_name, input, ..
149 } => {
150 let input_summary = summarize_tool_input(&input);
152 let line = format!(">> {} {}", tool_name, input_summary);
153 let agent_event = AgentEvent::Output {
154 task_id: task_id.clone(),
155 line,
156 };
157 let _ = self.event_tx.send(agent_event.clone()).await;
158 Some(agent_event)
159 }
160
161 OpenCodeEvent::ToolResult {
162 tool_name, success, ..
163 } => {
164 let status = if success { "ok" } else { "failed" };
165 let line = format!("<< {} {}", tool_name, status);
166 let agent_event = AgentEvent::Output {
167 task_id: task_id.clone(),
168 line,
169 };
170 let _ = self.event_tx.send(agent_event.clone()).await;
171 Some(agent_event)
172 }
173
174 OpenCodeEvent::MessageComplete { success, .. } => {
175 let duration_ms = self
177 .start_times
178 .get(&task_id)
179 .map(|t| t.elapsed().as_millis() as u64)
180 .unwrap_or(0);
181
182 let result = AgentResult {
183 task_id: task_id.clone(),
184 success,
185 exit_code: if success { Some(0) } else { Some(1) },
186 output: String::new(), duration_ms,
188 };
189
190 if let Some(session_id) = self.task_sessions.remove(&task_id) {
192 self.sessions.remove(&session_id);
193 }
194 self.start_times.remove(&task_id);
195
196 let agent_event = AgentEvent::Completed { result };
197 let _ = self.event_tx.send(agent_event.clone()).await;
198 Some(agent_event)
199 }
200
201 OpenCodeEvent::SessionError { error, .. } => {
202 let duration_ms = self
203 .start_times
204 .get(&task_id)
205 .map(|t| t.elapsed().as_millis() as u64)
206 .unwrap_or(0);
207
208 let result = AgentResult {
209 task_id: task_id.clone(),
210 success: false,
211 exit_code: Some(1),
212 output: error,
213 duration_ms,
214 };
215
216 if let Some(session_id) = self.task_sessions.remove(&task_id) {
218 self.sessions.remove(&session_id);
219 }
220 self.start_times.remove(&task_id);
221
222 let agent_event = AgentEvent::Completed { result };
223 let _ = self.event_tx.send(agent_event.clone()).await;
224 Some(agent_event)
225 }
226
227 _ => None,
228 }
229 }
230
231 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
233 let mut results = Vec::new();
234 let mut event_stream = match self.manager.event_stream().await {
235 Ok(s) => s,
236 Err(_) => return results,
237 };
238
239 while !self.task_sessions.is_empty() {
240 if let Some(event) = event_stream.recv().await {
241 if let Some(AgentEvent::Completed { result }) = self.process_event(event).await {
242 results.push(result);
243 }
244 }
245 }
246
247 results
248 }
249
250 pub async fn cleanup(&mut self) {
252 let client = self.manager.client();
253 for session_id in self.sessions.keys() {
254 let _ = client.delete_session(session_id).await;
255 }
256 self.sessions.clear();
257 self.task_sessions.clear();
258 self.start_times.clear();
259 }
260}
261
262fn summarize_tool_input(input: &serde_json::Value) -> String {
264 match input {
265 serde_json::Value::Object(obj) => {
266 let keys: Vec<&str> = obj.keys().map(|k| k.as_str()).take(3).collect();
268 if keys.is_empty() {
269 "{}".to_string()
270 } else if keys.len() < obj.len() {
271 format!("{{{},...}}", keys.join(", "))
272 } else {
273 format!("{{{}}}", keys.join(", "))
274 }
275 }
276 serde_json::Value::String(s) => {
277 if s.len() > 50 {
278 format!("\"{}...\"", &s[..47])
279 } else {
280 format!("\"{}\"", s)
281 }
282 }
283 serde_json::Value::Null => "".to_string(),
284 other => {
285 let s = other.to_string();
286 if s.len() > 50 {
287 format!("{}...", &s[..47])
288 } else {
289 s
290 }
291 }
292 }
293}
294
295pub async fn execute_wave_server(
297 tasks: &[(Task, String)], working_dir: &Path,
299 model: Option<(&str, &str)>,
300 event_tx: mpsc::Sender<AgentEvent>,
301) -> Result<Vec<AgentResult>> {
302 let mut orchestrator = AgentOrchestrator::new(event_tx).await?;
303
304 for (task, tag) in tasks {
306 let prompt = generate_prompt(task, tag, working_dir);
307 if let Err(e) = orchestrator.spawn_agent(task, tag, &prompt, model).await {
308 eprintln!("Failed to spawn agent for {}: {}", task.id, e);
309 }
310 }
311
312 let results = orchestrator.wait_all().await;
314
315 orchestrator.cleanup().await;
317
318 Ok(results)
319}
320
321fn generate_prompt(task: &Task, tag: &str, working_dir: &Path) -> String {
323 let details = task
324 .details
325 .as_ref()
326 .map(|d| format!("\n\n## Details\n\n{}", d))
327 .unwrap_or_default();
328
329 let test_strategy = task
330 .test_strategy
331 .as_ref()
332 .map(|t| format!("\n\n## Test Strategy\n\n{}", t))
333 .unwrap_or_default();
334
335 format!(
336 r#"You are working on task [{id}] in phase "{tag}".
337
338## Task: {title}
339
340{description}{details}{test_strategy}
341
342## Instructions
343
3441. Implement the task requirements
3452. Test your changes
3463. When complete, run: `scud set-status {id} done --tag {tag}`
347
348Working directory: {working_dir}
349"#,
350 id = task.id,
351 tag = tag,
352 title = task.title,
353 description = task.description,
354 details = details,
355 test_strategy = test_strategy,
356 working_dir = working_dir.display(),
357 )
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_generate_prompt() {
366 let task = Task::new(
367 "1".to_string(),
368 "Test task".to_string(),
369 "Do something".to_string(),
370 );
371 let prompt = generate_prompt(&task, "feature", Path::new("/tmp"));
372
373 assert!(prompt.contains("[1]"));
374 assert!(prompt.contains("Test task"));
375 assert!(prompt.contains("Do something"));
376 assert!(prompt.contains("feature"));
377 assert!(prompt.contains("/tmp"));
378 }
379
380 #[test]
381 fn test_generate_prompt_with_details() {
382 let mut task = Task::new(
383 "2".to_string(),
384 "Task with details".to_string(),
385 "Main description".to_string(),
386 );
387 task.details = Some("Extra details here".to_string());
388 task.test_strategy = Some("Run unit tests".to_string());
389
390 let prompt = generate_prompt(&task, "api", Path::new("/project"));
391
392 assert!(prompt.contains("Extra details here"));
393 assert!(prompt.contains("Run unit tests"));
394 assert!(prompt.contains("## Details"));
395 assert!(prompt.contains("## Test Strategy"));
396 }
397
398 #[test]
399 fn test_summarize_tool_input_object() {
400 let input = serde_json::json!({"path": "/foo", "content": "bar"});
401 let summary = summarize_tool_input(&input);
402 assert!(summary.contains("path"));
403 assert!(summary.contains("content"));
404 }
405
406 #[test]
407 fn test_summarize_tool_input_string() {
408 let input = serde_json::json!("short string");
409 let summary = summarize_tool_input(&input);
410 assert_eq!(summary, "\"short string\"");
411
412 let long_string = "a".repeat(100);
413 let input = serde_json::json!(long_string);
414 let summary = summarize_tool_input(&input);
415 assert!(summary.len() < 60);
416 assert!(summary.ends_with("...\""));
417 }
418
419 #[test]
420 fn test_summarize_tool_input_null() {
421 let input = serde_json::Value::Null;
422 let summary = summarize_tool_input(&input);
423 assert_eq!(summary, "");
424 }
425
426 #[test]
427 fn test_agent_handle_debug() {
428 let handle = AgentHandle {
429 task_id: "auth:1".to_string(),
430 session_id: "sess-123".to_string(),
431 tag: "auth".to_string(),
432 };
433 let debug = format!("{:?}", handle);
434 assert!(debug.contains("auth:1"));
435 assert!(debug.contains("sess-123"));
436 }
437}