1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 Harness::Cursor => {
205 let mut c = Command::new(binary_path);
206 c.arg("-p");
207 if let Some(ref model) = config.model {
208 c.arg("--model").arg(model);
209 }
210 c.arg(&config.prompt);
211 c
212 }
213 };
214
215 cmd.current_dir(&config.working_dir);
217 cmd.env("SCUD_TASK_ID", &config.task_id);
218 cmd.stdout(Stdio::piped());
219 cmd.stderr(Stdio::piped());
220
221 let start_time = std::time::Instant::now();
222
223 let mut child = cmd.spawn().map_err(|e| {
225 anyhow::anyhow!(
226 "Failed to spawn {} for task {}: {}",
227 config.harness.name(),
228 config.task_id,
229 e
230 )
231 })?;
232
233 let _ = event_tx
235 .send(AgentEvent::Started {
236 task_id: task_id.clone(),
237 })
238 .await;
239
240 let stdout = child.stdout.take();
242 let stderr = child.stderr.take();
243 let event_tx_clone = event_tx.clone();
244 let task_id_clone = task_id.clone();
245
246 let handle = tokio::spawn(async move {
247 let mut output_buffer = String::new();
248
249 if let Some(stdout) = stdout {
251 let reader = BufReader::new(stdout);
252 let mut lines = reader.lines();
253 while let Ok(Some(line)) = lines.next_line().await {
254 output_buffer.push_str(&line);
255 output_buffer.push('\n');
256 let _ = event_tx_clone
257 .send(AgentEvent::Output {
258 task_id: task_id_clone.clone(),
259 line: line.clone(),
260 })
261 .await;
262 }
263 }
264
265 if let Some(stderr) = stderr {
267 let reader = BufReader::new(stderr);
268 let mut lines = reader.lines();
269 while let Ok(Some(line)) = lines.next_line().await {
270 output_buffer.push_str("[stderr] ");
271 output_buffer.push_str(&line);
272 output_buffer.push('\n');
273 }
274 }
275
276 let status = child.wait().await;
278 let duration_ms = start_time.elapsed().as_millis() as u64;
279
280 let (success, exit_code) = match status {
281 Ok(s) => (s.success(), s.code()),
282 Err(_) => (false, None),
283 };
284
285 let result = AgentResult {
286 task_id: task_id_clone.clone(),
287 success,
288 exit_code,
289 output: output_buffer,
290 duration_ms,
291 };
292
293 let _ = event_tx_clone
294 .send(AgentEvent::Completed {
295 result: result.clone(),
296 })
297 .await;
298
299 result
300 });
301
302 Ok(handle)
303}
304
305pub fn load_agent_config(
307 agent_type: Option<&str>,
308 default_harness: Harness,
309 default_model: Option<&str>,
310 working_dir: &Path,
311) -> (Harness, Option<String>) {
312 if let Some(agent_name) = agent_type {
313 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
314 let harness = agent_def.harness().unwrap_or(default_harness);
315 let model = agent_def
316 .model()
317 .map(String::from)
318 .or_else(|| default_model.map(String::from));
319 return (harness, model);
320 }
321 }
322
323 (default_harness, default_model.map(String::from))
325}
326
327pub struct AgentRunner {
329 event_tx: mpsc::Sender<AgentEvent>,
331 event_rx: mpsc::Receiver<AgentEvent>,
333 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
335}
336
337impl AgentRunner {
338 pub fn new(capacity: usize) -> Self {
340 let (event_tx, event_rx) = mpsc::channel(capacity);
341 Self {
342 event_tx,
343 event_rx,
344 handles: Vec::new(),
345 }
346 }
347
348 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
350 self.event_tx.clone()
351 }
352
353 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
355 let handle = spawn_agent(config, self.event_tx.clone()).await?;
356 self.handles.push(handle);
357 Ok(())
358 }
359
360 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
362 self.event_rx.recv().await
363 }
364
365 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
367 self.event_rx.try_recv().ok()
368 }
369
370 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
372 let handles = std::mem::take(&mut self.handles);
373 let mut results = Vec::new();
374
375 for handle in handles {
376 if let Ok(result) = handle.await {
377 results.push(result);
378 }
379 }
380
381 results
382 }
383
384 pub fn active_count(&self) -> usize {
386 self.handles.iter().filter(|h| !h.is_finished()).count()
387 }
388}
389
390pub async fn map_with_concurrency_limit<T, F, Fut, R>(
422 items: impl IntoIterator<Item = T>,
423 concurrency: usize,
424 f: F,
425) -> Vec<R>
426where
427 F: Fn(T) -> Fut,
428 Fut: Future<Output = R>,
429{
430 stream::iter(items)
431 .map(f)
432 .buffer_unordered(concurrency)
433 .collect()
434 .await
435}
436
437pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
450 items: impl IntoIterator<Item = T>,
451 concurrency: usize,
452 f: F,
453) -> Vec<R>
454where
455 F: Fn(T) -> Fut,
456 Fut: Future<Output = R>,
457{
458 stream::iter(items)
459 .map(f)
460 .buffered(concurrency)
461 .collect()
462 .await
463}
464
465pub async fn spawn_agents_with_limit(
478 configs: impl IntoIterator<Item = SpawnConfig>,
479 concurrency: usize,
480 event_tx: mpsc::Sender<AgentEvent>,
481) -> Vec<Result<AgentResult, anyhow::Error>> {
482 let configs: Vec<_> = configs.into_iter().collect();
483
484 map_with_concurrency_limit(configs, concurrency, |config| {
485 let tx = event_tx.clone();
486 async move {
487 match spawn_agent(config, tx).await {
488 Ok(handle) => handle.await.map_err(|e| anyhow::anyhow!("Join error: {}", e)),
489 Err(e) => Err(e),
490 }
491 }
492 })
493 .await
494}
495
496#[derive(Debug, Clone)]
498pub struct ConcurrentSpawnConfig {
499 pub max_concurrent: usize,
501 pub timeout_ms: u64,
503 pub fail_fast: bool,
505}
506
507impl Default for ConcurrentSpawnConfig {
508 fn default() -> Self {
509 Self {
510 max_concurrent: 5,
511 timeout_ms: 0,
512 fail_fast: false,
513 }
514 }
515}
516
517#[derive(Debug)]
519pub struct ConcurrentSpawnResult {
520 pub successes: Vec<AgentResult>,
522 pub failures: Vec<(String, String)>,
524 pub all_succeeded: bool,
526}
527
528pub async fn spawn_agents_concurrent(
541 configs: Vec<SpawnConfig>,
542 spawn_config: ConcurrentSpawnConfig,
543 event_tx: mpsc::Sender<AgentEvent>,
544) -> ConcurrentSpawnResult {
545 let mut successes = Vec::new();
546 let mut failures = Vec::new();
547
548 let results = if spawn_config.timeout_ms > 0 {
549 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
551
552 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
553 let tx = event_tx.clone();
554 let task_id = config.task_id.clone();
555 async move {
556 let result = tokio::time::timeout(timeout_duration, async {
557 match spawn_agent(config, tx).await {
558 Ok(handle) => handle
559 .await
560 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
561 Err(e) => Err(e),
562 }
563 })
564 .await;
565
566 match result {
567 Ok(Ok(agent_result)) => Ok(agent_result),
568 Ok(Err(e)) => Err((task_id, e.to_string())),
569 Err(_) => Err((task_id, "Timeout".to_string())),
570 }
571 }
572 })
573 .await
574 } else {
575 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
577 let tx = event_tx.clone();
578 let task_id = config.task_id.clone();
579 async move {
580 match spawn_agent(config, tx).await {
581 Ok(handle) => handle
582 .await
583 .map_err(|e| (task_id, format!("Join error: {}", e))),
584 Err(e) => Err((task_id, e.to_string())),
585 }
586 }
587 })
588 .await
589 };
590
591 for result in results {
592 match result {
593 Ok(agent_result) => successes.push(agent_result),
594 Err((task_id, error)) => failures.push((task_id, error)),
595 }
596 }
597
598 let all_succeeded = failures.is_empty();
599
600 ConcurrentSpawnResult {
601 successes,
602 failures,
603 all_succeeded,
604 }
605}
606
607pub async fn spawn_subagent(
623 task_id: String,
624 prompt: String,
625 working_dir: PathBuf,
626 harness: Harness,
627 model: Option<String>,
628) -> Result<AgentResult, anyhow::Error> {
629 let (tx, _rx) = mpsc::channel(10);
631
632 let config = SpawnConfig {
633 task_id,
634 prompt,
635 working_dir,
636 harness,
637 model,
638 };
639
640 let handle = spawn_agent(config, tx).await?;
641 handle
642 .await
643 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_extension_runner_new() {
652 let runner = ExtensionRunner::new();
653 assert!(runner.list_tools().is_empty());
654 }
655
656 #[test]
657 fn test_agent_result_debug() {
658 let result = AgentResult {
659 task_id: "test:1".to_string(),
660 success: true,
661 exit_code: Some(0),
662 output: "test output".to_string(),
663 duration_ms: 1000,
664 };
665
666 assert!(result.success);
667 assert_eq!(result.exit_code, Some(0));
668 assert_eq!(result.task_id, "test:1");
669 }
670
671 #[test]
672 fn test_spawn_config_debug() {
673 let config = SpawnConfig {
674 task_id: "test:1".to_string(),
675 prompt: "do something".to_string(),
676 working_dir: PathBuf::from("/tmp"),
677 harness: Harness::Claude,
678 model: Some("opus".to_string()),
679 };
680
681 assert_eq!(config.task_id, "test:1");
682 assert_eq!(config.harness, Harness::Claude);
683 }
684
685 #[tokio::test]
686 async fn test_agent_runner_new() {
687 let runner = AgentRunner::new(100);
688 assert_eq!(runner.active_count(), 0);
689 }
690
691 #[test]
692 fn test_tool_call_result() {
693 let result = ToolCallResult {
694 tool_name: "my_tool".to_string(),
695 output: serde_json::json!({"key": "value"}),
696 success: true,
697 };
698
699 assert_eq!(result.tool_name, "my_tool");
700 assert!(result.success);
701 assert_eq!(result.output["key"], "value");
702 }
703
704 #[test]
705 fn test_on_tool_call_not_found() {
706 let runner = ExtensionRunner::new();
707 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
708
709 assert!(result.is_err());
710 match result {
711 Err(ExtensionRunnerError::ToolNotFound(name)) => {
712 assert_eq!(name, "nonexistent");
713 }
714 _ => panic!("Expected ToolNotFound error"),
715 }
716 }
717
718 #[test]
719 fn test_on_tool_call_with_registered_tool() {
720 let mut runner = ExtensionRunner::new();
721
722 fn echo_tool(
724 args: &[Value],
725 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
726 Ok(args.first().cloned().unwrap_or(Value::Null))
727 }
728
729 runner.register_tool("echo".to_string(), echo_tool);
730
731 let result = runner
733 .on_tool_call("echo", serde_json::json!({"test": 123}))
734 .unwrap();
735
736 assert_eq!(result.tool_name, "echo");
737 assert!(result.success);
738 assert_eq!(result.output["test"], 123);
739 }
740
741 #[test]
742 fn test_on_tool_call_argument_conversion() {
743 let mut runner = ExtensionRunner::new();
744
745 fn count_args(
747 args: &[Value],
748 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
749 Ok(serde_json::json!(args.len()))
750 }
751
752 runner.register_tool("count".to_string(), count_args);
753
754 let result = runner
756 .on_tool_call("count", serde_json::json!([1, 2, 3]))
757 .unwrap();
758 assert_eq!(result.output, 3);
759
760 let result = runner
762 .on_tool_call("count", serde_json::json!({"a": 1}))
763 .unwrap();
764 assert_eq!(result.output, 1);
765
766 let result = runner.on_tool_call("count", Value::Null).unwrap();
768 assert_eq!(result.output, 0);
769
770 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
772 assert_eq!(result.output, 1);
773 }
774
775 #[tokio::test]
776 async fn test_map_with_concurrency_limit() {
777 use std::sync::atomic::{AtomicUsize, Ordering};
778 use std::sync::Arc;
779
780 let items: Vec<i32> = (0..10).collect();
781 let counter = Arc::new(AtomicUsize::new(0));
782 let max_concurrent = Arc::new(AtomicUsize::new(0));
783
784 let results = map_with_concurrency_limit(items, 3, |n| {
785 let counter = Arc::clone(&counter);
786 let max_concurrent = Arc::clone(&max_concurrent);
787 async move {
788 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
790
791 let mut max = max_concurrent.load(Ordering::SeqCst);
793 while current > max {
794 match max_concurrent.compare_exchange_weak(
795 max,
796 current,
797 Ordering::SeqCst,
798 Ordering::SeqCst,
799 ) {
800 Ok(_) => break,
801 Err(new_max) => max = new_max,
802 }
803 }
804
805 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
807
808 counter.fetch_sub(1, Ordering::SeqCst);
810
811 n * 2
812 }
813 })
814 .await;
815
816 assert_eq!(results.len(), 10);
818
819 let mut sorted: Vec<i32> = results;
821 sorted.sort();
822 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
823
824 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
826 }
827
828 #[tokio::test]
829 async fn test_map_with_concurrency_limit_ordered() {
830 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
831
832 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
833 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
835 n * 10
836 })
837 .await;
838
839 assert_eq!(results, vec![10, 20, 30, 40, 50]);
841 }
842
843 #[test]
844 fn test_concurrent_spawn_config_default() {
845 let config = ConcurrentSpawnConfig::default();
846
847 assert_eq!(config.max_concurrent, 5);
848 assert_eq!(config.timeout_ms, 0);
849 assert!(!config.fail_fast);
850 }
851
852 #[test]
853 fn test_concurrent_spawn_result() {
854 let result = ConcurrentSpawnResult {
855 successes: vec![AgentResult {
856 task_id: "1".to_string(),
857 success: true,
858 exit_code: Some(0),
859 output: "done".to_string(),
860 duration_ms: 100,
861 }],
862 failures: vec![],
863 all_succeeded: true,
864 };
865
866 assert!(result.all_succeeded);
867 assert_eq!(result.successes.len(), 1);
868 assert!(result.failures.is_empty());
869 }
870
871 #[test]
872 fn test_concurrent_spawn_result_with_failures() {
873 let result = ConcurrentSpawnResult {
874 successes: vec![],
875 failures: vec![("task1".to_string(), "error msg".to_string())],
876 all_succeeded: false,
877 };
878
879 assert!(!result.all_succeeded);
880 assert!(result.successes.is_empty());
881 assert_eq!(result.failures.len(), 1);
882 assert_eq!(result.failures[0].0, "task1");
883 }
884}