1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 Harness::Cursor => {
205 let mut c = Command::new(binary_path);
206 c.arg("-p");
207 if let Some(ref model) = config.model {
208 c.arg("--model").arg(model);
209 }
210 c.arg(&config.prompt);
211 c
212 }
213 Harness::Rho => {
214 let mut c = Command::new(binary_path);
215 c.arg("-p").arg(&config.prompt);
216 c.arg("-C").arg(&config.working_dir);
217 if let Some(ref model) = config.model {
218 c.arg("--model").arg(model);
219 }
220 c
221 }
222 #[cfg(feature = "direct-api")]
223 Harness::DirectApi => {
224 let mut c = Command::new(binary_path);
225 c.arg("agent-exec");
226 c.arg("--prompt").arg(&config.prompt);
227 if let Some(ref model) = config.model {
228 c.arg("--model").arg(model);
229 }
230 c
231 }
232 };
233
234 cmd.current_dir(&config.working_dir);
236 cmd.env("SCUD_TASK_ID", &config.task_id);
237 cmd.stdout(Stdio::piped());
238 cmd.stderr(Stdio::piped());
239
240 let start_time = std::time::Instant::now();
241
242 let mut child = cmd.spawn().map_err(|e| {
244 anyhow::anyhow!(
245 "Failed to spawn {} for task {}: {}",
246 config.harness.name(),
247 config.task_id,
248 e
249 )
250 })?;
251
252 let _ = event_tx
254 .send(AgentEvent::Started {
255 task_id: task_id.clone(),
256 })
257 .await;
258
259 let stdout = child.stdout.take();
261 let stderr = child.stderr.take();
262 let event_tx_clone = event_tx.clone();
263 let task_id_clone = task_id.clone();
264
265 let handle = tokio::spawn(async move {
266 let mut output_buffer = String::new();
267
268 if let Some(stdout) = stdout {
270 let reader = BufReader::new(stdout);
271 let mut lines = reader.lines();
272 while let Ok(Some(line)) = lines.next_line().await {
273 output_buffer.push_str(&line);
274 output_buffer.push('\n');
275 let _ = event_tx_clone
276 .send(AgentEvent::Output {
277 task_id: task_id_clone.clone(),
278 line: line.clone(),
279 })
280 .await;
281 }
282 }
283
284 if let Some(stderr) = stderr {
286 let reader = BufReader::new(stderr);
287 let mut lines = reader.lines();
288 while let Ok(Some(line)) = lines.next_line().await {
289 output_buffer.push_str("[stderr] ");
290 output_buffer.push_str(&line);
291 output_buffer.push('\n');
292 }
293 }
294
295 let status = child.wait().await;
297 let duration_ms = start_time.elapsed().as_millis() as u64;
298
299 let (success, exit_code) = match status {
300 Ok(s) => (s.success(), s.code()),
301 Err(_) => (false, None),
302 };
303
304 let result = AgentResult {
305 task_id: task_id_clone.clone(),
306 success,
307 exit_code,
308 output: output_buffer,
309 duration_ms,
310 };
311
312 let _ = event_tx_clone
313 .send(AgentEvent::Completed {
314 result: result.clone(),
315 })
316 .await;
317
318 result
319 });
320
321 Ok(handle)
322}
323
324pub fn load_agent_config(
326 agent_type: Option<&str>,
327 default_harness: Harness,
328 default_model: Option<&str>,
329 working_dir: &Path,
330) -> (Harness, Option<String>) {
331 if let Some(agent_name) = agent_type {
332 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
333 let harness = agent_def.harness().unwrap_or(default_harness);
334 let model = agent_def
335 .model()
336 .map(String::from)
337 .or_else(|| default_model.map(String::from));
338 return (harness, model);
339 }
340 }
341
342 (default_harness, default_model.map(String::from))
344}
345
346pub struct AgentRunner {
348 event_tx: mpsc::Sender<AgentEvent>,
350 event_rx: mpsc::Receiver<AgentEvent>,
352 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
354}
355
356impl AgentRunner {
357 pub fn new(capacity: usize) -> Self {
359 let (event_tx, event_rx) = mpsc::channel(capacity);
360 Self {
361 event_tx,
362 event_rx,
363 handles: Vec::new(),
364 }
365 }
366
367 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
369 self.event_tx.clone()
370 }
371
372 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
374 let handle = spawn_agent(config, self.event_tx.clone()).await?;
375 self.handles.push(handle);
376 Ok(())
377 }
378
379 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
381 self.event_rx.recv().await
382 }
383
384 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
386 self.event_rx.try_recv().ok()
387 }
388
389 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
391 let handles = std::mem::take(&mut self.handles);
392 let mut results = Vec::new();
393
394 for handle in handles {
395 if let Ok(result) = handle.await {
396 results.push(result);
397 }
398 }
399
400 results
401 }
402
403 pub fn active_count(&self) -> usize {
405 self.handles.iter().filter(|h| !h.is_finished()).count()
406 }
407}
408
409pub async fn map_with_concurrency_limit<T, F, Fut, R>(
441 items: impl IntoIterator<Item = T>,
442 concurrency: usize,
443 f: F,
444) -> Vec<R>
445where
446 F: Fn(T) -> Fut,
447 Fut: Future<Output = R>,
448{
449 stream::iter(items)
450 .map(f)
451 .buffer_unordered(concurrency)
452 .collect()
453 .await
454}
455
456pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
469 items: impl IntoIterator<Item = T>,
470 concurrency: usize,
471 f: F,
472) -> Vec<R>
473where
474 F: Fn(T) -> Fut,
475 Fut: Future<Output = R>,
476{
477 stream::iter(items)
478 .map(f)
479 .buffered(concurrency)
480 .collect()
481 .await
482}
483
484pub async fn spawn_agents_with_limit(
497 configs: impl IntoIterator<Item = SpawnConfig>,
498 concurrency: usize,
499 event_tx: mpsc::Sender<AgentEvent>,
500) -> Vec<Result<AgentResult, anyhow::Error>> {
501 let configs: Vec<_> = configs.into_iter().collect();
502
503 map_with_concurrency_limit(configs, concurrency, |config| {
504 let tx = event_tx.clone();
505 async move {
506 match spawn_agent(config, tx).await {
507 Ok(handle) => handle
508 .await
509 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
510 Err(e) => Err(e),
511 }
512 }
513 })
514 .await
515}
516
517#[derive(Debug, Clone)]
519pub struct ConcurrentSpawnConfig {
520 pub max_concurrent: usize,
522 pub timeout_ms: u64,
524 pub fail_fast: bool,
526}
527
528impl Default for ConcurrentSpawnConfig {
529 fn default() -> Self {
530 Self {
531 max_concurrent: 5,
532 timeout_ms: 0,
533 fail_fast: false,
534 }
535 }
536}
537
538#[derive(Debug)]
540pub struct ConcurrentSpawnResult {
541 pub successes: Vec<AgentResult>,
543 pub failures: Vec<(String, String)>,
545 pub all_succeeded: bool,
547}
548
549pub async fn spawn_agents_concurrent(
562 configs: Vec<SpawnConfig>,
563 spawn_config: ConcurrentSpawnConfig,
564 event_tx: mpsc::Sender<AgentEvent>,
565) -> ConcurrentSpawnResult {
566 let mut successes = Vec::new();
567 let mut failures = Vec::new();
568
569 let results = if spawn_config.timeout_ms > 0 {
570 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
572
573 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
574 let tx = event_tx.clone();
575 let task_id = config.task_id.clone();
576 async move {
577 let result = tokio::time::timeout(timeout_duration, async {
578 match spawn_agent(config, tx).await {
579 Ok(handle) => handle
580 .await
581 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
582 Err(e) => Err(e),
583 }
584 })
585 .await;
586
587 match result {
588 Ok(Ok(agent_result)) => Ok(agent_result),
589 Ok(Err(e)) => Err((task_id, e.to_string())),
590 Err(_) => Err((task_id, "Timeout".to_string())),
591 }
592 }
593 })
594 .await
595 } else {
596 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
598 let tx = event_tx.clone();
599 let task_id = config.task_id.clone();
600 async move {
601 match spawn_agent(config, tx).await {
602 Ok(handle) => handle
603 .await
604 .map_err(|e| (task_id, format!("Join error: {}", e))),
605 Err(e) => Err((task_id, e.to_string())),
606 }
607 }
608 })
609 .await
610 };
611
612 for result in results {
613 match result {
614 Ok(agent_result) => successes.push(agent_result),
615 Err((task_id, error)) => failures.push((task_id, error)),
616 }
617 }
618
619 let all_succeeded = failures.is_empty();
620
621 ConcurrentSpawnResult {
622 successes,
623 failures,
624 all_succeeded,
625 }
626}
627
628pub async fn spawn_subagent(
644 task_id: String,
645 prompt: String,
646 working_dir: PathBuf,
647 harness: Harness,
648 model: Option<String>,
649) -> Result<AgentResult, anyhow::Error> {
650 let (tx, _rx) = mpsc::channel(10);
652
653 let config = SpawnConfig {
654 task_id,
655 prompt,
656 working_dir,
657 harness,
658 model,
659 };
660
661 let handle = spawn_agent(config, tx).await?;
662 handle
663 .await
664 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn test_extension_runner_new() {
673 let runner = ExtensionRunner::new();
674 assert!(runner.list_tools().is_empty());
675 }
676
677 #[test]
678 fn test_agent_result_debug() {
679 let result = AgentResult {
680 task_id: "test:1".to_string(),
681 success: true,
682 exit_code: Some(0),
683 output: "test output".to_string(),
684 duration_ms: 1000,
685 };
686
687 assert!(result.success);
688 assert_eq!(result.exit_code, Some(0));
689 assert_eq!(result.task_id, "test:1");
690 }
691
692 #[test]
693 fn test_spawn_config_debug() {
694 let config = SpawnConfig {
695 task_id: "test:1".to_string(),
696 prompt: "do something".to_string(),
697 working_dir: PathBuf::from("/tmp"),
698 harness: Harness::Claude,
699 model: Some("opus".to_string()),
700 };
701
702 assert_eq!(config.task_id, "test:1");
703 assert_eq!(config.harness, Harness::Claude);
704 }
705
706 #[tokio::test]
707 async fn test_agent_runner_new() {
708 let runner = AgentRunner::new(100);
709 assert_eq!(runner.active_count(), 0);
710 }
711
712 #[test]
713 fn test_tool_call_result() {
714 let result = ToolCallResult {
715 tool_name: "my_tool".to_string(),
716 output: serde_json::json!({"key": "value"}),
717 success: true,
718 };
719
720 assert_eq!(result.tool_name, "my_tool");
721 assert!(result.success);
722 assert_eq!(result.output["key"], "value");
723 }
724
725 #[test]
726 fn test_on_tool_call_not_found() {
727 let runner = ExtensionRunner::new();
728 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
729
730 assert!(result.is_err());
731 match result {
732 Err(ExtensionRunnerError::ToolNotFound(name)) => {
733 assert_eq!(name, "nonexistent");
734 }
735 _ => panic!("Expected ToolNotFound error"),
736 }
737 }
738
739 #[test]
740 fn test_on_tool_call_with_registered_tool() {
741 let mut runner = ExtensionRunner::new();
742
743 fn echo_tool(args: &[Value]) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
745 Ok(args.first().cloned().unwrap_or(Value::Null))
746 }
747
748 runner.register_tool("echo".to_string(), echo_tool);
749
750 let result = runner
752 .on_tool_call("echo", serde_json::json!({"test": 123}))
753 .unwrap();
754
755 assert_eq!(result.tool_name, "echo");
756 assert!(result.success);
757 assert_eq!(result.output["test"], 123);
758 }
759
760 #[test]
761 fn test_on_tool_call_argument_conversion() {
762 let mut runner = ExtensionRunner::new();
763
764 fn count_args(args: &[Value]) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
766 Ok(serde_json::json!(args.len()))
767 }
768
769 runner.register_tool("count".to_string(), count_args);
770
771 let result = runner
773 .on_tool_call("count", serde_json::json!([1, 2, 3]))
774 .unwrap();
775 assert_eq!(result.output, 3);
776
777 let result = runner
779 .on_tool_call("count", serde_json::json!({"a": 1}))
780 .unwrap();
781 assert_eq!(result.output, 1);
782
783 let result = runner.on_tool_call("count", Value::Null).unwrap();
785 assert_eq!(result.output, 0);
786
787 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
789 assert_eq!(result.output, 1);
790 }
791
792 #[tokio::test]
793 async fn test_map_with_concurrency_limit() {
794 use std::sync::atomic::{AtomicUsize, Ordering};
795 use std::sync::Arc;
796
797 let items: Vec<i32> = (0..10).collect();
798 let counter = Arc::new(AtomicUsize::new(0));
799 let max_concurrent = Arc::new(AtomicUsize::new(0));
800
801 let results = map_with_concurrency_limit(items, 3, |n| {
802 let counter = Arc::clone(&counter);
803 let max_concurrent = Arc::clone(&max_concurrent);
804 async move {
805 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
807
808 let mut max = max_concurrent.load(Ordering::SeqCst);
810 while current > max {
811 match max_concurrent.compare_exchange_weak(
812 max,
813 current,
814 Ordering::SeqCst,
815 Ordering::SeqCst,
816 ) {
817 Ok(_) => break,
818 Err(new_max) => max = new_max,
819 }
820 }
821
822 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
824
825 counter.fetch_sub(1, Ordering::SeqCst);
827
828 n * 2
829 }
830 })
831 .await;
832
833 assert_eq!(results.len(), 10);
835
836 let mut sorted: Vec<i32> = results;
838 sorted.sort();
839 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
840
841 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
843 }
844
845 #[tokio::test]
846 async fn test_map_with_concurrency_limit_ordered() {
847 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
848
849 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
850 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
852 n * 10
853 })
854 .await;
855
856 assert_eq!(results, vec![10, 20, 30, 40, 50]);
858 }
859
860 #[test]
861 fn test_concurrent_spawn_config_default() {
862 let config = ConcurrentSpawnConfig::default();
863
864 assert_eq!(config.max_concurrent, 5);
865 assert_eq!(config.timeout_ms, 0);
866 assert!(!config.fail_fast);
867 }
868
869 #[test]
870 fn test_concurrent_spawn_result() {
871 let result = ConcurrentSpawnResult {
872 successes: vec![AgentResult {
873 task_id: "1".to_string(),
874 success: true,
875 exit_code: Some(0),
876 output: "done".to_string(),
877 duration_ms: 100,
878 }],
879 failures: vec![],
880 all_succeeded: true,
881 };
882
883 assert!(result.all_succeeded);
884 assert_eq!(result.successes.len(), 1);
885 assert!(result.failures.is_empty());
886 }
887
888 #[test]
889 fn test_concurrent_spawn_result_with_failures() {
890 let result = ConcurrentSpawnResult {
891 successes: vec![],
892 failures: vec![("task1".to_string(), "error msg".to_string())],
893 all_succeeded: false,
894 };
895
896 assert!(!result.all_succeeded);
897 assert!(result.successes.is_empty());
898 assert_eq!(result.failures.len(), 1);
899 assert_eq!(result.failures[0].0, "task1");
900 }
901}