1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 Harness::Cursor => {
205 let mut c = Command::new(binary_path);
206 c.arg("-p");
207 if let Some(ref model) = config.model {
208 c.arg("--model").arg(model);
209 }
210 c.arg(&config.prompt);
211 c
212 }
213 Harness::Rho => {
214 let mut c = Command::new(binary_path);
215 if let Some(ref model) = config.model {
216 c.arg("--model").arg(model);
217 }
218 c.arg(&config.prompt);
219 c
220 }
221 #[cfg(feature = "direct-api")]
222 Harness::DirectApi => {
223 let mut c = Command::new(binary_path);
224 c.arg("agent-exec");
225 c.arg("--prompt").arg(&config.prompt);
226 if let Some(ref model) = config.model {
227 c.arg("--model").arg(model);
228 }
229 c
230 }
231 };
232
233 cmd.current_dir(&config.working_dir);
235 cmd.env("SCUD_TASK_ID", &config.task_id);
236 cmd.stdout(Stdio::piped());
237 cmd.stderr(Stdio::piped());
238
239 let start_time = std::time::Instant::now();
240
241 let mut child = cmd.spawn().map_err(|e| {
243 anyhow::anyhow!(
244 "Failed to spawn {} for task {}: {}",
245 config.harness.name(),
246 config.task_id,
247 e
248 )
249 })?;
250
251 let _ = event_tx
253 .send(AgentEvent::Started {
254 task_id: task_id.clone(),
255 })
256 .await;
257
258 let stdout = child.stdout.take();
260 let stderr = child.stderr.take();
261 let event_tx_clone = event_tx.clone();
262 let task_id_clone = task_id.clone();
263
264 let handle = tokio::spawn(async move {
265 let mut output_buffer = String::new();
266
267 if let Some(stdout) = stdout {
269 let reader = BufReader::new(stdout);
270 let mut lines = reader.lines();
271 while let Ok(Some(line)) = lines.next_line().await {
272 output_buffer.push_str(&line);
273 output_buffer.push('\n');
274 let _ = event_tx_clone
275 .send(AgentEvent::Output {
276 task_id: task_id_clone.clone(),
277 line: line.clone(),
278 })
279 .await;
280 }
281 }
282
283 if let Some(stderr) = stderr {
285 let reader = BufReader::new(stderr);
286 let mut lines = reader.lines();
287 while let Ok(Some(line)) = lines.next_line().await {
288 output_buffer.push_str("[stderr] ");
289 output_buffer.push_str(&line);
290 output_buffer.push('\n');
291 }
292 }
293
294 let status = child.wait().await;
296 let duration_ms = start_time.elapsed().as_millis() as u64;
297
298 let (success, exit_code) = match status {
299 Ok(s) => (s.success(), s.code()),
300 Err(_) => (false, None),
301 };
302
303 let result = AgentResult {
304 task_id: task_id_clone.clone(),
305 success,
306 exit_code,
307 output: output_buffer,
308 duration_ms,
309 };
310
311 let _ = event_tx_clone
312 .send(AgentEvent::Completed {
313 result: result.clone(),
314 })
315 .await;
316
317 result
318 });
319
320 Ok(handle)
321}
322
323pub fn load_agent_config(
325 agent_type: Option<&str>,
326 default_harness: Harness,
327 default_model: Option<&str>,
328 working_dir: &Path,
329) -> (Harness, Option<String>) {
330 if let Some(agent_name) = agent_type {
331 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
332 let harness = agent_def.harness().unwrap_or(default_harness);
333 let model = agent_def
334 .model()
335 .map(String::from)
336 .or_else(|| default_model.map(String::from));
337 return (harness, model);
338 }
339 }
340
341 (default_harness, default_model.map(String::from))
343}
344
345pub struct AgentRunner {
347 event_tx: mpsc::Sender<AgentEvent>,
349 event_rx: mpsc::Receiver<AgentEvent>,
351 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
353}
354
355impl AgentRunner {
356 pub fn new(capacity: usize) -> Self {
358 let (event_tx, event_rx) = mpsc::channel(capacity);
359 Self {
360 event_tx,
361 event_rx,
362 handles: Vec::new(),
363 }
364 }
365
366 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
368 self.event_tx.clone()
369 }
370
371 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
373 let handle = spawn_agent(config, self.event_tx.clone()).await?;
374 self.handles.push(handle);
375 Ok(())
376 }
377
378 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
380 self.event_rx.recv().await
381 }
382
383 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
385 self.event_rx.try_recv().ok()
386 }
387
388 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
390 let handles = std::mem::take(&mut self.handles);
391 let mut results = Vec::new();
392
393 for handle in handles {
394 if let Ok(result) = handle.await {
395 results.push(result);
396 }
397 }
398
399 results
400 }
401
402 pub fn active_count(&self) -> usize {
404 self.handles.iter().filter(|h| !h.is_finished()).count()
405 }
406}
407
408pub async fn map_with_concurrency_limit<T, F, Fut, R>(
440 items: impl IntoIterator<Item = T>,
441 concurrency: usize,
442 f: F,
443) -> Vec<R>
444where
445 F: Fn(T) -> Fut,
446 Fut: Future<Output = R>,
447{
448 stream::iter(items)
449 .map(f)
450 .buffer_unordered(concurrency)
451 .collect()
452 .await
453}
454
455pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
468 items: impl IntoIterator<Item = T>,
469 concurrency: usize,
470 f: F,
471) -> Vec<R>
472where
473 F: Fn(T) -> Fut,
474 Fut: Future<Output = R>,
475{
476 stream::iter(items)
477 .map(f)
478 .buffered(concurrency)
479 .collect()
480 .await
481}
482
483pub async fn spawn_agents_with_limit(
496 configs: impl IntoIterator<Item = SpawnConfig>,
497 concurrency: usize,
498 event_tx: mpsc::Sender<AgentEvent>,
499) -> Vec<Result<AgentResult, anyhow::Error>> {
500 let configs: Vec<_> = configs.into_iter().collect();
501
502 map_with_concurrency_limit(configs, concurrency, |config| {
503 let tx = event_tx.clone();
504 async move {
505 match spawn_agent(config, tx).await {
506 Ok(handle) => handle
507 .await
508 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
509 Err(e) => Err(e),
510 }
511 }
512 })
513 .await
514}
515
516#[derive(Debug, Clone)]
518pub struct ConcurrentSpawnConfig {
519 pub max_concurrent: usize,
521 pub timeout_ms: u64,
523 pub fail_fast: bool,
525}
526
527impl Default for ConcurrentSpawnConfig {
528 fn default() -> Self {
529 Self {
530 max_concurrent: 5,
531 timeout_ms: 0,
532 fail_fast: false,
533 }
534 }
535}
536
537#[derive(Debug)]
539pub struct ConcurrentSpawnResult {
540 pub successes: Vec<AgentResult>,
542 pub failures: Vec<(String, String)>,
544 pub all_succeeded: bool,
546}
547
548pub async fn spawn_agents_concurrent(
561 configs: Vec<SpawnConfig>,
562 spawn_config: ConcurrentSpawnConfig,
563 event_tx: mpsc::Sender<AgentEvent>,
564) -> ConcurrentSpawnResult {
565 let mut successes = Vec::new();
566 let mut failures = Vec::new();
567
568 let results = if spawn_config.timeout_ms > 0 {
569 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
571
572 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
573 let tx = event_tx.clone();
574 let task_id = config.task_id.clone();
575 async move {
576 let result = tokio::time::timeout(timeout_duration, async {
577 match spawn_agent(config, tx).await {
578 Ok(handle) => handle
579 .await
580 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
581 Err(e) => Err(e),
582 }
583 })
584 .await;
585
586 match result {
587 Ok(Ok(agent_result)) => Ok(agent_result),
588 Ok(Err(e)) => Err((task_id, e.to_string())),
589 Err(_) => Err((task_id, "Timeout".to_string())),
590 }
591 }
592 })
593 .await
594 } else {
595 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
597 let tx = event_tx.clone();
598 let task_id = config.task_id.clone();
599 async move {
600 match spawn_agent(config, tx).await {
601 Ok(handle) => handle
602 .await
603 .map_err(|e| (task_id, format!("Join error: {}", e))),
604 Err(e) => Err((task_id, e.to_string())),
605 }
606 }
607 })
608 .await
609 };
610
611 for result in results {
612 match result {
613 Ok(agent_result) => successes.push(agent_result),
614 Err((task_id, error)) => failures.push((task_id, error)),
615 }
616 }
617
618 let all_succeeded = failures.is_empty();
619
620 ConcurrentSpawnResult {
621 successes,
622 failures,
623 all_succeeded,
624 }
625}
626
627pub async fn spawn_subagent(
643 task_id: String,
644 prompt: String,
645 working_dir: PathBuf,
646 harness: Harness,
647 model: Option<String>,
648) -> Result<AgentResult, anyhow::Error> {
649 let (tx, _rx) = mpsc::channel(10);
651
652 let config = SpawnConfig {
653 task_id,
654 prompt,
655 working_dir,
656 harness,
657 model,
658 };
659
660 let handle = spawn_agent(config, tx).await?;
661 handle
662 .await
663 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_extension_runner_new() {
672 let runner = ExtensionRunner::new();
673 assert!(runner.list_tools().is_empty());
674 }
675
676 #[test]
677 fn test_agent_result_debug() {
678 let result = AgentResult {
679 task_id: "test:1".to_string(),
680 success: true,
681 exit_code: Some(0),
682 output: "test output".to_string(),
683 duration_ms: 1000,
684 };
685
686 assert!(result.success);
687 assert_eq!(result.exit_code, Some(0));
688 assert_eq!(result.task_id, "test:1");
689 }
690
691 #[test]
692 fn test_spawn_config_debug() {
693 let config = SpawnConfig {
694 task_id: "test:1".to_string(),
695 prompt: "do something".to_string(),
696 working_dir: PathBuf::from("/tmp"),
697 harness: Harness::Claude,
698 model: Some("opus".to_string()),
699 };
700
701 assert_eq!(config.task_id, "test:1");
702 assert_eq!(config.harness, Harness::Claude);
703 }
704
705 #[tokio::test]
706 async fn test_agent_runner_new() {
707 let runner = AgentRunner::new(100);
708 assert_eq!(runner.active_count(), 0);
709 }
710
711 #[test]
712 fn test_tool_call_result() {
713 let result = ToolCallResult {
714 tool_name: "my_tool".to_string(),
715 output: serde_json::json!({"key": "value"}),
716 success: true,
717 };
718
719 assert_eq!(result.tool_name, "my_tool");
720 assert!(result.success);
721 assert_eq!(result.output["key"], "value");
722 }
723
724 #[test]
725 fn test_on_tool_call_not_found() {
726 let runner = ExtensionRunner::new();
727 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
728
729 assert!(result.is_err());
730 match result {
731 Err(ExtensionRunnerError::ToolNotFound(name)) => {
732 assert_eq!(name, "nonexistent");
733 }
734 _ => panic!("Expected ToolNotFound error"),
735 }
736 }
737
738 #[test]
739 fn test_on_tool_call_with_registered_tool() {
740 let mut runner = ExtensionRunner::new();
741
742 fn echo_tool(args: &[Value]) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
744 Ok(args.first().cloned().unwrap_or(Value::Null))
745 }
746
747 runner.register_tool("echo".to_string(), echo_tool);
748
749 let result = runner
751 .on_tool_call("echo", serde_json::json!({"test": 123}))
752 .unwrap();
753
754 assert_eq!(result.tool_name, "echo");
755 assert!(result.success);
756 assert_eq!(result.output["test"], 123);
757 }
758
759 #[test]
760 fn test_on_tool_call_argument_conversion() {
761 let mut runner = ExtensionRunner::new();
762
763 fn count_args(args: &[Value]) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
765 Ok(serde_json::json!(args.len()))
766 }
767
768 runner.register_tool("count".to_string(), count_args);
769
770 let result = runner
772 .on_tool_call("count", serde_json::json!([1, 2, 3]))
773 .unwrap();
774 assert_eq!(result.output, 3);
775
776 let result = runner
778 .on_tool_call("count", serde_json::json!({"a": 1}))
779 .unwrap();
780 assert_eq!(result.output, 1);
781
782 let result = runner.on_tool_call("count", Value::Null).unwrap();
784 assert_eq!(result.output, 0);
785
786 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
788 assert_eq!(result.output, 1);
789 }
790
791 #[tokio::test]
792 async fn test_map_with_concurrency_limit() {
793 use std::sync::atomic::{AtomicUsize, Ordering};
794 use std::sync::Arc;
795
796 let items: Vec<i32> = (0..10).collect();
797 let counter = Arc::new(AtomicUsize::new(0));
798 let max_concurrent = Arc::new(AtomicUsize::new(0));
799
800 let results = map_with_concurrency_limit(items, 3, |n| {
801 let counter = Arc::clone(&counter);
802 let max_concurrent = Arc::clone(&max_concurrent);
803 async move {
804 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
806
807 let mut max = max_concurrent.load(Ordering::SeqCst);
809 while current > max {
810 match max_concurrent.compare_exchange_weak(
811 max,
812 current,
813 Ordering::SeqCst,
814 Ordering::SeqCst,
815 ) {
816 Ok(_) => break,
817 Err(new_max) => max = new_max,
818 }
819 }
820
821 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
823
824 counter.fetch_sub(1, Ordering::SeqCst);
826
827 n * 2
828 }
829 })
830 .await;
831
832 assert_eq!(results.len(), 10);
834
835 let mut sorted: Vec<i32> = results;
837 sorted.sort();
838 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
839
840 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
842 }
843
844 #[tokio::test]
845 async fn test_map_with_concurrency_limit_ordered() {
846 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
847
848 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
849 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
851 n * 10
852 })
853 .await;
854
855 assert_eq!(results, vec![10, 20, 30, 40, 50]);
857 }
858
859 #[test]
860 fn test_concurrent_spawn_config_default() {
861 let config = ConcurrentSpawnConfig::default();
862
863 assert_eq!(config.max_concurrent, 5);
864 assert_eq!(config.timeout_ms, 0);
865 assert!(!config.fail_fast);
866 }
867
868 #[test]
869 fn test_concurrent_spawn_result() {
870 let result = ConcurrentSpawnResult {
871 successes: vec![AgentResult {
872 task_id: "1".to_string(),
873 success: true,
874 exit_code: Some(0),
875 output: "done".to_string(),
876 duration_ms: 100,
877 }],
878 failures: vec![],
879 all_succeeded: true,
880 };
881
882 assert!(result.all_succeeded);
883 assert_eq!(result.successes.len(), 1);
884 assert!(result.failures.is_empty());
885 }
886
887 #[test]
888 fn test_concurrent_spawn_result_with_failures() {
889 let result = ConcurrentSpawnResult {
890 successes: vec![],
891 failures: vec![("task1".to_string(), "error msg".to_string())],
892 all_succeeded: false,
893 };
894
895 assert!(!result.all_succeeded);
896 assert!(result.successes.is_empty());
897 assert_eq!(result.failures.len(), 1);
898 assert_eq!(result.failures[0].0, "task1");
899 }
900}