soul_core/executor/
direct.rs1use std::sync::Arc;
4
5#[cfg(test)]
6use async_trait::async_trait;
7use tokio::sync::mpsc;
8
9use crate::error::{SoulError, SoulResult};
10use crate::tool::{ToolOutput, ToolRegistry};
11use crate::types::ToolDefinition;
12
13use super::ToolExecutor;
14
15pub struct DirectExecutor {
20 tools: Arc<ToolRegistry>,
21}
22
23impl DirectExecutor {
24 pub fn new(tools: Arc<ToolRegistry>) -> Self {
25 Self { tools }
26 }
27}
28
29#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
30#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
31impl ToolExecutor for DirectExecutor {
32 async fn execute(
33 &self,
34 definition: &ToolDefinition,
35 call_id: &str,
36 arguments: serde_json::Value,
37 partial_tx: Option<mpsc::UnboundedSender<String>>,
38 ) -> SoulResult<ToolOutput> {
39 let tool = self
40 .tools
41 .get(&definition.name)
42 .ok_or_else(|| SoulError::ToolExecution {
43 tool_name: definition.name.clone(),
44 message: format!("Unknown tool: {}", definition.name),
45 })?;
46
47 tool.execute(call_id, arguments, partial_tx).await
48 }
49
50 fn executor_name(&self) -> &str {
51 "direct"
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use crate::tool::Tool;
59 use serde_json::json;
60
61 struct MockTool;
62
63 #[async_trait]
64 impl Tool for MockTool {
65 fn name(&self) -> &str {
66 "mock"
67 }
68
69 fn definition(&self) -> ToolDefinition {
70 ToolDefinition {
71 name: "mock".into(),
72 description: "Mock tool".into(),
73 input_schema: json!({"type": "object"}),
74 }
75 }
76
77 async fn execute(
78 &self,
79 _call_id: &str,
80 _arguments: serde_json::Value,
81 _partial_tx: Option<mpsc::UnboundedSender<String>>,
82 ) -> SoulResult<ToolOutput> {
83 Ok(ToolOutput::success("mock result"))
84 }
85 }
86
87 #[tokio::test]
88 async fn delegates_to_tool_registry() {
89 let mut registry = ToolRegistry::new();
90 registry.register(Box::new(MockTool));
91 let executor = DirectExecutor::new(Arc::new(registry));
92
93 let def = ToolDefinition {
94 name: "mock".into(),
95 description: "".into(),
96 input_schema: json!({}),
97 };
98
99 let result = executor.execute(&def, "c1", json!({}), None).await.unwrap();
100 assert_eq!(result.content, "mock result");
101 }
102
103 #[tokio::test]
104 async fn unknown_tool_errors() {
105 let registry = ToolRegistry::new();
106 let executor = DirectExecutor::new(Arc::new(registry));
107
108 let def = ToolDefinition {
109 name: "nonexistent".into(),
110 description: "".into(),
111 input_schema: json!({}),
112 };
113
114 let result = executor.execute(&def, "c1", json!({}), None).await;
115 assert!(result.is_err());
116 }
117
118 #[test]
119 fn executor_name_is_direct() {
120 let registry = ToolRegistry::new();
121 let executor = DirectExecutor::new(Arc::new(registry));
122 assert_eq!(executor.executor_name(), "direct");
123 }
124
125 #[test]
126 fn is_send_sync() {
127 fn assert_send_sync<T: Send + Sync>() {}
128 assert_send_sync::<DirectExecutor>();
129 }
130}