1use crate::reasoning::circuit_breaker::CircuitBreakerRegistry;
8use crate::reasoning::inference::ToolDefinition;
9use crate::reasoning::loop_types::{LoopConfig, Observation, ProposedAction};
10use async_trait::async_trait;
11use futures::stream::{FuturesUnordered, StreamExt};
12use std::time::Duration;
13
14#[async_trait]
16pub trait ActionExecutor: Send + Sync {
17 async fn execute_actions(
22 &self,
23 actions: &[ProposedAction],
24 config: &LoopConfig,
25 circuit_breakers: &CircuitBreakerRegistry,
26 ) -> Vec<Observation>;
27
28 fn tool_definitions(&self) -> Vec<ToolDefinition> {
34 Vec::new()
35 }
36}
37
38pub struct DefaultActionExecutor {
40 tool_timeout: Duration,
41}
42
43impl DefaultActionExecutor {
44 pub fn new(tool_timeout: Duration) -> Self {
45 Self { tool_timeout }
46 }
47}
48
49impl Default for DefaultActionExecutor {
50 fn default() -> Self {
51 Self::new(Duration::from_secs(30))
52 }
53}
54
55#[async_trait]
56impl ActionExecutor for DefaultActionExecutor {
57 async fn execute_actions(
58 &self,
59 actions: &[ProposedAction],
60 config: &LoopConfig,
61 circuit_breakers: &CircuitBreakerRegistry,
62 ) -> Vec<Observation> {
63 let tool_calls: Vec<&ProposedAction> = actions
64 .iter()
65 .filter(|a| matches!(a, ProposedAction::ToolCall { .. }))
66 .collect();
67
68 if tool_calls.is_empty() {
69 return Vec::new();
70 }
71
72 let timeout = self.tool_timeout.min(config.tool_timeout);
73
74 let mut futures = FuturesUnordered::new();
76
77 for action in &tool_calls {
78 if let ProposedAction::ToolCall {
79 call_id,
80 name,
81 arguments,
82 } = action
83 {
84 let name = name.clone();
85 let arguments = arguments.clone();
86 let call_id = call_id.clone();
87
88 let cb_result = circuit_breakers.check(&name).await;
90
91 futures.push(async move {
92 if let Err(cb_err) = cb_result {
93 return Observation {
94 source: name,
95 content: format!(
96 "Tool circuit is open: {}. The tool endpoint has been failing and is temporarily disabled.",
97 cb_err
98 ),
99 is_error: true,
100 call_id: Some(call_id),
101 metadata: {
102 let mut m = std::collections::HashMap::new();
103 m.insert("error_type".into(), "circuit_open".into());
104 m
105 },
106 };
107 }
108
109 let result = tokio::time::timeout(timeout, async {
111 execute_tool_call(&name, &arguments).await
114 })
115 .await;
116
117 match result {
118 Ok(Ok(content)) => {
119 Observation::tool_result(&name, content).with_call_id(call_id)
120 }
121 Ok(Err(err)) => {
122 Observation::tool_error(&name, err).with_call_id(call_id)
123 }
124 Err(_) => Observation {
125 source: name.clone(),
126 content: format!(
127 "Tool '{}' timed out after {:?}",
128 name, timeout
129 ),
130 is_error: true,
131 call_id: Some(call_id),
132 metadata: {
133 let mut m = std::collections::HashMap::new();
134 m.insert("error_type".into(), "timeout".into());
135 m
136 },
137 },
138 }
139 });
140 }
141 }
142
143 let mut observations = Vec::with_capacity(tool_calls.len());
145 while let Some(obs) = futures.next().await {
146 let tool_name = obs
148 .metadata
149 .get("tool_name")
150 .cloned()
151 .unwrap_or_else(|| obs.source.clone());
152 if obs.is_error {
153 circuit_breakers.record_failure(&tool_name).await;
154 } else {
155 circuit_breakers.record_success(&tool_name).await;
156 }
157 observations.push(obs);
158 }
159
160 observations
161 }
162}
163
164async fn execute_tool_call(name: &str, arguments: &str) -> Result<String, String> {
168 tracing::debug!("Executing tool '{}' with arguments: {}", name, arguments);
169 Ok(format!(
172 "Tool '{}' executed successfully with arguments: {}",
173 name, arguments
174 ))
175}
176
177pub struct EnforcedActionExecutor {
179 enforcer: std::sync::Arc<dyn crate::integrations::tool_invocation::ToolInvocationEnforcer>,
180}
181
182impl EnforcedActionExecutor {
183 pub fn new(
184 enforcer: std::sync::Arc<dyn crate::integrations::tool_invocation::ToolInvocationEnforcer>,
185 ) -> Self {
186 Self { enforcer }
187 }
188}
189
190#[async_trait]
191impl ActionExecutor for EnforcedActionExecutor {
192 async fn execute_actions(
193 &self,
194 actions: &[ProposedAction],
195 config: &LoopConfig,
196 circuit_breakers: &CircuitBreakerRegistry,
197 ) -> Vec<Observation> {
198 let tool_calls: Vec<&ProposedAction> = actions
199 .iter()
200 .filter(|a| matches!(a, ProposedAction::ToolCall { .. }))
201 .collect();
202
203 if tool_calls.is_empty() {
204 return Vec::new();
205 }
206
207 let timeout = config.tool_timeout;
208 let mut futures = FuturesUnordered::new();
209
210 for action in &tool_calls {
211 if let ProposedAction::ToolCall {
212 call_id,
213 name,
214 arguments,
215 } = action
216 {
217 let name = name.clone();
218 let arguments = arguments.clone();
219 let call_id = call_id.clone();
220 let enforcer = self.enforcer.clone();
221
222 let cb_result = circuit_breakers.check(&name).await;
223
224 futures.push(async move {
225 if let Err(cb_err) = cb_result {
226 return Observation {
227 source: name,
228 content: format!("Tool circuit is open: {}", cb_err),
229 is_error: true,
230 call_id: Some(call_id),
231 metadata: {
232 let mut m = std::collections::HashMap::new();
233 m.insert("error_type".into(), "circuit_open".into());
234 m
235 },
236 };
237 }
238
239 let tool = crate::integrations::mcp::McpTool {
240 name: name.clone(),
241 description: String::new(),
242 schema: serde_json::json!({}),
243 provider: crate::integrations::mcp::ToolProvider {
244 identifier: "reasoning_loop".into(),
245 name: "Reasoning Loop".into(),
246 public_key_url: String::new(),
247 version: None,
248 },
249 verification_status:
250 crate::integrations::mcp::VerificationStatus::Skipped {
251 reason: "Invoked via reasoning loop".into(),
252 },
253 metadata: None,
254 sensitive_params: vec![],
255 };
256
257 let args: serde_json::Value =
258 serde_json::from_str(&arguments).unwrap_or(serde_json::json!({}));
259
260 let context = crate::integrations::tool_invocation::InvocationContext {
261 agent_id: crate::types::AgentId::new(),
262 tool_name: name.clone(),
263 arguments: args,
264 timestamp: chrono::Utc::now(),
265 metadata: std::collections::HashMap::new(),
266 agent_credential: None,
267 };
268
269 match tokio::time::timeout(
270 timeout,
271 enforcer.execute_tool_with_enforcement(&tool, context),
272 )
273 .await
274 {
275 Ok(Ok(result)) => {
276 Observation::tool_result(&name, result.result.to_string())
277 .with_call_id(call_id)
278 }
279 Ok(Err(err)) => {
280 Observation::tool_error(&name, err.to_string()).with_call_id(call_id)
281 }
282 Err(_) => Observation {
283 source: name.clone(),
284 content: format!("Tool '{}' timed out", name),
285 is_error: true,
286 call_id: Some(call_id),
287 metadata: {
288 let mut m = std::collections::HashMap::new();
289 m.insert("error_type".into(), "timeout".into());
290 m
291 },
292 },
293 }
294 });
295 }
296 }
297
298 let mut observations = Vec::with_capacity(tool_calls.len());
299 while let Some(obs) = futures.next().await {
300 let tool_name = obs
301 .metadata
302 .get("tool_name")
303 .cloned()
304 .unwrap_or_else(|| obs.source.clone());
305 if obs.is_error {
306 circuit_breakers.record_failure(&tool_name).await;
307 } else {
308 circuit_breakers.record_success(&tool_name).await;
309 }
310 observations.push(obs);
311 }
312
313 observations
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[tokio::test]
322 async fn test_default_executor_no_actions() {
323 let executor = DefaultActionExecutor::default();
324 let config = LoopConfig::default();
325 let circuit_breakers = CircuitBreakerRegistry::default();
326
327 let obs = executor
328 .execute_actions(&[], &config, &circuit_breakers)
329 .await;
330 assert!(obs.is_empty());
331 }
332
333 #[tokio::test]
334 async fn test_default_executor_single_tool() {
335 let executor = DefaultActionExecutor::default();
336 let config = LoopConfig::default();
337 let circuit_breakers = CircuitBreakerRegistry::default();
338
339 let actions = vec![ProposedAction::ToolCall {
340 call_id: "c1".into(),
341 name: "search".into(),
342 arguments: r#"{"q": "test"}"#.into(),
343 }];
344
345 let obs = executor
346 .execute_actions(&actions, &config, &circuit_breakers)
347 .await;
348 assert_eq!(obs.len(), 1);
349 assert!(!obs[0].is_error);
350 assert_eq!(obs[0].source, "search");
351 assert_eq!(obs[0].call_id.as_deref(), Some("c1"));
352 }
353
354 #[tokio::test]
355 async fn test_default_executor_parallel_dispatch() {
356 let executor = DefaultActionExecutor::default();
357 let config = LoopConfig::default();
358 let circuit_breakers = CircuitBreakerRegistry::default();
359
360 let actions: Vec<ProposedAction> = (0..3)
361 .map(|i| ProposedAction::ToolCall {
362 call_id: format!("c{}", i),
363 name: format!("tool_{}", i),
364 arguments: "{}".into(),
365 })
366 .collect();
367
368 let start = std::time::Instant::now();
369 let obs = executor
370 .execute_actions(&actions, &config, &circuit_breakers)
371 .await;
372 let elapsed = start.elapsed();
373
374 assert_eq!(obs.len(), 3);
375 assert!(obs.iter().all(|o| !o.is_error));
377 assert!(
381 elapsed.as_millis() < 100,
382 "Parallel dispatch took {}ms, expected <100ms",
383 elapsed.as_millis()
384 );
385 }
386
387 #[tokio::test]
388 async fn test_executor_skips_non_tool_actions() {
389 let executor = DefaultActionExecutor::default();
390 let config = LoopConfig::default();
391 let circuit_breakers = CircuitBreakerRegistry::default();
392
393 let actions = vec![
394 ProposedAction::Respond {
395 content: "done".into(),
396 },
397 ProposedAction::Delegate {
398 target: "other".into(),
399 message: "hi".into(),
400 },
401 ];
402
403 let obs = executor
404 .execute_actions(&actions, &config, &circuit_breakers)
405 .await;
406 assert!(obs.is_empty());
407 }
408
409 #[test]
410 fn test_default_executor_has_empty_tool_definitions() {
411 let executor = DefaultActionExecutor::default();
412 assert!(executor.tool_definitions().is_empty());
413 }
414
415 #[tokio::test]
416 async fn test_executor_circuit_breaker_integration() {
417 let executor = DefaultActionExecutor::default();
418 let config = LoopConfig::default();
419 let circuit_breakers =
420 CircuitBreakerRegistry::new(crate::reasoning::circuit_breaker::CircuitBreakerConfig {
421 failure_threshold: 2,
422 recovery_timeout: std::time::Duration::from_secs(30),
423 half_open_max_calls: 1,
424 });
425
426 circuit_breakers.record_failure("failing_tool").await;
428 circuit_breakers.record_failure("failing_tool").await;
429
430 let actions = vec![ProposedAction::ToolCall {
431 call_id: "c1".into(),
432 name: "failing_tool".into(),
433 arguments: "{}".into(),
434 }];
435
436 let obs = executor
437 .execute_actions(&actions, &config, &circuit_breakers)
438 .await;
439 assert_eq!(obs.len(), 1);
440 assert!(obs[0].is_error);
441 assert!(obs[0].content.contains("circuit is open"));
442 }
443}