Skip to main content

symbi_runtime/reasoning/
executor.rs

1//! Action executor with parallel dispatch
2//!
3//! Executes approved actions concurrently using `FuturesUnordered`,
4//! with per-tool timeouts, circuit breaker integration, and barrier
5//! sync before returning observations.
6
7use crate::reasoning::circuit_breaker::CircuitBreakerRegistry;
8use crate::reasoning::inference::ToolDefinition;
9use crate::reasoning::loop_types::{LoopConfig, Observation, ProposedAction};
10use async_trait::async_trait;
11use futures::stream::{FuturesUnordered, StreamExt};
12use std::time::Duration;
13
14/// Trait for executing proposed actions and producing observations.
15#[async_trait]
16pub trait ActionExecutor: Send + Sync {
17    /// Execute a batch of approved actions, potentially in parallel.
18    ///
19    /// Returns observations from all action results. Circuit breakers
20    /// are checked before each dispatch.
21    async fn execute_actions(
22        &self,
23        actions: &[ProposedAction],
24        config: &LoopConfig,
25        circuit_breakers: &CircuitBreakerRegistry,
26    ) -> Vec<Observation>;
27
28    /// Return tool definitions this executor can handle.
29    ///
30    /// The runner auto-populates `LoopConfig.tool_definitions` from this
31    /// when the config's list is empty. Override in executors that discover
32    /// tools dynamically (e.g., `ComposioToolExecutor`).
33    fn tool_definitions(&self) -> Vec<ToolDefinition> {
34        Vec::new()
35    }
36}
37
38/// Default executor that dispatches tool calls in parallel.
39pub struct DefaultActionExecutor {
40    tool_timeout: Duration,
41}
42
43impl DefaultActionExecutor {
44    pub fn new(tool_timeout: Duration) -> Self {
45        Self { tool_timeout }
46    }
47}
48
49impl Default for DefaultActionExecutor {
50    fn default() -> Self {
51        Self::new(Duration::from_secs(30))
52    }
53}
54
55#[async_trait]
56impl ActionExecutor for DefaultActionExecutor {
57    async fn execute_actions(
58        &self,
59        actions: &[ProposedAction],
60        config: &LoopConfig,
61        circuit_breakers: &CircuitBreakerRegistry,
62    ) -> Vec<Observation> {
63        let tool_calls: Vec<&ProposedAction> = actions
64            .iter()
65            .filter(|a| matches!(a, ProposedAction::ToolCall { .. }))
66            .collect();
67
68        if tool_calls.is_empty() {
69            return Vec::new();
70        }
71
72        let timeout = self.tool_timeout.min(config.tool_timeout);
73
74        // Dispatch tool calls concurrently
75        let mut futures = FuturesUnordered::new();
76
77        for action in &tool_calls {
78            if let ProposedAction::ToolCall {
79                call_id,
80                name,
81                arguments,
82            } = action
83            {
84                let name = name.clone();
85                let arguments = arguments.clone();
86                let call_id = call_id.clone();
87
88                // Check circuit breaker first
89                let cb_result = circuit_breakers.check(&name).await;
90
91                futures.push(async move {
92                    if let Err(cb_err) = cb_result {
93                        return Observation {
94                            source: name,
95                            content: format!(
96                                "Tool circuit is open: {}. The tool endpoint has been failing and is temporarily disabled.",
97                                cb_err
98                            ),
99                            is_error: true,
100                            call_id: Some(call_id),
101                            metadata: {
102                                let mut m = std::collections::HashMap::new();
103                                m.insert("error_type".into(), "circuit_open".into());
104                                m
105                            },
106                        };
107                    }
108
109                    // Execute the tool call with timeout
110                    let result = tokio::time::timeout(timeout, async {
111                        // In production, this would call the ToolInvocationEnforcer.
112                        // For now, produce an observation indicating the tool was called.
113                        execute_tool_call(&name, &arguments).await
114                    })
115                    .await;
116
117                    match result {
118                        Ok(Ok(content)) => {
119                            Observation::tool_result(&name, content).with_call_id(call_id)
120                        }
121                        Ok(Err(err)) => {
122                            Observation::tool_error(&name, err).with_call_id(call_id)
123                        }
124                        Err(_) => Observation {
125                            source: name.clone(),
126                            content: format!(
127                                "Tool '{}' timed out after {:?}",
128                                name, timeout
129                            ),
130                            is_error: true,
131                            call_id: Some(call_id),
132                            metadata: {
133                                let mut m = std::collections::HashMap::new();
134                                m.insert("error_type".into(), "timeout".into());
135                                m
136                            },
137                        },
138                    }
139                });
140            }
141        }
142
143        // Barrier sync: wait for all tool calls to complete
144        let mut observations = Vec::with_capacity(tool_calls.len());
145        while let Some(obs) = futures.next().await {
146            // Record success/failure in circuit breaker
147            let tool_name = obs
148                .metadata
149                .get("tool_name")
150                .cloned()
151                .unwrap_or_else(|| obs.source.clone());
152            if obs.is_error {
153                circuit_breakers.record_failure(&tool_name).await;
154            } else {
155                circuit_breakers.record_success(&tool_name).await;
156            }
157            observations.push(obs);
158        }
159
160        observations
161    }
162}
163
164/// Execute a single tool call. In production, this delegates to the
165/// ToolInvocationEnforcer → MCP client pipeline. This default implementation
166/// returns the arguments as the "result" for testing purposes.
167async fn execute_tool_call(name: &str, arguments: &str) -> Result<String, String> {
168    tracing::debug!("Executing tool '{}' with arguments: {}", name, arguments);
169    // Production implementation would call through ToolInvocationEnforcer here.
170    // For the reasoning loop infrastructure, we return a placeholder.
171    Ok(format!(
172        "Tool '{}' executed successfully with arguments: {}",
173        name, arguments
174    ))
175}
176
177/// An executor that delegates to a real ToolInvocationEnforcer.
178pub struct EnforcedActionExecutor {
179    enforcer: std::sync::Arc<dyn crate::integrations::tool_invocation::ToolInvocationEnforcer>,
180}
181
182impl EnforcedActionExecutor {
183    pub fn new(
184        enforcer: std::sync::Arc<dyn crate::integrations::tool_invocation::ToolInvocationEnforcer>,
185    ) -> Self {
186        Self { enforcer }
187    }
188}
189
190#[async_trait]
191impl ActionExecutor for EnforcedActionExecutor {
192    async fn execute_actions(
193        &self,
194        actions: &[ProposedAction],
195        config: &LoopConfig,
196        circuit_breakers: &CircuitBreakerRegistry,
197    ) -> Vec<Observation> {
198        let tool_calls: Vec<&ProposedAction> = actions
199            .iter()
200            .filter(|a| matches!(a, ProposedAction::ToolCall { .. }))
201            .collect();
202
203        if tool_calls.is_empty() {
204            return Vec::new();
205        }
206
207        let timeout = config.tool_timeout;
208        let mut futures = FuturesUnordered::new();
209
210        for action in &tool_calls {
211            if let ProposedAction::ToolCall {
212                call_id,
213                name,
214                arguments,
215            } = action
216            {
217                let name = name.clone();
218                let arguments = arguments.clone();
219                let call_id = call_id.clone();
220                let enforcer = self.enforcer.clone();
221
222                let cb_result = circuit_breakers.check(&name).await;
223
224                futures.push(async move {
225                    if let Err(cb_err) = cb_result {
226                        return Observation {
227                            source: name,
228                            content: format!("Tool circuit is open: {}", cb_err),
229                            is_error: true,
230                            call_id: Some(call_id),
231                            metadata: {
232                                let mut m = std::collections::HashMap::new();
233                                m.insert("error_type".into(), "circuit_open".into());
234                                m
235                            },
236                        };
237                    }
238
239                    let tool = crate::integrations::mcp::McpTool {
240                        name: name.clone(),
241                        description: String::new(),
242                        schema: serde_json::json!({}),
243                        provider: crate::integrations::mcp::ToolProvider {
244                            identifier: "reasoning_loop".into(),
245                            name: "Reasoning Loop".into(),
246                            public_key_url: String::new(),
247                            version: None,
248                        },
249                        verification_status:
250                            crate::integrations::mcp::VerificationStatus::Skipped {
251                                reason: "Invoked via reasoning loop".into(),
252                            },
253                        metadata: None,
254                        sensitive_params: vec![],
255                    };
256
257                    let args: serde_json::Value =
258                        serde_json::from_str(&arguments).unwrap_or(serde_json::json!({}));
259
260                    let context = crate::integrations::tool_invocation::InvocationContext {
261                        agent_id: crate::types::AgentId::new(),
262                        tool_name: name.clone(),
263                        arguments: args,
264                        timestamp: chrono::Utc::now(),
265                        metadata: std::collections::HashMap::new(),
266                        agent_credential: None,
267                    };
268
269                    match tokio::time::timeout(
270                        timeout,
271                        enforcer.execute_tool_with_enforcement(&tool, context),
272                    )
273                    .await
274                    {
275                        Ok(Ok(result)) => {
276                            Observation::tool_result(&name, result.result.to_string())
277                                .with_call_id(call_id)
278                        }
279                        Ok(Err(err)) => {
280                            Observation::tool_error(&name, err.to_string()).with_call_id(call_id)
281                        }
282                        Err(_) => Observation {
283                            source: name.clone(),
284                            content: format!("Tool '{}' timed out", name),
285                            is_error: true,
286                            call_id: Some(call_id),
287                            metadata: {
288                                let mut m = std::collections::HashMap::new();
289                                m.insert("error_type".into(), "timeout".into());
290                                m
291                            },
292                        },
293                    }
294                });
295            }
296        }
297
298        let mut observations = Vec::with_capacity(tool_calls.len());
299        while let Some(obs) = futures.next().await {
300            let tool_name = obs
301                .metadata
302                .get("tool_name")
303                .cloned()
304                .unwrap_or_else(|| obs.source.clone());
305            if obs.is_error {
306                circuit_breakers.record_failure(&tool_name).await;
307            } else {
308                circuit_breakers.record_success(&tool_name).await;
309            }
310            observations.push(obs);
311        }
312
313        observations
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[tokio::test]
322    async fn test_default_executor_no_actions() {
323        let executor = DefaultActionExecutor::default();
324        let config = LoopConfig::default();
325        let circuit_breakers = CircuitBreakerRegistry::default();
326
327        let obs = executor
328            .execute_actions(&[], &config, &circuit_breakers)
329            .await;
330        assert!(obs.is_empty());
331    }
332
333    #[tokio::test]
334    async fn test_default_executor_single_tool() {
335        let executor = DefaultActionExecutor::default();
336        let config = LoopConfig::default();
337        let circuit_breakers = CircuitBreakerRegistry::default();
338
339        let actions = vec![ProposedAction::ToolCall {
340            call_id: "c1".into(),
341            name: "search".into(),
342            arguments: r#"{"q": "test"}"#.into(),
343        }];
344
345        let obs = executor
346            .execute_actions(&actions, &config, &circuit_breakers)
347            .await;
348        assert_eq!(obs.len(), 1);
349        assert!(!obs[0].is_error);
350        assert_eq!(obs[0].source, "search");
351        assert_eq!(obs[0].call_id.as_deref(), Some("c1"));
352    }
353
354    #[tokio::test]
355    async fn test_default_executor_parallel_dispatch() {
356        let executor = DefaultActionExecutor::default();
357        let config = LoopConfig::default();
358        let circuit_breakers = CircuitBreakerRegistry::default();
359
360        let actions: Vec<ProposedAction> = (0..3)
361            .map(|i| ProposedAction::ToolCall {
362                call_id: format!("c{}", i),
363                name: format!("tool_{}", i),
364                arguments: "{}".into(),
365            })
366            .collect();
367
368        let start = std::time::Instant::now();
369        let obs = executor
370            .execute_actions(&actions, &config, &circuit_breakers)
371            .await;
372        let elapsed = start.elapsed();
373
374        assert_eq!(obs.len(), 3);
375        // All should succeed
376        assert!(obs.iter().all(|o| !o.is_error));
377        // Parallel dispatch means wall-clock ≈ max(individual), not sum
378        // Individual calls are near-instant in the default executor,
379        // so elapsed should be well under 100ms
380        assert!(
381            elapsed.as_millis() < 100,
382            "Parallel dispatch took {}ms, expected <100ms",
383            elapsed.as_millis()
384        );
385    }
386
387    #[tokio::test]
388    async fn test_executor_skips_non_tool_actions() {
389        let executor = DefaultActionExecutor::default();
390        let config = LoopConfig::default();
391        let circuit_breakers = CircuitBreakerRegistry::default();
392
393        let actions = vec![
394            ProposedAction::Respond {
395                content: "done".into(),
396            },
397            ProposedAction::Delegate {
398                target: "other".into(),
399                message: "hi".into(),
400            },
401        ];
402
403        let obs = executor
404            .execute_actions(&actions, &config, &circuit_breakers)
405            .await;
406        assert!(obs.is_empty());
407    }
408
409    #[test]
410    fn test_default_executor_has_empty_tool_definitions() {
411        let executor = DefaultActionExecutor::default();
412        assert!(executor.tool_definitions().is_empty());
413    }
414
415    #[tokio::test]
416    async fn test_executor_circuit_breaker_integration() {
417        let executor = DefaultActionExecutor::default();
418        let config = LoopConfig::default();
419        let circuit_breakers =
420            CircuitBreakerRegistry::new(crate::reasoning::circuit_breaker::CircuitBreakerConfig {
421                failure_threshold: 2,
422                recovery_timeout: std::time::Duration::from_secs(30),
423                half_open_max_calls: 1,
424            });
425
426        // Trip the circuit breaker for "failing_tool"
427        circuit_breakers.record_failure("failing_tool").await;
428        circuit_breakers.record_failure("failing_tool").await;
429
430        let actions = vec![ProposedAction::ToolCall {
431            call_id: "c1".into(),
432            name: "failing_tool".into(),
433            arguments: "{}".into(),
434        }];
435
436        let obs = executor
437            .execute_actions(&actions, &config, &circuit_breakers)
438            .await;
439        assert_eq!(obs.len(), 1);
440        assert!(obs[0].is_error);
441        assert!(obs[0].content.contains("circuit is open"));
442    }
443}