Skip to main content

traitclaw_core/traits/
execution_strategy.rs

1//! Tool execution strategies for configurable concurrency.
2//!
3//! By default, tools are executed sequentially. Use [`ParallelStrategy`] for
4//! concurrent execution or [`AdaptiveStrategy`] to let the [`Tracker`] decide.
5//!
6//! [`Tracker`]: crate::traits::tracker::Tracker
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::traits::guard::{Guard, GuardResult};
13use crate::traits::tool::ErasedTool;
14use crate::traits::tracker::Tracker;
15use crate::types::action::Action;
16use crate::types::agent_state::AgentState;
17use crate::types::tool_call::ToolCall;
18
19/// A pending tool call to be executed by a strategy.
20#[derive(Debug, Clone)]
21pub struct PendingToolCall {
22    /// Unique identifier for the tool call.
23    pub id: String,
24    /// Name of the tool to invoke.
25    pub name: String,
26    /// JSON arguments for the tool.
27    pub arguments: serde_json::Value,
28}
29
30impl From<&ToolCall> for PendingToolCall {
31    fn from(tc: &ToolCall) -> Self {
32        Self {
33            id: tc.id.clone(),
34            name: tc.name.clone(),
35            arguments: tc.arguments.clone(),
36        }
37    }
38}
39
40/// The result of executing a single tool call.
41#[derive(Debug, Clone)]
42pub struct ToolResult {
43    /// ID of the tool call this result corresponds to.
44    pub id: String,
45    /// Output string (may be an error message if execution failed).
46    pub output: String,
47}
48
49/// Trait for pluggable tool execution strategies.
50///
51/// Implementations control how a batch of tool calls are executed —
52/// sequentially, in parallel, or with custom logic.
53#[async_trait]
54pub trait ExecutionStrategy: Send + Sync {
55    /// Execute a batch of tool calls and return results.
56    async fn execute_batch(
57        &self,
58        calls: Vec<PendingToolCall>,
59        tools: &[Arc<dyn ErasedTool>],
60        guards: &[Arc<dyn Guard>],
61        state: &AgentState,
62    ) -> Vec<ToolResult>;
63}
64
65// ───────────────────────────── Sequential ─────────────────────────────
66
67/// Execute tool calls one at a time in order.
68///
69/// This is the default strategy — safe and predictable.
70pub struct SequentialStrategy;
71
72#[async_trait]
73impl ExecutionStrategy for SequentialStrategy {
74    async fn execute_batch(
75        &self,
76        calls: Vec<PendingToolCall>,
77        tools: &[Arc<dyn ErasedTool>],
78        guards: &[Arc<dyn Guard>],
79        _state: &AgentState,
80    ) -> Vec<ToolResult> {
81        let mut results = Vec::with_capacity(calls.len());
82        for call in calls {
83            let output = execute_single(&call, tools, guards).await;
84            results.push(ToolResult {
85                id: call.id,
86                output,
87            });
88        }
89        results
90    }
91}
92
93// ───────────────────────────── Parallel ─────────────────────────────
94
95/// Execute tool calls concurrently with bounded concurrency.
96pub struct ParallelStrategy {
97    /// Maximum number of concurrent tool executions.
98    pub max_concurrency: usize,
99}
100
101impl ParallelStrategy {
102    /// Create a parallel strategy with the given concurrency limit.
103    #[must_use]
104    pub fn new(max_concurrency: usize) -> Self {
105        Self {
106            max_concurrency: max_concurrency.max(1),
107        }
108    }
109}
110
111#[async_trait]
112impl ExecutionStrategy for ParallelStrategy {
113    async fn execute_batch(
114        &self,
115        calls: Vec<PendingToolCall>,
116        tools: &[Arc<dyn ErasedTool>],
117        guards: &[Arc<dyn Guard>],
118        _state: &AgentState,
119    ) -> Vec<ToolResult> {
120        use tokio::sync::Semaphore;
121
122        let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
123        let tools = Arc::new(tools.to_vec());
124        let guards = Arc::new(guards.to_vec());
125
126        // P3 fix: pre-clone call IDs before moving calls into spawned tasks,
127        // so we can attribute errors correctly if a task panics.
128        let call_ids: Vec<String> = calls.iter().map(|c| c.id.clone()).collect();
129        let mut handles = Vec::with_capacity(calls.len());
130
131        for call in calls {
132            let sem = semaphore.clone();
133            let tools = tools.clone();
134            let guards = guards.clone();
135
136            handles.push(tokio::spawn(async move {
137                let _permit = sem.acquire().await.expect("semaphore closed");
138                let output = execute_single(&call, &tools, &guards).await;
139                ToolResult {
140                    id: call.id,
141                    output,
142                }
143            }));
144        }
145
146        let mut results = Vec::with_capacity(handles.len());
147        for (i, handle) in handles.into_iter().enumerate() {
148            match handle.await {
149                Ok(result) => results.push(result),
150                Err(e) => results.push(ToolResult {
151                    id: call_ids[i].clone(),
152                    output: format!("Error: task panicked: {e}"),
153                }),
154            }
155        }
156        results
157    }
158}
159
160// ───────────────────────────── Adaptive ─────────────────────────────
161
162/// Adaptive strategy that queries [`Tracker::recommended_concurrency()`] to
163/// decide whether to run sequentially or in parallel.
164pub struct AdaptiveStrategy {
165    tracker: Arc<dyn Tracker>,
166}
167
168impl AdaptiveStrategy {
169    /// Create an adaptive strategy that uses the given tracker.
170    #[must_use]
171    pub fn new(tracker: Arc<dyn Tracker>) -> Self {
172        Self { tracker }
173    }
174}
175
176#[async_trait]
177impl ExecutionStrategy for AdaptiveStrategy {
178    async fn execute_batch(
179        &self,
180        calls: Vec<PendingToolCall>,
181        tools: &[Arc<dyn ErasedTool>],
182        guards: &[Arc<dyn Guard>],
183        state: &AgentState,
184    ) -> Vec<ToolResult> {
185        let concurrency = self.tracker.recommended_concurrency(state);
186        if concurrency <= 1 {
187            SequentialStrategy
188                .execute_batch(calls, tools, guards, state)
189                .await
190        } else {
191            ParallelStrategy::new(concurrency)
192                .execute_batch(calls, tools, guards, state)
193                .await
194        }
195    }
196}
197
198// ───────────────────────────── Helpers ─────────────────────────────
199
200/// Execute a single tool call with guard checks.
201async fn execute_single(
202    call: &PendingToolCall,
203    tools: &[Arc<dyn ErasedTool>],
204    guards: &[Arc<dyn Guard>],
205) -> String {
206    let action = Action::ToolCall {
207        name: call.name.clone(),
208        arguments: call.arguments.clone(),
209    };
210
211    // Guard checks — catch_unwind for external guard safety (Story 7.4)
212    for guard in guards {
213        let guard_name = guard.name().to_string();
214        let action_ref = &action;
215
216        let guard_span = tracing::info_span!(
217            target: "traitclaw::guard",
218            "guard.check",
219            guard.name = guard_name.as_str(),
220            guard.result = tracing::field::Empty,
221        );
222        let _g = guard_span.enter();
223
224        let result =
225            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| guard.check(action_ref)));
226
227        match result {
228            Ok(GuardResult::Allow) => {
229                guard_span.record("guard.result", "allow");
230            }
231            Ok(GuardResult::Deny { reason, .. }) => {
232                guard_span.record("guard.result", "deny");
233                return format!("Error: Action blocked by guard: {reason}");
234            }
235            Ok(GuardResult::Sanitize { warning, .. }) => {
236                guard_span.record("guard.result", "sanitize");
237                tracing::info!(
238                    target: "traitclaw::guard",
239                    guard = guard_name.as_str(),
240                    "Guard sanitized: {warning}"
241                );
242            }
243            Err(_) => {
244                guard_span.record("guard.result", "panic");
245                // Guard panicked — default to Deny for safety (P3 code review)
246                tracing::warn!(
247                    target: "traitclaw::guard",
248                    guard = guard_name.as_str(),
249                    "Guard panicked — denying action for safety"
250                );
251                return format!("Error: Action blocked — guard '{guard_name}' panicked");
252            }
253        }
254    }
255
256    // Find and execute tool
257    let tool_span = tracing::info_span!(
258        target: "traitclaw::tool",
259        "tool.call",
260        tool.name = call.name.as_str(),
261        tool.success = tracing::field::Empty,
262    );
263    let _t = tool_span.enter();
264
265    if let Some(tool) = tools.iter().find(|t| t.name() == call.name) {
266        match tool.execute_json(call.arguments.clone()).await {
267            Ok(output) => {
268                tool_span.record("tool.success", true);
269                serde_json::to_string(&output)
270                    .unwrap_or_else(|e| format!("Error serializing output: {e}"))
271            }
272            Err(e) => {
273                tool_span.record("tool.success", false);
274                format!("Error executing tool: {e}")
275            }
276        }
277    } else {
278        tool_span.record("tool.success", false);
279        let available: Vec<_> = tools.iter().map(|t| t.name().to_string()).collect();
280        format!(
281            "Error: Tool '{}' not found. Available: {}",
282            call.name,
283            available.join(", ")
284        )
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::traits::guard::NoopGuard;
292
293    struct AddTool;
294
295    #[async_trait]
296    impl ErasedTool for AddTool {
297        fn name(&self) -> &'static str {
298            "add"
299        }
300        fn description(&self) -> &'static str {
301            "Adds two numbers"
302        }
303        fn schema(&self) -> crate::traits::tool::ToolSchema {
304            crate::traits::tool::ToolSchema {
305                name: "add".into(),
306                description: "add".into(),
307                parameters: serde_json::json!({}),
308            }
309        }
310        async fn execute_json(
311            &self,
312            _args: serde_json::Value,
313        ) -> std::result::Result<serde_json::Value, crate::Error> {
314            Ok(serde_json::json!("result"))
315        }
316    }
317
318    fn make_calls(n: usize) -> Vec<PendingToolCall> {
319        (0..n)
320            .map(|i| PendingToolCall {
321                id: format!("call-{i}"),
322                name: "add".into(),
323                arguments: serde_json::json!({}),
324            })
325            .collect()
326    }
327
328    #[tokio::test]
329    async fn test_sequential_executes_in_order() {
330        let strategy = SequentialStrategy;
331        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
332        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
333
334        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
335
336        let results = strategy
337            .execute_batch(make_calls(3), &tools, &guards, &state)
338            .await;
339
340        assert_eq!(results.len(), 3);
341        assert_eq!(results[0].id, "call-0");
342        assert_eq!(results[1].id, "call-1");
343        assert_eq!(results[2].id, "call-2");
344        // All should succeed
345        for r in &results {
346            assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
347        }
348    }
349
350    #[tokio::test]
351    async fn test_parallel_executes_concurrently() {
352        let strategy = ParallelStrategy::new(4);
353        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
354        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
355
356        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
357
358        let results = strategy
359            .execute_batch(make_calls(5), &tools, &guards, &state)
360            .await;
361
362        assert_eq!(results.len(), 5);
363        for r in &results {
364            assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
365        }
366    }
367
368    #[tokio::test]
369    async fn test_guard_blocks_propagate() {
370        use crate::traits::guard::{Guard, GuardResult};
371
372        struct DenyGuard;
373        impl Guard for DenyGuard {
374            fn name(&self) -> &'static str {
375                "deny"
376            }
377            fn check(&self, _action: &Action) -> GuardResult {
378                GuardResult::Deny {
379                    reason: "blocked".into(),
380                    severity: crate::traits::guard::GuardSeverity::High,
381                }
382            }
383        }
384
385        let strategy = SequentialStrategy;
386        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
387        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(DenyGuard)];
388
389        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
390
391        let results = strategy
392            .execute_batch(make_calls(1), &tools, &guards, &state)
393            .await;
394
395        assert_eq!(results.len(), 1);
396        assert!(results[0].output.contains("blocked"));
397    }
398
399    #[tokio::test]
400    async fn test_guard_panic_defaults_to_deny() {
401        use crate::traits::guard::{Guard, GuardResult};
402
403        struct PanicGuard;
404        impl Guard for PanicGuard {
405            fn name(&self) -> &'static str {
406                "panic_guard"
407            }
408            fn check(&self, _action: &Action) -> GuardResult {
409                panic!("intentional panic in guard");
410            }
411        }
412
413        let strategy = SequentialStrategy;
414        let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
415        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(PanicGuard)];
416
417        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
418
419        let results = strategy
420            .execute_batch(make_calls(1), &tools, &guards, &state)
421            .await;
422
423        assert_eq!(results.len(), 1);
424        // P3: panicking guard should deny, not allow
425        assert!(
426            results[0].output.contains("panicked"),
427            "Expected deny on panic, got: {}",
428            results[0].output
429        );
430    }
431
432    #[tokio::test]
433    async fn test_tool_not_found_returns_error() {
434        let strategy = SequentialStrategy;
435        let tools: Vec<Arc<dyn ErasedTool>> = vec![]; // no tools registered
436        let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
437
438        let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
439
440        let calls = vec![PendingToolCall {
441            id: "c1".into(),
442            name: "nonexistent".into(),
443            arguments: serde_json::json!({}),
444        }];
445
446        let results = strategy.execute_batch(calls, &tools, &guards, &state).await;
447
448        assert_eq!(results.len(), 1);
449        assert!(
450            results[0].output.contains("not found"),
451            "Expected 'not found', got: {}",
452            results[0].output
453        );
454    }
455}