Skip to main content

traitclaw_core/traits/
execution_strategy.rs

1//! Tool execution strategies for configurable concurrency.
2//!
3//! By default, tools are executed sequentially. Use [`ParallelStrategy`] for
4//! concurrent execution or [`AdaptiveStrategy`] to let the [`Tracker`] decide.
5//!
6//! [`Tracker`]: crate::traits::tracker::Tracker
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::traits::guard::{Guard, GuardResult};
13use crate::traits::tool::ErasedTool;
14use crate::traits::tracker::Tracker;
15use crate::types::action::Action;
16use crate::types::agent_state::AgentState;
17use crate::types::tool_call::ToolCall;
18
19/// A pending tool call to be executed by a strategy.
20#[derive(Debug, Clone)]
21pub struct PendingToolCall {
22    /// Unique identifier for the tool call.
23    pub id: String,
24    /// Name of the tool to invoke.
25    pub name: String,
26    /// JSON arguments for the tool.
27    pub arguments: serde_json::Value,
28}
29
30impl From<&ToolCall> for PendingToolCall {
31    fn from(tc: &ToolCall) -> Self {
32        Self {
33            id: tc.id.clone(),
34            name: tc.name.clone(),
35            arguments: tc.arguments.clone(),
36        }
37    }
38}
39
40/// The result of executing a single tool call.
41#[derive(Debug, Clone)]
42pub struct ToolResult {
43    /// ID of the tool call this result corresponds to.
44    pub id: String,
45    /// Output string (may be an error message if execution failed).
46    pub output: String,
47}
48
49/// Trait for pluggable tool execution strategies.
50///
51/// Implementations control how a batch of tool calls are executed —
52/// sequentially, in parallel, or with custom logic.
53#[async_trait]
54pub trait ExecutionStrategy: Send + Sync {
55    /// Execute a batch of tool calls and return results.
56    async fn execute_batch(
57        &self,
58        calls: Vec<PendingToolCall>,
59        tools: &[Arc<dyn ErasedTool>],
60        guards: &[Arc<dyn Guard>],
61        state: &AgentState,
62    ) -> Vec<ToolResult>;
63}
64
65// ───────────────────────────── Sequential ─────────────────────────────
66
67/// Execute tool calls one at a time in order.
68///
69/// This is the default strategy — safe and predictable.
70pub struct SequentialStrategy;
71
72#[async_trait]
73impl ExecutionStrategy for SequentialStrategy {
74    async fn execute_batch(
75        &self,
76        calls: Vec<PendingToolCall>,
77        tools: &[Arc<dyn ErasedTool>],
78        guards: &[Arc<dyn Guard>],
79        _state: &AgentState,
80    ) -> Vec<ToolResult> {
81        let mut results = Vec::with_capacity(calls.len());
82        for call in calls {
83            let output = execute_single(&call, tools, guards).await;
84            results.push(ToolResult {
85                id: call.id,
86                output,
87            });
88        }
89        results
90    }
91}
92
93// ───────────────────────────── Parallel ─────────────────────────────
94
95/// Execute tool calls concurrently with bounded concurrency.
96pub struct ParallelStrategy {
97    /// Maximum number of concurrent tool executions.
98    pub max_concurrency: usize,
99}
100
101impl ParallelStrategy {
102    /// Create a parallel strategy with the given concurrency limit.
103    #[must_use]
104    pub fn new(max_concurrency: usize) -> Self {
105        Self {
106            max_concurrency: max_concurrency.max(1),
107        }
108    }
109}
110
111#[async_trait]
112impl ExecutionStrategy for ParallelStrategy {
113    async fn execute_batch(
114        &self,
115        calls: Vec<PendingToolCall>,
116        tools: &[Arc<dyn ErasedTool>],
117        guards: &[Arc<dyn Guard>],
118        _state: &AgentState,
119    ) -> Vec<ToolResult> {
120        use tokio::sync::Semaphore;
121
122        let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
123        let tools = Arc::new(tools.to_vec());
124        let guards = Arc::new(guards.to_vec());
125
126        // P3 fix: pre-clone call IDs before moving calls into spawned tasks,
127        // so we can attribute errors correctly if a task panics.
128        let call_ids: Vec<String> = calls.iter().map(|c| c.id.clone()).collect();
129        let mut handles = Vec::with_capacity(calls.len());
130
131        for call in calls {
132            let sem = semaphore.clone();
133            let tools = tools.clone();
134            let guards = guards.clone();
135
136            handles.push(tokio::spawn(async move {
137                let _permit = sem.acquire().await.expect("semaphore closed");
138                let output = execute_single(&call, &tools, &guards).await;
139                ToolResult {
140                    id: call.id,
141                    output,
142                }
143            }));
144        }
145
146        let mut results = Vec::with_capacity(handles.len());
147        for (i, handle) in handles.into_iter().enumerate() {
148            match handle.await {
149                Ok(result) => results.push(result),
150                Err(e) => results.push(ToolResult {
151                    id: call_ids[i].clone(),
152                    output: format!("Error: task panicked: {e}"),
153                }),
154            }
155        }
156        results
157    }
158}
159
160// ───────────────────────────── Adaptive ─────────────────────────────
161
162/// Adaptive strategy that queries [`Tracker::recommended_concurrency()`] to
163/// decide whether to run sequentially or in parallel.
164pub struct AdaptiveStrategy {
165    tracker: Arc<dyn Tracker>,
166}
167
168impl AdaptiveStrategy {
169    /// Create an adaptive strategy that uses the given tracker.
170    #[must_use]
171    pub fn new(tracker: Arc<dyn Tracker>) -> Self {
172        Self { tracker }
173    }
174}
175
176#[async_trait]
177impl ExecutionStrategy for AdaptiveStrategy {
178    async fn execute_batch(
179        &self,
180        calls: Vec<PendingToolCall>,
181        tools: &[Arc<dyn ErasedTool>],
182        guards: &[Arc<dyn Guard>],
183        state: &AgentState,
184    ) -> Vec<ToolResult> {
185        let concurrency = self.tracker.recommended_concurrency(state);
186        if concurrency <= 1 {
187            SequentialStrategy
188                .execute_batch(calls, tools, guards, state)
189                .await
190        } else {
191            ParallelStrategy::new(concurrency)
192                .execute_batch(calls, tools, guards, state)
193                .await
194        }
195    }
196}
197
198// ───────────────────────────── Helpers ─────────────────────────────
199
200/// Execute a single tool call with guard checks.
201async fn execute_single(
202    call: &PendingToolCall,
203    tools: &[Arc<dyn ErasedTool>],
204    guards: &[Arc<dyn Guard>],
205) -> String {
206    let action = Action::ToolCall {
207        name: call.name.clone(),
208        arguments: call.arguments.clone(),
209    };
210
211    // Guard checks — catch_unwind for external guard safety (Story 7.4)
212    for guard in guards {
213        let guard_name = guard.name().to_string();
214        let action_ref = &action;
215        let result =
216            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| guard.check(action_ref)));
217
218        match result {
219            Ok(GuardResult::Allow) => {}
220            Ok(GuardResult::Deny { reason, .. }) => {
221                return format!("Error: Action blocked by guard: {reason}");
222            }
223            Ok(GuardResult::Sanitize { warning, .. }) => {
224                tracing::info!(guard = guard_name.as_str(), "Guard sanitized: {warning}");
225            }
226            Err(_) => {
227                // Guard panicked — default to Deny for safety (P3 code review)
228                tracing::warn!(
229                    guard = guard_name.as_str(),
230                    "Guard panicked — denying action for safety"
231                );
232                return format!("Error: Action blocked — guard '{guard_name}' panicked");
233            }
234        }
235    }
236
237    // Find and execute tool
238    if let Some(tool) = tools.iter().find(|t| t.name() == call.name) {
239        match tool.execute_json(call.arguments.clone()).await {
240            Ok(output) => serde_json::to_string(&output)
241                .unwrap_or_else(|e| format!("Error serializing output: {e}")),
242            Err(e) => format!("Error executing tool: {e}"),
243        }
244    } else {
245        let available: Vec<_> = tools.iter().map(|t| t.name().to_string()).collect();
246        format!(
247            "Error: Tool '{}' not found. Available: {}",
248            call.name,
249            available.join(", ")
250        )
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::traits::guard::NoopGuard;
258
259    struct AddTool;
260
261    #[async_trait]
262    impl ErasedTool for AddTool {
263        fn name(&self) -> &'static str {
264            "add"
265        }
266        fn description(&self) -> &'static str {
267            "Adds two numbers"
268        }
269        fn schema(&self) -> crate::traits::tool::ToolSchema {
270            crate::traits::tool::ToolSchema {
271                name: "add".into(),
272                description: "add".into(),
273                parameters: serde_json::json!({}),
274            }
275        }
276        async fn execute_json(
277            &self,
278            _args: serde_json::Value,
279        ) -> std::result::Result<serde_json::Value, crate::Error> {
280            Ok(serde_json::json!("result"))
281        }
282    }
283
284    fn make_calls(n: usize) -> Vec<PendingToolCall> {
285        (0..n)
286            .map(|i| PendingToolCall {
287                id: format!("call-{i}"),
288                name: "add".into(),
289                arguments: serde_json::json!({}),
290            })
291            .collect()
292    }
293
294    #[tokio::test]
295    async fn test_sequential_executes_in_order() {
296        let strategy = SequentialStrategy;
297        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
298        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
299
300        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
301
302        let results = strategy
303            .execute_batch(make_calls(3), &tools, &guards, &state)
304            .await;
305
306        assert_eq!(results.len(), 3);
307        assert_eq!(results[0].id, "call-0");
308        assert_eq!(results[1].id, "call-1");
309        assert_eq!(results[2].id, "call-2");
310        // All should succeed
311        for r in &results {
312            assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
313        }
314    }
315
316    #[tokio::test]
317    async fn test_parallel_executes_concurrently() {
318        let strategy = ParallelStrategy::new(4);
319        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
320        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
321
322        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
323
324        let results = strategy
325            .execute_batch(make_calls(5), &tools, &guards, &state)
326            .await;
327
328        assert_eq!(results.len(), 5);
329        for r in &results {
330            assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
331        }
332    }
333
334    #[tokio::test]
335    async fn test_guard_blocks_propagate() {
336        use crate::traits::guard::{Guard, GuardResult};
337
338        struct DenyGuard;
339        impl Guard for DenyGuard {
340            fn name(&self) -> &'static str {
341                "deny"
342            }
343            fn check(&self, _action: &Action) -> GuardResult {
344                GuardResult::Deny {
345                    reason: "blocked".into(),
346                    severity: crate::traits::guard::GuardSeverity::High,
347                }
348            }
349        }
350
351        let strategy = SequentialStrategy;
352        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
353        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(DenyGuard)];
354
355        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
356
357        let results = strategy
358            .execute_batch(make_calls(1), &tools, &guards, &state)
359            .await;
360
361        assert_eq!(results.len(), 1);
362        assert!(results[0].output.contains("blocked"));
363    }
364
365    #[tokio::test]
366    async fn test_guard_panic_defaults_to_deny() {
367        use crate::traits::guard::{Guard, GuardResult};
368
369        struct PanicGuard;
370        impl Guard for PanicGuard {
371            fn name(&self) -> &'static str {
372                "panic_guard"
373            }
374            fn check(&self, _action: &Action) -> GuardResult {
375                panic!("intentional panic in guard");
376            }
377        }
378
379        let strategy = SequentialStrategy;
380        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
381        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(PanicGuard)];
382
383        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
384
385        let results = strategy
386            .execute_batch(make_calls(1), &tools, &guards, &state)
387            .await;
388
389        assert_eq!(results.len(), 1);
390        // P3: panicking guard should deny, not allow
391        assert!(
392            results[0].output.contains("panicked"),
393            "Expected deny on panic, got: {}",
394            results[0].output
395        );
396    }
397
398    #[tokio::test]
399    async fn test_tool_not_found_returns_error() {
400        let strategy = SequentialStrategy;
401        let tools: Vec<Arc<dyn ErasedTool>> = vec![]; // no tools registered
402        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
403
404        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
405
406        let calls = vec![PendingToolCall {
407            id: "c1".into(),
408            name: "nonexistent".into(),
409            arguments: serde_json::json!({}),
410        }];
411
412        let results = strategy.execute_batch(calls, &tools, &guards, &state).await;
413
414        assert_eq!(results.len(), 1);
415        assert!(
416            results[0].output.contains("not found"),
417            "Expected 'not found', got: {}",
418            results[0].output
419        );
420    }
421}