1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 };
205
206 cmd.current_dir(&config.working_dir);
208 cmd.env("SCUD_TASK_ID", &config.task_id);
209 cmd.stdout(Stdio::piped());
210 cmd.stderr(Stdio::piped());
211
212 let start_time = std::time::Instant::now();
213
214 let mut child = cmd.spawn().map_err(|e| {
216 anyhow::anyhow!(
217 "Failed to spawn {} for task {}: {}",
218 config.harness.name(),
219 config.task_id,
220 e
221 )
222 })?;
223
224 let _ = event_tx
226 .send(AgentEvent::Started {
227 task_id: task_id.clone(),
228 })
229 .await;
230
231 let stdout = child.stdout.take();
233 let stderr = child.stderr.take();
234 let event_tx_clone = event_tx.clone();
235 let task_id_clone = task_id.clone();
236
237 let handle = tokio::spawn(async move {
238 let mut output_buffer = String::new();
239
240 if let Some(stdout) = stdout {
242 let reader = BufReader::new(stdout);
243 let mut lines = reader.lines();
244 while let Ok(Some(line)) = lines.next_line().await {
245 output_buffer.push_str(&line);
246 output_buffer.push('\n');
247 let _ = event_tx_clone
248 .send(AgentEvent::Output {
249 task_id: task_id_clone.clone(),
250 line: line.clone(),
251 })
252 .await;
253 }
254 }
255
256 if let Some(stderr) = stderr {
258 let reader = BufReader::new(stderr);
259 let mut lines = reader.lines();
260 while let Ok(Some(line)) = lines.next_line().await {
261 output_buffer.push_str("[stderr] ");
262 output_buffer.push_str(&line);
263 output_buffer.push('\n');
264 }
265 }
266
267 let status = child.wait().await;
269 let duration_ms = start_time.elapsed().as_millis() as u64;
270
271 let (success, exit_code) = match status {
272 Ok(s) => (s.success(), s.code()),
273 Err(_) => (false, None),
274 };
275
276 let result = AgentResult {
277 task_id: task_id_clone.clone(),
278 success,
279 exit_code,
280 output: output_buffer,
281 duration_ms,
282 };
283
284 let _ = event_tx_clone
285 .send(AgentEvent::Completed {
286 result: result.clone(),
287 })
288 .await;
289
290 result
291 });
292
293 Ok(handle)
294}
295
296pub fn load_agent_config(
298 agent_type: Option<&str>,
299 default_harness: Harness,
300 default_model: Option<&str>,
301 working_dir: &Path,
302) -> (Harness, Option<String>) {
303 if let Some(agent_name) = agent_type {
304 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
305 let harness = agent_def.harness().unwrap_or(default_harness);
306 let model = agent_def
307 .model()
308 .map(String::from)
309 .or_else(|| default_model.map(String::from));
310 return (harness, model);
311 }
312 }
313
314 (default_harness, default_model.map(String::from))
316}
317
318pub struct AgentRunner {
320 event_tx: mpsc::Sender<AgentEvent>,
322 event_rx: mpsc::Receiver<AgentEvent>,
324 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
326}
327
328impl AgentRunner {
329 pub fn new(capacity: usize) -> Self {
331 let (event_tx, event_rx) = mpsc::channel(capacity);
332 Self {
333 event_tx,
334 event_rx,
335 handles: Vec::new(),
336 }
337 }
338
339 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
341 self.event_tx.clone()
342 }
343
344 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
346 let handle = spawn_agent(config, self.event_tx.clone()).await?;
347 self.handles.push(handle);
348 Ok(())
349 }
350
351 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
353 self.event_rx.recv().await
354 }
355
356 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
358 self.event_rx.try_recv().ok()
359 }
360
361 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
363 let handles = std::mem::take(&mut self.handles);
364 let mut results = Vec::new();
365
366 for handle in handles {
367 if let Ok(result) = handle.await {
368 results.push(result);
369 }
370 }
371
372 results
373 }
374
375 pub fn active_count(&self) -> usize {
377 self.handles.iter().filter(|h| !h.is_finished()).count()
378 }
379}
380
381pub async fn map_with_concurrency_limit<T, F, Fut, R>(
413 items: impl IntoIterator<Item = T>,
414 concurrency: usize,
415 f: F,
416) -> Vec<R>
417where
418 F: Fn(T) -> Fut,
419 Fut: Future<Output = R>,
420{
421 stream::iter(items)
422 .map(f)
423 .buffer_unordered(concurrency)
424 .collect()
425 .await
426}
427
428pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
441 items: impl IntoIterator<Item = T>,
442 concurrency: usize,
443 f: F,
444) -> Vec<R>
445where
446 F: Fn(T) -> Fut,
447 Fut: Future<Output = R>,
448{
449 stream::iter(items)
450 .map(f)
451 .buffered(concurrency)
452 .collect()
453 .await
454}
455
456pub async fn spawn_agents_with_limit(
469 configs: impl IntoIterator<Item = SpawnConfig>,
470 concurrency: usize,
471 event_tx: mpsc::Sender<AgentEvent>,
472) -> Vec<Result<AgentResult, anyhow::Error>> {
473 let configs: Vec<_> = configs.into_iter().collect();
474
475 map_with_concurrency_limit(configs, concurrency, |config| {
476 let tx = event_tx.clone();
477 async move {
478 match spawn_agent(config, tx).await {
479 Ok(handle) => handle.await.map_err(|e| anyhow::anyhow!("Join error: {}", e)),
480 Err(e) => Err(e),
481 }
482 }
483 })
484 .await
485}
486
487#[derive(Debug, Clone)]
489pub struct ConcurrentSpawnConfig {
490 pub max_concurrent: usize,
492 pub timeout_ms: u64,
494 pub fail_fast: bool,
496}
497
498impl Default for ConcurrentSpawnConfig {
499 fn default() -> Self {
500 Self {
501 max_concurrent: 5,
502 timeout_ms: 0,
503 fail_fast: false,
504 }
505 }
506}
507
508#[derive(Debug)]
510pub struct ConcurrentSpawnResult {
511 pub successes: Vec<AgentResult>,
513 pub failures: Vec<(String, String)>,
515 pub all_succeeded: bool,
517}
518
519pub async fn spawn_agents_concurrent(
532 configs: Vec<SpawnConfig>,
533 spawn_config: ConcurrentSpawnConfig,
534 event_tx: mpsc::Sender<AgentEvent>,
535) -> ConcurrentSpawnResult {
536 let mut successes = Vec::new();
537 let mut failures = Vec::new();
538
539 let results = if spawn_config.timeout_ms > 0 {
540 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
542
543 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
544 let tx = event_tx.clone();
545 let task_id = config.task_id.clone();
546 async move {
547 let result = tokio::time::timeout(timeout_duration, async {
548 match spawn_agent(config, tx).await {
549 Ok(handle) => handle
550 .await
551 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
552 Err(e) => Err(e),
553 }
554 })
555 .await;
556
557 match result {
558 Ok(Ok(agent_result)) => Ok(agent_result),
559 Ok(Err(e)) => Err((task_id, e.to_string())),
560 Err(_) => Err((task_id, "Timeout".to_string())),
561 }
562 }
563 })
564 .await
565 } else {
566 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
568 let tx = event_tx.clone();
569 let task_id = config.task_id.clone();
570 async move {
571 match spawn_agent(config, tx).await {
572 Ok(handle) => handle
573 .await
574 .map_err(|e| (task_id, format!("Join error: {}", e))),
575 Err(e) => Err((task_id, e.to_string())),
576 }
577 }
578 })
579 .await
580 };
581
582 for result in results {
583 match result {
584 Ok(agent_result) => successes.push(agent_result),
585 Err((task_id, error)) => failures.push((task_id, error)),
586 }
587 }
588
589 let all_succeeded = failures.is_empty();
590
591 ConcurrentSpawnResult {
592 successes,
593 failures,
594 all_succeeded,
595 }
596}
597
598pub async fn spawn_subagent(
614 task_id: String,
615 prompt: String,
616 working_dir: PathBuf,
617 harness: Harness,
618 model: Option<String>,
619) -> Result<AgentResult, anyhow::Error> {
620 let (tx, _rx) = mpsc::channel(10);
622
623 let config = SpawnConfig {
624 task_id,
625 prompt,
626 working_dir,
627 harness,
628 model,
629 };
630
631 let handle = spawn_agent(config, tx).await?;
632 handle
633 .await
634 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_extension_runner_new() {
643 let runner = ExtensionRunner::new();
644 assert!(runner.list_tools().is_empty());
645 }
646
647 #[test]
648 fn test_agent_result_debug() {
649 let result = AgentResult {
650 task_id: "test:1".to_string(),
651 success: true,
652 exit_code: Some(0),
653 output: "test output".to_string(),
654 duration_ms: 1000,
655 };
656
657 assert!(result.success);
658 assert_eq!(result.exit_code, Some(0));
659 assert_eq!(result.task_id, "test:1");
660 }
661
662 #[test]
663 fn test_spawn_config_debug() {
664 let config = SpawnConfig {
665 task_id: "test:1".to_string(),
666 prompt: "do something".to_string(),
667 working_dir: PathBuf::from("/tmp"),
668 harness: Harness::Claude,
669 model: Some("opus".to_string()),
670 };
671
672 assert_eq!(config.task_id, "test:1");
673 assert_eq!(config.harness, Harness::Claude);
674 }
675
676 #[tokio::test]
677 async fn test_agent_runner_new() {
678 let runner = AgentRunner::new(100);
679 assert_eq!(runner.active_count(), 0);
680 }
681
682 #[test]
683 fn test_tool_call_result() {
684 let result = ToolCallResult {
685 tool_name: "my_tool".to_string(),
686 output: serde_json::json!({"key": "value"}),
687 success: true,
688 };
689
690 assert_eq!(result.tool_name, "my_tool");
691 assert!(result.success);
692 assert_eq!(result.output["key"], "value");
693 }
694
695 #[test]
696 fn test_on_tool_call_not_found() {
697 let runner = ExtensionRunner::new();
698 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
699
700 assert!(result.is_err());
701 match result {
702 Err(ExtensionRunnerError::ToolNotFound(name)) => {
703 assert_eq!(name, "nonexistent");
704 }
705 _ => panic!("Expected ToolNotFound error"),
706 }
707 }
708
709 #[test]
710 fn test_on_tool_call_with_registered_tool() {
711 let mut runner = ExtensionRunner::new();
712
713 fn echo_tool(
715 args: &[Value],
716 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
717 Ok(args.first().cloned().unwrap_or(Value::Null))
718 }
719
720 runner.register_tool("echo".to_string(), echo_tool);
721
722 let result = runner
724 .on_tool_call("echo", serde_json::json!({"test": 123}))
725 .unwrap();
726
727 assert_eq!(result.tool_name, "echo");
728 assert!(result.success);
729 assert_eq!(result.output["test"], 123);
730 }
731
732 #[test]
733 fn test_on_tool_call_argument_conversion() {
734 let mut runner = ExtensionRunner::new();
735
736 fn count_args(
738 args: &[Value],
739 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
740 Ok(serde_json::json!(args.len()))
741 }
742
743 runner.register_tool("count".to_string(), count_args);
744
745 let result = runner
747 .on_tool_call("count", serde_json::json!([1, 2, 3]))
748 .unwrap();
749 assert_eq!(result.output, 3);
750
751 let result = runner
753 .on_tool_call("count", serde_json::json!({"a": 1}))
754 .unwrap();
755 assert_eq!(result.output, 1);
756
757 let result = runner.on_tool_call("count", Value::Null).unwrap();
759 assert_eq!(result.output, 0);
760
761 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
763 assert_eq!(result.output, 1);
764 }
765
766 #[tokio::test]
767 async fn test_map_with_concurrency_limit() {
768 use std::sync::atomic::{AtomicUsize, Ordering};
769 use std::sync::Arc;
770
771 let items: Vec<i32> = (0..10).collect();
772 let counter = Arc::new(AtomicUsize::new(0));
773 let max_concurrent = Arc::new(AtomicUsize::new(0));
774
775 let results = map_with_concurrency_limit(items, 3, |n| {
776 let counter = Arc::clone(&counter);
777 let max_concurrent = Arc::clone(&max_concurrent);
778 async move {
779 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
781
782 let mut max = max_concurrent.load(Ordering::SeqCst);
784 while current > max {
785 match max_concurrent.compare_exchange_weak(
786 max,
787 current,
788 Ordering::SeqCst,
789 Ordering::SeqCst,
790 ) {
791 Ok(_) => break,
792 Err(new_max) => max = new_max,
793 }
794 }
795
796 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
798
799 counter.fetch_sub(1, Ordering::SeqCst);
801
802 n * 2
803 }
804 })
805 .await;
806
807 assert_eq!(results.len(), 10);
809
810 let mut sorted: Vec<i32> = results;
812 sorted.sort();
813 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
814
815 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
817 }
818
819 #[tokio::test]
820 async fn test_map_with_concurrency_limit_ordered() {
821 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
822
823 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
824 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
826 n * 10
827 })
828 .await;
829
830 assert_eq!(results, vec![10, 20, 30, 40, 50]);
832 }
833
834 #[test]
835 fn test_concurrent_spawn_config_default() {
836 let config = ConcurrentSpawnConfig::default();
837
838 assert_eq!(config.max_concurrent, 5);
839 assert_eq!(config.timeout_ms, 0);
840 assert!(!config.fail_fast);
841 }
842
843 #[test]
844 fn test_concurrent_spawn_result() {
845 let result = ConcurrentSpawnResult {
846 successes: vec![AgentResult {
847 task_id: "1".to_string(),
848 success: true,
849 exit_code: Some(0),
850 output: "done".to_string(),
851 duration_ms: 100,
852 }],
853 failures: vec![],
854 all_succeeded: true,
855 };
856
857 assert!(result.all_succeeded);
858 assert_eq!(result.successes.len(), 1);
859 assert!(result.failures.is_empty());
860 }
861
862 #[test]
863 fn test_concurrent_spawn_result_with_failures() {
864 let result = ConcurrentSpawnResult {
865 successes: vec![],
866 failures: vec![("task1".to_string(), "error msg".to_string())],
867 all_succeeded: false,
868 };
869
870 assert!(!result.all_succeeded);
871 assert!(result.successes.is_empty());
872 assert_eq!(result.failures.len(), 1);
873 assert_eq!(result.failures[0].0, "task1");
874 }
875}